Skip to main content

cc_sdk/
interactive.rs

1//! Working interactive client implementation
2
3use crate::{
4    errors::{Result, SdkError},
5    transport::{InputMessage, SubprocessTransport, Transport},
6    types::{ClaudeCodeOptions, ControlRequest, Message},
7};
8use futures::{Stream, StreamExt};
9use std::sync::Arc;
10use tokio::sync::Mutex;
11use tokio_stream::wrappers::ReceiverStream;
12use tracing::{debug, info};
13
14/// Interactive client for stateful conversations with Claude
15///
16/// This is the recommended client for interactive use. It provides a clean API
17/// that matches the Python SDK's functionality.
18pub struct InteractiveClient {
19    transport: Arc<Mutex<Box<dyn Transport + Send>>>,
20    connected: bool,
21}
22
23impl InteractiveClient {
24    /// Create a new client
25    pub fn new(options: ClaudeCodeOptions) -> Result<Self> {
26        let transport: Box<dyn Transport + Send> = Box::new(SubprocessTransport::new(options)?);
27        Ok(Self {
28            transport: Arc::new(Mutex::new(transport)),
29            connected: false,
30        })
31    }
32
33    /// Connect to Claude
34    pub async fn connect(&mut self) -> Result<()> {
35        if self.connected {
36            return Ok(());
37        }
38
39        let mut transport = self.transport.lock().await;
40        transport.connect().await?;
41        drop(transport); // Release lock immediately
42
43        self.connected = true;
44        info!("Connected to Claude CLI");
45        Ok(())
46    }
47
48    /// Send a message and receive all messages until Result message
49    pub async fn send_and_receive(&mut self, prompt: String) -> Result<Vec<Message>> {
50        if !self.connected {
51            return Err(SdkError::InvalidState {
52                message: "Not connected".into(),
53            });
54        }
55
56        // Send message
57        {
58            let mut transport = self.transport.lock().await;
59            let message = InputMessage::user(prompt, "default".to_string());
60            transport.send_message(message).await?;
61        } // Lock released here
62
63        debug!("Message sent, waiting for response");
64
65        // Receive messages
66        let mut messages = Vec::new();
67        loop {
68            // Try to get a message
69            let msg_result = {
70                let mut transport = self.transport.lock().await;
71                let mut stream = transport.receive_messages();
72                stream.next().await
73            }; // Lock released here
74
75            // Process the message
76            if let Some(result) = msg_result {
77                match result {
78                    Ok(msg) => {
79                        debug!("Received: {:?}", msg);
80                        let is_result = matches!(msg, Message::Result { .. });
81                        messages.push(msg);
82                        if is_result {
83                            break;
84                        }
85                    }
86                    Err(e) => return Err(e),
87                }
88            } else {
89                // No more messages, wait a bit
90                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
91            }
92        }
93
94        Ok(messages)
95    }
96
97    /// Send a message without waiting for response
98    pub async fn send_message(&mut self, prompt: String) -> Result<()> {
99        if !self.connected {
100            return Err(SdkError::InvalidState {
101                message: "Not connected".into(),
102            });
103        }
104
105        let mut transport = self.transport.lock().await;
106        let message = InputMessage::user(prompt, "default".to_string());
107        transport.send_message(message).await?;
108        drop(transport);
109
110        debug!("Message sent");
111        Ok(())
112    }
113
114    /// Receive messages until Result message (convenience method like Python SDK)
115    pub async fn receive_response(&mut self) -> Result<Vec<Message>> {
116        if !self.connected {
117            return Err(SdkError::InvalidState {
118                message: "Not connected".into(),
119            });
120        }
121
122        let mut messages = Vec::new();
123        loop {
124            // Try to get a message
125            let msg_result = {
126            let mut transport = self.transport.lock().await;
127            let mut stream = transport.receive_messages();
128                stream.next().await
129            }; // Lock released here
130
131            // Process the message
132            if let Some(result) = msg_result {
133                match result {
134                    Ok(msg) => {
135                        debug!("Received: {:?}", msg);
136                        let is_result = matches!(msg, Message::Result { .. });
137                        messages.push(msg);
138                        if is_result {
139                            break;
140                        }
141                    }
142                    Err(e) => return Err(e),
143                }
144            } else {
145                // No more messages, wait a bit
146                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
147            }
148        }
149
150        Ok(messages)
151    }
152
153    /// Receive messages as a stream (streaming output support)
154    /// 
155    /// Returns a stream of messages that can be iterated over asynchronously.
156    /// This is similar to Python SDK's `receive_messages()` method.
157    /// 
158    /// # Example
159    /// 
160    /// ```rust,no_run
161    /// use cc_sdk::{InteractiveClient, ClaudeCodeOptions};
162    /// use futures::StreamExt;
163    /// 
164    /// #[tokio::main]
165    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
166    ///     let mut client = InteractiveClient::new(ClaudeCodeOptions::default())?;
167    ///     client.connect().await?;
168    ///     
169    ///     // Send a message
170    ///     client.send_message("Hello!".to_string()).await?;
171    ///     
172    ///     // Receive messages as a stream
173    ///     let mut stream = client.receive_messages_stream().await;
174    ///     while let Some(msg) = stream.next().await {
175    ///         match msg {
176    ///             Ok(message) => println!("Received: {:?}", message),
177    ///             Err(e) => eprintln!("Error: {}", e),
178    ///         }
179    ///     }
180    ///     
181    ///     Ok(())
182    /// }
183    /// ```
184    pub async fn receive_messages_stream(&mut self) -> impl Stream<Item = Result<Message>> + '_ {
185        // Create a channel for messages
186        let (tx, rx) = tokio::sync::mpsc::channel(100);
187        let transport = self.transport.clone();
188        
189        // Spawn a task to receive messages from transport
190        tokio::spawn(async move {
191            let mut transport = transport.lock().await;
192            let mut stream = transport.receive_messages();
193            
194            while let Some(result) = stream.next().await {
195                // Send each message through the channel
196                if tx.send(result).await.is_err() {
197                    // Receiver dropped, stop sending
198                    break;
199                }
200            }
201        });
202        
203        // Return the receiver as a stream
204        ReceiverStream::new(rx)
205    }
206
207    /// Receive messages as an async iterator until a Result message
208    /// 
209    /// This is a convenience method that collects messages until a Result message
210    /// is received, similar to Python SDK's `receive_response()`.
211    pub async fn receive_response_stream(&mut self) -> impl Stream<Item = Result<Message>> + '_ {
212        // Create a stream that stops after Result message
213        async_stream::stream! {
214            let mut stream = self.receive_messages_stream().await;
215            
216            while let Some(result) = stream.next().await {
217                match &result {
218                    Ok(msg) => {
219                        let is_result = matches!(msg, Message::Result { .. });
220                        yield result;
221                        if is_result {
222                            break;
223                        }
224                    }
225                    Err(_) => {
226                        yield result;
227                        break;
228                    }
229                }
230            }
231        }
232    }
233
234    /// Send interrupt signal to cancel current operation
235    pub async fn interrupt(&mut self) -> Result<()> {
236        if !self.connected {
237            return Err(SdkError::InvalidState {
238                message: "Not connected".into(),
239            });
240        }
241
242        let mut transport = self.transport.lock().await;
243        let request = ControlRequest::Interrupt {
244            request_id: uuid::Uuid::new_v4().to_string(),
245        };
246        transport.send_control_request(request).await?;
247        drop(transport);
248
249        info!("Interrupt sent");
250        Ok(())
251    }
252
253    /// Get MCP server status for all configured servers
254    ///
255    /// Note: Requires the CLI to support `mcp_status` SDK control messages.
256    /// Returns an empty list if the CLI doesn't support this feature.
257    pub async fn get_mcp_status(&mut self) -> Result<Vec<crate::types::McpServerStatus>> {
258        if !self.connected {
259            return Err(SdkError::InvalidState {
260                message: "Not connected".into(),
261            });
262        }
263        // MCP status requires SDK control protocol support from the CLI.
264        // The transport currently doesn't expose a bidirectional SDK control channel
265        // for this operation. Return empty for now.
266        Ok(vec![])
267    }
268
269    /// Add an MCP server at runtime via SDK control protocol
270    pub async fn add_mcp_server(
271        &mut self,
272        name: &str,
273        config: crate::types::McpServerConfig,
274    ) -> Result<()> {
275        if !self.connected {
276            return Err(SdkError::InvalidState {
277                message: "Not connected".into(),
278            });
279        }
280
281        let config_json = serde_json::to_value(&config)
282            .map_err(|e| SdkError::TransportError(format!("Failed to serialize MCP config: {e}")))?;
283
284        let mcp_msg = crate::types::SDKControlMcpMessageRequest {
285            subtype: "mcp_message".to_string(),
286            mcp_server_name: name.to_string(),
287            message: serde_json::json!({
288                "action": "add",
289                "config": config_json
290            }),
291        };
292
293        let mut transport = self.transport.lock().await;
294        let request = crate::types::SDKControlRequest::McpMessage(mcp_msg);
295        let json = serde_json::to_value(&request)
296            .map_err(|e| SdkError::TransportError(format!("Failed to serialize: {e}")))?;
297        let input = crate::transport::InputMessage {
298            r#type: "sdk_control".to_string(),
299            message: json,
300            parent_tool_use_id: None,
301            session_id: String::new(),
302        };
303        transport.send_message(input).await
304    }
305
306    /// Remove an MCP server at runtime
307    pub async fn remove_mcp_server(&mut self, name: &str) -> Result<()> {
308        if !self.connected {
309            return Err(SdkError::InvalidState {
310                message: "Not connected".into(),
311            });
312        }
313
314        let mcp_msg = crate::types::SDKControlMcpMessageRequest {
315            subtype: "mcp_message".to_string(),
316            mcp_server_name: name.to_string(),
317            message: serde_json::json!({ "action": "remove" }),
318        };
319
320        let mut transport = self.transport.lock().await;
321        let request = crate::types::SDKControlRequest::McpMessage(mcp_msg);
322        let json = serde_json::to_value(&request)
323            .map_err(|e| SdkError::TransportError(format!("Failed to serialize: {e}")))?;
324        let input = crate::transport::InputMessage {
325            r#type: "sdk_control".to_string(),
326            message: json,
327            parent_tool_use_id: None,
328            session_id: String::new(),
329        };
330        transport.send_message(input).await
331    }
332
333    /// Reconnect an MCP server
334    pub async fn reconnect_mcp_server(&mut self, name: &str) -> Result<()> {
335        if !self.connected {
336            return Err(SdkError::InvalidState {
337                message: "Not connected".into(),
338            });
339        }
340
341        let mcp_msg = crate::types::SDKControlMcpMessageRequest {
342            subtype: "mcp_message".to_string(),
343            mcp_server_name: name.to_string(),
344            message: serde_json::json!({ "action": "reconnect" }),
345        };
346
347        let mut transport = self.transport.lock().await;
348        let request = crate::types::SDKControlRequest::McpMessage(mcp_msg);
349        let json = serde_json::to_value(&request)
350            .map_err(|e| SdkError::TransportError(format!("Failed to serialize: {e}")))?;
351        let input = crate::transport::InputMessage {
352            r#type: "sdk_control".to_string(),
353            message: json,
354            parent_tool_use_id: None,
355            session_id: String::new(),
356        };
357        transport.send_message(input).await
358    }
359
360    /// Toggle an MCP server enabled/disabled
361    pub async fn toggle_mcp_server(&mut self, name: &str, enabled: bool) -> Result<()> {
362        if !self.connected {
363            return Err(SdkError::InvalidState {
364                message: "Not connected".into(),
365            });
366        }
367
368        let mcp_msg = crate::types::SDKControlMcpMessageRequest {
369            subtype: "mcp_message".to_string(),
370            mcp_server_name: name.to_string(),
371            message: serde_json::json!({ "action": "toggle", "enabled": enabled }),
372        };
373
374        let mut transport = self.transport.lock().await;
375        let request = crate::types::SDKControlRequest::McpMessage(mcp_msg);
376        let json = serde_json::to_value(&request)
377            .map_err(|e| SdkError::TransportError(format!("Failed to serialize: {e}")))?;
378        let input = crate::transport::InputMessage {
379            r#type: "sdk_control".to_string(),
380            message: json,
381            parent_tool_use_id: None,
382            session_id: String::new(),
383        };
384        transport.send_message(input).await
385    }
386
387    /// List available sessions
388    pub async fn list_sessions(
389        &self,
390        directory: Option<&str>,
391        limit: Option<usize>,
392        include_worktrees: bool,
393    ) -> Result<Vec<crate::sessions::SessionInfo>> {
394        crate::sessions::list_sessions(directory, limit, include_worktrees).await
395    }
396
397    /// Get messages from a specific session
398    pub async fn get_session_messages(
399        &self,
400        session_id: &str,
401        directory: Option<&str>,
402        limit: Option<usize>,
403        offset: usize,
404    ) -> Result<Vec<crate::sessions::SessionMessage>> {
405        crate::sessions::get_session_messages(session_id, directory, limit, offset).await
406    }
407
408    /// Rename a session
409    pub async fn rename_session(&self, session_id: &str, title: &str) -> Result<()> {
410        crate::sessions::rename_session(session_id, title).await
411    }
412
413    /// Tag a session
414    pub async fn tag_session(&self, session_id: &str, tag: Option<&str>) -> Result<()> {
415        crate::sessions::tag_session(session_id, tag).await
416    }
417
418    /// Disconnect
419    pub async fn disconnect(&mut self) -> Result<()> {
420        if !self.connected {
421            return Ok(());
422        }
423
424        let mut transport = self.transport.lock().await;
425        transport.disconnect().await?;
426        drop(transport);
427
428        self.connected = false;
429        info!("Disconnected from Claude CLI");
430        Ok(())
431    }
432}