Skip to main content

claude_cli_sdk/
client.rs

1//! The core `Client` struct — multi-turn, bidirectional Claude Code sessions.
2//!
3//! # Architecture
4//!
5//! On [`connect()`](Client::connect), the client spawns a background task that:
6//! 1. Reads JSON values from the transport
7//! 2. Routes permission requests to the configured callback
8//! 3. Routes hook requests to registered hook matchers
9//! 4. Applies the message callback (if any)
10//! 5. Forwards resulting messages through a `flume` channel
11//!
12//! Callers consume messages via [`send()`](Client::send) (which returns a
13//! stream), or via [`receive_messages()`](Client::receive_messages).
14
15use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::time::Duration;
18
19use dashmap::DashMap;
20use futures_core::Stream;
21use tokio::sync::oneshot;
22use tokio_stream::StreamExt;
23use tokio_util::sync::CancellationToken;
24
25use crate::callback::apply_callback;
26use crate::config::{ClientConfig, PermissionMode};
27use crate::errors::{Error, Result};
28use crate::transport::{CliTransport, Transport};
29use crate::types::content::UserContent;
30use crate::types::messages::{Message, SessionInfo};
31
32// ── Cancellation helper ──────────────────────────────────────────────────────
33
34/// Wait for the token to fire, or pend forever if `None`.
35///
36/// Useful as a `tokio::select!` branch that compiles away when no token is
37/// provided.
38pub(crate) async fn cancelled_or_pending(token: Option<&CancellationToken>) {
39    match token {
40        Some(t) => t.cancelled().await,
41        None => std::future::pending().await,
42    }
43}
44
45// ── Timeout helper ───────────────────────────────────────────────────────────
46
47/// Receive a message from a flume channel with an optional timeout and
48/// optional cancellation token.
49///
50/// Returns `Error::Timeout` if the deadline expires, `Error::Cancelled` if
51/// the token fires, or `Error::Transport` if the channel is closed.
52pub(crate) async fn recv_with_timeout(
53    rx: &flume::Receiver<Result<Message>>,
54    timeout: Option<Duration>,
55    cancel: Option<&CancellationToken>,
56) -> Result<Message> {
57    let recv_fut = rx.recv_async();
58    tokio::select! {
59        biased;
60        _ = cancelled_or_pending(cancel) => {
61            Err(Error::Cancelled)
62        }
63        result = async {
64            match timeout {
65                Some(d) => match tokio::time::timeout(d, recv_fut).await {
66                    Ok(Ok(msg)) => msg,
67                    Ok(Err(_)) => Err(Error::Transport("message channel closed".into())),
68                    Err(_) => Err(Error::Timeout(format!("read timed out after {}s", d.as_secs_f64()))),
69                },
70                None => match recv_fut.await {
71                    Ok(msg) => msg,
72                    Err(_) => Err(Error::Transport("message channel closed".into())),
73                },
74            }
75        } => result,
76    }
77}
78
79// ── Control protocol helpers ────────────────────────────────────────────────
80
81fn build_control_response_success(
82    request_id: &str,
83    response: serde_json::Value,
84) -> serde_json::Value {
85    serde_json::json!({
86        "type": "control_response",
87        "response": {
88            "subtype": "success",
89            "request_id": request_id,
90            "response": response
91        }
92    })
93}
94
95fn build_control_response_error(request_id: &str, error: impl Into<String>) -> serde_json::Value {
96    serde_json::json!({
97        "type": "control_response",
98        "response": {
99            "subtype": "error",
100            "request_id": request_id,
101            "error": error.into()
102        }
103    })
104}
105
106async fn write_json_line(transport: &dyn Transport, value: serde_json::Value) {
107    if let Ok(json) = serde_json::to_string(&value) {
108        let _ = transport.write(&json).await;
109    }
110}
111
112// ── Shared turn stream helper ─────────────────────────────────────────────────
113
114/// Read messages from the receiver until a `Result` message or error,
115/// then clear the turn flag.
116fn read_turn_stream<'a>(
117    rx: &'a flume::Receiver<Result<Message>>,
118    read_timeout: Option<Duration>,
119    turn_flag: Arc<AtomicBool>,
120    cancel: Option<CancellationToken>,
121) -> impl Stream<Item = Result<Message>> + 'a {
122    async_stream::stream! {
123        loop {
124            match recv_with_timeout(rx, read_timeout, cancel.as_ref()).await {
125                Ok(msg) => {
126                    let is_result = matches!(&msg, Message::Result(_));
127                    yield Ok(msg);
128                    if is_result {
129                        break;
130                    }
131                }
132                Err(e) => {
133                    yield Err(e);
134                    break;
135                }
136            }
137        }
138        turn_flag.store(false, Ordering::Release);
139    }
140}
141
142// ── Client ───────────────────────────────────────────────────────────────────
143
144/// A stateful Claude Code client that manages a persistent session.
145///
146/// # Lifecycle
147///
148/// 1. Create with [`Client::new(config)`](Client::new) or [`Client::with_transport(config, transport)`](Client::with_transport)
149/// 2. Call [`connect()`](Client::connect) to spawn the CLI and read the init message
150/// 3. Use [`send()`](Client::send) to send prompts and stream responses
151/// 4. Call [`close()`](Client::close) to shut down cleanly
152pub struct Client {
153    config: ClientConfig,
154    transport: Arc<dyn Transport>,
155    session_id: Option<String>,
156    message_rx: Option<flume::Receiver<Result<Message>>>,
157    shutdown_tx: Option<oneshot::Sender<()>>,
158    turn_active: Arc<AtomicBool>,
159    /// Pending outbound control requests awaiting responses.
160    pending_control: Arc<DashMap<String, oneshot::Sender<serde_json::Value>>>,
161    /// Counter for generating unique request IDs.
162    request_counter: Arc<AtomicU64>,
163}
164
165impl std::fmt::Debug for Client {
166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167        f.debug_struct("Client")
168            .field("session_id", &self.session_id)
169            .field("connected", &self.is_connected())
170            .finish_non_exhaustive()
171    }
172}
173
174impl Client {
175    /// Create a new client with the given configuration.
176    ///
177    /// This does NOT start the CLI — call [`connect()`](Client::connect) next.
178    /// Validates the config (e.g., `cwd` existence) and discovers the CLI binary.
179    ///
180    /// # Errors
181    ///
182    /// Returns [`Error::CliNotFound`] if the CLI binary cannot be discovered,
183    /// or [`Error::Config`] if the configuration is invalid.
184    pub fn new(config: ClientConfig) -> Result<Self> {
185        config.validate()?;
186        let transport = Arc::new(CliTransport::from_config(&config)?);
187        Ok(Self {
188            config,
189            transport,
190            session_id: None,
191            message_rx: None,
192            shutdown_tx: None,
193            turn_active: Arc::new(AtomicBool::new(false)),
194            pending_control: Arc::new(DashMap::new()),
195            request_counter: Arc::new(AtomicU64::new(0)),
196        })
197    }
198
199    /// Create a client with a custom transport (useful for testing).
200    pub fn with_transport(config: ClientConfig, transport: Arc<dyn Transport>) -> Result<Self> {
201        config.validate()?;
202        Ok(Self {
203            config,
204            transport,
205            session_id: None,
206            message_rx: None,
207            shutdown_tx: None,
208            turn_active: Arc::new(AtomicBool::new(false)),
209            pending_control: Arc::new(DashMap::new()),
210            request_counter: Arc::new(AtomicU64::new(0)),
211        })
212    }
213
214    /// Connect to the CLI and return the session info from the init message.
215    ///
216    /// This spawns the CLI process (or connects to the mock transport) and
217    /// starts the background reader task. The entire connect sequence
218    /// (transport connect + init message read) is subject to `connect_timeout`.
219    pub async fn connect(&mut self) -> Result<SessionInfo> {
220        let timeout = self.config.connect_timeout;
221        let result = match timeout {
222            Some(d) => tokio::time::timeout(d, self.connect_inner())
223                .await
224                .map_err(|_| {
225                    Error::Timeout(format!("connect timed out after {}s", d.as_secs_f64()))
226                })?,
227            None => self.connect_inner().await,
228        };
229        if result.is_err() {
230            // Clean up: stop background task + kill CLI process.
231            if let Some(tx) = self.shutdown_tx.take() {
232                let _ = tx.send(());
233            }
234            self.message_rx.take();
235            let _ = self.transport.close().await;
236        }
237        result
238    }
239
240    async fn connect_inner(&mut self) -> Result<SessionInfo> {
241        self.transport.connect().await?;
242
243        // Set up the message routing pipeline.
244        let (msg_tx, msg_rx) = flume::bounded(1024);
245        let (shutdown_tx, shutdown_rx) = oneshot::channel();
246
247        let transport = Arc::clone(&self.transport);
248        let message_callback = self.config.message_callback.clone();
249        let pending_control = Arc::clone(&self.pending_control);
250
251        // Move hooks into the background task for hook dispatch.
252        let hooks: Vec<crate::hooks::HookMatcher> = std::mem::take(&mut self.config.hooks);
253        let default_hook_timeout = self.config.default_hook_timeout;
254        let hook_transport = Arc::clone(&self.transport);
255
256        // Capture permission callback for the background task.
257        let can_use_tool = self.config.can_use_tool.clone();
258        let perm_transport = Arc::clone(&self.transport);
259
260        // Cancellation token for cooperative consumer-side abort.
261        let cancel_token = self.config.cancellation_token.clone();
262
263        // Shared session_id for the background task.
264        // Updated after the init message is parsed so subsequent hook dispatches
265        // carry the real session ID rather than None.
266        let shared_session_id: Arc<std::sync::Mutex<Option<String>>> =
267            Arc::new(std::sync::Mutex::new(None));
268        let hook_session_id = Arc::clone(&shared_session_id);
269
270        // Spawn background reader task.
271        tokio::spawn(async move {
272            let mut stream = transport.read_messages();
273            let mut shutdown = shutdown_rx;
274
275            loop {
276                tokio::select! {
277                    biased;
278                    _ = &mut shutdown => break,
279                    _ = cancelled_or_pending(cancel_token.as_ref()) => break,
280                    item = stream.next() => {
281                        match item {
282                            Some(Ok(value)) => {
283                                // Route control_response messages to pending senders.
284                                // The request_id may be at the top level or nested
285                                // under response.request_id.
286                                if value.get("type").and_then(|v| v.as_str()) == Some("control_response") {
287                                    let req_id = value.get("request_id")
288                                        .and_then(|v| v.as_str())
289                                        .or_else(|| value.pointer("/response/request_id")
290                                            .and_then(|v| v.as_str()));
291                                    if let Some(rid) = req_id {
292                                        if let Some((_, tx)) = pending_control.remove(rid) {
293                                            let _ = tx.send(value);
294                                        }
295                                    }
296                                    continue;
297                                }
298
299                                // Route hook_request messages to registered hooks.
300                                if value.get("type").and_then(|v| v.as_str()) == Some("hook_request") {
301                                    if let Ok(req) = serde_json::from_value::<crate::hooks::HookRequest>(value) {
302                                        let sid = hook_session_id
303                                            .lock()
304                                            .expect("session_id lock")
305                                            .clone();
306                                        let output = crate::hooks::dispatch_hook(
307                                            &req,
308                                            &hooks,
309                                            default_hook_timeout,
310                                            sid,
311                                        ).await;
312                                        let response = crate::hooks::HookResponse::from_output(
313                                            req.request_id,
314                                            output,
315                                        );
316                                        if let Ok(json) = serde_json::to_string(&response) {
317                                            let _ = hook_transport.write(&json).await;
318                                        }
319                                    }
320                                    continue;
321                                }
322
323                                // Route control_request messages (permissions, hooks).
324                                //
325                                // The CLI sends: {"type": "control_request", "request_id": "...",
326                                //   "request": {"subtype": "can_use_tool"|"hook_callback", ...}}
327                                // We respond: {"type": "control_response", "response":
328                                //   {"subtype": "success"|"error", "request_id": "...", "response": {...}}}
329                                if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
330                                    let request_id = value.get("request_id")
331                                        .and_then(|v| v.as_str())
332                                        .unwrap_or("")
333                                        .to_string();
334                                    if request_id.is_empty() {
335                                        let _ = msg_tx.send(Err(Error::ControlProtocol(
336                                            "received control_request without request_id".into(),
337                                        )));
338                                        continue;
339                                    }
340                                    let subtype = value.pointer("/request/subtype")
341                                        .and_then(|v| v.as_str())
342                                        .unwrap_or("");
343
344                                    match subtype {
345                                        "can_use_tool" => {
346                                            let request = value.get("request").cloned()
347                                                .unwrap_or_default();
348                                            let tool_name = request.get("tool_name")
349                                                .and_then(|v| v.as_str())
350                                                .unwrap_or("")
351                                                .to_string();
352                                            let tool_input = request.get("input")
353                                                .cloned()
354                                                .unwrap_or(serde_json::Value::Null);
355                                            let tool_use_id = request.get("tool_use_id")
356                                                .and_then(|v| v.as_str())
357                                                .unwrap_or("")
358                                                .to_string();
359                                            let suggestions: Vec<String> = request
360                                                .get("permission_suggestions")
361                                                .and_then(|v| v.as_array())
362                                                .map(|arr| arr.iter()
363                                                    .filter_map(|v| v.as_str().map(String::from))
364                                                    .collect())
365                                                .unwrap_or_default();
366
367                                            if let Some(ref callback) = can_use_tool {
368                                                let sid = hook_session_id
369                                                    .lock()
370                                                    .expect("session_id lock")
371                                                    .clone()
372                                                    .unwrap_or_default();
373                                                let ctx = crate::permissions::PermissionContext {
374                                                    tool_use_id,
375                                                    session_id: sid,
376                                                    request_id: request_id.clone(),
377                                                    suggestions,
378                                                };
379                                                let decision = callback(&tool_name, &tool_input, ctx).await;
380
381                                                // Build response matching Python SDK format.
382                                                let response_data = match decision {
383                                                    crate::permissions::PermissionDecision::Allow { updated_input } => {
384                                                        let input = updated_input.unwrap_or(tool_input);
385                                                        serde_json::json!({
386                                                            "behavior": "allow",
387                                                            "updatedInput": input
388                                                        })
389                                                    }
390                                                    crate::permissions::PermissionDecision::Deny { message, interrupt } => {
391                                                        let mut d = serde_json::json!({
392                                                            "behavior": "deny",
393                                                            "message": message
394                                                        });
395                                                        if interrupt {
396                                                            d["interrupt"] = serde_json::json!(true);
397                                                        }
398                                                        d
399                                                    }
400                                                };
401                                                let response =
402                                                    build_control_response_success(&request_id, response_data);
403                                                write_json_line(&*perm_transport, response).await;
404                                            } else {
405                                                let response = build_control_response_error(
406                                                    &request_id,
407                                                    "no can_use_tool callback configured",
408                                                );
409                                                write_json_line(&*perm_transport, response).await;
410                                                let _ = msg_tx.send(Err(Error::ControlProtocol(
411                                                    "received can_use_tool control_request but no \
412                                                     callback is configured"
413                                                        .into(),
414                                                )));
415                                            }
416                                        }
417                                        "hook_callback" => {
418                                            let request = value.get("request").cloned()
419                                                .unwrap_or_default();
420                                            let hook_event_str = request.get("hook_event_name")
421                                                .and_then(|v| v.as_str())
422                                                .unwrap_or("");
423                                            let hook_tool_name = request.get("tool_name")
424                                                .and_then(|v| v.as_str())
425                                                .map(String::from);
426                                            let hook_tool_input = request.get("tool_input").cloned();
427                                            let hook_tool_result = request.get("tool_result").cloned();
428                                            let hook_tool_use_id = request.get("tool_use_id")
429                                                .and_then(|v| v.as_str())
430                                                .map(String::from);
431
432                                            // Map hook event name string to HookEvent enum.
433                                            let hook_event = match hook_event_str {
434                                                "PreToolUse" => Some(crate::hooks::HookEvent::PreToolUse),
435                                                "PostToolUse" => Some(crate::hooks::HookEvent::PostToolUse),
436                                                "PostToolUseFailure" => Some(crate::hooks::HookEvent::PostToolUseFailure),
437                                                "UserPromptSubmit" => Some(crate::hooks::HookEvent::UserPromptSubmit),
438                                                "Stop" => Some(crate::hooks::HookEvent::Stop),
439                                                "SubagentStop" => Some(crate::hooks::HookEvent::SubagentStop),
440                                                "PreCompact" => Some(crate::hooks::HookEvent::PreCompact),
441                                                "Notification" => Some(crate::hooks::HookEvent::Notification),
442                                                _ => None,
443                                            };
444
445                                            if let Some(event) = hook_event {
446                                                let req = crate::hooks::HookRequest {
447                                                    request_id: request_id.clone(),
448                                                    hook_event: event,
449                                                    tool_name: hook_tool_name,
450                                                    tool_input: hook_tool_input,
451                                                    tool_result: hook_tool_result,
452                                                    tool_use_id: hook_tool_use_id,
453                                                };
454                                                let sid = hook_session_id
455                                                    .lock()
456                                                    .expect("session_id lock")
457                                                    .clone();
458                                                let output = crate::hooks::dispatch_hook(
459                                                    &req,
460                                                    &hooks,
461                                                    default_hook_timeout,
462                                                    sid,
463                                                ).await;
464
465                                                // Convert HookOutput to control_response format.
466                                                let response_data = match output.decision {
467                                                    crate::hooks::HookDecision::Allow => {
468                                                        serde_json::json!({"continue_": true})
469                                                    }
470                                                    crate::hooks::HookDecision::Block => {
471                                                        serde_json::json!({
472                                                            "continue_": false,
473                                                            "reason": output.reason.unwrap_or_default()
474                                                        })
475                                                    }
476                                                    crate::hooks::HookDecision::Modify => {
477                                                        let mut d = serde_json::json!({"continue_": true});
478                                                        if let Some(input) = output.updated_input {
479                                                            d["updatedInput"] = input;
480                                                        }
481                                                        d
482                                                    }
483                                                    crate::hooks::HookDecision::Abort => {
484                                                        serde_json::json!({
485                                                            "continue_": false,
486                                                            "reason": output.reason.unwrap_or_default()
487                                                        })
488                                                    }
489                                                };
490                                                let response =
491                                                    build_control_response_success(&request_id, response_data);
492                                                write_json_line(&*hook_transport, response).await;
493                                            } else {
494                                                // Unknown hook event — respond with success/continue
495                                                // to avoid blocking the CLI.
496                                                let response = build_control_response_success(
497                                                    &request_id,
498                                                    serde_json::json!({"continue_": true}),
499                                                );
500                                                write_json_line(&*hook_transport, response).await;
501                                                let _ = msg_tx.send(Err(Error::ControlProtocol(
502                                                    format!(
503                                                        "received unknown hook_event_name: {hook_event_str}"
504                                                    ),
505                                                )));
506                                            }
507                                        }
508                                        _ => {
509                                            // Unknown subtype — respond with error.
510                                            let response = build_control_response_error(
511                                                &request_id,
512                                                format!("unknown control_request subtype: {subtype}"),
513                                            );
514                                            write_json_line(&*perm_transport, response).await;
515                                        }
516                                    }
517                                    continue;
518                                }
519
520                                // Legacy: route permission_request messages (pre-control_request format).
521                                if value.get("type").and_then(|v| v.as_str()) == Some("permission_request") {
522                                    if let Some(ref callback) = can_use_tool {
523                                        if let Ok(req) = serde_json::from_value::<crate::permissions::ControlRequest>(value) {
524                                            let crate::permissions::ControlRequestData::PermissionRequest {
525                                                ref tool_name,
526                                                ref tool_input,
527                                                ref tool_use_id,
528                                                ref suggestions,
529                                            } = req.request;
530                                            let sid = hook_session_id
531                                                .lock()
532                                                .expect("session_id lock")
533                                                .clone()
534                                                .unwrap_or_default();
535                                            let ctx = crate::permissions::PermissionContext {
536                                                tool_use_id: tool_use_id.clone(),
537                                                session_id: sid,
538                                                request_id: req.request_id.clone(),
539                                                suggestions: suggestions.clone(),
540                                            };
541                                            let decision = callback(tool_name, tool_input, ctx).await;
542                                            let response = crate::permissions::ControlResponse {
543                                                kind: "permission_response".into(),
544                                                request_id: req.request_id,
545                                                result: crate::permissions::ControlResponseResult::from(decision),
546                                            };
547                                            if let Ok(json) = serde_json::to_string(&response) {
548                                                let _ = perm_transport.write(&json).await;
549                                            }
550                                        }
551                                    } else {
552                                        let deny_response = serde_json::json!({
553                                            "kind": "permission_response",
554                                            "request_id": value.get("request_id")
555                                                .and_then(|v| v.as_str())
556                                                .unwrap_or(""),
557                                            "result": {
558                                                "type": "deny",
559                                                "message": "no permission callback configured"
560                                            }
561                                        });
562                                        if let Ok(json) = serde_json::to_string(&deny_response) {
563                                            let _ = perm_transport.write(&json).await;
564                                        }
565                                        let _ = msg_tx.send(Err(Error::ControlProtocol(
566                                            "received permission_request but no can_use_tool \
567                                             callback is configured"
568                                                .into(),
569                                        )));
570                                    }
571                                    continue;
572                                }
573
574                                // Parse the JSON value into a Message.
575                                let msg: Message = match serde_json::from_value(value) {
576                                    Ok(m) => m,
577                                    Err(e) => {
578                                        let _ = msg_tx.send(Err(Error::Json(e)));
579                                        continue;
580                                    }
581                                };
582
583                                // Apply the message callback.
584                                let msg = match apply_callback(msg, message_callback.as_ref()) {
585                                    Some(m) => m,
586                                    None => continue, // Filtered out
587                                };
588
589                                if msg_tx.send(Ok(msg)).is_err() {
590                                    break; // Receiver dropped
591                                }
592                            }
593                            Some(Err(e)) => {
594                                let _ = msg_tx.send(Err(e));
595                            }
596                            None => break, // Stream ended
597                        }
598                    }
599                }
600            }
601        });
602
603        self.message_rx = Some(msg_rx);
604        self.shutdown_tx = Some(shutdown_tx);
605
606        // If an init trigger message is configured (e.g., for --input-format
607        // stream-json mode), write it to stdin now. The CLI won't emit the
608        // system/init message until it receives stdin input in this mode.
609        if let Some(ref msg) = self.config.init_stdin_message {
610            self.transport.write(msg).await?;
611        }
612
613        // Wait for the system/init message.
614        // The overall connect_timeout is enforced by the caller, so we wait
615        // indefinitely here (the outer timeout will cancel us if needed).
616        let init_msg = self
617            .message_rx
618            .as_ref()
619            .unwrap()
620            .recv_async()
621            .await
622            .map_err(|_| Error::Transport("connection closed before init message".into()))?
623            .map_err(|e| Error::Transport(format!("error reading init message: {e}")))?;
624
625        if let Message::System(ref sys) = init_msg {
626            let info = SessionInfo::try_from(sys)?;
627            self.session_id = Some(info.session_id.clone());
628            // Propagate the session ID to the background task so hook
629            // dispatches after this point carry the real session ID.
630            *shared_session_id.lock().expect("session_id lock") = Some(info.session_id.clone());
631            Ok(info)
632        } else {
633            Err(Error::ControlProtocol(format!(
634                "expected system/init as first message, got: {init_msg:?}"
635            )))
636        }
637    }
638
639    /// Send a text prompt and return a stream of response messages.
640    ///
641    /// The stream yields messages until a `Result` message is received
642    /// (which terminates the turn).
643    pub fn send(
644        &self,
645        prompt: impl Into<String>,
646    ) -> Result<impl Stream<Item = Result<Message>> + '_> {
647        let prompt = prompt.into();
648        let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
649        let transport = Arc::clone(&self.transport);
650
651        // Guard against concurrent turns.
652        if self
653            .turn_active
654            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
655            .is_err()
656        {
657            return Err(Error::ControlProtocol("turn already in progress".into()));
658        }
659        let turn_flag = Arc::clone(&self.turn_active);
660        let read_timeout = self.config.read_timeout;
661        let cancel = self.config.cancellation_token.clone();
662
663        Ok(async_stream::stream! {
664            // Write the prompt to stdin.
665            if let Err(e) = transport.write(&prompt).await {
666                turn_flag.store(false, Ordering::Release);
667                yield Err(e);
668                return;
669            }
670
671            let inner = read_turn_stream(rx, read_timeout, turn_flag, cancel);
672            tokio::pin!(inner);
673            while let Some(item) = inner.next().await {
674                yield item;
675            }
676        })
677    }
678
679    /// Send structured content blocks (text + images) and return a stream of
680    /// response messages.
681    ///
682    /// This is the multi-modal equivalent of [`send()`](Client::send). Content
683    /// is serialised as a JSON user message and written to the CLI's stdin.
684    ///
685    /// # Errors
686    ///
687    /// Returns [`Error::Config`] if `content` is empty, or [`Error::NotConnected`]
688    /// if the client is not connected.
689    pub fn send_content(
690        &self,
691        content: Vec<UserContent>,
692    ) -> Result<impl Stream<Item = Result<Message>> + '_> {
693        if content.is_empty() {
694            return Err(Error::Config("content must not be empty".into()));
695        }
696
697        let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
698        let transport = Arc::clone(&self.transport);
699
700        // Guard against concurrent turns.
701        if self
702            .turn_active
703            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
704            .is_err()
705        {
706            return Err(Error::ControlProtocol("turn already in progress".into()));
707        }
708        let turn_flag = Arc::clone(&self.turn_active);
709        let read_timeout = self.config.read_timeout;
710        let cancel = self.config.cancellation_token.clone();
711
712        Ok(async_stream::stream! {
713            // Serialize content blocks as a JSON user message.
714            let user_message = serde_json::json!({
715                "type": "user",
716                "message": {
717                    "role": "user",
718                    "content": content
719                }
720            });
721            let json = match serde_json::to_string(&user_message) {
722                Ok(j) => j,
723                Err(e) => {
724                    turn_flag.store(false, Ordering::Release);
725                    yield Err(Error::Json(e));
726                    return;
727                }
728            };
729
730            if let Err(e) = transport.write(&json).await {
731                turn_flag.store(false, Ordering::Release);
732                yield Err(e);
733                return;
734            }
735
736            let inner = read_turn_stream(rx, read_timeout, turn_flag, cancel);
737            tokio::pin!(inner);
738            while let Some(item) = inner.next().await {
739                yield item;
740            }
741        })
742    }
743
744    /// Return a stream of all incoming messages (without sending a prompt).
745    ///
746    /// Useful for consuming messages from a resumed session.
747    pub fn receive_messages(&self) -> Result<impl Stream<Item = Result<Message>> + '_> {
748        let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
749        let read_timeout = self.config.read_timeout;
750        let cancel = self.config.cancellation_token.clone();
751
752        Ok(async_stream::stream! {
753            loop {
754                match recv_with_timeout(rx, read_timeout, cancel.as_ref()).await {
755                    Ok(msg) => yield Ok(msg),
756                    Err(e) if matches!(e, Error::Transport(_)) => break, // Channel closed
757                    Err(e) => {
758                        yield Err(e);
759                        break;
760                    }
761                }
762            }
763        })
764    }
765
766    /// Write raw text to the CLI's stdin without creating a response stream.
767    ///
768    /// Use this only when [`receive_messages()`] is already consuming responses.
769    /// **Do not call this while a [`send()`] turn is in progress** — doing so
770    /// interleaves writes on the same stdin handle and produces undefined
771    /// protocol behaviour.
772    pub async fn write_to_stdin(&self, text: &str) -> Result<()> {
773        debug_assert!(
774            !self.turn_active.load(Ordering::Relaxed),
775            "write_to_stdin called while a send() turn is active"
776        );
777        self.transport.write(text).await
778    }
779
780    /// Send an interrupt signal to the CLI (SIGINT).
781    pub async fn interrupt(&self) -> Result<()> {
782        self.transport.interrupt().await
783    }
784
785    /// Respond to a permission request from the CLI.
786    ///
787    /// When the CLI asks for permission to use a tool, this method sends
788    /// the decision back via the control protocol.
789    pub async fn respond_to_permission(
790        &self,
791        request_id: &str,
792        decision: crate::permissions::PermissionDecision,
793    ) -> Result<()> {
794        use crate::permissions::{ControlResponse, ControlResponseResult};
795
796        let response = ControlResponse {
797            kind: "permission_response".into(),
798            request_id: request_id.to_string(),
799            result: ControlResponseResult::from(decision),
800        };
801        let json = serde_json::to_string(&response).map_err(Error::Json)?;
802        self.transport.write(&json).await
803    }
804
805    // ── Dynamic control ─────────────────────────────────────────────────
806
807    /// Send a control request to the CLI and wait for the response.
808    ///
809    /// This is the low-level mechanism for dynamic mid-session control.
810    /// The request is wrapped in a `{"type": "control_request", ...}` envelope
811    /// and written to stdin. The background reader routes the matching
812    /// `control_response` back via a `oneshot` channel.
813    async fn send_control_request(&self, request: serde_json::Value) -> Result<serde_json::Value> {
814        let counter = self.request_counter.fetch_add(1, Ordering::Relaxed);
815        let request_id = format!("sdk_req_{counter}");
816
817        let (tx, rx) = oneshot::channel();
818        self.pending_control.insert(request_id.clone(), tx);
819
820        let envelope = serde_json::json!({
821            "type": "control_request",
822            "request_id": request_id,
823            "request": request
824        });
825        let json = serde_json::to_string(&envelope).map_err(Error::Json)?;
826        self.transport.write(&json).await?;
827
828        let timeout = self.config.control_request_timeout;
829        match tokio::time::timeout(timeout, rx).await {
830            Ok(Ok(value)) => Ok(value),
831            Ok(Err(_)) => {
832                self.pending_control.remove(&request_id);
833                Err(Error::ControlProtocol(
834                    "control response channel closed".into(),
835                ))
836            }
837            Err(_) => {
838                self.pending_control.remove(&request_id);
839                Err(Error::Timeout(format!(
840                    "control request timed out after {}s",
841                    timeout.as_secs_f64()
842                )))
843            }
844        }
845    }
846
847    /// Dynamically change the model used for subsequent turns.
848    ///
849    /// Pass `None` to revert to the session's default model.
850    ///
851    /// # Errors
852    ///
853    /// Returns an error if the CLI rejects the model change or the control
854    /// protocol fails.
855    pub async fn set_model(&self, model: Option<&str>) -> Result<()> {
856        self.send_control_request(serde_json::json!({
857            "subtype": "set_model",
858            "model": model
859        }))
860        .await?;
861        Ok(())
862    }
863
864    /// Dynamically change the permission mode for the current session.
865    ///
866    /// # Errors
867    ///
868    /// Returns an error if the CLI rejects the mode change or the control
869    /// protocol fails.
870    pub async fn set_permission_mode(&self, mode: PermissionMode) -> Result<()> {
871        self.send_control_request(serde_json::json!({
872            "subtype": "set_permission_mode",
873            "mode": mode.as_cli_flag()
874        }))
875        .await?;
876        Ok(())
877    }
878
879    /// Write raw data to the transport's stdin.
880    ///
881    /// This is a low-level method used by free functions like
882    /// [`query_stream_with_content`](crate::query_stream_with_content).
883    pub(crate) async fn transport_write(&self, data: &str) -> Result<()> {
884        self.transport.write(data).await
885    }
886
887    /// Take ownership of the message receiver (for use in `query_stream`).
888    ///
889    /// After calling this, `receive_messages()` and `send()` will return
890    /// `NotConnected`.
891    pub(crate) fn take_message_rx(&mut self) -> Option<flume::Receiver<Result<Message>>> {
892        self.message_rx.take()
893    }
894
895    /// Returns the configured read timeout.
896    #[must_use]
897    pub fn read_timeout(&self) -> Option<Duration> {
898        self.config.read_timeout
899    }
900
901    /// Returns the session ID if connected.
902    #[must_use]
903    pub fn session_id(&self) -> Option<&str> {
904        self.session_id.as_deref()
905    }
906
907    /// Returns `true` if the client is connected.
908    #[must_use]
909    pub fn is_connected(&self) -> bool {
910        self.transport.is_ready()
911    }
912
913    /// Close the client and shut down the CLI process.
914    ///
915    /// Returns the CLI process exit code if available. After calling this,
916    /// the `Drop` warning will not fire.
917    pub async fn close(&mut self) -> Result<Option<i32>> {
918        // Signal the background task to stop.
919        if let Some(tx) = self.shutdown_tx.take() {
920            let _ = tx.send(());
921        }
922        // Drop the message receiver so the Drop impl knows we cleaned up.
923        self.message_rx.take();
924        self.transport.close().await
925    }
926}
927
928impl Drop for Client {
929    fn drop(&mut self) {
930        if self.shutdown_tx.is_some() || self.message_rx.is_some() {
931            tracing::warn!(
932                "claude_cli_sdk::Client dropped without calling close(). \
933                 Resources may not be cleaned up properly."
934            );
935        }
936    }
937}
938
939// ── Tests ────────────────────────────────────────────────────────────────────
940
941#[cfg(test)]
942mod tests {
943    use super::*;
944    use crate::config::ClientConfig;
945
946    #[cfg(feature = "testing")]
947    use crate::testing::{ScenarioBuilder, assistant_text};
948
949    fn test_config() -> ClientConfig {
950        ClientConfig::builder().prompt("test").build()
951    }
952
953    #[cfg(feature = "testing")]
954    #[tokio::test]
955    async fn client_connect_and_receive_init() {
956        let transport = ScenarioBuilder::new("test-session")
957            .exchange(vec![assistant_text("Hello!")])
958            .build();
959        let transport = Arc::new(transport);
960
961        let mut client = Client::with_transport(test_config(), transport).unwrap();
962        let info = client.connect().await.unwrap();
963
964        assert_eq!(info.session_id, "test-session");
965        assert!(client.is_connected());
966        assert_eq!(client.session_id(), Some("test-session"));
967    }
968
969    #[cfg(feature = "testing")]
970    #[tokio::test]
971    async fn client_send_yields_messages() {
972        let transport = ScenarioBuilder::new("s1")
973            .exchange(vec![assistant_text("response")])
974            .build();
975        let transport = Arc::new(transport);
976
977        let mut client = Client::with_transport(test_config(), transport).unwrap();
978        client.connect().await.unwrap();
979
980        let stream = client.send("hello").unwrap();
981        tokio::pin!(stream);
982
983        let mut messages = Vec::new();
984        while let Some(msg) = stream.next().await {
985            messages.push(msg.unwrap());
986        }
987
988        // Should get assistant + result
989        assert_eq!(messages.len(), 2);
990        assert!(matches!(&messages[0], Message::Assistant(_)));
991        assert!(matches!(&messages[1], Message::Result(_)));
992    }
993
994    #[cfg(feature = "testing")]
995    #[tokio::test]
996    async fn client_close_succeeds() {
997        let transport = ScenarioBuilder::new("s1").build();
998        let transport = Arc::new(transport);
999
1000        let mut client = Client::with_transport(test_config(), transport).unwrap();
1001        client.connect().await.unwrap();
1002        assert!(client.close().await.is_ok());
1003    }
1004
1005    #[cfg(feature = "testing")]
1006    #[tokio::test]
1007    async fn client_message_callback_filters() {
1008        use crate::callback::MessageCallback;
1009
1010        // Filter out assistant messages.
1011        let callback: MessageCallback = Arc::new(|msg| match &msg {
1012            Message::Assistant(_) => None,
1013            _ => Some(msg),
1014        });
1015
1016        let config = ClientConfig::builder()
1017            .prompt("test")
1018            .message_callback(callback)
1019            .build();
1020
1021        let transport = ScenarioBuilder::new("s1")
1022            .exchange(vec![assistant_text("filtered")])
1023            .build();
1024        let transport = Arc::new(transport);
1025
1026        let mut client = Client::with_transport(config, transport).unwrap();
1027        client.connect().await.unwrap();
1028
1029        let stream = client.send("hello").unwrap();
1030        tokio::pin!(stream);
1031
1032        let mut messages = Vec::new();
1033        while let Some(msg) = stream.next().await {
1034            messages.push(msg.unwrap());
1035        }
1036
1037        // Only result (assistant was filtered).
1038        assert_eq!(messages.len(), 1);
1039        assert!(matches!(&messages[0], Message::Result(_)));
1040    }
1041
1042    #[cfg(feature = "testing")]
1043    #[test]
1044    fn client_debug_before_connect() {
1045        let transport = Arc::new(crate::testing::MockTransport::new());
1046        let client = Client::with_transport(test_config(), transport).unwrap();
1047        let debug = format!("{client:?}");
1048        assert!(debug.contains("Client"));
1049    }
1050
1051    // ── Timeout tests ────────────────────────────────────────────────────
1052
1053    #[cfg(feature = "testing")]
1054    #[tokio::test]
1055    async fn client_connect_timeout_fires() {
1056        use crate::testing::MockTransport;
1057
1058        let transport = MockTransport::new();
1059        // Set connect delay longer than timeout.
1060        transport.set_connect_delay(Duration::from_secs(5));
1061        let transport = Arc::new(transport);
1062
1063        let config = ClientConfig::builder()
1064            .prompt("test")
1065            .connect_timeout(Some(Duration::from_millis(50)))
1066            .build();
1067
1068        let mut client = Client::with_transport(config, transport).unwrap();
1069        let result = client.connect().await;
1070        assert!(result.is_err());
1071        assert!(matches!(result.unwrap_err(), Error::Timeout(_)));
1072    }
1073
1074    #[cfg(feature = "testing")]
1075    #[tokio::test]
1076    async fn client_read_timeout_fires() {
1077        // Build a scenario with init + assistant, but add a large recv_delay
1078        // so the assistant message arrives after the read timeout.
1079        let transport = ScenarioBuilder::new("s1")
1080            .exchange(vec![assistant_text("delayed")])
1081            .build();
1082        // Delay each message by 5 seconds — way longer than our 50ms timeout.
1083        // The init message also gets this delay, but the connect has no timeout
1084        // wrapping for the recv path since connect_timeout is set to None here
1085        // and connect_inner waits indefinitely for init.
1086        // Actually, we need init to arrive fast but subsequent messages slow.
1087        // The MockTransport applies recv_delay to ALL messages including init.
1088        // So set connect_timeout high enough and read_timeout low.
1089        transport.set_recv_delay(Duration::from_millis(200));
1090        let transport = Arc::new(transport);
1091
1092        let config = ClientConfig::builder()
1093            .prompt("test")
1094            .connect_timeout(Some(Duration::from_secs(5)))
1095            .read_timeout(Some(Duration::from_millis(50)))
1096            .build();
1097
1098        let mut client = Client::with_transport(config, transport).unwrap();
1099        client.connect().await.unwrap();
1100
1101        let stream = client.send("hello").unwrap();
1102        tokio::pin!(stream);
1103
1104        let mut got_timeout = false;
1105        while let Some(msg) = stream.next().await {
1106            if let Err(Error::Timeout(_)) = msg {
1107                got_timeout = true;
1108                break;
1109            }
1110        }
1111        assert!(got_timeout, "expected a timeout error");
1112    }
1113
1114    #[cfg(feature = "testing")]
1115    #[tokio::test]
1116    async fn client_permission_callback_invoked_and_responds() {
1117        use crate::permissions::{CanUseToolCallback, PermissionDecision};
1118        use crate::testing::MockTransport;
1119        use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
1120
1121        let invoked = Arc::new(AtomicBool::new(false));
1122        let invoked_clone = Arc::clone(&invoked);
1123
1124        let callback: CanUseToolCallback = Arc::new(move |tool_name: &str, _input, _ctx| {
1125            let invoked = Arc::clone(&invoked_clone);
1126            let tool = tool_name.to_owned();
1127            Box::pin(async move {
1128                invoked.store(true, AtomicOrdering::Release);
1129                assert_eq!(tool, "Bash");
1130                PermissionDecision::allow()
1131            })
1132        });
1133
1134        let config = ClientConfig::builder()
1135            .prompt("test")
1136            .can_use_tool(callback)
1137            .build();
1138
1139        let transport = MockTransport::new();
1140        // Enqueue: init, permission_request, assistant, result
1141        transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
1142        transport.enqueue(r#"{"type":"permission_request","request_id":"perm-1","request":{"type":"permission_request","tool_name":"Bash","tool_input":{"command":"ls"},"tool_use_id":"tu-1","suggestions":["allow_once"]}}"#);
1143        transport.enqueue(&serde_json::to_string(&crate::testing::assistant_text("done")).unwrap());
1144        transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
1145        let transport = Arc::new(transport);
1146
1147        let mut client = Client::with_transport(config, transport.clone()).unwrap();
1148        client.connect().await.unwrap();
1149
1150        let stream = client.send("hello").unwrap();
1151        tokio::pin!(stream);
1152        let mut messages = Vec::new();
1153        while let Some(msg) = stream.next().await {
1154            messages.push(msg.unwrap());
1155        }
1156
1157        // Callback should have been invoked.
1158        assert!(
1159            invoked.load(AtomicOrdering::Acquire),
1160            "permission callback was not invoked"
1161        );
1162
1163        // Permission request should NOT leak as a Message — only assistant + result.
1164        assert_eq!(messages.len(), 2);
1165        assert!(matches!(&messages[0], Message::Assistant(_)));
1166        assert!(matches!(&messages[1], Message::Result(_)));
1167
1168        // Verify a permission_response was written back.
1169        let written = transport.written_lines();
1170        let perm_responses: Vec<_> = written
1171            .iter()
1172            .filter(|line| line.contains("permission_response"))
1173            .collect();
1174        assert_eq!(
1175            perm_responses.len(),
1176            1,
1177            "expected exactly one permission_response written"
1178        );
1179        let resp: serde_json::Value = serde_json::from_str(perm_responses[0]).unwrap();
1180        assert_eq!(resp["kind"], "permission_response");
1181        assert_eq!(resp["request_id"], "perm-1");
1182        assert_eq!(resp["result"]["type"], "allow");
1183    }
1184
1185    #[cfg(feature = "testing")]
1186    #[tokio::test]
1187    async fn client_permission_callback_deny_writes_deny_response() {
1188        use crate::permissions::{CanUseToolCallback, PermissionDecision};
1189        use crate::testing::MockTransport;
1190
1191        let callback: CanUseToolCallback = Arc::new(|_tool_name, _input, _ctx| {
1192            Box::pin(async { PermissionDecision::deny("not allowed") })
1193        });
1194
1195        let config = ClientConfig::builder()
1196            .prompt("test")
1197            .can_use_tool(callback)
1198            .build();
1199
1200        let transport = MockTransport::new();
1201        transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
1202        transport.enqueue(r#"{"type":"permission_request","request_id":"perm-2","request":{"type":"permission_request","tool_name":"Write","tool_input":{"path":"/etc/shadow"},"tool_use_id":"tu-2","suggestions":[]}}"#);
1203        transport
1204            .enqueue(&serde_json::to_string(&crate::testing::assistant_text("denied")).unwrap());
1205        transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
1206        let transport = Arc::new(transport);
1207
1208        let mut client = Client::with_transport(config, transport.clone()).unwrap();
1209        client.connect().await.unwrap();
1210
1211        let stream = client.send("hello").unwrap();
1212        tokio::pin!(stream);
1213        let mut messages = Vec::new();
1214        while let Some(msg) = stream.next().await {
1215            messages.push(msg.unwrap());
1216        }
1217
1218        // Verify deny response was written.
1219        let written = transport.written_lines();
1220        let perm_responses: Vec<_> = written
1221            .iter()
1222            .filter(|line| line.contains("permission_response"))
1223            .collect();
1224        assert_eq!(perm_responses.len(), 1);
1225        let resp: serde_json::Value = serde_json::from_str(perm_responses[0]).unwrap();
1226        assert_eq!(resp["kind"], "permission_response");
1227        assert_eq!(resp["request_id"], "perm-2");
1228        assert_eq!(resp["result"]["type"], "deny");
1229        assert_eq!(resp["result"]["message"], "not allowed");
1230    }
1231
1232    #[cfg(feature = "testing")]
1233    #[tokio::test]
1234    async fn client_permission_request_without_callback_yields_error() {
1235        use crate::testing::MockTransport;
1236
1237        // No can_use_tool callback configured.
1238        let config = ClientConfig::builder().prompt("test").build();
1239
1240        let transport = MockTransport::new();
1241        transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
1242        transport.enqueue(r#"{"type":"permission_request","request_id":"perm-3","request":{"type":"permission_request","tool_name":"Bash","tool_input":{"command":"ls"},"tool_use_id":"tu-3","suggestions":[]}}"#);
1243        transport
1244            .enqueue(&serde_json::to_string(&crate::testing::assistant_text("after")).unwrap());
1245        transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
1246        let transport = Arc::new(transport);
1247
1248        let mut client = Client::with_transport(config, transport).unwrap();
1249        client.connect().await.unwrap();
1250
1251        let stream = client.send("hello").unwrap();
1252        tokio::pin!(stream);
1253
1254        let mut got_error = false;
1255        let mut messages = Vec::new();
1256        while let Some(result) = stream.next().await {
1257            match result {
1258                Ok(msg) => messages.push(msg),
1259                Err(Error::ControlProtocol(ref msg)) if msg.contains("can_use_tool") => {
1260                    got_error = true;
1261                }
1262                Err(e) => panic!("unexpected error: {e}"),
1263            }
1264        }
1265
1266        assert!(
1267            got_error,
1268            "should have received a ControlProtocol error for missing callback"
1269        );
1270    }
1271
1272    #[tokio::test]
1273    async fn recv_with_timeout_respects_cancellation_token() {
1274        let (tx, rx) = flume::unbounded::<Result<Message>>();
1275        let token = CancellationToken::new();
1276
1277        // Cancel immediately.
1278        token.cancel();
1279
1280        let result = recv_with_timeout(&rx, None, Some(&token)).await;
1281        assert!(result.is_err());
1282        assert!(result.unwrap_err().is_cancelled());
1283
1284        // Sender is still alive — we didn't get a transport error.
1285        drop(tx);
1286    }
1287
1288    #[tokio::test]
1289    async fn recv_with_timeout_none_cancel_still_works() {
1290        let (_tx, rx) = flume::unbounded::<Result<Message>>();
1291
1292        // With no cancel token and a short timeout, we should get a timeout error.
1293        let result = recv_with_timeout(&rx, Some(Duration::from_millis(10)), None).await;
1294        assert!(matches!(result, Err(Error::Timeout(_))));
1295    }
1296
1297    // ── control_request (new wire format) tests ──────────────────────
1298
1299    #[cfg(feature = "testing")]
1300    #[tokio::test]
1301    async fn client_control_request_can_use_tool_allow() {
1302        use crate::permissions::{CanUseToolCallback, PermissionDecision};
1303        use crate::testing::MockTransport;
1304        use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
1305
1306        let invoked = Arc::new(AtomicBool::new(false));
1307        let invoked_clone = Arc::clone(&invoked);
1308
1309        let callback: CanUseToolCallback = Arc::new(move |tool_name: &str, _input, _ctx| {
1310            let invoked = Arc::clone(&invoked_clone);
1311            let tool = tool_name.to_owned();
1312            Box::pin(async move {
1313                invoked.store(true, AtomicOrdering::Release);
1314                assert_eq!(tool, "Bash");
1315                PermissionDecision::allow()
1316            })
1317        });
1318
1319        let config = ClientConfig::builder()
1320            .prompt("test")
1321            .can_use_tool(callback)
1322            .build();
1323
1324        let transport = MockTransport::new();
1325        transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
1326        // control_request with can_use_tool subtype (actual CLI wire format).
1327        transport.enqueue(r#"{"type":"control_request","request_id":"cr-1","request":{"subtype":"can_use_tool","tool_name":"Bash","input":{"command":"ls"},"tool_use_id":"tu-1","permission_suggestions":["allow_once"]}}"#);
1328        transport.enqueue(&serde_json::to_string(&crate::testing::assistant_text("done")).unwrap());
1329        transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
1330        let transport = Arc::new(transport);
1331
1332        let mut client = Client::with_transport(config, transport.clone()).unwrap();
1333        client.connect().await.unwrap();
1334
1335        let stream = client.send("hello").unwrap();
1336        tokio::pin!(stream);
1337        let mut messages = Vec::new();
1338        while let Some(msg) = stream.next().await {
1339            messages.push(msg.unwrap());
1340        }
1341
1342        assert!(
1343            invoked.load(AtomicOrdering::Acquire),
1344            "permission callback was not invoked"
1345        );
1346
1347        // control_request should NOT leak as a Message — only assistant + result.
1348        assert_eq!(messages.len(), 2);
1349        assert!(matches!(&messages[0], Message::Assistant(_)));
1350        assert!(matches!(&messages[1], Message::Result(_)));
1351
1352        // Verify a control_response was written back with correct format.
1353        let written = transport.written_lines();
1354        let responses: Vec<_> = written
1355            .iter()
1356            .filter(|line| line.contains("control_response"))
1357            .collect();
1358        assert_eq!(responses.len(), 1, "expected exactly one control_response");
1359        let resp: serde_json::Value = serde_json::from_str(responses[0]).unwrap();
1360        assert_eq!(resp["type"], "control_response");
1361        assert_eq!(resp["response"]["subtype"], "success");
1362        assert_eq!(resp["response"]["request_id"], "cr-1");
1363        assert_eq!(resp["response"]["response"]["behavior"], "allow");
1364    }
1365
1366    #[cfg(feature = "testing")]
1367    #[tokio::test]
1368    async fn client_control_request_can_use_tool_deny() {
1369        use crate::permissions::{CanUseToolCallback, PermissionDecision};
1370        use crate::testing::MockTransport;
1371
1372        let callback: CanUseToolCallback = Arc::new(|_tool_name, _input, _ctx| {
1373            Box::pin(async { PermissionDecision::deny("forbidden") })
1374        });
1375
1376        let config = ClientConfig::builder()
1377            .prompt("test")
1378            .can_use_tool(callback)
1379            .build();
1380
1381        let transport = MockTransport::new();
1382        transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
1383        transport.enqueue(r#"{"type":"control_request","request_id":"cr-2","request":{"subtype":"can_use_tool","tool_name":"Write","input":{"path":"/etc/shadow"},"tool_use_id":"tu-2","permission_suggestions":[]}}"#);
1384        transport.enqueue(&serde_json::to_string(&crate::testing::assistant_text("denied")).unwrap());
1385        transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
1386        let transport = Arc::new(transport);
1387
1388        let mut client = Client::with_transport(config, transport.clone()).unwrap();
1389        client.connect().await.unwrap();
1390
1391        let stream = client.send("hello").unwrap();
1392        tokio::pin!(stream);
1393        let mut messages = Vec::new();
1394        while let Some(msg) = stream.next().await {
1395            messages.push(msg.unwrap());
1396        }
1397
1398        let written = transport.written_lines();
1399        let responses: Vec<_> = written
1400            .iter()
1401            .filter(|line| line.contains("control_response"))
1402            .collect();
1403        assert_eq!(responses.len(), 1);
1404        let resp: serde_json::Value = serde_json::from_str(responses[0]).unwrap();
1405        assert_eq!(resp["type"], "control_response");
1406        assert_eq!(resp["response"]["response"]["behavior"], "deny");
1407        assert_eq!(resp["response"]["response"]["message"], "forbidden");
1408    }
1409
1410    #[cfg(feature = "testing")]
1411    #[tokio::test]
1412    async fn client_control_request_no_callback_yields_error() {
1413        use crate::testing::MockTransport;
1414
1415        // No can_use_tool callback configured.
1416        let config = ClientConfig::builder().prompt("test").build();
1417
1418        let transport = MockTransport::new();
1419        transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
1420        transport.enqueue(r#"{"type":"control_request","request_id":"cr-3","request":{"subtype":"can_use_tool","tool_name":"Bash","input":{"command":"ls"},"tool_use_id":"tu-3","permission_suggestions":[]}}"#);
1421        transport.enqueue(&serde_json::to_string(&crate::testing::assistant_text("after")).unwrap());
1422        transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
1423        let transport = Arc::new(transport);
1424
1425        let mut client = Client::with_transport(config, transport.clone()).unwrap();
1426        client.connect().await.unwrap();
1427
1428        let stream = client.send("hello").unwrap();
1429        tokio::pin!(stream);
1430
1431        let mut got_error = false;
1432        let mut messages = Vec::new();
1433        while let Some(result) = stream.next().await {
1434            match result {
1435                Ok(msg) => messages.push(msg),
1436                Err(Error::ControlProtocol(ref msg)) if msg.contains("can_use_tool") => {
1437                    got_error = true;
1438                }
1439                Err(e) => panic!("unexpected error: {e}"),
1440            }
1441        }
1442
1443        assert!(
1444            got_error,
1445            "should have received a ControlProtocol error for missing callback"
1446        );
1447
1448        // Verify error control_response was written.
1449        let written = transport.written_lines();
1450        let responses: Vec<_> = written
1451            .iter()
1452            .filter(|line| line.contains("control_response"))
1453            .collect();
1454        assert_eq!(responses.len(), 1);
1455        let resp: serde_json::Value = serde_json::from_str(responses[0]).unwrap();
1456        assert_eq!(resp["response"]["subtype"], "error");
1457    }
1458
1459    #[cfg(feature = "testing")]
1460    #[tokio::test]
1461    async fn client_control_request_hook_callback() {
1462        use crate::hooks::{HookCallback, HookEvent, HookMatcher, HookOutput};
1463        use crate::testing::MockTransport;
1464
1465        // Register a PreToolUse hook that allows.
1466        let callback: HookCallback =
1467            Arc::new(|_, _, _| Box::pin(async { HookOutput::allow() }));
1468
1469        let config = ClientConfig::builder()
1470            .prompt("test")
1471            .hooks(vec![HookMatcher::new(HookEvent::PreToolUse, callback)])
1472            .build();
1473
1474        let transport = MockTransport::new();
1475        transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
1476        transport.enqueue(r#"{"type":"control_request","request_id":"hook-1","request":{"subtype":"hook_callback","hook_event_name":"PreToolUse","tool_name":"Bash","tool_input":{"command":"ls"},"tool_use_id":"tu-h1"}}"#);
1477        transport.enqueue(&serde_json::to_string(&crate::testing::assistant_text("hooked")).unwrap());
1478        transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
1479        let transport = Arc::new(transport);
1480
1481        let mut client = Client::with_transport(config, transport.clone()).unwrap();
1482        client.connect().await.unwrap();
1483
1484        let stream = client.send("hello").unwrap();
1485        tokio::pin!(stream);
1486        let mut messages = Vec::new();
1487        while let Some(msg) = stream.next().await {
1488            messages.push(msg.unwrap());
1489        }
1490
1491        // Hook control_request should not leak as a Message.
1492        assert_eq!(messages.len(), 2);
1493        assert!(matches!(&messages[0], Message::Assistant(_)));
1494        assert!(matches!(&messages[1], Message::Result(_)));
1495
1496        // Verify control_response for hook was written.
1497        let written = transport.written_lines();
1498        let responses: Vec<_> = written
1499            .iter()
1500            .filter(|line| line.contains("control_response"))
1501            .collect();
1502        assert_eq!(responses.len(), 1);
1503        let resp: serde_json::Value = serde_json::from_str(responses[0]).unwrap();
1504        assert_eq!(resp["type"], "control_response");
1505        assert_eq!(resp["response"]["subtype"], "success");
1506        assert_eq!(resp["response"]["request_id"], "hook-1");
1507        // Hook Allow → continue_: true
1508        assert_eq!(resp["response"]["response"]["continue_"], true);
1509    }
1510
1511    #[cfg(feature = "testing")]
1512    #[tokio::test]
1513    async fn client_read_timeout_none_waits() {
1514        // MockTransport with recv_delay < reasonable wait, read_timeout None.
1515        let transport = ScenarioBuilder::new("s1")
1516            .exchange(vec![assistant_text("delayed")])
1517            .build();
1518        transport.set_recv_delay(Duration::from_millis(50));
1519        let transport = Arc::new(transport);
1520
1521        let config = ClientConfig::builder()
1522            .prompt("test")
1523            .read_timeout(None)
1524            .build();
1525
1526        let mut client = Client::with_transport(config, transport).unwrap();
1527        client.connect().await.unwrap();
1528
1529        let stream = client.send("hello").unwrap();
1530        tokio::pin!(stream);
1531
1532        let mut messages = Vec::new();
1533        while let Some(msg) = stream.next().await {
1534            messages.push(msg.unwrap());
1535        }
1536
1537        // Should get assistant + result even with delay since no timeout.
1538        assert_eq!(messages.len(), 2);
1539    }
1540}