Skip to main content

claude_code/
query.rs

1//! Core query/session management for Claude Code communication.
2//!
3//! This module provides the [`Query`] struct, which handles the low-level
4//! communication protocol with the Claude Code CLI process, including:
5//!
6//! - Session initialization and handshake
7//! - Control request/response protocol (permissions, hooks, MCP)
8//! - Background message reading and routing via tokio tasks
9//! - Lifecycle management (interrupt, model change, rewind)
10//!
11//! Most users should use [`ClaudeSdkClient`](crate::ClaudeSdkClient) or
12//! [`query()`](crate::query_fn::query) instead of interacting with this module directly.
13
14use std::collections::HashMap;
15use std::future::Future;
16use std::panic::{self, AssertUnwindSafe};
17use std::sync::Arc;
18use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
19use std::time::Duration;
20
21use futures::{FutureExt, Stream, StreamExt};
22use serde_json::{Map, Value, json};
23use tokio::sync::{Mutex, mpsc, oneshot};
24use tokio::task::JoinHandle;
25use tracing::{debug, warn};
26
27use crate::errors::{Error, Result};
28use crate::message_parser::parse_message;
29use crate::sdk_mcp::McpSdkServer;
30use crate::transport::{TransportCloseHandle, TransportReader, TransportWriter};
31use crate::types::{
32    AgentDefinition, CanUseToolCallback, HookCallback, HookMatcher, McpStatusResponse, Message,
33    PermissionResult, ToolPermissionContext,
34};
35
36/// Channel buffer size for SDK messages (matches Python SDK's buffer=100).
37const MESSAGE_CHANNEL_BUFFER: usize = 100;
38
39/// Converts hook callback output keys from Rust-safe names to CLI protocol names.
40///
41/// Specifically maps `async_` → `async` and `continue_` → `continue`, since
42/// those are reserved words in Rust.
43fn convert_hook_output_for_cli(output: Value) -> Value {
44    let Some(obj) = output.as_object() else {
45        return output;
46    };
47
48    let mut converted = Map::new();
49    for (key, value) in obj {
50        match key.as_str() {
51            "async_" => {
52                converted.insert("async".to_string(), value.clone());
53            }
54            "continue_" => {
55                converted.insert("continue".to_string(), value.clone());
56            }
57            _ => {
58                converted.insert(key.clone(), value.clone());
59            }
60        }
61    }
62    Value::Object(converted)
63}
64
65fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
66    if let Some(msg) = payload.downcast_ref::<&str>() {
67        (*msg).to_string()
68    } else if let Some(msg) = payload.downcast_ref::<String>() {
69        msg.clone()
70    } else {
71        "unknown panic payload".to_string()
72    }
73}
74
75fn callback_panic_error(callback_type: &str, payload: Box<dyn std::any::Any + Send>) -> Error {
76    let panic_message = panic_payload_to_string(payload);
77    warn!(
78        callback_type,
79        panic_message, "Caught panic in callback invocation"
80    );
81    Error::Other(format!(
82        "{callback_type} callback panicked: {panic_message}"
83    ))
84}
85
86async fn await_callback_with_panic_isolation<T, F>(
87    callback_type: &str,
88    callback_future: F,
89) -> Result<T>
90where
91    F: Future<Output = Result<T>>,
92{
93    match AssertUnwindSafe(callback_future).catch_unwind().await {
94        Ok(result) => result,
95        Err(payload) => Err(callback_panic_error(callback_type, payload)),
96    }
97}
98
99/// Tracks pending control request senders and early-arrival response buffers.
100///
101/// Both maps are behind a single mutex to ensure atomicity: a response that
102/// arrives before the sender is registered gets buffered, and a later
103/// `send_control_request` drains the buffer under the same lock.
104struct PendingControlsState {
105    senders: HashMap<String, oneshot::Sender<Result<Value>>>,
106    buffered: HashMap<String, Result<Value>>,
107}
108
109/// Shared state accessible by both the background reader task and the main task.
110struct QuerySharedState {
111    can_use_tool: Option<CanUseToolCallback>,
112    hook_callbacks: Mutex<HashMap<String, HookCallback>>,
113    sdk_mcp_servers: HashMap<String, Arc<McpSdkServer>>,
114    /// Pending control request/response matching.
115    pending_controls: Mutex<PendingControlsState>,
116    /// Shared writer for responding to control requests and sending messages.
117    writer: Arc<Mutex<Box<dyn TransportWriter>>>,
118    /// Whether stdin close is deferred until first result.
119    pending_stdin_close: AtomicBool,
120    /// Timeout for deferred stdin close.
121    stream_close_timeout: Duration,
122    /// Whether the background reader task has terminated.
123    reader_terminated: AtomicBool,
124    /// Reason for reader task termination, if known.
125    reader_termination_reason: Mutex<Option<String>>,
126}
127
128/// Low-level query session handler for Claude Code CLI communication.
129///
130/// Manages the bidirectional JSON stream protocol between the SDK and the CLI.
131/// On startup, a background tokio task
132/// is spawned to continuously read messages from the transport and route them:
133///
134/// - **Control responses** are delivered to the waiting control-request caller via oneshot channels.
135/// - **Control requests** (permissions, hooks, MCP) are handled by the background task.
136/// - **SDK messages** (user, assistant, system, result) are parsed and delivered via an mpsc channel.
137///
138/// This architecture mirrors the Python SDK's task-group model and enables
139/// concurrent send and receive operations.
140pub struct Query {
141    /// Shared state for background task and main task.
142    state: Option<Arc<QuerySharedState>>,
143
144    /// Receiver end of the SDK message channel.
145    message_rx: Option<mpsc::Receiver<Result<Message>>>,
146
147    /// Handle for the background reader task.
148    reader_task: Option<JoinHandle<()>>,
149
150    /// Handle for closing the split transport.
151    close_handle: Option<Box<dyn TransportCloseHandle>>,
152
153    /// Monotonically increasing request ID counter.
154    request_counter: Arc<AtomicUsize>,
155
156    /// Whether the query is in streaming mode.
157    is_streaming_mode: bool,
158
159    /// Agent definitions to register during initialization.
160    agents: Option<HashMap<String, AgentDefinition>>,
161
162    /// Whether initialization has completed.
163    initialized: bool,
164
165    /// The initialization response from the CLI.
166    initialization_result: Option<Value>,
167
168    /// Timeout for the initialization handshake.
169    initialize_timeout: Duration,
170
171    /// Whether hooks or SDK MCP servers are present (for deferred stdin close).
172    has_hooks_or_mcp: bool,
173}
174
175impl Query {
176    /// Creates a new `Query` and starts the background reader task.
177    ///
178    /// This is the primary constructor. It splits the given reader and writer,
179    /// registers callbacks, and spawns the background task.
180    ///
181    /// # Arguments
182    ///
183    /// * `reader` — The transport reader half.
184    /// * `writer` — The transport writer half (wrapped in `Arc<Mutex<>>` for sharing).
185    /// * `close_handle` — Handle for closing the transport.
186    /// * `is_streaming_mode` — Whether to use the streaming protocol.
187    /// * `can_use_tool` — Optional permission callback for tool approval.
188    /// * `hook_callbacks` — Hook callbacks keyed by callback ID.
189    /// * `sdk_mcp_servers` — In-process MCP servers.
190    /// * `agents` — Optional subagent definitions.
191    /// * `initialize_timeout` — Timeout for the initialization handshake.
192    #[allow(clippy::too_many_arguments)]
193    pub(crate) fn start(
194        reader: Box<dyn TransportReader>,
195        writer: Box<dyn TransportWriter>,
196        close_handle: Box<dyn TransportCloseHandle>,
197        is_streaming_mode: bool,
198        can_use_tool: Option<CanUseToolCallback>,
199        hook_callbacks: HashMap<String, HookCallback>,
200        sdk_mcp_servers: HashMap<String, Arc<McpSdkServer>>,
201        agents: Option<HashMap<String, AgentDefinition>>,
202        initialize_timeout: Duration,
203    ) -> Self {
204        let stream_close_timeout_ms: u64 = std::env::var("CLAUDE_CODE_STREAM_CLOSE_TIMEOUT")
205            .ok()
206            .and_then(|v| v.parse().ok())
207            .unwrap_or(60_000);
208        let stream_close_timeout =
209            Duration::from_millis(stream_close_timeout_ms).max(Duration::from_secs(60));
210
211        let has_hooks_or_mcp = !hook_callbacks.is_empty() || !sdk_mcp_servers.is_empty();
212        let writer = Arc::new(Mutex::new(writer));
213
214        let state = Arc::new(QuerySharedState {
215            can_use_tool,
216            hook_callbacks: Mutex::new(hook_callbacks),
217            sdk_mcp_servers,
218            pending_controls: Mutex::new(PendingControlsState {
219                senders: HashMap::new(),
220                buffered: HashMap::new(),
221            }),
222            writer: writer.clone(),
223            pending_stdin_close: AtomicBool::new(false),
224            stream_close_timeout,
225            reader_terminated: AtomicBool::new(false),
226            reader_termination_reason: Mutex::new(None),
227        });
228
229        let (message_tx, message_rx) = mpsc::channel(MESSAGE_CHANNEL_BUFFER);
230
231        let reader_state = state.clone();
232        let reader_task = tokio::spawn(async move {
233            background_reader_task(reader, reader_state, message_tx).await;
234        });
235
236        Self {
237            state: Some(state),
238            message_rx: Some(message_rx),
239            reader_task: Some(reader_task),
240            close_handle: Some(close_handle),
241            request_counter: Arc::new(AtomicUsize::new(0)),
242            is_streaming_mode,
243            agents,
244            initialized: false,
245            initialization_result: None,
246            initialize_timeout,
247            has_hooks_or_mcp,
248        }
249    }
250
251    /// Sends the initialization handshake to the CLI.
252    ///
253    /// Registers hook callbacks and agent definitions with the CLI process,
254    /// and waits for the initialization response.
255    ///
256    /// Returns the initialization response payload, or `None` if not in streaming mode.
257    ///
258    /// # Example
259    ///
260    /// ```rust,ignore
261    /// use claude_code::Query;
262    /// use serde_json::Map;
263    /// use serde_json::Value;
264    ///
265    /// # async fn demo(query: &mut Query) -> claude_code::Result<()> {
266    /// let _ = query.initialize(Map::<String, Value>::new()).await?;
267    /// # Ok(())
268    /// # }
269    /// ```
270    pub async fn initialize(&mut self, hooks_config: Map<String, Value>) -> Result<Option<Value>> {
271        if !self.is_streaming_mode {
272            return Ok(None);
273        }
274
275        let mut request = Map::new();
276        request.insert(
277            "subtype".to_string(),
278            Value::String("initialize".to_string()),
279        );
280        request.insert(
281            "hooks".to_string(),
282            if hooks_config.is_empty() {
283                Value::Null
284            } else {
285                Value::Object(hooks_config)
286            },
287        );
288
289        if let Some(agents) = &self.agents {
290            request.insert(
291                "agents".to_string(),
292                serde_json::to_value(agents).unwrap_or(Value::Null),
293            );
294        }
295
296        let response = self
297            .send_control_request(Value::Object(request), self.initialize_timeout)
298            .await?;
299        self.initialized = true;
300        self.initialization_result = Some(response.clone());
301        Ok(Some(response))
302    }
303
304    /// Returns the initialization result from the CLI handshake.
305    ///
306    /// Returns `None` if [`initialize()`](Self::initialize) has not been called yet.
307    ///
308    /// # Example
309    ///
310    /// ```rust,ignore
311    /// use claude_code::Query;
312    ///
313    /// fn demo(query: &Query) {
314    ///     let _info = query.initialization_result();
315    /// }
316    /// ```
317    pub fn initialization_result(&self) -> Option<Value> {
318        self.initialization_result.clone()
319    }
320
321    /// Sends a control request to the CLI and waits for the matching response.
322    ///
323    /// The request is written via the shared writer. The background reader task
324    /// delivers the matching control response via a oneshot channel.
325    async fn send_control_request(&self, request: Value, timeout: Duration) -> Result<Value> {
326        if !self.is_streaming_mode {
327            return Err(Error::Other(
328                "Control requests require streaming mode".to_string(),
329            ));
330        }
331
332        let state = self
333            .state
334            .as_ref()
335            .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
336
337        let request_id = format!(
338            "req_{}",
339            self.request_counter.fetch_add(1, Ordering::SeqCst) + 1
340        );
341
342        // Write the control request first (so it's always observable).
343        let control_request = json!({
344            "type": "control_request",
345            "request_id": request_id,
346            "request": request,
347        });
348        state
349            .writer
350            .lock()
351            .await
352            .write(&(control_request.to_string() + "\n"))
353            .await?;
354
355        // Register a oneshot channel for the response, checking the buffer first.
356        // Both operations are under a single lock to avoid a race where a response
357        // arrives between checking the buffer and registering the sender.
358        let (tx, rx) = oneshot::channel();
359        {
360            let mut controls = state.pending_controls.lock().await;
361            if let Some(result) = controls.buffered.remove(&request_id) {
362                return result;
363            }
364            controls.senders.insert(request_id.clone(), tx);
365        }
366        if state.reader_terminated.load(Ordering::SeqCst) {
367            state
368                .pending_controls
369                .lock()
370                .await
371                .senders
372                .remove(&request_id);
373            let reason = reader_termination_reason(state).await;
374            return Err(Error::Other(format!(
375                "Background reader task terminated: {reason}"
376            )));
377        }
378
379        // Wait for the response with timeout.
380        let result = tokio::time::timeout(timeout, rx).await;
381        match result {
382            Ok(Ok(value)) => value,
383            Ok(Err(_)) => {
384                // Channel closed — background task died
385                Err(Error::Other(
386                    "Background reader task terminated while waiting for control response"
387                        .to_string(),
388                ))
389            }
390            Err(_) => {
391                // Timeout
392                let subtype = request
393                    .get("subtype")
394                    .and_then(Value::as_str)
395                    .unwrap_or("unknown");
396                state
397                    .pending_controls
398                    .lock()
399                    .await
400                    .senders
401                    .remove(&request_id);
402                Err(Error::Other(format!("Control request timeout: {subtype}")))
403            }
404        }
405    }
406
407    /// Sends a user text message to the CLI.
408    ///
409    /// # Example
410    ///
411    /// ```rust,ignore
412    /// use claude_code::Query;
413    ///
414    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
415    /// query.send_user_message("hello", "default").await?;
416    /// # Ok(())
417    /// # }
418    /// ```
419    pub async fn send_user_message(&self, prompt: &str, session_id: &str) -> Result<()> {
420        let message = json!({
421            "type": "user",
422            "message": {"role": "user", "content": prompt},
423            "parent_tool_use_id": Value::Null,
424            "session_id": session_id
425        });
426        self.write_message(&message).await
427    }
428
429    /// Sends a raw JSON message to the CLI without any transformation.
430    ///
431    /// # Example
432    ///
433    /// ```rust,ignore
434    /// use claude_code::Query;
435    /// use serde_json::json;
436    ///
437    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
438    /// query.send_raw_message(json!({"type":"user","message":{"role":"user","content":"hi"}})).await?;
439    /// # Ok(())
440    /// # }
441    /// ```
442    pub async fn send_raw_message(&self, message: Value) -> Result<()> {
443        self.write_message(&message).await
444    }
445
446    /// Writes a JSON message to the shared writer.
447    async fn write_message(&self, message: &Value) -> Result<()> {
448        let state = self
449            .state
450            .as_ref()
451            .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
452        state
453            .writer
454            .lock()
455            .await
456            .write(&(message.to_string() + "\n"))
457            .await
458    }
459
460    /// Sends multiple input messages to the CLI without closing stdin.
461    ///
462    /// # Example
463    ///
464    /// ```rust,ignore
465    /// use claude_code::Query;
466    /// use serde_json::json;
467    ///
468    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
469    /// query
470    ///     .send_input_messages(vec![json!({"type":"user","message":{"role":"user","content":"hello"}})])
471    ///     .await?;
472    /// # Ok(())
473    /// # }
474    /// ```
475    pub async fn send_input_messages(&self, messages: Vec<Value>) -> Result<()> {
476        for message in messages {
477            self.send_raw_message(message).await?;
478        }
479        Ok(())
480    }
481
482    /// Sends streamed input messages to the CLI without closing stdin.
483    ///
484    /// # Example
485    ///
486    /// ```rust,ignore
487    /// use claude_code::Query;
488    /// use futures::stream;
489    /// use serde_json::json;
490    ///
491    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
492    /// query
493    ///     .send_input_from_stream(stream::iter(vec![json!({"type":"user","message":{"role":"user","content":"hello"}})]))
494    ///     .await?;
495    /// # Ok(())
496    /// # }
497    /// ```
498    pub async fn send_input_from_stream<S>(&self, mut messages: S) -> Result<()>
499    where
500        S: Stream<Item = Value> + Unpin,
501    {
502        while let Some(message) = messages.next().await {
503            self.send_raw_message(message).await?;
504        }
505        Ok(())
506    }
507
508    /// Spawns a background task that streams input messages to the CLI.
509    ///
510    /// This is useful for long-lived or unbounded input streams where the caller
511    /// should continue processing messages concurrently.
512    ///
513    /// The returned task completes when the input stream ends or a write error
514    /// occurs. It does not close stdin.
515    ///
516    /// # Example
517    ///
518    /// ```rust,ignore
519    /// use claude_code::Query;
520    /// use futures::stream;
521    /// use serde_json::json;
522    ///
523    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
524    /// let handle = query.spawn_input_from_stream(stream::iter(vec![
525    ///     json!({"type":"user","message":{"role":"user","content":"hello"}}),
526    /// ]))?;
527    /// handle.await??;
528    /// # Ok(())
529    /// # }
530    /// ```
531    pub fn spawn_input_from_stream<S>(&self, mut messages: S) -> Result<JoinHandle<Result<()>>>
532    where
533        S: Stream<Item = Value> + Send + Unpin + 'static,
534    {
535        let state = self
536            .state
537            .as_ref()
538            .cloned()
539            .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
540
541        Ok(tokio::spawn(async move {
542            while let Some(message) = messages.next().await {
543                state
544                    .writer
545                    .lock()
546                    .await
547                    .write(&(message.to_string() + "\n"))
548                    .await?;
549            }
550            Ok(())
551        }))
552    }
553
554    /// Streams multiple messages to the CLI and closes the input stream.
555    ///
556    /// If SDK MCP servers or hooks are present, stdin close is deferred until
557    /// the first result message is received (or a timeout expires).
558    ///
559    /// # Example
560    ///
561    /// ```rust,ignore
562    /// use claude_code::Query;
563    /// use serde_json::json;
564    ///
565    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
566    /// query
567    ///     .stream_input(vec![json!({"type":"user","message":{"role":"user","content":"hello"}})])
568    ///     .await?;
569    /// # Ok(())
570    /// # }
571    /// ```
572    pub async fn stream_input(&self, messages: Vec<Value>) -> Result<()> {
573        self.send_input_messages(messages).await?;
574        self.finalize_stream_input().await
575    }
576
577    /// Streams messages from an async stream source and closes the input stream.
578    ///
579    /// # Example
580    ///
581    /// ```rust,ignore
582    /// use claude_code::Query;
583    /// use futures::stream;
584    /// use serde_json::json;
585    ///
586    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
587    /// query
588    ///     .stream_input_from_stream(stream::iter(vec![json!({"type":"user","message":{"role":"user","content":"hello"}})]))
589    ///     .await?;
590    /// # Ok(())
591    /// # }
592    /// ```
593    pub async fn stream_input_from_stream<S>(&self, mut messages: S) -> Result<()>
594    where
595        S: Stream<Item = Value> + Unpin,
596    {
597        self.send_input_from_stream(&mut messages).await?;
598        self.finalize_stream_input().await
599    }
600
601    async fn finalize_stream_input(&self) -> Result<()> {
602        let state = self
603            .state
604            .as_ref()
605            .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
606
607        if self.has_hooks_or_mcp {
608            debug!(
609                has_hooks_or_mcp = self.has_hooks_or_mcp,
610                "Deferring stdin close until first result"
611            );
612            state.pending_stdin_close.store(true, Ordering::SeqCst);
613        } else {
614            state.writer.lock().await.end_input().await?;
615        }
616        Ok(())
617    }
618
619    /// Closes the input stream without sending any messages.
620    ///
621    /// # Example
622    ///
623    /// ```rust,ignore
624    /// use claude_code::Query;
625    ///
626    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
627    /// query.end_input().await?;
628    /// # Ok(())
629    /// # }
630    /// ```
631    pub async fn end_input(&self) -> Result<()> {
632        let state = self
633            .state
634            .as_ref()
635            .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
636        state.writer.lock().await.end_input().await
637    }
638
639    /// Receives the next content message from the CLI.
640    ///
641    /// Messages are delivered by the background reader task via an mpsc channel.
642    /// Control messages are handled transparently by the background task.
643    ///
644    /// Returns `None` when the stream is exhausted (no more messages).
645    ///
646    /// # Example
647    ///
648    /// ```rust,ignore
649    /// use claude_code::Query;
650    ///
651    /// # async fn demo(query: &mut Query) -> claude_code::Result<()> {
652    /// while let Some(message) = query.receive_next_message().await? {
653    ///     println!("{message:?}");
654    /// }
655    /// # Ok(())
656    /// # }
657    /// ```
658    pub async fn receive_next_message(&mut self) -> Result<Option<Message>> {
659        let rx = self
660            .message_rx
661            .as_mut()
662            .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
663
664        match rx.recv().await {
665            Some(Ok(message)) => Ok(Some(message)),
666            Some(Err(err)) => Err(err),
667            None => Ok(None),
668        }
669    }
670
671    /// Queries the status of connected MCP servers via the CLI.
672    ///
673    /// # Example
674    ///
675    /// ```rust,ignore
676    /// use claude_code::Query;
677    ///
678    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
679    /// let _status = query.get_mcp_status().await?;
680    /// # Ok(())
681    /// # }
682    /// ```
683    pub async fn get_mcp_status(&self) -> Result<McpStatusResponse> {
684        let raw = self
685            .send_control_request(json!({ "subtype": "mcp_status" }), Duration::from_secs(60))
686            .await?;
687        serde_json::from_value(raw).map_err(|err| {
688            Error::Other(format!("Failed to decode typed MCP status response: {err}"))
689        })
690    }
691
692    /// Sends an interrupt signal to the CLI to stop the current operation.
693    ///
694    /// # Example
695    ///
696    /// ```rust,ignore
697    /// use claude_code::Query;
698    ///
699    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
700    /// query.interrupt().await?;
701    /// # Ok(())
702    /// # }
703    /// ```
704    pub async fn interrupt(&self) -> Result<()> {
705        self.send_control_request(json!({ "subtype": "interrupt" }), Duration::from_secs(60))
706            .await?;
707        Ok(())
708    }
709
710    /// Changes the permission mode via a control request.
711    ///
712    /// # Example
713    ///
714    /// ```rust,ignore
715    /// use claude_code::Query;
716    ///
717    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
718    /// query.set_permission_mode("plan").await?;
719    /// # Ok(())
720    /// # }
721    /// ```
722    pub async fn set_permission_mode(&self, mode: &str) -> Result<()> {
723        self.send_control_request(
724            json!({ "subtype": "set_permission_mode", "mode": mode }),
725            Duration::from_secs(60),
726        )
727        .await?;
728        Ok(())
729    }
730
731    /// Changes the model used by the CLI via a control request.
732    ///
733    /// # Example
734    ///
735    /// ```rust,ignore
736    /// use claude_code::Query;
737    ///
738    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
739    /// query.set_model(Some("sonnet")).await?;
740    /// # Ok(())
741    /// # }
742    /// ```
743    pub async fn set_model(&self, model: Option<&str>) -> Result<()> {
744        self.send_control_request(
745            json!({ "subtype": "set_model", "model": model }),
746            Duration::from_secs(60),
747        )
748        .await?;
749        Ok(())
750    }
751
752    /// Rewinds file changes to a specific user message checkpoint.
753    ///
754    /// # Example
755    ///
756    /// ```rust,ignore
757    /// use claude_code::Query;
758    ///
759    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
760    /// query.rewind_files("user-msg-1").await?;
761    /// # Ok(())
762    /// # }
763    /// ```
764    pub async fn rewind_files(&self, user_message_id: &str) -> Result<()> {
765        self.send_control_request(
766            json!({ "subtype": "rewind_files", "user_message_id": user_message_id }),
767            Duration::from_secs(60),
768        )
769        .await?;
770        Ok(())
771    }
772
773    /// Reconnects a disconnected or failed MCP server.
774    pub async fn reconnect_mcp_server(&self, server_name: &str) -> Result<()> {
775        self.send_control_request(
776            json!({ "subtype": "mcp_reconnect", "serverName": server_name }),
777            Duration::from_secs(60),
778        )
779        .await?;
780        Ok(())
781    }
782
783    /// Enables or disables an MCP server.
784    pub async fn toggle_mcp_server(&self, server_name: &str, enabled: bool) -> Result<()> {
785        self.send_control_request(
786            json!({ "subtype": "mcp_toggle", "serverName": server_name, "enabled": enabled }),
787            Duration::from_secs(60),
788        )
789        .await?;
790        Ok(())
791    }
792
793    /// Stops a running task.
794    pub async fn stop_task(&self, task_id: &str) -> Result<()> {
795        self.send_control_request(
796            json!({ "subtype": "stop_task", "task_id": task_id }),
797            Duration::from_secs(60),
798        )
799        .await?;
800        Ok(())
801    }
802
803    /// Closes the query session.
804    ///
805    /// # Example
806    ///
807    /// ```rust,ignore
808    /// use claude_code::Query;
809    ///
810    /// # async fn demo(query: Query) -> claude_code::Result<()> {
811    /// query.close().await?;
812    /// # Ok(())
813    /// # }
814    /// ```
815    pub async fn close(mut self) -> Result<()> {
816        self.shutdown().await
817    }
818
819    /// Internal shutdown logic.
820    async fn shutdown(&mut self) -> Result<()> {
821        self.message_rx.take();
822        self.state.take();
823
824        if let Some(task) = self.reader_task.take() {
825            task.abort();
826            let _ = task.await;
827        }
828
829        if let Some(close_handle) = self.close_handle.take() {
830            close_handle.close().await?;
831        }
832
833        Ok(())
834    }
835
836    /// Takes the message receiver for stream construction.
837    pub(crate) fn take_message_receiver(&mut self) -> Option<mpsc::Receiver<Result<Message>>> {
838        self.message_rx.take()
839    }
840}
841
842impl Drop for Query {
843    fn drop(&mut self) {
844        if let Some(task) = self.reader_task.take() {
845            task.abort();
846        }
847
848        if let Some(close_handle) = self.close_handle.take() {
849            // Spawn a detached task to perform async cleanup.
850            // If no runtime is available, fall back to a temporary current-thread
851            // runtime for best-effort synchronous cleanup.
852            if let Ok(handle) = tokio::runtime::Handle::try_current() {
853                handle.spawn(async move {
854                    let _ = close_handle.close().await;
855                });
856            } else if let Ok(runtime) = tokio::runtime::Builder::new_current_thread()
857                .enable_all()
858                .build()
859            {
860                let _ = runtime.block_on(async move { close_handle.close().await });
861            }
862        }
863    }
864}
865
866// ---------------------------------------------------------------------------
867// Background Reader Task
868// ---------------------------------------------------------------------------
869
870/// Background task that continuously reads from the transport reader and routes
871/// messages to their appropriate destinations.
872async fn background_reader_task(
873    mut reader: Box<dyn TransportReader>,
874    state: Arc<QuerySharedState>,
875    message_tx: mpsc::Sender<Result<Message>>,
876) {
877    loop {
878        // Handle deferred stdin close timeout.
879        let read_result = if state.pending_stdin_close.load(Ordering::SeqCst) {
880            let timeout_dur = state.stream_close_timeout;
881            match tokio::time::timeout(timeout_dur, reader.read_next_message()).await {
882                Ok(result) => result,
883                Err(_) => {
884                    debug!("Timed out waiting for first result, closing input stream");
885                    try_close_deferred_stdin(&state).await;
886                    continue;
887                }
888            }
889        } else {
890            reader.read_next_message().await
891        };
892
893        let raw = match read_result {
894            Ok(Some(raw)) => raw,
895            Ok(None) => {
896                try_close_deferred_stdin(&state).await;
897                break;
898            }
899            Err(err) => {
900                mark_reader_terminated(&state, err.to_string()).await;
901                let _ = message_tx.send(Err(err)).await;
902                break;
903            }
904        };
905
906        let msg_type = raw.get("type").and_then(Value::as_str).unwrap_or_default();
907
908        if msg_type == "control_response" {
909            handle_control_response(&state, &raw).await;
910            continue;
911        }
912
913        if msg_type == "control_request" {
914            if let Err(err) = handle_control_request(&state, raw).await {
915                debug!("Error handling control request: {err}");
916            }
917            continue;
918        }
919
920        if msg_type == "control_cancel_request" {
921            continue;
922        }
923
924        // Parse and forward SDK messages.
925        match parse_message(&raw) {
926            Ok(Some(msg)) => {
927                if matches!(msg, Message::Result(_))
928                    && state.pending_stdin_close.load(Ordering::SeqCst)
929                {
930                    debug!("Received first result, closing input stream");
931                    try_close_deferred_stdin(&state).await;
932                }
933
934                if message_tx.send(Ok(msg)).await.is_err() {
935                    break;
936                }
937            }
938            Ok(None) => {}
939            Err(err) => {
940                if message_tx
941                    .send(Err(Error::MessageParse(err)))
942                    .await
943                    .is_err()
944                {
945                    break;
946                }
947            }
948        }
949    }
950}
951
952/// Marks reader termination and fails all pending control requests immediately.
953async fn mark_reader_terminated(state: &QuerySharedState, reason: String) {
954    state.reader_terminated.store(true, Ordering::SeqCst);
955    let stored_reason = {
956        let mut termination_reason = state.reader_termination_reason.lock().await;
957        if termination_reason.is_none() {
958            *termination_reason = Some(reason);
959        }
960        termination_reason
961            .clone()
962            .unwrap_or_else(|| "Unknown reason".to_string())
963    };
964
965    let mut controls = state.pending_controls.lock().await;
966    for (_, sender) in controls.senders.drain() {
967        let _ = sender.send(Err(Error::Other(format!(
968            "Background reader task terminated: {stored_reason}"
969        ))));
970    }
971}
972
973/// Returns the recorded reader termination reason or a generic fallback.
974async fn reader_termination_reason(state: &QuerySharedState) -> String {
975    state
976        .reader_termination_reason
977        .lock()
978        .await
979        .clone()
980        .unwrap_or_else(|| "Unknown reason".to_string())
981}
982
983/// Closes deferred stdin via the shared writer.
984async fn try_close_deferred_stdin(state: &QuerySharedState) {
985    if state
986        .pending_stdin_close
987        .compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
988        .is_ok()
989    {
990        if let Err(e) = state.writer.lock().await.end_input().await {
991            debug!("Error closing deferred stdin: {e}");
992        }
993    }
994}
995
996/// Routes a control response to the waiting oneshot sender, or buffers it.
997///
998/// If no sender is registered for this response's `request_id`, the parsed
999/// result is stored in the buffer for later retrieval by `send_control_request`.
1000async fn handle_control_response(state: &QuerySharedState, raw: &Value) {
1001    let Some(response) = raw.get("response").and_then(Value::as_object) else {
1002        return;
1003    };
1004    let response_request_id = response
1005        .get("request_id")
1006        .and_then(Value::as_str)
1007        .unwrap_or_default();
1008
1009    let subtype = response
1010        .get("subtype")
1011        .and_then(Value::as_str)
1012        .unwrap_or_default();
1013
1014    let result: Result<Value> = if subtype == "error" {
1015        let error = response
1016            .get("error")
1017            .and_then(Value::as_str)
1018            .unwrap_or("Unknown error");
1019        Err(Error::Other(error.to_string()))
1020    } else {
1021        Ok(response
1022            .get("response")
1023            .cloned()
1024            .unwrap_or_else(|| json!({})))
1025    };
1026
1027    let mut controls = state.pending_controls.lock().await;
1028    if let Some(sender) = controls.senders.remove(response_request_id) {
1029        let _ = sender.send(result);
1030    } else {
1031        // Response arrived before sender was registered — buffer it.
1032        controls
1033            .buffered
1034            .insert(response_request_id.to_string(), result);
1035    }
1036}
1037
1038async fn handle_can_use_tool_request(
1039    state: &QuerySharedState,
1040    request_data: &Map<String, Value>,
1041) -> Result<Value> {
1042    let callback = state
1043        .can_use_tool
1044        .clone()
1045        .ok_or_else(|| Error::Other("canUseTool callback is not provided".to_string()))?;
1046    let tool_name = request_data
1047        .get("tool_name")
1048        .and_then(Value::as_str)
1049        .unwrap_or_default()
1050        .to_string();
1051    let input = request_data
1052        .get("input")
1053        .cloned()
1054        .unwrap_or_else(|| json!({}));
1055    let suggestions = request_data
1056        .get("permission_suggestions")
1057        .and_then(Value::as_array)
1058        .cloned()
1059        .unwrap_or_default()
1060        .into_iter()
1061        .filter_map(|value| serde_json::from_value(value).ok())
1062        .collect();
1063    let blocked_path = request_data
1064        .get("blocked_path")
1065        .and_then(Value::as_str)
1066        .map(ToString::to_string);
1067    let context = ToolPermissionContext {
1068        suggestions,
1069        blocked_path,
1070        signal: None,
1071    };
1072
1073    let callback_future = panic::catch_unwind(AssertUnwindSafe(|| {
1074        callback(tool_name, input.clone(), context)
1075    }))
1076    .map_err(|payload| callback_panic_error("can_use_tool", payload))?;
1077    let callback_result =
1078        await_callback_with_panic_isolation("can_use_tool", callback_future).await?;
1079    let output = match callback_result {
1080        PermissionResult::Allow(allow) => {
1081            let mut obj = Map::new();
1082            obj.insert("behavior".to_string(), Value::String("allow".to_string()));
1083            obj.insert(
1084                "updatedInput".to_string(),
1085                allow.updated_input.unwrap_or(input),
1086            );
1087            if let Some(updated_permissions) = allow.updated_permissions {
1088                let permissions_json: Vec<Value> = updated_permissions
1089                    .into_iter()
1090                    .map(|permission| permission.to_cli_dict())
1091                    .collect();
1092                obj.insert(
1093                    "updatedPermissions".to_string(),
1094                    Value::Array(permissions_json),
1095                );
1096            }
1097            Value::Object(obj)
1098        }
1099        PermissionResult::Deny(deny) => {
1100            let mut obj = Map::new();
1101            obj.insert("behavior".to_string(), Value::String("deny".to_string()));
1102            obj.insert("message".to_string(), Value::String(deny.message));
1103            if deny.interrupt {
1104                obj.insert("interrupt".to_string(), Value::Bool(true));
1105            }
1106            Value::Object(obj)
1107        }
1108    };
1109    Ok(output)
1110}
1111
1112async fn handle_hook_callback_request(
1113    state: &QuerySharedState,
1114    request_data: &Map<String, Value>,
1115) -> Result<Value> {
1116    let callback_id = request_data
1117        .get("callback_id")
1118        .and_then(Value::as_str)
1119        .ok_or_else(|| Error::Other("Missing callback_id in hook_callback".to_string()))?;
1120    let callback = state
1121        .hook_callbacks
1122        .lock()
1123        .await
1124        .get(callback_id)
1125        .cloned()
1126        .ok_or_else(|| Error::Other(format!("No hook callback found for ID: {callback_id}")))?;
1127    let input = request_data.get("input").cloned().unwrap_or(Value::Null);
1128    let tool_use_id = request_data
1129        .get("tool_use_id")
1130        .and_then(Value::as_str)
1131        .map(ToString::to_string);
1132    let callback_future = panic::catch_unwind(AssertUnwindSafe(|| {
1133        callback(input, tool_use_id, Default::default())
1134    }))
1135    .map_err(|payload| callback_panic_error("hook", payload))?;
1136    let output = await_callback_with_panic_isolation("hook", callback_future).await?;
1137    Ok(convert_hook_output_for_cli(output))
1138}
1139
1140async fn handle_mcp_message_request(
1141    state: &QuerySharedState,
1142    request_data: &Map<String, Value>,
1143) -> Result<Value> {
1144    let server_name = request_data
1145        .get("server_name")
1146        .and_then(Value::as_str)
1147        .ok_or_else(|| Error::Other("Missing server_name in mcp_message".to_string()))?;
1148    let message = request_data
1149        .get("message")
1150        .cloned()
1151        .ok_or_else(|| Error::Other("Missing message in mcp_message".to_string()))?;
1152    let response = handle_sdk_mcp_request(&state.sdk_mcp_servers, server_name, &message).await;
1153    Ok(json!({ "mcp_response": response }))
1154}
1155
1156/// Handles an incoming control request from the CLI within the background task.
1157async fn handle_control_request(state: &QuerySharedState, request: Value) -> Result<()> {
1158    let Some(request_obj) = request.as_object() else {
1159        return Err(Error::Other("Invalid control request format".to_string()));
1160    };
1161    let request_id = request_obj
1162        .get("request_id")
1163        .and_then(Value::as_str)
1164        .ok_or_else(|| Error::Other("Missing request_id in control request".to_string()))?
1165        .to_string();
1166    let request_data = request_obj
1167        .get("request")
1168        .and_then(Value::as_object)
1169        .ok_or_else(|| Error::Other("Missing request payload".to_string()))?;
1170    let subtype = request_data
1171        .get("subtype")
1172        .and_then(Value::as_str)
1173        .ok_or_else(|| Error::Other("Missing request subtype".to_string()))?;
1174
1175    let result: Result<Value> = match subtype {
1176        "can_use_tool" => handle_can_use_tool_request(state, request_data).await,
1177        "hook_callback" => handle_hook_callback_request(state, request_data).await,
1178        "mcp_message" => handle_mcp_message_request(state, request_data).await,
1179        _ => Err(Error::Other(format!(
1180            "Unsupported control request subtype: {subtype}"
1181        ))),
1182    };
1183
1184    let response_json = match result {
1185        Ok(payload) => json!({
1186            "type": "control_response",
1187            "response": {
1188                "subtype": "success",
1189                "request_id": request_id,
1190                "response": payload
1191            }
1192        }),
1193        Err(err) => json!({
1194            "type": "control_response",
1195            "response": {
1196                "subtype": "error",
1197                "request_id": request_id,
1198                "error": err.to_string()
1199            }
1200        }),
1201    };
1202
1203    state
1204        .writer
1205        .lock()
1206        .await
1207        .write(&(response_json.to_string() + "\n"))
1208        .await
1209}
1210
1211/// Routes an MCP message to the appropriate in-process SDK MCP server.
1212///
1213/// Implements JSON-RPC message routing for in-process SDK MCP servers.
1214/// Handles `initialize`, `tools/list`, `tools/call`, and `notifications/initialized` methods.
1215///
1216/// # Example
1217///
1218/// ```rust,no_run
1219/// use claude_code::{create_sdk_mcp_server, tool};
1220/// use claude_code::query::handle_sdk_mcp_request;
1221/// use serde_json::{json, Value};
1222/// use std::collections::HashMap;
1223/// use std::sync::Arc;
1224///
1225/// # async fn example() {
1226///     let config = create_sdk_mcp_server(
1227///     "tools",
1228///     "1.0.0",
1229///     vec![tool("echo", "Echo", json!({"type":"object"}), |_args: Value| async move {
1230///         Ok(json!({"content": []}))
1231///     })],
1232///     );
1233///
1234///     let mut servers = HashMap::new();
1235///     servers.insert(config.name.clone(), Arc::clone(&config.instance));
1236///
1237///     let response = handle_sdk_mcp_request(
1238///     &servers,
1239///     "tools",
1240///     &json!({"jsonrpc":"2.0","id":1,"method":"tools/list"}),
1241///     )
1242///     .await;
1243///
1244///     assert_eq!(response["jsonrpc"], "2.0");
1245/// # }
1246/// ```
1247pub async fn handle_sdk_mcp_request(
1248    sdk_mcp_servers: &HashMap<String, Arc<McpSdkServer>>,
1249    server_name: &str,
1250    message: &Value,
1251) -> Value {
1252    let Some(server) = sdk_mcp_servers.get(server_name) else {
1253        return json!({
1254            "jsonrpc": "2.0",
1255            "id": message.get("id").cloned().unwrap_or(Value::Null),
1256            "error": {
1257                "code": -32601,
1258                "message": format!("Server '{server_name}' not found")
1259            }
1260        });
1261    };
1262
1263    let method = message
1264        .get("method")
1265        .and_then(Value::as_str)
1266        .unwrap_or_default();
1267    let id = message.get("id").cloned().unwrap_or(Value::Null);
1268    let params = message.get("params").cloned().unwrap_or_else(|| json!({}));
1269
1270    match method {
1271        "initialize" => json!({
1272            "jsonrpc": "2.0",
1273            "id": id,
1274            "result": {
1275                "protocolVersion": "2024-11-05",
1276                "capabilities": {"tools": {}},
1277                "serverInfo": {
1278                    "name": server.name,
1279                    "version": server.version
1280                }
1281            }
1282        }),
1283        "tools/list" => json!({
1284            "jsonrpc": "2.0",
1285            "id": id,
1286            "result": {
1287                "tools": server.list_tools_json()
1288            }
1289        }),
1290        "tools/call" => {
1291            let tool_name = params
1292                .get("name")
1293                .and_then(Value::as_str)
1294                .unwrap_or_default();
1295            let arguments = params
1296                .get("arguments")
1297                .cloned()
1298                .unwrap_or_else(|| json!({}));
1299            let result = server.call_tool_json(tool_name, arguments).await;
1300            json!({
1301                "jsonrpc": "2.0",
1302                "id": id,
1303                "result": result
1304            })
1305        }
1306        "notifications/initialized" => json!({
1307            "jsonrpc": "2.0",
1308            "result": {}
1309        }),
1310        _ => json!({
1311            "jsonrpc": "2.0",
1312            "id": id,
1313            "error": {
1314                "code": -32601,
1315                "message": format!("Method '{method}' not found")
1316            }
1317        }),
1318    }
1319}
1320
1321// ---------------------------------------------------------------------------
1322// Helper: Build hooks config and callbacks
1323// ---------------------------------------------------------------------------
1324
1325/// Builds the hooks configuration for the initialization handshake and extracts
1326/// hook callbacks for the background task.
1327pub(crate) fn build_hooks_config(
1328    hooks: &HashMap<String, Vec<HookMatcher>>,
1329) -> (Map<String, Value>, HashMap<String, HookCallback>) {
1330    let mut hooks_config = Map::new();
1331    let mut hook_callbacks = HashMap::new();
1332    let mut next_callback_id: usize = 0;
1333
1334    for (event, matchers) in hooks {
1335        if matchers.is_empty() {
1336            continue;
1337        }
1338        let mut event_matchers = Vec::new();
1339        for matcher in matchers {
1340            let mut callback_ids = Vec::new();
1341            for callback in &matcher.hooks {
1342                let callback_id = format!("hook_{}", next_callback_id);
1343                next_callback_id += 1;
1344                hook_callbacks.insert(callback_id.clone(), callback.clone());
1345                callback_ids.push(callback_id);
1346            }
1347
1348            let mut matcher_obj = Map::new();
1349            matcher_obj.insert(
1350                "matcher".to_string(),
1351                matcher
1352                    .matcher
1353                    .as_ref()
1354                    .map(|m| Value::String(m.clone()))
1355                    .unwrap_or(Value::Null),
1356            );
1357            matcher_obj.insert("hookCallbackIds".to_string(), json!(callback_ids));
1358            if let Some(timeout) = matcher.timeout {
1359                matcher_obj.insert("timeout".to_string(), json!(timeout));
1360            }
1361            event_matchers.push(Value::Object(matcher_obj));
1362        }
1363        hooks_config.insert(event.clone(), Value::Array(event_matchers));
1364    }
1365
1366    (hooks_config, hook_callbacks)
1367}