1use crate::{
7 errors::{Result, SdkError},
8 transport::{InputMessage, SubprocessTransport, Transport},
9 types::{ClaudeCodeOptions, ControlRequest, ControlResponse, Message},
10};
11use futures::stream::{Stream, StreamExt};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::{Mutex, RwLock, mpsc};
15use tokio_stream::wrappers::ReceiverStream;
16use tracing::{debug, error, info};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum ClientState {
21 Disconnected,
23 Connected,
25 Error,
27}
28
29pub struct ClaudeSDKClient {
83 #[allow(dead_code)]
85 options: ClaudeCodeOptions,
86 transport: Arc<Mutex<SubprocessTransport>>,
88 state: Arc<RwLock<ClientState>>,
90 sessions: Arc<RwLock<HashMap<String, SessionData>>>,
92 message_tx: Arc<Mutex<Option<mpsc::Sender<Result<Message>>>>>,
94 message_buffer: Arc<Mutex<Vec<Message>>>,
96 request_counter: Arc<Mutex<u64>>,
98}
99
100#[allow(dead_code)]
102struct SessionData {
103 id: String,
105 message_count: usize,
107 created_at: std::time::Instant,
109}
110
111impl ClaudeSDKClient {
112 pub fn new(options: ClaudeCodeOptions) -> Self {
114 unsafe {
116 std::env::set_var("CLAUDE_CODE_ENTRYPOINT", "sdk-rust");
117 }
118
119 let transport = match SubprocessTransport::new(options.clone()) {
120 Ok(t) => t,
121 Err(e) => {
122 error!("Failed to create transport: {}", e);
123 SubprocessTransport::with_cli_path(options.clone(), "")
125 }
126 };
127
128 Self {
129 options,
130 transport: Arc::new(Mutex::new(transport)),
131 state: Arc::new(RwLock::new(ClientState::Disconnected)),
132 sessions: Arc::new(RwLock::new(HashMap::new())),
133 message_tx: Arc::new(Mutex::new(None)),
134 message_buffer: Arc::new(Mutex::new(Vec::new())),
135 request_counter: Arc::new(Mutex::new(0)),
136 }
137 }
138
139 pub async fn connect(&mut self, initial_prompt: Option<String>) -> Result<()> {
141 {
143 let state = self.state.read().await;
144 if *state == ClientState::Connected {
145 return Ok(());
146 }
147 }
148
149 {
151 let mut transport = self.transport.lock().await;
152 transport.connect().await?;
153 }
154
155 {
157 let mut state = self.state.write().await;
158 *state = ClientState::Connected;
159 }
160
161 info!("Connected to Claude CLI");
162
163 self.start_message_receiver().await;
165
166 if let Some(prompt) = initial_prompt {
168 self.send_request(prompt, None).await?;
169 }
170
171 Ok(())
172 }
173
174 pub async fn send_user_message(&mut self, prompt: String) -> Result<()> {
176 {
178 let state = self.state.read().await;
179 if *state != ClientState::Connected {
180 return Err(SdkError::InvalidState {
181 message: "Not connected".into(),
182 });
183 }
184 }
185
186 let session_id = "default".to_string();
188
189 {
191 let mut sessions = self.sessions.write().await;
192 let session = sessions.entry(session_id.clone()).or_insert_with(|| {
193 debug!("Creating new session: {}", session_id);
194 SessionData {
195 id: session_id.clone(),
196 message_count: 0,
197 created_at: std::time::Instant::now(),
198 }
199 });
200 session.message_count += 1;
201 }
202
203 let message = InputMessage::user(prompt, session_id.clone());
205
206 {
207 let mut transport = self.transport.lock().await;
208 transport.send_message(message).await?;
209 }
210
211 debug!("Sent request to Claude");
212 Ok(())
213 }
214
215 pub async fn send_request(
217 &mut self,
218 prompt: String,
219 _session_id: Option<String>,
220 ) -> Result<()> {
221 self.send_user_message(prompt).await
223 }
224
225 pub async fn receive_messages(&mut self) -> impl Stream<Item = Result<Message>> + use<> {
230 let (tx, rx) = mpsc::channel(100);
232
233 let buffered_messages = {
235 let mut buffer = self.message_buffer.lock().await;
236 std::mem::take(&mut *buffer)
237 };
238
239 let tx_clone = tx.clone();
241 tokio::spawn(async move {
242 for msg in buffered_messages {
243 if tx_clone.send(Ok(msg)).await.is_err() {
244 break;
245 }
246 }
247 });
248
249 {
251 let mut message_tx = self.message_tx.lock().await;
252 *message_tx = Some(tx);
253 }
254
255 ReceiverStream::new(rx)
256 }
257
258 pub async fn interrupt(&mut self) -> Result<()> {
260 {
262 let state = self.state.read().await;
263 if *state != ClientState::Connected {
264 return Err(SdkError::InvalidState {
265 message: "Not connected".into(),
266 });
267 }
268 }
269
270 let request_id = {
272 let mut counter = self.request_counter.lock().await;
273 *counter += 1;
274 format!("interrupt_{}", *counter)
275 };
276
277 let request = ControlRequest::Interrupt {
279 request_id: request_id.clone(),
280 };
281
282 {
283 let mut transport = self.transport.lock().await;
284 transport.send_control_request(request).await?;
285 }
286
287 info!("Sent interrupt request: {}", request_id);
288
289 let transport = self.transport.clone();
291 let ack_task = tokio::spawn(async move {
292 let mut transport = transport.lock().await;
293 match tokio::time::timeout(
294 std::time::Duration::from_secs(5),
295 transport.receive_control_response(),
296 )
297 .await
298 {
299 Ok(Ok(Some(ControlResponse::InterruptAck {
300 request_id: ack_id,
301 success,
302 }))) => {
303 if ack_id == request_id && success {
304 Ok(())
305 } else {
306 Err(SdkError::ControlRequestError(
307 "Interrupt not acknowledged successfully".into(),
308 ))
309 }
310 }
311 Ok(Ok(None)) => Err(SdkError::ControlRequestError(
312 "No interrupt acknowledgment received".into(),
313 )),
314 Ok(Err(e)) => Err(e),
315 Err(_) => Err(SdkError::timeout(5)),
316 }
317 });
318
319 ack_task
320 .await
321 .map_err(|_| SdkError::ControlRequestError("Interrupt task panicked".into()))?
322 }
323
324 pub async fn is_connected(&self) -> bool {
326 let state = self.state.read().await;
327 *state == ClientState::Connected
328 }
329
330 pub async fn get_sessions(&self) -> Vec<String> {
332 let sessions = self.sessions.read().await;
333 sessions.keys().cloned().collect()
334 }
335
336 pub async fn disconnect(&mut self) -> Result<()> {
338 {
340 let state = self.state.read().await;
341 if *state == ClientState::Disconnected {
342 return Ok(());
343 }
344 }
345
346 {
348 let mut transport = self.transport.lock().await;
349 transport.disconnect().await?;
350 }
351
352 {
354 let mut state = self.state.write().await;
355 *state = ClientState::Disconnected;
356 }
357
358 {
360 let mut sessions = self.sessions.write().await;
361 sessions.clear();
362 }
363
364 info!("Disconnected from Claude CLI");
365 Ok(())
366 }
367
368 async fn start_message_receiver(&mut self) {
370 let transport = self.transport.clone();
371 let message_tx = self.message_tx.clone();
372 let message_buffer = self.message_buffer.clone();
373 let state = self.state.clone();
374
375 tokio::spawn(async move {
376 let mut transport = transport.lock().await;
377 let mut stream = transport.receive_messages();
378
379 while let Some(result) = stream.next().await {
380 match result {
381 Ok(message) => {
382 let sent = {
384 let mut tx_opt = message_tx.lock().await;
385 if let Some(tx) = tx_opt.as_mut() {
386 tx.send(Ok(message.clone())).await.is_ok()
387 } else {
388 false
389 }
390 };
391
392 if !sent {
394 let mut buffer = message_buffer.lock().await;
395 buffer.push(message);
396 }
397 }
398 Err(e) => {
399 error!("Error receiving message: {}", e);
400
401 let mut tx_opt = message_tx.lock().await;
403 if let Some(tx) = tx_opt.as_mut() {
404 let _ = tx.send(Err(e)).await;
405 }
406
407 let mut state = state.write().await;
409 *state = ClientState::Error;
410 break;
411 }
412 }
413 }
414
415 debug!("Message receiver task ended");
416 });
417 }
418}
419
420impl Drop for ClaudeSDKClient {
421 fn drop(&mut self) {
422 let transport = self.transport.clone();
424 let state = self.state.clone();
425
426 tokio::spawn(async move {
427 let state = state.read().await;
428 if *state == ClientState::Connected {
429 let mut transport = transport.lock().await;
430 if let Err(e) = transport.disconnect().await {
431 debug!("Error disconnecting in drop: {}", e);
432 }
433 }
434 });
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[tokio::test]
443 async fn test_client_lifecycle() {
444 let options = ClaudeCodeOptions::default();
445 let client = ClaudeSDKClient::new(options);
446
447 assert!(!client.is_connected().await);
448 assert_eq!(client.get_sessions().await.len(), 0);
449 }
450
451 #[tokio::test]
452 async fn test_client_state_transitions() {
453 let options = ClaudeCodeOptions::default();
454 let client = ClaudeSDKClient::new(options);
455
456 let state = client.state.read().await;
457 assert_eq!(*state, ClientState::Disconnected);
458 }
459}