Skip to main content

aimdb_mcp/
connection.rs

1//! Connection pool management for AimDB instances
2//!
3//! Manages persistent connections to AimDB instances to avoid
4//! reconnecting on every tool call. Includes auto-reconnect logic.
5
6use aimdb_client::connection::AimxClient;
7use aimdb_client::ClientError;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::Mutex;
11use tracing::debug;
12
13/// Connection entry with metadata
14#[derive(Debug)]
15struct ConnectionEntry {
16    /// Last successful connection time (for staleness detection)
17    last_used: std::time::Instant,
18}
19
20/// Connection pool for managing AimDB connections
21#[derive(Clone)]
22pub struct ConnectionPool {
23    /// Track which connections we've attempted (for logging/metrics)
24    connections: Arc<Mutex<HashMap<String, ConnectionEntry>>>,
25    /// Persistent drain clients โ€” kept alive so drain readers accumulate values
26    /// Key: socket_path, Value: shared AimxClient
27    drain_clients: Arc<Mutex<HashMap<String, Arc<tokio::sync::Mutex<AimxClient>>>>>,
28}
29
30impl std::fmt::Debug for ConnectionPool {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        f.debug_struct("ConnectionPool")
33            .field("connections", &"<...>")
34            .field("drain_clients", &"<...>")
35            .finish()
36    }
37}
38
39impl ConnectionPool {
40    /// Create a new connection pool
41    pub fn new() -> Self {
42        Self {
43            connections: Arc::new(Mutex::new(HashMap::new())),
44            drain_clients: Arc::new(Mutex::new(HashMap::new())),
45        }
46    }
47
48    /// Get or create a connection to an AimDB instance
49    ///
50    /// Note: Since AimxClient doesn't implement Clone, we create a fresh
51    /// connection each time. The pool tracks connection metadata for
52    /// monitoring and future optimization (e.g., persistent connections
53    /// via Arc<Mutex<AimxClient>> if AimxClient becomes Sync).
54    pub async fn get_connection(&self, socket_path: &str) -> Result<AimxClient, ClientError> {
55        let mut pool = self.connections.lock().await;
56
57        // Update or insert connection metadata
58        let now = std::time::Instant::now();
59
60        if let Some(entry) = pool.get_mut(socket_path) {
61            debug!(
62                "โ™ป๏ธ  Connection metadata exists for {}, reconnecting",
63                socket_path
64            );
65            entry.last_used = now;
66        } else {
67            debug!("๐Ÿ”Œ First connection to {}", socket_path);
68            pool.insert(socket_path.to_string(), ConnectionEntry { last_used: now });
69        }
70
71        // Always create a new connection (until AimxClient supports cloning/sharing)
72        drop(pool); // Release lock before async operation
73        AimxClient::connect(socket_path).await
74    }
75
76    /// Remove a connection from the pool (called when operations fail)
77    pub async fn invalidate_connection(&self, socket_path: &str) {
78        let mut pool = self.connections.lock().await;
79        if pool.remove(socket_path).is_some() {
80            debug!("โŒ Invalidated connection metadata for {}", socket_path);
81        }
82    }
83
84    /// Get or create a persistent drain client for a socket path.
85    ///
86    /// Drain clients are kept alive across calls so the server-side drain
87    /// reader accumulates values between invocations. The first drain call
88    /// on a new connection is a cold start (returns empty); subsequent calls
89    /// return all values accumulated since the previous drain.
90    pub async fn get_drain_client(
91        &self,
92        socket_path: &str,
93    ) -> Result<Arc<tokio::sync::Mutex<AimxClient>>, ClientError> {
94        let drain_map = self.drain_clients.lock().await;
95
96        if let Some(client) = drain_map.get(socket_path) {
97            debug!("โ™ป๏ธ  Reusing persistent drain client for {}", socket_path);
98            return Ok(Arc::clone(client));
99        }
100
101        debug!("๐Ÿ”Œ Creating persistent drain client for {}", socket_path);
102
103        // Drop lock before async connect
104        drop(drain_map);
105
106        let client = AimxClient::connect(socket_path).await?;
107        let shared = Arc::new(tokio::sync::Mutex::new(client));
108
109        let mut drain_map = self.drain_clients.lock().await;
110        // Double-check: another task may have inserted while we were connecting
111        if let Some(existing) = drain_map.get(socket_path) {
112            return Ok(Arc::clone(existing));
113        }
114        drain_map.insert(socket_path.to_string(), Arc::clone(&shared));
115        Ok(shared)
116    }
117
118    /// Invalidate (remove) a persistent drain client, e.g. after connection error
119    pub async fn invalidate_drain_client(&self, socket_path: &str) {
120        let mut drain_map = self.drain_clients.lock().await;
121        if drain_map.remove(socket_path).is_some() {
122            debug!("โŒ Invalidated drain client for {}", socket_path);
123        }
124    }
125
126    /// Clear all connections in the pool
127    pub async fn clear(&self) {
128        let mut pool = self.connections.lock().await;
129        pool.clear();
130        let mut drain_map = self.drain_clients.lock().await;
131        drain_map.clear();
132        debug!("๐Ÿงน Cleared connection pool");
133    }
134
135    /// Get the number of tracked connections
136    pub async fn connection_count(&self) -> usize {
137        let pool = self.connections.lock().await;
138        pool.len()
139    }
140}
141
142impl Default for ConnectionPool {
143    fn default() -> Self {
144        Self::new()
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    #[tokio::test]
153    async fn test_pool_creation() {
154        let pool = ConnectionPool::new();
155        assert_eq!(pool.connection_count().await, 0);
156    }
157
158    #[tokio::test]
159    async fn test_pool_clear() {
160        let pool = ConnectionPool::new();
161        pool.clear().await;
162        assert_eq!(pool.connection_count().await, 0);
163    }
164
165    #[tokio::test]
166    async fn test_invalidate_nonexistent_connection() {
167        let pool = ConnectionPool::new();
168        // Should not panic
169        pool.invalidate_connection("/tmp/nonexistent.sock").await;
170        assert_eq!(pool.connection_count().await, 0);
171    }
172}