cc_sdk/client.rs
1//! Interactive client for bidirectional communication with Claude
2//!
3//! This module provides the `ClaudeSDKClient` for interactive, stateful
4//! conversations with Claude Code CLI.
5
6use crate::{
7 errors::{Result, SdkError},
8 internal_query::Query,
9 token_tracker::BudgetManager,
10 transport::{InputMessage, SubprocessTransport, Transport},
11 types::{ClaudeCodeOptions, ContentBlock, ControlRequest, ControlResponse, Message},
12};
13use futures::stream::{Stream, StreamExt};
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::pin::Pin;
17use tokio::sync::{Mutex, RwLock, mpsc};
18use tokio_stream::wrappers::ReceiverStream;
19use tracing::{debug, error, info};
20
21/// Client state
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum ClientState {
24 /// Not connected
25 Disconnected,
26 /// Connected and ready
27 Connected,
28 /// Error state
29 Error,
30}
31
32/// Interactive client for bidirectional communication with Claude
33///
34/// `ClaudeSDKClient` provides a stateful, interactive interface for communicating
35/// with Claude Code CLI. Unlike the simple `query` function, this client supports:
36///
37/// - Bidirectional communication
38/// - Multiple sessions
39/// - Interrupt capabilities
40/// - State management
41/// - Follow-up messages based on responses
42///
43/// # Example
44///
45/// ```rust,no_run
46/// use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions, Message, Result};
47/// use futures::StreamExt;
48///
49/// #[tokio::main]
50/// async fn main() -> Result<()> {
51/// let options = ClaudeCodeOptions::builder()
52/// .system_prompt("You are a helpful assistant")
53/// .model("claude-3-opus-20240229")
54/// .build();
55///
56/// let mut client = ClaudeSDKClient::new(options);
57///
58/// // Connect with initial prompt
59/// client.connect(Some("Hello!".to_string())).await?;
60///
61/// // Receive initial response
62/// let mut messages = client.receive_messages().await;
63/// while let Some(msg) = messages.next().await {
64/// match msg? {
65/// Message::Result { .. } => break,
66/// msg => println!("{:?}", msg),
67/// }
68/// }
69///
70/// // Send follow-up
71/// client.send_request("What's 2 + 2?".to_string(), None).await?;
72///
73/// // Receive response
74/// let mut messages = client.receive_messages().await;
75/// while let Some(msg) = messages.next().await {
76/// println!("{:?}", msg?);
77/// }
78///
79/// // Disconnect
80/// client.disconnect().await?;
81///
82/// Ok(())
83/// }
84/// ```
85pub struct ClaudeSDKClient {
86 /// Configuration options
87 #[allow(dead_code)]
88 options: ClaudeCodeOptions,
89 /// Transport layer
90 transport: Arc<Mutex<Box<dyn Transport + Send>>>,
91 /// Internal query handler (when control protocol is enabled)
92 query_handler: Option<Arc<Mutex<Query>>>,
93 /// Client state
94 state: Arc<RwLock<ClientState>>,
95 /// Active sessions
96 sessions: Arc<RwLock<HashMap<String, SessionData>>>,
97 /// Message sender for current receiver
98 message_tx: Arc<Mutex<Option<mpsc::Sender<Result<Message>>>>>,
99 /// Message buffer for multiple receivers
100 message_buffer: Arc<Mutex<Vec<Message>>>,
101 /// Request counter
102 request_counter: Arc<Mutex<u64>>,
103 /// Budget manager for token tracking
104 budget_manager: BudgetManager,
105}
106
107/// Session data
108#[allow(dead_code)]
109struct SessionData {
110 /// Session ID
111 id: String,
112 /// Number of messages sent
113 message_count: usize,
114 /// Creation time
115 created_at: std::time::Instant,
116}
117
118impl ClaudeSDKClient {
119 /// Create a new client with the given options
120 pub fn new(options: ClaudeCodeOptions) -> Self {
121 let transport = match SubprocessTransport::new(options.clone()) {
122 Ok(t) => t,
123 Err(e) => {
124 error!("Failed to create transport: {}", e);
125 // Create with empty path, will fail on connect
126 SubprocessTransport::with_cli_path(options.clone(), "")
127 }
128 };
129
130 // Wrap transport in Arc for sharing
131 let transport_arc: Arc<Mutex<Box<dyn Transport + Send>>> =
132 Arc::new(Mutex::new(Box::new(transport)));
133
134 Self::with_transport_internal(options, transport_arc)
135 }
136
137 /// Create a new client with a custom transport implementation
138 ///
139 /// This allows users to provide their own Transport implementation instead of
140 /// using the default SubprocessTransport. Useful for testing, custom CLI paths,
141 /// or alternative communication mechanisms.
142 ///
143 /// # Arguments
144 ///
145 /// * `options` - Configuration options for the client
146 /// * `transport` - Custom transport implementation
147 ///
148 /// # Example
149 ///
150 /// ```rust,no_run
151 /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions, SubprocessTransport};
152 /// # fn example() {
153 /// let options = ClaudeCodeOptions::default();
154 /// let transport = SubprocessTransport::with_cli_path(options.clone(), "/custom/path/claude-code");
155 /// let client = ClaudeSDKClient::with_transport(options, Box::new(transport));
156 /// # }
157 /// ```
158 pub fn with_transport(options: ClaudeCodeOptions, transport: Box<dyn Transport + Send>) -> Self {
159 // Wrap transport in Arc for sharing
160 let transport_arc: Arc<Mutex<Box<dyn Transport + Send>>> =
161 Arc::new(Mutex::new(transport));
162
163 Self::with_transport_internal(options, transport_arc)
164 }
165
166 /// Internal helper to construct client with pre-wrapped transport
167 fn with_transport_internal(
168 mut options: ClaudeCodeOptions,
169 transport_arc: Arc<Mutex<Box<dyn Transport + Send>>>,
170 ) -> Self {
171 // Auto-configure permission_prompt_tool_name when can_use_tool is set
172 // (matching Python SDK behavior: route permission requests via stdio control protocol)
173 if options.can_use_tool.is_some() && options.permission_prompt_tool_name.is_none() {
174 options.permission_prompt_tool_name = Some("stdio".to_string());
175 }
176
177 // Create query handler if control protocol features are enabled
178 let query_handler = if options.can_use_tool.is_some()
179 || options.hooks.is_some()
180 || !options.mcp_servers.is_empty()
181 || options.enable_file_checkpointing {
182 // Extract SDK MCP server instances
183 let sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>> = options.mcp_servers
184 .iter()
185 .filter_map(|(k, v)| {
186 // Only extract SDK type MCP servers
187 if let crate::types::McpServerConfig::Sdk { name: _, instance } = v {
188 Some((k.clone(), instance.clone()))
189 } else {
190 None
191 }
192 })
193 .collect();
194
195 // Enable streaming mode when control protocol is active
196 let is_streaming = options.can_use_tool.is_some()
197 || options.hooks.is_some()
198 || !sdk_mcp_servers.is_empty();
199
200 let query = Query::new(
201 transport_arc.clone(), // Share the same transport
202 is_streaming, // Enable streaming for control protocol
203 options.can_use_tool.clone(),
204 options.hooks.clone(),
205 sdk_mcp_servers,
206 );
207 Some(Arc::new(Mutex::new(query)))
208 } else {
209 None
210 };
211
212 Self {
213 options,
214 transport: transport_arc,
215 query_handler,
216 state: Arc::new(RwLock::new(ClientState::Disconnected)),
217 sessions: Arc::new(RwLock::new(HashMap::new())),
218 message_tx: Arc::new(Mutex::new(None)),
219 message_buffer: Arc::new(Mutex::new(Vec::new())),
220 request_counter: Arc::new(Mutex::new(0)),
221 budget_manager: BudgetManager::new(),
222 }
223 }
224
225 /// Connect to Claude CLI with an optional initial prompt
226 pub async fn connect(&mut self, initial_prompt: Option<String>) -> Result<()> {
227 // Check if already connected
228 {
229 let state = self.state.read().await;
230 if *state == ClientState::Connected {
231 return Ok(());
232 }
233 }
234
235 // Connect transport
236 {
237 let mut transport = self.transport.lock().await;
238 transport.connect().await?;
239 }
240
241 // Initialize query handler if present
242 if let Some(ref query_handler) = self.query_handler {
243 let mut handler = query_handler.lock().await;
244 handler.start().await?;
245 handler.initialize().await?;
246 info!("Initialized SDK control protocol");
247 }
248
249 // Update state
250 {
251 let mut state = self.state.write().await;
252 *state = ClientState::Connected;
253 }
254
255 info!("Connected to Claude CLI");
256
257 // Start message receiver task (always needed for regular messages)
258 self.start_message_receiver().await;
259
260 // Send initial prompt if provided
261 if let Some(prompt) = initial_prompt {
262 self.send_request(prompt, None).await?;
263 }
264
265 Ok(())
266 }
267
268 /// Send a user message to Claude
269 pub async fn send_user_message(&mut self, prompt: String) -> Result<()> {
270 // Check connection
271 {
272 let state = self.state.read().await;
273 if *state != ClientState::Connected {
274 return Err(SdkError::InvalidState {
275 message: "Not connected".into(),
276 });
277 }
278 }
279
280 // Use default session ID
281 let session_id = "default".to_string();
282
283 // Update session data
284 {
285 let mut sessions = self.sessions.write().await;
286 let session = sessions.entry(session_id.clone()).or_insert_with(|| {
287 debug!("Creating new session: {}", session_id);
288 SessionData {
289 id: session_id.clone(),
290 message_count: 0,
291 created_at: std::time::Instant::now(),
292 }
293 });
294 session.message_count += 1;
295 }
296
297 // Create and send message
298 let message = InputMessage::user(prompt, session_id.clone());
299
300 {
301 let mut transport = self.transport.lock().await;
302 transport.send_message(message).await?;
303 }
304
305 debug!("Sent request to Claude");
306 Ok(())
307 }
308
309 /// Send a request to Claude (alias for send_user_message with optional session_id)
310 pub async fn send_request(
311 &mut self,
312 prompt: String,
313 _session_id: Option<String>,
314 ) -> Result<()> {
315 // For now, ignore session_id and use send_user_message
316 self.send_user_message(prompt).await
317 }
318
319 /// Receive messages from Claude
320 ///
321 /// Returns a stream of messages. The stream will end when a Result message
322 /// is received or the connection is closed.
323 pub async fn receive_messages(&mut self) -> impl Stream<Item = Result<Message>> + use<> {
324 // Always use the regular message receiver
325 // (Query handler shares the same transport and receives control messages separately)
326 // Create a new channel for this receiver
327 let (tx, rx) = mpsc::channel(100);
328
329 // Get buffered messages and clear buffer
330 let buffered_messages = {
331 let mut buffer = self.message_buffer.lock().await;
332 std::mem::take(&mut *buffer)
333 };
334
335 // Send buffered messages to the new receiver
336 let tx_clone = tx.clone();
337 tokio::spawn(async move {
338 for msg in buffered_messages {
339 if tx_clone.send(Ok(msg)).await.is_err() {
340 break;
341 }
342 }
343 });
344
345 // Store the sender for the message receiver task
346 {
347 let mut message_tx = self.message_tx.lock().await;
348 *message_tx = Some(tx);
349 }
350
351 ReceiverStream::new(rx)
352 }
353
354 /// Send an interrupt request
355 pub async fn interrupt(&mut self) -> Result<()> {
356 // Check connection
357 {
358 let state = self.state.read().await;
359 if *state != ClientState::Connected {
360 return Err(SdkError::InvalidState {
361 message: "Not connected".into(),
362 });
363 }
364 }
365
366 // If we have a query handler, use it
367 if let Some(ref query_handler) = self.query_handler {
368 let mut handler = query_handler.lock().await;
369 return handler.interrupt().await;
370 }
371
372 // Otherwise use regular interrupt
373 // Generate request ID
374 let request_id = {
375 let mut counter = self.request_counter.lock().await;
376 *counter += 1;
377 format!("interrupt_{}", *counter)
378 };
379
380 // Send interrupt request
381 let request = ControlRequest::Interrupt {
382 request_id: request_id.clone(),
383 };
384
385 {
386 let mut transport = self.transport.lock().await;
387 transport.send_control_request(request).await?;
388 }
389
390 info!("Sent interrupt request: {}", request_id);
391
392 // Wait for acknowledgment (with timeout)
393 let transport = self.transport.clone();
394 let ack_task = tokio::spawn(async move {
395 let mut transport = transport.lock().await;
396 match tokio::time::timeout(
397 std::time::Duration::from_secs(5),
398 transport.receive_control_response(),
399 )
400 .await
401 {
402 Ok(Ok(Some(ControlResponse::InterruptAck {
403 request_id: ack_id,
404 success,
405 }))) => {
406 if ack_id == request_id && success {
407 Ok(())
408 } else {
409 Err(SdkError::ControlRequestError(
410 "Interrupt not acknowledged successfully".into(),
411 ))
412 }
413 }
414 Ok(Ok(None)) => Err(SdkError::ControlRequestError(
415 "No interrupt acknowledgment received".into(),
416 )),
417 Ok(Err(e)) => Err(e),
418 Err(_) => Err(SdkError::timeout(5)),
419 }
420 });
421
422 ack_task
423 .await
424 .map_err(|_| SdkError::ControlRequestError("Interrupt task panicked".into()))?
425 }
426
427 /// Check if the client is connected
428 pub async fn is_connected(&self) -> bool {
429 let state = self.state.read().await;
430 *state == ClientState::Connected
431 }
432
433 /// Get active session IDs
434 pub async fn get_sessions(&self) -> Vec<String> {
435 let sessions = self.sessions.read().await;
436 sessions.keys().cloned().collect()
437 }
438
439 /// Receive messages until and including a ResultMessage
440 ///
441 /// This is a convenience method that collects all messages from a single response.
442 /// It will automatically stop after receiving a ResultMessage.
443 pub async fn receive_response(&mut self) -> Pin<Box<dyn Stream<Item = Result<Message>> + Send + '_>> {
444 let mut messages = self.receive_messages().await;
445
446 // Create a stream that stops after ResultMessage
447 Box::pin(async_stream::stream! {
448 while let Some(msg_result) = messages.next().await {
449 match &msg_result {
450 Ok(Message::Result { .. }) => {
451 yield msg_result;
452 return;
453 }
454 _ => {
455 yield msg_result;
456 }
457 }
458 }
459 })
460 }
461
462 /// Get server information
463 ///
464 /// Returns initialization information from the Claude Code server including:
465 /// - Available commands
466 /// - Current and available output styles
467 /// - Server capabilities
468 pub async fn get_server_info(&self) -> Option<serde_json::Value> {
469 // If we have a query handler with control protocol, get from there
470 if let Some(ref query_handler) = self.query_handler {
471 let handler = query_handler.lock().await;
472 if let Some(init_result) = handler.get_initialization_result() {
473 return Some(init_result.clone());
474 }
475 }
476
477 // Otherwise check message buffer for init message
478 let buffer = self.message_buffer.lock().await;
479 for msg in buffer.iter() {
480 if let Message::System { subtype, data } = msg
481 && subtype == "init" {
482 return Some(data.clone());
483 }
484 }
485 None
486 }
487
488 /// Get account information
489 ///
490 /// This method attempts to retrieve Claude account information through multiple methods:
491 /// 1. From environment variable `ANTHROPIC_USER_EMAIL`
492 /// 2. From Claude CLI config file (if accessible)
493 /// 3. By querying the CLI with `/status` command (interactive mode)
494 ///
495 /// # Returns
496 ///
497 /// A string containing the account information, or an error if unavailable.
498 ///
499 /// # Example
500 ///
501 /// ```rust,no_run
502 /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
503 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
504 /// let mut client = ClaudeSDKClient::new(ClaudeCodeOptions::default());
505 /// client.connect(None).await?;
506 ///
507 /// match client.get_account_info().await {
508 /// Ok(info) => println!("Account: {}", info),
509 /// Err(_) => println!("Account info not available"),
510 /// }
511 /// # Ok(())
512 /// # }
513 /// ```
514 ///
515 /// # Note
516 ///
517 /// Account information may not always be available in SDK mode.
518 /// Consider setting the `ANTHROPIC_USER_EMAIL` environment variable
519 /// for reliable account identification.
520 pub async fn get_account_info(&mut self) -> Result<String> {
521 // Check connection
522 {
523 let state = self.state.read().await;
524 if *state != ClientState::Connected {
525 return Err(SdkError::InvalidState {
526 message: "Not connected. Call connect() first.".into(),
527 });
528 }
529 }
530
531 // Method 1: Check environment variable
532 if let Ok(email) = std::env::var("ANTHROPIC_USER_EMAIL") {
533 return Ok(format!("Email: {}", email));
534 }
535
536 // Method 2: Try reading from Claude config
537 if let Some(config_info) = Self::read_claude_config().await {
538 return Ok(config_info);
539 }
540
541 // Method 3: Try /status command (may not work in SDK mode)
542 self.send_user_message("/status".to_string()).await?;
543
544 let mut messages = self.receive_messages().await;
545 let mut account_info = String::new();
546
547 while let Some(msg_result) = messages.next().await {
548 match msg_result? {
549 Message::Assistant { message } => {
550 for block in message.content {
551 if let ContentBlock::Text(text) = block {
552 account_info.push_str(&text.text);
553 account_info.push('\n');
554 }
555 }
556 }
557 Message::Result { .. } => break,
558 _ => {}
559 }
560 }
561
562 let trimmed = account_info.trim();
563
564 // Check if we got actual status info or just a chat response
565 if !trimmed.is_empty() && (
566 trimmed.contains("account") ||
567 trimmed.contains("email") ||
568 trimmed.contains("subscription") ||
569 trimmed.contains("authenticated")
570 ) {
571 return Ok(trimmed.to_string());
572 }
573
574 Err(SdkError::InvalidState {
575 message: "Account information not available. Try setting ANTHROPIC_USER_EMAIL environment variable.".into(),
576 })
577 }
578
579 /// Read Claude config file
580 async fn read_claude_config() -> Option<String> {
581 // Try common config locations
582 let config_paths = vec![
583 dirs::home_dir()?.join(".config").join("claude").join("config.json"),
584 dirs::home_dir()?.join(".claude").join("config.json"),
585 ];
586
587 for path in config_paths {
588 if let Ok(content) = tokio::fs::read_to_string(&path).await {
589 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&content) {
590 if let Some(email) = json.get("email").and_then(|v| v.as_str()) {
591 return Some(format!("Email: {}", email));
592 }
593 if let Some(user) = json.get("user").and_then(|v| v.as_str()) {
594 return Some(format!("User: {}", user));
595 }
596 }
597 }
598 }
599
600 None
601 }
602
603 /// Set permission mode dynamically
604 ///
605 /// Changes the permission mode during an active session.
606 /// Requires control protocol to be enabled (via can_use_tool, hooks, mcp_servers, or file checkpointing).
607 ///
608 /// # Arguments
609 ///
610 /// * `mode` - Permission mode: "default", "acceptEdits", "plan", or "bypassPermissions"
611 ///
612 /// # Example
613 ///
614 /// ```rust,no_run
615 /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
616 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
617 /// let mut client = ClaudeSDKClient::new(ClaudeCodeOptions::default());
618 /// client.connect(None).await?;
619 ///
620 /// // Switch to accept edits mode
621 /// client.set_permission_mode("acceptEdits").await?;
622 /// # Ok(())
623 /// # }
624 /// ```
625 pub async fn set_permission_mode(&mut self, mode: &str) -> Result<()> {
626 if let Some(ref query_handler) = self.query_handler {
627 let mut handler = query_handler.lock().await;
628 handler.set_permission_mode(mode).await
629 } else {
630 Err(SdkError::InvalidState {
631 message: "Query handler not initialized. Enable control protocol features (can_use_tool, hooks, mcp_servers, or enable_file_checkpointing).".to_string(),
632 })
633 }
634 }
635
636 /// Set model dynamically
637 ///
638 /// Changes the active model during an active session.
639 /// Requires control protocol to be enabled (via can_use_tool, hooks, mcp_servers, or file checkpointing).
640 ///
641 /// # Arguments
642 ///
643 /// * `model` - Model identifier (e.g., "claude-3-5-sonnet-20241022") or None to use default
644 ///
645 /// # Example
646 ///
647 /// ```rust,no_run
648 /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
649 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
650 /// let mut client = ClaudeSDKClient::new(ClaudeCodeOptions::default());
651 /// client.connect(None).await?;
652 ///
653 /// // Switch to a different model
654 /// client.set_model(Some("claude-3-5-sonnet-20241022".to_string())).await?;
655 /// # Ok(())
656 /// # }
657 /// ```
658 pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
659 if let Some(ref query_handler) = self.query_handler {
660 let mut handler = query_handler.lock().await;
661 handler.set_model(model).await
662 } else {
663 Err(SdkError::InvalidState {
664 message: "Query handler not initialized. Enable control protocol features (can_use_tool, hooks, mcp_servers, or enable_file_checkpointing).".to_string(),
665 })
666 }
667 }
668
669 /// Send a query with optional session ID
670 ///
671 /// This method is similar to Python SDK's query method in ClaudeSDKClient
672 pub async fn query(&mut self, prompt: String, session_id: Option<String>) -> Result<()> {
673 let session_id = session_id.unwrap_or_else(|| "default".to_string());
674
675 // Send the message
676 let message = InputMessage::user(prompt, session_id);
677
678 {
679 let mut transport = self.transport.lock().await;
680 transport.send_message(message).await?;
681 }
682
683 Ok(())
684 }
685
686 /// Rewind tracked files to their state at a specific user message
687 ///
688 /// Requires `enable_file_checkpointing` to be enabled in `ClaudeCodeOptions`.
689 /// This method allows you to undo file changes made during the session by
690 /// reverting them to their state at any previous user message checkpoint.
691 ///
692 /// # Arguments
693 ///
694 /// * `user_message_id` - UUID of the user message to rewind to. This should be
695 /// the `uuid` field from a message received during the conversation.
696 ///
697 /// # Example
698 ///
699 /// ```rust,no_run
700 /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
701 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
702 /// let options = ClaudeCodeOptions::builder()
703 /// .enable_file_checkpointing(true)
704 /// .build();
705 /// let mut client = ClaudeSDKClient::new(options);
706 /// client.connect(None).await?;
707 ///
708 /// // Ask Claude to make some changes
709 /// client.send_request("Make some changes to my files".to_string(), None).await?;
710 ///
711 /// // ... later, rewind to a checkpoint
712 /// // client.rewind_files("user-message-uuid-here").await?;
713 /// # Ok(())
714 /// # }
715 /// ```
716 ///
717 /// # Errors
718 ///
719 /// Returns an error if:
720 /// - The client is not connected
721 /// - The query handler is not initialized (control protocol required)
722 /// - File checkpointing is not enabled
723 /// - The specified user_message_id is invalid
724 pub async fn rewind_files(&mut self, user_message_id: &str) -> Result<()> {
725 // Check connection
726 {
727 let state = self.state.read().await;
728 if *state != ClientState::Connected {
729 return Err(SdkError::InvalidState {
730 message: "Not connected. Call connect() first.".into(),
731 });
732 }
733 }
734
735 if !self.options.enable_file_checkpointing {
736 return Err(SdkError::InvalidState {
737 message: "File checkpointing is not enabled. Set ClaudeCodeOptions::builder().enable_file_checkpointing(true).".to_string(),
738 });
739 }
740
741 // Require query handler for control protocol
742 if let Some(ref query_handler) = self.query_handler {
743 let mut handler = query_handler.lock().await;
744 handler.rewind_files(user_message_id).await
745 } else {
746 Err(SdkError::InvalidState {
747 message: "Query handler not initialized. Enable control protocol features (can_use_tool, hooks, mcp_servers, or enable_file_checkpointing).".to_string(),
748 })
749 }
750 }
751
752 /// Get context usage information including token distribution and cache stats
753 ///
754 /// Returns detailed context window usage including categories, API usage with
755 /// cache token information, and auto-compact settings.
756 ///
757 /// Requires control protocol to be enabled.
758 pub async fn get_context_usage(&mut self) -> Result<crate::types::ContextUsageResponse> {
759 {
760 let state = self.state.read().await;
761 if *state != ClientState::Connected {
762 return Err(SdkError::InvalidState {
763 message: "Not connected. Call connect() first.".into(),
764 });
765 }
766 }
767
768 if let Some(ref query_handler) = self.query_handler {
769 let mut handler = query_handler.lock().await;
770 let raw = handler.get_context_usage().await?;
771 serde_json::from_value(raw).map_err(|e| SdkError::ControlRequestError(
772 format!("Failed to parse context usage response: {e}")
773 ))
774 } else {
775 Err(SdkError::InvalidState {
776 message: "Query handler not initialized. Enable control protocol features.".to_string(),
777 })
778 }
779 }
780
781 /// Stop a background task by ID
782 ///
783 /// Requires control protocol to be enabled.
784 pub async fn stop_task(&mut self, task_id: &str) -> Result<()> {
785 {
786 let state = self.state.read().await;
787 if *state != ClientState::Connected {
788 return Err(SdkError::InvalidState {
789 message: "Not connected. Call connect() first.".into(),
790 });
791 }
792 }
793
794 if let Some(ref query_handler) = self.query_handler {
795 let mut handler = query_handler.lock().await;
796 handler.stop_task(task_id).await
797 } else {
798 Err(SdkError::InvalidState {
799 message: "Query handler not initialized. Enable control protocol features.".to_string(),
800 })
801 }
802 }
803
804 /// Get MCP server status information
805 ///
806 /// Returns the current status of all connected MCP servers.
807 /// Requires control protocol to be enabled.
808 pub async fn get_mcp_status(&mut self) -> Result<serde_json::Value> {
809 {
810 let state = self.state.read().await;
811 if *state != ClientState::Connected {
812 return Err(SdkError::InvalidState {
813 message: "Not connected. Call connect() first.".into(),
814 });
815 }
816 }
817
818 if let Some(ref query_handler) = self.query_handler {
819 let mut handler = query_handler.lock().await;
820 handler.get_mcp_status().await
821 } else {
822 Err(SdkError::InvalidState {
823 message: "Query handler not initialized. Enable control protocol features.".to_string(),
824 })
825 }
826 }
827
828 /// Reconnect a failed MCP server
829 ///
830 /// Attempts to re-establish connection to the specified MCP server.
831 /// Requires control protocol to be enabled.
832 pub async fn reconnect_mcp_server(&mut self, server_name: &str) -> Result<()> {
833 {
834 let state = self.state.read().await;
835 if *state != ClientState::Connected {
836 return Err(SdkError::InvalidState {
837 message: "Not connected. Call connect() first.".into(),
838 });
839 }
840 }
841
842 if let Some(ref query_handler) = self.query_handler {
843 let mut handler = query_handler.lock().await;
844 handler.reconnect_mcp_server(server_name).await
845 } else {
846 Err(SdkError::InvalidState {
847 message: "Query handler not initialized. Enable control protocol features.".to_string(),
848 })
849 }
850 }
851
852 /// Toggle an MCP server on/off
853 ///
854 /// Enables or disables the specified MCP server.
855 /// Requires control protocol to be enabled.
856 pub async fn toggle_mcp_server(&mut self, server_name: &str, enabled: bool) -> Result<()> {
857 {
858 let state = self.state.read().await;
859 if *state != ClientState::Connected {
860 return Err(SdkError::InvalidState {
861 message: "Not connected. Call connect() first.".into(),
862 });
863 }
864 }
865
866 if let Some(ref query_handler) = self.query_handler {
867 let mut handler = query_handler.lock().await;
868 handler.toggle_mcp_server(server_name, enabled).await
869 } else {
870 Err(SdkError::InvalidState {
871 message: "Query handler not initialized. Enable control protocol features.".to_string(),
872 })
873 }
874 }
875
876 /// Disconnect from Claude CLI
877 pub async fn disconnect(&mut self) -> Result<()> {
878 // Check if already disconnected
879 {
880 let state = self.state.read().await;
881 if *state == ClientState::Disconnected {
882 return Ok(());
883 }
884 }
885
886 // Disconnect transport
887 {
888 let mut transport = self.transport.lock().await;
889 transport.disconnect().await?;
890 }
891
892 // Update state
893 {
894 let mut state = self.state.write().await;
895 *state = ClientState::Disconnected;
896 }
897
898 // Clear sessions
899 {
900 let mut sessions = self.sessions.write().await;
901 sessions.clear();
902 }
903
904 info!("Disconnected from Claude CLI");
905 Ok(())
906 }
907
908 /// Start the message receiver task
909 async fn start_message_receiver(&mut self) {
910 let transport = self.transport.clone();
911 let message_tx = self.message_tx.clone();
912 let message_buffer = self.message_buffer.clone();
913 let state = self.state.clone();
914 let budget_manager = self.budget_manager.clone();
915
916 tokio::spawn(async move {
917 // Subscribe to messages without holding the lock
918 let mut stream = {
919 let mut transport = transport.lock().await;
920 transport.receive_messages()
921 }; // Lock is released here immediately
922
923 while let Some(result) = stream.next().await {
924 match result {
925 Ok(message) => {
926 // Update token usage for Result messages
927 if let Message::Result { .. } = &message
928 && let Message::Result { usage, total_cost_usd, .. } = &message {
929 let (input_tokens, output_tokens) = if let Some(usage_json) = usage {
930 let input = usage_json.get("input_tokens")
931 .and_then(|v| v.as_u64())
932 .unwrap_or(0);
933 let output = usage_json.get("output_tokens")
934 .and_then(|v| v.as_u64())
935 .unwrap_or(0);
936 (input, output)
937 } else {
938 (0, 0)
939 };
940 let cost = total_cost_usd.unwrap_or(0.0);
941 budget_manager.update_usage(input_tokens, output_tokens, cost).await;
942 }
943
944 // Buffer init messages for get_server_info()
945 if let Message::System { subtype, .. } = &message
946 && subtype == "init" {
947 let mut buffer = message_buffer.lock().await;
948 buffer.push(message.clone());
949 }
950
951 // Try to send to current receiver
952 let sent = {
953 let mut tx_opt = message_tx.lock().await;
954 if let Some(tx) = tx_opt.as_mut() {
955 tx.send(Ok(message.clone())).await.is_ok()
956 } else {
957 false
958 }
959 };
960
961 // If no receiver or send failed, buffer the message
962 if !sent {
963 let mut buffer = message_buffer.lock().await;
964 buffer.push(message);
965 }
966 }
967 Err(e) => {
968 error!("Error receiving message: {}", e);
969
970 // Send error to receiver if available
971 let mut tx_opt = message_tx.lock().await;
972 if let Some(tx) = tx_opt.as_mut() {
973 let _ = tx.send(Err(e)).await;
974 }
975
976 // Update state on error
977 let mut state = state.write().await;
978 *state = ClientState::Error;
979 break;
980 }
981 }
982 }
983
984 debug!("Message receiver task ended");
985 });
986 }
987
988 /// Get token usage statistics
989 ///
990 /// Returns the current token usage tracker with cumulative statistics
991 /// for all queries executed by this client.
992 pub async fn get_usage_stats(&self) -> crate::token_tracker::TokenUsageTracker {
993 self.budget_manager.get_usage().await
994 }
995
996 /// Set budget limit with optional warning callback
997 ///
998 /// # Arguments
999 ///
1000 /// * `limit` - Budget limit configuration (cost and/or token caps)
1001 /// * `on_warning` - Optional callback function triggered when usage exceeds warning threshold
1002 ///
1003 /// # Example
1004 ///
1005 /// ```rust,no_run
1006 /// use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
1007 /// use cc_sdk::token_tracker::{BudgetLimit, BudgetWarningCallback};
1008 /// use std::sync::Arc;
1009 ///
1010 /// # async fn example() {
1011 /// let mut client = ClaudeSDKClient::new(ClaudeCodeOptions::default());
1012 ///
1013 /// // Set budget with callback
1014 /// let cb: BudgetWarningCallback = Arc::new(|msg: &str| println!("Budget warning: {}", msg));
1015 /// client.set_budget_limit(BudgetLimit::with_cost(5.0), Some(cb)).await;
1016 /// # }
1017 /// ```
1018 pub async fn set_budget_limit(
1019 &self,
1020 limit: crate::token_tracker::BudgetLimit,
1021 on_warning: Option<crate::token_tracker::BudgetWarningCallback>,
1022 ) {
1023 self.budget_manager.set_limit(limit).await;
1024 if let Some(callback) = on_warning {
1025 self.budget_manager.set_warning_callback(callback).await;
1026 }
1027 }
1028
1029 /// Clear budget limit and reset warning state
1030 pub async fn clear_budget_limit(&self) {
1031 self.budget_manager.clear_limit().await;
1032 }
1033
1034 /// Reset token usage statistics to zero
1035 ///
1036 /// Clears all accumulated token and cost statistics.
1037 /// Budget limits remain in effect.
1038 pub async fn reset_usage_stats(&self) {
1039 self.budget_manager.reset_usage().await;
1040 }
1041
1042 /// Check if budget has been exceeded
1043 ///
1044 /// Returns true if current usage exceeds any configured limits
1045 pub async fn is_budget_exceeded(&self) -> bool {
1046 self.budget_manager.is_exceeded().await
1047 }
1048
1049 // Removed unused helper; usage is updated inline in message receiver
1050}
1051
1052impl Drop for ClaudeSDKClient {
1053 fn drop(&mut self) {
1054 // Try to disconnect gracefully
1055 let transport = self.transport.clone();
1056 let state = self.state.clone();
1057
1058 if let Ok(handle) = tokio::runtime::Handle::try_current() {
1059 handle.spawn(async move {
1060 let state = state.read().await;
1061 if *state == ClientState::Connected {
1062 let mut transport = transport.lock().await;
1063 if let Err(e) = transport.disconnect().await {
1064 debug!("Error disconnecting in drop: {}", e);
1065 }
1066 }
1067 });
1068 }
1069 }
1070}
1071
1072#[cfg(test)]
1073mod tests {
1074 use super::*;
1075
1076 #[tokio::test]
1077 async fn test_client_lifecycle() {
1078 let options = ClaudeCodeOptions::default();
1079 let client = ClaudeSDKClient::new(options);
1080
1081 assert!(!client.is_connected().await);
1082 assert_eq!(client.get_sessions().await.len(), 0);
1083 }
1084
1085 #[tokio::test]
1086 async fn test_client_state_transitions() {
1087 let options = ClaudeCodeOptions::default();
1088 let client = ClaudeSDKClient::new(options);
1089
1090 let state = client.state.read().await;
1091 assert_eq!(*state, ClientState::Disconnected);
1092 }
1093
1094 #[test]
1095 fn test_file_checkpointing_enables_query_handler() {
1096 let options = ClaudeCodeOptions::builder()
1097 .enable_file_checkpointing(true)
1098 .build();
1099 let client = ClaudeSDKClient::new(options);
1100
1101 assert!(
1102 client.query_handler.is_some(),
1103 "enable_file_checkpointing should initialize the query handler for control protocol requests"
1104 );
1105 }
1106}