cc_sdk/
optimized_client.rs

1//! Optimized client implementation with performance improvements
2
3use crate::{
4    errors::{Result, SdkError},
5    transport::{InputMessage, SubprocessTransport, Transport},
6    types::{ClaudeCodeOptions, ControlRequest, Message},
7};
8use crate::token_tracker::{BudgetLimit, BudgetManager, BudgetWarningCallback, TokenUsageTracker};
9use futures::stream::StreamExt;
10use std::collections::VecDeque;
11use std::sync::Arc;
12use tokio::sync::{RwLock, Semaphore, mpsc};
13use tokio::time::{Duration, timeout};
14use tracing::{debug, error, info, warn};
15
16/// Client mode for different usage patterns
17#[derive(Debug, Clone, Copy)]
18pub enum ClientMode {
19    /// One-shot query mode (stateless)
20    OneShot,
21    /// Interactive mode (stateful conversations)
22    Interactive,
23    /// Batch processing mode
24    Batch {
25        /// Maximum number of concurrent requests
26        max_concurrent: usize,
27    },
28}
29
30/// Connection pool for reusing subprocess transports
31struct ConnectionPool {
32    /// Available idle connections
33    idle_connections: Arc<RwLock<VecDeque<Box<dyn Transport + Send>>>>,
34    /// Maximum number of connections
35    max_connections: usize,
36    /// Semaphore for limiting concurrent connections
37    connection_semaphore: Arc<Semaphore>,
38    /// Base options for creating new connections
39    base_options: ClaudeCodeOptions,
40}
41
42impl ConnectionPool {
43    fn new(base_options: ClaudeCodeOptions, max_connections: usize) -> Self {
44        Self {
45            idle_connections: Arc::new(RwLock::new(VecDeque::new())),
46            max_connections,
47            connection_semaphore: Arc::new(Semaphore::new(max_connections)),
48            base_options,
49        }
50    }
51
52    async fn acquire(&self) -> Result<Box<dyn Transport + Send>> {
53        // Try to get an idle connection first
54        {
55            let mut idle = self.idle_connections.write().await;
56            if let Some(transport) = idle.pop_front() {
57                // Verify connection is still valid
58                if transport.is_connected() {
59                    debug!("Reusing existing connection from pool");
60                    return Ok(transport);
61                }
62            }
63        }
64
65        // Create new connection if under limit
66        let _permit =
67            self.connection_semaphore
68                .acquire()
69                .await
70                .map_err(|_| SdkError::InvalidState {
71                    message: "Failed to acquire connection permit".into(),
72                })?;
73
74        let mut transport: Box<dyn Transport + Send> =
75            Box::new(SubprocessTransport::new(self.base_options.clone())?);
76        transport.connect().await?;
77        debug!("Created new connection");
78        Ok(transport)
79    }
80
81    async fn release(&self, transport: Box<dyn Transport + Send>) {
82        if transport.is_connected() && self.idle_connections.read().await.len() < self.max_connections {
83            let mut idle = self.idle_connections.write().await;
84            idle.push_back(transport);
85            debug!("Returned connection to pool");
86        } else {
87            // Connection is invalid or pool is full, let it drop
88            debug!("Dropping connection");
89        }
90    }
91}
92
93/// Optimized client with improved performance characteristics
94pub struct OptimizedClient {
95    /// Client mode
96    mode: ClientMode,
97    /// Connection pool
98    pool: Arc<ConnectionPool>,
99    /// Message receiver for interactive mode
100    message_rx: Arc<RwLock<Option<mpsc::Receiver<Message>>>>,
101    /// Current transport for interactive mode
102    current_transport: Arc<RwLock<Option<Box<dyn Transport + Send>>>>,
103    /// Budget manager for token/cost tracking
104    budget_manager: BudgetManager,
105}
106
107impl OptimizedClient {
108    /// Create a new optimized client
109    pub fn new(options: ClaudeCodeOptions, mode: ClientMode) -> Result<Self> {
110        unsafe {
111            std::env::set_var("CLAUDE_CODE_ENTRYPOINT", "sdk-rust");
112        }
113
114        let max_connections = match mode {
115            ClientMode::Batch { max_concurrent } => max_concurrent,
116            _ => 1,
117        };
118
119        let pool = Arc::new(ConnectionPool::new(options, max_connections));
120
121        Ok(Self {
122            mode,
123            pool,
124            message_rx: Arc::new(RwLock::new(None)),
125            current_transport: Arc::new(RwLock::new(None)),
126            budget_manager: BudgetManager::new(),
127        })
128    }
129
130    /// Execute a one-shot query with automatic retry
131    pub async fn query(&self, prompt: String) -> Result<Vec<Message>> {
132        self.query_with_retry(prompt, 3, Duration::from_millis(100))
133            .await
134    }
135
136    /// Execute a query with custom retry configuration
137    pub async fn query_with_retry(
138        &self,
139        prompt: String,
140        max_retries: u32,
141        initial_delay: Duration,
142    ) -> Result<Vec<Message>> {
143        let mut retries = 0;
144        let mut delay = initial_delay;
145
146        loop {
147            match self.execute_query(&prompt).await {
148                Ok(messages) => return Ok(messages),
149                Err(e) if retries < max_retries => {
150                    warn!("Query failed, retrying in {:?}: {}", delay, e);
151                    tokio::time::sleep(delay).await;
152                    retries += 1;
153                    delay *= 2; // Exponential backoff
154                }
155                Err(e) => return Err(e),
156            }
157        }
158    }
159
160    /// Internal query execution
161    async fn execute_query(&self, prompt: &str) -> Result<Vec<Message>> {
162        let mut transport = self.pool.acquire().await?;
163
164        // Send message
165        let message = InputMessage::user(prompt.to_string(), "default".to_string());
166        transport.send_message(message).await?;
167
168        // Collect response with timeout
169        let timeout_duration = Duration::from_secs(120);
170        let messages = timeout(timeout_duration, self.collect_messages(&mut *transport))
171            .await
172            .map_err(|_| SdkError::Timeout { seconds: 120 })??;
173
174        // Return transport to pool
175        self.pool.release(transport).await;
176
177        Ok(messages)
178    }
179
180    /// Collect messages until Result message
181    async fn collect_messages<T: Transport + Send + ?Sized>(&self, transport: &mut T) -> Result<Vec<Message>> {
182        let mut messages = Vec::new();
183        let mut stream = transport.receive_messages();
184
185        while let Some(result) = stream.next().await {
186            match result {
187                Ok(msg) => {
188                    debug!("Received: {:?}", msg);
189                    let is_result = matches!(msg, Message::Result { .. });
190
191                    // Update budget/usage on result messages
192                    if let Message::Result { usage, total_cost_usd, .. } = &msg {
193                        let (input_tokens, output_tokens) = if let Some(usage_json) = usage {
194                            let input = usage_json
195                                .get("input_tokens")
196                                .and_then(|v| v.as_u64())
197                                .unwrap_or(0);
198                            let output = usage_json
199                                .get("output_tokens")
200                                .and_then(|v| v.as_u64())
201                                .unwrap_or(0);
202                            (input, output)
203                        } else {
204                            (0, 0)
205                        };
206                        let cost = total_cost_usd.unwrap_or(0.0);
207                        self.budget_manager
208                            .update_usage(input_tokens, output_tokens, cost)
209                            .await;
210                    }
211                    messages.push(msg);
212                    if is_result {
213                        break;
214                    }
215                }
216                Err(e) => return Err(e),
217            }
218        }
219
220        Ok(messages)
221    }
222
223    /// Get token/cost usage statistics
224    pub async fn get_usage_stats(&self) -> TokenUsageTracker {
225        self.budget_manager.get_usage().await
226    }
227
228    /// Set budget limit with optional warning callback
229    ///
230    /// Example:
231    /// ```rust,no_run
232    /// use cc_sdk::{OptimizedClient, ClaudeCodeOptions, ClientMode};
233    /// use cc_sdk::token_tracker::{BudgetLimit, BudgetWarningCallback};
234    /// use std::sync::Arc;
235    /// # async fn demo() -> cc_sdk::Result<()> {
236    /// let client = OptimizedClient::new(ClaudeCodeOptions::default(), ClientMode::OneShot)?;
237    /// let cb: BudgetWarningCallback = Arc::new(|msg: &str| println!("Warn: {}", msg));
238    /// client.set_budget_limit(BudgetLimit::with_cost(1.0), Some(cb)).await;
239    /// # Ok(()) }
240    /// ```
241    pub async fn set_budget_limit(
242        &self,
243        limit: BudgetLimit,
244        on_warning: Option<BudgetWarningCallback>,
245    ) {
246        self.budget_manager.set_limit(limit).await;
247        if let Some(cb) = on_warning {
248            self.budget_manager.set_warning_callback(cb).await;
249        }
250    }
251
252    /// Clear budget limit and reset warning state
253    pub async fn clear_budget_limit(&self) {
254        self.budget_manager.clear_limit().await;
255    }
256
257    /// Reset usage statistics to zero
258    pub async fn reset_usage_stats(&self) {
259        self.budget_manager.reset_usage().await;
260    }
261
262    /// Check whether budget is exceeded
263    pub async fn is_budget_exceeded(&self) -> bool {
264        self.budget_manager.is_exceeded().await
265    }
266
267    /// Start an interactive session
268    pub async fn start_interactive_session(&self) -> Result<()> {
269        if !matches!(self.mode, ClientMode::Interactive) {
270            return Err(SdkError::InvalidState {
271                message: "Client not in interactive mode".into(),
272            });
273        }
274
275        // Acquire a transport for the session
276        let transport = self.pool.acquire().await?;
277
278        // Create message channel
279        let (tx, rx) = mpsc::channel::<Message>(100);
280
281        // Store transport and receiver
282        *self.current_transport.write().await = Some(transport);
283        *self.message_rx.write().await = Some(rx);
284
285        // Start background message processor
286        self.start_message_processor(tx).await;
287
288        info!("Interactive session started");
289        Ok(())
290    }
291
292    /// Start background task to process messages
293    async fn start_message_processor(&self, tx: mpsc::Sender<Message>) {
294        let transport_ref = self.current_transport.clone();
295
296        tokio::spawn(async move {
297            loop {
298                // Get message from transport
299                let msg_result = {
300                    let mut transport_guard = transport_ref.write().await;
301                    if let Some(transport) = transport_guard.as_mut() {
302                        let mut stream = transport.receive_messages();
303                        stream.next().await
304                    } else {
305                        break;
306                    }
307                };
308
309                // Process message
310                if let Some(result) = msg_result {
311                    match result {
312                        Ok(msg) => {
313                            if tx.send(msg).await.is_err() {
314                                error!("Failed to send message to channel");
315                                break;
316                            }
317                        }
318                        Err(e) => {
319                            error!("Error receiving message: {}", e);
320                            break;
321                        }
322                    }
323                }
324            }
325        });
326    }
327
328    /// Send a message in interactive mode
329    pub async fn send_interactive(&self, prompt: String) -> Result<()> {
330        let transport_guard = self.current_transport.read().await;
331        if let Some(_transport) = transport_guard.as_ref() {
332            // Need to handle transport mutability properly
333            drop(transport_guard);
334
335            let mut transport_guard = self.current_transport.write().await;
336            if let Some(transport) = transport_guard.as_mut() {
337                let message = InputMessage::user(prompt, "default".to_string());
338                transport.send_message(message).await?;
339            } else {
340                return Err(SdkError::InvalidState {
341                    message: "Transport lost during operation".into(),
342                });
343            }
344            Ok(())
345        } else {
346            Err(SdkError::InvalidState {
347                message: "No active interactive session".into(),
348            })
349        }
350    }
351
352    /// Receive messages in interactive mode
353    pub async fn receive_interactive(&self) -> Result<Vec<Message>> {
354        let mut rx_guard = self.message_rx.write().await;
355        if let Some(rx) = rx_guard.as_mut() {
356            let mut messages = Vec::new();
357
358            // Collect messages until Result
359            while let Some(msg) = rx.recv().await {
360                let is_result = matches!(msg, Message::Result { .. });
361                messages.push(msg);
362                if is_result {
363                    break;
364                }
365            }
366
367            Ok(messages)
368        } else {
369            Err(SdkError::InvalidState {
370                message: "No active interactive session".into(),
371            })
372        }
373    }
374
375    /// Process a batch of queries concurrently
376    pub async fn process_batch(&self, prompts: Vec<String>) -> Result<Vec<Result<Vec<Message>>>> {
377        let max_concurrent = match self.mode {
378            ClientMode::Batch { max_concurrent } => max_concurrent,
379            _ => {
380                return Err(SdkError::InvalidState {
381                    message: "Client not in batch mode".into(),
382                });
383            }
384        };
385
386        let semaphore = Arc::new(Semaphore::new(max_concurrent));
387        let mut handles = Vec::new();
388
389        for prompt in prompts {
390            let permit = semaphore.clone().acquire_owned().await.unwrap();
391            let client = self.clone(); // Assume client is cloneable
392
393            let handle = tokio::spawn(async move {
394                let result = client.query(prompt).await;
395                drop(permit);
396                result
397            });
398
399            handles.push(handle);
400        }
401
402        // Collect results
403        let mut results = Vec::new();
404        for handle in handles {
405            match handle.await {
406                Ok(result) => results.push(result),
407                Err(e) => {
408                    results.push(Err(SdkError::TransportError(format!("Task failed: {e}"))))
409                }
410            }
411        }
412
413        Ok(results)
414    }
415
416    /// Send interrupt signal
417    pub async fn interrupt(&self) -> Result<()> {
418        let transport_guard = self.current_transport.read().await;
419        if let Some(_transport) = transport_guard.as_ref() {
420            drop(transport_guard);
421
422            let mut transport_guard = self.current_transport.write().await;
423            if let Some(transport) = transport_guard.as_mut() {
424                let request = ControlRequest::Interrupt {
425                    request_id: uuid::Uuid::new_v4().to_string(),
426                };
427                transport.send_control_request(request).await?;
428            } else {
429                return Err(SdkError::InvalidState {
430                    message: "Transport lost during operation".into(),
431                });
432            }
433            info!("Interrupt sent");
434            Ok(())
435        } else {
436            Err(SdkError::InvalidState {
437                message: "No active session".into(),
438            })
439        }
440    }
441
442    /// End interactive session
443    pub async fn end_interactive_session(&self) -> Result<()> {
444        // Clear current transport
445        if let Some(transport) = self.current_transport.write().await.take() {
446            self.pool.release(transport).await;
447        }
448
449        // Clear message receiver
450        *self.message_rx.write().await = None;
451
452        info!("Interactive session ended");
453        Ok(())
454    }
455}
456
457// Implement Clone if needed (this is a simplified version)
458impl Clone for OptimizedClient {
459    fn clone(&self) -> Self {
460        Self {
461            mode: self.mode,
462            pool: self.pool.clone(),
463            message_rx: Arc::new(RwLock::new(None)),
464            current_transport: Arc::new(RwLock::new(None)),
465            budget_manager: self.budget_manager.clone(),
466        }
467    }
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473
474    #[test]
475    fn test_client_mode_creation() {
476        let options = ClaudeCodeOptions::builder().build();
477
478        // Test OneShot mode
479        let client = OptimizedClient::new(options.clone(), ClientMode::OneShot);
480        assert!(client.is_ok());
481
482        // Test Interactive mode
483        let client = OptimizedClient::new(options.clone(), ClientMode::Interactive);
484        assert!(client.is_ok());
485
486        // Test Batch mode
487        let client = OptimizedClient::new(options, ClientMode::Batch { max_concurrent: 5 });
488        assert!(client.is_ok());
489    }
490
491    #[test]
492    fn test_connection_pool_creation() {
493        let options = ClaudeCodeOptions::builder().build();
494        let pool = ConnectionPool::new(options, 10);
495
496        assert_eq!(pool.max_connections, 10);
497    }
498
499    #[tokio::test]
500    async fn test_client_cloning() {
501        let options = ClaudeCodeOptions::builder().build();
502        let client = OptimizedClient::new(options, ClientMode::OneShot).unwrap();
503
504        let cloned = client.clone();
505
506        // Verify mode is preserved
507        match (client.mode, cloned.mode) {
508            (ClientMode::OneShot, ClientMode::OneShot) => (),
509            _ => panic!("Mode not preserved during cloning"),
510        }
511    }
512}