use crate::error::{CacheError, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
use tokio::sync::Mutex;
#[async_trait]
pub trait DatabaseOperations: Debug + Send + Sync {
async fn is_connected(&self) -> bool;
async fn query(&self, sql: &str) -> Result<Vec<HashMap<String, String>>>;
async fn execute(&self, sql: &str) -> Result<u64>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PoolConfig {
pub max_size: u32,
pub min_idle: u32,
pub connection_timeout: u64,
pub idle_timeout: u64,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
max_size: 10,
min_idle: 1,
connection_timeout: 30,
idle_timeout: 600,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct PoolStats {
pub active_connections: u32,
pub idle_connections: u32,
pub waiting_requests: u32,
pub total_connections: u32,
}
impl PoolStats {
pub fn new() -> Self {
Self::default()
}
}
pub struct ConnectionPool<T: DatabaseOperations> {
pool: Arc<Mutex<Vec<Arc<T>>>>,
#[allow(dead_code)]
config: PoolConfig,
creator: Arc<Box<dyn Fn() -> Result<Arc<T>> + Send + Sync>>,
#[allow(dead_code)]
active_count: Arc<Mutex<u32>>,
stats: Arc<Mutex<PoolStats>>,
}
impl<T: DatabaseOperations> ConnectionPool<T> {
pub async fn new<F>(config: PoolConfig, creator: F) -> Result<Self>
where
F: Fn() -> Result<Arc<T>> + Send + Sync + 'static,
{
let mut connections = Vec::new();
for _ in 0..config.min_idle {
connections.push(creator()?);
}
Ok(Self {
pool: Arc::new(Mutex::new(connections)),
config,
creator: Arc::new(Box::new(creator)),
active_count: Arc::new(Mutex::new(0)),
stats: Arc::new(Mutex::new(PoolStats::new())),
})
}
pub async fn get_connection(&self) -> Result<Arc<T>> {
let mut pool = self.pool.lock().await;
if let Some(conn) = pool.pop() {
let mut stats = self.stats.lock().await;
stats.idle_connections = stats.idle_connections.saturating_sub(1);
stats.active_connections = stats.active_connections.saturating_add(1);
return Ok(conn);
}
let mut stats = self.stats.lock().await;
if stats.active_connections >= self.config.max_size {
return Err(CacheError::DatabaseError(
"Connection pool exhausted".to_string(),
));
}
stats.active_connections = stats.active_connections.saturating_add(1);
drop(stats);
let new_conn = (self.creator)()?;
Ok(new_conn)
}
pub async fn return_connection(&self, conn: Arc<T>) {
let mut pool = self.pool.lock().await;
let mut stats = self.stats.lock().await;
stats.active_connections = stats.active_connections.saturating_sub(1);
pool.push(conn);
}
pub async fn get_stats(&self) -> PoolStats {
let stats = self.stats.lock().await;
let pool = self.pool.lock().await;
let pool_size = pool.len() as u32;
PoolStats {
active_connections: stats.active_connections,
idle_connections: pool_size,
waiting_requests: 0,
total_connections: stats.active_connections + pool_size,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug)]
struct MockConnection {
#[allow(dead_code)]
id: usize,
connection_count: &'static AtomicUsize,
}
impl MockConnection {
fn new(id: usize, connection_count: &'static AtomicUsize) -> Self {
connection_count.fetch_add(1, Ordering::SeqCst);
Self {
id,
connection_count,
}
}
}
#[async_trait::async_trait]
impl DatabaseOperations for MockConnection {
async fn is_connected(&self) -> bool {
true
}
async fn query(&self, _sql: &str) -> Result<Vec<HashMap<String, String>>> {
Ok(vec![])
}
async fn execute(&self, _sql: &str) -> Result<u64> {
Ok(0)
}
}
impl Drop for MockConnection {
fn drop(&mut self) {
self.connection_count.fetch_sub(1, Ordering::SeqCst);
}
}
#[tokio::test]
async fn test_connection_pool_creates_connection_when_empty() {
static CONNECTION_COUNT: AtomicUsize = AtomicUsize::new(0);
let config = PoolConfig {
max_size: 5,
min_idle: 0,
connection_timeout: 30,
idle_timeout: 600,
};
let pool = ConnectionPool::<MockConnection>::new(config, || {
Ok(Arc::new(MockConnection::new(
CONNECTION_COUNT.load(Ordering::SeqCst) + 1,
&CONNECTION_COUNT,
)))
})
.await
.unwrap();
let conn = pool.get_connection().await.unwrap();
let stats = pool.get_stats().await;
assert_eq!(CONNECTION_COUNT.load(Ordering::SeqCst), 1);
assert_eq!(stats.active_connections, 1);
pool.return_connection(conn).await;
let stats = pool.get_stats().await;
assert_eq!(stats.active_connections, 0);
assert_eq!(stats.idle_connections, 1);
}
#[tokio::test]
async fn test_connection_pool_exhaustion() {
static CONNECTION_COUNT: AtomicUsize = AtomicUsize::new(0);
let config = PoolConfig {
max_size: 2,
min_idle: 0,
connection_timeout: 30,
idle_timeout: 600,
};
let pool = ConnectionPool::<MockConnection>::new(config, || {
Ok(Arc::new(MockConnection::new(
CONNECTION_COUNT.load(Ordering::SeqCst) + 1,
&CONNECTION_COUNT,
)))
})
.await
.unwrap();
let conn1 = pool.get_connection().await.unwrap();
assert_eq!(CONNECTION_COUNT.load(Ordering::SeqCst), 1);
let conn2 = pool.get_connection().await.unwrap();
assert_eq!(CONNECTION_COUNT.load(Ordering::SeqCst), 2);
let result = pool.get_connection().await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Connection pool exhausted"));
drop(conn1);
drop(conn2);
}
#[tokio::test]
async fn test_connection_pool_return_and_reacquire() {
static CONNECTION_COUNT: AtomicUsize = AtomicUsize::new(0);
let config = PoolConfig {
max_size: 3,
min_idle: 0,
connection_timeout: 30,
idle_timeout: 600,
};
let pool = ConnectionPool::<MockConnection>::new(config, || {
Ok(Arc::new(MockConnection::new(
CONNECTION_COUNT.load(Ordering::SeqCst) + 1,
&CONNECTION_COUNT,
)))
})
.await
.unwrap();
let conn = pool.get_connection().await.unwrap();
assert_eq!(CONNECTION_COUNT.load(Ordering::SeqCst), 1);
pool.return_connection(conn).await;
let conn2 = pool.get_connection().await.unwrap();
assert_eq!(CONNECTION_COUNT.load(Ordering::SeqCst), 1);
pool.return_connection(conn2).await;
}
#[tokio::test]
async fn test_connection_pool_zero_min_idle() {
static CONNECTION_COUNT: AtomicUsize = AtomicUsize::new(0);
let config = PoolConfig {
max_size: 5,
min_idle: 0, connection_timeout: 30,
idle_timeout: 600,
};
let pool = ConnectionPool::<MockConnection>::new(config, || {
Ok(Arc::new(MockConnection::new(
CONNECTION_COUNT.load(Ordering::SeqCst) + 1,
&CONNECTION_COUNT,
)))
})
.await
.unwrap();
let stats = pool.get_stats().await;
assert_eq!(stats.active_connections, 0);
assert_eq!(stats.idle_connections, 0);
let conn = pool.get_connection().await.unwrap();
let stats = pool.get_stats().await;
assert_eq!(stats.active_connections, 1);
assert_eq!(stats.idle_connections, 0);
pool.return_connection(conn).await;
}
#[tokio::test]
async fn test_connection_pool_stats_accuracy() {
static CONNECTION_COUNT: AtomicUsize = AtomicUsize::new(0);
let config = PoolConfig {
max_size: 10,
min_idle: 2,
connection_timeout: 30,
idle_timeout: 600,
};
let pool = ConnectionPool::<MockConnection>::new(config, || {
Ok(Arc::new(MockConnection::new(
CONNECTION_COUNT.load(Ordering::SeqCst) + 1,
&CONNECTION_COUNT,
)))
})
.await
.unwrap();
let stats = pool.get_stats().await;
assert_eq!(stats.idle_connections, 2);
assert_eq!(stats.active_connections, 0);
let conn = pool.get_connection().await.unwrap();
let stats = pool.get_stats().await;
assert_eq!(stats.idle_connections, 1);
assert_eq!(stats.active_connections, 1);
pool.return_connection(conn).await;
let stats = pool.get_stats().await;
assert_eq!(stats.idle_connections, 2);
assert_eq!(stats.active_connections, 0);
}
}