1#![deny(unsafe_code)]
2
3use crate::connection::{ConnectionInfo, PooledConnection, WarmingState};
4use crate::types::{ConnectionStatus, NetworkError, PeerId};
5use dashmap::DashMap;
6use parking_lot::RwLock;
7use std::collections::{HashMap, VecDeque};
8use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::sync::{Notify, Semaphore};
12use tokio::time::{interval, sleep};
13use tracing::{debug, warn};
14
15#[derive(Debug, Clone)]
17pub struct PoolConfig {
18 pub max_size: usize,
20 pub min_size: usize,
22 pub idle_timeout: Duration,
24 pub max_lifetime: Duration,
26 pub health_check_interval: Duration,
28 pub acquire_timeout: Duration,
30 pub enable_warming: bool,
32 pub validate_on_checkout: bool,
34 pub max_reuse_count: u64,
36}
37
38impl Default for PoolConfig {
39 fn default() -> Self {
40 Self {
41 max_size: 100,
42 min_size: 10,
43 idle_timeout: Duration::from_secs(300), max_lifetime: Duration::from_secs(3600), health_check_interval: Duration::from_secs(30),
46 acquire_timeout: Duration::from_secs(10),
47 enable_warming: true,
48 validate_on_checkout: true,
49 max_reuse_count: 1000,
50 }
51 }
52}
53
54#[derive(Debug, Clone, Default)]
56pub struct PoolStats {
57 pub total_created: u64,
59 pub total_destroyed: u64,
61 pub current_size: usize,
63 pub available: usize,
65 pub active: usize,
67 pub acquisitions: u64,
69 pub releases: u64,
71 pub failed_acquisitions: u64,
73 pub timeouts: u64,
75 pub avg_wait_time: Duration,
77 pub hit_rate: f64,
79}
80
81pub struct ConnectionPool {
83 config: PoolConfig,
85 available: Arc<DashMap<PeerId, VecDeque<PooledConnection>>>,
87 active: Arc<DashMap<PeerId, HashMap<u64, PooledConnection>>>,
89 semaphores: Arc<DashMap<PeerId, Arc<Semaphore>>>,
91 stats: Arc<RwLock<PoolStats>>,
93 connection_counter: AtomicUsize,
95 shutdown: AtomicBool,
97 waiters: Arc<DashMap<PeerId, Arc<Notify>>>,
99 #[allow(dead_code)]
101 maintenance_handle: Option<tokio::task::JoinHandle<()>>,
102}
103
104impl ConnectionPool {
105 pub fn new(config: PoolConfig) -> Self {
107 let pool = Self {
108 config: config.clone(),
109 available: Arc::new(DashMap::new()),
110 active: Arc::new(DashMap::new()),
111 semaphores: Arc::new(DashMap::new()),
112 stats: Arc::new(RwLock::new(PoolStats::default())),
113 connection_counter: AtomicUsize::new(0),
114 shutdown: AtomicBool::new(false),
115 waiters: Arc::new(DashMap::new()),
116 maintenance_handle: None,
117 };
118
119 let maintenance_pool = pool.clone();
121 let handle = tokio::spawn(async move {
122 maintenance_pool.run_maintenance().await;
123 });
124
125 Self {
126 maintenance_handle: Some(handle),
127 ..pool
128 }
129 }
130
131 pub async fn acquire(&self, peer_id: PeerId) -> Result<PooledConnection, NetworkError> {
133 if self.shutdown.load(Ordering::Acquire) {
134 return Err(NetworkError::ConnectionError(
135 "Pool is shutting down".into(),
136 ));
137 }
138
139 let start_time = Instant::now();
140
141 let semaphore = self
143 .semaphores
144 .entry(peer_id)
145 .or_insert_with(|| Arc::new(Semaphore::new(self.config.max_size)))
146 .clone();
147
148 let permit = tokio::select! {
150 result = semaphore.acquire() => {
151 result.map_err(|_| NetworkError::ConnectionError("Semaphore closed".into()))?
152 }
153 _ = sleep(self.config.acquire_timeout) => {
154 self.increment_timeouts();
155 return Err(NetworkError::ConnectionError("Connection acquisition timeout".into()));
156 }
157 };
158
159 if let Some(mut available_queue) = self.available.get_mut(&peer_id) {
161 while let Some(mut conn) = available_queue.pop_front() {
162 if self.is_connection_valid(&conn) {
164 if self.config.validate_on_checkout {
165 if !self.validate_connection(&conn).await {
167 continue;
168 }
169 }
170
171 conn.last_used = Instant::now();
173 conn.usage_count += 1;
174
175 let conn_id = self.connection_counter.fetch_add(1, Ordering::Relaxed) as u64;
177 self.active
178 .entry(peer_id)
179 .or_insert_with(HashMap::new)
180 .insert(conn_id, conn.clone());
181
182 self.update_acquisition_stats(start_time.elapsed());
184
185 std::mem::forget(permit);
187
188 return Ok(conn);
189 }
190 }
191 }
192
193 if self.get_peer_connection_count(peer_id) < self.config.max_size {
195 let conn = self.create_connection(peer_id).await?;
196
197 let conn_id = self.connection_counter.fetch_add(1, Ordering::Relaxed) as u64;
199 self.active
200 .entry(peer_id)
201 .or_insert_with(HashMap::new)
202 .insert(conn_id, conn.clone());
203
204 self.update_acquisition_stats(start_time.elapsed());
206 self.increment_created();
207
208 std::mem::forget(permit);
210
211 Ok(conn)
212 } else {
213 let waiter = self
215 .waiters
216 .entry(peer_id)
217 .or_insert_with(|| Arc::new(Notify::new()))
218 .clone();
219
220 drop(permit); tokio::select! {
223 _ = waiter.notified() => {
224 Box::pin(self.acquire(peer_id)).await
226 }
227 _ = sleep(self.config.acquire_timeout) => {
228 self.increment_failed_acquisitions();
229 Err(NetworkError::ConnectionError("No available connections".into()))
230 }
231 }
232 }
233 }
234
235 pub fn release(&self, peer_id: PeerId, mut connection: PooledConnection) {
237 if self.shutdown.load(Ordering::Acquire) {
238 return;
239 }
240
241 connection.last_used = Instant::now();
243
244 if !self.should_keep_connection(&connection) {
246 self.destroy_connection(peer_id, connection);
247 return;
248 }
249
250 self.available
252 .entry(peer_id)
253 .or_insert_with(VecDeque::new)
254 .push_back(connection);
255
256 if let Some(waiter) = self.waiters.get(&peer_id) {
258 waiter.notify_one();
259 }
260
261 self.increment_releases();
263 }
264
265 async fn validate_connection(&self, conn: &PooledConnection) -> bool {
267 if !conn.info.is_healthy() {
269 return false;
270 }
271
272 true
278 }
279
280 fn is_connection_valid(&self, conn: &PooledConnection) -> bool {
282 if conn.created_at.elapsed() > self.config.max_lifetime {
284 return false;
285 }
286
287 if conn.last_used.elapsed() > self.config.idle_timeout {
289 return false;
290 }
291
292 if conn.usage_count >= self.config.max_reuse_count {
294 return false;
295 }
296
297 conn.info.is_healthy()
299 }
300
301 fn should_keep_connection(&self, conn: &PooledConnection) -> bool {
303 self.is_connection_valid(conn) && self.get_total_connection_count() < self.config.max_size
304 }
305
306 async fn create_connection(&self, _peer_id: PeerId) -> Result<PooledConnection, NetworkError> {
308 let info = ConnectionInfo::new(ConnectionStatus::Connected);
310
311 let mut conn = PooledConnection {
312 info,
313 created_at: Instant::now(),
314 last_used: Instant::now(),
315 usage_count: 0,
316 weight: 1.0,
317 max_streams: 100,
318 active_streams: 0,
319 warming_state: WarmingState::Cold,
320 affinity_group: None,
321 };
322
323 if self.config.enable_warming {
325 self.warm_connection(&mut conn).await?;
326 }
327
328 Ok(conn)
329 }
330
331 async fn warm_connection(&self, conn: &mut PooledConnection) -> Result<(), NetworkError> {
333 conn.warming_state = WarmingState::Warming;
334
335 sleep(Duration::from_millis(50)).await;
337
338 conn.warming_state = WarmingState::Warm;
345 Ok(())
346 }
347
348 fn destroy_connection(&self, _peer_id: PeerId, _conn: PooledConnection) {
350 self.increment_destroyed();
352 }
353
354 fn get_peer_connection_count(&self, peer_id: PeerId) -> usize {
356 let available_count = self
357 .available
358 .get(&peer_id)
359 .map(|queue| queue.len())
360 .unwrap_or(0);
361
362 let active_count = self.active.get(&peer_id).map(|map| map.len()).unwrap_or(0);
363
364 available_count + active_count
365 }
366
367 fn get_total_connection_count(&self) -> usize {
369 let available_count: usize = self.available.iter().map(|entry| entry.value().len()).sum();
370
371 let active_count: usize = self.active.iter().map(|entry| entry.value().len()).sum();
372
373 available_count + active_count
374 }
375
376 async fn run_maintenance(&self) {
378 let mut interval = interval(self.config.health_check_interval);
379
380 while !self.shutdown.load(Ordering::Acquire) {
381 interval.tick().await;
382
383 self.cleanup_expired_connections();
385
386 self.maintain_minimum_size().await;
388
389 self.update_pool_stats();
391 }
392 }
393
394 fn cleanup_expired_connections(&self) {
396 for mut entry in self.available.iter_mut() {
397 let peer_id = *entry.key();
398 let queue = entry.value_mut();
399
400 queue.retain(|conn| {
402 if self.is_connection_valid(conn) {
403 true
404 } else {
405 self.destroy_connection(peer_id, conn.clone());
406 false
407 }
408 });
409 }
410 }
411
412 async fn maintain_minimum_size(&self) {
414 let total_count = self.get_total_connection_count();
416
417 if total_count < self.config.min_size {
418 let needed = self.config.min_size - total_count;
419 debug!("Pool below minimum size, creating {} connections", needed);
420
421 for entry in self.available.iter() {
423 let peer_id = *entry.key();
424 for _ in 0..needed {
425 match self.create_connection(peer_id).await {
426 Ok(conn) => {
427 self.available
428 .entry(peer_id)
429 .or_insert_with(VecDeque::new)
430 .push_back(conn);
431 self.increment_created();
432 }
433 Err(e) => {
434 warn!("Failed to create connection during maintenance: {}", e);
435 }
436 }
437 }
438 }
439 }
440 }
441
442 fn update_pool_stats(&self) {
444 let mut stats = self.stats.write();
445
446 stats.current_size = self.get_total_connection_count();
447 stats.available = self.available.iter().map(|entry| entry.value().len()).sum();
448 stats.active = self.active.iter().map(|entry| entry.value().len()).sum();
449
450 if stats.acquisitions > 0 {
452 stats.hit_rate = 1.0 - (stats.failed_acquisitions as f64 / stats.acquisitions as f64);
453 }
454 }
455
456 pub async fn shutdown(&mut self) {
458 self.shutdown.store(true, Ordering::Release);
459
460 if let Some(handle) = self.maintenance_handle.take() {
462 handle.abort();
463 }
464
465 for entry in self.available.iter() {
467 let peer_id = *entry.key();
468 for conn in entry.value().iter() {
469 self.destroy_connection(peer_id, conn.clone());
470 }
471 }
472
473 for entry in self.active.iter() {
474 let peer_id = *entry.key();
475 for (_, conn) in entry.value().iter() {
476 self.destroy_connection(peer_id, conn.clone());
477 }
478 }
479
480 self.available.clear();
482 self.active.clear();
483 self.semaphores.clear();
484 self.waiters.clear();
485 }
486
487 pub fn get_stats(&self) -> PoolStats {
489 self.stats.read().clone()
490 }
491
492 fn increment_created(&self) {
494 self.stats.write().total_created += 1;
495 }
496
497 fn increment_destroyed(&self) {
498 self.stats.write().total_destroyed += 1;
499 }
500
501 fn increment_releases(&self) {
502 self.stats.write().releases += 1;
503 }
504
505 fn increment_timeouts(&self) {
506 self.stats.write().timeouts += 1;
507 }
508
509 fn increment_failed_acquisitions(&self) {
510 self.stats.write().failed_acquisitions += 1;
511 }
512
513 fn update_acquisition_stats(&self, wait_time: Duration) {
514 let mut stats = self.stats.write();
515 stats.acquisitions += 1;
516
517 let alpha = 0.1;
519 let current_avg = stats.avg_wait_time.as_millis() as f64;
520 let new_wait = wait_time.as_millis() as f64;
521 let updated_avg = alpha * new_wait + (1.0 - alpha) * current_avg;
522 stats.avg_wait_time = Duration::from_millis(updated_avg as u64);
523 }
524}
525
526impl Clone for ConnectionPool {
527 fn clone(&self) -> Self {
528 Self {
529 config: self.config.clone(),
530 available: self.available.clone(),
531 active: self.active.clone(),
532 semaphores: self.semaphores.clone(),
533 stats: self.stats.clone(),
534 connection_counter: AtomicUsize::new(self.connection_counter.load(Ordering::Relaxed)),
535 shutdown: AtomicBool::new(self.shutdown.load(Ordering::Relaxed)),
536 waiters: self.waiters.clone(),
537 maintenance_handle: None, }
539 }
540}
541
542#[cfg(test)]
543mod tests {
544 use super::*;
545
546 #[tokio::test]
547 async fn test_pool_creation() {
548 let config = PoolConfig::default();
549 let pool = ConnectionPool::new(config);
550
551 let stats = pool.get_stats();
552 assert_eq!(stats.current_size, 0);
553 assert_eq!(stats.available, 0);
554 assert_eq!(stats.active, 0);
555 }
556
557 #[tokio::test]
558 async fn test_connection_acquisition() {
559 let config = PoolConfig {
560 max_size: 10,
561 min_size: 0,
562 ..Default::default()
563 };
564 let pool = ConnectionPool::new(config);
565 let peer_id = PeerId::random();
566
567 let conn = pool.acquire(peer_id).await.unwrap();
569 assert_eq!(conn.usage_count, 1);
570
571 let stats = pool.get_stats();
572 assert_eq!(stats.acquisitions, 1);
573 assert_eq!(stats.total_created, 1);
574 }
575
576 #[tokio::test]
577 async fn test_connection_release() {
578 let config = PoolConfig::default();
579 let pool = ConnectionPool::new(config);
580 let peer_id = PeerId::random();
581
582 let conn = pool.acquire(peer_id).await.unwrap();
584 pool.release(peer_id, conn);
585
586 let stats = pool.get_stats();
587 assert_eq!(stats.releases, 1);
588 assert_eq!(stats.available, 1);
589 }
590
591 #[tokio::test]
592 async fn test_connection_reuse() {
593 let config = PoolConfig::default();
594 let pool = ConnectionPool::new(config);
595 let peer_id = PeerId::random();
596
597 let conn1 = pool.acquire(peer_id).await.unwrap();
599 let created_at = conn1.created_at;
600 pool.release(peer_id, conn1);
601
602 let conn2 = pool.acquire(peer_id).await.unwrap();
604 assert_eq!(conn2.created_at, created_at);
605 assert_eq!(conn2.usage_count, 2);
606
607 let stats = pool.get_stats();
608 assert_eq!(stats.total_created, 1);
609 assert_eq!(stats.acquisitions, 2);
610 }
611
612 #[tokio::test]
613 async fn test_pool_limits() {
614 let config = PoolConfig {
615 max_size: 2,
616 acquire_timeout: Duration::from_millis(100),
617 ..Default::default()
618 };
619 let pool = ConnectionPool::new(config);
620 let peer_id = PeerId::random();
621
622 let conn1 = pool.acquire(peer_id).await.unwrap();
624 let conn2 = pool.acquire(peer_id).await.unwrap();
625
626 let result = pool.acquire(peer_id).await;
628 assert!(result.is_err());
629
630 pool.release(peer_id, conn1);
632 let conn3 = pool.acquire(peer_id).await;
633 assert!(conn3.is_ok());
634
635 pool.release(peer_id, conn2);
637 pool.release(peer_id, conn3.unwrap());
638 }
639
640 #[tokio::test]
641 async fn test_connection_expiration() {
642 let config = PoolConfig {
643 idle_timeout: Duration::from_millis(100),
644 health_check_interval: Duration::from_millis(50),
645 ..Default::default()
646 };
647 let pool = ConnectionPool::new(config);
648 let peer_id = PeerId::random();
649
650 let conn = pool.acquire(peer_id).await.unwrap();
652 pool.release(peer_id, conn);
653
654 sleep(Duration::from_millis(200)).await;
656
657 let stats = pool.get_stats();
659 assert_eq!(stats.available, 0);
660 }
661}