cc_sdk/
client_working.rs

1//! A working interactive client implementation
2
3use crate::{
4    errors::{Result, SdkError},
5    transport::{InputMessage, SubprocessTransport, Transport},
6    types::{ClaudeCodeOptions, Message},
7};
8use futures::StreamExt;
9use std::sync::Arc;
10use tokio::sync::{Mutex, RwLock, mpsc};
11use tracing::{debug, error, info};
12
13/// Client state
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ClientState {
16    Disconnected,
17    Connected,
18    Error,
19}
20
21/// Working interactive client
22pub struct ClaudeSDKClientWorking {
23    /// Configuration options
24    options: ClaudeCodeOptions,
25    /// Transport wrapped in Arc<Mutex<>> for shared access
26    transport: Arc<Mutex<Option<SubprocessTransport>>>,
27    /// Channel to receive messages
28    message_rx: Arc<Mutex<Option<mpsc::Receiver<Message>>>>,
29    /// Client state
30    state: Arc<RwLock<ClientState>>,
31}
32
33impl ClaudeSDKClientWorking {
34    /// Create a new client
35    pub fn new(options: ClaudeCodeOptions) -> Self {
36        unsafe {
37            std::env::set_var("CLAUDE_CODE_ENTRYPOINT", "sdk-rust");
38        }
39
40        Self {
41            options,
42            transport: Arc::new(Mutex::new(None)),
43            message_rx: Arc::new(Mutex::new(None)),
44            state: Arc::new(RwLock::new(ClientState::Disconnected)),
45        }
46    }
47
48    /// Connect to Claude
49    pub async fn connect(&mut self, initial_prompt: Option<String>) -> Result<()> {
50        // Check if already connected
51        {
52            let state = self.state.read().await;
53            if *state == ClientState::Connected {
54                return Ok(());
55            }
56        }
57
58        // Create transport
59        let mut new_transport = SubprocessTransport::new(self.options.clone())?;
60        new_transport.connect().await?;
61
62        // Create message channel
63        let (tx, rx) = mpsc::channel::<Message>(100);
64
65        // Store transport
66        {
67            let mut transport = self.transport.lock().await;
68            *transport = Some(new_transport);
69        }
70
71        // Store receiver
72        {
73            let mut message_rx = self.message_rx.lock().await;
74            *message_rx = Some(rx);
75        }
76
77        // Update state
78        {
79            let mut state = self.state.write().await;
80            *state = ClientState::Connected;
81        }
82
83        // Start background task to read messages
84        let transport_clone = self.transport.clone();
85        let state_clone = self.state.clone();
86        let tx_clone = tx.clone();
87
88        tokio::spawn(async move {
89            loop {
90                // Get one message at a time
91                let msg_result = {
92                    let mut transport_guard = transport_clone.lock().await;
93                    if let Some(transport) = transport_guard.as_mut() {
94                        // Get the stream and immediately poll it once
95                        let mut stream = transport.receive_messages();
96                        stream.next().await
97                    } else {
98                        break;
99                    }
100                };
101
102                // Process the message if we got one
103                if let Some(result) = msg_result {
104                    match result {
105                        Ok(msg) => {
106                            debug!("Received message: {:?}", msg);
107                            if tx_clone.send(msg).await.is_err() {
108                                break;
109                            }
110                        }
111                        Err(e) => {
112                            error!("Error receiving message: {}", e);
113                            let mut state = state_clone.write().await;
114                            *state = ClientState::Error;
115                            break;
116                        }
117                    }
118                } else {
119                    // No message available, wait a bit
120                    tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
121                }
122
123                // Stream ended, check if we should reconnect
124                tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
125
126                let should_continue = {
127                    let state = state_clone.read().await;
128                    *state == ClientState::Connected
129                };
130
131                if !should_continue {
132                    break;
133                }
134            }
135
136            debug!("Message reader task ended");
137        });
138
139        info!("Connected to Claude CLI");
140
141        // Send initial prompt if provided
142        if let Some(prompt) = initial_prompt {
143            self.send_user_message(prompt).await?;
144        }
145
146        Ok(())
147    }
148
149    /// Send a user message
150    pub async fn send_user_message(&mut self, prompt: String) -> Result<()> {
151        // Check connection
152        {
153            let state = self.state.read().await;
154            if *state != ClientState::Connected {
155                return Err(SdkError::InvalidState {
156                    message: "Not connected".into(),
157                });
158            }
159        }
160
161        // Create message
162        let message = InputMessage::user(prompt, "default".to_string());
163
164        // Send message
165        {
166            let mut transport_guard = self.transport.lock().await;
167            if let Some(transport) = transport_guard.as_mut() {
168                transport.send_message(message).await?;
169                debug!("User message sent");
170            } else {
171                return Err(SdkError::InvalidState {
172                    message: "Transport not available".into(),
173                });
174            }
175        }
176
177        Ok(())
178    }
179
180    /// Receive next message
181    pub async fn receive_message(&mut self) -> Result<Option<Message>> {
182        let mut rx_guard = self.message_rx.lock().await;
183        if let Some(rx) = rx_guard.as_mut() {
184            Ok(rx.recv().await)
185        } else {
186            Err(SdkError::InvalidState {
187                message: "Not connected".into(),
188            })
189        }
190    }
191
192    /// Receive all messages until result
193    pub async fn receive_response(&mut self) -> Result<Vec<Message>> {
194        let mut messages = Vec::new();
195
196        while let Some(msg) = self.receive_message().await? {
197            let is_result = matches!(msg, Message::Result { .. });
198            messages.push(msg);
199            if is_result {
200                break;
201            }
202        }
203
204        Ok(messages)
205    }
206
207    /// Disconnect
208    pub async fn disconnect(&mut self) -> Result<()> {
209        // Update state
210        {
211            let mut state = self.state.write().await;
212            if *state == ClientState::Disconnected {
213                return Ok(());
214            }
215            *state = ClientState::Disconnected;
216        }
217
218        // Disconnect transport
219        {
220            let mut transport_guard = self.transport.lock().await;
221            if let Some(mut transport) = transport_guard.take() {
222                transport.disconnect().await?;
223            }
224        }
225
226        // Clear receiver
227        {
228            let mut rx_guard = self.message_rx.lock().await;
229            rx_guard.take();
230        }
231
232        info!("Disconnected from Claude CLI");
233        Ok(())
234    }
235
236    /// Check if connected
237    pub async fn is_connected(&self) -> bool {
238        let state = self.state.read().await;
239        *state == ClientState::Connected
240    }
241}