1use crate::{
4 errors::{Result, SdkError},
5 transport::{InputMessage, SubprocessTransport, Transport},
6 types::{ClaudeCodeOptions, ControlRequest, Message},
7};
8use futures::stream::StreamExt;
9use std::collections::VecDeque;
10use std::sync::Arc;
11use tokio::sync::{RwLock, Semaphore, mpsc};
12use tokio::time::{Duration, timeout};
13use tracing::{debug, error, info, warn};
14
15#[derive(Debug, Clone, Copy)]
17pub enum ClientMode {
18 OneShot,
20 Interactive,
22 Batch {
24 max_concurrent: usize,
26 },
27}
28
29struct ConnectionPool {
31 idle_connections: Arc<RwLock<VecDeque<SubprocessTransport>>>,
33 max_connections: usize,
35 connection_semaphore: Arc<Semaphore>,
37 base_options: ClaudeCodeOptions,
39}
40
41impl ConnectionPool {
42 fn new(base_options: ClaudeCodeOptions, max_connections: usize) -> Self {
43 Self {
44 idle_connections: Arc::new(RwLock::new(VecDeque::new())),
45 max_connections,
46 connection_semaphore: Arc::new(Semaphore::new(max_connections)),
47 base_options,
48 }
49 }
50
51 async fn acquire(&self) -> Result<SubprocessTransport> {
52 {
54 let mut idle = self.idle_connections.write().await;
55 if let Some(transport) = idle.pop_front() {
56 if transport.is_connected() {
58 debug!("Reusing existing connection from pool");
59 return Ok(transport);
60 }
61 }
62 }
63
64 let _permit =
66 self.connection_semaphore
67 .acquire()
68 .await
69 .map_err(|_| SdkError::InvalidState {
70 message: "Failed to acquire connection permit".into(),
71 })?;
72
73 let mut transport = SubprocessTransport::new(self.base_options.clone())?;
74 transport.connect().await?;
75 debug!("Created new connection");
76 Ok(transport)
77 }
78
79 async fn release(&self, transport: SubprocessTransport) {
80 if transport.is_connected()
81 && self.idle_connections.read().await.len() < self.max_connections
82 {
83 let mut idle = self.idle_connections.write().await;
84 idle.push_back(transport);
85 debug!("Returned connection to pool");
86 } else {
87 debug!("Dropping connection");
89 }
90 }
91}
92
93pub struct OptimizedClient {
95 mode: ClientMode,
97 pool: Arc<ConnectionPool>,
99 message_rx: Arc<RwLock<Option<mpsc::Receiver<Message>>>>,
101 current_transport: Arc<RwLock<Option<SubprocessTransport>>>,
103}
104
105impl OptimizedClient {
106 pub fn new(options: ClaudeCodeOptions, mode: ClientMode) -> Result<Self> {
108 unsafe {
109 std::env::set_var("CLAUDE_CODE_ENTRYPOINT", "sdk-rust");
110 }
111
112 let max_connections = match mode {
113 ClientMode::Batch { max_concurrent } => max_concurrent,
114 _ => 1,
115 };
116
117 let pool = Arc::new(ConnectionPool::new(options, max_connections));
118
119 Ok(Self {
120 mode,
121 pool,
122 message_rx: Arc::new(RwLock::new(None)),
123 current_transport: Arc::new(RwLock::new(None)),
124 })
125 }
126
127 pub async fn query(&self, prompt: String) -> Result<Vec<Message>> {
129 self.query_with_retry(prompt, 3, Duration::from_millis(100))
130 .await
131 }
132
133 pub async fn query_with_retry(
135 &self,
136 prompt: String,
137 max_retries: u32,
138 initial_delay: Duration,
139 ) -> Result<Vec<Message>> {
140 let mut retries = 0;
141 let mut delay = initial_delay;
142
143 loop {
144 match self.execute_query(&prompt).await {
145 Ok(messages) => return Ok(messages),
146 Err(e) if retries < max_retries => {
147 warn!("Query failed, retrying in {:?}: {}", delay, e);
148 tokio::time::sleep(delay).await;
149 retries += 1;
150 delay *= 2; }
152 Err(e) => return Err(e),
153 }
154 }
155 }
156
157 async fn execute_query(&self, prompt: &str) -> Result<Vec<Message>> {
159 let mut transport = self.pool.acquire().await?;
160
161 let message = InputMessage::user(prompt.to_string(), "default".to_string());
163 transport.send_message(message).await?;
164
165 let timeout_duration = Duration::from_secs(120);
167 let messages = timeout(timeout_duration, self.collect_messages(&mut transport))
168 .await
169 .map_err(|_| SdkError::Timeout { seconds: 120 })??;
170
171 self.pool.release(transport).await;
173
174 Ok(messages)
175 }
176
177 async fn collect_messages(&self, transport: &mut SubprocessTransport) -> Result<Vec<Message>> {
179 let mut messages = Vec::new();
180 let mut stream = transport.receive_messages();
181
182 while let Some(result) = stream.next().await {
183 match result {
184 Ok(msg) => {
185 debug!("Received: {:?}", msg);
186 let is_result = matches!(msg, Message::Result { .. });
187 messages.push(msg);
188 if is_result {
189 break;
190 }
191 }
192 Err(e) => return Err(e),
193 }
194 }
195
196 Ok(messages)
197 }
198
199 pub async fn start_interactive_session(&self) -> Result<()> {
201 if !matches!(self.mode, ClientMode::Interactive) {
202 return Err(SdkError::InvalidState {
203 message: "Client not in interactive mode".into(),
204 });
205 }
206
207 let transport = self.pool.acquire().await?;
209
210 let (tx, rx) = mpsc::channel::<Message>(100);
212
213 *self.current_transport.write().await = Some(transport);
215 *self.message_rx.write().await = Some(rx);
216
217 self.start_message_processor(tx).await;
219
220 info!("Interactive session started");
221 Ok(())
222 }
223
224 async fn start_message_processor(&self, tx: mpsc::Sender<Message>) {
226 let transport_ref = self.current_transport.clone();
227
228 tokio::spawn(async move {
229 loop {
230 let msg_result = {
232 let mut transport_guard = transport_ref.write().await;
233 if let Some(transport) = transport_guard.as_mut() {
234 let mut stream = transport.receive_messages();
235 stream.next().await
236 } else {
237 break;
238 }
239 };
240
241 if let Some(result) = msg_result {
243 match result {
244 Ok(msg) => {
245 if tx.send(msg).await.is_err() {
246 error!("Failed to send message to channel");
247 break;
248 }
249 }
250 Err(e) => {
251 error!("Error receiving message: {}", e);
252 break;
253 }
254 }
255 }
256 }
257 });
258 }
259
260 pub async fn send_interactive(&self, prompt: String) -> Result<()> {
262 let transport_guard = self.current_transport.read().await;
263 if let Some(_transport) = transport_guard.as_ref() {
264 drop(transport_guard);
266
267 let mut transport_guard = self.current_transport.write().await;
268 if let Some(transport) = transport_guard.as_mut() {
269 let message = InputMessage::user(prompt, "default".to_string());
270 transport.send_message(message).await?;
271 } else {
272 return Err(SdkError::InvalidState {
273 message: "Transport lost during operation".into(),
274 });
275 }
276 Ok(())
277 } else {
278 Err(SdkError::InvalidState {
279 message: "No active interactive session".into(),
280 })
281 }
282 }
283
284 pub async fn receive_interactive(&self) -> Result<Vec<Message>> {
286 let mut rx_guard = self.message_rx.write().await;
287 if let Some(rx) = rx_guard.as_mut() {
288 let mut messages = Vec::new();
289
290 while let Some(msg) = rx.recv().await {
292 let is_result = matches!(msg, Message::Result { .. });
293 messages.push(msg);
294 if is_result {
295 break;
296 }
297 }
298
299 Ok(messages)
300 } else {
301 Err(SdkError::InvalidState {
302 message: "No active interactive session".into(),
303 })
304 }
305 }
306
307 pub async fn process_batch(&self, prompts: Vec<String>) -> Result<Vec<Result<Vec<Message>>>> {
309 let max_concurrent = match self.mode {
310 ClientMode::Batch { max_concurrent } => max_concurrent,
311 _ => {
312 return Err(SdkError::InvalidState {
313 message: "Client not in batch mode".into(),
314 });
315 }
316 };
317
318 let semaphore = Arc::new(Semaphore::new(max_concurrent));
319 let mut handles = Vec::new();
320
321 for prompt in prompts {
322 let permit = semaphore.clone().acquire_owned().await.unwrap();
323 let client = self.clone(); let handle = tokio::spawn(async move {
326 let result = client.query(prompt).await;
327 drop(permit);
328 result
329 });
330
331 handles.push(handle);
332 }
333
334 let mut results = Vec::new();
336 for handle in handles {
337 match handle.await {
338 Ok(result) => results.push(result),
339 Err(e) => {
340 results.push(Err(SdkError::TransportError(format!("Task failed: {e}"))))
341 }
342 }
343 }
344
345 Ok(results)
346 }
347
348 pub async fn interrupt(&self) -> Result<()> {
350 let transport_guard = self.current_transport.read().await;
351 if let Some(_transport) = transport_guard.as_ref() {
352 drop(transport_guard);
353
354 let mut transport_guard = self.current_transport.write().await;
355 if let Some(transport) = transport_guard.as_mut() {
356 let request = ControlRequest::Interrupt {
357 request_id: uuid::Uuid::new_v4().to_string(),
358 };
359 transport.send_control_request(request).await?;
360 } else {
361 return Err(SdkError::InvalidState {
362 message: "Transport lost during operation".into(),
363 });
364 }
365 info!("Interrupt sent");
366 Ok(())
367 } else {
368 Err(SdkError::InvalidState {
369 message: "No active session".into(),
370 })
371 }
372 }
373
374 pub async fn end_interactive_session(&self) -> Result<()> {
376 if let Some(transport) = self.current_transport.write().await.take() {
378 self.pool.release(transport).await;
379 }
380
381 *self.message_rx.write().await = None;
383
384 info!("Interactive session ended");
385 Ok(())
386 }
387}
388
389impl Clone for OptimizedClient {
391 fn clone(&self) -> Self {
392 Self {
393 mode: self.mode,
394 pool: self.pool.clone(),
395 message_rx: Arc::new(RwLock::new(None)),
396 current_transport: Arc::new(RwLock::new(None)),
397 }
398 }
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404
405 #[test]
406 fn test_client_mode_creation() {
407 let options = ClaudeCodeOptions::builder().build();
408
409 let client = OptimizedClient::new(options.clone(), ClientMode::OneShot);
411 assert!(client.is_ok());
412
413 let client = OptimizedClient::new(options.clone(), ClientMode::Interactive);
415 assert!(client.is_ok());
416
417 let client = OptimizedClient::new(options, ClientMode::Batch { max_concurrent: 5 });
419 assert!(client.is_ok());
420 }
421
422 #[test]
423 fn test_connection_pool_creation() {
424 let options = ClaudeCodeOptions::builder().build();
425 let pool = ConnectionPool::new(options, 10);
426
427 assert_eq!(pool.max_connections, 10);
428 }
429
430 #[tokio::test]
431 async fn test_client_cloning() {
432 let options = ClaudeCodeOptions::builder().build();
433 let client = OptimizedClient::new(options, ClientMode::OneShot).unwrap();
434
435 let cloned = client.clone();
436
437 match (client.mode, cloned.mode) {
439 (ClientMode::OneShot, ClientMode::OneShot) => (),
440 _ => panic!("Mode not preserved during cloning"),
441 }
442 }
443}