1use std::sync::Arc;
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::time::Duration;
18
19use dashmap::DashMap;
20use futures_core::Stream;
21use tokio::sync::oneshot;
22use tokio_stream::StreamExt;
23use tokio_util::sync::CancellationToken;
24
25use crate::callback::apply_callback;
26use crate::config::{ClientConfig, PermissionMode};
27use crate::errors::{Error, Result};
28use crate::transport::{CliTransport, Transport};
29use crate::types::content::UserContent;
30use crate::types::messages::{Message, SessionInfo};
31
32pub(crate) async fn cancelled_or_pending(token: Option<&CancellationToken>) {
39 match token {
40 Some(t) => t.cancelled().await,
41 None => std::future::pending().await,
42 }
43}
44
45pub(crate) async fn recv_with_timeout(
53 rx: &flume::Receiver<Result<Message>>,
54 timeout: Option<Duration>,
55 cancel: Option<&CancellationToken>,
56) -> Result<Message> {
57 let recv_fut = rx.recv_async();
58 tokio::select! {
59 biased;
60 _ = cancelled_or_pending(cancel) => {
61 Err(Error::Cancelled)
62 }
63 result = async {
64 match timeout {
65 Some(d) => match tokio::time::timeout(d, recv_fut).await {
66 Ok(Ok(msg)) => msg,
67 Ok(Err(_)) => Err(Error::Transport("message channel closed".into())),
68 Err(_) => Err(Error::Timeout(format!("read timed out after {}s", d.as_secs_f64()))),
69 },
70 None => match recv_fut.await {
71 Ok(msg) => msg,
72 Err(_) => Err(Error::Transport("message channel closed".into())),
73 },
74 }
75 } => result,
76 }
77}
78
79fn build_control_response_success(
82 request_id: &str,
83 response: serde_json::Value,
84) -> serde_json::Value {
85 serde_json::json!({
86 "type": "control_response",
87 "response": {
88 "subtype": "success",
89 "request_id": request_id,
90 "response": response
91 }
92 })
93}
94
95fn build_control_response_error(request_id: &str, error: impl Into<String>) -> serde_json::Value {
96 serde_json::json!({
97 "type": "control_response",
98 "response": {
99 "subtype": "error",
100 "request_id": request_id,
101 "error": error.into()
102 }
103 })
104}
105
106async fn write_json_line(transport: &dyn Transport, value: serde_json::Value) {
107 if let Ok(json) = serde_json::to_string(&value) {
108 let _ = transport.write(&json).await;
109 }
110}
111
112fn read_turn_stream<'a>(
117 rx: &'a flume::Receiver<Result<Message>>,
118 read_timeout: Option<Duration>,
119 turn_flag: Arc<AtomicBool>,
120 cancel: Option<CancellationToken>,
121) -> impl Stream<Item = Result<Message>> + 'a {
122 async_stream::stream! {
123 loop {
124 match recv_with_timeout(rx, read_timeout, cancel.as_ref()).await {
125 Ok(msg) => {
126 let is_result = matches!(&msg, Message::Result(_));
127 yield Ok(msg);
128 if is_result {
129 break;
130 }
131 }
132 Err(e) => {
133 yield Err(e);
134 break;
135 }
136 }
137 }
138 turn_flag.store(false, Ordering::Release);
139 }
140}
141
142pub struct Client {
153 config: ClientConfig,
154 transport: Arc<dyn Transport>,
155 session_id: Option<String>,
156 message_rx: Option<flume::Receiver<Result<Message>>>,
157 shutdown_tx: Option<oneshot::Sender<()>>,
158 turn_active: Arc<AtomicBool>,
159 pending_control: Arc<DashMap<String, oneshot::Sender<serde_json::Value>>>,
161 request_counter: Arc<AtomicU64>,
163}
164
165impl std::fmt::Debug for Client {
166 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167 f.debug_struct("Client")
168 .field("session_id", &self.session_id)
169 .field("connected", &self.is_connected())
170 .finish_non_exhaustive()
171 }
172}
173
174impl Client {
175 pub fn new(config: ClientConfig) -> Result<Self> {
185 config.validate()?;
186 let transport = Arc::new(CliTransport::from_config(&config)?);
187 Ok(Self {
188 config,
189 transport,
190 session_id: None,
191 message_rx: None,
192 shutdown_tx: None,
193 turn_active: Arc::new(AtomicBool::new(false)),
194 pending_control: Arc::new(DashMap::new()),
195 request_counter: Arc::new(AtomicU64::new(0)),
196 })
197 }
198
199 pub fn with_transport(config: ClientConfig, transport: Arc<dyn Transport>) -> Result<Self> {
201 config.validate()?;
202 Ok(Self {
203 config,
204 transport,
205 session_id: None,
206 message_rx: None,
207 shutdown_tx: None,
208 turn_active: Arc::new(AtomicBool::new(false)),
209 pending_control: Arc::new(DashMap::new()),
210 request_counter: Arc::new(AtomicU64::new(0)),
211 })
212 }
213
214 pub async fn connect(&mut self) -> Result<SessionInfo> {
220 let timeout = self.config.connect_timeout;
221 let result = match timeout {
222 Some(d) => tokio::time::timeout(d, self.connect_inner())
223 .await
224 .map_err(|_| {
225 Error::Timeout(format!("connect timed out after {}s", d.as_secs_f64()))
226 })?,
227 None => self.connect_inner().await,
228 };
229 if result.is_err() {
230 if let Some(tx) = self.shutdown_tx.take() {
232 let _ = tx.send(());
233 }
234 self.message_rx.take();
235 let _ = self.transport.close().await;
236 }
237 result
238 }
239
240 async fn connect_inner(&mut self) -> Result<SessionInfo> {
241 self.transport.connect().await?;
242
243 let (msg_tx, msg_rx) = flume::bounded(1024);
245 let (shutdown_tx, shutdown_rx) = oneshot::channel();
246
247 let transport = Arc::clone(&self.transport);
248 let message_callback = self.config.message_callback.clone();
249 let pending_control = Arc::clone(&self.pending_control);
250
251 let hooks: Vec<crate::hooks::HookMatcher> = std::mem::take(&mut self.config.hooks);
253 let default_hook_timeout = self.config.default_hook_timeout;
254 let hook_transport = Arc::clone(&self.transport);
255
256 let can_use_tool = self.config.can_use_tool.clone();
258 let perm_transport = Arc::clone(&self.transport);
259
260 let cancel_token = self.config.cancellation_token.clone();
262
263 let shared_session_id: Arc<std::sync::Mutex<Option<String>>> =
267 Arc::new(std::sync::Mutex::new(None));
268 let hook_session_id = Arc::clone(&shared_session_id);
269
270 tokio::spawn(async move {
272 let mut stream = transport.read_messages();
273 let mut shutdown = shutdown_rx;
274
275 loop {
276 tokio::select! {
277 biased;
278 _ = &mut shutdown => break,
279 _ = cancelled_or_pending(cancel_token.as_ref()) => break,
280 item = stream.next() => {
281 match item {
282 Some(Ok(value)) => {
283 if value.get("type").and_then(|v| v.as_str()) == Some("control_response") {
287 let req_id = value.get("request_id")
288 .and_then(|v| v.as_str())
289 .or_else(|| value.pointer("/response/request_id")
290 .and_then(|v| v.as_str()));
291 if let Some(rid) = req_id {
292 if let Some((_, tx)) = pending_control.remove(rid) {
293 let _ = tx.send(value);
294 }
295 }
296 continue;
297 }
298
299 if value.get("type").and_then(|v| v.as_str()) == Some("hook_request") {
301 if let Ok(req) = serde_json::from_value::<crate::hooks::HookRequest>(value) {
302 let sid = hook_session_id
303 .lock()
304 .expect("session_id lock")
305 .clone();
306 let output = crate::hooks::dispatch_hook(
307 &req,
308 &hooks,
309 default_hook_timeout,
310 sid,
311 ).await;
312 let response = crate::hooks::HookResponse::from_output(
313 req.request_id,
314 output,
315 );
316 if let Ok(json) = serde_json::to_string(&response) {
317 let _ = hook_transport.write(&json).await;
318 }
319 }
320 continue;
321 }
322
323 if value.get("type").and_then(|v| v.as_str()) == Some("control_request") {
330 let request_id = value.get("request_id")
331 .and_then(|v| v.as_str())
332 .unwrap_or("")
333 .to_string();
334 if request_id.is_empty() {
335 let _ = msg_tx.send(Err(Error::ControlProtocol(
336 "received control_request without request_id".into(),
337 )));
338 continue;
339 }
340 let subtype = value.pointer("/request/subtype")
341 .and_then(|v| v.as_str())
342 .unwrap_or("");
343
344 match subtype {
345 "can_use_tool" => {
346 let request = value.get("request").cloned()
347 .unwrap_or_default();
348 let tool_name = request.get("tool_name")
349 .and_then(|v| v.as_str())
350 .unwrap_or("")
351 .to_string();
352 let tool_input = request.get("input")
353 .cloned()
354 .unwrap_or(serde_json::Value::Null);
355 let tool_use_id = request.get("tool_use_id")
356 .and_then(|v| v.as_str())
357 .unwrap_or("")
358 .to_string();
359 let suggestions: Vec<String> = request
360 .get("permission_suggestions")
361 .and_then(|v| v.as_array())
362 .map(|arr| arr.iter()
363 .filter_map(|v| v.as_str().map(String::from))
364 .collect())
365 .unwrap_or_default();
366
367 if let Some(ref callback) = can_use_tool {
368 let sid = hook_session_id
369 .lock()
370 .expect("session_id lock")
371 .clone()
372 .unwrap_or_default();
373 let ctx = crate::permissions::PermissionContext {
374 tool_use_id,
375 session_id: sid,
376 request_id: request_id.clone(),
377 suggestions,
378 };
379 let decision = callback(&tool_name, &tool_input, ctx).await;
380
381 let response_data = match decision {
383 crate::permissions::PermissionDecision::Allow { updated_input } => {
384 let input = updated_input.unwrap_or(tool_input);
385 serde_json::json!({
386 "behavior": "allow",
387 "updatedInput": input
388 })
389 }
390 crate::permissions::PermissionDecision::Deny { message, interrupt } => {
391 let mut d = serde_json::json!({
392 "behavior": "deny",
393 "message": message
394 });
395 if interrupt {
396 d["interrupt"] = serde_json::json!(true);
397 }
398 d
399 }
400 };
401 let response =
402 build_control_response_success(&request_id, response_data);
403 write_json_line(&*perm_transport, response).await;
404 } else {
405 let response = build_control_response_error(
406 &request_id,
407 "no can_use_tool callback configured",
408 );
409 write_json_line(&*perm_transport, response).await;
410 let _ = msg_tx.send(Err(Error::ControlProtocol(
411 "received can_use_tool control_request but no \
412 callback is configured"
413 .into(),
414 )));
415 }
416 }
417 "hook_callback" => {
418 let request = value.get("request").cloned()
419 .unwrap_or_default();
420 let hook_event_str = request.get("hook_event_name")
421 .and_then(|v| v.as_str())
422 .unwrap_or("");
423 let hook_tool_name = request.get("tool_name")
424 .and_then(|v| v.as_str())
425 .map(String::from);
426 let hook_tool_input = request.get("tool_input").cloned();
427 let hook_tool_result = request.get("tool_result").cloned();
428 let hook_tool_use_id = request.get("tool_use_id")
429 .and_then(|v| v.as_str())
430 .map(String::from);
431
432 let hook_event = match hook_event_str {
434 "PreToolUse" => Some(crate::hooks::HookEvent::PreToolUse),
435 "PostToolUse" => Some(crate::hooks::HookEvent::PostToolUse),
436 "PostToolUseFailure" => Some(crate::hooks::HookEvent::PostToolUseFailure),
437 "UserPromptSubmit" => Some(crate::hooks::HookEvent::UserPromptSubmit),
438 "Stop" => Some(crate::hooks::HookEvent::Stop),
439 "SubagentStop" => Some(crate::hooks::HookEvent::SubagentStop),
440 "PreCompact" => Some(crate::hooks::HookEvent::PreCompact),
441 "Notification" => Some(crate::hooks::HookEvent::Notification),
442 _ => None,
443 };
444
445 if let Some(event) = hook_event {
446 let req = crate::hooks::HookRequest {
447 request_id: request_id.clone(),
448 hook_event: event,
449 tool_name: hook_tool_name,
450 tool_input: hook_tool_input,
451 tool_result: hook_tool_result,
452 tool_use_id: hook_tool_use_id,
453 };
454 let sid = hook_session_id
455 .lock()
456 .expect("session_id lock")
457 .clone();
458 let output = crate::hooks::dispatch_hook(
459 &req,
460 &hooks,
461 default_hook_timeout,
462 sid,
463 ).await;
464
465 let response_data = match output.decision {
467 crate::hooks::HookDecision::Allow => {
468 serde_json::json!({"continue_": true})
469 }
470 crate::hooks::HookDecision::Block => {
471 serde_json::json!({
472 "continue_": false,
473 "reason": output.reason.unwrap_or_default()
474 })
475 }
476 crate::hooks::HookDecision::Modify => {
477 let mut d = serde_json::json!({"continue_": true});
478 if let Some(input) = output.updated_input {
479 d["updatedInput"] = input;
480 }
481 d
482 }
483 crate::hooks::HookDecision::Abort => {
484 serde_json::json!({
485 "continue_": false,
486 "reason": output.reason.unwrap_or_default()
487 })
488 }
489 };
490 let response =
491 build_control_response_success(&request_id, response_data);
492 write_json_line(&*hook_transport, response).await;
493 } else {
494 let response = build_control_response_success(
497 &request_id,
498 serde_json::json!({"continue_": true}),
499 );
500 write_json_line(&*hook_transport, response).await;
501 let _ = msg_tx.send(Err(Error::ControlProtocol(
502 format!(
503 "received unknown hook_event_name: {hook_event_str}"
504 ),
505 )));
506 }
507 }
508 _ => {
509 let response = build_control_response_error(
511 &request_id,
512 format!("unknown control_request subtype: {subtype}"),
513 );
514 write_json_line(&*perm_transport, response).await;
515 }
516 }
517 continue;
518 }
519
520 if value.get("type").and_then(|v| v.as_str()) == Some("permission_request") {
522 if let Some(ref callback) = can_use_tool {
523 if let Ok(req) = serde_json::from_value::<crate::permissions::ControlRequest>(value) {
524 let crate::permissions::ControlRequestData::PermissionRequest {
525 ref tool_name,
526 ref tool_input,
527 ref tool_use_id,
528 ref suggestions,
529 } = req.request;
530 let sid = hook_session_id
531 .lock()
532 .expect("session_id lock")
533 .clone()
534 .unwrap_or_default();
535 let ctx = crate::permissions::PermissionContext {
536 tool_use_id: tool_use_id.clone(),
537 session_id: sid,
538 request_id: req.request_id.clone(),
539 suggestions: suggestions.clone(),
540 };
541 let decision = callback(tool_name, tool_input, ctx).await;
542 let response = crate::permissions::ControlResponse {
543 kind: "permission_response".into(),
544 request_id: req.request_id,
545 result: crate::permissions::ControlResponseResult::from(decision),
546 };
547 if let Ok(json) = serde_json::to_string(&response) {
548 let _ = perm_transport.write(&json).await;
549 }
550 }
551 } else {
552 let deny_response = serde_json::json!({
553 "kind": "permission_response",
554 "request_id": value.get("request_id")
555 .and_then(|v| v.as_str())
556 .unwrap_or(""),
557 "result": {
558 "type": "deny",
559 "message": "no permission callback configured"
560 }
561 });
562 if let Ok(json) = serde_json::to_string(&deny_response) {
563 let _ = perm_transport.write(&json).await;
564 }
565 let _ = msg_tx.send(Err(Error::ControlProtocol(
566 "received permission_request but no can_use_tool \
567 callback is configured"
568 .into(),
569 )));
570 }
571 continue;
572 }
573
574 let msg: Message = match serde_json::from_value(value) {
576 Ok(m) => m,
577 Err(e) => {
578 let _ = msg_tx.send(Err(Error::Json(e)));
579 continue;
580 }
581 };
582
583 let msg = match apply_callback(msg, message_callback.as_ref()) {
585 Some(m) => m,
586 None => continue, };
588
589 if msg_tx.send(Ok(msg)).is_err() {
590 break; }
592 }
593 Some(Err(e)) => {
594 let _ = msg_tx.send(Err(e));
595 }
596 None => break, }
598 }
599 }
600 }
601 });
602
603 self.message_rx = Some(msg_rx);
604 self.shutdown_tx = Some(shutdown_tx);
605
606 if let Some(ref msg) = self.config.init_stdin_message {
610 self.transport.write(msg).await?;
611 }
612
613 let init_msg = self
617 .message_rx
618 .as_ref()
619 .unwrap()
620 .recv_async()
621 .await
622 .map_err(|_| Error::Transport("connection closed before init message".into()))?
623 .map_err(|e| Error::Transport(format!("error reading init message: {e}")))?;
624
625 if let Message::System(ref sys) = init_msg {
626 let info = SessionInfo::try_from(sys)?;
627 self.session_id = Some(info.session_id.clone());
628 *shared_session_id.lock().expect("session_id lock") = Some(info.session_id.clone());
631 Ok(info)
632 } else {
633 Err(Error::ControlProtocol(format!(
634 "expected system/init as first message, got: {init_msg:?}"
635 )))
636 }
637 }
638
639 pub fn send(
644 &self,
645 prompt: impl Into<String>,
646 ) -> Result<impl Stream<Item = Result<Message>> + '_> {
647 let prompt = prompt.into();
648 let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
649 let transport = Arc::clone(&self.transport);
650
651 if self
653 .turn_active
654 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
655 .is_err()
656 {
657 return Err(Error::ControlProtocol("turn already in progress".into()));
658 }
659 let turn_flag = Arc::clone(&self.turn_active);
660 let read_timeout = self.config.read_timeout;
661 let cancel = self.config.cancellation_token.clone();
662
663 Ok(async_stream::stream! {
664 if let Err(e) = transport.write(&prompt).await {
666 turn_flag.store(false, Ordering::Release);
667 yield Err(e);
668 return;
669 }
670
671 let inner = read_turn_stream(rx, read_timeout, turn_flag, cancel);
672 tokio::pin!(inner);
673 while let Some(item) = inner.next().await {
674 yield item;
675 }
676 })
677 }
678
679 pub fn send_content(
690 &self,
691 content: Vec<UserContent>,
692 ) -> Result<impl Stream<Item = Result<Message>> + '_> {
693 if content.is_empty() {
694 return Err(Error::Config("content must not be empty".into()));
695 }
696
697 let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
698 let transport = Arc::clone(&self.transport);
699
700 if self
702 .turn_active
703 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
704 .is_err()
705 {
706 return Err(Error::ControlProtocol("turn already in progress".into()));
707 }
708 let turn_flag = Arc::clone(&self.turn_active);
709 let read_timeout = self.config.read_timeout;
710 let cancel = self.config.cancellation_token.clone();
711
712 Ok(async_stream::stream! {
713 let user_message = serde_json::json!({
715 "type": "user",
716 "message": {
717 "role": "user",
718 "content": content
719 }
720 });
721 let json = match serde_json::to_string(&user_message) {
722 Ok(j) => j,
723 Err(e) => {
724 turn_flag.store(false, Ordering::Release);
725 yield Err(Error::Json(e));
726 return;
727 }
728 };
729
730 if let Err(e) = transport.write(&json).await {
731 turn_flag.store(false, Ordering::Release);
732 yield Err(e);
733 return;
734 }
735
736 let inner = read_turn_stream(rx, read_timeout, turn_flag, cancel);
737 tokio::pin!(inner);
738 while let Some(item) = inner.next().await {
739 yield item;
740 }
741 })
742 }
743
744 pub fn receive_messages(&self) -> Result<impl Stream<Item = Result<Message>> + '_> {
748 let rx = self.message_rx.as_ref().ok_or(Error::NotConnected)?;
749 let read_timeout = self.config.read_timeout;
750 let cancel = self.config.cancellation_token.clone();
751
752 Ok(async_stream::stream! {
753 loop {
754 match recv_with_timeout(rx, read_timeout, cancel.as_ref()).await {
755 Ok(msg) => yield Ok(msg),
756 Err(e) if matches!(e, Error::Transport(_)) => break, Err(e) => {
758 yield Err(e);
759 break;
760 }
761 }
762 }
763 })
764 }
765
766 pub async fn write_to_stdin(&self, text: &str) -> Result<()> {
773 debug_assert!(
774 !self.turn_active.load(Ordering::Relaxed),
775 "write_to_stdin called while a send() turn is active"
776 );
777 self.transport.write(text).await
778 }
779
780 pub async fn interrupt(&self) -> Result<()> {
782 self.transport.interrupt().await
783 }
784
785 pub async fn respond_to_permission(
790 &self,
791 request_id: &str,
792 decision: crate::permissions::PermissionDecision,
793 ) -> Result<()> {
794 use crate::permissions::{ControlResponse, ControlResponseResult};
795
796 let response = ControlResponse {
797 kind: "permission_response".into(),
798 request_id: request_id.to_string(),
799 result: ControlResponseResult::from(decision),
800 };
801 let json = serde_json::to_string(&response).map_err(Error::Json)?;
802 self.transport.write(&json).await
803 }
804
805 async fn send_control_request(&self, request: serde_json::Value) -> Result<serde_json::Value> {
814 let counter = self.request_counter.fetch_add(1, Ordering::Relaxed);
815 let request_id = format!("sdk_req_{counter}");
816
817 let (tx, rx) = oneshot::channel();
818 self.pending_control.insert(request_id.clone(), tx);
819
820 let envelope = serde_json::json!({
821 "type": "control_request",
822 "request_id": request_id,
823 "request": request
824 });
825 let json = serde_json::to_string(&envelope).map_err(Error::Json)?;
826 self.transport.write(&json).await?;
827
828 let timeout = self.config.control_request_timeout;
829 match tokio::time::timeout(timeout, rx).await {
830 Ok(Ok(value)) => Ok(value),
831 Ok(Err(_)) => {
832 self.pending_control.remove(&request_id);
833 Err(Error::ControlProtocol(
834 "control response channel closed".into(),
835 ))
836 }
837 Err(_) => {
838 self.pending_control.remove(&request_id);
839 Err(Error::Timeout(format!(
840 "control request timed out after {}s",
841 timeout.as_secs_f64()
842 )))
843 }
844 }
845 }
846
847 pub async fn set_model(&self, model: Option<&str>) -> Result<()> {
856 self.send_control_request(serde_json::json!({
857 "subtype": "set_model",
858 "model": model
859 }))
860 .await?;
861 Ok(())
862 }
863
864 pub async fn set_permission_mode(&self, mode: PermissionMode) -> Result<()> {
871 self.send_control_request(serde_json::json!({
872 "subtype": "set_permission_mode",
873 "mode": mode.as_cli_flag()
874 }))
875 .await?;
876 Ok(())
877 }
878
879 pub(crate) async fn transport_write(&self, data: &str) -> Result<()> {
884 self.transport.write(data).await
885 }
886
887 pub(crate) fn take_message_rx(&mut self) -> Option<flume::Receiver<Result<Message>>> {
892 self.message_rx.take()
893 }
894
895 #[must_use]
897 pub fn read_timeout(&self) -> Option<Duration> {
898 self.config.read_timeout
899 }
900
901 #[must_use]
903 pub fn session_id(&self) -> Option<&str> {
904 self.session_id.as_deref()
905 }
906
907 #[must_use]
909 pub fn is_connected(&self) -> bool {
910 self.transport.is_ready()
911 }
912
913 pub async fn close(&mut self) -> Result<Option<i32>> {
918 if let Some(tx) = self.shutdown_tx.take() {
920 let _ = tx.send(());
921 }
922 self.message_rx.take();
924 self.transport.close().await
925 }
926}
927
928impl Drop for Client {
929 fn drop(&mut self) {
930 if self.shutdown_tx.is_some() || self.message_rx.is_some() {
931 tracing::warn!(
932 "claude_cli_sdk::Client dropped without calling close(). \
933 Resources may not be cleaned up properly."
934 );
935 }
936 }
937}
938
939#[cfg(test)]
942mod tests {
943 use super::*;
944 use crate::config::ClientConfig;
945
946 #[cfg(feature = "testing")]
947 use crate::testing::{ScenarioBuilder, assistant_text};
948
949 fn test_config() -> ClientConfig {
950 ClientConfig::builder().prompt("test").build()
951 }
952
953 #[cfg(feature = "testing")]
954 #[tokio::test]
955 async fn client_connect_and_receive_init() {
956 let transport = ScenarioBuilder::new("test-session")
957 .exchange(vec![assistant_text("Hello!")])
958 .build();
959 let transport = Arc::new(transport);
960
961 let mut client = Client::with_transport(test_config(), transport).unwrap();
962 let info = client.connect().await.unwrap();
963
964 assert_eq!(info.session_id, "test-session");
965 assert!(client.is_connected());
966 assert_eq!(client.session_id(), Some("test-session"));
967 }
968
969 #[cfg(feature = "testing")]
970 #[tokio::test]
971 async fn client_send_yields_messages() {
972 let transport = ScenarioBuilder::new("s1")
973 .exchange(vec![assistant_text("response")])
974 .build();
975 let transport = Arc::new(transport);
976
977 let mut client = Client::with_transport(test_config(), transport).unwrap();
978 client.connect().await.unwrap();
979
980 let stream = client.send("hello").unwrap();
981 tokio::pin!(stream);
982
983 let mut messages = Vec::new();
984 while let Some(msg) = stream.next().await {
985 messages.push(msg.unwrap());
986 }
987
988 assert_eq!(messages.len(), 2);
990 assert!(matches!(&messages[0], Message::Assistant(_)));
991 assert!(matches!(&messages[1], Message::Result(_)));
992 }
993
994 #[cfg(feature = "testing")]
995 #[tokio::test]
996 async fn client_close_succeeds() {
997 let transport = ScenarioBuilder::new("s1").build();
998 let transport = Arc::new(transport);
999
1000 let mut client = Client::with_transport(test_config(), transport).unwrap();
1001 client.connect().await.unwrap();
1002 assert!(client.close().await.is_ok());
1003 }
1004
1005 #[cfg(feature = "testing")]
1006 #[tokio::test]
1007 async fn client_message_callback_filters() {
1008 use crate::callback::MessageCallback;
1009
1010 let callback: MessageCallback = Arc::new(|msg| match &msg {
1012 Message::Assistant(_) => None,
1013 _ => Some(msg),
1014 });
1015
1016 let config = ClientConfig::builder()
1017 .prompt("test")
1018 .message_callback(callback)
1019 .build();
1020
1021 let transport = ScenarioBuilder::new("s1")
1022 .exchange(vec![assistant_text("filtered")])
1023 .build();
1024 let transport = Arc::new(transport);
1025
1026 let mut client = Client::with_transport(config, transport).unwrap();
1027 client.connect().await.unwrap();
1028
1029 let stream = client.send("hello").unwrap();
1030 tokio::pin!(stream);
1031
1032 let mut messages = Vec::new();
1033 while let Some(msg) = stream.next().await {
1034 messages.push(msg.unwrap());
1035 }
1036
1037 assert_eq!(messages.len(), 1);
1039 assert!(matches!(&messages[0], Message::Result(_)));
1040 }
1041
1042 #[cfg(feature = "testing")]
1043 #[test]
1044 fn client_debug_before_connect() {
1045 let transport = Arc::new(crate::testing::MockTransport::new());
1046 let client = Client::with_transport(test_config(), transport).unwrap();
1047 let debug = format!("{client:?}");
1048 assert!(debug.contains("Client"));
1049 }
1050
1051 #[cfg(feature = "testing")]
1054 #[tokio::test]
1055 async fn client_connect_timeout_fires() {
1056 use crate::testing::MockTransport;
1057
1058 let transport = MockTransport::new();
1059 transport.set_connect_delay(Duration::from_secs(5));
1061 let transport = Arc::new(transport);
1062
1063 let config = ClientConfig::builder()
1064 .prompt("test")
1065 .connect_timeout(Some(Duration::from_millis(50)))
1066 .build();
1067
1068 let mut client = Client::with_transport(config, transport).unwrap();
1069 let result = client.connect().await;
1070 assert!(result.is_err());
1071 assert!(matches!(result.unwrap_err(), Error::Timeout(_)));
1072 }
1073
1074 #[cfg(feature = "testing")]
1075 #[tokio::test]
1076 async fn client_read_timeout_fires() {
1077 let transport = ScenarioBuilder::new("s1")
1080 .exchange(vec![assistant_text("delayed")])
1081 .build();
1082 transport.set_recv_delay(Duration::from_millis(200));
1090 let transport = Arc::new(transport);
1091
1092 let config = ClientConfig::builder()
1093 .prompt("test")
1094 .connect_timeout(Some(Duration::from_secs(5)))
1095 .read_timeout(Some(Duration::from_millis(50)))
1096 .build();
1097
1098 let mut client = Client::with_transport(config, transport).unwrap();
1099 client.connect().await.unwrap();
1100
1101 let stream = client.send("hello").unwrap();
1102 tokio::pin!(stream);
1103
1104 let mut got_timeout = false;
1105 while let Some(msg) = stream.next().await {
1106 if let Err(Error::Timeout(_)) = msg {
1107 got_timeout = true;
1108 break;
1109 }
1110 }
1111 assert!(got_timeout, "expected a timeout error");
1112 }
1113
1114 #[cfg(feature = "testing")]
1115 #[tokio::test]
1116 async fn client_permission_callback_invoked_and_responds() {
1117 use crate::permissions::{CanUseToolCallback, PermissionDecision};
1118 use crate::testing::MockTransport;
1119 use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
1120
1121 let invoked = Arc::new(AtomicBool::new(false));
1122 let invoked_clone = Arc::clone(&invoked);
1123
1124 let callback: CanUseToolCallback = Arc::new(move |tool_name: &str, _input, _ctx| {
1125 let invoked = Arc::clone(&invoked_clone);
1126 let tool = tool_name.to_owned();
1127 Box::pin(async move {
1128 invoked.store(true, AtomicOrdering::Release);
1129 assert_eq!(tool, "Bash");
1130 PermissionDecision::allow()
1131 })
1132 });
1133
1134 let config = ClientConfig::builder()
1135 .prompt("test")
1136 .can_use_tool(callback)
1137 .build();
1138
1139 let transport = MockTransport::new();
1140 transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
1142 transport.enqueue(r#"{"type":"permission_request","request_id":"perm-1","request":{"type":"permission_request","tool_name":"Bash","tool_input":{"command":"ls"},"tool_use_id":"tu-1","suggestions":["allow_once"]}}"#);
1143 transport.enqueue(&serde_json::to_string(&crate::testing::assistant_text("done")).unwrap());
1144 transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
1145 let transport = Arc::new(transport);
1146
1147 let mut client = Client::with_transport(config, transport.clone()).unwrap();
1148 client.connect().await.unwrap();
1149
1150 let stream = client.send("hello").unwrap();
1151 tokio::pin!(stream);
1152 let mut messages = Vec::new();
1153 while let Some(msg) = stream.next().await {
1154 messages.push(msg.unwrap());
1155 }
1156
1157 assert!(
1159 invoked.load(AtomicOrdering::Acquire),
1160 "permission callback was not invoked"
1161 );
1162
1163 assert_eq!(messages.len(), 2);
1165 assert!(matches!(&messages[0], Message::Assistant(_)));
1166 assert!(matches!(&messages[1], Message::Result(_)));
1167
1168 let written = transport.written_lines();
1170 let perm_responses: Vec<_> = written
1171 .iter()
1172 .filter(|line| line.contains("permission_response"))
1173 .collect();
1174 assert_eq!(
1175 perm_responses.len(),
1176 1,
1177 "expected exactly one permission_response written"
1178 );
1179 let resp: serde_json::Value = serde_json::from_str(perm_responses[0]).unwrap();
1180 assert_eq!(resp["kind"], "permission_response");
1181 assert_eq!(resp["request_id"], "perm-1");
1182 assert_eq!(resp["result"]["type"], "allow");
1183 }
1184
1185 #[cfg(feature = "testing")]
1186 #[tokio::test]
1187 async fn client_permission_callback_deny_writes_deny_response() {
1188 use crate::permissions::{CanUseToolCallback, PermissionDecision};
1189 use crate::testing::MockTransport;
1190
1191 let callback: CanUseToolCallback = Arc::new(|_tool_name, _input, _ctx| {
1192 Box::pin(async { PermissionDecision::deny("not allowed") })
1193 });
1194
1195 let config = ClientConfig::builder()
1196 .prompt("test")
1197 .can_use_tool(callback)
1198 .build();
1199
1200 let transport = MockTransport::new();
1201 transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
1202 transport.enqueue(r#"{"type":"permission_request","request_id":"perm-2","request":{"type":"permission_request","tool_name":"Write","tool_input":{"path":"/etc/shadow"},"tool_use_id":"tu-2","suggestions":[]}}"#);
1203 transport
1204 .enqueue(&serde_json::to_string(&crate::testing::assistant_text("denied")).unwrap());
1205 transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
1206 let transport = Arc::new(transport);
1207
1208 let mut client = Client::with_transport(config, transport.clone()).unwrap();
1209 client.connect().await.unwrap();
1210
1211 let stream = client.send("hello").unwrap();
1212 tokio::pin!(stream);
1213 let mut messages = Vec::new();
1214 while let Some(msg) = stream.next().await {
1215 messages.push(msg.unwrap());
1216 }
1217
1218 let written = transport.written_lines();
1220 let perm_responses: Vec<_> = written
1221 .iter()
1222 .filter(|line| line.contains("permission_response"))
1223 .collect();
1224 assert_eq!(perm_responses.len(), 1);
1225 let resp: serde_json::Value = serde_json::from_str(perm_responses[0]).unwrap();
1226 assert_eq!(resp["kind"], "permission_response");
1227 assert_eq!(resp["request_id"], "perm-2");
1228 assert_eq!(resp["result"]["type"], "deny");
1229 assert_eq!(resp["result"]["message"], "not allowed");
1230 }
1231
1232 #[cfg(feature = "testing")]
1233 #[tokio::test]
1234 async fn client_permission_request_without_callback_yields_error() {
1235 use crate::testing::MockTransport;
1236
1237 let config = ClientConfig::builder().prompt("test").build();
1239
1240 let transport = MockTransport::new();
1241 transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
1242 transport.enqueue(r#"{"type":"permission_request","request_id":"perm-3","request":{"type":"permission_request","tool_name":"Bash","tool_input":{"command":"ls"},"tool_use_id":"tu-3","suggestions":[]}}"#);
1243 transport
1244 .enqueue(&serde_json::to_string(&crate::testing::assistant_text("after")).unwrap());
1245 transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
1246 let transport = Arc::new(transport);
1247
1248 let mut client = Client::with_transport(config, transport).unwrap();
1249 client.connect().await.unwrap();
1250
1251 let stream = client.send("hello").unwrap();
1252 tokio::pin!(stream);
1253
1254 let mut got_error = false;
1255 let mut messages = Vec::new();
1256 while let Some(result) = stream.next().await {
1257 match result {
1258 Ok(msg) => messages.push(msg),
1259 Err(Error::ControlProtocol(ref msg)) if msg.contains("can_use_tool") => {
1260 got_error = true;
1261 }
1262 Err(e) => panic!("unexpected error: {e}"),
1263 }
1264 }
1265
1266 assert!(
1267 got_error,
1268 "should have received a ControlProtocol error for missing callback"
1269 );
1270 }
1271
1272 #[tokio::test]
1273 async fn recv_with_timeout_respects_cancellation_token() {
1274 let (tx, rx) = flume::unbounded::<Result<Message>>();
1275 let token = CancellationToken::new();
1276
1277 token.cancel();
1279
1280 let result = recv_with_timeout(&rx, None, Some(&token)).await;
1281 assert!(result.is_err());
1282 assert!(result.unwrap_err().is_cancelled());
1283
1284 drop(tx);
1286 }
1287
1288 #[tokio::test]
1289 async fn recv_with_timeout_none_cancel_still_works() {
1290 let (_tx, rx) = flume::unbounded::<Result<Message>>();
1291
1292 let result = recv_with_timeout(&rx, Some(Duration::from_millis(10)), None).await;
1294 assert!(matches!(result, Err(Error::Timeout(_))));
1295 }
1296
1297 #[cfg(feature = "testing")]
1300 #[tokio::test]
1301 async fn client_control_request_can_use_tool_allow() {
1302 use crate::permissions::{CanUseToolCallback, PermissionDecision};
1303 use crate::testing::MockTransport;
1304 use std::sync::atomic::{AtomicBool, Ordering as AtomicOrdering};
1305
1306 let invoked = Arc::new(AtomicBool::new(false));
1307 let invoked_clone = Arc::clone(&invoked);
1308
1309 let callback: CanUseToolCallback = Arc::new(move |tool_name: &str, _input, _ctx| {
1310 let invoked = Arc::clone(&invoked_clone);
1311 let tool = tool_name.to_owned();
1312 Box::pin(async move {
1313 invoked.store(true, AtomicOrdering::Release);
1314 assert_eq!(tool, "Bash");
1315 PermissionDecision::allow()
1316 })
1317 });
1318
1319 let config = ClientConfig::builder()
1320 .prompt("test")
1321 .can_use_tool(callback)
1322 .build();
1323
1324 let transport = MockTransport::new();
1325 transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
1326 transport.enqueue(r#"{"type":"control_request","request_id":"cr-1","request":{"subtype":"can_use_tool","tool_name":"Bash","input":{"command":"ls"},"tool_use_id":"tu-1","permission_suggestions":["allow_once"]}}"#);
1328 transport.enqueue(&serde_json::to_string(&crate::testing::assistant_text("done")).unwrap());
1329 transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
1330 let transport = Arc::new(transport);
1331
1332 let mut client = Client::with_transport(config, transport.clone()).unwrap();
1333 client.connect().await.unwrap();
1334
1335 let stream = client.send("hello").unwrap();
1336 tokio::pin!(stream);
1337 let mut messages = Vec::new();
1338 while let Some(msg) = stream.next().await {
1339 messages.push(msg.unwrap());
1340 }
1341
1342 assert!(
1343 invoked.load(AtomicOrdering::Acquire),
1344 "permission callback was not invoked"
1345 );
1346
1347 assert_eq!(messages.len(), 2);
1349 assert!(matches!(&messages[0], Message::Assistant(_)));
1350 assert!(matches!(&messages[1], Message::Result(_)));
1351
1352 let written = transport.written_lines();
1354 let responses: Vec<_> = written
1355 .iter()
1356 .filter(|line| line.contains("control_response"))
1357 .collect();
1358 assert_eq!(responses.len(), 1, "expected exactly one control_response");
1359 let resp: serde_json::Value = serde_json::from_str(responses[0]).unwrap();
1360 assert_eq!(resp["type"], "control_response");
1361 assert_eq!(resp["response"]["subtype"], "success");
1362 assert_eq!(resp["response"]["request_id"], "cr-1");
1363 assert_eq!(resp["response"]["response"]["behavior"], "allow");
1364 }
1365
1366 #[cfg(feature = "testing")]
1367 #[tokio::test]
1368 async fn client_control_request_can_use_tool_deny() {
1369 use crate::permissions::{CanUseToolCallback, PermissionDecision};
1370 use crate::testing::MockTransport;
1371
1372 let callback: CanUseToolCallback = Arc::new(|_tool_name, _input, _ctx| {
1373 Box::pin(async { PermissionDecision::deny("forbidden") })
1374 });
1375
1376 let config = ClientConfig::builder()
1377 .prompt("test")
1378 .can_use_tool(callback)
1379 .build();
1380
1381 let transport = MockTransport::new();
1382 transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
1383 transport.enqueue(r#"{"type":"control_request","request_id":"cr-2","request":{"subtype":"can_use_tool","tool_name":"Write","input":{"path":"/etc/shadow"},"tool_use_id":"tu-2","permission_suggestions":[]}}"#);
1384 transport.enqueue(&serde_json::to_string(&crate::testing::assistant_text("denied")).unwrap());
1385 transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
1386 let transport = Arc::new(transport);
1387
1388 let mut client = Client::with_transport(config, transport.clone()).unwrap();
1389 client.connect().await.unwrap();
1390
1391 let stream = client.send("hello").unwrap();
1392 tokio::pin!(stream);
1393 let mut messages = Vec::new();
1394 while let Some(msg) = stream.next().await {
1395 messages.push(msg.unwrap());
1396 }
1397
1398 let written = transport.written_lines();
1399 let responses: Vec<_> = written
1400 .iter()
1401 .filter(|line| line.contains("control_response"))
1402 .collect();
1403 assert_eq!(responses.len(), 1);
1404 let resp: serde_json::Value = serde_json::from_str(responses[0]).unwrap();
1405 assert_eq!(resp["type"], "control_response");
1406 assert_eq!(resp["response"]["response"]["behavior"], "deny");
1407 assert_eq!(resp["response"]["response"]["message"], "forbidden");
1408 }
1409
1410 #[cfg(feature = "testing")]
1411 #[tokio::test]
1412 async fn client_control_request_no_callback_yields_error() {
1413 use crate::testing::MockTransport;
1414
1415 let config = ClientConfig::builder().prompt("test").build();
1417
1418 let transport = MockTransport::new();
1419 transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
1420 transport.enqueue(r#"{"type":"control_request","request_id":"cr-3","request":{"subtype":"can_use_tool","tool_name":"Bash","input":{"command":"ls"},"tool_use_id":"tu-3","permission_suggestions":[]}}"#);
1421 transport.enqueue(&serde_json::to_string(&crate::testing::assistant_text("after")).unwrap());
1422 transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
1423 let transport = Arc::new(transport);
1424
1425 let mut client = Client::with_transport(config, transport.clone()).unwrap();
1426 client.connect().await.unwrap();
1427
1428 let stream = client.send("hello").unwrap();
1429 tokio::pin!(stream);
1430
1431 let mut got_error = false;
1432 let mut messages = Vec::new();
1433 while let Some(result) = stream.next().await {
1434 match result {
1435 Ok(msg) => messages.push(msg),
1436 Err(Error::ControlProtocol(ref msg)) if msg.contains("can_use_tool") => {
1437 got_error = true;
1438 }
1439 Err(e) => panic!("unexpected error: {e}"),
1440 }
1441 }
1442
1443 assert!(
1444 got_error,
1445 "should have received a ControlProtocol error for missing callback"
1446 );
1447
1448 let written = transport.written_lines();
1450 let responses: Vec<_> = written
1451 .iter()
1452 .filter(|line| line.contains("control_response"))
1453 .collect();
1454 assert_eq!(responses.len(), 1);
1455 let resp: serde_json::Value = serde_json::from_str(responses[0]).unwrap();
1456 assert_eq!(resp["response"]["subtype"], "error");
1457 }
1458
1459 #[cfg(feature = "testing")]
1460 #[tokio::test]
1461 async fn client_control_request_hook_callback() {
1462 use crate::hooks::{HookCallback, HookEvent, HookMatcher, HookOutput};
1463 use crate::testing::MockTransport;
1464
1465 let callback: HookCallback =
1467 Arc::new(|_, _, _| Box::pin(async { HookOutput::allow() }));
1468
1469 let config = ClientConfig::builder()
1470 .prompt("test")
1471 .hooks(vec![HookMatcher::new(HookEvent::PreToolUse, callback)])
1472 .build();
1473
1474 let transport = MockTransport::new();
1475 transport.enqueue(r#"{"type":"system","subtype":"init","session_id":"s1","cwd":"/","tools":[],"mcp_servers":[],"model":"m"}"#);
1476 transport.enqueue(r#"{"type":"control_request","request_id":"hook-1","request":{"subtype":"hook_callback","hook_event_name":"PreToolUse","tool_name":"Bash","tool_input":{"command":"ls"},"tool_use_id":"tu-h1"}}"#);
1477 transport.enqueue(&serde_json::to_string(&crate::testing::assistant_text("hooked")).unwrap());
1478 transport.enqueue(r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"num_turns":1,"usage":{}}"#);
1479 let transport = Arc::new(transport);
1480
1481 let mut client = Client::with_transport(config, transport.clone()).unwrap();
1482 client.connect().await.unwrap();
1483
1484 let stream = client.send("hello").unwrap();
1485 tokio::pin!(stream);
1486 let mut messages = Vec::new();
1487 while let Some(msg) = stream.next().await {
1488 messages.push(msg.unwrap());
1489 }
1490
1491 assert_eq!(messages.len(), 2);
1493 assert!(matches!(&messages[0], Message::Assistant(_)));
1494 assert!(matches!(&messages[1], Message::Result(_)));
1495
1496 let written = transport.written_lines();
1498 let responses: Vec<_> = written
1499 .iter()
1500 .filter(|line| line.contains("control_response"))
1501 .collect();
1502 assert_eq!(responses.len(), 1);
1503 let resp: serde_json::Value = serde_json::from_str(responses[0]).unwrap();
1504 assert_eq!(resp["type"], "control_response");
1505 assert_eq!(resp["response"]["subtype"], "success");
1506 assert_eq!(resp["response"]["request_id"], "hook-1");
1507 assert_eq!(resp["response"]["response"]["continue_"], true);
1509 }
1510
1511 #[cfg(feature = "testing")]
1512 #[tokio::test]
1513 async fn client_read_timeout_none_waits() {
1514 let transport = ScenarioBuilder::new("s1")
1516 .exchange(vec![assistant_text("delayed")])
1517 .build();
1518 transport.set_recv_delay(Duration::from_millis(50));
1519 let transport = Arc::new(transport);
1520
1521 let config = ClientConfig::builder()
1522 .prompt("test")
1523 .read_timeout(None)
1524 .build();
1525
1526 let mut client = Client::with_transport(config, transport).unwrap();
1527 client.connect().await.unwrap();
1528
1529 let stream = client.send("hello").unwrap();
1530 tokio::pin!(stream);
1531
1532 let mut messages = Vec::new();
1533 while let Some(msg) = stream.next().await {
1534 messages.push(msg.unwrap());
1535 }
1536
1537 assert_eq!(messages.len(), 2);
1539 }
1540}