Skip to main content

cc_sdk/
client.rs

1//! Interactive client for bidirectional communication with Claude
2//!
3//! This module provides the `ClaudeSDKClient` for interactive, stateful
4//! conversations with Claude Code CLI.
5
6use crate::{
7    errors::{Result, SdkError},
8    internal_query::Query,
9    token_tracker::BudgetManager,
10    transport::{InputMessage, SubprocessTransport, Transport},
11    types::{ClaudeCodeOptions, ContentBlock, ControlRequest, ControlResponse, Message},
12};
13use futures::stream::{Stream, StreamExt};
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::pin::Pin;
17use tokio::sync::{Mutex, RwLock, mpsc};
18use tokio_stream::wrappers::ReceiverStream;
19use tracing::{debug, error, info};
20
21/// Client state
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum ClientState {
24    /// Not connected
25    Disconnected,
26    /// Connected and ready
27    Connected,
28    /// Error state
29    Error,
30}
31
32/// Interactive client for bidirectional communication with Claude
33///
34/// `ClaudeSDKClient` provides a stateful, interactive interface for communicating
35/// with Claude Code CLI. Unlike the simple `query` function, this client supports:
36///
37/// - Bidirectional communication
38/// - Multiple sessions
39/// - Interrupt capabilities
40/// - State management
41/// - Follow-up messages based on responses
42///
43/// # Example
44///
45/// ```rust,no_run
46/// use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions, Message, Result};
47/// use futures::StreamExt;
48///
49/// #[tokio::main]
50/// async fn main() -> Result<()> {
51///     let options = ClaudeCodeOptions::builder()
52///         .system_prompt("You are a helpful assistant")
53///         .model("claude-3-opus-20240229")
54///         .build();
55///
56///     let mut client = ClaudeSDKClient::new(options);
57///
58///     // Connect with initial prompt
59///     client.connect(Some("Hello!".to_string())).await?;
60///
61///     // Receive initial response
62///     let mut messages = client.receive_messages().await;
63///     while let Some(msg) = messages.next().await {
64///         match msg? {
65///             Message::Result { .. } => break,
66///             msg => println!("{:?}", msg),
67///         }
68///     }
69///
70///     // Send follow-up
71///     client.send_request("What's 2 + 2?".to_string(), None).await?;
72///
73///     // Receive response
74///     let mut messages = client.receive_messages().await;
75///     while let Some(msg) = messages.next().await {
76///         println!("{:?}", msg?);
77///     }
78///
79///     // Disconnect
80///     client.disconnect().await?;
81///
82///     Ok(())
83/// }
84/// ```
85pub struct ClaudeSDKClient {
86    /// Configuration options
87    #[allow(dead_code)]
88    options: ClaudeCodeOptions,
89    /// Transport layer
90    transport: Arc<Mutex<Box<dyn Transport + Send>>>,
91    /// Internal query handler (when control protocol is enabled)
92    query_handler: Option<Arc<Mutex<Query>>>,
93    /// Client state
94    state: Arc<RwLock<ClientState>>,
95    /// Active sessions
96    sessions: Arc<RwLock<HashMap<String, SessionData>>>,
97    /// Message sender for current receiver
98    message_tx: Arc<Mutex<Option<mpsc::Sender<Result<Message>>>>>,
99    /// Message buffer for multiple receivers
100    message_buffer: Arc<Mutex<Vec<Message>>>,
101    /// Request counter
102    request_counter: Arc<Mutex<u64>>,
103    /// Budget manager for token tracking
104    budget_manager: BudgetManager,
105}
106
107/// Session data
108#[allow(dead_code)]
109struct SessionData {
110    /// Session ID
111    id: String,
112    /// Number of messages sent
113    message_count: usize,
114    /// Creation time
115    created_at: std::time::Instant,
116}
117
118impl ClaudeSDKClient {
119    /// Create a new client with the given options
120    pub fn new(options: ClaudeCodeOptions) -> Self {
121        let transport = match SubprocessTransport::new(options.clone()) {
122            Ok(t) => t,
123            Err(e) => {
124                error!("Failed to create transport: {}", e);
125                // Create with empty path, will fail on connect
126                SubprocessTransport::with_cli_path(options.clone(), "")
127            }
128        };
129
130        // Wrap transport in Arc for sharing
131        let transport_arc: Arc<Mutex<Box<dyn Transport + Send>>> =
132            Arc::new(Mutex::new(Box::new(transport)));
133
134        Self::with_transport_internal(options, transport_arc)
135    }
136
137    /// Create a new client with a custom transport implementation
138    ///
139    /// This allows users to provide their own Transport implementation instead of
140    /// using the default SubprocessTransport. Useful for testing, custom CLI paths,
141    /// or alternative communication mechanisms.
142    ///
143    /// # Arguments
144    ///
145    /// * `options` - Configuration options for the client
146    /// * `transport` - Custom transport implementation
147    ///
148    /// # Example
149    ///
150    /// ```rust,no_run
151    /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions, SubprocessTransport};
152    /// # fn example() {
153    /// let options = ClaudeCodeOptions::default();
154    /// let transport = SubprocessTransport::with_cli_path(options.clone(), "/custom/path/claude-code");
155    /// let client = ClaudeSDKClient::with_transport(options, Box::new(transport));
156    /// # }
157    /// ```
158    pub fn with_transport(options: ClaudeCodeOptions, transport: Box<dyn Transport + Send>) -> Self {
159        // Wrap transport in Arc for sharing
160        let transport_arc: Arc<Mutex<Box<dyn Transport + Send>>> =
161            Arc::new(Mutex::new(transport));
162
163        Self::with_transport_internal(options, transport_arc)
164    }
165
166    /// Internal helper to construct client with pre-wrapped transport
167    fn with_transport_internal(
168        mut options: ClaudeCodeOptions,
169        transport_arc: Arc<Mutex<Box<dyn Transport + Send>>>,
170    ) -> Self {
171        // Auto-configure permission_prompt_tool_name when can_use_tool is set
172        // (matching Python SDK behavior: route permission requests via stdio control protocol)
173        if options.can_use_tool.is_some() && options.permission_prompt_tool_name.is_none() {
174            options.permission_prompt_tool_name = Some("stdio".to_string());
175        }
176
177        // Create query handler if control protocol features are enabled
178        let query_handler = if options.can_use_tool.is_some()
179            || options.hooks.is_some()
180            || !options.mcp_servers.is_empty()
181            || options.enable_file_checkpointing {
182            // Extract SDK MCP server instances
183            let sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>> = options.mcp_servers
184                .iter()
185                .filter_map(|(k, v)| {
186                    // Only extract SDK type MCP servers
187                    if let crate::types::McpServerConfig::Sdk { name: _, instance } = v {
188                        Some((k.clone(), instance.clone()))
189                    } else {
190                        None
191                    }
192                })
193                .collect();
194
195            // Enable streaming mode when control protocol is active
196            let is_streaming = options.can_use_tool.is_some()
197                || options.hooks.is_some()
198                || !sdk_mcp_servers.is_empty();
199
200            let query = Query::new(
201                transport_arc.clone(), // Share the same transport
202                is_streaming, // Enable streaming for control protocol
203                options.can_use_tool.clone(),
204                options.hooks.clone(),
205                sdk_mcp_servers,
206            );
207            Some(Arc::new(Mutex::new(query)))
208        } else {
209            None
210        };
211
212        Self {
213            options,
214            transport: transport_arc,
215            query_handler,
216            state: Arc::new(RwLock::new(ClientState::Disconnected)),
217            sessions: Arc::new(RwLock::new(HashMap::new())),
218            message_tx: Arc::new(Mutex::new(None)),
219            message_buffer: Arc::new(Mutex::new(Vec::new())),
220            request_counter: Arc::new(Mutex::new(0)),
221            budget_manager: BudgetManager::new(),
222        }
223    }
224
225    /// Connect to Claude CLI with an optional initial prompt
226    pub async fn connect(&mut self, initial_prompt: Option<String>) -> Result<()> {
227        // Check if already connected
228        {
229            let state = self.state.read().await;
230            if *state == ClientState::Connected {
231                return Ok(());
232            }
233        }
234
235        // Connect transport
236        {
237            let mut transport = self.transport.lock().await;
238            transport.connect().await?;
239        }
240
241        // Initialize query handler if present
242        if let Some(ref query_handler) = self.query_handler {
243            let mut handler = query_handler.lock().await;
244            handler.start().await?;
245            handler.initialize().await?;
246            info!("Initialized SDK control protocol");
247        }
248
249        // Update state
250        {
251            let mut state = self.state.write().await;
252            *state = ClientState::Connected;
253        }
254
255        info!("Connected to Claude CLI");
256
257        // Start message receiver task (always needed for regular messages)
258        self.start_message_receiver().await;
259
260        // Send initial prompt if provided
261        if let Some(prompt) = initial_prompt {
262            self.send_request(prompt, None).await?;
263        }
264
265        Ok(())
266    }
267
268    /// Send a user message to Claude
269    pub async fn send_user_message(&mut self, prompt: String) -> Result<()> {
270        // Check connection
271        {
272            let state = self.state.read().await;
273            if *state != ClientState::Connected {
274                return Err(SdkError::InvalidState {
275                    message: "Not connected".into(),
276                });
277            }
278        }
279
280        // Use default session ID
281        let session_id = "default".to_string();
282
283        // Update session data
284        {
285            let mut sessions = self.sessions.write().await;
286            let session = sessions.entry(session_id.clone()).or_insert_with(|| {
287                debug!("Creating new session: {}", session_id);
288                SessionData {
289                    id: session_id.clone(),
290                    message_count: 0,
291                    created_at: std::time::Instant::now(),
292                }
293            });
294            session.message_count += 1;
295        }
296
297        // Create and send message
298        let message = InputMessage::user(prompt, session_id.clone());
299
300        {
301            let mut transport = self.transport.lock().await;
302            transport.send_message(message).await?;
303        }
304
305        debug!("Sent request to Claude");
306        Ok(())
307    }
308
309    /// Send a request to Claude (alias for send_user_message with optional session_id)
310    pub async fn send_request(
311        &mut self,
312        prompt: String,
313        _session_id: Option<String>,
314    ) -> Result<()> {
315        // For now, ignore session_id and use send_user_message
316        self.send_user_message(prompt).await
317    }
318
319    /// Receive messages from Claude
320    ///
321    /// Returns a stream of messages. The stream will end when a Result message
322    /// is received or the connection is closed.
323    pub async fn receive_messages(&mut self) -> impl Stream<Item = Result<Message>> + use<> {
324        // Always use the regular message receiver
325        // (Query handler shares the same transport and receives control messages separately)
326        // Create a new channel for this receiver
327        let (tx, rx) = mpsc::channel(100);
328
329        // Get buffered messages and clear buffer
330        let buffered_messages = {
331            let mut buffer = self.message_buffer.lock().await;
332            std::mem::take(&mut *buffer)
333        };
334
335        // Send buffered messages to the new receiver
336        let tx_clone = tx.clone();
337        tokio::spawn(async move {
338            for msg in buffered_messages {
339                if tx_clone.send(Ok(msg)).await.is_err() {
340                    break;
341                }
342            }
343        });
344
345        // Store the sender for the message receiver task
346        {
347            let mut message_tx = self.message_tx.lock().await;
348            *message_tx = Some(tx);
349        }
350
351        ReceiverStream::new(rx)
352    }
353
354    /// Send an interrupt request
355    pub async fn interrupt(&mut self) -> Result<()> {
356        // Check connection
357        {
358            let state = self.state.read().await;
359            if *state != ClientState::Connected {
360                return Err(SdkError::InvalidState {
361                    message: "Not connected".into(),
362                });
363            }
364        }
365
366        // If we have a query handler, use it
367        if let Some(ref query_handler) = self.query_handler {
368            let mut handler = query_handler.lock().await;
369            return handler.interrupt().await;
370        }
371
372        // Otherwise use regular interrupt
373        // Generate request ID
374        let request_id = {
375            let mut counter = self.request_counter.lock().await;
376            *counter += 1;
377            format!("interrupt_{}", *counter)
378        };
379
380        // Send interrupt request
381        let request = ControlRequest::Interrupt {
382            request_id: request_id.clone(),
383        };
384
385        {
386            let mut transport = self.transport.lock().await;
387            transport.send_control_request(request).await?;
388        }
389
390        info!("Sent interrupt request: {}", request_id);
391
392        // Wait for acknowledgment (with timeout)
393        let transport = self.transport.clone();
394        let ack_task = tokio::spawn(async move {
395            let mut transport = transport.lock().await;
396            match tokio::time::timeout(
397                std::time::Duration::from_secs(5),
398                transport.receive_control_response(),
399            )
400            .await
401            {
402                Ok(Ok(Some(ControlResponse::InterruptAck {
403                    request_id: ack_id,
404                    success,
405                }))) => {
406                    if ack_id == request_id && success {
407                        Ok(())
408                    } else {
409                        Err(SdkError::ControlRequestError(
410                            "Interrupt not acknowledged successfully".into(),
411                        ))
412                    }
413                }
414                Ok(Ok(None)) => Err(SdkError::ControlRequestError(
415                    "No interrupt acknowledgment received".into(),
416                )),
417                Ok(Err(e)) => Err(e),
418                Err(_) => Err(SdkError::timeout(5)),
419            }
420        });
421
422        ack_task
423            .await
424            .map_err(|_| SdkError::ControlRequestError("Interrupt task panicked".into()))?
425    }
426
427    /// Check if the client is connected
428    pub async fn is_connected(&self) -> bool {
429        let state = self.state.read().await;
430        *state == ClientState::Connected
431    }
432
433    /// Get active session IDs
434    pub async fn get_sessions(&self) -> Vec<String> {
435        let sessions = self.sessions.read().await;
436        sessions.keys().cloned().collect()
437    }
438
439    /// Receive messages until and including a ResultMessage
440    ///
441    /// This is a convenience method that collects all messages from a single response.
442    /// It will automatically stop after receiving a ResultMessage.
443    pub async fn receive_response(&mut self) -> Pin<Box<dyn Stream<Item = Result<Message>> + Send + '_>> {
444        let mut messages = self.receive_messages().await;
445        
446        // Create a stream that stops after ResultMessage
447        Box::pin(async_stream::stream! {
448            while let Some(msg_result) = messages.next().await {
449                match &msg_result {
450                    Ok(Message::Result { .. }) => {
451                        yield msg_result;
452                        return;
453                    }
454                    _ => {
455                        yield msg_result;
456                    }
457                }
458            }
459        })
460    }
461
462    /// Get server information
463    ///
464    /// Returns initialization information from the Claude Code server including:
465    /// - Available commands
466    /// - Current and available output styles
467    /// - Server capabilities
468    pub async fn get_server_info(&self) -> Option<serde_json::Value> {
469        // If we have a query handler with control protocol, get from there
470        if let Some(ref query_handler) = self.query_handler {
471            let handler = query_handler.lock().await;
472            if let Some(init_result) = handler.get_initialization_result() {
473                return Some(init_result.clone());
474            }
475        }
476
477        // Otherwise check message buffer for init message
478        let buffer = self.message_buffer.lock().await;
479        for msg in buffer.iter() {
480            if let Message::System { subtype, data } = msg
481                && subtype == "init" {
482                    return Some(data.clone());
483                }
484        }
485        None
486    }
487
488    /// Get account information
489    ///
490    /// This method attempts to retrieve Claude account information through multiple methods:
491    /// 1. From environment variable `ANTHROPIC_USER_EMAIL`
492    /// 2. From Claude CLI config file (if accessible)
493    /// 3. By querying the CLI with `/status` command (interactive mode)
494    ///
495    /// # Returns
496    ///
497    /// A string containing the account information, or an error if unavailable.
498    ///
499    /// # Example
500    ///
501    /// ```rust,no_run
502    /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
503    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
504    /// let mut client = ClaudeSDKClient::new(ClaudeCodeOptions::default());
505    /// client.connect(None).await?;
506    ///
507    /// match client.get_account_info().await {
508    ///     Ok(info) => println!("Account: {}", info),
509    ///     Err(_) => println!("Account info not available"),
510    /// }
511    /// # Ok(())
512    /// # }
513    /// ```
514    ///
515    /// # Note
516    ///
517    /// Account information may not always be available in SDK mode.
518    /// Consider setting the `ANTHROPIC_USER_EMAIL` environment variable
519    /// for reliable account identification.
520    pub async fn get_account_info(&mut self) -> Result<String> {
521        // Check connection
522        {
523            let state = self.state.read().await;
524            if *state != ClientState::Connected {
525                return Err(SdkError::InvalidState {
526                    message: "Not connected. Call connect() first.".into(),
527                });
528            }
529        }
530
531        // Method 1: Check environment variable
532        if let Ok(email) = std::env::var("ANTHROPIC_USER_EMAIL") {
533            return Ok(format!("Email: {}", email));
534        }
535
536        // Method 2: Try reading from Claude config
537        if let Some(config_info) = Self::read_claude_config().await {
538            return Ok(config_info);
539        }
540
541        // Method 3: Try /status command (may not work in SDK mode)
542        self.send_user_message("/status".to_string()).await?;
543
544        let mut messages = self.receive_messages().await;
545        let mut account_info = String::new();
546
547        while let Some(msg_result) = messages.next().await {
548            match msg_result? {
549                Message::Assistant { message } => {
550                    for block in message.content {
551                        if let ContentBlock::Text(text) = block {
552                            account_info.push_str(&text.text);
553                            account_info.push('\n');
554                        }
555                    }
556                }
557                Message::Result { .. } => break,
558                _ => {}
559            }
560        }
561
562        let trimmed = account_info.trim();
563
564        // Check if we got actual status info or just a chat response
565        if !trimmed.is_empty() && (
566            trimmed.contains("account") ||
567            trimmed.contains("email") ||
568            trimmed.contains("subscription") ||
569            trimmed.contains("authenticated")
570        ) {
571            return Ok(trimmed.to_string());
572        }
573
574        Err(SdkError::InvalidState {
575            message: "Account information not available. Try setting ANTHROPIC_USER_EMAIL environment variable.".into(),
576        })
577    }
578
579    /// Read Claude config file
580    async fn read_claude_config() -> Option<String> {
581        // Try common config locations
582        let config_paths = vec![
583            dirs::home_dir()?.join(".config").join("claude").join("config.json"),
584            dirs::home_dir()?.join(".claude").join("config.json"),
585        ];
586
587        for path in config_paths {
588            if let Ok(content) = tokio::fs::read_to_string(&path).await {
589                if let Ok(json) = serde_json::from_str::<serde_json::Value>(&content) {
590                    if let Some(email) = json.get("email").and_then(|v| v.as_str()) {
591                        return Some(format!("Email: {}", email));
592                    }
593                    if let Some(user) = json.get("user").and_then(|v| v.as_str()) {
594                        return Some(format!("User: {}", user));
595                    }
596                }
597            }
598        }
599
600        None
601    }
602
603    /// Set permission mode dynamically
604    ///
605    /// Changes the permission mode during an active session.
606    /// Requires control protocol to be enabled (via can_use_tool, hooks, mcp_servers, or file checkpointing).
607    ///
608    /// # Arguments
609    ///
610    /// * `mode` - Permission mode: "default", "acceptEdits", "plan", or "bypassPermissions"
611    ///
612    /// # Example
613    ///
614    /// ```rust,no_run
615    /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
616    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
617    /// let mut client = ClaudeSDKClient::new(ClaudeCodeOptions::default());
618    /// client.connect(None).await?;
619    ///
620    /// // Switch to accept edits mode
621    /// client.set_permission_mode("acceptEdits").await?;
622    /// # Ok(())
623    /// # }
624    /// ```
625    pub async fn set_permission_mode(&mut self, mode: &str) -> Result<()> {
626        if let Some(ref query_handler) = self.query_handler {
627            let mut handler = query_handler.lock().await;
628            handler.set_permission_mode(mode).await
629        } else {
630            Err(SdkError::InvalidState {
631                message: "Query handler not initialized. Enable control protocol features (can_use_tool, hooks, mcp_servers, or enable_file_checkpointing).".to_string(),
632            })
633        }
634    }
635
636    /// Set model dynamically
637    ///
638    /// Changes the active model during an active session.
639    /// Requires control protocol to be enabled (via can_use_tool, hooks, mcp_servers, or file checkpointing).
640    ///
641    /// # Arguments
642    ///
643    /// * `model` - Model identifier (e.g., "claude-3-5-sonnet-20241022") or None to use default
644    ///
645    /// # Example
646    ///
647    /// ```rust,no_run
648    /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
649    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
650    /// let mut client = ClaudeSDKClient::new(ClaudeCodeOptions::default());
651    /// client.connect(None).await?;
652    ///
653    /// // Switch to a different model
654    /// client.set_model(Some("claude-3-5-sonnet-20241022".to_string())).await?;
655    /// # Ok(())
656    /// # }
657    /// ```
658    pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
659        if let Some(ref query_handler) = self.query_handler {
660            let mut handler = query_handler.lock().await;
661            handler.set_model(model).await
662        } else {
663            Err(SdkError::InvalidState {
664                message: "Query handler not initialized. Enable control protocol features (can_use_tool, hooks, mcp_servers, or enable_file_checkpointing).".to_string(),
665            })
666        }
667    }
668
669    /// Send a query with optional session ID
670    ///
671    /// This method is similar to Python SDK's query method in ClaudeSDKClient
672    pub async fn query(&mut self, prompt: String, session_id: Option<String>) -> Result<()> {
673        let session_id = session_id.unwrap_or_else(|| "default".to_string());
674
675        // Send the message
676        let message = InputMessage::user(prompt, session_id);
677
678        {
679            let mut transport = self.transport.lock().await;
680            transport.send_message(message).await?;
681        }
682
683        Ok(())
684    }
685
686    /// Rewind tracked files to their state at a specific user message
687    ///
688    /// Requires `enable_file_checkpointing` to be enabled in `ClaudeCodeOptions`.
689    /// This method allows you to undo file changes made during the session by
690    /// reverting them to their state at any previous user message checkpoint.
691    ///
692    /// # Arguments
693    ///
694    /// * `user_message_id` - UUID of the user message to rewind to. This should be
695    ///   the `uuid` field from a message received during the conversation.
696    ///
697    /// # Example
698    ///
699    /// ```rust,no_run
700    /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
701    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
702    /// let options = ClaudeCodeOptions::builder()
703    ///     .enable_file_checkpointing(true)
704    ///     .build();
705    /// let mut client = ClaudeSDKClient::new(options);
706    /// client.connect(None).await?;
707    ///
708    /// // Ask Claude to make some changes
709    /// client.send_request("Make some changes to my files".to_string(), None).await?;
710    ///
711    /// // ... later, rewind to a checkpoint
712    /// // client.rewind_files("user-message-uuid-here").await?;
713    /// # Ok(())
714    /// # }
715    /// ```
716    ///
717    /// # Errors
718    ///
719    /// Returns an error if:
720    /// - The client is not connected
721    /// - The query handler is not initialized (control protocol required)
722    /// - File checkpointing is not enabled
723    /// - The specified user_message_id is invalid
724    pub async fn rewind_files(&mut self, user_message_id: &str) -> Result<()> {
725        // Check connection
726        {
727            let state = self.state.read().await;
728            if *state != ClientState::Connected {
729                return Err(SdkError::InvalidState {
730                    message: "Not connected. Call connect() first.".into(),
731                });
732            }
733        }
734
735        if !self.options.enable_file_checkpointing {
736            return Err(SdkError::InvalidState {
737                message: "File checkpointing is not enabled. Set ClaudeCodeOptions::builder().enable_file_checkpointing(true).".to_string(),
738            });
739        }
740
741        // Require query handler for control protocol
742        if let Some(ref query_handler) = self.query_handler {
743            let mut handler = query_handler.lock().await;
744            handler.rewind_files(user_message_id).await
745        } else {
746            Err(SdkError::InvalidState {
747                message: "Query handler not initialized. Enable control protocol features (can_use_tool, hooks, mcp_servers, or enable_file_checkpointing).".to_string(),
748            })
749        }
750    }
751
752    /// Get context usage information including token distribution and cache stats
753    ///
754    /// Returns detailed context window usage including categories, API usage with
755    /// cache token information, and auto-compact settings.
756    ///
757    /// Requires control protocol to be enabled.
758    pub async fn get_context_usage(&mut self) -> Result<crate::types::ContextUsageResponse> {
759        {
760            let state = self.state.read().await;
761            if *state != ClientState::Connected {
762                return Err(SdkError::InvalidState {
763                    message: "Not connected. Call connect() first.".into(),
764                });
765            }
766        }
767
768        if let Some(ref query_handler) = self.query_handler {
769            let mut handler = query_handler.lock().await;
770            let raw = handler.get_context_usage().await?;
771            serde_json::from_value(raw).map_err(|e| SdkError::ControlRequestError(
772                format!("Failed to parse context usage response: {e}")
773            ))
774        } else {
775            Err(SdkError::InvalidState {
776                message: "Query handler not initialized. Enable control protocol features.".to_string(),
777            })
778        }
779    }
780
781    /// Stop a background task by ID
782    ///
783    /// Requires control protocol to be enabled.
784    pub async fn stop_task(&mut self, task_id: &str) -> Result<()> {
785        {
786            let state = self.state.read().await;
787            if *state != ClientState::Connected {
788                return Err(SdkError::InvalidState {
789                    message: "Not connected. Call connect() first.".into(),
790                });
791            }
792        }
793
794        if let Some(ref query_handler) = self.query_handler {
795            let mut handler = query_handler.lock().await;
796            handler.stop_task(task_id).await
797        } else {
798            Err(SdkError::InvalidState {
799                message: "Query handler not initialized. Enable control protocol features.".to_string(),
800            })
801        }
802    }
803
804    /// Get MCP server status information
805    ///
806    /// Returns the current status of all connected MCP servers.
807    /// Requires control protocol to be enabled.
808    pub async fn get_mcp_status(&mut self) -> Result<serde_json::Value> {
809        {
810            let state = self.state.read().await;
811            if *state != ClientState::Connected {
812                return Err(SdkError::InvalidState {
813                    message: "Not connected. Call connect() first.".into(),
814                });
815            }
816        }
817
818        if let Some(ref query_handler) = self.query_handler {
819            let mut handler = query_handler.lock().await;
820            handler.get_mcp_status().await
821        } else {
822            Err(SdkError::InvalidState {
823                message: "Query handler not initialized. Enable control protocol features.".to_string(),
824            })
825        }
826    }
827
828    /// Reconnect a failed MCP server
829    ///
830    /// Attempts to re-establish connection to the specified MCP server.
831    /// Requires control protocol to be enabled.
832    pub async fn reconnect_mcp_server(&mut self, server_name: &str) -> Result<()> {
833        {
834            let state = self.state.read().await;
835            if *state != ClientState::Connected {
836                return Err(SdkError::InvalidState {
837                    message: "Not connected. Call connect() first.".into(),
838                });
839            }
840        }
841
842        if let Some(ref query_handler) = self.query_handler {
843            let mut handler = query_handler.lock().await;
844            handler.reconnect_mcp_server(server_name).await
845        } else {
846            Err(SdkError::InvalidState {
847                message: "Query handler not initialized. Enable control protocol features.".to_string(),
848            })
849        }
850    }
851
852    /// Toggle an MCP server on/off
853    ///
854    /// Enables or disables the specified MCP server.
855    /// Requires control protocol to be enabled.
856    pub async fn toggle_mcp_server(&mut self, server_name: &str, enabled: bool) -> Result<()> {
857        {
858            let state = self.state.read().await;
859            if *state != ClientState::Connected {
860                return Err(SdkError::InvalidState {
861                    message: "Not connected. Call connect() first.".into(),
862                });
863            }
864        }
865
866        if let Some(ref query_handler) = self.query_handler {
867            let mut handler = query_handler.lock().await;
868            handler.toggle_mcp_server(server_name, enabled).await
869        } else {
870            Err(SdkError::InvalidState {
871                message: "Query handler not initialized. Enable control protocol features.".to_string(),
872            })
873        }
874    }
875
876    /// Disconnect from Claude CLI
877    pub async fn disconnect(&mut self) -> Result<()> {
878        // Check if already disconnected
879        {
880            let state = self.state.read().await;
881            if *state == ClientState::Disconnected {
882                return Ok(());
883            }
884        }
885
886        // Disconnect transport
887        {
888            let mut transport = self.transport.lock().await;
889            transport.disconnect().await?;
890        }
891
892        // Update state
893        {
894            let mut state = self.state.write().await;
895            *state = ClientState::Disconnected;
896        }
897
898        // Clear sessions
899        {
900            let mut sessions = self.sessions.write().await;
901            sessions.clear();
902        }
903
904        info!("Disconnected from Claude CLI");
905        Ok(())
906    }
907
908    /// Start the message receiver task
909    async fn start_message_receiver(&mut self) {
910        let transport = self.transport.clone();
911        let message_tx = self.message_tx.clone();
912        let message_buffer = self.message_buffer.clone();
913        let state = self.state.clone();
914        let budget_manager = self.budget_manager.clone();
915
916        tokio::spawn(async move {
917            // Subscribe to messages without holding the lock
918            let mut stream = {
919                let mut transport = transport.lock().await;
920                transport.receive_messages()
921            }; // Lock is released here immediately
922
923            while let Some(result) = stream.next().await {
924                match result {
925                    Ok(message) => {
926                        // Update token usage for Result messages
927                        if let Message::Result { .. } = &message
928                            && let Message::Result { usage, total_cost_usd, .. } = &message {
929                                let (input_tokens, output_tokens) = if let Some(usage_json) = usage {
930                                    let input = usage_json.get("input_tokens")
931                                        .and_then(|v| v.as_u64())
932                                        .unwrap_or(0);
933                                    let output = usage_json.get("output_tokens")
934                                        .and_then(|v| v.as_u64())
935                                        .unwrap_or(0);
936                                    (input, output)
937                                } else {
938                                    (0, 0)
939                                };
940                                let cost = total_cost_usd.unwrap_or(0.0);
941                                budget_manager.update_usage(input_tokens, output_tokens, cost).await;
942                            }
943
944                        // Buffer init messages for get_server_info()
945                        if let Message::System { subtype, .. } = &message
946                            && subtype == "init" {
947                                let mut buffer = message_buffer.lock().await;
948                                buffer.push(message.clone());
949                            }
950
951                        // Try to send to current receiver
952                        let sent = {
953                            let mut tx_opt = message_tx.lock().await;
954                            if let Some(tx) = tx_opt.as_mut() {
955                                tx.send(Ok(message.clone())).await.is_ok()
956                            } else {
957                                false
958                            }
959                        };
960
961                        // If no receiver or send failed, buffer the message
962                        if !sent {
963                            let mut buffer = message_buffer.lock().await;
964                            buffer.push(message);
965                        }
966                    }
967                    Err(e) => {
968                        error!("Error receiving message: {}", e);
969
970                        // Send error to receiver if available
971                        let mut tx_opt = message_tx.lock().await;
972                        if let Some(tx) = tx_opt.as_mut() {
973                            let _ = tx.send(Err(e)).await;
974                        }
975
976                        // Update state on error
977                        let mut state = state.write().await;
978                        *state = ClientState::Error;
979                        break;
980                    }
981                }
982            }
983
984            debug!("Message receiver task ended");
985        });
986    }
987
988    /// Get token usage statistics
989    ///
990    /// Returns the current token usage tracker with cumulative statistics
991    /// for all queries executed by this client.
992    pub async fn get_usage_stats(&self) -> crate::token_tracker::TokenUsageTracker {
993        self.budget_manager.get_usage().await
994    }
995
996    /// Set budget limit with optional warning callback
997    ///
998    /// # Arguments
999    ///
1000    /// * `limit` - Budget limit configuration (cost and/or token caps)
1001    /// * `on_warning` - Optional callback function triggered when usage exceeds warning threshold
1002    ///
1003    /// # Example
1004    ///
1005    /// ```rust,no_run
1006    /// use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
1007    /// use cc_sdk::token_tracker::{BudgetLimit, BudgetWarningCallback};
1008    /// use std::sync::Arc;
1009    ///
1010    /// # async fn example() {
1011    /// let mut client = ClaudeSDKClient::new(ClaudeCodeOptions::default());
1012    ///
1013    /// // Set budget with callback
1014    /// let cb: BudgetWarningCallback = Arc::new(|msg: &str| println!("Budget warning: {}", msg));
1015    /// client.set_budget_limit(BudgetLimit::with_cost(5.0), Some(cb)).await;
1016    /// # }
1017    /// ```
1018    pub async fn set_budget_limit(
1019        &self,
1020        limit: crate::token_tracker::BudgetLimit,
1021        on_warning: Option<crate::token_tracker::BudgetWarningCallback>,
1022    ) {
1023        self.budget_manager.set_limit(limit).await;
1024        if let Some(callback) = on_warning {
1025            self.budget_manager.set_warning_callback(callback).await;
1026        }
1027    }
1028
1029    /// Clear budget limit and reset warning state
1030    pub async fn clear_budget_limit(&self) {
1031        self.budget_manager.clear_limit().await;
1032    }
1033
1034    /// Reset token usage statistics to zero
1035    ///
1036    /// Clears all accumulated token and cost statistics.
1037    /// Budget limits remain in effect.
1038    pub async fn reset_usage_stats(&self) {
1039        self.budget_manager.reset_usage().await;
1040    }
1041
1042    /// Check if budget has been exceeded
1043    ///
1044    /// Returns true if current usage exceeds any configured limits
1045    pub async fn is_budget_exceeded(&self) -> bool {
1046        self.budget_manager.is_exceeded().await
1047    }
1048
1049    // Removed unused helper; usage is updated inline in message receiver
1050}
1051
1052impl Drop for ClaudeSDKClient {
1053    fn drop(&mut self) {
1054        // Try to disconnect gracefully
1055        let transport = self.transport.clone();
1056        let state = self.state.clone();
1057
1058        if let Ok(handle) = tokio::runtime::Handle::try_current() {
1059            handle.spawn(async move {
1060                let state = state.read().await;
1061                if *state == ClientState::Connected {
1062                    let mut transport = transport.lock().await;
1063                    if let Err(e) = transport.disconnect().await {
1064                        debug!("Error disconnecting in drop: {}", e);
1065                    }
1066                }
1067            });
1068        }
1069    }
1070}
1071
1072#[cfg(test)]
1073mod tests {
1074    use super::*;
1075
1076    #[tokio::test]
1077    async fn test_client_lifecycle() {
1078        let options = ClaudeCodeOptions::default();
1079        let client = ClaudeSDKClient::new(options);
1080
1081        assert!(!client.is_connected().await);
1082        assert_eq!(client.get_sessions().await.len(), 0);
1083    }
1084
1085    #[tokio::test]
1086    async fn test_client_state_transitions() {
1087        let options = ClaudeCodeOptions::default();
1088        let client = ClaudeSDKClient::new(options);
1089
1090        let state = client.state.read().await;
1091        assert_eq!(*state, ClientState::Disconnected);
1092    }
1093
1094    #[test]
1095    fn test_file_checkpointing_enables_query_handler() {
1096        let options = ClaudeCodeOptions::builder()
1097            .enable_file_checkpointing(true)
1098            .build();
1099        let client = ClaudeSDKClient::new(options);
1100
1101        assert!(
1102            client.query_handler.is_some(),
1103            "enable_file_checkpointing should initialize the query handler for control protocol requests"
1104        );
1105    }
1106}