Skip to main content

cc_sdk/
optimized_client.rs

1//! Optimized client implementation with performance improvements
2
3use crate::{
4    errors::{Result, SdkError},
5    transport::{InputMessage, SubprocessTransport, Transport},
6    types::{ClaudeCodeOptions, ControlRequest, Message},
7};
8use crate::token_tracker::{BudgetLimit, BudgetManager, BudgetWarningCallback, TokenUsageTracker};
9use futures::stream::StreamExt;
10use std::collections::VecDeque;
11use std::sync::Arc;
12use tokio::sync::{RwLock, Semaphore, mpsc};
13use tokio::time::{Duration, timeout};
14use tracing::{debug, error, info, warn};
15
16/// Client mode for different usage patterns
17#[derive(Debug, Clone, Copy)]
18pub enum ClientMode {
19    /// One-shot query mode (stateless)
20    OneShot,
21    /// Interactive mode (stateful conversations)
22    Interactive,
23    /// Batch processing mode
24    Batch {
25        /// Maximum number of concurrent requests
26        max_concurrent: usize,
27    },
28}
29
30/// Connection pool for reusing subprocess transports
31struct ConnectionPool {
32    /// Available idle connections
33    idle_connections: Arc<RwLock<VecDeque<Box<dyn Transport + Send>>>>,
34    /// Maximum number of connections
35    max_connections: usize,
36    /// Semaphore for limiting concurrent connections
37    connection_semaphore: Arc<Semaphore>,
38    /// Base options for creating new connections
39    base_options: ClaudeCodeOptions,
40}
41
42impl ConnectionPool {
43    fn new(base_options: ClaudeCodeOptions, max_connections: usize) -> Self {
44        Self {
45            idle_connections: Arc::new(RwLock::new(VecDeque::new())),
46            max_connections,
47            connection_semaphore: Arc::new(Semaphore::new(max_connections)),
48            base_options,
49        }
50    }
51
52    async fn acquire(&self) -> Result<Box<dyn Transport + Send>> {
53        // Try to get an idle connection first
54        {
55            let mut idle = self.idle_connections.write().await;
56            if let Some(transport) = idle.pop_front() {
57                // Verify connection is still valid
58                if transport.is_connected() {
59                    debug!("Reusing existing connection from pool");
60                    return Ok(transport);
61                }
62            }
63        }
64
65        // Create new connection if under limit
66        let _permit =
67            self.connection_semaphore
68                .acquire()
69                .await
70                .map_err(|_| SdkError::InvalidState {
71                    message: "Failed to acquire connection permit".into(),
72                })?;
73
74        let mut transport: Box<dyn Transport + Send> =
75            Box::new(SubprocessTransport::new(self.base_options.clone())?);
76        transport.connect().await?;
77        debug!("Created new connection");
78        Ok(transport)
79    }
80
81    async fn release(&self, transport: Box<dyn Transport + Send>) {
82        if transport.is_connected() && self.idle_connections.read().await.len() < self.max_connections {
83            let mut idle = self.idle_connections.write().await;
84            idle.push_back(transport);
85            debug!("Returned connection to pool");
86        } else {
87            // Connection is invalid or pool is full, let it drop
88            debug!("Dropping connection");
89        }
90    }
91}
92
93/// Optimized client with improved performance characteristics
94pub struct OptimizedClient {
95    /// Client mode
96    mode: ClientMode,
97    /// Connection pool
98    pool: Arc<ConnectionPool>,
99    /// Message receiver for interactive mode
100    message_rx: Arc<RwLock<Option<mpsc::Receiver<Message>>>>,
101    /// Current transport for interactive mode
102    current_transport: Arc<RwLock<Option<Box<dyn Transport + Send>>>>,
103    /// Budget manager for token/cost tracking
104    budget_manager: BudgetManager,
105}
106
107impl OptimizedClient {
108    /// Create a new optimized client
109    pub fn new(options: ClaudeCodeOptions, mode: ClientMode) -> Result<Self> {
110        let max_connections = match mode {
111            ClientMode::Batch { max_concurrent } => max_concurrent,
112            _ => 1,
113        };
114
115        let pool = Arc::new(ConnectionPool::new(options, max_connections));
116
117        Ok(Self {
118            mode,
119            pool,
120            message_rx: Arc::new(RwLock::new(None)),
121            current_transport: Arc::new(RwLock::new(None)),
122            budget_manager: BudgetManager::new(),
123        })
124    }
125
126    /// Execute a one-shot query with automatic retry
127    pub async fn query(&self, prompt: String) -> Result<Vec<Message>> {
128        self.query_with_retry(prompt, 3, Duration::from_millis(100))
129            .await
130    }
131
132    /// Execute a query with custom retry configuration
133    pub async fn query_with_retry(
134        &self,
135        prompt: String,
136        max_retries: u32,
137        initial_delay: Duration,
138    ) -> Result<Vec<Message>> {
139        let mut retries = 0;
140        let mut delay = initial_delay;
141
142        loop {
143            match self.execute_query(&prompt).await {
144                Ok(messages) => return Ok(messages),
145                Err(e) if retries < max_retries => {
146                    warn!("Query failed, retrying in {:?}: {}", delay, e);
147                    tokio::time::sleep(delay).await;
148                    retries += 1;
149                    delay *= 2; // Exponential backoff
150                }
151                Err(e) => return Err(e),
152            }
153        }
154    }
155
156    /// Internal query execution
157    async fn execute_query(&self, prompt: &str) -> Result<Vec<Message>> {
158        let mut transport = self.pool.acquire().await?;
159
160        // Send message
161        let message = InputMessage::user(prompt.to_string(), "default".to_string());
162        transport.send_message(message).await?;
163
164        // Collect response with timeout
165        let timeout_duration = Duration::from_secs(120);
166        let messages = timeout(timeout_duration, self.collect_messages(&mut *transport))
167            .await
168            .map_err(|_| SdkError::Timeout { seconds: 120 })??;
169
170        // Return transport to pool
171        self.pool.release(transport).await;
172
173        Ok(messages)
174    }
175
176    /// Collect messages until Result message
177    async fn collect_messages<T: Transport + Send + ?Sized>(&self, transport: &mut T) -> Result<Vec<Message>> {
178        let mut messages = Vec::new();
179        let mut stream = transport.receive_messages();
180
181        while let Some(result) = stream.next().await {
182            match result {
183                Ok(msg) => {
184                    debug!("Received: {:?}", msg);
185                    let is_result = matches!(msg, Message::Result { .. });
186
187                    // Update budget/usage on result messages
188                    if let Message::Result { usage, total_cost_usd, .. } = &msg {
189                        let (input_tokens, output_tokens) = if let Some(usage_json) = usage {
190                            let input = usage_json
191                                .get("input_tokens")
192                                .and_then(|v| v.as_u64())
193                                .unwrap_or(0);
194                            let output = usage_json
195                                .get("output_tokens")
196                                .and_then(|v| v.as_u64())
197                                .unwrap_or(0);
198                            (input, output)
199                        } else {
200                            (0, 0)
201                        };
202                        let cost = total_cost_usd.unwrap_or(0.0);
203                        self.budget_manager
204                            .update_usage(input_tokens, output_tokens, cost)
205                            .await;
206                    }
207                    messages.push(msg);
208                    if is_result {
209                        break;
210                    }
211                }
212                Err(e) => return Err(e),
213            }
214        }
215
216        Ok(messages)
217    }
218
219    /// Get token/cost usage statistics
220    pub async fn get_usage_stats(&self) -> TokenUsageTracker {
221        self.budget_manager.get_usage().await
222    }
223
224    /// Set budget limit with optional warning callback
225    ///
226    /// Example:
227    /// ```rust,no_run
228    /// use cc_sdk::{OptimizedClient, ClaudeCodeOptions, ClientMode};
229    /// use cc_sdk::token_tracker::{BudgetLimit, BudgetWarningCallback};
230    /// use std::sync::Arc;
231    /// # async fn demo() -> cc_sdk::Result<()> {
232    /// let client = OptimizedClient::new(ClaudeCodeOptions::default(), ClientMode::OneShot)?;
233    /// let cb: BudgetWarningCallback = Arc::new(|msg: &str| println!("Warn: {}", msg));
234    /// client.set_budget_limit(BudgetLimit::with_cost(1.0), Some(cb)).await;
235    /// # Ok(()) }
236    /// ```
237    pub async fn set_budget_limit(
238        &self,
239        limit: BudgetLimit,
240        on_warning: Option<BudgetWarningCallback>,
241    ) {
242        self.budget_manager.set_limit(limit).await;
243        if let Some(cb) = on_warning {
244            self.budget_manager.set_warning_callback(cb).await;
245        }
246    }
247
248    /// Clear budget limit and reset warning state
249    pub async fn clear_budget_limit(&self) {
250        self.budget_manager.clear_limit().await;
251    }
252
253    /// Reset usage statistics to zero
254    pub async fn reset_usage_stats(&self) {
255        self.budget_manager.reset_usage().await;
256    }
257
258    /// Check whether budget is exceeded
259    pub async fn is_budget_exceeded(&self) -> bool {
260        self.budget_manager.is_exceeded().await
261    }
262
263    /// Start an interactive session
264    pub async fn start_interactive_session(&self) -> Result<()> {
265        if !matches!(self.mode, ClientMode::Interactive) {
266            return Err(SdkError::InvalidState {
267                message: "Client not in interactive mode".into(),
268            });
269        }
270
271        // Acquire a transport for the session
272        let transport = self.pool.acquire().await?;
273
274        // Create message channel
275        let (tx, rx) = mpsc::channel::<Message>(100);
276
277        // Store transport and receiver
278        *self.current_transport.write().await = Some(transport);
279        *self.message_rx.write().await = Some(rx);
280
281        // Start background message processor
282        self.start_message_processor(tx).await;
283
284        info!("Interactive session started");
285        Ok(())
286    }
287
288    /// Start background task to process messages
289    async fn start_message_processor(&self, tx: mpsc::Sender<Message>) {
290        let transport_ref = self.current_transport.clone();
291
292        tokio::spawn(async move {
293            loop {
294                // Get message from transport
295                let msg_result = {
296                    let mut transport_guard = transport_ref.write().await;
297                    if let Some(transport) = transport_guard.as_mut() {
298                        let mut stream = transport.receive_messages();
299                        stream.next().await
300                    } else {
301                        break;
302                    }
303                };
304
305                // Process message
306                if let Some(result) = msg_result {
307                    match result {
308                        Ok(msg) => {
309                            if tx.send(msg).await.is_err() {
310                                error!("Failed to send message to channel");
311                                break;
312                            }
313                        }
314                        Err(e) => {
315                            error!("Error receiving message: {}", e);
316                            break;
317                        }
318                    }
319                }
320            }
321        });
322    }
323
324    /// Send a message in interactive mode
325    pub async fn send_interactive(&self, prompt: String) -> Result<()> {
326        let transport_guard = self.current_transport.read().await;
327        if let Some(_transport) = transport_guard.as_ref() {
328            // Need to handle transport mutability properly
329            drop(transport_guard);
330
331            let mut transport_guard = self.current_transport.write().await;
332            if let Some(transport) = transport_guard.as_mut() {
333                let message = InputMessage::user(prompt, "default".to_string());
334                transport.send_message(message).await?;
335            } else {
336                return Err(SdkError::InvalidState {
337                    message: "Transport lost during operation".into(),
338                });
339            }
340            Ok(())
341        } else {
342            Err(SdkError::InvalidState {
343                message: "No active interactive session".into(),
344            })
345        }
346    }
347
348    /// Receive messages in interactive mode
349    pub async fn receive_interactive(&self) -> Result<Vec<Message>> {
350        let mut rx_guard = self.message_rx.write().await;
351        if let Some(rx) = rx_guard.as_mut() {
352            let mut messages = Vec::new();
353
354            // Collect messages until Result
355            while let Some(msg) = rx.recv().await {
356                let is_result = matches!(msg, Message::Result { .. });
357                messages.push(msg);
358                if is_result {
359                    break;
360                }
361            }
362
363            Ok(messages)
364        } else {
365            Err(SdkError::InvalidState {
366                message: "No active interactive session".into(),
367            })
368        }
369    }
370
371    /// Process a batch of queries concurrently
372    pub async fn process_batch(&self, prompts: Vec<String>) -> Result<Vec<Result<Vec<Message>>>> {
373        let max_concurrent = match self.mode {
374            ClientMode::Batch { max_concurrent } => max_concurrent,
375            _ => {
376                return Err(SdkError::InvalidState {
377                    message: "Client not in batch mode".into(),
378                });
379            }
380        };
381
382        let semaphore = Arc::new(Semaphore::new(max_concurrent));
383        let mut handles = Vec::new();
384
385        for prompt in prompts {
386            let permit = semaphore.clone().acquire_owned().await.unwrap();
387            let client = self.clone(); // Assume client is cloneable
388
389            let handle = tokio::spawn(async move {
390                let result = client.query(prompt).await;
391                drop(permit);
392                result
393            });
394
395            handles.push(handle);
396        }
397
398        // Collect results
399        let mut results = Vec::new();
400        for handle in handles {
401            match handle.await {
402                Ok(result) => results.push(result),
403                Err(e) => {
404                    results.push(Err(SdkError::TransportError(format!("Task failed: {e}"))))
405                }
406            }
407        }
408
409        Ok(results)
410    }
411
412    /// Send interrupt signal
413    pub async fn interrupt(&self) -> Result<()> {
414        let transport_guard = self.current_transport.read().await;
415        if let Some(_transport) = transport_guard.as_ref() {
416            drop(transport_guard);
417
418            let mut transport_guard = self.current_transport.write().await;
419            if let Some(transport) = transport_guard.as_mut() {
420                let request = ControlRequest::Interrupt {
421                    request_id: uuid::Uuid::new_v4().to_string(),
422                };
423                transport.send_control_request(request).await?;
424            } else {
425                return Err(SdkError::InvalidState {
426                    message: "Transport lost during operation".into(),
427                });
428            }
429            info!("Interrupt sent");
430            Ok(())
431        } else {
432            Err(SdkError::InvalidState {
433                message: "No active session".into(),
434            })
435        }
436    }
437
438    /// End interactive session
439    pub async fn end_interactive_session(&self) -> Result<()> {
440        // Clear current transport
441        if let Some(transport) = self.current_transport.write().await.take() {
442            self.pool.release(transport).await;
443        }
444
445        // Clear message receiver
446        *self.message_rx.write().await = None;
447
448        info!("Interactive session ended");
449        Ok(())
450    }
451}
452
453// Implement Clone if needed (this is a simplified version)
454impl Clone for OptimizedClient {
455    fn clone(&self) -> Self {
456        Self {
457            mode: self.mode,
458            pool: self.pool.clone(),
459            message_rx: Arc::new(RwLock::new(None)),
460            current_transport: Arc::new(RwLock::new(None)),
461            budget_manager: self.budget_manager.clone(),
462        }
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    #[test]
471    fn test_client_mode_creation() {
472        let options = ClaudeCodeOptions::builder().build();
473
474        // Test OneShot mode
475        let client = OptimizedClient::new(options.clone(), ClientMode::OneShot);
476        assert!(client.is_ok());
477
478        // Test Interactive mode
479        let client = OptimizedClient::new(options.clone(), ClientMode::Interactive);
480        assert!(client.is_ok());
481
482        // Test Batch mode
483        let client = OptimizedClient::new(options, ClientMode::Batch { max_concurrent: 5 });
484        assert!(client.is_ok());
485    }
486
487    #[test]
488    fn test_connection_pool_creation() {
489        let options = ClaudeCodeOptions::builder().build();
490        let pool = ConnectionPool::new(options, 10);
491
492        assert_eq!(pool.max_connections, 10);
493    }
494
495    #[tokio::test]
496    async fn test_client_cloning() {
497        let options = ClaudeCodeOptions::builder().build();
498        let client = OptimizedClient::new(options, ClientMode::OneShot).unwrap();
499
500        let cloned = client.clone();
501
502        // Verify mode is preserved
503        match (client.mode, cloned.mode) {
504            (ClientMode::OneShot, ClientMode::OneShot) => (),
505            _ => panic!("Mode not preserved during cloning"),
506        }
507    }
508}