use do_memory_core::{Error, Result};
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::{debug, info, warn};
pub mod config;
pub mod connection;
mod monitoring;
#[cfg(test)]
mod tests;
pub use config::{KeepAliveConfig, KeepAliveStatistics};
pub use connection::KeepAliveConnection;
use crate::pool::{ConnectionPool, PoolStatistics, PooledConnection};
pub struct KeepAlivePool {
pool: Arc<ConnectionPool>,
config: KeepAliveConfig,
last_used: RwLock<HashMap<usize, Instant>>,
stats: Arc<RwLock<KeepAliveStatistics>>,
next_conn_id: RwLock<usize>,
_cleanup_handle: tokio::task::JoinHandle<()>,
}
impl KeepAlivePool {
pub async fn new(pool: Arc<ConnectionPool>, config: Option<KeepAliveConfig>) -> Result<Self> {
let config = config.unwrap_or_default();
info!(
"Creating keep-alive pool with interval={:?}, stale_threshold={:?}",
config.keep_alive_interval, config.stale_threshold
);
let stats = Arc::new(RwLock::new(KeepAliveStatistics::default()));
let last_used = RwLock::new(HashMap::new());
let next_conn_id = RwLock::new(0);
let pool_instance = Self {
pool: Arc::clone(&pool),
config: config.clone(),
last_used,
stats: Arc::clone(&stats),
next_conn_id,
_cleanup_handle: tokio::spawn(async move {
}),
};
let _ = pool
.get()
.await
.map_err(|e| Error::Storage(format!("Failed to validate connection pool: {}", e)))?;
info!("Keep-alive pool created successfully");
Ok(pool_instance)
}
pub async fn with_config(pool: Arc<ConnectionPool>, config: KeepAliveConfig) -> Result<Self> {
Self::new(pool, Some(config)).await
}
pub async fn get(&self) -> Result<KeepAliveConnection> {
let start = Instant::now();
let pooled = self.pool.get().await?;
let conn_id = {
let mut next_id = self.next_conn_id.write();
let id = *next_id;
*next_id += 1;
id
};
let now = Instant::now();
let was_stale = {
let last_used_map = self.last_used.read();
if let Some(last_used_time) = last_used_map.get(&conn_id) {
let elapsed = now.duration_since(*last_used_time);
elapsed > self.config.stale_threshold
} else {
false
}
};
{
let mut stats = self.stats.write();
stats.total_connections_created += 1;
stats.active_connections += 1;
stats.update_activity();
}
{
let mut last_used_map = self.last_used.write();
last_used_map.insert(conn_id, now);
}
if was_stale {
self.refresh_connection(conn_id, &pooled).await?;
}
let elapsed = start.elapsed();
debug!(
"Keep-alive connection acquired (id={}, stale={}, elapsed={:?})",
conn_id, was_stale, elapsed
);
let stats_ref = Arc::clone(&self.stats);
Ok(KeepAliveConnection::new(pooled, conn_id, now, stats_ref))
}
pub fn is_stale(&self, conn_id: usize) -> bool {
let last_used_map = self.last_used.read();
if let Some(last_used_time) = last_used_map.get(&conn_id) {
Instant::now().duration_since(*last_used_time) > self.config.stale_threshold
} else {
true
}
}
async fn refresh_connection(&self, conn_id: usize, pooled: &PooledConnection) -> Result<()> {
debug!("Refreshing stale connection {}", conn_id);
{
let mut stats = self.stats.write();
stats.total_connections_refreshed += 1;
stats.total_stale_detected += 1;
}
if self.config.enable_proactive_ping {
if let Err(e) = self.ping_connection(pooled).await {
let mut stats = self.stats.write();
stats.total_ping_failures += 1;
warn!(
"Ping failed for connection {}, may need refresh: {}",
conn_id, e
);
} else {
let mut stats = self.stats.write();
stats.total_proactive_pings += 1;
}
}
{
let mut last_used_map = self.last_used.write();
last_used_map.insert(conn_id, Instant::now());
}
Ok(())
}
async fn ping_connection(&self, pooled: &PooledConnection) -> Result<()> {
if let Some(conn) = pooled.connection() {
tokio::time::timeout(self.config.ping_timeout, conn.query("SELECT 1", ()))
.await
.map_err(|_| Error::Storage("Ping timeout".to_string()))?
.map_err(|e| Error::Storage(format!("Ping failed: {}", e)))?;
Ok(())
} else {
Err(Error::Storage("Connection not available".to_string()))
}
}
pub fn statistics(&self) -> KeepAliveStatistics {
self.stats.read().clone()
}
pub async fn pool_statistics(&self) -> PoolStatistics {
self.pool.statistics().await
}
pub fn config(&self) -> &KeepAliveConfig {
&self.config
}
pub fn active_connections(&self) -> usize {
self.stats.read().active_connections
}
pub fn tracked_connections(&self) -> usize {
self.last_used.read().len()
}
pub async fn shutdown(&self) {
info!("Shutting down keep-alive pool");
let timeout = Duration::from_secs(30);
let start = Instant::now();
while start.elapsed() < timeout {
{
let stats = self.stats.read();
if stats.active_connections == 0 {
break;
}
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
let stats = self.stats.read();
if stats.active_connections > 0 {
warn!(
"Keep-alive pool shutdown with {} active connections",
stats.active_connections
);
} else {
info!("Keep-alive pool shutdown complete");
}
}
}