use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use tokio::sync::{mpsc, Mutex, Semaphore};
use tokio::time::timeout;
use crate::errors::{ClaudeError, ConnectionError, ProcessError, Result};
use crate::types::config::ClaudeAgentOptions;
use crate::version::{ENTRYPOINT, SDK_VERSION};
pub const DEFAULT_MIN_POOL_SIZE: usize = 1;
pub const DEFAULT_MAX_POOL_SIZE: usize = 10;
pub const DEFAULT_IDLE_TIMEOUT_SECS: u64 = 300; pub const DEFAULT_HEALTH_CHECK_INTERVAL_SECS: u64 = 60;
const ACQUIRE_TIMEOUT_SECS: u64 = 30;
#[derive(Clone, Debug)]
pub struct PoolConfig {
pub min_size: usize,
pub max_size: usize,
pub idle_timeout: Duration,
pub health_check_interval: Duration,
pub enabled: bool,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
min_size: DEFAULT_MIN_POOL_SIZE,
max_size: DEFAULT_MAX_POOL_SIZE,
idle_timeout: Duration::from_secs(DEFAULT_IDLE_TIMEOUT_SECS),
health_check_interval: Duration::from_secs(DEFAULT_HEALTH_CHECK_INTERVAL_SECS),
enabled: false, }
}
}
impl PoolConfig {
pub fn new() -> Self {
Self::default()
}
pub fn enabled(mut self) -> Self {
self.enabled = true;
self
}
pub fn min_size(mut self, size: usize) -> Self {
self.min_size = size;
self
}
pub fn max_size(mut self, size: usize) -> Self {
self.max_size = size;
self
}
pub fn idle_timeout(mut self, duration: Duration) -> Self {
self.idle_timeout = duration;
self
}
}
struct PooledWorker {
id: usize,
process: Child,
stdin: ChildStdin,
stdout: Arc<Mutex<BufReader<ChildStdout>>>,
last_activity: std::time::Instant,
healthy: bool,
}
impl PooledWorker {
async fn new(id: usize, options: &ClaudeAgentOptions) -> Result<Self> {
let (process, stdin, stdout) = Self::spawn_process(options).await?;
Ok(Self {
id,
process,
stdin,
stdout: Arc::new(Mutex::new(BufReader::new(stdout))),
last_activity: std::time::Instant::now(),
healthy: true,
})
}
async fn spawn_process(
options: &ClaudeAgentOptions,
) -> Result<(Child, ChildStdin, ChildStdout)> {
use std::process::Stdio;
let cli_path = if let Some(ref path) = options.cli_path {
path.clone()
} else {
return Err(ClaudeError::Connection(ConnectionError::new(
"CLI path must be specified for pooled connections".to_string(),
)));
};
let mut env = options.env.clone();
env.insert("CLAUDE_CODE_ENTRYPOINT".to_string(), ENTRYPOINT.to_string());
env.insert(
"CLAUDE_AGENT_SDK_VERSION".to_string(),
SDK_VERSION.to_string(),
);
let mut cmd = Command::new(&cli_path);
cmd.args(["--output-format", "stream-json", "--verbose", "--input-format", "stream-json"])
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null()) .envs(&env);
if let Some(ref cwd) = options.cwd {
cmd.current_dir(cwd);
}
let mut child = cmd.spawn().map_err(|e| {
ClaudeError::Process(ProcessError::new(
format!("Failed to spawn CLI process for pool: {}", e),
None,
None,
))
})?;
let stdin = child.stdin.take().ok_or_else(|| {
ClaudeError::Connection(ConnectionError::new("Failed to get stdin".to_string()))
})?;
let stdout = child.stdout.take().ok_or_else(|| {
ClaudeError::Connection(ConnectionError::new("Failed to get stdout".to_string()))
})?;
Ok((child, stdin, stdout))
}
fn is_healthy(&self) -> bool {
self.healthy && self.process.id().is_some()
}
fn touch(&mut self) {
self.last_activity = std::time::Instant::now();
}
fn is_idle_timeout(&self, timeout_dur: Duration) -> bool {
self.last_activity.elapsed() > timeout_dur
}
async fn write(&mut self, data: &str) -> Result<()> {
self.stdin
.write_all(data.as_bytes())
.await
.map_err(|e| ClaudeError::Transport(format!("Failed to write to pooled worker: {}", e)))?;
self.stdin
.write_all(b"\n")
.await
.map_err(|e| ClaudeError::Transport(format!("Failed to write newline: {}", e)))?;
self.stdin
.flush()
.await
.map_err(|e| ClaudeError::Transport(format!("Failed to flush pooled worker: {}", e)))?;
self.touch();
Ok(())
}
async fn read_line(&mut self, line: &mut String) -> Result<usize> {
let mut stdout = self.stdout.lock().await;
let n = stdout
.read_line(line)
.await
.map_err(|e| ClaudeError::Transport(format!("Failed to read from pooled worker: {}", e)))?;
drop(stdout); self.touch();
Ok(n)
}
}
impl Drop for PooledWorker {
fn drop(&mut self) {
if let Some(pid) = self.process.id() {
tracing::debug!("Dropping pooled worker with PID {}", pid);
let _ = self.process.start_kill();
}
}
}
pub struct WorkerGuard {
worker: Option<PooledWorker>,
return_tx: mpsc::Sender<PooledWorker>,
_permit: Option<tokio::sync::OwnedSemaphorePermit>,
}
impl WorkerGuard {
pub async fn write(&mut self, data: &str) -> Result<()> {
if let Some(ref mut worker) = self.worker {
worker.write(data).await
} else {
Err(ClaudeError::Transport("Worker not available".to_string()))
}
}
pub async fn read_line(&mut self, line: &mut String) -> Result<usize> {
if let Some(ref mut worker) = self.worker {
worker.read_line(line).await
} else {
Err(ClaudeError::Transport("Worker not available".to_string()))
}
}
#[allow(dead_code)]
pub fn stdout(&self) -> Option<Arc<Mutex<BufReader<ChildStdout>>>> {
self.worker.as_ref().map(|w| Arc::clone(&w.stdout))
}
}
impl Drop for WorkerGuard {
fn drop(&mut self) {
if let Some(worker) = self.worker.take() {
let _ = self.return_tx.try_send(worker);
}
}
}
pub struct ConnectionPool {
config: PoolConfig,
options: ClaudeAgentOptions,
return_tx: mpsc::Sender<PooledWorker>,
return_rx: Mutex<mpsc::Receiver<PooledWorker>>,
semaphore: Arc<Semaphore>,
next_worker_id: Mutex<usize>,
state: Mutex<PoolState>,
}
struct PoolState {
total_created: usize,
active_count: usize,
}
impl ConnectionPool {
pub fn new(config: PoolConfig, options: ClaudeAgentOptions) -> Self {
let (return_tx, return_rx) = mpsc::channel(config.max_size);
let semaphore = Arc::new(Semaphore::new(config.max_size));
Self {
config,
options,
return_tx,
return_rx: Mutex::new(return_rx),
semaphore,
next_worker_id: Mutex::new(0),
state: Mutex::new(PoolState {
total_created: 0,
active_count: 0,
}),
}
}
pub async fn initialize(&self) -> Result<()> {
for _ in 0..self.config.min_size {
let worker = self.create_worker().await?;
let _ = self.return_tx.try_send(worker);
}
Ok(())
}
async fn create_worker(&self) -> Result<PooledWorker> {
let id = {
let mut guard = self.next_worker_id.lock().await;
*guard += 1;
*guard
};
let worker = PooledWorker::new(id, &self.options).await?;
let mut state = self.state.lock().await;
state.total_created += 1;
state.active_count += 1;
tracing::debug!("Created pooled worker {} (total: {}, active: {})",
id, state.total_created, state.active_count);
Ok(worker)
}
pub async fn acquire(&self) -> Result<WorkerGuard> {
let permit = timeout(
Duration::from_secs(ACQUIRE_TIMEOUT_SECS),
Arc::clone(&self.semaphore).acquire_owned(),
)
.await
.map_err(|_| {
ClaudeError::Connection(ConnectionError::new(
"Timeout acquiring worker from pool".to_string(),
))
})?
.map_err(|e| {
ClaudeError::Connection(ConnectionError::new(format!(
"Failed to acquire semaphore: {}",
e
)))
})?;
let worker = {
let mut rx = self.return_rx.lock().await;
match rx.try_recv() {
Ok(worker) => {
if worker.is_healthy() && !worker.is_idle_timeout(self.config.idle_timeout) {
Some(worker)
} else {
tracing::debug!("Recycling unhealthy/timed-out worker {}", worker.id);
None
}
}
Err(_) => None,
}
};
let worker = match worker {
Some(w) => w,
None => self.create_worker().await?,
};
Ok(WorkerGuard {
worker: Some(worker),
return_tx: self.return_tx.clone(),
_permit: Some(permit),
})
}
#[allow(dead_code)]
pub async fn stats(&self) -> PoolStats {
let state = self.state.lock().await;
PoolStats {
total_created: state.total_created,
active_count: state.active_count,
available_permits: self.semaphore.available_permits(),
}
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
}
#[derive(Debug, Clone)]
pub struct PoolStats {
pub total_created: usize,
pub active_count: usize,
pub available_permits: usize,
}
static POOL: std::sync::OnceLock<Arc<Mutex<Option<Arc<ConnectionPool>>>>> = std::sync::OnceLock::new();
fn get_pool_singleton() -> &'static Arc<Mutex<Option<Arc<ConnectionPool>>>> {
POOL.get_or_init(|| Arc::new(Mutex::new(None)))
}
pub async fn init_global_pool(config: PoolConfig, options: ClaudeAgentOptions) -> Result<()> {
let pool = Arc::new(ConnectionPool::new(config, options));
if pool.is_enabled() {
pool.initialize().await?;
}
let global = get_pool_singleton();
let mut guard = global.lock().await;
*guard = Some(pool);
Ok(())
}
pub async fn get_global_pool() -> Option<Arc<ConnectionPool>> {
let global = get_pool_singleton();
let guard = global.lock().await;
guard.clone()
}
#[allow(dead_code)]
pub async fn shutdown_global_pool() {
let global = get_pool_singleton();
let mut guard = global.lock().await;
*guard = None;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_config_default() {
let config = PoolConfig::default();
assert_eq!(config.min_size, DEFAULT_MIN_POOL_SIZE);
assert_eq!(config.max_size, DEFAULT_MAX_POOL_SIZE);
assert!(!config.enabled);
}
#[test]
fn test_pool_config_builder() {
let config = PoolConfig::new()
.enabled()
.min_size(2)
.max_size(5);
assert!(config.enabled);
assert_eq!(config.min_size, 2);
assert_eq!(config.max_size, 5);
}
}