Skip to main content

claude_code/
query.rs

1//! Core query/session management for Claude Code communication.
2//!
3//! This module provides the [`Query`] struct, which handles the low-level
4//! communication protocol with the Claude Code CLI process, including:
5//!
6//! - Session initialization and handshake
7//! - Control request/response protocol (permissions, hooks, MCP)
8//! - Background message reading and routing via tokio tasks
9//! - Lifecycle management (interrupt, model change, rewind)
10//!
11//! Most users should use [`ClaudeSdkClient`](crate::ClaudeSdkClient) or
12//! [`query()`](crate::query_fn::query) instead of interacting with this module directly.
13
14use std::collections::HashMap;
15use std::future::Future;
16use std::panic::{self, AssertUnwindSafe};
17use std::sync::Arc;
18use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
19use std::time::Duration;
20
21use futures::{FutureExt, Stream, StreamExt};
22use serde_json::{Map, Value, json};
23use tokio::sync::{Mutex, mpsc, oneshot};
24use tokio::task::JoinHandle;
25use tracing::{debug, warn};
26
27use crate::errors::{Error, Result};
28use crate::message_parser::parse_message;
29use crate::sdk_mcp::McpSdkServer;
30use crate::transport::{TransportCloseHandle, TransportReader, TransportWriter};
31use crate::types::{
32    AgentDefinition, CanUseToolCallback, HookCallback, HookMatcher, Message, PermissionResult,
33    ToolPermissionContext,
34};
35
36/// Channel buffer size for SDK messages (matches Python SDK's buffer=100).
37const MESSAGE_CHANNEL_BUFFER: usize = 100;
38
39/// Converts hook callback output keys from Rust-safe names to CLI protocol names.
40///
41/// Specifically maps `async_` → `async` and `continue_` → `continue`, since
42/// those are reserved words in Rust.
43fn convert_hook_output_for_cli(output: Value) -> Value {
44    let Some(obj) = output.as_object() else {
45        return output;
46    };
47
48    let mut converted = Map::new();
49    for (key, value) in obj {
50        match key.as_str() {
51            "async_" => {
52                converted.insert("async".to_string(), value.clone());
53            }
54            "continue_" => {
55                converted.insert("continue".to_string(), value.clone());
56            }
57            _ => {
58                converted.insert(key.clone(), value.clone());
59            }
60        }
61    }
62    Value::Object(converted)
63}
64
65fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
66    if let Some(msg) = payload.downcast_ref::<&str>() {
67        (*msg).to_string()
68    } else if let Some(msg) = payload.downcast_ref::<String>() {
69        msg.clone()
70    } else {
71        "unknown panic payload".to_string()
72    }
73}
74
75fn callback_panic_error(callback_type: &str, payload: Box<dyn std::any::Any + Send>) -> Error {
76    let panic_message = panic_payload_to_string(payload);
77    warn!(
78        callback_type,
79        panic_message, "Caught panic in callback invocation"
80    );
81    Error::Other(format!(
82        "{callback_type} callback panicked: {panic_message}"
83    ))
84}
85
86async fn await_callback_with_panic_isolation<T, F>(
87    callback_type: &str,
88    callback_future: F,
89) -> Result<T>
90where
91    F: Future<Output = Result<T>>,
92{
93    match AssertUnwindSafe(callback_future).catch_unwind().await {
94        Ok(result) => result,
95        Err(payload) => Err(callback_panic_error(callback_type, payload)),
96    }
97}
98
99/// Tracks pending control request senders and early-arrival response buffers.
100///
101/// Both maps are behind a single mutex to ensure atomicity: a response that
102/// arrives before the sender is registered gets buffered, and a later
103/// `send_control_request` drains the buffer under the same lock.
104struct PendingControlsState {
105    senders: HashMap<String, oneshot::Sender<Result<Value>>>,
106    buffered: HashMap<String, Result<Value>>,
107}
108
109/// Shared state accessible by both the background reader task and the main task.
110struct QuerySharedState {
111    can_use_tool: Option<CanUseToolCallback>,
112    hook_callbacks: Mutex<HashMap<String, HookCallback>>,
113    sdk_mcp_servers: HashMap<String, Arc<McpSdkServer>>,
114    /// Pending control request/response matching.
115    pending_controls: Mutex<PendingControlsState>,
116    /// Shared writer for responding to control requests and sending messages.
117    writer: Arc<Mutex<Box<dyn TransportWriter>>>,
118    /// Whether stdin close is deferred until first result.
119    pending_stdin_close: AtomicBool,
120    /// Timeout for deferred stdin close.
121    stream_close_timeout: Duration,
122    /// Whether the background reader task has terminated.
123    reader_terminated: AtomicBool,
124    /// Reason for reader task termination, if known.
125    reader_termination_reason: Mutex<Option<String>>,
126}
127
128/// Low-level query session handler for Claude Code CLI communication.
129///
130/// Manages the bidirectional JSON stream protocol between the SDK and the CLI.
131/// On startup, a background tokio task
132/// is spawned to continuously read messages from the transport and route them:
133///
134/// - **Control responses** are delivered to the waiting control-request caller via oneshot channels.
135/// - **Control requests** (permissions, hooks, MCP) are handled by the background task.
136/// - **SDK messages** (user, assistant, system, result) are parsed and delivered via an mpsc channel.
137///
138/// This architecture mirrors the Python SDK's task-group model and enables
139/// concurrent send and receive operations.
140pub struct Query {
141    /// Shared state for background task and main task.
142    state: Option<Arc<QuerySharedState>>,
143
144    /// Receiver end of the SDK message channel.
145    message_rx: Option<mpsc::Receiver<Result<Message>>>,
146
147    /// Handle for the background reader task.
148    reader_task: Option<JoinHandle<()>>,
149
150    /// Handle for closing the split transport.
151    close_handle: Option<Box<dyn TransportCloseHandle>>,
152
153    /// Monotonically increasing request ID counter.
154    request_counter: Arc<AtomicUsize>,
155
156    /// Whether the query is in streaming mode.
157    is_streaming_mode: bool,
158
159    /// Agent definitions to register during initialization.
160    agents: Option<HashMap<String, AgentDefinition>>,
161
162    /// Whether initialization has completed.
163    initialized: bool,
164
165    /// The initialization response from the CLI.
166    initialization_result: Option<Value>,
167
168    /// Timeout for the initialization handshake.
169    initialize_timeout: Duration,
170
171    /// Whether hooks or SDK MCP servers are present (for deferred stdin close).
172    has_hooks_or_mcp: bool,
173}
174
175impl Query {
176    /// Creates a new `Query` and starts the background reader task.
177    ///
178    /// This is the primary constructor. It splits the given reader and writer,
179    /// registers callbacks, and spawns the background task.
180    ///
181    /// # Arguments
182    ///
183    /// * `reader` — The transport reader half.
184    /// * `writer` — The transport writer half (wrapped in `Arc<Mutex<>>` for sharing).
185    /// * `close_handle` — Handle for closing the transport.
186    /// * `is_streaming_mode` — Whether to use the streaming protocol.
187    /// * `can_use_tool` — Optional permission callback for tool approval.
188    /// * `hook_callbacks` — Hook callbacks keyed by callback ID.
189    /// * `sdk_mcp_servers` — In-process MCP servers.
190    /// * `agents` — Optional subagent definitions.
191    /// * `initialize_timeout` — Timeout for the initialization handshake.
192    #[allow(clippy::too_many_arguments)]
193    pub(crate) fn start(
194        reader: Box<dyn TransportReader>,
195        writer: Box<dyn TransportWriter>,
196        close_handle: Box<dyn TransportCloseHandle>,
197        is_streaming_mode: bool,
198        can_use_tool: Option<CanUseToolCallback>,
199        hook_callbacks: HashMap<String, HookCallback>,
200        sdk_mcp_servers: HashMap<String, Arc<McpSdkServer>>,
201        agents: Option<HashMap<String, AgentDefinition>>,
202        initialize_timeout: Duration,
203    ) -> Self {
204        let stream_close_timeout_ms: u64 = std::env::var("CLAUDE_CODE_STREAM_CLOSE_TIMEOUT")
205            .ok()
206            .and_then(|v| v.parse().ok())
207            .unwrap_or(60_000);
208        let stream_close_timeout =
209            Duration::from_millis(stream_close_timeout_ms).max(Duration::from_secs(60));
210
211        let has_hooks_or_mcp = !hook_callbacks.is_empty() || !sdk_mcp_servers.is_empty();
212        let writer = Arc::new(Mutex::new(writer));
213
214        let state = Arc::new(QuerySharedState {
215            can_use_tool,
216            hook_callbacks: Mutex::new(hook_callbacks),
217            sdk_mcp_servers,
218            pending_controls: Mutex::new(PendingControlsState {
219                senders: HashMap::new(),
220                buffered: HashMap::new(),
221            }),
222            writer: writer.clone(),
223            pending_stdin_close: AtomicBool::new(false),
224            stream_close_timeout,
225            reader_terminated: AtomicBool::new(false),
226            reader_termination_reason: Mutex::new(None),
227        });
228
229        let (message_tx, message_rx) = mpsc::channel(MESSAGE_CHANNEL_BUFFER);
230
231        let reader_state = state.clone();
232        let reader_task = tokio::spawn(async move {
233            background_reader_task(reader, reader_state, message_tx).await;
234        });
235
236        Self {
237            state: Some(state),
238            message_rx: Some(message_rx),
239            reader_task: Some(reader_task),
240            close_handle: Some(close_handle),
241            request_counter: Arc::new(AtomicUsize::new(0)),
242            is_streaming_mode,
243            agents,
244            initialized: false,
245            initialization_result: None,
246            initialize_timeout,
247            has_hooks_or_mcp,
248        }
249    }
250
251    /// Sends the initialization handshake to the CLI.
252    ///
253    /// Registers hook callbacks and agent definitions with the CLI process,
254    /// and waits for the initialization response.
255    ///
256    /// Returns the initialization response payload, or `None` if not in streaming mode.
257    ///
258    /// # Example
259    ///
260    /// ```rust,ignore
261    /// use claude_code::Query;
262    /// use serde_json::Map;
263    /// use serde_json::Value;
264    ///
265    /// # async fn demo(query: &mut Query) -> claude_code::Result<()> {
266    /// let _ = query.initialize(Map::<String, Value>::new()).await?;
267    /// # Ok(())
268    /// # }
269    /// ```
270    pub async fn initialize(&mut self, hooks_config: Map<String, Value>) -> Result<Option<Value>> {
271        if !self.is_streaming_mode {
272            return Ok(None);
273        }
274
275        let mut request = Map::new();
276        request.insert(
277            "subtype".to_string(),
278            Value::String("initialize".to_string()),
279        );
280        request.insert(
281            "hooks".to_string(),
282            if hooks_config.is_empty() {
283                Value::Null
284            } else {
285                Value::Object(hooks_config)
286            },
287        );
288
289        if let Some(agents) = &self.agents {
290            request.insert(
291                "agents".to_string(),
292                serde_json::to_value(agents).unwrap_or(Value::Null),
293            );
294        }
295
296        let response = self
297            .send_control_request(Value::Object(request), self.initialize_timeout)
298            .await?;
299        self.initialized = true;
300        self.initialization_result = Some(response.clone());
301        Ok(Some(response))
302    }
303
304    /// Returns the initialization result from the CLI handshake.
305    ///
306    /// Returns `None` if [`initialize()`](Self::initialize) has not been called yet.
307    ///
308    /// # Example
309    ///
310    /// ```rust,ignore
311    /// use claude_code::Query;
312    ///
313    /// fn demo(query: &Query) {
314    ///     let _info = query.initialization_result();
315    /// }
316    /// ```
317    pub fn initialization_result(&self) -> Option<Value> {
318        self.initialization_result.clone()
319    }
320
321    /// Sends a control request to the CLI and waits for the matching response.
322    ///
323    /// The request is written via the shared writer. The background reader task
324    /// delivers the matching control response via a oneshot channel.
325    async fn send_control_request(&self, request: Value, timeout: Duration) -> Result<Value> {
326        if !self.is_streaming_mode {
327            return Err(Error::Other(
328                "Control requests require streaming mode".to_string(),
329            ));
330        }
331
332        let state = self
333            .state
334            .as_ref()
335            .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
336
337        let request_id = format!(
338            "req_{}",
339            self.request_counter.fetch_add(1, Ordering::SeqCst) + 1
340        );
341
342        // Write the control request first (so it's always observable).
343        let control_request = json!({
344            "type": "control_request",
345            "request_id": request_id,
346            "request": request,
347        });
348        state
349            .writer
350            .lock()
351            .await
352            .write(&(control_request.to_string() + "\n"))
353            .await?;
354
355        // Register a oneshot channel for the response, checking the buffer first.
356        // Both operations are under a single lock to avoid a race where a response
357        // arrives between checking the buffer and registering the sender.
358        let (tx, rx) = oneshot::channel();
359        {
360            let mut controls = state.pending_controls.lock().await;
361            if let Some(result) = controls.buffered.remove(&request_id) {
362                return result;
363            }
364            controls.senders.insert(request_id.clone(), tx);
365        }
366        if state.reader_terminated.load(Ordering::SeqCst) {
367            state
368                .pending_controls
369                .lock()
370                .await
371                .senders
372                .remove(&request_id);
373            let reason = reader_termination_reason(state).await;
374            return Err(Error::Other(format!(
375                "Background reader task terminated: {reason}"
376            )));
377        }
378
379        // Wait for the response with timeout.
380        let result = tokio::time::timeout(timeout, rx).await;
381        match result {
382            Ok(Ok(value)) => value,
383            Ok(Err(_)) => {
384                // Channel closed — background task died
385                Err(Error::Other(
386                    "Background reader task terminated while waiting for control response"
387                        .to_string(),
388                ))
389            }
390            Err(_) => {
391                // Timeout
392                let subtype = request
393                    .get("subtype")
394                    .and_then(Value::as_str)
395                    .unwrap_or("unknown");
396                state
397                    .pending_controls
398                    .lock()
399                    .await
400                    .senders
401                    .remove(&request_id);
402                Err(Error::Other(format!("Control request timeout: {subtype}")))
403            }
404        }
405    }
406
407    /// Sends a user text message to the CLI.
408    ///
409    /// # Example
410    ///
411    /// ```rust,ignore
412    /// use claude_code::Query;
413    ///
414    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
415    /// query.send_user_message("hello", "default").await?;
416    /// # Ok(())
417    /// # }
418    /// ```
419    pub async fn send_user_message(&self, prompt: &str, session_id: &str) -> Result<()> {
420        let message = json!({
421            "type": "user",
422            "message": {"role": "user", "content": prompt},
423            "parent_tool_use_id": Value::Null,
424            "session_id": session_id
425        });
426        self.write_message(&message).await
427    }
428
429    /// Sends a raw JSON message to the CLI without any transformation.
430    ///
431    /// # Example
432    ///
433    /// ```rust,ignore
434    /// use claude_code::Query;
435    /// use serde_json::json;
436    ///
437    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
438    /// query.send_raw_message(json!({"type":"user","message":{"role":"user","content":"hi"}})).await?;
439    /// # Ok(())
440    /// # }
441    /// ```
442    pub async fn send_raw_message(&self, message: Value) -> Result<()> {
443        self.write_message(&message).await
444    }
445
446    /// Writes a JSON message to the shared writer.
447    async fn write_message(&self, message: &Value) -> Result<()> {
448        let state = self
449            .state
450            .as_ref()
451            .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
452        state
453            .writer
454            .lock()
455            .await
456            .write(&(message.to_string() + "\n"))
457            .await
458    }
459
460    /// Sends multiple input messages to the CLI without closing stdin.
461    ///
462    /// # Example
463    ///
464    /// ```rust,ignore
465    /// use claude_code::Query;
466    /// use serde_json::json;
467    ///
468    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
469    /// query
470    ///     .send_input_messages(vec![json!({"type":"user","message":{"role":"user","content":"hello"}})])
471    ///     .await?;
472    /// # Ok(())
473    /// # }
474    /// ```
475    pub async fn send_input_messages(&self, messages: Vec<Value>) -> Result<()> {
476        for message in messages {
477            self.send_raw_message(message).await?;
478        }
479        Ok(())
480    }
481
482    /// Sends streamed input messages to the CLI without closing stdin.
483    ///
484    /// # Example
485    ///
486    /// ```rust,ignore
487    /// use claude_code::Query;
488    /// use futures::stream;
489    /// use serde_json::json;
490    ///
491    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
492    /// query
493    ///     .send_input_from_stream(stream::iter(vec![json!({"type":"user","message":{"role":"user","content":"hello"}})]))
494    ///     .await?;
495    /// # Ok(())
496    /// # }
497    /// ```
498    pub async fn send_input_from_stream<S>(&self, mut messages: S) -> Result<()>
499    where
500        S: Stream<Item = Value> + Unpin,
501    {
502        while let Some(message) = messages.next().await {
503            self.send_raw_message(message).await?;
504        }
505        Ok(())
506    }
507
508    /// Spawns a background task that streams input messages to the CLI.
509    ///
510    /// This is useful for long-lived or unbounded input streams where the caller
511    /// should continue processing messages concurrently.
512    ///
513    /// The returned task completes when the input stream ends or a write error
514    /// occurs. It does not close stdin.
515    ///
516    /// # Example
517    ///
518    /// ```rust,ignore
519    /// use claude_code::Query;
520    /// use futures::stream;
521    /// use serde_json::json;
522    ///
523    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
524    /// let handle = query.spawn_input_from_stream(stream::iter(vec![
525    ///     json!({"type":"user","message":{"role":"user","content":"hello"}}),
526    /// ]))?;
527    /// handle.await??;
528    /// # Ok(())
529    /// # }
530    /// ```
531    pub fn spawn_input_from_stream<S>(&self, mut messages: S) -> Result<JoinHandle<Result<()>>>
532    where
533        S: Stream<Item = Value> + Send + Unpin + 'static,
534    {
535        let state = self
536            .state
537            .as_ref()
538            .cloned()
539            .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
540
541        Ok(tokio::spawn(async move {
542            while let Some(message) = messages.next().await {
543                state
544                    .writer
545                    .lock()
546                    .await
547                    .write(&(message.to_string() + "\n"))
548                    .await?;
549            }
550            Ok(())
551        }))
552    }
553
554    /// Streams multiple messages to the CLI and closes the input stream.
555    ///
556    /// If SDK MCP servers or hooks are present, stdin close is deferred until
557    /// the first result message is received (or a timeout expires).
558    ///
559    /// # Example
560    ///
561    /// ```rust,ignore
562    /// use claude_code::Query;
563    /// use serde_json::json;
564    ///
565    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
566    /// query
567    ///     .stream_input(vec![json!({"type":"user","message":{"role":"user","content":"hello"}})])
568    ///     .await?;
569    /// # Ok(())
570    /// # }
571    /// ```
572    pub async fn stream_input(&self, messages: Vec<Value>) -> Result<()> {
573        self.send_input_messages(messages).await?;
574        self.finalize_stream_input().await
575    }
576
577    /// Streams messages from an async stream source and closes the input stream.
578    ///
579    /// # Example
580    ///
581    /// ```rust,ignore
582    /// use claude_code::Query;
583    /// use futures::stream;
584    /// use serde_json::json;
585    ///
586    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
587    /// query
588    ///     .stream_input_from_stream(stream::iter(vec![json!({"type":"user","message":{"role":"user","content":"hello"}})]))
589    ///     .await?;
590    /// # Ok(())
591    /// # }
592    /// ```
593    pub async fn stream_input_from_stream<S>(&self, mut messages: S) -> Result<()>
594    where
595        S: Stream<Item = Value> + Unpin,
596    {
597        self.send_input_from_stream(&mut messages).await?;
598        self.finalize_stream_input().await
599    }
600
601    async fn finalize_stream_input(&self) -> Result<()> {
602        let state = self
603            .state
604            .as_ref()
605            .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
606
607        if self.has_hooks_or_mcp {
608            debug!(
609                has_hooks_or_mcp = self.has_hooks_or_mcp,
610                "Deferring stdin close until first result"
611            );
612            state.pending_stdin_close.store(true, Ordering::SeqCst);
613        } else {
614            state.writer.lock().await.end_input().await?;
615        }
616        Ok(())
617    }
618
619    /// Closes the input stream without sending any messages.
620    ///
621    /// # Example
622    ///
623    /// ```rust,ignore
624    /// use claude_code::Query;
625    ///
626    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
627    /// query.end_input().await?;
628    /// # Ok(())
629    /// # }
630    /// ```
631    pub async fn end_input(&self) -> Result<()> {
632        let state = self
633            .state
634            .as_ref()
635            .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
636        state.writer.lock().await.end_input().await
637    }
638
639    /// Receives the next content message from the CLI.
640    ///
641    /// Messages are delivered by the background reader task via an mpsc channel.
642    /// Control messages are handled transparently by the background task.
643    ///
644    /// Returns `None` when the stream is exhausted (no more messages).
645    ///
646    /// # Example
647    ///
648    /// ```rust,ignore
649    /// use claude_code::Query;
650    ///
651    /// # async fn demo(query: &mut Query) -> claude_code::Result<()> {
652    /// while let Some(message) = query.receive_next_message().await? {
653    ///     println!("{message:?}");
654    /// }
655    /// # Ok(())
656    /// # }
657    /// ```
658    pub async fn receive_next_message(&mut self) -> Result<Option<Message>> {
659        let rx = self
660            .message_rx
661            .as_mut()
662            .ok_or_else(|| Error::Other("Query not started or already closed.".to_string()))?;
663
664        match rx.recv().await {
665            Some(Ok(message)) => Ok(Some(message)),
666            Some(Err(err)) => Err(err),
667            None => Ok(None),
668        }
669    }
670
671    /// Queries the status of connected MCP servers via the CLI.
672    ///
673    /// # Example
674    ///
675    /// ```rust,ignore
676    /// use claude_code::Query;
677    ///
678    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
679    /// let _status = query.get_mcp_status().await?;
680    /// # Ok(())
681    /// # }
682    /// ```
683    pub async fn get_mcp_status(&self) -> Result<Value> {
684        self.send_control_request(json!({ "subtype": "mcp_status" }), Duration::from_secs(60))
685            .await
686    }
687
688    /// Sends an interrupt signal to the CLI to stop the current operation.
689    ///
690    /// # Example
691    ///
692    /// ```rust,ignore
693    /// use claude_code::Query;
694    ///
695    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
696    /// query.interrupt().await?;
697    /// # Ok(())
698    /// # }
699    /// ```
700    pub async fn interrupt(&self) -> Result<()> {
701        self.send_control_request(json!({ "subtype": "interrupt" }), Duration::from_secs(60))
702            .await?;
703        Ok(())
704    }
705
706    /// Changes the permission mode via a control request.
707    ///
708    /// # Example
709    ///
710    /// ```rust,ignore
711    /// use claude_code::Query;
712    ///
713    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
714    /// query.set_permission_mode("plan").await?;
715    /// # Ok(())
716    /// # }
717    /// ```
718    pub async fn set_permission_mode(&self, mode: &str) -> Result<()> {
719        self.send_control_request(
720            json!({ "subtype": "set_permission_mode", "mode": mode }),
721            Duration::from_secs(60),
722        )
723        .await?;
724        Ok(())
725    }
726
727    /// Changes the model used by the CLI via a control request.
728    ///
729    /// # Example
730    ///
731    /// ```rust,ignore
732    /// use claude_code::Query;
733    ///
734    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
735    /// query.set_model(Some("sonnet")).await?;
736    /// # Ok(())
737    /// # }
738    /// ```
739    pub async fn set_model(&self, model: Option<&str>) -> Result<()> {
740        self.send_control_request(
741            json!({ "subtype": "set_model", "model": model }),
742            Duration::from_secs(60),
743        )
744        .await?;
745        Ok(())
746    }
747
748    /// Rewinds file changes to a specific user message checkpoint.
749    ///
750    /// # Example
751    ///
752    /// ```rust,ignore
753    /// use claude_code::Query;
754    ///
755    /// # async fn demo(query: &Query) -> claude_code::Result<()> {
756    /// query.rewind_files("user-msg-1").await?;
757    /// # Ok(())
758    /// # }
759    /// ```
760    pub async fn rewind_files(&self, user_message_id: &str) -> Result<()> {
761        self.send_control_request(
762            json!({ "subtype": "rewind_files", "user_message_id": user_message_id }),
763            Duration::from_secs(60),
764        )
765        .await?;
766        Ok(())
767    }
768
769    /// Closes the query session.
770    ///
771    /// # Example
772    ///
773    /// ```rust,ignore
774    /// use claude_code::Query;
775    ///
776    /// # async fn demo(query: Query) -> claude_code::Result<()> {
777    /// query.close().await?;
778    /// # Ok(())
779    /// # }
780    /// ```
781    pub async fn close(mut self) -> Result<()> {
782        self.shutdown().await
783    }
784
785    /// Internal shutdown logic.
786    async fn shutdown(&mut self) -> Result<()> {
787        self.message_rx.take();
788        self.state.take();
789
790        if let Some(task) = self.reader_task.take() {
791            task.abort();
792            let _ = task.await;
793        }
794
795        if let Some(close_handle) = self.close_handle.take() {
796            close_handle.close().await?;
797        }
798
799        Ok(())
800    }
801
802    /// Takes the message receiver for stream construction.
803    pub(crate) fn take_message_receiver(&mut self) -> Option<mpsc::Receiver<Result<Message>>> {
804        self.message_rx.take()
805    }
806}
807
808impl Drop for Query {
809    fn drop(&mut self) {
810        if let Some(task) = self.reader_task.take() {
811            task.abort();
812        }
813
814        if let Some(close_handle) = self.close_handle.take() {
815            // Spawn a detached task to perform async cleanup.
816            // If no runtime is available, fall back to a temporary current-thread
817            // runtime for best-effort synchronous cleanup.
818            if let Ok(handle) = tokio::runtime::Handle::try_current() {
819                handle.spawn(async move {
820                    let _ = close_handle.close().await;
821                });
822            } else if let Ok(runtime) = tokio::runtime::Builder::new_current_thread()
823                .enable_all()
824                .build()
825            {
826                let _ = runtime.block_on(async move { close_handle.close().await });
827            }
828        }
829    }
830}
831
832// ---------------------------------------------------------------------------
833// Background Reader Task
834// ---------------------------------------------------------------------------
835
836/// Background task that continuously reads from the transport reader and routes
837/// messages to their appropriate destinations.
838async fn background_reader_task(
839    mut reader: Box<dyn TransportReader>,
840    state: Arc<QuerySharedState>,
841    message_tx: mpsc::Sender<Result<Message>>,
842) {
843    loop {
844        // Handle deferred stdin close timeout.
845        let read_result = if state.pending_stdin_close.load(Ordering::SeqCst) {
846            let timeout_dur = state.stream_close_timeout;
847            match tokio::time::timeout(timeout_dur, reader.read_next_message()).await {
848                Ok(result) => result,
849                Err(_) => {
850                    debug!("Timed out waiting for first result, closing input stream");
851                    try_close_deferred_stdin(&state).await;
852                    continue;
853                }
854            }
855        } else {
856            reader.read_next_message().await
857        };
858
859        let raw = match read_result {
860            Ok(Some(raw)) => raw,
861            Ok(None) => {
862                try_close_deferred_stdin(&state).await;
863                break;
864            }
865            Err(err) => {
866                mark_reader_terminated(&state, err.to_string()).await;
867                let _ = message_tx.send(Err(err)).await;
868                break;
869            }
870        };
871
872        let msg_type = raw.get("type").and_then(Value::as_str).unwrap_or_default();
873
874        if msg_type == "control_response" {
875            handle_control_response(&state, &raw).await;
876            continue;
877        }
878
879        if msg_type == "control_request" {
880            if let Err(err) = handle_control_request(&state, raw).await {
881                debug!("Error handling control request: {err}");
882            }
883            continue;
884        }
885
886        if msg_type == "control_cancel_request" {
887            continue;
888        }
889
890        // Parse and forward SDK messages.
891        match parse_message(&raw) {
892            Ok(Some(msg)) => {
893                if matches!(msg, Message::Result(_))
894                    && state.pending_stdin_close.load(Ordering::SeqCst)
895                {
896                    debug!("Received first result, closing input stream");
897                    try_close_deferred_stdin(&state).await;
898                }
899
900                if message_tx.send(Ok(msg)).await.is_err() {
901                    break;
902                }
903            }
904            Ok(None) => {}
905            Err(err) => {
906                if message_tx
907                    .send(Err(Error::MessageParse(err)))
908                    .await
909                    .is_err()
910                {
911                    break;
912                }
913            }
914        }
915    }
916}
917
918/// Marks reader termination and fails all pending control requests immediately.
919async fn mark_reader_terminated(state: &QuerySharedState, reason: String) {
920    state.reader_terminated.store(true, Ordering::SeqCst);
921    let stored_reason = {
922        let mut termination_reason = state.reader_termination_reason.lock().await;
923        if termination_reason.is_none() {
924            *termination_reason = Some(reason);
925        }
926        termination_reason
927            .clone()
928            .unwrap_or_else(|| "Unknown reason".to_string())
929    };
930
931    let mut controls = state.pending_controls.lock().await;
932    for (_, sender) in controls.senders.drain() {
933        let _ = sender.send(Err(Error::Other(format!(
934            "Background reader task terminated: {stored_reason}"
935        ))));
936    }
937}
938
939/// Returns the recorded reader termination reason or a generic fallback.
940async fn reader_termination_reason(state: &QuerySharedState) -> String {
941    state
942        .reader_termination_reason
943        .lock()
944        .await
945        .clone()
946        .unwrap_or_else(|| "Unknown reason".to_string())
947}
948
949/// Closes deferred stdin via the shared writer.
950async fn try_close_deferred_stdin(state: &QuerySharedState) {
951    if state
952        .pending_stdin_close
953        .compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst)
954        .is_ok()
955    {
956        if let Err(e) = state.writer.lock().await.end_input().await {
957            debug!("Error closing deferred stdin: {e}");
958        }
959    }
960}
961
962/// Routes a control response to the waiting oneshot sender, or buffers it.
963///
964/// If no sender is registered for this response's `request_id`, the parsed
965/// result is stored in the buffer for later retrieval by `send_control_request`.
966async fn handle_control_response(state: &QuerySharedState, raw: &Value) {
967    let Some(response) = raw.get("response").and_then(Value::as_object) else {
968        return;
969    };
970    let response_request_id = response
971        .get("request_id")
972        .and_then(Value::as_str)
973        .unwrap_or_default();
974
975    let subtype = response
976        .get("subtype")
977        .and_then(Value::as_str)
978        .unwrap_or_default();
979
980    let result: Result<Value> = if subtype == "error" {
981        let error = response
982            .get("error")
983            .and_then(Value::as_str)
984            .unwrap_or("Unknown error");
985        Err(Error::Other(error.to_string()))
986    } else {
987        Ok(response
988            .get("response")
989            .cloned()
990            .unwrap_or_else(|| json!({})))
991    };
992
993    let mut controls = state.pending_controls.lock().await;
994    if let Some(sender) = controls.senders.remove(response_request_id) {
995        let _ = sender.send(result);
996    } else {
997        // Response arrived before sender was registered — buffer it.
998        controls
999            .buffered
1000            .insert(response_request_id.to_string(), result);
1001    }
1002}
1003
1004async fn handle_can_use_tool_request(
1005    state: &QuerySharedState,
1006    request_data: &Map<String, Value>,
1007) -> Result<Value> {
1008    let callback = state
1009        .can_use_tool
1010        .clone()
1011        .ok_or_else(|| Error::Other("canUseTool callback is not provided".to_string()))?;
1012    let tool_name = request_data
1013        .get("tool_name")
1014        .and_then(Value::as_str)
1015        .unwrap_or_default()
1016        .to_string();
1017    let input = request_data
1018        .get("input")
1019        .cloned()
1020        .unwrap_or_else(|| json!({}));
1021    let suggestions = request_data
1022        .get("permission_suggestions")
1023        .and_then(Value::as_array)
1024        .cloned()
1025        .unwrap_or_default()
1026        .into_iter()
1027        .filter_map(|value| serde_json::from_value(value).ok())
1028        .collect();
1029    let blocked_path = request_data
1030        .get("blocked_path")
1031        .and_then(Value::as_str)
1032        .map(ToString::to_string);
1033    let context = ToolPermissionContext {
1034        suggestions,
1035        blocked_path,
1036        signal: None,
1037    };
1038
1039    let callback_future = panic::catch_unwind(AssertUnwindSafe(|| {
1040        callback(tool_name, input.clone(), context)
1041    }))
1042    .map_err(|payload| callback_panic_error("can_use_tool", payload))?;
1043    let callback_result =
1044        await_callback_with_panic_isolation("can_use_tool", callback_future).await?;
1045    let output = match callback_result {
1046        PermissionResult::Allow(allow) => {
1047            let mut obj = Map::new();
1048            obj.insert("behavior".to_string(), Value::String("allow".to_string()));
1049            obj.insert(
1050                "updatedInput".to_string(),
1051                allow.updated_input.unwrap_or(input),
1052            );
1053            if let Some(updated_permissions) = allow.updated_permissions {
1054                let permissions_json: Vec<Value> = updated_permissions
1055                    .into_iter()
1056                    .map(|permission| permission.to_cli_dict())
1057                    .collect();
1058                obj.insert(
1059                    "updatedPermissions".to_string(),
1060                    Value::Array(permissions_json),
1061                );
1062            }
1063            Value::Object(obj)
1064        }
1065        PermissionResult::Deny(deny) => {
1066            let mut obj = Map::new();
1067            obj.insert("behavior".to_string(), Value::String("deny".to_string()));
1068            obj.insert("message".to_string(), Value::String(deny.message));
1069            if deny.interrupt {
1070                obj.insert("interrupt".to_string(), Value::Bool(true));
1071            }
1072            Value::Object(obj)
1073        }
1074    };
1075    Ok(output)
1076}
1077
1078async fn handle_hook_callback_request(
1079    state: &QuerySharedState,
1080    request_data: &Map<String, Value>,
1081) -> Result<Value> {
1082    let callback_id = request_data
1083        .get("callback_id")
1084        .and_then(Value::as_str)
1085        .ok_or_else(|| Error::Other("Missing callback_id in hook_callback".to_string()))?;
1086    let callback = state
1087        .hook_callbacks
1088        .lock()
1089        .await
1090        .get(callback_id)
1091        .cloned()
1092        .ok_or_else(|| Error::Other(format!("No hook callback found for ID: {callback_id}")))?;
1093    let input = request_data.get("input").cloned().unwrap_or(Value::Null);
1094    let tool_use_id = request_data
1095        .get("tool_use_id")
1096        .and_then(Value::as_str)
1097        .map(ToString::to_string);
1098    let callback_future = panic::catch_unwind(AssertUnwindSafe(|| {
1099        callback(input, tool_use_id, Default::default())
1100    }))
1101    .map_err(|payload| callback_panic_error("hook", payload))?;
1102    let output = await_callback_with_panic_isolation("hook", callback_future).await?;
1103    Ok(convert_hook_output_for_cli(output))
1104}
1105
1106async fn handle_mcp_message_request(
1107    state: &QuerySharedState,
1108    request_data: &Map<String, Value>,
1109) -> Result<Value> {
1110    let server_name = request_data
1111        .get("server_name")
1112        .and_then(Value::as_str)
1113        .ok_or_else(|| Error::Other("Missing server_name in mcp_message".to_string()))?;
1114    let message = request_data
1115        .get("message")
1116        .cloned()
1117        .ok_or_else(|| Error::Other("Missing message in mcp_message".to_string()))?;
1118    let response = handle_sdk_mcp_request(&state.sdk_mcp_servers, server_name, &message).await;
1119    Ok(json!({ "mcp_response": response }))
1120}
1121
1122/// Handles an incoming control request from the CLI within the background task.
1123async fn handle_control_request(state: &QuerySharedState, request: Value) -> Result<()> {
1124    let Some(request_obj) = request.as_object() else {
1125        return Err(Error::Other("Invalid control request format".to_string()));
1126    };
1127    let request_id = request_obj
1128        .get("request_id")
1129        .and_then(Value::as_str)
1130        .ok_or_else(|| Error::Other("Missing request_id in control request".to_string()))?
1131        .to_string();
1132    let request_data = request_obj
1133        .get("request")
1134        .and_then(Value::as_object)
1135        .ok_or_else(|| Error::Other("Missing request payload".to_string()))?;
1136    let subtype = request_data
1137        .get("subtype")
1138        .and_then(Value::as_str)
1139        .ok_or_else(|| Error::Other("Missing request subtype".to_string()))?;
1140
1141    let result: Result<Value> = match subtype {
1142        "can_use_tool" => handle_can_use_tool_request(state, request_data).await,
1143        "hook_callback" => handle_hook_callback_request(state, request_data).await,
1144        "mcp_message" => handle_mcp_message_request(state, request_data).await,
1145        _ => Err(Error::Other(format!(
1146            "Unsupported control request subtype: {subtype}"
1147        ))),
1148    };
1149
1150    let response_json = match result {
1151        Ok(payload) => json!({
1152            "type": "control_response",
1153            "response": {
1154                "subtype": "success",
1155                "request_id": request_id,
1156                "response": payload
1157            }
1158        }),
1159        Err(err) => json!({
1160            "type": "control_response",
1161            "response": {
1162                "subtype": "error",
1163                "request_id": request_id,
1164                "error": err.to_string()
1165            }
1166        }),
1167    };
1168
1169    state
1170        .writer
1171        .lock()
1172        .await
1173        .write(&(response_json.to_string() + "\n"))
1174        .await
1175}
1176
1177/// Routes an MCP message to the appropriate in-process SDK MCP server.
1178///
1179/// Implements JSON-RPC message routing for in-process SDK MCP servers.
1180/// Handles `initialize`, `tools/list`, `tools/call`, and `notifications/initialized` methods.
1181///
1182/// # Example
1183///
1184/// ```rust,no_run
1185/// use claude_code::{create_sdk_mcp_server, tool};
1186/// use claude_code::query::handle_sdk_mcp_request;
1187/// use serde_json::{json, Value};
1188/// use std::collections::HashMap;
1189/// use std::sync::Arc;
1190///
1191/// # async fn example() {
1192///     let config = create_sdk_mcp_server(
1193///     "tools",
1194///     "1.0.0",
1195///     vec![tool("echo", "Echo", json!({"type":"object"}), |_args: Value| async move {
1196///         Ok(json!({"content": []}))
1197///     })],
1198///     );
1199///
1200///     let mut servers = HashMap::new();
1201///     servers.insert(config.name.clone(), Arc::clone(&config.instance));
1202///
1203///     let response = handle_sdk_mcp_request(
1204///     &servers,
1205///     "tools",
1206///     &json!({"jsonrpc":"2.0","id":1,"method":"tools/list"}),
1207///     )
1208///     .await;
1209///
1210///     assert_eq!(response["jsonrpc"], "2.0");
1211/// # }
1212/// ```
1213pub async fn handle_sdk_mcp_request(
1214    sdk_mcp_servers: &HashMap<String, Arc<McpSdkServer>>,
1215    server_name: &str,
1216    message: &Value,
1217) -> Value {
1218    let Some(server) = sdk_mcp_servers.get(server_name) else {
1219        return json!({
1220            "jsonrpc": "2.0",
1221            "id": message.get("id").cloned().unwrap_or(Value::Null),
1222            "error": {
1223                "code": -32601,
1224                "message": format!("Server '{server_name}' not found")
1225            }
1226        });
1227    };
1228
1229    let method = message
1230        .get("method")
1231        .and_then(Value::as_str)
1232        .unwrap_or_default();
1233    let id = message.get("id").cloned().unwrap_or(Value::Null);
1234    let params = message.get("params").cloned().unwrap_or_else(|| json!({}));
1235
1236    match method {
1237        "initialize" => json!({
1238            "jsonrpc": "2.0",
1239            "id": id,
1240            "result": {
1241                "protocolVersion": "2024-11-05",
1242                "capabilities": {"tools": {}},
1243                "serverInfo": {
1244                    "name": server.name,
1245                    "version": server.version
1246                }
1247            }
1248        }),
1249        "tools/list" => json!({
1250            "jsonrpc": "2.0",
1251            "id": id,
1252            "result": {
1253                "tools": server.list_tools_json()
1254            }
1255        }),
1256        "tools/call" => {
1257            let tool_name = params
1258                .get("name")
1259                .and_then(Value::as_str)
1260                .unwrap_or_default();
1261            let arguments = params
1262                .get("arguments")
1263                .cloned()
1264                .unwrap_or_else(|| json!({}));
1265            let result = server.call_tool_json(tool_name, arguments).await;
1266            json!({
1267                "jsonrpc": "2.0",
1268                "id": id,
1269                "result": result
1270            })
1271        }
1272        "notifications/initialized" => json!({
1273            "jsonrpc": "2.0",
1274            "result": {}
1275        }),
1276        _ => json!({
1277            "jsonrpc": "2.0",
1278            "id": id,
1279            "error": {
1280                "code": -32601,
1281                "message": format!("Method '{method}' not found")
1282            }
1283        }),
1284    }
1285}
1286
1287// ---------------------------------------------------------------------------
1288// Helper: Build hooks config and callbacks
1289// ---------------------------------------------------------------------------
1290
1291/// Builds the hooks configuration for the initialization handshake and extracts
1292/// hook callbacks for the background task.
1293pub(crate) fn build_hooks_config(
1294    hooks: &HashMap<String, Vec<HookMatcher>>,
1295) -> (Map<String, Value>, HashMap<String, HookCallback>) {
1296    let mut hooks_config = Map::new();
1297    let mut hook_callbacks = HashMap::new();
1298    let mut next_callback_id: usize = 0;
1299
1300    for (event, matchers) in hooks {
1301        if matchers.is_empty() {
1302            continue;
1303        }
1304        let mut event_matchers = Vec::new();
1305        for matcher in matchers {
1306            let mut callback_ids = Vec::new();
1307            for callback in &matcher.hooks {
1308                let callback_id = format!("hook_{}", next_callback_id);
1309                next_callback_id += 1;
1310                hook_callbacks.insert(callback_id.clone(), callback.clone());
1311                callback_ids.push(callback_id);
1312            }
1313
1314            let mut matcher_obj = Map::new();
1315            matcher_obj.insert(
1316                "matcher".to_string(),
1317                matcher
1318                    .matcher
1319                    .as_ref()
1320                    .map(|m| Value::String(m.clone()))
1321                    .unwrap_or(Value::Null),
1322            );
1323            matcher_obj.insert("hookCallbackIds".to_string(), json!(callback_ids));
1324            if let Some(timeout) = matcher.timeout {
1325                matcher_obj.insert("timeout".to_string(), json!(timeout));
1326            }
1327            event_matchers.push(Value::Object(matcher_obj));
1328        }
1329        hooks_config.insert(event.clone(), Value::Array(event_matchers));
1330    }
1331
1332    (hooks_config, hook_callbacks)
1333}