Skip to main content

aegis_client/
pool.rs

1//! Aegis Client Connection Pool
2//!
3//! Connection pool management for efficient database access.
4//! Uses a channel-based approach for real connection recycling.
5//!
6//! @version 0.1.0
7//! @author AutomataNexus Development Team
8
9use crate::config::{ConnectionConfig, PoolConfig};
10use crate::connection::{Connection, PooledConnection};
11use crate::error::ClientError;
12use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
13use std::sync::Arc;
14use tokio::sync::mpsc;
15use tokio::sync::Semaphore;
16use tokio::time::timeout;
17
18// =============================================================================
19// Connection Pool
20// =============================================================================
21
22/// A pool of database connections.
23///
24/// Connections are recycled via an unbounded mpsc channel. When a `PooledConnection`
25/// is dropped, its inner `Arc<Connection>` is sent back through the channel so
26/// subsequent `get()` calls can reuse it instead of opening a new connection.
27pub struct ConnectionPool {
28    config: PoolConfig,
29    connection_config: ConnectionConfig,
30    /// Sender half – cloned into every `PooledConnection` for return-on-drop.
31    return_tx: mpsc::UnboundedSender<Arc<Connection>>,
32    /// Receiver half – drained on each `get()` to reclaim returned connections.
33    return_rx: tokio::sync::Mutex<mpsc::UnboundedReceiver<Arc<Connection>>>,
34    /// Limits the total number of connections that may be checked out at once.
35    semaphore: Arc<Semaphore>,
36    total_created: AtomicU64,
37    total_acquired: AtomicU64,
38    total_released: Arc<AtomicU64>,
39    current_size: Arc<AtomicUsize>,
40    closed: std::sync::atomic::AtomicBool,
41}
42
43impl ConnectionPool {
44    /// Create a new connection pool.
45    pub async fn new(config: PoolConfig) -> Result<Self, ClientError> {
46        Self::with_connection_config(config, ConnectionConfig::default()).await
47    }
48
49    /// Create a pool with specific connection configuration.
50    pub async fn with_connection_config(
51        config: PoolConfig,
52        connection_config: ConnectionConfig,
53    ) -> Result<Self, ClientError> {
54        let (return_tx, return_rx) = mpsc::unbounded_channel();
55
56        let pool = Self {
57            semaphore: Arc::new(Semaphore::new(config.max_connections)),
58            return_tx,
59            return_rx: tokio::sync::Mutex::new(return_rx),
60            total_created: AtomicU64::new(0),
61            total_acquired: AtomicU64::new(0),
62            total_released: Arc::new(AtomicU64::new(0)),
63            current_size: Arc::new(AtomicUsize::new(0)),
64            closed: std::sync::atomic::AtomicBool::new(false),
65            config,
66            connection_config,
67        };
68
69        // Pre-create minimum connections
70        pool.initialize().await?;
71
72        Ok(pool)
73    }
74
75    async fn initialize(&self) -> Result<(), ClientError> {
76        for _ in 0..self.config.min_connections {
77            let conn = self.create_connection().await?;
78            // Seed the channel with pre-created connections
79            let _ = self.return_tx.send(conn);
80        }
81        Ok(())
82    }
83
84    async fn create_connection(&self) -> Result<Arc<Connection>, ClientError> {
85        let conn = Connection::new(self.connection_config.clone()).await?;
86        self.total_created.fetch_add(1, Ordering::SeqCst);
87        self.current_size.fetch_add(1, Ordering::SeqCst);
88        Ok(Arc::new(conn))
89    }
90
91    /// Try to pull one usable connection from the return channel.
92    ///
93    /// Drains stale/disconnected connections and returns the first good one,
94    /// or `None` if the channel is empty.
95    async fn try_recv_usable(&self) -> Option<Arc<Connection>> {
96        let mut rx = self.return_rx.lock().await;
97        loop {
98            match rx.try_recv() {
99                Ok(conn) => {
100                    if conn.is_connected() && conn.idle_time() < self.config.idle_timeout {
101                        return Some(conn);
102                    }
103                    // Stale or disconnected – discard it
104                    self.current_size.fetch_sub(1, Ordering::SeqCst);
105                }
106                Err(_) => return None,
107            }
108        }
109    }
110
111    /// Get a connection from the pool.
112    pub async fn get(&self) -> Result<PooledConnection, ClientError> {
113        if self.closed.load(Ordering::SeqCst) {
114            return Err(ClientError::ConnectionClosed);
115        }
116
117        // Wait for a permit (limits total concurrent checkouts)
118        let permit_result = timeout(
119            self.config.acquire_timeout,
120            self.semaphore.clone().acquire_owned(),
121        )
122        .await;
123
124        let permit = match permit_result {
125            Ok(Ok(p)) => p,
126            Ok(Err(_)) => return Err(ClientError::PoolExhausted),
127            Err(_) => return Err(ClientError::PoolTimeout),
128        };
129
130        // Try to reuse a recycled connection from the channel
131        let conn = if let Some(conn) = self.try_recv_usable().await {
132            conn
133        } else {
134            // No idle connections available – create a fresh one
135            self.create_connection().await?
136        };
137
138        self.total_acquired.fetch_add(1, Ordering::SeqCst);
139
140        // Clone the sender so the PooledConnection can return itself on drop
141        let tx = self.return_tx.clone();
142        let released = Arc::clone(&self.total_released);
143        let current_size = Arc::clone(&self.current_size);
144        let closed = self.closed.load(Ordering::SeqCst);
145
146        Ok(PooledConnection::new(conn, move |conn| {
147            // Release the semaphore permit so another caller can proceed
148            drop(permit);
149
150            // If pool is still open and connection is alive, recycle it
151            if !closed && conn.is_connected() {
152                match tx.send(conn) {
153                    Ok(_) => {
154                        released.fetch_add(1, Ordering::SeqCst);
155                    }
156                    Err(_) => {
157                        // Channel closed (pool was dropped) – just let the connection drop
158                        current_size.fetch_sub(1, Ordering::SeqCst);
159                    }
160                }
161            } else {
162                // Connection is dead or pool is closed – discard
163                current_size.fetch_sub(1, Ordering::SeqCst);
164            }
165        }))
166    }
167
168    /// Return a connection to the pool (explicit async path).
169    pub async fn return_connection(&self, conn: Arc<Connection>) {
170        if !self.closed.load(Ordering::SeqCst) && conn.is_connected() {
171            let _ = self.return_tx.send(conn);
172            self.total_released.fetch_add(1, Ordering::SeqCst);
173        } else {
174            self.current_size.fetch_sub(1, Ordering::SeqCst);
175        }
176    }
177
178    /// Check if the pool is healthy.
179    pub async fn is_healthy(&self) -> bool {
180        if self.closed.load(Ordering::SeqCst) {
181            return false;
182        }
183
184        // The pool is healthy if there are connections in existence
185        self.current_size.load(Ordering::SeqCst) > 0
186    }
187
188    /// Get pool statistics.
189    pub fn stats(&self) -> PoolStats {
190        PoolStats {
191            total_created: self.total_created.load(Ordering::SeqCst),
192            total_acquired: self.total_acquired.load(Ordering::SeqCst),
193            total_released: self.total_released.load(Ordering::SeqCst),
194            current_size: self.current_size.load(Ordering::SeqCst),
195            max_size: self.config.max_connections,
196            min_size: self.config.min_connections,
197            available_permits: self.semaphore.available_permits(),
198        }
199    }
200
201    /// Close all connections in the pool.
202    pub async fn close(&self) {
203        self.closed.store(true, Ordering::SeqCst);
204
205        // Drain all returned connections from the channel and close them
206        let mut rx = self.return_rx.lock().await;
207        while let Ok(conn) = rx.try_recv() {
208            conn.close().await;
209            self.current_size.fetch_sub(1, Ordering::SeqCst);
210        }
211    }
212
213    /// Get the current pool size.
214    pub fn size(&self) -> usize {
215        self.current_size.load(Ordering::SeqCst)
216    }
217
218    /// Get available permits.
219    pub fn available(&self) -> usize {
220        self.semaphore.available_permits()
221    }
222}
223
224// =============================================================================
225// Pool Statistics
226// =============================================================================
227
228/// Statistics for the connection pool.
229#[derive(Debug, Clone)]
230pub struct PoolStats {
231    pub total_created: u64,
232    pub total_acquired: u64,
233    pub total_released: u64,
234    pub current_size: usize,
235    pub max_size: usize,
236    pub min_size: usize,
237    pub available_permits: usize,
238}
239
240impl PoolStats {
241    /// Get pool utilization as a percentage.
242    pub fn utilization(&self) -> f64 {
243        if self.max_size == 0 {
244            return 0.0;
245        }
246        let in_use = self.max_size - self.available_permits;
247        (in_use as f64 / self.max_size as f64) * 100.0
248    }
249}
250
251// =============================================================================
252// Tests
253// =============================================================================
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    /// Get test connection config - uses AEGIS_TEST_PORT env var or defaults to 9090
260    fn test_connection_config() -> ConnectionConfig {
261        let port = std::env::var("AEGIS_TEST_PORT")
262            .ok()
263            .and_then(|p| p.parse().ok())
264            .unwrap_or(9090);
265        ConnectionConfig {
266            host: "127.0.0.1".to_string(),
267            port,
268            ..Default::default()
269        }
270    }
271
272    /// Helper to create pool with test config
273    async fn create_test_pool(pool_config: PoolConfig) -> Result<ConnectionPool, ClientError> {
274        ConnectionPool::with_connection_config(pool_config, test_connection_config()).await
275    }
276
277    #[tokio::test]
278    async fn test_pool_creation() {
279        let config = PoolConfig {
280            min_connections: 2,
281            max_connections: 5,
282            ..Default::default()
283        };
284
285        match create_test_pool(config).await {
286            Ok(pool) => assert_eq!(pool.size(), 2),
287            Err(e) => eprintln!("Skipping test, server not available: {}", e),
288        }
289    }
290
291    #[tokio::test]
292    async fn test_pool_get_connection() {
293        let config = PoolConfig::default();
294
295        match create_test_pool(config).await {
296            Ok(pool) => {
297                let conn = pool.get().await.expect("Should get connection from pool");
298                assert!(conn.inner().is_connected());
299            }
300            Err(e) => eprintln!("Skipping test, server not available: {}", e),
301        }
302    }
303
304    #[tokio::test]
305    async fn test_pool_stats() {
306        let config = PoolConfig {
307            min_connections: 1,
308            max_connections: 5,
309            ..Default::default()
310        };
311
312        match create_test_pool(config).await {
313            Ok(pool) => {
314                let stats = pool.stats();
315                assert_eq!(stats.min_size, 1);
316                assert_eq!(stats.max_size, 5);
317                assert!(stats.total_created >= 1);
318            }
319            Err(e) => eprintln!("Skipping test, server not available: {}", e),
320        }
321    }
322
323    #[tokio::test]
324    async fn test_pool_acquire_multiple() {
325        let config = PoolConfig {
326            min_connections: 0,
327            max_connections: 3,
328            ..Default::default()
329        };
330
331        match create_test_pool(config).await {
332            Ok(pool) => {
333                // Try to acquire connections - may fail if server isn't running
334                let c1 = match pool.get().await {
335                    Ok(c) => c,
336                    Err(e) => {
337                        eprintln!("Skipping test, server not available: {}", e);
338                        return;
339                    }
340                };
341                let c2 = pool
342                    .get()
343                    .await
344                    .expect("Should get second connection from pool");
345                let c3 = pool
346                    .get()
347                    .await
348                    .expect("Should get third connection from pool");
349
350                assert!(c1.inner().is_connected());
351                assert!(c2.inner().is_connected());
352                assert!(c3.inner().is_connected());
353
354                let stats = pool.stats();
355                assert_eq!(stats.total_acquired, 3);
356            }
357            Err(e) => eprintln!("Skipping test, server not available: {}", e),
358        }
359    }
360
361    #[tokio::test]
362    async fn test_pool_close() {
363        let config = PoolConfig {
364            min_connections: 2,
365            ..Default::default()
366        };
367
368        match create_test_pool(config).await {
369            Ok(pool) => {
370                assert!(pool.size() >= 2);
371                pool.close().await;
372                assert!(!pool.is_healthy().await);
373            }
374            Err(e) => eprintln!("Skipping test, server not available: {}", e),
375        }
376    }
377
378    #[tokio::test]
379    async fn test_pool_utilization() {
380        let stats = PoolStats {
381            total_created: 5,
382            total_acquired: 10,
383            total_released: 8,
384            current_size: 5,
385            max_size: 10,
386            min_size: 2,
387            available_permits: 8,
388        };
389
390        let util = stats.utilization();
391        assert!((util - 20.0).abs() < 0.01);
392    }
393}