cc_sdk/client.rs
1//! Interactive client for bidirectional communication with Claude
2//!
3//! This module provides the `ClaudeSDKClient` for interactive, stateful
4//! conversations with Claude Code CLI.
5
6use crate::{
7 errors::{Result, SdkError},
8 internal_query::Query,
9 token_tracker::BudgetManager,
10 transport::{InputMessage, SubprocessTransport, Transport},
11 types::{ClaudeCodeOptions, ContentBlock, ControlRequest, ControlResponse, Message},
12};
13use futures::stream::{Stream, StreamExt};
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::pin::Pin;
17use tokio::sync::{Mutex, RwLock, mpsc};
18use tokio_stream::wrappers::ReceiverStream;
19use tracing::{debug, error, info};
20
21/// Client state
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum ClientState {
24 /// Not connected
25 Disconnected,
26 /// Connected and ready
27 Connected,
28 /// Error state
29 Error,
30}
31
32/// Interactive client for bidirectional communication with Claude
33///
34/// `ClaudeSDKClient` provides a stateful, interactive interface for communicating
35/// with Claude Code CLI. Unlike the simple `query` function, this client supports:
36///
37/// - Bidirectional communication
38/// - Multiple sessions
39/// - Interrupt capabilities
40/// - State management
41/// - Follow-up messages based on responses
42///
43/// # Example
44///
45/// ```rust,no_run
46/// use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions, Message, Result};
47/// use futures::StreamExt;
48///
49/// #[tokio::main]
50/// async fn main() -> Result<()> {
51/// let options = ClaudeCodeOptions::builder()
52/// .system_prompt("You are a helpful assistant")
53/// .model("claude-3-opus-20240229")
54/// .build();
55///
56/// let mut client = ClaudeSDKClient::new(options);
57///
58/// // Connect with initial prompt
59/// client.connect(Some("Hello!".to_string())).await?;
60///
61/// // Receive initial response
62/// let mut messages = client.receive_messages().await;
63/// while let Some(msg) = messages.next().await {
64/// match msg? {
65/// Message::Result { .. } => break,
66/// msg => println!("{:?}", msg),
67/// }
68/// }
69///
70/// // Send follow-up
71/// client.send_request("What's 2 + 2?".to_string(), None).await?;
72///
73/// // Receive response
74/// let mut messages = client.receive_messages().await;
75/// while let Some(msg) = messages.next().await {
76/// println!("{:?}", msg?);
77/// }
78///
79/// // Disconnect
80/// client.disconnect().await?;
81///
82/// Ok(())
83/// }
84/// ```
85pub struct ClaudeSDKClient {
86 /// Configuration options
87 #[allow(dead_code)]
88 options: ClaudeCodeOptions,
89 /// Transport layer
90 transport: Arc<Mutex<Box<dyn Transport + Send>>>,
91 /// Internal query handler (when control protocol is enabled)
92 query_handler: Option<Arc<Mutex<Query>>>,
93 /// Client state
94 state: Arc<RwLock<ClientState>>,
95 /// Active sessions
96 sessions: Arc<RwLock<HashMap<String, SessionData>>>,
97 /// Message sender for current receiver
98 message_tx: Arc<Mutex<Option<mpsc::Sender<Result<Message>>>>>,
99 /// Message buffer for multiple receivers
100 message_buffer: Arc<Mutex<Vec<Message>>>,
101 /// Request counter
102 request_counter: Arc<Mutex<u64>>,
103 /// Budget manager for token tracking
104 budget_manager: BudgetManager,
105}
106
107/// Session data
108#[allow(dead_code)]
109struct SessionData {
110 /// Session ID
111 id: String,
112 /// Number of messages sent
113 message_count: usize,
114 /// Creation time
115 created_at: std::time::Instant,
116}
117
118impl ClaudeSDKClient {
119 /// Create a new client with the given options
120 pub fn new(options: ClaudeCodeOptions) -> Self {
121 // Set environment variable to indicate SDK usage
122 unsafe {
123 std::env::set_var("CLAUDE_CODE_ENTRYPOINT", "sdk-rust");
124 }
125
126 let transport = match SubprocessTransport::new(options.clone()) {
127 Ok(t) => t,
128 Err(e) => {
129 error!("Failed to create transport: {}", e);
130 // Create with empty path, will fail on connect
131 SubprocessTransport::with_cli_path(options.clone(), "")
132 }
133 };
134
135 // Wrap transport in Arc for sharing
136 let transport_arc: Arc<Mutex<Box<dyn Transport + Send>>> =
137 Arc::new(Mutex::new(Box::new(transport)));
138
139 Self::with_transport_internal(options, transport_arc)
140 }
141
142 /// Create a new client with a custom transport implementation
143 ///
144 /// This allows users to provide their own Transport implementation instead of
145 /// using the default SubprocessTransport. Useful for testing, custom CLI paths,
146 /// or alternative communication mechanisms.
147 ///
148 /// # Arguments
149 ///
150 /// * `options` - Configuration options for the client
151 /// * `transport` - Custom transport implementation
152 ///
153 /// # Example
154 ///
155 /// ```rust,no_run
156 /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions, SubprocessTransport};
157 /// # fn example() {
158 /// let options = ClaudeCodeOptions::default();
159 /// let transport = SubprocessTransport::with_cli_path(options.clone(), "/custom/path/claude-code");
160 /// let client = ClaudeSDKClient::with_transport(options, Box::new(transport));
161 /// # }
162 /// ```
163 pub fn with_transport(options: ClaudeCodeOptions, transport: Box<dyn Transport + Send>) -> Self {
164 // Set environment variable to indicate SDK usage
165 unsafe {
166 std::env::set_var("CLAUDE_CODE_ENTRYPOINT", "sdk-rust");
167 }
168
169 // Wrap transport in Arc for sharing
170 let transport_arc: Arc<Mutex<Box<dyn Transport + Send>>> =
171 Arc::new(Mutex::new(transport));
172
173 Self::with_transport_internal(options, transport_arc)
174 }
175
176 /// Internal helper to construct client with pre-wrapped transport
177 fn with_transport_internal(
178 options: ClaudeCodeOptions,
179 transport_arc: Arc<Mutex<Box<dyn Transport + Send>>>,
180 ) -> Self {
181 // Create query handler if control protocol features are enabled
182 let query_handler = if options.can_use_tool.is_some()
183 || options.hooks.is_some()
184 || !options.mcp_servers.is_empty()
185 || options.enable_file_checkpointing {
186 // Extract SDK MCP server instances
187 let sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>> = options.mcp_servers
188 .iter()
189 .filter_map(|(k, v)| {
190 // Only extract SDK type MCP servers
191 if let crate::types::McpServerConfig::Sdk { name: _, instance } = v {
192 Some((k.clone(), instance.clone()))
193 } else {
194 None
195 }
196 })
197 .collect();
198
199 // Enable streaming mode when control protocol is active
200 let is_streaming = options.can_use_tool.is_some()
201 || options.hooks.is_some()
202 || !sdk_mcp_servers.is_empty();
203
204 let query = Query::new(
205 transport_arc.clone(), // Share the same transport
206 is_streaming, // Enable streaming for control protocol
207 options.can_use_tool.clone(),
208 options.hooks.clone(),
209 sdk_mcp_servers,
210 );
211 Some(Arc::new(Mutex::new(query)))
212 } else {
213 None
214 };
215
216 Self {
217 options,
218 transport: transport_arc,
219 query_handler,
220 state: Arc::new(RwLock::new(ClientState::Disconnected)),
221 sessions: Arc::new(RwLock::new(HashMap::new())),
222 message_tx: Arc::new(Mutex::new(None)),
223 message_buffer: Arc::new(Mutex::new(Vec::new())),
224 request_counter: Arc::new(Mutex::new(0)),
225 budget_manager: BudgetManager::new(),
226 }
227 }
228
229 /// Connect to Claude CLI with an optional initial prompt
230 pub async fn connect(&mut self, initial_prompt: Option<String>) -> Result<()> {
231 // Check if already connected
232 {
233 let state = self.state.read().await;
234 if *state == ClientState::Connected {
235 return Ok(());
236 }
237 }
238
239 // Connect transport
240 {
241 let mut transport = self.transport.lock().await;
242 transport.connect().await?;
243 }
244
245 // Initialize query handler if present
246 if let Some(ref query_handler) = self.query_handler {
247 let mut handler = query_handler.lock().await;
248 handler.start().await?;
249 handler.initialize().await?;
250 info!("Initialized SDK control protocol");
251 }
252
253 // Update state
254 {
255 let mut state = self.state.write().await;
256 *state = ClientState::Connected;
257 }
258
259 info!("Connected to Claude CLI");
260
261 // Start message receiver task (always needed for regular messages)
262 self.start_message_receiver().await;
263
264 // Send initial prompt if provided
265 if let Some(prompt) = initial_prompt {
266 self.send_request(prompt, None).await?;
267 }
268
269 Ok(())
270 }
271
272 /// Send a user message to Claude
273 pub async fn send_user_message(&mut self, prompt: String) -> Result<()> {
274 // Check connection
275 {
276 let state = self.state.read().await;
277 if *state != ClientState::Connected {
278 return Err(SdkError::InvalidState {
279 message: "Not connected".into(),
280 });
281 }
282 }
283
284 // Use default session ID
285 let session_id = "default".to_string();
286
287 // Update session data
288 {
289 let mut sessions = self.sessions.write().await;
290 let session = sessions.entry(session_id.clone()).or_insert_with(|| {
291 debug!("Creating new session: {}", session_id);
292 SessionData {
293 id: session_id.clone(),
294 message_count: 0,
295 created_at: std::time::Instant::now(),
296 }
297 });
298 session.message_count += 1;
299 }
300
301 // Create and send message
302 let message = InputMessage::user(prompt, session_id.clone());
303
304 {
305 let mut transport = self.transport.lock().await;
306 transport.send_message(message).await?;
307 }
308
309 debug!("Sent request to Claude");
310 Ok(())
311 }
312
313 /// Send a request to Claude (alias for send_user_message with optional session_id)
314 pub async fn send_request(
315 &mut self,
316 prompt: String,
317 _session_id: Option<String>,
318 ) -> Result<()> {
319 // For now, ignore session_id and use send_user_message
320 self.send_user_message(prompt).await
321 }
322
323 /// Receive messages from Claude
324 ///
325 /// Returns a stream of messages. The stream will end when a Result message
326 /// is received or the connection is closed.
327 pub async fn receive_messages(&mut self) -> impl Stream<Item = Result<Message>> + use<> {
328 // Always use the regular message receiver
329 // (Query handler shares the same transport and receives control messages separately)
330 // Create a new channel for this receiver
331 let (tx, rx) = mpsc::channel(100);
332
333 // Get buffered messages and clear buffer
334 let buffered_messages = {
335 let mut buffer = self.message_buffer.lock().await;
336 std::mem::take(&mut *buffer)
337 };
338
339 // Send buffered messages to the new receiver
340 let tx_clone = tx.clone();
341 tokio::spawn(async move {
342 for msg in buffered_messages {
343 if tx_clone.send(Ok(msg)).await.is_err() {
344 break;
345 }
346 }
347 });
348
349 // Store the sender for the message receiver task
350 {
351 let mut message_tx = self.message_tx.lock().await;
352 *message_tx = Some(tx);
353 }
354
355 ReceiverStream::new(rx)
356 }
357
358 /// Send an interrupt request
359 pub async fn interrupt(&mut self) -> Result<()> {
360 // Check connection
361 {
362 let state = self.state.read().await;
363 if *state != ClientState::Connected {
364 return Err(SdkError::InvalidState {
365 message: "Not connected".into(),
366 });
367 }
368 }
369
370 // If we have a query handler, use it
371 if let Some(ref query_handler) = self.query_handler {
372 let mut handler = query_handler.lock().await;
373 return handler.interrupt().await;
374 }
375
376 // Otherwise use regular interrupt
377 // Generate request ID
378 let request_id = {
379 let mut counter = self.request_counter.lock().await;
380 *counter += 1;
381 format!("interrupt_{}", *counter)
382 };
383
384 // Send interrupt request
385 let request = ControlRequest::Interrupt {
386 request_id: request_id.clone(),
387 };
388
389 {
390 let mut transport = self.transport.lock().await;
391 transport.send_control_request(request).await?;
392 }
393
394 info!("Sent interrupt request: {}", request_id);
395
396 // Wait for acknowledgment (with timeout)
397 let transport = self.transport.clone();
398 let ack_task = tokio::spawn(async move {
399 let mut transport = transport.lock().await;
400 match tokio::time::timeout(
401 std::time::Duration::from_secs(5),
402 transport.receive_control_response(),
403 )
404 .await
405 {
406 Ok(Ok(Some(ControlResponse::InterruptAck {
407 request_id: ack_id,
408 success,
409 }))) => {
410 if ack_id == request_id && success {
411 Ok(())
412 } else {
413 Err(SdkError::ControlRequestError(
414 "Interrupt not acknowledged successfully".into(),
415 ))
416 }
417 }
418 Ok(Ok(None)) => Err(SdkError::ControlRequestError(
419 "No interrupt acknowledgment received".into(),
420 )),
421 Ok(Err(e)) => Err(e),
422 Err(_) => Err(SdkError::timeout(5)),
423 }
424 });
425
426 ack_task
427 .await
428 .map_err(|_| SdkError::ControlRequestError("Interrupt task panicked".into()))?
429 }
430
431 /// Check if the client is connected
432 pub async fn is_connected(&self) -> bool {
433 let state = self.state.read().await;
434 *state == ClientState::Connected
435 }
436
437 /// Get active session IDs
438 pub async fn get_sessions(&self) -> Vec<String> {
439 let sessions = self.sessions.read().await;
440 sessions.keys().cloned().collect()
441 }
442
443 /// Receive messages until and including a ResultMessage
444 ///
445 /// This is a convenience method that collects all messages from a single response.
446 /// It will automatically stop after receiving a ResultMessage.
447 pub async fn receive_response(&mut self) -> Pin<Box<dyn Stream<Item = Result<Message>> + Send + '_>> {
448 let mut messages = self.receive_messages().await;
449
450 // Create a stream that stops after ResultMessage
451 Box::pin(async_stream::stream! {
452 while let Some(msg_result) = messages.next().await {
453 match &msg_result {
454 Ok(Message::Result { .. }) => {
455 yield msg_result;
456 return;
457 }
458 _ => {
459 yield msg_result;
460 }
461 }
462 }
463 })
464 }
465
466 /// Get server information
467 ///
468 /// Returns initialization information from the Claude Code server including:
469 /// - Available commands
470 /// - Current and available output styles
471 /// - Server capabilities
472 pub async fn get_server_info(&self) -> Option<serde_json::Value> {
473 // If we have a query handler with control protocol, get from there
474 if let Some(ref query_handler) = self.query_handler {
475 let handler = query_handler.lock().await;
476 if let Some(init_result) = handler.get_initialization_result() {
477 return Some(init_result.clone());
478 }
479 }
480
481 // Otherwise check message buffer for init message
482 let buffer = self.message_buffer.lock().await;
483 for msg in buffer.iter() {
484 if let Message::System { subtype, data } = msg
485 && subtype == "init" {
486 return Some(data.clone());
487 }
488 }
489 None
490 }
491
492 /// Get account information
493 ///
494 /// This method attempts to retrieve Claude account information through multiple methods:
495 /// 1. From environment variable `ANTHROPIC_USER_EMAIL`
496 /// 2. From Claude CLI config file (if accessible)
497 /// 3. By querying the CLI with `/status` command (interactive mode)
498 ///
499 /// # Returns
500 ///
501 /// A string containing the account information, or an error if unavailable.
502 ///
503 /// # Example
504 ///
505 /// ```rust,no_run
506 /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
507 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
508 /// let mut client = ClaudeSDKClient::new(ClaudeCodeOptions::default());
509 /// client.connect(None).await?;
510 ///
511 /// match client.get_account_info().await {
512 /// Ok(info) => println!("Account: {}", info),
513 /// Err(_) => println!("Account info not available"),
514 /// }
515 /// # Ok(())
516 /// # }
517 /// ```
518 ///
519 /// # Note
520 ///
521 /// Account information may not always be available in SDK mode.
522 /// Consider setting the `ANTHROPIC_USER_EMAIL` environment variable
523 /// for reliable account identification.
524 pub async fn get_account_info(&mut self) -> Result<String> {
525 // Check connection
526 {
527 let state = self.state.read().await;
528 if *state != ClientState::Connected {
529 return Err(SdkError::InvalidState {
530 message: "Not connected. Call connect() first.".into(),
531 });
532 }
533 }
534
535 // Method 1: Check environment variable
536 if let Ok(email) = std::env::var("ANTHROPIC_USER_EMAIL") {
537 return Ok(format!("Email: {}", email));
538 }
539
540 // Method 2: Try reading from Claude config
541 if let Some(config_info) = Self::read_claude_config().await {
542 return Ok(config_info);
543 }
544
545 // Method 3: Try /status command (may not work in SDK mode)
546 self.send_user_message("/status".to_string()).await?;
547
548 let mut messages = self.receive_messages().await;
549 let mut account_info = String::new();
550
551 while let Some(msg_result) = messages.next().await {
552 match msg_result? {
553 Message::Assistant { message } => {
554 for block in message.content {
555 if let ContentBlock::Text(text) = block {
556 account_info.push_str(&text.text);
557 account_info.push('\n');
558 }
559 }
560 }
561 Message::Result { .. } => break,
562 _ => {}
563 }
564 }
565
566 let trimmed = account_info.trim();
567
568 // Check if we got actual status info or just a chat response
569 if !trimmed.is_empty() && (
570 trimmed.contains("account") ||
571 trimmed.contains("email") ||
572 trimmed.contains("subscription") ||
573 trimmed.contains("authenticated")
574 ) {
575 return Ok(trimmed.to_string());
576 }
577
578 Err(SdkError::InvalidState {
579 message: "Account information not available. Try setting ANTHROPIC_USER_EMAIL environment variable.".into(),
580 })
581 }
582
583 /// Read Claude config file
584 async fn read_claude_config() -> Option<String> {
585 // Try common config locations
586 let config_paths = vec![
587 dirs::home_dir()?.join(".config").join("claude").join("config.json"),
588 dirs::home_dir()?.join(".claude").join("config.json"),
589 ];
590
591 for path in config_paths {
592 if let Ok(content) = tokio::fs::read_to_string(&path).await {
593 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&content) {
594 if let Some(email) = json.get("email").and_then(|v| v.as_str()) {
595 return Some(format!("Email: {}", email));
596 }
597 if let Some(user) = json.get("user").and_then(|v| v.as_str()) {
598 return Some(format!("User: {}", user));
599 }
600 }
601 }
602 }
603
604 None
605 }
606
607 /// Set permission mode dynamically
608 ///
609 /// Changes the permission mode during an active session.
610 /// Requires control protocol to be enabled (via can_use_tool, hooks, mcp_servers, or file checkpointing).
611 ///
612 /// # Arguments
613 ///
614 /// * `mode` - Permission mode: "default", "acceptEdits", "plan", or "bypassPermissions"
615 ///
616 /// # Example
617 ///
618 /// ```rust,no_run
619 /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
620 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
621 /// let mut client = ClaudeSDKClient::new(ClaudeCodeOptions::default());
622 /// client.connect(None).await?;
623 ///
624 /// // Switch to accept edits mode
625 /// client.set_permission_mode("acceptEdits").await?;
626 /// # Ok(())
627 /// # }
628 /// ```
629 pub async fn set_permission_mode(&mut self, mode: &str) -> Result<()> {
630 if let Some(ref query_handler) = self.query_handler {
631 let mut handler = query_handler.lock().await;
632 handler.set_permission_mode(mode).await
633 } else {
634 Err(SdkError::InvalidState {
635 message: "Query handler not initialized. Enable control protocol features (can_use_tool, hooks, mcp_servers, or enable_file_checkpointing).".to_string(),
636 })
637 }
638 }
639
640 /// Set model dynamically
641 ///
642 /// Changes the active model during an active session.
643 /// Requires control protocol to be enabled (via can_use_tool, hooks, mcp_servers, or file checkpointing).
644 ///
645 /// # Arguments
646 ///
647 /// * `model` - Model identifier (e.g., "claude-3-5-sonnet-20241022") or None to use default
648 ///
649 /// # Example
650 ///
651 /// ```rust,no_run
652 /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
653 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
654 /// let mut client = ClaudeSDKClient::new(ClaudeCodeOptions::default());
655 /// client.connect(None).await?;
656 ///
657 /// // Switch to a different model
658 /// client.set_model(Some("claude-3-5-sonnet-20241022".to_string())).await?;
659 /// # Ok(())
660 /// # }
661 /// ```
662 pub async fn set_model(&mut self, model: Option<String>) -> Result<()> {
663 if let Some(ref query_handler) = self.query_handler {
664 let mut handler = query_handler.lock().await;
665 handler.set_model(model).await
666 } else {
667 Err(SdkError::InvalidState {
668 message: "Query handler not initialized. Enable control protocol features (can_use_tool, hooks, mcp_servers, or enable_file_checkpointing).".to_string(),
669 })
670 }
671 }
672
673 /// Send a query with optional session ID
674 ///
675 /// This method is similar to Python SDK's query method in ClaudeSDKClient
676 pub async fn query(&mut self, prompt: String, session_id: Option<String>) -> Result<()> {
677 let session_id = session_id.unwrap_or_else(|| "default".to_string());
678
679 // Send the message
680 let message = InputMessage::user(prompt, session_id);
681
682 {
683 let mut transport = self.transport.lock().await;
684 transport.send_message(message).await?;
685 }
686
687 Ok(())
688 }
689
690 /// Rewind tracked files to their state at a specific user message
691 ///
692 /// Requires `enable_file_checkpointing` to be enabled in `ClaudeCodeOptions`.
693 /// This method allows you to undo file changes made during the session by
694 /// reverting them to their state at any previous user message checkpoint.
695 ///
696 /// # Arguments
697 ///
698 /// * `user_message_id` - UUID of the user message to rewind to. This should be
699 /// the `uuid` field from a message received during the conversation.
700 ///
701 /// # Example
702 ///
703 /// ```rust,no_run
704 /// # use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
705 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
706 /// let options = ClaudeCodeOptions::builder()
707 /// .enable_file_checkpointing(true)
708 /// .build();
709 /// let mut client = ClaudeSDKClient::new(options);
710 /// client.connect(None).await?;
711 ///
712 /// // Ask Claude to make some changes
713 /// client.send_request("Make some changes to my files".to_string(), None).await?;
714 ///
715 /// // ... later, rewind to a checkpoint
716 /// // client.rewind_files("user-message-uuid-here").await?;
717 /// # Ok(())
718 /// # }
719 /// ```
720 ///
721 /// # Errors
722 ///
723 /// Returns an error if:
724 /// - The client is not connected
725 /// - The query handler is not initialized (control protocol required)
726 /// - File checkpointing is not enabled
727 /// - The specified user_message_id is invalid
728 pub async fn rewind_files(&mut self, user_message_id: &str) -> Result<()> {
729 // Check connection
730 {
731 let state = self.state.read().await;
732 if *state != ClientState::Connected {
733 return Err(SdkError::InvalidState {
734 message: "Not connected. Call connect() first.".into(),
735 });
736 }
737 }
738
739 if !self.options.enable_file_checkpointing {
740 return Err(SdkError::InvalidState {
741 message: "File checkpointing is not enabled. Set ClaudeCodeOptions::builder().enable_file_checkpointing(true).".to_string(),
742 });
743 }
744
745 // Require query handler for control protocol
746 if let Some(ref query_handler) = self.query_handler {
747 let mut handler = query_handler.lock().await;
748 handler.rewind_files(user_message_id).await
749 } else {
750 Err(SdkError::InvalidState {
751 message: "Query handler not initialized. Enable control protocol features (can_use_tool, hooks, mcp_servers, or enable_file_checkpointing).".to_string(),
752 })
753 }
754 }
755
756 /// Disconnect from Claude CLI
757 pub async fn disconnect(&mut self) -> Result<()> {
758 // Check if already disconnected
759 {
760 let state = self.state.read().await;
761 if *state == ClientState::Disconnected {
762 return Ok(());
763 }
764 }
765
766 // Disconnect transport
767 {
768 let mut transport = self.transport.lock().await;
769 transport.disconnect().await?;
770 }
771
772 // Update state
773 {
774 let mut state = self.state.write().await;
775 *state = ClientState::Disconnected;
776 }
777
778 // Clear sessions
779 {
780 let mut sessions = self.sessions.write().await;
781 sessions.clear();
782 }
783
784 info!("Disconnected from Claude CLI");
785 Ok(())
786 }
787
788 /// Start the message receiver task
789 async fn start_message_receiver(&mut self) {
790 let transport = self.transport.clone();
791 let message_tx = self.message_tx.clone();
792 let message_buffer = self.message_buffer.clone();
793 let state = self.state.clone();
794 let budget_manager = self.budget_manager.clone();
795
796 tokio::spawn(async move {
797 // Subscribe to messages without holding the lock
798 let mut stream = {
799 let mut transport = transport.lock().await;
800 transport.receive_messages()
801 }; // Lock is released here immediately
802
803 while let Some(result) = stream.next().await {
804 match result {
805 Ok(message) => {
806 // Update token usage for Result messages
807 if let Message::Result { .. } = &message
808 && let Message::Result { usage, total_cost_usd, .. } = &message {
809 let (input_tokens, output_tokens) = if let Some(usage_json) = usage {
810 let input = usage_json.get("input_tokens")
811 .and_then(|v| v.as_u64())
812 .unwrap_or(0);
813 let output = usage_json.get("output_tokens")
814 .and_then(|v| v.as_u64())
815 .unwrap_or(0);
816 (input, output)
817 } else {
818 (0, 0)
819 };
820 let cost = total_cost_usd.unwrap_or(0.0);
821 budget_manager.update_usage(input_tokens, output_tokens, cost).await;
822 }
823
824 // Buffer init messages for get_server_info()
825 if let Message::System { subtype, .. } = &message
826 && subtype == "init" {
827 let mut buffer = message_buffer.lock().await;
828 buffer.push(message.clone());
829 }
830
831 // Try to send to current receiver
832 let sent = {
833 let mut tx_opt = message_tx.lock().await;
834 if let Some(tx) = tx_opt.as_mut() {
835 tx.send(Ok(message.clone())).await.is_ok()
836 } else {
837 false
838 }
839 };
840
841 // If no receiver or send failed, buffer the message
842 if !sent {
843 let mut buffer = message_buffer.lock().await;
844 buffer.push(message);
845 }
846 }
847 Err(e) => {
848 error!("Error receiving message: {}", e);
849
850 // Send error to receiver if available
851 let mut tx_opt = message_tx.lock().await;
852 if let Some(tx) = tx_opt.as_mut() {
853 let _ = tx.send(Err(e)).await;
854 }
855
856 // Update state on error
857 let mut state = state.write().await;
858 *state = ClientState::Error;
859 break;
860 }
861 }
862 }
863
864 debug!("Message receiver task ended");
865 });
866 }
867
868 /// Get token usage statistics
869 ///
870 /// Returns the current token usage tracker with cumulative statistics
871 /// for all queries executed by this client.
872 pub async fn get_usage_stats(&self) -> crate::token_tracker::TokenUsageTracker {
873 self.budget_manager.get_usage().await
874 }
875
876 /// Set budget limit with optional warning callback
877 ///
878 /// # Arguments
879 ///
880 /// * `limit` - Budget limit configuration (cost and/or token caps)
881 /// * `on_warning` - Optional callback function triggered when usage exceeds warning threshold
882 ///
883 /// # Example
884 ///
885 /// ```rust,no_run
886 /// use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
887 /// use cc_sdk::token_tracker::{BudgetLimit, BudgetWarningCallback};
888 /// use std::sync::Arc;
889 ///
890 /// # async fn example() {
891 /// let mut client = ClaudeSDKClient::new(ClaudeCodeOptions::default());
892 ///
893 /// // Set budget with callback
894 /// let cb: BudgetWarningCallback = Arc::new(|msg: &str| println!("Budget warning: {}", msg));
895 /// client.set_budget_limit(BudgetLimit::with_cost(5.0), Some(cb)).await;
896 /// # }
897 /// ```
898 pub async fn set_budget_limit(
899 &self,
900 limit: crate::token_tracker::BudgetLimit,
901 on_warning: Option<crate::token_tracker::BudgetWarningCallback>,
902 ) {
903 self.budget_manager.set_limit(limit).await;
904 if let Some(callback) = on_warning {
905 self.budget_manager.set_warning_callback(callback).await;
906 }
907 }
908
909 /// Clear budget limit and reset warning state
910 pub async fn clear_budget_limit(&self) {
911 self.budget_manager.clear_limit().await;
912 }
913
914 /// Reset token usage statistics to zero
915 ///
916 /// Clears all accumulated token and cost statistics.
917 /// Budget limits remain in effect.
918 pub async fn reset_usage_stats(&self) {
919 self.budget_manager.reset_usage().await;
920 }
921
922 /// Check if budget has been exceeded
923 ///
924 /// Returns true if current usage exceeds any configured limits
925 pub async fn is_budget_exceeded(&self) -> bool {
926 self.budget_manager.is_exceeded().await
927 }
928
929 // Removed unused helper; usage is updated inline in message receiver
930}
931
932impl Drop for ClaudeSDKClient {
933 fn drop(&mut self) {
934 // Try to disconnect gracefully
935 let transport = self.transport.clone();
936 let state = self.state.clone();
937
938 if let Ok(handle) = tokio::runtime::Handle::try_current() {
939 handle.spawn(async move {
940 let state = state.read().await;
941 if *state == ClientState::Connected {
942 let mut transport = transport.lock().await;
943 if let Err(e) = transport.disconnect().await {
944 debug!("Error disconnecting in drop: {}", e);
945 }
946 }
947 });
948 }
949 }
950}
951
952#[cfg(test)]
953mod tests {
954 use super::*;
955
956 #[tokio::test]
957 async fn test_client_lifecycle() {
958 let options = ClaudeCodeOptions::default();
959 let client = ClaudeSDKClient::new(options);
960
961 assert!(!client.is_connected().await);
962 assert_eq!(client.get_sessions().await.len(), 0);
963 }
964
965 #[tokio::test]
966 async fn test_client_state_transitions() {
967 let options = ClaudeCodeOptions::default();
968 let client = ClaudeSDKClient::new(options);
969
970 let state = client.state.read().await;
971 assert_eq!(*state, ClientState::Disconnected);
972 }
973
974 #[test]
975 fn test_file_checkpointing_enables_query_handler() {
976 let options = ClaudeCodeOptions::builder()
977 .enable_file_checkpointing(true)
978 .build();
979 let client = ClaudeSDKClient::new(options);
980
981 assert!(
982 client.query_handler.is_some(),
983 "enable_file_checkpointing should initialize the query handler for control protocol requests"
984 );
985 }
986}