Skip to main content

claude_code_rs/
query.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use serde_json::Value;
6use tokio::sync::{mpsc, oneshot, Mutex};
7use tokio_util::sync::CancellationToken;
8
9use crate::error::{Error, Result};
10use crate::message_parser::parse_message;
11use crate::types::control::{SDKCapabilities, SDKControlCommand};
12use crate::types::hooks::{HookDecision, HookDefinition, HookEvent, HookInput};
13use crate::types::messages::Message;
14use crate::types::permissions::{CanUseToolCallback, CanUseToolInput};
15use crate::transport::{Transport, TransportWriter};
16
17const DEFAULT_CONTROL_TIMEOUT: Duration = Duration::from_secs(30);
18
19/// Handler for MCP messages routed through the control protocol.
20pub type McpMessageHandler = Arc<
21    dyn Fn(String, Value) -> std::pin::Pin<Box<dyn std::future::Future<Output = Value> + Send>>
22        + Send
23        + Sync,
24>;
25
26/// Query manages the bidirectional control protocol over a Transport connection.
27///
28/// Routes incoming messages: control requests are handled internally,
29/// regular messages are forwarded to the consumer channel.
30pub struct Query {
31    transport: Box<dyn Transport>,
32    writer: Option<TransportWriter>,
33    hooks: Vec<HookDefinition>,
34    can_use_tool: Option<CanUseToolCallback>,
35    mcp_handler: Option<McpMessageHandler>,
36    pending_responses: Arc<Mutex<HashMap<String, oneshot::Sender<Value>>>>,
37    cancel: CancellationToken,
38    control_timeout: Duration,
39    server_info: Arc<Mutex<Option<Value>>>,
40}
41
42impl Query {
43    pub fn new(
44        transport: Box<dyn Transport>,
45        hooks: Vec<HookDefinition>,
46        can_use_tool: Option<CanUseToolCallback>,
47        mcp_handler: Option<McpMessageHandler>,
48        control_timeout: Option<Duration>,
49    ) -> Self {
50        Self {
51            transport,
52            writer: None,
53            hooks,
54            can_use_tool,
55            mcp_handler,
56            pending_responses: Arc::new(Mutex::new(HashMap::new())),
57            cancel: CancellationToken::new(),
58            control_timeout: control_timeout.unwrap_or(DEFAULT_CONTROL_TIMEOUT),
59            server_info: Arc::new(Mutex::new(None)),
60        }
61    }
62
63    /// Connect to the CLI and perform the initialization handshake.
64    pub async fn connect(&mut self) -> Result<mpsc::Receiver<Result<Message>>> {
65        let (raw_rx, writer) = self.transport.connect().await?;
66        self.writer = Some(writer.clone());
67
68        let (consumer_tx, consumer_rx) = mpsc::channel::<Result<Message>>(256);
69
70        // Start the message router task.
71        self.spawn_router(raw_rx, consumer_tx, writer);
72
73        // Perform init handshake.
74        self.initialize().await?;
75
76        Ok(consumer_rx)
77    }
78
79    /// Send a user message to the CLI.
80    pub async fn send_message(&self, prompt: &str, session_id: Option<&str>) -> Result<()> {
81        let writer = self.writer.as_ref().ok_or(Error::NotConnected)?;
82        let msg = serde_json::json!({
83            "type": "user",
84            "message": {
85                "role": "user",
86                "content": prompt
87            },
88            "session_id": session_id,
89            "parent_tool_use_id": null
90        });
91        writer.write(msg).await
92    }
93
94    /// Send a control command and wait for the response.
95    pub async fn send_control_command(&self, command: SDKControlCommand) -> Result<Value> {
96        self.send_raw_control_request(command.to_request_body()).await
97    }
98
99    pub async fn interrupt(&self) -> Result<Value> {
100        self.send_control_command(SDKControlCommand::interrupt())
101            .await
102    }
103
104    pub async fn set_permission_mode(&self, mode: &str) -> Result<Value> {
105        self.send_control_command(SDKControlCommand::set_permission_mode(mode))
106            .await
107    }
108
109    pub async fn set_model(&self, model: &str) -> Result<Value> {
110        self.send_control_command(SDKControlCommand::set_model(model))
111            .await
112    }
113
114    pub async fn rewind_files(&self, user_message_id: &str) -> Result<Value> {
115        self.send_control_command(SDKControlCommand::rewind_files(user_message_id))
116            .await
117    }
118
119    pub async fn get_mcp_status(&self) -> Result<Value> {
120        self.send_control_command(SDKControlCommand::get_mcp_status())
121            .await
122    }
123
124    pub async fn get_server_info(&self) -> Option<Value> {
125        self.server_info.lock().await.clone()
126    }
127
128    #[allow(dead_code)]
129    pub async fn end_input(&self) -> Result<()> {
130        self.transport.end_input().await
131    }
132
133    /// Wait until the cancel token is triggered (transport/router finished).
134    pub async fn closed(&self) {
135        self.cancel.cancelled().await;
136    }
137
138    pub async fn close(&mut self) -> Result<()> {
139        self.cancel.cancel();
140        self.writer = None;
141        self.transport.close().await
142    }
143
144    /// Send a raw control request and wait for the response with timeout.
145    async fn send_raw_control_request(&self, request_body: Value) -> Result<Value> {
146        let writer = self.writer.as_ref().ok_or(Error::NotConnected)?;
147        let request_id = generate_request_id();
148
149        let request = serde_json::json!({
150            "type": "control_request",
151            "request_id": request_id,
152            "request": request_body,
153        });
154
155        let (tx, rx) = oneshot::channel();
156        {
157            let mut pending = self.pending_responses.lock().await;
158            pending.insert(request_id.clone(), tx);
159        }
160
161        writer.write(request).await?;
162
163        tokio::time::timeout(self.control_timeout, rx)
164            .await
165            .map_err(|_| Error::ControlTimeout(self.control_timeout))?
166            .map_err(|_| Error::ControlProtocol("response channel dropped".into()))
167    }
168
169    async fn initialize(&self) -> Result<()> {
170        let capabilities = SDKCapabilities {
171            hooks: !self.hooks.is_empty(),
172            permissions: self.can_use_tool.is_some(),
173            mcp: self.mcp_handler.is_some(),
174            agent_definitions: vec![],
175            mcp_servers: vec![],
176        };
177
178        let response = self.send_raw_control_request(serde_json::json!({
179            "subtype": "initialize",
180            "protocol_version": "1",
181            "capabilities": capabilities,
182        })).await?;
183
184        {
185            let mut info = self.server_info.lock().await;
186            *info = Some(response);
187        }
188
189        Ok(())
190    }
191
192    fn spawn_router(
193        &self,
194        mut raw_rx: mpsc::Receiver<Result<Value>>,
195        consumer_tx: mpsc::Sender<Result<Message>>,
196        writer: TransportWriter,
197    ) {
198        let pending = self.pending_responses.clone();
199        let hooks = self.hooks.clone();
200        let can_use_tool = self.can_use_tool.clone();
201        let mcp_handler = self.mcp_handler.clone();
202        let cancel = self.cancel.clone();
203
204        tokio::spawn(async move {
205            loop {
206                tokio::select! {
207                    _ = cancel.cancelled() => break,
208                    msg = raw_rx.recv() => {
209                        match msg {
210                            Some(Ok(value)) => {
211                                let msg_type = value.get("type")
212                                    .and_then(|v| v.as_str())
213                                    .unwrap_or("");
214
215                                match msg_type {
216                                    "control_response" => {
217                                        route_control_response(&pending, &value).await;
218                                    }
219                                    "control_request" => {
220                                        dispatch_control_request(
221                                            &value,
222                                            &hooks,
223                                            &can_use_tool,
224                                            &mcp_handler,
225                                            &writer,
226                                        ).await;
227                                    }
228                                    _ => {
229                                        let parsed = parse_message(value);
230                                        if consumer_tx.send(parsed).await.is_err() {
231                                            break;
232                                        }
233                                    }
234                                }
235                            }
236                            Some(Err(e)) => {
237                                let _ = consumer_tx.send(Err(e)).await;
238                                break;
239                            }
240                            None => break,
241                        }
242                    }
243                }
244            }
245            // Signal that the router is done so Query::closed() returns
246            // and the owning task can drop the Query cleanly.
247            cancel.cancel();
248        });
249    }
250}
251
252async fn route_control_response(
253    pending: &Arc<Mutex<HashMap<String, oneshot::Sender<Value>>>>,
254    value: &Value,
255) {
256    let response = value.get("response").cloned().unwrap_or(value.clone());
257    let request_id = response
258        .get("request_id")
259        .and_then(|v| v.as_str())
260        .unwrap_or("");
261
262    let mut pending = pending.lock().await;
263    if let Some(tx) = pending.remove(request_id) {
264        let _ = tx.send(response);
265    } else {
266        tracing::warn!(request_id, "control response for unknown request");
267    }
268}
269
270async fn dispatch_control_request(
271    value: &Value,
272    hooks: &[HookDefinition],
273    can_use_tool: &Option<CanUseToolCallback>,
274    mcp_handler: &Option<McpMessageHandler>,
275    writer: &TransportWriter,
276) {
277    let request_id = value
278        .get("request_id")
279        .and_then(|v| v.as_str())
280        .unwrap_or("")
281        .to_string();
282
283    let request = match value.get("request") {
284        Some(r) => r,
285        None => {
286            tracing::warn!("control request missing 'request' field");
287            return;
288        }
289    };
290
291    let subtype = request
292        .get("subtype")
293        .and_then(|v| v.as_str())
294        .unwrap_or("");
295
296    let response_body = match subtype {
297        "can_use_tool" => handle_can_use_tool(request, can_use_tool).await,
298        "hook_callback" => handle_hook_callback(request, hooks).await,
299        "mcp_message" => handle_mcp_message(request, mcp_handler).await,
300        other => {
301            tracing::warn!(subtype = other, "unknown control request subtype");
302            serde_json::json!({"error": format!("unknown subtype: {other}")})
303        }
304    };
305
306    let control_response = serde_json::json!({
307        "type": "control_response",
308        "response": {
309            "subtype": "success",
310            "request_id": request_id,
311            "response": response_body,
312        }
313    });
314
315    if let Err(e) = writer.write(control_response).await {
316        tracing::error!("failed to send control response: {e}");
317    }
318}
319
320async fn handle_can_use_tool(request: &Value, callback: &Option<CanUseToolCallback>) -> Value {
321    let tool_name = request
322        .get("tool_name")
323        .and_then(|v| v.as_str())
324        .unwrap_or("")
325        .to_string();
326    let input = request.get("input").cloned().unwrap_or(Value::Null);
327
328    if let Some(cb) = callback {
329        let result = cb(CanUseToolInput { tool_name, input }).await;
330        if result.allowed {
331            serde_json::json!({"behavior": "allow"})
332        } else {
333            serde_json::json!({
334                "behavior": "deny",
335                "message": result.reason.unwrap_or_default()
336            })
337        }
338    } else {
339        serde_json::json!({"behavior": "allow"})
340    }
341}
342
343async fn handle_hook_callback(request: &Value, hooks: &[HookDefinition]) -> Value {
344    let callback_id = request
345        .get("callback_id")
346        .and_then(|v| v.as_str())
347        .unwrap_or("");
348    let hook_input = request.get("input").cloned().unwrap_or(Value::Null);
349
350    let hook_index: Option<usize> = callback_id
351        .strip_prefix("hook_")
352        .and_then(|s| s.parse().ok());
353
354    let hook = hook_index.and_then(|i| hooks.get(i));
355
356    if let Some(hook) = hook {
357        let event_name = hook.event.as_str();
358        let typed_input = match hook.event {
359            HookEvent::PreToolUse => HookInput::PreToolUse(
360                serde_json::from_value(hook_input.clone()).unwrap_or_else(|e| {
361                    tracing::warn!(event = event_name, "hook input parse failed: {e}");
362                    Default::default()
363                }),
364            ),
365            HookEvent::PostToolUse => HookInput::PostToolUse(
366                serde_json::from_value(hook_input.clone()).unwrap_or_else(|e| {
367                    tracing::warn!(event = event_name, "hook input parse failed: {e}");
368                    Default::default()
369                }),
370            ),
371            HookEvent::Notification => HookInput::Notification(
372                serde_json::from_value(hook_input.clone()).unwrap_or_else(|e| {
373                    tracing::warn!(event = event_name, "hook input parse failed: {e}");
374                    Default::default()
375                }),
376            ),
377            HookEvent::Stop | HookEvent::SubagentStop => HookInput::Stop(
378                serde_json::from_value(hook_input.clone()).unwrap_or_else(|e| {
379                    tracing::warn!(event = event_name, "hook input parse failed: {e}");
380                    Default::default()
381                }),
382            ),
383        };
384
385        let output = (hook.callback)(typed_input).await;
386        let mut result = serde_json::json!({"continue": true});
387        if let Some(decision) = &output.decision {
388            let hook_specific = serde_json::json!({
389                "hookEventName": hook.event.as_str(),
390                "permissionDecision": decision.as_str(),
391                "permissionDecisionReason": output.reason.as_deref().unwrap_or(""),
392            });
393            result["hookSpecificOutput"] = hook_specific;
394
395            if *decision == HookDecision::Block {
396                result["continue"] = Value::Bool(false);
397            }
398        }
399        result
400    } else {
401        tracing::warn!(callback_id, "hook callback not found");
402        serde_json::json!({"continue": true})
403    }
404}
405
406async fn handle_mcp_message(request: &Value, handler: &Option<McpMessageHandler>) -> Value {
407    let server_name = request
408        .get("server_name")
409        .and_then(|v| v.as_str())
410        .unwrap_or("")
411        .to_string();
412    let message = request.get("message").cloned().unwrap_or(Value::Null);
413
414    if let Some(handler) = handler {
415        handler(server_name, message).await
416    } else {
417        serde_json::json!({"error": "no MCP handler registered"})
418    }
419}
420
421fn generate_request_id() -> String {
422    use rand::Rng;
423    let mut rng = rand::rng();
424    let suffix: u64 = rng.random();
425    format!("req_{suffix:016x}")
426}
427
428impl Drop for Query {
429    fn drop(&mut self) {
430        self.cancel.cancel();
431    }
432}