use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{RwLock, Semaphore};
use tracing::{debug, warn};
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub max_connections: usize,
pub min_idle: usize,
pub max_idle_time: Duration,
pub connection_timeout: Duration,
pub health_check_enabled: bool,
pub health_check_interval: Duration,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
max_connections: 100,
min_idle: 10,
max_idle_time: Duration::from_secs(600), connection_timeout: Duration::from_secs(30),
health_check_enabled: true,
health_check_interval: Duration::from_secs(60),
}
}
}
pub struct PooledConnection<T> {
inner: T,
created_at: Instant,
last_used: Instant,
}
impl<T> PooledConnection<T> {
pub fn new(connection: T) -> Self {
let now = Instant::now();
Self {
inner: connection,
created_at: now,
last_used: now,
}
}
pub fn get(&self) -> &T {
&self.inner
}
pub fn get_mut(&mut self) -> &mut T {
self.last_used = Instant::now();
&mut self.inner
}
pub fn is_stale(&self, max_idle_time: Duration) -> bool {
self.last_used.elapsed() > max_idle_time
}
pub fn age(&self) -> Duration {
self.created_at.elapsed()
}
}
pub struct ConnectionPool<T> {
config: PoolConfig,
available: Arc<RwLock<Vec<PooledConnection<T>>>>,
semaphore: Arc<Semaphore>,
metrics: Arc<RwLock<PoolMetrics>>,
}
#[derive(Debug, Default, Clone)]
pub struct PoolMetrics {
pub active_connections: usize,
pub idle_connections: usize,
pub total_acquired: u64,
pub total_released: u64,
pub total_created: u64,
pub total_closed: u64,
pub acquire_timeouts: u64,
pub health_check_failures: u64,
}
impl<T> ConnectionPool<T>
where
T: Send + 'static,
{
pub fn new(config: PoolConfig) -> Self {
Self {
semaphore: Arc::new(Semaphore::new(config.max_connections)),
available: Arc::new(RwLock::new(Vec::with_capacity(config.max_connections))),
metrics: Arc::new(RwLock::new(PoolMetrics::default())),
config,
}
}
pub async fn acquire<F, Fut>(&self, create_fn: F) -> Result<PooledConnection<T>, PoolError>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<T, PoolError>>,
{
let permit = tokio::time::timeout(
self.config.connection_timeout,
self.semaphore.clone().acquire_owned(),
)
.await
.map_err(|_| {
debug!("Connection pool acquire timeout");
PoolError::Timeout
})?
.map_err(|_| PoolError::Closed)?;
let mut available = self.available.write().await;
available.retain(|conn| !conn.is_stale(self.config.max_idle_time));
let connection = if let Some(mut conn) = available.pop() {
conn.last_used = Instant::now();
drop(available);
let mut metrics = self.metrics.write().await;
metrics.total_acquired += 1;
metrics.active_connections += 1;
metrics.idle_connections = metrics.idle_connections.saturating_sub(1);
drop(metrics);
debug!("Reusing pooled connection");
conn
} else {
drop(available);
let inner = create_fn().await?;
let conn = PooledConnection::new(inner);
let mut metrics = self.metrics.write().await;
metrics.total_created += 1;
metrics.total_acquired += 1;
metrics.active_connections += 1;
drop(metrics);
debug!("Created new pooled connection");
conn
};
std::mem::forget(permit);
Ok(connection)
}
pub async fn release(&self, connection: PooledConnection<T>) {
let mut available = self.available.write().await;
if available.len() >= self.config.min_idle && connection.is_stale(self.config.max_idle_time)
{
drop(available);
let mut metrics = self.metrics.write().await;
metrics.total_closed += 1;
metrics.active_connections = metrics.active_connections.saturating_sub(1);
drop(metrics);
self.semaphore.add_permits(1);
debug!("Closed stale connection");
return;
}
available.push(connection);
drop(available);
let mut metrics = self.metrics.write().await;
metrics.total_released += 1;
metrics.active_connections = metrics.active_connections.saturating_sub(1);
metrics.idle_connections += 1;
drop(metrics);
self.semaphore.add_permits(1);
debug!("Released connection to pool");
}
pub async fn metrics(&self) -> PoolMetrics {
self.metrics.read().await.clone()
}
pub async fn size(&self) -> usize {
self.available.read().await.len()
}
pub async fn health_check<F, Fut>(&self, check_fn: F)
where
F: Fn(&T) -> Fut,
Fut: std::future::Future<Output = bool>,
{
if !self.config.health_check_enabled {
return;
}
let mut available = self.available.write().await;
let mut healthy = Vec::new();
let mut failures = 0;
for conn in available.drain(..) {
if check_fn(conn.get()).await {
healthy.push(conn);
} else {
failures += 1;
warn!("Connection failed health check");
}
}
*available = healthy;
drop(available);
if failures > 0 {
let mut metrics = self.metrics.write().await;
metrics.health_check_failures += failures;
metrics.total_closed += failures;
metrics.idle_connections = metrics.idle_connections.saturating_sub(failures as usize);
drop(metrics);
self.semaphore.add_permits(failures as usize);
}
}
pub async fn maintain_idle<F, Fut>(&self, create_fn: F)
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T, PoolError>>,
{
let current_idle = self.available.read().await.len();
if current_idle < self.config.min_idle {
let needed = self.config.min_idle - current_idle;
for _ in 0..needed {
if let Ok(permit) = self.semaphore.clone().try_acquire_owned() {
match create_fn().await {
Ok(conn) => {
let pooled = PooledConnection::new(conn);
self.available.write().await.push(pooled);
let mut metrics = self.metrics.write().await;
metrics.total_created += 1;
metrics.idle_connections += 1;
std::mem::forget(permit);
}
Err(e) => {
warn!("Failed to create idle connection: {:?}", e);
drop(permit);
}
}
} else {
break;
}
}
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum PoolError {
#[error("Connection pool timeout")]
Timeout,
#[error("Connection pool closed")]
Closed,
#[error("Failed to create connection: {0}")]
CreateError(String),
#[error("Connection error: {0}")]
ConnectionError(String),
}
pub type HttpClientPool = ConnectionPool<reqwest::Client>;
impl HttpClientPool {
pub fn new_http(config: PoolConfig) -> Self {
Self::new(config)
}
pub async fn acquire_client(&self) -> Result<PooledConnection<reqwest::Client>, PoolError> {
self.acquire(|| async {
reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.pool_max_idle_per_host(10)
.build()
.map_err(|e| PoolError::CreateError(e.to_string()))
})
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn test_pool_config_default() {
let config = PoolConfig::default();
assert_eq!(config.max_connections, 100);
assert_eq!(config.min_idle, 10);
assert_eq!(config.max_idle_time, Duration::from_secs(600));
assert_eq!(config.connection_timeout, Duration::from_secs(30));
assert!(config.health_check_enabled);
assert_eq!(config.health_check_interval, Duration::from_secs(60));
}
#[test]
fn test_pool_config_clone() {
let config = PoolConfig {
max_connections: 50,
min_idle: 5,
..Default::default()
};
let cloned = config.clone();
assert_eq!(cloned.max_connections, config.max_connections);
assert_eq!(cloned.min_idle, config.min_idle);
}
#[test]
fn test_pool_config_debug() {
let config = PoolConfig::default();
let debug = format!("{:?}", config);
assert!(debug.contains("PoolConfig"));
assert!(debug.contains("max_connections"));
}
#[test]
fn test_pooled_connection_new() {
let conn = PooledConnection::new(42u32);
assert_eq!(*conn.get(), 42);
}
#[test]
fn test_pooled_connection_get() {
let conn = PooledConnection::new("test".to_string());
assert_eq!(*conn.get(), "test");
}
#[test]
fn test_pooled_connection_get_mut() {
let mut conn = PooledConnection::new(vec![1, 2, 3]);
conn.get_mut().push(4);
assert_eq!(*conn.get(), vec![1, 2, 3, 4]);
}
#[test]
fn test_pooled_connection_is_stale() {
let conn = PooledConnection::new(42u32);
assert!(!conn.is_stale(Duration::from_secs(1)));
assert!(conn.is_stale(Duration::from_nanos(0)));
}
#[test]
fn test_pooled_connection_age() {
let conn = PooledConnection::new(42u32);
let age = conn.age();
assert!(age < Duration::from_secs(1));
}
#[test]
fn test_pool_metrics_default() {
let metrics = PoolMetrics::default();
assert_eq!(metrics.active_connections, 0);
assert_eq!(metrics.idle_connections, 0);
assert_eq!(metrics.total_acquired, 0);
assert_eq!(metrics.total_released, 0);
assert_eq!(metrics.total_created, 0);
assert_eq!(metrics.total_closed, 0);
assert_eq!(metrics.acquire_timeouts, 0);
assert_eq!(metrics.health_check_failures, 0);
}
#[test]
fn test_pool_metrics_clone() {
let metrics = PoolMetrics {
total_acquired: 10,
active_connections: 5,
..Default::default()
};
let cloned = metrics.clone();
assert_eq!(cloned.total_acquired, 10);
assert_eq!(cloned.active_connections, 5);
}
#[test]
fn test_pool_metrics_debug() {
let metrics = PoolMetrics::default();
let debug = format!("{:?}", metrics);
assert!(debug.contains("PoolMetrics"));
assert!(debug.contains("active_connections"));
}
#[test]
fn test_pool_error_timeout() {
let error = PoolError::Timeout;
assert!(error.to_string().contains("timeout"));
}
#[test]
fn test_pool_error_closed() {
let error = PoolError::Closed;
assert!(error.to_string().contains("closed"));
}
#[test]
fn test_pool_error_create_error() {
let error = PoolError::CreateError("connection failed".to_string());
let msg = error.to_string();
assert!(msg.contains("create connection"));
assert!(msg.contains("connection failed"));
}
#[test]
fn test_pool_error_connection_error() {
let error = PoolError::ConnectionError("network issue".to_string());
let msg = error.to_string();
assert!(msg.contains("Connection error"));
assert!(msg.contains("network issue"));
}
#[test]
fn test_pool_error_debug() {
let error = PoolError::Timeout;
let debug = format!("{:?}", error);
assert!(debug.contains("Timeout"));
}
#[tokio::test]
async fn test_connection_pool() {
let config = PoolConfig {
max_connections: 5,
min_idle: 2,
..Default::default()
};
let pool = ConnectionPool::<u32>::new(config);
let conn1 = pool.acquire(|| async { Ok(42) }).await.unwrap();
assert_eq!(*conn1.get(), 42);
pool.release(conn1).await;
let metrics = pool.metrics().await;
assert_eq!(metrics.total_created, 1);
assert_eq!(metrics.total_acquired, 1);
assert_eq!(metrics.total_released, 1);
}
#[tokio::test]
async fn test_connection_pool_new() {
let config = PoolConfig {
max_connections: 10,
min_idle: 2,
..Default::default()
};
let pool = ConnectionPool::<u32>::new(config);
assert_eq!(pool.size().await, 0);
}
#[tokio::test]
async fn test_connection_pool_acquire_creates_connection() {
let config = PoolConfig::default();
let pool = ConnectionPool::<String>::new(config);
let conn = pool.acquire(|| async { Ok("test-connection".to_string()) }).await.unwrap();
assert_eq!(*conn.get(), "test-connection");
let metrics = pool.metrics().await;
assert_eq!(metrics.total_created, 1);
assert_eq!(metrics.total_acquired, 1);
}
#[tokio::test]
async fn test_connection_pool_reuses_connection() {
let config = PoolConfig {
max_connections: 5,
min_idle: 1,
..Default::default()
};
let pool = ConnectionPool::<u32>::new(config);
let create_count = Arc::new(AtomicUsize::new(0));
let create_count_clone = create_count.clone();
let conn1 = pool
.acquire(move || {
let count = create_count_clone.clone();
async move {
count.fetch_add(1, Ordering::SeqCst);
Ok(42u32)
}
})
.await
.unwrap();
pool.release(conn1).await;
let create_count_clone = create_count.clone();
let conn2 = pool
.acquire(move || {
let count = create_count_clone.clone();
async move {
count.fetch_add(1, Ordering::SeqCst);
Ok(100u32)
}
})
.await
.unwrap();
assert_eq!(*conn2.get(), 42);
assert_eq!(create_count.load(Ordering::SeqCst), 1);
let metrics = pool.metrics().await;
assert_eq!(metrics.total_created, 1);
assert_eq!(metrics.total_acquired, 2);
}
#[tokio::test]
async fn test_connection_pool_release() {
let config = PoolConfig::default();
let pool = ConnectionPool::<u32>::new(config);
let conn = pool.acquire(|| async { Ok(42) }).await.unwrap();
assert_eq!(pool.size().await, 0);
pool.release(conn).await;
assert_eq!(pool.size().await, 1);
let metrics = pool.metrics().await;
assert_eq!(metrics.total_released, 1);
assert_eq!(metrics.idle_connections, 1);
}
#[tokio::test]
async fn test_connection_pool_metrics() {
let config = PoolConfig::default();
let pool = ConnectionPool::<u32>::new(config);
let metrics = pool.metrics().await;
assert_eq!(metrics.total_created, 0);
let conn = pool.acquire(|| async { Ok(1) }).await.unwrap();
pool.release(conn).await;
let metrics = pool.metrics().await;
assert_eq!(metrics.total_created, 1);
assert_eq!(metrics.total_acquired, 1);
assert_eq!(metrics.total_released, 1);
}
#[tokio::test]
async fn test_connection_pool_size() {
let config = PoolConfig::default();
let pool = ConnectionPool::<u32>::new(config);
assert_eq!(pool.size().await, 0);
let conn1 = pool.acquire(|| async { Ok(1) }).await.unwrap();
let conn2 = pool.acquire(|| async { Ok(2) }).await.unwrap();
assert_eq!(pool.size().await, 0);
pool.release(conn1).await;
assert_eq!(pool.size().await, 1);
pool.release(conn2).await;
assert_eq!(pool.size().await, 2);
}
#[tokio::test]
async fn test_connection_pool_multiple_concurrent_acquires() {
let config = PoolConfig {
max_connections: 10,
..Default::default()
};
let pool = Arc::new(ConnectionPool::<u32>::new(config));
let mut handles = vec![];
for i in 0..5 {
let pool_clone = pool.clone();
let handle = tokio::spawn(async move {
let conn = pool_clone.acquire(move || async move { Ok(i as u32) }).await.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
pool_clone.release(conn).await;
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
let metrics = pool.metrics().await;
assert_eq!(metrics.total_acquired, 5);
assert_eq!(metrics.total_released, 5);
}
#[tokio::test]
async fn test_connection_pool_acquire_error() {
let config = PoolConfig::default();
let pool = ConnectionPool::<u32>::new(config);
let result = pool
.acquire(|| async { Err(PoolError::CreateError("test error".to_string())) })
.await;
assert!(result.is_err());
if let Err(PoolError::CreateError(msg)) = result {
assert_eq!(msg, "test error");
}
}
#[tokio::test]
async fn test_connection_pool_health_check() {
let config = PoolConfig {
max_connections: 5,
min_idle: 0,
health_check_enabled: true,
..Default::default()
};
let pool = ConnectionPool::<u32>::new(config);
let conn1 = pool.acquire(|| async { Ok(1) }).await.unwrap();
let conn2 = pool.acquire(|| async { Ok(2) }).await.unwrap();
pool.release(conn1).await;
pool.release(conn2).await;
pool.health_check(|_| async { true }).await;
assert_eq!(pool.size().await, 2);
pool.health_check(|_| async { false }).await;
assert_eq!(pool.size().await, 0);
let metrics = pool.metrics().await;
assert_eq!(metrics.health_check_failures, 2);
}
#[tokio::test]
async fn test_connection_pool_health_check_disabled() {
let config = PoolConfig {
health_check_enabled: false,
..Default::default()
};
let pool = ConnectionPool::<u32>::new(config);
let conn = pool.acquire(|| async { Ok(1) }).await.unwrap();
pool.release(conn).await;
pool.health_check(|_| async { false }).await;
assert_eq!(pool.size().await, 1);
}
#[tokio::test]
async fn test_connection_pool_maintain_idle() {
let config = PoolConfig {
max_connections: 10,
min_idle: 3,
..Default::default()
};
let pool = ConnectionPool::<u32>::new(config);
assert_eq!(pool.size().await, 0);
pool.maintain_idle(|| async { Ok(42u32) }).await;
assert_eq!(pool.size().await, 3);
let metrics = pool.metrics().await;
assert_eq!(metrics.total_created, 3);
assert_eq!(metrics.idle_connections, 3);
}
#[tokio::test]
async fn test_connection_pool_maintain_idle_already_sufficient() {
let config = PoolConfig {
max_connections: 10,
min_idle: 2,
..Default::default()
};
let pool = ConnectionPool::<u32>::new(config);
let conn1 = pool.acquire(|| async { Ok(1) }).await.unwrap();
let conn2 = pool.acquire(|| async { Ok(2) }).await.unwrap();
let conn3 = pool.acquire(|| async { Ok(3) }).await.unwrap();
pool.release(conn1).await;
pool.release(conn2).await;
pool.release(conn3).await;
let initial_created = pool.metrics().await.total_created;
pool.maintain_idle(|| async { Ok(100u32) }).await;
let final_created = pool.metrics().await.total_created;
assert_eq!(initial_created, final_created);
}
#[tokio::test]
async fn test_connection_pool_maintain_idle_error() {
let config = PoolConfig {
max_connections: 10,
min_idle: 3,
..Default::default()
};
let pool = ConnectionPool::<u32>::new(config);
pool.maintain_idle(|| async { Err(PoolError::CreateError("test".to_string())) })
.await;
assert_eq!(pool.size().await, 0);
}
#[tokio::test]
async fn test_http_client_pool_new() {
let config = PoolConfig::default();
let pool = HttpClientPool::new_http(config);
assert_eq!(pool.size().await, 0);
}
#[tokio::test]
async fn test_http_client_pool_acquire() {
let config = PoolConfig::default();
let pool = HttpClientPool::new_http(config);
let result = pool.acquire_client().await;
assert!(result.is_ok());
let conn = result.unwrap();
let _client: &reqwest::Client = conn.get();
}
#[tokio::test]
async fn test_connection_pool_stale_connection_not_returned() {
let config = PoolConfig {
max_connections: 5,
min_idle: 0, max_idle_time: Duration::from_millis(1), ..Default::default()
};
let pool = ConnectionPool::<u32>::new(config);
let conn = pool.acquire(|| async { Ok(42) }).await.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
pool.release(conn).await;
let metrics = pool.metrics().await;
assert_eq!(metrics.total_closed, 1);
}
#[tokio::test]
async fn test_connection_pool_with_complex_type() {
#[derive(Debug, Clone)]
struct ComplexConnection {
id: u32,
data: Vec<String>,
}
let config = PoolConfig::default();
let pool = ConnectionPool::<ComplexConnection>::new(config);
let conn = pool
.acquire(|| async {
Ok(ComplexConnection {
id: 123,
data: vec!["test".to_string()],
})
})
.await
.unwrap();
assert_eq!(conn.get().id, 123);
assert_eq!(conn.get().data, vec!["test".to_string()]);
}
#[tokio::test]
async fn test_pooled_connection_updates_last_used() {
let mut conn = PooledConnection::new(42u32);
let initial_time = conn.last_used;
tokio::time::sleep(Duration::from_millis(1)).await;
let _ = conn.get_mut();
assert!(conn.last_used > initial_time);
}
#[tokio::test]
async fn test_connection_pool_partial_health_check() {
let config = PoolConfig {
max_connections: 10,
min_idle: 0,
health_check_enabled: true,
..Default::default()
};
let pool = ConnectionPool::<u32>::new(config);
let conn1 = pool.acquire(|| async { Ok(1) }).await.unwrap();
let conn2 = pool.acquire(|| async { Ok(2) }).await.unwrap();
let conn3 = pool.acquire(|| async { Ok(3) }).await.unwrap();
pool.release(conn1).await;
pool.release(conn2).await;
pool.release(conn3).await;
pool.health_check(|val| {
let v = *val;
async move { v % 2 != 0 }
})
.await;
assert_eq!(pool.size().await, 2);
let metrics = pool.metrics().await;
assert_eq!(metrics.health_check_failures, 1); }
}