cc_sdk/
optimized_client.rs

1//! Optimized client implementation with performance improvements
2
3use crate::{
4    errors::{Result, SdkError},
5    transport::{InputMessage, SubprocessTransport, Transport},
6    types::{ClaudeCodeOptions, ControlRequest, Message},
7};
8use futures::stream::StreamExt;
9use std::collections::VecDeque;
10use std::sync::Arc;
11use tokio::sync::{RwLock, Semaphore, mpsc};
12use tokio::time::{Duration, timeout};
13use tracing::{debug, error, info, warn};
14
15/// Client mode for different usage patterns
16#[derive(Debug, Clone, Copy)]
17pub enum ClientMode {
18    /// One-shot query mode (stateless)
19    OneShot,
20    /// Interactive mode (stateful conversations)
21    Interactive,
22    /// Batch processing mode
23    Batch {
24        /// Maximum number of concurrent requests
25        max_concurrent: usize,
26    },
27}
28
29/// Connection pool for reusing subprocess transports
30struct ConnectionPool {
31    /// Available idle connections
32    idle_connections: Arc<RwLock<VecDeque<SubprocessTransport>>>,
33    /// Maximum number of connections
34    max_connections: usize,
35    /// Semaphore for limiting concurrent connections
36    connection_semaphore: Arc<Semaphore>,
37    /// Base options for creating new connections
38    base_options: ClaudeCodeOptions,
39}
40
41impl ConnectionPool {
42    fn new(base_options: ClaudeCodeOptions, max_connections: usize) -> Self {
43        Self {
44            idle_connections: Arc::new(RwLock::new(VecDeque::new())),
45            max_connections,
46            connection_semaphore: Arc::new(Semaphore::new(max_connections)),
47            base_options,
48        }
49    }
50
51    async fn acquire(&self) -> Result<SubprocessTransport> {
52        // Try to get an idle connection first
53        {
54            let mut idle = self.idle_connections.write().await;
55            if let Some(transport) = idle.pop_front() {
56                // Verify connection is still valid
57                if transport.is_connected() {
58                    debug!("Reusing existing connection from pool");
59                    return Ok(transport);
60                }
61            }
62        }
63
64        // Create new connection if under limit
65        let _permit =
66            self.connection_semaphore
67                .acquire()
68                .await
69                .map_err(|_| SdkError::InvalidState {
70                    message: "Failed to acquire connection permit".into(),
71                })?;
72
73        let mut transport = SubprocessTransport::new(self.base_options.clone())?;
74        transport.connect().await?;
75        debug!("Created new connection");
76        Ok(transport)
77    }
78
79    async fn release(&self, transport: SubprocessTransport) {
80        if transport.is_connected()
81            && self.idle_connections.read().await.len() < self.max_connections
82        {
83            let mut idle = self.idle_connections.write().await;
84            idle.push_back(transport);
85            debug!("Returned connection to pool");
86        } else {
87            // Connection is invalid or pool is full, let it drop
88            debug!("Dropping connection");
89        }
90    }
91}
92
93/// Optimized client with improved performance characteristics
94pub struct OptimizedClient {
95    /// Client mode
96    mode: ClientMode,
97    /// Connection pool
98    pool: Arc<ConnectionPool>,
99    /// Message receiver for interactive mode
100    message_rx: Arc<RwLock<Option<mpsc::Receiver<Message>>>>,
101    /// Current transport for interactive mode
102    current_transport: Arc<RwLock<Option<SubprocessTransport>>>,
103}
104
105impl OptimizedClient {
106    /// Create a new optimized client
107    pub fn new(options: ClaudeCodeOptions, mode: ClientMode) -> Result<Self> {
108        unsafe {
109            std::env::set_var("CLAUDE_CODE_ENTRYPOINT", "sdk-rust");
110        }
111
112        let max_connections = match mode {
113            ClientMode::Batch { max_concurrent } => max_concurrent,
114            _ => 1,
115        };
116
117        let pool = Arc::new(ConnectionPool::new(options, max_connections));
118
119        Ok(Self {
120            mode,
121            pool,
122            message_rx: Arc::new(RwLock::new(None)),
123            current_transport: Arc::new(RwLock::new(None)),
124        })
125    }
126
127    /// Execute a one-shot query with automatic retry
128    pub async fn query(&self, prompt: String) -> Result<Vec<Message>> {
129        self.query_with_retry(prompt, 3, Duration::from_millis(100))
130            .await
131    }
132
133    /// Execute a query with custom retry configuration
134    pub async fn query_with_retry(
135        &self,
136        prompt: String,
137        max_retries: u32,
138        initial_delay: Duration,
139    ) -> Result<Vec<Message>> {
140        let mut retries = 0;
141        let mut delay = initial_delay;
142
143        loop {
144            match self.execute_query(&prompt).await {
145                Ok(messages) => return Ok(messages),
146                Err(e) if retries < max_retries => {
147                    warn!("Query failed, retrying in {:?}: {}", delay, e);
148                    tokio::time::sleep(delay).await;
149                    retries += 1;
150                    delay *= 2; // Exponential backoff
151                }
152                Err(e) => return Err(e),
153            }
154        }
155    }
156
157    /// Internal query execution
158    async fn execute_query(&self, prompt: &str) -> Result<Vec<Message>> {
159        let mut transport = self.pool.acquire().await?;
160
161        // Send message
162        let message = InputMessage::user(prompt.to_string(), "default".to_string());
163        transport.send_message(message).await?;
164
165        // Collect response with timeout
166        let timeout_duration = Duration::from_secs(120);
167        let messages = timeout(timeout_duration, self.collect_messages(&mut transport))
168            .await
169            .map_err(|_| SdkError::Timeout { seconds: 120 })??;
170
171        // Return transport to pool
172        self.pool.release(transport).await;
173
174        Ok(messages)
175    }
176
177    /// Collect messages until Result message
178    async fn collect_messages(&self, transport: &mut SubprocessTransport) -> Result<Vec<Message>> {
179        let mut messages = Vec::new();
180        let mut stream = transport.receive_messages();
181
182        while let Some(result) = stream.next().await {
183            match result {
184                Ok(msg) => {
185                    debug!("Received: {:?}", msg);
186                    let is_result = matches!(msg, Message::Result { .. });
187                    messages.push(msg);
188                    if is_result {
189                        break;
190                    }
191                }
192                Err(e) => return Err(e),
193            }
194        }
195
196        Ok(messages)
197    }
198
199    /// Start an interactive session
200    pub async fn start_interactive_session(&self) -> Result<()> {
201        if !matches!(self.mode, ClientMode::Interactive) {
202            return Err(SdkError::InvalidState {
203                message: "Client not in interactive mode".into(),
204            });
205        }
206
207        // Acquire a transport for the session
208        let transport = self.pool.acquire().await?;
209
210        // Create message channel
211        let (tx, rx) = mpsc::channel::<Message>(100);
212
213        // Store transport and receiver
214        *self.current_transport.write().await = Some(transport);
215        *self.message_rx.write().await = Some(rx);
216
217        // Start background message processor
218        self.start_message_processor(tx).await;
219
220        info!("Interactive session started");
221        Ok(())
222    }
223
224    /// Start background task to process messages
225    async fn start_message_processor(&self, tx: mpsc::Sender<Message>) {
226        let transport_ref = self.current_transport.clone();
227
228        tokio::spawn(async move {
229            loop {
230                // Get message from transport
231                let msg_result = {
232                    let mut transport_guard = transport_ref.write().await;
233                    if let Some(transport) = transport_guard.as_mut() {
234                        let mut stream = transport.receive_messages();
235                        stream.next().await
236                    } else {
237                        break;
238                    }
239                };
240
241                // Process message
242                if let Some(result) = msg_result {
243                    match result {
244                        Ok(msg) => {
245                            if tx.send(msg).await.is_err() {
246                                error!("Failed to send message to channel");
247                                break;
248                            }
249                        }
250                        Err(e) => {
251                            error!("Error receiving message: {}", e);
252                            break;
253                        }
254                    }
255                }
256            }
257        });
258    }
259
260    /// Send a message in interactive mode
261    pub async fn send_interactive(&self, prompt: String) -> Result<()> {
262        let transport_guard = self.current_transport.read().await;
263        if let Some(_transport) = transport_guard.as_ref() {
264            // Need to handle transport mutability properly
265            drop(transport_guard);
266
267            let mut transport_guard = self.current_transport.write().await;
268            if let Some(transport) = transport_guard.as_mut() {
269                let message = InputMessage::user(prompt, "default".to_string());
270                transport.send_message(message).await?;
271            } else {
272                return Err(SdkError::InvalidState {
273                    message: "Transport lost during operation".into(),
274                });
275            }
276            Ok(())
277        } else {
278            Err(SdkError::InvalidState {
279                message: "No active interactive session".into(),
280            })
281        }
282    }
283
284    /// Receive messages in interactive mode
285    pub async fn receive_interactive(&self) -> Result<Vec<Message>> {
286        let mut rx_guard = self.message_rx.write().await;
287        if let Some(rx) = rx_guard.as_mut() {
288            let mut messages = Vec::new();
289
290            // Collect messages until Result
291            while let Some(msg) = rx.recv().await {
292                let is_result = matches!(msg, Message::Result { .. });
293                messages.push(msg);
294                if is_result {
295                    break;
296                }
297            }
298
299            Ok(messages)
300        } else {
301            Err(SdkError::InvalidState {
302                message: "No active interactive session".into(),
303            })
304        }
305    }
306
307    /// Process a batch of queries concurrently
308    pub async fn process_batch(&self, prompts: Vec<String>) -> Result<Vec<Result<Vec<Message>>>> {
309        let max_concurrent = match self.mode {
310            ClientMode::Batch { max_concurrent } => max_concurrent,
311            _ => {
312                return Err(SdkError::InvalidState {
313                    message: "Client not in batch mode".into(),
314                });
315            }
316        };
317
318        let semaphore = Arc::new(Semaphore::new(max_concurrent));
319        let mut handles = Vec::new();
320
321        for prompt in prompts {
322            let permit = semaphore.clone().acquire_owned().await.unwrap();
323            let client = self.clone(); // Assume client is cloneable
324
325            let handle = tokio::spawn(async move {
326                let result = client.query(prompt).await;
327                drop(permit);
328                result
329            });
330
331            handles.push(handle);
332        }
333
334        // Collect results
335        let mut results = Vec::new();
336        for handle in handles {
337            match handle.await {
338                Ok(result) => results.push(result),
339                Err(e) => {
340                    results.push(Err(SdkError::TransportError(format!("Task failed: {e}"))))
341                }
342            }
343        }
344
345        Ok(results)
346    }
347
348    /// Send interrupt signal
349    pub async fn interrupt(&self) -> Result<()> {
350        let transport_guard = self.current_transport.read().await;
351        if let Some(_transport) = transport_guard.as_ref() {
352            drop(transport_guard);
353
354            let mut transport_guard = self.current_transport.write().await;
355            if let Some(transport) = transport_guard.as_mut() {
356                let request = ControlRequest::Interrupt {
357                    request_id: uuid::Uuid::new_v4().to_string(),
358                };
359                transport.send_control_request(request).await?;
360            } else {
361                return Err(SdkError::InvalidState {
362                    message: "Transport lost during operation".into(),
363                });
364            }
365            info!("Interrupt sent");
366            Ok(())
367        } else {
368            Err(SdkError::InvalidState {
369                message: "No active session".into(),
370            })
371        }
372    }
373
374    /// End interactive session
375    pub async fn end_interactive_session(&self) -> Result<()> {
376        // Clear current transport
377        if let Some(transport) = self.current_transport.write().await.take() {
378            self.pool.release(transport).await;
379        }
380
381        // Clear message receiver
382        *self.message_rx.write().await = None;
383
384        info!("Interactive session ended");
385        Ok(())
386    }
387}
388
389// Implement Clone if needed (this is a simplified version)
390impl Clone for OptimizedClient {
391    fn clone(&self) -> Self {
392        Self {
393            mode: self.mode,
394            pool: self.pool.clone(),
395            message_rx: Arc::new(RwLock::new(None)),
396            current_transport: Arc::new(RwLock::new(None)),
397        }
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    #[test]
406    fn test_client_mode_creation() {
407        let options = ClaudeCodeOptions::builder().build();
408
409        // Test OneShot mode
410        let client = OptimizedClient::new(options.clone(), ClientMode::OneShot);
411        assert!(client.is_ok());
412
413        // Test Interactive mode
414        let client = OptimizedClient::new(options.clone(), ClientMode::Interactive);
415        assert!(client.is_ok());
416
417        // Test Batch mode
418        let client = OptimizedClient::new(options, ClientMode::Batch { max_concurrent: 5 });
419        assert!(client.is_ok());
420    }
421
422    #[test]
423    fn test_connection_pool_creation() {
424        let options = ClaudeCodeOptions::builder().build();
425        let pool = ConnectionPool::new(options, 10);
426
427        assert_eq!(pool.max_connections, 10);
428    }
429
430    #[tokio::test]
431    async fn test_client_cloning() {
432        let options = ClaudeCodeOptions::builder().build();
433        let client = OptimizedClient::new(options, ClientMode::OneShot).unwrap();
434
435        let cloned = client.clone();
436
437        // Verify mode is preserved
438        match (client.mode, cloned.mode) {
439            (ClientMode::OneShot, ClientMode::OneShot) => (),
440            _ => panic!("Mode not preserved during cloning"),
441        }
442    }
443}