Skip to main content

claude_cli_sdk/
client.rs

1//! The core `Client` struct — multi-turn, bidirectional Claude Code sessions.
2//!
3//! # Architecture
4//!
5//! On [`connect()`](Client::connect), the client spawns a background task that:
6//! 1. Reads JSON values from the transport
7//! 2. Routes permission requests to the configured callback
8//! 3. Routes hook requests to registered hook matchers
9//! 4. Applies the message callback (if any)
10//! 5. Forwards resulting messages through a `flume` channel
11//!
12//! Callers consume messages via [`send()`](Client::send) (which returns a
13//! stream), or via [`receive_messages()`](Client::receive_messages).
14
15use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::time::Duration;
18
19use dashmap::DashMap;
20use futures_core::Stream;
21use tokio::sync::oneshot;
22use tokio_stream::StreamExt;
23
24use crate::callback::apply_callback;
25use crate::config::{ClientConfig, PermissionMode};
26use crate::errors::{Error, Result};
27use crate::transport::{CliTransport, Transport};
28use crate::types::content::UserContent;
29use crate::types::messages::{Message, SessionInfo};
30
31// ── Timeout helper ───────────────────────────────────────────────────────────
32
33/// Receive a message from a flume channel with an optional timeout.
34///
35/// Returns `Error::Timeout` if the deadline expires, or `Error::Transport` if
36/// the channel is closed.
37pub(crate) async fn recv_with_timeout(
38    rx: &flume::Receiver<Result<Message>>,
39    timeout: Option<Duration>,
40) -> Result<Message> {
41    let recv_fut = rx.recv_async();
42    match timeout {
43        Some(d) => tokio::time::timeout(d, recv_fut)
44            .await
45            .map_err(|_| Error::Timeout(format!("read timed out after {}s", d.as_secs_f64())))?
46            .map_err(|_| Error::Transport("message channel closed".into()))?,
47        None => recv_fut
48            .await
49            .map_err(|_| Error::Transport("message channel closed".into()))?,
50    }
51}
52
53// ── Shared turn stream helper ─────────────────────────────────────────────────
54
55/// Read messages from the receiver until a `Result` message or error,
56/// then clear the turn flag.
57fn read_turn_stream<'a>(
58    rx: &'a flume::Receiver<Result<Message>>,
59    read_timeout: Option<Duration>,
60    turn_flag: Arc<AtomicBool>,
61) -> impl Stream<Item = Result<Message>> + 'a {
62    async_stream::stream! {
63        loop {
64            match recv_with_timeout(rx, read_timeout).await {
65                Ok(msg) => {
66                    let is_result = matches!(&msg, Message::Result(_));
67                    yield Ok(msg);
68                    if is_result {
69                        break;
70                    }
71                }
72                Err(e) => {
73                    yield Err(e);
74                    break;
75                }
76            }
77        }
78        turn_flag.store(false, Ordering::Release);
79    }
80}
81
82// ── Client ───────────────────────────────────────────────────────────────────
83
84/// A stateful Claude Code client that manages a persistent session.
85///
86/// # Lifecycle
87///
88/// 1. Create with [`Client::new(config)`](Client::new) or [`Client::with_transport(config, transport)`](Client::with_transport)
89/// 2. Call [`connect()`](Client::connect) to spawn the CLI and read the init message
90/// 3. Use [`send()`](Client::send) to send prompts and stream responses
91/// 4. Call [`close()`](Client::close) to shut down cleanly
92pub struct Client {
93    config: ClientConfig,
94    transport: Arc<dyn Transport>,
95    session_id: Option<String>,
96    message_rx: Option<flume::Receiver<Result<Message>>>,
97    shutdown_tx: Option<oneshot::Sender<()>>,
98    turn_active: Arc<AtomicBool>,
99    /// Pending outbound control requests awaiting responses.
100    pending_control: Arc<DashMap<String, oneshot::Sender<serde_json::Value>>>,
101    /// Counter for generating unique request IDs.
102    request_counter: Arc<AtomicU64>,
103}
104
105impl std::fmt::Debug for Client {
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        f.debug_struct("Client")
108            .field("session_id", &self.session_id)
109            .field("connected", &self.is_connected())
110            .finish_non_exhaustive()
111    }
112}
113
114impl Client {
115    /// Create a new client with the given configuration.
116    ///
117    /// This does NOT start the CLI — call [`connect()`](Client::connect) next.
118    /// Validates the config (e.g., `cwd` existence) and discovers the CLI binary.
119    ///
120    /// # Errors
121    ///
122    /// Returns [`Error::CliNotFound`] if the CLI binary cannot be discovered,
123    /// or [`Error::Config`] if the configuration is invalid.
124    pub fn new(config: ClientConfig) -> Result<Self> {
125        config.validate()?;
126        let transport = Arc::new(CliTransport::from_config(&config)?);
127        Ok(Self {
128            config,
129            transport,
130            session_id: None,
131            message_rx: None,
132            shutdown_tx: None,
133            turn_active: Arc::new(AtomicBool::new(false)),
134            pending_control: Arc::new(DashMap::new()),
135            request_counter: Arc::new(AtomicU64::new(0)),
136        })
137    }
138
139    /// Create a client with a custom transport (useful for testing).
140    pub fn with_transport(config: ClientConfig, transport: Arc<dyn Transport>) -> Result<Self> {
141        Ok(Self {
142            config,
143            transport,
144            session_id: None,
145            message_rx: None,
146            shutdown_tx: None,
147            turn_active: Arc::new(AtomicBool::new(false)),
148            pending_control: Arc::new(DashMap::new()),
149            request_counter: Arc::new(AtomicU64::new(0)),
150        })
151    }
152
153    /// Connect to the CLI and return the session info from the init message.
154    ///
155    /// This spawns the CLI process (or connects to the mock transport) and
156    /// starts the background reader task. The entire connect sequence
157    /// (transport connect + init message read) is subject to `connect_timeout`.
158    pub async fn connect(&mut self) -> Result<SessionInfo> {
159        let timeout = self.config.connect_timeout;
160        match timeout {
161            Some(d) => tokio::time::timeout(d, self.connect_inner())
162                .await
163                .map_err(|_| {
164                    Error::Timeout(format!("connect timed out after {}s", d.as_secs_f64()))
165                })?,
166            None => self.connect_inner().await,
167        }
168    }
169
170    async fn connect_inner(&mut self) -> Result<SessionInfo> {
171        self.transport.connect().await?;
172
173        // Set up the message routing pipeline.
174        let (msg_tx, msg_rx) = flume::unbounded();
175        let (shutdown_tx, shutdown_rx) = oneshot::channel();
176
177        let transport = Arc::clone(&self.transport);
178        let message_callback = self.config.message_callback.clone();
179        let pending_control = Arc::clone(&self.pending_control);
180
181        // Move hooks into the background task for hook dispatch.
182        let hooks: Vec<crate::hooks::HookMatcher> = std::mem::take(&mut self.config.hooks);
183        let default_hook_timeout = self.config.default_hook_timeout;
184        let hook_transport = Arc::clone(&self.transport);
185
186        // Capture permission callback for the background task.
187        let can_use_tool = self.config.can_use_tool.clone();
188        let perm_transport = Arc::clone(&self.transport);
189
190        // Shared session_id for the background task.
191        // Updated after the init message is parsed so subsequent hook dispatches
192        // carry the real session ID rather than None.
193        let shared_session_id: Arc<std::sync::Mutex<Option<String>>> =
194            Arc::new(std::sync::Mutex::new(None));
195        let hook_session_id = Arc::clone(&shared_session_id);
196
197        // Spawn background reader task.
198        tokio::spawn(async move {
199            let mut stream = transport.read_messages();
200            let mut shutdown = shutdown_rx;
201
202            loop {
203                tokio::select! {
204                    biased;
205                    _ = &mut shutdown => break,
206                    item = stream.next() => {
207                        match item {
208                            Some(Ok(value)) => {
209                                // Route control_response messages to pending senders.
210                                if value.get("type").and_then(|v| v.as_str()) == Some("control_response") {
211                                    if let Some(req_id) = value.get("request_id").and_then(|v| v.as_str()) {
212                                        if let Some((_, tx)) = pending_control.remove(req_id) {
213                                            let _ = tx.send(value);
214                                        }
215                                    }
216                                    continue;
217                                }
218
219                                // Route hook_request messages to registered hooks.
220                                if value.get("type").and_then(|v| v.as_str()) == Some("hook_request") {
221                                    if let Ok(req) = serde_json::from_value::<crate::hooks::HookRequest>(value) {
222                                        let sid = hook_session_id
223                                            .lock()
224                                            .expect("session_id lock")
225                                            .clone();
226                                        let output = crate::hooks::dispatch_hook(
227                                            &req,
228                                            &hooks,
229                                            default_hook_timeout,
230                                            sid,
231                                        ).await;
232                                        let response = crate::hooks::HookResponse::from_output(
233                                            req.request_id,
234                                            output,
235                                        );
236                                        if let Ok(json) = serde_json::to_string(&response) {
237                                            let _ = hook_transport.write(&json).await;
238                                        }
239                                    }
240                                    continue;
241                                }
242
243                                // Route permission_request messages to the can_use_tool callback.
244                                if value.get("type").and_then(|v| v.as_str()) == Some("permission_request") {
245                                    if let Some(ref callback) = can_use_tool {
246                                        if let Ok(req) = serde_json::from_value::<crate::permissions::ControlRequest>(value) {
247                                            let crate::permissions::ControlRequestData::PermissionRequest {
248                                                ref tool_name,
249                                                ref tool_input,
250                                                ref tool_use_id,
251                                                ref suggestions,
252                                            } = req.request;
253                                            let sid = hook_session_id
254                                                .lock()
255                                                .expect("session_id lock")
256                                                .clone()
257                                                .unwrap_or_default();
258                                            let ctx = crate::permissions::PermissionContext {
259                                                tool_use_id: tool_use_id.clone(),
260                                                session_id: sid,
261                                                request_id: req.request_id.clone(),
262                                                suggestions: suggestions.clone(),
263                                            };
264                                            let decision = callback(tool_name, tool_input, ctx).await;
265                                            let response = crate::permissions::ControlResponse {
266                                                kind: "permission_response".into(),
267                                                request_id: req.request_id,
268                                                result: crate::permissions::ControlResponseResult::from(decision),
269                                            };
270                                            if let Ok(json) = serde_json::to_string(&response) {
271                                                let _ = perm_transport.write(&json).await;
272                                            }
273                                        }
274                                    } else {
275                                        // No can_use_tool callback configured. The CLI is waiting
276                                        // for a permission_response that will never arrive, which
277                                        // would cause a silent hang. Surface an error so the
278                                        // caller knows they need to set `can_use_tool` or use a
279                                        // non-default PermissionMode.
280                                        let _ = msg_tx.send(Err(Error::ControlProtocol(
281                                            "received permission_request but no can_use_tool \
282                                             callback is configured — set can_use_tool on \
283                                             ClientConfig or use a PermissionMode that does not \
284                                             require interactive approval"
285                                                .into(),
286                                        )));
287                                    }
288                                    continue;
289                                }
290
291                                // Parse the JSON value into a Message.
292                                let msg: Message = match serde_json::from_value(value) {
293                                    Ok(m) => m,
294                                    Err(e) => {
295                                        let _ = msg_tx.send(Err(Error::Json(e)));
296                                        continue;
297                                    }
298                                };
299
300                                // Apply the message callback.
301                                let msg = match apply_callback(msg, message_callback.as_ref()) {
302                                    Some(m) => m,
303                                    None => continue, // Filtered out
304                                };
305
306                                if msg_tx.send(Ok(msg)).is_err() {
307                                    break; // Receiver dropped
308                                }
309                            }
310                            Some(Err(e)) => {
311                                let _ = msg_tx.send(Err(e));
312                            }
313                            None => break, // Stream ended
314                        }
315                    }
316                }
317            }
318        });
319
320        self.message_rx = Some(msg_rx);
321        self.shutdown_tx = Some(shutdown_tx);
322
323        // Wait for the system/init message.
324        // The overall connect_timeout is enforced by the caller, so we wait
325        // indefinitely here (the outer timeout will cancel us if needed).
326        let init_msg = self
327            .message_rx
328            .as_ref()
329            .unwrap()
330            .recv_async()
331            .await
332            .map_err(|_| Error::Transport("connection closed before init message".into()))?
333            .map_err(|e| Error::Transport(format!("error reading init message: {e}")))?;
334
335        if let Message::System(ref sys) = init_msg {
336            let info = SessionInfo::try_from(sys)?;
337            self.session_id = Some(info.session_id.clone());
338            // Propagate the session ID to the background task so hook
339            // dispatches after this point carry the real session ID.
340            *shared_session_id.lock().expect("session_id lock") = Some(info.session_id.clone());
341            Ok(info)
342        } else {
343            Err(Error::ControlProtocol(format!(
344                "expected system/init as first message, got: {init_msg:?}"
345            )))
346        }
347    }
348
349    /// Send a text prompt and return a stream of response messages.
350    ///
351    /// The stream yields messages until a `Result` message is received
352    /// (which terminates the turn).
353    pub fn send(
354        &self,
355        prompt: impl Into<String>,
356    ) -> Result<impl Stream<Item = Result<Message>> + '_> {
357        let prompt = prompt.into();
358        let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
359        let transport = Arc::clone(&self.transport);
360
361        // Guard against concurrent turns.
362        if self
363            .turn_active
364            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
365            .is_err()
366        {
367            return Err(Error::ControlProtocol("turn already in progress".into()));
368        }
369        let turn_flag = Arc::clone(&self.turn_active);
370        let read_timeout = self.config.read_timeout;
371
372        Ok(async_stream::stream! {
373            // Write the prompt to stdin.
374            if let Err(e) = transport.write(&prompt).await {
375                turn_flag.store(false, Ordering::Release);
376                yield Err(e);
377                return;
378            }
379
380            let inner = read_turn_stream(rx, read_timeout, turn_flag);
381            tokio::pin!(inner);
382            while let Some(item) = inner.next().await {
383                yield item;
384            }
385        })
386    }
387
388    /// Send structured content blocks (text + images) and return a stream of
389    /// response messages.
390    ///
391    /// This is the multi-modal equivalent of [`send()`](Client::send). Content
392    /// is serialised as a JSON user message and written to the CLI's stdin.
393    ///
394    /// # Errors
395    ///
396    /// Returns [`Error::Config`] if `content` is empty, or [`Error::NotConnected`]
397    /// if the client is not connected.
398    pub fn send_content(
399        &self,
400        content: Vec<UserContent>,
401    ) -> Result<impl Stream<Item = Result<Message>> + '_> {
402        if content.is_empty() {
403            return Err(Error::Config("content must not be empty".into()));
404        }
405
406        let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
407        let transport = Arc::clone(&self.transport);
408
409        // Guard against concurrent turns.
410        if self
411            .turn_active
412            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
413            .is_err()
414        {
415            return Err(Error::ControlProtocol("turn already in progress".into()));
416        }
417        let turn_flag = Arc::clone(&self.turn_active);
418        let read_timeout = self.config.read_timeout;
419
420        Ok(async_stream::stream! {
421            // Serialize content blocks as a JSON user message.
422            let user_message = serde_json::json!({
423                "type": "user",
424                "message": {
425                    "role": "user",
426                    "content": content
427                }
428            });
429            let json = match serde_json::to_string(&user_message) {
430                Ok(j) => j,
431                Err(e) => {
432                    turn_flag.store(false, Ordering::Release);
433                    yield Err(Error::Json(e));
434                    return;
435                }
436            };
437
438            if let Err(e) = transport.write(&json).await {
439                turn_flag.store(false, Ordering::Release);
440                yield Err(e);
441                return;
442            }
443
444            let inner = read_turn_stream(rx, read_timeout, turn_flag);
445            tokio::pin!(inner);
446            while let Some(item) = inner.next().await {
447                yield item;
448            }
449        })
450    }
451
452    /// Return a stream of all incoming messages (without sending a prompt).
453    ///
454    /// Useful for consuming messages from a resumed session.
455    pub fn receive_messages(&self) -> Result<impl Stream<Item = Result<Message>> + '_> {
456        let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
457        let read_timeout = self.config.read_timeout;
458
459        Ok(async_stream::stream! {
460            loop {
461                match recv_with_timeout(rx, read_timeout).await {
462                    Ok(msg) => yield Ok(msg),
463                    Err(e) if matches!(e, Error::Transport(_)) => break, // Channel closed
464                    Err(e) => {
465                        yield Err(e);
466                        break;
467                    }
468                }
469            }
470        })
471    }
472
473    /// Send an interrupt signal to the CLI (SIGINT).
474    pub async fn interrupt(&self) -> Result<()> {
475        self.transport.interrupt().await
476    }
477
478    /// Respond to a permission request from the CLI.
479    ///
480    /// When the CLI asks for permission to use a tool, this method sends
481    /// the decision back via the control protocol.
482    pub async fn respond_to_permission(
483        &self,
484        request_id: &str,
485        decision: crate::permissions::PermissionDecision,
486    ) -> Result<()> {
487        use crate::permissions::{ControlResponse, ControlResponseResult};
488
489        let response = ControlResponse {
490            kind: "permission_response".into(),
491            request_id: request_id.to_string(),
492            result: ControlResponseResult::from(decision),
493        };
494        let json = serde_json::to_string(&response).map_err(Error::Json)?;
495        self.transport.write(&json).await
496    }
497
498    // ── Dynamic control ─────────────────────────────────────────────────
499
500    /// Send a control request to the CLI and wait for the response.
501    ///
502    /// This is the low-level mechanism for dynamic mid-session control.
503    /// The request is wrapped in a `{"type": "control_request", ...}` envelope
504    /// and written to stdin. The background reader routes the matching
505    /// `control_response` back via a `oneshot` channel.
506    async fn send_control_request(&self, request: serde_json::Value) -> Result<serde_json::Value> {
507        let counter = self.request_counter.fetch_add(1, Ordering::Relaxed);
508        let request_id = format!("sdk_req_{counter}");
509
510        let (tx, rx) = oneshot::channel();
511        self.pending_control.insert(request_id.clone(), tx);
512
513        let envelope = serde_json::json!({
514            "type": "control_request",
515            "request_id": request_id,
516            "request": request
517        });
518        let json = serde_json::to_string(&envelope).map_err(Error::Json)?;
519        self.transport.write(&json).await?;
520
521        rx.await
522            .map_err(|_| Error::ControlProtocol("control response channel closed".into()))
523    }
524
525    /// Dynamically change the model used for subsequent turns.
526    ///
527    /// Pass `None` to revert to the session's default model.
528    ///
529    /// # Errors
530    ///
531    /// Returns an error if the CLI rejects the model change or the control
532    /// protocol fails.
533    pub async fn set_model(&self, model: Option<&str>) -> Result<()> {
534        self.send_control_request(serde_json::json!({
535            "subtype": "set_model",
536            "model": model
537        }))
538        .await?;
539        Ok(())
540    }
541
542    /// Dynamically change the permission mode for the current session.
543    ///
544    /// # Errors
545    ///
546    /// Returns an error if the CLI rejects the mode change or the control
547    /// protocol fails.
548    pub async fn set_permission_mode(&self, mode: PermissionMode) -> Result<()> {
549        self.send_control_request(serde_json::json!({
550            "subtype": "set_permission_mode",
551            "mode": mode.as_cli_flag()
552        }))
553        .await?;
554        Ok(())
555    }
556
557    /// Write raw data to the transport's stdin.
558    ///
559    /// This is a low-level method used by free functions like
560    /// [`query_stream_with_content`](crate::query_stream_with_content).
561    pub(crate) async fn transport_write(&self, data: &str) -> Result<()> {
562        self.transport.write(data).await
563    }
564
565    /// Take ownership of the message receiver (for use in `query_stream`).
566    ///
567    /// After calling this, `receive_messages()` and `send()` will return
568    /// `NotConnected`.
569    pub(crate) fn take_message_rx(&mut self) -> Option<flume::Receiver<Result<Message>>> {
570        self.message_rx.take()
571    }
572
573    /// Returns the configured read timeout.
574    #[must_use]
575    pub fn read_timeout(&self) -> Option<Duration> {
576        self.config.read_timeout
577    }
578
579    /// Returns the session ID if connected.
580    #[must_use]
581    pub fn session_id(&self) -> Option<&str> {
582        self.session_id.as_deref()
583    }
584
585    /// Returns `true` if the client is connected.
586    #[must_use]
587    pub fn is_connected(&self) -> bool {
588        self.transport.is_ready()
589    }
590
591    /// Close the client and shut down the CLI process.
592    ///
593    /// After calling this, the `Drop` warning will not fire.
594    pub async fn close(&mut self) -> Result<()> {
595        // Signal the background task to stop.
596        if let Some(tx) = self.shutdown_tx.take() {
597            let _ = tx.send(());
598        }
599        // Drop the message receiver so the Drop impl knows we cleaned up.
600        self.message_rx.take();
601        self.transport.close().await?;
602        Ok(())
603    }
604}
605
606impl Drop for Client {
607    fn drop(&mut self) {
608        if self.shutdown_tx.is_some() || self.message_rx.is_some() {
609            tracing::warn!(
610                "claude_cli_sdk::Client dropped without calling close(). \
611                 Resources may not be cleaned up properly."
612            );
613        }
614    }
615}
616
617// ── Tests ────────────────────────────────────────────────────────────────────
618
619#[cfg(test)]
620mod tests {
621    use super::*;
622    use crate::config::ClientConfig;
623
624    #[cfg(feature = "testing")]
625    use crate::testing::{ScenarioBuilder, assistant_text};
626
627    fn test_config() -> ClientConfig {
628        ClientConfig::builder().prompt("test").build()
629    }
630
631    #[cfg(feature = "testing")]
632    #[tokio::test]
633    async fn client_connect_and_receive_init() {
634        let transport = ScenarioBuilder::new("test-session")
635            .exchange(vec![assistant_text("Hello!")])
636            .build();
637        let transport = Arc::new(transport);
638
639        let mut client = Client::with_transport(test_config(), transport).unwrap();
640        let info = client.connect().await.unwrap();
641
642        assert_eq!(info.session_id, "test-session");
643        assert!(client.is_connected());
644        assert_eq!(client.session_id(), Some("test-session"));
645    }
646
647    #[cfg(feature = "testing")]
648    #[tokio::test]
649    async fn client_send_yields_messages() {
650        let transport = ScenarioBuilder::new("s1")
651            .exchange(vec![assistant_text("response")])
652            .build();
653        let transport = Arc::new(transport);
654
655        let mut client = Client::with_transport(test_config(), transport).unwrap();
656        client.connect().await.unwrap();
657
658        let stream = client.send("hello").unwrap();
659        tokio::pin!(stream);
660
661        let mut messages = Vec::new();
662        while let Some(msg) = stream.next().await {
663            messages.push(msg.unwrap());
664        }
665
666        // Should get assistant + result
667        assert_eq!(messages.len(), 2);
668        assert!(matches!(&messages[0], Message::Assistant(_)));
669        assert!(matches!(&messages[1], Message::Result(_)));
670    }
671
672    #[cfg(feature = "testing")]
673    #[tokio::test]
674    async fn client_close_succeeds() {
675        let transport = ScenarioBuilder::new("s1").build();
676        let transport = Arc::new(transport);
677
678        let mut client = Client::with_transport(test_config(), transport).unwrap();
679        client.connect().await.unwrap();
680        assert!(client.close().await.is_ok());
681    }
682
683    #[cfg(feature = "testing")]
684    #[tokio::test]
685    async fn client_message_callback_filters() {
686        use crate::callback::MessageCallback;
687
688        // Filter out assistant messages.
689        let callback: MessageCallback = Arc::new(|msg| match &msg {
690            Message::Assistant(_) => None,
691            _ => Some(msg),
692        });
693
694        let config = ClientConfig::builder()
695            .prompt("test")
696            .message_callback(callback)
697            .build();
698
699        let transport = ScenarioBuilder::new("s1")
700            .exchange(vec![assistant_text("filtered")])
701            .build();
702        let transport = Arc::new(transport);
703
704        let mut client = Client::with_transport(config, transport).unwrap();
705        client.connect().await.unwrap();
706
707        let stream = client.send("hello").unwrap();
708        tokio::pin!(stream);
709
710        let mut messages = Vec::new();
711        while let Some(msg) = stream.next().await {
712            messages.push(msg.unwrap());
713        }
714
715        // Only result (assistant was filtered).
716        assert_eq!(messages.len(), 1);
717        assert!(matches!(&messages[0], Message::Result(_)));
718    }
719
720    #[cfg(feature = "testing")]
721    #[test]
722    fn client_debug_before_connect() {
723        let transport = Arc::new(crate::testing::MockTransport::new());
724        let client = Client::with_transport(test_config(), transport).unwrap();
725        let debug = format!("{client:?}");
726        assert!(debug.contains("Client"));
727    }
728
729    // ── Timeout tests ────────────────────────────────────────────────────
730
731    #[cfg(feature = "testing")]
732    #[tokio::test]
733    async fn client_connect_timeout_fires() {
734        use crate::testing::MockTransport;
735
736        let transport = MockTransport::new();
737        // Set connect delay longer than timeout.
738        transport.set_connect_delay(Duration::from_secs(5));
739        let transport = Arc::new(transport);
740
741        let config = ClientConfig::builder()
742            .prompt("test")
743            .connect_timeout(Some(Duration::from_millis(50)))
744            .build();
745
746        let mut client = Client::with_transport(config, transport).unwrap();
747        let result = client.connect().await;
748        assert!(result.is_err());
749        assert!(matches!(result.unwrap_err(), Error::Timeout(_)));
750    }
751
752    #[cfg(feature = "testing")]
753    #[tokio::test]
754    async fn client_read_timeout_fires() {
755        // Build a scenario with init + assistant, but add a large recv_delay
756        // so the assistant message arrives after the read timeout.
757        let transport = ScenarioBuilder::new("s1")
758            .exchange(vec![assistant_text("delayed")])
759            .build();
760        // Delay each message by 5 seconds — way longer than our 50ms timeout.
761        // The init message also gets this delay, but the connect has no timeout
762        // wrapping for the recv path since connect_timeout is set to None here
763        // and connect_inner waits indefinitely for init.
764        // Actually, we need init to arrive fast but subsequent messages slow.
765        // The MockTransport applies recv_delay to ALL messages including init.
766        // So set connect_timeout high enough and read_timeout low.
767        transport.set_recv_delay(Duration::from_millis(200));
768        let transport = Arc::new(transport);
769
770        let config = ClientConfig::builder()
771            .prompt("test")
772            .connect_timeout(Some(Duration::from_secs(5)))
773            .read_timeout(Some(Duration::from_millis(50)))
774            .build();
775
776        let mut client = Client::with_transport(config, transport).unwrap();
777        client.connect().await.unwrap();
778
779        let stream = client.send("hello").unwrap();
780        tokio::pin!(stream);
781
782        let mut got_timeout = false;
783        while let Some(msg) = stream.next().await {
784            if let Err(Error::Timeout(_)) = msg {
785                got_timeout = true;
786                break;
787            }
788        }
789        assert!(got_timeout, "expected a timeout error");
790    }
791
792    #[cfg(feature = "testing")]
793    #[tokio::test]
794    async fn client_permission_callback_invoked_and_responds() {
795        use crate::permissions::{CanUseToolCallback, PermissionDecision};
796        use crate::testing::MockTransport;
797        use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
798
799        let invoked = Arc::new(AtomicBool::new(false));
800        let invoked_clone = Arc::clone(&invoked);
801
802        let callback: CanUseToolCallback = Arc::new(move |tool_name: &str, _input, _ctx| {
803            let invoked = Arc::clone(&invoked_clone);
804            let tool = tool_name.to_owned();
805            Box::pin(async move {
806                invoked.store(true, AtomicOrdering::Release);
807                assert_eq!(tool, "Bash");
808                PermissionDecision::allow()
809            })
810        });
811
812        let config = ClientConfig::builder()
813            .prompt("test")
814            .can_use_tool(callback)
815            .build();
816
817        let transport = MockTransport::new();
818        // Enqueue: init, permission_request, assistant, result
819        transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
820        transport.enqueue(r#"{"type":"permission_request","request_id":"perm-1","request":{"type":"permission_request","tool_name":"Bash","tool_input":{"command":"ls"},"tool_use_id":"tu-1","suggestions":["allow_once"]}}"#);
821        transport.enqueue(&serde_json::to_string(&crate::testing::assistant_text("done")).unwrap());
822        transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
823        let transport = Arc::new(transport);
824
825        let mut client = Client::with_transport(config, transport.clone()).unwrap();
826        client.connect().await.unwrap();
827
828        let stream = client.send("hello").unwrap();
829        tokio::pin!(stream);
830        let mut messages = Vec::new();
831        while let Some(msg) = stream.next().await {
832            messages.push(msg.unwrap());
833        }
834
835        // Callback should have been invoked.
836        assert!(
837            invoked.load(AtomicOrdering::Acquire),
838            "permission callback was not invoked"
839        );
840
841        // Permission request should NOT leak as a Message — only assistant + result.
842        assert_eq!(messages.len(), 2);
843        assert!(matches!(&messages[0], Message::Assistant(_)));
844        assert!(matches!(&messages[1], Message::Result(_)));
845
846        // Verify a permission_response was written back.
847        let written = transport.written_lines();
848        let perm_responses: Vec<_> = written
849            .iter()
850            .filter(|line| line.contains("permission_response"))
851            .collect();
852        assert_eq!(
853            perm_responses.len(),
854            1,
855            "expected exactly one permission_response written"
856        );
857        let resp: serde_json::Value = serde_json::from_str(perm_responses[0]).unwrap();
858        assert_eq!(resp["kind"], "permission_response");
859        assert_eq!(resp["request_id"], "perm-1");
860        assert_eq!(resp["result"]["type"], "allow");
861    }
862
863    #[cfg(feature = "testing")]
864    #[tokio::test]
865    async fn client_permission_callback_deny_writes_deny_response() {
866        use crate::permissions::{CanUseToolCallback, PermissionDecision};
867        use crate::testing::MockTransport;
868
869        let callback: CanUseToolCallback = Arc::new(|_tool_name, _input, _ctx| {
870            Box::pin(async { PermissionDecision::deny("not allowed") })
871        });
872
873        let config = ClientConfig::builder()
874            .prompt("test")
875            .can_use_tool(callback)
876            .build();
877
878        let transport = MockTransport::new();
879        transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
880        transport.enqueue(r#"{"type":"permission_request","request_id":"perm-2","request":{"type":"permission_request","tool_name":"Write","tool_input":{"path":"/etc/shadow"},"tool_use_id":"tu-2","suggestions":[]}}"#);
881        transport
882            .enqueue(&serde_json::to_string(&crate::testing::assistant_text("denied")).unwrap());
883        transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
884        let transport = Arc::new(transport);
885
886        let mut client = Client::with_transport(config, transport.clone()).unwrap();
887        client.connect().await.unwrap();
888
889        let stream = client.send("hello").unwrap();
890        tokio::pin!(stream);
891        let mut messages = Vec::new();
892        while let Some(msg) = stream.next().await {
893            messages.push(msg.unwrap());
894        }
895
896        // Verify deny response was written.
897        let written = transport.written_lines();
898        let perm_responses: Vec<_> = written
899            .iter()
900            .filter(|line| line.contains("permission_response"))
901            .collect();
902        assert_eq!(perm_responses.len(), 1);
903        let resp: serde_json::Value = serde_json::from_str(perm_responses[0]).unwrap();
904        assert_eq!(resp["kind"], "permission_response");
905        assert_eq!(resp["request_id"], "perm-2");
906        assert_eq!(resp["result"]["type"], "deny");
907        assert_eq!(resp["result"]["message"], "not allowed");
908    }
909
910    #[cfg(feature = "testing")]
911    #[tokio::test]
912    async fn client_permission_request_without_callback_yields_error() {
913        use crate::testing::MockTransport;
914
915        // No can_use_tool callback configured.
916        let config = ClientConfig::builder().prompt("test").build();
917
918        let transport = MockTransport::new();
919        transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
920        transport.enqueue(r#"{"type":"permission_request","request_id":"perm-3","request":{"type":"permission_request","tool_name":"Bash","tool_input":{"command":"ls"},"tool_use_id":"tu-3","suggestions":[]}}"#);
921        transport
922            .enqueue(&serde_json::to_string(&crate::testing::assistant_text("after")).unwrap());
923        transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
924        let transport = Arc::new(transport);
925
926        let mut client = Client::with_transport(config, transport).unwrap();
927        client.connect().await.unwrap();
928
929        let stream = client.send("hello").unwrap();
930        tokio::pin!(stream);
931
932        let mut got_error = false;
933        let mut messages = Vec::new();
934        while let Some(result) = stream.next().await {
935            match result {
936                Ok(msg) => messages.push(msg),
937                Err(Error::ControlProtocol(ref msg)) if msg.contains("can_use_tool") => {
938                    got_error = true;
939                }
940                Err(e) => panic!("unexpected error: {e}"),
941            }
942        }
943
944        assert!(
945            got_error,
946            "should have received a ControlProtocol error for missing callback"
947        );
948    }
949
950    #[cfg(feature = "testing")]
951    #[tokio::test]
952    async fn client_read_timeout_none_waits() {
953        // MockTransport with recv_delay < reasonable wait, read_timeout None.
954        let transport = ScenarioBuilder::new("s1")
955            .exchange(vec![assistant_text("delayed")])
956            .build();
957        transport.set_recv_delay(Duration::from_millis(50));
958        let transport = Arc::new(transport);
959
960        let config = ClientConfig::builder()
961            .prompt("test")
962            .read_timeout(None)
963            .build();
964
965        let mut client = Client::with_transport(config, transport).unwrap();
966        client.connect().await.unwrap();
967
968        let stream = client.send("hello").unwrap();
969        tokio::pin!(stream);
970
971        let mut messages = Vec::new();
972        while let Some(msg) = stream.next().await {
973            messages.push(msg.unwrap());
974        }
975
976        // Should get assistant + result even with delay since no timeout.
977        assert_eq!(messages.len(), 2);
978    }
979}