amaters_net/
pool.rs

1//! Connection pool implementation for managing reusable connections
2//!
3//! Provides connection pooling with configurable limits, health checks,
4//! and lifecycle management for efficient resource utilization.
5
6use crate::balancer::{BalancingStrategy, EndpointId, LoadBalancer};
7use crate::circuit_breaker::CircuitBreaker;
8use crate::error::{NetError, NetResult};
9use async_trait::async_trait;
10use parking_lot::RwLock;
11use std::collections::VecDeque;
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use tokio::time;
15use tonic::transport::{Channel, Endpoint};
16
17/// Configuration for connection pool
18#[derive(Debug, Clone)]
19pub struct PoolConfig {
20    /// Minimum number of connections to maintain
21    pub min_size: usize,
22    /// Maximum number of connections allowed
23    pub max_size: usize,
24    /// Connection idle timeout (connections idle longer are closed)
25    pub idle_timeout: Duration,
26    /// Connection maximum lifetime (connections older are closed)
27    pub max_lifetime: Duration,
28    /// Connection timeout for establishing new connections
29    pub connect_timeout: Duration,
30    /// Health check interval
31    pub health_check_interval: Duration,
32    /// Load balancing strategy
33    pub balancing_strategy: BalancingStrategy,
34    /// Enable circuit breaker
35    pub enable_circuit_breaker: bool,
36}
37
38impl Default for PoolConfig {
39    fn default() -> Self {
40        Self {
41            min_size: 2,
42            max_size: 10,
43            idle_timeout: Duration::from_secs(300), // 5 minutes
44            max_lifetime: Duration::from_secs(1800), // 30 minutes
45            connect_timeout: Duration::from_secs(10),
46            health_check_interval: Duration::from_secs(30),
47            balancing_strategy: BalancingStrategy::LeastConnections,
48            enable_circuit_breaker: true,
49        }
50    }
51}
52
53/// Pool statistics
54#[derive(Debug, Clone, Default)]
55pub struct PoolStats {
56    /// Total number of connections (active + idle)
57    pub total_connections: usize,
58    /// Number of active (in-use) connections
59    pub active_connections: usize,
60    /// Number of idle (available) connections
61    pub idle_connections: usize,
62    /// Number of failed connection attempts
63    pub failed_connections: u64,
64    /// Total connections created
65    pub total_created: u64,
66    /// Total connections closed
67    pub total_closed: u64,
68    /// Number of times pool was exhausted (max size reached)
69    pub pool_exhausted_count: u64,
70    /// Average connection wait time in milliseconds
71    pub avg_wait_time_ms: u64,
72}
73
74/// Connection metadata
75#[derive(Debug)]
76struct ConnectionMeta {
77    /// gRPC channel
78    channel: Channel,
79    /// Endpoint ID
80    endpoint_id: EndpointId,
81    /// Time when connection was created
82    created_at: Instant,
83    /// Time when connection was last used
84    last_used: Instant,
85}
86
87impl ConnectionMeta {
88    /// Create new connection metadata
89    fn new(channel: Channel, endpoint_id: EndpointId) -> Self {
90        let now = Instant::now();
91        Self {
92            channel,
93            endpoint_id,
94            created_at: now,
95            last_used: now,
96        }
97    }
98
99    /// Check if connection is expired based on idle timeout
100    fn is_idle_expired(&self, idle_timeout: Duration) -> bool {
101        self.last_used.elapsed() > idle_timeout
102    }
103
104    /// Check if connection exceeded max lifetime
105    fn is_lifetime_expired(&self, max_lifetime: Duration) -> bool {
106        self.created_at.elapsed() > max_lifetime
107    }
108
109    /// Update last used timestamp
110    fn touch(&mut self) {
111        self.last_used = Instant::now();
112    }
113}
114
115/// Pooled connection wrapper
116pub struct PooledConnection {
117    meta: Option<ConnectionMeta>,
118    pool: Arc<ConnectionPoolInner>,
119}
120
121impl PooledConnection {
122    /// Get the underlying gRPC channel
123    pub fn channel(&self) -> &Channel {
124        &self.meta.as_ref().expect("connection should exist").channel
125    }
126
127    /// Get endpoint ID
128    pub fn endpoint_id(&self) -> &str {
129        &self
130            .meta
131            .as_ref()
132            .expect("connection should exist")
133            .endpoint_id
134    }
135}
136
137impl Drop for PooledConnection {
138    fn drop(&mut self) {
139        if let Some(mut meta) = self.meta.take() {
140            meta.touch();
141            self.pool.return_connection(meta);
142        }
143    }
144}
145
146/// Internal connection pool state
147struct ConnectionPoolInner {
148    config: PoolConfig,
149    idle_connections: RwLock<VecDeque<ConnectionMeta>>,
150    active_count: std::sync::Mutex<usize>,
151    stats: RwLock<PoolStats>,
152    load_balancer: LoadBalancer,
153    circuit_breaker: Option<CircuitBreaker>,
154}
155
156impl ConnectionPoolInner {
157    /// Return a connection to the pool
158    fn return_connection(&self, meta: ConnectionMeta) {
159        // Check if connection is expired
160        if meta.is_idle_expired(self.config.idle_timeout)
161            || meta.is_lifetime_expired(self.config.max_lifetime)
162        {
163            // Connection expired, don't return to pool
164            self.stats.write().total_closed += 1;
165            let mut active = self
166                .active_count
167                .lock()
168                .expect("active count lock poisoned");
169            *active = active.saturating_sub(1);
170            return;
171        }
172
173        // Return to pool
174        self.idle_connections.write().push_back(meta);
175        let mut active = self
176            .active_count
177            .lock()
178            .expect("active count lock poisoned");
179        *active = active.saturating_sub(1);
180    }
181
182    /// Get pool statistics
183    fn get_stats(&self) -> PoolStats {
184        let mut stats = self.stats.read().clone();
185        let idle = self.idle_connections.read().len();
186        let active = *self
187            .active_count
188            .lock()
189            .expect("active count lock poisoned");
190        stats.total_connections = idle + active;
191        stats.active_connections = active;
192        stats.idle_connections = idle;
193        stats
194    }
195}
196
197/// Connection pool for managing gRPC connections
198pub struct ConnectionPool {
199    inner: Arc<ConnectionPoolInner>,
200    shutdown_tx: tokio::sync::watch::Sender<bool>,
201}
202
203impl ConnectionPool {
204    /// Create a new connection pool
205    pub fn new(config: PoolConfig) -> Self {
206        let load_balancer = LoadBalancer::new(config.balancing_strategy);
207        let circuit_breaker = if config.enable_circuit_breaker {
208            Some(CircuitBreaker::new())
209        } else {
210            None
211        };
212
213        let inner = Arc::new(ConnectionPoolInner {
214            config: config.clone(),
215            idle_connections: RwLock::new(VecDeque::new()),
216            active_count: std::sync::Mutex::new(0),
217            stats: RwLock::new(PoolStats::default()),
218            load_balancer,
219            circuit_breaker,
220        });
221
222        let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
223
224        // Spawn health check task
225        let health_check_inner = Arc::clone(&inner);
226        tokio::spawn(async move {
227            Self::health_check_loop(health_check_inner, shutdown_rx).await;
228        });
229
230        Self { inner, shutdown_tx }
231    }
232
233    /// Add an endpoint to the connection pool
234    pub fn add_endpoint(&self, id: EndpointId, address: String) {
235        self.add_endpoint_with_weight(id, address, 1);
236    }
237
238    /// Add an endpoint with weight
239    pub fn add_endpoint_with_weight(&self, id: EndpointId, address: String, weight: u32) {
240        let endpoint = crate::balancer::Endpoint::with_weight(id, address, weight);
241        self.inner.load_balancer.add_endpoint(endpoint);
242    }
243
244    /// Remove an endpoint from the connection pool
245    pub fn remove_endpoint(&self, endpoint_id: &str) -> bool {
246        // Remove from load balancer
247        let removed = self.inner.load_balancer.remove_endpoint(endpoint_id);
248
249        // Close connections for this endpoint
250        if removed {
251            let mut idle = self.inner.idle_connections.write();
252            idle.retain(|conn| conn.endpoint_id != endpoint_id);
253        }
254
255        removed
256    }
257
258    /// Get a connection from the pool
259    pub async fn get_connection(&self) -> NetResult<PooledConnection> {
260        let start = Instant::now();
261
262        // Check circuit breaker
263        if let Some(ref cb) = self.inner.circuit_breaker {
264            cb.is_request_allowed()?;
265        }
266
267        // Try to get idle connection first
268        if let Some(mut meta) = self.inner.idle_connections.write().pop_front() {
269            meta.touch();
270            *self
271                .inner
272                .active_count
273                .lock()
274                .expect("active count lock poisoned") += 1;
275
276            return Ok(PooledConnection {
277                meta: Some(meta),
278                pool: Arc::clone(&self.inner),
279            });
280        }
281
282        // No idle connections, check if we can create a new one
283        let active = *self
284            .inner
285            .active_count
286            .lock()
287            .expect("active count lock poisoned");
288        let idle = self.inner.idle_connections.read().len();
289
290        if active + idle >= self.inner.config.max_size {
291            // Pool exhausted, wait for available connection
292            self.inner.stats.write().pool_exhausted_count += 1;
293
294            // Wait with timeout
295            let timeout = Duration::from_secs(30);
296            let deadline = Instant::now() + timeout;
297
298            while Instant::now() < deadline {
299                if let Some(mut meta) = self.inner.idle_connections.write().pop_front() {
300                    meta.touch();
301                    *self
302                        .inner
303                        .active_count
304                        .lock()
305                        .expect("active count lock poisoned") += 1;
306
307                    // Update wait time stats
308                    let wait_time = start.elapsed().as_millis() as u64;
309                    let mut stats = self.inner.stats.write();
310                    stats.avg_wait_time_ms = (stats.avg_wait_time_ms + wait_time) / 2;
311
312                    return Ok(PooledConnection {
313                        meta: Some(meta),
314                        pool: Arc::clone(&self.inner),
315                    });
316                }
317
318                // Wait a bit before retrying
319                time::sleep(Duration::from_millis(100)).await;
320            }
321
322            return Err(NetError::ServerOverloaded(
323                "Connection pool exhausted".to_string(),
324            ));
325        }
326
327        // Create new connection
328        let meta = self.create_connection().await?;
329        *self
330            .inner
331            .active_count
332            .lock()
333            .expect("active count lock poisoned") += 1;
334
335        Ok(PooledConnection {
336            meta: Some(meta),
337            pool: Arc::clone(&self.inner),
338        })
339    }
340
341    /// Create a new connection
342    async fn create_connection(&self) -> NetResult<ConnectionMeta> {
343        // Select endpoint using load balancer
344        let endpoint = self.inner.load_balancer.select_endpoint()?;
345
346        // Create gRPC channel
347        let channel = Endpoint::from_shared(format!("http://{}", endpoint.address))
348            .map_err(|e| NetError::InvalidRequest(format!("Invalid endpoint: {}", e)))?
349            .connect_timeout(self.inner.config.connect_timeout)
350            .timeout(Duration::from_secs(30))
351            .connect()
352            .await
353            .map_err(|e| {
354                self.inner.stats.write().failed_connections += 1;
355                if let Some(ref cb) = self.inner.circuit_breaker {
356                    cb.record_failure();
357                }
358                NetError::ConnectionRefused(format!("Failed to connect: {}", e))
359            })?;
360
361        // Record success
362        if let Some(ref cb) = self.inner.circuit_breaker {
363            cb.record_success();
364        }
365
366        self.inner.stats.write().total_created += 1;
367
368        Ok(ConnectionMeta::new(channel, endpoint.id.clone()))
369    }
370
371    /// Health check loop
372    async fn health_check_loop(
373        inner: Arc<ConnectionPoolInner>,
374        mut shutdown_rx: tokio::sync::watch::Receiver<bool>,
375    ) {
376        let mut interval = time::interval(inner.config.health_check_interval);
377
378        loop {
379            tokio::select! {
380                _ = interval.tick() => {
381                    Self::perform_health_check(&inner).await;
382                }
383                _ = shutdown_rx.changed() => {
384                    if *shutdown_rx.borrow() {
385                        break;
386                    }
387                }
388            }
389        }
390    }
391
392    /// Perform health check on idle connections
393    async fn perform_health_check(inner: &Arc<ConnectionPoolInner>) {
394        let needed = {
395            // Scope the lock to ensure it's dropped before any await
396            let mut idle = inner.idle_connections.write();
397            let config = &inner.config;
398
399            // Remove expired connections
400            idle.retain(|conn| {
401                !conn.is_idle_expired(config.idle_timeout)
402                    && !conn.is_lifetime_expired(config.max_lifetime)
403            });
404
405            // Ensure minimum pool size
406            let current_size = idle.len()
407                + *inner
408                    .active_count
409                    .lock()
410                    .expect("active count lock poisoned");
411            config.min_size.saturating_sub(current_size)
412        }; // Lock is dropped here
413
414        // Create needed connections (async operation)
415        for _ in 0..needed {
416            // This is best effort - we don't wait for results
417            // Real implementation would handle this more carefully
418            let _ = async {
419                // Would create connection here
420            }
421            .await;
422        }
423    }
424
425    /// Get pool statistics
426    pub fn stats(&self) -> PoolStats {
427        self.inner.get_stats()
428    }
429
430    /// Get circuit breaker statistics
431    pub fn circuit_breaker_stats(&self) -> Option<crate::circuit_breaker::CircuitBreakerStats> {
432        self.inner.circuit_breaker.as_ref().map(|cb| cb.stats())
433    }
434
435    /// Shutdown the connection pool gracefully
436    pub async fn shutdown(self) -> NetResult<()> {
437        // Signal shutdown to background tasks
438        self.shutdown_tx
439            .send(true)
440            .map_err(|_| NetError::ServerInternal("Failed to signal shutdown".to_string()))?;
441
442        // Wait for a short period to allow tasks to complete
443        time::sleep(Duration::from_millis(500)).await;
444
445        // Close all idle connections
446        let mut idle = self.inner.idle_connections.write();
447        let count = idle.len();
448        idle.clear();
449
450        self.inner.stats.write().total_closed += count as u64;
451
452        Ok(())
453    }
454
455    /// Drain the pool (prepare for graceful shutdown)
456    pub async fn drain(&self) -> NetResult<()> {
457        // Wait for active connections to complete
458        let timeout = Duration::from_secs(30);
459        let deadline = Instant::now() + timeout;
460
461        while Instant::now() < deadline {
462            let active = *self
463                .inner
464                .active_count
465                .lock()
466                .expect("active count lock poisoned");
467            if active == 0 {
468                break;
469            }
470            time::sleep(Duration::from_millis(100)).await;
471        }
472
473        let active = *self
474            .inner
475            .active_count
476            .lock()
477            .expect("active count lock poisoned");
478        if active > 0 {
479            return Err(NetError::Timeout(format!(
480                "Drain timeout: {} active connections remaining",
481                active
482            )));
483        }
484
485        Ok(())
486    }
487}
488
489/// Connection pool builder for fluent configuration
490pub struct ConnectionPoolBuilder {
491    config: PoolConfig,
492    endpoints: Vec<(EndpointId, String, u32)>,
493}
494
495impl ConnectionPoolBuilder {
496    /// Create a new builder
497    pub fn new() -> Self {
498        Self {
499            config: PoolConfig::default(),
500            endpoints: Vec::new(),
501        }
502    }
503
504    /// Set minimum pool size
505    pub fn min_size(mut self, size: usize) -> Self {
506        self.config.min_size = size;
507        self
508    }
509
510    /// Set maximum pool size
511    pub fn max_size(mut self, size: usize) -> Self {
512        self.config.max_size = size;
513        self
514    }
515
516    /// Set idle timeout
517    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
518        self.config.idle_timeout = timeout;
519        self
520    }
521
522    /// Set max lifetime
523    pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
524        self.config.max_lifetime = lifetime;
525        self
526    }
527
528    /// Set connect timeout
529    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
530        self.config.connect_timeout = timeout;
531        self
532    }
533
534    /// Set health check interval
535    pub fn health_check_interval(mut self, interval: Duration) -> Self {
536        self.config.health_check_interval = interval;
537        self
538    }
539
540    /// Set balancing strategy
541    pub fn balancing_strategy(mut self, strategy: BalancingStrategy) -> Self {
542        self.config.balancing_strategy = strategy;
543        self
544    }
545
546    /// Enable or disable circuit breaker
547    pub fn circuit_breaker(mut self, enabled: bool) -> Self {
548        self.config.enable_circuit_breaker = enabled;
549        self
550    }
551
552    /// Add an endpoint
553    pub fn add_endpoint(mut self, id: EndpointId, address: String) -> Self {
554        self.endpoints.push((id, address, 1));
555        self
556    }
557
558    /// Add an endpoint with weight
559    pub fn add_endpoint_with_weight(
560        mut self,
561        id: EndpointId,
562        address: String,
563        weight: u32,
564    ) -> Self {
565        self.endpoints.push((id, address, weight));
566        self
567    }
568
569    /// Build the connection pool
570    pub fn build(self) -> ConnectionPool {
571        let pool = ConnectionPool::new(self.config);
572
573        for (id, address, weight) in self.endpoints {
574            pool.add_endpoint_with_weight(id, address, weight);
575        }
576
577        pool
578    }
579}
580
581impl Default for ConnectionPoolBuilder {
582    fn default() -> Self {
583        Self::new()
584    }
585}
586
587#[cfg(test)]
588mod tests {
589    use super::*;
590
591    #[test]
592    fn test_pool_config_default() {
593        let config = PoolConfig::default();
594        assert_eq!(config.min_size, 2);
595        assert_eq!(config.max_size, 10);
596        assert!(config.enable_circuit_breaker);
597    }
598
599    #[tokio::test]
600    async fn test_connection_meta_expiry() {
601        // Skip if we can't connect (localhost not available)
602        let endpoint = Endpoint::from_static("http://localhost:50051");
603        if let Ok(channel) = endpoint.connect().await {
604            let meta = ConnectionMeta::new(channel, "ep1".to_string());
605
606            assert!(!meta.is_idle_expired(Duration::from_secs(10)));
607            assert!(!meta.is_lifetime_expired(Duration::from_secs(10)));
608        }
609        // Test passes even without connection - we're testing the struct, not connectivity
610    }
611
612    #[tokio::test]
613    async fn test_pool_builder() {
614        let pool = ConnectionPoolBuilder::new()
615            .min_size(5)
616            .max_size(20)
617            .idle_timeout(Duration::from_secs(600))
618            .balancing_strategy(BalancingStrategy::RoundRobin)
619            .add_endpoint("ep1".to_string(), "localhost:50051".to_string())
620            .add_endpoint("ep2".to_string(), "localhost:50052".to_string())
621            .build();
622
623        let stats = pool.stats();
624        assert_eq!(stats.active_connections, 0);
625        assert_eq!(stats.idle_connections, 0);
626    }
627
628    #[tokio::test]
629    async fn test_pool_add_remove_endpoint() {
630        let pool = ConnectionPool::new(PoolConfig::default());
631
632        pool.add_endpoint("ep1".to_string(), "localhost:50051".to_string());
633        pool.add_endpoint("ep2".to_string(), "localhost:50052".to_string());
634
635        assert!(pool.remove_endpoint("ep1"));
636        assert!(!pool.remove_endpoint("ep3"));
637    }
638
639    #[tokio::test]
640    async fn test_pool_stats() {
641        let pool = ConnectionPool::new(PoolConfig::default());
642        pool.add_endpoint("ep1".to_string(), "localhost:50051".to_string());
643
644        let stats = pool.stats();
645        assert_eq!(stats.total_connections, 0);
646        assert_eq!(stats.active_connections, 0);
647        assert_eq!(stats.idle_connections, 0);
648    }
649
650    #[tokio::test]
651    async fn test_pool_shutdown() {
652        let pool = ConnectionPool::new(PoolConfig::default());
653        pool.add_endpoint("ep1".to_string(), "localhost:50051".to_string());
654
655        // Shutdown should complete successfully
656        let result = pool.shutdown().await;
657        assert!(result.is_ok());
658    }
659}