Skip to main content

cc_sdk/
client_working.rs

1//! A working interactive client implementation
2
3use crate::{
4    errors::{Result, SdkError},
5    transport::{InputMessage, SubprocessTransport, Transport},
6    types::{ClaudeCodeOptions, Message},
7};
8use futures::StreamExt;
9use std::sync::Arc;
10use tokio::sync::{Mutex, RwLock, mpsc};
11use tracing::{debug, error, info};
12
13/// Client state
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ClientState {
16    Disconnected,
17    Connected,
18    Error,
19}
20
21/// Working interactive client
22pub struct ClaudeSDKClientWorking {
23    /// Configuration options
24    options: ClaudeCodeOptions,
25    /// Transport wrapped in Arc<Mutex<>> for shared access
26    transport: Arc<Mutex<Option<SubprocessTransport>>>,
27    /// Channel to receive messages
28    message_rx: Arc<Mutex<Option<mpsc::Receiver<Message>>>>,
29    /// Client state
30    state: Arc<RwLock<ClientState>>,
31}
32
33impl ClaudeSDKClientWorking {
34    /// Create a new client
35    pub fn new(options: ClaudeCodeOptions) -> Self {
36        Self {
37            options,
38            transport: Arc::new(Mutex::new(None)),
39            message_rx: Arc::new(Mutex::new(None)),
40            state: Arc::new(RwLock::new(ClientState::Disconnected)),
41        }
42    }
43
44    /// Connect to Claude
45    pub async fn connect(&mut self, initial_prompt: Option<String>) -> Result<()> {
46        // Check if already connected
47        {
48            let state = self.state.read().await;
49            if *state == ClientState::Connected {
50                return Ok(());
51            }
52        }
53
54        // Create transport
55        let mut new_transport = SubprocessTransport::new(self.options.clone())?;
56        new_transport.connect().await?;
57
58        // Create message channel
59        let (tx, rx) = mpsc::channel::<Message>(100);
60
61        // Store transport
62        {
63            let mut transport = self.transport.lock().await;
64            *transport = Some(new_transport);
65        }
66
67        // Store receiver
68        {
69            let mut message_rx = self.message_rx.lock().await;
70            *message_rx = Some(rx);
71        }
72
73        // Update state
74        {
75            let mut state = self.state.write().await;
76            *state = ClientState::Connected;
77        }
78
79        // Start background task to read messages
80        let transport_clone = self.transport.clone();
81        let state_clone = self.state.clone();
82        let tx_clone = tx.clone();
83
84        tokio::spawn(async move {
85            loop {
86                // Get one message at a time
87                let msg_result = {
88                    let mut transport_guard = transport_clone.lock().await;
89                    if let Some(transport) = transport_guard.as_mut() {
90                        // Get the stream and immediately poll it once
91                        let mut stream = transport.receive_messages();
92                        stream.next().await
93                    } else {
94                        break;
95                    }
96                };
97
98                // Process the message if we got one
99                if let Some(result) = msg_result {
100                    match result {
101                        Ok(msg) => {
102                            debug!("Received message: {:?}", msg);
103                            if tx_clone.send(msg).await.is_err() {
104                                break;
105                            }
106                        }
107                        Err(e) => {
108                            error!("Error receiving message: {}", e);
109                            let mut state = state_clone.write().await;
110                            *state = ClientState::Error;
111                            break;
112                        }
113                    }
114                } else {
115                    // No message available, wait a bit
116                    tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
117                }
118
119                // Stream ended, check if we should reconnect
120                tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
121
122                let should_continue = {
123                    let state = state_clone.read().await;
124                    *state == ClientState::Connected
125                };
126
127                if !should_continue {
128                    break;
129                }
130            }
131
132            debug!("Message reader task ended");
133        });
134
135        info!("Connected to Claude CLI");
136
137        // Send initial prompt if provided
138        if let Some(prompt) = initial_prompt {
139            self.send_user_message(prompt).await?;
140        }
141
142        Ok(())
143    }
144
145    /// Send a user message
146    pub async fn send_user_message(&mut self, prompt: String) -> Result<()> {
147        // Check connection
148        {
149            let state = self.state.read().await;
150            if *state != ClientState::Connected {
151                return Err(SdkError::InvalidState {
152                    message: "Not connected".into(),
153                });
154            }
155        }
156
157        // Create message
158        let message = InputMessage::user(prompt, "default".to_string());
159
160        // Send message
161        {
162            let mut transport_guard = self.transport.lock().await;
163            if let Some(transport) = transport_guard.as_mut() {
164                transport.send_message(message).await?;
165                debug!("User message sent");
166            } else {
167                return Err(SdkError::InvalidState {
168                    message: "Transport not available".into(),
169                });
170            }
171        }
172
173        Ok(())
174    }
175
176    /// Receive next message
177    pub async fn receive_message(&mut self) -> Result<Option<Message>> {
178        let mut rx_guard = self.message_rx.lock().await;
179        if let Some(rx) = rx_guard.as_mut() {
180            Ok(rx.recv().await)
181        } else {
182            Err(SdkError::InvalidState {
183                message: "Not connected".into(),
184            })
185        }
186    }
187
188    /// Receive all messages until result
189    pub async fn receive_response(&mut self) -> Result<Vec<Message>> {
190        let mut messages = Vec::new();
191
192        while let Some(msg) = self.receive_message().await? {
193            let is_result = matches!(msg, Message::Result { .. });
194            messages.push(msg);
195            if is_result {
196                break;
197            }
198        }
199
200        Ok(messages)
201    }
202
203    /// Disconnect
204    pub async fn disconnect(&mut self) -> Result<()> {
205        // Update state
206        {
207            let mut state = self.state.write().await;
208            if *state == ClientState::Disconnected {
209                return Ok(());
210            }
211            *state = ClientState::Disconnected;
212        }
213
214        // Disconnect transport
215        {
216            let mut transport_guard = self.transport.lock().await;
217            if let Some(mut transport) = transport_guard.take() {
218                transport.disconnect().await?;
219            }
220        }
221
222        // Clear receiver
223        {
224            let mut rx_guard = self.message_rx.lock().await;
225            rx_guard.take();
226        }
227
228        info!("Disconnected from Claude CLI");
229        Ok(())
230    }
231
232    /// Check if connected
233    pub async fn is_connected(&self) -> bool {
234        let state = self.state.read().await;
235        *state == ClientState::Connected
236    }
237}