#![deny(unsafe_code)]
use crate::connection::{ConnectionInfo, PooledConnection, WarmingState};
use crate::types::{ConnectionStatus, NetworkError, PeerId};
use dashmap::DashMap;
use parking_lot::RwLock;
use std::collections::{HashMap, VecDeque};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Notify, Semaphore};
use tokio::time::{interval, sleep};
use tracing::{debug, warn};
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub max_size: usize,
pub min_size: usize,
pub idle_timeout: Duration,
pub max_lifetime: Duration,
pub health_check_interval: Duration,
pub acquire_timeout: Duration,
pub enable_warming: bool,
pub validate_on_checkout: bool,
pub max_reuse_count: u64,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
max_size: 100,
min_size: 10,
idle_timeout: Duration::from_secs(300), max_lifetime: Duration::from_secs(3600), health_check_interval: Duration::from_secs(30),
acquire_timeout: Duration::from_secs(10),
enable_warming: true,
validate_on_checkout: true,
max_reuse_count: 1000,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct PoolStats {
pub total_created: u64,
pub total_destroyed: u64,
pub current_size: usize,
pub available: usize,
pub active: usize,
pub acquisitions: u64,
pub releases: u64,
pub failed_acquisitions: u64,
pub timeouts: u64,
pub avg_wait_time: Duration,
pub hit_rate: f64,
}
pub struct ConnectionPool {
config: PoolConfig,
available: Arc<DashMap<PeerId, VecDeque<PooledConnection>>>,
active: Arc<DashMap<PeerId, HashMap<u64, PooledConnection>>>,
semaphores: Arc<DashMap<PeerId, Arc<Semaphore>>>,
stats: Arc<RwLock<PoolStats>>,
connection_counter: AtomicUsize,
shutdown: AtomicBool,
waiters: Arc<DashMap<PeerId, Arc<Notify>>>,
#[allow(dead_code)]
maintenance_handle: Option<tokio::task::JoinHandle<()>>,
}
impl ConnectionPool {
pub fn new(config: PoolConfig) -> Self {
let pool = Self {
config: config.clone(),
available: Arc::new(DashMap::new()),
active: Arc::new(DashMap::new()),
semaphores: Arc::new(DashMap::new()),
stats: Arc::new(RwLock::new(PoolStats::default())),
connection_counter: AtomicUsize::new(0),
shutdown: AtomicBool::new(false),
waiters: Arc::new(DashMap::new()),
maintenance_handle: None,
};
let maintenance_pool = pool.clone();
let handle = tokio::spawn(async move {
maintenance_pool.run_maintenance().await;
});
Self {
maintenance_handle: Some(handle),
..pool
}
}
pub async fn acquire(&self, peer_id: PeerId) -> Result<PooledConnection, NetworkError> {
if self.shutdown.load(Ordering::Acquire) {
return Err(NetworkError::ConnectionError(
"Pool is shutting down".into(),
));
}
let start_time = Instant::now();
let semaphore = self
.semaphores
.entry(peer_id)
.or_insert_with(|| Arc::new(Semaphore::new(self.config.max_size)))
.clone();
let permit = tokio::select! {
result = semaphore.acquire() => {
result.map_err(|_| NetworkError::ConnectionError("Semaphore closed".into()))?
}
_ = sleep(self.config.acquire_timeout) => {
self.increment_timeouts();
return Err(NetworkError::ConnectionError("Connection acquisition timeout".into()));
}
};
if let Some(mut available_queue) = self.available.get_mut(&peer_id) {
while let Some(mut conn) = available_queue.pop_front() {
if self.is_connection_valid(&conn) {
if self.config.validate_on_checkout {
if !self.validate_connection(&conn).await {
continue;
}
}
conn.last_used = Instant::now();
conn.usage_count += 1;
let conn_id = self.connection_counter.fetch_add(1, Ordering::Relaxed) as u64;
self.active
.entry(peer_id)
.or_insert_with(HashMap::new)
.insert(conn_id, conn.clone());
self.update_acquisition_stats(start_time.elapsed());
std::mem::forget(permit);
return Ok(conn);
}
}
}
if self.get_peer_connection_count(peer_id) < self.config.max_size {
let conn = self.create_connection(peer_id).await?;
let conn_id = self.connection_counter.fetch_add(1, Ordering::Relaxed) as u64;
self.active
.entry(peer_id)
.or_insert_with(HashMap::new)
.insert(conn_id, conn.clone());
self.update_acquisition_stats(start_time.elapsed());
self.increment_created();
std::mem::forget(permit);
Ok(conn)
} else {
let waiter = self
.waiters
.entry(peer_id)
.or_insert_with(|| Arc::new(Notify::new()))
.clone();
drop(permit);
tokio::select! {
_ = waiter.notified() => {
Box::pin(self.acquire(peer_id)).await
}
_ = sleep(self.config.acquire_timeout) => {
self.increment_failed_acquisitions();
Err(NetworkError::ConnectionError("No available connections".into()))
}
}
}
}
pub fn release(&self, peer_id: PeerId, mut connection: PooledConnection) {
if self.shutdown.load(Ordering::Acquire) {
return;
}
connection.last_used = Instant::now();
if !self.should_keep_connection(&connection) {
self.destroy_connection(peer_id, connection);
return;
}
self.available
.entry(peer_id)
.or_insert_with(VecDeque::new)
.push_back(connection);
if let Some(waiter) = self.waiters.get(&peer_id) {
waiter.notify_one();
}
self.increment_releases();
}
async fn validate_connection(&self, conn: &PooledConnection) -> bool {
if !conn.info.is_healthy() {
return false;
}
true
}
fn is_connection_valid(&self, conn: &PooledConnection) -> bool {
if conn.created_at.elapsed() > self.config.max_lifetime {
return false;
}
if conn.last_used.elapsed() > self.config.idle_timeout {
return false;
}
if conn.usage_count >= self.config.max_reuse_count {
return false;
}
conn.info.is_healthy()
}
fn should_keep_connection(&self, conn: &PooledConnection) -> bool {
self.is_connection_valid(conn) && self.get_total_connection_count() < self.config.max_size
}
async fn create_connection(&self, _peer_id: PeerId) -> Result<PooledConnection, NetworkError> {
let info = ConnectionInfo::new(ConnectionStatus::Connected);
let mut conn = PooledConnection {
info,
created_at: Instant::now(),
last_used: Instant::now(),
usage_count: 0,
weight: 1.0,
max_streams: 100,
active_streams: 0,
warming_state: WarmingState::Cold,
affinity_group: None,
};
if self.config.enable_warming {
self.warm_connection(&mut conn).await?;
}
Ok(conn)
}
async fn warm_connection(&self, conn: &mut PooledConnection) -> Result<(), NetworkError> {
conn.warming_state = WarmingState::Warming;
sleep(Duration::from_millis(50)).await;
conn.warming_state = WarmingState::Warm;
Ok(())
}
fn destroy_connection(&self, _peer_id: PeerId, _conn: PooledConnection) {
self.increment_destroyed();
}
fn get_peer_connection_count(&self, peer_id: PeerId) -> usize {
let available_count = self
.available
.get(&peer_id)
.map(|queue| queue.len())
.unwrap_or(0);
let active_count = self.active.get(&peer_id).map(|map| map.len()).unwrap_or(0);
available_count + active_count
}
fn get_total_connection_count(&self) -> usize {
let available_count: usize = self.available.iter().map(|entry| entry.value().len()).sum();
let active_count: usize = self.active.iter().map(|entry| entry.value().len()).sum();
available_count + active_count
}
async fn run_maintenance(&self) {
let mut interval = interval(self.config.health_check_interval);
while !self.shutdown.load(Ordering::Acquire) {
interval.tick().await;
self.cleanup_expired_connections();
self.maintain_minimum_size().await;
self.update_pool_stats();
}
}
fn cleanup_expired_connections(&self) {
for mut entry in self.available.iter_mut() {
let peer_id = *entry.key();
let queue = entry.value_mut();
queue.retain(|conn| {
if self.is_connection_valid(conn) {
true
} else {
self.destroy_connection(peer_id, conn.clone());
false
}
});
}
}
async fn maintain_minimum_size(&self) {
let total_count = self.get_total_connection_count();
if total_count < self.config.min_size {
let needed = self.config.min_size - total_count;
debug!("Pool below minimum size, creating {} connections", needed);
for entry in self.available.iter() {
let peer_id = *entry.key();
for _ in 0..needed {
match self.create_connection(peer_id).await {
Ok(conn) => {
self.available
.entry(peer_id)
.or_insert_with(VecDeque::new)
.push_back(conn);
self.increment_created();
}
Err(e) => {
warn!("Failed to create connection during maintenance: {}", e);
}
}
}
}
}
}
fn update_pool_stats(&self) {
let mut stats = self.stats.write();
stats.current_size = self.get_total_connection_count();
stats.available = self.available.iter().map(|entry| entry.value().len()).sum();
stats.active = self.active.iter().map(|entry| entry.value().len()).sum();
if stats.acquisitions > 0 {
stats.hit_rate = 1.0 - (stats.failed_acquisitions as f64 / stats.acquisitions as f64);
}
}
pub async fn shutdown(&mut self) {
self.shutdown.store(true, Ordering::Release);
if let Some(handle) = self.maintenance_handle.take() {
handle.abort();
}
for entry in self.available.iter() {
let peer_id = *entry.key();
for conn in entry.value().iter() {
self.destroy_connection(peer_id, conn.clone());
}
}
for entry in self.active.iter() {
let peer_id = *entry.key();
for (_, conn) in entry.value().iter() {
self.destroy_connection(peer_id, conn.clone());
}
}
self.available.clear();
self.active.clear();
self.semaphores.clear();
self.waiters.clear();
}
pub fn get_stats(&self) -> PoolStats {
self.stats.read().clone()
}
fn increment_created(&self) {
self.stats.write().total_created += 1;
}
fn increment_destroyed(&self) {
self.stats.write().total_destroyed += 1;
}
fn increment_releases(&self) {
self.stats.write().releases += 1;
}
fn increment_timeouts(&self) {
self.stats.write().timeouts += 1;
}
fn increment_failed_acquisitions(&self) {
self.stats.write().failed_acquisitions += 1;
}
fn update_acquisition_stats(&self, wait_time: Duration) {
let mut stats = self.stats.write();
stats.acquisitions += 1;
let alpha = 0.1;
let current_avg = stats.avg_wait_time.as_millis() as f64;
let new_wait = wait_time.as_millis() as f64;
let updated_avg = alpha * new_wait + (1.0 - alpha) * current_avg;
stats.avg_wait_time = Duration::from_millis(updated_avg as u64);
}
}
impl Clone for ConnectionPool {
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
available: self.available.clone(),
active: self.active.clone(),
semaphores: self.semaphores.clone(),
stats: self.stats.clone(),
connection_counter: AtomicUsize::new(self.connection_counter.load(Ordering::Relaxed)),
shutdown: AtomicBool::new(self.shutdown.load(Ordering::Relaxed)),
waiters: self.waiters.clone(),
maintenance_handle: None, }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_pool_creation() {
let config = PoolConfig::default();
let pool = ConnectionPool::new(config);
let stats = pool.get_stats();
assert_eq!(stats.current_size, 0);
assert_eq!(stats.available, 0);
assert_eq!(stats.active, 0);
}
#[tokio::test]
async fn test_connection_acquisition() {
let config = PoolConfig {
max_size: 10,
min_size: 0,
..Default::default()
};
let pool = ConnectionPool::new(config);
let peer_id = PeerId::random();
let conn = pool.acquire(peer_id).await.unwrap();
assert_eq!(conn.usage_count, 1);
let stats = pool.get_stats();
assert_eq!(stats.acquisitions, 1);
assert_eq!(stats.total_created, 1);
}
#[tokio::test]
async fn test_connection_release() {
let config = PoolConfig::default();
let pool = ConnectionPool::new(config);
let peer_id = PeerId::random();
let conn = pool.acquire(peer_id).await.unwrap();
pool.release(peer_id, conn);
let stats = pool.get_stats();
assert_eq!(stats.releases, 1);
assert_eq!(stats.available, 1);
}
#[tokio::test]
async fn test_connection_reuse() {
let config = PoolConfig::default();
let pool = ConnectionPool::new(config);
let peer_id = PeerId::random();
let conn1 = pool.acquire(peer_id).await.unwrap();
let created_at = conn1.created_at;
pool.release(peer_id, conn1);
let conn2 = pool.acquire(peer_id).await.unwrap();
assert_eq!(conn2.created_at, created_at);
assert_eq!(conn2.usage_count, 2);
let stats = pool.get_stats();
assert_eq!(stats.total_created, 1);
assert_eq!(stats.acquisitions, 2);
}
#[tokio::test]
async fn test_pool_limits() {
let config = PoolConfig {
max_size: 2,
acquire_timeout: Duration::from_millis(100),
..Default::default()
};
let pool = ConnectionPool::new(config);
let peer_id = PeerId::random();
let conn1 = pool.acquire(peer_id).await.unwrap();
let conn2 = pool.acquire(peer_id).await.unwrap();
let result = pool.acquire(peer_id).await;
assert!(result.is_err());
pool.release(peer_id, conn1);
let conn3 = pool.acquire(peer_id).await;
assert!(conn3.is_ok());
pool.release(peer_id, conn2);
pool.release(peer_id, conn3.unwrap());
}
#[tokio::test]
async fn test_connection_expiration() {
let config = PoolConfig {
idle_timeout: Duration::from_millis(100),
health_check_interval: Duration::from_millis(50),
..Default::default()
};
let pool = ConnectionPool::new(config);
let peer_id = PeerId::random();
let conn = pool.acquire(peer_id).await.unwrap();
pool.release(peer_id, conn);
sleep(Duration::from_millis(200)).await;
let stats = pool.get_stats();
assert_eq!(stats.available, 0);
}
}