Skip to main content

ares/llm/
pool.rs

1//! LLM Client Connection Pooling (DIR-44)
2//!
3//! This module provides connection pooling for LLM clients, enabling connection
4//! reuse across requests to reduce latency and resource consumption.
5//!
6//! # Architecture
7//!
8//! The pool maintains a set of pre-initialized `LLMClient` instances per provider
9//! configuration. Clients are checked out, used, and returned to the pool.
10//!
11//! # Features
12//!
13//! - Configurable maximum pool size per provider
14//! - Connection health checking with configurable TTL
15//! - Automatic stale connection cleanup
16//! - Graceful shutdown with connection draining
17//! - Fair distribution via round-robin or least-connections
18//!
19//! # Example
20//!
21//! ```rust,ignore
22//! use ares::llm::pool::{ClientPool, PoolConfig};
23//! use ares::llm::Provider;
24//!
25//! let config = PoolConfig::default();
26//! let pool = ClientPool::new(config);
27//!
28//! // Register a provider
29//! pool.register_provider("openai", provider).await?;
30//!
31//! // Get a pooled client
32//! let guard = pool.get("openai").await?;
33//! let response = guard.client().generate("Hello!").await?;
34//! // Client is automatically returned to pool when guard is dropped
35//! ```
36
37use crate::llm::client::{LLMClient, Provider};
38use crate::types::{AppError, Result};
39use parking_lot::{Mutex, RwLock};
40use std::collections::HashMap;
41use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
42use std::sync::Arc;
43use std::time::{Duration, Instant};
44use tokio::sync::{OwnedSemaphorePermit, Semaphore};
45
46/// Configuration for the client pool
47#[derive(Debug, Clone)]
48pub struct PoolConfig {
49    /// Maximum number of clients per provider (default: 10)
50    pub max_connections_per_provider: usize,
51
52    /// Minimum number of idle clients to maintain per provider (default: 2)
53    pub min_idle_connections: usize,
54
55    /// Maximum time a client can be idle before being considered stale (default: 5 minutes)
56    pub idle_timeout: Duration,
57
58    /// Maximum lifetime of a client before forced refresh (default: 30 minutes)
59    pub max_lifetime: Duration,
60
61    /// How often to run health checks on idle connections (default: 60 seconds)
62    pub health_check_interval: Duration,
63
64    /// Timeout for acquiring a client from the pool (default: 30 seconds)
65    pub acquire_timeout: Duration,
66
67    /// Whether to enable connection health checking (default: true)
68    pub enable_health_check: bool,
69}
70
71impl Default for PoolConfig {
72    fn default() -> Self {
73        Self {
74            max_connections_per_provider: 10,
75            min_idle_connections: 2,
76            idle_timeout: Duration::from_secs(300), // 5 minutes
77            max_lifetime: Duration::from_secs(1800), // 30 minutes
78            health_check_interval: Duration::from_secs(60),
79            acquire_timeout: Duration::from_secs(30),
80            enable_health_check: true,
81        }
82    }
83}
84
85impl PoolConfig {
86    /// Create a new pool config with custom max connections
87    pub fn with_max_connections(mut self, max: usize) -> Self {
88        self.max_connections_per_provider = max;
89        self
90    }
91
92    /// Create a new pool config with custom idle timeout
93    pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
94        self.idle_timeout = timeout;
95        self
96    }
97
98    /// Create a new pool config with custom max lifetime
99    pub fn with_max_lifetime(mut self, lifetime: Duration) -> Self {
100        self.max_lifetime = lifetime;
101        self
102    }
103
104    /// Disable health checking (useful for testing)
105    pub fn without_health_check(mut self) -> Self {
106        self.enable_health_check = false;
107        self
108    }
109}
110
111/// Metadata for a pooled client
112#[derive(Debug)]
113struct PooledClientMeta {
114    /// When this client was created
115    created_at: Instant,
116    /// When this client was last used
117    last_used: Instant,
118    /// Number of times this client has been used
119    #[allow(dead_code)] // Used for metrics/debugging
120    use_count: AtomicU64,
121}
122
123impl PooledClientMeta {
124    fn new() -> Self {
125        let now = Instant::now();
126        Self {
127            created_at: now,
128            last_used: now,
129            use_count: AtomicU64::new(0),
130        }
131    }
132
133    fn mark_used(&mut self) {
134        self.last_used = Instant::now();
135        self.use_count.fetch_add(1, Ordering::Relaxed);
136    }
137
138    fn is_stale(&self, config: &PoolConfig) -> bool {
139        let now = Instant::now();
140        let idle_duration = now.duration_since(self.last_used);
141        let lifetime = now.duration_since(self.created_at);
142
143        idle_duration > config.idle_timeout || lifetime > config.max_lifetime
144    }
145}
146
147/// A pooled LLM client with its metadata
148struct PooledClient {
149    client: Box<dyn LLMClient>,
150    meta: PooledClientMeta,
151}
152
153impl std::fmt::Debug for PooledClient {
154    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155        f.debug_struct("PooledClient")
156            .field("meta", &self.meta)
157            .finish()
158    }
159}
160
161/// Pool of clients for a single provider
162#[derive(Debug)]
163struct ProviderPool {
164    /// The provider configuration for creating new clients
165    provider: Provider,
166    /// Pool of available clients
167    clients: Mutex<Vec<PooledClient>>,
168    /// Semaphore to limit concurrent connections
169    semaphore: Arc<Semaphore>,
170    /// Number of clients currently in use
171    in_use_count: AtomicUsize,
172    /// Total number of clients created (for stats)
173    total_created: AtomicU64,
174    /// Configuration reference
175    config: PoolConfig,
176}
177
178impl ProviderPool {
179    fn new(provider: Provider, config: PoolConfig) -> Self {
180        let semaphore = Arc::new(Semaphore::new(config.max_connections_per_provider));
181        Self {
182            provider,
183            clients: Mutex::new(Vec::with_capacity(config.max_connections_per_provider)),
184            semaphore,
185            in_use_count: AtomicUsize::new(0),
186            total_created: AtomicU64::new(0),
187            config,
188        }
189    }
190
191    /// Get an available client from the pool, or create a new one
192    async fn acquire(&self) -> Result<(Box<dyn LLMClient>, OwnedSemaphorePermit)> {
193        // Acquire a permit (blocks if pool is at capacity)
194        let permit = tokio::time::timeout(
195            self.config.acquire_timeout,
196            self.semaphore.clone().acquire_owned(),
197        )
198        .await
199        .map_err(|_| AppError::LLM("Timeout waiting for available client in pool".to_string()))?
200        .map_err(|_| AppError::LLM("Pool semaphore closed".to_string()))?;
201
202        // Try to get an existing client from the pool
203        let maybe_client = {
204            let mut clients = self.clients.lock();
205            // Find a non-stale client
206            let mut found_idx = None;
207            for (idx, pooled) in clients.iter().enumerate() {
208                if !pooled.meta.is_stale(&self.config) {
209                    found_idx = Some(idx);
210                    break;
211                }
212            }
213
214            if let Some(idx) = found_idx {
215                Some(clients.swap_remove(idx))
216            } else {
217                // Remove all stale clients
218                clients.retain(|c| !c.meta.is_stale(&self.config));
219                None
220            }
221        };
222
223        let client = if let Some(mut pooled) = maybe_client {
224            pooled.meta.mark_used();
225            pooled.client
226        } else {
227            // Create a new client
228            self.total_created.fetch_add(1, Ordering::Relaxed);
229            self.provider.create_client().await?
230        };
231
232        self.in_use_count.fetch_add(1, Ordering::Relaxed);
233        Ok((client, permit))
234    }
235
236    /// Return a client to the pool
237    fn release(&self, client: Box<dyn LLMClient>) {
238        self.in_use_count.fetch_sub(1, Ordering::Relaxed);
239
240        let mut clients = self.clients.lock();
241
242        // Only return to pool if we haven't exceeded max idle
243        if clients.len() < self.config.max_connections_per_provider {
244            clients.push(PooledClient {
245                client,
246                meta: PooledClientMeta::new(),
247            });
248        }
249        // Otherwise, client is dropped
250    }
251
252    /// Remove stale connections from the pool
253    fn cleanup_stale(&self) -> usize {
254        let mut clients = self.clients.lock();
255        let before = clients.len();
256        clients.retain(|c| !c.meta.is_stale(&self.config));
257        before - clients.len()
258    }
259
260    /// Get pool statistics
261    fn stats(&self) -> ProviderPoolStats {
262        let clients = self.clients.lock();
263        ProviderPoolStats {
264            available: clients.len(),
265            in_use: self.in_use_count.load(Ordering::Relaxed),
266            total_created: self.total_created.load(Ordering::Relaxed),
267            max_size: self.config.max_connections_per_provider,
268        }
269    }
270
271    /// Drain all connections (for shutdown)
272    fn drain(&self) {
273        let mut clients = self.clients.lock();
274        clients.clear();
275    }
276}
277
278/// Statistics for a provider pool
279#[derive(Debug, Clone)]
280pub struct ProviderPoolStats {
281    /// Number of available (idle) clients
282    pub available: usize,
283    /// Number of clients currently in use
284    pub in_use: usize,
285    /// Total number of clients created over the pool's lifetime
286    pub total_created: u64,
287    /// Maximum pool size
288    pub max_size: usize,
289}
290
291/// Overall pool statistics
292#[derive(Debug, Clone)]
293pub struct PoolStats {
294    /// Stats per provider
295    pub providers: HashMap<String, ProviderPoolStats>,
296    /// Total available clients across all providers
297    pub total_available: usize,
298    /// Total in-use clients across all providers
299    pub total_in_use: usize,
300}
301
302/// Guard that returns a client to the pool when dropped
303pub struct PooledClientGuard {
304    client: Option<Box<dyn LLMClient>>,
305    pool: Arc<ProviderPool>,
306    _permit: OwnedSemaphorePermit,
307}
308
309impl std::fmt::Debug for PooledClientGuard {
310    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311        f.debug_struct("PooledClientGuard")
312            .field("has_client", &self.client.is_some())
313            .field("pool", &self.pool)
314            .finish()
315    }
316}
317
318impl PooledClientGuard {
319    /// Get a reference to the underlying client
320    pub fn client(&self) -> &dyn LLMClient {
321        self.client.as_ref().expect("Client already taken").as_ref()
322    }
323
324    /// Get a mutable reference to the underlying client
325    pub fn client_mut(&mut self) -> &mut dyn LLMClient {
326        self.client.as_mut().expect("Client already taken").as_mut()
327    }
328
329    /// Take ownership of the client, preventing it from being returned to the pool
330    ///
331    /// This is useful if you need to move the client elsewhere, but be aware that
332    /// it won't be returned to the pool.
333    pub fn take(mut self) -> Box<dyn LLMClient> {
334        self.client.take().expect("Client already taken")
335    }
336}
337
338impl Drop for PooledClientGuard {
339    fn drop(&mut self) {
340        if let Some(client) = self.client.take() {
341            self.pool.release(client);
342        }
343    }
344}
345
346impl std::ops::Deref for PooledClientGuard {
347    type Target = Box<dyn LLMClient>;
348
349    fn deref(&self) -> &Self::Target {
350        self.client.as_ref().expect("Client already taken")
351    }
352}
353
354/// LLM Client Pool for managing reusable client connections
355///
356/// The pool maintains separate sub-pools for each registered provider,
357/// allowing efficient reuse of HTTP connections and client state.
358pub struct ClientPool {
359    config: PoolConfig,
360    providers: RwLock<HashMap<String, Arc<ProviderPool>>>,
361    shutdown: std::sync::atomic::AtomicBool,
362}
363
364impl ClientPool {
365    /// Create a new client pool with the given configuration
366    pub fn new(config: PoolConfig) -> Self {
367        Self {
368            config,
369            providers: RwLock::new(HashMap::new()),
370            shutdown: std::sync::atomic::AtomicBool::new(false),
371        }
372    }
373
374    /// Create a new client pool with default configuration
375    pub fn with_defaults() -> Self {
376        Self::new(PoolConfig::default())
377    }
378
379    /// Register a provider with the pool
380    ///
381    /// This creates a sub-pool for the given provider that will manage
382    /// client instances for that provider.
383    #[allow(unreachable_code, unused_variables)]
384    pub fn register_provider(&self, name: &str, provider: Provider) {
385        let pool = Arc::new(ProviderPool::new(provider, self.config.clone()));
386        let mut providers = self.providers.write();
387        providers.insert(name.to_string(), pool);
388    }
389
390    /// Check if a provider is registered
391    pub fn has_provider(&self, name: &str) -> bool {
392        self.providers.read().contains_key(name)
393    }
394
395    /// List all registered provider names
396    pub fn provider_names(&self) -> Vec<String> {
397        self.providers.read().keys().cloned().collect()
398    }
399
400    /// Get a client from the pool for the specified provider
401    ///
402    /// The returned guard will automatically return the client to the pool
403    /// when dropped.
404    pub async fn get(&self, provider_name: &str) -> Result<PooledClientGuard> {
405        if self.shutdown.load(Ordering::Relaxed) {
406            return Err(AppError::LLM("Pool is shutting down".to_string()));
407        }
408
409        let pool = {
410            let providers = self.providers.read();
411            providers.get(provider_name).cloned().ok_or_else(|| {
412                AppError::Configuration(format!(
413                    "Provider '{}' not registered in pool",
414                    provider_name
415                ))
416            })?
417        };
418
419        let (client, permit) = pool.acquire().await?;
420
421        Ok(PooledClientGuard {
422            client: Some(client),
423            pool,
424            _permit: permit,
425        })
426    }
427
428    /// Get pool statistics
429    pub fn stats(&self) -> PoolStats {
430        let providers = self.providers.read();
431        let mut stats = PoolStats {
432            providers: HashMap::new(),
433            total_available: 0,
434            total_in_use: 0,
435        };
436
437        for (name, pool) in providers.iter() {
438            let provider_stats = pool.stats();
439            stats.total_available += provider_stats.available;
440            stats.total_in_use += provider_stats.in_use;
441            stats.providers.insert(name.clone(), provider_stats);
442        }
443
444        stats
445    }
446
447    /// Clean up stale connections across all providers
448    ///
449    /// Returns the total number of connections removed.
450    pub fn cleanup_stale(&self) -> usize {
451        let providers = self.providers.read();
452        providers.values().map(|p| p.cleanup_stale()).sum()
453    }
454
455    /// Start a background task that periodically cleans up stale connections
456    ///
457    /// The task runs until the pool is shut down.
458    pub fn start_cleanup_task(self: &Arc<Self>) -> tokio::task::JoinHandle<()> {
459        let pool = Arc::clone(self);
460        let interval = pool.config.health_check_interval;
461
462        tokio::spawn(async move {
463            let mut interval_timer = tokio::time::interval(interval);
464            loop {
465                interval_timer.tick().await;
466
467                if pool.shutdown.load(Ordering::Relaxed) {
468                    break;
469                }
470
471                let removed = pool.cleanup_stale();
472                if removed > 0 {
473                    tracing::debug!("Pool cleanup: removed {} stale connections", removed);
474                }
475            }
476        })
477    }
478
479    /// Gracefully shut down the pool
480    ///
481    /// This prevents new clients from being acquired and drains all existing
482    /// connections.
483    pub fn shutdown(&self) {
484        self.shutdown.store(true, Ordering::Relaxed);
485
486        let providers = self.providers.read();
487        for pool in providers.values() {
488            pool.drain();
489        }
490    }
491
492    /// Check if the pool is shut down
493    pub fn is_shutdown(&self) -> bool {
494        self.shutdown.load(Ordering::Relaxed)
495    }
496}
497
498impl Default for ClientPool {
499    fn default() -> Self {
500        Self::with_defaults()
501    }
502}
503
504/// Builder for creating a `ClientPool` with registered providers
505pub struct ClientPoolBuilder {
506    config: PoolConfig,
507    providers: Vec<(String, Provider)>,
508}
509
510impl ClientPoolBuilder {
511    /// Create a new builder with default configuration
512    pub fn new() -> Self {
513        Self {
514            config: PoolConfig::default(),
515            providers: Vec::new(),
516        }
517    }
518
519    /// Set the pool configuration
520    pub fn config(mut self, config: PoolConfig) -> Self {
521        self.config = config;
522        self
523    }
524
525    /// Add a provider to the pool
526    pub fn provider(mut self, name: impl Into<String>, provider: Provider) -> Self {
527        self.providers.push((name.into(), provider));
528        self
529    }
530
531    /// Build the client pool
532    pub fn build(self) -> ClientPool {
533        let pool = ClientPool::new(self.config);
534        for (name, provider) in self.providers {
535            pool.register_provider(&name, provider);
536        }
537        pool
538    }
539
540    /// Build the client pool wrapped in an Arc
541    pub fn build_arc(self) -> Arc<ClientPool> {
542        Arc::new(self.build())
543    }
544}
545
546impl Default for ClientPoolBuilder {
547    fn default() -> Self {
548        Self::new()
549    }
550}
551
552#[cfg(test)]
553mod tests {
554    use super::*;
555
556    #[test]
557    fn test_pool_config_defaults() {
558        let config = PoolConfig::default();
559        assert_eq!(config.max_connections_per_provider, 10);
560        assert_eq!(config.min_idle_connections, 2);
561        assert_eq!(config.idle_timeout, Duration::from_secs(300));
562        assert_eq!(config.max_lifetime, Duration::from_secs(1800));
563        assert!(config.enable_health_check);
564    }
565
566    #[test]
567    fn test_pool_config_builder() {
568        let config = PoolConfig::default()
569            .with_max_connections(20)
570            .with_idle_timeout(Duration::from_secs(60))
571            .without_health_check();
572
573        assert_eq!(config.max_connections_per_provider, 20);
574        assert_eq!(config.idle_timeout, Duration::from_secs(60));
575        assert!(!config.enable_health_check);
576    }
577
578    #[test]
579    fn test_pooled_client_meta_stale_detection() {
580        let config = PoolConfig::default()
581            .with_idle_timeout(Duration::from_millis(10))
582            .with_max_lifetime(Duration::from_millis(50));
583
584        let meta = PooledClientMeta::new();
585
586        // Should not be stale immediately
587        assert!(!meta.is_stale(&config));
588
589        // Sleep to trigger idle timeout
590        std::thread::sleep(Duration::from_millis(15));
591        assert!(meta.is_stale(&config));
592    }
593
594    #[test]
595    fn test_pool_stats() {
596        let pool = ClientPool::with_defaults();
597        let stats = pool.stats();
598
599        assert_eq!(stats.total_available, 0);
600        assert_eq!(stats.total_in_use, 0);
601        assert!(stats.providers.is_empty());
602    }
603
604    #[test]
605    fn test_pool_shutdown() {
606        let pool = ClientPool::with_defaults();
607        assert!(!pool.is_shutdown());
608
609        pool.shutdown();
610        assert!(pool.is_shutdown());
611    }
612
613    #[cfg(feature = "ollama")]
614    #[test]
615    fn test_provider_registration() {
616        use crate::llm::client::ModelParams;
617
618        let pool = ClientPool::with_defaults();
619
620        let provider = Provider::Ollama {
621            base_url: "http://localhost:11434".to_string(),
622            model: "test".to_string(),
623            params: ModelParams::default(),
624        };
625
626        pool.register_provider("ollama", provider);
627
628        assert!(pool.has_provider("ollama"));
629        assert!(!pool.has_provider("openai"));
630        assert_eq!(pool.provider_names(), vec!["ollama"]);
631    }
632
633    #[cfg(feature = "ollama")]
634    #[test]
635    fn test_builder_pattern() {
636        use crate::llm::client::ModelParams;
637
638        let provider = Provider::Ollama {
639            base_url: "http://localhost:11434".to_string(),
640            model: "test".to_string(),
641            params: ModelParams::default(),
642        };
643
644        let pool = ClientPoolBuilder::new()
645            .config(PoolConfig::default().with_max_connections(5))
646            .provider("ollama", provider)
647            .build();
648
649        assert!(pool.has_provider("ollama"));
650    }
651
652    #[cfg(feature = "ollama")]
653    #[tokio::test]
654    async fn test_get_unregistered_provider_error() {
655        let pool = ClientPool::with_defaults();
656
657        let result = pool.get("nonexistent").await;
658        assert!(result.is_err());
659
660        let err = result.unwrap_err();
661        assert!(matches!(err, AppError::Configuration(_)));
662    }
663
664    #[tokio::test]
665    async fn test_get_after_shutdown() {
666        let pool = ClientPool::with_defaults();
667        pool.shutdown();
668
669        // Even if we had a provider, should fail after shutdown
670        let result = pool.get("anything").await;
671        assert!(result.is_err());
672
673        if let Err(AppError::LLM(msg)) = result {
674            assert!(msg.contains("shutting down"));
675        } else {
676            panic!("Expected LLM error");
677        }
678    }
679}