Skip to main content

claude_cli_sdk/
client.rs

1//! The core `Client` struct — multi-turn, bidirectional Claude Code sessions.
2//!
3//! # Architecture
4//!
5//! On [`connect()`](Client::connect), the client spawns a background task that:
6//! 1. Reads JSON values from the transport
7//! 2. Routes permission requests to the configured callback
8//! 3. Routes hook requests to registered hook matchers
9//! 4. Applies the message callback (if any)
10//! 5. Forwards resulting messages through a `flume` channel
11//!
12//! Callers consume messages via [`send()`](Client::send) (which returns a
13//! stream), or via [`receive_messages()`](Client::receive_messages).
14
15use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::time::Duration;
18
19use dashmap::DashMap;
20use futures_core::Stream;
21use tokio::sync::oneshot;
22use tokio_stream::StreamExt;
23use tokio_util::sync::CancellationToken;
24
25use crate::callback::apply_callback;
26use crate::config::{ClientConfig, PermissionMode};
27use crate::errors::{Error, Result};
28use crate::transport::{CliTransport, Transport};
29use crate::types::content::UserContent;
30use crate::types::messages::{Message, SessionInfo};
31
32// ── Cancellation helper ──────────────────────────────────────────────────────
33
34/// Wait for the token to fire, or pend forever if `None`.
35///
36/// Useful as a `tokio::select!` branch that compiles away when no token is
37/// provided.
38pub(crate) async fn cancelled_or_pending(token: Option<&CancellationToken>) {
39    match token {
40        Some(t) => t.cancelled().await,
41        None => std::future::pending().await,
42    }
43}
44
45// ── Timeout helper ───────────────────────────────────────────────────────────
46
47/// Receive a message from a flume channel with an optional timeout and
48/// optional cancellation token.
49///
50/// Returns `Error::Timeout` if the deadline expires, `Error::Cancelled` if
51/// the token fires, or `Error::Transport` if the channel is closed.
52pub(crate) async fn recv_with_timeout(
53    rx: &flume::Receiver<Result<Message>>,
54    timeout: Option<Duration>,
55    cancel: Option<&CancellationToken>,
56) -> Result<Message> {
57    let recv_fut = rx.recv_async();
58    tokio::select! {
59        biased;
60        _ = cancelled_or_pending(cancel) => {
61            Err(Error::Cancelled)
62        }
63        result = async {
64            match timeout {
65                Some(d) => match tokio::time::timeout(d, recv_fut).await {
66                    Ok(Ok(msg)) => msg,
67                    Ok(Err(_)) => Err(Error::Transport("message channel closed".into())),
68                    Err(_) => Err(Error::Timeout(format!("read timed out after {}s", d.as_secs_f64()))),
69                },
70                None => match recv_fut.await {
71                    Ok(msg) => msg,
72                    Err(_) => Err(Error::Transport("message channel closed".into())),
73                },
74            }
75        } => result,
76    }
77}
78
79// ── Shared turn stream helper ─────────────────────────────────────────────────
80
81/// Read messages from the receiver until a `Result` message or error,
82/// then clear the turn flag.
83fn read_turn_stream<'a>(
84    rx: &'a flume::Receiver<Result<Message>>,
85    read_timeout: Option<Duration>,
86    turn_flag: Arc<AtomicBool>,
87    cancel: Option<CancellationToken>,
88) -> impl Stream<Item = Result<Message>> + 'a {
89    async_stream::stream! {
90        loop {
91            match recv_with_timeout(rx, read_timeout, cancel.as_ref()).await {
92                Ok(msg) => {
93                    let is_result = matches!(&msg, Message::Result(_));
94                    yield Ok(msg);
95                    if is_result {
96                        break;
97                    }
98                }
99                Err(e) => {
100                    yield Err(e);
101                    break;
102                }
103            }
104        }
105        turn_flag.store(false, Ordering::Release);
106    }
107}
108
109// ── Client ───────────────────────────────────────────────────────────────────
110
111/// A stateful Claude Code client that manages a persistent session.
112///
113/// # Lifecycle
114///
115/// 1. Create with [`Client::new(config)`](Client::new) or [`Client::with_transport(config, transport)`](Client::with_transport)
116/// 2. Call [`connect()`](Client::connect) to spawn the CLI and read the init message
117/// 3. Use [`send()`](Client::send) to send prompts and stream responses
118/// 4. Call [`close()`](Client::close) to shut down cleanly
119pub struct Client {
120    config: ClientConfig,
121    transport: Arc<dyn Transport>,
122    session_id: Option<String>,
123    message_rx: Option<flume::Receiver<Result<Message>>>,
124    shutdown_tx: Option<oneshot::Sender<()>>,
125    turn_active: Arc<AtomicBool>,
126    /// Pending outbound control requests awaiting responses.
127    pending_control: Arc<DashMap<String, oneshot::Sender<serde_json::Value>>>,
128    /// Counter for generating unique request IDs.
129    request_counter: Arc<AtomicU64>,
130}
131
132impl std::fmt::Debug for Client {
133    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134        f.debug_struct("Client")
135            .field("session_id", &self.session_id)
136            .field("connected", &self.is_connected())
137            .finish_non_exhaustive()
138    }
139}
140
141impl Client {
142    /// Create a new client with the given configuration.
143    ///
144    /// This does NOT start the CLI — call [`connect()`](Client::connect) next.
145    /// Validates the config (e.g., `cwd` existence) and discovers the CLI binary.
146    ///
147    /// # Errors
148    ///
149    /// Returns [`Error::CliNotFound`] if the CLI binary cannot be discovered,
150    /// or [`Error::Config`] if the configuration is invalid.
151    pub fn new(config: ClientConfig) -> Result<Self> {
152        config.validate()?;
153        let transport = Arc::new(CliTransport::from_config(&config)?);
154        Ok(Self {
155            config,
156            transport,
157            session_id: None,
158            message_rx: None,
159            shutdown_tx: None,
160            turn_active: Arc::new(AtomicBool::new(false)),
161            pending_control: Arc::new(DashMap::new()),
162            request_counter: Arc::new(AtomicU64::new(0)),
163        })
164    }
165
166    /// Create a client with a custom transport (useful for testing).
167    pub fn with_transport(config: ClientConfig, transport: Arc<dyn Transport>) -> Result<Self> {
168        Ok(Self {
169            config,
170            transport,
171            session_id: None,
172            message_rx: None,
173            shutdown_tx: None,
174            turn_active: Arc::new(AtomicBool::new(false)),
175            pending_control: Arc::new(DashMap::new()),
176            request_counter: Arc::new(AtomicU64::new(0)),
177        })
178    }
179
180    /// Connect to the CLI and return the session info from the init message.
181    ///
182    /// This spawns the CLI process (or connects to the mock transport) and
183    /// starts the background reader task. The entire connect sequence
184    /// (transport connect + init message read) is subject to `connect_timeout`.
185    pub async fn connect(&mut self) -> Result<SessionInfo> {
186        let timeout = self.config.connect_timeout;
187        match timeout {
188            Some(d) => tokio::time::timeout(d, self.connect_inner())
189                .await
190                .map_err(|_| {
191                    Error::Timeout(format!("connect timed out after {}s", d.as_secs_f64()))
192                })?,
193            None => self.connect_inner().await,
194        }
195    }
196
197    async fn connect_inner(&mut self) -> Result<SessionInfo> {
198        self.transport.connect().await?;
199
200        // Set up the message routing pipeline.
201        let (msg_tx, msg_rx) = flume::unbounded();
202        let (shutdown_tx, shutdown_rx) = oneshot::channel();
203
204        let transport = Arc::clone(&self.transport);
205        let message_callback = self.config.message_callback.clone();
206        let pending_control = Arc::clone(&self.pending_control);
207
208        // Move hooks into the background task for hook dispatch.
209        let hooks: Vec<crate::hooks::HookMatcher> = std::mem::take(&mut self.config.hooks);
210        let default_hook_timeout = self.config.default_hook_timeout;
211        let hook_transport = Arc::clone(&self.transport);
212
213        // Capture permission callback for the background task.
214        let can_use_tool = self.config.can_use_tool.clone();
215        let perm_transport = Arc::clone(&self.transport);
216
217        // Cancellation token for cooperative consumer-side abort.
218        let cancel_token = self.config.cancellation_token.clone();
219
220        // Shared session_id for the background task.
221        // Updated after the init message is parsed so subsequent hook dispatches
222        // carry the real session ID rather than None.
223        let shared_session_id: Arc<std::sync::Mutex<Option<String>>> =
224            Arc::new(std::sync::Mutex::new(None));
225        let hook_session_id = Arc::clone(&shared_session_id);
226
227        // Spawn background reader task.
228        tokio::spawn(async move {
229            let mut stream = transport.read_messages();
230            let mut shutdown = shutdown_rx;
231
232            loop {
233                tokio::select! {
234                    biased;
235                    _ = &mut shutdown => break,
236                    _ = cancelled_or_pending(cancel_token.as_ref()) => break,
237                    item = stream.next() => {
238                        match item {
239                            Some(Ok(value)) => {
240                                // Route control_response messages to pending senders.
241                                if value.get("type").and_then(|v| v.as_str()) == Some("control_response") {
242                                    if let Some(req_id) = value.get("request_id").and_then(|v| v.as_str()) {
243                                        if let Some((_, tx)) = pending_control.remove(req_id) {
244                                            let _ = tx.send(value);
245                                        }
246                                    }
247                                    continue;
248                                }
249
250                                // Route hook_request messages to registered hooks.
251                                if value.get("type").and_then(|v| v.as_str()) == Some("hook_request") {
252                                    if let Ok(req) = serde_json::from_value::<crate::hooks::HookRequest>(value) {
253                                        let sid = hook_session_id
254                                            .lock()
255                                            .expect("session_id lock")
256                                            .clone();
257                                        let output = crate::hooks::dispatch_hook(
258                                            &req,
259                                            &hooks,
260                                            default_hook_timeout,
261                                            sid,
262                                        ).await;
263                                        let response = crate::hooks::HookResponse::from_output(
264                                            req.request_id,
265                                            output,
266                                        );
267                                        if let Ok(json) = serde_json::to_string(&response) {
268                                            let _ = hook_transport.write(&json).await;
269                                        }
270                                    }
271                                    continue;
272                                }
273
274                                // Route permission_request messages to the can_use_tool callback.
275                                if value.get("type").and_then(|v| v.as_str()) == Some("permission_request") {
276                                    if let Some(ref callback) = can_use_tool {
277                                        if let Ok(req) = serde_json::from_value::<crate::permissions::ControlRequest>(value) {
278                                            let crate::permissions::ControlRequestData::PermissionRequest {
279                                                ref tool_name,
280                                                ref tool_input,
281                                                ref tool_use_id,
282                                                ref suggestions,
283                                            } = req.request;
284                                            let sid = hook_session_id
285                                                .lock()
286                                                .expect("session_id lock")
287                                                .clone()
288                                                .unwrap_or_default();
289                                            let ctx = crate::permissions::PermissionContext {
290                                                tool_use_id: tool_use_id.clone(),
291                                                session_id: sid,
292                                                request_id: req.request_id.clone(),
293                                                suggestions: suggestions.clone(),
294                                            };
295                                            let decision = callback(tool_name, tool_input, ctx).await;
296                                            let response = crate::permissions::ControlResponse {
297                                                kind: "permission_response".into(),
298                                                request_id: req.request_id,
299                                                result: crate::permissions::ControlResponseResult::from(decision),
300                                            };
301                                            if let Ok(json) = serde_json::to_string(&response) {
302                                                let _ = perm_transport.write(&json).await;
303                                            }
304                                        }
305                                    } else {
306                                        // No can_use_tool callback configured. The CLI is waiting
307                                        // for a permission_response that will never arrive, which
308                                        // would cause a silent hang. Surface an error so the
309                                        // caller knows they need to set `can_use_tool` or use a
310                                        // non-default PermissionMode.
311                                        let _ = msg_tx.send(Err(Error::ControlProtocol(
312                                            "received permission_request but no can_use_tool \
313                                             callback is configured — set can_use_tool on \
314                                             ClientConfig or use a PermissionMode that does not \
315                                             require interactive approval"
316                                                .into(),
317                                        )));
318                                    }
319                                    continue;
320                                }
321
322                                // Parse the JSON value into a Message.
323                                let msg: Message = match serde_json::from_value(value) {
324                                    Ok(m) => m,
325                                    Err(e) => {
326                                        let _ = msg_tx.send(Err(Error::Json(e)));
327                                        continue;
328                                    }
329                                };
330
331                                // Apply the message callback.
332                                let msg = match apply_callback(msg, message_callback.as_ref()) {
333                                    Some(m) => m,
334                                    None => continue, // Filtered out
335                                };
336
337                                if msg_tx.send(Ok(msg)).is_err() {
338                                    break; // Receiver dropped
339                                }
340                            }
341                            Some(Err(e)) => {
342                                let _ = msg_tx.send(Err(e));
343                            }
344                            None => break, // Stream ended
345                        }
346                    }
347                }
348            }
349        });
350
351        self.message_rx = Some(msg_rx);
352        self.shutdown_tx = Some(shutdown_tx);
353
354        // Wait for the system/init message.
355        // The overall connect_timeout is enforced by the caller, so we wait
356        // indefinitely here (the outer timeout will cancel us if needed).
357        let init_msg = self
358            .message_rx
359            .as_ref()
360            .unwrap()
361            .recv_async()
362            .await
363            .map_err(|_| Error::Transport("connection closed before init message".into()))?
364            .map_err(|e| Error::Transport(format!("error reading init message: {e}")))?;
365
366        if let Message::System(ref sys) = init_msg {
367            let info = SessionInfo::try_from(sys)?;
368            self.session_id = Some(info.session_id.clone());
369            // Propagate the session ID to the background task so hook
370            // dispatches after this point carry the real session ID.
371            *shared_session_id.lock().expect("session_id lock") = Some(info.session_id.clone());
372            Ok(info)
373        } else {
374            Err(Error::ControlProtocol(format!(
375                "expected system/init as first message, got: {init_msg:?}"
376            )))
377        }
378    }
379
380    /// Send a text prompt and return a stream of response messages.
381    ///
382    /// The stream yields messages until a `Result` message is received
383    /// (which terminates the turn).
384    pub fn send(
385        &self,
386        prompt: impl Into<String>,
387    ) -> Result<impl Stream<Item = Result<Message>> + '_> {
388        let prompt = prompt.into();
389        let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
390        let transport = Arc::clone(&self.transport);
391
392        // Guard against concurrent turns.
393        if self
394            .turn_active
395            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
396            .is_err()
397        {
398            return Err(Error::ControlProtocol("turn already in progress".into()));
399        }
400        let turn_flag = Arc::clone(&self.turn_active);
401        let read_timeout = self.config.read_timeout;
402        let cancel = self.config.cancellation_token.clone();
403
404        Ok(async_stream::stream! {
405            // Write the prompt to stdin.
406            if let Err(e) = transport.write(&prompt).await {
407                turn_flag.store(false, Ordering::Release);
408                yield Err(e);
409                return;
410            }
411
412            let inner = read_turn_stream(rx, read_timeout, turn_flag, cancel);
413            tokio::pin!(inner);
414            while let Some(item) = inner.next().await {
415                yield item;
416            }
417        })
418    }
419
420    /// Send structured content blocks (text + images) and return a stream of
421    /// response messages.
422    ///
423    /// This is the multi-modal equivalent of [`send()`](Client::send). Content
424    /// is serialised as a JSON user message and written to the CLI's stdin.
425    ///
426    /// # Errors
427    ///
428    /// Returns [`Error::Config`] if `content` is empty, or [`Error::NotConnected`]
429    /// if the client is not connected.
430    pub fn send_content(
431        &self,
432        content: Vec<UserContent>,
433    ) -> Result<impl Stream<Item = Result<Message>> + '_> {
434        if content.is_empty() {
435            return Err(Error::Config("content must not be empty".into()));
436        }
437
438        let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
439        let transport = Arc::clone(&self.transport);
440
441        // Guard against concurrent turns.
442        if self
443            .turn_active
444            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
445            .is_err()
446        {
447            return Err(Error::ControlProtocol("turn already in progress".into()));
448        }
449        let turn_flag = Arc::clone(&self.turn_active);
450        let read_timeout = self.config.read_timeout;
451        let cancel = self.config.cancellation_token.clone();
452
453        Ok(async_stream::stream! {
454            // Serialize content blocks as a JSON user message.
455            let user_message = serde_json::json!({
456                "type": "user",
457                "message": {
458                    "role": "user",
459                    "content": content
460                }
461            });
462            let json = match serde_json::to_string(&user_message) {
463                Ok(j) => j,
464                Err(e) => {
465                    turn_flag.store(false, Ordering::Release);
466                    yield Err(Error::Json(e));
467                    return;
468                }
469            };
470
471            if let Err(e) = transport.write(&json).await {
472                turn_flag.store(false, Ordering::Release);
473                yield Err(e);
474                return;
475            }
476
477            let inner = read_turn_stream(rx, read_timeout, turn_flag, cancel);
478            tokio::pin!(inner);
479            while let Some(item) = inner.next().await {
480                yield item;
481            }
482        })
483    }
484
485    /// Return a stream of all incoming messages (without sending a prompt).
486    ///
487    /// Useful for consuming messages from a resumed session.
488    pub fn receive_messages(&self) -> Result<impl Stream<Item = Result<Message>> + '_> {
489        let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
490        let read_timeout = self.config.read_timeout;
491        let cancel = self.config.cancellation_token.clone();
492
493        Ok(async_stream::stream! {
494            loop {
495                match recv_with_timeout(rx, read_timeout, cancel.as_ref()).await {
496                    Ok(msg) => yield Ok(msg),
497                    Err(e) if matches!(e, Error::Transport(_)) => break, // Channel closed
498                    Err(e) => {
499                        yield Err(e);
500                        break;
501                    }
502                }
503            }
504        })
505    }
506
507    /// Send an interrupt signal to the CLI (SIGINT).
508    pub async fn interrupt(&self) -> Result<()> {
509        self.transport.interrupt().await
510    }
511
512    /// Respond to a permission request from the CLI.
513    ///
514    /// When the CLI asks for permission to use a tool, this method sends
515    /// the decision back via the control protocol.
516    pub async fn respond_to_permission(
517        &self,
518        request_id: &str,
519        decision: crate::permissions::PermissionDecision,
520    ) -> Result<()> {
521        use crate::permissions::{ControlResponse, ControlResponseResult};
522
523        let response = ControlResponse {
524            kind: "permission_response".into(),
525            request_id: request_id.to_string(),
526            result: ControlResponseResult::from(decision),
527        };
528        let json = serde_json::to_string(&response).map_err(Error::Json)?;
529        self.transport.write(&json).await
530    }
531
532    // ── Dynamic control ─────────────────────────────────────────────────
533
534    /// Send a control request to the CLI and wait for the response.
535    ///
536    /// This is the low-level mechanism for dynamic mid-session control.
537    /// The request is wrapped in a `{"type": "control_request", ...}` envelope
538    /// and written to stdin. The background reader routes the matching
539    /// `control_response` back via a `oneshot` channel.
540    async fn send_control_request(&self, request: serde_json::Value) -> Result<serde_json::Value> {
541        let counter = self.request_counter.fetch_add(1, Ordering::Relaxed);
542        let request_id = format!("sdk_req_{counter}");
543
544        let (tx, rx) = oneshot::channel();
545        self.pending_control.insert(request_id.clone(), tx);
546
547        let envelope = serde_json::json!({
548            "type": "control_request",
549            "request_id": request_id,
550            "request": request
551        });
552        let json = serde_json::to_string(&envelope).map_err(Error::Json)?;
553        self.transport.write(&json).await?;
554
555        rx.await
556            .map_err(|_| Error::ControlProtocol("control response channel closed".into()))
557    }
558
559    /// Dynamically change the model used for subsequent turns.
560    ///
561    /// Pass `None` to revert to the session's default model.
562    ///
563    /// # Errors
564    ///
565    /// Returns an error if the CLI rejects the model change or the control
566    /// protocol fails.
567    pub async fn set_model(&self, model: Option<&str>) -> Result<()> {
568        self.send_control_request(serde_json::json!({
569            "subtype": "set_model",
570            "model": model
571        }))
572        .await?;
573        Ok(())
574    }
575
576    /// Dynamically change the permission mode for the current session.
577    ///
578    /// # Errors
579    ///
580    /// Returns an error if the CLI rejects the mode change or the control
581    /// protocol fails.
582    pub async fn set_permission_mode(&self, mode: PermissionMode) -> Result<()> {
583        self.send_control_request(serde_json::json!({
584            "subtype": "set_permission_mode",
585            "mode": mode.as_cli_flag()
586        }))
587        .await?;
588        Ok(())
589    }
590
591    /// Write raw data to the transport's stdin.
592    ///
593    /// This is a low-level method used by free functions like
594    /// [`query_stream_with_content`](crate::query_stream_with_content).
595    pub(crate) async fn transport_write(&self, data: &str) -> Result<()> {
596        self.transport.write(data).await
597    }
598
599    /// Take ownership of the message receiver (for use in `query_stream`).
600    ///
601    /// After calling this, `receive_messages()` and `send()` will return
602    /// `NotConnected`.
603    pub(crate) fn take_message_rx(&mut self) -> Option<flume::Receiver<Result<Message>>> {
604        self.message_rx.take()
605    }
606
607    /// Returns the configured read timeout.
608    #[must_use]
609    pub fn read_timeout(&self) -> Option<Duration> {
610        self.config.read_timeout
611    }
612
613    /// Returns the session ID if connected.
614    #[must_use]
615    pub fn session_id(&self) -> Option<&str> {
616        self.session_id.as_deref()
617    }
618
619    /// Returns `true` if the client is connected.
620    #[must_use]
621    pub fn is_connected(&self) -> bool {
622        self.transport.is_ready()
623    }
624
625    /// Close the client and shut down the CLI process.
626    ///
627    /// After calling this, the `Drop` warning will not fire.
628    pub async fn close(&mut self) -> Result<()> {
629        // Signal the background task to stop.
630        if let Some(tx) = self.shutdown_tx.take() {
631            let _ = tx.send(());
632        }
633        // Drop the message receiver so the Drop impl knows we cleaned up.
634        self.message_rx.take();
635        self.transport.close().await?;
636        Ok(())
637    }
638}
639
640impl Drop for Client {
641    fn drop(&mut self) {
642        if self.shutdown_tx.is_some() || self.message_rx.is_some() {
643            tracing::warn!(
644                "claude_cli_sdk::Client dropped without calling close(). \
645                 Resources may not be cleaned up properly."
646            );
647        }
648    }
649}
650
651// ── Tests ────────────────────────────────────────────────────────────────────
652
653#[cfg(test)]
654mod tests {
655    use super::*;
656    use crate::config::ClientConfig;
657
658    #[cfg(feature = "testing")]
659    use crate::testing::{ScenarioBuilder, assistant_text};
660
661    fn test_config() -> ClientConfig {
662        ClientConfig::builder().prompt("test").build()
663    }
664
665    #[cfg(feature = "testing")]
666    #[tokio::test]
667    async fn client_connect_and_receive_init() {
668        let transport = ScenarioBuilder::new("test-session")
669            .exchange(vec![assistant_text("Hello!")])
670            .build();
671        let transport = Arc::new(transport);
672
673        let mut client = Client::with_transport(test_config(), transport).unwrap();
674        let info = client.connect().await.unwrap();
675
676        assert_eq!(info.session_id, "test-session");
677        assert!(client.is_connected());
678        assert_eq!(client.session_id(), Some("test-session"));
679    }
680
681    #[cfg(feature = "testing")]
682    #[tokio::test]
683    async fn client_send_yields_messages() {
684        let transport = ScenarioBuilder::new("s1")
685            .exchange(vec![assistant_text("response")])
686            .build();
687        let transport = Arc::new(transport);
688
689        let mut client = Client::with_transport(test_config(), transport).unwrap();
690        client.connect().await.unwrap();
691
692        let stream = client.send("hello").unwrap();
693        tokio::pin!(stream);
694
695        let mut messages = Vec::new();
696        while let Some(msg) = stream.next().await {
697            messages.push(msg.unwrap());
698        }
699
700        // Should get assistant + result
701        assert_eq!(messages.len(), 2);
702        assert!(matches!(&messages[0], Message::Assistant(_)));
703        assert!(matches!(&messages[1], Message::Result(_)));
704    }
705
706    #[cfg(feature = "testing")]
707    #[tokio::test]
708    async fn client_close_succeeds() {
709        let transport = ScenarioBuilder::new("s1").build();
710        let transport = Arc::new(transport);
711
712        let mut client = Client::with_transport(test_config(), transport).unwrap();
713        client.connect().await.unwrap();
714        assert!(client.close().await.is_ok());
715    }
716
717    #[cfg(feature = "testing")]
718    #[tokio::test]
719    async fn client_message_callback_filters() {
720        use crate::callback::MessageCallback;
721
722        // Filter out assistant messages.
723        let callback: MessageCallback = Arc::new(|msg| match &msg {
724            Message::Assistant(_) => None,
725            _ => Some(msg),
726        });
727
728        let config = ClientConfig::builder()
729            .prompt("test")
730            .message_callback(callback)
731            .build();
732
733        let transport = ScenarioBuilder::new("s1")
734            .exchange(vec![assistant_text("filtered")])
735            .build();
736        let transport = Arc::new(transport);
737
738        let mut client = Client::with_transport(config, transport).unwrap();
739        client.connect().await.unwrap();
740
741        let stream = client.send("hello").unwrap();
742        tokio::pin!(stream);
743
744        let mut messages = Vec::new();
745        while let Some(msg) = stream.next().await {
746            messages.push(msg.unwrap());
747        }
748
749        // Only result (assistant was filtered).
750        assert_eq!(messages.len(), 1);
751        assert!(matches!(&messages[0], Message::Result(_)));
752    }
753
754    #[cfg(feature = "testing")]
755    #[test]
756    fn client_debug_before_connect() {
757        let transport = Arc::new(crate::testing::MockTransport::new());
758        let client = Client::with_transport(test_config(), transport).unwrap();
759        let debug = format!("{client:?}");
760        assert!(debug.contains("Client"));
761    }
762
763    // ── Timeout tests ────────────────────────────────────────────────────
764
765    #[cfg(feature = "testing")]
766    #[tokio::test]
767    async fn client_connect_timeout_fires() {
768        use crate::testing::MockTransport;
769
770        let transport = MockTransport::new();
771        // Set connect delay longer than timeout.
772        transport.set_connect_delay(Duration::from_secs(5));
773        let transport = Arc::new(transport);
774
775        let config = ClientConfig::builder()
776            .prompt("test")
777            .connect_timeout(Some(Duration::from_millis(50)))
778            .build();
779
780        let mut client = Client::with_transport(config, transport).unwrap();
781        let result = client.connect().await;
782        assert!(result.is_err());
783        assert!(matches!(result.unwrap_err(), Error::Timeout(_)));
784    }
785
786    #[cfg(feature = "testing")]
787    #[tokio::test]
788    async fn client_read_timeout_fires() {
789        // Build a scenario with init + assistant, but add a large recv_delay
790        // so the assistant message arrives after the read timeout.
791        let transport = ScenarioBuilder::new("s1")
792            .exchange(vec![assistant_text("delayed")])
793            .build();
794        // Delay each message by 5 seconds — way longer than our 50ms timeout.
795        // The init message also gets this delay, but the connect has no timeout
796        // wrapping for the recv path since connect_timeout is set to None here
797        // and connect_inner waits indefinitely for init.
798        // Actually, we need init to arrive fast but subsequent messages slow.
799        // The MockTransport applies recv_delay to ALL messages including init.
800        // So set connect_timeout high enough and read_timeout low.
801        transport.set_recv_delay(Duration::from_millis(200));
802        let transport = Arc::new(transport);
803
804        let config = ClientConfig::builder()
805            .prompt("test")
806            .connect_timeout(Some(Duration::from_secs(5)))
807            .read_timeout(Some(Duration::from_millis(50)))
808            .build();
809
810        let mut client = Client::with_transport(config, transport).unwrap();
811        client.connect().await.unwrap();
812
813        let stream = client.send("hello").unwrap();
814        tokio::pin!(stream);
815
816        let mut got_timeout = false;
817        while let Some(msg) = stream.next().await {
818            if let Err(Error::Timeout(_)) = msg {
819                got_timeout = true;
820                break;
821            }
822        }
823        assert!(got_timeout, "expected a timeout error");
824    }
825
826    #[cfg(feature = "testing")]
827    #[tokio::test]
828    async fn client_permission_callback_invoked_and_responds() {
829        use crate::permissions::{CanUseToolCallback, PermissionDecision};
830        use crate::testing::MockTransport;
831        use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
832
833        let invoked = Arc::new(AtomicBool::new(false));
834        let invoked_clone = Arc::clone(&invoked);
835
836        let callback: CanUseToolCallback = Arc::new(move |tool_name: &str, _input, _ctx| {
837            let invoked = Arc::clone(&invoked_clone);
838            let tool = tool_name.to_owned();
839            Box::pin(async move {
840                invoked.store(true, AtomicOrdering::Release);
841                assert_eq!(tool, "Bash");
842                PermissionDecision::allow()
843            })
844        });
845
846        let config = ClientConfig::builder()
847            .prompt("test")
848            .can_use_tool(callback)
849            .build();
850
851        let transport = MockTransport::new();
852        // Enqueue: init, permission_request, assistant, result
853        transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
854        transport.enqueue(r#"{"type":"permission_request","request_id":"perm-1","request":{"type":"permission_request","tool_name":"Bash","tool_input":{"command":"ls"},"tool_use_id":"tu-1","suggestions":["allow_once"]}}"#);
855        transport.enqueue(&serde_json::to_string(&crate::testing::assistant_text("done")).unwrap());
856        transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
857        let transport = Arc::new(transport);
858
859        let mut client = Client::with_transport(config, transport.clone()).unwrap();
860        client.connect().await.unwrap();
861
862        let stream = client.send("hello").unwrap();
863        tokio::pin!(stream);
864        let mut messages = Vec::new();
865        while let Some(msg) = stream.next().await {
866            messages.push(msg.unwrap());
867        }
868
869        // Callback should have been invoked.
870        assert!(
871            invoked.load(AtomicOrdering::Acquire),
872            "permission callback was not invoked"
873        );
874
875        // Permission request should NOT leak as a Message — only assistant + result.
876        assert_eq!(messages.len(), 2);
877        assert!(matches!(&messages[0], Message::Assistant(_)));
878        assert!(matches!(&messages[1], Message::Result(_)));
879
880        // Verify a permission_response was written back.
881        let written = transport.written_lines();
882        let perm_responses: Vec<_> = written
883            .iter()
884            .filter(|line| line.contains("permission_response"))
885            .collect();
886        assert_eq!(
887            perm_responses.len(),
888            1,
889            "expected exactly one permission_response written"
890        );
891        let resp: serde_json::Value = serde_json::from_str(perm_responses[0]).unwrap();
892        assert_eq!(resp["kind"], "permission_response");
893        assert_eq!(resp["request_id"], "perm-1");
894        assert_eq!(resp["result"]["type"], "allow");
895    }
896
897    #[cfg(feature = "testing")]
898    #[tokio::test]
899    async fn client_permission_callback_deny_writes_deny_response() {
900        use crate::permissions::{CanUseToolCallback, PermissionDecision};
901        use crate::testing::MockTransport;
902
903        let callback: CanUseToolCallback = Arc::new(|_tool_name, _input, _ctx| {
904            Box::pin(async { PermissionDecision::deny("not allowed") })
905        });
906
907        let config = ClientConfig::builder()
908            .prompt("test")
909            .can_use_tool(callback)
910            .build();
911
912        let transport = MockTransport::new();
913        transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
914        transport.enqueue(r#"{"type":"permission_request","request_id":"perm-2","request":{"type":"permission_request","tool_name":"Write","tool_input":{"path":"/etc/shadow"},"tool_use_id":"tu-2","suggestions":[]}}"#);
915        transport
916            .enqueue(&serde_json::to_string(&crate::testing::assistant_text("denied")).unwrap());
917        transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
918        let transport = Arc::new(transport);
919
920        let mut client = Client::with_transport(config, transport.clone()).unwrap();
921        client.connect().await.unwrap();
922
923        let stream = client.send("hello").unwrap();
924        tokio::pin!(stream);
925        let mut messages = Vec::new();
926        while let Some(msg) = stream.next().await {
927            messages.push(msg.unwrap());
928        }
929
930        // Verify deny response was written.
931        let written = transport.written_lines();
932        let perm_responses: Vec<_> = written
933            .iter()
934            .filter(|line| line.contains("permission_response"))
935            .collect();
936        assert_eq!(perm_responses.len(), 1);
937        let resp: serde_json::Value = serde_json::from_str(perm_responses[0]).unwrap();
938        assert_eq!(resp["kind"], "permission_response");
939        assert_eq!(resp["request_id"], "perm-2");
940        assert_eq!(resp["result"]["type"], "deny");
941        assert_eq!(resp["result"]["message"], "not allowed");
942    }
943
944    #[cfg(feature = "testing")]
945    #[tokio::test]
946    async fn client_permission_request_without_callback_yields_error() {
947        use crate::testing::MockTransport;
948
949        // No can_use_tool callback configured.
950        let config = ClientConfig::builder().prompt("test").build();
951
952        let transport = MockTransport::new();
953        transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
954        transport.enqueue(r#"{"type":"permission_request","request_id":"perm-3","request":{"type":"permission_request","tool_name":"Bash","tool_input":{"command":"ls"},"tool_use_id":"tu-3","suggestions":[]}}"#);
955        transport
956            .enqueue(&serde_json::to_string(&crate::testing::assistant_text("after")).unwrap());
957        transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
958        let transport = Arc::new(transport);
959
960        let mut client = Client::with_transport(config, transport).unwrap();
961        client.connect().await.unwrap();
962
963        let stream = client.send("hello").unwrap();
964        tokio::pin!(stream);
965
966        let mut got_error = false;
967        let mut messages = Vec::new();
968        while let Some(result) = stream.next().await {
969            match result {
970                Ok(msg) => messages.push(msg),
971                Err(Error::ControlProtocol(ref msg)) if msg.contains("can_use_tool") => {
972                    got_error = true;
973                }
974                Err(e) => panic!("unexpected error: {e}"),
975            }
976        }
977
978        assert!(
979            got_error,
980            "should have received a ControlProtocol error for missing callback"
981        );
982    }
983
984    #[tokio::test]
985    async fn recv_with_timeout_respects_cancellation_token() {
986        let (tx, rx) = flume::unbounded::<Result<Message>>();
987        let token = CancellationToken::new();
988
989        // Cancel immediately.
990        token.cancel();
991
992        let result = recv_with_timeout(&rx, None, Some(&token)).await;
993        assert!(result.is_err());
994        assert!(result.unwrap_err().is_cancelled());
995
996        // Sender is still alive — we didn't get a transport error.
997        drop(tx);
998    }
999
1000    #[tokio::test]
1001    async fn recv_with_timeout_none_cancel_still_works() {
1002        let (_tx, rx) = flume::unbounded::<Result<Message>>();
1003
1004        // With no cancel token and a short timeout, we should get a timeout error.
1005        let result = recv_with_timeout(&rx, Some(Duration::from_millis(10)), None).await;
1006        assert!(matches!(result, Err(Error::Timeout(_))));
1007    }
1008
1009    #[cfg(feature = "testing")]
1010    #[tokio::test]
1011    async fn client_read_timeout_none_waits() {
1012        // MockTransport with recv_delay < reasonable wait, read_timeout None.
1013        let transport = ScenarioBuilder::new("s1")
1014            .exchange(vec![assistant_text("delayed")])
1015            .build();
1016        transport.set_recv_delay(Duration::from_millis(50));
1017        let transport = Arc::new(transport);
1018
1019        let config = ClientConfig::builder()
1020            .prompt("test")
1021            .read_timeout(None)
1022            .build();
1023
1024        let mut client = Client::with_transport(config, transport).unwrap();
1025        client.connect().await.unwrap();
1026
1027        let stream = client.send("hello").unwrap();
1028        tokio::pin!(stream);
1029
1030        let mut messages = Vec::new();
1031        while let Some(msg) = stream.next().await {
1032            messages.push(msg.unwrap());
1033        }
1034
1035        // Should get assistant + result even with delay since no timeout.
1036        assert_eq!(messages.len(), 2);
1037    }
1038}