1use crate::backend::{BackendClient, BackendConfig};
7use crate::{NodeId, ProxyError, Result};
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::{OwnedSemaphorePermit, RwLock, Semaphore};
13use uuid::Uuid;
14
15#[derive(Debug, Clone)]
17pub struct PoolConfig {
18 pub min_connections: usize,
20 pub max_connections: usize,
22 pub idle_timeout: Duration,
24 pub max_lifetime: Duration,
26 pub acquire_timeout: Duration,
28 pub test_on_acquire: bool,
30}
31
32impl Default for PoolConfig {
33 fn default() -> Self {
34 Self {
35 min_connections: 2,
36 max_connections: 10,
37 idle_timeout: Duration::from_secs(300),
38 max_lifetime: Duration::from_secs(1800),
39 acquire_timeout: Duration::from_secs(30),
40 test_on_acquire: true,
41 }
42 }
43}
44
45pub struct PooledConnection {
47 pub id: Uuid,
49 pub node_id: NodeId,
51 pub created_at: chrono::DateTime<chrono::Utc>,
53 pub last_used: chrono::DateTime<chrono::Utc>,
55 pub state: ConnectionState,
57 pub use_count: u64,
59 pub(crate) permit: Option<OwnedSemaphorePermit>,
64 pub(crate) client: Option<BackendClient>,
69}
70
71impl std::fmt::Debug for PooledConnection {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 f.debug_struct("PooledConnection")
75 .field("id", &self.id)
76 .field("node_id", &self.node_id)
77 .field("created_at", &self.created_at)
78 .field("last_used", &self.last_used)
79 .field("state", &self.state)
80 .field("use_count", &self.use_count)
81 .field("has_permit", &self.permit.is_some())
82 .field("has_live_client", &self.client.is_some())
83 .finish()
84 }
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
89pub enum ConnectionState {
90 Idle,
92 InUse,
94 Validating,
96 Closed,
98}
99
100struct NodePool {
102 connections: Vec<PooledConnection>,
104 semaphore: Arc<Semaphore>,
106 total_created: u64,
108 total_closed: u64,
110 endpoint: Option<(String, u16)>,
114}
115
116impl NodePool {
117 fn new(max_connections: usize) -> Self {
118 Self {
119 connections: Vec::new(),
120 semaphore: Arc::new(Semaphore::new(max_connections)),
121 total_created: 0,
122 total_closed: 0,
123 endpoint: None,
124 }
125 }
126}
127
128pub struct ConnectionPool {
130 config: PoolConfig,
132 pools: Arc<RwLock<HashMap<NodeId, NodePool>>>,
134 total_connections: AtomicU64,
136 active_connections: AtomicU64,
138 metrics: PoolMetricsCounters,
140 backend_template: Option<BackendConfig>,
145}
146
147#[derive(Debug, Default)]
150struct PoolMetricsCounters {
151 acquires: AtomicU64,
152 acquire_failures: AtomicU64,
153 connections_created: AtomicU64,
154 connections_closed: AtomicU64,
155 connections_recycled: AtomicU64,
156 validation_failures: AtomicU64,
157 acquire_timeouts: AtomicU64,
158}
159
160impl PoolMetricsCounters {
161 fn snapshot(&self) -> PoolMetrics {
162 PoolMetrics {
163 acquires: self.acquires.load(Ordering::Relaxed),
164 acquire_failures: self.acquire_failures.load(Ordering::Relaxed),
165 connections_created: self.connections_created.load(Ordering::Relaxed),
166 connections_closed: self.connections_closed.load(Ordering::Relaxed),
167 connections_recycled: self.connections_recycled.load(Ordering::Relaxed),
168 validation_failures: self.validation_failures.load(Ordering::Relaxed),
169 acquire_timeouts: self.acquire_timeouts.load(Ordering::Relaxed),
170 }
171 }
172}
173
174#[derive(Debug, Clone, Default)]
176pub struct PoolMetrics {
177 pub acquires: u64,
179 pub acquire_failures: u64,
181 pub connections_created: u64,
183 pub connections_closed: u64,
185 pub connections_recycled: u64,
187 pub validation_failures: u64,
189 pub acquire_timeouts: u64,
191}
192
193impl ConnectionPool {
194 pub fn new(config: PoolConfig) -> Self {
196 Self {
197 config,
198 pools: Arc::new(RwLock::new(HashMap::new())),
199 total_connections: AtomicU64::new(0),
200 active_connections: AtomicU64::new(0),
201 metrics: PoolMetricsCounters::default(),
202 backend_template: None,
203 }
204 }
205
206 pub fn with_backend_template(mut self, template: BackendConfig) -> Self {
212 self.backend_template = Some(template);
213 self
214 }
215
216 pub async fn add_node(&self, node_id: NodeId) {
218 let mut pools = self.pools.write().await;
219 if !pools.contains_key(&node_id) {
220 pools.insert(
221 node_id,
222 NodePool::new(self.config.max_connections),
223 );
224 tracing::debug!("Added node {:?} to connection pool", node_id);
225 }
226 }
227
228 pub async fn add_node_with_endpoint(
232 &self,
233 node_id: NodeId,
234 host: impl Into<String>,
235 port: u16,
236 ) {
237 let mut pools = self.pools.write().await;
238 if !pools.contains_key(&node_id) {
239 let mut np = NodePool::new(self.config.max_connections);
240 np.endpoint = Some((host.into(), port));
241 pools.insert(node_id, np);
242 tracing::debug!("Added node {:?} to connection pool (with endpoint)", node_id);
243 }
244 }
245
246 pub async fn remove_node(&self, node_id: &NodeId) {
248 let mut pools = self.pools.write().await;
249 if let Some(pool) = pools.remove(node_id) {
250 let count = pool.connections.len() as u64;
251 self.total_connections.fetch_sub(count, Ordering::SeqCst);
252 tracing::debug!("Removed node {:?} from connection pool", node_id);
253 }
254 }
255
256 pub async fn get_connection(&self, node_id: &NodeId) -> Result<PooledConnection> {
258 self.metrics.acquires.fetch_add(1, Ordering::Relaxed);
259
260 let (mut maybe_idle, semaphore) = {
264 let mut pools = self.pools.write().await;
265 let pool = pools.get_mut(node_id).ok_or_else(|| {
266 ProxyError::Connection(format!("Node {:?} not found in pool", node_id))
267 })?;
268
269 let semaphore = pool.semaphore.clone();
270 let idle = pool
271 .connections
272 .iter()
273 .position(|c| c.state == ConnectionState::Idle)
274 .map(|idx| pool.connections.swap_remove(idx));
275 (idle, semaphore)
276 };
277
278 if let Some(mut conn) = maybe_idle.take() {
282 let age = chrono::Utc::now()
283 .signed_duration_since(conn.created_at)
284 .to_std()
285 .unwrap_or(Duration::ZERO);
286
287 if age <= self.config.max_lifetime {
288 conn.state = ConnectionState::InUse;
289 conn.last_used = chrono::Utc::now();
290 conn.use_count += 1;
291 self.active_connections.fetch_add(1, Ordering::SeqCst);
292 return Ok(conn);
293 }
294
295 self.metrics
298 .connections_recycled
299 .fetch_add(1, Ordering::Relaxed);
300 self.total_connections.fetch_sub(1, Ordering::SeqCst);
301 drop(conn);
302 }
303
304 let permit = match tokio::time::timeout(
307 self.config.acquire_timeout,
308 semaphore.acquire_owned(),
309 )
310 .await
311 {
312 Ok(Ok(p)) => p,
313 Ok(Err(_)) => {
314 self.metrics
315 .acquire_failures
316 .fetch_add(1, Ordering::Relaxed);
317 return Err(ProxyError::PoolExhausted(format!(
318 "Failed to acquire semaphore for node {:?}",
319 node_id
320 )));
321 }
322 Err(_) => {
323 self.metrics
324 .acquire_timeouts
325 .fetch_add(1, Ordering::Relaxed);
326 return Err(ProxyError::Timeout(format!(
327 "Timeout acquiring connection for node {:?}",
328 node_id
329 )));
330 }
331 };
332
333 let conn = self.create_connection(*node_id, Some(permit)).await?;
334 self.active_connections.fetch_add(1, Ordering::SeqCst);
335 self.total_connections.fetch_add(1, Ordering::SeqCst);
336
337 {
338 let mut pools = self.pools.write().await;
339 if let Some(pool) = pools.get_mut(node_id) {
340 pool.total_created += 1;
341 }
342 }
343
344 Ok(conn)
345 }
346
347 pub async fn return_connection(&self, mut conn: PooledConnection) {
349 self.active_connections.fetch_sub(1, Ordering::SeqCst);
350
351 let mut pools = self.pools.write().await;
352 if let Some(pool) = pools.get_mut(&conn.node_id) {
353 conn.state = ConnectionState::Idle;
354 conn.last_used = chrono::Utc::now();
355 pool.connections.push(conn);
356 }
357 }
358
359 pub async fn close_connection(&self, conn: PooledConnection) {
361 self.active_connections.fetch_sub(1, Ordering::SeqCst);
362 self.total_connections.fetch_sub(1, Ordering::SeqCst);
363 self.metrics
364 .connections_closed
365 .fetch_add(1, Ordering::Relaxed);
366
367 let mut pools = self.pools.write().await;
368 if let Some(pool) = pools.get_mut(&conn.node_id) {
369 pool.total_closed += 1;
370 }
371
372 tracing::debug!("Closed connection {:?}", conn.id);
373 }
374
375 async fn create_connection(
384 &self,
385 node_id: NodeId,
386 permit: Option<OwnedSemaphorePermit>,
387 ) -> Result<PooledConnection> {
388 let endpoint = self
389 .pools
390 .read()
391 .await
392 .get(&node_id)
393 .and_then(|p| p.endpoint.clone());
394
395 let client = match (&self.backend_template, endpoint) {
396 (Some(template), Some((host, port))) => {
397 let mut cfg = template.clone();
398 cfg.host = host;
399 cfg.port = port;
400 match BackendClient::connect(&cfg).await {
401 Ok(c) => Some(c),
402 Err(e) => {
403 return Err(ProxyError::Connection(format!(
404 "backend connect for node {:?} failed: {}",
405 node_id, e
406 )));
407 }
408 }
409 }
410 _ => None,
411 };
412
413 let now = chrono::Utc::now();
414 let conn = PooledConnection {
415 id: Uuid::new_v4(),
416 node_id,
417 created_at: now,
418 last_used: now,
419 state: ConnectionState::InUse,
420 use_count: 1,
421 permit,
422 client,
423 };
424
425 self.metrics
426 .connections_created
427 .fetch_add(1, Ordering::Relaxed);
428
429 tracing::debug!(
430 "Created connection {:?} for node {:?} (live={})",
431 conn.id,
432 node_id,
433 conn.client.is_some()
434 );
435
436 Ok(conn)
437 }
438
439 pub async fn validate_connection(&self, conn: &PooledConnection) -> Result<bool> {
445 if conn.state == ConnectionState::Closed {
446 self.metrics
447 .validation_failures
448 .fetch_add(1, Ordering::Relaxed);
449 return Ok(false);
450 }
451 if let Some(client) = &conn.client {
455 let _ = client; }
463 Ok(true)
464 }
465
466 pub async fn run_reset_query(
470 &self,
471 conn: &mut PooledConnection,
472 query: &str,
473 ) -> Result<()> {
474 if let Some(client) = conn.client.as_mut() {
475 client
476 .execute(query)
477 .await
478 .map_err(|e| ProxyError::Connection(format!("reset query failed: {}", e)))?;
479 }
480 Ok(())
481 }
482
483 pub async fn close_all(&self) -> Result<()> {
485 let mut pools = self.pools.write().await;
486 for (_, pool) in pools.iter_mut() {
487 pool.connections.clear();
488 }
489 self.total_connections.store(0, Ordering::SeqCst);
490 self.active_connections.store(0, Ordering::SeqCst);
491 tracing::info!("Closed all connections");
492 Ok(())
493 }
494
495 pub async fn evict_idle(&self) {
497 let mut pools = self.pools.write().await;
498 let mut evicted = 0;
499
500 for (_, pool) in pools.iter_mut() {
501 let before = pool.connections.len();
502 pool.connections.retain(|conn| {
503 let idle_time = chrono::Utc::now()
504 .signed_duration_since(conn.last_used)
505 .to_std()
506 .unwrap_or(Duration::ZERO);
507
508 idle_time < self.config.idle_timeout
509 });
510 evicted += before - pool.connections.len();
511 }
512
513 if evicted > 0 {
514 self.total_connections
515 .fetch_sub(evicted as u64, Ordering::SeqCst);
516 tracing::debug!("Evicted {} idle connections", evicted);
517 }
518 }
519
520 pub async fn total_connections(&self) -> usize {
522 self.total_connections.load(Ordering::SeqCst) as usize
523 }
524
525 pub async fn active_connections(&self) -> usize {
527 self.active_connections.load(Ordering::SeqCst) as usize
528 }
529
530 pub async fn metrics(&self) -> PoolMetrics {
532 self.metrics.snapshot()
533 }
534
535 pub async fn node_stats(&self, node_id: &NodeId) -> Option<NodePoolStats> {
537 let pools = self.pools.read().await;
538 pools.get(node_id).map(|pool| NodePoolStats {
539 idle_connections: pool
540 .connections
541 .iter()
542 .filter(|c| c.state == ConnectionState::Idle)
543 .count(),
544 total_created: pool.total_created,
545 total_closed: pool.total_closed,
546 })
547 }
548}
549
550#[derive(Debug, Clone)]
552pub struct NodePoolStats {
553 pub idle_connections: usize,
555 pub total_created: u64,
557 pub total_closed: u64,
559}
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564
565 #[test]
566 fn test_pool_config_default() {
567 let config = PoolConfig::default();
568 assert_eq!(config.min_connections, 2);
569 assert_eq!(config.max_connections, 10);
570 assert!(config.test_on_acquire);
571 }
572
573 #[tokio::test]
574 async fn test_add_remove_node() {
575 let pool = ConnectionPool::new(PoolConfig::default());
576 let node_id = NodeId::new();
577
578 pool.add_node(node_id).await;
579 assert!(pool.node_stats(&node_id).await.is_some());
580
581 pool.remove_node(&node_id).await;
582 assert!(pool.node_stats(&node_id).await.is_none());
583 }
584
585 #[tokio::test]
586 async fn test_get_return_connection() {
587 let pool = ConnectionPool::new(PoolConfig::default());
588 let node_id = NodeId::new();
589
590 pool.add_node(node_id).await;
591
592 let conn = pool.get_connection(&node_id).await.expect("get failed");
594 assert_eq!(conn.node_id, node_id);
595 assert_eq!(conn.state, ConnectionState::InUse);
596 assert_eq!(pool.active_connections().await, 1);
597
598 pool.return_connection(conn).await;
600 assert_eq!(pool.active_connections().await, 0);
601 }
602
603 #[tokio::test]
604 async fn test_metrics() {
605 let pool = ConnectionPool::new(PoolConfig::default());
606 let node_id = NodeId::new();
607
608 pool.add_node(node_id).await;
609
610 let conn = pool.get_connection(&node_id).await.expect("get failed");
611 pool.return_connection(conn).await;
612
613 let metrics = pool.metrics().await;
614 assert_eq!(metrics.acquires, 1);
615 assert_eq!(metrics.connections_created, 1);
616 }
617
618 #[tokio::test]
623 async fn test_max_connections_enforced_while_in_use() {
624 let pool = ConnectionPool::new(PoolConfig {
625 min_connections: 0,
626 max_connections: 2,
627 acquire_timeout: Duration::from_millis(50),
628 ..Default::default()
629 });
630 let node_id = NodeId::new();
631 pool.add_node(node_id).await;
632
633 let c1 = pool.get_connection(&node_id).await.expect("first acquire");
634 let c2 = pool.get_connection(&node_id).await.expect("second acquire");
635
636 let err = pool
639 .get_connection(&node_id)
640 .await
641 .expect_err("third acquire should time out while c1/c2 held");
642 assert!(
643 matches!(err, ProxyError::Timeout(_)),
644 "expected Timeout, got {err:?}"
645 );
646
647 drop(c1);
649 let _c3 = pool
650 .get_connection(&node_id)
651 .await
652 .expect("acquire should succeed after c1 dropped");
653
654 drop(c2);
656 }
657
658 #[tokio::test]
663 async fn test_backend_template_with_unreachable_endpoint_errors() {
664 use crate::backend::{tls::default_client_config, TlsMode};
665
666 let template = BackendConfig {
667 host: "placeholder".into(),
668 port: 0,
669 user: "postgres".into(),
670 password: None,
671 database: None,
672 application_name: Some("helios-pool".into()),
673 tls_mode: TlsMode::Disable,
674 connect_timeout: Duration::from_millis(200),
675 query_timeout: Duration::from_millis(200),
676 tls_config: default_client_config(),
677 };
678
679 let pool = ConnectionPool::new(PoolConfig {
680 max_connections: 2,
681 acquire_timeout: Duration::from_millis(300),
682 ..Default::default()
683 })
684 .with_backend_template(template);
685
686 let node_id = NodeId::new();
687 pool.add_node_with_endpoint(node_id, "127.0.0.1", 1).await;
689
690 let err = pool
691 .get_connection(&node_id)
692 .await
693 .expect_err("acquire must fail when backend is unreachable");
694 match err {
695 ProxyError::Connection(msg) => {
696 assert!(
697 msg.contains("backend connect"),
698 "expected backend-connect error, got {}",
699 msg
700 );
701 }
702 other => panic!("expected Connection error, got {:?}", other),
703 }
704 }
705
706 #[tokio::test]
711 async fn test_add_node_with_endpoint_but_no_template_returns_skeleton_client() {
712 let pool = ConnectionPool::new(PoolConfig::default());
713 let node_id = NodeId::new();
714 pool.add_node_with_endpoint(node_id, "127.0.0.1", 5432).await;
715
716 let conn = pool.get_connection(&node_id).await.expect("acquire");
717 assert!(conn.client.is_none(), "no template → no live client");
718 }
719
720 #[tokio::test]
723 async fn test_return_then_reacquire_reuses_permit() {
724 let pool = ConnectionPool::new(PoolConfig {
725 min_connections: 0,
726 max_connections: 1,
727 acquire_timeout: Duration::from_millis(50),
728 ..Default::default()
729 });
730 let node_id = NodeId::new();
731 pool.add_node(node_id).await;
732
733 let c1 = pool.get_connection(&node_id).await.expect("first acquire");
734 pool.return_connection(c1).await;
735
736 let c2 = pool.get_connection(&node_id).await.expect("reacquire");
739 assert!(c2.permit.is_some(), "reused connection must carry its permit");
740
741 let metrics = pool.metrics().await;
742 assert_eq!(
743 metrics.connections_created, 1,
744 "reuse must not create a second connection"
745 );
746 }
747}