ricecoder_mcp/
connection_pool.rs

1//! Connection pool for managing MCP server connections
2
3use crate::error::{Error, Result};
4use std::collections::VecDeque;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7use tracing::{debug, info, warn};
8
9/// Represents a pooled connection to an MCP server
10#[derive(Debug, Clone)]
11pub struct PooledConnection {
12    pub id: String,
13    pub server_id: String,
14    pub is_valid: bool,
15    pub last_used: std::time::Instant,
16}
17
18impl PooledConnection {
19    /// Creates a new pooled connection
20    pub fn new(id: String, server_id: String) -> Self {
21        Self {
22            id,
23            server_id,
24            is_valid: true,
25            last_used: std::time::Instant::now(),
26        }
27    }
28
29    /// Marks the connection as used
30    pub fn mark_used(&mut self) {
31        self.last_used = std::time::Instant::now();
32    }
33
34    /// Checks if the connection is still valid
35    pub fn is_still_valid(&self) -> bool {
36        self.is_valid
37    }
38
39    /// Invalidates the connection
40    pub fn invalidate(&mut self) {
41        self.is_valid = false;
42    }
43}
44
45/// Configuration for the connection pool
46#[derive(Debug, Clone)]
47pub struct PoolConfig {
48    pub min_connections: usize,
49    pub max_connections: usize,
50    pub connection_timeout_ms: u64,
51    pub idle_timeout_ms: u64,
52}
53
54impl Default for PoolConfig {
55    fn default() -> Self {
56        Self {
57            min_connections: 1,
58            max_connections: 10,
59            connection_timeout_ms: 2000,
60            idle_timeout_ms: 30000,
61        }
62    }
63}
64
65/// Connection pool for managing MCP server connections
66#[derive(Debug, Clone)]
67pub struct ConnectionPool {
68    config: PoolConfig,
69    available: Arc<RwLock<VecDeque<PooledConnection>>>,
70    in_use: Arc<RwLock<Vec<PooledConnection>>>,
71    connection_counter: Arc<RwLock<u64>>,
72}
73
74impl ConnectionPool {
75    /// Creates a new connection pool with default configuration
76    pub fn new() -> Self {
77        Self::with_config(PoolConfig::default())
78    }
79
80    /// Creates a new connection pool with custom configuration
81    pub fn with_config(config: PoolConfig) -> Self {
82        Self {
83            config,
84            available: Arc::new(RwLock::new(VecDeque::new())),
85            in_use: Arc::new(RwLock::new(Vec::new())),
86            connection_counter: Arc::new(RwLock::new(0)),
87        }
88    }
89
90    /// Acquires a connection from the pool
91    ///
92    /// # Arguments
93    /// * `server_id` - The server ID for which to acquire a connection
94    ///
95    /// # Returns
96    /// A pooled connection
97    ///
98    /// # Errors
99    /// Returns error if pool is at max capacity or connection creation fails
100    pub async fn acquire(&self, server_id: &str) -> Result<PooledConnection> {
101        debug!("Acquiring connection for server: {}", server_id);
102
103        // Try to get an available connection
104        let mut available = self.available.write().await;
105        if let Some(mut conn) = available.pop_front() {
106            if conn.is_still_valid() {
107                conn.mark_used();
108                let mut in_use = self.in_use.write().await;
109                in_use.push(conn.clone());
110                info!("Reused connection from pool for server: {}", server_id);
111                return Ok(conn);
112            }
113        }
114
115        // Check if we can create a new connection
116        let in_use = self.in_use.read().await;
117        if in_use.len() >= self.config.max_connections {
118            return Err(Error::ConnectionError(
119                "Connection pool at maximum capacity".to_string(),
120            ));
121        }
122        drop(in_use);
123
124        // Create a new connection
125        let mut counter = self.connection_counter.write().await;
126        *counter += 1;
127        let conn_id = format!("conn-{}", counter);
128        drop(counter);
129
130        let conn = PooledConnection::new(conn_id, server_id.to_string());
131        let mut in_use = self.in_use.write().await;
132        in_use.push(conn.clone());
133
134        info!(
135            "Created new connection for server: {} (total: {})",
136            server_id,
137            in_use.len()
138        );
139        Ok(conn)
140    }
141
142    /// Releases a connection back to the pool
143    ///
144    /// # Arguments
145    /// * `connection` - The connection to release
146    ///
147    /// # Returns
148    /// Result indicating success or failure
149    pub async fn release(&self, connection: PooledConnection) -> Result<()> {
150        debug!("Releasing connection: {}", connection.id);
151
152        let mut in_use = self.in_use.write().await;
153        in_use.retain(|c| c.id != connection.id);
154        drop(in_use);
155
156        if connection.is_still_valid() {
157            let mut available = self.available.write().await;
158            available.push_back(connection);
159            info!("Connection returned to pool");
160        } else {
161            info!("Connection invalidated, not returned to pool");
162        }
163
164        Ok(())
165    }
166
167    /// Validates a connection
168    ///
169    /// # Arguments
170    /// * `connection` - The connection to validate
171    ///
172    /// # Returns
173    /// True if connection is valid, false otherwise
174    pub async fn validate(&self, connection: &PooledConnection) -> bool {
175        debug!("Validating connection: {}", connection.id);
176
177        // Check if connection is still in use
178        let in_use = self.in_use.read().await;
179        let is_in_use = in_use.iter().any(|c| c.id == connection.id);
180
181        if !is_in_use {
182            warn!("Connection not in use: {}", connection.id);
183            return false;
184        }
185
186        // Check if connection is still valid
187        if !connection.is_still_valid() {
188            warn!("Connection is invalid: {}", connection.id);
189            return false;
190        }
191
192        true
193    }
194
195    /// Performs health check on all connections
196    ///
197    /// # Returns
198    /// Number of invalid connections removed
199    pub async fn health_check(&self) -> usize {
200        debug!("Performing health check on connection pool");
201
202        let mut available = self.available.write().await;
203        let initial_count = available.len();
204
205        // Remove invalid connections
206        available.retain(|c| c.is_still_valid());
207
208        let removed = initial_count - available.len();
209        if removed > 0 {
210            info!("Removed {} invalid connections from pool", removed);
211        }
212
213        removed
214    }
215
216    /// Gets the current pool statistics
217    pub async fn get_stats(&self) -> PoolStats {
218        let available = self.available.read().await;
219        let in_use = self.in_use.read().await;
220
221        PoolStats {
222            available_connections: available.len(),
223            in_use_connections: in_use.len(),
224            total_connections: available.len() + in_use.len(),
225            max_connections: self.config.max_connections,
226        }
227    }
228
229    /// Clears all connections from the pool
230    pub async fn clear(&self) {
231        debug!("Clearing connection pool");
232
233        let mut available = self.available.write().await;
234        available.clear();
235
236        let mut in_use = self.in_use.write().await;
237        in_use.clear();
238
239        info!("Connection pool cleared");
240    }
241}
242
243impl Default for ConnectionPool {
244    fn default() -> Self {
245        Self::new()
246    }
247}
248
249/// Statistics about the connection pool
250#[derive(Debug, Clone)]
251pub struct PoolStats {
252    pub available_connections: usize,
253    pub in_use_connections: usize,
254    pub total_connections: usize,
255    pub max_connections: usize,
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[tokio::test]
263    async fn test_create_pool() {
264        let pool = ConnectionPool::new();
265        let stats = pool.get_stats().await;
266        assert_eq!(stats.available_connections, 0);
267        assert_eq!(stats.in_use_connections, 0);
268    }
269
270    #[tokio::test]
271    async fn test_acquire_connection() {
272        let pool = ConnectionPool::new();
273        let conn = pool.acquire("server1").await.unwrap();
274        assert_eq!(conn.server_id, "server1");
275        assert!(conn.is_still_valid());
276    }
277
278    #[tokio::test]
279    async fn test_release_connection() {
280        let pool = ConnectionPool::new();
281        let conn = pool.acquire("server1").await.unwrap();
282        let result = pool.release(conn).await;
283        assert!(result.is_ok());
284
285        let stats = pool.get_stats().await;
286        assert_eq!(stats.available_connections, 1);
287        assert_eq!(stats.in_use_connections, 0);
288    }
289
290    #[tokio::test]
291    async fn test_reuse_connection() {
292        let pool = ConnectionPool::new();
293        let conn1 = pool.acquire("server1").await.unwrap();
294        let conn1_id = conn1.id.clone();
295
296        pool.release(conn1).await.unwrap();
297
298        let conn2 = pool.acquire("server1").await.unwrap();
299        assert_eq!(conn2.id, conn1_id);
300    }
301
302    #[tokio::test]
303    async fn test_max_connections() {
304        let config = PoolConfig {
305            min_connections: 1,
306            max_connections: 2,
307            connection_timeout_ms: 2000,
308            idle_timeout_ms: 30000,
309        };
310        let pool = ConnectionPool::with_config(config);
311
312        let conn1 = pool.acquire("server1").await.unwrap();
313        let conn2 = pool.acquire("server1").await.unwrap();
314
315        let result = pool.acquire("server1").await;
316        assert!(result.is_err());
317
318        pool.release(conn1).await.unwrap();
319        pool.release(conn2).await.unwrap();
320    }
321
322    #[tokio::test]
323    async fn test_validate_connection() {
324        let pool = ConnectionPool::new();
325        let conn = pool.acquire("server1").await.unwrap();
326
327        let is_valid = pool.validate(&conn).await;
328        assert!(is_valid);
329
330        pool.release(conn.clone()).await.unwrap();
331
332        let is_valid = pool.validate(&conn).await;
333        assert!(!is_valid);
334    }
335
336    #[tokio::test]
337    async fn test_health_check() {
338        let pool = ConnectionPool::new();
339        let mut conn = pool.acquire("server1").await.unwrap();
340        pool.release(conn.clone()).await.unwrap();
341
342        conn.invalidate();
343        let mut available = pool.available.write().await;
344        available.push_back(conn);
345        drop(available);
346
347        let removed = pool.health_check().await;
348        assert_eq!(removed, 1);
349    }
350
351    #[tokio::test]
352    async fn test_clear_pool() {
353        let pool = ConnectionPool::new();
354        let _conn1 = pool.acquire("server1").await.unwrap();
355        let _conn2 = pool.acquire("server1").await.unwrap();
356
357        pool.clear().await;
358
359        let stats = pool.get_stats().await;
360        assert_eq!(stats.total_connections, 0);
361    }
362}