use std::sync::Arc;
use tokio::sync::Semaphore;
#[derive(Clone, Debug)]
pub struct ConnectionPool {
pub db_path: std::path::PathBuf,
semaphore: Arc<Semaphore>,
pub max_connections: usize,
}
impl ConnectionPool {
pub fn new(db_path: impl AsRef<std::path::Path>, max_connections: usize) -> Self {
Self {
db_path: db_path.as_ref().to_path_buf(),
semaphore: Arc::new(Semaphore::new(max_connections)),
max_connections,
}
}
pub async fn acquire(&self) -> anyhow::Result<ConnectionPermit> {
let permit = self.semaphore.clone().acquire_owned().await?;
Ok(ConnectionPermit {
_permit: permit,
db_path: self.db_path.clone(),
})
}
pub fn available_connections(&self) -> usize {
self.semaphore.available_permits()
}
pub async fn try_acquire(&self) -> Option<ConnectionPermit> {
self.semaphore
.clone()
.try_acquire_owned()
.ok()
.map(|permit| ConnectionPermit {
_permit: permit,
db_path: self.db_path.clone(),
})
}
}
pub struct ConnectionPermit {
_permit: tokio::sync::OwnedSemaphorePermit,
db_path: std::path::PathBuf,
}
impl std::fmt::Debug for ConnectionPermit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectionPermit")
.field("db_path", &self.db_path)
.finish()
}
}
impl ConnectionPermit {
pub fn db_path(&self) -> &std::path::Path {
&self.db_path
}
}
impl std::fmt::Display for ConnectionPermit {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ConnectionPermit({})", self.db_path.display())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_pool_creation() {
let pool = ConnectionPool::new("/tmp/test.db", 5);
assert_eq!(pool.max_connections, 5);
assert_eq!(pool.available_connections(), 5);
}
#[tokio::test]
async fn test_pool_acquire() {
let pool = ConnectionPool::new("/tmp/test.db", 2);
let permit1 = pool.acquire().await.unwrap();
assert_eq!(pool.available_connections(), 1);
let permit2 = pool.acquire().await.unwrap();
assert_eq!(pool.available_connections(), 0);
drop(permit1);
assert_eq!(pool.available_connections(), 1);
drop(permit2);
assert_eq!(pool.available_connections(), 2);
}
#[tokio::test]
async fn test_pool_try_acquire() {
let pool = ConnectionPool::new("/tmp/test.db", 1);
let permit1 = pool.try_acquire().await;
assert!(permit1.is_some());
assert_eq!(pool.available_connections(), 0);
let permit2 = pool.try_acquire().await;
assert!(permit2.is_none());
drop(permit1);
assert_eq!(pool.available_connections(), 1);
}
#[tokio::test]
async fn test_pool_db_path() {
let pool = ConnectionPool::new("/tmp/test.db", 5);
assert_eq!(pool.db_path, std::path::PathBuf::from("/tmp/test.db"));
let permit = pool.acquire().await.unwrap();
assert_eq!(permit.db_path(), std::path::Path::new("/tmp/test.db"));
}
#[tokio::test]
async fn test_pool_concurrent_acquires() {
use tokio::sync::Barrier;
let pool = Arc::new(ConnectionPool::new("/tmp/test.db", 5));
let barrier = Arc::new(Barrier::new(10));
let mut handles = vec![];
for _i in 0..10 {
let pool_clone = pool.clone();
let barrier_clone = barrier.clone();
handles.push(tokio::spawn(async move {
barrier_clone.wait().await; let _permit = pool_clone.acquire().await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}));
}
for handle in handles {
handle.await.unwrap();
}
assert_eq!(pool.available_connections(), 5);
}
#[tokio::test]
async fn test_pool_timeout_behavior() {
use tokio::time::{timeout, Duration};
let pool = ConnectionPool::new("/tmp/test.db", 1);
let _permit1 = pool.acquire().await.unwrap();
assert_eq!(pool.available_connections(), 0);
let start = std::time::Instant::now();
let result = timeout(Duration::from_millis(100), pool.acquire()).await;
let elapsed = start.elapsed();
assert!(result.is_err());
assert!(elapsed >= Duration::from_millis(90));
assert!(elapsed < Duration::from_millis(200));
}
#[tokio::test]
async fn test_pool_permit_drop_returns() {
let pool = ConnectionPool::new("/tmp/test.db", 3);
assert_eq!(pool.available_connections(), 3);
let permit = pool.acquire().await.unwrap();
assert_eq!(pool.available_connections(), 2);
drop(permit);
assert_eq!(pool.available_connections(), 3);
}
#[tokio::test]
async fn test_pool_stress() {
let pool = ConnectionPool::new("/tmp/test.db", 10);
for _ in 0..100 {
let permit = pool.acquire().await.unwrap();
assert_eq!(permit.db_path(), std::path::Path::new("/tmp/test.db"));
drop(permit);
}
assert_eq!(pool.available_connections(), 10);
}
#[tokio::test]
async fn test_pool_all_permits_acquired() {
let pool = ConnectionPool::new("/tmp/test.db", 3);
let permit1 = pool.acquire().await.unwrap();
let permit2 = pool.acquire().await.unwrap();
let permit3 = pool.acquire().await.unwrap();
assert_eq!(pool.available_connections(), 0);
let permit4 = pool.try_acquire().await;
assert!(permit4.is_none());
drop(permit1);
let permit5 = pool.try_acquire().await;
assert!(permit5.is_some());
drop(permit2);
drop(permit3);
drop(permit5);
}
#[tokio::test]
async fn test_pool_available_count() {
let pool = ConnectionPool::new("/tmp/test.db", 5);
assert_eq!(pool.available_connections(), 5);
let permit1 = pool.acquire().await.unwrap();
assert_eq!(pool.available_connections(), 4);
let permit2 = pool.acquire().await.unwrap();
assert_eq!(pool.available_connections(), 3);
let permit3 = pool.acquire().await.unwrap();
assert_eq!(pool.available_connections(), 2);
drop(permit1);
assert_eq!(pool.available_connections(), 3);
drop(permit2);
assert_eq!(pool.available_connections(), 4);
drop(permit3);
assert_eq!(pool.available_connections(), 5);
}
}