1use crate::config::{ConnectionConfig, PoolConfig};
10use crate::connection::{Connection, PooledConnection};
11use crate::error::ClientError;
12use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
13use std::sync::Arc;
14use tokio::sync::mpsc;
15use tokio::sync::Semaphore;
16use tokio::time::timeout;
17
18pub struct ConnectionPool {
28 config: PoolConfig,
29 connection_config: ConnectionConfig,
30 return_tx: mpsc::UnboundedSender<Arc<Connection>>,
32 return_rx: tokio::sync::Mutex<mpsc::UnboundedReceiver<Arc<Connection>>>,
34 semaphore: Arc<Semaphore>,
36 total_created: AtomicU64,
37 total_acquired: AtomicU64,
38 total_released: Arc<AtomicU64>,
39 current_size: Arc<AtomicUsize>,
40 closed: std::sync::atomic::AtomicBool,
41}
42
43impl ConnectionPool {
44 pub async fn new(config: PoolConfig) -> Result<Self, ClientError> {
46 Self::with_connection_config(config, ConnectionConfig::default()).await
47 }
48
49 pub async fn with_connection_config(
51 config: PoolConfig,
52 connection_config: ConnectionConfig,
53 ) -> Result<Self, ClientError> {
54 let (return_tx, return_rx) = mpsc::unbounded_channel();
55
56 let pool = Self {
57 semaphore: Arc::new(Semaphore::new(config.max_connections)),
58 return_tx,
59 return_rx: tokio::sync::Mutex::new(return_rx),
60 total_created: AtomicU64::new(0),
61 total_acquired: AtomicU64::new(0),
62 total_released: Arc::new(AtomicU64::new(0)),
63 current_size: Arc::new(AtomicUsize::new(0)),
64 closed: std::sync::atomic::AtomicBool::new(false),
65 config,
66 connection_config,
67 };
68
69 pool.initialize().await?;
71
72 Ok(pool)
73 }
74
75 async fn initialize(&self) -> Result<(), ClientError> {
76 for _ in 0..self.config.min_connections {
77 let conn = self.create_connection().await?;
78 let _ = self.return_tx.send(conn);
80 }
81 Ok(())
82 }
83
84 async fn create_connection(&self) -> Result<Arc<Connection>, ClientError> {
85 let conn = Connection::new(self.connection_config.clone()).await?;
86 self.total_created.fetch_add(1, Ordering::SeqCst);
87 self.current_size.fetch_add(1, Ordering::SeqCst);
88 Ok(Arc::new(conn))
89 }
90
91 async fn try_recv_usable(&self) -> Option<Arc<Connection>> {
96 let mut rx = self.return_rx.lock().await;
97 loop {
98 match rx.try_recv() {
99 Ok(conn) => {
100 if conn.is_connected() && conn.idle_time() < self.config.idle_timeout {
101 return Some(conn);
102 }
103 self.current_size.fetch_sub(1, Ordering::SeqCst);
105 }
106 Err(_) => return None,
107 }
108 }
109 }
110
111 pub async fn get(&self) -> Result<PooledConnection, ClientError> {
113 if self.closed.load(Ordering::SeqCst) {
114 return Err(ClientError::ConnectionClosed);
115 }
116
117 let permit_result = timeout(
119 self.config.acquire_timeout,
120 self.semaphore.clone().acquire_owned(),
121 )
122 .await;
123
124 let permit = match permit_result {
125 Ok(Ok(p)) => p,
126 Ok(Err(_)) => return Err(ClientError::PoolExhausted),
127 Err(_) => return Err(ClientError::PoolTimeout),
128 };
129
130 let conn = if let Some(conn) = self.try_recv_usable().await {
132 conn
133 } else {
134 self.create_connection().await?
136 };
137
138 self.total_acquired.fetch_add(1, Ordering::SeqCst);
139
140 let tx = self.return_tx.clone();
142 let released = Arc::clone(&self.total_released);
143 let current_size = Arc::clone(&self.current_size);
144 let closed = self.closed.load(Ordering::SeqCst);
145
146 Ok(PooledConnection::new(conn, move |conn| {
147 drop(permit);
149
150 if !closed && conn.is_connected() {
152 match tx.send(conn) {
153 Ok(_) => {
154 released.fetch_add(1, Ordering::SeqCst);
155 }
156 Err(_) => {
157 current_size.fetch_sub(1, Ordering::SeqCst);
159 }
160 }
161 } else {
162 current_size.fetch_sub(1, Ordering::SeqCst);
164 }
165 }))
166 }
167
168 pub async fn return_connection(&self, conn: Arc<Connection>) {
170 if !self.closed.load(Ordering::SeqCst) && conn.is_connected() {
171 let _ = self.return_tx.send(conn);
172 self.total_released.fetch_add(1, Ordering::SeqCst);
173 } else {
174 self.current_size.fetch_sub(1, Ordering::SeqCst);
175 }
176 }
177
178 pub async fn is_healthy(&self) -> bool {
180 if self.closed.load(Ordering::SeqCst) {
181 return false;
182 }
183
184 self.current_size.load(Ordering::SeqCst) > 0
186 }
187
188 pub fn stats(&self) -> PoolStats {
190 PoolStats {
191 total_created: self.total_created.load(Ordering::SeqCst),
192 total_acquired: self.total_acquired.load(Ordering::SeqCst),
193 total_released: self.total_released.load(Ordering::SeqCst),
194 current_size: self.current_size.load(Ordering::SeqCst),
195 max_size: self.config.max_connections,
196 min_size: self.config.min_connections,
197 available_permits: self.semaphore.available_permits(),
198 }
199 }
200
201 pub async fn close(&self) {
203 self.closed.store(true, Ordering::SeqCst);
204
205 let mut rx = self.return_rx.lock().await;
207 while let Ok(conn) = rx.try_recv() {
208 conn.close().await;
209 self.current_size.fetch_sub(1, Ordering::SeqCst);
210 }
211 }
212
213 pub fn size(&self) -> usize {
215 self.current_size.load(Ordering::SeqCst)
216 }
217
218 pub fn available(&self) -> usize {
220 self.semaphore.available_permits()
221 }
222}
223
224#[derive(Debug, Clone)]
230pub struct PoolStats {
231 pub total_created: u64,
232 pub total_acquired: u64,
233 pub total_released: u64,
234 pub current_size: usize,
235 pub max_size: usize,
236 pub min_size: usize,
237 pub available_permits: usize,
238}
239
240impl PoolStats {
241 pub fn utilization(&self) -> f64 {
243 if self.max_size == 0 {
244 return 0.0;
245 }
246 let in_use = self.max_size - self.available_permits;
247 (in_use as f64 / self.max_size as f64) * 100.0
248 }
249}
250
251#[cfg(test)]
256mod tests {
257 use super::*;
258
259 fn test_connection_config() -> ConnectionConfig {
261 let port = std::env::var("AEGIS_TEST_PORT")
262 .ok()
263 .and_then(|p| p.parse().ok())
264 .unwrap_or(9090);
265 ConnectionConfig {
266 host: "127.0.0.1".to_string(),
267 port,
268 ..Default::default()
269 }
270 }
271
272 async fn create_test_pool(pool_config: PoolConfig) -> Result<ConnectionPool, ClientError> {
274 ConnectionPool::with_connection_config(pool_config, test_connection_config()).await
275 }
276
277 #[tokio::test]
278 async fn test_pool_creation() {
279 let config = PoolConfig {
280 min_connections: 2,
281 max_connections: 5,
282 ..Default::default()
283 };
284
285 match create_test_pool(config).await {
286 Ok(pool) => assert_eq!(pool.size(), 2),
287 Err(e) => eprintln!("Skipping test, server not available: {}", e),
288 }
289 }
290
291 #[tokio::test]
292 async fn test_pool_get_connection() {
293 let config = PoolConfig::default();
294
295 match create_test_pool(config).await {
296 Ok(pool) => {
297 let conn = pool.get().await.expect("Should get connection from pool");
298 assert!(conn.inner().is_connected());
299 }
300 Err(e) => eprintln!("Skipping test, server not available: {}", e),
301 }
302 }
303
304 #[tokio::test]
305 async fn test_pool_stats() {
306 let config = PoolConfig {
307 min_connections: 1,
308 max_connections: 5,
309 ..Default::default()
310 };
311
312 match create_test_pool(config).await {
313 Ok(pool) => {
314 let stats = pool.stats();
315 assert_eq!(stats.min_size, 1);
316 assert_eq!(stats.max_size, 5);
317 assert!(stats.total_created >= 1);
318 }
319 Err(e) => eprintln!("Skipping test, server not available: {}", e),
320 }
321 }
322
323 #[tokio::test]
324 async fn test_pool_acquire_multiple() {
325 let config = PoolConfig {
326 min_connections: 0,
327 max_connections: 3,
328 ..Default::default()
329 };
330
331 match create_test_pool(config).await {
332 Ok(pool) => {
333 let c1 = match pool.get().await {
335 Ok(c) => c,
336 Err(e) => {
337 eprintln!("Skipping test, server not available: {}", e);
338 return;
339 }
340 };
341 let c2 = pool
342 .get()
343 .await
344 .expect("Should get second connection from pool");
345 let c3 = pool
346 .get()
347 .await
348 .expect("Should get third connection from pool");
349
350 assert!(c1.inner().is_connected());
351 assert!(c2.inner().is_connected());
352 assert!(c3.inner().is_connected());
353
354 let stats = pool.stats();
355 assert_eq!(stats.total_acquired, 3);
356 }
357 Err(e) => eprintln!("Skipping test, server not available: {}", e),
358 }
359 }
360
361 #[tokio::test]
362 async fn test_pool_close() {
363 let config = PoolConfig {
364 min_connections: 2,
365 ..Default::default()
366 };
367
368 match create_test_pool(config).await {
369 Ok(pool) => {
370 assert!(pool.size() >= 2);
371 pool.close().await;
372 assert!(!pool.is_healthy().await);
373 }
374 Err(e) => eprintln!("Skipping test, server not available: {}", e),
375 }
376 }
377
378 #[tokio::test]
379 async fn test_pool_utilization() {
380 let stats = PoolStats {
381 total_created: 5,
382 total_acquired: 10,
383 total_released: 8,
384 current_size: 5,
385 max_size: 10,
386 min_size: 2,
387 available_permits: 8,
388 };
389
390 let util = stats.utilization();
391 assert!((util - 20.0).abs() < 0.01);
392 }
393}