1use crate::{
4 errors::{Result, SdkError},
5 transport::{InputMessage, SubprocessTransport, Transport},
6 types::{ClaudeCodeOptions, ControlRequest, Message},
7};
8use crate::token_tracker::{BudgetLimit, BudgetManager, BudgetWarningCallback, TokenUsageTracker};
9use futures::stream::StreamExt;
10use std::collections::VecDeque;
11use std::sync::Arc;
12use tokio::sync::{RwLock, Semaphore, mpsc};
13use tokio::time::{Duration, timeout};
14use tracing::{debug, error, info, warn};
15
16#[derive(Debug, Clone, Copy)]
18pub enum ClientMode {
19 OneShot,
21 Interactive,
23 Batch {
25 max_concurrent: usize,
27 },
28}
29
30struct ConnectionPool {
32 idle_connections: Arc<RwLock<VecDeque<Box<dyn Transport + Send>>>>,
34 max_connections: usize,
36 connection_semaphore: Arc<Semaphore>,
38 base_options: ClaudeCodeOptions,
40}
41
42impl ConnectionPool {
43 fn new(base_options: ClaudeCodeOptions, max_connections: usize) -> Self {
44 Self {
45 idle_connections: Arc::new(RwLock::new(VecDeque::new())),
46 max_connections,
47 connection_semaphore: Arc::new(Semaphore::new(max_connections)),
48 base_options,
49 }
50 }
51
52 async fn acquire(&self) -> Result<Box<dyn Transport + Send>> {
53 {
55 let mut idle = self.idle_connections.write().await;
56 if let Some(transport) = idle.pop_front() {
57 if transport.is_connected() {
59 debug!("Reusing existing connection from pool");
60 return Ok(transport);
61 }
62 }
63 }
64
65 let _permit =
67 self.connection_semaphore
68 .acquire()
69 .await
70 .map_err(|_| SdkError::InvalidState {
71 message: "Failed to acquire connection permit".into(),
72 })?;
73
74 let mut transport: Box<dyn Transport + Send> =
75 Box::new(SubprocessTransport::new(self.base_options.clone())?);
76 transport.connect().await?;
77 debug!("Created new connection");
78 Ok(transport)
79 }
80
81 async fn release(&self, transport: Box<dyn Transport + Send>) {
82 if transport.is_connected() && self.idle_connections.read().await.len() < self.max_connections {
83 let mut idle = self.idle_connections.write().await;
84 idle.push_back(transport);
85 debug!("Returned connection to pool");
86 } else {
87 debug!("Dropping connection");
89 }
90 }
91}
92
93pub struct OptimizedClient {
95 mode: ClientMode,
97 pool: Arc<ConnectionPool>,
99 message_rx: Arc<RwLock<Option<mpsc::Receiver<Message>>>>,
101 current_transport: Arc<RwLock<Option<Box<dyn Transport + Send>>>>,
103 budget_manager: BudgetManager,
105}
106
107impl OptimizedClient {
108 pub fn new(options: ClaudeCodeOptions, mode: ClientMode) -> Result<Self> {
110 let max_connections = match mode {
111 ClientMode::Batch { max_concurrent } => max_concurrent,
112 _ => 1,
113 };
114
115 let pool = Arc::new(ConnectionPool::new(options, max_connections));
116
117 Ok(Self {
118 mode,
119 pool,
120 message_rx: Arc::new(RwLock::new(None)),
121 current_transport: Arc::new(RwLock::new(None)),
122 budget_manager: BudgetManager::new(),
123 })
124 }
125
126 pub async fn query(&self, prompt: String) -> Result<Vec<Message>> {
128 self.query_with_retry(prompt, 3, Duration::from_millis(100))
129 .await
130 }
131
132 pub async fn query_with_retry(
134 &self,
135 prompt: String,
136 max_retries: u32,
137 initial_delay: Duration,
138 ) -> Result<Vec<Message>> {
139 let mut retries = 0;
140 let mut delay = initial_delay;
141
142 loop {
143 match self.execute_query(&prompt).await {
144 Ok(messages) => return Ok(messages),
145 Err(e) if retries < max_retries => {
146 warn!("Query failed, retrying in {:?}: {}", delay, e);
147 tokio::time::sleep(delay).await;
148 retries += 1;
149 delay *= 2; }
151 Err(e) => return Err(e),
152 }
153 }
154 }
155
156 async fn execute_query(&self, prompt: &str) -> Result<Vec<Message>> {
158 let mut transport = self.pool.acquire().await?;
159
160 let message = InputMessage::user(prompt.to_string(), "default".to_string());
162 transport.send_message(message).await?;
163
164 let timeout_duration = Duration::from_secs(120);
166 let messages = timeout(timeout_duration, self.collect_messages(&mut *transport))
167 .await
168 .map_err(|_| SdkError::Timeout { seconds: 120 })??;
169
170 self.pool.release(transport).await;
172
173 Ok(messages)
174 }
175
176 async fn collect_messages<T: Transport + Send + ?Sized>(&self, transport: &mut T) -> Result<Vec<Message>> {
178 let mut messages = Vec::new();
179 let mut stream = transport.receive_messages();
180
181 while let Some(result) = stream.next().await {
182 match result {
183 Ok(msg) => {
184 debug!("Received: {:?}", msg);
185 let is_result = matches!(msg, Message::Result { .. });
186
187 if let Message::Result { usage, total_cost_usd, .. } = &msg {
189 let (input_tokens, output_tokens) = if let Some(usage_json) = usage {
190 let input = usage_json
191 .get("input_tokens")
192 .and_then(|v| v.as_u64())
193 .unwrap_or(0);
194 let output = usage_json
195 .get("output_tokens")
196 .and_then(|v| v.as_u64())
197 .unwrap_or(0);
198 (input, output)
199 } else {
200 (0, 0)
201 };
202 let cost = total_cost_usd.unwrap_or(0.0);
203 self.budget_manager
204 .update_usage(input_tokens, output_tokens, cost)
205 .await;
206 }
207 messages.push(msg);
208 if is_result {
209 break;
210 }
211 }
212 Err(e) => return Err(e),
213 }
214 }
215
216 Ok(messages)
217 }
218
219 pub async fn get_usage_stats(&self) -> TokenUsageTracker {
221 self.budget_manager.get_usage().await
222 }
223
224 pub async fn set_budget_limit(
238 &self,
239 limit: BudgetLimit,
240 on_warning: Option<BudgetWarningCallback>,
241 ) {
242 self.budget_manager.set_limit(limit).await;
243 if let Some(cb) = on_warning {
244 self.budget_manager.set_warning_callback(cb).await;
245 }
246 }
247
248 pub async fn clear_budget_limit(&self) {
250 self.budget_manager.clear_limit().await;
251 }
252
253 pub async fn reset_usage_stats(&self) {
255 self.budget_manager.reset_usage().await;
256 }
257
258 pub async fn is_budget_exceeded(&self) -> bool {
260 self.budget_manager.is_exceeded().await
261 }
262
263 pub async fn start_interactive_session(&self) -> Result<()> {
265 if !matches!(self.mode, ClientMode::Interactive) {
266 return Err(SdkError::InvalidState {
267 message: "Client not in interactive mode".into(),
268 });
269 }
270
271 let transport = self.pool.acquire().await?;
273
274 let (tx, rx) = mpsc::channel::<Message>(100);
276
277 *self.current_transport.write().await = Some(transport);
279 *self.message_rx.write().await = Some(rx);
280
281 self.start_message_processor(tx).await;
283
284 info!("Interactive session started");
285 Ok(())
286 }
287
288 async fn start_message_processor(&self, tx: mpsc::Sender<Message>) {
290 let transport_ref = self.current_transport.clone();
291
292 tokio::spawn(async move {
293 loop {
294 let msg_result = {
296 let mut transport_guard = transport_ref.write().await;
297 if let Some(transport) = transport_guard.as_mut() {
298 let mut stream = transport.receive_messages();
299 stream.next().await
300 } else {
301 break;
302 }
303 };
304
305 if let Some(result) = msg_result {
307 match result {
308 Ok(msg) => {
309 if tx.send(msg).await.is_err() {
310 error!("Failed to send message to channel");
311 break;
312 }
313 }
314 Err(e) => {
315 error!("Error receiving message: {}", e);
316 break;
317 }
318 }
319 }
320 }
321 });
322 }
323
324 pub async fn send_interactive(&self, prompt: String) -> Result<()> {
326 let transport_guard = self.current_transport.read().await;
327 if let Some(_transport) = transport_guard.as_ref() {
328 drop(transport_guard);
330
331 let mut transport_guard = self.current_transport.write().await;
332 if let Some(transport) = transport_guard.as_mut() {
333 let message = InputMessage::user(prompt, "default".to_string());
334 transport.send_message(message).await?;
335 } else {
336 return Err(SdkError::InvalidState {
337 message: "Transport lost during operation".into(),
338 });
339 }
340 Ok(())
341 } else {
342 Err(SdkError::InvalidState {
343 message: "No active interactive session".into(),
344 })
345 }
346 }
347
348 pub async fn receive_interactive(&self) -> Result<Vec<Message>> {
350 let mut rx_guard = self.message_rx.write().await;
351 if let Some(rx) = rx_guard.as_mut() {
352 let mut messages = Vec::new();
353
354 while let Some(msg) = rx.recv().await {
356 let is_result = matches!(msg, Message::Result { .. });
357 messages.push(msg);
358 if is_result {
359 break;
360 }
361 }
362
363 Ok(messages)
364 } else {
365 Err(SdkError::InvalidState {
366 message: "No active interactive session".into(),
367 })
368 }
369 }
370
371 pub async fn process_batch(&self, prompts: Vec<String>) -> Result<Vec<Result<Vec<Message>>>> {
373 let max_concurrent = match self.mode {
374 ClientMode::Batch { max_concurrent } => max_concurrent,
375 _ => {
376 return Err(SdkError::InvalidState {
377 message: "Client not in batch mode".into(),
378 });
379 }
380 };
381
382 let semaphore = Arc::new(Semaphore::new(max_concurrent));
383 let mut handles = Vec::new();
384
385 for prompt in prompts {
386 let permit = semaphore.clone().acquire_owned().await.unwrap();
387 let client = self.clone(); let handle = tokio::spawn(async move {
390 let result = client.query(prompt).await;
391 drop(permit);
392 result
393 });
394
395 handles.push(handle);
396 }
397
398 let mut results = Vec::new();
400 for handle in handles {
401 match handle.await {
402 Ok(result) => results.push(result),
403 Err(e) => {
404 results.push(Err(SdkError::TransportError(format!("Task failed: {e}"))))
405 }
406 }
407 }
408
409 Ok(results)
410 }
411
412 pub async fn interrupt(&self) -> Result<()> {
414 let transport_guard = self.current_transport.read().await;
415 if let Some(_transport) = transport_guard.as_ref() {
416 drop(transport_guard);
417
418 let mut transport_guard = self.current_transport.write().await;
419 if let Some(transport) = transport_guard.as_mut() {
420 let request = ControlRequest::Interrupt {
421 request_id: uuid::Uuid::new_v4().to_string(),
422 };
423 transport.send_control_request(request).await?;
424 } else {
425 return Err(SdkError::InvalidState {
426 message: "Transport lost during operation".into(),
427 });
428 }
429 info!("Interrupt sent");
430 Ok(())
431 } else {
432 Err(SdkError::InvalidState {
433 message: "No active session".into(),
434 })
435 }
436 }
437
438 pub async fn end_interactive_session(&self) -> Result<()> {
440 if let Some(transport) = self.current_transport.write().await.take() {
442 self.pool.release(transport).await;
443 }
444
445 *self.message_rx.write().await = None;
447
448 info!("Interactive session ended");
449 Ok(())
450 }
451}
452
453impl Clone for OptimizedClient {
455 fn clone(&self) -> Self {
456 Self {
457 mode: self.mode,
458 pool: self.pool.clone(),
459 message_rx: Arc::new(RwLock::new(None)),
460 current_transport: Arc::new(RwLock::new(None)),
461 budget_manager: self.budget_manager.clone(),
462 }
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 #[test]
471 fn test_client_mode_creation() {
472 let options = ClaudeCodeOptions::builder().build();
473
474 let client = OptimizedClient::new(options.clone(), ClientMode::OneShot);
476 assert!(client.is_ok());
477
478 let client = OptimizedClient::new(options.clone(), ClientMode::Interactive);
480 assert!(client.is_ok());
481
482 let client = OptimizedClient::new(options, ClientMode::Batch { max_concurrent: 5 });
484 assert!(client.is_ok());
485 }
486
487 #[test]
488 fn test_connection_pool_creation() {
489 let options = ClaudeCodeOptions::builder().build();
490 let pool = ConnectionPool::new(options, 10);
491
492 assert_eq!(pool.max_connections, 10);
493 }
494
495 #[tokio::test]
496 async fn test_client_cloning() {
497 let options = ClaudeCodeOptions::builder().build();
498 let client = OptimizedClient::new(options, ClientMode::OneShot).unwrap();
499
500 let cloned = client.clone();
501
502 match (client.mode, cloned.mode) {
504 (ClientMode::OneShot, ClientMode::OneShot) => (),
505 _ => panic!("Mode not preserved during cloning"),
506 }
507 }
508}