claude_code_agent_sdk/client.rs
1//! ClaudeClient for bidirectional streaming interactions with hook support
2
3use futures::stream::Stream;
4use std::pin::Pin;
5use std::sync::Arc;
6use tokio::io::AsyncWriteExt;
7use tokio::sync::Mutex;
8
9use crate::errors::{ClaudeError, Result};
10use crate::internal::message_parser::MessageParser;
11use crate::internal::query_full::QueryFull;
12use crate::internal::transport::subprocess::QueryPrompt;
13use crate::internal::transport::{SubprocessTransport, Transport};
14use crate::types::config::{ClaudeAgentOptions, PermissionMode};
15use crate::types::hooks::HookEvent;
16use crate::types::messages::{Message, UserContentBlock};
17
18/// Client for bidirectional streaming interactions with Claude
19///
20/// This client provides the same functionality as Python's ClaudeSDKClient,
21/// supporting bidirectional communication, streaming responses, and dynamic
22/// control over the Claude session.
23///
24/// # Example
25///
26/// ```no_run
27/// use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions};
28/// use futures::StreamExt;
29///
30/// #[tokio::main]
31/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
32/// let mut client = ClaudeClient::new(ClaudeAgentOptions::default());
33///
34/// // Connect to Claude
35/// client.connect().await?;
36///
37/// // Send a query
38/// client.query("Hello Claude!").await?;
39///
40/// // Receive response as a stream
41/// {
42/// let mut stream = client.receive_response();
43/// while let Some(message) = stream.next().await {
44/// println!("Received: {:?}", message?);
45/// }
46/// }
47///
48/// // Disconnect
49/// client.disconnect().await?;
50/// Ok(())
51/// }
52/// ```
53pub struct ClaudeClient {
54 options: ClaudeAgentOptions,
55 query: Option<Arc<Mutex<QueryFull>>>,
56 connected: bool,
57}
58
59impl ClaudeClient {
60 /// Create a new ClaudeClient
61 ///
62 /// # Arguments
63 ///
64 /// * `options` - Configuration options for the Claude client
65 ///
66 /// # Example
67 ///
68 /// ```no_run
69 /// use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions};
70 ///
71 /// let client = ClaudeClient::new(ClaudeAgentOptions::default());
72 /// ```
73 pub fn new(options: ClaudeAgentOptions) -> Self {
74 Self {
75 options,
76 query: None,
77 connected: false,
78 }
79 }
80
81 /// Create a new ClaudeClient with early validation
82 ///
83 /// Unlike `new()`, this validates the configuration eagerly by attempting
84 /// to create the transport. This catches issues like invalid working directory
85 /// or missing CLI before `connect()` is called.
86 ///
87 /// # Arguments
88 ///
89 /// * `options` - Configuration options for the Claude client
90 ///
91 /// # Errors
92 ///
93 /// Returns an error if:
94 /// - The working directory does not exist or is not a directory
95 /// - Claude CLI cannot be found
96 ///
97 /// # Example
98 ///
99 /// ```no_run
100 /// use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions};
101 ///
102 /// let client = ClaudeClient::try_new(ClaudeAgentOptions::default())?;
103 /// # Ok::<(), claude_agent_sdk_rs::ClaudeError>(())
104 /// ```
105 pub fn try_new(options: ClaudeAgentOptions) -> Result<Self> {
106 // Validate by attempting to create transport (but don't keep it)
107 let prompt = QueryPrompt::Streaming;
108 let _ = SubprocessTransport::new(prompt, options.clone())?;
109
110 Ok(Self {
111 options,
112 query: None,
113 connected: false,
114 })
115 }
116
117 /// Connect to Claude (analogous to Python's __aenter__)
118 ///
119 /// This establishes the connection to the Claude Code CLI and initializes
120 /// the bidirectional communication channel.
121 ///
122 /// # Errors
123 ///
124 /// Returns an error if:
125 /// - Claude CLI cannot be found or started
126 /// - The initialization handshake fails
127 /// - Hook registration fails
128 /// - `can_use_tool` callback is set with incompatible `permission_prompt_tool_name`
129 pub async fn connect(&mut self) -> Result<()> {
130 if self.connected {
131 return Ok(());
132 }
133
134 // Validate can_use_tool configuration (aligned with Python SDK behavior)
135 // When can_use_tool callback is set, permission_prompt_tool_name must be "stdio"
136 // to ensure the control protocol can handle permission requests
137 if self.options.can_use_tool.is_some() {
138 if let Some(ref tool_name) = self.options.permission_prompt_tool_name {
139 if tool_name != "stdio" {
140 return Err(ClaudeError::InvalidConfig(
141 "can_use_tool callback requires permission_prompt_tool_name to be 'stdio' or unset. \
142 Custom permission_prompt_tool_name is incompatible with can_use_tool callback."
143 .to_string(),
144 ));
145 }
146 }
147 }
148
149 // Create transport in streaming mode (no initial prompt)
150 let prompt = QueryPrompt::Streaming;
151 let mut transport = SubprocessTransport::new(prompt, self.options.clone())?;
152
153 // Don't send initial prompt - we'll use query() for that
154 transport.connect().await?;
155
156 // Extract stdin for direct access (avoids transport lock deadlock)
157 let stdin = Arc::clone(&transport.stdin);
158
159 // Create Query with hooks
160 let mut query = QueryFull::new(Box::new(transport));
161 query.set_stdin(stdin);
162
163 // Set control request timeout from options
164 query.set_control_request_timeout(self.options.control_request_timeout);
165
166 // Extract SDK MCP servers from options
167 let sdk_mcp_servers =
168 if let crate::types::mcp::McpServers::Dict(servers_dict) = &self.options.mcp_servers {
169 servers_dict
170 .iter()
171 .filter_map(|(name, config)| {
172 if let crate::types::mcp::McpServerConfig::Sdk(sdk_config) = config {
173 Some((name.clone(), sdk_config.clone()))
174 } else {
175 None
176 }
177 })
178 .collect()
179 } else {
180 std::collections::HashMap::new()
181 };
182 query.set_sdk_mcp_servers(sdk_mcp_servers).await;
183
184 // Set can_use_tool callback if provided
185 if let Some(ref callback) = self.options.can_use_tool {
186 query.set_can_use_tool(Some(Arc::clone(callback))).await;
187 }
188
189 // Convert hooks to internal format
190 let hooks = self.options.hooks.as_ref().map(|hooks_map| {
191 hooks_map
192 .iter()
193 .map(|(event, matchers)| {
194 let event_name = match event {
195 HookEvent::PreToolUse => "PreToolUse",
196 HookEvent::PostToolUse => "PostToolUse",
197 HookEvent::UserPromptSubmit => "UserPromptSubmit",
198 HookEvent::Stop => "Stop",
199 HookEvent::SubagentStop => "SubagentStop",
200 HookEvent::PreCompact => "PreCompact",
201 };
202 (event_name.to_string(), matchers.clone())
203 })
204 .collect()
205 });
206
207 // Start reading messages in background FIRST
208 // This must happen before initialize() because initialize()
209 // sends a control request and waits for response
210 query.start().await?;
211
212 // Initialize with hooks (sends control request)
213 query.initialize(hooks).await?;
214
215 self.query = Some(Arc::new(Mutex::new(query)));
216 self.connected = true;
217
218 Ok(())
219 }
220
221 /// Send a query to Claude
222 ///
223 /// This sends a new user prompt to Claude. Claude will remember the context
224 /// of previous queries within the same session.
225 ///
226 /// # Arguments
227 ///
228 /// * `prompt` - The user prompt to send
229 ///
230 /// # Errors
231 ///
232 /// Returns an error if the client is not connected or if sending fails.
233 ///
234 /// # Example
235 ///
236 /// ```no_run
237 /// # use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions};
238 /// # #[tokio::main]
239 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
240 /// # let mut client = ClaudeClient::new(ClaudeAgentOptions::default());
241 /// # client.connect().await?;
242 /// client.query("What is 2 + 2?").await?;
243 /// # Ok(())
244 /// # }
245 /// ```
246 pub async fn query(&mut self, prompt: impl Into<String>) -> Result<()> {
247 self.query_with_session(prompt, "default").await
248 }
249
250 /// Send a query to Claude with a specific session ID
251 ///
252 /// This sends a new user prompt to Claude. Different session IDs maintain
253 /// separate conversation contexts.
254 ///
255 /// # Arguments
256 ///
257 /// * `prompt` - The user prompt to send
258 /// * `session_id` - Session identifier for the conversation
259 ///
260 /// # Errors
261 ///
262 /// Returns an error if the client is not connected or if sending fails.
263 ///
264 /// # Example
265 ///
266 /// ```no_run
267 /// # use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions};
268 /// # #[tokio::main]
269 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
270 /// # let mut client = ClaudeClient::new(ClaudeAgentOptions::default());
271 /// # client.connect().await?;
272 /// // Separate conversation contexts
273 /// client.query_with_session("First question", "session-1").await?;
274 /// client.query_with_session("Different question", "session-2").await?;
275 /// # Ok(())
276 /// # }
277 /// ```
278 pub async fn query_with_session(
279 &mut self,
280 prompt: impl Into<String>,
281 session_id: impl Into<String>,
282 ) -> Result<()> {
283 let query = self.query.as_ref().ok_or_else(|| {
284 ClaudeError::InvalidConfig("Client not connected. Call connect() first.".to_string())
285 })?;
286
287 let prompt_str = prompt.into();
288 let session_id_str = session_id.into();
289
290 // Format as JSON message for stream-json input format
291 let user_message = serde_json::json!({
292 "type": "user",
293 "message": {
294 "role": "user",
295 "content": prompt_str
296 },
297 "session_id": session_id_str
298 });
299
300 let message_str = serde_json::to_string(&user_message).map_err(|e| {
301 ClaudeError::Transport(format!("Failed to serialize user message: {}", e))
302 })?;
303
304 // Write directly to stdin (bypasses transport lock)
305 let query_guard = query.lock().await;
306 let stdin = query_guard.stdin.clone();
307 drop(query_guard);
308
309 if let Some(stdin_arc) = stdin {
310 let mut stdin_guard = stdin_arc.lock().await;
311 if let Some(ref mut stdin_stream) = *stdin_guard {
312 stdin_stream
313 .write_all(message_str.as_bytes())
314 .await
315 .map_err(|e| ClaudeError::Transport(format!("Failed to write query: {}", e)))?;
316 stdin_stream.write_all(b"\n").await.map_err(|e| {
317 ClaudeError::Transport(format!("Failed to write newline: {}", e))
318 })?;
319 stdin_stream
320 .flush()
321 .await
322 .map_err(|e| ClaudeError::Transport(format!("Failed to flush: {}", e)))?;
323 } else {
324 return Err(ClaudeError::Transport("stdin not available".to_string()));
325 }
326 } else {
327 return Err(ClaudeError::Transport("stdin not set".to_string()));
328 }
329
330 Ok(())
331 }
332
333 /// Send a query with structured content blocks (supports images)
334 ///
335 /// This method enables multimodal queries in bidirectional streaming mode.
336 /// Use it to send images alongside text for vision-related tasks.
337 ///
338 /// # Arguments
339 ///
340 /// * `content` - A vector of content blocks (text and/or images)
341 ///
342 /// # Errors
343 ///
344 /// Returns an error if:
345 /// - The content vector is empty (must include at least one text or image block)
346 /// - The client is not connected (call `connect()` first)
347 /// - Sending the message fails
348 ///
349 /// # Example
350 ///
351 /// ```no_run
352 /// # use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions, UserContentBlock};
353 /// # #[tokio::main]
354 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
355 /// # let mut client = ClaudeClient::new(ClaudeAgentOptions::default());
356 /// # client.connect().await?;
357 /// let base64_data = "iVBORw0KGgo..."; // base64 encoded image
358 /// client.query_with_content(vec![
359 /// UserContentBlock::text("What's in this image?"),
360 /// UserContentBlock::image_base64("image/png", base64_data)?,
361 /// ]).await?;
362 /// # Ok(())
363 /// # }
364 /// ```
365 pub async fn query_with_content(
366 &mut self,
367 content: impl Into<Vec<UserContentBlock>>,
368 ) -> Result<()> {
369 self.query_with_content_and_session(content, "default")
370 .await
371 }
372
373 /// Send a query with structured content blocks and a specific session ID
374 ///
375 /// This method enables multimodal queries with session management for
376 /// maintaining separate conversation contexts.
377 ///
378 /// # Arguments
379 ///
380 /// * `content` - A vector of content blocks (text and/or images)
381 /// * `session_id` - Session identifier for the conversation
382 ///
383 /// # Errors
384 ///
385 /// Returns an error if:
386 /// - The content vector is empty (must include at least one text or image block)
387 /// - The client is not connected (call `connect()` first)
388 /// - Sending the message fails
389 ///
390 /// # Example
391 ///
392 /// ```no_run
393 /// # use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions, UserContentBlock};
394 /// # #[tokio::main]
395 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
396 /// # let mut client = ClaudeClient::new(ClaudeAgentOptions::default());
397 /// # client.connect().await?;
398 /// client.query_with_content_and_session(
399 /// vec![
400 /// UserContentBlock::text("Analyze this chart"),
401 /// UserContentBlock::image_url("https://example.com/chart.png"),
402 /// ],
403 /// "analysis-session",
404 /// ).await?;
405 /// # Ok(())
406 /// # }
407 /// ```
408 pub async fn query_with_content_and_session(
409 &mut self,
410 content: impl Into<Vec<UserContentBlock>>,
411 session_id: impl Into<String>,
412 ) -> Result<()> {
413 let query = self.query.as_ref().ok_or_else(|| {
414 ClaudeError::InvalidConfig("Client not connected. Call connect() first.".to_string())
415 })?;
416
417 let content_blocks: Vec<UserContentBlock> = content.into();
418 UserContentBlock::validate_content(&content_blocks)?;
419
420 let session_id_str = session_id.into();
421
422 // Format as JSON message for stream-json input format
423 // Content is an array of content blocks, not a simple string
424 let user_message = serde_json::json!({
425 "type": "user",
426 "message": {
427 "role": "user",
428 "content": content_blocks
429 },
430 "session_id": session_id_str
431 });
432
433 let message_str = serde_json::to_string(&user_message).map_err(|e| {
434 ClaudeError::Transport(format!("Failed to serialize user message: {}", e))
435 })?;
436
437 // Write directly to stdin (bypasses transport lock)
438 let query_guard = query.lock().await;
439 let stdin = query_guard.stdin.clone();
440 drop(query_guard);
441
442 if let Some(stdin_arc) = stdin {
443 let mut stdin_guard = stdin_arc.lock().await;
444 if let Some(ref mut stdin_stream) = *stdin_guard {
445 stdin_stream
446 .write_all(message_str.as_bytes())
447 .await
448 .map_err(|e| ClaudeError::Transport(format!("Failed to write query: {}", e)))?;
449 stdin_stream.write_all(b"\n").await.map_err(|e| {
450 ClaudeError::Transport(format!("Failed to write newline: {}", e))
451 })?;
452 stdin_stream
453 .flush()
454 .await
455 .map_err(|e| ClaudeError::Transport(format!("Failed to flush: {}", e)))?;
456 } else {
457 return Err(ClaudeError::Transport("stdin not available".to_string()));
458 }
459 } else {
460 return Err(ClaudeError::Transport("stdin not set".to_string()));
461 }
462
463 Ok(())
464 }
465
466 /// Receive all messages as a stream (continuous)
467 ///
468 /// This method returns a stream that yields all messages from Claude
469 /// indefinitely until the stream is closed or an error occurs.
470 ///
471 /// Use this when you want to process all messages, including multiple
472 /// responses and system events.
473 ///
474 /// # Returns
475 ///
476 /// A stream of `Result<Message>` that continues until the connection closes.
477 ///
478 /// # Example
479 ///
480 /// ```no_run
481 /// # use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions};
482 /// # use futures::StreamExt;
483 /// # #[tokio::main]
484 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
485 /// # let mut client = ClaudeClient::new(ClaudeAgentOptions::default());
486 /// # client.connect().await?;
487 /// # client.query("Hello").await?;
488 /// let mut stream = client.receive_messages();
489 /// while let Some(message) = stream.next().await {
490 /// println!("Received: {:?}", message?);
491 /// }
492 /// # Ok(())
493 /// # }
494 /// ```
495 pub fn receive_messages(&self) -> Pin<Box<dyn Stream<Item = Result<Message>> + Send + '_>> {
496 let query = match &self.query {
497 Some(q) => Arc::clone(q),
498 None => {
499 return Box::pin(futures::stream::once(async {
500 Err(ClaudeError::InvalidConfig(
501 "Client not connected. Call connect() first.".to_string(),
502 ))
503 }));
504 }
505 };
506
507 Box::pin(async_stream::stream! {
508 let rx: Arc<Mutex<tokio::sync::mpsc::UnboundedReceiver<serde_json::Value>>> = {
509 let query_guard = query.lock().await;
510 Arc::clone(&query_guard.message_rx)
511 };
512
513 loop {
514 let message = {
515 let mut rx_guard = rx.lock().await;
516 rx_guard.recv().await
517 };
518
519 match message {
520 Some(json) => {
521 match MessageParser::parse(json) {
522 Ok(msg) => yield Ok(msg),
523 Err(e) => {
524 eprintln!("Failed to parse message: {}", e);
525 yield Err(e);
526 }
527 }
528 }
529 None => break,
530 }
531 }
532 })
533 }
534
535 /// Receive messages until a ResultMessage
536 ///
537 /// This method returns a stream that yields messages until it encounters
538 /// a `ResultMessage`, which signals the completion of a Claude response.
539 ///
540 /// This is the most common pattern for handling Claude responses, as it
541 /// processes one complete "turn" of the conversation.
542 ///
543 /// # Returns
544 ///
545 /// A stream of `Result<Message>` that ends when a ResultMessage is received.
546 ///
547 /// # Example
548 ///
549 /// ```no_run
550 /// # use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions, Message};
551 /// # use futures::StreamExt;
552 /// # #[tokio::main]
553 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
554 /// # let mut client = ClaudeClient::new(ClaudeAgentOptions::default());
555 /// # client.connect().await?;
556 /// # client.query("Hello").await?;
557 /// let mut stream = client.receive_response();
558 /// while let Some(message) = stream.next().await {
559 /// match message? {
560 /// Message::Assistant(msg) => println!("Assistant: {:?}", msg),
561 /// Message::Result(result) => {
562 /// println!("Done! Cost: ${:?}", result.total_cost_usd);
563 /// break;
564 /// }
565 /// _ => {}
566 /// }
567 /// }
568 /// # Ok(())
569 /// # }
570 /// ```
571 pub fn receive_response(&self) -> Pin<Box<dyn Stream<Item = Result<Message>> + Send + '_>> {
572 let query = match &self.query {
573 Some(q) => Arc::clone(q),
574 None => {
575 return Box::pin(futures::stream::once(async {
576 Err(ClaudeError::InvalidConfig(
577 "Client not connected. Call connect() first.".to_string(),
578 ))
579 }));
580 }
581 };
582
583 Box::pin(async_stream::stream! {
584 let rx: Arc<Mutex<tokio::sync::mpsc::UnboundedReceiver<serde_json::Value>>> = {
585 let query_guard = query.lock().await;
586 Arc::clone(&query_guard.message_rx)
587 };
588
589 loop {
590 let message = {
591 let mut rx_guard = rx.lock().await;
592 rx_guard.recv().await
593 };
594
595 match message {
596 Some(json) => {
597 match MessageParser::parse(json) {
598 Ok(msg) => {
599 let is_result = matches!(msg, Message::Result(_));
600 yield Ok(msg);
601 if is_result {
602 break;
603 }
604 }
605 Err(e) => {
606 eprintln!("Failed to parse message: {}", e);
607 yield Err(e);
608 }
609 }
610 }
611 None => break,
612 }
613 }
614 })
615 }
616
617 /// Send an interrupt signal to stop the current Claude operation
618 ///
619 /// This is analogous to Python's `client.interrupt()`.
620 ///
621 /// # Errors
622 ///
623 /// Returns an error if the client is not connected or if sending fails.
624 pub async fn interrupt(&self) -> Result<()> {
625 let query = self.query.as_ref().ok_or_else(|| {
626 ClaudeError::InvalidConfig("Client not connected. Call connect() first.".to_string())
627 })?;
628
629 let query_guard = query.lock().await;
630 query_guard.interrupt().await
631 }
632
633 /// Change the permission mode dynamically
634 ///
635 /// This is analogous to Python's `client.set_permission_mode()`.
636 ///
637 /// # Arguments
638 ///
639 /// * `mode` - The new permission mode to set
640 ///
641 /// # Errors
642 ///
643 /// Returns an error if the client is not connected or if sending fails.
644 pub async fn set_permission_mode(&self, mode: PermissionMode) -> Result<()> {
645 let query = self.query.as_ref().ok_or_else(|| {
646 ClaudeError::InvalidConfig("Client not connected. Call connect() first.".to_string())
647 })?;
648
649 let query_guard = query.lock().await;
650 query_guard.set_permission_mode(mode).await
651 }
652
653 /// Change the AI model dynamically
654 ///
655 /// This is analogous to Python's `client.set_model()`.
656 ///
657 /// # Arguments
658 ///
659 /// * `model` - The new model name, or None to use default
660 ///
661 /// # Errors
662 ///
663 /// Returns an error if the client is not connected or if sending fails.
664 pub async fn set_model(&self, model: Option<&str>) -> Result<()> {
665 let query = self.query.as_ref().ok_or_else(|| {
666 ClaudeError::InvalidConfig("Client not connected. Call connect() first.".to_string())
667 })?;
668
669 let query_guard = query.lock().await;
670 query_guard.set_model(model).await
671 }
672
673 /// Rewind tracked files to their state at a specific user message.
674 ///
675 /// This is analogous to Python's `client.rewind_files()`.
676 ///
677 /// # Requirements
678 ///
679 /// - `enable_file_checkpointing=true` in options to track file changes
680 /// - `extra_args={"replay-user-messages": None}` to receive UserMessage
681 /// objects with `uuid` in the response stream
682 ///
683 /// # Arguments
684 ///
685 /// * `user_message_id` - UUID of the user message to rewind to. This should be
686 /// the `uuid` field from a `UserMessage` received during the conversation.
687 ///
688 /// # Errors
689 ///
690 /// Returns an error if the client is not connected or if sending fails.
691 ///
692 /// # Example
693 ///
694 /// ```no_run
695 /// # use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions, Message};
696 /// # use std::collections::HashMap;
697 /// # #[tokio::main]
698 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
699 /// let options = ClaudeAgentOptions::builder()
700 /// .enable_file_checkpointing(true)
701 /// .extra_args(HashMap::from([("replay-user-messages".to_string(), None)]))
702 /// .build();
703 /// let mut client = ClaudeClient::new(options);
704 /// client.connect().await?;
705 ///
706 /// client.query("Make some changes to my files").await?;
707 /// let mut checkpoint_id = None;
708 /// {
709 /// let mut stream = client.receive_response();
710 /// use futures::StreamExt;
711 /// while let Some(Ok(msg)) = stream.next().await {
712 /// if let Message::User(user_msg) = &msg {
713 /// if let Some(uuid) = &user_msg.uuid {
714 /// checkpoint_id = Some(uuid.clone());
715 /// }
716 /// }
717 /// }
718 /// }
719 ///
720 /// // Later, rewind to that point
721 /// if let Some(id) = checkpoint_id {
722 /// client.rewind_files(&id).await?;
723 /// }
724 /// # Ok(())
725 /// # }
726 /// ```
727 pub async fn rewind_files(&self, user_message_id: &str) -> Result<()> {
728 let query = self.query.as_ref().ok_or_else(|| {
729 ClaudeError::InvalidConfig("Client not connected. Call connect() first.".to_string())
730 })?;
731
732 let query_guard = query.lock().await;
733 query_guard.rewind_files(user_message_id).await
734 }
735
736 /// Get server initialization info including available commands and output styles
737 ///
738 /// Returns initialization information from the Claude Code server including:
739 /// - Available commands (slash commands, system commands, etc.)
740 /// - Current and available output styles
741 /// - Server capabilities
742 ///
743 /// This is analogous to Python's `client.get_server_info()`.
744 ///
745 /// # Returns
746 ///
747 /// Dictionary with server info, or None if not connected
748 ///
749 /// # Example
750 ///
751 /// ```no_run
752 /// # use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions};
753 /// # #[tokio::main]
754 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
755 /// # let mut client = ClaudeClient::new(ClaudeAgentOptions::default());
756 /// # client.connect().await?;
757 /// if let Some(info) = client.get_server_info().await {
758 /// println!("Commands available: {}", info.get("commands").map(|c| c.as_array().map(|a| a.len()).unwrap_or(0)).unwrap_or(0));
759 /// println!("Output style: {:?}", info.get("output_style"));
760 /// }
761 /// # Ok(())
762 /// # }
763 /// ```
764 pub async fn get_server_info(&self) -> Option<serde_json::Value> {
765 let query = self.query.as_ref()?;
766 let query_guard = query.lock().await;
767 query_guard.get_initialization_result().await
768 }
769
770 /// Start a new session by switching to a different session ID
771 ///
772 /// This is a convenience method that creates a new conversation context.
773 /// It's equivalent to calling `query_with_session()` with a new session ID.
774 ///
775 /// To completely clear memory and start fresh, use `ClaudeAgentOptions::builder().fork_session(true).build()`
776 /// when creating a new client.
777 ///
778 /// # Arguments
779 ///
780 /// * `session_id` - The new session ID to use
781 /// * `prompt` - Initial message for the new session
782 ///
783 /// # Errors
784 ///
785 /// Returns an error if the client is not connected or if sending fails.
786 ///
787 /// # Example
788 ///
789 /// ```no_run
790 /// # use claude_agent_sdk_rs::{ClaudeClient, ClaudeAgentOptions};
791 /// # #[tokio::main]
792 /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
793 /// # let mut client = ClaudeClient::new(ClaudeAgentOptions::default());
794 /// # client.connect().await?;
795 /// // First conversation
796 /// client.query("Hello").await?;
797 ///
798 /// // Start new conversation with different context
799 /// client.new_session("session-2", "Tell me about Rust").await?;
800 /// # Ok(())
801 /// # }
802 /// ```
803 pub async fn new_session(
804 &mut self,
805 session_id: impl Into<String>,
806 prompt: impl Into<String>,
807 ) -> Result<()> {
808 self.query_with_session(prompt, session_id).await
809 }
810
811 /// Disconnect from Claude (analogous to Python's __aexit__)
812 ///
813 /// This cleanly shuts down the connection to Claude Code CLI.
814 ///
815 /// # Errors
816 ///
817 /// Returns an error if disconnection fails.
818 pub async fn disconnect(&mut self) -> Result<()> {
819 if !self.connected {
820 return Ok(());
821 }
822
823 if let Some(query) = self.query.take() {
824 // Close stdin first (using direct access) to signal CLI to exit
825 // This will cause the background task to finish and release transport lock
826 let query_guard = query.lock().await;
827 if let Some(ref stdin_arc) = query_guard.stdin {
828 let mut stdin_guard = stdin_arc.lock().await;
829 if let Some(mut stdin_stream) = stdin_guard.take() {
830 let _ = stdin_stream.shutdown().await;
831 }
832 }
833 let transport = Arc::clone(&query_guard.transport);
834 drop(query_guard);
835
836 // Give background task a moment to finish reading and release lock
837 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
838
839 let mut transport_guard = transport.lock().await;
840 transport_guard.close().await?;
841 }
842
843 self.connected = false;
844 Ok(())
845 }
846}
847
848impl Drop for ClaudeClient {
849 fn drop(&mut self) {
850 // Note: We can't run async code in Drop, so we can't guarantee clean shutdown
851 // Users should call disconnect() explicitly
852 if self.connected {
853 eprintln!(
854 "Warning: ClaudeClient dropped without calling disconnect(). Resources may not be cleaned up properly."
855 );
856 }
857 }
858}
859
860#[cfg(test)]
861mod tests {
862 use super::*;
863 use crate::types::permissions::{PermissionResult, PermissionResultAllow};
864 use std::sync::Arc;
865
866 #[tokio::test]
867 async fn test_connect_rejects_can_use_tool_with_custom_permission_tool() {
868 let callback: crate::types::permissions::CanUseToolCallback = Arc::new(
869 |_tool_name, _tool_input, _context| {
870 Box::pin(async move { PermissionResult::Allow(PermissionResultAllow::default()) })
871 },
872 );
873
874 let opts = ClaudeAgentOptions::builder()
875 .can_use_tool(callback)
876 .permission_prompt_tool_name("custom_tool") // Not "stdio"
877 .build();
878
879 let mut client = ClaudeClient::new(opts);
880 let result = client.connect().await;
881
882 assert!(result.is_err());
883 let err = result.unwrap_err();
884 assert!(matches!(err, ClaudeError::InvalidConfig(_)));
885 assert!(err.to_string().contains("permission_prompt_tool_name"));
886 }
887
888 #[tokio::test]
889 async fn test_connect_accepts_can_use_tool_with_stdio() {
890 let callback: crate::types::permissions::CanUseToolCallback = Arc::new(
891 |_tool_name, _tool_input, _context| {
892 Box::pin(async move { PermissionResult::Allow(PermissionResultAllow::default()) })
893 },
894 );
895
896 let opts = ClaudeAgentOptions::builder()
897 .can_use_tool(callback)
898 .permission_prompt_tool_name("stdio") // Explicitly "stdio" is OK
899 .build();
900
901 let mut client = ClaudeClient::new(opts);
902 // This will fail later (CLI not found), but should pass validation
903 let result = client.connect().await;
904
905 // Should not be InvalidConfig error about permission_prompt_tool_name
906 if let Err(ref err) = result {
907 assert!(
908 !err.to_string().contains("permission_prompt_tool_name"),
909 "Should not fail on permission_prompt_tool_name validation"
910 );
911 }
912 }
913
914 #[tokio::test]
915 async fn test_connect_accepts_can_use_tool_without_permission_tool() {
916 let callback: crate::types::permissions::CanUseToolCallback = Arc::new(
917 |_tool_name, _tool_input, _context| {
918 Box::pin(async move { PermissionResult::Allow(PermissionResultAllow::default()) })
919 },
920 );
921
922 let opts = ClaudeAgentOptions::builder()
923 .can_use_tool(callback)
924 // No permission_prompt_tool_name set - defaults to stdio
925 .build();
926
927 let mut client = ClaudeClient::new(opts);
928 // This will fail later (CLI not found), but should pass validation
929 let result = client.connect().await;
930
931 // Should not be InvalidConfig error about permission_prompt_tool_name
932 if let Err(ref err) = result {
933 assert!(
934 !err.to_string().contains("permission_prompt_tool_name"),
935 "Should not fail on permission_prompt_tool_name validation"
936 );
937 }
938 }
939}