1use crate::{
4 errors::{Result, SdkError},
5 transport::{InputMessage, SubprocessTransport, Transport},
6 types::{ClaudeCodeOptions, ControlRequest, Message},
7};
8use crate::token_tracker::{BudgetLimit, BudgetManager, BudgetWarningCallback, TokenUsageTracker};
9use futures::stream::StreamExt;
10use std::collections::VecDeque;
11use std::sync::Arc;
12use tokio::sync::{RwLock, Semaphore, mpsc};
13use tokio::time::{Duration, timeout};
14use tracing::{debug, error, info, warn};
15
16#[derive(Debug, Clone, Copy)]
18pub enum ClientMode {
19 OneShot,
21 Interactive,
23 Batch {
25 max_concurrent: usize,
27 },
28}
29
30struct ConnectionPool {
32 idle_connections: Arc<RwLock<VecDeque<Box<dyn Transport + Send>>>>,
34 max_connections: usize,
36 connection_semaphore: Arc<Semaphore>,
38 base_options: ClaudeCodeOptions,
40}
41
42impl ConnectionPool {
43 fn new(base_options: ClaudeCodeOptions, max_connections: usize) -> Self {
44 Self {
45 idle_connections: Arc::new(RwLock::new(VecDeque::new())),
46 max_connections,
47 connection_semaphore: Arc::new(Semaphore::new(max_connections)),
48 base_options,
49 }
50 }
51
52 async fn acquire(&self) -> Result<Box<dyn Transport + Send>> {
53 {
55 let mut idle = self.idle_connections.write().await;
56 if let Some(transport) = idle.pop_front() {
57 if transport.is_connected() {
59 debug!("Reusing existing connection from pool");
60 return Ok(transport);
61 }
62 }
63 }
64
65 let _permit =
67 self.connection_semaphore
68 .acquire()
69 .await
70 .map_err(|_| SdkError::InvalidState {
71 message: "Failed to acquire connection permit".into(),
72 })?;
73
74 let mut transport: Box<dyn Transport + Send> =
75 Box::new(SubprocessTransport::new(self.base_options.clone())?);
76 transport.connect().await?;
77 debug!("Created new connection");
78 Ok(transport)
79 }
80
81 async fn release(&self, transport: Box<dyn Transport + Send>) {
82 if transport.is_connected() && self.idle_connections.read().await.len() < self.max_connections {
83 let mut idle = self.idle_connections.write().await;
84 idle.push_back(transport);
85 debug!("Returned connection to pool");
86 } else {
87 debug!("Dropping connection");
89 }
90 }
91}
92
93pub struct OptimizedClient {
95 mode: ClientMode,
97 pool: Arc<ConnectionPool>,
99 message_rx: Arc<RwLock<Option<mpsc::Receiver<Message>>>>,
101 current_transport: Arc<RwLock<Option<Box<dyn Transport + Send>>>>,
103 budget_manager: BudgetManager,
105}
106
107impl OptimizedClient {
108 pub fn new(options: ClaudeCodeOptions, mode: ClientMode) -> Result<Self> {
110 unsafe {
111 std::env::set_var("CLAUDE_CODE_ENTRYPOINT", "sdk-rust");
112 }
113
114 let max_connections = match mode {
115 ClientMode::Batch { max_concurrent } => max_concurrent,
116 _ => 1,
117 };
118
119 let pool = Arc::new(ConnectionPool::new(options, max_connections));
120
121 Ok(Self {
122 mode,
123 pool,
124 message_rx: Arc::new(RwLock::new(None)),
125 current_transport: Arc::new(RwLock::new(None)),
126 budget_manager: BudgetManager::new(),
127 })
128 }
129
130 pub async fn query(&self, prompt: String) -> Result<Vec<Message>> {
132 self.query_with_retry(prompt, 3, Duration::from_millis(100))
133 .await
134 }
135
136 pub async fn query_with_retry(
138 &self,
139 prompt: String,
140 max_retries: u32,
141 initial_delay: Duration,
142 ) -> Result<Vec<Message>> {
143 let mut retries = 0;
144 let mut delay = initial_delay;
145
146 loop {
147 match self.execute_query(&prompt).await {
148 Ok(messages) => return Ok(messages),
149 Err(e) if retries < max_retries => {
150 warn!("Query failed, retrying in {:?}: {}", delay, e);
151 tokio::time::sleep(delay).await;
152 retries += 1;
153 delay *= 2; }
155 Err(e) => return Err(e),
156 }
157 }
158 }
159
160 async fn execute_query(&self, prompt: &str) -> Result<Vec<Message>> {
162 let mut transport = self.pool.acquire().await?;
163
164 let message = InputMessage::user(prompt.to_string(), "default".to_string());
166 transport.send_message(message).await?;
167
168 let timeout_duration = Duration::from_secs(120);
170 let messages = timeout(timeout_duration, self.collect_messages(&mut *transport))
171 .await
172 .map_err(|_| SdkError::Timeout { seconds: 120 })??;
173
174 self.pool.release(transport).await;
176
177 Ok(messages)
178 }
179
180 async fn collect_messages<T: Transport + Send + ?Sized>(&self, transport: &mut T) -> Result<Vec<Message>> {
182 let mut messages = Vec::new();
183 let mut stream = transport.receive_messages();
184
185 while let Some(result) = stream.next().await {
186 match result {
187 Ok(msg) => {
188 debug!("Received: {:?}", msg);
189 let is_result = matches!(msg, Message::Result { .. });
190
191 if let Message::Result { usage, total_cost_usd, .. } = &msg {
193 let (input_tokens, output_tokens) = if let Some(usage_json) = usage {
194 let input = usage_json
195 .get("input_tokens")
196 .and_then(|v| v.as_u64())
197 .unwrap_or(0);
198 let output = usage_json
199 .get("output_tokens")
200 .and_then(|v| v.as_u64())
201 .unwrap_or(0);
202 (input, output)
203 } else {
204 (0, 0)
205 };
206 let cost = total_cost_usd.unwrap_or(0.0);
207 self.budget_manager
208 .update_usage(input_tokens, output_tokens, cost)
209 .await;
210 }
211 messages.push(msg);
212 if is_result {
213 break;
214 }
215 }
216 Err(e) => return Err(e),
217 }
218 }
219
220 Ok(messages)
221 }
222
223 pub async fn get_usage_stats(&self) -> TokenUsageTracker {
225 self.budget_manager.get_usage().await
226 }
227
228 pub async fn set_budget_limit(
242 &self,
243 limit: BudgetLimit,
244 on_warning: Option<BudgetWarningCallback>,
245 ) {
246 self.budget_manager.set_limit(limit).await;
247 if let Some(cb) = on_warning {
248 self.budget_manager.set_warning_callback(cb).await;
249 }
250 }
251
252 pub async fn clear_budget_limit(&self) {
254 self.budget_manager.clear_limit().await;
255 }
256
257 pub async fn reset_usage_stats(&self) {
259 self.budget_manager.reset_usage().await;
260 }
261
262 pub async fn is_budget_exceeded(&self) -> bool {
264 self.budget_manager.is_exceeded().await
265 }
266
267 pub async fn start_interactive_session(&self) -> Result<()> {
269 if !matches!(self.mode, ClientMode::Interactive) {
270 return Err(SdkError::InvalidState {
271 message: "Client not in interactive mode".into(),
272 });
273 }
274
275 let transport = self.pool.acquire().await?;
277
278 let (tx, rx) = mpsc::channel::<Message>(100);
280
281 *self.current_transport.write().await = Some(transport);
283 *self.message_rx.write().await = Some(rx);
284
285 self.start_message_processor(tx).await;
287
288 info!("Interactive session started");
289 Ok(())
290 }
291
292 async fn start_message_processor(&self, tx: mpsc::Sender<Message>) {
294 let transport_ref = self.current_transport.clone();
295
296 tokio::spawn(async move {
297 loop {
298 let msg_result = {
300 let mut transport_guard = transport_ref.write().await;
301 if let Some(transport) = transport_guard.as_mut() {
302 let mut stream = transport.receive_messages();
303 stream.next().await
304 } else {
305 break;
306 }
307 };
308
309 if let Some(result) = msg_result {
311 match result {
312 Ok(msg) => {
313 if tx.send(msg).await.is_err() {
314 error!("Failed to send message to channel");
315 break;
316 }
317 }
318 Err(e) => {
319 error!("Error receiving message: {}", e);
320 break;
321 }
322 }
323 }
324 }
325 });
326 }
327
328 pub async fn send_interactive(&self, prompt: String) -> Result<()> {
330 let transport_guard = self.current_transport.read().await;
331 if let Some(_transport) = transport_guard.as_ref() {
332 drop(transport_guard);
334
335 let mut transport_guard = self.current_transport.write().await;
336 if let Some(transport) = transport_guard.as_mut() {
337 let message = InputMessage::user(prompt, "default".to_string());
338 transport.send_message(message).await?;
339 } else {
340 return Err(SdkError::InvalidState {
341 message: "Transport lost during operation".into(),
342 });
343 }
344 Ok(())
345 } else {
346 Err(SdkError::InvalidState {
347 message: "No active interactive session".into(),
348 })
349 }
350 }
351
352 pub async fn receive_interactive(&self) -> Result<Vec<Message>> {
354 let mut rx_guard = self.message_rx.write().await;
355 if let Some(rx) = rx_guard.as_mut() {
356 let mut messages = Vec::new();
357
358 while let Some(msg) = rx.recv().await {
360 let is_result = matches!(msg, Message::Result { .. });
361 messages.push(msg);
362 if is_result {
363 break;
364 }
365 }
366
367 Ok(messages)
368 } else {
369 Err(SdkError::InvalidState {
370 message: "No active interactive session".into(),
371 })
372 }
373 }
374
375 pub async fn process_batch(&self, prompts: Vec<String>) -> Result<Vec<Result<Vec<Message>>>> {
377 let max_concurrent = match self.mode {
378 ClientMode::Batch { max_concurrent } => max_concurrent,
379 _ => {
380 return Err(SdkError::InvalidState {
381 message: "Client not in batch mode".into(),
382 });
383 }
384 };
385
386 let semaphore = Arc::new(Semaphore::new(max_concurrent));
387 let mut handles = Vec::new();
388
389 for prompt in prompts {
390 let permit = semaphore.clone().acquire_owned().await.unwrap();
391 let client = self.clone(); let handle = tokio::spawn(async move {
394 let result = client.query(prompt).await;
395 drop(permit);
396 result
397 });
398
399 handles.push(handle);
400 }
401
402 let mut results = Vec::new();
404 for handle in handles {
405 match handle.await {
406 Ok(result) => results.push(result),
407 Err(e) => {
408 results.push(Err(SdkError::TransportError(format!("Task failed: {e}"))))
409 }
410 }
411 }
412
413 Ok(results)
414 }
415
416 pub async fn interrupt(&self) -> Result<()> {
418 let transport_guard = self.current_transport.read().await;
419 if let Some(_transport) = transport_guard.as_ref() {
420 drop(transport_guard);
421
422 let mut transport_guard = self.current_transport.write().await;
423 if let Some(transport) = transport_guard.as_mut() {
424 let request = ControlRequest::Interrupt {
425 request_id: uuid::Uuid::new_v4().to_string(),
426 };
427 transport.send_control_request(request).await?;
428 } else {
429 return Err(SdkError::InvalidState {
430 message: "Transport lost during operation".into(),
431 });
432 }
433 info!("Interrupt sent");
434 Ok(())
435 } else {
436 Err(SdkError::InvalidState {
437 message: "No active session".into(),
438 })
439 }
440 }
441
442 pub async fn end_interactive_session(&self) -> Result<()> {
444 if let Some(transport) = self.current_transport.write().await.take() {
446 self.pool.release(transport).await;
447 }
448
449 *self.message_rx.write().await = None;
451
452 info!("Interactive session ended");
453 Ok(())
454 }
455}
456
457impl Clone for OptimizedClient {
459 fn clone(&self) -> Self {
460 Self {
461 mode: self.mode,
462 pool: self.pool.clone(),
463 message_rx: Arc::new(RwLock::new(None)),
464 current_transport: Arc::new(RwLock::new(None)),
465 budget_manager: self.budget_manager.clone(),
466 }
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473
474 #[test]
475 fn test_client_mode_creation() {
476 let options = ClaudeCodeOptions::builder().build();
477
478 let client = OptimizedClient::new(options.clone(), ClientMode::OneShot);
480 assert!(client.is_ok());
481
482 let client = OptimizedClient::new(options.clone(), ClientMode::Interactive);
484 assert!(client.is_ok());
485
486 let client = OptimizedClient::new(options, ClientMode::Batch { max_concurrent: 5 });
488 assert!(client.is_ok());
489 }
490
491 #[test]
492 fn test_connection_pool_creation() {
493 let options = ClaudeCodeOptions::builder().build();
494 let pool = ConnectionPool::new(options, 10);
495
496 assert_eq!(pool.max_connections, 10);
497 }
498
499 #[tokio::test]
500 async fn test_client_cloning() {
501 let options = ClaudeCodeOptions::builder().build();
502 let client = OptimizedClient::new(options, ClientMode::OneShot).unwrap();
503
504 let cloned = client.clone();
505
506 match (client.mode, cloned.mode) {
508 (ClientMode::OneShot, ClientMode::OneShot) => (),
509 _ => panic!("Mode not preserved during cloning"),
510 }
511 }
512}