cc_sdk/
client.rs

1//! Interactive client for bidirectional communication with Claude
2//!
3//! This module provides the `ClaudeSDKClient` for interactive, stateful
4//! conversations with Claude Code CLI.
5
6use crate::{
7    errors::{Result, SdkError},
8    internal_query::Query,
9    token_tracker::BudgetManager,
10    transport::{InputMessage, SubprocessTransport, Transport},
11    types::{ClaudeCodeOptions, ControlRequest, ControlResponse, Message},
12};
13use futures::stream::{Stream, StreamExt};
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::pin::Pin;
17use tokio::sync::{Mutex, RwLock, mpsc};
18use tokio_stream::wrappers::ReceiverStream;
19use tracing::{debug, error, info};
20
21/// Client state
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum ClientState {
24    /// Not connected
25    Disconnected,
26    /// Connected and ready
27    Connected,
28    /// Error state
29    Error,
30}
31
32/// Interactive client for bidirectional communication with Claude
33///
34/// `ClaudeSDKClient` provides a stateful, interactive interface for communicating
35/// with Claude Code CLI. Unlike the simple `query` function, this client supports:
36///
37/// - Bidirectional communication
38/// - Multiple sessions
39/// - Interrupt capabilities
40/// - State management
41/// - Follow-up messages based on responses
42///
43/// # Example
44///
45/// ```rust,no_run
46/// use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions, Message, Result};
47/// use futures::StreamExt;
48///
49/// #[tokio::main]
50/// async fn main() -> Result<()> {
51///     let options = ClaudeCodeOptions::builder()
52///         .system_prompt("You are a helpful assistant")
53///         .model("claude-3-opus-20240229")
54///         .build();
55///
56///     let mut client = ClaudeSDKClient::new(options);
57///
58///     // Connect with initial prompt
59///     client.connect(Some("Hello!".to_string())).await?;
60///
61///     // Receive initial response
62///     let mut messages = client.receive_messages().await;
63///     while let Some(msg) = messages.next().await {
64///         match msg? {
65///             Message::Result { .. } => break,
66///             msg => println!("{:?}", msg),
67///         }
68///     }
69///
70///     // Send follow-up
71///     client.send_request("What's 2 + 2?".to_string(), None).await?;
72///
73///     // Receive response
74///     let mut messages = client.receive_messages().await;
75///     while let Some(msg) = messages.next().await {
76///         println!("{:?}", msg?);
77///     }
78///
79///     // Disconnect
80///     client.disconnect().await?;
81///
82///     Ok(())
83/// }
84/// ```
85pub struct ClaudeSDKClient {
86    /// Configuration options
87    #[allow(dead_code)]
88    options: ClaudeCodeOptions,
89    /// Transport layer
90    transport: Arc<Mutex<Box<dyn Transport + Send>>>,
91    /// Internal query handler (when control protocol is enabled)
92    query_handler: Option<Arc<Mutex<Query>>>,
93    /// Client state
94    state: Arc<RwLock<ClientState>>,
95    /// Active sessions
96    sessions: Arc<RwLock<HashMap<String, SessionData>>>,
97    /// Message sender for current receiver
98    message_tx: Arc<Mutex<Option<mpsc::Sender<Result<Message>>>>>,
99    /// Message buffer for multiple receivers
100    message_buffer: Arc<Mutex<Vec<Message>>>,
101    /// Request counter
102    request_counter: Arc<Mutex<u64>>,
103    /// Budget manager for token tracking
104    budget_manager: BudgetManager,
105}
106
107/// Session data
108#[allow(dead_code)]
109struct SessionData {
110    /// Session ID
111    id: String,
112    /// Number of messages sent
113    message_count: usize,
114    /// Creation time
115    created_at: std::time::Instant,
116}
117
118impl ClaudeSDKClient {
119    /// Create a new client with the given options
120    pub fn new(options: ClaudeCodeOptions) -> Self {
121        // Set environment variable to indicate SDK usage
122        unsafe {
123            std::env::set_var("CLAUDE_CODE_ENTRYPOINT", "sdk-rust");
124        }
125
126        let transport = match SubprocessTransport::new(options.clone()) {
127            Ok(t) => t,
128            Err(e) => {
129                error!("Failed to create transport: {}", e);
130                // Create with empty path, will fail on connect
131                SubprocessTransport::with_cli_path(options.clone(), "")
132            }
133        };
134
135        // Wrap transport in Arc for sharing
136        let transport_arc: Arc<Mutex<Box<dyn Transport + Send>>> =
137            Arc::new(Mutex::new(Box::new(transport)));
138        
139        // Create query handler if control protocol features are enabled
140        let query_handler = if options.can_use_tool.is_some()
141            || options.hooks.is_some()
142            || !options.mcp_servers.is_empty() {
143            // Extract SDK MCP server instances
144            let sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>> = options.mcp_servers
145                .iter()
146                .filter_map(|(k, v)| {
147                    // Only extract SDK type MCP servers
148                    if let crate::types::McpServerConfig::Sdk { name: _, instance } = v {
149                        Some((k.clone(), instance.clone()))
150                    } else {
151                        None
152                    }
153                })
154                .collect();
155
156            // Enable streaming mode when control protocol is active
157            let is_streaming = options.can_use_tool.is_some()
158                || options.hooks.is_some()
159                || !sdk_mcp_servers.is_empty();
160
161            let query = Query::new(
162                transport_arc.clone(), // Share the same transport
163                is_streaming, // Enable streaming for control protocol
164                options.can_use_tool.clone(),
165                options.hooks.clone(),
166                sdk_mcp_servers,
167            );
168            Some(Arc::new(Mutex::new(query)))
169        } else {
170            None
171        };
172
173        Self {
174            options,
175            transport: transport_arc,
176            query_handler,
177            state: Arc::new(RwLock::new(ClientState::Disconnected)),
178            sessions: Arc::new(RwLock::new(HashMap::new())),
179            message_tx: Arc::new(Mutex::new(None)),
180            message_buffer: Arc::new(Mutex::new(Vec::new())),
181            request_counter: Arc::new(Mutex::new(0)),
182            budget_manager: BudgetManager::new(),
183        }
184    }
185
186    /// Connect to Claude CLI with an optional initial prompt
187    pub async fn connect(&mut self, initial_prompt: Option<String>) -> Result<()> {
188        // Check if already connected
189        {
190            let state = self.state.read().await;
191            if *state == ClientState::Connected {
192                return Ok(());
193            }
194        }
195
196        // Connect transport
197        {
198            let mut transport = self.transport.lock().await;
199            transport.connect().await?;
200        }
201
202        // Initialize query handler if present
203        if let Some(ref query_handler) = self.query_handler {
204            let mut handler = query_handler.lock().await;
205            handler.start().await?;
206            handler.initialize().await?;
207            info!("Initialized SDK control protocol");
208        }
209
210        // Update state
211        {
212            let mut state = self.state.write().await;
213            *state = ClientState::Connected;
214        }
215
216        info!("Connected to Claude CLI");
217
218        // Start message receiver task (always needed for regular messages)
219        self.start_message_receiver().await;
220
221        // Send initial prompt if provided
222        if let Some(prompt) = initial_prompt {
223            self.send_request(prompt, None).await?;
224        }
225
226        Ok(())
227    }
228
229    /// Send a user message to Claude
230    pub async fn send_user_message(&mut self, prompt: String) -> Result<()> {
231        // Check connection
232        {
233            let state = self.state.read().await;
234            if *state != ClientState::Connected {
235                return Err(SdkError::InvalidState {
236                    message: "Not connected".into(),
237                });
238            }
239        }
240
241        // Use default session ID
242        let session_id = "default".to_string();
243
244        // Update session data
245        {
246            let mut sessions = self.sessions.write().await;
247            let session = sessions.entry(session_id.clone()).or_insert_with(|| {
248                debug!("Creating new session: {}", session_id);
249                SessionData {
250                    id: session_id.clone(),
251                    message_count: 0,
252                    created_at: std::time::Instant::now(),
253                }
254            });
255            session.message_count += 1;
256        }
257
258        // Create and send message
259        let message = InputMessage::user(prompt, session_id.clone());
260
261        {
262            let mut transport = self.transport.lock().await;
263            transport.send_message(message).await?;
264        }
265
266        debug!("Sent request to Claude");
267        Ok(())
268    }
269
270    /// Send a request to Claude (alias for send_user_message with optional session_id)
271    pub async fn send_request(
272        &mut self,
273        prompt: String,
274        _session_id: Option<String>,
275    ) -> Result<()> {
276        // For now, ignore session_id and use send_user_message
277        self.send_user_message(prompt).await
278    }
279
280    /// Receive messages from Claude
281    ///
282    /// Returns a stream of messages. The stream will end when a Result message
283    /// is received or the connection is closed.
284    pub async fn receive_messages(&mut self) -> impl Stream<Item = Result<Message>> + use<> {
285        // Always use the regular message receiver
286        // (Query handler shares the same transport and receives control messages separately)
287        // Create a new channel for this receiver
288        let (tx, rx) = mpsc::channel(100);
289
290        // Get buffered messages and clear buffer
291        let buffered_messages = {
292            let mut buffer = self.message_buffer.lock().await;
293            std::mem::take(&mut *buffer)
294        };
295
296        // Send buffered messages to the new receiver
297        let tx_clone = tx.clone();
298        tokio::spawn(async move {
299            for msg in buffered_messages {
300                if tx_clone.send(Ok(msg)).await.is_err() {
301                    break;
302                }
303            }
304        });
305
306        // Store the sender for the message receiver task
307        {
308            let mut message_tx = self.message_tx.lock().await;
309            *message_tx = Some(tx);
310        }
311
312        ReceiverStream::new(rx)
313    }
314
315    /// Send an interrupt request
316    pub async fn interrupt(&mut self) -> Result<()> {
317        // Check connection
318        {
319            let state = self.state.read().await;
320            if *state != ClientState::Connected {
321                return Err(SdkError::InvalidState {
322                    message: "Not connected".into(),
323                });
324            }
325        }
326
327        // If we have a query handler, use it
328        if let Some(ref query_handler) = self.query_handler {
329            let mut handler = query_handler.lock().await;
330            return handler.interrupt().await;
331        }
332
333        // Otherwise use regular interrupt
334        // Generate request ID
335        let request_id = {
336            let mut counter = self.request_counter.lock().await;
337            *counter += 1;
338            format!("interrupt_{}", *counter)
339        };
340
341        // Send interrupt request
342        let request = ControlRequest::Interrupt {
343            request_id: request_id.clone(),
344        };
345
346        {
347            let mut transport = self.transport.lock().await;
348            transport.send_control_request(request).await?;
349        }
350
351        info!("Sent interrupt request: {}", request_id);
352
353        // Wait for acknowledgment (with timeout)
354        let transport = self.transport.clone();
355        let ack_task = tokio::spawn(async move {
356            let mut transport = transport.lock().await;
357            match tokio::time::timeout(
358                std::time::Duration::from_secs(5),
359                transport.receive_control_response(),
360            )
361            .await
362            {
363                Ok(Ok(Some(ControlResponse::InterruptAck {
364                    request_id: ack_id,
365                    success,
366                }))) => {
367                    if ack_id == request_id && success {
368                        Ok(())
369                    } else {
370                        Err(SdkError::ControlRequestError(
371                            "Interrupt not acknowledged successfully".into(),
372                        ))
373                    }
374                }
375                Ok(Ok(None)) => Err(SdkError::ControlRequestError(
376                    "No interrupt acknowledgment received".into(),
377                )),
378                Ok(Err(e)) => Err(e),
379                Err(_) => Err(SdkError::timeout(5)),
380            }
381        });
382
383        ack_task
384            .await
385            .map_err(|_| SdkError::ControlRequestError("Interrupt task panicked".into()))?
386    }
387
388    /// Check if the client is connected
389    pub async fn is_connected(&self) -> bool {
390        let state = self.state.read().await;
391        *state == ClientState::Connected
392    }
393
394    /// Get active session IDs
395    pub async fn get_sessions(&self) -> Vec<String> {
396        let sessions = self.sessions.read().await;
397        sessions.keys().cloned().collect()
398    }
399
400    /// Receive messages until and including a ResultMessage
401    ///
402    /// This is a convenience method that collects all messages from a single response.
403    /// It will automatically stop after receiving a ResultMessage.
404    pub async fn receive_response(&mut self) -> Pin<Box<dyn Stream<Item = Result<Message>> + Send + '_>> {
405        let mut messages = self.receive_messages().await;
406        
407        // Create a stream that stops after ResultMessage
408        Box::pin(async_stream::stream! {
409            while let Some(msg_result) = messages.next().await {
410                match &msg_result {
411                    Ok(Message::Result { .. }) => {
412                        yield msg_result;
413                        return;
414                    }
415                    _ => {
416                        yield msg_result;
417                    }
418                }
419            }
420        })
421    }
422
423    /// Get server information
424    ///
425    /// Returns initialization information from the Claude Code server including:
426    /// - Available commands
427    /// - Current and available output styles
428    /// - Server capabilities
429    pub async fn get_server_info(&self) -> Option<serde_json::Value> {
430        // If we have a query handler with control protocol, get from there
431        if let Some(ref query_handler) = self.query_handler {
432            let handler = query_handler.lock().await;
433            if let Some(init_result) = handler.get_initialization_result() {
434                return Some(init_result.clone());
435            }
436        }
437        
438        // Otherwise check message buffer for init message
439        let buffer = self.message_buffer.lock().await;
440        for msg in buffer.iter() {
441            if let Message::System { subtype, data } = msg {
442                if subtype == "init" {
443                    return Some(data.clone());
444                }
445            }
446        }
447        None
448    }
449
450    /// Send a query with optional session ID
451    ///
452    /// This method is similar to Python SDK's query method in ClaudeSDKClient
453    pub async fn query(&mut self, prompt: String, session_id: Option<String>) -> Result<()> {
454        let session_id = session_id.unwrap_or_else(|| "default".to_string());
455        
456        // Send the message
457        let message = InputMessage::user(prompt, session_id);
458        
459        {
460            let mut transport = self.transport.lock().await;
461            transport.send_message(message).await?;
462        }
463        
464        Ok(())
465    }
466
467    /// Disconnect from Claude CLI
468    pub async fn disconnect(&mut self) -> Result<()> {
469        // Check if already disconnected
470        {
471            let state = self.state.read().await;
472            if *state == ClientState::Disconnected {
473                return Ok(());
474            }
475        }
476
477        // Disconnect transport
478        {
479            let mut transport = self.transport.lock().await;
480            transport.disconnect().await?;
481        }
482
483        // Update state
484        {
485            let mut state = self.state.write().await;
486            *state = ClientState::Disconnected;
487        }
488
489        // Clear sessions
490        {
491            let mut sessions = self.sessions.write().await;
492            sessions.clear();
493        }
494
495        info!("Disconnected from Claude CLI");
496        Ok(())
497    }
498
499    /// Start the message receiver task
500    async fn start_message_receiver(&mut self) {
501        let transport = self.transport.clone();
502        let message_tx = self.message_tx.clone();
503        let message_buffer = self.message_buffer.clone();
504        let state = self.state.clone();
505        let budget_manager = self.budget_manager.clone();
506
507        tokio::spawn(async move {
508            // Subscribe to messages without holding the lock
509            let mut stream = {
510                let mut transport = transport.lock().await;
511                transport.receive_messages()
512            }; // Lock is released here immediately
513
514            while let Some(result) = stream.next().await {
515                match result {
516                    Ok(message) => {
517                        // Update token usage for Result messages
518                        if let Message::Result { .. } = &message {
519                            if let Message::Result { usage, total_cost_usd, .. } = &message {
520                                let (input_tokens, output_tokens) = if let Some(usage_json) = usage {
521                                    let input = usage_json.get("input_tokens")
522                                        .and_then(|v| v.as_u64())
523                                        .unwrap_or(0);
524                                    let output = usage_json.get("output_tokens")
525                                        .and_then(|v| v.as_u64())
526                                        .unwrap_or(0);
527                                    (input, output)
528                                } else {
529                                    (0, 0)
530                                };
531                                let cost = total_cost_usd.unwrap_or(0.0);
532                                budget_manager.update_usage(input_tokens, output_tokens, cost).await;
533                            }
534                        }
535
536                        // Buffer init messages for get_server_info()
537                        if let Message::System { subtype, .. } = &message {
538                            if subtype == "init" {
539                                let mut buffer = message_buffer.lock().await;
540                                buffer.push(message.clone());
541                            }
542                        }
543
544                        // Try to send to current receiver
545                        let sent = {
546                            let mut tx_opt = message_tx.lock().await;
547                            if let Some(tx) = tx_opt.as_mut() {
548                                tx.send(Ok(message.clone())).await.is_ok()
549                            } else {
550                                false
551                            }
552                        };
553
554                        // If no receiver or send failed, buffer the message
555                        if !sent {
556                            let mut buffer = message_buffer.lock().await;
557                            buffer.push(message);
558                        }
559                    }
560                    Err(e) => {
561                        error!("Error receiving message: {}", e);
562
563                        // Send error to receiver if available
564                        let mut tx_opt = message_tx.lock().await;
565                        if let Some(tx) = tx_opt.as_mut() {
566                            let _ = tx.send(Err(e)).await;
567                        }
568
569                        // Update state on error
570                        let mut state = state.write().await;
571                        *state = ClientState::Error;
572                        break;
573                    }
574                }
575            }
576
577            debug!("Message receiver task ended");
578        });
579    }
580
581    /// Get token usage statistics
582    ///
583    /// Returns the current token usage tracker with cumulative statistics
584    /// for all queries executed by this client.
585    pub async fn get_usage_stats(&self) -> crate::token_tracker::TokenUsageTracker {
586        self.budget_manager.get_usage().await
587    }
588
589    /// Set budget limit with optional warning callback
590    ///
591    /// # Arguments
592    ///
593    /// * `limit` - Budget limit configuration (cost and/or token caps)
594    /// * `on_warning` - Optional callback function triggered when usage exceeds warning threshold
595    ///
596    /// # Example
597    ///
598    /// ```rust,no_run
599    /// use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions};
600    /// use cc_sdk::token_tracker::{BudgetLimit, BudgetWarningCallback};
601    /// use std::sync::Arc;
602    ///
603    /// # async fn example() {
604    /// let mut client = ClaudeSDKClient::new(ClaudeCodeOptions::default());
605    ///
606    /// // Set budget with callback
607    /// let cb: BudgetWarningCallback = Arc::new(|msg: &str| println!("Budget warning: {}", msg));
608    /// client.set_budget_limit(BudgetLimit::with_cost(5.0), Some(cb)).await;
609    /// # }
610    /// ```
611    pub async fn set_budget_limit(
612        &self,
613        limit: crate::token_tracker::BudgetLimit,
614        on_warning: Option<crate::token_tracker::BudgetWarningCallback>,
615    ) {
616        self.budget_manager.set_limit(limit).await;
617        if let Some(callback) = on_warning {
618            self.budget_manager.set_warning_callback(callback).await;
619        }
620    }
621
622    /// Clear budget limit and reset warning state
623    pub async fn clear_budget_limit(&self) {
624        self.budget_manager.clear_limit().await;
625    }
626
627    /// Reset token usage statistics to zero
628    ///
629    /// Clears all accumulated token and cost statistics.
630    /// Budget limits remain in effect.
631    pub async fn reset_usage_stats(&self) {
632        self.budget_manager.reset_usage().await;
633    }
634
635    /// Check if budget has been exceeded
636    ///
637    /// Returns true if current usage exceeds any configured limits
638    pub async fn is_budget_exceeded(&self) -> bool {
639        self.budget_manager.is_exceeded().await
640    }
641
642    // Removed unused helper; usage is updated inline in message receiver
643}
644
645impl Drop for ClaudeSDKClient {
646    fn drop(&mut self) {
647        // Try to disconnect gracefully
648        let transport = self.transport.clone();
649        let state = self.state.clone();
650
651        tokio::spawn(async move {
652            let state = state.read().await;
653            if *state == ClientState::Connected {
654                let mut transport = transport.lock().await;
655                if let Err(e) = transport.disconnect().await {
656                    debug!("Error disconnecting in drop: {}", e);
657                }
658            }
659        });
660    }
661}
662
663#[cfg(test)]
664mod tests {
665    use super::*;
666
667    #[tokio::test]
668    async fn test_client_lifecycle() {
669        let options = ClaudeCodeOptions::default();
670        let client = ClaudeSDKClient::new(options);
671
672        assert!(!client.is_connected().await);
673        assert_eq!(client.get_sessions().await.len(), 0);
674    }
675
676    #[tokio::test]
677    async fn test_client_state_transitions() {
678        let options = ClaudeCodeOptions::default();
679        let client = ClaudeSDKClient::new(options);
680
681        let state = client.state.read().await;
682        assert_eq!(*state, ClientState::Disconnected);
683    }
684}