use super::connection_wrapper::PooledConnection;
use do_memory_core::{Error, Result};
use libsql::Database;
use parking_lot::Mutex;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Semaphore;
use tracing::{debug, info};
#[derive(Debug, Clone)]
pub struct CachingPoolConfig {
pub max_connections: usize,
pub min_connections: usize,
pub connection_timeout: Duration,
pub max_idle_time: Duration,
pub max_connection_age: Duration,
pub enable_health_check: bool,
}
impl Default for CachingPoolConfig {
fn default() -> Self {
Self {
max_connections: 10,
min_connections: 2,
connection_timeout: Duration::from_secs(5),
max_idle_time: Duration::from_secs(300),
max_connection_age: Duration::from_secs(3600),
enable_health_check: true,
}
}
}
#[derive(Debug, Default, Clone)]
pub struct CachingPoolStats {
pub total_created: u64,
pub total_checkouts: u64,
pub total_returns: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub active_connections: usize,
pub idle_connections: usize,
pub evictions: u64,
}
pub struct CachingPool {
db: Arc<Database>,
config: CachingPoolConfig,
idle_connections: Mutex<Vec<PooledConnection>>,
active_connection_ids: Mutex<std::collections::HashSet<u64>>,
semaphore: Arc<Semaphore>,
stats: Mutex<CachingPoolStats>,
cleanup_callback: Mutex<Option<Arc<dyn Fn(u64) + Send + Sync>>>,
}
impl CachingPool {
pub async fn new(db: Arc<Database>, config: CachingPoolConfig) -> Result<Self> {
info!(
"Creating caching pool: min={}, max={}",
config.min_connections, config.max_connections
);
let conn = db
.connect()
.map_err(|e| Error::Storage(format!("Failed to connect: {}", e)))?;
conn.query("SELECT 1", ())
.await
.map_err(|e| Error::Storage(format!("Database validation failed: {}", e)))?;
let semaphore = Arc::new(Semaphore::new(config.max_connections));
let idle_connections = Mutex::new(Vec::new());
let active_connection_ids = Mutex::new(std::collections::HashSet::new());
let stats = Mutex::new(CachingPoolStats::default());
let pool = Self {
db,
config,
idle_connections,
active_connection_ids,
semaphore,
stats,
cleanup_callback: Mutex::new(None),
};
pool.pre_create_connections().await?;
info!("Caching pool created successfully");
Ok(pool)
}
pub fn set_cleanup_callback<F>(&self, callback: F)
where
F: Fn(u64) + Send + Sync + 'static,
{
*self.cleanup_callback.lock() = Some(Arc::new(callback));
}
async fn pre_create_connections(&self) -> Result<()> {
let current_count = self.idle_connections.lock().len();
let needed = self.config.min_connections.saturating_sub(current_count);
for _ in 0..needed {
let conn = self.create_connection().await?;
self.idle_connections.lock().push(conn);
}
debug!("Pre-created {} connections", needed);
Ok(())
}
async fn create_connection(&self) -> Result<PooledConnection> {
let conn = self
.db
.connect()
.map_err(|e| Error::Storage(format!("Failed to create connection: {}", e)))?;
let pooled_conn = PooledConnection::new(conn);
if self.config.enable_health_check {
pooled_conn
.validate()
.await
.map_err(|e| Error::Storage(format!("Connection health check failed: {}", e)))?;
}
self.stats.lock().total_created += 1;
self.stats.lock().cache_misses += 1;
Ok(pooled_conn)
}
pub async fn get(&self) -> Result<ConnectionGuard> {
let permit = tokio::time::timeout(
self.config.connection_timeout,
self.semaphore.clone().acquire_owned(),
)
.await
.map_err(|_| {
Error::Storage(format!(
"Connection pool timeout after {:?}",
self.config.connection_timeout
))
})?
.map_err(|e| Error::Storage(format!("Failed to acquire permit: {}", e)))?;
let mut pooled_conn = {
let mut idle = self.idle_connections.lock();
idle.pop()
};
let conn_id = if let Some(ref conn) = pooled_conn {
debug!("Reusing connection {}", conn.id());
self.stats.lock().cache_hits += 1;
conn.id()
} else {
let new_conn = self.create_connection().await?;
let id = new_conn.id();
pooled_conn = Some(new_conn);
id
};
self.active_connection_ids.lock().insert(conn_id);
self.stats.lock().total_checkouts += 1;
self.stats.lock().active_connections += 1;
self.stats.lock().idle_connections = self.idle_connections.lock().len();
let connection = pooled_conn.ok_or_else(|| {
Error::Storage("Failed to get connection from pool: connection is None".to_string())
})?;
Ok(ConnectionGuard {
pool: self as *const Self as usize, connection: Some(connection),
_permit: Some(permit),
})
}
fn return_connection(&self, mut connection: PooledConnection) {
let conn_id = connection.id();
debug!("Returning connection {} to pool", conn_id);
self.active_connection_ids.lock().remove(&conn_id);
self.stats.lock().total_returns += 1;
self.stats.lock().active_connections = self.active_connection_ids.lock().len();
self.stats.lock().idle_connections = self.idle_connections.lock().len() + 1;
connection.touch();
self.idle_connections.lock().push(connection);
}
fn destroy_connection(&self, connection: PooledConnection) {
let conn_id = connection.id();
debug!("Destroying connection {}", conn_id);
self.active_connection_ids.lock().remove(&conn_id);
self.stats.lock().evictions += 1;
if let Some(callback) = self.cleanup_callback.lock().as_ref() {
callback(conn_id);
}
}
pub fn cleanup_idle_connections(&self) -> usize {
let mut idle = self.idle_connections.lock();
let original_len = idle.len();
idle.retain(|conn| {
let age = conn.age();
let idle_time = conn.idle_time();
let should_keep =
age < self.config.max_connection_age && idle_time < self.config.max_idle_time;
if !should_keep {
if let Some(callback) = self.cleanup_callback.lock().as_ref() {
callback(conn.id());
}
self.stats.lock().evictions += 1;
}
should_keep
});
let evicted = original_len - idle.len();
if evicted > 0 {
info!(
"Cleaned up {} idle connections (remaining: {})",
evicted,
idle.len()
);
}
self.stats.lock().idle_connections = idle.len();
evicted
}
pub fn stats(&self) -> CachingPoolStats {
self.stats.lock().clone()
}
pub fn cache_hit_rate(&self) -> f64 {
let stats = self.stats.lock();
let total = stats.cache_hits + stats.cache_misses;
if total == 0 {
0.0
} else {
stats.cache_hits as f64 / total as f64
}
}
pub fn available_connections(&self) -> usize {
self.idle_connections.lock().len()
}
pub fn active_connections(&self) -> usize {
self.active_connection_ids.lock().len()
}
}
pub struct ConnectionGuard {
pool: usize,
connection: Option<PooledConnection>,
_permit: Option<tokio::sync::OwnedSemaphorePermit>,
}
impl ConnectionGuard {
pub fn id(&self) -> do_memory_core::Result<u64> {
self.connection.as_ref().map(|c| c.id()).ok_or_else(|| {
do_memory_core::Error::Storage(
"ConnectionGuard::id() called on guard without connection".to_string(),
)
})
}
pub fn connection(&self) -> do_memory_core::Result<&libsql::Connection> {
self.connection
.as_ref()
.map(|c| c.connection())
.ok_or_else(|| {
do_memory_core::Error::Storage(
"ConnectionGuard::connection() called on guard without connection".to_string(),
)
})
}
pub fn pooled(&self) -> Option<&PooledConnection> {
self.connection.as_ref()
}
}
impl Drop for ConnectionGuard {
fn drop(&mut self) {
let pool = unsafe { &*(self.pool as *const CachingPool) };
if let (Some(_permit), Some(connection)) = (self._permit.take(), self.connection.take()) {
pool.return_connection(connection);
}
}
}
unsafe impl Send for ConnectionGuard {}
#[cfg(test)]
#[path = "caching_pool_tests.rs"]
mod tests;