use async_trait::async_trait;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, Semaphore};
use crate::connection::{Connection, ConnectionConfig, ConnectionFactory};
use crate::error::{Error, Result};
#[async_trait]
pub trait ConnectionPool: Send + Sync {
async fn get(&self) -> Result<PooledConnection>;
async fn return_connection(&self, conn: Box<dyn Connection>, created_at: Instant);
fn size(&self) -> usize;
fn idle(&self) -> usize;
fn in_use(&self) -> usize {
self.size().saturating_sub(self.idle())
}
fn stats(&self) -> PoolStats;
async fn close(&self) -> Result<()>;
}
pub struct PooledConnection {
conn: Option<Box<dyn Connection>>,
pool: Arc<dyn ConnectionPool>,
created_at: Instant,
borrowed_at: Instant,
}
impl PooledConnection {
pub fn new(conn: Box<dyn Connection>, pool: Arc<dyn ConnectionPool>) -> Self {
let now = Instant::now();
Self {
conn: Some(conn),
pool,
created_at: now,
borrowed_at: now,
}
}
pub fn with_created_at(
conn: Box<dyn Connection>,
pool: Arc<dyn ConnectionPool>,
created_at: Instant,
) -> Self {
Self {
conn: Some(conn),
pool,
created_at,
borrowed_at: Instant::now(),
}
}
pub fn created_at(&self) -> Instant {
self.created_at
}
pub fn age(&self) -> Duration {
self.created_at.elapsed()
}
pub fn is_expired(&self, max_lifetime: Duration) -> bool {
self.age() > max_lifetime
}
#[inline]
pub fn borrowed_at(&self) -> Instant {
self.borrowed_at
}
#[inline]
pub fn time_in_use(&self) -> Duration {
self.borrowed_at.elapsed()
}
pub fn connection(&self) -> &dyn Connection {
self.conn
.as_ref()
.expect("connection already returned")
.as_ref()
}
pub fn connection_mut(&mut self) -> &mut dyn Connection {
self.conn
.as_mut()
.expect("connection already returned")
.as_mut()
}
}
impl std::ops::Deref for PooledConnection {
type Target = dyn Connection;
fn deref(&self) -> &Self::Target {
self.conn
.as_ref()
.expect("connection already returned")
.as_ref()
}
}
impl std::ops::DerefMut for PooledConnection {
fn deref_mut(&mut self) -> &mut Self::Target {
self.conn
.as_mut()
.expect("connection already returned")
.as_mut()
}
}
impl Drop for PooledConnection {
fn drop(&mut self) {
if let Some(conn) = self.conn.take() {
let pool = self.pool.clone();
let created_at = self.created_at;
tokio::spawn(async move {
pool.return_connection(conn, created_at).await;
});
}
}
}
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub connection: ConnectionConfig,
pub min_size: usize,
pub max_size: usize,
pub acquire_timeout: Duration,
pub max_lifetime: Duration,
pub idle_timeout: Duration,
pub health_check_interval: Duration,
pub test_on_borrow: bool,
pub test_on_return: bool,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
connection: ConnectionConfig::default(),
min_size: 1,
max_size: 10,
acquire_timeout: Duration::from_secs(30),
max_lifetime: Duration::from_secs(1800), idle_timeout: Duration::from_secs(600), health_check_interval: Duration::from_secs(30),
test_on_borrow: true,
test_on_return: false,
}
}
}
impl PoolConfig {
pub fn new(url: impl Into<String>) -> Self {
Self {
connection: ConnectionConfig::new(url),
..Default::default()
}
}
pub fn with_min_size(mut self, size: usize) -> Self {
self.min_size = size;
self
}
pub fn with_max_size(mut self, size: usize) -> Self {
self.max_size = size;
self
}
pub fn with_acquire_timeout(mut self, timeout: Duration) -> Self {
self.acquire_timeout = timeout;
self
}
pub fn with_max_lifetime(mut self, lifetime: Duration) -> Self {
self.max_lifetime = lifetime;
self
}
pub fn with_idle_timeout(mut self, timeout: Duration) -> Self {
self.idle_timeout = timeout;
self
}
pub fn with_test_on_borrow(mut self, test: bool) -> Self {
self.test_on_borrow = test;
self
}
pub fn with_test_on_return(mut self, test: bool) -> Self {
self.test_on_return = test;
self
}
}
#[derive(Debug, Clone, Default)]
pub struct PoolStats {
pub connections_created: u64,
pub connections_closed: u64,
pub acquisitions: u64,
pub exhausted_count: u64,
pub total_wait_time_ms: u64,
pub health_check_failures: u64,
pub lifetime_expired_count: u64,
pub idle_expired_count: u64,
pub reused_count: u64,
pub fresh_count: u64,
}
impl PoolStats {
#[inline]
pub fn recycled_total(&self) -> u64 {
self.lifetime_expired_count + self.idle_expired_count
}
#[inline]
pub fn reuse_rate(&self) -> f64 {
if self.acquisitions == 0 {
0.0
} else {
self.reused_count as f64 / self.acquisitions as f64
}
}
#[inline]
pub fn avg_wait_time_ms(&self) -> f64 {
if self.acquisitions == 0 {
0.0
} else {
self.total_wait_time_ms as f64 / self.acquisitions as f64
}
}
#[inline]
pub fn health_failure_rate(&self) -> f64 {
let total_checks = self.acquisitions + self.health_check_failures;
if total_checks == 0 {
0.0
} else {
self.health_check_failures as f64 / total_checks as f64
}
}
#[inline]
pub fn active_connections(&self) -> u64 {
self.connections_created
.saturating_sub(self.connections_closed)
}
}
#[derive(Debug, Default)]
#[allow(missing_docs)]
pub struct AtomicPoolStats {
pub connections_created: AtomicU64,
pub connections_closed: AtomicU64,
pub acquisitions: AtomicU64,
pub exhausted_count: AtomicU64,
pub total_wait_time_ms: AtomicU64,
pub health_check_failures: AtomicU64,
pub lifetime_expired_count: AtomicU64,
pub idle_expired_count: AtomicU64,
pub reused_count: AtomicU64,
pub fresh_count: AtomicU64,
}
impl AtomicPoolStats {
pub fn new() -> Self {
Self::default()
}
pub fn record_created(&self) {
self.connections_created.fetch_add(1, Ordering::Relaxed);
}
pub fn record_closed(&self) {
self.connections_closed.fetch_add(1, Ordering::Relaxed);
}
pub fn record_acquisition(&self, wait_time_ms: u64) {
self.acquisitions.fetch_add(1, Ordering::Relaxed);
self.total_wait_time_ms
.fetch_add(wait_time_ms, Ordering::Relaxed);
}
pub fn record_exhausted(&self) {
self.exhausted_count.fetch_add(1, Ordering::Relaxed);
}
pub fn record_health_check_failure(&self) {
self.health_check_failures.fetch_add(1, Ordering::Relaxed);
}
pub fn record_lifetime_expired(&self) {
self.lifetime_expired_count.fetch_add(1, Ordering::Relaxed);
}
pub fn record_idle_expired(&self) {
self.idle_expired_count.fetch_add(1, Ordering::Relaxed);
}
#[inline]
pub fn record_reused(&self) {
self.reused_count.fetch_add(1, Ordering::Relaxed);
}
#[inline]
pub fn record_fresh(&self) {
self.fresh_count.fetch_add(1, Ordering::Relaxed);
}
pub fn snapshot(&self) -> PoolStats {
PoolStats {
connections_created: self.connections_created.load(Ordering::Relaxed),
connections_closed: self.connections_closed.load(Ordering::Relaxed),
acquisitions: self.acquisitions.load(Ordering::Relaxed),
exhausted_count: self.exhausted_count.load(Ordering::Relaxed),
total_wait_time_ms: self.total_wait_time_ms.load(Ordering::Relaxed),
health_check_failures: self.health_check_failures.load(Ordering::Relaxed),
lifetime_expired_count: self.lifetime_expired_count.load(Ordering::Relaxed),
idle_expired_count: self.idle_expired_count.load(Ordering::Relaxed),
reused_count: self.reused_count.load(Ordering::Relaxed),
fresh_count: self.fresh_count.load(Ordering::Relaxed),
}
}
pub fn avg_wait_time_ms(&self) -> f64 {
let acquisitions = self.acquisitions.load(Ordering::Relaxed);
if acquisitions == 0 {
0.0
} else {
self.total_wait_time_ms.load(Ordering::Relaxed) as f64 / acquisitions as f64
}
}
}
pub struct PoolBuilder {
config: PoolConfig,
}
impl PoolBuilder {
pub fn new(url: impl Into<String>) -> Self {
Self {
config: PoolConfig::new(url),
}
}
pub fn min_size(mut self, size: usize) -> Self {
self.config.min_size = size;
self
}
pub fn max_size(mut self, size: usize) -> Self {
self.config.max_size = size;
self
}
pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
self.config.acquire_timeout = timeout;
self
}
pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
self.config.max_lifetime = lifetime;
self
}
pub fn idle_timeout(mut self, timeout: Duration) -> Self {
self.config.idle_timeout = timeout;
self
}
pub fn test_on_borrow(mut self, test: bool) -> Self {
self.config.test_on_borrow = test;
self
}
pub fn config(self) -> PoolConfig {
self.config
}
}
pub struct SimpleConnectionPool {
config: PoolConfig,
factory: Arc<dyn ConnectionFactory>,
idle: Mutex<Vec<PoolEntry>>,
semaphore: Semaphore,
total_connections: AtomicUsize,
stats: Arc<AtomicPoolStats>,
shutdown: std::sync::atomic::AtomicBool,
self_ref: tokio::sync::OnceCell<std::sync::Weak<Self>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecycleReason {
LifetimeExpired,
IdleExpired,
}
struct PoolEntry {
conn: Box<dyn Connection>,
created_at: Instant,
last_used: Instant,
}
impl SimpleConnectionPool {
pub async fn new(config: PoolConfig, factory: Arc<dyn ConnectionFactory>) -> Result<Arc<Self>> {
let pool = Arc::new(Self {
semaphore: Semaphore::new(config.max_size),
config: config.clone(),
factory,
idle: Mutex::new(Vec::with_capacity(config.max_size)),
total_connections: AtomicUsize::new(0),
stats: Arc::new(AtomicPoolStats::new()),
shutdown: std::sync::atomic::AtomicBool::new(false),
self_ref: tokio::sync::OnceCell::new(),
});
let _ = pool.self_ref.set(Arc::downgrade(&pool));
for _ in 0..config.min_size {
if let Ok(conn) = pool.create_connection().await {
let mut idle = pool.idle.lock().await;
idle.push(PoolEntry {
conn,
created_at: Instant::now(),
last_used: Instant::now(),
});
}
}
Ok(pool)
}
pub fn builder(url: impl Into<String>) -> PoolBuilder {
PoolBuilder::new(url)
}
fn get_self_arc(&self) -> Option<Arc<Self>> {
self.self_ref.get().and_then(|w| w.upgrade())
}
async fn create_connection(&self) -> Result<Box<dyn Connection>> {
let conn = self.factory.connect(&self.config.connection).await?;
self.total_connections.fetch_add(1, Ordering::Release);
self.stats.record_created();
Ok(conn)
}
async fn validate_connection(&self, conn: &dyn Connection) -> bool {
if self.config.test_on_borrow {
conn.is_valid().await
} else {
true
}
}
fn should_recycle(&self, entry: &PoolEntry) -> Option<RecycleReason> {
if entry.created_at.elapsed() > self.config.max_lifetime {
Some(RecycleReason::LifetimeExpired)
} else if entry.last_used.elapsed() > self.config.idle_timeout {
Some(RecycleReason::IdleExpired)
} else {
None
}
}
pub fn config(&self) -> &PoolConfig {
&self.config
}
}
#[async_trait]
impl ConnectionPool for SimpleConnectionPool {
async fn get(&self) -> Result<PooledConnection> {
if self.shutdown.load(Ordering::Acquire) {
return Err(Error::PoolExhausted {
message: "Pool is shut down".to_string(),
});
}
let start = Instant::now();
let permit = tokio::time::timeout(self.config.acquire_timeout, self.semaphore.acquire())
.await
.map_err(|_| {
self.stats.record_exhausted();
Error::PoolExhausted {
message: format!(
"Timeout waiting for connection ({}ms)",
self.config.acquire_timeout.as_millis()
),
}
})?
.map_err(|_| Error::PoolExhausted {
message: "Pool semaphore closed".to_string(),
})?;
let conn_with_time: Option<(Box<dyn Connection>, Instant)> = {
let mut idle = self.idle.lock().await;
loop {
match idle.pop() {
Some(entry) => {
if let Some(reason) = self.should_recycle(&entry) {
match reason {
RecycleReason::LifetimeExpired => {
self.stats.record_lifetime_expired();
}
RecycleReason::IdleExpired => {
self.stats.record_idle_expired();
}
}
self.total_connections.fetch_sub(1, Ordering::Release);
self.stats.record_closed();
continue;
}
if !self.validate_connection(&*entry.conn).await {
self.total_connections.fetch_sub(1, Ordering::Release);
self.stats.record_closed();
self.stats.record_health_check_failure();
continue;
}
break Some((entry.conn, entry.created_at));
}
None => break None,
}
}
};
let (conn, created_at, was_reused) = match conn_with_time {
Some((c, t)) => {
self.stats.record_reused();
(c, t, true)
}
None => {
match self.create_connection().await {
Ok(c) => {
self.stats.record_fresh();
(c, Instant::now(), false)
}
Err(e) => {
drop(permit);
return Err(e);
}
}
}
};
let wait_ms = start.elapsed().as_millis() as u64;
self.stats.record_acquisition(wait_ms);
#[cfg(feature = "pool-trace")]
tracing::trace!(
reused = was_reused,
wait_ms = wait_ms,
age_ms = created_at.elapsed().as_millis() as u64,
"connection acquired from pool"
);
let _ = was_reused;
std::mem::forget(permit);
let pool_arc = self.get_self_arc().ok_or_else(|| Error::PoolExhausted {
message: "Pool has been dropped".to_string(),
})?;
Ok(PooledConnection::with_created_at(
conn, pool_arc, created_at,
))
}
async fn return_connection(&self, conn: Box<dyn Connection>, created_at: Instant) {
self.semaphore.add_permits(1);
if self.shutdown.load(Ordering::Acquire) {
let _ = conn.close().await;
self.total_connections.fetch_sub(1, Ordering::Release);
self.stats.record_closed();
return;
}
if self.config.test_on_return && !conn.is_valid().await {
self.total_connections.fetch_sub(1, Ordering::Release);
self.stats.record_closed();
self.stats.record_health_check_failure();
return;
}
let mut idle = self.idle.lock().await;
idle.push(PoolEntry {
conn,
created_at,
last_used: Instant::now(),
});
}
fn size(&self) -> usize {
self.total_connections.load(Ordering::Acquire)
}
fn idle(&self) -> usize {
self.semaphore.available_permits()
}
fn stats(&self) -> PoolStats {
self.stats.snapshot()
}
async fn close(&self) -> Result<()> {
self.shutdown.store(true, Ordering::Release);
let mut idle = self.idle.lock().await;
for entry in idle.drain(..) {
let _ = entry.conn.close().await;
self.total_connections.fetch_sub(1, Ordering::Release);
self.stats.record_closed();
}
Ok(())
}
}
pub async fn create_pool(
url: impl Into<String>,
factory: Arc<dyn ConnectionFactory>,
) -> Result<Arc<SimpleConnectionPool>> {
let config = PoolConfig::new(url);
SimpleConnectionPool::new(config, factory).await
}
pub async fn create_pool_with_config(
config: PoolConfig,
factory: Arc<dyn ConnectionFactory>,
) -> Result<Arc<SimpleConnectionPool>> {
SimpleConnectionPool::new(config, factory).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_config_builder() {
let config = PoolConfig::new("postgres://localhost/test")
.with_min_size(5)
.with_max_size(20)
.with_acquire_timeout(Duration::from_secs(10))
.with_test_on_borrow(true);
assert_eq!(config.min_size, 5);
assert_eq!(config.max_size, 20);
assert_eq!(config.acquire_timeout, Duration::from_secs(10));
assert!(config.test_on_borrow);
}
#[test]
fn test_pool_builder() {
let config = PoolBuilder::new("mysql://localhost/test")
.min_size(2)
.max_size(15)
.acquire_timeout(Duration::from_secs(5))
.config();
assert_eq!(config.min_size, 2);
assert_eq!(config.max_size, 15);
assert_eq!(config.acquire_timeout, Duration::from_secs(5));
}
#[test]
fn test_atomic_pool_stats() {
let stats = AtomicPoolStats::new();
stats.record_created();
stats.record_created();
stats.record_acquisition(100);
stats.record_acquisition(200);
stats.record_closed();
stats.record_exhausted();
stats.record_health_check_failure();
stats.record_lifetime_expired();
stats.record_lifetime_expired();
stats.record_idle_expired();
stats.record_reused();
stats.record_reused();
stats.record_reused();
stats.record_fresh();
let snapshot = stats.snapshot();
assert_eq!(snapshot.connections_created, 2);
assert_eq!(snapshot.connections_closed, 1);
assert_eq!(snapshot.acquisitions, 2);
assert_eq!(snapshot.total_wait_time_ms, 300);
assert_eq!(snapshot.exhausted_count, 1);
assert_eq!(snapshot.health_check_failures, 1);
assert_eq!(snapshot.lifetime_expired_count, 2);
assert_eq!(snapshot.idle_expired_count, 1);
assert_eq!(snapshot.reused_count, 3);
assert_eq!(snapshot.fresh_count, 1);
assert!((stats.avg_wait_time_ms() - 150.0).abs() < 0.01);
}
#[test]
fn test_pool_stats_helper_methods() {
let stats = PoolStats {
connections_created: 100,
connections_closed: 30,
acquisitions: 1000,
exhausted_count: 5,
total_wait_time_ms: 5000,
health_check_failures: 10,
lifetime_expired_count: 15,
idle_expired_count: 10,
reused_count: 800,
fresh_count: 200,
};
assert_eq!(stats.recycled_total(), 25);
assert!((stats.reuse_rate() - 0.8).abs() < 0.001);
assert!((stats.avg_wait_time_ms() - 5.0).abs() < 0.001);
assert!((stats.health_failure_rate() - 0.0099).abs() < 0.001);
assert_eq!(stats.active_connections(), 70);
}
#[test]
fn test_pool_stats_zero_acquisitions() {
let stats = PoolStats::default();
assert!((stats.reuse_rate() - 0.0).abs() < 0.001);
assert!((stats.avg_wait_time_ms() - 0.0).abs() < 0.001);
assert!((stats.health_failure_rate() - 0.0).abs() < 0.001);
}
#[test]
fn test_recycle_reason_enum() {
let lifetime_reason = RecycleReason::LifetimeExpired;
let idle_reason = RecycleReason::IdleExpired;
assert_eq!(lifetime_reason, RecycleReason::LifetimeExpired);
assert_eq!(idle_reason, RecycleReason::IdleExpired);
assert_ne!(lifetime_reason, idle_reason);
assert!(format!("{:?}", lifetime_reason).contains("LifetimeExpired"));
assert!(format!("{:?}", idle_reason).contains("IdleExpired"));
}
#[test]
fn test_pool_stats_default() {
let stats = PoolStats::default();
assert_eq!(stats.connections_created, 0);
assert_eq!(stats.connections_closed, 0);
assert_eq!(stats.acquisitions, 0);
assert_eq!(stats.exhausted_count, 0);
assert_eq!(stats.total_wait_time_ms, 0);
assert_eq!(stats.health_check_failures, 0);
assert_eq!(stats.lifetime_expired_count, 0);
assert_eq!(stats.idle_expired_count, 0);
assert_eq!(stats.reused_count, 0);
assert_eq!(stats.fresh_count, 0);
}
#[test]
fn test_pool_config_defaults() {
let config = PoolConfig::default();
assert_eq!(config.min_size, 1);
assert_eq!(config.max_size, 10);
assert_eq!(config.acquire_timeout, Duration::from_secs(30));
assert_eq!(config.max_lifetime, Duration::from_secs(1800)); assert_eq!(config.idle_timeout, Duration::from_secs(600)); assert_eq!(config.health_check_interval, Duration::from_secs(30));
assert!(config.test_on_borrow);
assert!(!config.test_on_return);
}
#[test]
fn test_pooled_connection_lifecycle_methods() {
use std::time::{Duration, Instant};
let now = Instant::now();
let short_lifetime = Duration::from_millis(10);
assert!(now.elapsed() < Duration::from_secs(1800));
std::thread::sleep(Duration::from_millis(15));
assert!(now.elapsed() > short_lifetime);
}
#[test]
fn test_pooled_connection_borrowed_at() {
use std::time::{Duration, Instant};
let borrowed_at = Instant::now();
std::thread::sleep(Duration::from_millis(5));
assert!(borrowed_at.elapsed() >= Duration::from_millis(5));
assert!(borrowed_at.elapsed() < Duration::from_secs(1));
}
}