1use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::time::Duration;
18
19use dashmap::DashMap;
20use futures_core::Stream;
21use tokio::sync::oneshot;
22use tokio_stream::StreamExt;
23use tokio_util::sync::CancellationToken;
24
25use crate::callback::apply_callback;
26use crate::config::{ClientConfig, PermissionMode};
27use crate::errors::{Error, Result};
28use crate::transport::{CliTransport, Transport};
29use crate::types::content::UserContent;
30use crate::types::messages::{Message, SessionInfo};
31
32pub(crate) async fn cancelled_or_pending(token: Option<&CancellationToken>) {
39 match token {
40 Some(t) => t.cancelled().await,
41 None => std::future::pending().await,
42 }
43}
44
45pub(crate) async fn recv_with_timeout(
53 rx: &flume::Receiver<Result<Message>>,
54 timeout: Option<Duration>,
55 cancel: Option<&CancellationToken>,
56) -> Result<Message> {
57 let recv_fut = rx.recv_async();
58 tokio::select! {
59 biased;
60 _ = cancelled_or_pending(cancel) => {
61 Err(Error::Cancelled)
62 }
63 result = async {
64 match timeout {
65 Some(d) => match tokio::time::timeout(d, recv_fut).await {
66 Ok(Ok(msg)) => msg,
67 Ok(Err(_)) => Err(Error::Transport("message channel closed".into())),
68 Err(_) => Err(Error::Timeout(format!("read timed out after {}s", d.as_secs_f64()))),
69 },
70 None => match recv_fut.await {
71 Ok(msg) => msg,
72 Err(_) => Err(Error::Transport("message channel closed".into())),
73 },
74 }
75 } => result,
76 }
77}
78
79fn read_turn_stream<'a>(
84 rx: &'a flume::Receiver<Result<Message>>,
85 read_timeout: Option<Duration>,
86 turn_flag: Arc<AtomicBool>,
87 cancel: Option<CancellationToken>,
88) -> impl Stream<Item = Result<Message>> + 'a {
89 async_stream::stream! {
90 loop {
91 match recv_with_timeout(rx, read_timeout, cancel.as_ref()).await {
92 Ok(msg) => {
93 let is_result = matches!(&msg, Message::Result(_));
94 yield Ok(msg);
95 if is_result {
96 break;
97 }
98 }
99 Err(e) => {
100 yield Err(e);
101 break;
102 }
103 }
104 }
105 turn_flag.store(false, Ordering::Release);
106 }
107}
108
109pub struct Client {
120 config: ClientConfig,
121 transport: Arc<dyn Transport>,
122 session_id: Option<String>,
123 message_rx: Option<flume::Receiver<Result<Message>>>,
124 shutdown_tx: Option<oneshot::Sender<()>>,
125 turn_active: Arc<AtomicBool>,
126 pending_control: Arc<DashMap<String, oneshot::Sender<serde_json::Value>>>,
128 request_counter: Arc<AtomicU64>,
130}
131
132impl std::fmt::Debug for Client {
133 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134 f.debug_struct("Client")
135 .field("session_id", &self.session_id)
136 .field("connected", &self.is_connected())
137 .finish_non_exhaustive()
138 }
139}
140
141impl Client {
142 pub fn new(config: ClientConfig) -> Result<Self> {
152 config.validate()?;
153 let transport = Arc::new(CliTransport::from_config(&config)?);
154 Ok(Self {
155 config,
156 transport,
157 session_id: None,
158 message_rx: None,
159 shutdown_tx: None,
160 turn_active: Arc::new(AtomicBool::new(false)),
161 pending_control: Arc::new(DashMap::new()),
162 request_counter: Arc::new(AtomicU64::new(0)),
163 })
164 }
165
166 pub fn with_transport(config: ClientConfig, transport: Arc<dyn Transport>) -> Result<Self> {
168 Ok(Self {
169 config,
170 transport,
171 session_id: None,
172 message_rx: None,
173 shutdown_tx: None,
174 turn_active: Arc::new(AtomicBool::new(false)),
175 pending_control: Arc::new(DashMap::new()),
176 request_counter: Arc::new(AtomicU64::new(0)),
177 })
178 }
179
180 pub async fn connect(&mut self) -> Result<SessionInfo> {
186 let timeout = self.config.connect_timeout;
187 match timeout {
188 Some(d) => tokio::time::timeout(d, self.connect_inner())
189 .await
190 .map_err(|_| {
191 Error::Timeout(format!("connect timed out after {}s", d.as_secs_f64()))
192 })?,
193 None => self.connect_inner().await,
194 }
195 }
196
197 async fn connect_inner(&mut self) -> Result<SessionInfo> {
198 self.transport.connect().await?;
199
200 let (msg_tx, msg_rx) = flume::unbounded();
202 let (shutdown_tx, shutdown_rx) = oneshot::channel();
203
204 let transport = Arc::clone(&self.transport);
205 let message_callback = self.config.message_callback.clone();
206 let pending_control = Arc::clone(&self.pending_control);
207
208 let hooks: Vec<crate::hooks::HookMatcher> = std::mem::take(&mut self.config.hooks);
210 let default_hook_timeout = self.config.default_hook_timeout;
211 let hook_transport = Arc::clone(&self.transport);
212
213 let can_use_tool = self.config.can_use_tool.clone();
215 let perm_transport = Arc::clone(&self.transport);
216
217 let cancel_token = self.config.cancellation_token.clone();
219
220 let shared_session_id: Arc<std::sync::Mutex<Option<String>>> =
224 Arc::new(std::sync::Mutex::new(None));
225 let hook_session_id = Arc::clone(&shared_session_id);
226
227 tokio::spawn(async move {
229 let mut stream = transport.read_messages();
230 let mut shutdown = shutdown_rx;
231
232 loop {
233 tokio::select! {
234 biased;
235 _ = &mut shutdown => break,
236 _ = cancelled_or_pending(cancel_token.as_ref()) => break,
237 item = stream.next() => {
238 match item {
239 Some(Ok(value)) => {
240 if value.get("type").and_then(|v| v.as_str()) == Some("control_response") {
242 if let Some(req_id) = value.get("request_id").and_then(|v| v.as_str()) {
243 if let Some((_, tx)) = pending_control.remove(req_id) {
244 let _ = tx.send(value);
245 }
246 }
247 continue;
248 }
249
250 if value.get("type").and_then(|v| v.as_str()) == Some("hook_request") {
252 if let Ok(req) = serde_json::from_value::<crate::hooks::HookRequest>(value) {
253 let sid = hook_session_id
254 .lock()
255 .expect("session_id lock")
256 .clone();
257 let output = crate::hooks::dispatch_hook(
258 &req,
259 &hooks,
260 default_hook_timeout,
261 sid,
262 ).await;
263 let response = crate::hooks::HookResponse::from_output(
264 req.request_id,
265 output,
266 );
267 if let Ok(json) = serde_json::to_string(&response) {
268 let _ = hook_transport.write(&json).await;
269 }
270 }
271 continue;
272 }
273
274 if value.get("type").and_then(|v| v.as_str()) == Some("permission_request") {
276 if let Some(ref callback) = can_use_tool {
277 if let Ok(req) = serde_json::from_value::<crate::permissions::ControlRequest>(value) {
278 let crate::permissions::ControlRequestData::PermissionRequest {
279 ref tool_name,
280 ref tool_input,
281 ref tool_use_id,
282 ref suggestions,
283 } = req.request;
284 let sid = hook_session_id
285 .lock()
286 .expect("session_id lock")
287 .clone()
288 .unwrap_or_default();
289 let ctx = crate::permissions::PermissionContext {
290 tool_use_id: tool_use_id.clone(),
291 session_id: sid,
292 request_id: req.request_id.clone(),
293 suggestions: suggestions.clone(),
294 };
295 let decision = callback(tool_name, tool_input, ctx).await;
296 let response = crate::permissions::ControlResponse {
297 kind: "permission_response".into(),
298 request_id: req.request_id,
299 result: crate::permissions::ControlResponseResult::from(decision),
300 };
301 if let Ok(json) = serde_json::to_string(&response) {
302 let _ = perm_transport.write(&json).await;
303 }
304 }
305 } else {
306 let _ = msg_tx.send(Err(Error::ControlProtocol(
312 "received permission_request but no can_use_tool \
313 callback is configured — set can_use_tool on \
314 ClientConfig or use a PermissionMode that does not \
315 require interactive approval"
316 .into(),
317 )));
318 }
319 continue;
320 }
321
322 let msg: Message = match serde_json::from_value(value) {
324 Ok(m) => m,
325 Err(e) => {
326 let _ = msg_tx.send(Err(Error::Json(e)));
327 continue;
328 }
329 };
330
331 let msg = match apply_callback(msg, message_callback.as_ref()) {
333 Some(m) => m,
334 None => continue, };
336
337 if msg_tx.send(Ok(msg)).is_err() {
338 break; }
340 }
341 Some(Err(e)) => {
342 let _ = msg_tx.send(Err(e));
343 }
344 None => break, }
346 }
347 }
348 }
349 });
350
351 self.message_rx = Some(msg_rx);
352 self.shutdown_tx = Some(shutdown_tx);
353
354 let init_msg = self
358 .message_rx
359 .as_ref()
360 .unwrap()
361 .recv_async()
362 .await
363 .map_err(|_| Error::Transport("connection closed before init message".into()))?
364 .map_err(|e| Error::Transport(format!("error reading init message: {e}")))?;
365
366 if let Message::System(ref sys) = init_msg {
367 let info = SessionInfo::try_from(sys)?;
368 self.session_id = Some(info.session_id.clone());
369 *shared_session_id.lock().expect("session_id lock") = Some(info.session_id.clone());
372 Ok(info)
373 } else {
374 Err(Error::ControlProtocol(format!(
375 "expected system/init as first message, got: {init_msg:?}"
376 )))
377 }
378 }
379
380 pub fn send(
385 &self,
386 prompt: impl Into<String>,
387 ) -> Result<impl Stream<Item = Result<Message>> + '_> {
388 let prompt = prompt.into();
389 let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
390 let transport = Arc::clone(&self.transport);
391
392 if self
394 .turn_active
395 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
396 .is_err()
397 {
398 return Err(Error::ControlProtocol("turn already in progress".into()));
399 }
400 let turn_flag = Arc::clone(&self.turn_active);
401 let read_timeout = self.config.read_timeout;
402 let cancel = self.config.cancellation_token.clone();
403
404 Ok(async_stream::stream! {
405 if let Err(e) = transport.write(&prompt).await {
407 turn_flag.store(false, Ordering::Release);
408 yield Err(e);
409 return;
410 }
411
412 let inner = read_turn_stream(rx, read_timeout, turn_flag, cancel);
413 tokio::pin!(inner);
414 while let Some(item) = inner.next().await {
415 yield item;
416 }
417 })
418 }
419
420 pub fn send_content(
431 &self,
432 content: Vec<UserContent>,
433 ) -> Result<impl Stream<Item = Result<Message>> + '_> {
434 if content.is_empty() {
435 return Err(Error::Config("content must not be empty".into()));
436 }
437
438 let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
439 let transport = Arc::clone(&self.transport);
440
441 if self
443 .turn_active
444 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
445 .is_err()
446 {
447 return Err(Error::ControlProtocol("turn already in progress".into()));
448 }
449 let turn_flag = Arc::clone(&self.turn_active);
450 let read_timeout = self.config.read_timeout;
451 let cancel = self.config.cancellation_token.clone();
452
453 Ok(async_stream::stream! {
454 let user_message = serde_json::json!({
456 "type": "user",
457 "message": {
458 "role": "user",
459 "content": content
460 }
461 });
462 let json = match serde_json::to_string(&user_message) {
463 Ok(j) => j,
464 Err(e) => {
465 turn_flag.store(false, Ordering::Release);
466 yield Err(Error::Json(e));
467 return;
468 }
469 };
470
471 if let Err(e) = transport.write(&json).await {
472 turn_flag.store(false, Ordering::Release);
473 yield Err(e);
474 return;
475 }
476
477 let inner = read_turn_stream(rx, read_timeout, turn_flag, cancel);
478 tokio::pin!(inner);
479 while let Some(item) = inner.next().await {
480 yield item;
481 }
482 })
483 }
484
485 pub fn receive_messages(&self) -> Result<impl Stream<Item = Result<Message>> + '_> {
489 let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
490 let read_timeout = self.config.read_timeout;
491 let cancel = self.config.cancellation_token.clone();
492
493 Ok(async_stream::stream! {
494 loop {
495 match recv_with_timeout(rx, read_timeout, cancel.as_ref()).await {
496 Ok(msg) => yield Ok(msg),
497 Err(e) if matches!(e, Error::Transport(_)) => break, Err(e) => {
499 yield Err(e);
500 break;
501 }
502 }
503 }
504 })
505 }
506
507 pub async fn interrupt(&self) -> Result<()> {
509 self.transport.interrupt().await
510 }
511
512 pub async fn respond_to_permission(
517 &self,
518 request_id: &str,
519 decision: crate::permissions::PermissionDecision,
520 ) -> Result<()> {
521 use crate::permissions::{ControlResponse, ControlResponseResult};
522
523 let response = ControlResponse {
524 kind: "permission_response".into(),
525 request_id: request_id.to_string(),
526 result: ControlResponseResult::from(decision),
527 };
528 let json = serde_json::to_string(&response).map_err(Error::Json)?;
529 self.transport.write(&json).await
530 }
531
532 async fn send_control_request(&self, request: serde_json::Value) -> Result<serde_json::Value> {
541 let counter = self.request_counter.fetch_add(1, Ordering::Relaxed);
542 let request_id = format!("sdk_req_{counter}");
543
544 let (tx, rx) = oneshot::channel();
545 self.pending_control.insert(request_id.clone(), tx);
546
547 let envelope = serde_json::json!({
548 "type": "control_request",
549 "request_id": request_id,
550 "request": request
551 });
552 let json = serde_json::to_string(&envelope).map_err(Error::Json)?;
553 self.transport.write(&json).await?;
554
555 rx.await
556 .map_err(|_| Error::ControlProtocol("control response channel closed".into()))
557 }
558
559 pub async fn set_model(&self, model: Option<&str>) -> Result<()> {
568 self.send_control_request(serde_json::json!({
569 "subtype": "set_model",
570 "model": model
571 }))
572 .await?;
573 Ok(())
574 }
575
576 pub async fn set_permission_mode(&self, mode: PermissionMode) -> Result<()> {
583 self.send_control_request(serde_json::json!({
584 "subtype": "set_permission_mode",
585 "mode": mode.as_cli_flag()
586 }))
587 .await?;
588 Ok(())
589 }
590
591 pub(crate) async fn transport_write(&self, data: &str) -> Result<()> {
596 self.transport.write(data).await
597 }
598
599 pub(crate) fn take_message_rx(&mut self) -> Option<flume::Receiver<Result<Message>>> {
604 self.message_rx.take()
605 }
606
607 #[must_use]
609 pub fn read_timeout(&self) -> Option<Duration> {
610 self.config.read_timeout
611 }
612
613 #[must_use]
615 pub fn session_id(&self) -> Option<&str> {
616 self.session_id.as_deref()
617 }
618
619 #[must_use]
621 pub fn is_connected(&self) -> bool {
622 self.transport.is_ready()
623 }
624
625 pub async fn close(&mut self) -> Result<()> {
629 if let Some(tx) = self.shutdown_tx.take() {
631 let _ = tx.send(());
632 }
633 self.message_rx.take();
635 self.transport.close().await?;
636 Ok(())
637 }
638}
639
640impl Drop for Client {
641 fn drop(&mut self) {
642 if self.shutdown_tx.is_some() || self.message_rx.is_some() {
643 tracing::warn!(
644 "claude_cli_sdk::Client dropped without calling close(). \
645 Resources may not be cleaned up properly."
646 );
647 }
648 }
649}
650
651#[cfg(test)]
654mod tests {
655 use super::*;
656 use crate::config::ClientConfig;
657
658 #[cfg(feature = "testing")]
659 use crate::testing::{ScenarioBuilder, assistant_text};
660
661 fn test_config() -> ClientConfig {
662 ClientConfig::builder().prompt("test").build()
663 }
664
665 #[cfg(feature = "testing")]
666 #[tokio::test]
667 async fn client_connect_and_receive_init() {
668 let transport = ScenarioBuilder::new("test-session")
669 .exchange(vec![assistant_text("Hello!")])
670 .build();
671 let transport = Arc::new(transport);
672
673 let mut client = Client::with_transport(test_config(), transport).unwrap();
674 let info = client.connect().await.unwrap();
675
676 assert_eq!(info.session_id, "test-session");
677 assert!(client.is_connected());
678 assert_eq!(client.session_id(), Some("test-session"));
679 }
680
681 #[cfg(feature = "testing")]
682 #[tokio::test]
683 async fn client_send_yields_messages() {
684 let transport = ScenarioBuilder::new("s1")
685 .exchange(vec![assistant_text("response")])
686 .build();
687 let transport = Arc::new(transport);
688
689 let mut client = Client::with_transport(test_config(), transport).unwrap();
690 client.connect().await.unwrap();
691
692 let stream = client.send("hello").unwrap();
693 tokio::pin!(stream);
694
695 let mut messages = Vec::new();
696 while let Some(msg) = stream.next().await {
697 messages.push(msg.unwrap());
698 }
699
700 assert_eq!(messages.len(), 2);
702 assert!(matches!(&messages[0], Message::Assistant(_)));
703 assert!(matches!(&messages[1], Message::Result(_)));
704 }
705
706 #[cfg(feature = "testing")]
707 #[tokio::test]
708 async fn client_close_succeeds() {
709 let transport = ScenarioBuilder::new("s1").build();
710 let transport = Arc::new(transport);
711
712 let mut client = Client::with_transport(test_config(), transport).unwrap();
713 client.connect().await.unwrap();
714 assert!(client.close().await.is_ok());
715 }
716
717 #[cfg(feature = "testing")]
718 #[tokio::test]
719 async fn client_message_callback_filters() {
720 use crate::callback::MessageCallback;
721
722 let callback: MessageCallback = Arc::new(|msg| match &msg {
724 Message::Assistant(_) => None,
725 _ => Some(msg),
726 });
727
728 let config = ClientConfig::builder()
729 .prompt("test")
730 .message_callback(callback)
731 .build();
732
733 let transport = ScenarioBuilder::new("s1")
734 .exchange(vec![assistant_text("filtered")])
735 .build();
736 let transport = Arc::new(transport);
737
738 let mut client = Client::with_transport(config, transport).unwrap();
739 client.connect().await.unwrap();
740
741 let stream = client.send("hello").unwrap();
742 tokio::pin!(stream);
743
744 let mut messages = Vec::new();
745 while let Some(msg) = stream.next().await {
746 messages.push(msg.unwrap());
747 }
748
749 assert_eq!(messages.len(), 1);
751 assert!(matches!(&messages[0], Message::Result(_)));
752 }
753
754 #[cfg(feature = "testing")]
755 #[test]
756 fn client_debug_before_connect() {
757 let transport = Arc::new(crate::testing::MockTransport::new());
758 let client = Client::with_transport(test_config(), transport).unwrap();
759 let debug = format!("{client:?}");
760 assert!(debug.contains("Client"));
761 }
762
763 #[cfg(feature = "testing")]
766 #[tokio::test]
767 async fn client_connect_timeout_fires() {
768 use crate::testing::MockTransport;
769
770 let transport = MockTransport::new();
771 transport.set_connect_delay(Duration::from_secs(5));
773 let transport = Arc::new(transport);
774
775 let config = ClientConfig::builder()
776 .prompt("test")
777 .connect_timeout(Some(Duration::from_millis(50)))
778 .build();
779
780 let mut client = Client::with_transport(config, transport).unwrap();
781 let result = client.connect().await;
782 assert!(result.is_err());
783 assert!(matches!(result.unwrap_err(), Error::Timeout(_)));
784 }
785
786 #[cfg(feature = "testing")]
787 #[tokio::test]
788 async fn client_read_timeout_fires() {
789 let transport = ScenarioBuilder::new("s1")
792 .exchange(vec![assistant_text("delayed")])
793 .build();
794 transport.set_recv_delay(Duration::from_millis(200));
802 let transport = Arc::new(transport);
803
804 let config = ClientConfig::builder()
805 .prompt("test")
806 .connect_timeout(Some(Duration::from_secs(5)))
807 .read_timeout(Some(Duration::from_millis(50)))
808 .build();
809
810 let mut client = Client::with_transport(config, transport).unwrap();
811 client.connect().await.unwrap();
812
813 let stream = client.send("hello").unwrap();
814 tokio::pin!(stream);
815
816 let mut got_timeout = false;
817 while let Some(msg) = stream.next().await {
818 if let Err(Error::Timeout(_)) = msg {
819 got_timeout = true;
820 break;
821 }
822 }
823 assert!(got_timeout, "expected a timeout error");
824 }
825
826 #[cfg(feature = "testing")]
827 #[tokio::test]
828 async fn client_permission_callback_invoked_and_responds() {
829 use crate::permissions::{CanUseToolCallback, PermissionDecision};
830 use crate::testing::MockTransport;
831 use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
832
833 let invoked = Arc::new(AtomicBool::new(false));
834 let invoked_clone = Arc::clone(&invoked);
835
836 let callback: CanUseToolCallback = Arc::new(move |tool_name: &str, _input, _ctx| {
837 let invoked = Arc::clone(&invoked_clone);
838 let tool = tool_name.to_owned();
839 Box::pin(async move {
840 invoked.store(true, AtomicOrdering::Release);
841 assert_eq!(tool, "Bash");
842 PermissionDecision::allow()
843 })
844 });
845
846 let config = ClientConfig::builder()
847 .prompt("test")
848 .can_use_tool(callback)
849 .build();
850
851 let transport = MockTransport::new();
852 transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
854 transport.enqueue(r#"{"type":"permission_request","request_id":"perm-1","request":{"type":"permission_request","tool_name":"Bash","tool_input":{"command":"ls"},"tool_use_id":"tu-1","suggestions":["allow_once"]}}"#);
855 transport.enqueue(&serde_json::to_string(&crate::testing::assistant_text("done")).unwrap());
856 transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
857 let transport = Arc::new(transport);
858
859 let mut client = Client::with_transport(config, transport.clone()).unwrap();
860 client.connect().await.unwrap();
861
862 let stream = client.send("hello").unwrap();
863 tokio::pin!(stream);
864 let mut messages = Vec::new();
865 while let Some(msg) = stream.next().await {
866 messages.push(msg.unwrap());
867 }
868
869 assert!(
871 invoked.load(AtomicOrdering::Acquire),
872 "permission callback was not invoked"
873 );
874
875 assert_eq!(messages.len(), 2);
877 assert!(matches!(&messages[0], Message::Assistant(_)));
878 assert!(matches!(&messages[1], Message::Result(_)));
879
880 let written = transport.written_lines();
882 let perm_responses: Vec<_> = written
883 .iter()
884 .filter(|line| line.contains("permission_response"))
885 .collect();
886 assert_eq!(
887 perm_responses.len(),
888 1,
889 "expected exactly one permission_response written"
890 );
891 let resp: serde_json::Value = serde_json::from_str(perm_responses[0]).unwrap();
892 assert_eq!(resp["kind"], "permission_response");
893 assert_eq!(resp["request_id"], "perm-1");
894 assert_eq!(resp["result"]["type"], "allow");
895 }
896
897 #[cfg(feature = "testing")]
898 #[tokio::test]
899 async fn client_permission_callback_deny_writes_deny_response() {
900 use crate::permissions::{CanUseToolCallback, PermissionDecision};
901 use crate::testing::MockTransport;
902
903 let callback: CanUseToolCallback = Arc::new(|_tool_name, _input, _ctx| {
904 Box::pin(async { PermissionDecision::deny("not allowed") })
905 });
906
907 let config = ClientConfig::builder()
908 .prompt("test")
909 .can_use_tool(callback)
910 .build();
911
912 let transport = MockTransport::new();
913 transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
914 transport.enqueue(r#"{"type":"permission_request","request_id":"perm-2","request":{"type":"permission_request","tool_name":"Write","tool_input":{"path":"/etc/shadow"},"tool_use_id":"tu-2","suggestions":[]}}"#);
915 transport
916 .enqueue(&serde_json::to_string(&crate::testing::assistant_text("denied")).unwrap());
917 transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
918 let transport = Arc::new(transport);
919
920 let mut client = Client::with_transport(config, transport.clone()).unwrap();
921 client.connect().await.unwrap();
922
923 let stream = client.send("hello").unwrap();
924 tokio::pin!(stream);
925 let mut messages = Vec::new();
926 while let Some(msg) = stream.next().await {
927 messages.push(msg.unwrap());
928 }
929
930 let written = transport.written_lines();
932 let perm_responses: Vec<_> = written
933 .iter()
934 .filter(|line| line.contains("permission_response"))
935 .collect();
936 assert_eq!(perm_responses.len(), 1);
937 let resp: serde_json::Value = serde_json::from_str(perm_responses[0]).unwrap();
938 assert_eq!(resp["kind"], "permission_response");
939 assert_eq!(resp["request_id"], "perm-2");
940 assert_eq!(resp["result"]["type"], "deny");
941 assert_eq!(resp["result"]["message"], "not allowed");
942 }
943
944 #[cfg(feature = "testing")]
945 #[tokio::test]
946 async fn client_permission_request_without_callback_yields_error() {
947 use crate::testing::MockTransport;
948
949 let config = ClientConfig::builder().prompt("test").build();
951
952 let transport = MockTransport::new();
953 transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
954 transport.enqueue(r#"{"type":"permission_request","request_id":"perm-3","request":{"type":"permission_request","tool_name":"Bash","tool_input":{"command":"ls"},"tool_use_id":"tu-3","suggestions":[]}}"#);
955 transport
956 .enqueue(&serde_json::to_string(&crate::testing::assistant_text("after")).unwrap());
957 transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
958 let transport = Arc::new(transport);
959
960 let mut client = Client::with_transport(config, transport).unwrap();
961 client.connect().await.unwrap();
962
963 let stream = client.send("hello").unwrap();
964 tokio::pin!(stream);
965
966 let mut got_error = false;
967 let mut messages = Vec::new();
968 while let Some(result) = stream.next().await {
969 match result {
970 Ok(msg) => messages.push(msg),
971 Err(Error::ControlProtocol(ref msg)) if msg.contains("can_use_tool") => {
972 got_error = true;
973 }
974 Err(e) => panic!("unexpected error: {e}"),
975 }
976 }
977
978 assert!(
979 got_error,
980 "should have received a ControlProtocol error for missing callback"
981 );
982 }
983
984 #[tokio::test]
985 async fn recv_with_timeout_respects_cancellation_token() {
986 let (tx, rx) = flume::unbounded::<Result<Message>>();
987 let token = CancellationToken::new();
988
989 token.cancel();
991
992 let result = recv_with_timeout(&rx, None, Some(&token)).await;
993 assert!(result.is_err());
994 assert!(result.unwrap_err().is_cancelled());
995
996 drop(tx);
998 }
999
1000 #[tokio::test]
1001 async fn recv_with_timeout_none_cancel_still_works() {
1002 let (_tx, rx) = flume::unbounded::<Result<Message>>();
1003
1004 let result = recv_with_timeout(&rx, Some(Duration::from_millis(10)), None).await;
1006 assert!(matches!(result, Err(Error::Timeout(_))));
1007 }
1008
1009 #[cfg(feature = "testing")]
1010 #[tokio::test]
1011 async fn client_read_timeout_none_waits() {
1012 let transport = ScenarioBuilder::new("s1")
1014 .exchange(vec![assistant_text("delayed")])
1015 .build();
1016 transport.set_recv_delay(Duration::from_millis(50));
1017 let transport = Arc::new(transport);
1018
1019 let config = ClientConfig::builder()
1020 .prompt("test")
1021 .read_timeout(None)
1022 .build();
1023
1024 let mut client = Client::with_transport(config, transport).unwrap();
1025 client.connect().await.unwrap();
1026
1027 let stream = client.send("hello").unwrap();
1028 tokio::pin!(stream);
1029
1030 let mut messages = Vec::new();
1031 while let Some(msg) = stream.next().await {
1032 messages.push(msg.unwrap());
1033 }
1034
1035 assert_eq!(messages.len(), 2);
1037 }
1038}