ricecoder_mcp/
connection_pool.rs1use crate::error::{Error, Result};
4use std::collections::VecDeque;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7use tracing::{debug, info, warn};
8
9#[derive(Debug, Clone)]
11pub struct PooledConnection {
12 pub id: String,
13 pub server_id: String,
14 pub is_valid: bool,
15 pub last_used: std::time::Instant,
16}
17
18impl PooledConnection {
19 pub fn new(id: String, server_id: String) -> Self {
21 Self {
22 id,
23 server_id,
24 is_valid: true,
25 last_used: std::time::Instant::now(),
26 }
27 }
28
29 pub fn mark_used(&mut self) {
31 self.last_used = std::time::Instant::now();
32 }
33
34 pub fn is_still_valid(&self) -> bool {
36 self.is_valid
37 }
38
39 pub fn invalidate(&mut self) {
41 self.is_valid = false;
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct PoolConfig {
48 pub min_connections: usize,
49 pub max_connections: usize,
50 pub connection_timeout_ms: u64,
51 pub idle_timeout_ms: u64,
52}
53
54impl Default for PoolConfig {
55 fn default() -> Self {
56 Self {
57 min_connections: 1,
58 max_connections: 10,
59 connection_timeout_ms: 2000,
60 idle_timeout_ms: 30000,
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct ConnectionPool {
68 config: PoolConfig,
69 available: Arc<RwLock<VecDeque<PooledConnection>>>,
70 in_use: Arc<RwLock<Vec<PooledConnection>>>,
71 connection_counter: Arc<RwLock<u64>>,
72}
73
74impl ConnectionPool {
75 pub fn new() -> Self {
77 Self::with_config(PoolConfig::default())
78 }
79
80 pub fn with_config(config: PoolConfig) -> Self {
82 Self {
83 config,
84 available: Arc::new(RwLock::new(VecDeque::new())),
85 in_use: Arc::new(RwLock::new(Vec::new())),
86 connection_counter: Arc::new(RwLock::new(0)),
87 }
88 }
89
90 pub async fn acquire(&self, server_id: &str) -> Result<PooledConnection> {
101 debug!("Acquiring connection for server: {}", server_id);
102
103 let mut available = self.available.write().await;
105 if let Some(mut conn) = available.pop_front() {
106 if conn.is_still_valid() {
107 conn.mark_used();
108 let mut in_use = self.in_use.write().await;
109 in_use.push(conn.clone());
110 info!("Reused connection from pool for server: {}", server_id);
111 return Ok(conn);
112 }
113 }
114
115 let in_use = self.in_use.read().await;
117 if in_use.len() >= self.config.max_connections {
118 return Err(Error::ConnectionError(
119 "Connection pool at maximum capacity".to_string(),
120 ));
121 }
122 drop(in_use);
123
124 let mut counter = self.connection_counter.write().await;
126 *counter += 1;
127 let conn_id = format!("conn-{}", counter);
128 drop(counter);
129
130 let conn = PooledConnection::new(conn_id, server_id.to_string());
131 let mut in_use = self.in_use.write().await;
132 in_use.push(conn.clone());
133
134 info!(
135 "Created new connection for server: {} (total: {})",
136 server_id,
137 in_use.len()
138 );
139 Ok(conn)
140 }
141
142 pub async fn release(&self, connection: PooledConnection) -> Result<()> {
150 debug!("Releasing connection: {}", connection.id);
151
152 let mut in_use = self.in_use.write().await;
153 in_use.retain(|c| c.id != connection.id);
154 drop(in_use);
155
156 if connection.is_still_valid() {
157 let mut available = self.available.write().await;
158 available.push_back(connection);
159 info!("Connection returned to pool");
160 } else {
161 info!("Connection invalidated, not returned to pool");
162 }
163
164 Ok(())
165 }
166
167 pub async fn validate(&self, connection: &PooledConnection) -> bool {
175 debug!("Validating connection: {}", connection.id);
176
177 let in_use = self.in_use.read().await;
179 let is_in_use = in_use.iter().any(|c| c.id == connection.id);
180
181 if !is_in_use {
182 warn!("Connection not in use: {}", connection.id);
183 return false;
184 }
185
186 if !connection.is_still_valid() {
188 warn!("Connection is invalid: {}", connection.id);
189 return false;
190 }
191
192 true
193 }
194
195 pub async fn health_check(&self) -> usize {
200 debug!("Performing health check on connection pool");
201
202 let mut available = self.available.write().await;
203 let initial_count = available.len();
204
205 available.retain(|c| c.is_still_valid());
207
208 let removed = initial_count - available.len();
209 if removed > 0 {
210 info!("Removed {} invalid connections from pool", removed);
211 }
212
213 removed
214 }
215
216 pub async fn get_stats(&self) -> PoolStats {
218 let available = self.available.read().await;
219 let in_use = self.in_use.read().await;
220
221 PoolStats {
222 available_connections: available.len(),
223 in_use_connections: in_use.len(),
224 total_connections: available.len() + in_use.len(),
225 max_connections: self.config.max_connections,
226 }
227 }
228
229 pub async fn clear(&self) {
231 debug!("Clearing connection pool");
232
233 let mut available = self.available.write().await;
234 available.clear();
235
236 let mut in_use = self.in_use.write().await;
237 in_use.clear();
238
239 info!("Connection pool cleared");
240 }
241}
242
243impl Default for ConnectionPool {
244 fn default() -> Self {
245 Self::new()
246 }
247}
248
249#[derive(Debug, Clone)]
251pub struct PoolStats {
252 pub available_connections: usize,
253 pub in_use_connections: usize,
254 pub total_connections: usize,
255 pub max_connections: usize,
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[tokio::test]
263 async fn test_create_pool() {
264 let pool = ConnectionPool::new();
265 let stats = pool.get_stats().await;
266 assert_eq!(stats.available_connections, 0);
267 assert_eq!(stats.in_use_connections, 0);
268 }
269
270 #[tokio::test]
271 async fn test_acquire_connection() {
272 let pool = ConnectionPool::new();
273 let conn = pool.acquire("server1").await.unwrap();
274 assert_eq!(conn.server_id, "server1");
275 assert!(conn.is_still_valid());
276 }
277
278 #[tokio::test]
279 async fn test_release_connection() {
280 let pool = ConnectionPool::new();
281 let conn = pool.acquire("server1").await.unwrap();
282 let result = pool.release(conn).await;
283 assert!(result.is_ok());
284
285 let stats = pool.get_stats().await;
286 assert_eq!(stats.available_connections, 1);
287 assert_eq!(stats.in_use_connections, 0);
288 }
289
290 #[tokio::test]
291 async fn test_reuse_connection() {
292 let pool = ConnectionPool::new();
293 let conn1 = pool.acquire("server1").await.unwrap();
294 let conn1_id = conn1.id.clone();
295
296 pool.release(conn1).await.unwrap();
297
298 let conn2 = pool.acquire("server1").await.unwrap();
299 assert_eq!(conn2.id, conn1_id);
300 }
301
302 #[tokio::test]
303 async fn test_max_connections() {
304 let config = PoolConfig {
305 min_connections: 1,
306 max_connections: 2,
307 connection_timeout_ms: 2000,
308 idle_timeout_ms: 30000,
309 };
310 let pool = ConnectionPool::with_config(config);
311
312 let conn1 = pool.acquire("server1").await.unwrap();
313 let conn2 = pool.acquire("server1").await.unwrap();
314
315 let result = pool.acquire("server1").await;
316 assert!(result.is_err());
317
318 pool.release(conn1).await.unwrap();
319 pool.release(conn2).await.unwrap();
320 }
321
322 #[tokio::test]
323 async fn test_validate_connection() {
324 let pool = ConnectionPool::new();
325 let conn = pool.acquire("server1").await.unwrap();
326
327 let is_valid = pool.validate(&conn).await;
328 assert!(is_valid);
329
330 pool.release(conn.clone()).await.unwrap();
331
332 let is_valid = pool.validate(&conn).await;
333 assert!(!is_valid);
334 }
335
336 #[tokio::test]
337 async fn test_health_check() {
338 let pool = ConnectionPool::new();
339 let mut conn = pool.acquire("server1").await.unwrap();
340 pool.release(conn.clone()).await.unwrap();
341
342 conn.invalidate();
343 let mut available = pool.available.write().await;
344 available.push_back(conn);
345 drop(available);
346
347 let removed = pool.health_check().await;
348 assert_eq!(removed, 1);
349 }
350
351 #[tokio::test]
352 async fn test_clear_pool() {
353 let pool = ConnectionPool::new();
354 let _conn1 = pool.acquire("server1").await.unwrap();
355 let _conn2 = pool.acquire("server1").await.unwrap();
356
357 pool.clear().await;
358
359 let stats = pool.get_stats().await;
360 assert_eq!(stats.total_connections, 0);
361 }
362}