Skip to main content

claude_cli_sdk/
client.rs

1//! The core `Client` struct — multi-turn, bidirectional Claude Code sessions.
2//!
3//! # Architecture
4//!
5//! On [`connect()`](Client::connect), the client spawns a background task that:
6//! 1. Reads JSON values from the transport
7//! 2. Routes permission requests to the configured callback
8//! 3. Routes hook requests to registered hook matchers
9//! 4. Applies the message callback (if any)
10//! 5. Forwards resulting messages through a `flume` channel
11//!
12//! Callers consume messages via [`send()`](Client::send) (which returns a
13//! stream), or via [`receive_messages()`](Client::receive_messages).
14
15use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::time::Duration;
18
19use dashmap::DashMap;
20use futures_core::Stream;
21use tokio::sync::oneshot;
22use tokio_stream::StreamExt;
23use tokio_util::sync::CancellationToken;
24
25use crate::callback::apply_callback;
26use crate::config::{ClientConfig, PermissionMode};
27use crate::errors::{Error, Result};
28use crate::transport::{CliTransport, Transport};
29use crate::types::content::UserContent;
30use crate::types::messages::{Message, SessionInfo};
31
32// ── Cancellation helper ──────────────────────────────────────────────────────
33
34/// Wait for the token to fire, or pend forever if `None`.
35///
36/// Useful as a `tokio::select!` branch that compiles away when no token is
37/// provided.
38pub(crate) async fn cancelled_or_pending(token: Option<&CancellationToken>) {
39    match token {
40        Some(t) => t.cancelled().await,
41        None => std::future::pending().await,
42    }
43}
44
45// ── Timeout helper ───────────────────────────────────────────────────────────
46
47/// Receive a message from a flume channel with an optional timeout and
48/// optional cancellation token.
49///
50/// Returns `Error::Timeout` if the deadline expires, `Error::Cancelled` if
51/// the token fires, or `Error::Transport` if the channel is closed.
52pub(crate) async fn recv_with_timeout(
53    rx: &flume::Receiver<Result<Message>>,
54    timeout: Option<Duration>,
55    cancel: Option<&CancellationToken>,
56) -> Result<Message> {
57    let recv_fut = rx.recv_async();
58    tokio::select! {
59        biased;
60        _ = cancelled_or_pending(cancel) => {
61            Err(Error::Cancelled)
62        }
63        result = async {
64            match timeout {
65                Some(d) => match tokio::time::timeout(d, recv_fut).await {
66                    Ok(Ok(msg)) => msg,
67                    Ok(Err(_)) => Err(Error::Transport("message channel closed".into())),
68                    Err(_) => Err(Error::Timeout(format!("read timed out after {}s", d.as_secs_f64()))),
69                },
70                None => match recv_fut.await {
71                    Ok(msg) => msg,
72                    Err(_) => Err(Error::Transport("message channel closed".into())),
73                },
74            }
75        } => result,
76    }
77}
78
79// ── Shared turn stream helper ─────────────────────────────────────────────────
80
81/// Read messages from the receiver until a `Result` message or error,
82/// then clear the turn flag.
83fn read_turn_stream<'a>(
84    rx: &'a flume::Receiver<Result<Message>>,
85    read_timeout: Option<Duration>,
86    turn_flag: Arc<AtomicBool>,
87    cancel: Option<CancellationToken>,
88) -> impl Stream<Item = Result<Message>> + 'a {
89    async_stream::stream! {
90        loop {
91            match recv_with_timeout(rx, read_timeout, cancel.as_ref()).await {
92                Ok(msg) => {
93                    let is_result = matches!(&msg, Message::Result(_));
94                    yield Ok(msg);
95                    if is_result {
96                        break;
97                    }
98                }
99                Err(e) => {
100                    yield Err(e);
101                    break;
102                }
103            }
104        }
105        turn_flag.store(false, Ordering::Release);
106    }
107}
108
109// ── Client ───────────────────────────────────────────────────────────────────
110
111/// A stateful Claude Code client that manages a persistent session.
112///
113/// # Lifecycle
114///
115/// 1. Create with [`Client::new(config)`](Client::new) or [`Client::with_transport(config, transport)`](Client::with_transport)
116/// 2. Call [`connect()`](Client::connect) to spawn the CLI and read the init message
117/// 3. Use [`send()`](Client::send) to send prompts and stream responses
118/// 4. Call [`close()`](Client::close) to shut down cleanly
119pub struct Client {
120    config: ClientConfig,
121    transport: Arc<dyn Transport>,
122    session_id: Option<String>,
123    message_rx: Option<flume::Receiver<Result<Message>>>,
124    shutdown_tx: Option<oneshot::Sender<()>>,
125    turn_active: Arc<AtomicBool>,
126    /// Pending outbound control requests awaiting responses.
127    pending_control: Arc<DashMap<String, oneshot::Sender<serde_json::Value>>>,
128    /// Counter for generating unique request IDs.
129    request_counter: Arc<AtomicU64>,
130}
131
132impl std::fmt::Debug for Client {
133    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134        f.debug_struct("Client")
135            .field("session_id", &self.session_id)
136            .field("connected", &self.is_connected())
137            .finish_non_exhaustive()
138    }
139}
140
141impl Client {
142    /// Create a new client with the given configuration.
143    ///
144    /// This does NOT start the CLI — call [`connect()`](Client::connect) next.
145    /// Validates the config (e.g., `cwd` existence) and discovers the CLI binary.
146    ///
147    /// # Errors
148    ///
149    /// Returns [`Error::CliNotFound`] if the CLI binary cannot be discovered,
150    /// or [`Error::Config`] if the configuration is invalid.
151    pub fn new(config: ClientConfig) -> Result<Self> {
152        config.validate()?;
153        let transport = Arc::new(CliTransport::from_config(&config)?);
154        Ok(Self {
155            config,
156            transport,
157            session_id: None,
158            message_rx: None,
159            shutdown_tx: None,
160            turn_active: Arc::new(AtomicBool::new(false)),
161            pending_control: Arc::new(DashMap::new()),
162            request_counter: Arc::new(AtomicU64::new(0)),
163        })
164    }
165
166    /// Create a client with a custom transport (useful for testing).
167    pub fn with_transport(config: ClientConfig, transport: Arc<dyn Transport>) -> Result<Self> {
168        config.validate()?;
169        Ok(Self {
170            config,
171            transport,
172            session_id: None,
173            message_rx: None,
174            shutdown_tx: None,
175            turn_active: Arc::new(AtomicBool::new(false)),
176            pending_control: Arc::new(DashMap::new()),
177            request_counter: Arc::new(AtomicU64::new(0)),
178        })
179    }
180
181    /// Connect to the CLI and return the session info from the init message.
182    ///
183    /// This spawns the CLI process (or connects to the mock transport) and
184    /// starts the background reader task. The entire connect sequence
185    /// (transport connect + init message read) is subject to `connect_timeout`.
186    pub async fn connect(&mut self) -> Result<SessionInfo> {
187        let timeout = self.config.connect_timeout;
188        let result = match timeout {
189            Some(d) => tokio::time::timeout(d, self.connect_inner())
190                .await
191                .map_err(|_| {
192                    Error::Timeout(format!("connect timed out after {}s", d.as_secs_f64()))
193                })?,
194            None => self.connect_inner().await,
195        };
196        if result.is_err() {
197            // Clean up: stop background task + kill CLI process.
198            if let Some(tx) = self.shutdown_tx.take() {
199                let _ = tx.send(());
200            }
201            self.message_rx.take();
202            let _ = self.transport.close().await;
203        }
204        result
205    }
206
207    async fn connect_inner(&mut self) -> Result<SessionInfo> {
208        self.transport.connect().await?;
209
210        // Set up the message routing pipeline.
211        let (msg_tx, msg_rx) = flume::bounded(1024);
212        let (shutdown_tx, shutdown_rx) = oneshot::channel();
213
214        let transport = Arc::clone(&self.transport);
215        let message_callback = self.config.message_callback.clone();
216        let pending_control = Arc::clone(&self.pending_control);
217
218        // Move hooks into the background task for hook dispatch.
219        let hooks: Vec<crate::hooks::HookMatcher> = std::mem::take(&mut self.config.hooks);
220        let default_hook_timeout = self.config.default_hook_timeout;
221        let hook_transport = Arc::clone(&self.transport);
222
223        // Capture permission callback for the background task.
224        let can_use_tool = self.config.can_use_tool.clone();
225        let perm_transport = Arc::clone(&self.transport);
226
227        // Cancellation token for cooperative consumer-side abort.
228        let cancel_token = self.config.cancellation_token.clone();
229
230        // Shared session_id for the background task.
231        // Updated after the init message is parsed so subsequent hook dispatches
232        // carry the real session ID rather than None.
233        let shared_session_id: Arc<std::sync::Mutex<Option<String>>> =
234            Arc::new(std::sync::Mutex::new(None));
235        let hook_session_id = Arc::clone(&shared_session_id);
236
237        // Spawn background reader task.
238        tokio::spawn(async move {
239            let mut stream = transport.read_messages();
240            let mut shutdown = shutdown_rx;
241
242            loop {
243                tokio::select! {
244                    biased;
245                    _ = &mut shutdown => break,
246                    _ = cancelled_or_pending(cancel_token.as_ref()) => break,
247                    item = stream.next() => {
248                        match item {
249                            Some(Ok(value)) => {
250                                // Route control_response messages to pending senders.
251                                if value.get("type").and_then(|v| v.as_str()) == Some("control_response") {
252                                    if let Some(req_id) = value.get("request_id").and_then(|v| v.as_str()) {
253                                        if let Some((_, tx)) = pending_control.remove(req_id) {
254                                            let _ = tx.send(value);
255                                        }
256                                    }
257                                    continue;
258                                }
259
260                                // Route hook_request messages to registered hooks.
261                                if value.get("type").and_then(|v| v.as_str()) == Some("hook_request") {
262                                    if let Ok(req) = serde_json::from_value::<crate::hooks::HookRequest>(value) {
263                                        let sid = hook_session_id
264                                            .lock()
265                                            .expect("session_id lock")
266                                            .clone();
267                                        let output = crate::hooks::dispatch_hook(
268                                            &req,
269                                            &hooks,
270                                            default_hook_timeout,
271                                            sid,
272                                        ).await;
273                                        let response = crate::hooks::HookResponse::from_output(
274                                            req.request_id,
275                                            output,
276                                        );
277                                        if let Ok(json) = serde_json::to_string(&response) {
278                                            let _ = hook_transport.write(&json).await;
279                                        }
280                                    }
281                                    continue;
282                                }
283
284                                // Route permission_request messages to the can_use_tool callback.
285                                if value.get("type").and_then(|v| v.as_str()) == Some("permission_request") {
286                                    if let Some(ref callback) = can_use_tool {
287                                        if let Ok(req) = serde_json::from_value::<crate::permissions::ControlRequest>(value) {
288                                            let crate::permissions::ControlRequestData::PermissionRequest {
289                                                ref tool_name,
290                                                ref tool_input,
291                                                ref tool_use_id,
292                                                ref suggestions,
293                                            } = req.request;
294                                            let sid = hook_session_id
295                                                .lock()
296                                                .expect("session_id lock")
297                                                .clone()
298                                                .unwrap_or_default();
299                                            let ctx = crate::permissions::PermissionContext {
300                                                tool_use_id: tool_use_id.clone(),
301                                                session_id: sid,
302                                                request_id: req.request_id.clone(),
303                                                suggestions: suggestions.clone(),
304                                            };
305                                            let decision = callback(tool_name, tool_input, ctx).await;
306                                            let response = crate::permissions::ControlResponse {
307                                                kind: "permission_response".into(),
308                                                request_id: req.request_id,
309                                                result: crate::permissions::ControlResponseResult::from(decision),
310                                            };
311                                            if let Ok(json) = serde_json::to_string(&response) {
312                                                let _ = perm_transport.write(&json).await;
313                                            }
314                                        }
315                                    } else {
316                                        // No can_use_tool callback configured. Send a deny
317                                        // response so the CLI doesn't hang waiting forever.
318                                        let deny_response = serde_json::json!({
319                                            "kind": "permission_response",
320                                            "request_id": value.get("request_id")
321                                                .and_then(|v| v.as_str())
322                                                .unwrap_or(""),
323                                            "result": {
324                                                "type": "deny",
325                                                "message": "no permission callback configured"
326                                            }
327                                        });
328                                        if let Ok(json) = serde_json::to_string(&deny_response) {
329                                            let _ = perm_transport.write(&json).await;
330                                        }
331                                        let _ = msg_tx.send(Err(Error::ControlProtocol(
332                                            "received permission_request but no can_use_tool \
333                                             callback is configured — set can_use_tool on \
334                                             ClientConfig or use a PermissionMode that does not \
335                                             require interactive approval"
336                                                .into(),
337                                        )));
338                                    }
339                                    continue;
340                                }
341
342                                // Parse the JSON value into a Message.
343                                let msg: Message = match serde_json::from_value(value) {
344                                    Ok(m) => m,
345                                    Err(e) => {
346                                        let _ = msg_tx.send(Err(Error::Json(e)));
347                                        continue;
348                                    }
349                                };
350
351                                // Apply the message callback.
352                                let msg = match apply_callback(msg, message_callback.as_ref()) {
353                                    Some(m) => m,
354                                    None => continue, // Filtered out
355                                };
356
357                                if msg_tx.send(Ok(msg)).is_err() {
358                                    break; // Receiver dropped
359                                }
360                            }
361                            Some(Err(e)) => {
362                                let _ = msg_tx.send(Err(e));
363                            }
364                            None => break, // Stream ended
365                        }
366                    }
367                }
368            }
369        });
370
371        self.message_rx = Some(msg_rx);
372        self.shutdown_tx = Some(shutdown_tx);
373
374        // Wait for the system/init message.
375        // The overall connect_timeout is enforced by the caller, so we wait
376        // indefinitely here (the outer timeout will cancel us if needed).
377        let init_msg = self
378            .message_rx
379            .as_ref()
380            .unwrap()
381            .recv_async()
382            .await
383            .map_err(|_| Error::Transport("connection closed before init message".into()))?
384            .map_err(|e| Error::Transport(format!("error reading init message: {e}")))?;
385
386        if let Message::System(ref sys) = init_msg {
387            let info = SessionInfo::try_from(sys)?;
388            self.session_id = Some(info.session_id.clone());
389            // Propagate the session ID to the background task so hook
390            // dispatches after this point carry the real session ID.
391            *shared_session_id.lock().expect("session_id lock") = Some(info.session_id.clone());
392            Ok(info)
393        } else {
394            Err(Error::ControlProtocol(format!(
395                "expected system/init as first message, got: {init_msg:?}"
396            )))
397        }
398    }
399
400    /// Send a text prompt and return a stream of response messages.
401    ///
402    /// The stream yields messages until a `Result` message is received
403    /// (which terminates the turn).
404    pub fn send(
405        &self,
406        prompt: impl Into<String>,
407    ) -> Result<impl Stream<Item = Result<Message>> + '_> {
408        let prompt = prompt.into();
409        let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
410        let transport = Arc::clone(&self.transport);
411
412        // Guard against concurrent turns.
413        if self
414            .turn_active
415            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
416            .is_err()
417        {
418            return Err(Error::ControlProtocol("turn already in progress".into()));
419        }
420        let turn_flag = Arc::clone(&self.turn_active);
421        let read_timeout = self.config.read_timeout;
422        let cancel = self.config.cancellation_token.clone();
423
424        Ok(async_stream::stream! {
425            // Write the prompt to stdin.
426            if let Err(e) = transport.write(&prompt).await {
427                turn_flag.store(false, Ordering::Release);
428                yield Err(e);
429                return;
430            }
431
432            let inner = read_turn_stream(rx, read_timeout, turn_flag, cancel);
433            tokio::pin!(inner);
434            while let Some(item) = inner.next().await {
435                yield item;
436            }
437        })
438    }
439
440    /// Send structured content blocks (text + images) and return a stream of
441    /// response messages.
442    ///
443    /// This is the multi-modal equivalent of [`send()`](Client::send). Content
444    /// is serialised as a JSON user message and written to the CLI's stdin.
445    ///
446    /// # Errors
447    ///
448    /// Returns [`Error::Config`] if `content` is empty, or [`Error::NotConnected`]
449    /// if the client is not connected.
450    pub fn send_content(
451        &self,
452        content: Vec<UserContent>,
453    ) -> Result<impl Stream<Item = Result<Message>> + '_> {
454        if content.is_empty() {
455            return Err(Error::Config("content must not be empty".into()));
456        }
457
458        let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
459        let transport = Arc::clone(&self.transport);
460
461        // Guard against concurrent turns.
462        if self
463            .turn_active
464            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
465            .is_err()
466        {
467            return Err(Error::ControlProtocol("turn already in progress".into()));
468        }
469        let turn_flag = Arc::clone(&self.turn_active);
470        let read_timeout = self.config.read_timeout;
471        let cancel = self.config.cancellation_token.clone();
472
473        Ok(async_stream::stream! {
474            // Serialize content blocks as a JSON user message.
475            let user_message = serde_json::json!({
476                "type": "user",
477                "message": {
478                    "role": "user",
479                    "content": content
480                }
481            });
482            let json = match serde_json::to_string(&user_message) {
483                Ok(j) => j,
484                Err(e) => {
485                    turn_flag.store(false, Ordering::Release);
486                    yield Err(Error::Json(e));
487                    return;
488                }
489            };
490
491            if let Err(e) = transport.write(&json).await {
492                turn_flag.store(false, Ordering::Release);
493                yield Err(e);
494                return;
495            }
496
497            let inner = read_turn_stream(rx, read_timeout, turn_flag, cancel);
498            tokio::pin!(inner);
499            while let Some(item) = inner.next().await {
500                yield item;
501            }
502        })
503    }
504
505    /// Return a stream of all incoming messages (without sending a prompt).
506    ///
507    /// Useful for consuming messages from a resumed session.
508    pub fn receive_messages(&self) -> Result<impl Stream<Item = Result<Message>> + '_> {
509        let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
510        let read_timeout = self.config.read_timeout;
511        let cancel = self.config.cancellation_token.clone();
512
513        Ok(async_stream::stream! {
514            loop {
515                match recv_with_timeout(rx, read_timeout, cancel.as_ref()).await {
516                    Ok(msg) => yield Ok(msg),
517                    Err(e) if matches!(e, Error::Transport(_)) => break, // Channel closed
518                    Err(e) => {
519                        yield Err(e);
520                        break;
521                    }
522                }
523            }
524        })
525    }
526
527    /// Send an interrupt signal to the CLI (SIGINT).
528    pub async fn interrupt(&self) -> Result<()> {
529        self.transport.interrupt().await
530    }
531
532    /// Respond to a permission request from the CLI.
533    ///
534    /// When the CLI asks for permission to use a tool, this method sends
535    /// the decision back via the control protocol.
536    pub async fn respond_to_permission(
537        &self,
538        request_id: &str,
539        decision: crate::permissions::PermissionDecision,
540    ) -> Result<()> {
541        use crate::permissions::{ControlResponse, ControlResponseResult};
542
543        let response = ControlResponse {
544            kind: "permission_response".into(),
545            request_id: request_id.to_string(),
546            result: ControlResponseResult::from(decision),
547        };
548        let json = serde_json::to_string(&response).map_err(Error::Json)?;
549        self.transport.write(&json).await
550    }
551
552    // ── Dynamic control ─────────────────────────────────────────────────
553
554    /// Send a control request to the CLI and wait for the response.
555    ///
556    /// This is the low-level mechanism for dynamic mid-session control.
557    /// The request is wrapped in a `{"type": "control_request", ...}` envelope
558    /// and written to stdin. The background reader routes the matching
559    /// `control_response` back via a `oneshot` channel.
560    async fn send_control_request(&self, request: serde_json::Value) -> Result<serde_json::Value> {
561        let counter = self.request_counter.fetch_add(1, Ordering::Relaxed);
562        let request_id = format!("sdk_req_{counter}");
563
564        let (tx, rx) = oneshot::channel();
565        self.pending_control.insert(request_id.clone(), tx);
566
567        let envelope = serde_json::json!({
568            "type": "control_request",
569            "request_id": request_id,
570            "request": request
571        });
572        let json = serde_json::to_string(&envelope).map_err(Error::Json)?;
573        self.transport.write(&json).await?;
574
575        let timeout = self.config.control_request_timeout;
576        match tokio::time::timeout(timeout, rx).await {
577            Ok(Ok(value)) => Ok(value),
578            Ok(Err(_)) => {
579                self.pending_control.remove(&request_id);
580                Err(Error::ControlProtocol(
581                    "control response channel closed".into(),
582                ))
583            }
584            Err(_) => {
585                self.pending_control.remove(&request_id);
586                Err(Error::Timeout(format!(
587                    "control request timed out after {}s",
588                    timeout.as_secs_f64()
589                )))
590            }
591        }
592    }
593
594    /// Dynamically change the model used for subsequent turns.
595    ///
596    /// Pass `None` to revert to the session's default model.
597    ///
598    /// # Errors
599    ///
600    /// Returns an error if the CLI rejects the model change or the control
601    /// protocol fails.
602    pub async fn set_model(&self, model: Option<&str>) -> Result<()> {
603        self.send_control_request(serde_json::json!({
604            "subtype": "set_model",
605            "model": model
606        }))
607        .await?;
608        Ok(())
609    }
610
611    /// Dynamically change the permission mode for the current session.
612    ///
613    /// # Errors
614    ///
615    /// Returns an error if the CLI rejects the mode change or the control
616    /// protocol fails.
617    pub async fn set_permission_mode(&self, mode: PermissionMode) -> Result<()> {
618        self.send_control_request(serde_json::json!({
619            "subtype": "set_permission_mode",
620            "mode": mode.as_cli_flag()
621        }))
622        .await?;
623        Ok(())
624    }
625
626    /// Write raw data to the transport's stdin.
627    ///
628    /// This is a low-level method used by free functions like
629    /// [`query_stream_with_content`](crate::query_stream_with_content).
630    pub(crate) async fn transport_write(&self, data: &str) -> Result<()> {
631        self.transport.write(data).await
632    }
633
634    /// Take ownership of the message receiver (for use in `query_stream`).
635    ///
636    /// After calling this, `receive_messages()` and `send()` will return
637    /// `NotConnected`.
638    pub(crate) fn take_message_rx(&mut self) -> Option<flume::Receiver<Result<Message>>> {
639        self.message_rx.take()
640    }
641
642    /// Returns the configured read timeout.
643    #[must_use]
644    pub fn read_timeout(&self) -> Option<Duration> {
645        self.config.read_timeout
646    }
647
648    /// Returns the session ID if connected.
649    #[must_use]
650    pub fn session_id(&self) -> Option<&str> {
651        self.session_id.as_deref()
652    }
653
654    /// Returns `true` if the client is connected.
655    #[must_use]
656    pub fn is_connected(&self) -> bool {
657        self.transport.is_ready()
658    }
659
660    /// Close the client and shut down the CLI process.
661    ///
662    /// Returns the CLI process exit code if available. After calling this,
663    /// the `Drop` warning will not fire.
664    pub async fn close(&mut self) -> Result<Option<i32>> {
665        // Signal the background task to stop.
666        if let Some(tx) = self.shutdown_tx.take() {
667            let _ = tx.send(());
668        }
669        // Drop the message receiver so the Drop impl knows we cleaned up.
670        self.message_rx.take();
671        self.transport.close().await
672    }
673}
674
675impl Drop for Client {
676    fn drop(&mut self) {
677        if self.shutdown_tx.is_some() || self.message_rx.is_some() {
678            tracing::warn!(
679                "claude_cli_sdk::Client dropped without calling close(). \
680                 Resources may not be cleaned up properly."
681            );
682        }
683    }
684}
685
686// ── Tests ────────────────────────────────────────────────────────────────────
687
688#[cfg(test)]
689mod tests {
690    use super::*;
691    use crate::config::ClientConfig;
692
693    #[cfg(feature = "testing")]
694    use crate::testing::{ScenarioBuilder, assistant_text};
695
696    fn test_config() -> ClientConfig {
697        ClientConfig::builder().prompt("test").build()
698    }
699
700    #[cfg(feature = "testing")]
701    #[tokio::test]
702    async fn client_connect_and_receive_init() {
703        let transport = ScenarioBuilder::new("test-session")
704            .exchange(vec![assistant_text("Hello!")])
705            .build();
706        let transport = Arc::new(transport);
707
708        let mut client = Client::with_transport(test_config(), transport).unwrap();
709        let info = client.connect().await.unwrap();
710
711        assert_eq!(info.session_id, "test-session");
712        assert!(client.is_connected());
713        assert_eq!(client.session_id(), Some("test-session"));
714    }
715
716    #[cfg(feature = "testing")]
717    #[tokio::test]
718    async fn client_send_yields_messages() {
719        let transport = ScenarioBuilder::new("s1")
720            .exchange(vec![assistant_text("response")])
721            .build();
722        let transport = Arc::new(transport);
723
724        let mut client = Client::with_transport(test_config(), transport).unwrap();
725        client.connect().await.unwrap();
726
727        let stream = client.send("hello").unwrap();
728        tokio::pin!(stream);
729
730        let mut messages = Vec::new();
731        while let Some(msg) = stream.next().await {
732            messages.push(msg.unwrap());
733        }
734
735        // Should get assistant + result
736        assert_eq!(messages.len(), 2);
737        assert!(matches!(&messages[0], Message::Assistant(_)));
738        assert!(matches!(&messages[1], Message::Result(_)));
739    }
740
741    #[cfg(feature = "testing")]
742    #[tokio::test]
743    async fn client_close_succeeds() {
744        let transport = ScenarioBuilder::new("s1").build();
745        let transport = Arc::new(transport);
746
747        let mut client = Client::with_transport(test_config(), transport).unwrap();
748        client.connect().await.unwrap();
749        assert!(client.close().await.is_ok());
750    }
751
752    #[cfg(feature = "testing")]
753    #[tokio::test]
754    async fn client_message_callback_filters() {
755        use crate::callback::MessageCallback;
756
757        // Filter out assistant messages.
758        let callback: MessageCallback = Arc::new(|msg| match &msg {
759            Message::Assistant(_) => None,
760            _ => Some(msg),
761        });
762
763        let config = ClientConfig::builder()
764            .prompt("test")
765            .message_callback(callback)
766            .build();
767
768        let transport = ScenarioBuilder::new("s1")
769            .exchange(vec![assistant_text("filtered")])
770            .build();
771        let transport = Arc::new(transport);
772
773        let mut client = Client::with_transport(config, transport).unwrap();
774        client.connect().await.unwrap();
775
776        let stream = client.send("hello").unwrap();
777        tokio::pin!(stream);
778
779        let mut messages = Vec::new();
780        while let Some(msg) = stream.next().await {
781            messages.push(msg.unwrap());
782        }
783
784        // Only result (assistant was filtered).
785        assert_eq!(messages.len(), 1);
786        assert!(matches!(&messages[0], Message::Result(_)));
787    }
788
789    #[cfg(feature = "testing")]
790    #[test]
791    fn client_debug_before_connect() {
792        let transport = Arc::new(crate::testing::MockTransport::new());
793        let client = Client::with_transport(test_config(), transport).unwrap();
794        let debug = format!("{client:?}");
795        assert!(debug.contains("Client"));
796    }
797
798    // ── Timeout tests ────────────────────────────────────────────────────
799
800    #[cfg(feature = "testing")]
801    #[tokio::test]
802    async fn client_connect_timeout_fires() {
803        use crate::testing::MockTransport;
804
805        let transport = MockTransport::new();
806        // Set connect delay longer than timeout.
807        transport.set_connect_delay(Duration::from_secs(5));
808        let transport = Arc::new(transport);
809
810        let config = ClientConfig::builder()
811            .prompt("test")
812            .connect_timeout(Some(Duration::from_millis(50)))
813            .build();
814
815        let mut client = Client::with_transport(config, transport).unwrap();
816        let result = client.connect().await;
817        assert!(result.is_err());
818        assert!(matches!(result.unwrap_err(), Error::Timeout(_)));
819    }
820
821    #[cfg(feature = "testing")]
822    #[tokio::test]
823    async fn client_read_timeout_fires() {
824        // Build a scenario with init + assistant, but add a large recv_delay
825        // so the assistant message arrives after the read timeout.
826        let transport = ScenarioBuilder::new("s1")
827            .exchange(vec![assistant_text("delayed")])
828            .build();
829        // Delay each message by 5 seconds — way longer than our 50ms timeout.
830        // The init message also gets this delay, but the connect has no timeout
831        // wrapping for the recv path since connect_timeout is set to None here
832        // and connect_inner waits indefinitely for init.
833        // Actually, we need init to arrive fast but subsequent messages slow.
834        // The MockTransport applies recv_delay to ALL messages including init.
835        // So set connect_timeout high enough and read_timeout low.
836        transport.set_recv_delay(Duration::from_millis(200));
837        let transport = Arc::new(transport);
838
839        let config = ClientConfig::builder()
840            .prompt("test")
841            .connect_timeout(Some(Duration::from_secs(5)))
842            .read_timeout(Some(Duration::from_millis(50)))
843            .build();
844
845        let mut client = Client::with_transport(config, transport).unwrap();
846        client.connect().await.unwrap();
847
848        let stream = client.send("hello").unwrap();
849        tokio::pin!(stream);
850
851        let mut got_timeout = false;
852        while let Some(msg) = stream.next().await {
853            if let Err(Error::Timeout(_)) = msg {
854                got_timeout = true;
855                break;
856            }
857        }
858        assert!(got_timeout, "expected a timeout error");
859    }
860
861    #[cfg(feature = "testing")]
862    #[tokio::test]
863    async fn client_permission_callback_invoked_and_responds() {
864        use crate::permissions::{CanUseToolCallback, PermissionDecision};
865        use crate::testing::MockTransport;
866        use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
867
868        let invoked = Arc::new(AtomicBool::new(false));
869        let invoked_clone = Arc::clone(&invoked);
870
871        let callback: CanUseToolCallback = Arc::new(move |tool_name: &str, _input, _ctx| {
872            let invoked = Arc::clone(&invoked_clone);
873            let tool = tool_name.to_owned();
874            Box::pin(async move {
875                invoked.store(true, AtomicOrdering::Release);
876                assert_eq!(tool, "Bash");
877                PermissionDecision::allow()
878            })
879        });
880
881        let config = ClientConfig::builder()
882            .prompt("test")
883            .can_use_tool(callback)
884            .build();
885
886        let transport = MockTransport::new();
887        // Enqueue: init, permission_request, assistant, result
888        transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
889        transport.enqueue(r#"{"type":"permission_request","request_id":"perm-1","request":{"type":"permission_request","tool_name":"Bash","tool_input":{"command":"ls"},"tool_use_id":"tu-1","suggestions":["allow_once"]}}"#);
890        transport.enqueue(&serde_json::to_string(&crate::testing::assistant_text("done")).unwrap());
891        transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
892        let transport = Arc::new(transport);
893
894        let mut client = Client::with_transport(config, transport.clone()).unwrap();
895        client.connect().await.unwrap();
896
897        let stream = client.send("hello").unwrap();
898        tokio::pin!(stream);
899        let mut messages = Vec::new();
900        while let Some(msg) = stream.next().await {
901            messages.push(msg.unwrap());
902        }
903
904        // Callback should have been invoked.
905        assert!(
906            invoked.load(AtomicOrdering::Acquire),
907            "permission callback was not invoked"
908        );
909
910        // Permission request should NOT leak as a Message — only assistant + result.
911        assert_eq!(messages.len(), 2);
912        assert!(matches!(&messages[0], Message::Assistant(_)));
913        assert!(matches!(&messages[1], Message::Result(_)));
914
915        // Verify a permission_response was written back.
916        let written = transport.written_lines();
917        let perm_responses: Vec<_> = written
918            .iter()
919            .filter(|line| line.contains("permission_response"))
920            .collect();
921        assert_eq!(
922            perm_responses.len(),
923            1,
924            "expected exactly one permission_response written"
925        );
926        let resp: serde_json::Value = serde_json::from_str(perm_responses[0]).unwrap();
927        assert_eq!(resp["kind"], "permission_response");
928        assert_eq!(resp["request_id"], "perm-1");
929        assert_eq!(resp["result"]["type"], "allow");
930    }
931
932    #[cfg(feature = "testing")]
933    #[tokio::test]
934    async fn client_permission_callback_deny_writes_deny_response() {
935        use crate::permissions::{CanUseToolCallback, PermissionDecision};
936        use crate::testing::MockTransport;
937
938        let callback: CanUseToolCallback = Arc::new(|_tool_name, _input, _ctx| {
939            Box::pin(async { PermissionDecision::deny("not allowed") })
940        });
941
942        let config = ClientConfig::builder()
943            .prompt("test")
944            .can_use_tool(callback)
945            .build();
946
947        let transport = MockTransport::new();
948        transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
949        transport.enqueue(r#"{"type":"permission_request","request_id":"perm-2","request":{"type":"permission_request","tool_name":"Write","tool_input":{"path":"/etc/shadow"},"tool_use_id":"tu-2","suggestions":[]}}"#);
950        transport
951            .enqueue(&serde_json::to_string(&crate::testing::assistant_text("denied")).unwrap());
952        transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
953        let transport = Arc::new(transport);
954
955        let mut client = Client::with_transport(config, transport.clone()).unwrap();
956        client.connect().await.unwrap();
957
958        let stream = client.send("hello").unwrap();
959        tokio::pin!(stream);
960        let mut messages = Vec::new();
961        while let Some(msg) = stream.next().await {
962            messages.push(msg.unwrap());
963        }
964
965        // Verify deny response was written.
966        let written = transport.written_lines();
967        let perm_responses: Vec<_> = written
968            .iter()
969            .filter(|line| line.contains("permission_response"))
970            .collect();
971        assert_eq!(perm_responses.len(), 1);
972        let resp: serde_json::Value = serde_json::from_str(perm_responses[0]).unwrap();
973        assert_eq!(resp["kind"], "permission_response");
974        assert_eq!(resp["request_id"], "perm-2");
975        assert_eq!(resp["result"]["type"], "deny");
976        assert_eq!(resp["result"]["message"], "not allowed");
977    }
978
979    #[cfg(feature = "testing")]
980    #[tokio::test]
981    async fn client_permission_request_without_callback_yields_error() {
982        use crate::testing::MockTransport;
983
984        // No can_use_tool callback configured.
985        let config = ClientConfig::builder().prompt("test").build();
986
987        let transport = MockTransport::new();
988        transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
989        transport.enqueue(r#"{"type":"permission_request","request_id":"perm-3","request":{"type":"permission_request","tool_name":"Bash","tool_input":{"command":"ls"},"tool_use_id":"tu-3","suggestions":[]}}"#);
990        transport
991            .enqueue(&serde_json::to_string(&crate::testing::assistant_text("after")).unwrap());
992        transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
993        let transport = Arc::new(transport);
994
995        let mut client = Client::with_transport(config, transport).unwrap();
996        client.connect().await.unwrap();
997
998        let stream = client.send("hello").unwrap();
999        tokio::pin!(stream);
1000
1001        let mut got_error = false;
1002        let mut messages = Vec::new();
1003        while let Some(result) = stream.next().await {
1004            match result {
1005                Ok(msg) => messages.push(msg),
1006                Err(Error::ControlProtocol(ref msg)) if msg.contains("can_use_tool") => {
1007                    got_error = true;
1008                }
1009                Err(e) => panic!("unexpected error: {e}"),
1010            }
1011        }
1012
1013        assert!(
1014            got_error,
1015            "should have received a ControlProtocol error for missing callback"
1016        );
1017    }
1018
1019    #[tokio::test]
1020    async fn recv_with_timeout_respects_cancellation_token() {
1021        let (tx, rx) = flume::unbounded::<Result<Message>>();
1022        let token = CancellationToken::new();
1023
1024        // Cancel immediately.
1025        token.cancel();
1026
1027        let result = recv_with_timeout(&rx, None, Some(&token)).await;
1028        assert!(result.is_err());
1029        assert!(result.unwrap_err().is_cancelled());
1030
1031        // Sender is still alive — we didn't get a transport error.
1032        drop(tx);
1033    }
1034
1035    #[tokio::test]
1036    async fn recv_with_timeout_none_cancel_still_works() {
1037        let (_tx, rx) = flume::unbounded::<Result<Message>>();
1038
1039        // With no cancel token and a short timeout, we should get a timeout error.
1040        let result = recv_with_timeout(&rx, Some(Duration::from_millis(10)), None).await;
1041        assert!(matches!(result, Err(Error::Timeout(_))));
1042    }
1043
1044    #[cfg(feature = "testing")]
1045    #[tokio::test]
1046    async fn client_read_timeout_none_waits() {
1047        // MockTransport with recv_delay < reasonable wait, read_timeout None.
1048        let transport = ScenarioBuilder::new("s1")
1049            .exchange(vec![assistant_text("delayed")])
1050            .build();
1051        transport.set_recv_delay(Duration::from_millis(50));
1052        let transport = Arc::new(transport);
1053
1054        let config = ClientConfig::builder()
1055            .prompt("test")
1056            .read_timeout(None)
1057            .build();
1058
1059        let mut client = Client::with_transport(config, transport).unwrap();
1060        client.connect().await.unwrap();
1061
1062        let stream = client.send("hello").unwrap();
1063        tokio::pin!(stream);
1064
1065        let mut messages = Vec::new();
1066        while let Some(msg) = stream.next().await {
1067            messages.push(msg.unwrap());
1068        }
1069
1070        // Should get assistant + result even with delay since no timeout.
1071        assert_eq!(messages.len(), 2);
1072    }
1073}