Skip to main content

claude_code_rs/
query.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use serde_json::Value;
6use tokio::sync::{mpsc, oneshot, Mutex};
7use tokio_util::sync::CancellationToken;
8
9use crate::error::{Error, Result};
10use crate::message_parser::parse_message;
11use crate::types::control::{SDKCapabilities, SDKControlCommand, SDKInitMessage};
12use crate::types::hooks::{
13    HookDecision, HookDefinition, HookEvent, HookInput, NotificationInput, PostToolUseInput,
14    PreToolUseInput, StopInput,
15};
16use crate::types::messages::Message;
17use crate::types::permissions::{CanUseToolCallback, CanUseToolInput};
18use crate::transport::{Transport, TransportWriter};
19
20const DEFAULT_CONTROL_TIMEOUT: Duration = Duration::from_secs(30);
21
22/// Handler for MCP messages routed through the control protocol.
23pub type McpMessageHandler = Arc<
24    dyn Fn(String, Value) -> std::pin::Pin<Box<dyn std::future::Future<Output = Value> + Send>>
25        + Send
26        + Sync,
27>;
28
29/// Query manages the bidirectional control protocol over a Transport connection.
30///
31/// Routes incoming messages: control requests are handled internally,
32/// regular messages are forwarded to the consumer channel.
33pub struct Query {
34    transport: Box<dyn Transport>,
35    writer: Option<TransportWriter>,
36    hooks: Vec<HookDefinition>,
37    can_use_tool: Option<CanUseToolCallback>,
38    mcp_handler: Option<McpMessageHandler>,
39    pending_responses: Arc<Mutex<HashMap<String, oneshot::Sender<Value>>>>,
40    cancel: CancellationToken,
41    control_timeout: Duration,
42    server_info: Arc<Mutex<Option<Value>>>,
43}
44
45impl Query {
46    pub fn new(
47        transport: Box<dyn Transport>,
48        hooks: Vec<HookDefinition>,
49        can_use_tool: Option<CanUseToolCallback>,
50        mcp_handler: Option<McpMessageHandler>,
51        control_timeout: Option<Duration>,
52    ) -> Self {
53        Self {
54            transport,
55            writer: None,
56            hooks,
57            can_use_tool,
58            mcp_handler,
59            pending_responses: Arc::new(Mutex::new(HashMap::new())),
60            cancel: CancellationToken::new(),
61            control_timeout: control_timeout.unwrap_or(DEFAULT_CONTROL_TIMEOUT),
62            server_info: Arc::new(Mutex::new(None)),
63        }
64    }
65
66    /// Connect to the CLI and perform the initialization handshake.
67    pub async fn connect(&mut self) -> Result<mpsc::Receiver<Result<Message>>> {
68        let (raw_rx, writer) = self.transport.connect().await?;
69        self.writer = Some(writer.clone());
70
71        let (consumer_tx, consumer_rx) = mpsc::channel::<Result<Message>>(256);
72
73        // Start the message router task.
74        self.spawn_router(raw_rx, consumer_tx, writer.clone());
75
76        // Perform init handshake.
77        self.initialize().await?;
78
79        Ok(consumer_rx)
80    }
81
82    /// Send a user message to the CLI.
83    pub async fn send_message(&self, prompt: &str, session_id: Option<&str>) -> Result<()> {
84        let writer = self.writer.as_ref().ok_or(Error::NotConnected)?;
85        let msg = serde_json::json!({
86            "type": "user",
87            "message": {
88                "role": "user",
89                "content": prompt
90            },
91            "session_id": session_id.unwrap_or(""),
92            "parent_tool_use_id": null
93        });
94        writer.write(msg).await
95    }
96
97    /// Send a control command and wait for the response.
98    pub async fn send_control_command(&self, command: SDKControlCommand) -> Result<Value> {
99        let writer = self.writer.as_ref().ok_or(Error::NotConnected)?;
100        let request_id = generate_request_id();
101
102        let mut request = serde_json::json!({
103            "type": "control_request",
104            "request_id": request_id,
105            "request": {
106                "subtype": command.command_type,
107            }
108        });
109
110        if let Value::Object(params) = command.params {
111            if let Value::Object(ref mut req) = request["request"] {
112                for (k, v) in params {
113                    req.insert(k, v);
114                }
115            }
116        }
117
118        let (tx, rx) = oneshot::channel();
119        {
120            let mut pending = self.pending_responses.lock().await;
121            pending.insert(request_id.clone(), tx);
122        }
123
124        writer.write(request).await?;
125
126        let response = tokio::time::timeout(self.control_timeout, rx)
127            .await
128            .map_err(|_| Error::ControlTimeout(self.control_timeout))?
129            .map_err(|_| Error::ControlProtocol("response channel dropped".into()))?;
130
131        Ok(response)
132    }
133
134    pub async fn interrupt(&self) -> Result<Value> {
135        self.send_control_command(SDKControlCommand::interrupt())
136            .await
137    }
138
139    pub async fn set_permission_mode(&self, mode: &str) -> Result<Value> {
140        self.send_control_command(SDKControlCommand::set_permission_mode(mode))
141            .await
142    }
143
144    pub async fn set_model(&self, model: &str) -> Result<Value> {
145        self.send_control_command(SDKControlCommand::set_model(model))
146            .await
147    }
148
149    pub async fn rewind_files(&self, user_message_id: &str) -> Result<Value> {
150        self.send_control_command(SDKControlCommand::rewind_files(user_message_id))
151            .await
152    }
153
154    pub async fn get_mcp_status(&self) -> Result<Value> {
155        self.send_control_command(SDKControlCommand::get_mcp_status())
156            .await
157    }
158
159    pub async fn get_server_info(&self) -> Option<Value> {
160        self.server_info.lock().await.clone()
161    }
162
163    pub async fn end_input(&self) -> Result<()> {
164        self.transport.end_input().await
165    }
166
167    pub async fn close(&mut self) -> Result<()> {
168        self.cancel.cancel();
169        self.writer = None;
170        self.transport.close().await
171    }
172
173    async fn initialize(&self) -> Result<()> {
174        let writer = self.writer.as_ref().ok_or(Error::NotConnected)?;
175
176        let capabilities = SDKCapabilities {
177            hooks: !self.hooks.is_empty(),
178            permissions: self.can_use_tool.is_some(),
179            mcp: self.mcp_handler.is_some(),
180            agent_definitions: vec![],
181            mcp_servers: vec![],
182        };
183
184        let init_msg = SDKInitMessage::new(capabilities);
185        let init_value = serde_json::to_value(&init_msg)?;
186
187        let request_id = generate_request_id();
188        let request = serde_json::json!({
189            "type": "control_request",
190            "request_id": request_id,
191            "request": {
192                "subtype": "initialize",
193                "protocol_version": "1",
194                "capabilities": init_value.get("capabilities"),
195            }
196        });
197
198        let (tx, rx) = oneshot::channel();
199        {
200            let mut pending = self.pending_responses.lock().await;
201            pending.insert(request_id.clone(), tx);
202        }
203
204        writer.write(request).await?;
205
206        let response = tokio::time::timeout(self.control_timeout, rx)
207            .await
208            .map_err(|_| Error::ControlTimeout(self.control_timeout))?
209            .map_err(|_| Error::ControlProtocol("init response channel dropped".into()))?;
210
211        {
212            let mut info = self.server_info.lock().await;
213            *info = Some(response);
214        }
215
216        Ok(())
217    }
218
219    fn spawn_router(
220        &self,
221        mut raw_rx: mpsc::Receiver<Result<Value>>,
222        consumer_tx: mpsc::Sender<Result<Message>>,
223        writer: TransportWriter,
224    ) {
225        let pending = self.pending_responses.clone();
226        let hooks = self.hooks.clone();
227        let can_use_tool = self.can_use_tool.clone();
228        let mcp_handler = self.mcp_handler.clone();
229        let cancel = self.cancel.clone();
230
231        tokio::spawn(async move {
232            loop {
233                tokio::select! {
234                    _ = cancel.cancelled() => break,
235                    msg = raw_rx.recv() => {
236                        match msg {
237                            Some(Ok(value)) => {
238                                let msg_type = value.get("type")
239                                    .and_then(|v| v.as_str())
240                                    .unwrap_or("");
241
242                                match msg_type {
243                                    "control_response" => {
244                                        route_control_response(&pending, &value).await;
245                                    }
246                                    "control_request" => {
247                                        dispatch_control_request(
248                                            &value,
249                                            &hooks,
250                                            &can_use_tool,
251                                            &mcp_handler,
252                                            &writer,
253                                        ).await;
254                                    }
255                                    _ => {
256                                        let parsed = parse_message(value);
257                                        if consumer_tx.send(parsed).await.is_err() {
258                                            break;
259                                        }
260                                    }
261                                }
262                            }
263                            Some(Err(e)) => {
264                                let _ = consumer_tx.send(Err(e)).await;
265                                break;
266                            }
267                            None => break,
268                        }
269                    }
270                }
271            }
272        });
273    }
274}
275
276async fn route_control_response(
277    pending: &Arc<Mutex<HashMap<String, oneshot::Sender<Value>>>>,
278    value: &Value,
279) {
280    let response = value.get("response").cloned().unwrap_or(value.clone());
281    let request_id = response
282        .get("request_id")
283        .and_then(|v| v.as_str())
284        .unwrap_or("");
285
286    let mut pending = pending.lock().await;
287    if let Some(tx) = pending.remove(request_id) {
288        let _ = tx.send(response);
289    } else {
290        tracing::warn!(request_id, "control response for unknown request");
291    }
292}
293
294async fn dispatch_control_request(
295    value: &Value,
296    hooks: &[HookDefinition],
297    can_use_tool: &Option<CanUseToolCallback>,
298    mcp_handler: &Option<McpMessageHandler>,
299    writer: &TransportWriter,
300) {
301    let request_id = value
302        .get("request_id")
303        .and_then(|v| v.as_str())
304        .unwrap_or("")
305        .to_string();
306
307    let request = match value.get("request") {
308        Some(r) => r,
309        None => {
310            tracing::warn!("control request missing 'request' field");
311            return;
312        }
313    };
314
315    let subtype = request
316        .get("subtype")
317        .and_then(|v| v.as_str())
318        .unwrap_or("");
319
320    let response_body = match subtype {
321        "can_use_tool" => handle_can_use_tool(request, can_use_tool).await,
322        "hook_callback" => handle_hook_callback(request, hooks).await,
323        "mcp_message" => handle_mcp_message(request, mcp_handler).await,
324        other => {
325            tracing::warn!(subtype = other, "unknown control request subtype");
326            serde_json::json!({"error": format!("unknown subtype: {other}")})
327        }
328    };
329
330    let control_response = serde_json::json!({
331        "type": "control_response",
332        "response": {
333            "subtype": "success",
334            "request_id": request_id,
335            "response": response_body,
336        }
337    });
338
339    if let Err(e) = writer.write(control_response).await {
340        tracing::error!("failed to send control response: {e}");
341    }
342}
343
344async fn handle_can_use_tool(request: &Value, callback: &Option<CanUseToolCallback>) -> Value {
345    let tool_name = request
346        .get("tool_name")
347        .and_then(|v| v.as_str())
348        .unwrap_or("")
349        .to_string();
350    let input = request.get("input").cloned().unwrap_or(Value::Null);
351
352    if let Some(cb) = callback {
353        let result = cb(CanUseToolInput { tool_name, input }).await;
354        if result.allowed {
355            serde_json::json!({"behavior": "allow"})
356        } else {
357            serde_json::json!({
358                "behavior": "deny",
359                "message": result.reason.unwrap_or_default()
360            })
361        }
362    } else {
363        serde_json::json!({"behavior": "allow"})
364    }
365}
366
367async fn handle_hook_callback(request: &Value, hooks: &[HookDefinition]) -> Value {
368    let callback_id = request
369        .get("callback_id")
370        .and_then(|v| v.as_str())
371        .unwrap_or("");
372    let hook_input = request.get("input").cloned().unwrap_or(Value::Null);
373
374    let hook_index: Option<usize> = callback_id
375        .strip_prefix("hook_")
376        .and_then(|s| s.parse().ok());
377
378    let hook = hook_index.and_then(|i| hooks.get(i));
379
380    if let Some(hook) = hook {
381        let typed_input = match hook.event {
382            HookEvent::PreToolUse => {
383                let pre: PreToolUseInput =
384                    serde_json::from_value(hook_input).unwrap_or(PreToolUseInput {
385                        tool_name: String::new(),
386                        tool_input: Value::Null,
387                    });
388                HookInput::PreToolUse(pre)
389            }
390            HookEvent::PostToolUse => {
391                let post: PostToolUseInput =
392                    serde_json::from_value(hook_input).unwrap_or(PostToolUseInput {
393                        tool_name: String::new(),
394                        tool_input: Value::Null,
395                        tool_output: Value::Null,
396                    });
397                HookInput::PostToolUse(post)
398            }
399            HookEvent::Notification => {
400                let notif: NotificationInput =
401                    serde_json::from_value(hook_input).unwrap_or(NotificationInput {
402                        title: String::new(),
403                        message: None,
404                    });
405                HookInput::Notification(notif)
406            }
407            HookEvent::Stop | HookEvent::SubagentStop => {
408                let stop: StopInput =
409                    serde_json::from_value(hook_input).unwrap_or(StopInput { reason: None });
410                HookInput::Stop(stop)
411            }
412        };
413
414        let output = (hook.callback)(typed_input).await;
415        let mut result = serde_json::json!({"continue": true});
416        if let Some(decision) = &output.decision {
417            let hook_specific = serde_json::json!({
418                "hookEventName": match hook.event {
419                    HookEvent::PreToolUse => "PreToolUse",
420                    HookEvent::PostToolUse => "PostToolUse",
421                    HookEvent::Notification => "Notification",
422                    HookEvent::Stop => "Stop",
423                    HookEvent::SubagentStop => "SubagentStop",
424                },
425                "permissionDecision": match decision {
426                    HookDecision::Approve => "approve",
427                    HookDecision::Block => "deny",
428                    HookDecision::Ignore => "ignore",
429                },
430                "permissionDecisionReason": output.reason.as_deref().unwrap_or(""),
431            });
432            result["hookSpecificOutput"] = hook_specific;
433
434            if *decision == HookDecision::Block {
435                result["continue"] = Value::Bool(false);
436            }
437        }
438        result
439    } else {
440        tracing::warn!(callback_id, "hook callback not found");
441        serde_json::json!({"continue": true})
442    }
443}
444
445async fn handle_mcp_message(request: &Value, handler: &Option<McpMessageHandler>) -> Value {
446    let server_name = request
447        .get("server_name")
448        .and_then(|v| v.as_str())
449        .unwrap_or("")
450        .to_string();
451    let message = request.get("message").cloned().unwrap_or(Value::Null);
452
453    if let Some(handler) = handler {
454        handler(server_name, message).await
455    } else {
456        serde_json::json!({"error": "no MCP handler registered"})
457    }
458}
459
460fn generate_request_id() -> String {
461    use rand::Rng;
462    let mut rng = rand::rng();
463    let suffix: u64 = rng.random();
464    format!("req_{suffix:016x}")
465}
466
467impl Drop for Query {
468    fn drop(&mut self) {
469        self.cancel.cancel();
470    }
471}