use std::collections::VecDeque;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, Semaphore};
use tracing::{debug, error, warn};
use crate::error::{ProtocolError, Result};
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub min_size: usize,
pub max_size: usize,
pub idle_timeout: Duration,
pub max_lifetime: Duration,
pub max_waiters: usize,
pub circuit_breaker_threshold: usize,
pub circuit_breaker_timeout: Duration,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
min_size: 5,
max_size: 50,
idle_timeout: Duration::from_secs(300), max_lifetime: Duration::from_secs(3600), max_waiters: 1000, circuit_breaker_threshold: 5,
circuit_breaker_timeout: Duration::from_secs(10),
}
}
}
impl PoolConfig {
pub fn validate(&self) -> Result<()> {
let mut errors = Vec::new();
if self.max_size == 0 {
errors.push("Pool max_size must be greater than 0".to_string());
}
if self.min_size > self.max_size {
errors.push(format!(
"Pool min_size ({}) cannot exceed max_size ({})",
self.min_size, self.max_size
));
}
if self.max_size > 10_000 {
errors.push(format!(
"Pool max_size ({}) exceeds recommended limit (10,000)",
self.max_size
));
}
if self.max_waiters == 0 {
errors.push("Pool max_waiters must be greater than 0".to_string());
}
if self.max_waiters > 1_000_000 {
errors.push(format!(
"Pool max_waiters ({}) exceeds recommended limit (1,000,000)",
self.max_waiters
));
}
if self.idle_timeout.is_zero() {
errors.push("Pool idle_timeout must be greater than 0".to_string());
}
if self.max_lifetime.is_zero() {
errors.push("Pool max_lifetime must be greater than 0".to_string());
}
if self.idle_timeout >= self.max_lifetime {
errors.push(format!(
"Pool idle_timeout ({:?}) should be less than max_lifetime ({:?})",
self.idle_timeout, self.max_lifetime
));
}
if self.idle_timeout.as_secs() > 3600 {
errors.push(format!(
"Pool idle_timeout ({} seconds) is unusually long (recommended: < 1 hour)",
self.idle_timeout.as_secs()
));
}
if self.max_lifetime.as_secs() > 86400 {
errors.push(format!(
"Pool max_lifetime ({} seconds) is unusually long (recommended: < 24 hours)",
self.max_lifetime.as_secs()
));
}
if self.circuit_breaker_threshold == 0 {
errors.push("Circuit breaker threshold must be greater than 0".to_string());
}
if self.circuit_breaker_threshold > 100 {
errors.push(format!(
"Circuit breaker threshold ({}) is unusually high (recommended: < 100)",
self.circuit_breaker_threshold
));
}
if self.circuit_breaker_timeout.is_zero() {
errors.push("Circuit breaker timeout must be greater than 0".to_string());
}
if self.circuit_breaker_timeout.as_secs() > 300 {
errors.push(format!(
"Circuit breaker timeout ({} seconds) is unusually long (recommended: < 5 minutes)",
self.circuit_breaker_timeout.as_secs()
));
}
if errors.is_empty() {
Ok(())
} else {
Err(ProtocolError::ConfigError(format!(
"Pool configuration validation failed:\n - {}",
errors.join("\n - ")
)))
}
}
}
struct PooledConnection<T> {
connection: T,
created_at: Instant,
last_used_at: Instant,
}
impl<T> PooledConnection<T> {
fn new(connection: T) -> Self {
let now = Instant::now();
Self {
connection,
created_at: now,
last_used_at: now,
}
}
fn is_expired(&self, config: &PoolConfig) -> bool {
let now = Instant::now();
if now.duration_since(self.created_at) > config.max_lifetime {
return true;
}
if now.duration_since(self.last_used_at) > config.idle_timeout {
return true;
}
false
}
fn touch(&mut self) {
self.last_used_at = Instant::now();
}
}
pub trait ConnectionFactory<T>: Send + Sync {
fn create(&self) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<T>> + Send>>;
fn is_healthy(&self, _conn: &T) -> bool {
true
}
}
#[derive(Debug, Default)]
pub struct PoolMetrics {
pub connections_created: AtomicU64,
pub connections_reused: AtomicU64,
pub connections_evicted: AtomicU64,
pub acquisition_errors: AtomicU64,
pub active_connections: AtomicUsize,
pub idle_connections: AtomicUsize,
pub total_wait_time_us: AtomicU64,
pub total_acquisitions: AtomicU64,
}
impl PoolMetrics {
pub fn new() -> Self {
Self::default()
}
pub fn average_wait_time_us(&self) -> u64 {
let total = self.total_acquisitions.load(Ordering::Relaxed);
if total == 0 {
return 0;
}
self.total_wait_time_us.load(Ordering::Relaxed) / total
}
pub fn utilization_percent(&self) -> f64 {
let active = self.active_connections.load(Ordering::Relaxed) as f64;
let idle = self.idle_connections.load(Ordering::Relaxed) as f64;
let total = active + idle;
if total == 0.0 {
return 0.0;
}
(active / total) * 100.0
}
}
#[derive(Debug)]
struct CircuitBreaker {
consecutive_failures: AtomicUsize,
threshold: usize,
timeout: Duration,
opened_at: Mutex<Option<Instant>>,
}
impl CircuitBreaker {
fn new(threshold: usize, timeout: Duration) -> Self {
Self {
consecutive_failures: AtomicUsize::new(0),
threshold,
timeout,
opened_at: Mutex::new(None),
}
}
async fn check(&self) -> Result<()> {
let mut opened_at = self.opened_at.lock().await;
if let Some(opened_time) = *opened_at {
if opened_time.elapsed() < self.timeout {
return Err(ProtocolError::CircuitBreakerOpen);
}
*opened_at = None;
self.consecutive_failures.store(0, Ordering::SeqCst);
debug!("Circuit breaker entering half-open state");
}
Ok(())
}
async fn record_success(&self) {
self.consecutive_failures.store(0, Ordering::SeqCst);
let mut opened_at = self.opened_at.lock().await;
if opened_at.is_some() {
*opened_at = None;
debug!("Circuit breaker closed after successful operation");
}
}
async fn record_failure(&self) {
let failures = self.consecutive_failures.fetch_add(1, Ordering::SeqCst) + 1;
if failures >= self.threshold {
let mut opened_at = self.opened_at.lock().await;
*opened_at = Some(Instant::now());
error!(
"Circuit breaker opened after {} consecutive failures",
failures
);
}
}
}
pub struct ConnectionPool<T> {
config: PoolConfig,
factory: Arc<dyn ConnectionFactory<T>>,
connections: Arc<Mutex<VecDeque<PooledConnection<T>>>>,
metrics: Arc<PoolMetrics>,
circuit_breaker: Arc<CircuitBreaker>,
backpressure: Arc<Semaphore>,
}
impl<T: Send + 'static> ConnectionPool<T> {
pub fn new(factory: Arc<dyn ConnectionFactory<T>>, config: PoolConfig) -> Result<Self> {
config.validate()?;
let metrics = Arc::new(PoolMetrics::new());
let circuit_breaker = Arc::new(CircuitBreaker::new(
config.circuit_breaker_threshold,
config.circuit_breaker_timeout,
));
let pool = Self {
config: config.clone(),
factory: factory.clone(),
connections: Arc::new(Mutex::new(VecDeque::new())),
metrics: metrics.clone(),
circuit_breaker,
backpressure: Arc::new(Semaphore::new(config.max_waiters)),
};
if config.min_size > 0 {
let factory_clone = factory;
let connections_clone = pool.connections.clone();
let metrics_clone = metrics;
let min_size = config.min_size;
tokio::spawn(async move {
debug!("Warming connection pool with {} connections", min_size);
for _ in 0..min_size {
match factory_clone.create().await {
Ok(conn) => {
let mut connections = connections_clone.lock().await;
connections.push_back(PooledConnection::new(conn));
metrics_clone
.connections_created
.fetch_add(1, Ordering::Relaxed);
metrics_clone
.idle_connections
.fetch_add(1, Ordering::Relaxed);
}
Err(e) => {
warn!("Failed to warm connection: {}", e);
break;
}
}
}
debug!("Connection pool warming complete");
});
}
Ok(pool)
}
pub async fn acquire(&self) -> Result<PooledConnectionGuard<T>> {
let start = Instant::now();
let _permit = self
.backpressure
.acquire()
.await
.map_err(|_| ProtocolError::PoolExhausted)?;
self.circuit_breaker.check().await?;
let mut connections = self.connections.lock().await;
while let Some(mut pooled) = connections.pop_back() {
if !pooled.is_expired(&self.config) && self.factory.is_healthy(&pooled.connection) {
pooled.touch();
self.metrics
.connections_reused
.fetch_add(1, Ordering::Relaxed);
self.metrics
.idle_connections
.fetch_sub(1, Ordering::Relaxed);
self.metrics
.active_connections
.fetch_add(1, Ordering::Relaxed);
self.metrics
.total_acquisitions
.fetch_add(1, Ordering::Relaxed);
self.metrics
.total_wait_time_us
.fetch_add(start.elapsed().as_micros() as u64, Ordering::Relaxed);
debug!("Reused connection from pool (LRU)");
return Ok(PooledConnectionGuard {
connection: Some(pooled.connection),
pool: self.connections.clone(),
metrics: self.metrics.clone(),
});
}
debug!("Evicted expired/unhealthy connection from pool");
self.metrics
.connections_evicted
.fetch_add(1, Ordering::Relaxed);
self.metrics
.idle_connections
.fetch_sub(1, Ordering::Relaxed);
}
drop(connections);
match self.factory.create().await {
Ok(new_conn) => {
self.circuit_breaker.record_success().await;
self.metrics
.connections_created
.fetch_add(1, Ordering::Relaxed);
self.metrics
.active_connections
.fetch_add(1, Ordering::Relaxed);
self.metrics
.total_acquisitions
.fetch_add(1, Ordering::Relaxed);
self.metrics
.total_wait_time_us
.fetch_add(start.elapsed().as_micros() as u64, Ordering::Relaxed);
debug!("Created new connection for pool");
Ok(PooledConnectionGuard {
connection: Some(new_conn),
pool: self.connections.clone(),
metrics: self.metrics.clone(),
})
}
Err(e) => {
self.circuit_breaker.record_failure().await;
self.metrics
.acquisition_errors
.fetch_add(1, Ordering::Relaxed);
Err(e)
}
}
}
pub fn metrics(&self) -> Arc<PoolMetrics> {
self.metrics.clone()
}
pub async fn size(&self) -> usize {
self.connections.lock().await.len()
}
pub async fn clear(&self) {
self.connections.lock().await.clear();
debug!("Cleared all connections from pool");
}
pub fn config(&self) -> &PoolConfig {
&self.config
}
}
pub struct PooledConnectionGuard<T: Send + 'static> {
connection: Option<T>,
pool: Arc<Mutex<VecDeque<PooledConnection<T>>>>,
metrics: Arc<PoolMetrics>,
}
impl<T: Send + 'static> PooledConnectionGuard<T> {
pub fn get(&self) -> Option<&T> {
self.connection.as_ref()
}
pub fn get_mut(&mut self) -> Option<&mut T> {
self.connection.as_mut()
}
pub fn into_inner(mut self) -> Option<T> {
self.connection.take()
}
}
impl<T: Send + 'static> AsRef<T> for PooledConnectionGuard<T> {
#[allow(clippy::expect_used)] fn as_ref(&self) -> &T {
self.connection.as_ref().expect("Connection should exist")
}
}
impl<T: Send + 'static> AsMut<T> for PooledConnectionGuard<T> {
#[allow(clippy::expect_used)] fn as_mut(&mut self) -> &mut T {
self.connection.as_mut().expect("Connection should exist")
}
}
impl<T: Send + 'static> Drop for PooledConnectionGuard<T> {
fn drop(&mut self) {
if let Some(conn) = self.connection.take() {
let pool = self.pool.clone();
let metrics = self.metrics.clone();
let pooled = PooledConnection::new(conn);
metrics.active_connections.fetch_sub(1, Ordering::Relaxed);
tokio::spawn(async move {
let mut connections = pool.lock().await;
if connections.len() < 100 {
connections.push_back(pooled);
metrics.idle_connections.fetch_add(1, Ordering::Relaxed);
} else {
warn!("Connection pool at capacity, discarding connection");
}
});
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[allow(dead_code)]
struct TestConnection {
id: usize,
}
struct TestFactory {
counter: Arc<AtomicUsize>,
}
impl TestFactory {
fn new() -> Self {
Self {
counter: Arc::new(AtomicUsize::new(0)),
}
}
fn count(&self) -> usize {
self.counter.load(Ordering::SeqCst)
}
}
impl ConnectionFactory<TestConnection> for TestFactory {
fn create(
&self,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<TestConnection>> + Send>>
{
let id = self.counter.fetch_add(1, Ordering::SeqCst);
Box::pin(async move { Ok(TestConnection { id }) })
}
}
#[tokio::test]
async fn test_pool_creation() {
let factory = Arc::new(TestFactory::new());
let pool = ConnectionPool::new(
factory.clone(),
PoolConfig {
min_size: 2,
max_size: 10,
idle_timeout: Duration::from_secs(60),
max_lifetime: Duration::from_secs(600),
..Default::default()
},
);
assert!(pool.is_ok());
}
#[tokio::test]
#[allow(clippy::unwrap_used)] async fn test_pool_acquire_creates_connection() {
let factory = Arc::new(TestFactory::new());
let pool = ConnectionPool::new(factory.clone(), PoolConfig::default()).unwrap();
let guard = pool.acquire().await.unwrap();
assert!(guard.get().is_some());
assert_eq!(factory.count(), 1);
}
#[tokio::test]
async fn test_config_validation() {
let invalid_config = PoolConfig {
min_size: 100,
max_size: 10,
idle_timeout: Duration::from_secs(60),
max_lifetime: Duration::from_secs(600),
..Default::default()
};
let factory = Arc::new(TestFactory::new());
let result = ConnectionPool::new(factory, invalid_config);
assert!(result.is_err());
}
#[tokio::test]
async fn test_config_validation_zero_max_size() {
let config = PoolConfig {
max_size: 0,
..Default::default()
};
assert!(config.validate().is_err());
}
#[tokio::test]
async fn test_config_validation_zero_timeouts() {
let config = PoolConfig {
idle_timeout: Duration::from_secs(0),
..Default::default()
};
assert!(config.validate().is_err());
let config2 = PoolConfig {
max_lifetime: Duration::from_secs(0),
..Default::default()
};
assert!(config2.validate().is_err());
}
#[tokio::test]
async fn test_config_validation_idle_exceeds_lifetime() {
let config = PoolConfig {
idle_timeout: Duration::from_secs(600),
max_lifetime: Duration::from_secs(300),
..Default::default()
};
assert!(config.validate().is_err());
}
#[tokio::test]
async fn test_config_validation_circuit_breaker() {
let config = PoolConfig {
circuit_breaker_threshold: 0,
..Default::default()
};
assert!(config.validate().is_err());
let config2 = PoolConfig {
circuit_breaker_timeout: Duration::from_secs(0),
..Default::default()
};
assert!(config2.validate().is_err());
}
#[tokio::test]
async fn test_config_validation_valid_config() {
let config = PoolConfig::default();
assert!(config.validate().is_ok());
}
}