claude_agent_sdk/internal/
pool.rs1use std::sync::Arc;
21use std::time::Duration;
22
23use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
24use tokio::process::{Child, ChildStdin, ChildStdout, Command};
25use tokio::sync::{mpsc, Mutex, Semaphore};
26use tokio::time::timeout;
27
28use crate::errors::{ClaudeError, ConnectionError, ProcessError, Result};
29use crate::types::config::ClaudeAgentOptions;
30use crate::version::{ENTRYPOINT, SDK_VERSION};
31
32pub const DEFAULT_MIN_POOL_SIZE: usize = 1;
34pub const DEFAULT_MAX_POOL_SIZE: usize = 10;
36pub const DEFAULT_IDLE_TIMEOUT_SECS: u64 = 300; pub const DEFAULT_HEALTH_CHECK_INTERVAL_SECS: u64 = 60;
40const ACQUIRE_TIMEOUT_SECS: u64 = 30;
42
43#[derive(Clone, Debug)]
45pub struct PoolConfig {
46 pub min_size: usize,
48 pub max_size: usize,
50 pub idle_timeout: Duration,
52 pub health_check_interval: Duration,
54 pub enabled: bool,
56}
57
58impl Default for PoolConfig {
59 fn default() -> Self {
60 Self {
61 min_size: DEFAULT_MIN_POOL_SIZE,
62 max_size: DEFAULT_MAX_POOL_SIZE,
63 idle_timeout: Duration::from_secs(DEFAULT_IDLE_TIMEOUT_SECS),
64 health_check_interval: Duration::from_secs(DEFAULT_HEALTH_CHECK_INTERVAL_SECS),
65 enabled: false, }
67 }
68}
69
70impl PoolConfig {
71 pub fn new() -> Self {
73 Self::default()
74 }
75
76 pub fn enabled(mut self) -> Self {
78 self.enabled = true;
79 self
80 }
81
82 pub fn min_size(mut self, size: usize) -> Self {
84 self.min_size = size;
85 self
86 }
87
88 pub fn max_size(mut self, size: usize) -> Self {
90 self.max_size = size;
91 self
92 }
93
94 pub fn idle_timeout(mut self, duration: Duration) -> Self {
96 self.idle_timeout = duration;
97 self
98 }
99}
100
101struct PooledWorker {
103 id: usize,
105 process: Child,
107 stdin: ChildStdin,
109 stdout: Arc<Mutex<BufReader<ChildStdout>>>,
111 last_activity: std::time::Instant,
113 healthy: bool,
115}
116
117impl PooledWorker {
118 async fn new(id: usize, options: &ClaudeAgentOptions) -> Result<Self> {
120 let (process, stdin, stdout) = Self::spawn_process(options).await?;
121
122 Ok(Self {
123 id,
124 process,
125 stdin,
126 stdout: Arc::new(Mutex::new(BufReader::new(stdout))),
127 last_activity: std::time::Instant::now(),
128 healthy: true,
129 })
130 }
131
132 async fn spawn_process(
134 options: &ClaudeAgentOptions,
135 ) -> Result<(Child, ChildStdin, ChildStdout)> {
136 use std::process::Stdio;
137
138 let cli_path = if let Some(ref path) = options.cli_path {
139 path.clone()
140 } else {
141 return Err(ClaudeError::Connection(ConnectionError::new(
143 "CLI path must be specified for pooled connections".to_string(),
144 )));
145 };
146
147 let mut env = options.env.clone();
149 env.insert("CLAUDE_CODE_ENTRYPOINT".to_string(), ENTRYPOINT.to_string());
150 env.insert(
151 "CLAUDE_AGENT_SDK_VERSION".to_string(),
152 SDK_VERSION.to_string(),
153 );
154
155 let mut cmd = Command::new(&cli_path);
157 cmd.args(["--output-format", "stream-json", "--verbose", "--input-format", "stream-json"])
158 .stdin(Stdio::piped())
159 .stdout(Stdio::piped())
160 .stderr(Stdio::null()) .envs(&env);
162
163 if let Some(ref cwd) = options.cwd {
164 cmd.current_dir(cwd);
165 }
166
167 let mut child = cmd.spawn().map_err(|e| {
169 ClaudeError::Process(ProcessError::new(
170 format!("Failed to spawn CLI process for pool: {}", e),
171 None,
172 None,
173 ))
174 })?;
175
176 let stdin = child.stdin.take().ok_or_else(|| {
177 ClaudeError::Connection(ConnectionError::new("Failed to get stdin".to_string()))
178 })?;
179
180 let stdout = child.stdout.take().ok_or_else(|| {
181 ClaudeError::Connection(ConnectionError::new("Failed to get stdout".to_string()))
182 })?;
183
184 Ok((child, stdin, stdout))
185 }
186
187 fn is_healthy(&self) -> bool {
189 self.healthy && self.process.id().is_some()
190 }
191
192 fn touch(&mut self) {
194 self.last_activity = std::time::Instant::now();
195 }
196
197 fn is_idle_timeout(&self, timeout_dur: Duration) -> bool {
199 self.last_activity.elapsed() > timeout_dur
200 }
201
202 async fn write(&mut self, data: &str) -> Result<()> {
204 self.stdin
205 .write_all(data.as_bytes())
206 .await
207 .map_err(|e| ClaudeError::Transport(format!("Failed to write to pooled worker: {}", e)))?;
208 self.stdin
209 .write_all(b"\n")
210 .await
211 .map_err(|e| ClaudeError::Transport(format!("Failed to write newline: {}", e)))?;
212 self.stdin
213 .flush()
214 .await
215 .map_err(|e| ClaudeError::Transport(format!("Failed to flush pooled worker: {}", e)))?;
216 self.touch();
217 Ok(())
218 }
219
220 async fn read_line(&mut self, line: &mut String) -> Result<usize> {
222 let mut stdout = self.stdout.lock().await;
223 let n = stdout
224 .read_line(line)
225 .await
226 .map_err(|e| ClaudeError::Transport(format!("Failed to read from pooled worker: {}", e)))?;
227 drop(stdout); self.touch();
229 Ok(n)
230 }
231}
232
233impl Drop for PooledWorker {
234 fn drop(&mut self) {
235 if let Some(pid) = self.process.id() {
236 tracing::debug!("Dropping pooled worker with PID {}", pid);
237 let _ = self.process.start_kill();
238 }
239 }
240}
241
242pub struct WorkerGuard {
244 worker: Option<PooledWorker>,
245 return_tx: mpsc::Sender<PooledWorker>,
246 _permit: Option<tokio::sync::OwnedSemaphorePermit>,
247}
248
249impl WorkerGuard {
250 pub async fn write(&mut self, data: &str) -> Result<()> {
252 if let Some(ref mut worker) = self.worker {
253 worker.write(data).await
254 } else {
255 Err(ClaudeError::Transport("Worker not available".to_string()))
256 }
257 }
258
259 pub async fn read_line(&mut self, line: &mut String) -> Result<usize> {
261 if let Some(ref mut worker) = self.worker {
262 worker.read_line(line).await
263 } else {
264 Err(ClaudeError::Transport("Worker not available".to_string()))
265 }
266 }
267
268 #[allow(dead_code)]
270 pub fn stdout(&self) -> Option<Arc<Mutex<BufReader<ChildStdout>>>> {
271 self.worker.as_ref().map(|w| Arc::clone(&w.stdout))
272 }
273}
274
275impl Drop for WorkerGuard {
276 fn drop(&mut self) {
277 if let Some(worker) = self.worker.take() {
278 let _ = self.return_tx.try_send(worker);
280 }
281 }
283}
284
285pub struct ConnectionPool {
287 config: PoolConfig,
289 options: ClaudeAgentOptions,
291 return_tx: mpsc::Sender<PooledWorker>,
293 return_rx: Mutex<mpsc::Receiver<PooledWorker>>,
295 semaphore: Arc<Semaphore>,
297 next_worker_id: Mutex<usize>,
299 state: Mutex<PoolState>,
301}
302
303struct PoolState {
304 total_created: usize,
306 active_count: usize,
308}
309
310impl ConnectionPool {
311 pub fn new(config: PoolConfig, options: ClaudeAgentOptions) -> Self {
313 let (return_tx, return_rx) = mpsc::channel(config.max_size);
314 let semaphore = Arc::new(Semaphore::new(config.max_size));
315
316 Self {
317 config,
318 options,
319 return_tx,
320 return_rx: Mutex::new(return_rx),
321 semaphore,
322 next_worker_id: Mutex::new(0),
323 state: Mutex::new(PoolState {
324 total_created: 0,
325 active_count: 0,
326 }),
327 }
328 }
329
330 pub async fn initialize(&self) -> Result<()> {
332 for _ in 0..self.config.min_size {
333 let worker = self.create_worker().await?;
334 let _ = self.return_tx.try_send(worker);
335 }
336 Ok(())
337 }
338
339 async fn create_worker(&self) -> Result<PooledWorker> {
341 let id = {
342 let mut guard = self.next_worker_id.lock().await;
343 *guard += 1;
344 *guard
345 };
346
347 let worker = PooledWorker::new(id, &self.options).await?;
348
349 let mut state = self.state.lock().await;
350 state.total_created += 1;
351 state.active_count += 1;
352
353 tracing::debug!("Created pooled worker {} (total: {}, active: {})",
354 id, state.total_created, state.active_count);
355
356 Ok(worker)
357 }
358
359 pub async fn acquire(&self) -> Result<WorkerGuard> {
361 let permit = timeout(
363 Duration::from_secs(ACQUIRE_TIMEOUT_SECS),
364 Arc::clone(&self.semaphore).acquire_owned(),
365 )
366 .await
367 .map_err(|_| {
368 ClaudeError::Connection(ConnectionError::new(
369 "Timeout acquiring worker from pool".to_string(),
370 ))
371 })?
372 .map_err(|e| {
373 ClaudeError::Connection(ConnectionError::new(format!(
374 "Failed to acquire semaphore: {}",
375 e
376 )))
377 })?;
378
379 let worker = {
381 let mut rx = self.return_rx.lock().await;
382 match rx.try_recv() {
383 Ok(worker) => {
384 if worker.is_healthy() && !worker.is_idle_timeout(self.config.idle_timeout) {
385 Some(worker)
386 } else {
387 tracing::debug!("Recycling unhealthy/timed-out worker {}", worker.id);
389 None
390 }
391 }
392 Err(_) => None,
393 }
394 };
395
396 let worker = match worker {
398 Some(w) => w,
399 None => self.create_worker().await?,
400 };
401
402 Ok(WorkerGuard {
403 worker: Some(worker),
404 return_tx: self.return_tx.clone(),
405 _permit: Some(permit),
406 })
407 }
408
409 #[allow(dead_code)]
411 pub async fn stats(&self) -> PoolStats {
412 let state = self.state.lock().await;
413 PoolStats {
414 total_created: state.total_created,
415 active_count: state.active_count,
416 available_permits: self.semaphore.available_permits(),
417 }
418 }
419
420 pub fn is_enabled(&self) -> bool {
422 self.config.enabled
423 }
424}
425
426#[derive(Debug, Clone)]
428pub struct PoolStats {
429 pub total_created: usize,
431 pub active_count: usize,
433 pub available_permits: usize,
435}
436
437static POOL: std::sync::OnceLock<Arc<Mutex<Option<Arc<ConnectionPool>>>>> = std::sync::OnceLock::new();
439
440fn get_pool_singleton() -> &'static Arc<Mutex<Option<Arc<ConnectionPool>>>> {
441 POOL.get_or_init(|| Arc::new(Mutex::new(None)))
442}
443
444pub async fn init_global_pool(config: PoolConfig, options: ClaudeAgentOptions) -> Result<()> {
446 let pool = Arc::new(ConnectionPool::new(config, options));
447
448 if pool.is_enabled() {
449 pool.initialize().await?;
450 }
451
452 let global = get_pool_singleton();
453 let mut guard = global.lock().await;
454 *guard = Some(pool);
455
456 Ok(())
457}
458
459pub async fn get_global_pool() -> Option<Arc<ConnectionPool>> {
461 let global = get_pool_singleton();
462 let guard = global.lock().await;
463 guard.clone()
464}
465
466#[allow(dead_code)]
468pub async fn shutdown_global_pool() {
469 let global = get_pool_singleton();
470 let mut guard = global.lock().await;
471 *guard = None;
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477
478 #[test]
479 fn test_pool_config_default() {
480 let config = PoolConfig::default();
481 assert_eq!(config.min_size, DEFAULT_MIN_POOL_SIZE);
482 assert_eq!(config.max_size, DEFAULT_MAX_POOL_SIZE);
483 assert!(!config.enabled);
484 }
485
486 #[test]
487 fn test_pool_config_builder() {
488 let config = PoolConfig::new()
489 .enabled()
490 .min_size(2)
491 .max_size(5);
492
493 assert!(config.enabled);
494 assert_eq!(config.min_size, 2);
495 assert_eq!(config.max_size, 5);
496 }
497}