qudag_network/
connection_pool.rs

1#![deny(unsafe_code)]
2
3use crate::connection::{ConnectionInfo, PooledConnection, WarmingState};
4use crate::types::{ConnectionStatus, NetworkError, PeerId};
5use dashmap::DashMap;
6use parking_lot::RwLock;
7use std::collections::{HashMap, VecDeque};
8use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::sync::{Notify, Semaphore};
12use tokio::time::{interval, sleep};
13use tracing::{debug, warn};
14
15/// Connection pool configuration
16#[derive(Debug, Clone)]
17pub struct PoolConfig {
18    /// Maximum connections in pool
19    pub max_size: usize,
20    /// Minimum connections to maintain
21    pub min_size: usize,
22    /// Connection idle timeout
23    pub idle_timeout: Duration,
24    /// Connection max lifetime
25    pub max_lifetime: Duration,
26    /// Health check interval
27    pub health_check_interval: Duration,
28    /// Connection acquisition timeout
29    pub acquire_timeout: Duration,
30    /// Enable connection warming
31    pub enable_warming: bool,
32    /// Connection validation on checkout
33    pub validate_on_checkout: bool,
34    /// Maximum connection reuse count
35    pub max_reuse_count: u64,
36}
37
38impl Default for PoolConfig {
39    fn default() -> Self {
40        Self {
41            max_size: 100,
42            min_size: 10,
43            idle_timeout: Duration::from_secs(300), // 5 minutes
44            max_lifetime: Duration::from_secs(3600), // 1 hour
45            health_check_interval: Duration::from_secs(30),
46            acquire_timeout: Duration::from_secs(10),
47            enable_warming: true,
48            validate_on_checkout: true,
49            max_reuse_count: 1000,
50        }
51    }
52}
53
54/// Connection pool statistics
55#[derive(Debug, Clone, Default)]
56pub struct PoolStats {
57    /// Total connections created
58    pub total_created: u64,
59    /// Total connections destroyed
60    pub total_destroyed: u64,
61    /// Current pool size
62    pub current_size: usize,
63    /// Available connections
64    pub available: usize,
65    /// Active connections
66    pub active: usize,
67    /// Connection acquisition count
68    pub acquisitions: u64,
69    /// Connection release count
70    pub releases: u64,
71    /// Failed acquisition attempts
72    pub failed_acquisitions: u64,
73    /// Connection timeout count
74    pub timeouts: u64,
75    /// Average wait time for connection
76    pub avg_wait_time: Duration,
77    /// Pool hit rate
78    pub hit_rate: f64,
79}
80
81/// Connection pool for efficient connection management
82pub struct ConnectionPool {
83    /// Pool configuration
84    config: PoolConfig,
85    /// Available connections
86    available: Arc<DashMap<PeerId, VecDeque<PooledConnection>>>,
87    /// Active connections (checked out)
88    active: Arc<DashMap<PeerId, HashMap<u64, PooledConnection>>>,
89    /// Connection semaphores per peer
90    semaphores: Arc<DashMap<PeerId, Arc<Semaphore>>>,
91    /// Pool statistics
92    stats: Arc<RwLock<PoolStats>>,
93    /// Connection ID counter
94    connection_counter: AtomicUsize,
95    /// Pool shutdown flag
96    shutdown: AtomicBool,
97    /// Connection waiters
98    waiters: Arc<DashMap<PeerId, Arc<Notify>>>,
99    /// Maintenance task handle
100    #[allow(dead_code)]
101    maintenance_handle: Option<tokio::task::JoinHandle<()>>,
102}
103
104impl ConnectionPool {
105    /// Create a new connection pool
106    pub fn new(config: PoolConfig) -> Self {
107        let pool = Self {
108            config: config.clone(),
109            available: Arc::new(DashMap::new()),
110            active: Arc::new(DashMap::new()),
111            semaphores: Arc::new(DashMap::new()),
112            stats: Arc::new(RwLock::new(PoolStats::default())),
113            connection_counter: AtomicUsize::new(0),
114            shutdown: AtomicBool::new(false),
115            waiters: Arc::new(DashMap::new()),
116            maintenance_handle: None,
117        };
118
119        // Start maintenance task
120        let maintenance_pool = pool.clone();
121        let handle = tokio::spawn(async move {
122            maintenance_pool.run_maintenance().await;
123        });
124
125        Self {
126            maintenance_handle: Some(handle),
127            ..pool
128        }
129    }
130
131    /// Acquire a connection from the pool
132    pub async fn acquire(&self, peer_id: PeerId) -> Result<PooledConnection, NetworkError> {
133        if self.shutdown.load(Ordering::Acquire) {
134            return Err(NetworkError::ConnectionError(
135                "Pool is shutting down".into(),
136            ));
137        }
138
139        let start_time = Instant::now();
140
141        // Get or create semaphore for this peer
142        let semaphore = self
143            .semaphores
144            .entry(peer_id)
145            .or_insert_with(|| Arc::new(Semaphore::new(self.config.max_size)))
146            .clone();
147
148        // Try to acquire permit with timeout
149        let permit = tokio::select! {
150            result = semaphore.acquire() => {
151                result.map_err(|_| NetworkError::ConnectionError("Semaphore closed".into()))?
152            }
153            _ = sleep(self.config.acquire_timeout) => {
154                self.increment_timeouts();
155                return Err(NetworkError::ConnectionError("Connection acquisition timeout".into()));
156            }
157        };
158
159        // Check available connections
160        if let Some(mut available_queue) = self.available.get_mut(&peer_id) {
161            while let Some(mut conn) = available_queue.pop_front() {
162                // Validate connection
163                if self.is_connection_valid(&conn) {
164                    if self.config.validate_on_checkout {
165                        // Perform additional validation if needed
166                        if !self.validate_connection(&conn).await {
167                            continue;
168                        }
169                    }
170
171                    // Update connection state
172                    conn.last_used = Instant::now();
173                    conn.usage_count += 1;
174
175                    // Move to active connections
176                    let conn_id = self.connection_counter.fetch_add(1, Ordering::Relaxed) as u64;
177                    self.active
178                        .entry(peer_id)
179                        .or_insert_with(HashMap::new)
180                        .insert(conn_id, conn.clone());
181
182                    // Update statistics
183                    self.update_acquisition_stats(start_time.elapsed());
184
185                    // Forget the permit (keep it alive)
186                    std::mem::forget(permit);
187
188                    return Ok(conn);
189                }
190            }
191        }
192
193        // No available connection, create new one if under limit
194        if self.get_peer_connection_count(peer_id) < self.config.max_size {
195            let conn = self.create_connection(peer_id).await?;
196
197            // Move to active connections
198            let conn_id = self.connection_counter.fetch_add(1, Ordering::Relaxed) as u64;
199            self.active
200                .entry(peer_id)
201                .or_insert_with(HashMap::new)
202                .insert(conn_id, conn.clone());
203
204            // Update statistics
205            self.update_acquisition_stats(start_time.elapsed());
206            self.increment_created();
207
208            // Forget the permit (keep it alive)
209            std::mem::forget(permit);
210
211            Ok(conn)
212        } else {
213            // Wait for a connection to become available
214            let waiter = self
215                .waiters
216                .entry(peer_id)
217                .or_insert_with(|| Arc::new(Notify::new()))
218                .clone();
219
220            drop(permit); // Release permit while waiting
221
222            tokio::select! {
223                _ = waiter.notified() => {
224                    // Retry acquisition
225                    Box::pin(self.acquire(peer_id)).await
226                }
227                _ = sleep(self.config.acquire_timeout) => {
228                    self.increment_failed_acquisitions();
229                    Err(NetworkError::ConnectionError("No available connections".into()))
230                }
231            }
232        }
233    }
234
235    /// Release a connection back to the pool
236    pub fn release(&self, peer_id: PeerId, mut connection: PooledConnection) {
237        if self.shutdown.load(Ordering::Acquire) {
238            return;
239        }
240
241        // Update connection state
242        connection.last_used = Instant::now();
243
244        // Check if connection should be kept
245        if !self.should_keep_connection(&connection) {
246            self.destroy_connection(peer_id, connection);
247            return;
248        }
249
250        // Return to available pool
251        self.available
252            .entry(peer_id)
253            .or_insert_with(VecDeque::new)
254            .push_back(connection);
255
256        // Notify waiters
257        if let Some(waiter) = self.waiters.get(&peer_id) {
258            waiter.notify_one();
259        }
260
261        // Update statistics
262        self.increment_releases();
263    }
264
265    /// Validate a connection
266    async fn validate_connection(&self, conn: &PooledConnection) -> bool {
267        // Basic validation - check if connection is healthy
268        if !conn.info.is_healthy() {
269            return false;
270        }
271
272        // Additional validation could include:
273        // - Ping test
274        // - Resource usage check
275        // - Performance metrics validation
276
277        true
278    }
279
280    /// Check if connection is valid for use
281    fn is_connection_valid(&self, conn: &PooledConnection) -> bool {
282        // Check lifetime
283        if conn.created_at.elapsed() > self.config.max_lifetime {
284            return false;
285        }
286
287        // Check idle time
288        if conn.last_used.elapsed() > self.config.idle_timeout {
289            return false;
290        }
291
292        // Check reuse count
293        if conn.usage_count >= self.config.max_reuse_count {
294            return false;
295        }
296
297        // Check health
298        conn.info.is_healthy()
299    }
300
301    /// Check if connection should be kept in pool
302    fn should_keep_connection(&self, conn: &PooledConnection) -> bool {
303        self.is_connection_valid(conn) && self.get_total_connection_count() < self.config.max_size
304    }
305
306    /// Create a new connection
307    async fn create_connection(&self, _peer_id: PeerId) -> Result<PooledConnection, NetworkError> {
308        // Simulate connection creation (in real implementation, this would establish actual connection)
309        let info = ConnectionInfo::new(ConnectionStatus::Connected);
310
311        let mut conn = PooledConnection {
312            info,
313            created_at: Instant::now(),
314            last_used: Instant::now(),
315            usage_count: 0,
316            weight: 1.0,
317            max_streams: 100,
318            active_streams: 0,
319            warming_state: WarmingState::Cold,
320            affinity_group: None,
321        };
322
323        // Warm connection if enabled
324        if self.config.enable_warming {
325            self.warm_connection(&mut conn).await?;
326        }
327
328        Ok(conn)
329    }
330
331    /// Warm a connection
332    async fn warm_connection(&self, conn: &mut PooledConnection) -> Result<(), NetworkError> {
333        conn.warming_state = WarmingState::Warming;
334
335        // Simulate warming process
336        sleep(Duration::from_millis(50)).await;
337
338        // In real implementation, this would:
339        // - Establish TLS handshake
340        // - Perform protocol negotiation
341        // - Prime any caches
342        // - Run initial health checks
343
344        conn.warming_state = WarmingState::Warm;
345        Ok(())
346    }
347
348    /// Destroy a connection
349    fn destroy_connection(&self, _peer_id: PeerId, _conn: PooledConnection) {
350        // In real implementation, this would close the actual connection
351        self.increment_destroyed();
352    }
353
354    /// Get connection count for a peer
355    fn get_peer_connection_count(&self, peer_id: PeerId) -> usize {
356        let available_count = self
357            .available
358            .get(&peer_id)
359            .map(|queue| queue.len())
360            .unwrap_or(0);
361
362        let active_count = self.active.get(&peer_id).map(|map| map.len()).unwrap_or(0);
363
364        available_count + active_count
365    }
366
367    /// Get total connection count
368    fn get_total_connection_count(&self) -> usize {
369        let available_count: usize = self.available.iter().map(|entry| entry.value().len()).sum();
370
371        let active_count: usize = self.active.iter().map(|entry| entry.value().len()).sum();
372
373        available_count + active_count
374    }
375
376    /// Run maintenance tasks
377    async fn run_maintenance(&self) {
378        let mut interval = interval(self.config.health_check_interval);
379
380        while !self.shutdown.load(Ordering::Acquire) {
381            interval.tick().await;
382
383            // Clean up expired connections
384            self.cleanup_expired_connections();
385
386            // Maintain minimum pool size
387            self.maintain_minimum_size().await;
388
389            // Update pool statistics
390            self.update_pool_stats();
391        }
392    }
393
394    /// Clean up expired connections
395    fn cleanup_expired_connections(&self) {
396        for mut entry in self.available.iter_mut() {
397            let peer_id = *entry.key();
398            let queue = entry.value_mut();
399
400            // Remove invalid connections
401            queue.retain(|conn| {
402                if self.is_connection_valid(conn) {
403                    true
404                } else {
405                    self.destroy_connection(peer_id, conn.clone());
406                    false
407                }
408            });
409        }
410    }
411
412    /// Maintain minimum pool size
413    async fn maintain_minimum_size(&self) {
414        // This is a simplified version - in production, you'd want more sophisticated logic
415        let total_count = self.get_total_connection_count();
416
417        if total_count < self.config.min_size {
418            let needed = self.config.min_size - total_count;
419            debug!("Pool below minimum size, creating {} connections", needed);
420
421            // Create connections for known peers
422            for entry in self.available.iter() {
423                let peer_id = *entry.key();
424                for _ in 0..needed {
425                    match self.create_connection(peer_id).await {
426                        Ok(conn) => {
427                            self.available
428                                .entry(peer_id)
429                                .or_insert_with(VecDeque::new)
430                                .push_back(conn);
431                            self.increment_created();
432                        }
433                        Err(e) => {
434                            warn!("Failed to create connection during maintenance: {}", e);
435                        }
436                    }
437                }
438            }
439        }
440    }
441
442    /// Update pool statistics
443    fn update_pool_stats(&self) {
444        let mut stats = self.stats.write();
445
446        stats.current_size = self.get_total_connection_count();
447        stats.available = self.available.iter().map(|entry| entry.value().len()).sum();
448        stats.active = self.active.iter().map(|entry| entry.value().len()).sum();
449
450        // Calculate hit rate
451        if stats.acquisitions > 0 {
452            stats.hit_rate = 1.0 - (stats.failed_acquisitions as f64 / stats.acquisitions as f64);
453        }
454    }
455
456    /// Shutdown the pool
457    pub async fn shutdown(&mut self) {
458        self.shutdown.store(true, Ordering::Release);
459
460        // Stop maintenance task
461        if let Some(handle) = self.maintenance_handle.take() {
462            handle.abort();
463        }
464
465        // Close all connections
466        for entry in self.available.iter() {
467            let peer_id = *entry.key();
468            for conn in entry.value().iter() {
469                self.destroy_connection(peer_id, conn.clone());
470            }
471        }
472
473        for entry in self.active.iter() {
474            let peer_id = *entry.key();
475            for (_, conn) in entry.value().iter() {
476                self.destroy_connection(peer_id, conn.clone());
477            }
478        }
479
480        // Clear all data
481        self.available.clear();
482        self.active.clear();
483        self.semaphores.clear();
484        self.waiters.clear();
485    }
486
487    /// Get pool statistics
488    pub fn get_stats(&self) -> PoolStats {
489        self.stats.read().clone()
490    }
491
492    // Statistics update methods
493    fn increment_created(&self) {
494        self.stats.write().total_created += 1;
495    }
496
497    fn increment_destroyed(&self) {
498        self.stats.write().total_destroyed += 1;
499    }
500
501    fn increment_releases(&self) {
502        self.stats.write().releases += 1;
503    }
504
505    fn increment_timeouts(&self) {
506        self.stats.write().timeouts += 1;
507    }
508
509    fn increment_failed_acquisitions(&self) {
510        self.stats.write().failed_acquisitions += 1;
511    }
512
513    fn update_acquisition_stats(&self, wait_time: Duration) {
514        let mut stats = self.stats.write();
515        stats.acquisitions += 1;
516
517        // Update average wait time (exponential moving average)
518        let alpha = 0.1;
519        let current_avg = stats.avg_wait_time.as_millis() as f64;
520        let new_wait = wait_time.as_millis() as f64;
521        let updated_avg = alpha * new_wait + (1.0 - alpha) * current_avg;
522        stats.avg_wait_time = Duration::from_millis(updated_avg as u64);
523    }
524}
525
526impl Clone for ConnectionPool {
527    fn clone(&self) -> Self {
528        Self {
529            config: self.config.clone(),
530            available: self.available.clone(),
531            active: self.active.clone(),
532            semaphores: self.semaphores.clone(),
533            stats: self.stats.clone(),
534            connection_counter: AtomicUsize::new(self.connection_counter.load(Ordering::Relaxed)),
535            shutdown: AtomicBool::new(self.shutdown.load(Ordering::Relaxed)),
536            waiters: self.waiters.clone(),
537            maintenance_handle: None, // Don't clone the maintenance task
538        }
539    }
540}
541
542#[cfg(test)]
543mod tests {
544    use super::*;
545
546    #[tokio::test]
547    async fn test_pool_creation() {
548        let config = PoolConfig::default();
549        let pool = ConnectionPool::new(config);
550
551        let stats = pool.get_stats();
552        assert_eq!(stats.current_size, 0);
553        assert_eq!(stats.available, 0);
554        assert_eq!(stats.active, 0);
555    }
556
557    #[tokio::test]
558    async fn test_connection_acquisition() {
559        let config = PoolConfig {
560            max_size: 10,
561            min_size: 0,
562            ..Default::default()
563        };
564        let pool = ConnectionPool::new(config);
565        let peer_id = PeerId::random();
566
567        // Acquire connection
568        let conn = pool.acquire(peer_id).await.unwrap();
569        assert_eq!(conn.usage_count, 1);
570
571        let stats = pool.get_stats();
572        assert_eq!(stats.acquisitions, 1);
573        assert_eq!(stats.total_created, 1);
574    }
575
576    #[tokio::test]
577    async fn test_connection_release() {
578        let config = PoolConfig::default();
579        let pool = ConnectionPool::new(config);
580        let peer_id = PeerId::random();
581
582        // Acquire and release connection
583        let conn = pool.acquire(peer_id).await.unwrap();
584        pool.release(peer_id, conn);
585
586        let stats = pool.get_stats();
587        assert_eq!(stats.releases, 1);
588        assert_eq!(stats.available, 1);
589    }
590
591    #[tokio::test]
592    async fn test_connection_reuse() {
593        let config = PoolConfig::default();
594        let pool = ConnectionPool::new(config);
595        let peer_id = PeerId::random();
596
597        // First acquisition
598        let conn1 = pool.acquire(peer_id).await.unwrap();
599        let created_at = conn1.created_at;
600        pool.release(peer_id, conn1);
601
602        // Second acquisition should reuse
603        let conn2 = pool.acquire(peer_id).await.unwrap();
604        assert_eq!(conn2.created_at, created_at);
605        assert_eq!(conn2.usage_count, 2);
606
607        let stats = pool.get_stats();
608        assert_eq!(stats.total_created, 1);
609        assert_eq!(stats.acquisitions, 2);
610    }
611
612    #[tokio::test]
613    async fn test_pool_limits() {
614        let config = PoolConfig {
615            max_size: 2,
616            acquire_timeout: Duration::from_millis(100),
617            ..Default::default()
618        };
619        let pool = ConnectionPool::new(config);
620        let peer_id = PeerId::random();
621
622        // Acquire max connections
623        let conn1 = pool.acquire(peer_id).await.unwrap();
624        let conn2 = pool.acquire(peer_id).await.unwrap();
625
626        // Third acquisition should timeout
627        let result = pool.acquire(peer_id).await;
628        assert!(result.is_err());
629
630        // Release one and try again
631        pool.release(peer_id, conn1);
632        let conn3 = pool.acquire(peer_id).await;
633        assert!(conn3.is_ok());
634
635        // Cleanup
636        pool.release(peer_id, conn2);
637        pool.release(peer_id, conn3.unwrap());
638    }
639
640    #[tokio::test]
641    async fn test_connection_expiration() {
642        let config = PoolConfig {
643            idle_timeout: Duration::from_millis(100),
644            health_check_interval: Duration::from_millis(50),
645            ..Default::default()
646        };
647        let pool = ConnectionPool::new(config);
648        let peer_id = PeerId::random();
649
650        // Create and release connection
651        let conn = pool.acquire(peer_id).await.unwrap();
652        pool.release(peer_id, conn);
653
654        // Wait for expiration
655        sleep(Duration::from_millis(200)).await;
656
657        // Connection should be cleaned up
658        let stats = pool.get_stats();
659        assert_eq!(stats.available, 0);
660    }
661}