cc_sdk/
client.rs

1//! Interactive client for bidirectional communication with Claude
2//!
3//! This module provides the `ClaudeSDKClient` for interactive, stateful
4//! conversations with Claude Code CLI.
5
6use crate::{
7    errors::{Result, SdkError},
8    transport::{InputMessage, SubprocessTransport, Transport},
9    types::{ClaudeCodeOptions, ControlRequest, ControlResponse, Message},
10};
11use futures::stream::{Stream, StreamExt};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::{Mutex, RwLock, mpsc};
15use tokio_stream::wrappers::ReceiverStream;
16use tracing::{debug, error, info};
17
18/// Client state
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum ClientState {
21    /// Not connected
22    Disconnected,
23    /// Connected and ready
24    Connected,
25    /// Error state
26    Error,
27}
28
29/// Interactive client for bidirectional communication with Claude
30///
31/// `ClaudeSDKClient` provides a stateful, interactive interface for communicating
32/// with Claude Code CLI. Unlike the simple `query` function, this client supports:
33///
34/// - Bidirectional communication
35/// - Multiple sessions
36/// - Interrupt capabilities
37/// - State management
38/// - Follow-up messages based on responses
39///
40/// # Example
41///
42/// ```rust,no_run
43/// use cc_sdk::{ClaudeSDKClient, ClaudeCodeOptions, Message, Result};
44/// use futures::StreamExt;
45///
46/// #[tokio::main]
47/// async fn main() -> Result<()> {
48///     let options = ClaudeCodeOptions::builder()
49///         .system_prompt("You are a helpful assistant")
50///         .model("claude-3-opus-20240229")
51///         .build();
52///
53///     let mut client = ClaudeSDKClient::new(options);
54///
55///     // Connect with initial prompt
56///     client.connect(Some("Hello!".to_string())).await?;
57///
58///     // Receive initial response
59///     let mut messages = client.receive_messages().await;
60///     while let Some(msg) = messages.next().await {
61///         match msg? {
62///             Message::Result { .. } => break,
63///             msg => println!("{:?}", msg),
64///         }
65///     }
66///
67///     // Send follow-up
68///     client.send_request("What's 2 + 2?".to_string(), None).await?;
69///
70///     // Receive response
71///     let mut messages = client.receive_messages().await;
72///     while let Some(msg) = messages.next().await {
73///         println!("{:?}", msg?);
74///     }
75///
76///     // Disconnect
77///     client.disconnect().await?;
78///
79///     Ok(())
80/// }
81/// ```
82pub struct ClaudeSDKClient {
83    /// Configuration options
84    #[allow(dead_code)]
85    options: ClaudeCodeOptions,
86    /// Transport layer
87    transport: Arc<Mutex<SubprocessTransport>>,
88    /// Client state
89    state: Arc<RwLock<ClientState>>,
90    /// Active sessions
91    sessions: Arc<RwLock<HashMap<String, SessionData>>>,
92    /// Message sender for current receiver
93    message_tx: Arc<Mutex<Option<mpsc::Sender<Result<Message>>>>>,
94    /// Message buffer for multiple receivers
95    message_buffer: Arc<Mutex<Vec<Message>>>,
96    /// Request counter
97    request_counter: Arc<Mutex<u64>>,
98}
99
100/// Session data
101#[allow(dead_code)]
102struct SessionData {
103    /// Session ID
104    id: String,
105    /// Number of messages sent
106    message_count: usize,
107    /// Creation time
108    created_at: std::time::Instant,
109}
110
111impl ClaudeSDKClient {
112    /// Create a new client with the given options
113    pub fn new(options: ClaudeCodeOptions) -> Self {
114        // Set environment variable to indicate SDK usage
115        unsafe {
116            std::env::set_var("CLAUDE_CODE_ENTRYPOINT", "sdk-rust");
117        }
118
119        let transport = match SubprocessTransport::new(options.clone()) {
120            Ok(t) => t,
121            Err(e) => {
122                error!("Failed to create transport: {}", e);
123                // Create with empty path, will fail on connect
124                SubprocessTransport::with_cli_path(options.clone(), "")
125            }
126        };
127
128        Self {
129            options,
130            transport: Arc::new(Mutex::new(transport)),
131            state: Arc::new(RwLock::new(ClientState::Disconnected)),
132            sessions: Arc::new(RwLock::new(HashMap::new())),
133            message_tx: Arc::new(Mutex::new(None)),
134            message_buffer: Arc::new(Mutex::new(Vec::new())),
135            request_counter: Arc::new(Mutex::new(0)),
136        }
137    }
138
139    /// Connect to Claude CLI with an optional initial prompt
140    pub async fn connect(&mut self, initial_prompt: Option<String>) -> Result<()> {
141        // Check if already connected
142        {
143            let state = self.state.read().await;
144            if *state == ClientState::Connected {
145                return Ok(());
146            }
147        }
148
149        // Connect transport
150        {
151            let mut transport = self.transport.lock().await;
152            transport.connect().await?;
153        }
154
155        // Update state
156        {
157            let mut state = self.state.write().await;
158            *state = ClientState::Connected;
159        }
160
161        info!("Connected to Claude CLI");
162
163        // Start message receiver task
164        self.start_message_receiver().await;
165
166        // Send initial prompt if provided
167        if let Some(prompt) = initial_prompt {
168            self.send_request(prompt, None).await?;
169        }
170
171        Ok(())
172    }
173
174    /// Send a user message to Claude
175    pub async fn send_user_message(&mut self, prompt: String) -> Result<()> {
176        // Check connection
177        {
178            let state = self.state.read().await;
179            if *state != ClientState::Connected {
180                return Err(SdkError::InvalidState {
181                    message: "Not connected".into(),
182                });
183            }
184        }
185
186        // Use default session ID
187        let session_id = "default".to_string();
188
189        // Update session data
190        {
191            let mut sessions = self.sessions.write().await;
192            let session = sessions.entry(session_id.clone()).or_insert_with(|| {
193                debug!("Creating new session: {}", session_id);
194                SessionData {
195                    id: session_id.clone(),
196                    message_count: 0,
197                    created_at: std::time::Instant::now(),
198                }
199            });
200            session.message_count += 1;
201        }
202
203        // Create and send message
204        let message = InputMessage::user(prompt, session_id.clone());
205
206        {
207            let mut transport = self.transport.lock().await;
208            transport.send_message(message).await?;
209        }
210
211        debug!("Sent request to Claude");
212        Ok(())
213    }
214
215    /// Send a request to Claude (alias for send_user_message with optional session_id)
216    pub async fn send_request(
217        &mut self,
218        prompt: String,
219        _session_id: Option<String>,
220    ) -> Result<()> {
221        // For now, ignore session_id and use send_user_message
222        self.send_user_message(prompt).await
223    }
224
225    /// Receive messages from Claude
226    ///
227    /// Returns a stream of messages. The stream will end when a Result message
228    /// is received or the connection is closed.
229    pub async fn receive_messages(&mut self) -> impl Stream<Item = Result<Message>> + use<> {
230        // Create a new channel for this receiver
231        let (tx, rx) = mpsc::channel(100);
232
233        // Get buffered messages and clear buffer
234        let buffered_messages = {
235            let mut buffer = self.message_buffer.lock().await;
236            std::mem::take(&mut *buffer)
237        };
238
239        // Send buffered messages to the new receiver
240        let tx_clone = tx.clone();
241        tokio::spawn(async move {
242            for msg in buffered_messages {
243                if tx_clone.send(Ok(msg)).await.is_err() {
244                    break;
245                }
246            }
247        });
248
249        // Store the sender for the message receiver task
250        {
251            let mut message_tx = self.message_tx.lock().await;
252            *message_tx = Some(tx);
253        }
254
255        ReceiverStream::new(rx)
256    }
257
258    /// Send an interrupt request
259    pub async fn interrupt(&mut self) -> Result<()> {
260        // Check connection
261        {
262            let state = self.state.read().await;
263            if *state != ClientState::Connected {
264                return Err(SdkError::InvalidState {
265                    message: "Not connected".into(),
266                });
267            }
268        }
269
270        // Generate request ID
271        let request_id = {
272            let mut counter = self.request_counter.lock().await;
273            *counter += 1;
274            format!("interrupt_{}", *counter)
275        };
276
277        // Send interrupt request
278        let request = ControlRequest::Interrupt {
279            request_id: request_id.clone(),
280        };
281
282        {
283            let mut transport = self.transport.lock().await;
284            transport.send_control_request(request).await?;
285        }
286
287        info!("Sent interrupt request: {}", request_id);
288
289        // Wait for acknowledgment (with timeout)
290        let transport = self.transport.clone();
291        let ack_task = tokio::spawn(async move {
292            let mut transport = transport.lock().await;
293            match tokio::time::timeout(
294                std::time::Duration::from_secs(5),
295                transport.receive_control_response(),
296            )
297            .await
298            {
299                Ok(Ok(Some(ControlResponse::InterruptAck {
300                    request_id: ack_id,
301                    success,
302                }))) => {
303                    if ack_id == request_id && success {
304                        Ok(())
305                    } else {
306                        Err(SdkError::ControlRequestError(
307                            "Interrupt not acknowledged successfully".into(),
308                        ))
309                    }
310                }
311                Ok(Ok(None)) => Err(SdkError::ControlRequestError(
312                    "No interrupt acknowledgment received".into(),
313                )),
314                Ok(Err(e)) => Err(e),
315                Err(_) => Err(SdkError::timeout(5)),
316            }
317        });
318
319        ack_task
320            .await
321            .map_err(|_| SdkError::ControlRequestError("Interrupt task panicked".into()))?
322    }
323
324    /// Check if the client is connected
325    pub async fn is_connected(&self) -> bool {
326        let state = self.state.read().await;
327        *state == ClientState::Connected
328    }
329
330    /// Get active session IDs
331    pub async fn get_sessions(&self) -> Vec<String> {
332        let sessions = self.sessions.read().await;
333        sessions.keys().cloned().collect()
334    }
335
336    /// Disconnect from Claude CLI
337    pub async fn disconnect(&mut self) -> Result<()> {
338        // Check if already disconnected
339        {
340            let state = self.state.read().await;
341            if *state == ClientState::Disconnected {
342                return Ok(());
343            }
344        }
345
346        // Disconnect transport
347        {
348            let mut transport = self.transport.lock().await;
349            transport.disconnect().await?;
350        }
351
352        // Update state
353        {
354            let mut state = self.state.write().await;
355            *state = ClientState::Disconnected;
356        }
357
358        // Clear sessions
359        {
360            let mut sessions = self.sessions.write().await;
361            sessions.clear();
362        }
363
364        info!("Disconnected from Claude CLI");
365        Ok(())
366    }
367
368    /// Start the message receiver task
369    async fn start_message_receiver(&mut self) {
370        let transport = self.transport.clone();
371        let message_tx = self.message_tx.clone();
372        let message_buffer = self.message_buffer.clone();
373        let state = self.state.clone();
374
375        tokio::spawn(async move {
376            let mut transport = transport.lock().await;
377            let mut stream = transport.receive_messages();
378
379            while let Some(result) = stream.next().await {
380                match result {
381                    Ok(message) => {
382                        // Try to send to current receiver
383                        let sent = {
384                            let mut tx_opt = message_tx.lock().await;
385                            if let Some(tx) = tx_opt.as_mut() {
386                                tx.send(Ok(message.clone())).await.is_ok()
387                            } else {
388                                false
389                            }
390                        };
391
392                        // If no receiver or send failed, buffer the message
393                        if !sent {
394                            let mut buffer = message_buffer.lock().await;
395                            buffer.push(message);
396                        }
397                    }
398                    Err(e) => {
399                        error!("Error receiving message: {}", e);
400
401                        // Send error to receiver if available
402                        let mut tx_opt = message_tx.lock().await;
403                        if let Some(tx) = tx_opt.as_mut() {
404                            let _ = tx.send(Err(e)).await;
405                        }
406
407                        // Update state on error
408                        let mut state = state.write().await;
409                        *state = ClientState::Error;
410                        break;
411                    }
412                }
413            }
414
415            debug!("Message receiver task ended");
416        });
417    }
418}
419
420impl Drop for ClaudeSDKClient {
421    fn drop(&mut self) {
422        // Try to disconnect gracefully
423        let transport = self.transport.clone();
424        let state = self.state.clone();
425
426        tokio::spawn(async move {
427            let state = state.read().await;
428            if *state == ClientState::Connected {
429                let mut transport = transport.lock().await;
430                if let Err(e) = transport.disconnect().await {
431                    debug!("Error disconnecting in drop: {}", e);
432                }
433            }
434        });
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[tokio::test]
443    async fn test_client_lifecycle() {
444        let options = ClaudeCodeOptions::default();
445        let client = ClaudeSDKClient::new(options);
446
447        assert!(!client.is_connected().await);
448        assert_eq!(client.get_sessions().await.len(), 0);
449    }
450
451    #[tokio::test]
452    async fn test_client_state_transitions() {
453        let options = ClaudeCodeOptions::default();
454        let client = ClaudeSDKClient::new(options);
455
456        let state = client.state.read().await;
457        assert_eq!(*state, ClientState::Disconnected);
458    }
459}