1use aimdb_client::connection::AimxClient;
7use aimdb_client::ClientError;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::Mutex;
11use tracing::debug;
12
13#[derive(Debug)]
15struct ConnectionEntry {
16 last_used: std::time::Instant,
18}
19
20#[derive(Clone)]
22pub struct ConnectionPool {
23 connections: Arc<Mutex<HashMap<String, ConnectionEntry>>>,
25 drain_clients: Arc<Mutex<HashMap<String, Arc<tokio::sync::Mutex<AimxClient>>>>>,
28}
29
30impl std::fmt::Debug for ConnectionPool {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 f.debug_struct("ConnectionPool")
33 .field("connections", &"<...>")
34 .field("drain_clients", &"<...>")
35 .finish()
36 }
37}
38
39impl ConnectionPool {
40 pub fn new() -> Self {
42 Self {
43 connections: Arc::new(Mutex::new(HashMap::new())),
44 drain_clients: Arc::new(Mutex::new(HashMap::new())),
45 }
46 }
47
48 pub async fn get_connection(&self, socket_path: &str) -> Result<AimxClient, ClientError> {
55 let mut pool = self.connections.lock().await;
56
57 let now = std::time::Instant::now();
59
60 if let Some(entry) = pool.get_mut(socket_path) {
61 debug!(
62 "โป๏ธ Connection metadata exists for {}, reconnecting",
63 socket_path
64 );
65 entry.last_used = now;
66 } else {
67 debug!("๐ First connection to {}", socket_path);
68 pool.insert(socket_path.to_string(), ConnectionEntry { last_used: now });
69 }
70
71 drop(pool); AimxClient::connect(socket_path).await
74 }
75
76 pub async fn invalidate_connection(&self, socket_path: &str) {
78 let mut pool = self.connections.lock().await;
79 if pool.remove(socket_path).is_some() {
80 debug!("โ Invalidated connection metadata for {}", socket_path);
81 }
82 }
83
84 pub async fn get_drain_client(
91 &self,
92 socket_path: &str,
93 ) -> Result<Arc<tokio::sync::Mutex<AimxClient>>, ClientError> {
94 let drain_map = self.drain_clients.lock().await;
95
96 if let Some(client) = drain_map.get(socket_path) {
97 debug!("โป๏ธ Reusing persistent drain client for {}", socket_path);
98 return Ok(Arc::clone(client));
99 }
100
101 debug!("๐ Creating persistent drain client for {}", socket_path);
102
103 drop(drain_map);
105
106 let client = AimxClient::connect(socket_path).await?;
107 let shared = Arc::new(tokio::sync::Mutex::new(client));
108
109 let mut drain_map = self.drain_clients.lock().await;
110 if let Some(existing) = drain_map.get(socket_path) {
112 return Ok(Arc::clone(existing));
113 }
114 drain_map.insert(socket_path.to_string(), Arc::clone(&shared));
115 Ok(shared)
116 }
117
118 pub async fn invalidate_drain_client(&self, socket_path: &str) {
120 let mut drain_map = self.drain_clients.lock().await;
121 if drain_map.remove(socket_path).is_some() {
122 debug!("โ Invalidated drain client for {}", socket_path);
123 }
124 }
125
126 pub async fn clear(&self) {
128 let mut pool = self.connections.lock().await;
129 pool.clear();
130 let mut drain_map = self.drain_clients.lock().await;
131 drain_map.clear();
132 debug!("๐งน Cleared connection pool");
133 }
134
135 pub async fn connection_count(&self) -> usize {
137 let pool = self.connections.lock().await;
138 pool.len()
139 }
140}
141
142impl Default for ConnectionPool {
143 fn default() -> Self {
144 Self::new()
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 #[tokio::test]
153 async fn test_pool_creation() {
154 let pool = ConnectionPool::new();
155 assert_eq!(pool.connection_count().await, 0);
156 }
157
158 #[tokio::test]
159 async fn test_pool_clear() {
160 let pool = ConnectionPool::new();
161 pool.clear().await;
162 assert_eq!(pool.connection_count().await, 0);
163 }
164
165 #[tokio::test]
166 async fn test_invalidate_nonexistent_connection() {
167 let pool = ConnectionPool::new();
168 pool.invalidate_connection("/tmp/nonexistent.sock").await;
170 assert_eq!(pool.connection_count().await, 0);
171 }
172}