use std::collections::VecDeque;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use thiserror::Error;
use tokio::sync::{Mutex, Notify, Semaphore};
use tokio::time::timeout;
#[derive(Debug, Error)]
pub enum PoolError {
#[error("Connection creation failed: {0}")]
ConnectionCreationFailed(String),
#[error("Timeout waiting for connection after {0}ms")]
AcquireTimeout(u64),
#[error("Connection health check failed")]
HealthCheckFailed,
#[error("Connection reset failed: {0}")]
ResetFailed(String),
#[error("Pool has been shut down")]
PoolShutdown,
#[error("Invalid configuration: {0}")]
InvalidConfig(String),
}
#[derive(Debug, Error)]
pub enum ConnectionError {
#[error("Connection error: {0}")]
Generic(String),
#[error("Connection is closed")]
Closed,
#[error("Operation timed out")]
Timeout,
}
pub trait Connection: Send + Sync {
fn is_healthy(&self) -> bool;
fn reset(&mut self) -> Result<(), ConnectionError>;
}
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub min_connections: usize,
pub max_connections: usize,
pub connection_timeout: Duration,
pub idle_timeout: Duration,
pub max_lifetime: Duration,
pub test_on_acquire: bool,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
min_connections: 1,
max_connections: 10,
connection_timeout: Duration::from_secs(30),
idle_timeout: Duration::from_secs(600),
max_lifetime: Duration::from_secs(3600),
test_on_acquire: true,
}
}
}
impl PoolConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn min_connections(mut self, min: usize) -> Self {
self.min_connections = min;
self
}
#[must_use]
pub fn max_connections(mut self, max: usize) -> Self {
self.max_connections = max;
self
}
#[must_use]
pub fn connection_timeout(mut self, timeout: Duration) -> Self {
self.connection_timeout = timeout;
self
}
#[must_use]
pub fn idle_timeout(mut self, timeout: Duration) -> Self {
self.idle_timeout = timeout;
self
}
#[must_use]
pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
self.max_lifetime = lifetime;
self
}
#[must_use]
pub fn test_on_acquire(mut self, test: bool) -> Self {
self.test_on_acquire = test;
self
}
pub fn validate(&self) -> Result<(), PoolError> {
if self.min_connections > self.max_connections {
return Err(PoolError::InvalidConfig(format!(
"min_connections ({}) cannot be greater than max_connections ({})",
self.min_connections, self.max_connections
)));
}
if self.max_connections == 0 {
return Err(PoolError::InvalidConfig(
"max_connections must be greater than 0".to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct PoolStats {
pub active_connections: usize,
pub idle_connections: usize,
pub total_connections: usize,
pub waiting_requests: usize,
pub acquire_count: u64,
pub release_count: u64,
pub timeout_count: u64,
}
struct PooledConnectionMeta<C>
where
C: Connection,
{
connection: C,
created_at: Instant,
last_used: Instant,
}
impl<C: Connection> PooledConnectionMeta<C> {
fn new(connection: C) -> Self {
let now = Instant::now();
Self {
connection,
created_at: now,
last_used: now,
}
}
fn is_expired(&self, max_lifetime: Duration) -> bool {
self.created_at.elapsed() > max_lifetime
}
fn is_idle_too_long(&self, idle_timeout: Duration) -> bool {
self.last_used.elapsed() > idle_timeout
}
}
pub type ConnectionFactory<C> =
Arc<dyn Fn() -> Pin<Box<dyn Future<Output = Result<C, PoolError>> + Send>> + Send + Sync>;
pub struct ConnectionPool<C>
where
C: Connection + 'static,
{
config: PoolConfig,
factory: ConnectionFactory<C>,
pool: Mutex<VecDeque<PooledConnectionMeta<C>>>,
semaphore: Semaphore,
notify: Notify,
active_count: AtomicUsize,
total_created: AtomicUsize,
acquire_count: AtomicU64,
release_count: AtomicU64,
timeout_count: AtomicU64,
waiting_count: AtomicUsize,
shutdown: Mutex<bool>,
}
impl<C> fmt::Debug for ConnectionPool<C>
where
C: Connection,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ConnectionPool")
.field("config", &self.config)
.field("active_count", &self.active_count.load(Ordering::Relaxed))
.field("total_created", &self.total_created.load(Ordering::Relaxed))
.field("acquire_count", &self.acquire_count.load(Ordering::Relaxed))
.field("release_count", &self.release_count.load(Ordering::Relaxed))
.field("timeout_count", &self.timeout_count.load(Ordering::Relaxed))
.finish_non_exhaustive()
}
}
impl<C> ConnectionPool<C>
where
C: Connection + 'static,
{
#[must_use]
pub fn new<F, Fut>(config: PoolConfig, factory: F) -> Self
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<C, PoolError>> + Send + 'static,
{
config.validate().expect("Invalid pool configuration");
let factory: ConnectionFactory<C> = Arc::new(move || Box::pin(factory()));
Self {
semaphore: Semaphore::new(config.max_connections),
config,
factory,
pool: Mutex::new(VecDeque::new()),
notify: Notify::new(),
active_count: AtomicUsize::new(0),
total_created: AtomicUsize::new(0),
acquire_count: AtomicU64::new(0),
release_count: AtomicU64::new(0),
timeout_count: AtomicU64::new(0),
waiting_count: AtomicUsize::new(0),
shutdown: Mutex::new(false),
}
}
pub async fn acquire(&self) -> Result<PooledConnection<'_, C>, PoolError> {
if *self.shutdown.lock().await {
return Err(PoolError::PoolShutdown);
}
self.waiting_count.fetch_add(1, Ordering::Relaxed);
let result = self.acquire_inner().await;
self.waiting_count.fetch_sub(1, Ordering::Relaxed);
result
}
async fn acquire_inner(&self) -> Result<PooledConnection<'_, C>, PoolError> {
let deadline = Instant::now() + self.config.connection_timeout;
loop {
if let Some(conn) = self.try_get_pooled_connection().await? {
self.acquire_count.fetch_add(1, Ordering::Relaxed);
return Ok(PooledConnection {
pool: self,
connection: Some(conn),
});
}
if let Ok(Ok(permit)) =
timeout(Duration::from_millis(0), self.semaphore.acquire()).await
{
permit.forget(); match self.create_connection().await {
Ok(conn) => {
self.active_count.fetch_add(1, Ordering::Relaxed);
self.acquire_count.fetch_add(1, Ordering::Relaxed);
return Ok(PooledConnection {
pool: self,
connection: Some(conn),
});
}
Err(e) => {
self.semaphore.add_permits(1);
return Err(e);
}
}
}
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
self.timeout_count.fetch_add(1, Ordering::Relaxed);
#[allow(clippy::cast_possible_truncation)]
return Err(PoolError::AcquireTimeout(
self.config.connection_timeout.as_millis() as u64,
));
}
if timeout(remaining, self.notify.notified()).await.is_err() {
self.timeout_count.fetch_add(1, Ordering::Relaxed);
#[allow(clippy::cast_possible_truncation)]
return Err(PoolError::AcquireTimeout(
self.config.connection_timeout.as_millis() as u64,
));
}
}
}
async fn try_get_pooled_connection(&self) -> Result<Option<C>, PoolError> {
let mut pool = self.pool.lock().await;
while let Some(mut meta) = pool.pop_front() {
if meta.is_expired(self.config.max_lifetime)
|| meta.is_idle_too_long(self.config.idle_timeout)
{
self.total_created.fetch_sub(1, Ordering::Relaxed);
self.semaphore.add_permits(1);
continue;
}
if self.config.test_on_acquire && !meta.connection.is_healthy() {
self.total_created.fetch_sub(1, Ordering::Relaxed);
self.semaphore.add_permits(1);
continue;
}
if let Err(e) = meta.connection.reset() {
self.total_created.fetch_sub(1, Ordering::Relaxed);
self.semaphore.add_permits(1);
tracing::warn!("Connection reset failed: {e}");
continue;
}
self.active_count.fetch_add(1, Ordering::Relaxed);
return Ok(Some(meta.connection));
}
Ok(None)
}
async fn create_connection(&self) -> Result<C, PoolError> {
let conn = (self.factory)().await?;
self.total_created.fetch_add(1, Ordering::Relaxed);
Ok(conn)
}
#[must_use]
pub fn stats(&self) -> PoolStats {
let active = self.active_count.load(Ordering::Relaxed);
let total = self.total_created.load(Ordering::Relaxed);
let idle = total.saturating_sub(active);
PoolStats {
active_connections: active,
idle_connections: idle,
total_connections: total,
waiting_requests: self.waiting_count.load(Ordering::Relaxed),
acquire_count: self.acquire_count.load(Ordering::Relaxed),
release_count: self.release_count.load(Ordering::Relaxed),
timeout_count: self.timeout_count.load(Ordering::Relaxed),
}
}
pub fn resize(&self, new_size: usize) {
if new_size == 0 {
tracing::warn!("Cannot resize pool to 0, ignoring");
return;
}
let current_max = self.config.max_connections;
if new_size > current_max {
self.semaphore.add_permits(new_size - current_max);
}
}
pub async fn health_check(&self) {
let mut pool = self.pool.lock().await;
let mut healthy_connections = VecDeque::new();
while let Some(meta) = pool.pop_front() {
if meta.connection.is_healthy()
&& !meta.is_expired(self.config.max_lifetime)
&& !meta.is_idle_too_long(self.config.idle_timeout)
{
healthy_connections.push_back(meta);
} else {
self.total_created.fetch_sub(1, Ordering::Relaxed);
self.semaphore.add_permits(1);
}
}
*pool = healthy_connections;
}
pub async fn shutdown(&self) {
*self.shutdown.lock().await = true;
let mut pool = self.pool.lock().await;
let count = pool.len();
pool.clear();
self.total_created.fetch_sub(count, Ordering::Relaxed);
self.semaphore.add_permits(count);
self.notify.notify_waiters();
}
}
pub struct PooledConnection<'a, C>
where
C: Connection + 'static,
{
pool: &'a ConnectionPool<C>,
connection: Option<C>,
}
impl<C> fmt::Debug for PooledConnection<'_, C>
where
C: Connection + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PooledConnection")
.field("connection", &self.connection)
.finish()
}
}
impl<C> PooledConnection<'_, C>
where
C: Connection,
{
#[must_use]
pub fn get(&self) -> Option<&C> {
self.connection.as_ref()
}
pub fn get_mut(&mut self) -> Option<&mut C> {
self.connection.as_mut()
}
pub fn take(mut self) -> Option<C> {
if let Some(conn) = self.connection.take() {
self.pool.active_count.fetch_sub(1, Ordering::Relaxed);
self.pool.total_created.fetch_sub(1, Ordering::Relaxed);
self.pool.semaphore.add_permits(1);
Some(conn)
} else {
None
}
}
}
impl<C> std::ops::Deref for PooledConnection<'_, C>
where
C: Connection,
{
type Target = C;
fn deref(&self) -> &Self::Target {
self.connection.as_ref().expect("Connection has been taken")
}
}
impl<C> std::ops::DerefMut for PooledConnection<'_, C>
where
C: Connection,
{
fn deref_mut(&mut self) -> &mut Self::Target {
self.connection.as_mut().expect("Connection has been taken")
}
}
impl<C> Drop for PooledConnection<'_, C>
where
C: Connection + 'static,
{
fn drop(&mut self) {
if let Some(conn) = self.connection.take() {
self.pool.active_count.fetch_sub(1, Ordering::Relaxed);
self.pool.release_count.fetch_add(1, Ordering::Relaxed);
if let Ok(mut pool) = self.pool.pool.try_lock() {
if pool.len() < self.pool.config.max_connections {
pool.push_back(PooledConnectionMeta::new(conn));
drop(pool);
self.pool.notify.notify_one();
} else {
drop(pool);
self.pool.total_created.fetch_sub(1, Ordering::Relaxed);
self.pool.semaphore.add_permits(1);
}
} else {
self.pool.total_created.fetch_sub(1, Ordering::Relaxed);
self.pool.semaphore.add_permits(1);
}
}
}
}
#[derive(Debug, Clone)]
pub struct MockConnection {
pub healthy: bool,
pub reset_count: usize,
pub id: u64,
}
impl MockConnection {
#[must_use]
pub fn new(id: u64) -> Self {
Self {
healthy: true,
id,
reset_count: 0,
}
}
#[must_use]
pub fn unhealthy(id: u64) -> Self {
Self {
healthy: false,
id,
reset_count: 0,
}
}
}
impl Connection for MockConnection {
fn is_healthy(&self) -> bool {
self.healthy
}
fn reset(&mut self) -> Result<(), ConnectionError> {
self.reset_count += 1;
Ok(())
}
}
#[cfg(test)]
#[allow(clippy::items_after_statements)]
mod tests {
use super::*;
use std::sync::atomic::AtomicU64;
use tokio::time::sleep;
fn create_mock_factory()
-> impl Fn() -> Pin<Box<dyn Future<Output = Result<MockConnection, PoolError>> + Send>> {
static COUNTER: AtomicU64 = AtomicU64::new(0);
move || {
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::new(id)) })
}
}
#[tokio::test]
async fn test_pool_creation() {
let config = PoolConfig::default();
let pool = ConnectionPool::new(config, create_mock_factory());
let stats = pool.stats();
assert_eq!(stats.total_connections, 0);
assert_eq!(stats.active_connections, 0);
}
#[tokio::test]
async fn test_basic_acquire_release() {
let config = PoolConfig::default();
let pool = ConnectionPool::new(config, create_mock_factory());
let conn = pool.acquire().await.unwrap();
assert!(conn.is_healthy());
let stats = pool.stats();
assert_eq!(stats.active_connections, 1);
assert_eq!(stats.acquire_count, 1);
drop(conn);
let stats = pool.stats();
assert_eq!(stats.active_connections, 0);
assert_eq!(stats.release_count, 1);
}
#[tokio::test]
async fn test_connection_reuse() {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let config = PoolConfig::default().test_on_acquire(false);
let pool = ConnectionPool::new(config, || {
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::new(id)) })
});
let conn1 = pool.acquire().await.unwrap();
let id1 = conn1.id;
drop(conn1);
tokio::task::yield_now().await;
let conn2 = pool.acquire().await.unwrap();
let id2 = conn2.id;
assert_eq!(id1, id2);
}
#[tokio::test]
async fn test_pool_exhaustion_timeout() {
let config = PoolConfig::default()
.max_connections(1)
.connection_timeout(Duration::from_millis(50));
static COUNTER: AtomicU64 = AtomicU64::new(100);
let pool = ConnectionPool::new(config, || {
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::new(id)) })
});
let conn1 = pool.acquire().await.unwrap();
let result = pool.acquire().await;
assert!(matches!(result, Err(PoolError::AcquireTimeout(_))));
let stats = pool.stats();
assert_eq!(stats.timeout_count, 1);
drop(conn1);
}
#[tokio::test]
async fn test_connection_returned_on_drop() {
let config = PoolConfig::default()
.max_connections(1)
.connection_timeout(Duration::from_millis(100));
static COUNTER: AtomicU64 = AtomicU64::new(200);
let pool = Arc::new(ConnectionPool::new(config, || {
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::new(id)) })
}));
let pool_clone = Arc::clone(&pool);
let handle = tokio::spawn(async move {
let _conn = pool_clone.acquire().await.unwrap();
sleep(Duration::from_millis(20)).await;
});
sleep(Duration::from_millis(50)).await;
let conn = pool.acquire().await.unwrap();
assert!(conn.is_healthy());
handle.await.unwrap();
}
#[tokio::test]
async fn test_health_check() {
let config = PoolConfig::default().test_on_acquire(false);
static COUNTER: AtomicU64 = AtomicU64::new(300);
let pool = ConnectionPool::new(config, || {
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::unhealthy(id)) })
});
let conn = pool.acquire().await.unwrap();
drop(conn);
tokio::task::yield_now().await;
pool.health_check().await;
let stats = pool.stats();
assert_eq!(stats.idle_connections, 0);
}
#[tokio::test]
async fn test_test_on_acquire() {
let config = PoolConfig::default().test_on_acquire(true);
let call_count = std::sync::Arc::new(AtomicU64::new(0));
let call_count_clone = std::sync::Arc::clone(&call_count);
let pool = ConnectionPool::new(config, move || {
let count = call_count_clone.fetch_add(1, Ordering::Relaxed);
Box::pin(async move {
Ok(MockConnection::new(count))
})
});
let conn = pool.acquire().await.unwrap();
assert!(conn.is_healthy());
drop(conn);
assert!(call_count.load(Ordering::Relaxed) >= 1);
}
#[tokio::test]
async fn test_pool_stats() {
let config = PoolConfig::default().max_connections(5);
static COUNTER: AtomicU64 = AtomicU64::new(400);
let pool = ConnectionPool::new(config, || {
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::new(id)) })
});
let conn1 = pool.acquire().await.unwrap();
let conn2 = pool.acquire().await.unwrap();
let conn3 = pool.acquire().await.unwrap();
let stats = pool.stats();
assert_eq!(stats.active_connections, 3);
assert_eq!(stats.acquire_count, 3);
drop(conn1);
drop(conn2);
let stats = pool.stats();
assert_eq!(stats.active_connections, 1);
assert_eq!(stats.release_count, 2);
drop(conn3);
}
#[test]
fn test_config_validation() {
let config = PoolConfig::default().min_connections(10).max_connections(5);
let result = config.validate();
assert!(matches!(result, Err(PoolError::InvalidConfig(_))));
let config = PoolConfig::default().max_connections(0);
let result = config.validate();
assert!(matches!(result, Err(PoolError::InvalidConfig(_))));
}
#[tokio::test]
async fn test_pool_shutdown() {
let config = PoolConfig::default();
static COUNTER: AtomicU64 = AtomicU64::new(500);
let pool = ConnectionPool::new(config, || {
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::new(id)) })
});
let conn = pool.acquire().await.unwrap();
drop(conn);
pool.shutdown().await;
let result = pool.acquire().await;
assert!(matches!(result, Err(PoolError::PoolShutdown)));
}
#[tokio::test]
async fn test_resize_pool() {
let config = PoolConfig::default().max_connections(2);
static COUNTER: AtomicU64 = AtomicU64::new(600);
let pool = ConnectionPool::new(config, || {
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::new(id)) })
});
let conn1 = pool.acquire().await.unwrap();
let conn2 = pool.acquire().await.unwrap();
pool.resize(3);
let config2 = PoolConfig::default()
.max_connections(3)
.connection_timeout(Duration::from_millis(50));
static COUNTER2: AtomicU64 = AtomicU64::new(700);
let pool2 = ConnectionPool::new(config2, || {
let id = COUNTER2.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::new(id)) })
});
let c1 = pool2.acquire().await.unwrap();
let c2 = pool2.acquire().await.unwrap();
let c3 = pool2.acquire().await.unwrap();
let stats = pool2.stats();
assert_eq!(stats.active_connections, 3);
drop(c1);
drop(c2);
drop(c3);
drop(conn1);
drop(conn2);
}
#[tokio::test]
async fn test_connection_lifetime_expiration() {
let config = PoolConfig::default()
.max_lifetime(Duration::from_millis(50))
.test_on_acquire(false);
static COUNTER: AtomicU64 = AtomicU64::new(800);
let pool = ConnectionPool::new(config, || {
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::new(id)) })
});
let conn1 = pool.acquire().await.unwrap();
let id1 = conn1.id;
drop(conn1);
sleep(Duration::from_millis(60)).await;
let conn2 = pool.acquire().await.unwrap();
let id2 = conn2.id;
assert_ne!(id1, id2);
drop(conn2);
}
#[tokio::test]
async fn test_connection_idle_timeout() {
let config = PoolConfig::default()
.idle_timeout(Duration::from_millis(50))
.test_on_acquire(false);
static COUNTER: AtomicU64 = AtomicU64::new(900);
let pool = ConnectionPool::new(config, || {
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::new(id)) })
});
let conn1 = pool.acquire().await.unwrap();
let id1 = conn1.id;
drop(conn1);
sleep(Duration::from_millis(60)).await;
let conn2 = pool.acquire().await.unwrap();
let id2 = conn2.id;
assert_ne!(id1, id2);
drop(conn2);
}
#[tokio::test]
async fn test_concurrent_acquire() {
let config = PoolConfig::default().max_connections(5);
static COUNTER: AtomicU64 = AtomicU64::new(1000);
let pool = Arc::new(ConnectionPool::new(config, || {
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::new(id)) })
}));
let mut handles = Vec::new();
for _ in 0..10 {
let pool = Arc::clone(&pool);
handles.push(tokio::spawn(async move {
let conn = pool.acquire().await.unwrap();
sleep(Duration::from_millis(10)).await;
drop(conn);
}));
}
for handle in handles {
handle.await.unwrap();
}
let stats = pool.stats();
assert_eq!(stats.acquire_count, 10);
assert_eq!(stats.release_count, 10);
}
#[tokio::test]
async fn test_pooled_connection_deref() {
let config = PoolConfig::default();
static COUNTER: AtomicU64 = AtomicU64::new(1100);
let pool = ConnectionPool::new(config, || {
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::new(id)) })
});
let conn = pool.acquire().await.unwrap();
assert!(conn.is_healthy());
drop(conn);
}
#[tokio::test]
async fn test_pooled_connection_deref_mut() {
let config = PoolConfig::default();
static COUNTER: AtomicU64 = AtomicU64::new(1200);
let pool = ConnectionPool::new(config, || {
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::new(id)) })
});
let mut conn = pool.acquire().await.unwrap();
conn.healthy = false;
assert!(!conn.is_healthy());
drop(conn);
}
#[tokio::test]
async fn test_pooled_connection_take() {
let config = PoolConfig::default();
static COUNTER: AtomicU64 = AtomicU64::new(1300);
let pool = ConnectionPool::new(config, || {
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::new(id)) })
});
let conn = pool.acquire().await.unwrap();
let taken = conn.take();
assert!(taken.is_some());
let stats = pool.stats();
assert_eq!(stats.active_connections, 0);
}
#[test]
fn test_mock_connection_reset() {
let mut conn = MockConnection::new(1);
assert_eq!(conn.reset_count, 0);
conn.reset().unwrap();
assert_eq!(conn.reset_count, 1);
conn.reset().unwrap();
assert_eq!(conn.reset_count, 2);
}
#[test]
fn test_config_builder() {
let config = PoolConfig::new()
.min_connections(2)
.max_connections(20)
.connection_timeout(Duration::from_secs(60))
.idle_timeout(Duration::from_secs(300))
.max_lifetime(Duration::from_secs(1800))
.test_on_acquire(false);
assert_eq!(config.min_connections, 2);
assert_eq!(config.max_connections, 20);
assert_eq!(config.connection_timeout, Duration::from_secs(60));
assert_eq!(config.idle_timeout, Duration::from_secs(300));
assert_eq!(config.max_lifetime, Duration::from_secs(1800));
assert!(!config.test_on_acquire);
}
#[test]
fn test_pool_error_display() {
let err = PoolError::ConnectionCreationFailed("test".to_string());
assert!(err.to_string().contains("Connection creation failed"));
let err = PoolError::AcquireTimeout(1000);
assert!(err.to_string().contains("Timeout"));
assert!(err.to_string().contains("1000"));
let err = PoolError::HealthCheckFailed;
assert!(err.to_string().contains("health check failed"));
let err = PoolError::PoolShutdown;
assert!(err.to_string().contains("shut down"));
}
#[test]
fn test_connection_error_display() {
let err = ConnectionError::Generic("test".to_string());
assert!(err.to_string().contains("test"));
let err = ConnectionError::Closed;
assert!(err.to_string().contains("closed"));
let err = ConnectionError::Timeout;
assert!(err.to_string().contains("timed out"));
}
#[tokio::test]
async fn test_waiting_requests_count() {
let config = PoolConfig::default()
.max_connections(1)
.connection_timeout(Duration::from_millis(200));
static COUNTER: AtomicU64 = AtomicU64::new(1400);
let pool = Arc::new(ConnectionPool::new(config, || {
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::new(id)) })
}));
let conn = pool.acquire().await.unwrap();
let pool_clone = Arc::clone(&pool);
let handle = tokio::spawn(async move {
let _ = pool_clone.acquire().await;
});
sleep(Duration::from_millis(10)).await;
let stats = pool.stats();
let _ = stats.waiting_requests;
drop(conn);
let _ = handle.await;
}
#[tokio::test]
async fn test_multiple_acquire_release_cycles() {
let config = PoolConfig::default()
.max_connections(3)
.test_on_acquire(false);
static COUNTER: AtomicU64 = AtomicU64::new(1500);
let pool = ConnectionPool::new(config, || {
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::new(id)) })
});
for _ in 0..10 {
let c1 = pool.acquire().await.unwrap();
let c2 = pool.acquire().await.unwrap();
drop(c1);
drop(c2);
}
let stats = pool.stats();
assert_eq!(stats.acquire_count, 20);
assert_eq!(stats.release_count, 20);
}
#[tokio::test]
async fn test_pool_debug() {
let config = PoolConfig::default();
static COUNTER: AtomicU64 = AtomicU64::new(1600);
let pool = ConnectionPool::new(config, || {
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::new(id)) })
});
let debug_str = format!("{pool:?}");
assert!(debug_str.contains("ConnectionPool"));
assert!(debug_str.contains("config"));
}
#[tokio::test]
async fn test_pooled_connection_get() {
let config = PoolConfig::default();
static COUNTER: AtomicU64 = AtomicU64::new(1700);
let pool = ConnectionPool::new(config, || {
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::new(id)) })
});
let mut conn = pool.acquire().await.unwrap();
let ref_conn = conn.get();
assert!(ref_conn.is_some());
assert!(ref_conn.unwrap().is_healthy());
let ref_mut_conn = conn.get_mut();
assert!(ref_mut_conn.is_some());
ref_mut_conn.unwrap().healthy = false;
assert!(!conn.is_healthy());
drop(conn);
}
#[tokio::test]
async fn test_factory_error() {
let config = PoolConfig::default();
let pool = ConnectionPool::<MockConnection>::new(config, || {
Box::pin(async {
Err(PoolError::ConnectionCreationFailed(
"Factory error".to_string(),
))
})
});
let result = pool.acquire().await;
assert!(matches!(
result,
Err(PoolError::ConnectionCreationFailed(_))
));
}
#[test]
fn test_pool_stats_default() {
let stats = PoolStats::default();
assert_eq!(stats.active_connections, 0);
assert_eq!(stats.idle_connections, 0);
assert_eq!(stats.total_connections, 0);
assert_eq!(stats.waiting_requests, 0);
assert_eq!(stats.acquire_count, 0);
assert_eq!(stats.release_count, 0);
assert_eq!(stats.timeout_count, 0);
}
#[tokio::test]
async fn test_pooled_connection_debug() {
let config = PoolConfig::default();
static COUNTER: AtomicU64 = AtomicU64::new(1800);
let pool = ConnectionPool::new(config, || {
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::new(id)) })
});
let conn = pool.acquire().await.unwrap();
let debug_str = format!("{conn:?}");
assert!(debug_str.contains("PooledConnection"));
drop(conn);
}
#[tokio::test]
async fn test_health_check_with_expired() {
let config = PoolConfig::default()
.max_lifetime(Duration::from_millis(10))
.test_on_acquire(false);
static COUNTER: AtomicU64 = AtomicU64::new(1900);
let pool = ConnectionPool::new(config, || {
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::new(id)) })
});
let conn = pool.acquire().await.unwrap();
drop(conn);
sleep(Duration::from_millis(20)).await;
pool.health_check().await;
let stats = pool.stats();
assert_eq!(stats.idle_connections, 0);
}
#[tokio::test]
async fn test_resize_to_zero() {
let config = PoolConfig::default().max_connections(2);
static COUNTER: AtomicU64 = AtomicU64::new(2000);
let pool = ConnectionPool::new(config, || {
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
Box::pin(async move { Ok(MockConnection::new(id)) })
});
pool.resize(0);
let conn = pool.acquire().await.unwrap();
assert!(conn.is_healthy());
drop(conn);
}
}