cc_sdk/client.rs
1//! Interactive client for bidirectional communication with Claude
2//!
3//! This module provides the `ClaudeSDKClient` for interactive, stateful
4//! conversations with Claude Code CLI.
5
6use crate::{
7 errors::{Result, SdkError},
8 internal_query::Query,
9 token_tracker::BudgetManager,
10 transport::{InputMessage, SubprocessTransport, Transport},
11 types::{ClaudeCodeOptions, ContentBlock, ControlRequest, ControlResponse, Message},
12};
13use futures::stream::{Stream, StreamExt};
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::pin::Pin;
17use tokio::sync::{Mutex, RwLock, mpsc};
18use tokio_stream::wrappers::ReceiverStream;
19use tracing::{debug, error, info};
20
21/// Client state
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum ClientState {
24 /// Not connected
25 Disconnected,
26 /// Connected and ready
27 Connected,
28 /// Error state
29 Error,
30}
31
32/// Interactive client for bidirectional communication with Claude
33///
34/// `ClaudeSDKClient` provides a stateful, interactive interface for communicating
35/// with Claude Code CLI. Unlike the simple `query` function, this client supports:
36///
37/// - Bidirectional communication
38/// - Multiple sessions
39/// - Interrupt capabilities
40/// - State management
41/// - Follow-up messages based on responses
42///
43/// # Example
44///
45/// ```rust,no_run
46/// use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions, Message, Result};
47/// use futures::StreamExt;
48///
49/// #[tokio::main]
50/// async fn main() -> Result<()> {
51/// let options = ClaudeCodeOptions::builder()
52/// .system_prompt("You are a helpful assistant")
53/// .model("claude-3-opus-20240229")
54/// .build();
55///
56/// let mut client = ClaudeSDKClient::new(options);
57///
58/// // Connect with initial prompt
59/// client.connect(Some("Hello!".to_string())).await?;
60///
61/// // Receive initial response
62/// let mut messages = client.receive_messages().await;
63/// while let Some(msg) = messages.next().await {
64/// match msg? {
65/// Message::Result { .. } => break,
66/// msg => println!("{:?}", msg),
67/// }
68/// }
69///
70/// // Send follow-up
71/// client.send_request("What's 2 + 2?".to_string(), None).await?;
72///
73/// // Receive response
74/// let mut messages = client.receive_messages().await;
75/// while let Some(msg) = messages.next().await {
76/// println!("{:?}", msg?);
77/// }
78///
79/// // Disconnect
80/// client.disconnect().await?;
81///
82/// Ok(())
83/// }
84/// ```
85pub struct ClaudeSDKClient {
86 /// Configuration options
87 #[allow(dead_code)]
88 options: ClaudeCodeOptions,
89 /// Transport layer
90 transport: Arc<Mutex<Box<dyn Transport + Send>>>,
91 /// Internal query handler (when control protocol is enabled)
92 query_handler: Option<Arc<Mutex<Query>>>,
93 /// Client state
94 state: Arc<RwLock<ClientState>>,
95 /// Active sessions
96 sessions: Arc<RwLock<HashMap<String, SessionData>>>,
97 /// Message sender for current receiver
98 message_tx: Arc<Mutex<Option<mpsc::Sender<Result<Message>>>>>,
99 /// Message buffer for multiple receivers
100 message_buffer: Arc<Mutex<Vec<Message>>>,
101 /// Request counter
102 request_counter: Arc<Mutex<u64>>,
103 /// Budget manager for token tracking
104 budget_manager: BudgetManager,
105}
106
107/// Session data
108#[allow(dead_code)]
109struct SessionData {
110 /// Session ID
111 id: String,
112 /// Number of messages sent
113 message_count: usize,
114 /// Creation time
115 created_at: std::time::Instant,
116}
117
118impl ClaudeSDKClient {
119 /// Create a new client with the given options
120 pub fn new(options: ClaudeCodeOptions) -> Self {
121 let transport = match SubprocessTransport::new(options.clone()) {
122 Ok(t) => t,
123 Err(e) => {
124 error!("Failed to create transport: {}", e);
125 // Create with empty path, will fail on connect
126 SubprocessTransport::with_cli_path(options.clone(), "")
127 }
128 };
129
130 // Wrap transport in Arc for sharing
131 let transport_arc: Arc<Mutex<Box<dyn Transport + Send>>> =
132 Arc::new(Mutex::new(Box::new(transport)));
133
134 Self::with_transport_internal(options, transport_arc)
135 }
136
137 /// Create a new client with a custom transport implementation
138 ///
139 /// This allows users to provide their own Transport implementation instead of
140 /// using the default SubprocessTransport. Useful for testing, custom CLI paths,
141 /// or alternative communication mechanisms.
142 ///
143 /// # Arguments
144 ///
145 /// * `options` - Configuration options for the client
146 /// * `transport` - Custom transport implementation
147 ///
148 /// # Example
149 ///
150 /// ```rust,no_run
151 /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions, SubprocessTransport};
152 /// # fn example() {
153 /// let options = ClaudeCodeOptions::default();
154 /// let transport = SubprocessTransport::with_cli_path(options.clone(), "/custom/path/claude-code");
155 /// let client = ClaudeSDKClient::with_transport(options, Box::new(transport));
156 /// # }
157 /// ```
158 pub fn with_transport(options: ClaudeCodeOptions, transport: Box<dyn Transport + Send>) -> Self {
159 // Wrap transport in Arc for sharing
160 let transport_arc: Arc<Mutex<Box<dyn Transport + Send>>> =
161 Arc::new(Mutex::new(transport));
162
163 Self::with_transport_internal(options, transport_arc)
164 }
165
166 /// Internal helper to construct client with pre-wrapped transport
167 fn with_transport_internal(
168 mut options: ClaudeCodeOptions,
169 transport_arc: Arc<Mutex<Box<dyn Transport + Send>>>,
170 ) -> Self {
171 // Auto-configure permission_prompt_tool_name when can_use_tool is set
172 // (matching Python SDK behavior: route permission requests via stdio control protocol)
173 if options.can_use_tool.is_some() && options.permission_prompt_tool_name.is_none() {
174 options.permission_prompt_tool_name = Some("stdio".to_string());
175 }
176
177 // Create query handler if control protocol features are enabled
178 let query_handler = if options.can_use_tool.is_some()
179 || options.hooks.is_some()
180 || !options.mcp_servers.is_empty()
181 || options.enable_file_checkpointing {
182 // Extract SDK MCP server instances
183 let sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>> = options.mcp_servers
184 .iter()
185 .filter_map(|(k, v)| {
186 // Only extract SDK type MCP servers
187 if let crate::types::McpServerConfig::Sdk { name: _, instance } = v {
188 Some((k.clone(), instance.clone()))
189 } else {
190 None
191 }
192 })
193 .collect();
194
195 // Enable streaming mode when control protocol is active
196 let is_streaming = options.can_use_tool.is_some()
197 || options.hooks.is_some()
198 || !sdk_mcp_servers.is_empty();
199
200 let query = Query::new(
201 transport_arc.clone(), // Share the same transport
202 is_streaming, // Enable streaming for control protocol
203 options.can_use_tool.clone(),
204 options.hooks.clone(),
205 sdk_mcp_servers,
206 );
207 Some(Arc::new(Mutex::new(query)))
208 } else {
209 None
210 };
211
212 Self {
213 options,
214 transport: transport_arc,
215 query_handler,
216 state: Arc::new(RwLock::new(ClientState::Disconnected)),
217 sessions: Arc::new(RwLock::new(HashMap::new())),
218 message_tx: Arc::new(Mutex::new(None)),
219 message_buffer: Arc::new(Mutex::new(Vec::new())),
220 request_counter: Arc::new(Mutex::new(0)),
221 budget_manager: BudgetManager::new(),
222 }
223 }
224
225 /// Connect to Claude CLI with an optional initial prompt
226 pub async fn connect(&mut self, initial_prompt: Option<String>) -> Result<()> {
227 // Check if already connected
228 {
229 let state = self.state.read().await;
230 if *state == ClientState::Connected {
231 return Ok(());
232 }
233 }
234
235 // Connect transport
236 {
237 let mut transport = self.transport.lock().await;
238 transport.connect().await?;
239 }
240
241 // Initialize query handler if present
242 if let Some(ref query_handler) = self.query_handler {
243 let mut handler = query_handler.lock().await;
244 handler.start().await?;
245 handler.initialize().await?;
246 info!("Initialized SDK control protocol");
247 }
248
249 // Update state
250 {
251 let mut state = self.state.write().await;
252 *state = ClientState::Connected;
253 }
254
255 info!("Connected to Claude CLI");
256
257 // Start message receiver task (always needed for regular messages)
258 self.start_message_receiver().await;
259
260 // Send initial prompt if provided
261 if let Some(prompt) = initial_prompt {
262 self.send_request(prompt, None).await?;
263 }
264
265 Ok(())
266 }
267
268 /// Send a user message to Claude
269 ///
270 /// Uses the session ID from `ClaudeCodeOptions::session_id` if set,
271 /// otherwise falls back to "default".
272 pub async fn send_user_message(&mut self, prompt: String) -> Result<()> {
273 let session_id = self
274 .options
275 .session_id
276 .clone()
277 .unwrap_or_else(|| "default".to_string());
278 self.send_user_message_with_session(prompt, session_id).await
279 }
280
281 async fn send_user_message_with_session(
282 &mut self,
283 prompt: String,
284 session_id: String,
285 ) -> Result<()> {
286 // Check connection
287 {
288 let state = self.state.read().await;
289 if *state != ClientState::Connected {
290 return Err(SdkError::InvalidState {
291 message: "Not connected".into(),
292 });
293 }
294 }
295
296 // Update session data
297 {
298 let mut sessions = self.sessions.write().await;
299 let session = sessions.entry(session_id.clone()).or_insert_with(|| {
300 debug!("Creating new session: {}", session_id);
301 SessionData {
302 id: session_id.clone(),
303 message_count: 0,
304 created_at: std::time::Instant::now(),
305 }
306 });
307 session.message_count += 1;
308 }
309
310 // Create and send message
311 let message = InputMessage::user(prompt, session_id.clone());
312
313 {
314 let mut transport = self.transport.lock().await;
315 transport.send_message(message).await?;
316 }
317
318 debug!("Sent request to Claude");
319 Ok(())
320 }
321
322 /// Send a request to Claude with an optional explicit session ID
323 ///
324 /// Priority: explicit `session_id` argument > `ClaudeCodeOptions::session_id` > "default"
325 pub async fn send_request(
326 &mut self,
327 prompt: String,
328 session_id: Option<String>,
329 ) -> Result<()> {
330 let effective_session_id = session_id
331 .or_else(|| self.options.session_id.clone())
332 .unwrap_or_else(|| "default".to_string());
333 self.send_user_message_with_session(prompt, effective_session_id).await
334 }
335
336 /// Receive messages from Claude
337 ///
338 /// Returns a stream of messages. The stream will end when a Result message
339 /// is received or the connection is closed.
340 pub async fn receive_messages(&mut self) -> impl Stream<Item = Result<Message>> + use<> {
341 // Always use the regular message receiver
342 // (Query handler shares the same transport and receives control messages separately)
343 // Create a new channel for this receiver
344 let (tx, rx) = mpsc::channel(100);
345
346 // Get buffered messages and clear buffer
347 let buffered_messages = {
348 let mut buffer = self.message_buffer.lock().await;
349 std::mem::take(&mut *buffer)
350 };
351
352 // Send buffered messages to the new receiver
353 let tx_clone = tx.clone();
354 tokio::spawn(async move {
355 for msg in buffered_messages {
356 if tx_clone.send(Ok(msg)).await.is_err() {
357 break;
358 }
359 }
360 });
361
362 // Store the sender for the message receiver task
363 {
364 let mut message_tx = self.message_tx.lock().await;
365 *message_tx = Some(tx);
366 }
367
368 ReceiverStream::new(rx)
369 }
370
371 /// Send an interrupt request
372 pub async fn interrupt(&mut self) -> Result<()> {
373 // Check connection
374 {
375 let state = self.state.read().await;
376 if *state != ClientState::Connected {
377 return Err(SdkError::InvalidState {
378 message: "Not connected".into(),
379 });
380 }
381 }
382
383 // If we have a query handler, use it
384 if let Some(ref query_handler) = self.query_handler {
385 let mut handler = query_handler.lock().await;
386 return handler.interrupt().await;
387 }
388
389 // Otherwise use regular interrupt
390 // Generate request ID
391 let request_id = {
392 let mut counter = self.request_counter.lock().await;
393 *counter += 1;
394 format!("interrupt_{}", *counter)
395 };
396
397 // Send interrupt request
398 let request = ControlRequest::Interrupt {
399 request_id: request_id.clone(),
400 };
401
402 {
403 let mut transport = self.transport.lock().await;
404 transport.send_control_request(request).await?;
405 }
406
407 info!("Sent interrupt request: {}", request_id);
408
409 // Wait for acknowledgment (with timeout)
410 let transport = self.transport.clone();
411 let ack_task = tokio::spawn(async move {
412 let mut transport = transport.lock().await;
413 match tokio::time::timeout(
414 std::time::Duration::from_secs(5),
415 transport.receive_control_response(),
416 )
417 .await
418 {
419 Ok(Ok(Some(ControlResponse::InterruptAck {
420 request_id: ack_id,
421 success,
422 }))) => {
423 if ack_id == request_id && success {
424 Ok(())
425 } else {
426 Err(SdkError::ControlRequestError(
427 "Interrupt not acknowledged successfully".into(),
428 ))
429 }
430 }
431 Ok(Ok(None)) => Err(SdkError::ControlRequestError(
432 "No interrupt acknowledgment received".into(),
433 )),
434 Ok(Err(e)) => Err(e),
435 Err(_) => Err(SdkError::timeout(5)),
436 }
437 });
438
439 ack_task
440 .await
441 .map_err(|_| SdkError::ControlRequestError("Interrupt task panicked".into()))?
442 }
443
444 /// Check if the client is connected
445 pub async fn is_connected(&self) -> bool {
446 let state = self.state.read().await;
447 *state == ClientState::Connected
448 }
449
450 /// Get active session IDs
451 pub async fn get_sessions(&self) -> Vec<String> {
452 let sessions = self.sessions.read().await;
453 sessions.keys().cloned().collect()
454 }
455
456 /// Receive messages until and including a ResultMessage
457 ///
458 /// This is a convenience method that collects all messages from a single response.
459 /// It will automatically stop after receiving a ResultMessage.
460 pub async fn receive_response(&mut self) -> Pin<Box<dyn Stream<Item = Result<Message>> + Send + '_>> {
461 let mut messages = self.receive_messages().await;
462
463 // Create a stream that stops after ResultMessage
464 Box::pin(async_stream::stream! {
465 while let Some(msg_result) = messages.next().await {
466 match &msg_result {
467 Ok(Message::Result { .. }) => {
468 yield msg_result;
469 return;
470 }
471 _ => {
472 yield msg_result;
473 }
474 }
475 }
476 })
477 }
478
479 /// Get server information
480 ///
481 /// Returns initialization information from the Claude Code server including:
482 /// - Available commands
483 /// - Current and available output styles
484 /// - Server capabilities
485 pub async fn get_server_info(&self) -> Option<serde_json::Value> {
486 // If we have a query handler with control protocol, get from there
487 if let Some(ref query_handler) = self.query_handler {
488 let handler = query_handler.lock().await;
489 if let Some(init_result) = handler.get_initialization_result() {
490 return Some(init_result.clone());
491 }
492 }
493
494 // Otherwise check message buffer for init message
495 let buffer = self.message_buffer.lock().await;
496 for msg in buffer.iter() {
497 if let Message::System { subtype, data } = msg
498 && subtype == "init" {
499 return Some(data.clone());
500 }
501 }
502 None
503 }
504
505 /// Get account information
506 ///
507 /// This method attempts to retrieve Claude account information through multiple methods:
508 /// 1. From environment variable `ANTHROPIC_USER_EMAIL`
509 /// 2. From Claude CLI config file (if accessible)
510 /// 3. By querying the CLI with `/status` command (interactive mode)
511 ///
512 /// # Returns
513 ///
514 /// A string containing the account information, or an error if unavailable.
515 ///
516 /// # Example
517 ///
518 /// ```rust,no_run
519 /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
520 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
521 /// let mut client = ClaudeSDKClient::new(ClaudeCodeOptions::default());
522 /// client.connect(None).await?;
523 ///
524 /// match client.get_account_info().await {
525 /// Ok(info) => println!("Account: {}", info),
526 /// Err(_) => println!("Account info not available"),
527 /// }
528 /// # Ok(())
529 /// # }
530 /// ```
531 ///
532 /// # Note
533 ///
534 /// Account information may not always be available in SDK mode.
535 /// Consider setting the `ANTHROPIC_USER_EMAIL` environment variable
536 /// for reliable account identification.
537 pub async fn get_account_info(&mut self) -> Result<String> {
538 // Check connection
539 {
540 let state = self.state.read().await;
541 if *state != ClientState::Connected {
542 return Err(SdkError::InvalidState {
543 message: "Not connected. Call connect() first.".into(),
544 });
545 }
546 }
547
548 // Method 1: Check environment variable
549 if let Ok(email) = std::env::var("ANTHROPIC_USER_EMAIL") {
550 return Ok(format!("Email: {}", email));
551 }
552
553 // Method 2: Try reading from Claude config
554 if let Some(config_info) = Self::read_claude_config().await {
555 return Ok(config_info);
556 }
557
558 // Method 3: Try /status command (may not work in SDK mode)
559 self.send_user_message("/status".to_string()).await?;
560
561 let mut messages = self.receive_messages().await;
562 let mut account_info = String::new();
563
564 while let Some(msg_result) = messages.next().await {
565 match msg_result? {
566 Message::Assistant { message } => {
567 for block in message.content {
568 if let ContentBlock::Text(text) = block {
569 account_info.push_str(&text.text);
570 account_info.push('\n');
571 }
572 }
573 }
574 Message::Result { .. } => break,
575 _ => {}
576 }
577 }
578
579 let trimmed = account_info.trim();
580
581 // Check if we got actual status info or just a chat response
582 if !trimmed.is_empty() && (
583 trimmed.contains("account") ||
584 trimmed.contains("email") ||
585 trimmed.contains("subscription") ||
586 trimmed.contains("authenticated")
587 ) {
588 return Ok(trimmed.to_string());
589 }
590
591 Err(SdkError::InvalidState {
592 message: "Account information not available. Try setting ANTHROPIC_USER_EMAIL environment variable.".into(),
593 })
594 }
595
596 /// Read Claude config file
597 async fn read_claude_config() -> Option<String> {
598 // Try common config locations
599 let config_paths = vec![
600 dirs::home_dir()?.join(".config").join("claude").join("config.json"),
601 dirs::home_dir()?.join(".claude").join("config.json"),
602 ];
603
604 for path in config_paths {
605 if let Ok(content) = tokio::fs::read_to_string(&path).await {
606 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&content) {
607 if let Some(email) = json.get("email").and_then(|v| v.as_str()) {
608 return Some(format!("Email: {}", email));
609 }
610 if let Some(user) = json.get("user").and_then(|v| v.as_str()) {
611 return Some(format!("User: {}", user));
612 }
613 }
614 }
615 }
616
617 None
618 }
619
620 /// Set permission mode dynamically
621 ///
622 /// Changes the permission mode during an active session.
623 /// Requires control protocol to be enabled (via can_use_tool, hooks, mcp_servers, or file checkpointing).
624 ///
625 /// # Arguments
626 ///
627 /// * `mode` - Permission mode: "default", "acceptEdits", "plan", or "bypassPermissions"
628 ///
629 /// # Example
630 ///
631 /// ```rust,no_run
632 /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
633 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
634 /// let mut client = ClaudeSDKClient::new(ClaudeCodeOptions::default());
635 /// client.connect(None).await?;
636 ///
637 /// // Switch to accept edits mode
638 /// client.set_permission_mode("acceptEdits").await?;
639 /// # Ok(())
640 /// # }
641 /// ```
642 pub async fn set_permission_mode(&mut self, mode: &str) -> Result<()> {
643 if let Some(ref query_handler) = self.query_handler {
644 let mut handler = query_handler.lock().await;
645 handler.set_permission_mode(mode).await
646 } else {
647 Err(SdkError::InvalidState {
648 message: "Query handler not initialized. Enable control protocol features (can_use_tool, hooks, mcp_servers, or enable_file_checkpointing).".to_string(),
649 })
650 }
651 }
652
653 /// Set model dynamically
654 ///
655 /// Changes the active model during an active session.
656 /// Requires control protocol to be enabled (via can_use_tool, hooks, mcp_servers, or file checkpointing).
657 ///
658 /// # Arguments
659 ///
660 /// * `model` - Model identifier (e.g., "claude-3-5-sonnet-20241022") or None to use default
661 ///
662 /// # Example
663 ///
664 /// ```rust,no_run
665 /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
666 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
667 /// let mut client = ClaudeSDKClient::new(ClaudeCodeOptions::default());
668 /// client.connect(None).await?;
669 ///
670 /// // Switch to a different model
671 /// client.set_model(Some("claude-3-5-sonnet-20241022".to_string())).await?;
672 /// # Ok(())
673 /// # }
674 /// ```
675 pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
676 if let Some(ref query_handler) = self.query_handler {
677 let mut handler = query_handler.lock().await;
678 handler.set_model(model).await
679 } else {
680 Err(SdkError::InvalidState {
681 message: "Query handler not initialized. Enable control protocol features (can_use_tool, hooks, mcp_servers, or enable_file_checkpointing).".to_string(),
682 })
683 }
684 }
685
686 /// Send a query with optional session ID
687 ///
688 /// This method is similar to Python SDK's query method in ClaudeSDKClient
689 pub async fn query(&mut self, prompt: String, session_id: Option<String>) -> Result<()> {
690 let session_id = session_id.unwrap_or_else(|| "default".to_string());
691
692 // Send the message
693 let message = InputMessage::user(prompt, session_id);
694
695 {
696 let mut transport = self.transport.lock().await;
697 transport.send_message(message).await?;
698 }
699
700 Ok(())
701 }
702
703 /// Rewind tracked files to their state at a specific user message
704 ///
705 /// Requires `enable_file_checkpointing` to be enabled in `ClaudeCodeOptions`.
706 /// This method allows you to undo file changes made during the session by
707 /// reverting them to their state at any previous user message checkpoint.
708 ///
709 /// # Arguments
710 ///
711 /// * `user_message_id` - UUID of the user message to rewind to. This should be
712 /// the `uuid` field from a message received during the conversation.
713 ///
714 /// # Example
715 ///
716 /// ```rust,no_run
717 /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
718 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
719 /// let options = ClaudeCodeOptions::builder()
720 /// .enable_file_checkpointing(true)
721 /// .build();
722 /// let mut client = ClaudeSDKClient::new(options);
723 /// client.connect(None).await?;
724 ///
725 /// // Ask Claude to make some changes
726 /// client.send_request("Make some changes to my files".to_string(), None).await?;
727 ///
728 /// // ... later, rewind to a checkpoint
729 /// // client.rewind_files("user-message-uuid-here").await?;
730 /// # Ok(())
731 /// # }
732 /// ```
733 ///
734 /// # Errors
735 ///
736 /// Returns an error if:
737 /// - The client is not connected
738 /// - The query handler is not initialized (control protocol required)
739 /// - File checkpointing is not enabled
740 /// - The specified user_message_id is invalid
741 pub async fn rewind_files(&mut self, user_message_id: &str) -> Result<()> {
742 // Check connection
743 {
744 let state = self.state.read().await;
745 if *state != ClientState::Connected {
746 return Err(SdkError::InvalidState {
747 message: "Not connected. Call connect() first.".into(),
748 });
749 }
750 }
751
752 if !self.options.enable_file_checkpointing {
753 return Err(SdkError::InvalidState {
754 message: "File checkpointing is not enabled. Set ClaudeCodeOptions::builder().enable_file_checkpointing(true).".to_string(),
755 });
756 }
757
758 // Require query handler for control protocol
759 if let Some(ref query_handler) = self.query_handler {
760 let mut handler = query_handler.lock().await;
761 handler.rewind_files(user_message_id).await
762 } else {
763 Err(SdkError::InvalidState {
764 message: "Query handler not initialized. Enable control protocol features (can_use_tool, hooks, mcp_servers, or enable_file_checkpointing).".to_string(),
765 })
766 }
767 }
768
769 /// Get context usage information including token distribution and cache stats
770 ///
771 /// Returns detailed context window usage including categories, API usage with
772 /// cache token information, and auto-compact settings.
773 ///
774 /// Requires control protocol to be enabled.
775 pub async fn get_context_usage(&mut self) -> Result<crate::types::ContextUsageResponse> {
776 {
777 let state = self.state.read().await;
778 if *state != ClientState::Connected {
779 return Err(SdkError::InvalidState {
780 message: "Not connected. Call connect() first.".into(),
781 });
782 }
783 }
784
785 if let Some(ref query_handler) = self.query_handler {
786 let mut handler = query_handler.lock().await;
787 let raw = handler.get_context_usage().await?;
788 serde_json::from_value(raw).map_err(|e| SdkError::ControlRequestError(
789 format!("Failed to parse context usage response: {e}")
790 ))
791 } else {
792 Err(SdkError::InvalidState {
793 message: "Query handler not initialized. Enable control protocol features.".to_string(),
794 })
795 }
796 }
797
798 /// Stop a background task by ID
799 ///
800 /// Requires control protocol to be enabled.
801 pub async fn stop_task(&mut self, task_id: &str) -> Result<()> {
802 {
803 let state = self.state.read().await;
804 if *state != ClientState::Connected {
805 return Err(SdkError::InvalidState {
806 message: "Not connected. Call connect() first.".into(),
807 });
808 }
809 }
810
811 if let Some(ref query_handler) = self.query_handler {
812 let mut handler = query_handler.lock().await;
813 handler.stop_task(task_id).await
814 } else {
815 Err(SdkError::InvalidState {
816 message: "Query handler not initialized. Enable control protocol features.".to_string(),
817 })
818 }
819 }
820
821 /// Get MCP server status information
822 ///
823 /// Returns the current status of all connected MCP servers.
824 /// Requires control protocol to be enabled.
825 pub async fn get_mcp_status(&mut self) -> Result<serde_json::Value> {
826 {
827 let state = self.state.read().await;
828 if *state != ClientState::Connected {
829 return Err(SdkError::InvalidState {
830 message: "Not connected. Call connect() first.".into(),
831 });
832 }
833 }
834
835 if let Some(ref query_handler) = self.query_handler {
836 let mut handler = query_handler.lock().await;
837 handler.get_mcp_status().await
838 } else {
839 Err(SdkError::InvalidState {
840 message: "Query handler not initialized. Enable control protocol features.".to_string(),
841 })
842 }
843 }
844
845 /// Reconnect a failed MCP server
846 ///
847 /// Attempts to re-establish connection to the specified MCP server.
848 /// Requires control protocol to be enabled.
849 pub async fn reconnect_mcp_server(&mut self, server_name: &str) -> Result<()> {
850 {
851 let state = self.state.read().await;
852 if *state != ClientState::Connected {
853 return Err(SdkError::InvalidState {
854 message: "Not connected. Call connect() first.".into(),
855 });
856 }
857 }
858
859 if let Some(ref query_handler) = self.query_handler {
860 let mut handler = query_handler.lock().await;
861 handler.reconnect_mcp_server(server_name).await
862 } else {
863 Err(SdkError::InvalidState {
864 message: "Query handler not initialized. Enable control protocol features.".to_string(),
865 })
866 }
867 }
868
869 /// Toggle an MCP server on/off
870 ///
871 /// Enables or disables the specified MCP server.
872 /// Requires control protocol to be enabled.
873 pub async fn toggle_mcp_server(&mut self, server_name: &str, enabled: bool) -> Result<()> {
874 {
875 let state = self.state.read().await;
876 if *state != ClientState::Connected {
877 return Err(SdkError::InvalidState {
878 message: "Not connected. Call connect() first.".into(),
879 });
880 }
881 }
882
883 if let Some(ref query_handler) = self.query_handler {
884 let mut handler = query_handler.lock().await;
885 handler.toggle_mcp_server(server_name, enabled).await
886 } else {
887 Err(SdkError::InvalidState {
888 message: "Query handler not initialized. Enable control protocol features.".to_string(),
889 })
890 }
891 }
892
893 /// Disconnect from Claude CLI
894 pub async fn disconnect(&mut self) -> Result<()> {
895 // Check if already disconnected
896 {
897 let state = self.state.read().await;
898 if *state == ClientState::Disconnected {
899 return Ok(());
900 }
901 }
902
903 // Disconnect transport
904 {
905 let mut transport = self.transport.lock().await;
906 transport.disconnect().await?;
907 }
908
909 // Update state
910 {
911 let mut state = self.state.write().await;
912 *state = ClientState::Disconnected;
913 }
914
915 // Clear sessions
916 {
917 let mut sessions = self.sessions.write().await;
918 sessions.clear();
919 }
920
921 info!("Disconnected from Claude CLI");
922 Ok(())
923 }
924
925 /// Start the message receiver task
926 async fn start_message_receiver(&mut self) {
927 let transport = self.transport.clone();
928 let message_tx = self.message_tx.clone();
929 let message_buffer = self.message_buffer.clone();
930 let state = self.state.clone();
931 let budget_manager = self.budget_manager.clone();
932
933 tokio::spawn(async move {
934 // Subscribe to messages without holding the lock
935 let mut stream = {
936 let mut transport = transport.lock().await;
937 transport.receive_messages()
938 }; // Lock is released here immediately
939
940 while let Some(result) = stream.next().await {
941 match result {
942 Ok(message) => {
943 // Update token usage for Result messages
944 if let Message::Result { .. } = &message
945 && let Message::Result { usage, total_cost_usd, .. } = &message {
946 let (input_tokens, output_tokens) = if let Some(usage_json) = usage {
947 let input = usage_json.get("input_tokens")
948 .and_then(|v| v.as_u64())
949 .unwrap_or(0);
950 let output = usage_json.get("output_tokens")
951 .and_then(|v| v.as_u64())
952 .unwrap_or(0);
953 (input, output)
954 } else {
955 (0, 0)
956 };
957 let cost = total_cost_usd.unwrap_or(0.0);
958 budget_manager.update_usage(input_tokens, output_tokens, cost).await;
959 }
960
961 // Buffer init messages for get_server_info()
962 if let Message::System { subtype, .. } = &message
963 && subtype == "init" {
964 let mut buffer = message_buffer.lock().await;
965 buffer.push(message.clone());
966 }
967
968 // Try to send to current receiver
969 let sent = {
970 let mut tx_opt = message_tx.lock().await;
971 if let Some(tx) = tx_opt.as_mut() {
972 tx.send(Ok(message.clone())).await.is_ok()
973 } else {
974 false
975 }
976 };
977
978 // If no receiver or send failed, buffer the message
979 if !sent {
980 let mut buffer = message_buffer.lock().await;
981 buffer.push(message);
982 }
983 }
984 Err(e) => {
985 error!("Error receiving message: {}", e);
986
987 // Send error to receiver if available
988 let mut tx_opt = message_tx.lock().await;
989 if let Some(tx) = tx_opt.as_mut() {
990 let _ = tx.send(Err(e)).await;
991 }
992
993 // Update state on error
994 let mut state = state.write().await;
995 *state = ClientState::Error;
996 break;
997 }
998 }
999 }
1000
1001 debug!("Message receiver task ended");
1002 });
1003 }
1004
1005 /// Get token usage statistics
1006 ///
1007 /// Returns the current token usage tracker with cumulative statistics
1008 /// for all queries executed by this client.
1009 pub async fn get_usage_stats(&self) -> crate::token_tracker::TokenUsageTracker {
1010 self.budget_manager.get_usage().await
1011 }
1012
1013 /// Set budget limit with optional warning callback
1014 ///
1015 /// # Arguments
1016 ///
1017 /// * `limit` - Budget limit configuration (cost and/or token caps)
1018 /// * `on_warning` - Optional callback function triggered when usage exceeds warning threshold
1019 ///
1020 /// # Example
1021 ///
1022 /// ```rust,no_run
1023 /// use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
1024 /// use cc_sdk::token_tracker::{BudgetLimit, BudgetWarningCallback};
1025 /// use std::sync::Arc;
1026 ///
1027 /// # async fn example() {
1028 /// let mut client = ClaudeSDKClient::new(ClaudeCodeOptions::default());
1029 ///
1030 /// // Set budget with callback
1031 /// let cb: BudgetWarningCallback = Arc::new(|msg: &str| println!("Budget warning: {}", msg));
1032 /// client.set_budget_limit(BudgetLimit::with_cost(5.0), Some(cb)).await;
1033 /// # }
1034 /// ```
1035 pub async fn set_budget_limit(
1036 &self,
1037 limit: crate::token_tracker::BudgetLimit,
1038 on_warning: Option<crate::token_tracker::BudgetWarningCallback>,
1039 ) {
1040 self.budget_manager.set_limit(limit).await;
1041 if let Some(callback) = on_warning {
1042 self.budget_manager.set_warning_callback(callback).await;
1043 }
1044 }
1045
1046 /// Clear budget limit and reset warning state
1047 pub async fn clear_budget_limit(&self) {
1048 self.budget_manager.clear_limit().await;
1049 }
1050
1051 /// Reset token usage statistics to zero
1052 ///
1053 /// Clears all accumulated token and cost statistics.
1054 /// Budget limits remain in effect.
1055 pub async fn reset_usage_stats(&self) {
1056 self.budget_manager.reset_usage().await;
1057 }
1058
1059 /// Check if budget has been exceeded
1060 ///
1061 /// Returns true if current usage exceeds any configured limits
1062 pub async fn is_budget_exceeded(&self) -> bool {
1063 self.budget_manager.is_exceeded().await
1064 }
1065
1066 // Removed unused helper; usage is updated inline in message receiver
1067}
1068
1069impl Drop for ClaudeSDKClient {
1070 fn drop(&mut self) {
1071 // Try to disconnect gracefully
1072 let transport = self.transport.clone();
1073 let state = self.state.clone();
1074
1075 if let Ok(handle) = tokio::runtime::Handle::try_current() {
1076 handle.spawn(async move {
1077 let state = state.read().await;
1078 if *state == ClientState::Connected {
1079 let mut transport = transport.lock().await;
1080 if let Err(e) = transport.disconnect().await {
1081 debug!("Error disconnecting in drop: {}", e);
1082 }
1083 }
1084 });
1085 }
1086 }
1087}
1088
1089#[cfg(test)]
1090mod tests {
1091 use super::*;
1092
1093 #[tokio::test]
1094 async fn test_client_lifecycle() {
1095 let options = ClaudeCodeOptions::default();
1096 let client = ClaudeSDKClient::new(options);
1097
1098 assert!(!client.is_connected().await);
1099 assert_eq!(client.get_sessions().await.len(), 0);
1100 }
1101
1102 #[tokio::test]
1103 async fn test_client_state_transitions() {
1104 let options = ClaudeCodeOptions::default();
1105 let client = ClaudeSDKClient::new(options);
1106
1107 let state = client.state.read().await;
1108 assert_eq!(*state, ClientState::Disconnected);
1109 }
1110
1111 #[test]
1112 fn test_file_checkpointing_enables_query_handler() {
1113 let options = ClaudeCodeOptions::builder()
1114 .enable_file_checkpointing(true)
1115 .build();
1116 let client = ClaudeSDKClient::new(options);
1117
1118 assert!(
1119 client.query_handler.is_some(),
1120 "enable_file_checkpointing should initialize the query handler for control protocol requests"
1121 );
1122 }
1123
1124 #[test]
1125 fn test_session_id_option_defaults_to_none() {
1126 let options = ClaudeCodeOptions::default();
1127 assert!(options.session_id.is_none());
1128 }
1129
1130 #[test]
1131 fn test_session_id_builder() {
1132 let options = ClaudeCodeOptions::builder()
1133 .session_id("my-session-123")
1134 .build();
1135 assert_eq!(options.session_id.as_deref(), Some("my-session-123"));
1136 }
1137}