1use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::time::Duration;
18
19use dashmap::DashMap;
20use futures_core::Stream;
21use tokio::sync::oneshot;
22use tokio_stream::StreamExt;
23
24use crate::callback::apply_callback;
25use crate::config::{ClientConfig, PermissionMode};
26use crate::errors::{Error, Result};
27use crate::transport::{CliTransport, Transport};
28use crate::types::content::UserContent;
29use crate::types::messages::{Message, SessionInfo};
30
31pub(crate) async fn recv_with_timeout(
38 rx: &flume::Receiver<Result<Message>>,
39 timeout: Option<Duration>,
40) -> Result<Message> {
41 let recv_fut = rx.recv_async();
42 match timeout {
43 Some(d) => tokio::time::timeout(d, recv_fut)
44 .await
45 .map_err(|_| Error::Timeout(format!("read timed out after {}s", d.as_secs_f64())))?
46 .map_err(|_| Error::Transport("message channel closed".into()))?,
47 None => recv_fut
48 .await
49 .map_err(|_| Error::Transport("message channel closed".into()))?,
50 }
51}
52
53fn read_turn_stream<'a>(
58 rx: &'a flume::Receiver<Result<Message>>,
59 read_timeout: Option<Duration>,
60 turn_flag: Arc<AtomicBool>,
61) -> impl Stream<Item = Result<Message>> + 'a {
62 async_stream::stream! {
63 loop {
64 match recv_with_timeout(rx, read_timeout).await {
65 Ok(msg) => {
66 let is_result = matches!(&msg, Message::Result(_));
67 yield Ok(msg);
68 if is_result {
69 break;
70 }
71 }
72 Err(e) => {
73 yield Err(e);
74 break;
75 }
76 }
77 }
78 turn_flag.store(false, Ordering::Release);
79 }
80}
81
82pub struct Client {
93 config: ClientConfig,
94 transport: Arc<dyn Transport>,
95 session_id: Option<String>,
96 message_rx: Option<flume::Receiver<Result<Message>>>,
97 shutdown_tx: Option<oneshot::Sender<()>>,
98 turn_active: Arc<AtomicBool>,
99 pending_control: Arc<DashMap<String, oneshot::Sender<serde_json::Value>>>,
101 request_counter: Arc<AtomicU64>,
103}
104
105impl std::fmt::Debug for Client {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 f.debug_struct("Client")
108 .field("session_id", &self.session_id)
109 .field("connected", &self.is_connected())
110 .finish_non_exhaustive()
111 }
112}
113
114impl Client {
115 pub fn new(config: ClientConfig) -> Result<Self> {
125 config.validate()?;
126 let transport = Arc::new(CliTransport::from_config(&config)?);
127 Ok(Self {
128 config,
129 transport,
130 session_id: None,
131 message_rx: None,
132 shutdown_tx: None,
133 turn_active: Arc::new(AtomicBool::new(false)),
134 pending_control: Arc::new(DashMap::new()),
135 request_counter: Arc::new(AtomicU64::new(0)),
136 })
137 }
138
139 pub fn with_transport(config: ClientConfig, transport: Arc<dyn Transport>) -> Result<Self> {
141 Ok(Self {
142 config,
143 transport,
144 session_id: None,
145 message_rx: None,
146 shutdown_tx: None,
147 turn_active: Arc::new(AtomicBool::new(false)),
148 pending_control: Arc::new(DashMap::new()),
149 request_counter: Arc::new(AtomicU64::new(0)),
150 })
151 }
152
153 pub async fn connect(&mut self) -> Result<SessionInfo> {
159 let timeout = self.config.connect_timeout;
160 match timeout {
161 Some(d) => tokio::time::timeout(d, self.connect_inner())
162 .await
163 .map_err(|_| {
164 Error::Timeout(format!("connect timed out after {}s", d.as_secs_f64()))
165 })?,
166 None => self.connect_inner().await,
167 }
168 }
169
170 async fn connect_inner(&mut self) -> Result<SessionInfo> {
171 self.transport.connect().await?;
172
173 let (msg_tx, msg_rx) = flume::unbounded();
175 let (shutdown_tx, shutdown_rx) = oneshot::channel();
176
177 let transport = Arc::clone(&self.transport);
178 let message_callback = self.config.message_callback.clone();
179 let pending_control = Arc::clone(&self.pending_control);
180
181 let hooks: Vec<crate::hooks::HookMatcher> = std::mem::take(&mut self.config.hooks);
183 let default_hook_timeout = self.config.default_hook_timeout;
184 let hook_transport = Arc::clone(&self.transport);
185
186 let can_use_tool = self.config.can_use_tool.clone();
188 let perm_transport = Arc::clone(&self.transport);
189
190 let shared_session_id: Arc<std::sync::Mutex<Option<String>>> =
194 Arc::new(std::sync::Mutex::new(None));
195 let hook_session_id = Arc::clone(&shared_session_id);
196
197 tokio::spawn(async move {
199 let mut stream = transport.read_messages();
200 let mut shutdown = shutdown_rx;
201
202 loop {
203 tokio::select! {
204 biased;
205 _ = &mut shutdown => break,
206 item = stream.next() => {
207 match item {
208 Some(Ok(value)) => {
209 if value.get("type").and_then(|v| v.as_str()) == Some("control_response") {
211 if let Some(req_id) = value.get("request_id").and_then(|v| v.as_str()) {
212 if let Some((_, tx)) = pending_control.remove(req_id) {
213 let _ = tx.send(value);
214 }
215 }
216 continue;
217 }
218
219 if value.get("type").and_then(|v| v.as_str()) == Some("hook_request") {
221 if let Ok(req) = serde_json::from_value::<crate::hooks::HookRequest>(value) {
222 let sid = hook_session_id
223 .lock()
224 .expect("session_id lock")
225 .clone();
226 let output = crate::hooks::dispatch_hook(
227 &req,
228 &hooks,
229 default_hook_timeout,
230 sid,
231 ).await;
232 let response = crate::hooks::HookResponse::from_output(
233 req.request_id,
234 output,
235 );
236 if let Ok(json) = serde_json::to_string(&response) {
237 let _ = hook_transport.write(&json).await;
238 }
239 }
240 continue;
241 }
242
243 if value.get("type").and_then(|v| v.as_str()) == Some("permission_request") {
245 if let Some(ref callback) = can_use_tool {
246 if let Ok(req) = serde_json::from_value::<crate::permissions::ControlRequest>(value) {
247 let crate::permissions::ControlRequestData::PermissionRequest {
248 ref tool_name,
249 ref tool_input,
250 ref tool_use_id,
251 ref suggestions,
252 } = req.request;
253 let sid = hook_session_id
254 .lock()
255 .expect("session_id lock")
256 .clone()
257 .unwrap_or_default();
258 let ctx = crate::permissions::PermissionContext {
259 tool_use_id: tool_use_id.clone(),
260 session_id: sid,
261 request_id: req.request_id.clone(),
262 suggestions: suggestions.clone(),
263 };
264 let decision = callback(tool_name, tool_input, ctx).await;
265 let response = crate::permissions::ControlResponse {
266 kind: "permission_response".into(),
267 request_id: req.request_id,
268 result: crate::permissions::ControlResponseResult::from(decision),
269 };
270 if let Ok(json) = serde_json::to_string(&response) {
271 let _ = perm_transport.write(&json).await;
272 }
273 }
274 } else {
275 let _ = msg_tx.send(Err(Error::ControlProtocol(
281 "received permission_request but no can_use_tool \
282 callback is configured — set can_use_tool on \
283 ClientConfig or use a PermissionMode that does not \
284 require interactive approval"
285 .into(),
286 )));
287 }
288 continue;
289 }
290
291 let msg: Message = match serde_json::from_value(value) {
293 Ok(m) => m,
294 Err(e) => {
295 let _ = msg_tx.send(Err(Error::Json(e)));
296 continue;
297 }
298 };
299
300 let msg = match apply_callback(msg, message_callback.as_ref()) {
302 Some(m) => m,
303 None => continue, };
305
306 if msg_tx.send(Ok(msg)).is_err() {
307 break; }
309 }
310 Some(Err(e)) => {
311 let _ = msg_tx.send(Err(e));
312 }
313 None => break, }
315 }
316 }
317 }
318 });
319
320 self.message_rx = Some(msg_rx);
321 self.shutdown_tx = Some(shutdown_tx);
322
323 let init_msg = self
327 .message_rx
328 .as_ref()
329 .unwrap()
330 .recv_async()
331 .await
332 .map_err(|_| Error::Transport("connection closed before init message".into()))?
333 .map_err(|e| Error::Transport(format!("error reading init message: {e}")))?;
334
335 if let Message::System(ref sys) = init_msg {
336 let info = SessionInfo::try_from(sys)?;
337 self.session_id = Some(info.session_id.clone());
338 *shared_session_id.lock().expect("session_id lock") = Some(info.session_id.clone());
341 Ok(info)
342 } else {
343 Err(Error::ControlProtocol(format!(
344 "expected system/init as first message, got: {init_msg:?}"
345 )))
346 }
347 }
348
349 pub fn send(
354 &self,
355 prompt: impl Into<String>,
356 ) -> Result<impl Stream<Item = Result<Message>> + '_> {
357 let prompt = prompt.into();
358 let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
359 let transport = Arc::clone(&self.transport);
360
361 if self
363 .turn_active
364 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
365 .is_err()
366 {
367 return Err(Error::ControlProtocol("turn already in progress".into()));
368 }
369 let turn_flag = Arc::clone(&self.turn_active);
370 let read_timeout = self.config.read_timeout;
371
372 Ok(async_stream::stream! {
373 if let Err(e) = transport.write(&prompt).await {
375 turn_flag.store(false, Ordering::Release);
376 yield Err(e);
377 return;
378 }
379
380 let inner = read_turn_stream(rx, read_timeout, turn_flag);
381 tokio::pin!(inner);
382 while let Some(item) = inner.next().await {
383 yield item;
384 }
385 })
386 }
387
388 pub fn send_content(
399 &self,
400 content: Vec<UserContent>,
401 ) -> Result<impl Stream<Item = Result<Message>> + '_> {
402 if content.is_empty() {
403 return Err(Error::Config("content must not be empty".into()));
404 }
405
406 let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
407 let transport = Arc::clone(&self.transport);
408
409 if self
411 .turn_active
412 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
413 .is_err()
414 {
415 return Err(Error::ControlProtocol("turn already in progress".into()));
416 }
417 let turn_flag = Arc::clone(&self.turn_active);
418 let read_timeout = self.config.read_timeout;
419
420 Ok(async_stream::stream! {
421 let user_message = serde_json::json!({
423 "type": "user",
424 "message": {
425 "role": "user",
426 "content": content
427 }
428 });
429 let json = match serde_json::to_string(&user_message) {
430 Ok(j) => j,
431 Err(e) => {
432 turn_flag.store(false, Ordering::Release);
433 yield Err(Error::Json(e));
434 return;
435 }
436 };
437
438 if let Err(e) = transport.write(&json).await {
439 turn_flag.store(false, Ordering::Release);
440 yield Err(e);
441 return;
442 }
443
444 let inner = read_turn_stream(rx, read_timeout, turn_flag);
445 tokio::pin!(inner);
446 while let Some(item) = inner.next().await {
447 yield item;
448 }
449 })
450 }
451
452 pub fn receive_messages(&self) -> Result<impl Stream<Item = Result<Message>> + '_> {
456 let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
457 let read_timeout = self.config.read_timeout;
458
459 Ok(async_stream::stream! {
460 loop {
461 match recv_with_timeout(rx, read_timeout).await {
462 Ok(msg) => yield Ok(msg),
463 Err(e) if matches!(e, Error::Transport(_)) => break, Err(e) => {
465 yield Err(e);
466 break;
467 }
468 }
469 }
470 })
471 }
472
473 pub async fn interrupt(&self) -> Result<()> {
475 self.transport.interrupt().await
476 }
477
478 pub async fn respond_to_permission(
483 &self,
484 request_id: &str,
485 decision: crate::permissions::PermissionDecision,
486 ) -> Result<()> {
487 use crate::permissions::{ControlResponse, ControlResponseResult};
488
489 let response = ControlResponse {
490 kind: "permission_response".into(),
491 request_id: request_id.to_string(),
492 result: ControlResponseResult::from(decision),
493 };
494 let json = serde_json::to_string(&response).map_err(Error::Json)?;
495 self.transport.write(&json).await
496 }
497
498 async fn send_control_request(&self, request: serde_json::Value) -> Result<serde_json::Value> {
507 let counter = self.request_counter.fetch_add(1, Ordering::Relaxed);
508 let request_id = format!("sdk_req_{counter}");
509
510 let (tx, rx) = oneshot::channel();
511 self.pending_control.insert(request_id.clone(), tx);
512
513 let envelope = serde_json::json!({
514 "type": "control_request",
515 "request_id": request_id,
516 "request": request
517 });
518 let json = serde_json::to_string(&envelope).map_err(Error::Json)?;
519 self.transport.write(&json).await?;
520
521 rx.await
522 .map_err(|_| Error::ControlProtocol("control response channel closed".into()))
523 }
524
525 pub async fn set_model(&self, model: Option<&str>) -> Result<()> {
534 self.send_control_request(serde_json::json!({
535 "subtype": "set_model",
536 "model": model
537 }))
538 .await?;
539 Ok(())
540 }
541
542 pub async fn set_permission_mode(&self, mode: PermissionMode) -> Result<()> {
549 self.send_control_request(serde_json::json!({
550 "subtype": "set_permission_mode",
551 "mode": mode.as_cli_flag()
552 }))
553 .await?;
554 Ok(())
555 }
556
557 pub(crate) async fn transport_write(&self, data: &str) -> Result<()> {
562 self.transport.write(data).await
563 }
564
565 pub(crate) fn take_message_rx(&mut self) -> Option<flume::Receiver<Result<Message>>> {
570 self.message_rx.take()
571 }
572
573 #[must_use]
575 pub fn read_timeout(&self) -> Option<Duration> {
576 self.config.read_timeout
577 }
578
579 #[must_use]
581 pub fn session_id(&self) -> Option<&str> {
582 self.session_id.as_deref()
583 }
584
585 #[must_use]
587 pub fn is_connected(&self) -> bool {
588 self.transport.is_ready()
589 }
590
591 pub async fn close(&mut self) -> Result<()> {
595 if let Some(tx) = self.shutdown_tx.take() {
597 let _ = tx.send(());
598 }
599 self.message_rx.take();
601 self.transport.close().await?;
602 Ok(())
603 }
604}
605
606impl Drop for Client {
607 fn drop(&mut self) {
608 if self.shutdown_tx.is_some() || self.message_rx.is_some() {
609 tracing::warn!(
610 "claude_cli_sdk::Client dropped without calling close(). \
611 Resources may not be cleaned up properly."
612 );
613 }
614 }
615}
616
617#[cfg(test)]
620mod tests {
621 use super::*;
622 use crate::config::ClientConfig;
623
624 #[cfg(feature = "testing")]
625 use crate::testing::{ScenarioBuilder, assistant_text};
626
627 fn test_config() -> ClientConfig {
628 ClientConfig::builder().prompt("test").build()
629 }
630
631 #[cfg(feature = "testing")]
632 #[tokio::test]
633 async fn client_connect_and_receive_init() {
634 let transport = ScenarioBuilder::new("test-session")
635 .exchange(vec![assistant_text("Hello!")])
636 .build();
637 let transport = Arc::new(transport);
638
639 let mut client = Client::with_transport(test_config(), transport).unwrap();
640 let info = client.connect().await.unwrap();
641
642 assert_eq!(info.session_id, "test-session");
643 assert!(client.is_connected());
644 assert_eq!(client.session_id(), Some("test-session"));
645 }
646
647 #[cfg(feature = "testing")]
648 #[tokio::test]
649 async fn client_send_yields_messages() {
650 let transport = ScenarioBuilder::new("s1")
651 .exchange(vec![assistant_text("response")])
652 .build();
653 let transport = Arc::new(transport);
654
655 let mut client = Client::with_transport(test_config(), transport).unwrap();
656 client.connect().await.unwrap();
657
658 let stream = client.send("hello").unwrap();
659 tokio::pin!(stream);
660
661 let mut messages = Vec::new();
662 while let Some(msg) = stream.next().await {
663 messages.push(msg.unwrap());
664 }
665
666 assert_eq!(messages.len(), 2);
668 assert!(matches!(&messages[0], Message::Assistant(_)));
669 assert!(matches!(&messages[1], Message::Result(_)));
670 }
671
672 #[cfg(feature = "testing")]
673 #[tokio::test]
674 async fn client_close_succeeds() {
675 let transport = ScenarioBuilder::new("s1").build();
676 let transport = Arc::new(transport);
677
678 let mut client = Client::with_transport(test_config(), transport).unwrap();
679 client.connect().await.unwrap();
680 assert!(client.close().await.is_ok());
681 }
682
683 #[cfg(feature = "testing")]
684 #[tokio::test]
685 async fn client_message_callback_filters() {
686 use crate::callback::MessageCallback;
687
688 let callback: MessageCallback = Arc::new(|msg| match &msg {
690 Message::Assistant(_) => None,
691 _ => Some(msg),
692 });
693
694 let config = ClientConfig::builder()
695 .prompt("test")
696 .message_callback(callback)
697 .build();
698
699 let transport = ScenarioBuilder::new("s1")
700 .exchange(vec![assistant_text("filtered")])
701 .build();
702 let transport = Arc::new(transport);
703
704 let mut client = Client::with_transport(config, transport).unwrap();
705 client.connect().await.unwrap();
706
707 let stream = client.send("hello").unwrap();
708 tokio::pin!(stream);
709
710 let mut messages = Vec::new();
711 while let Some(msg) = stream.next().await {
712 messages.push(msg.unwrap());
713 }
714
715 assert_eq!(messages.len(), 1);
717 assert!(matches!(&messages[0], Message::Result(_)));
718 }
719
720 #[cfg(feature = "testing")]
721 #[test]
722 fn client_debug_before_connect() {
723 let transport = Arc::new(crate::testing::MockTransport::new());
724 let client = Client::with_transport(test_config(), transport).unwrap();
725 let debug = format!("{client:?}");
726 assert!(debug.contains("Client"));
727 }
728
729 #[cfg(feature = "testing")]
732 #[tokio::test]
733 async fn client_connect_timeout_fires() {
734 use crate::testing::MockTransport;
735
736 let transport = MockTransport::new();
737 transport.set_connect_delay(Duration::from_secs(5));
739 let transport = Arc::new(transport);
740
741 let config = ClientConfig::builder()
742 .prompt("test")
743 .connect_timeout(Some(Duration::from_millis(50)))
744 .build();
745
746 let mut client = Client::with_transport(config, transport).unwrap();
747 let result = client.connect().await;
748 assert!(result.is_err());
749 assert!(matches!(result.unwrap_err(), Error::Timeout(_)));
750 }
751
752 #[cfg(feature = "testing")]
753 #[tokio::test]
754 async fn client_read_timeout_fires() {
755 let transport = ScenarioBuilder::new("s1")
758 .exchange(vec![assistant_text("delayed")])
759 .build();
760 transport.set_recv_delay(Duration::from_millis(200));
768 let transport = Arc::new(transport);
769
770 let config = ClientConfig::builder()
771 .prompt("test")
772 .connect_timeout(Some(Duration::from_secs(5)))
773 .read_timeout(Some(Duration::from_millis(50)))
774 .build();
775
776 let mut client = Client::with_transport(config, transport).unwrap();
777 client.connect().await.unwrap();
778
779 let stream = client.send("hello").unwrap();
780 tokio::pin!(stream);
781
782 let mut got_timeout = false;
783 while let Some(msg) = stream.next().await {
784 if let Err(Error::Timeout(_)) = msg {
785 got_timeout = true;
786 break;
787 }
788 }
789 assert!(got_timeout, "expected a timeout error");
790 }
791
792 #[cfg(feature = "testing")]
793 #[tokio::test]
794 async fn client_permission_callback_invoked_and_responds() {
795 use crate::permissions::{CanUseToolCallback, PermissionDecision};
796 use crate::testing::MockTransport;
797 use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
798
799 let invoked = Arc::new(AtomicBool::new(false));
800 let invoked_clone = Arc::clone(&invoked);
801
802 let callback: CanUseToolCallback = Arc::new(move |tool_name: &str, _input, _ctx| {
803 let invoked = Arc::clone(&invoked_clone);
804 let tool = tool_name.to_owned();
805 Box::pin(async move {
806 invoked.store(true, AtomicOrdering::Release);
807 assert_eq!(tool, "Bash");
808 PermissionDecision::allow()
809 })
810 });
811
812 let config = ClientConfig::builder()
813 .prompt("test")
814 .can_use_tool(callback)
815 .build();
816
817 let transport = MockTransport::new();
818 transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
820 transport.enqueue(r#"{"type":"permission_request","request_id":"perm-1","request":{"type":"permission_request","tool_name":"Bash","tool_input":{"command":"ls"},"tool_use_id":"tu-1","suggestions":["allow_once"]}}"#);
821 transport.enqueue(&serde_json::to_string(&crate::testing::assistant_text("done")).unwrap());
822 transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
823 let transport = Arc::new(transport);
824
825 let mut client = Client::with_transport(config, transport.clone()).unwrap();
826 client.connect().await.unwrap();
827
828 let stream = client.send("hello").unwrap();
829 tokio::pin!(stream);
830 let mut messages = Vec::new();
831 while let Some(msg) = stream.next().await {
832 messages.push(msg.unwrap());
833 }
834
835 assert!(
837 invoked.load(AtomicOrdering::Acquire),
838 "permission callback was not invoked"
839 );
840
841 assert_eq!(messages.len(), 2);
843 assert!(matches!(&messages[0], Message::Assistant(_)));
844 assert!(matches!(&messages[1], Message::Result(_)));
845
846 let written = transport.written_lines();
848 let perm_responses: Vec<_> = written
849 .iter()
850 .filter(|line| line.contains("permission_response"))
851 .collect();
852 assert_eq!(
853 perm_responses.len(),
854 1,
855 "expected exactly one permission_response written"
856 );
857 let resp: serde_json::Value = serde_json::from_str(perm_responses[0]).unwrap();
858 assert_eq!(resp["kind"], "permission_response");
859 assert_eq!(resp["request_id"], "perm-1");
860 assert_eq!(resp["result"]["type"], "allow");
861 }
862
863 #[cfg(feature = "testing")]
864 #[tokio::test]
865 async fn client_permission_callback_deny_writes_deny_response() {
866 use crate::permissions::{CanUseToolCallback, PermissionDecision};
867 use crate::testing::MockTransport;
868
869 let callback: CanUseToolCallback = Arc::new(|_tool_name, _input, _ctx| {
870 Box::pin(async { PermissionDecision::deny("not allowed") })
871 });
872
873 let config = ClientConfig::builder()
874 .prompt("test")
875 .can_use_tool(callback)
876 .build();
877
878 let transport = MockTransport::new();
879 transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
880 transport.enqueue(r#"{"type":"permission_request","request_id":"perm-2","request":{"type":"permission_request","tool_name":"Write","tool_input":{"path":"/etc/shadow"},"tool_use_id":"tu-2","suggestions":[]}}"#);
881 transport
882 .enqueue(&serde_json::to_string(&crate::testing::assistant_text("denied")).unwrap());
883 transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
884 let transport = Arc::new(transport);
885
886 let mut client = Client::with_transport(config, transport.clone()).unwrap();
887 client.connect().await.unwrap();
888
889 let stream = client.send("hello").unwrap();
890 tokio::pin!(stream);
891 let mut messages = Vec::new();
892 while let Some(msg) = stream.next().await {
893 messages.push(msg.unwrap());
894 }
895
896 let written = transport.written_lines();
898 let perm_responses: Vec<_> = written
899 .iter()
900 .filter(|line| line.contains("permission_response"))
901 .collect();
902 assert_eq!(perm_responses.len(), 1);
903 let resp: serde_json::Value = serde_json::from_str(perm_responses[0]).unwrap();
904 assert_eq!(resp["kind"], "permission_response");
905 assert_eq!(resp["request_id"], "perm-2");
906 assert_eq!(resp["result"]["type"], "deny");
907 assert_eq!(resp["result"]["message"], "not allowed");
908 }
909
910 #[cfg(feature = "testing")]
911 #[tokio::test]
912 async fn client_permission_request_without_callback_yields_error() {
913 use crate::testing::MockTransport;
914
915 let config = ClientConfig::builder().prompt("test").build();
917
918 let transport = MockTransport::new();
919 transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
920 transport.enqueue(r#"{"type":"permission_request","request_id":"perm-3","request":{"type":"permission_request","tool_name":"Bash","tool_input":{"command":"ls"},"tool_use_id":"tu-3","suggestions":[]}}"#);
921 transport
922 .enqueue(&serde_json::to_string(&crate::testing::assistant_text("after")).unwrap());
923 transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
924 let transport = Arc::new(transport);
925
926 let mut client = Client::with_transport(config, transport).unwrap();
927 client.connect().await.unwrap();
928
929 let stream = client.send("hello").unwrap();
930 tokio::pin!(stream);
931
932 let mut got_error = false;
933 let mut messages = Vec::new();
934 while let Some(result) = stream.next().await {
935 match result {
936 Ok(msg) => messages.push(msg),
937 Err(Error::ControlProtocol(ref msg)) if msg.contains("can_use_tool") => {
938 got_error = true;
939 }
940 Err(e) => panic!("unexpected error: {e}"),
941 }
942 }
943
944 assert!(
945 got_error,
946 "should have received a ControlProtocol error for missing callback"
947 );
948 }
949
950 #[cfg(feature = "testing")]
951 #[tokio::test]
952 async fn client_read_timeout_none_waits() {
953 let transport = ScenarioBuilder::new("s1")
955 .exchange(vec![assistant_text("delayed")])
956 .build();
957 transport.set_recv_delay(Duration::from_millis(50));
958 let transport = Arc::new(transport);
959
960 let config = ClientConfig::builder()
961 .prompt("test")
962 .read_timeout(None)
963 .build();
964
965 let mut client = Client::with_transport(config, transport).unwrap();
966 client.connect().await.unwrap();
967
968 let stream = client.send("hello").unwrap();
969 tokio::pin!(stream);
970
971 let mut messages = Vec::new();
972 while let Some(msg) = stream.next().await {
973 messages.push(msg.unwrap());
974 }
975
976 assert_eq!(messages.len(), 2);
978 }
979}