1use crate::{
7 errors::{Result, SdkError},
8 internal_query::Query,
9 token_tracker::BudgetManager,
10 transport::{InputMessage, SubprocessTransport, Transport},
11 types::{ClaudeCodeOptions, ControlRequest, ControlResponse, Message},
12};
13use futures::stream::{Stream, StreamExt};
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::pin::Pin;
17use tokio::sync::{Mutex, RwLock, mpsc};
18use tokio_stream::wrappers::ReceiverStream;
19use tracing::{debug, error, info};
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum ClientState {
24 Disconnected,
26 Connected,
28 Error,
30}
31
32pub struct ClaudeSDKClient {
86 #[allow(dead_code)]
88 options: ClaudeCodeOptions,
89 transport: Arc<Mutex<Box<dyn Transport + Send>>>,
91 query_handler: Option<Arc<Mutex<Query>>>,
93 state: Arc<RwLock<ClientState>>,
95 sessions: Arc<RwLock<HashMap<String, SessionData>>>,
97 message_tx: Arc<Mutex<Option<mpsc::Sender<Result<Message>>>>>,
99 message_buffer: Arc<Mutex<Vec<Message>>>,
101 request_counter: Arc<Mutex<u64>>,
103 budget_manager: BudgetManager,
105}
106
107#[allow(dead_code)]
109struct SessionData {
110 id: String,
112 message_count: usize,
114 created_at: std::time::Instant,
116}
117
118impl ClaudeSDKClient {
119 pub fn new(options: ClaudeCodeOptions) -> Self {
121 unsafe {
123 std::env::set_var("CLAUDE_CODE_ENTRYPOINT", "sdk-rust");
124 }
125
126 let transport = match SubprocessTransport::new(options.clone()) {
127 Ok(t) => t,
128 Err(e) => {
129 error!("Failed to create transport: {}", e);
130 SubprocessTransport::with_cli_path(options.clone(), "")
132 }
133 };
134
135 let transport_arc: Arc<Mutex<Box<dyn Transport + Send>>> =
137 Arc::new(Mutex::new(Box::new(transport)));
138
139 let query_handler = if options.can_use_tool.is_some()
141 || options.hooks.is_some()
142 || !options.mcp_servers.is_empty() {
143 let sdk_mcp_servers: HashMap<String, Arc<dyn std::any::Any + Send + Sync>> = options.mcp_servers
145 .iter()
146 .filter_map(|(k, v)| {
147 if let crate::types::McpServerConfig::Sdk { name: _, instance } = v {
149 Some((k.clone(), instance.clone()))
150 } else {
151 None
152 }
153 })
154 .collect();
155
156 let is_streaming = options.can_use_tool.is_some()
158 || options.hooks.is_some()
159 || !sdk_mcp_servers.is_empty();
160
161 let query = Query::new(
162 transport_arc.clone(), is_streaming, options.can_use_tool.clone(),
165 options.hooks.clone(),
166 sdk_mcp_servers,
167 );
168 Some(Arc::new(Mutex::new(query)))
169 } else {
170 None
171 };
172
173 Self {
174 options,
175 transport: transport_arc,
176 query_handler,
177 state: Arc::new(RwLock::new(ClientState::Disconnected)),
178 sessions: Arc::new(RwLock::new(HashMap::new())),
179 message_tx: Arc::new(Mutex::new(None)),
180 message_buffer: Arc::new(Mutex::new(Vec::new())),
181 request_counter: Arc::new(Mutex::new(0)),
182 budget_manager: BudgetManager::new(),
183 }
184 }
185
186 pub async fn connect(&mut self, initial_prompt: Option<String>) -> Result<()> {
188 {
190 let state = self.state.read().await;
191 if *state == ClientState::Connected {
192 return Ok(());
193 }
194 }
195
196 {
198 let mut transport = self.transport.lock().await;
199 transport.connect().await?;
200 }
201
202 if let Some(ref query_handler) = self.query_handler {
204 let mut handler = query_handler.lock().await;
205 handler.start().await?;
206 handler.initialize().await?;
207 info!("Initialized SDK control protocol");
208 }
209
210 {
212 let mut state = self.state.write().await;
213 *state = ClientState::Connected;
214 }
215
216 info!("Connected to Claude CLI");
217
218 self.start_message_receiver().await;
220
221 if let Some(prompt) = initial_prompt {
223 self.send_request(prompt, None).await?;
224 }
225
226 Ok(())
227 }
228
229 pub async fn send_user_message(&mut self, prompt: String) -> Result<()> {
231 {
233 let state = self.state.read().await;
234 if *state != ClientState::Connected {
235 return Err(SdkError::InvalidState {
236 message: "Not connected".into(),
237 });
238 }
239 }
240
241 let session_id = "default".to_string();
243
244 {
246 let mut sessions = self.sessions.write().await;
247 let session = sessions.entry(session_id.clone()).or_insert_with(|| {
248 debug!("Creating new session: {}", session_id);
249 SessionData {
250 id: session_id.clone(),
251 message_count: 0,
252 created_at: std::time::Instant::now(),
253 }
254 });
255 session.message_count += 1;
256 }
257
258 let message = InputMessage::user(prompt, session_id.clone());
260
261 {
262 let mut transport = self.transport.lock().await;
263 transport.send_message(message).await?;
264 }
265
266 debug!("Sent request to Claude");
267 Ok(())
268 }
269
270 pub async fn send_request(
272 &mut self,
273 prompt: String,
274 _session_id: Option<String>,
275 ) -> Result<()> {
276 self.send_user_message(prompt).await
278 }
279
280 pub async fn receive_messages(&mut self) -> impl Stream<Item = Result<Message>> + use<> {
285 let (tx, rx) = mpsc::channel(100);
289
290 let buffered_messages = {
292 let mut buffer = self.message_buffer.lock().await;
293 std::mem::take(&mut *buffer)
294 };
295
296 let tx_clone = tx.clone();
298 tokio::spawn(async move {
299 for msg in buffered_messages {
300 if tx_clone.send(Ok(msg)).await.is_err() {
301 break;
302 }
303 }
304 });
305
306 {
308 let mut message_tx = self.message_tx.lock().await;
309 *message_tx = Some(tx);
310 }
311
312 ReceiverStream::new(rx)
313 }
314
315 pub async fn interrupt(&mut self) -> Result<()> {
317 {
319 let state = self.state.read().await;
320 if *state != ClientState::Connected {
321 return Err(SdkError::InvalidState {
322 message: "Not connected".into(),
323 });
324 }
325 }
326
327 if let Some(ref query_handler) = self.query_handler {
329 let mut handler = query_handler.lock().await;
330 return handler.interrupt().await;
331 }
332
333 let request_id = {
336 let mut counter = self.request_counter.lock().await;
337 *counter += 1;
338 format!("interrupt_{}", *counter)
339 };
340
341 let request = ControlRequest::Interrupt {
343 request_id: request_id.clone(),
344 };
345
346 {
347 let mut transport = self.transport.lock().await;
348 transport.send_control_request(request).await?;
349 }
350
351 info!("Sent interrupt request: {}", request_id);
352
353 let transport = self.transport.clone();
355 let ack_task = tokio::spawn(async move {
356 let mut transport = transport.lock().await;
357 match tokio::time::timeout(
358 std::time::Duration::from_secs(5),
359 transport.receive_control_response(),
360 )
361 .await
362 {
363 Ok(Ok(Some(ControlResponse::InterruptAck {
364 request_id: ack_id,
365 success,
366 }))) => {
367 if ack_id == request_id && success {
368 Ok(())
369 } else {
370 Err(SdkError::ControlRequestError(
371 "Interrupt not acknowledged successfully".into(),
372 ))
373 }
374 }
375 Ok(Ok(None)) => Err(SdkError::ControlRequestError(
376 "No interrupt acknowledgment received".into(),
377 )),
378 Ok(Err(e)) => Err(e),
379 Err(_) => Err(SdkError::timeout(5)),
380 }
381 });
382
383 ack_task
384 .await
385 .map_err(|_| SdkError::ControlRequestError("Interrupt task panicked".into()))?
386 }
387
388 pub async fn is_connected(&self) -> bool {
390 let state = self.state.read().await;
391 *state == ClientState::Connected
392 }
393
394 pub async fn get_sessions(&self) -> Vec<String> {
396 let sessions = self.sessions.read().await;
397 sessions.keys().cloned().collect()
398 }
399
400 pub async fn receive_response(&mut self) -> Pin<Box<dyn Stream<Item = Result<Message>> + Send + '_>> {
405 let mut messages = self.receive_messages().await;
406
407 Box::pin(async_stream::stream! {
409 while let Some(msg_result) = messages.next().await {
410 match &msg_result {
411 Ok(Message::Result { .. }) => {
412 yield msg_result;
413 return;
414 }
415 _ => {
416 yield msg_result;
417 }
418 }
419 }
420 })
421 }
422
423 pub async fn get_server_info(&self) -> Option<serde_json::Value> {
430 if let Some(ref query_handler) = self.query_handler {
432 let handler = query_handler.lock().await;
433 if let Some(init_result) = handler.get_initialization_result() {
434 return Some(init_result.clone());
435 }
436 }
437
438 let buffer = self.message_buffer.lock().await;
440 for msg in buffer.iter() {
441 if let Message::System { subtype, data } = msg {
442 if subtype == "init" {
443 return Some(data.clone());
444 }
445 }
446 }
447 None
448 }
449
450 pub async fn query(&mut self, prompt: String, session_id: Option<String>) -> Result<()> {
454 let session_id = session_id.unwrap_or_else(|| "default".to_string());
455
456 let message = InputMessage::user(prompt, session_id);
458
459 {
460 let mut transport = self.transport.lock().await;
461 transport.send_message(message).await?;
462 }
463
464 Ok(())
465 }
466
467 pub async fn disconnect(&mut self) -> Result<()> {
469 {
471 let state = self.state.read().await;
472 if *state == ClientState::Disconnected {
473 return Ok(());
474 }
475 }
476
477 {
479 let mut transport = self.transport.lock().await;
480 transport.disconnect().await?;
481 }
482
483 {
485 let mut state = self.state.write().await;
486 *state = ClientState::Disconnected;
487 }
488
489 {
491 let mut sessions = self.sessions.write().await;
492 sessions.clear();
493 }
494
495 info!("Disconnected from Claude CLI");
496 Ok(())
497 }
498
499 async fn start_message_receiver(&mut self) {
501 let transport = self.transport.clone();
502 let message_tx = self.message_tx.clone();
503 let message_buffer = self.message_buffer.clone();
504 let state = self.state.clone();
505 let budget_manager = self.budget_manager.clone();
506
507 tokio::spawn(async move {
508 let mut stream = {
510 let mut transport = transport.lock().await;
511 transport.receive_messages()
512 }; while let Some(result) = stream.next().await {
515 match result {
516 Ok(message) => {
517 if let Message::Result { .. } = &message {
519 if let Message::Result { usage, total_cost_usd, .. } = &message {
520 let (input_tokens, output_tokens) = if let Some(usage_json) = usage {
521 let input = usage_json.get("input_tokens")
522 .and_then(|v| v.as_u64())
523 .unwrap_or(0);
524 let output = usage_json.get("output_tokens")
525 .and_then(|v| v.as_u64())
526 .unwrap_or(0);
527 (input, output)
528 } else {
529 (0, 0)
530 };
531 let cost = total_cost_usd.unwrap_or(0.0);
532 budget_manager.update_usage(input_tokens, output_tokens, cost).await;
533 }
534 }
535
536 if let Message::System { subtype, .. } = &message {
538 if subtype == "init" {
539 let mut buffer = message_buffer.lock().await;
540 buffer.push(message.clone());
541 }
542 }
543
544 let sent = {
546 let mut tx_opt = message_tx.lock().await;
547 if let Some(tx) = tx_opt.as_mut() {
548 tx.send(Ok(message.clone())).await.is_ok()
549 } else {
550 false
551 }
552 };
553
554 if !sent {
556 let mut buffer = message_buffer.lock().await;
557 buffer.push(message);
558 }
559 }
560 Err(e) => {
561 error!("Error receiving message: {}", e);
562
563 let mut tx_opt = message_tx.lock().await;
565 if let Some(tx) = tx_opt.as_mut() {
566 let _ = tx.send(Err(e)).await;
567 }
568
569 let mut state = state.write().await;
571 *state = ClientState::Error;
572 break;
573 }
574 }
575 }
576
577 debug!("Message receiver task ended");
578 });
579 }
580
581 pub async fn get_usage_stats(&self) -> crate::token_tracker::TokenUsageTracker {
586 self.budget_manager.get_usage().await
587 }
588
589 pub async fn set_budget_limit(
612 &self,
613 limit: crate::token_tracker::BudgetLimit,
614 on_warning: Option<crate::token_tracker::BudgetWarningCallback>,
615 ) {
616 self.budget_manager.set_limit(limit).await;
617 if let Some(callback) = on_warning {
618 self.budget_manager.set_warning_callback(callback).await;
619 }
620 }
621
622 pub async fn clear_budget_limit(&self) {
624 self.budget_manager.clear_limit().await;
625 }
626
627 pub async fn reset_usage_stats(&self) {
632 self.budget_manager.reset_usage().await;
633 }
634
635 pub async fn is_budget_exceeded(&self) -> bool {
639 self.budget_manager.is_exceeded().await
640 }
641
642 }
644
645impl Drop for ClaudeSDKClient {
646 fn drop(&mut self) {
647 let transport = self.transport.clone();
649 let state = self.state.clone();
650
651 tokio::spawn(async move {
652 let state = state.read().await;
653 if *state == ClientState::Connected {
654 let mut transport = transport.lock().await;
655 if let Err(e) = transport.disconnect().await {
656 debug!("Error disconnecting in drop: {}", e);
657 }
658 }
659 });
660 }
661}
662
663#[cfg(test)]
664mod tests {
665 use super::*;
666
667 #[tokio::test]
668 async fn test_client_lifecycle() {
669 let options = ClaudeCodeOptions::default();
670 let client = ClaudeSDKClient::new(options);
671
672 assert!(!client.is_connected().await);
673 assert_eq!(client.get_sessions().await.len(), 0);
674 }
675
676 #[tokio::test]
677 async fn test_client_state_transitions() {
678 let options = ClaudeCodeOptions::default();
679 let client = ClaudeSDKClient::new(options);
680
681 let state = client.state.read().await;
682 assert_eq!(*state, ClientState::Disconnected);
683 }
684}