1use crate::llm::client::{LLMClient, Provider};
38use crate::types::{AppError, Result};
39use parking_lot::{Mutex, RwLock};
40use std::collections::HashMap;
41use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
42use std::sync::Arc;
43use std::time::{Duration, Instant};
44use tokio::sync::{OwnedSemaphorePermit, Semaphore};
45
46#[derive(Debug, Clone)]
48pub struct PoolConfig {
49 pub max_connections_per_provider: usize,
51
52 pub min_idle_connections: usize,
54
55 pub idle_timeout: Duration,
57
58 pub max_lifetime: Duration,
60
61 pub health_check_interval: Duration,
63
64 pub acquire_timeout: Duration,
66
67 pub enable_health_check: bool,
69}
70
71impl Default for PoolConfig {
72 fn default() -> Self {
73 Self {
74 max_connections_per_provider: 10,
75 min_idle_connections: 2,
76 idle_timeout: Duration::from_secs(300), max_lifetime: Duration::from_secs(1800), health_check_interval: Duration::from_secs(60),
79 acquire_timeout: Duration::from_secs(30),
80 enable_health_check: true,
81 }
82 }
83}
84
85impl PoolConfig {
86 pub fn with_max_connections(mut self, max: usize) -> Self {
88 self.max_connections_per_provider = max;
89 self
90 }
91
92 pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
94 self.idle_timeout = timeout;
95 self
96 }
97
98 pub fn with_max_lifetime(mut self, lifetime: Duration) -> Self {
100 self.max_lifetime = lifetime;
101 self
102 }
103
104 pub fn without_health_check(mut self) -> Self {
106 self.enable_health_check = false;
107 self
108 }
109}
110
111#[derive(Debug)]
113struct PooledClientMeta {
114 created_at: Instant,
116 last_used: Instant,
118 #[allow(dead_code)] use_count: AtomicU64,
121}
122
123impl PooledClientMeta {
124 fn new() -> Self {
125 let now = Instant::now();
126 Self {
127 created_at: now,
128 last_used: now,
129 use_count: AtomicU64::new(0),
130 }
131 }
132
133 fn mark_used(&mut self) {
134 self.last_used = Instant::now();
135 self.use_count.fetch_add(1, Ordering::Relaxed);
136 }
137
138 fn is_stale(&self, config: &PoolConfig) -> bool {
139 let now = Instant::now();
140 let idle_duration = now.duration_since(self.last_used);
141 let lifetime = now.duration_since(self.created_at);
142
143 idle_duration > config.idle_timeout || lifetime > config.max_lifetime
144 }
145}
146
147struct PooledClient {
149 client: Box<dyn LLMClient>,
150 meta: PooledClientMeta,
151}
152
153impl std::fmt::Debug for PooledClient {
154 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
155 f.debug_struct("PooledClient")
156 .field("meta", &self.meta)
157 .finish()
158 }
159}
160
161#[derive(Debug)]
163struct ProviderPool {
164 provider: Provider,
166 clients: Mutex<Vec<PooledClient>>,
168 semaphore: Arc<Semaphore>,
170 in_use_count: AtomicUsize,
172 total_created: AtomicU64,
174 config: PoolConfig,
176}
177
178impl ProviderPool {
179 fn new(provider: Provider, config: PoolConfig) -> Self {
180 let semaphore = Arc::new(Semaphore::new(config.max_connections_per_provider));
181 Self {
182 provider,
183 clients: Mutex::new(Vec::with_capacity(config.max_connections_per_provider)),
184 semaphore,
185 in_use_count: AtomicUsize::new(0),
186 total_created: AtomicU64::new(0),
187 config,
188 }
189 }
190
191 async fn acquire(&self) -> Result<(Box<dyn LLMClient>, OwnedSemaphorePermit)> {
193 let permit = tokio::time::timeout(
195 self.config.acquire_timeout,
196 self.semaphore.clone().acquire_owned(),
197 )
198 .await
199 .map_err(|_| AppError::LLM("Timeout waiting for available client in pool".to_string()))?
200 .map_err(|_| AppError::LLM("Pool semaphore closed".to_string()))?;
201
202 let maybe_client = {
204 let mut clients = self.clients.lock();
205 let mut found_idx = None;
207 for (idx, pooled) in clients.iter().enumerate() {
208 if !pooled.meta.is_stale(&self.config) {
209 found_idx = Some(idx);
210 break;
211 }
212 }
213
214 if let Some(idx) = found_idx {
215 Some(clients.swap_remove(idx))
216 } else {
217 clients.retain(|c| !c.meta.is_stale(&self.config));
219 None
220 }
221 };
222
223 let client = if let Some(mut pooled) = maybe_client {
224 pooled.meta.mark_used();
225 pooled.client
226 } else {
227 self.total_created.fetch_add(1, Ordering::Relaxed);
229 self.provider.create_client().await?
230 };
231
232 self.in_use_count.fetch_add(1, Ordering::Relaxed);
233 Ok((client, permit))
234 }
235
236 fn release(&self, client: Box<dyn LLMClient>) {
238 self.in_use_count.fetch_sub(1, Ordering::Relaxed);
239
240 let mut clients = self.clients.lock();
241
242 if clients.len() < self.config.max_connections_per_provider {
244 clients.push(PooledClient {
245 client,
246 meta: PooledClientMeta::new(),
247 });
248 }
249 }
251
252 fn cleanup_stale(&self) -> usize {
254 let mut clients = self.clients.lock();
255 let before = clients.len();
256 clients.retain(|c| !c.meta.is_stale(&self.config));
257 before - clients.len()
258 }
259
260 fn stats(&self) -> ProviderPoolStats {
262 let clients = self.clients.lock();
263 ProviderPoolStats {
264 available: clients.len(),
265 in_use: self.in_use_count.load(Ordering::Relaxed),
266 total_created: self.total_created.load(Ordering::Relaxed),
267 max_size: self.config.max_connections_per_provider,
268 }
269 }
270
271 fn drain(&self) {
273 let mut clients = self.clients.lock();
274 clients.clear();
275 }
276}
277
278#[derive(Debug, Clone)]
280pub struct ProviderPoolStats {
281 pub available: usize,
283 pub in_use: usize,
285 pub total_created: u64,
287 pub max_size: usize,
289}
290
291#[derive(Debug, Clone)]
293pub struct PoolStats {
294 pub providers: HashMap<String, ProviderPoolStats>,
296 pub total_available: usize,
298 pub total_in_use: usize,
300}
301
302pub struct PooledClientGuard {
304 client: Option<Box<dyn LLMClient>>,
305 pool: Arc<ProviderPool>,
306 _permit: OwnedSemaphorePermit,
307}
308
309impl std::fmt::Debug for PooledClientGuard {
310 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311 f.debug_struct("PooledClientGuard")
312 .field("has_client", &self.client.is_some())
313 .field("pool", &self.pool)
314 .finish()
315 }
316}
317
318impl PooledClientGuard {
319 pub fn client(&self) -> &dyn LLMClient {
321 self.client.as_ref().expect("Client already taken").as_ref()
322 }
323
324 pub fn client_mut(&mut self) -> &mut dyn LLMClient {
326 self.client.as_mut().expect("Client already taken").as_mut()
327 }
328
329 pub fn take(mut self) -> Box<dyn LLMClient> {
334 self.client.take().expect("Client already taken")
335 }
336}
337
338impl Drop for PooledClientGuard {
339 fn drop(&mut self) {
340 if let Some(client) = self.client.take() {
341 self.pool.release(client);
342 }
343 }
344}
345
346impl std::ops::Deref for PooledClientGuard {
347 type Target = Box<dyn LLMClient>;
348
349 fn deref(&self) -> &Self::Target {
350 self.client.as_ref().expect("Client already taken")
351 }
352}
353
354pub struct ClientPool {
359 config: PoolConfig,
360 providers: RwLock<HashMap<String, Arc<ProviderPool>>>,
361 shutdown: std::sync::atomic::AtomicBool,
362}
363
364impl ClientPool {
365 pub fn new(config: PoolConfig) -> Self {
367 Self {
368 config,
369 providers: RwLock::new(HashMap::new()),
370 shutdown: std::sync::atomic::AtomicBool::new(false),
371 }
372 }
373
374 pub fn with_defaults() -> Self {
376 Self::new(PoolConfig::default())
377 }
378
379 #[allow(unreachable_code, unused_variables)]
384 pub fn register_provider(&self, name: &str, provider: Provider) {
385 let pool = Arc::new(ProviderPool::new(provider, self.config.clone()));
386 let mut providers = self.providers.write();
387 providers.insert(name.to_string(), pool);
388 }
389
390 pub fn has_provider(&self, name: &str) -> bool {
392 self.providers.read().contains_key(name)
393 }
394
395 pub fn provider_names(&self) -> Vec<String> {
397 self.providers.read().keys().cloned().collect()
398 }
399
400 pub async fn get(&self, provider_name: &str) -> Result<PooledClientGuard> {
405 if self.shutdown.load(Ordering::Relaxed) {
406 return Err(AppError::LLM("Pool is shutting down".to_string()));
407 }
408
409 let pool = {
410 let providers = self.providers.read();
411 providers.get(provider_name).cloned().ok_or_else(|| {
412 AppError::Configuration(format!(
413 "Provider '{}' not registered in pool",
414 provider_name
415 ))
416 })?
417 };
418
419 let (client, permit) = pool.acquire().await?;
420
421 Ok(PooledClientGuard {
422 client: Some(client),
423 pool,
424 _permit: permit,
425 })
426 }
427
428 pub fn stats(&self) -> PoolStats {
430 let providers = self.providers.read();
431 let mut stats = PoolStats {
432 providers: HashMap::new(),
433 total_available: 0,
434 total_in_use: 0,
435 };
436
437 for (name, pool) in providers.iter() {
438 let provider_stats = pool.stats();
439 stats.total_available += provider_stats.available;
440 stats.total_in_use += provider_stats.in_use;
441 stats.providers.insert(name.clone(), provider_stats);
442 }
443
444 stats
445 }
446
447 pub fn cleanup_stale(&self) -> usize {
451 let providers = self.providers.read();
452 providers.values().map(|p| p.cleanup_stale()).sum()
453 }
454
455 pub fn start_cleanup_task(self: &Arc<Self>) -> tokio::task::JoinHandle<()> {
459 let pool = Arc::clone(self);
460 let interval = pool.config.health_check_interval;
461
462 tokio::spawn(async move {
463 let mut interval_timer = tokio::time::interval(interval);
464 loop {
465 interval_timer.tick().await;
466
467 if pool.shutdown.load(Ordering::Relaxed) {
468 break;
469 }
470
471 let removed = pool.cleanup_stale();
472 if removed > 0 {
473 tracing::debug!("Pool cleanup: removed {} stale connections", removed);
474 }
475 }
476 })
477 }
478
479 pub fn shutdown(&self) {
484 self.shutdown.store(true, Ordering::Relaxed);
485
486 let providers = self.providers.read();
487 for pool in providers.values() {
488 pool.drain();
489 }
490 }
491
492 pub fn is_shutdown(&self) -> bool {
494 self.shutdown.load(Ordering::Relaxed)
495 }
496}
497
498impl Default for ClientPool {
499 fn default() -> Self {
500 Self::with_defaults()
501 }
502}
503
504pub struct ClientPoolBuilder {
506 config: PoolConfig,
507 providers: Vec<(String, Provider)>,
508}
509
510impl ClientPoolBuilder {
511 pub fn new() -> Self {
513 Self {
514 config: PoolConfig::default(),
515 providers: Vec::new(),
516 }
517 }
518
519 pub fn config(mut self, config: PoolConfig) -> Self {
521 self.config = config;
522 self
523 }
524
525 pub fn provider(mut self, name: impl Into<String>, provider: Provider) -> Self {
527 self.providers.push((name.into(), provider));
528 self
529 }
530
531 pub fn build(self) -> ClientPool {
533 let pool = ClientPool::new(self.config);
534 for (name, provider) in self.providers {
535 pool.register_provider(&name, provider);
536 }
537 pool
538 }
539
540 pub fn build_arc(self) -> Arc<ClientPool> {
542 Arc::new(self.build())
543 }
544}
545
546impl Default for ClientPoolBuilder {
547 fn default() -> Self {
548 Self::new()
549 }
550}
551
552#[cfg(test)]
553mod tests {
554 use super::*;
555
556 #[test]
557 fn test_pool_config_defaults() {
558 let config = PoolConfig::default();
559 assert_eq!(config.max_connections_per_provider, 10);
560 assert_eq!(config.min_idle_connections, 2);
561 assert_eq!(config.idle_timeout, Duration::from_secs(300));
562 assert_eq!(config.max_lifetime, Duration::from_secs(1800));
563 assert!(config.enable_health_check);
564 }
565
566 #[test]
567 fn test_pool_config_builder() {
568 let config = PoolConfig::default()
569 .with_max_connections(20)
570 .with_idle_timeout(Duration::from_secs(60))
571 .without_health_check();
572
573 assert_eq!(config.max_connections_per_provider, 20);
574 assert_eq!(config.idle_timeout, Duration::from_secs(60));
575 assert!(!config.enable_health_check);
576 }
577
578 #[test]
579 fn test_pooled_client_meta_stale_detection() {
580 let config = PoolConfig::default()
581 .with_idle_timeout(Duration::from_millis(10))
582 .with_max_lifetime(Duration::from_millis(50));
583
584 let meta = PooledClientMeta::new();
585
586 assert!(!meta.is_stale(&config));
588
589 std::thread::sleep(Duration::from_millis(15));
591 assert!(meta.is_stale(&config));
592 }
593
594 #[test]
595 fn test_pool_stats() {
596 let pool = ClientPool::with_defaults();
597 let stats = pool.stats();
598
599 assert_eq!(stats.total_available, 0);
600 assert_eq!(stats.total_in_use, 0);
601 assert!(stats.providers.is_empty());
602 }
603
604 #[test]
605 fn test_pool_shutdown() {
606 let pool = ClientPool::with_defaults();
607 assert!(!pool.is_shutdown());
608
609 pool.shutdown();
610 assert!(pool.is_shutdown());
611 }
612
613 #[cfg(feature = "ollama")]
614 #[test]
615 fn test_provider_registration() {
616 use crate::llm::client::ModelParams;
617
618 let pool = ClientPool::with_defaults();
619
620 let provider = Provider::Ollama {
621 base_url: "http://localhost:11434".to_string(),
622 model: "test".to_string(),
623 params: ModelParams::default(),
624 };
625
626 pool.register_provider("ollama", provider);
627
628 assert!(pool.has_provider("ollama"));
629 assert!(!pool.has_provider("openai"));
630 assert_eq!(pool.provider_names(), vec!["ollama"]);
631 }
632
633 #[cfg(feature = "ollama")]
634 #[test]
635 fn test_builder_pattern() {
636 use crate::llm::client::ModelParams;
637
638 let provider = Provider::Ollama {
639 base_url: "http://localhost:11434".to_string(),
640 model: "test".to_string(),
641 params: ModelParams::default(),
642 };
643
644 let pool = ClientPoolBuilder::new()
645 .config(PoolConfig::default().with_max_connections(5))
646 .provider("ollama", provider)
647 .build();
648
649 assert!(pool.has_provider("ollama"));
650 }
651
652 #[cfg(feature = "ollama")]
653 #[tokio::test]
654 async fn test_get_unregistered_provider_error() {
655 let pool = ClientPool::with_defaults();
656
657 let result = pool.get("nonexistent").await;
658 assert!(result.is_err());
659
660 let err = result.unwrap_err();
661 assert!(matches!(err, AppError::Configuration(_)));
662 }
663
664 #[tokio::test]
665 async fn test_get_after_shutdown() {
666 let pool = ClientPool::with_defaults();
667 pool.shutdown();
668
669 let result = pool.get("anything").await;
671 assert!(result.is_err());
672
673 if let Err(AppError::LLM(msg)) = result {
674 assert!(msg.contains("shutting down"));
675 } else {
676 panic!("Expected LLM error");
677 }
678 }
679}