Skip to main content

network_protocol/service/
pool.rs

1//! # Connection Pooling
2//!
3//! Generic connection pool for all transport types.
4//!
5//! This module provides a thread-safe, async-aware connection pooling mechanism
6//! that eliminates expensive TLS handshakes for repeated connections to the same
7//! endpoint. Essential for database and RPC scenarios where clients make many
8//! short-lived requests.
9//!
10//! ## Features
11//! - Generic over any transport type `T`
12//! - Configurable pool size (min/max connections)
13//! - TTL-based connection expiration
14//! - Health checks with automatic eviction
15//! - FIFO acquisition with LRU eviction on overflow
16//! - Thread-safe async operations
17
18use std::collections::VecDeque;
19use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
20use std::sync::Arc;
21use std::time::{Duration, Instant};
22use tokio::sync::{Mutex, Semaphore};
23use tracing::{debug, error, warn};
24
25use crate::error::{ProtocolError, Result};
26
27/// Configuration for connection pooling
28#[derive(Debug, Clone)]
29pub struct PoolConfig {
30    /// Minimum connections to maintain (pre-warmed)
31    pub min_size: usize,
32    /// Maximum connections in pool
33    pub max_size: usize,
34    /// Time-to-live for idle connections before eviction
35    pub idle_timeout: Duration,
36    /// Maximum lifetime of a connection regardless of idle time
37    pub max_lifetime: Duration,
38    /// Maximum concurrent waiters for connections (backpressure limit)
39    pub max_waiters: usize,
40    /// Circuit breaker: consecutive failures before opening
41    pub circuit_breaker_threshold: usize,
42    /// Circuit breaker: time to wait before trying again
43    pub circuit_breaker_timeout: Duration,
44}
45
46impl Default for PoolConfig {
47    fn default() -> Self {
48        Self {
49            min_size: 5,
50            max_size: 50,
51            idle_timeout: Duration::from_secs(300), // 5 minutes
52            max_lifetime: Duration::from_secs(3600), // 1 hour
53            max_waiters: 1000,                      // Prevent OOM
54            circuit_breaker_threshold: 5,
55            circuit_breaker_timeout: Duration::from_secs(10),
56        }
57    }
58}
59
60impl PoolConfig {
61    /// Validate configuration parameters
62    pub fn validate(&self) -> Result<()> {
63        let mut errors = Vec::new();
64
65        // Validate pool sizes
66        if self.max_size == 0 {
67            errors.push("Pool max_size must be greater than 0".to_string());
68        }
69
70        if self.min_size > self.max_size {
71            errors.push(format!(
72                "Pool min_size ({}) cannot exceed max_size ({})",
73                self.min_size, self.max_size
74            ));
75        }
76
77        // Validate reasonable limits
78        if self.max_size > 10_000 {
79            errors.push(format!(
80                "Pool max_size ({}) exceeds recommended limit (10,000)",
81                self.max_size
82            ));
83        }
84
85        if self.max_waiters == 0 {
86            errors.push("Pool max_waiters must be greater than 0".to_string());
87        }
88
89        if self.max_waiters > 1_000_000 {
90            errors.push(format!(
91                "Pool max_waiters ({}) exceeds recommended limit (1,000,000)",
92                self.max_waiters
93            ));
94        }
95
96        // Validate timeouts
97        if self.idle_timeout.is_zero() {
98            errors.push("Pool idle_timeout must be greater than 0".to_string());
99        }
100
101        if self.max_lifetime.is_zero() {
102            errors.push("Pool max_lifetime must be greater than 0".to_string());
103        }
104
105        if self.idle_timeout >= self.max_lifetime {
106            errors.push(format!(
107                "Pool idle_timeout ({:?}) should be less than max_lifetime ({:?})",
108                self.idle_timeout, self.max_lifetime
109            ));
110        }
111
112        if self.idle_timeout.as_secs() > 3600 {
113            errors.push(format!(
114                "Pool idle_timeout ({} seconds) is unusually long (recommended: < 1 hour)",
115                self.idle_timeout.as_secs()
116            ));
117        }
118
119        if self.max_lifetime.as_secs() > 86400 {
120            errors.push(format!(
121                "Pool max_lifetime ({} seconds) is unusually long (recommended: < 24 hours)",
122                self.max_lifetime.as_secs()
123            ));
124        }
125
126        // Validate circuit breaker settings
127        if self.circuit_breaker_threshold == 0 {
128            errors.push("Circuit breaker threshold must be greater than 0".to_string());
129        }
130
131        if self.circuit_breaker_threshold > 100 {
132            errors.push(format!(
133                "Circuit breaker threshold ({}) is unusually high (recommended: < 100)",
134                self.circuit_breaker_threshold
135            ));
136        }
137
138        if self.circuit_breaker_timeout.is_zero() {
139            errors.push("Circuit breaker timeout must be greater than 0".to_string());
140        }
141
142        if self.circuit_breaker_timeout.as_secs() > 300 {
143            errors.push(format!(
144                "Circuit breaker timeout ({} seconds) is unusually long (recommended: < 5 minutes)",
145                self.circuit_breaker_timeout.as_secs()
146            ));
147        }
148
149        // Return aggregated errors
150        if errors.is_empty() {
151            Ok(())
152        } else {
153            Err(ProtocolError::ConfigError(format!(
154                "Pool configuration validation failed:\n  - {}",
155                errors.join("\n  - ")
156            )))
157        }
158    }
159}
160
161/// Pooled connection wrapper with metadata
162struct PooledConnection<T> {
163    connection: T,
164    created_at: Instant,
165    last_used_at: Instant,
166}
167
168impl<T> PooledConnection<T> {
169    fn new(connection: T) -> Self {
170        let now = Instant::now();
171        Self {
172            connection,
173            created_at: now,
174            last_used_at: now,
175        }
176    }
177
178    fn is_expired(&self, config: &PoolConfig) -> bool {
179        let now = Instant::now();
180        // Check max lifetime exceeded
181        if now.duration_since(self.created_at) > config.max_lifetime {
182            return true;
183        }
184        // Check idle timeout exceeded
185        if now.duration_since(self.last_used_at) > config.idle_timeout {
186            return true;
187        }
188        false
189    }
190
191    fn touch(&mut self) {
192        self.last_used_at = Instant::now();
193    }
194}
195
196/// Factory trait for creating new connections
197pub trait ConnectionFactory<T>: Send + Sync {
198    /// Create a new connection asynchronously
199    fn create(&self) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<T>> + Send>>;
200
201    /// Check if a connection is still healthy
202    fn is_healthy(&self, _conn: &T) -> bool {
203        true
204    }
205}
206
207/// Pool metrics for monitoring and capacity planning
208#[derive(Debug, Default)]
209pub struct PoolMetrics {
210    /// Total connections created
211    pub connections_created: AtomicU64,
212    /// Total connections reused from pool
213    pub connections_reused: AtomicU64,
214    /// Total connections evicted (expired/unhealthy)
215    pub connections_evicted: AtomicU64,
216    /// Total acquisition errors
217    pub acquisition_errors: AtomicU64,
218    /// Current active (checked out) connections
219    pub active_connections: AtomicUsize,
220    /// Current idle (in pool) connections
221    pub idle_connections: AtomicUsize,
222    /// Total wait time in microseconds (for avg calculation)
223    pub total_wait_time_us: AtomicU64,
224    /// Total successful acquisitions
225    pub total_acquisitions: AtomicU64,
226}
227
228impl PoolMetrics {
229    pub fn new() -> Self {
230        Self::default()
231    }
232
233    pub fn average_wait_time_us(&self) -> u64 {
234        let total = self.total_acquisitions.load(Ordering::Relaxed);
235        if total == 0 {
236            return 0;
237        }
238        self.total_wait_time_us.load(Ordering::Relaxed) / total
239    }
240
241    pub fn utilization_percent(&self) -> f64 {
242        let active = self.active_connections.load(Ordering::Relaxed) as f64;
243        let idle = self.idle_connections.load(Ordering::Relaxed) as f64;
244        let total = active + idle;
245        if total == 0.0 {
246            return 0.0;
247        }
248        (active / total) * 100.0
249    }
250}
251
252/// Circuit breaker for fail-fast behavior
253#[derive(Debug)]
254struct CircuitBreaker {
255    consecutive_failures: AtomicUsize,
256    threshold: usize,
257    timeout: Duration,
258    opened_at: Mutex<Option<Instant>>,
259}
260
261impl CircuitBreaker {
262    fn new(threshold: usize, timeout: Duration) -> Self {
263        Self {
264            consecutive_failures: AtomicUsize::new(0),
265            threshold,
266            timeout,
267            opened_at: Mutex::new(None),
268        }
269    }
270
271    async fn check(&self) -> Result<()> {
272        let mut opened_at = self.opened_at.lock().await;
273        if let Some(opened_time) = *opened_at {
274            // Circuit is open, check if timeout elapsed
275            if opened_time.elapsed() < self.timeout {
276                return Err(ProtocolError::CircuitBreakerOpen);
277            }
278            // Timeout elapsed, enter half-open state
279            *opened_at = None;
280            self.consecutive_failures.store(0, Ordering::SeqCst);
281            debug!("Circuit breaker entering half-open state");
282        }
283        Ok(())
284    }
285
286    async fn record_success(&self) {
287        self.consecutive_failures.store(0, Ordering::SeqCst);
288        let mut opened_at = self.opened_at.lock().await;
289        if opened_at.is_some() {
290            *opened_at = None;
291            debug!("Circuit breaker closed after successful operation");
292        }
293    }
294
295    async fn record_failure(&self) {
296        let failures = self.consecutive_failures.fetch_add(1, Ordering::SeqCst) + 1;
297        if failures >= self.threshold {
298            let mut opened_at = self.opened_at.lock().await;
299            *opened_at = Some(Instant::now());
300            error!(
301                "Circuit breaker opened after {} consecutive failures",
302                failures
303            );
304        }
305    }
306}
307
308/// Generic connection pool for any transport type
309pub struct ConnectionPool<T> {
310    config: PoolConfig,
311    factory: Arc<dyn ConnectionFactory<T>>,
312    connections: Arc<Mutex<VecDeque<PooledConnection<T>>>>,
313    metrics: Arc<PoolMetrics>,
314    circuit_breaker: Arc<CircuitBreaker>,
315    backpressure: Arc<Semaphore>,
316}
317
318impl<T: Send + 'static> ConnectionPool<T> {
319    /// Create a new connection pool
320    pub fn new(factory: Arc<dyn ConnectionFactory<T>>, config: PoolConfig) -> Result<Self> {
321        config.validate()?;
322
323        let metrics = Arc::new(PoolMetrics::new());
324        let circuit_breaker = Arc::new(CircuitBreaker::new(
325            config.circuit_breaker_threshold,
326            config.circuit_breaker_timeout,
327        ));
328
329        let pool = Self {
330            config: config.clone(),
331            factory: factory.clone(),
332            connections: Arc::new(Mutex::new(VecDeque::new())),
333            metrics: metrics.clone(),
334            circuit_breaker,
335            backpressure: Arc::new(Semaphore::new(config.max_waiters)),
336        };
337
338        // Spawn connection warming task
339        if config.min_size > 0 {
340            let factory_clone = factory;
341            let connections_clone = pool.connections.clone();
342            let metrics_clone = metrics;
343            let min_size = config.min_size;
344
345            tokio::spawn(async move {
346                debug!("Warming connection pool with {} connections", min_size);
347                for _ in 0..min_size {
348                    match factory_clone.create().await {
349                        Ok(conn) => {
350                            let mut connections = connections_clone.lock().await;
351                            connections.push_back(PooledConnection::new(conn));
352                            metrics_clone
353                                .connections_created
354                                .fetch_add(1, Ordering::Relaxed);
355                            metrics_clone
356                                .idle_connections
357                                .fetch_add(1, Ordering::Relaxed);
358                        }
359                        Err(e) => {
360                            warn!("Failed to warm connection: {}", e);
361                            break;
362                        }
363                    }
364                }
365                debug!("Connection pool warming complete");
366            });
367        }
368
369        Ok(pool)
370    }
371
372    /// Get a connection from the pool or create a new one
373    pub async fn acquire(&self) -> Result<PooledConnectionGuard<T>> {
374        let start = Instant::now();
375
376        // Enforce backpressure limit
377        let _permit = self
378            .backpressure
379            .acquire()
380            .await
381            .map_err(|_| ProtocolError::PoolExhausted)?;
382
383        // Check circuit breaker
384        self.circuit_breaker.check().await?;
385
386        let mut connections = self.connections.lock().await;
387
388        // Try to find a valid connection in the pool (LRU: take from back)
389        while let Some(mut pooled) = connections.pop_back() {
390            if !pooled.is_expired(&self.config) && self.factory.is_healthy(&pooled.connection) {
391                pooled.touch();
392                self.metrics
393                    .connections_reused
394                    .fetch_add(1, Ordering::Relaxed);
395                self.metrics
396                    .idle_connections
397                    .fetch_sub(1, Ordering::Relaxed);
398                self.metrics
399                    .active_connections
400                    .fetch_add(1, Ordering::Relaxed);
401                self.metrics
402                    .total_acquisitions
403                    .fetch_add(1, Ordering::Relaxed);
404                self.metrics
405                    .total_wait_time_us
406                    .fetch_add(start.elapsed().as_micros() as u64, Ordering::Relaxed);
407
408                debug!("Reused connection from pool (LRU)");
409                return Ok(PooledConnectionGuard {
410                    connection: Some(pooled.connection),
411                    pool: self.connections.clone(),
412                    metrics: self.metrics.clone(),
413                });
414            }
415            debug!("Evicted expired/unhealthy connection from pool");
416            self.metrics
417                .connections_evicted
418                .fetch_add(1, Ordering::Relaxed);
419            self.metrics
420                .idle_connections
421                .fetch_sub(1, Ordering::Relaxed);
422        }
423
424        // No valid connection found, create new one
425        drop(connections); // Release lock before creating new connection
426
427        match self.factory.create().await {
428            Ok(new_conn) => {
429                self.circuit_breaker.record_success().await;
430                self.metrics
431                    .connections_created
432                    .fetch_add(1, Ordering::Relaxed);
433                self.metrics
434                    .active_connections
435                    .fetch_add(1, Ordering::Relaxed);
436                self.metrics
437                    .total_acquisitions
438                    .fetch_add(1, Ordering::Relaxed);
439                self.metrics
440                    .total_wait_time_us
441                    .fetch_add(start.elapsed().as_micros() as u64, Ordering::Relaxed);
442
443                debug!("Created new connection for pool");
444
445                Ok(PooledConnectionGuard {
446                    connection: Some(new_conn),
447                    pool: self.connections.clone(),
448                    metrics: self.metrics.clone(),
449                })
450            }
451            Err(e) => {
452                self.circuit_breaker.record_failure().await;
453                self.metrics
454                    .acquisition_errors
455                    .fetch_add(1, Ordering::Relaxed);
456                Err(e)
457            }
458        }
459    }
460
461    /// Get pool metrics
462    pub fn metrics(&self) -> Arc<PoolMetrics> {
463        self.metrics.clone()
464    }
465
466    /// Current number of connections in pool
467    pub async fn size(&self) -> usize {
468        self.connections.lock().await.len()
469    }
470
471    /// Clear all connections from the pool
472    pub async fn clear(&self) {
473        self.connections.lock().await.clear();
474        debug!("Cleared all connections from pool");
475    }
476
477    /// Get pool configuration
478    pub fn config(&self) -> &PoolConfig {
479        &self.config
480    }
481}
482
483/// RAII guard for pooled connections
484///
485/// Returns the connection to the pool on drop.
486pub struct PooledConnectionGuard<T: Send + 'static> {
487    connection: Option<T>,
488    pool: Arc<Mutex<VecDeque<PooledConnection<T>>>>,
489    metrics: Arc<PoolMetrics>,
490}
491
492impl<T: Send + 'static> PooledConnectionGuard<T> {
493    /// Get a reference to the underlying connection
494    pub fn get(&self) -> Option<&T> {
495        self.connection.as_ref()
496    }
497
498    /// Get a mutable reference to the underlying connection
499    pub fn get_mut(&mut self) -> Option<&mut T> {
500        self.connection.as_mut()
501    }
502
503    /// Extract the connection (won't be returned to pool)
504    pub fn into_inner(mut self) -> Option<T> {
505        self.connection.take()
506    }
507}
508
509impl<T: Send + 'static> AsRef<T> for PooledConnectionGuard<T> {
510    #[allow(clippy::expect_used)] // Connection is guaranteed to exist unless into_inner() was called
511    fn as_ref(&self) -> &T {
512        self.connection.as_ref().expect("Connection should exist")
513    }
514}
515
516impl<T: Send + 'static> AsMut<T> for PooledConnectionGuard<T> {
517    #[allow(clippy::expect_used)] // Connection is guaranteed to exist unless into_inner() was called
518    fn as_mut(&mut self) -> &mut T {
519        self.connection.as_mut().expect("Connection should exist")
520    }
521}
522
523impl<T: Send + 'static> Drop for PooledConnectionGuard<T> {
524    fn drop(&mut self) {
525        if let Some(conn) = self.connection.take() {
526            let pool = self.pool.clone();
527            let metrics = self.metrics.clone();
528            let pooled = PooledConnection::new(conn);
529
530            // Update metrics
531            metrics.active_connections.fetch_sub(1, Ordering::Relaxed);
532
533            // Try to return connection to pool (async context may not be available)
534            // This spawns a background task to handle the return
535            tokio::spawn(async move {
536                let mut connections = pool.lock().await;
537                if connections.len() < 100 {
538                    // Reasonable max to prevent memory issues
539                    connections.push_back(pooled);
540                    metrics.idle_connections.fetch_add(1, Ordering::Relaxed);
541                } else {
542                    warn!("Connection pool at capacity, discarding connection");
543                }
544            });
545        }
546    }
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552    use std::sync::atomic::{AtomicUsize, Ordering};
553
554    #[allow(dead_code)]
555    struct TestConnection {
556        id: usize,
557    }
558
559    struct TestFactory {
560        counter: Arc<AtomicUsize>,
561    }
562
563    impl TestFactory {
564        fn new() -> Self {
565            Self {
566                counter: Arc::new(AtomicUsize::new(0)),
567            }
568        }
569
570        fn count(&self) -> usize {
571            self.counter.load(Ordering::SeqCst)
572        }
573    }
574
575    impl ConnectionFactory<TestConnection> for TestFactory {
576        fn create(
577            &self,
578        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<TestConnection>> + Send>>
579        {
580            let id = self.counter.fetch_add(1, Ordering::SeqCst);
581            Box::pin(async move { Ok(TestConnection { id }) })
582        }
583    }
584
585    #[tokio::test]
586    async fn test_pool_creation() {
587        let factory = Arc::new(TestFactory::new());
588        let pool = ConnectionPool::new(
589            factory.clone(),
590            PoolConfig {
591                min_size: 2,
592                max_size: 10,
593                idle_timeout: Duration::from_secs(60),
594                max_lifetime: Duration::from_secs(600),
595                ..Default::default()
596            },
597        );
598
599        assert!(pool.is_ok());
600    }
601
602    #[tokio::test]
603    #[allow(clippy::unwrap_used)] // Test code
604    async fn test_pool_acquire_creates_connection() {
605        let factory = Arc::new(TestFactory::new());
606        let pool = ConnectionPool::new(factory.clone(), PoolConfig::default()).unwrap();
607
608        let guard = pool.acquire().await.unwrap();
609        assert!(guard.get().is_some());
610        assert_eq!(factory.count(), 1);
611    }
612
613    #[tokio::test]
614    async fn test_config_validation() {
615        let invalid_config = PoolConfig {
616            min_size: 100,
617            max_size: 10,
618            idle_timeout: Duration::from_secs(60),
619            max_lifetime: Duration::from_secs(600),
620            ..Default::default()
621        };
622
623        let factory = Arc::new(TestFactory::new());
624        let result = ConnectionPool::new(factory, invalid_config);
625        assert!(result.is_err());
626    }
627
628    #[tokio::test]
629    async fn test_config_validation_zero_max_size() {
630        let config = PoolConfig {
631            max_size: 0,
632            ..Default::default()
633        };
634        assert!(config.validate().is_err());
635    }
636
637    #[tokio::test]
638    async fn test_config_validation_zero_timeouts() {
639        let config = PoolConfig {
640            idle_timeout: Duration::from_secs(0),
641            ..Default::default()
642        };
643        assert!(config.validate().is_err());
644
645        let config2 = PoolConfig {
646            max_lifetime: Duration::from_secs(0),
647            ..Default::default()
648        };
649        assert!(config2.validate().is_err());
650    }
651
652    #[tokio::test]
653    async fn test_config_validation_idle_exceeds_lifetime() {
654        let config = PoolConfig {
655            idle_timeout: Duration::from_secs(600),
656            max_lifetime: Duration::from_secs(300),
657            ..Default::default()
658        };
659        assert!(config.validate().is_err());
660    }
661
662    #[tokio::test]
663    async fn test_config_validation_circuit_breaker() {
664        let config = PoolConfig {
665            circuit_breaker_threshold: 0,
666            ..Default::default()
667        };
668        assert!(config.validate().is_err());
669
670        let config2 = PoolConfig {
671            circuit_breaker_timeout: Duration::from_secs(0),
672            ..Default::default()
673        };
674        assert!(config2.validate().is_err());
675    }
676
677    #[tokio::test]
678    async fn test_config_validation_valid_config() {
679        let config = PoolConfig::default();
680        assert!(config.validate().is_ok());
681    }
682}