1use crate::llm::client::{LLMClient, Provider};
38use crate::types::{AppError, Result};
39use parking_lot::{Mutex, RwLock};
40use std::collections::HashMap;
41use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
42use std::sync::Arc;
43use std::time::{Duration, Instant};
44use tokio::sync::{OwnedSemaphorePermit, Semaphore};
45
46#[derive(Debug, Clone)]
48pub struct PoolConfig {
49 pub max_connections_per_provider: usize,
51
52 pub min_idle_connections: usize,
54
55 pub idle_timeout: Duration,
57
58 pub max_lifetime: Duration,
60
61 pub health_check_interval: Duration,
63
64 pub acquire_timeout: Duration,
66
67 pub enable_health_check: bool,
69}
70
71impl Default for PoolConfig {
72 fn default() -> Self {
73 Self {
74 max_connections_per_provider: 10,
75 min_idle_connections: 2,
76 idle_timeout: Duration::from_secs(300), max_lifetime: Duration::from_secs(1800), health_check_interval: Duration::from_secs(60),
79 acquire_timeout: Duration::from_secs(30),
80 enable_health_check: true,
81 }
82 }
83}
84
85impl PoolConfig {
86 pub fn with_max_connections(mut self, max: usize) -> Self {
88 self.max_connections_per_provider = max;
89 self
90 }
91
92 pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
94 self.idle_timeout = timeout;
95 self
96 }
97
98 pub fn with_max_lifetime(mut self, lifetime: Duration) -> Self {
100 self.max_lifetime = lifetime;
101 self
102 }
103
104 pub fn without_health_check(mut self) -> Self {
106 self.enable_health_check = false;
107 self
108 }
109}
110
111#[derive(Debug)]
113struct PooledClientMeta {
114 created_at: Instant,
116 last_used: Instant,
118 #[allow(dead_code)] use_count: AtomicU64,
121}
122
123impl PooledClientMeta {
124 fn new() -> Self {
125 let now = Instant::now();
126 Self {
127 created_at: now,
128 last_used: now,
129 use_count: AtomicU64::new(0),
130 }
131 }
132
133 fn mark_used(&mut self) {
134 self.last_used = Instant::now();
135 self.use_count.fetch_add(1, Ordering::Relaxed);
136 }
137
138 fn is_stale(&self, config: &PoolConfig) -> bool {
139 let now = Instant::now();
140 let idle_duration = now.duration_since(self.last_used);
141 let lifetime = now.duration_since(self.created_at);
142
143 idle_duration > config.idle_timeout || lifetime > config.max_lifetime
144 }
145}
146
147struct PooledClient {
149 client: Box<dyn LLMClient>,
150 meta: PooledClientMeta,
151}
152
153impl std::fmt::Debug for PooledClient {
154 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155 f.debug_struct("PooledClient").field("meta", &self.meta).finish()
156 }
157}
158
159#[derive(Debug)]
161struct ProviderPool {
162 provider: Provider,
164 clients: Mutex<Vec<PooledClient>>,
166 semaphore: Arc<Semaphore>,
168 in_use_count: AtomicUsize,
170 total_created: AtomicU64,
172 config: PoolConfig,
174}
175
176impl ProviderPool {
177 fn new(provider: Provider, config: PoolConfig) -> Self {
178 let semaphore = Arc::new(Semaphore::new(config.max_connections_per_provider));
179 Self {
180 provider,
181 clients: Mutex::new(Vec::with_capacity(config.max_connections_per_provider)),
182 semaphore,
183 in_use_count: AtomicUsize::new(0),
184 total_created: AtomicU64::new(0),
185 config,
186 }
187 }
188
189 async fn acquire(&self) -> Result<(Box<dyn LLMClient>, OwnedSemaphorePermit)> {
191 let permit = tokio::time::timeout(
193 self.config.acquire_timeout,
194 self.semaphore.clone().acquire_owned(),
195 )
196 .await
197 .map_err(|_| AppError::LLM("Timeout waiting for available client in pool".to_string()))?
198 .map_err(|_| AppError::LLM("Pool semaphore closed".to_string()))?;
199
200 let maybe_client = {
202 let mut clients = self.clients.lock();
203 let mut found_idx = None;
205 for (idx, pooled) in clients.iter().enumerate() {
206 if !pooled.meta.is_stale(&self.config) {
207 found_idx = Some(idx);
208 break;
209 }
210 }
211
212 if let Some(idx) = found_idx {
213 Some(clients.swap_remove(idx))
214 } else {
215 clients.retain(|c| !c.meta.is_stale(&self.config));
217 None
218 }
219 };
220
221 let client = if let Some(mut pooled) = maybe_client {
222 pooled.meta.mark_used();
223 pooled.client
224 } else {
225 self.total_created.fetch_add(1, Ordering::Relaxed);
227 self.provider.create_client().await?
228 };
229
230 self.in_use_count.fetch_add(1, Ordering::Relaxed);
231 Ok((client, permit))
232 }
233
234 fn release(&self, client: Box<dyn LLMClient>) {
236 self.in_use_count.fetch_sub(1, Ordering::Relaxed);
237
238 let mut clients = self.clients.lock();
239
240 if clients.len() < self.config.max_connections_per_provider {
242 clients.push(PooledClient {
243 client,
244 meta: PooledClientMeta::new(),
245 });
246 }
247 }
249
250 fn cleanup_stale(&self) -> usize {
252 let mut clients = self.clients.lock();
253 let before = clients.len();
254 clients.retain(|c| !c.meta.is_stale(&self.config));
255 before - clients.len()
256 }
257
258 fn stats(&self) -> ProviderPoolStats {
260 let clients = self.clients.lock();
261 ProviderPoolStats {
262 available: clients.len(),
263 in_use: self.in_use_count.load(Ordering::Relaxed),
264 total_created: self.total_created.load(Ordering::Relaxed),
265 max_size: self.config.max_connections_per_provider,
266 }
267 }
268
269 fn drain(&self) {
271 let mut clients = self.clients.lock();
272 clients.clear();
273 }
274}
275
276#[derive(Debug, Clone)]
278pub struct ProviderPoolStats {
279 pub available: usize,
281 pub in_use: usize,
283 pub total_created: u64,
285 pub max_size: usize,
287}
288
289#[derive(Debug, Clone)]
291pub struct PoolStats {
292 pub providers: HashMap<String, ProviderPoolStats>,
294 pub total_available: usize,
296 pub total_in_use: usize,
298}
299
300pub struct PooledClientGuard {
302 client: Option<Box<dyn LLMClient>>,
303 pool: Arc<ProviderPool>,
304 _permit: OwnedSemaphorePermit,
305}
306
307impl std::fmt::Debug for PooledClientGuard {
308 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309 f.debug_struct("PooledClientGuard")
310 .field("has_client", &self.client.is_some())
311 .field("pool", &self.pool)
312 .finish()
313 }
314}
315
316impl PooledClientGuard {
317 pub fn client(&self) -> &dyn LLMClient {
319 self.client.as_ref().expect("Client already taken").as_ref()
320 }
321
322 pub fn client_mut(&mut self) -> &mut dyn LLMClient {
324 self.client.as_mut().expect("Client already taken").as_mut()
325 }
326
327 pub fn take(mut self) -> Box<dyn LLMClient> {
332 self.client.take().expect("Client already taken")
333 }
334}
335
336impl Drop for PooledClientGuard {
337 fn drop(&mut self) {
338 if let Some(client) = self.client.take() {
339 self.pool.release(client);
340 }
341 }
342}
343
344impl std::ops::Deref for PooledClientGuard {
345 type Target = Box<dyn LLMClient>;
346
347 fn deref(&self) -> &Self::Target {
348 self.client.as_ref().expect("Client already taken")
349 }
350}
351
352pub struct ClientPool {
357 config: PoolConfig,
358 providers: RwLock<HashMap<String, Arc<ProviderPool>>>,
359 shutdown: std::sync::atomic::AtomicBool,
360}
361
362impl ClientPool {
363 pub fn new(config: PoolConfig) -> Self {
365 Self {
366 config,
367 providers: RwLock::new(HashMap::new()),
368 shutdown: std::sync::atomic::AtomicBool::new(false),
369 }
370 }
371
372 pub fn with_defaults() -> Self {
374 Self::new(PoolConfig::default())
375 }
376
377 #[allow(unreachable_code, unused_variables)]
382 pub fn register_provider(&self, name: &str, provider: Provider) {
383 let pool = Arc::new(ProviderPool::new(provider, self.config.clone()));
384 let mut providers = self.providers.write();
385 providers.insert(name.to_string(), pool);
386 }
387
388 pub fn has_provider(&self, name: &str) -> bool {
390 self.providers.read().contains_key(name)
391 }
392
393 pub fn provider_names(&self) -> Vec<String> {
395 self.providers.read().keys().cloned().collect()
396 }
397
398 pub async fn get(&self, provider_name: &str) -> Result<PooledClientGuard> {
403 if self.shutdown.load(Ordering::Relaxed) {
404 return Err(AppError::LLM("Pool is shutting down".to_string()));
405 }
406
407 let pool = {
408 let providers = self.providers.read();
409 providers.get(provider_name).cloned().ok_or_else(|| {
410 AppError::Configuration(format!("Provider '{}' not registered in pool", provider_name))
411 })?
412 };
413
414 let (client, permit) = pool.acquire().await?;
415
416 Ok(PooledClientGuard {
417 client: Some(client),
418 pool,
419 _permit: permit,
420 })
421 }
422
423 pub fn stats(&self) -> PoolStats {
425 let providers = self.providers.read();
426 let mut stats = PoolStats {
427 providers: HashMap::new(),
428 total_available: 0,
429 total_in_use: 0,
430 };
431
432 for (name, pool) in providers.iter() {
433 let provider_stats = pool.stats();
434 stats.total_available += provider_stats.available;
435 stats.total_in_use += provider_stats.in_use;
436 stats.providers.insert(name.clone(), provider_stats);
437 }
438
439 stats
440 }
441
442 pub fn cleanup_stale(&self) -> usize {
446 let providers = self.providers.read();
447 providers.values().map(|p| p.cleanup_stale()).sum()
448 }
449
450 pub fn start_cleanup_task(self: &Arc<Self>) -> tokio::task::JoinHandle<()> {
454 let pool = Arc::clone(self);
455 let interval = pool.config.health_check_interval;
456
457 tokio::spawn(async move {
458 let mut interval_timer = tokio::time::interval(interval);
459 loop {
460 interval_timer.tick().await;
461
462 if pool.shutdown.load(Ordering::Relaxed) {
463 break;
464 }
465
466 let removed = pool.cleanup_stale();
467 if removed > 0 {
468 tracing::debug!("Pool cleanup: removed {} stale connections", removed);
469 }
470 }
471 })
472 }
473
474 pub fn shutdown(&self) {
479 self.shutdown.store(true, Ordering::Relaxed);
480
481 let providers = self.providers.read();
482 for pool in providers.values() {
483 pool.drain();
484 }
485 }
486
487 pub fn is_shutdown(&self) -> bool {
489 self.shutdown.load(Ordering::Relaxed)
490 }
491}
492
493impl Default for ClientPool {
494 fn default() -> Self {
495 Self::with_defaults()
496 }
497}
498
499pub struct ClientPoolBuilder {
501 config: PoolConfig,
502 providers: Vec<(String, Provider)>,
503}
504
505impl ClientPoolBuilder {
506 pub fn new() -> Self {
508 Self {
509 config: PoolConfig::default(),
510 providers: Vec::new(),
511 }
512 }
513
514 pub fn config(mut self, config: PoolConfig) -> Self {
516 self.config = config;
517 self
518 }
519
520 pub fn provider(mut self, name: impl Into<String>, provider: Provider) -> Self {
522 self.providers.push((name.into(), provider));
523 self
524 }
525
526 pub fn build(self) -> ClientPool {
528 let pool = ClientPool::new(self.config);
529 for (name, provider) in self.providers {
530 pool.register_provider(&name, provider);
531 }
532 pool
533 }
534
535 pub fn build_arc(self) -> Arc<ClientPool> {
537 Arc::new(self.build())
538 }
539}
540
541impl Default for ClientPoolBuilder {
542 fn default() -> Self {
543 Self::new()
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550
551 #[test]
552 fn test_pool_config_defaults() {
553 let config = PoolConfig::default();
554 assert_eq!(config.max_connections_per_provider, 10);
555 assert_eq!(config.min_idle_connections, 2);
556 assert_eq!(config.idle_timeout, Duration::from_secs(300));
557 assert_eq!(config.max_lifetime, Duration::from_secs(1800));
558 assert!(config.enable_health_check);
559 }
560
561 #[test]
562 fn test_pool_config_builder() {
563 let config = PoolConfig::default()
564 .with_max_connections(20)
565 .with_idle_timeout(Duration::from_secs(60))
566 .without_health_check();
567
568 assert_eq!(config.max_connections_per_provider, 20);
569 assert_eq!(config.idle_timeout, Duration::from_secs(60));
570 assert!(!config.enable_health_check);
571 }
572
573 #[test]
574 fn test_pooled_client_meta_stale_detection() {
575 let config = PoolConfig::default()
576 .with_idle_timeout(Duration::from_millis(10))
577 .with_max_lifetime(Duration::from_millis(50));
578
579 let meta = PooledClientMeta::new();
580
581 assert!(!meta.is_stale(&config));
583
584 std::thread::sleep(Duration::from_millis(15));
586 assert!(meta.is_stale(&config));
587 }
588
589 #[test]
590 fn test_pool_stats() {
591 let pool = ClientPool::with_defaults();
592 let stats = pool.stats();
593
594 assert_eq!(stats.total_available, 0);
595 assert_eq!(stats.total_in_use, 0);
596 assert!(stats.providers.is_empty());
597 }
598
599 #[test]
600 fn test_pool_shutdown() {
601 let pool = ClientPool::with_defaults();
602 assert!(!pool.is_shutdown());
603
604 pool.shutdown();
605 assert!(pool.is_shutdown());
606 }
607
608 #[cfg(feature = "ollama")]
609 #[test]
610 fn test_provider_registration() {
611 use crate::llm::client::ModelParams;
612
613 let pool = ClientPool::with_defaults();
614
615 let provider = Provider::Ollama {
616 base_url: "http://localhost:11434".to_string(),
617 model: "test".to_string(),
618 params: ModelParams::default(),
619 };
620
621 pool.register_provider("ollama", provider);
622
623 assert!(pool.has_provider("ollama"));
624 assert!(!pool.has_provider("openai"));
625 assert_eq!(pool.provider_names(), vec!["ollama"]);
626 }
627
628 #[cfg(feature = "ollama")]
629 #[test]
630 fn test_builder_pattern() {
631 use crate::llm::client::ModelParams;
632
633 let provider = Provider::Ollama {
634 base_url: "http://localhost:11434".to_string(),
635 model: "test".to_string(),
636 params: ModelParams::default(),
637 };
638
639 let pool = ClientPoolBuilder::new()
640 .config(PoolConfig::default().with_max_connections(5))
641 .provider("ollama", provider)
642 .build();
643
644 assert!(pool.has_provider("ollama"));
645 }
646
647 #[cfg(feature = "ollama")]
648 #[tokio::test]
649 async fn test_get_unregistered_provider_error() {
650 let pool = ClientPool::with_defaults();
651
652 let result = pool.get("nonexistent").await;
653 assert!(result.is_err());
654
655 let err = result.unwrap_err();
656 assert!(matches!(err, AppError::Configuration(_)));
657 }
658
659 #[tokio::test]
660 async fn test_get_after_shutdown() {
661 let pool = ClientPool::with_defaults();
662 pool.shutdown();
663
664 let result = pool.get("anything").await;
666 assert!(result.is_err());
667
668 if let Err(AppError::LLM(msg)) = result {
669 assert!(msg.contains("shutting down"));
670 } else {
671 panic!("Expected LLM error");
672 }
673 }
674}