Skip to main content

claude_cli_sdk/
client.rs

1//! The core `Client` struct — multi-turn, bidirectional Claude Code sessions.
2//!
3//! # Architecture
4//!
5//! On [`connect()`](Client::connect), the client spawns a background task that:
6//! 1. Reads JSON values from the transport
7//! 2. Routes permission requests to the configured callback
8//! 3. Routes hook requests to registered hook matchers
9//! 4. Applies the message callback (if any)
10//! 5. Forwards resulting messages through a `flume` channel
11//!
12//! Callers consume messages via [`send()`](Client::send) (which returns a
13//! stream), or via [`receive_messages()`](Client::receive_messages).
14
15use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::time::Duration;
18
19use dashmap::DashMap;
20use futures_core::Stream;
21use tokio::sync::oneshot;
22use tokio_stream::StreamExt;
23use tokio_util::sync::CancellationToken;
24
25use crate::callback::apply_callback;
26use crate::config::{ClientConfig, PermissionMode};
27use crate::errors::{Error, Result};
28use crate::transport::{CliTransport, Transport};
29use crate::types::content::UserContent;
30use crate::types::messages::{Message, SessionInfo};
31
32// ── Cancellation helper ──────────────────────────────────────────────────────
33
34/// Wait for the token to fire, or pend forever if `None`.
35///
36/// Useful as a `tokio::select!` branch that compiles away when no token is
37/// provided.
38pub(crate) async fn cancelled_or_pending(token: Option<&CancellationToken>) {
39    match token {
40        Some(t) => t.cancelled().await,
41        None => std::future::pending().await,
42    }
43}
44
45// ── Timeout helper ───────────────────────────────────────────────────────────
46
47/// Receive a message from a flume channel with an optional timeout and
48/// optional cancellation token.
49///
50/// Returns `Error::Timeout` if the deadline expires, `Error::Cancelled` if
51/// the token fires, or `Error::Transport` if the channel is closed.
52pub(crate) async fn recv_with_timeout(
53    rx: &flume::Receiver<Result<Message>>,
54    timeout: Option<Duration>,
55    cancel: Option<&CancellationToken>,
56) -> Result<Message> {
57    let recv_fut = rx.recv_async();
58    tokio::select! {
59        biased;
60        _ = cancelled_or_pending(cancel) => {
61            Err(Error::Cancelled)
62        }
63        result = async {
64            match timeout {
65                Some(d) => match tokio::time::timeout(d, recv_fut).await {
66                    Ok(Ok(msg)) => msg,
67                    Ok(Err(_)) => Err(Error::Transport("message channel closed".into())),
68                    Err(_) => Err(Error::Timeout(format!("read timed out after {}s", d.as_secs_f64()))),
69                },
70                None => match recv_fut.await {
71                    Ok(msg) => msg,
72                    Err(_) => Err(Error::Transport("message channel closed".into())),
73                },
74            }
75        } => result,
76    }
77}
78
79// ── Shared turn stream helper ─────────────────────────────────────────────────
80
81/// Read messages from the receiver until a `Result` message or error,
82/// then clear the turn flag.
83fn read_turn_stream<'a>(
84    rx: &'a flume::Receiver<Result<Message>>,
85    read_timeout: Option<Duration>,
86    turn_flag: Arc<AtomicBool>,
87    cancel: Option<CancellationToken>,
88) -> impl Stream<Item = Result<Message>> + 'a {
89    async_stream::stream! {
90        loop {
91            match recv_with_timeout(rx, read_timeout, cancel.as_ref()).await {
92                Ok(msg) => {
93                    let is_result = matches!(&msg, Message::Result(_));
94                    yield Ok(msg);
95                    if is_result {
96                        break;
97                    }
98                }
99                Err(e) => {
100                    yield Err(e);
101                    break;
102                }
103            }
104        }
105        turn_flag.store(false, Ordering::Release);
106    }
107}
108
109// ── Client ───────────────────────────────────────────────────────────────────
110
111/// A stateful Claude Code client that manages a persistent session.
112///
113/// # Lifecycle
114///
115/// 1. Create with [`Client::new(config)`](Client::new) or [`Client::with_transport(config, transport)`](Client::with_transport)
116/// 2. Call [`connect()`](Client::connect) to spawn the CLI and read the init message
117/// 3. Use [`send()`](Client::send) to send prompts and stream responses
118/// 4. Call [`close()`](Client::close) to shut down cleanly
119pub struct Client {
120    config: ClientConfig,
121    transport: Arc<dyn Transport>,
122    session_id: Option<String>,
123    message_rx: Option<flume::Receiver<Result<Message>>>,
124    shutdown_tx: Option<oneshot::Sender<()>>,
125    turn_active: Arc<AtomicBool>,
126    /// Pending outbound control requests awaiting responses.
127    pending_control: Arc<DashMap<String, oneshot::Sender<serde_json::Value>>>,
128    /// Counter for generating unique request IDs.
129    request_counter: Arc<AtomicU64>,
130}
131
132impl std::fmt::Debug for Client {
133    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134        f.debug_struct("Client")
135            .field("session_id", &self.session_id)
136            .field("connected", &self.is_connected())
137            .finish_non_exhaustive()
138    }
139}
140
141impl Client {
142    /// Create a new client with the given configuration.
143    ///
144    /// This does NOT start the CLI — call [`connect()`](Client::connect) next.
145    /// Validates the config (e.g., `cwd` existence) and discovers the CLI binary.
146    ///
147    /// # Errors
148    ///
149    /// Returns [`Error::CliNotFound`] if the CLI binary cannot be discovered,
150    /// or [`Error::Config`] if the configuration is invalid.
151    pub fn new(config: ClientConfig) -> Result<Self> {
152        config.validate()?;
153        let transport = Arc::new(CliTransport::from_config(&config)?);
154        Ok(Self {
155            config,
156            transport,
157            session_id: None,
158            message_rx: None,
159            shutdown_tx: None,
160            turn_active: Arc::new(AtomicBool::new(false)),
161            pending_control: Arc::new(DashMap::new()),
162            request_counter: Arc::new(AtomicU64::new(0)),
163        })
164    }
165
166    /// Create a client with a custom transport (useful for testing).
167    pub fn with_transport(config: ClientConfig, transport: Arc<dyn Transport>) -> Result<Self> {
168        config.validate()?;
169        Ok(Self {
170            config,
171            transport,
172            session_id: None,
173            message_rx: None,
174            shutdown_tx: None,
175            turn_active: Arc::new(AtomicBool::new(false)),
176            pending_control: Arc::new(DashMap::new()),
177            request_counter: Arc::new(AtomicU64::new(0)),
178        })
179    }
180
181    /// Connect to the CLI and return the session info from the init message.
182    ///
183    /// This spawns the CLI process (or connects to the mock transport) and
184    /// starts the background reader task. The entire connect sequence
185    /// (transport connect + init message read) is subject to `connect_timeout`.
186    pub async fn connect(&mut self) -> Result<SessionInfo> {
187        let timeout = self.config.connect_timeout;
188        let result = match timeout {
189            Some(d) => tokio::time::timeout(d, self.connect_inner())
190                .await
191                .map_err(|_| {
192                    Error::Timeout(format!("connect timed out after {}s", d.as_secs_f64()))
193                })?,
194            None => self.connect_inner().await,
195        };
196        if result.is_err() {
197            // Clean up: stop background task + kill CLI process.
198            if let Some(tx) = self.shutdown_tx.take() {
199                let _ = tx.send(());
200            }
201            self.message_rx.take();
202            let _ = self.transport.close().await;
203        }
204        result
205    }
206
207    async fn connect_inner(&mut self) -> Result<SessionInfo> {
208        self.transport.connect().await?;
209
210        // Set up the message routing pipeline.
211        let (msg_tx, msg_rx) = flume::bounded(1024);
212        let (shutdown_tx, shutdown_rx) = oneshot::channel();
213
214        let transport = Arc::clone(&self.transport);
215        let message_callback = self.config.message_callback.clone();
216        let pending_control = Arc::clone(&self.pending_control);
217
218        // Move hooks into the background task for hook dispatch.
219        let hooks: Vec<crate::hooks::HookMatcher> = std::mem::take(&mut self.config.hooks);
220        let default_hook_timeout = self.config.default_hook_timeout;
221        let hook_transport = Arc::clone(&self.transport);
222
223        // Capture permission callback for the background task.
224        let can_use_tool = self.config.can_use_tool.clone();
225        let perm_transport = Arc::clone(&self.transport);
226
227        // Cancellation token for cooperative consumer-side abort.
228        let cancel_token = self.config.cancellation_token.clone();
229
230        // Shared session_id for the background task.
231        // Updated after the init message is parsed so subsequent hook dispatches
232        // carry the real session ID rather than None.
233        let shared_session_id: Arc<std::sync::Mutex<Option<String>>> =
234            Arc::new(std::sync::Mutex::new(None));
235        let hook_session_id = Arc::clone(&shared_session_id);
236
237        // Spawn background reader task.
238        tokio::spawn(async move {
239            let mut stream = transport.read_messages();
240            let mut shutdown = shutdown_rx;
241
242            loop {
243                tokio::select! {
244                    biased;
245                    _ = &mut shutdown => break,
246                    _ = cancelled_or_pending(cancel_token.as_ref()) => break,
247                    item = stream.next() => {
248                        match item {
249                            Some(Ok(value)) => {
250                                // Route control_response messages to pending senders.
251                                if value.get("type").and_then(|v| v.as_str()) == Some("control_response") {
252                                    if let Some(req_id) = value.get("request_id").and_then(|v| v.as_str()) {
253                                        if let Some((_, tx)) = pending_control.remove(req_id) {
254                                            let _ = tx.send(value);
255                                        }
256                                    }
257                                    continue;
258                                }
259
260                                // Route hook_request messages to registered hooks.
261                                if value.get("type").and_then(|v| v.as_str()) == Some("hook_request") {
262                                    if let Ok(req) = serde_json::from_value::<crate::hooks::HookRequest>(value) {
263                                        let sid = hook_session_id
264                                            .lock()
265                                            .expect("session_id lock")
266                                            .clone();
267                                        let output = crate::hooks::dispatch_hook(
268                                            &req,
269                                            &hooks,
270                                            default_hook_timeout,
271                                            sid,
272                                        ).await;
273                                        let response = crate::hooks::HookResponse::from_output(
274                                            req.request_id,
275                                            output,
276                                        );
277                                        if let Ok(json) = serde_json::to_string(&response) {
278                                            let _ = hook_transport.write(&json).await;
279                                        }
280                                    }
281                                    continue;
282                                }
283
284                                // Route permission_request messages to the can_use_tool callback.
285                                if value.get("type").and_then(|v| v.as_str()) == Some("permission_request") {
286                                    if let Some(ref callback) = can_use_tool {
287                                        if let Ok(req) = serde_json::from_value::<crate::permissions::ControlRequest>(value) {
288                                            let crate::permissions::ControlRequestData::PermissionRequest {
289                                                ref tool_name,
290                                                ref tool_input,
291                                                ref tool_use_id,
292                                                ref suggestions,
293                                            } = req.request;
294                                            let sid = hook_session_id
295                                                .lock()
296                                                .expect("session_id lock")
297                                                .clone()
298                                                .unwrap_or_default();
299                                            let ctx = crate::permissions::PermissionContext {
300                                                tool_use_id: tool_use_id.clone(),
301                                                session_id: sid,
302                                                request_id: req.request_id.clone(),
303                                                suggestions: suggestions.clone(),
304                                            };
305                                            let decision = callback(tool_name, tool_input, ctx).await;
306                                            let response = crate::permissions::ControlResponse {
307                                                kind: "permission_response".into(),
308                                                request_id: req.request_id,
309                                                result: crate::permissions::ControlResponseResult::from(decision),
310                                            };
311                                            if let Ok(json) = serde_json::to_string(&response) {
312                                                let _ = perm_transport.write(&json).await;
313                                            }
314                                        }
315                                    } else {
316                                        // No can_use_tool callback configured. Send a deny
317                                        // response so the CLI doesn't hang waiting forever.
318                                        let deny_response = serde_json::json!({
319                                            "kind": "permission_response",
320                                            "request_id": value.get("request_id")
321                                                .and_then(|v| v.as_str())
322                                                .unwrap_or(""),
323                                            "result": {
324                                                "type": "deny",
325                                                "message": "no permission callback configured"
326                                            }
327                                        });
328                                        if let Ok(json) = serde_json::to_string(&deny_response) {
329                                            let _ = perm_transport.write(&json).await;
330                                        }
331                                        let _ = msg_tx.send(Err(Error::ControlProtocol(
332                                            "received permission_request but no can_use_tool \
333                                             callback is configured — set can_use_tool on \
334                                             ClientConfig or use a PermissionMode that does not \
335                                             require interactive approval"
336                                                .into(),
337                                        )));
338                                    }
339                                    continue;
340                                }
341
342                                // Parse the JSON value into a Message.
343                                let msg: Message = match serde_json::from_value(value) {
344                                    Ok(m) => m,
345                                    Err(e) => {
346                                        let _ = msg_tx.send(Err(Error::Json(e)));
347                                        continue;
348                                    }
349                                };
350
351                                // Apply the message callback.
352                                let msg = match apply_callback(msg, message_callback.as_ref()) {
353                                    Some(m) => m,
354                                    None => continue, // Filtered out
355                                };
356
357                                if msg_tx.send(Ok(msg)).is_err() {
358                                    break; // Receiver dropped
359                                }
360                            }
361                            Some(Err(e)) => {
362                                let _ = msg_tx.send(Err(e));
363                            }
364                            None => break, // Stream ended
365                        }
366                    }
367                }
368            }
369        });
370
371        self.message_rx = Some(msg_rx);
372        self.shutdown_tx = Some(shutdown_tx);
373
374        // If an init trigger message is configured (e.g., for --input-format
375        // stream-json mode), write it to stdin now. The CLI won't emit the
376        // system/init message until it receives stdin input in this mode.
377        if let Some(ref msg) = self.config.init_stdin_message {
378            self.transport.write(msg).await?;
379        }
380
381        // Wait for the system/init message.
382        // The overall connect_timeout is enforced by the caller, so we wait
383        // indefinitely here (the outer timeout will cancel us if needed).
384        let init_msg = self
385            .message_rx
386            .as_ref()
387            .unwrap()
388            .recv_async()
389            .await
390            .map_err(|_| Error::Transport("connection closed before init message".into()))?
391            .map_err(|e| Error::Transport(format!("error reading init message: {e}")))?;
392
393        if let Message::System(ref sys) = init_msg {
394            let info = SessionInfo::try_from(sys)?;
395            self.session_id = Some(info.session_id.clone());
396            // Propagate the session ID to the background task so hook
397            // dispatches after this point carry the real session ID.
398            *shared_session_id.lock().expect("session_id lock") = Some(info.session_id.clone());
399            Ok(info)
400        } else {
401            Err(Error::ControlProtocol(format!(
402                "expected system/init as first message, got: {init_msg:?}"
403            )))
404        }
405    }
406
407    /// Send a text prompt and return a stream of response messages.
408    ///
409    /// The stream yields messages until a `Result` message is received
410    /// (which terminates the turn).
411    pub fn send(
412        &self,
413        prompt: impl Into<String>,
414    ) -> Result<impl Stream<Item = Result<Message>> + '_> {
415        let prompt = prompt.into();
416        let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
417        let transport = Arc::clone(&self.transport);
418
419        // Guard against concurrent turns.
420        if self
421            .turn_active
422            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
423            .is_err()
424        {
425            return Err(Error::ControlProtocol("turn already in progress".into()));
426        }
427        let turn_flag = Arc::clone(&self.turn_active);
428        let read_timeout = self.config.read_timeout;
429        let cancel = self.config.cancellation_token.clone();
430
431        Ok(async_stream::stream! {
432            // Write the prompt to stdin.
433            if let Err(e) = transport.write(&prompt).await {
434                turn_flag.store(false, Ordering::Release);
435                yield Err(e);
436                return;
437            }
438
439            let inner = read_turn_stream(rx, read_timeout, turn_flag, cancel);
440            tokio::pin!(inner);
441            while let Some(item) = inner.next().await {
442                yield item;
443            }
444        })
445    }
446
447    /// Send structured content blocks (text + images) and return a stream of
448    /// response messages.
449    ///
450    /// This is the multi-modal equivalent of [`send()`](Client::send). Content
451    /// is serialised as a JSON user message and written to the CLI's stdin.
452    ///
453    /// # Errors
454    ///
455    /// Returns [`Error::Config`] if `content` is empty, or [`Error::NotConnected`]
456    /// if the client is not connected.
457    pub fn send_content(
458        &self,
459        content: Vec<UserContent>,
460    ) -> Result<impl Stream<Item = Result<Message>> + '_> {
461        if content.is_empty() {
462            return Err(Error::Config("content must not be empty".into()));
463        }
464
465        let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
466        let transport = Arc::clone(&self.transport);
467
468        // Guard against concurrent turns.
469        if self
470            .turn_active
471            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
472            .is_err()
473        {
474            return Err(Error::ControlProtocol("turn already in progress".into()));
475        }
476        let turn_flag = Arc::clone(&self.turn_active);
477        let read_timeout = self.config.read_timeout;
478        let cancel = self.config.cancellation_token.clone();
479
480        Ok(async_stream::stream! {
481            // Serialize content blocks as a JSON user message.
482            let user_message = serde_json::json!({
483                "type": "user",
484                "message": {
485                    "role": "user",
486                    "content": content
487                }
488            });
489            let json = match serde_json::to_string(&user_message) {
490                Ok(j) => j,
491                Err(e) => {
492                    turn_flag.store(false, Ordering::Release);
493                    yield Err(Error::Json(e));
494                    return;
495                }
496            };
497
498            if let Err(e) = transport.write(&json).await {
499                turn_flag.store(false, Ordering::Release);
500                yield Err(e);
501                return;
502            }
503
504            let inner = read_turn_stream(rx, read_timeout, turn_flag, cancel);
505            tokio::pin!(inner);
506            while let Some(item) = inner.next().await {
507                yield item;
508            }
509        })
510    }
511
512    /// Return a stream of all incoming messages (without sending a prompt).
513    ///
514    /// Useful for consuming messages from a resumed session.
515    pub fn receive_messages(&self) -> Result<impl Stream<Item = Result<Message>> + '_> {
516        let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
517        let read_timeout = self.config.read_timeout;
518        let cancel = self.config.cancellation_token.clone();
519
520        Ok(async_stream::stream! {
521            loop {
522                match recv_with_timeout(rx, read_timeout, cancel.as_ref()).await {
523                    Ok(msg) => yield Ok(msg),
524                    Err(e) if matches!(e, Error::Transport(_)) => break, // Channel closed
525                    Err(e) => {
526                        yield Err(e);
527                        break;
528                    }
529                }
530            }
531        })
532    }
533
534    /// Write raw text to the CLI's stdin without creating a response stream.
535    ///
536    /// Use this only when [`receive_messages()`] is already consuming responses.
537    /// **Do not call this while a [`send()`] turn is in progress** — doing so
538    /// interleaves writes on the same stdin handle and produces undefined
539    /// protocol behaviour.
540    pub async fn write_to_stdin(&self, text: &str) -> Result<()> {
541        debug_assert!(
542            !self.turn_active.load(Ordering::Relaxed),
543            "write_to_stdin called while a send() turn is active"
544        );
545        self.transport.write(text).await
546    }
547
548    /// Send an interrupt signal to the CLI (SIGINT).
549    pub async fn interrupt(&self) -> Result<()> {
550        self.transport.interrupt().await
551    }
552
553    /// Respond to a permission request from the CLI.
554    ///
555    /// When the CLI asks for permission to use a tool, this method sends
556    /// the decision back via the control protocol.
557    pub async fn respond_to_permission(
558        &self,
559        request_id: &str,
560        decision: crate::permissions::PermissionDecision,
561    ) -> Result<()> {
562        use crate::permissions::{ControlResponse, ControlResponseResult};
563
564        let response = ControlResponse {
565            kind: "permission_response".into(),
566            request_id: request_id.to_string(),
567            result: ControlResponseResult::from(decision),
568        };
569        let json = serde_json::to_string(&response).map_err(Error::Json)?;
570        self.transport.write(&json).await
571    }
572
573    // ── Dynamic control ─────────────────────────────────────────────────
574
575    /// Send a control request to the CLI and wait for the response.
576    ///
577    /// This is the low-level mechanism for dynamic mid-session control.
578    /// The request is wrapped in a `{"type": "control_request", ...}` envelope
579    /// and written to stdin. The background reader routes the matching
580    /// `control_response` back via a `oneshot` channel.
581    async fn send_control_request(&self, request: serde_json::Value) -> Result<serde_json::Value> {
582        let counter = self.request_counter.fetch_add(1, Ordering::Relaxed);
583        let request_id = format!("sdk_req_{counter}");
584
585        let (tx, rx) = oneshot::channel();
586        self.pending_control.insert(request_id.clone(), tx);
587
588        let envelope = serde_json::json!({
589            "type": "control_request",
590            "request_id": request_id,
591            "request": request
592        });
593        let json = serde_json::to_string(&envelope).map_err(Error::Json)?;
594        self.transport.write(&json).await?;
595
596        let timeout = self.config.control_request_timeout;
597        match tokio::time::timeout(timeout, rx).await {
598            Ok(Ok(value)) => Ok(value),
599            Ok(Err(_)) => {
600                self.pending_control.remove(&request_id);
601                Err(Error::ControlProtocol(
602                    "control response channel closed".into(),
603                ))
604            }
605            Err(_) => {
606                self.pending_control.remove(&request_id);
607                Err(Error::Timeout(format!(
608                    "control request timed out after {}s",
609                    timeout.as_secs_f64()
610                )))
611            }
612        }
613    }
614
615    /// Dynamically change the model used for subsequent turns.
616    ///
617    /// Pass `None` to revert to the session's default model.
618    ///
619    /// # Errors
620    ///
621    /// Returns an error if the CLI rejects the model change or the control
622    /// protocol fails.
623    pub async fn set_model(&self, model: Option<&str>) -> Result<()> {
624        self.send_control_request(serde_json::json!({
625            "subtype": "set_model",
626            "model": model
627        }))
628        .await?;
629        Ok(())
630    }
631
632    /// Dynamically change the permission mode for the current session.
633    ///
634    /// # Errors
635    ///
636    /// Returns an error if the CLI rejects the mode change or the control
637    /// protocol fails.
638    pub async fn set_permission_mode(&self, mode: PermissionMode) -> Result<()> {
639        self.send_control_request(serde_json::json!({
640            "subtype": "set_permission_mode",
641            "mode": mode.as_cli_flag()
642        }))
643        .await?;
644        Ok(())
645    }
646
647    /// Write raw data to the transport's stdin.
648    ///
649    /// This is a low-level method used by free functions like
650    /// [`query_stream_with_content`](crate::query_stream_with_content).
651    pub(crate) async fn transport_write(&self, data: &str) -> Result<()> {
652        self.transport.write(data).await
653    }
654
655    /// Take ownership of the message receiver (for use in `query_stream`).
656    ///
657    /// After calling this, `receive_messages()` and `send()` will return
658    /// `NotConnected`.
659    pub(crate) fn take_message_rx(&mut self) -> Option<flume::Receiver<Result<Message>>> {
660        self.message_rx.take()
661    }
662
663    /// Returns the configured read timeout.
664    #[must_use]
665    pub fn read_timeout(&self) -> Option<Duration> {
666        self.config.read_timeout
667    }
668
669    /// Returns the session ID if connected.
670    #[must_use]
671    pub fn session_id(&self) -> Option<&str> {
672        self.session_id.as_deref()
673    }
674
675    /// Returns `true` if the client is connected.
676    #[must_use]
677    pub fn is_connected(&self) -> bool {
678        self.transport.is_ready()
679    }
680
681    /// Close the client and shut down the CLI process.
682    ///
683    /// Returns the CLI process exit code if available. After calling this,
684    /// the `Drop` warning will not fire.
685    pub async fn close(&mut self) -> Result<Option<i32>> {
686        // Signal the background task to stop.
687        if let Some(tx) = self.shutdown_tx.take() {
688            let _ = tx.send(());
689        }
690        // Drop the message receiver so the Drop impl knows we cleaned up.
691        self.message_rx.take();
692        self.transport.close().await
693    }
694}
695
696impl Drop for Client {
697    fn drop(&mut self) {
698        if self.shutdown_tx.is_some() || self.message_rx.is_some() {
699            tracing::warn!(
700                "claude_cli_sdk::Client dropped without calling close(). \
701                 Resources may not be cleaned up properly."
702            );
703        }
704    }
705}
706
707// ── Tests ────────────────────────────────────────────────────────────────────
708
709#[cfg(test)]
710mod tests {
711    use super::*;
712    use crate::config::ClientConfig;
713
714    #[cfg(feature = "testing")]
715    use crate::testing::{ScenarioBuilder, assistant_text};
716
717    fn test_config() -> ClientConfig {
718        ClientConfig::builder().prompt("test").build()
719    }
720
721    #[cfg(feature = "testing")]
722    #[tokio::test]
723    async fn client_connect_and_receive_init() {
724        let transport = ScenarioBuilder::new("test-session")
725            .exchange(vec![assistant_text("Hello!")])
726            .build();
727        let transport = Arc::new(transport);
728
729        let mut client = Client::with_transport(test_config(), transport).unwrap();
730        let info = client.connect().await.unwrap();
731
732        assert_eq!(info.session_id, "test-session");
733        assert!(client.is_connected());
734        assert_eq!(client.session_id(), Some("test-session"));
735    }
736
737    #[cfg(feature = "testing")]
738    #[tokio::test]
739    async fn client_send_yields_messages() {
740        let transport = ScenarioBuilder::new("s1")
741            .exchange(vec![assistant_text("response")])
742            .build();
743        let transport = Arc::new(transport);
744
745        let mut client = Client::with_transport(test_config(), transport).unwrap();
746        client.connect().await.unwrap();
747
748        let stream = client.send("hello").unwrap();
749        tokio::pin!(stream);
750
751        let mut messages = Vec::new();
752        while let Some(msg) = stream.next().await {
753            messages.push(msg.unwrap());
754        }
755
756        // Should get assistant + result
757        assert_eq!(messages.len(), 2);
758        assert!(matches!(&messages[0], Message::Assistant(_)));
759        assert!(matches!(&messages[1], Message::Result(_)));
760    }
761
762    #[cfg(feature = "testing")]
763    #[tokio::test]
764    async fn client_close_succeeds() {
765        let transport = ScenarioBuilder::new("s1").build();
766        let transport = Arc::new(transport);
767
768        let mut client = Client::with_transport(test_config(), transport).unwrap();
769        client.connect().await.unwrap();
770        assert!(client.close().await.is_ok());
771    }
772
773    #[cfg(feature = "testing")]
774    #[tokio::test]
775    async fn client_message_callback_filters() {
776        use crate::callback::MessageCallback;
777
778        // Filter out assistant messages.
779        let callback: MessageCallback = Arc::new(|msg| match &msg {
780            Message::Assistant(_) => None,
781            _ => Some(msg),
782        });
783
784        let config = ClientConfig::builder()
785            .prompt("test")
786            .message_callback(callback)
787            .build();
788
789        let transport = ScenarioBuilder::new("s1")
790            .exchange(vec![assistant_text("filtered")])
791            .build();
792        let transport = Arc::new(transport);
793
794        let mut client = Client::with_transport(config, transport).unwrap();
795        client.connect().await.unwrap();
796
797        let stream = client.send("hello").unwrap();
798        tokio::pin!(stream);
799
800        let mut messages = Vec::new();
801        while let Some(msg) = stream.next().await {
802            messages.push(msg.unwrap());
803        }
804
805        // Only result (assistant was filtered).
806        assert_eq!(messages.len(), 1);
807        assert!(matches!(&messages[0], Message::Result(_)));
808    }
809
810    #[cfg(feature = "testing")]
811    #[test]
812    fn client_debug_before_connect() {
813        let transport = Arc::new(crate::testing::MockTransport::new());
814        let client = Client::with_transport(test_config(), transport).unwrap();
815        let debug = format!("{client:?}");
816        assert!(debug.contains("Client"));
817    }
818
819    // ── Timeout tests ────────────────────────────────────────────────────
820
821    #[cfg(feature = "testing")]
822    #[tokio::test]
823    async fn client_connect_timeout_fires() {
824        use crate::testing::MockTransport;
825
826        let transport = MockTransport::new();
827        // Set connect delay longer than timeout.
828        transport.set_connect_delay(Duration::from_secs(5));
829        let transport = Arc::new(transport);
830
831        let config = ClientConfig::builder()
832            .prompt("test")
833            .connect_timeout(Some(Duration::from_millis(50)))
834            .build();
835
836        let mut client = Client::with_transport(config, transport).unwrap();
837        let result = client.connect().await;
838        assert!(result.is_err());
839        assert!(matches!(result.unwrap_err(), Error::Timeout(_)));
840    }
841
842    #[cfg(feature = "testing")]
843    #[tokio::test]
844    async fn client_read_timeout_fires() {
845        // Build a scenario with init + assistant, but add a large recv_delay
846        // so the assistant message arrives after the read timeout.
847        let transport = ScenarioBuilder::new("s1")
848            .exchange(vec![assistant_text("delayed")])
849            .build();
850        // Delay each message by 5 seconds — way longer than our 50ms timeout.
851        // The init message also gets this delay, but the connect has no timeout
852        // wrapping for the recv path since connect_timeout is set to None here
853        // and connect_inner waits indefinitely for init.
854        // Actually, we need init to arrive fast but subsequent messages slow.
855        // The MockTransport applies recv_delay to ALL messages including init.
856        // So set connect_timeout high enough and read_timeout low.
857        transport.set_recv_delay(Duration::from_millis(200));
858        let transport = Arc::new(transport);
859
860        let config = ClientConfig::builder()
861            .prompt("test")
862            .connect_timeout(Some(Duration::from_secs(5)))
863            .read_timeout(Some(Duration::from_millis(50)))
864            .build();
865
866        let mut client = Client::with_transport(config, transport).unwrap();
867        client.connect().await.unwrap();
868
869        let stream = client.send("hello").unwrap();
870        tokio::pin!(stream);
871
872        let mut got_timeout = false;
873        while let Some(msg) = stream.next().await {
874            if let Err(Error::Timeout(_)) = msg {
875                got_timeout = true;
876                break;
877            }
878        }
879        assert!(got_timeout, "expected a timeout error");
880    }
881
882    #[cfg(feature = "testing")]
883    #[tokio::test]
884    async fn client_permission_callback_invoked_and_responds() {
885        use crate::permissions::{CanUseToolCallback, PermissionDecision};
886        use crate::testing::MockTransport;
887        use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
888
889        let invoked = Arc::new(AtomicBool::new(false));
890        let invoked_clone = Arc::clone(&invoked);
891
892        let callback: CanUseToolCallback = Arc::new(move |tool_name: &str, _input, _ctx| {
893            let invoked = Arc::clone(&invoked_clone);
894            let tool = tool_name.to_owned();
895            Box::pin(async move {
896                invoked.store(true, AtomicOrdering::Release);
897                assert_eq!(tool, "Bash");
898                PermissionDecision::allow()
899            })
900        });
901
902        let config = ClientConfig::builder()
903            .prompt("test")
904            .can_use_tool(callback)
905            .build();
906
907        let transport = MockTransport::new();
908        // Enqueue: init, permission_request, assistant, result
909        transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
910        transport.enqueue(r#"{"type":"permission_request","request_id":"perm-1","request":{"type":"permission_request","tool_name":"Bash","tool_input":{"command":"ls"},"tool_use_id":"tu-1","suggestions":["allow_once"]}}"#);
911        transport.enqueue(&serde_json::to_string(&crate::testing::assistant_text("done")).unwrap());
912        transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
913        let transport = Arc::new(transport);
914
915        let mut client = Client::with_transport(config, transport.clone()).unwrap();
916        client.connect().await.unwrap();
917
918        let stream = client.send("hello").unwrap();
919        tokio::pin!(stream);
920        let mut messages = Vec::new();
921        while let Some(msg) = stream.next().await {
922            messages.push(msg.unwrap());
923        }
924
925        // Callback should have been invoked.
926        assert!(
927            invoked.load(AtomicOrdering::Acquire),
928            "permission callback was not invoked"
929        );
930
931        // Permission request should NOT leak as a Message — only assistant + result.
932        assert_eq!(messages.len(), 2);
933        assert!(matches!(&messages[0], Message::Assistant(_)));
934        assert!(matches!(&messages[1], Message::Result(_)));
935
936        // Verify a permission_response was written back.
937        let written = transport.written_lines();
938        let perm_responses: Vec<_> = written
939            .iter()
940            .filter(|line| line.contains("permission_response"))
941            .collect();
942        assert_eq!(
943            perm_responses.len(),
944            1,
945            "expected exactly one permission_response written"
946        );
947        let resp: serde_json::Value = serde_json::from_str(perm_responses[0]).unwrap();
948        assert_eq!(resp["kind"], "permission_response");
949        assert_eq!(resp["request_id"], "perm-1");
950        assert_eq!(resp["result"]["type"], "allow");
951    }
952
953    #[cfg(feature = "testing")]
954    #[tokio::test]
955    async fn client_permission_callback_deny_writes_deny_response() {
956        use crate::permissions::{CanUseToolCallback, PermissionDecision};
957        use crate::testing::MockTransport;
958
959        let callback: CanUseToolCallback = Arc::new(|_tool_name, _input, _ctx| {
960            Box::pin(async { PermissionDecision::deny("not allowed") })
961        });
962
963        let config = ClientConfig::builder()
964            .prompt("test")
965            .can_use_tool(callback)
966            .build();
967
968        let transport = MockTransport::new();
969        transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
970        transport.enqueue(r#"{"type":"permission_request","request_id":"perm-2","request":{"type":"permission_request","tool_name":"Write","tool_input":{"path":"/etc/shadow"},"tool_use_id":"tu-2","suggestions":[]}}"#);
971        transport
972            .enqueue(&serde_json::to_string(&crate::testing::assistant_text("denied")).unwrap());
973        transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
974        let transport = Arc::new(transport);
975
976        let mut client = Client::with_transport(config, transport.clone()).unwrap();
977        client.connect().await.unwrap();
978
979        let stream = client.send("hello").unwrap();
980        tokio::pin!(stream);
981        let mut messages = Vec::new();
982        while let Some(msg) = stream.next().await {
983            messages.push(msg.unwrap());
984        }
985
986        // Verify deny response was written.
987        let written = transport.written_lines();
988        let perm_responses: Vec<_> = written
989            .iter()
990            .filter(|line| line.contains("permission_response"))
991            .collect();
992        assert_eq!(perm_responses.len(), 1);
993        let resp: serde_json::Value = serde_json::from_str(perm_responses[0]).unwrap();
994        assert_eq!(resp["kind"], "permission_response");
995        assert_eq!(resp["request_id"], "perm-2");
996        assert_eq!(resp["result"]["type"], "deny");
997        assert_eq!(resp["result"]["message"], "not allowed");
998    }
999
1000    #[cfg(feature = "testing")]
1001    #[tokio::test]
1002    async fn client_permission_request_without_callback_yields_error() {
1003        use crate::testing::MockTransport;
1004
1005        // No can_use_tool callback configured.
1006        let config = ClientConfig::builder().prompt("test").build();
1007
1008        let transport = MockTransport::new();
1009        transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
1010        transport.enqueue(r#"{"type":"permission_request","request_id":"perm-3","request":{"type":"permission_request","tool_name":"Bash","tool_input":{"command":"ls"},"tool_use_id":"tu-3","suggestions":[]}}"#);
1011        transport
1012            .enqueue(&serde_json::to_string(&crate::testing::assistant_text("after")).unwrap());
1013        transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
1014        let transport = Arc::new(transport);
1015
1016        let mut client = Client::with_transport(config, transport).unwrap();
1017        client.connect().await.unwrap();
1018
1019        let stream = client.send("hello").unwrap();
1020        tokio::pin!(stream);
1021
1022        let mut got_error = false;
1023        let mut messages = Vec::new();
1024        while let Some(result) = stream.next().await {
1025            match result {
1026                Ok(msg) => messages.push(msg),
1027                Err(Error::ControlProtocol(ref msg)) if msg.contains("can_use_tool") => {
1028                    got_error = true;
1029                }
1030                Err(e) => panic!("unexpected error: {e}"),
1031            }
1032        }
1033
1034        assert!(
1035            got_error,
1036            "should have received a ControlProtocol error for missing callback"
1037        );
1038    }
1039
1040    #[tokio::test]
1041    async fn recv_with_timeout_respects_cancellation_token() {
1042        let (tx, rx) = flume::unbounded::<Result<Message>>();
1043        let token = CancellationToken::new();
1044
1045        // Cancel immediately.
1046        token.cancel();
1047
1048        let result = recv_with_timeout(&rx, None, Some(&token)).await;
1049        assert!(result.is_err());
1050        assert!(result.unwrap_err().is_cancelled());
1051
1052        // Sender is still alive — we didn't get a transport error.
1053        drop(tx);
1054    }
1055
1056    #[tokio::test]
1057    async fn recv_with_timeout_none_cancel_still_works() {
1058        let (_tx, rx) = flume::unbounded::<Result<Message>>();
1059
1060        // With no cancel token and a short timeout, we should get a timeout error.
1061        let result = recv_with_timeout(&rx, Some(Duration::from_millis(10)), None).await;
1062        assert!(matches!(result, Err(Error::Timeout(_))));
1063    }
1064
1065    #[cfg(feature = "testing")]
1066    #[tokio::test]
1067    async fn client_read_timeout_none_waits() {
1068        // MockTransport with recv_delay < reasonable wait, read_timeout None.
1069        let transport = ScenarioBuilder::new("s1")
1070            .exchange(vec![assistant_text("delayed")])
1071            .build();
1072        transport.set_recv_delay(Duration::from_millis(50));
1073        let transport = Arc::new(transport);
1074
1075        let config = ClientConfig::builder()
1076            .prompt("test")
1077            .read_timeout(None)
1078            .build();
1079
1080        let mut client = Client::with_transport(config, transport).unwrap();
1081        client.connect().await.unwrap();
1082
1083        let stream = client.send("hello").unwrap();
1084        tokio::pin!(stream);
1085
1086        let mut messages = Vec::new();
1087        while let Some(msg) = stream.next().await {
1088            messages.push(msg.unwrap());
1089        }
1090
1091        // Should get assistant + result even with delay since no timeout.
1092        assert_eq!(messages.len(), 2);
1093    }
1094}