cc_sdk/
client.rs

1//! Interactive client for bidirectional communication with Claude
2//!
3//! This module provides the `ClaudeSDKClient` for interactive, stateful
4//! conversations with Claude Code CLI.
5
6use crate::{
7    errors::{Result, SdkError},
8    internal_query::Query,
9    token_tracker::BudgetManager,
10    transport::{InputMessage, SubprocessTransport, Transport},
11    types::{ClaudeCodeOptions, ContentBlock, ControlRequest, ControlResponse, Message},
12};
13use futures::stream::{Stream, StreamExt};
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::pin::Pin;
17use tokio::sync::{Mutex, RwLock, mpsc};
18use tokio_stream::wrappers::ReceiverStream;
19use tracing::{debug, error, info};
20
21/// Client state
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum ClientState {
24    /// Not connected
25    Disconnected,
26    /// Connected and ready
27    Connected,
28    /// Error state
29    Error,
30}
31
32/// Interactive client for bidirectional communication with Claude
33///
34/// `ClaudeSDKClient` provides a stateful, interactive interface for communicating
35/// with Claude Code CLI. Unlike the simple `query` function, this client supports:
36///
37/// - Bidirectional communication
38/// - Multiple sessions
39/// - Interrupt capabilities
40/// - State management
41/// - Follow-up messages based on responses
42///
43/// # Example
44///
45/// ```rust,no_run
46/// use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions, Message, Result};
47/// use futures::StreamExt;
48///
49/// #[tokio::main]
50/// async fn main() -> Result<()> {
51///     let options = ClaudeCodeOptions::builder()
52///         .system_prompt("You are a helpful assistant")
53///         .model("claude-3-opus-20240229")
54///         .build();
55///
56///     let mut client = ClaudeSDKClient::new(options);
57///
58///     // Connect with initial prompt
59///     client.connect(Some("Hello!".to_string())).await?;
60///
61///     // Receive initial response
62///     let mut messages = client.receive_messages().await;
63///     while let Some(msg) = messages.next().await {
64///         match msg? {
65///             Message::Result { .. } => break,
66///             msg => println!("{:?}", msg),
67///         }
68///     }
69///
70///     // Send follow-up
71///     client.send_request("What's 2 + 2?".to_string(), None).await?;
72///
73///     // Receive response
74///     let mut messages = client.receive_messages().await;
75///     while let Some(msg) = messages.next().await {
76///         println!("{:?}", msg?);
77///     }
78///
79///     // Disconnect
80///     client.disconnect().await?;
81///
82///     Ok(())
83/// }
84/// ```
85pub struct ClaudeSDKClient {
86    /// Configuration options
87    #[allow(dead_code)]
88    options: ClaudeCodeOptions,
89    /// Transport layer
90    transport: Arc<Mutex<Box<dyn Transport + Send>>>,
91    /// Internal query handler (when control protocol is enabled)
92    query_handler: Option<Arc<Mutex<Query>>>,
93    /// Client state
94    state: Arc<RwLock<ClientState>>,
95    /// Active sessions
96    sessions: Arc<RwLock<HashMap<String, SessionData>>>,
97    /// Message sender for current receiver
98    message_tx: Arc<Mutex<Option<mpsc::Sender<Result<Message>>>>>,
99    /// Message buffer for multiple receivers
100    message_buffer: Arc<Mutex<Vec<Message>>>,
101    /// Request counter
102    request_counter: Arc<Mutex<u64>>,
103    /// Budget manager for token tracking
104    budget_manager: BudgetManager,
105}
106
107/// Session data
108#[allow(dead_code)]
109struct SessionData {
110    /// Session ID
111    id: String,
112    /// Number of messages sent
113    message_count: usize,
114    /// Creation time
115    created_at: std::time::Instant,
116}
117
118impl ClaudeSDKClient {
119    /// Create a new client with the given options
120    pub fn new(options: ClaudeCodeOptions) -> Self {
121        // Set environment variable to indicate SDK usage
122        unsafe {
123            std::env::set_var("CLAUDE_CODE_ENTRYPOINT", "sdk-rust");
124        }
125
126        let transport = match SubprocessTransport::new(options.clone()) {
127            Ok(t) => t,
128            Err(e) => {
129                error!("Failed to create transport: {}", e);
130                // Create with empty path, will fail on connect
131                SubprocessTransport::with_cli_path(options.clone(), "")
132            }
133        };
134
135        // Wrap transport in Arc for sharing
136        let transport_arc: Arc<Mutex<Box<dyn Transport + Send>>> =
137            Arc::new(Mutex::new(Box::new(transport)));
138
139        Self::with_transport_internal(options, transport_arc)
140    }
141
142    /// Create a new client with a custom transport implementation
143    ///
144    /// This allows users to provide their own Transport implementation instead of
145    /// using the default SubprocessTransport. Useful for testing, custom CLI paths,
146    /// or alternative communication mechanisms.
147    ///
148    /// # Arguments
149    ///
150    /// * `options` - Configuration options for the client
151    /// * `transport` - Custom transport implementation
152    ///
153    /// # Example
154    ///
155    /// ```rust,no_run
156    /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions, SubprocessTransport};
157    /// # fn example() {
158    /// let options = ClaudeCodeOptions::default();
159    /// let transport = SubprocessTransport::with_cli_path(options.clone(), "/custom/path/claude-code");
160    /// let client = ClaudeSDKClient::with_transport(options, Box::new(transport));
161    /// # }
162    /// ```
163    pub fn with_transport(options: ClaudeCodeOptions, transport: Box<dyn Transport + Send>) -> Self {
164        // Set environment variable to indicate SDK usage
165        unsafe {
166            std::env::set_var("CLAUDE_CODE_ENTRYPOINT", "sdk-rust");
167        }
168
169        // Wrap transport in Arc for sharing
170        let transport_arc: Arc<Mutex<Box<dyn Transport + Send>>> =
171            Arc::new(Mutex::new(transport));
172
173        Self::with_transport_internal(options, transport_arc)
174    }
175
176    /// Internal helper to construct client with pre-wrapped transport
177    fn with_transport_internal(
178        options: ClaudeCodeOptions,
179        transport_arc: Arc<Mutex<Box<dyn Transport + Send>>>,
180    ) -> Self {
181        // Create query handler if control protocol features are enabled
182        let query_handler = if options.can_use_tool.is_some()
183            || options.hooks.is_some()
184            || !options.mcp_servers.is_empty()
185            || options.enable_file_checkpointing {
186            // Extract SDK MCP server instances
187            let sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>> = options.mcp_servers
188                .iter()
189                .filter_map(|(k, v)| {
190                    // Only extract SDK type MCP servers
191                    if let crate::types::McpServerConfig::Sdk { name: _, instance } = v {
192                        Some((k.clone(), instance.clone()))
193                    } else {
194                        None
195                    }
196                })
197                .collect();
198
199            // Enable streaming mode when control protocol is active
200            let is_streaming = options.can_use_tool.is_some()
201                || options.hooks.is_some()
202                || !sdk_mcp_servers.is_empty();
203
204            let query = Query::new(
205                transport_arc.clone(), // Share the same transport
206                is_streaming, // Enable streaming for control protocol
207                options.can_use_tool.clone(),
208                options.hooks.clone(),
209                sdk_mcp_servers,
210            );
211            Some(Arc::new(Mutex::new(query)))
212        } else {
213            None
214        };
215
216        Self {
217            options,
218            transport: transport_arc,
219            query_handler,
220            state: Arc::new(RwLock::new(ClientState::Disconnected)),
221            sessions: Arc::new(RwLock::new(HashMap::new())),
222            message_tx: Arc::new(Mutex::new(None)),
223            message_buffer: Arc::new(Mutex::new(Vec::new())),
224            request_counter: Arc::new(Mutex::new(0)),
225            budget_manager: BudgetManager::new(),
226        }
227    }
228
229    /// Connect to Claude CLI with an optional initial prompt
230    pub async fn connect(&mut self, initial_prompt: Option<String>) -> Result<()> {
231        // Check if already connected
232        {
233            let state = self.state.read().await;
234            if *state == ClientState::Connected {
235                return Ok(());
236            }
237        }
238
239        // Connect transport
240        {
241            let mut transport = self.transport.lock().await;
242            transport.connect().await?;
243        }
244
245        // Initialize query handler if present
246        if let Some(ref query_handler) = self.query_handler {
247            let mut handler = query_handler.lock().await;
248            handler.start().await?;
249            handler.initialize().await?;
250            info!("Initialized SDK control protocol");
251        }
252
253        // Update state
254        {
255            let mut state = self.state.write().await;
256            *state = ClientState::Connected;
257        }
258
259        info!("Connected to Claude CLI");
260
261        // Start message receiver task (always needed for regular messages)
262        self.start_message_receiver().await;
263
264        // Send initial prompt if provided
265        if let Some(prompt) = initial_prompt {
266            self.send_request(prompt, None).await?;
267        }
268
269        Ok(())
270    }
271
272    /// Send a user message to Claude
273    pub async fn send_user_message(&mut self, prompt: String) -> Result<()> {
274        // Check connection
275        {
276            let state = self.state.read().await;
277            if *state != ClientState::Connected {
278                return Err(SdkError::InvalidState {
279                    message: "Not connected".into(),
280                });
281            }
282        }
283
284        // Use default session ID
285        let session_id = "default".to_string();
286
287        // Update session data
288        {
289            let mut sessions = self.sessions.write().await;
290            let session = sessions.entry(session_id.clone()).or_insert_with(|| {
291                debug!("Creating new session: {}", session_id);
292                SessionData {
293                    id: session_id.clone(),
294                    message_count: 0,
295                    created_at: std::time::Instant::now(),
296                }
297            });
298            session.message_count += 1;
299        }
300
301        // Create and send message
302        let message = InputMessage::user(prompt, session_id.clone());
303
304        {
305            let mut transport = self.transport.lock().await;
306            transport.send_message(message).await?;
307        }
308
309        debug!("Sent request to Claude");
310        Ok(())
311    }
312
313    /// Send a request to Claude (alias for send_user_message with optional session_id)
314    pub async fn send_request(
315        &mut self,
316        prompt: String,
317        _session_id: Option<String>,
318    ) -> Result<()> {
319        // For now, ignore session_id and use send_user_message
320        self.send_user_message(prompt).await
321    }
322
323    /// Receive messages from Claude
324    ///
325    /// Returns a stream of messages. The stream will end when a Result message
326    /// is received or the connection is closed.
327    pub async fn receive_messages(&mut self) -> impl Stream<Item = Result<Message>> + use<> {
328        // Always use the regular message receiver
329        // (Query handler shares the same transport and receives control messages separately)
330        // Create a new channel for this receiver
331        let (tx, rx) = mpsc::channel(100);
332
333        // Get buffered messages and clear buffer
334        let buffered_messages = {
335            let mut buffer = self.message_buffer.lock().await;
336            std::mem::take(&mut *buffer)
337        };
338
339        // Send buffered messages to the new receiver
340        let tx_clone = tx.clone();
341        tokio::spawn(async move {
342            for msg in buffered_messages {
343                if tx_clone.send(Ok(msg)).await.is_err() {
344                    break;
345                }
346            }
347        });
348
349        // Store the sender for the message receiver task
350        {
351            let mut message_tx = self.message_tx.lock().await;
352            *message_tx = Some(tx);
353        }
354
355        ReceiverStream::new(rx)
356    }
357
358    /// Send an interrupt request
359    pub async fn interrupt(&mut self) -> Result<()> {
360        // Check connection
361        {
362            let state = self.state.read().await;
363            if *state != ClientState::Connected {
364                return Err(SdkError::InvalidState {
365                    message: "Not connected".into(),
366                });
367            }
368        }
369
370        // If we have a query handler, use it
371        if let Some(ref query_handler) = self.query_handler {
372            let mut handler = query_handler.lock().await;
373            return handler.interrupt().await;
374        }
375
376        // Otherwise use regular interrupt
377        // Generate request ID
378        let request_id = {
379            let mut counter = self.request_counter.lock().await;
380            *counter += 1;
381            format!("interrupt_{}", *counter)
382        };
383
384        // Send interrupt request
385        let request = ControlRequest::Interrupt {
386            request_id: request_id.clone(),
387        };
388
389        {
390            let mut transport = self.transport.lock().await;
391            transport.send_control_request(request).await?;
392        }
393
394        info!("Sent interrupt request: {}", request_id);
395
396        // Wait for acknowledgment (with timeout)
397        let transport = self.transport.clone();
398        let ack_task = tokio::spawn(async move {
399            let mut transport = transport.lock().await;
400            match tokio::time::timeout(
401                std::time::Duration::from_secs(5),
402                transport.receive_control_response(),
403            )
404            .await
405            {
406                Ok(Ok(Some(ControlResponse::InterruptAck {
407                    request_id: ack_id,
408                    success,
409                }))) => {
410                    if ack_id == request_id && success {
411                        Ok(())
412                    } else {
413                        Err(SdkError::ControlRequestError(
414                            "Interrupt not acknowledged successfully".into(),
415                        ))
416                    }
417                }
418                Ok(Ok(None)) => Err(SdkError::ControlRequestError(
419                    "No interrupt acknowledgment received".into(),
420                )),
421                Ok(Err(e)) => Err(e),
422                Err(_) => Err(SdkError::timeout(5)),
423            }
424        });
425
426        ack_task
427            .await
428            .map_err(|_| SdkError::ControlRequestError("Interrupt task panicked".into()))?
429    }
430
431    /// Check if the client is connected
432    pub async fn is_connected(&self) -> bool {
433        let state = self.state.read().await;
434        *state == ClientState::Connected
435    }
436
437    /// Get active session IDs
438    pub async fn get_sessions(&self) -> Vec<String> {
439        let sessions = self.sessions.read().await;
440        sessions.keys().cloned().collect()
441    }
442
443    /// Receive messages until and including a ResultMessage
444    ///
445    /// This is a convenience method that collects all messages from a single response.
446    /// It will automatically stop after receiving a ResultMessage.
447    pub async fn receive_response(&mut self) -> Pin<Box<dyn Stream<Item = Result<Message>> + Send + '_>> {
448        let mut messages = self.receive_messages().await;
449        
450        // Create a stream that stops after ResultMessage
451        Box::pin(async_stream::stream! {
452            while let Some(msg_result) = messages.next().await {
453                match &msg_result {
454                    Ok(Message::Result { .. }) => {
455                        yield msg_result;
456                        return;
457                    }
458                    _ => {
459                        yield msg_result;
460                    }
461                }
462            }
463        })
464    }
465
466    /// Get server information
467    ///
468    /// Returns initialization information from the Claude Code server including:
469    /// - Available commands
470    /// - Current and available output styles
471    /// - Server capabilities
472    pub async fn get_server_info(&self) -> Option<serde_json::Value> {
473        // If we have a query handler with control protocol, get from there
474        if let Some(ref query_handler) = self.query_handler {
475            let handler = query_handler.lock().await;
476            if let Some(init_result) = handler.get_initialization_result() {
477                return Some(init_result.clone());
478            }
479        }
480
481        // Otherwise check message buffer for init message
482        let buffer = self.message_buffer.lock().await;
483        for msg in buffer.iter() {
484            if let Message::System { subtype, data } = msg
485                && subtype == "init" {
486                    return Some(data.clone());
487                }
488        }
489        None
490    }
491
492    /// Get account information
493    ///
494    /// This method attempts to retrieve Claude account information through multiple methods:
495    /// 1. From environment variable `ANTHROPIC_USER_EMAIL`
496    /// 2. From Claude CLI config file (if accessible)
497    /// 3. By querying the CLI with `/status` command (interactive mode)
498    ///
499    /// # Returns
500    ///
501    /// A string containing the account information, or an error if unavailable.
502    ///
503    /// # Example
504    ///
505    /// ```rust,no_run
506    /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
507    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
508    /// let mut client = ClaudeSDKClient::new(ClaudeCodeOptions::default());
509    /// client.connect(None).await?;
510    ///
511    /// match client.get_account_info().await {
512    ///     Ok(info) => println!("Account: {}", info),
513    ///     Err(_) => println!("Account info not available"),
514    /// }
515    /// # Ok(())
516    /// # }
517    /// ```
518    ///
519    /// # Note
520    ///
521    /// Account information may not always be available in SDK mode.
522    /// Consider setting the `ANTHROPIC_USER_EMAIL` environment variable
523    /// for reliable account identification.
524    pub async fn get_account_info(&mut self) -> Result<String> {
525        // Check connection
526        {
527            let state = self.state.read().await;
528            if *state != ClientState::Connected {
529                return Err(SdkError::InvalidState {
530                    message: "Not connected. Call connect() first.".into(),
531                });
532            }
533        }
534
535        // Method 1: Check environment variable
536        if let Ok(email) = std::env::var("ANTHROPIC_USER_EMAIL") {
537            return Ok(format!("Email: {}", email));
538        }
539
540        // Method 2: Try reading from Claude config
541        if let Some(config_info) = Self::read_claude_config().await {
542            return Ok(config_info);
543        }
544
545        // Method 3: Try /status command (may not work in SDK mode)
546        self.send_user_message("/status".to_string()).await?;
547
548        let mut messages = self.receive_messages().await;
549        let mut account_info = String::new();
550
551        while let Some(msg_result) = messages.next().await {
552            match msg_result? {
553                Message::Assistant { message } => {
554                    for block in message.content {
555                        if let ContentBlock::Text(text) = block {
556                            account_info.push_str(&text.text);
557                            account_info.push('\n');
558                        }
559                    }
560                }
561                Message::Result { .. } => break,
562                _ => {}
563            }
564        }
565
566        let trimmed = account_info.trim();
567
568        // Check if we got actual status info or just a chat response
569        if !trimmed.is_empty() && (
570            trimmed.contains("account") ||
571            trimmed.contains("email") ||
572            trimmed.contains("subscription") ||
573            trimmed.contains("authenticated")
574        ) {
575            return Ok(trimmed.to_string());
576        }
577
578        Err(SdkError::InvalidState {
579            message: "Account information not available. Try setting ANTHROPIC_USER_EMAIL environment variable.".into(),
580        })
581    }
582
583    /// Read Claude config file
584    async fn read_claude_config() -> Option<String> {
585        // Try common config locations
586        let config_paths = vec![
587            dirs::home_dir()?.join(".config").join("claude").join("config.json"),
588            dirs::home_dir()?.join(".claude").join("config.json"),
589        ];
590
591        for path in config_paths {
592            if let Ok(content) = tokio::fs::read_to_string(&path).await {
593                if let Ok(json) = serde_json::from_str::<serde_json::Value>(&content) {
594                    if let Some(email) = json.get("email").and_then(|v| v.as_str()) {
595                        return Some(format!("Email: {}", email));
596                    }
597                    if let Some(user) = json.get("user").and_then(|v| v.as_str()) {
598                        return Some(format!("User: {}", user));
599                    }
600                }
601            }
602        }
603
604        None
605    }
606
607    /// Set permission mode dynamically
608    ///
609    /// Changes the permission mode during an active session.
610    /// Requires control protocol to be enabled (via can_use_tool, hooks, mcp_servers, or file checkpointing).
611    ///
612    /// # Arguments
613    ///
614    /// * `mode` - Permission mode: "default", "acceptEdits", "plan", or "bypassPermissions"
615    ///
616    /// # Example
617    ///
618    /// ```rust,no_run
619    /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
620    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
621    /// let mut client = ClaudeSDKClient::new(ClaudeCodeOptions::default());
622    /// client.connect(None).await?;
623    ///
624    /// // Switch to accept edits mode
625    /// client.set_permission_mode("acceptEdits").await?;
626    /// # Ok(())
627    /// # }
628    /// ```
629    pub async fn set_permission_mode(&mut self, mode: &str) -> Result<()> {
630        if let Some(ref query_handler) = self.query_handler {
631            let mut handler = query_handler.lock().await;
632            handler.set_permission_mode(mode).await
633        } else {
634            Err(SdkError::InvalidState {
635                message: "Query handler not initialized. Enable control protocol features (can_use_tool, hooks, mcp_servers, or enable_file_checkpointing).".to_string(),
636            })
637        }
638    }
639
640    /// Set model dynamically
641    ///
642    /// Changes the active model during an active session.
643    /// Requires control protocol to be enabled (via can_use_tool, hooks, mcp_servers, or file checkpointing).
644    ///
645    /// # Arguments
646    ///
647    /// * `model` - Model identifier (e.g., "claude-3-5-sonnet-20241022") or None to use default
648    ///
649    /// # Example
650    ///
651    /// ```rust,no_run
652    /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
653    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
654    /// let mut client = ClaudeSDKClient::new(ClaudeCodeOptions::default());
655    /// client.connect(None).await?;
656    ///
657    /// // Switch to a different model
658    /// client.set_model(Some("claude-3-5-sonnet-20241022".to_string())).await?;
659    /// # Ok(())
660    /// # }
661    /// ```
662    pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
663        if let Some(ref query_handler) = self.query_handler {
664            let mut handler = query_handler.lock().await;
665            handler.set_model(model).await
666        } else {
667            Err(SdkError::InvalidState {
668                message: "Query handler not initialized. Enable control protocol features (can_use_tool, hooks, mcp_servers, or enable_file_checkpointing).".to_string(),
669            })
670        }
671    }
672
673    /// Send a query with optional session ID
674    ///
675    /// This method is similar to Python SDK's query method in ClaudeSDKClient
676    pub async fn query(&mut self, prompt: String, session_id: Option<String>) -> Result<()> {
677        let session_id = session_id.unwrap_or_else(|| "default".to_string());
678
679        // Send the message
680        let message = InputMessage::user(prompt, session_id);
681
682        {
683            let mut transport = self.transport.lock().await;
684            transport.send_message(message).await?;
685        }
686
687        Ok(())
688    }
689
690    /// Rewind tracked files to their state at a specific user message
691    ///
692    /// Requires `enable_file_checkpointing` to be enabled in `ClaudeCodeOptions`.
693    /// This method allows you to undo file changes made during the session by
694    /// reverting them to their state at any previous user message checkpoint.
695    ///
696    /// # Arguments
697    ///
698    /// * `user_message_id` - UUID of the user message to rewind to. This should be
699    ///   the `uuid` field from a message received during the conversation.
700    ///
701    /// # Example
702    ///
703    /// ```rust,no_run
704    /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
705    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
706    /// let options = ClaudeCodeOptions::builder()
707    ///     .enable_file_checkpointing(true)
708    ///     .build();
709    /// let mut client = ClaudeSDKClient::new(options);
710    /// client.connect(None).await?;
711    ///
712    /// // Ask Claude to make some changes
713    /// client.send_request("Make some changes to my files".to_string(), None).await?;
714    ///
715    /// // ... later, rewind to a checkpoint
716    /// // client.rewind_files("user-message-uuid-here").await?;
717    /// # Ok(())
718    /// # }
719    /// ```
720    ///
721    /// # Errors
722    ///
723    /// Returns an error if:
724    /// - The client is not connected
725    /// - The query handler is not initialized (control protocol required)
726    /// - File checkpointing is not enabled
727    /// - The specified user_message_id is invalid
728    pub async fn rewind_files(&mut self, user_message_id: &str) -> Result<()> {
729        // Check connection
730        {
731            let state = self.state.read().await;
732            if *state != ClientState::Connected {
733                return Err(SdkError::InvalidState {
734                    message: "Not connected. Call connect() first.".into(),
735                });
736            }
737        }
738
739        if !self.options.enable_file_checkpointing {
740            return Err(SdkError::InvalidState {
741                message: "File checkpointing is not enabled. Set ClaudeCodeOptions::builder().enable_file_checkpointing(true).".to_string(),
742            });
743        }
744
745        // Require query handler for control protocol
746        if let Some(ref query_handler) = self.query_handler {
747            let mut handler = query_handler.lock().await;
748            handler.rewind_files(user_message_id).await
749        } else {
750            Err(SdkError::InvalidState {
751                message: "Query handler not initialized. Enable control protocol features (can_use_tool, hooks, mcp_servers, or enable_file_checkpointing).".to_string(),
752            })
753        }
754    }
755
756    /// Disconnect from Claude CLI
757    pub async fn disconnect(&mut self) -> Result<()> {
758        // Check if already disconnected
759        {
760            let state = self.state.read().await;
761            if *state == ClientState::Disconnected {
762                return Ok(());
763            }
764        }
765
766        // Disconnect transport
767        {
768            let mut transport = self.transport.lock().await;
769            transport.disconnect().await?;
770        }
771
772        // Update state
773        {
774            let mut state = self.state.write().await;
775            *state = ClientState::Disconnected;
776        }
777
778        // Clear sessions
779        {
780            let mut sessions = self.sessions.write().await;
781            sessions.clear();
782        }
783
784        info!("Disconnected from Claude CLI");
785        Ok(())
786    }
787
788    /// Start the message receiver task
789    async fn start_message_receiver(&mut self) {
790        let transport = self.transport.clone();
791        let message_tx = self.message_tx.clone();
792        let message_buffer = self.message_buffer.clone();
793        let state = self.state.clone();
794        let budget_manager = self.budget_manager.clone();
795
796        tokio::spawn(async move {
797            // Subscribe to messages without holding the lock
798            let mut stream = {
799                let mut transport = transport.lock().await;
800                transport.receive_messages()
801            }; // Lock is released here immediately
802
803            while let Some(result) = stream.next().await {
804                match result {
805                    Ok(message) => {
806                        // Update token usage for Result messages
807                        if let Message::Result { .. } = &message
808                            && let Message::Result { usage, total_cost_usd, .. } = &message {
809                                let (input_tokens, output_tokens) = if let Some(usage_json) = usage {
810                                    let input = usage_json.get("input_tokens")
811                                        .and_then(|v| v.as_u64())
812                                        .unwrap_or(0);
813                                    let output = usage_json.get("output_tokens")
814                                        .and_then(|v| v.as_u64())
815                                        .unwrap_or(0);
816                                    (input, output)
817                                } else {
818                                    (0, 0)
819                                };
820                                let cost = total_cost_usd.unwrap_or(0.0);
821                                budget_manager.update_usage(input_tokens, output_tokens, cost).await;
822                            }
823
824                        // Buffer init messages for get_server_info()
825                        if let Message::System { subtype, .. } = &message
826                            && subtype == "init" {
827                                let mut buffer = message_buffer.lock().await;
828                                buffer.push(message.clone());
829                            }
830
831                        // Try to send to current receiver
832                        let sent = {
833                            let mut tx_opt = message_tx.lock().await;
834                            if let Some(tx) = tx_opt.as_mut() {
835                                tx.send(Ok(message.clone())).await.is_ok()
836                            } else {
837                                false
838                            }
839                        };
840
841                        // If no receiver or send failed, buffer the message
842                        if !sent {
843                            let mut buffer = message_buffer.lock().await;
844                            buffer.push(message);
845                        }
846                    }
847                    Err(e) => {
848                        error!("Error receiving message: {}", e);
849
850                        // Send error to receiver if available
851                        let mut tx_opt = message_tx.lock().await;
852                        if let Some(tx) = tx_opt.as_mut() {
853                            let _ = tx.send(Err(e)).await;
854                        }
855
856                        // Update state on error
857                        let mut state = state.write().await;
858                        *state = ClientState::Error;
859                        break;
860                    }
861                }
862            }
863
864            debug!("Message receiver task ended");
865        });
866    }
867
868    /// Get token usage statistics
869    ///
870    /// Returns the current token usage tracker with cumulative statistics
871    /// for all queries executed by this client.
872    pub async fn get_usage_stats(&self) -> crate::token_tracker::TokenUsageTracker {
873        self.budget_manager.get_usage().await
874    }
875
876    /// Set budget limit with optional warning callback
877    ///
878    /// # Arguments
879    ///
880    /// * `limit` - Budget limit configuration (cost and/or token caps)
881    /// * `on_warning` - Optional callback function triggered when usage exceeds warning threshold
882    ///
883    /// # Example
884    ///
885    /// ```rust,no_run
886    /// use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
887    /// use cc_sdk::token_tracker::{BudgetLimit, BudgetWarningCallback};
888    /// use std::sync::Arc;
889    ///
890    /// # async fn example() {
891    /// let mut client = ClaudeSDKClient::new(ClaudeCodeOptions::default());
892    ///
893    /// // Set budget with callback
894    /// let cb: BudgetWarningCallback = Arc::new(|msg: &str| println!("Budget warning: {}", msg));
895    /// client.set_budget_limit(BudgetLimit::with_cost(5.0), Some(cb)).await;
896    /// # }
897    /// ```
898    pub async fn set_budget_limit(
899        &self,
900        limit: crate::token_tracker::BudgetLimit,
901        on_warning: Option<crate::token_tracker::BudgetWarningCallback>,
902    ) {
903        self.budget_manager.set_limit(limit).await;
904        if let Some(callback) = on_warning {
905            self.budget_manager.set_warning_callback(callback).await;
906        }
907    }
908
909    /// Clear budget limit and reset warning state
910    pub async fn clear_budget_limit(&self) {
911        self.budget_manager.clear_limit().await;
912    }
913
914    /// Reset token usage statistics to zero
915    ///
916    /// Clears all accumulated token and cost statistics.
917    /// Budget limits remain in effect.
918    pub async fn reset_usage_stats(&self) {
919        self.budget_manager.reset_usage().await;
920    }
921
922    /// Check if budget has been exceeded
923    ///
924    /// Returns true if current usage exceeds any configured limits
925    pub async fn is_budget_exceeded(&self) -> bool {
926        self.budget_manager.is_exceeded().await
927    }
928
929    // Removed unused helper; usage is updated inline in message receiver
930}
931
932impl Drop for ClaudeSDKClient {
933    fn drop(&mut self) {
934        // Try to disconnect gracefully
935        let transport = self.transport.clone();
936        let state = self.state.clone();
937
938        if let Ok(handle) = tokio::runtime::Handle::try_current() {
939            handle.spawn(async move {
940                let state = state.read().await;
941                if *state == ClientState::Connected {
942                    let mut transport = transport.lock().await;
943                    if let Err(e) = transport.disconnect().await {
944                        debug!("Error disconnecting in drop: {}", e);
945                    }
946                }
947            });
948        }
949    }
950}
951
952#[cfg(test)]
953mod tests {
954    use super::*;
955
956    #[tokio::test]
957    async fn test_client_lifecycle() {
958        let options = ClaudeCodeOptions::default();
959        let client = ClaudeSDKClient::new(options);
960
961        assert!(!client.is_connected().await);
962        assert_eq!(client.get_sessions().await.len(), 0);
963    }
964
965    #[tokio::test]
966    async fn test_client_state_transitions() {
967        let options = ClaudeCodeOptions::default();
968        let client = ClaudeSDKClient::new(options);
969
970        let state = client.state.read().await;
971        assert_eq!(*state, ClientState::Disconnected);
972    }
973
974    #[test]
975    fn test_file_checkpointing_enables_query_handler() {
976        let options = ClaudeCodeOptions::builder()
977            .enable_file_checkpointing(true)
978            .build();
979        let client = ClaudeSDKClient::new(options);
980
981        assert!(
982            client.query_handler.is_some(),
983            "enable_file_checkpointing should initialize the query handler for control protocol requests"
984        );
985    }
986}