use aimdb_client::connection::AimxClient;
use aimdb_client::ClientError;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::debug;
#[derive(Debug)]
struct ConnectionEntry {
last_used: std::time::Instant,
}
#[derive(Clone)]
pub struct ConnectionPool {
connections: Arc<Mutex<HashMap<String, ConnectionEntry>>>,
drain_clients: Arc<Mutex<HashMap<String, Arc<tokio::sync::Mutex<AimxClient>>>>>,
}
impl std::fmt::Debug for ConnectionPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectionPool")
.field("connections", &"<...>")
.field("drain_clients", &"<...>")
.finish()
}
}
impl ConnectionPool {
pub fn new() -> Self {
Self {
connections: Arc::new(Mutex::new(HashMap::new())),
drain_clients: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn get_connection(&self, socket_path: &str) -> Result<AimxClient, ClientError> {
let mut pool = self.connections.lock().await;
let now = std::time::Instant::now();
if let Some(entry) = pool.get_mut(socket_path) {
debug!(
"โป๏ธ Connection metadata exists for {}, reconnecting",
socket_path
);
entry.last_used = now;
} else {
debug!("๐ First connection to {}", socket_path);
pool.insert(socket_path.to_string(), ConnectionEntry { last_used: now });
}
drop(pool); AimxClient::connect(socket_path).await
}
pub async fn invalidate_connection(&self, socket_path: &str) {
let mut pool = self.connections.lock().await;
if pool.remove(socket_path).is_some() {
debug!("โ Invalidated connection metadata for {}", socket_path);
}
}
pub async fn get_drain_client(
&self,
socket_path: &str,
) -> Result<Arc<tokio::sync::Mutex<AimxClient>>, ClientError> {
let drain_map = self.drain_clients.lock().await;
if let Some(client) = drain_map.get(socket_path) {
debug!("โป๏ธ Reusing persistent drain client for {}", socket_path);
return Ok(Arc::clone(client));
}
debug!("๐ Creating persistent drain client for {}", socket_path);
drop(drain_map);
let client = AimxClient::connect(socket_path).await?;
let shared = Arc::new(tokio::sync::Mutex::new(client));
let mut drain_map = self.drain_clients.lock().await;
if let Some(existing) = drain_map.get(socket_path) {
return Ok(Arc::clone(existing));
}
drain_map.insert(socket_path.to_string(), Arc::clone(&shared));
Ok(shared)
}
pub async fn invalidate_drain_client(&self, socket_path: &str) {
let mut drain_map = self.drain_clients.lock().await;
if drain_map.remove(socket_path).is_some() {
debug!("โ Invalidated drain client for {}", socket_path);
}
}
pub async fn clear(&self) {
let mut pool = self.connections.lock().await;
pool.clear();
let mut drain_map = self.drain_clients.lock().await;
drain_map.clear();
debug!("๐งน Cleared connection pool");
}
pub async fn connection_count(&self) -> usize {
let pool = self.connections.lock().await;
pool.len()
}
}
impl Default for ConnectionPool {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_pool_creation() {
let pool = ConnectionPool::new();
assert_eq!(pool.connection_count().await, 0);
}
#[tokio::test]
async fn test_pool_clear() {
let pool = ConnectionPool::new();
pool.clear().await;
assert_eq!(pool.connection_count().await, 0);
}
#[tokio::test]
async fn test_invalidate_nonexistent_connection() {
let pool = ConnectionPool::new();
pool.invalidate_connection("/tmp/nonexistent.sock").await;
assert_eq!(pool.connection_count().await, 0);
}
}