claude_code_agent_sdk/client.rs
1//! ClaudeClient for bidirectional streaming interactions with hook support
2
3use tracing::{debug, info, instrument};
4
5use futures::stream::Stream;
6use std::pin::Pin;
7use std::sync::Arc;
8use tokio::io::AsyncWriteExt;
9use tokio::sync::Mutex;
10
11use crate::errors::{ClaudeError, Result};
12use crate::internal::message_parser::MessageParser;
13use crate::internal::query_full::QueryFull;
14use crate::internal::transport::subprocess::QueryPrompt;
15use crate::internal::transport::{SubprocessTransport, Transport};
16use crate::types::config::{ClaudeAgentOptions, PermissionMode};
17use crate::types::efficiency::{build_efficiency_hooks, merge_hooks};
18use crate::types::hooks::HookEvent;
19use crate::types::messages::{Message, UserContentBlock};
20
21/// Client for bidirectional streaming interactions with Claude
22///
23/// This client provides the same functionality as Python's ClaudeSDKClient,
24/// supporting bidirectional communication, streaming responses, and dynamic
25/// control over the Claude session.
26///
27/// # Example
28///
29/// ```no_run
30/// use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions};
31/// use futures::StreamExt;
32///
33/// #[tokio::main]
34/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
35/// let mut client = ClaudeClient::new(ClaudeAgentOptions::default());
36///
37/// // Connect to Claude
38/// client.connect().await?;
39///
40/// // Send a query
41/// client.query("Hello Claude!").await?;
42///
43/// // Receive response as a stream
44/// {
45/// let mut stream = client.receive_response();
46/// while let Some(message) = stream.next().await {
47/// println!("Received: {:?}", message?);
48/// }
49/// }
50///
51/// // Disconnect
52/// client.disconnect().await?;
53/// Ok(())
54/// }
55/// ```
56pub struct ClaudeClient {
57 options: ClaudeAgentOptions,
58 query: Option<Arc<Mutex<QueryFull>>>,
59 connected: bool,
60}
61
62impl ClaudeClient {
63 /// Create a new ClaudeClient
64 ///
65 /// # Arguments
66 ///
67 /// * `options` - Configuration options for the Claude client
68 ///
69 /// # Example
70 ///
71 /// ```no_run
72 /// use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions};
73 ///
74 /// let client = ClaudeClient::new(ClaudeAgentOptions::default());
75 /// ```
76 pub fn new(options: ClaudeAgentOptions) -> Self {
77 Self {
78 options,
79 query: None,
80 connected: false,
81 }
82 }
83
84 /// Create a new ClaudeClient with early validation
85 ///
86 /// Unlike `new()`, this validates the configuration eagerly by attempting
87 /// to create the transport. This catches issues like invalid working directory
88 /// or missing CLI before `connect()` is called.
89 ///
90 /// # Arguments
91 ///
92 /// * `options` - Configuration options for the Claude client
93 ///
94 /// # Errors
95 ///
96 /// Returns an error if:
97 /// - The working directory does not exist or is not a directory
98 /// - Claude CLI cannot be found
99 ///
100 /// # Example
101 ///
102 /// ```no_run
103 /// use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions};
104 ///
105 /// let client = ClaudeClient::try_new(ClaudeAgentOptions::default())?;
106 /// # Ok::<(), claude_agent_sdk_rs::ClaudeError>(())
107 /// ```
108 pub fn try_new(options: ClaudeAgentOptions) -> Result<Self> {
109 // Validate by attempting to create transport (but don't keep it)
110 let prompt = QueryPrompt::Streaming;
111 let _ = SubprocessTransport::new(prompt, options.clone())?;
112
113 Ok(Self {
114 options,
115 query: None,
116 connected: false,
117 })
118 }
119
120 /// Connect to Claude (analogous to Python's __aenter__)
121 ///
122 /// This establishes the connection to the Claude Code CLI and initializes
123 /// the bidirectional communication channel.
124 ///
125 /// # Errors
126 ///
127 /// Returns an error if:
128 /// - Claude CLI cannot be found or started
129 /// - The initialization handshake fails
130 /// - Hook registration fails
131 /// - `can_use_tool` callback is set with incompatible `permission_prompt_tool_name`
132 #[instrument(
133 name = "claude.client.connect",
134 skip(self),
135 fields(
136 has_can_use_tool = self.options.can_use_tool.is_some(),
137 has_hooks = self.options.hooks.is_some(),
138 model = %self.options.model.as_deref().unwrap_or("default"),
139 )
140 )]
141 pub async fn connect(&mut self) -> Result<()> {
142 if self.connected {
143 debug!("Client already connected, skipping");
144 return Ok(());
145 }
146
147 info!("Connecting to Claude Code CLI");
148
149 // Automatically set permission_prompt_tool_name to "stdio" when can_use_tool is provided
150 // This matches Python SDK behavior (client.py lines 106-122)
151 // which ensures CLI uses control protocol for permission prompts
152 let mut options = self.options.clone();
153 if options.can_use_tool.is_some() && options.permission_prompt_tool_name.is_none() {
154 info!("can_use_tool callback is set, automatically setting permission_prompt_tool_name to 'stdio'");
155 options.permission_prompt_tool_name = Some("stdio".to_string());
156 }
157
158 // Validate can_use_tool configuration (aligned with Python SDK behavior)
159 // When can_use_tool callback is set, permission_prompt_tool_name must be "stdio"
160 // to ensure the control protocol can handle permission requests
161 if options.can_use_tool.is_some()
162 && let Some(ref tool_name) = options.permission_prompt_tool_name
163 && tool_name != "stdio"
164 {
165 return Err(ClaudeError::InvalidConfig(
166 "can_use_tool callback requires permission_prompt_tool_name to be 'stdio' or unset. \
167 Custom permission_prompt_tool_name is incompatible with can_use_tool callback."
168 .to_string(),
169 ));
170 }
171
172 // Create transport in streaming mode (no initial prompt)
173 let prompt = QueryPrompt::Streaming;
174 let mut transport = SubprocessTransport::new(prompt, options)?;
175
176 // Don't send initial prompt - we'll use query() for that
177 transport.connect().await?;
178
179 // Extract stdin for direct access (avoids transport lock deadlock)
180 let stdin = Arc::clone(&transport.stdin);
181
182 // Create Query with hooks
183 let mut query = QueryFull::new(Box::new(transport));
184 query.set_stdin(stdin);
185
186 // Set control request timeout from options
187 query.set_control_request_timeout(self.options.control_request_timeout);
188
189 // Extract SDK MCP servers from options
190 let sdk_mcp_servers =
191 if let crate::types::mcp::McpServers::Dict(servers_dict) = &self.options.mcp_servers {
192 servers_dict
193 .iter()
194 .filter_map(|(name, config)| {
195 if let crate::types::mcp::McpServerConfig::Sdk(sdk_config) = config {
196 Some((name.clone(), sdk_config.clone()))
197 } else {
198 None
199 }
200 })
201 .collect()
202 } else {
203 std::collections::HashMap::new()
204 };
205 query.set_sdk_mcp_servers(sdk_mcp_servers).await;
206
207 // Set can_use_tool callback if provided
208 if let Some(ref callback) = self.options.can_use_tool {
209 query.set_can_use_tool(Some(Arc::clone(callback))).await;
210 }
211
212 // Build efficiency hooks if configured
213 let efficiency_hooks = self
214 .options
215 .efficiency
216 .as_ref()
217 .map(build_efficiency_hooks)
218 .unwrap_or_default();
219
220 // Merge user hooks with efficiency hooks
221 let merged_hooks = merge_hooks(self.options.hooks.clone(), efficiency_hooks);
222
223 // Convert hooks to internal format
224 let hooks = merged_hooks.as_ref().map(|hooks_map| {
225 hooks_map
226 .iter()
227 .map(|(event, matchers)| {
228 let event_name = match event {
229 HookEvent::PreToolUse => "PreToolUse",
230 HookEvent::PostToolUse => "PostToolUse",
231 HookEvent::UserPromptSubmit => "UserPromptSubmit",
232 HookEvent::Stop => "Stop",
233 HookEvent::SubagentStop => "SubagentStop",
234 HookEvent::PreCompact => "PreCompact",
235 };
236 (event_name.to_string(), matchers.clone())
237 })
238 .collect()
239 });
240
241 // Start reading messages in background FIRST
242 // This must happen before initialize() because initialize()
243 // sends a control request and waits for response
244 query.start().await?;
245
246 // Initialize with hooks (sends control request)
247 query.initialize(hooks).await?;
248
249 self.query = Some(Arc::new(Mutex::new(query)));
250 self.connected = true;
251
252 info!("Successfully connected to Claude Code CLI");
253 Ok(())
254 }
255
256 /// Send a query to Claude
257 ///
258 /// This sends a new user prompt to Claude. Claude will remember the context
259 /// of previous queries within the same session.
260 ///
261 /// # Arguments
262 ///
263 /// * `prompt` - The user prompt to send
264 ///
265 /// # Errors
266 ///
267 /// Returns an error if the client is not connected or if sending fails.
268 ///
269 /// # Example
270 ///
271 /// ```no_run
272 /// # use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions};
273 /// # #[tokio::main]
274 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
275 /// # let mut client = ClaudeClient::new(ClaudeAgentOptions::default());
276 /// # client.connect().await?;
277 /// client.query("What is 2 + 2?").await?;
278 /// # Ok(())
279 /// # }
280 /// ```
281 #[instrument(
282 name = "claude.client.query",
283 skip(self, prompt),
284 fields(session_id = "default",)
285 )]
286 pub async fn query(&mut self, prompt: impl Into<String>) -> Result<()> {
287 self.query_with_session(prompt, "default").await
288 }
289
290 /// Send a query to Claude with a specific session ID
291 ///
292 /// This sends a new user prompt to Claude. Different session IDs maintain
293 /// separate conversation contexts.
294 ///
295 /// # Arguments
296 ///
297 /// * `prompt` - The user prompt to send
298 /// * `session_id` - Session identifier for the conversation
299 ///
300 /// # Errors
301 ///
302 /// Returns an error if the client is not connected or if sending fails.
303 ///
304 /// # Example
305 ///
306 /// ```no_run
307 /// # use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions};
308 /// # #[tokio::main]
309 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
310 /// # let mut client = ClaudeClient::new(ClaudeAgentOptions::default());
311 /// # client.connect().await?;
312 /// // Separate conversation contexts
313 /// client.query_with_session("First question", "session-1").await?;
314 /// client.query_with_session("Different question", "session-2").await?;
315 /// # Ok(())
316 /// # }
317 /// ```
318 pub async fn query_with_session(
319 &mut self,
320 prompt: impl Into<String>,
321 session_id: impl Into<String>,
322 ) -> Result<()> {
323 let query = self.query.as_ref().ok_or_else(|| {
324 ClaudeError::InvalidConfig("Client not connected. Call connect() first.".to_string())
325 })?;
326
327 let prompt_str = prompt.into();
328 let session_id_str = session_id.into();
329
330 // Format as JSON message for stream-json input format
331 let user_message = serde_json::json!({
332 "type": "user",
333 "message": {
334 "role": "user",
335 "content": prompt_str
336 },
337 "session_id": session_id_str
338 });
339
340 let message_str = serde_json::to_string(&user_message).map_err(|e| {
341 ClaudeError::Transport(format!("Failed to serialize user message: {}", e))
342 })?;
343
344 // Write directly to stdin (bypasses transport lock)
345 let query_guard = query.lock().await;
346 let stdin = query_guard.stdin.clone();
347 drop(query_guard);
348
349 if let Some(stdin_arc) = stdin {
350 let mut stdin_guard = stdin_arc.lock().await;
351 if let Some(ref mut stdin_stream) = *stdin_guard {
352 stdin_stream
353 .write_all(message_str.as_bytes())
354 .await
355 .map_err(|e| ClaudeError::Transport(format!("Failed to write query: {}", e)))?;
356 stdin_stream.write_all(b"\n").await.map_err(|e| {
357 ClaudeError::Transport(format!("Failed to write newline: {}", e))
358 })?;
359 stdin_stream
360 .flush()
361 .await
362 .map_err(|e| ClaudeError::Transport(format!("Failed to flush: {}", e)))?;
363 } else {
364 return Err(ClaudeError::Transport("stdin not available".to_string()));
365 }
366 } else {
367 return Err(ClaudeError::Transport("stdin not set".to_string()));
368 }
369
370 Ok(())
371 }
372
373 /// Send a query with structured content blocks (supports images)
374 ///
375 /// This method enables multimodal queries in bidirectional streaming mode.
376 /// Use it to send images alongside text for vision-related tasks.
377 ///
378 /// # Arguments
379 ///
380 /// * `content` - A vector of content blocks (text and/or images)
381 ///
382 /// # Errors
383 ///
384 /// Returns an error if:
385 /// - The content vector is empty (must include at least one text or image block)
386 /// - The client is not connected (call `connect()` first)
387 /// - Sending the message fails
388 ///
389 /// # Example
390 ///
391 /// ```no_run
392 /// # use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions, UserContentBlock};
393 /// # #[tokio::main]
394 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
395 /// # let mut client = ClaudeClient::new(ClaudeAgentOptions::default());
396 /// # client.connect().await?;
397 /// let base64_data = "iVBORw0KGgo..."; // base64 encoded image
398 /// client.query_with_content(vec![
399 /// UserContentBlock::text("What's in this image?"),
400 /// UserContentBlock::image_base64("image/png", base64_data)?,
401 /// ]).await?;
402 /// # Ok(())
403 /// # }
404 /// ```
405 pub async fn query_with_content(
406 &mut self,
407 content: impl Into<Vec<UserContentBlock>>,
408 ) -> Result<()> {
409 self.query_with_content_and_session(content, "default")
410 .await
411 }
412
413 /// Send a query with structured content blocks and a specific session ID
414 ///
415 /// This method enables multimodal queries with session management for
416 /// maintaining separate conversation contexts.
417 ///
418 /// # Arguments
419 ///
420 /// * `content` - A vector of content blocks (text and/or images)
421 /// * `session_id` - Session identifier for the conversation
422 ///
423 /// # Errors
424 ///
425 /// Returns an error if:
426 /// - The content vector is empty (must include at least one text or image block)
427 /// - The client is not connected (call `connect()` first)
428 /// - Sending the message fails
429 ///
430 /// # Example
431 ///
432 /// ```no_run
433 /// # use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions, UserContentBlock};
434 /// # #[tokio::main]
435 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
436 /// # let mut client = ClaudeClient::new(ClaudeAgentOptions::default());
437 /// # client.connect().await?;
438 /// client.query_with_content_and_session(
439 /// vec![
440 /// UserContentBlock::text("Analyze this chart"),
441 /// UserContentBlock::image_url("https://example.com/chart.png"),
442 /// ],
443 /// "analysis-session",
444 /// ).await?;
445 /// # Ok(())
446 /// # }
447 /// ```
448 pub async fn query_with_content_and_session(
449 &mut self,
450 content: impl Into<Vec<UserContentBlock>>,
451 session_id: impl Into<String>,
452 ) -> Result<()> {
453 let query = self.query.as_ref().ok_or_else(|| {
454 ClaudeError::InvalidConfig("Client not connected. Call connect() first.".to_string())
455 })?;
456
457 let content_blocks: Vec<UserContentBlock> = content.into();
458 UserContentBlock::validate_content(&content_blocks)?;
459
460 let session_id_str = session_id.into();
461
462 // Format as JSON message for stream-json input format
463 // Content is an array of content blocks, not a simple string
464 let user_message = serde_json::json!({
465 "type": "user",
466 "message": {
467 "role": "user",
468 "content": content_blocks
469 },
470 "session_id": session_id_str
471 });
472
473 let message_str = serde_json::to_string(&user_message).map_err(|e| {
474 ClaudeError::Transport(format!("Failed to serialize user message: {}", e))
475 })?;
476
477 // Write directly to stdin (bypasses transport lock)
478 let query_guard = query.lock().await;
479 let stdin = query_guard.stdin.clone();
480 drop(query_guard);
481
482 if let Some(stdin_arc) = stdin {
483 let mut stdin_guard = stdin_arc.lock().await;
484 if let Some(ref mut stdin_stream) = *stdin_guard {
485 stdin_stream
486 .write_all(message_str.as_bytes())
487 .await
488 .map_err(|e| ClaudeError::Transport(format!("Failed to write query: {}", e)))?;
489 stdin_stream.write_all(b"\n").await.map_err(|e| {
490 ClaudeError::Transport(format!("Failed to write newline: {}", e))
491 })?;
492 stdin_stream
493 .flush()
494 .await
495 .map_err(|e| ClaudeError::Transport(format!("Failed to flush: {}", e)))?;
496 } else {
497 return Err(ClaudeError::Transport("stdin not available".to_string()));
498 }
499 } else {
500 return Err(ClaudeError::Transport("stdin not set".to_string()));
501 }
502
503 Ok(())
504 }
505
506 /// Receive all messages as a stream (continuous)
507 ///
508 /// This method returns a stream that yields all messages from Claude
509 /// indefinitely until the stream is closed or an error occurs.
510 ///
511 /// Use this when you want to process all messages, including multiple
512 /// responses and system events.
513 ///
514 /// # Returns
515 ///
516 /// A stream of `Result<Message>` that continues until the connection closes.
517 ///
518 /// # Example
519 ///
520 /// ```no_run
521 /// # use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions};
522 /// # use futures::StreamExt;
523 /// # #[tokio::main]
524 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
525 /// # let mut client = ClaudeClient::new(ClaudeAgentOptions::default());
526 /// # client.connect().await?;
527 /// # client.query("Hello").await?;
528 /// let mut stream = client.receive_messages();
529 /// while let Some(message) = stream.next().await {
530 /// println!("Received: {:?}", message?);
531 /// }
532 /// # Ok(())
533 /// # }
534 /// ```
535 pub fn receive_messages(&self) -> Pin<Box<dyn Stream<Item = Result<Message>> + Send + '_>> {
536 let query = match &self.query {
537 Some(q) => Arc::clone(q),
538 None => {
539 return Box::pin(futures::stream::once(async {
540 Err(ClaudeError::InvalidConfig(
541 "Client not connected. Call connect() first.".to_string(),
542 ))
543 }));
544 }
545 };
546
547 Box::pin(async_stream::stream! {
548 let rx: Arc<Mutex<tokio::sync::mpsc::UnboundedReceiver<serde_json::Value>>> = {
549 let query_guard = query.lock().await;
550 Arc::clone(&query_guard.message_rx)
551 };
552
553 loop {
554 let message = {
555 let mut rx_guard = rx.lock().await;
556 rx_guard.recv().await
557 };
558
559 match message {
560 Some(json) => {
561 match MessageParser::parse(json) {
562 Ok(msg) => yield Ok(msg),
563 Err(e) => {
564 eprintln!("Failed to parse message: {}", e);
565 yield Err(e);
566 }
567 }
568 }
569 None => break,
570 }
571 }
572 })
573 }
574
575 /// Receive messages until a ResultMessage
576 ///
577 /// This method returns a stream that yields messages until it encounters
578 /// a `ResultMessage`, which signals the completion of a Claude response.
579 ///
580 /// This is the most common pattern for handling Claude responses, as it
581 /// processes one complete "turn" of the conversation.
582 ///
583 /// This method uses query-scoped message channels to ensure message isolation,
584 /// preventing late-arriving ResultMessages from being consumed by the wrong prompt.
585 ///
586 /// # Returns
587 ///
588 /// A stream of `Result<Message>` that ends when a ResultMessage is received.
589 ///
590 /// # Example
591 ///
592 /// ```no_run
593 /// # use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions, Message};
594 /// # use futures::StreamExt;
595 /// # #[tokio::main]
596 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
597 /// # let mut client = ClaudeClient::new(ClaudeAgentOptions::default());
598 /// # client.connect().await?;
599 /// # client.query("Hello").await?;
600 /// let mut stream = client.receive_response();
601 /// while let Some(message) = stream.next().await {
602 /// match message? {
603 /// Message::Assistant(msg) => println!("Assistant: {:?}", msg),
604 /// Message::Result(result) => {
605 /// println!("Done! Cost: ${:?}", result.total_cost_usd);
606 /// break;
607 /// }
608 /// _ => {}
609 /// }
610 /// }
611 /// # Ok(())
612 /// # }
613 /// ```
614 pub fn receive_response(&self) -> Pin<Box<dyn Stream<Item = Result<Message>> + Send + '_>> {
615 let query = match &self.query {
616 Some(q) => Arc::clone(q),
617 None => {
618 return Box::pin(futures::stream::once(async {
619 Err(ClaudeError::InvalidConfig(
620 "Client not connected. Call connect() first.".to_string(),
621 ))
622 }));
623 }
624 };
625
626 Box::pin(async_stream::stream! {
627 // ====================================================================
628 // QUERY-SCOPED MESSAGE CHANNEL
629 // ====================================================================
630 // Create an isolated message channel for this specific query.
631 // This ensures we only receive messages intended for this prompt,
632 // preventing late-arriving ResultMessages from being consumed
633 // by the wrong query.
634 //
635 // Note: Cleanup is handled by the periodic cleanup task in QueryFull,
636 // which removes stale receivers whose senders have been closed.
637
638 let query_id = {
639 let query_guard = query.lock().await;
640 query_guard.generate_query_id()
641 };
642
643 debug!(
644 query_id = %query_id,
645 "Creating query-scoped receiver"
646 );
647
648 let mut rx = {
649 let query_guard = query.lock().await;
650 query_guard.register_query_receiver(query_id.clone()).await
651 };
652
653 loop {
654 let message = rx.recv().await;
655
656 match message {
657 Some(json) => {
658 match MessageParser::parse(json) {
659 Ok(msg) => {
660 let is_result = matches!(msg, Message::Result(_));
661 yield Ok(msg);
662 if is_result {
663 debug!(
664 query_id = %query_id,
665 "Received ResultMessage, ending stream"
666 );
667 // Cleanup will be handled by the periodic cleanup task
668 break;
669 }
670 }
671 Err(e) => {
672 eprintln!("Failed to parse message: {}", e);
673 yield Err(e);
674 }
675 }
676 }
677 None => {
678 debug!(
679 query_id = %query_id,
680 "Query-scoped channel closed"
681 );
682 // Cleanup will be handled by the periodic cleanup task
683 break;
684 }
685 }
686 }
687 })
688 }
689
690 /// Drain any leftover messages from the previous prompt
691 ///
692 /// This method removes any messages remaining in the channel from a previous
693 /// prompt. This should be called before starting a new prompt to ensure
694 /// that the new prompt doesn't receive stale messages.
695 ///
696 /// This is important when prompts are cancelled or end unexpectedly,
697 /// as there may be buffered messages that would otherwise be received
698 /// by the next prompt.
699 ///
700 /// # Returns
701 ///
702 /// The number of messages drained from the channel.
703 ///
704 /// # Example
705 ///
706 /// ```no_run
707 /// # use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions};
708 /// # #[tokio::main]
709 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
710 /// # let mut client = ClaudeClient::new(ClaudeAgentOptions::default());
711 /// # client.connect().await?;
712 /// // Before starting a new prompt, drain any leftover messages
713 /// let drained = client.drain_messages().await;
714 /// if drained > 0 {
715 /// eprintln!("Drained {} leftover messages from previous prompt", drained);
716 /// }
717 /// # Ok(())
718 /// # }
719 /// ```
720 pub async fn drain_messages(&self) -> usize {
721 let Some(query) = &self.query else {
722 return 0;
723 };
724
725 let rx: Arc<Mutex<tokio::sync::mpsc::UnboundedReceiver<serde_json::Value>>> = {
726 let query_guard = query.lock().await;
727 Arc::clone(&query_guard.message_rx)
728 };
729
730 let mut count = 0;
731 // Use try_recv to drain all currently available messages without blocking
732 loop {
733 let mut rx_guard = rx.lock().await;
734 match rx_guard.try_recv() {
735 Ok(_) => count += 1,
736 Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break,
737 Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
738 }
739 }
740
741 if count > 0 {
742 debug!(count, "Drained leftover messages from previous prompt");
743 }
744
745 count
746 }
747
748 /// Send an interrupt signal to stop the current Claude operation
749 ///
750 /// This is analogous to Python's `client.interrupt()`.
751 ///
752 /// # Errors
753 ///
754 /// Returns an error if the client is not connected or if sending fails.
755 pub async fn interrupt(&self) -> Result<()> {
756 let query = self.query.as_ref().ok_or_else(|| {
757 ClaudeError::InvalidConfig("Client not connected. Call connect() first.".to_string())
758 })?;
759
760 let query_guard = query.lock().await;
761 query_guard.interrupt().await
762 }
763
764 /// Change the permission mode dynamically
765 ///
766 /// This is analogous to Python's `client.set_permission_mode()`.
767 ///
768 /// # Arguments
769 ///
770 /// * `mode` - The new permission mode to set
771 ///
772 /// # Errors
773 ///
774 /// Returns an error if the client is not connected or if sending fails.
775 pub async fn set_permission_mode(&self, mode: PermissionMode) -> Result<()> {
776 let query = self.query.as_ref().ok_or_else(|| {
777 ClaudeError::InvalidConfig("Client not connected. Call connect() first.".to_string())
778 })?;
779
780 let query_guard = query.lock().await;
781 query_guard.set_permission_mode(mode).await
782 }
783
784 /// Change the AI model dynamically
785 ///
786 /// This is analogous to Python's `client.set_model()`.
787 ///
788 /// # Arguments
789 ///
790 /// * `model` - The new model name, or None to use default
791 ///
792 /// # Errors
793 ///
794 /// Returns an error if the client is not connected or if sending fails.
795 pub async fn set_model(&self, model: Option<&str>) -> Result<()> {
796 let query = self.query.as_ref().ok_or_else(|| {
797 ClaudeError::InvalidConfig("Client not connected. Call connect() first.".to_string())
798 })?;
799
800 let query_guard = query.lock().await;
801 query_guard.set_model(model).await
802 }
803
804 /// Rewind tracked files to their state at a specific user message.
805 ///
806 /// This is analogous to Python's `client.rewind_files()`.
807 ///
808 /// # Requirements
809 ///
810 /// - `enable_file_checkpointing=true` in options to track file changes
811 /// - `extra_args={"replay-user-messages": None}` to receive UserMessage
812 /// objects with `uuid` in the response stream
813 ///
814 /// # Arguments
815 ///
816 /// * `user_message_id` - UUID of the user message to rewind to. This should be
817 /// the `uuid` field from a `UserMessage` received during the conversation.
818 ///
819 /// # Errors
820 ///
821 /// Returns an error if the client is not connected or if sending fails.
822 ///
823 /// # Example
824 ///
825 /// ```no_run
826 /// # use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions, Message};
827 /// # use std::collections::HashMap;
828 /// # #[tokio::main]
829 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
830 /// let options = ClaudeAgentOptions::builder()
831 /// .enable_file_checkpointing(true)
832 /// .extra_args(HashMap::from([("replay-user-messages".to_string(), None)]))
833 /// .build();
834 /// let mut client = ClaudeClient::new(options);
835 /// client.connect().await?;
836 ///
837 /// client.query("Make some changes to my files").await?;
838 /// let mut checkpoint_id = None;
839 /// {
840 /// let mut stream = client.receive_response();
841 /// use futures::StreamExt;
842 /// while let Some(Ok(msg)) = stream.next().await {
843 /// if let Message::User(user_msg) = &msg {
844 /// if let Some(uuid) = &user_msg.uuid {
845 /// checkpoint_id = Some(uuid.clone());
846 /// }
847 /// }
848 /// }
849 /// }
850 ///
851 /// // Later, rewind to that point
852 /// if let Some(id) = checkpoint_id {
853 /// client.rewind_files(&id).await?;
854 /// }
855 /// # Ok(())
856 /// # }
857 /// ```
858 pub async fn rewind_files(&self, user_message_id: &str) -> Result<()> {
859 let query = self.query.as_ref().ok_or_else(|| {
860 ClaudeError::InvalidConfig("Client not connected. Call connect() first.".to_string())
861 })?;
862
863 let query_guard = query.lock().await;
864 query_guard.rewind_files(user_message_id).await
865 }
866
867 /// Get server initialization info including available commands and output styles
868 ///
869 /// Returns initialization information from the Claude Code server including:
870 /// - Available commands (slash commands, system commands, etc.)
871 /// - Current and available output styles
872 /// - Server capabilities
873 ///
874 /// This is analogous to Python's `client.get_server_info()`.
875 ///
876 /// # Returns
877 ///
878 /// Dictionary with server info, or None if not connected
879 ///
880 /// # Example
881 ///
882 /// ```no_run
883 /// # use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions};
884 /// # #[tokio::main]
885 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
886 /// # let mut client = ClaudeClient::new(ClaudeAgentOptions::default());
887 /// # client.connect().await?;
888 /// if let Some(info) = client.get_server_info().await {
889 /// println!("Commands available: {}", info.get("commands").map(|c| c.as_array().map(|a| a.len()).unwrap_or(0)).unwrap_or(0));
890 /// println!("Output style: {:?}", info.get("output_style"));
891 /// }
892 /// # Ok(())
893 /// # }
894 /// ```
895 pub async fn get_server_info(&self) -> Option<serde_json::Value> {
896 let query = self.query.as_ref()?;
897 let query_guard = query.lock().await;
898 query_guard.get_initialization_result().await
899 }
900
901 /// Start a new session by switching to a different session ID
902 ///
903 /// This is a convenience method that creates a new conversation context.
904 /// It's equivalent to calling `query_with_session()` with a new session ID.
905 ///
906 /// To completely clear memory and start fresh, use `ClaudeAgentOptions::builder().fork_session(true).build()`
907 /// when creating a new client.
908 ///
909 /// # Arguments
910 ///
911 /// * `session_id` - The new session ID to use
912 /// * `prompt` - Initial message for the new session
913 ///
914 /// # Errors
915 ///
916 /// Returns an error if the client is not connected or if sending fails.
917 ///
918 /// # Example
919 ///
920 /// ```no_run
921 /// # use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions};
922 /// # #[tokio::main]
923 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
924 /// # let mut client = ClaudeClient::new(ClaudeAgentOptions::default());
925 /// # client.connect().await?;
926 /// // First conversation
927 /// client.query("Hello").await?;
928 ///
929 /// // Start new conversation with different context
930 /// client.new_session("session-2", "Tell me about Rust").await?;
931 /// # Ok(())
932 /// # }
933 /// ```
934 pub async fn new_session(
935 &mut self,
936 session_id: impl Into<String>,
937 prompt: impl Into<String>,
938 ) -> Result<()> {
939 self.query_with_session(prompt, session_id).await
940 }
941
942 /// Disconnect from Claude (analogous to Python's __aexit__)
943 ///
944 /// This cleanly shuts down the connection to Claude Code CLI.
945 ///
946 /// # Errors
947 ///
948 /// Returns an error if disconnection fails.
949 #[instrument(name = "claude.client.disconnect", skip(self))]
950 pub async fn disconnect(&mut self) -> Result<()> {
951 if !self.connected {
952 debug!("Client already disconnected");
953 return Ok(());
954 }
955
956 info!("Disconnecting from Claude Code CLI");
957
958 if let Some(query) = self.query.take() {
959 // Close stdin first (using direct access) to signal CLI to exit
960 // This will cause the background task to finish and release transport lock
961 let query_guard = query.lock().await;
962 if let Some(ref stdin_arc) = query_guard.stdin {
963 let mut stdin_guard = stdin_arc.lock().await;
964 if let Some(mut stdin_stream) = stdin_guard.take() {
965 let _ = stdin_stream.shutdown().await;
966 }
967 }
968 let transport = Arc::clone(&query_guard.transport);
969 drop(query_guard);
970
971 // Give background task a moment to finish reading and release lock
972 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
973
974 let mut transport_guard = transport.lock().await;
975 transport_guard.close().await?;
976 }
977
978 self.connected = false;
979 debug!("Disconnected successfully");
980 Ok(())
981 }
982}
983
984impl Drop for ClaudeClient {
985 fn drop(&mut self) {
986 // Note: We can't run async code in Drop, so we can't guarantee clean shutdown
987 // Users should call disconnect() explicitly
988 if self.connected {
989 eprintln!(
990 "Warning: ClaudeClient dropped without calling disconnect(). Resources may not be cleaned up properly."
991 );
992 }
993 }
994}
995
996#[cfg(test)]
997mod tests {
998 use super::*;
999 use crate::types::permissions::{PermissionResult, PermissionResultAllow};
1000 use std::sync::Arc;
1001
1002 #[tokio::test]
1003 async fn test_connect_rejects_can_use_tool_with_custom_permission_tool() {
1004 let callback: crate::types::permissions::CanUseToolCallback =
1005 Arc::new(|_tool_name, _tool_input, _context| {
1006 Box::pin(async move { PermissionResult::Allow(PermissionResultAllow::default()) })
1007 });
1008
1009 let opts = ClaudeAgentOptions::builder()
1010 .can_use_tool(callback)
1011 .permission_prompt_tool_name("custom_tool") // Not "stdio"
1012 .build();
1013
1014 let mut client = ClaudeClient::new(opts);
1015 let result = client.connect().await;
1016
1017 assert!(result.is_err());
1018 let err = result.unwrap_err();
1019 assert!(matches!(err, ClaudeError::InvalidConfig(_)));
1020 assert!(err.to_string().contains("permission_prompt_tool_name"));
1021 }
1022
1023 #[tokio::test]
1024 async fn test_connect_accepts_can_use_tool_with_stdio() {
1025 let callback: crate::types::permissions::CanUseToolCallback =
1026 Arc::new(|_tool_name, _tool_input, _context| {
1027 Box::pin(async move { PermissionResult::Allow(PermissionResultAllow::default()) })
1028 });
1029
1030 let opts = ClaudeAgentOptions::builder()
1031 .can_use_tool(callback)
1032 .permission_prompt_tool_name("stdio") // Explicitly "stdio" is OK
1033 .build();
1034
1035 let mut client = ClaudeClient::new(opts);
1036 // This will fail later (CLI not found), but should pass validation
1037 let result = client.connect().await;
1038
1039 // Should not be InvalidConfig error about permission_prompt_tool_name
1040 if let Err(ref err) = result {
1041 assert!(
1042 !err.to_string().contains("permission_prompt_tool_name"),
1043 "Should not fail on permission_prompt_tool_name validation"
1044 );
1045 }
1046 }
1047
1048 #[tokio::test]
1049 async fn test_connect_accepts_can_use_tool_without_permission_tool() {
1050 let callback: crate::types::permissions::CanUseToolCallback =
1051 Arc::new(|_tool_name, _tool_input, _context| {
1052 Box::pin(async move { PermissionResult::Allow(PermissionResultAllow::default()) })
1053 });
1054
1055 let opts = ClaudeAgentOptions::builder()
1056 .can_use_tool(callback)
1057 // No permission_prompt_tool_name set - defaults to stdio
1058 .build();
1059
1060 let mut client = ClaudeClient::new(opts);
1061 // This will fail later (CLI not found), but should pass validation
1062 let result = client.connect().await;
1063
1064 // Should not be InvalidConfig error about permission_prompt_tool_name
1065 if let Err(ref err) = result {
1066 assert!(
1067 !err.to_string().contains("permission_prompt_tool_name"),
1068 "Should not fail on permission_prompt_tool_name validation"
1069 );
1070 }
1071 }
1072}