1use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::time::Duration;
18
19use dashmap::DashMap;
20use futures_core::Stream;
21use tokio::sync::oneshot;
22use tokio_stream::StreamExt;
23use tokio_util::sync::CancellationToken;
24
25use crate::callback::apply_callback;
26use crate::config::{ClientConfig, PermissionMode};
27use crate::errors::{Error, Result};
28use crate::transport::{CliTransport, Transport};
29use crate::types::content::UserContent;
30use crate::types::messages::{Message, SessionInfo};
31
32pub(crate) async fn cancelled_or_pending(token: Option<&CancellationToken>) {
39 match token {
40 Some(t) => t.cancelled().await,
41 None => std::future::pending().await,
42 }
43}
44
45pub(crate) async fn recv_with_timeout(
53 rx: &flume::Receiver<Result<Message>>,
54 timeout: Option<Duration>,
55 cancel: Option<&CancellationToken>,
56) -> Result<Message> {
57 let recv_fut = rx.recv_async();
58 tokio::select! {
59 biased;
60 _ = cancelled_or_pending(cancel) => {
61 Err(Error::Cancelled)
62 }
63 result = async {
64 match timeout {
65 Some(d) => match tokio::time::timeout(d, recv_fut).await {
66 Ok(Ok(msg)) => msg,
67 Ok(Err(_)) => Err(Error::Transport("message channel closed".into())),
68 Err(_) => Err(Error::Timeout(format!("read timed out after {}s", d.as_secs_f64()))),
69 },
70 None => match recv_fut.await {
71 Ok(msg) => msg,
72 Err(_) => Err(Error::Transport("message channel closed".into())),
73 },
74 }
75 } => result,
76 }
77}
78
79fn read_turn_stream<'a>(
84 rx: &'a flume::Receiver<Result<Message>>,
85 read_timeout: Option<Duration>,
86 turn_flag: Arc<AtomicBool>,
87 cancel: Option<CancellationToken>,
88) -> impl Stream<Item = Result<Message>> + 'a {
89 async_stream::stream! {
90 loop {
91 match recv_with_timeout(rx, read_timeout, cancel.as_ref()).await {
92 Ok(msg) => {
93 let is_result = matches!(&msg, Message::Result(_));
94 yield Ok(msg);
95 if is_result {
96 break;
97 }
98 }
99 Err(e) => {
100 yield Err(e);
101 break;
102 }
103 }
104 }
105 turn_flag.store(false, Ordering::Release);
106 }
107}
108
109pub struct Client {
120 config: ClientConfig,
121 transport: Arc<dyn Transport>,
122 session_id: Option<String>,
123 message_rx: Option<flume::Receiver<Result<Message>>>,
124 shutdown_tx: Option<oneshot::Sender<()>>,
125 turn_active: Arc<AtomicBool>,
126 pending_control: Arc<DashMap<String, oneshot::Sender<serde_json::Value>>>,
128 request_counter: Arc<AtomicU64>,
130}
131
132impl std::fmt::Debug for Client {
133 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134 f.debug_struct("Client")
135 .field("session_id", &self.session_id)
136 .field("connected", &self.is_connected())
137 .finish_non_exhaustive()
138 }
139}
140
141impl Client {
142 pub fn new(config: ClientConfig) -> Result<Self> {
152 config.validate()?;
153 let transport = Arc::new(CliTransport::from_config(&config)?);
154 Ok(Self {
155 config,
156 transport,
157 session_id: None,
158 message_rx: None,
159 shutdown_tx: None,
160 turn_active: Arc::new(AtomicBool::new(false)),
161 pending_control: Arc::new(DashMap::new()),
162 request_counter: Arc::new(AtomicU64::new(0)),
163 })
164 }
165
166 pub fn with_transport(config: ClientConfig, transport: Arc<dyn Transport>) -> Result<Self> {
168 config.validate()?;
169 Ok(Self {
170 config,
171 transport,
172 session_id: None,
173 message_rx: None,
174 shutdown_tx: None,
175 turn_active: Arc::new(AtomicBool::new(false)),
176 pending_control: Arc::new(DashMap::new()),
177 request_counter: Arc::new(AtomicU64::new(0)),
178 })
179 }
180
181 pub async fn connect(&mut self) -> Result<SessionInfo> {
187 let timeout = self.config.connect_timeout;
188 let result = match timeout {
189 Some(d) => tokio::time::timeout(d, self.connect_inner())
190 .await
191 .map_err(|_| {
192 Error::Timeout(format!("connect timed out after {}s", d.as_secs_f64()))
193 })?,
194 None => self.connect_inner().await,
195 };
196 if result.is_err() {
197 if let Some(tx) = self.shutdown_tx.take() {
199 let _ = tx.send(());
200 }
201 self.message_rx.take();
202 let _ = self.transport.close().await;
203 }
204 result
205 }
206
207 async fn connect_inner(&mut self) -> Result<SessionInfo> {
208 self.transport.connect().await?;
209
210 let (msg_tx, msg_rx) = flume::bounded(1024);
212 let (shutdown_tx, shutdown_rx) = oneshot::channel();
213
214 let transport = Arc::clone(&self.transport);
215 let message_callback = self.config.message_callback.clone();
216 let pending_control = Arc::clone(&self.pending_control);
217
218 let hooks: Vec<crate::hooks::HookMatcher> = std::mem::take(&mut self.config.hooks);
220 let default_hook_timeout = self.config.default_hook_timeout;
221 let hook_transport = Arc::clone(&self.transport);
222
223 let can_use_tool = self.config.can_use_tool.clone();
225 let perm_transport = Arc::clone(&self.transport);
226
227 let cancel_token = self.config.cancellation_token.clone();
229
230 let shared_session_id: Arc<std::sync::Mutex<Option<String>>> =
234 Arc::new(std::sync::Mutex::new(None));
235 let hook_session_id = Arc::clone(&shared_session_id);
236
237 tokio::spawn(async move {
239 let mut stream = transport.read_messages();
240 let mut shutdown = shutdown_rx;
241
242 loop {
243 tokio::select! {
244 biased;
245 _ = &mut shutdown => break,
246 _ = cancelled_or_pending(cancel_token.as_ref()) => break,
247 item = stream.next() => {
248 match item {
249 Some(Ok(value)) => {
250 if value.get("type").and_then(|v| v.as_str()) == Some("control_response") {
252 if let Some(req_id) = value.get("request_id").and_then(|v| v.as_str()) {
253 if let Some((_, tx)) = pending_control.remove(req_id) {
254 let _ = tx.send(value);
255 }
256 }
257 continue;
258 }
259
260 if value.get("type").and_then(|v| v.as_str()) == Some("hook_request") {
262 if let Ok(req) = serde_json::from_value::<crate::hooks::HookRequest>(value) {
263 let sid = hook_session_id
264 .lock()
265 .expect("session_id lock")
266 .clone();
267 let output = crate::hooks::dispatch_hook(
268 &req,
269 &hooks,
270 default_hook_timeout,
271 sid,
272 ).await;
273 let response = crate::hooks::HookResponse::from_output(
274 req.request_id,
275 output,
276 );
277 if let Ok(json) = serde_json::to_string(&response) {
278 let _ = hook_transport.write(&json).await;
279 }
280 }
281 continue;
282 }
283
284 if value.get("type").and_then(|v| v.as_str()) == Some("permission_request") {
286 if let Some(ref callback) = can_use_tool {
287 if let Ok(req) = serde_json::from_value::<crate::permissions::ControlRequest>(value) {
288 let crate::permissions::ControlRequestData::PermissionRequest {
289 ref tool_name,
290 ref tool_input,
291 ref tool_use_id,
292 ref suggestions,
293 } = req.request;
294 let sid = hook_session_id
295 .lock()
296 .expect("session_id lock")
297 .clone()
298 .unwrap_or_default();
299 let ctx = crate::permissions::PermissionContext {
300 tool_use_id: tool_use_id.clone(),
301 session_id: sid,
302 request_id: req.request_id.clone(),
303 suggestions: suggestions.clone(),
304 };
305 let decision = callback(tool_name, tool_input, ctx).await;
306 let response = crate::permissions::ControlResponse {
307 kind: "permission_response".into(),
308 request_id: req.request_id,
309 result: crate::permissions::ControlResponseResult::from(decision),
310 };
311 if let Ok(json) = serde_json::to_string(&response) {
312 let _ = perm_transport.write(&json).await;
313 }
314 }
315 } else {
316 let deny_response = serde_json::json!({
319 "kind": "permission_response",
320 "request_id": value.get("request_id")
321 .and_then(|v| v.as_str())
322 .unwrap_or(""),
323 "result": {
324 "type": "deny",
325 "message": "no permission callback configured"
326 }
327 });
328 if let Ok(json) = serde_json::to_string(&deny_response) {
329 let _ = perm_transport.write(&json).await;
330 }
331 let _ = msg_tx.send(Err(Error::ControlProtocol(
332 "received permission_request but no can_use_tool \
333 callback is configured — set can_use_tool on \
334 ClientConfig or use a PermissionMode that does not \
335 require interactive approval"
336 .into(),
337 )));
338 }
339 continue;
340 }
341
342 let msg: Message = match serde_json::from_value(value) {
344 Ok(m) => m,
345 Err(e) => {
346 let _ = msg_tx.send(Err(Error::Json(e)));
347 continue;
348 }
349 };
350
351 let msg = match apply_callback(msg, message_callback.as_ref()) {
353 Some(m) => m,
354 None => continue, };
356
357 if msg_tx.send(Ok(msg)).is_err() {
358 break; }
360 }
361 Some(Err(e)) => {
362 let _ = msg_tx.send(Err(e));
363 }
364 None => break, }
366 }
367 }
368 }
369 });
370
371 self.message_rx = Some(msg_rx);
372 self.shutdown_tx = Some(shutdown_tx);
373
374 if let Some(ref msg) = self.config.init_stdin_message {
378 self.transport.write(msg).await?;
379 }
380
381 let init_msg = self
385 .message_rx
386 .as_ref()
387 .unwrap()
388 .recv_async()
389 .await
390 .map_err(|_| Error::Transport("connection closed before init message".into()))?
391 .map_err(|e| Error::Transport(format!("error reading init message: {e}")))?;
392
393 if let Message::System(ref sys) = init_msg {
394 let info = SessionInfo::try_from(sys)?;
395 self.session_id = Some(info.session_id.clone());
396 *shared_session_id.lock().expect("session_id lock") = Some(info.session_id.clone());
399 Ok(info)
400 } else {
401 Err(Error::ControlProtocol(format!(
402 "expected system/init as first message, got: {init_msg:?}"
403 )))
404 }
405 }
406
407 pub fn send(
412 &self,
413 prompt: impl Into<String>,
414 ) -> Result<impl Stream<Item = Result<Message>> + '_> {
415 let prompt = prompt.into();
416 let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
417 let transport = Arc::clone(&self.transport);
418
419 if self
421 .turn_active
422 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
423 .is_err()
424 {
425 return Err(Error::ControlProtocol("turn already in progress".into()));
426 }
427 let turn_flag = Arc::clone(&self.turn_active);
428 let read_timeout = self.config.read_timeout;
429 let cancel = self.config.cancellation_token.clone();
430
431 Ok(async_stream::stream! {
432 if let Err(e) = transport.write(&prompt).await {
434 turn_flag.store(false, Ordering::Release);
435 yield Err(e);
436 return;
437 }
438
439 let inner = read_turn_stream(rx, read_timeout, turn_flag, cancel);
440 tokio::pin!(inner);
441 while let Some(item) = inner.next().await {
442 yield item;
443 }
444 })
445 }
446
447 pub fn send_content(
458 &self,
459 content: Vec<UserContent>,
460 ) -> Result<impl Stream<Item = Result<Message>> + '_> {
461 if content.is_empty() {
462 return Err(Error::Config("content must not be empty".into()));
463 }
464
465 let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
466 let transport = Arc::clone(&self.transport);
467
468 if self
470 .turn_active
471 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
472 .is_err()
473 {
474 return Err(Error::ControlProtocol("turn already in progress".into()));
475 }
476 let turn_flag = Arc::clone(&self.turn_active);
477 let read_timeout = self.config.read_timeout;
478 let cancel = self.config.cancellation_token.clone();
479
480 Ok(async_stream::stream! {
481 let user_message = serde_json::json!({
483 "type": "user",
484 "message": {
485 "role": "user",
486 "content": content
487 }
488 });
489 let json = match serde_json::to_string(&user_message) {
490 Ok(j) => j,
491 Err(e) => {
492 turn_flag.store(false, Ordering::Release);
493 yield Err(Error::Json(e));
494 return;
495 }
496 };
497
498 if let Err(e) = transport.write(&json).await {
499 turn_flag.store(false, Ordering::Release);
500 yield Err(e);
501 return;
502 }
503
504 let inner = read_turn_stream(rx, read_timeout, turn_flag, cancel);
505 tokio::pin!(inner);
506 while let Some(item) = inner.next().await {
507 yield item;
508 }
509 })
510 }
511
512 pub fn receive_messages(&self) -> Result<impl Stream<Item = Result<Message>> + '_> {
516 let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
517 let read_timeout = self.config.read_timeout;
518 let cancel = self.config.cancellation_token.clone();
519
520 Ok(async_stream::stream! {
521 loop {
522 match recv_with_timeout(rx, read_timeout, cancel.as_ref()).await {
523 Ok(msg) => yield Ok(msg),
524 Err(e) if matches!(e, Error::Transport(_)) => break, Err(e) => {
526 yield Err(e);
527 break;
528 }
529 }
530 }
531 })
532 }
533
534 pub async fn write_to_stdin(&self, text: &str) -> Result<()> {
541 debug_assert!(
542 !self.turn_active.load(Ordering::Relaxed),
543 "write_to_stdin called while a send() turn is active"
544 );
545 self.transport.write(text).await
546 }
547
548 pub async fn interrupt(&self) -> Result<()> {
550 self.transport.interrupt().await
551 }
552
553 pub async fn respond_to_permission(
558 &self,
559 request_id: &str,
560 decision: crate::permissions::PermissionDecision,
561 ) -> Result<()> {
562 use crate::permissions::{ControlResponse, ControlResponseResult};
563
564 let response = ControlResponse {
565 kind: "permission_response".into(),
566 request_id: request_id.to_string(),
567 result: ControlResponseResult::from(decision),
568 };
569 let json = serde_json::to_string(&response).map_err(Error::Json)?;
570 self.transport.write(&json).await
571 }
572
573 async fn send_control_request(&self, request: serde_json::Value) -> Result<serde_json::Value> {
582 let counter = self.request_counter.fetch_add(1, Ordering::Relaxed);
583 let request_id = format!("sdk_req_{counter}");
584
585 let (tx, rx) = oneshot::channel();
586 self.pending_control.insert(request_id.clone(), tx);
587
588 let envelope = serde_json::json!({
589 "type": "control_request",
590 "request_id": request_id,
591 "request": request
592 });
593 let json = serde_json::to_string(&envelope).map_err(Error::Json)?;
594 self.transport.write(&json).await?;
595
596 let timeout = self.config.control_request_timeout;
597 match tokio::time::timeout(timeout, rx).await {
598 Ok(Ok(value)) => Ok(value),
599 Ok(Err(_)) => {
600 self.pending_control.remove(&request_id);
601 Err(Error::ControlProtocol(
602 "control response channel closed".into(),
603 ))
604 }
605 Err(_) => {
606 self.pending_control.remove(&request_id);
607 Err(Error::Timeout(format!(
608 "control request timed out after {}s",
609 timeout.as_secs_f64()
610 )))
611 }
612 }
613 }
614
615 pub async fn set_model(&self, model: Option<&str>) -> Result<()> {
624 self.send_control_request(serde_json::json!({
625 "subtype": "set_model",
626 "model": model
627 }))
628 .await?;
629 Ok(())
630 }
631
632 pub async fn set_permission_mode(&self, mode: PermissionMode) -> Result<()> {
639 self.send_control_request(serde_json::json!({
640 "subtype": "set_permission_mode",
641 "mode": mode.as_cli_flag()
642 }))
643 .await?;
644 Ok(())
645 }
646
647 pub(crate) async fn transport_write(&self, data: &str) -> Result<()> {
652 self.transport.write(data).await
653 }
654
655 pub(crate) fn take_message_rx(&mut self) -> Option<flume::Receiver<Result<Message>>> {
660 self.message_rx.take()
661 }
662
663 #[must_use]
665 pub fn read_timeout(&self) -> Option<Duration> {
666 self.config.read_timeout
667 }
668
669 #[must_use]
671 pub fn session_id(&self) -> Option<&str> {
672 self.session_id.as_deref()
673 }
674
675 #[must_use]
677 pub fn is_connected(&self) -> bool {
678 self.transport.is_ready()
679 }
680
681 pub async fn close(&mut self) -> Result<Option<i32>> {
686 if let Some(tx) = self.shutdown_tx.take() {
688 let _ = tx.send(());
689 }
690 self.message_rx.take();
692 self.transport.close().await
693 }
694}
695
696impl Drop for Client {
697 fn drop(&mut self) {
698 if self.shutdown_tx.is_some() || self.message_rx.is_some() {
699 tracing::warn!(
700 "claude_cli_sdk::Client dropped without calling close(). \
701 Resources may not be cleaned up properly."
702 );
703 }
704 }
705}
706
707#[cfg(test)]
710mod tests {
711 use super::*;
712 use crate::config::ClientConfig;
713
714 #[cfg(feature = "testing")]
715 use crate::testing::{ScenarioBuilder, assistant_text};
716
717 fn test_config() -> ClientConfig {
718 ClientConfig::builder().prompt("test").build()
719 }
720
721 #[cfg(feature = "testing")]
722 #[tokio::test]
723 async fn client_connect_and_receive_init() {
724 let transport = ScenarioBuilder::new("test-session")
725 .exchange(vec![assistant_text("Hello!")])
726 .build();
727 let transport = Arc::new(transport);
728
729 let mut client = Client::with_transport(test_config(), transport).unwrap();
730 let info = client.connect().await.unwrap();
731
732 assert_eq!(info.session_id, "test-session");
733 assert!(client.is_connected());
734 assert_eq!(client.session_id(), Some("test-session"));
735 }
736
737 #[cfg(feature = "testing")]
738 #[tokio::test]
739 async fn client_send_yields_messages() {
740 let transport = ScenarioBuilder::new("s1")
741 .exchange(vec![assistant_text("response")])
742 .build();
743 let transport = Arc::new(transport);
744
745 let mut client = Client::with_transport(test_config(), transport).unwrap();
746 client.connect().await.unwrap();
747
748 let stream = client.send("hello").unwrap();
749 tokio::pin!(stream);
750
751 let mut messages = Vec::new();
752 while let Some(msg) = stream.next().await {
753 messages.push(msg.unwrap());
754 }
755
756 assert_eq!(messages.len(), 2);
758 assert!(matches!(&messages[0], Message::Assistant(_)));
759 assert!(matches!(&messages[1], Message::Result(_)));
760 }
761
762 #[cfg(feature = "testing")]
763 #[tokio::test]
764 async fn client_close_succeeds() {
765 let transport = ScenarioBuilder::new("s1").build();
766 let transport = Arc::new(transport);
767
768 let mut client = Client::with_transport(test_config(), transport).unwrap();
769 client.connect().await.unwrap();
770 assert!(client.close().await.is_ok());
771 }
772
773 #[cfg(feature = "testing")]
774 #[tokio::test]
775 async fn client_message_callback_filters() {
776 use crate::callback::MessageCallback;
777
778 let callback: MessageCallback = Arc::new(|msg| match &msg {
780 Message::Assistant(_) => None,
781 _ => Some(msg),
782 });
783
784 let config = ClientConfig::builder()
785 .prompt("test")
786 .message_callback(callback)
787 .build();
788
789 let transport = ScenarioBuilder::new("s1")
790 .exchange(vec![assistant_text("filtered")])
791 .build();
792 let transport = Arc::new(transport);
793
794 let mut client = Client::with_transport(config, transport).unwrap();
795 client.connect().await.unwrap();
796
797 let stream = client.send("hello").unwrap();
798 tokio::pin!(stream);
799
800 let mut messages = Vec::new();
801 while let Some(msg) = stream.next().await {
802 messages.push(msg.unwrap());
803 }
804
805 assert_eq!(messages.len(), 1);
807 assert!(matches!(&messages[0], Message::Result(_)));
808 }
809
810 #[cfg(feature = "testing")]
811 #[test]
812 fn client_debug_before_connect() {
813 let transport = Arc::new(crate::testing::MockTransport::new());
814 let client = Client::with_transport(test_config(), transport).unwrap();
815 let debug = format!("{client:?}");
816 assert!(debug.contains("Client"));
817 }
818
819 #[cfg(feature = "testing")]
822 #[tokio::test]
823 async fn client_connect_timeout_fires() {
824 use crate::testing::MockTransport;
825
826 let transport = MockTransport::new();
827 transport.set_connect_delay(Duration::from_secs(5));
829 let transport = Arc::new(transport);
830
831 let config = ClientConfig::builder()
832 .prompt("test")
833 .connect_timeout(Some(Duration::from_millis(50)))
834 .build();
835
836 let mut client = Client::with_transport(config, transport).unwrap();
837 let result = client.connect().await;
838 assert!(result.is_err());
839 assert!(matches!(result.unwrap_err(), Error::Timeout(_)));
840 }
841
842 #[cfg(feature = "testing")]
843 #[tokio::test]
844 async fn client_read_timeout_fires() {
845 let transport = ScenarioBuilder::new("s1")
848 .exchange(vec![assistant_text("delayed")])
849 .build();
850 transport.set_recv_delay(Duration::from_millis(200));
858 let transport = Arc::new(transport);
859
860 let config = ClientConfig::builder()
861 .prompt("test")
862 .connect_timeout(Some(Duration::from_secs(5)))
863 .read_timeout(Some(Duration::from_millis(50)))
864 .build();
865
866 let mut client = Client::with_transport(config, transport).unwrap();
867 client.connect().await.unwrap();
868
869 let stream = client.send("hello").unwrap();
870 tokio::pin!(stream);
871
872 let mut got_timeout = false;
873 while let Some(msg) = stream.next().await {
874 if let Err(Error::Timeout(_)) = msg {
875 got_timeout = true;
876 break;
877 }
878 }
879 assert!(got_timeout, "expected a timeout error");
880 }
881
882 #[cfg(feature = "testing")]
883 #[tokio::test]
884 async fn client_permission_callback_invoked_and_responds() {
885 use crate::permissions::{CanUseToolCallback, PermissionDecision};
886 use crate::testing::MockTransport;
887 use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
888
889 let invoked = Arc::new(AtomicBool::new(false));
890 let invoked_clone = Arc::clone(&invoked);
891
892 let callback: CanUseToolCallback = Arc::new(move |tool_name: &str, _input, _ctx| {
893 let invoked = Arc::clone(&invoked_clone);
894 let tool = tool_name.to_owned();
895 Box::pin(async move {
896 invoked.store(true, AtomicOrdering::Release);
897 assert_eq!(tool, "Bash");
898 PermissionDecision::allow()
899 })
900 });
901
902 let config = ClientConfig::builder()
903 .prompt("test")
904 .can_use_tool(callback)
905 .build();
906
907 let transport = MockTransport::new();
908 transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
910 transport.enqueue(r#"{"type":"permission_request","request_id":"perm-1","request":{"type":"permission_request","tool_name":"Bash","tool_input":{"command":"ls"},"tool_use_id":"tu-1","suggestions":["allow_once"]}}"#);
911 transport.enqueue(&serde_json::to_string(&crate::testing::assistant_text("done")).unwrap());
912 transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
913 let transport = Arc::new(transport);
914
915 let mut client = Client::with_transport(config, transport.clone()).unwrap();
916 client.connect().await.unwrap();
917
918 let stream = client.send("hello").unwrap();
919 tokio::pin!(stream);
920 let mut messages = Vec::new();
921 while let Some(msg) = stream.next().await {
922 messages.push(msg.unwrap());
923 }
924
925 assert!(
927 invoked.load(AtomicOrdering::Acquire),
928 "permission callback was not invoked"
929 );
930
931 assert_eq!(messages.len(), 2);
933 assert!(matches!(&messages[0], Message::Assistant(_)));
934 assert!(matches!(&messages[1], Message::Result(_)));
935
936 let written = transport.written_lines();
938 let perm_responses: Vec<_> = written
939 .iter()
940 .filter(|line| line.contains("permission_response"))
941 .collect();
942 assert_eq!(
943 perm_responses.len(),
944 1,
945 "expected exactly one permission_response written"
946 );
947 let resp: serde_json::Value = serde_json::from_str(perm_responses[0]).unwrap();
948 assert_eq!(resp["kind"], "permission_response");
949 assert_eq!(resp["request_id"], "perm-1");
950 assert_eq!(resp["result"]["type"], "allow");
951 }
952
953 #[cfg(feature = "testing")]
954 #[tokio::test]
955 async fn client_permission_callback_deny_writes_deny_response() {
956 use crate::permissions::{CanUseToolCallback, PermissionDecision};
957 use crate::testing::MockTransport;
958
959 let callback: CanUseToolCallback = Arc::new(|_tool_name, _input, _ctx| {
960 Box::pin(async { PermissionDecision::deny("not allowed") })
961 });
962
963 let config = ClientConfig::builder()
964 .prompt("test")
965 .can_use_tool(callback)
966 .build();
967
968 let transport = MockTransport::new();
969 transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
970 transport.enqueue(r#"{"type":"permission_request","request_id":"perm-2","request":{"type":"permission_request","tool_name":"Write","tool_input":{"path":"/etc/shadow"},"tool_use_id":"tu-2","suggestions":[]}}"#);
971 transport
972 .enqueue(&serde_json::to_string(&crate::testing::assistant_text("denied")).unwrap());
973 transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
974 let transport = Arc::new(transport);
975
976 let mut client = Client::with_transport(config, transport.clone()).unwrap();
977 client.connect().await.unwrap();
978
979 let stream = client.send("hello").unwrap();
980 tokio::pin!(stream);
981 let mut messages = Vec::new();
982 while let Some(msg) = stream.next().await {
983 messages.push(msg.unwrap());
984 }
985
986 let written = transport.written_lines();
988 let perm_responses: Vec<_> = written
989 .iter()
990 .filter(|line| line.contains("permission_response"))
991 .collect();
992 assert_eq!(perm_responses.len(), 1);
993 let resp: serde_json::Value = serde_json::from_str(perm_responses[0]).unwrap();
994 assert_eq!(resp["kind"], "permission_response");
995 assert_eq!(resp["request_id"], "perm-2");
996 assert_eq!(resp["result"]["type"], "deny");
997 assert_eq!(resp["result"]["message"], "not allowed");
998 }
999
1000 #[cfg(feature = "testing")]
1001 #[tokio::test]
1002 async fn client_permission_request_without_callback_yields_error() {
1003 use crate::testing::MockTransport;
1004
1005 let config = ClientConfig::builder().prompt("test").build();
1007
1008 let transport = MockTransport::new();
1009 transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
1010 transport.enqueue(r#"{"type":"permission_request","request_id":"perm-3","request":{"type":"permission_request","tool_name":"Bash","tool_input":{"command":"ls"},"tool_use_id":"tu-3","suggestions":[]}}"#);
1011 transport
1012 .enqueue(&serde_json::to_string(&crate::testing::assistant_text("after")).unwrap());
1013 transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
1014 let transport = Arc::new(transport);
1015
1016 let mut client = Client::with_transport(config, transport).unwrap();
1017 client.connect().await.unwrap();
1018
1019 let stream = client.send("hello").unwrap();
1020 tokio::pin!(stream);
1021
1022 let mut got_error = false;
1023 let mut messages = Vec::new();
1024 while let Some(result) = stream.next().await {
1025 match result {
1026 Ok(msg) => messages.push(msg),
1027 Err(Error::ControlProtocol(ref msg)) if msg.contains("can_use_tool") => {
1028 got_error = true;
1029 }
1030 Err(e) => panic!("unexpected error: {e}"),
1031 }
1032 }
1033
1034 assert!(
1035 got_error,
1036 "should have received a ControlProtocol error for missing callback"
1037 );
1038 }
1039
1040 #[tokio::test]
1041 async fn recv_with_timeout_respects_cancellation_token() {
1042 let (tx, rx) = flume::unbounded::<Result<Message>>();
1043 let token = CancellationToken::new();
1044
1045 token.cancel();
1047
1048 let result = recv_with_timeout(&rx, None, Some(&token)).await;
1049 assert!(result.is_err());
1050 assert!(result.unwrap_err().is_cancelled());
1051
1052 drop(tx);
1054 }
1055
1056 #[tokio::test]
1057 async fn recv_with_timeout_none_cancel_still_works() {
1058 let (_tx, rx) = flume::unbounded::<Result<Message>>();
1059
1060 let result = recv_with_timeout(&rx, Some(Duration::from_millis(10)), None).await;
1062 assert!(matches!(result, Err(Error::Timeout(_))));
1063 }
1064
1065 #[cfg(feature = "testing")]
1066 #[tokio::test]
1067 async fn client_read_timeout_none_waits() {
1068 let transport = ScenarioBuilder::new("s1")
1070 .exchange(vec![assistant_text("delayed")])
1071 .build();
1072 transport.set_recv_delay(Duration::from_millis(50));
1073 let transport = Arc::new(transport);
1074
1075 let config = ClientConfig::builder()
1076 .prompt("test")
1077 .read_timeout(None)
1078 .build();
1079
1080 let mut client = Client::with_transport(config, transport).unwrap();
1081 client.connect().await.unwrap();
1082
1083 let stream = client.send("hello").unwrap();
1084 tokio::pin!(stream);
1085
1086 let mut messages = Vec::new();
1087 while let Some(msg) = stream.next().await {
1088 messages.push(msg.unwrap());
1089 }
1090
1091 assert_eq!(messages.len(), 2);
1093 }
1094}