Skip to main content

claude_agent_sdk/internal/
pool.rs

1//! Connection pool for reusing CLI processes
2//!
3//! This module implements a connection pool that manages reusable CLI worker processes
4//! to reduce the overhead of spawning new processes for each query.
5//!
6//! # Architecture
7//!
8//! The pool uses a channel-based distribution pattern:
9//! 1. Workers are spawned and kept in a pool
10//! 2. When a query arrives, an available worker is acquired
11//! 3. After the query completes, the worker is returned to the pool
12//! 4. Unhealthy workers are recycled and replaced
13//!
14//! # Performance Targets
15//!
16//! - Reduce query latency from ~300ms to <100ms by reusing processes
17//! - Support concurrent queries with configurable pool size
18//! - Automatic worker health monitoring and replacement
19
20use std::sync::Arc;
21use std::time::Duration;
22
23use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
24use tokio::process::{Child, ChildStdin, ChildStdout, Command};
25use tokio::sync::{mpsc, Mutex, Semaphore};
26use tokio::time::timeout;
27
28use crate::errors::{ClaudeError, ConnectionError, ProcessError, Result};
29use crate::types::config::ClaudeAgentOptions;
30use crate::version::{ENTRYPOINT, SDK_VERSION};
31
32/// Default minimum pool size
33pub const DEFAULT_MIN_POOL_SIZE: usize = 1;
34/// Default maximum pool size
35pub const DEFAULT_MAX_POOL_SIZE: usize = 10;
36/// Default idle timeout for workers (seconds)
37pub const DEFAULT_IDLE_TIMEOUT_SECS: u64 = 300; // 5 minutes
38/// Default health check interval (seconds)
39pub const DEFAULT_HEALTH_CHECK_INTERVAL_SECS: u64 = 60;
40/// Worker acquisition timeout (seconds)
41const ACQUIRE_TIMEOUT_SECS: u64 = 30;
42
43/// Configuration for the connection pool
44#[derive(Clone, Debug)]
45pub struct PoolConfig {
46    /// Minimum number of workers to maintain
47    pub min_size: usize,
48    /// Maximum number of workers allowed
49    pub max_size: usize,
50    /// Idle timeout before recycling a worker
51    pub idle_timeout: Duration,
52    /// Interval for health checks
53    pub health_check_interval: Duration,
54    /// Enable connection pooling
55    pub enabled: bool,
56}
57
58impl Default for PoolConfig {
59    fn default() -> Self {
60        Self {
61            min_size: DEFAULT_MIN_POOL_SIZE,
62            max_size: DEFAULT_MAX_POOL_SIZE,
63            idle_timeout: Duration::from_secs(DEFAULT_IDLE_TIMEOUT_SECS),
64            health_check_interval: Duration::from_secs(DEFAULT_HEALTH_CHECK_INTERVAL_SECS),
65            enabled: false, // Disabled by default for backward compatibility
66        }
67    }
68}
69
70impl PoolConfig {
71    /// Create a new pool configuration
72    pub fn new() -> Self {
73        Self::default()
74    }
75
76    /// Enable the connection pool
77    pub fn enabled(mut self) -> Self {
78        self.enabled = true;
79        self
80    }
81
82    /// Set minimum pool size
83    pub fn min_size(mut self, size: usize) -> Self {
84        self.min_size = size;
85        self
86    }
87
88    /// Set maximum pool size
89    pub fn max_size(mut self, size: usize) -> Self {
90        self.max_size = size;
91        self
92    }
93
94    /// Set idle timeout
95    pub fn idle_timeout(mut self, duration: Duration) -> Self {
96        self.idle_timeout = duration;
97        self
98    }
99}
100
101/// A pooled worker that wraps a CLI process
102struct PooledWorker {
103    /// Worker ID for tracking
104    id: usize,
105    /// The CLI process
106    process: Child,
107    /// Stdin for writing to the process
108    stdin: ChildStdin,
109    /// Stdout reader (wrapped in Arc<Mutex> for sharing)
110    stdout: Arc<Mutex<BufReader<ChildStdout>>>,
111    /// Last activity timestamp
112    last_activity: std::time::Instant,
113    /// Whether this worker is healthy
114    healthy: bool,
115}
116
117impl PooledWorker {
118    /// Create a new pooled worker
119    async fn new(id: usize, options: &ClaudeAgentOptions) -> Result<Self> {
120        let (process, stdin, stdout) = Self::spawn_process(options).await?;
121
122        Ok(Self {
123            id,
124            process,
125            stdin,
126            stdout: Arc::new(Mutex::new(BufReader::new(stdout))),
127            last_activity: std::time::Instant::now(),
128            healthy: true,
129        })
130    }
131
132    /// Spawn a new CLI process
133    async fn spawn_process(
134        options: &ClaudeAgentOptions,
135    ) -> Result<(Child, ChildStdin, ChildStdout)> {
136        use std::process::Stdio;
137
138        let cli_path = if let Some(ref path) = options.cli_path {
139            path.clone()
140        } else {
141            // Use the existing CLI finding logic
142            return Err(ClaudeError::Connection(ConnectionError::new(
143                "CLI path must be specified for pooled connections".to_string(),
144            )));
145        };
146
147        // Build environment
148        let mut env = options.env.clone();
149        env.insert("CLAUDE_CODE_ENTRYPOINT".to_string(), ENTRYPOINT.to_string());
150        env.insert(
151            "CLAUDE_AGENT_SDK_VERSION".to_string(),
152            SDK_VERSION.to_string(),
153        );
154
155        // Build command for streaming mode
156        let mut cmd = Command::new(&cli_path);
157        cmd.args(["--output-format", "stream-json", "--verbose", "--input-format", "stream-json"])
158            .stdin(Stdio::piped())
159            .stdout(Stdio::piped())
160            .stderr(Stdio::null()) // Suppress stderr for pooled workers
161            .envs(&env);
162
163        if let Some(ref cwd) = options.cwd {
164            cmd.current_dir(cwd);
165        }
166
167        // Spawn process
168        let mut child = cmd.spawn().map_err(|e| {
169            ClaudeError::Process(ProcessError::new(
170                format!("Failed to spawn CLI process for pool: {}", e),
171                None,
172                None,
173            ))
174        })?;
175
176        let stdin = child.stdin.take().ok_or_else(|| {
177            ClaudeError::Connection(ConnectionError::new("Failed to get stdin".to_string()))
178        })?;
179
180        let stdout = child.stdout.take().ok_or_else(|| {
181            ClaudeError::Connection(ConnectionError::new("Failed to get stdout".to_string()))
182        })?;
183
184        Ok((child, stdin, stdout))
185    }
186
187    /// Check if the worker is still healthy
188    fn is_healthy(&self) -> bool {
189        self.healthy && self.process.id().is_some()
190    }
191
192    /// Update last activity timestamp
193    fn touch(&mut self) {
194        self.last_activity = std::time::Instant::now();
195    }
196
197    /// Check if worker has been idle too long
198    fn is_idle_timeout(&self, timeout_dur: Duration) -> bool {
199        self.last_activity.elapsed() > timeout_dur
200    }
201
202    /// Write data to the worker's stdin
203    async fn write(&mut self, data: &str) -> Result<()> {
204        self.stdin
205            .write_all(data.as_bytes())
206            .await
207            .map_err(|e| ClaudeError::Transport(format!("Failed to write to pooled worker: {}", e)))?;
208        self.stdin
209            .write_all(b"\n")
210            .await
211            .map_err(|e| ClaudeError::Transport(format!("Failed to write newline: {}", e)))?;
212        self.stdin
213            .flush()
214            .await
215            .map_err(|e| ClaudeError::Transport(format!("Failed to flush pooled worker: {}", e)))?;
216        self.touch();
217        Ok(())
218    }
219
220    /// Read a line from the worker's stdout
221    async fn read_line(&mut self, line: &mut String) -> Result<usize> {
222        let mut stdout = self.stdout.lock().await;
223        let n = stdout
224            .read_line(line)
225            .await
226            .map_err(|e| ClaudeError::Transport(format!("Failed to read from pooled worker: {}", e)))?;
227        drop(stdout); // Release lock before touching
228        self.touch();
229        Ok(n)
230    }
231}
232
233impl Drop for PooledWorker {
234    fn drop(&mut self) {
235        if let Some(pid) = self.process.id() {
236            tracing::debug!("Dropping pooled worker with PID {}", pid);
237            let _ = self.process.start_kill();
238        }
239    }
240}
241
242/// A guard that returns the worker to the pool when dropped
243pub struct WorkerGuard {
244    worker: Option<PooledWorker>,
245    return_tx: mpsc::Sender<PooledWorker>,
246    _permit: Option<tokio::sync::OwnedSemaphorePermit>,
247}
248
249impl WorkerGuard {
250    /// Write data to the worker
251    pub async fn write(&mut self, data: &str) -> Result<()> {
252        if let Some(ref mut worker) = self.worker {
253            worker.write(data).await
254        } else {
255            Err(ClaudeError::Transport("Worker not available".to_string()))
256        }
257    }
258
259    /// Read a line from the worker
260    pub async fn read_line(&mut self, line: &mut String) -> Result<usize> {
261        if let Some(ref mut worker) = self.worker {
262            worker.read_line(line).await
263        } else {
264            Err(ClaudeError::Transport("Worker not available".to_string()))
265        }
266    }
267
268    /// Get the stdout reader for streaming
269    #[allow(dead_code)]
270    pub fn stdout(&self) -> Option<Arc<Mutex<BufReader<ChildStdout>>>> {
271        self.worker.as_ref().map(|w| Arc::clone(&w.stdout))
272    }
273}
274
275impl Drop for WorkerGuard {
276    fn drop(&mut self) {
277        if let Some(worker) = self.worker.take() {
278            // Try to return the worker to the pool (non-blocking)
279            let _ = self.return_tx.try_send(worker);
280        }
281        // Permit is released when _permit is dropped
282    }
283}
284
285/// The connection pool for managing CLI worker processes
286pub struct ConnectionPool {
287    /// Pool configuration
288    config: PoolConfig,
289    /// SDK options for spawning workers
290    options: ClaudeAgentOptions,
291    /// Channel for returning workers to the pool
292    return_tx: mpsc::Sender<PooledWorker>,
293    /// Channel for receiving returned workers (stored in mutex for interior mutability)
294    return_rx: Mutex<mpsc::Receiver<PooledWorker>>,
295    /// Semaphore for limiting concurrent workers
296    semaphore: Arc<Semaphore>,
297    /// Counter for worker IDs
298    next_worker_id: Mutex<usize>,
299    /// Pool state
300    state: Mutex<PoolState>,
301}
302
303struct PoolState {
304    /// Total workers created
305    total_created: usize,
306    /// Active workers
307    active_count: usize,
308}
309
310impl ConnectionPool {
311    /// Create a new connection pool
312    pub fn new(config: PoolConfig, options: ClaudeAgentOptions) -> Self {
313        let (return_tx, return_rx) = mpsc::channel(config.max_size);
314        let semaphore = Arc::new(Semaphore::new(config.max_size));
315
316        Self {
317            config,
318            options,
319            return_tx,
320            return_rx: Mutex::new(return_rx),
321            semaphore,
322            next_worker_id: Mutex::new(0),
323            state: Mutex::new(PoolState {
324                total_created: 0,
325                active_count: 0,
326            }),
327        }
328    }
329
330    /// Initialize the pool with minimum workers
331    pub async fn initialize(&self) -> Result<()> {
332        for _ in 0..self.config.min_size {
333            let worker = self.create_worker().await?;
334            let _ = self.return_tx.try_send(worker);
335        }
336        Ok(())
337    }
338
339    /// Create a new worker
340    async fn create_worker(&self) -> Result<PooledWorker> {
341        let id = {
342            let mut guard = self.next_worker_id.lock().await;
343            *guard += 1;
344            *guard
345        };
346
347        let worker = PooledWorker::new(id, &self.options).await?;
348
349        let mut state = self.state.lock().await;
350        state.total_created += 1;
351        state.active_count += 1;
352
353        tracing::debug!("Created pooled worker {} (total: {}, active: {})",
354            id, state.total_created, state.active_count);
355
356        Ok(worker)
357    }
358
359    /// Acquire a worker from the pool
360    pub async fn acquire(&self) -> Result<WorkerGuard> {
361        // Try to acquire with timeout
362        let permit = timeout(
363            Duration::from_secs(ACQUIRE_TIMEOUT_SECS),
364            Arc::clone(&self.semaphore).acquire_owned(),
365        )
366        .await
367        .map_err(|_| {
368            ClaudeError::Connection(ConnectionError::new(
369                "Timeout acquiring worker from pool".to_string(),
370            ))
371        })?
372        .map_err(|e| {
373            ClaudeError::Connection(ConnectionError::new(format!(
374                "Failed to acquire semaphore: {}",
375                e
376            )))
377        })?;
378
379        // Try to get a worker from the return channel
380        let worker = {
381            let mut rx = self.return_rx.lock().await;
382            match rx.try_recv() {
383                Ok(worker) => {
384                    if worker.is_healthy() && !worker.is_idle_timeout(self.config.idle_timeout) {
385                        Some(worker)
386                    } else {
387                        // Worker is unhealthy or timed out, create new one
388                        tracing::debug!("Recycling unhealthy/timed-out worker {}", worker.id);
389                        None
390                    }
391                }
392                Err(_) => None,
393            }
394        };
395
396        // If no worker available, create a new one
397        let worker = match worker {
398            Some(w) => w,
399            None => self.create_worker().await?,
400        };
401
402        Ok(WorkerGuard {
403            worker: Some(worker),
404            return_tx: self.return_tx.clone(),
405            _permit: Some(permit),
406        })
407    }
408
409    /// Get pool statistics
410    #[allow(dead_code)]
411    pub async fn stats(&self) -> PoolStats {
412        let state = self.state.lock().await;
413        PoolStats {
414            total_created: state.total_created,
415            active_count: state.active_count,
416            available_permits: self.semaphore.available_permits(),
417        }
418    }
419
420    /// Check if the pool is enabled
421    pub fn is_enabled(&self) -> bool {
422        self.config.enabled
423    }
424}
425
426/// Pool statistics
427#[derive(Debug, Clone)]
428pub struct PoolStats {
429    /// Total workers created
430    pub total_created: usize,
431    /// Currently active workers
432    pub active_count: usize,
433    /// Available permits (slots for new workers)
434    pub available_permits: usize,
435}
436
437/// Global connection pool singleton
438static POOL: std::sync::OnceLock<Arc<Mutex<Option<Arc<ConnectionPool>>>>> = std::sync::OnceLock::new();
439
440fn get_pool_singleton() -> &'static Arc<Mutex<Option<Arc<ConnectionPool>>>> {
441    POOL.get_or_init(|| Arc::new(Mutex::new(None)))
442}
443
444/// Initialize the global connection pool
445pub async fn init_global_pool(config: PoolConfig, options: ClaudeAgentOptions) -> Result<()> {
446    let pool = Arc::new(ConnectionPool::new(config, options));
447
448    if pool.is_enabled() {
449        pool.initialize().await?;
450    }
451
452    let global = get_pool_singleton();
453    let mut guard = global.lock().await;
454    *guard = Some(pool);
455
456    Ok(())
457}
458
459/// Get the global connection pool
460pub async fn get_global_pool() -> Option<Arc<ConnectionPool>> {
461    let global = get_pool_singleton();
462    let guard = global.lock().await;
463    guard.clone()
464}
465
466/// Shutdown the global connection pool
467#[allow(dead_code)]
468pub async fn shutdown_global_pool() {
469    let global = get_pool_singleton();
470    let mut guard = global.lock().await;
471    *guard = None;
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477
478    #[test]
479    fn test_pool_config_default() {
480        let config = PoolConfig::default();
481        assert_eq!(config.min_size, DEFAULT_MIN_POOL_SIZE);
482        assert_eq!(config.max_size, DEFAULT_MAX_POOL_SIZE);
483        assert!(!config.enabled);
484    }
485
486    #[test]
487    fn test_pool_config_builder() {
488        let config = PoolConfig::new()
489            .enabled()
490            .min_size(2)
491            .max_size(5);
492
493        assert!(config.enabled);
494        assert_eq!(config.min_size, 2);
495        assert_eq!(config.max_size, 5);
496    }
497}