Skip to main content

ares/llm/
pool.rs

1//! LLM Client Connection Pooling (DIR-44)
2//!
3//! This module provides connection pooling for LLM clients, enabling connection
4//! reuse across requests to reduce latency and resource consumption.
5//!
6//! # Architecture
7//!
8//! The pool maintains a set of pre-initialized `LLMClient` instances per provider
9//! configuration. Clients are checked out, used, and returned to the pool.
10//!
11//! # Features
12//!
13//! - Configurable maximum pool size per provider
14//! - Connection health checking with configurable TTL
15//! - Automatic stale connection cleanup
16//! - Graceful shutdown with connection draining
17//! - Fair distribution via round-robin or least-connections
18//!
19//! # Example
20//!
21//! ```rust,ignore
22//! use ares::llm::pool::{ClientPool, PoolConfig};
23//! use ares::llm::Provider;
24//!
25//! let config = PoolConfig::default();
26//! let pool = ClientPool::new(config);
27//!
28//! // Register a provider
29//! pool.register_provider("openai", provider).await?;
30//!
31//! // Get a pooled client
32//! let guard = pool.get("openai").await?;
33//! let response = guard.client().generate("Hello!").await?;
34//! // Client is automatically returned to pool when guard is dropped
35//! ```
36
37use crate::llm::client::{LLMClient, Provider};
38use crate::types::{AppError, Result};
39use parking_lot::{Mutex, RwLock};
40use std::collections::HashMap;
41use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
42use std::sync::Arc;
43use std::time::{Duration, Instant};
44use tokio::sync::{OwnedSemaphorePermit, Semaphore};
45
46/// Configuration for the client pool
47#[derive(Debug, Clone)]
48pub struct PoolConfig {
49    /// Maximum number of clients per provider (default: 10)
50    pub max_connections_per_provider: usize,
51
52    /// Minimum number of idle clients to maintain per provider (default: 2)
53    pub min_idle_connections: usize,
54
55    /// Maximum time a client can be idle before being considered stale (default: 5 minutes)
56    pub idle_timeout: Duration,
57
58    /// Maximum lifetime of a client before forced refresh (default: 30 minutes)
59    pub max_lifetime: Duration,
60
61    /// How often to run health checks on idle connections (default: 60 seconds)
62    pub health_check_interval: Duration,
63
64    /// Timeout for acquiring a client from the pool (default: 30 seconds)
65    pub acquire_timeout: Duration,
66
67    /// Whether to enable connection health checking (default: true)
68    pub enable_health_check: bool,
69}
70
71impl Default for PoolConfig {
72    fn default() -> Self {
73        Self {
74            max_connections_per_provider: 10,
75            min_idle_connections: 2,
76            idle_timeout: Duration::from_secs(300),       // 5 minutes
77            max_lifetime: Duration::from_secs(1800),      // 30 minutes
78            health_check_interval: Duration::from_secs(60),
79            acquire_timeout: Duration::from_secs(30),
80            enable_health_check: true,
81        }
82    }
83}
84
85impl PoolConfig {
86    /// Create a new pool config with custom max connections
87    pub fn with_max_connections(mut self, max: usize) -> Self {
88        self.max_connections_per_provider = max;
89        self
90    }
91
92    /// Create a new pool config with custom idle timeout
93    pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
94        self.idle_timeout = timeout;
95        self
96    }
97
98    /// Create a new pool config with custom max lifetime
99    pub fn with_max_lifetime(mut self, lifetime: Duration) -> Self {
100        self.max_lifetime = lifetime;
101        self
102    }
103
104    /// Disable health checking (useful for testing)
105    pub fn without_health_check(mut self) -> Self {
106        self.enable_health_check = false;
107        self
108    }
109}
110
111/// Metadata for a pooled client
112#[derive(Debug)]
113struct PooledClientMeta {
114    /// When this client was created
115    created_at: Instant,
116    /// When this client was last used
117    last_used: Instant,
118    /// Number of times this client has been used
119    #[allow(dead_code)] // Used for metrics/debugging
120    use_count: AtomicU64,
121}
122
123impl PooledClientMeta {
124    fn new() -> Self {
125        let now = Instant::now();
126        Self {
127            created_at: now,
128            last_used: now,
129            use_count: AtomicU64::new(0),
130        }
131    }
132
133    fn mark_used(&mut self) {
134        self.last_used = Instant::now();
135        self.use_count.fetch_add(1, Ordering::Relaxed);
136    }
137
138    fn is_stale(&self, config: &PoolConfig) -> bool {
139        let now = Instant::now();
140        let idle_duration = now.duration_since(self.last_used);
141        let lifetime = now.duration_since(self.created_at);
142
143        idle_duration > config.idle_timeout || lifetime > config.max_lifetime
144    }
145}
146
147/// A pooled LLM client with its metadata
148struct PooledClient {
149    client: Box<dyn LLMClient>,
150    meta: PooledClientMeta,
151}
152
153impl std::fmt::Debug for PooledClient {
154    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155        f.debug_struct("PooledClient").field("meta", &self.meta).finish()
156    }
157}
158
159/// Pool of clients for a single provider
160#[derive(Debug)]
161struct ProviderPool {
162    /// The provider configuration for creating new clients
163    provider: Provider,
164    /// Pool of available clients
165    clients: Mutex<Vec<PooledClient>>,
166    /// Semaphore to limit concurrent connections
167    semaphore: Arc<Semaphore>,
168    /// Number of clients currently in use
169    in_use_count: AtomicUsize,
170    /// Total number of clients created (for stats)
171    total_created: AtomicU64,
172    /// Configuration reference
173    config: PoolConfig,
174}
175
176impl ProviderPool {
177    fn new(provider: Provider, config: PoolConfig) -> Self {
178        let semaphore = Arc::new(Semaphore::new(config.max_connections_per_provider));
179        Self {
180            provider,
181            clients: Mutex::new(Vec::with_capacity(config.max_connections_per_provider)),
182            semaphore,
183            in_use_count: AtomicUsize::new(0),
184            total_created: AtomicU64::new(0),
185            config,
186        }
187    }
188
189    /// Get an available client from the pool, or create a new one
190    async fn acquire(&self) -> Result<(Box<dyn LLMClient>, OwnedSemaphorePermit)> {
191        // Acquire a permit (blocks if pool is at capacity)
192        let permit = tokio::time::timeout(
193            self.config.acquire_timeout,
194            self.semaphore.clone().acquire_owned(),
195        )
196        .await
197        .map_err(|_| AppError::LLM("Timeout waiting for available client in pool".to_string()))?
198        .map_err(|_| AppError::LLM("Pool semaphore closed".to_string()))?;
199
200        // Try to get an existing client from the pool
201        let maybe_client = {
202            let mut clients = self.clients.lock();
203            // Find a non-stale client
204            let mut found_idx = None;
205            for (idx, pooled) in clients.iter().enumerate() {
206                if !pooled.meta.is_stale(&self.config) {
207                    found_idx = Some(idx);
208                    break;
209                }
210            }
211
212            if let Some(idx) = found_idx {
213                Some(clients.swap_remove(idx))
214            } else {
215                // Remove all stale clients
216                clients.retain(|c| !c.meta.is_stale(&self.config));
217                None
218            }
219        };
220
221        let client = if let Some(mut pooled) = maybe_client {
222            pooled.meta.mark_used();
223            pooled.client
224        } else {
225            // Create a new client
226            self.total_created.fetch_add(1, Ordering::Relaxed);
227            self.provider.create_client().await?
228        };
229
230        self.in_use_count.fetch_add(1, Ordering::Relaxed);
231        Ok((client, permit))
232    }
233
234    /// Return a client to the pool
235    fn release(&self, client: Box<dyn LLMClient>) {
236        self.in_use_count.fetch_sub(1, Ordering::Relaxed);
237
238        let mut clients = self.clients.lock();
239
240        // Only return to pool if we haven't exceeded max idle
241        if clients.len() < self.config.max_connections_per_provider {
242            clients.push(PooledClient {
243                client,
244                meta: PooledClientMeta::new(),
245            });
246        }
247        // Otherwise, client is dropped
248    }
249
250    /// Remove stale connections from the pool
251    fn cleanup_stale(&self) -> usize {
252        let mut clients = self.clients.lock();
253        let before = clients.len();
254        clients.retain(|c| !c.meta.is_stale(&self.config));
255        before - clients.len()
256    }
257
258    /// Get pool statistics
259    fn stats(&self) -> ProviderPoolStats {
260        let clients = self.clients.lock();
261        ProviderPoolStats {
262            available: clients.len(),
263            in_use: self.in_use_count.load(Ordering::Relaxed),
264            total_created: self.total_created.load(Ordering::Relaxed),
265            max_size: self.config.max_connections_per_provider,
266        }
267    }
268
269    /// Drain all connections (for shutdown)
270    fn drain(&self) {
271        let mut clients = self.clients.lock();
272        clients.clear();
273    }
274}
275
276/// Statistics for a provider pool
277#[derive(Debug, Clone)]
278pub struct ProviderPoolStats {
279    /// Number of available (idle) clients
280    pub available: usize,
281    /// Number of clients currently in use
282    pub in_use: usize,
283    /// Total number of clients created over the pool's lifetime
284    pub total_created: u64,
285    /// Maximum pool size
286    pub max_size: usize,
287}
288
289/// Overall pool statistics
290#[derive(Debug, Clone)]
291pub struct PoolStats {
292    /// Stats per provider
293    pub providers: HashMap<String, ProviderPoolStats>,
294    /// Total available clients across all providers
295    pub total_available: usize,
296    /// Total in-use clients across all providers
297    pub total_in_use: usize,
298}
299
300/// Guard that returns a client to the pool when dropped
301pub struct PooledClientGuard {
302    client: Option<Box<dyn LLMClient>>,
303    pool: Arc<ProviderPool>,
304    _permit: OwnedSemaphorePermit,
305}
306
307impl std::fmt::Debug for PooledClientGuard {
308    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309        f.debug_struct("PooledClientGuard")
310            .field("has_client", &self.client.is_some())
311            .field("pool", &self.pool)
312            .finish()
313    }
314}
315
316impl PooledClientGuard {
317    /// Get a reference to the underlying client
318    pub fn client(&self) -> &dyn LLMClient {
319        self.client.as_ref().expect("Client already taken").as_ref()
320    }
321
322    /// Get a mutable reference to the underlying client
323    pub fn client_mut(&mut self) -> &mut dyn LLMClient {
324        self.client.as_mut().expect("Client already taken").as_mut()
325    }
326
327    /// Take ownership of the client, preventing it from being returned to the pool
328    ///
329    /// This is useful if you need to move the client elsewhere, but be aware that
330    /// it won't be returned to the pool.
331    pub fn take(mut self) -> Box<dyn LLMClient> {
332        self.client.take().expect("Client already taken")
333    }
334}
335
336impl Drop for PooledClientGuard {
337    fn drop(&mut self) {
338        if let Some(client) = self.client.take() {
339            self.pool.release(client);
340        }
341    }
342}
343
344impl std::ops::Deref for PooledClientGuard {
345    type Target = Box<dyn LLMClient>;
346
347    fn deref(&self) -> &Self::Target {
348        self.client.as_ref().expect("Client already taken")
349    }
350}
351
352/// LLM Client Pool for managing reusable client connections
353///
354/// The pool maintains separate sub-pools for each registered provider,
355/// allowing efficient reuse of HTTP connections and client state.
356pub struct ClientPool {
357    config: PoolConfig,
358    providers: RwLock<HashMap<String, Arc<ProviderPool>>>,
359    shutdown: std::sync::atomic::AtomicBool,
360}
361
362impl ClientPool {
363    /// Create a new client pool with the given configuration
364    pub fn new(config: PoolConfig) -> Self {
365        Self {
366            config,
367            providers: RwLock::new(HashMap::new()),
368            shutdown: std::sync::atomic::AtomicBool::new(false),
369        }
370    }
371
372    /// Create a new client pool with default configuration
373    pub fn with_defaults() -> Self {
374        Self::new(PoolConfig::default())
375    }
376
377    /// Register a provider with the pool
378    ///
379    /// This creates a sub-pool for the given provider that will manage
380    /// client instances for that provider.
381    #[allow(unreachable_code, unused_variables)]
382    pub fn register_provider(&self, name: &str, provider: Provider) {
383        let pool = Arc::new(ProviderPool::new(provider, self.config.clone()));
384        let mut providers = self.providers.write();
385        providers.insert(name.to_string(), pool);
386    }
387
388    /// Check if a provider is registered
389    pub fn has_provider(&self, name: &str) -> bool {
390        self.providers.read().contains_key(name)
391    }
392
393    /// List all registered provider names
394    pub fn provider_names(&self) -> Vec<String> {
395        self.providers.read().keys().cloned().collect()
396    }
397
398    /// Get a client from the pool for the specified provider
399    ///
400    /// The returned guard will automatically return the client to the pool
401    /// when dropped.
402    pub async fn get(&self, provider_name: &str) -> Result<PooledClientGuard> {
403        if self.shutdown.load(Ordering::Relaxed) {
404            return Err(AppError::LLM("Pool is shutting down".to_string()));
405        }
406
407        let pool = {
408            let providers = self.providers.read();
409            providers.get(provider_name).cloned().ok_or_else(|| {
410                AppError::Configuration(format!("Provider '{}' not registered in pool", provider_name))
411            })?
412        };
413
414        let (client, permit) = pool.acquire().await?;
415
416        Ok(PooledClientGuard {
417            client: Some(client),
418            pool,
419            _permit: permit,
420        })
421    }
422
423    /// Get pool statistics
424    pub fn stats(&self) -> PoolStats {
425        let providers = self.providers.read();
426        let mut stats = PoolStats {
427            providers: HashMap::new(),
428            total_available: 0,
429            total_in_use: 0,
430        };
431
432        for (name, pool) in providers.iter() {
433            let provider_stats = pool.stats();
434            stats.total_available += provider_stats.available;
435            stats.total_in_use += provider_stats.in_use;
436            stats.providers.insert(name.clone(), provider_stats);
437        }
438
439        stats
440    }
441
442    /// Clean up stale connections across all providers
443    ///
444    /// Returns the total number of connections removed.
445    pub fn cleanup_stale(&self) -> usize {
446        let providers = self.providers.read();
447        providers.values().map(|p| p.cleanup_stale()).sum()
448    }
449
450    /// Start a background task that periodically cleans up stale connections
451    ///
452    /// The task runs until the pool is shut down.
453    pub fn start_cleanup_task(self: &Arc<Self>) -> tokio::task::JoinHandle<()> {
454        let pool = Arc::clone(self);
455        let interval = pool.config.health_check_interval;
456
457        tokio::spawn(async move {
458            let mut interval_timer = tokio::time::interval(interval);
459            loop {
460                interval_timer.tick().await;
461
462                if pool.shutdown.load(Ordering::Relaxed) {
463                    break;
464                }
465
466                let removed = pool.cleanup_stale();
467                if removed > 0 {
468                    tracing::debug!("Pool cleanup: removed {} stale connections", removed);
469                }
470            }
471        })
472    }
473
474    /// Gracefully shut down the pool
475    ///
476    /// This prevents new clients from being acquired and drains all existing
477    /// connections.
478    pub fn shutdown(&self) {
479        self.shutdown.store(true, Ordering::Relaxed);
480
481        let providers = self.providers.read();
482        for pool in providers.values() {
483            pool.drain();
484        }
485    }
486
487    /// Check if the pool is shut down
488    pub fn is_shutdown(&self) -> bool {
489        self.shutdown.load(Ordering::Relaxed)
490    }
491}
492
493impl Default for ClientPool {
494    fn default() -> Self {
495        Self::with_defaults()
496    }
497}
498
499/// Builder for creating a `ClientPool` with registered providers
500pub struct ClientPoolBuilder {
501    config: PoolConfig,
502    providers: Vec<(String, Provider)>,
503}
504
505impl ClientPoolBuilder {
506    /// Create a new builder with default configuration
507    pub fn new() -> Self {
508        Self {
509            config: PoolConfig::default(),
510            providers: Vec::new(),
511        }
512    }
513
514    /// Set the pool configuration
515    pub fn config(mut self, config: PoolConfig) -> Self {
516        self.config = config;
517        self
518    }
519
520    /// Add a provider to the pool
521    pub fn provider(mut self, name: impl Into<String>, provider: Provider) -> Self {
522        self.providers.push((name.into(), provider));
523        self
524    }
525
526    /// Build the client pool
527    pub fn build(self) -> ClientPool {
528        let pool = ClientPool::new(self.config);
529        for (name, provider) in self.providers {
530            pool.register_provider(&name, provider);
531        }
532        pool
533    }
534
535    /// Build the client pool wrapped in an Arc
536    pub fn build_arc(self) -> Arc<ClientPool> {
537        Arc::new(self.build())
538    }
539}
540
541impl Default for ClientPoolBuilder {
542    fn default() -> Self {
543        Self::new()
544    }
545}
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550
551    #[test]
552    fn test_pool_config_defaults() {
553        let config = PoolConfig::default();
554        assert_eq!(config.max_connections_per_provider, 10);
555        assert_eq!(config.min_idle_connections, 2);
556        assert_eq!(config.idle_timeout, Duration::from_secs(300));
557        assert_eq!(config.max_lifetime, Duration::from_secs(1800));
558        assert!(config.enable_health_check);
559    }
560
561    #[test]
562    fn test_pool_config_builder() {
563        let config = PoolConfig::default()
564            .with_max_connections(20)
565            .with_idle_timeout(Duration::from_secs(60))
566            .without_health_check();
567
568        assert_eq!(config.max_connections_per_provider, 20);
569        assert_eq!(config.idle_timeout, Duration::from_secs(60));
570        assert!(!config.enable_health_check);
571    }
572
573    #[test]
574    fn test_pooled_client_meta_stale_detection() {
575        let config = PoolConfig::default()
576            .with_idle_timeout(Duration::from_millis(10))
577            .with_max_lifetime(Duration::from_millis(50));
578
579        let meta = PooledClientMeta::new();
580
581        // Should not be stale immediately
582        assert!(!meta.is_stale(&config));
583
584        // Sleep to trigger idle timeout
585        std::thread::sleep(Duration::from_millis(15));
586        assert!(meta.is_stale(&config));
587    }
588
589    #[test]
590    fn test_pool_stats() {
591        let pool = ClientPool::with_defaults();
592        let stats = pool.stats();
593
594        assert_eq!(stats.total_available, 0);
595        assert_eq!(stats.total_in_use, 0);
596        assert!(stats.providers.is_empty());
597    }
598
599    #[test]
600    fn test_pool_shutdown() {
601        let pool = ClientPool::with_defaults();
602        assert!(!pool.is_shutdown());
603
604        pool.shutdown();
605        assert!(pool.is_shutdown());
606    }
607
608    #[cfg(feature = "ollama")]
609    #[test]
610    fn test_provider_registration() {
611        use crate::llm::client::ModelParams;
612
613        let pool = ClientPool::with_defaults();
614
615        let provider = Provider::Ollama {
616            base_url: "http://localhost:11434".to_string(),
617            model: "test".to_string(),
618            params: ModelParams::default(),
619        };
620
621        pool.register_provider("ollama", provider);
622
623        assert!(pool.has_provider("ollama"));
624        assert!(!pool.has_provider("openai"));
625        assert_eq!(pool.provider_names(), vec!["ollama"]);
626    }
627
628    #[cfg(feature = "ollama")]
629    #[test]
630    fn test_builder_pattern() {
631        use crate::llm::client::ModelParams;
632
633        let provider = Provider::Ollama {
634            base_url: "http://localhost:11434".to_string(),
635            model: "test".to_string(),
636            params: ModelParams::default(),
637        };
638
639        let pool = ClientPoolBuilder::new()
640            .config(PoolConfig::default().with_max_connections(5))
641            .provider("ollama", provider)
642            .build();
643
644        assert!(pool.has_provider("ollama"));
645    }
646
647    #[cfg(feature = "ollama")]
648    #[tokio::test]
649    async fn test_get_unregistered_provider_error() {
650        let pool = ClientPool::with_defaults();
651
652        let result = pool.get("nonexistent").await;
653        assert!(result.is_err());
654
655        let err = result.unwrap_err();
656        assert!(matches!(err, AppError::Configuration(_)));
657    }
658
659    #[tokio::test]
660    async fn test_get_after_shutdown() {
661        let pool = ClientPool::with_defaults();
662        pool.shutdown();
663
664        // Even if we had a provider, should fail after shutdown
665        let result = pool.get("anything").await;
666        assert!(result.is_err());
667
668        if let Err(AppError::LLM(msg)) = result {
669            assert!(msg.contains("shutting down"));
670        } else {
671            panic!("Expected LLM error");
672        }
673    }
674}