1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use serde_json::Value;
6use tokio::sync::{mpsc, oneshot, Mutex};
7use tokio_util::sync::CancellationToken;
8
9use crate::error::{Error, Result};
10use crate::message_parser::parse_message;
11use crate::types::control::{SDKCapabilities, SDKControlCommand};
12use crate::types::hooks::{HookDecision, HookDefinition, HookEvent, HookInput};
13use crate::types::messages::Message;
14use crate::types::permissions::{CanUseToolCallback, CanUseToolInput};
15use crate::transport::{Transport, TransportWriter};
16
17const DEFAULT_CONTROL_TIMEOUT: Duration = Duration::from_secs(30);
18
19pub type McpMessageHandler = Arc<
21 dyn Fn(String, Value) -> std::pin::Pin<Box<dyn std::future::Future<Output = Value> + Send>>
22 + Send
23 + Sync,
24>;
25
26pub struct Query {
31 transport: Box<dyn Transport>,
32 writer: Option<TransportWriter>,
33 hooks: Vec<HookDefinition>,
34 can_use_tool: Option<CanUseToolCallback>,
35 mcp_handler: Option<McpMessageHandler>,
36 pending_responses: Arc<Mutex<HashMap<String, oneshot::Sender<Value>>>>,
37 cancel: CancellationToken,
38 control_timeout: Duration,
39 server_info: Arc<Mutex<Option<Value>>>,
40}
41
42impl Query {
43 pub fn new(
44 transport: Box<dyn Transport>,
45 hooks: Vec<HookDefinition>,
46 can_use_tool: Option<CanUseToolCallback>,
47 mcp_handler: Option<McpMessageHandler>,
48 control_timeout: Option<Duration>,
49 ) -> Self {
50 Self {
51 transport,
52 writer: None,
53 hooks,
54 can_use_tool,
55 mcp_handler,
56 pending_responses: Arc::new(Mutex::new(HashMap::new())),
57 cancel: CancellationToken::new(),
58 control_timeout: control_timeout.unwrap_or(DEFAULT_CONTROL_TIMEOUT),
59 server_info: Arc::new(Mutex::new(None)),
60 }
61 }
62
63 pub async fn connect(&mut self) -> Result<mpsc::Receiver<Result<Message>>> {
65 let (raw_rx, writer) = self.transport.connect().await?;
66 self.writer = Some(writer.clone());
67
68 let (consumer_tx, consumer_rx) = mpsc::channel::<Result<Message>>(256);
69
70 self.spawn_router(raw_rx, consumer_tx, writer);
72
73 self.initialize().await?;
75
76 Ok(consumer_rx)
77 }
78
79 pub async fn send_message(&self, prompt: &str, session_id: Option<&str>) -> Result<()> {
81 let writer = self.writer.as_ref().ok_or(Error::NotConnected)?;
82 let msg = serde_json::json!({
83 "type": "user",
84 "message": {
85 "role": "user",
86 "content": prompt
87 },
88 "session_id": session_id,
89 "parent_tool_use_id": null
90 });
91 writer.write(msg).await
92 }
93
94 pub async fn send_control_command(&self, command: SDKControlCommand) -> Result<Value> {
96 self.send_raw_control_request(command.to_request_body()).await
97 }
98
99 pub async fn interrupt(&self) -> Result<Value> {
100 self.send_control_command(SDKControlCommand::interrupt())
101 .await
102 }
103
104 pub async fn set_permission_mode(&self, mode: &str) -> Result<Value> {
105 self.send_control_command(SDKControlCommand::set_permission_mode(mode))
106 .await
107 }
108
109 pub async fn set_model(&self, model: &str) -> Result<Value> {
110 self.send_control_command(SDKControlCommand::set_model(model))
111 .await
112 }
113
114 pub async fn rewind_files(&self, user_message_id: &str) -> Result<Value> {
115 self.send_control_command(SDKControlCommand::rewind_files(user_message_id))
116 .await
117 }
118
119 pub async fn get_mcp_status(&self) -> Result<Value> {
120 self.send_control_command(SDKControlCommand::get_mcp_status())
121 .await
122 }
123
124 pub async fn get_server_info(&self) -> Option<Value> {
125 self.server_info.lock().await.clone()
126 }
127
128 #[allow(dead_code)]
129 pub async fn end_input(&self) -> Result<()> {
130 self.transport.end_input().await
131 }
132
133 pub async fn closed(&self) {
135 self.cancel.cancelled().await;
136 }
137
138 pub async fn close(&mut self) -> Result<()> {
139 self.cancel.cancel();
140 self.writer = None;
141 self.transport.close().await
142 }
143
144 async fn send_raw_control_request(&self, request_body: Value) -> Result<Value> {
146 let writer = self.writer.as_ref().ok_or(Error::NotConnected)?;
147 let request_id = generate_request_id();
148
149 let request = serde_json::json!({
150 "type": "control_request",
151 "request_id": request_id,
152 "request": request_body,
153 });
154
155 let (tx, rx) = oneshot::channel();
156 {
157 let mut pending = self.pending_responses.lock().await;
158 pending.insert(request_id.clone(), tx);
159 }
160
161 writer.write(request).await?;
162
163 tokio::time::timeout(self.control_timeout, rx)
164 .await
165 .map_err(|_| Error::ControlTimeout(self.control_timeout))?
166 .map_err(|_| Error::ControlProtocol("response channel dropped".into()))
167 }
168
169 async fn initialize(&self) -> Result<()> {
170 let capabilities = SDKCapabilities {
171 hooks: !self.hooks.is_empty(),
172 permissions: self.can_use_tool.is_some(),
173 mcp: self.mcp_handler.is_some(),
174 agent_definitions: vec![],
175 mcp_servers: vec![],
176 };
177
178 let response = self.send_raw_control_request(serde_json::json!({
179 "subtype": "initialize",
180 "protocol_version": "1",
181 "capabilities": capabilities,
182 })).await?;
183
184 {
185 let mut info = self.server_info.lock().await;
186 *info = Some(response);
187 }
188
189 Ok(())
190 }
191
192 fn spawn_router(
193 &self,
194 mut raw_rx: mpsc::Receiver<Result<Value>>,
195 consumer_tx: mpsc::Sender<Result<Message>>,
196 writer: TransportWriter,
197 ) {
198 let pending = self.pending_responses.clone();
199 let hooks = self.hooks.clone();
200 let can_use_tool = self.can_use_tool.clone();
201 let mcp_handler = self.mcp_handler.clone();
202 let cancel = self.cancel.clone();
203
204 tokio::spawn(async move {
205 loop {
206 tokio::select! {
207 _ = cancel.cancelled() => break,
208 msg = raw_rx.recv() => {
209 match msg {
210 Some(Ok(value)) => {
211 let msg_type = value.get("type")
212 .and_then(|v| v.as_str())
213 .unwrap_or("");
214
215 match msg_type {
216 "control_response" => {
217 route_control_response(&pending, &value).await;
218 }
219 "control_request" => {
220 dispatch_control_request(
221 &value,
222 &hooks,
223 &can_use_tool,
224 &mcp_handler,
225 &writer,
226 ).await;
227 }
228 _ => {
229 let parsed = parse_message(value);
230 if consumer_tx.send(parsed).await.is_err() {
231 break;
232 }
233 }
234 }
235 }
236 Some(Err(e)) => {
237 let _ = consumer_tx.send(Err(e)).await;
238 break;
239 }
240 None => break,
241 }
242 }
243 }
244 }
245 cancel.cancel();
248 });
249 }
250}
251
252async fn route_control_response(
253 pending: &Arc<Mutex<HashMap<String, oneshot::Sender<Value>>>>,
254 value: &Value,
255) {
256 let response = value.get("response").cloned().unwrap_or(value.clone());
257 let request_id = response
258 .get("request_id")
259 .and_then(|v| v.as_str())
260 .unwrap_or("");
261
262 let mut pending = pending.lock().await;
263 if let Some(tx) = pending.remove(request_id) {
264 let _ = tx.send(response);
265 } else {
266 tracing::warn!(request_id, "control response for unknown request");
267 }
268}
269
270async fn dispatch_control_request(
271 value: &Value,
272 hooks: &[HookDefinition],
273 can_use_tool: &Option<CanUseToolCallback>,
274 mcp_handler: &Option<McpMessageHandler>,
275 writer: &TransportWriter,
276) {
277 let request_id = value
278 .get("request_id")
279 .and_then(|v| v.as_str())
280 .unwrap_or("")
281 .to_string();
282
283 let request = match value.get("request") {
284 Some(r) => r,
285 None => {
286 tracing::warn!("control request missing 'request' field");
287 return;
288 }
289 };
290
291 let subtype = request
292 .get("subtype")
293 .and_then(|v| v.as_str())
294 .unwrap_or("");
295
296 let response_body = match subtype {
297 "can_use_tool" => handle_can_use_tool(request, can_use_tool).await,
298 "hook_callback" => handle_hook_callback(request, hooks).await,
299 "mcp_message" => handle_mcp_message(request, mcp_handler).await,
300 other => {
301 tracing::warn!(subtype = other, "unknown control request subtype");
302 serde_json::json!({"error": format!("unknown subtype: {other}")})
303 }
304 };
305
306 let control_response = serde_json::json!({
307 "type": "control_response",
308 "response": {
309 "subtype": "success",
310 "request_id": request_id,
311 "response": response_body,
312 }
313 });
314
315 if let Err(e) = writer.write(control_response).await {
316 tracing::error!("failed to send control response: {e}");
317 }
318}
319
320async fn handle_can_use_tool(request: &Value, callback: &Option<CanUseToolCallback>) -> Value {
321 let tool_name = request
322 .get("tool_name")
323 .and_then(|v| v.as_str())
324 .unwrap_or("")
325 .to_string();
326 let input = request.get("input").cloned().unwrap_or(Value::Null);
327
328 if let Some(cb) = callback {
329 let result = cb(CanUseToolInput { tool_name, input }).await;
330 if result.allowed {
331 serde_json::json!({"behavior": "allow"})
332 } else {
333 serde_json::json!({
334 "behavior": "deny",
335 "message": result.reason.unwrap_or_default()
336 })
337 }
338 } else {
339 serde_json::json!({"behavior": "allow"})
340 }
341}
342
343async fn handle_hook_callback(request: &Value, hooks: &[HookDefinition]) -> Value {
344 let callback_id = request
345 .get("callback_id")
346 .and_then(|v| v.as_str())
347 .unwrap_or("");
348 let hook_input = request.get("input").cloned().unwrap_or(Value::Null);
349
350 let hook_index: Option<usize> = callback_id
351 .strip_prefix("hook_")
352 .and_then(|s| s.parse().ok());
353
354 let hook = hook_index.and_then(|i| hooks.get(i));
355
356 if let Some(hook) = hook {
357 let event_name = hook.event.as_str();
358 let typed_input = match hook.event {
359 HookEvent::PreToolUse => HookInput::PreToolUse(
360 serde_json::from_value(hook_input.clone()).unwrap_or_else(|e| {
361 tracing::warn!(event = event_name, "hook input parse failed: {e}");
362 Default::default()
363 }),
364 ),
365 HookEvent::PostToolUse => HookInput::PostToolUse(
366 serde_json::from_value(hook_input.clone()).unwrap_or_else(|e| {
367 tracing::warn!(event = event_name, "hook input parse failed: {e}");
368 Default::default()
369 }),
370 ),
371 HookEvent::Notification => HookInput::Notification(
372 serde_json::from_value(hook_input.clone()).unwrap_or_else(|e| {
373 tracing::warn!(event = event_name, "hook input parse failed: {e}");
374 Default::default()
375 }),
376 ),
377 HookEvent::Stop | HookEvent::SubagentStop => HookInput::Stop(
378 serde_json::from_value(hook_input.clone()).unwrap_or_else(|e| {
379 tracing::warn!(event = event_name, "hook input parse failed: {e}");
380 Default::default()
381 }),
382 ),
383 };
384
385 let output = (hook.callback)(typed_input).await;
386 let mut result = serde_json::json!({"continue": true});
387 if let Some(decision) = &output.decision {
388 let hook_specific = serde_json::json!({
389 "hookEventName": hook.event.as_str(),
390 "permissionDecision": decision.as_str(),
391 "permissionDecisionReason": output.reason.as_deref().unwrap_or(""),
392 });
393 result["hookSpecificOutput"] = hook_specific;
394
395 if *decision == HookDecision::Block {
396 result["continue"] = Value::Bool(false);
397 }
398 }
399 result
400 } else {
401 tracing::warn!(callback_id, "hook callback not found");
402 serde_json::json!({"continue": true})
403 }
404}
405
406async fn handle_mcp_message(request: &Value, handler: &Option<McpMessageHandler>) -> Value {
407 let server_name = request
408 .get("server_name")
409 .and_then(|v| v.as_str())
410 .unwrap_or("")
411 .to_string();
412 let message = request.get("message").cloned().unwrap_or(Value::Null);
413
414 if let Some(handler) = handler {
415 handler(server_name, message).await
416 } else {
417 serde_json::json!({"error": "no MCP handler registered"})
418 }
419}
420
421fn generate_request_id() -> String {
422 use rand::Rng;
423 let mut rng = rand::rng();
424 let suffix: u64 = rng.random();
425 format!("req_{suffix:016x}")
426}
427
428impl Drop for Query {
429 fn drop(&mut self) {
430 self.cancel.cancel();
431 }
432}