1use crate::backend::{BackendClient, BackendConfig};
7use crate::{NodeId, ProxyError, Result};
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::{OwnedSemaphorePermit, RwLock, Semaphore};
13use uuid::Uuid;
14
15#[derive(Debug, Clone)]
17pub struct PoolConfig {
18 pub min_connections: usize,
20 pub max_connections: usize,
22 pub idle_timeout: Duration,
24 pub max_lifetime: Duration,
26 pub acquire_timeout: Duration,
28 pub test_on_acquire: bool,
30}
31
32impl Default for PoolConfig {
33 fn default() -> Self {
34 Self {
35 min_connections: 2,
36 max_connections: 10,
37 idle_timeout: Duration::from_secs(300),
38 max_lifetime: Duration::from_secs(1800),
39 acquire_timeout: Duration::from_secs(30),
40 test_on_acquire: true,
41 }
42 }
43}
44
45pub struct PooledConnection {
47 pub id: Uuid,
49 pub node_id: NodeId,
51 pub created_at: chrono::DateTime<chrono::Utc>,
53 pub last_used: chrono::DateTime<chrono::Utc>,
55 pub state: ConnectionState,
57 pub use_count: u64,
59 pub(crate) permit: Option<OwnedSemaphorePermit>,
64 pub(crate) client: Option<BackendClient>,
69}
70
71impl std::fmt::Debug for PooledConnection {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 f.debug_struct("PooledConnection")
75 .field("id", &self.id)
76 .field("node_id", &self.node_id)
77 .field("created_at", &self.created_at)
78 .field("last_used", &self.last_used)
79 .field("state", &self.state)
80 .field("use_count", &self.use_count)
81 .field("has_permit", &self.permit.is_some())
82 .field("has_live_client", &self.client.is_some())
83 .finish()
84 }
85}
86
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
89pub enum ConnectionState {
90 Idle,
92 InUse,
94 Validating,
96 Closed,
98}
99
100struct NodePool {
102 connections: Vec<PooledConnection>,
104 semaphore: Arc<Semaphore>,
106 total_created: u64,
108 total_closed: u64,
110 endpoint: Option<(String, u16)>,
114}
115
116impl NodePool {
117 fn new(max_connections: usize) -> Self {
118 Self {
119 connections: Vec::new(),
120 semaphore: Arc::new(Semaphore::new(max_connections)),
121 total_created: 0,
122 total_closed: 0,
123 endpoint: None,
124 }
125 }
126}
127
128pub struct ConnectionPool {
130 config: PoolConfig,
132 pools: Arc<RwLock<HashMap<NodeId, NodePool>>>,
134 total_connections: AtomicU64,
136 active_connections: AtomicU64,
138 metrics: PoolMetricsCounters,
140 backend_template: Option<BackendConfig>,
145}
146
147#[derive(Debug, Default)]
150struct PoolMetricsCounters {
151 acquires: AtomicU64,
152 acquire_failures: AtomicU64,
153 connections_created: AtomicU64,
154 connections_closed: AtomicU64,
155 connections_recycled: AtomicU64,
156 validation_failures: AtomicU64,
157 acquire_timeouts: AtomicU64,
158}
159
160impl PoolMetricsCounters {
161 fn snapshot(&self) -> PoolMetrics {
162 PoolMetrics {
163 acquires: self.acquires.load(Ordering::Relaxed),
164 acquire_failures: self.acquire_failures.load(Ordering::Relaxed),
165 connections_created: self.connections_created.load(Ordering::Relaxed),
166 connections_closed: self.connections_closed.load(Ordering::Relaxed),
167 connections_recycled: self.connections_recycled.load(Ordering::Relaxed),
168 validation_failures: self.validation_failures.load(Ordering::Relaxed),
169 acquire_timeouts: self.acquire_timeouts.load(Ordering::Relaxed),
170 }
171 }
172}
173
174#[derive(Debug, Clone, Default)]
176pub struct PoolMetrics {
177 pub acquires: u64,
179 pub acquire_failures: u64,
181 pub connections_created: u64,
183 pub connections_closed: u64,
185 pub connections_recycled: u64,
187 pub validation_failures: u64,
189 pub acquire_timeouts: u64,
191}
192
193impl ConnectionPool {
194 pub fn new(config: PoolConfig) -> Self {
196 Self {
197 config,
198 pools: Arc::new(RwLock::new(HashMap::new())),
199 total_connections: AtomicU64::new(0),
200 active_connections: AtomicU64::new(0),
201 metrics: PoolMetricsCounters::default(),
202 backend_template: None,
203 }
204 }
205
206 pub fn with_backend_template(mut self, template: BackendConfig) -> Self {
212 self.backend_template = Some(template);
213 self
214 }
215
216 pub async fn add_node(&self, node_id: NodeId) {
218 let mut pools = self.pools.write().await;
219 if let std::collections::hash_map::Entry::Vacant(e) = pools.entry(node_id) {
220 e.insert(NodePool::new(self.config.max_connections));
221 tracing::debug!("Added node {:?} to connection pool", node_id);
222 }
223 }
224
225 pub async fn add_node_with_endpoint(
229 &self,
230 node_id: NodeId,
231 host: impl Into<String>,
232 port: u16,
233 ) {
234 let mut pools = self.pools.write().await;
235 if let std::collections::hash_map::Entry::Vacant(e) = pools.entry(node_id) {
236 let mut np = NodePool::new(self.config.max_connections);
237 np.endpoint = Some((host.into(), port));
238 e.insert(np);
239 tracing::debug!(
240 "Added node {:?} to connection pool (with endpoint)",
241 node_id
242 );
243 }
244 }
245
246 pub async fn remove_node(&self, node_id: &NodeId) {
248 let mut pools = self.pools.write().await;
249 if let Some(pool) = pools.remove(node_id) {
250 let count = pool.connections.len() as u64;
251 self.total_connections.fetch_sub(count, Ordering::SeqCst);
252 tracing::debug!("Removed node {:?} from connection pool", node_id);
253 }
254 }
255
256 pub async fn get_connection(&self, node_id: &NodeId) -> Result<PooledConnection> {
258 self.metrics.acquires.fetch_add(1, Ordering::Relaxed);
259
260 let (mut maybe_idle, semaphore) = {
264 let mut pools = self.pools.write().await;
265 let pool = pools.get_mut(node_id).ok_or_else(|| {
266 ProxyError::Connection(format!("Node {:?} not found in pool", node_id))
267 })?;
268
269 let semaphore = pool.semaphore.clone();
270 let idle = pool
271 .connections
272 .iter()
273 .position(|c| c.state == ConnectionState::Idle)
274 .map(|idx| pool.connections.swap_remove(idx));
275 (idle, semaphore)
276 };
277
278 if let Some(mut conn) = maybe_idle.take() {
282 let age = chrono::Utc::now()
283 .signed_duration_since(conn.created_at)
284 .to_std()
285 .unwrap_or(Duration::ZERO);
286
287 if age <= self.config.max_lifetime {
288 conn.state = ConnectionState::InUse;
289 conn.last_used = chrono::Utc::now();
290 conn.use_count += 1;
291 self.active_connections.fetch_add(1, Ordering::SeqCst);
292 return Ok(conn);
293 }
294
295 self.metrics
298 .connections_recycled
299 .fetch_add(1, Ordering::Relaxed);
300 self.total_connections.fetch_sub(1, Ordering::SeqCst);
301 drop(conn);
302 }
303
304 let permit = match tokio::time::timeout(
307 self.config.acquire_timeout,
308 semaphore.acquire_owned(),
309 )
310 .await
311 {
312 Ok(Ok(p)) => p,
313 Ok(Err(_)) => {
314 self.metrics
315 .acquire_failures
316 .fetch_add(1, Ordering::Relaxed);
317 return Err(ProxyError::PoolExhausted(format!(
318 "Failed to acquire semaphore for node {:?}",
319 node_id
320 )));
321 }
322 Err(_) => {
323 self.metrics
324 .acquire_timeouts
325 .fetch_add(1, Ordering::Relaxed);
326 return Err(ProxyError::Timeout(format!(
327 "Timeout acquiring connection for node {:?}",
328 node_id
329 )));
330 }
331 };
332
333 let conn = self.create_connection(*node_id, Some(permit)).await?;
334 self.active_connections.fetch_add(1, Ordering::SeqCst);
335 self.total_connections.fetch_add(1, Ordering::SeqCst);
336
337 {
338 let mut pools = self.pools.write().await;
339 if let Some(pool) = pools.get_mut(node_id) {
340 pool.total_created += 1;
341 }
342 }
343
344 Ok(conn)
345 }
346
347 pub async fn return_connection(&self, mut conn: PooledConnection) {
349 self.active_connections.fetch_sub(1, Ordering::SeqCst);
350
351 let mut pools = self.pools.write().await;
352 if let Some(pool) = pools.get_mut(&conn.node_id) {
353 conn.state = ConnectionState::Idle;
354 conn.last_used = chrono::Utc::now();
355 pool.connections.push(conn);
356 }
357 }
358
359 pub async fn close_connection(&self, conn: PooledConnection) {
361 self.active_connections.fetch_sub(1, Ordering::SeqCst);
362 self.total_connections.fetch_sub(1, Ordering::SeqCst);
363 self.metrics
364 .connections_closed
365 .fetch_add(1, Ordering::Relaxed);
366
367 let mut pools = self.pools.write().await;
368 if let Some(pool) = pools.get_mut(&conn.node_id) {
369 pool.total_closed += 1;
370 }
371
372 tracing::debug!("Closed connection {:?}", conn.id);
373 }
374
375 async fn create_connection(
384 &self,
385 node_id: NodeId,
386 permit: Option<OwnedSemaphorePermit>,
387 ) -> Result<PooledConnection> {
388 let endpoint = self
389 .pools
390 .read()
391 .await
392 .get(&node_id)
393 .and_then(|p| p.endpoint.clone());
394
395 let client = match (&self.backend_template, endpoint) {
396 (Some(template), Some((host, port))) => {
397 let mut cfg = template.clone();
398 cfg.host = host;
399 cfg.port = port;
400 match BackendClient::connect(&cfg).await {
401 Ok(c) => Some(c),
402 Err(e) => {
403 return Err(ProxyError::Connection(format!(
404 "backend connect for node {:?} failed: {}",
405 node_id, e
406 )));
407 }
408 }
409 }
410 _ => None,
411 };
412
413 let now = chrono::Utc::now();
414 let conn = PooledConnection {
415 id: Uuid::new_v4(),
416 node_id,
417 created_at: now,
418 last_used: now,
419 state: ConnectionState::InUse,
420 use_count: 1,
421 permit,
422 client,
423 };
424
425 self.metrics
426 .connections_created
427 .fetch_add(1, Ordering::Relaxed);
428
429 tracing::debug!(
430 "Created connection {:?} for node {:?} (live={})",
431 conn.id,
432 node_id,
433 conn.client.is_some()
434 );
435
436 Ok(conn)
437 }
438
439 pub async fn validate_connection(&self, conn: &PooledConnection) -> Result<bool> {
445 if conn.state == ConnectionState::Closed {
446 self.metrics
447 .validation_failures
448 .fetch_add(1, Ordering::Relaxed);
449 return Ok(false);
450 }
451 if let Some(client) = &conn.client {
455 let _ = client; }
463 Ok(true)
464 }
465
466 pub async fn run_reset_query(&self, conn: &mut PooledConnection, query: &str) -> Result<()> {
470 if let Some(client) = conn.client.as_mut() {
471 client
472 .execute(query)
473 .await
474 .map_err(|e| ProxyError::Connection(format!("reset query failed: {}", e)))?;
475 }
476 Ok(())
477 }
478
479 pub async fn close_all(&self) -> Result<()> {
481 let mut pools = self.pools.write().await;
482 for (_, pool) in pools.iter_mut() {
483 pool.connections.clear();
484 }
485 self.total_connections.store(0, Ordering::SeqCst);
486 self.active_connections.store(0, Ordering::SeqCst);
487 tracing::info!("Closed all connections");
488 Ok(())
489 }
490
491 pub async fn evict_idle(&self) {
493 let mut pools = self.pools.write().await;
494 let mut evicted = 0;
495
496 for (_, pool) in pools.iter_mut() {
497 let before = pool.connections.len();
498 pool.connections.retain(|conn| {
499 let idle_time = chrono::Utc::now()
500 .signed_duration_since(conn.last_used)
501 .to_std()
502 .unwrap_or(Duration::ZERO);
503
504 idle_time < self.config.idle_timeout
505 });
506 evicted += before - pool.connections.len();
507 }
508
509 if evicted > 0 {
510 self.total_connections
511 .fetch_sub(evicted as u64, Ordering::SeqCst);
512 tracing::debug!("Evicted {} idle connections", evicted);
513 }
514 }
515
516 pub async fn total_connections(&self) -> usize {
518 self.total_connections.load(Ordering::SeqCst) as usize
519 }
520
521 pub async fn active_connections(&self) -> usize {
523 self.active_connections.load(Ordering::SeqCst) as usize
524 }
525
526 pub async fn metrics(&self) -> PoolMetrics {
528 self.metrics.snapshot()
529 }
530
531 pub async fn node_stats(&self, node_id: &NodeId) -> Option<NodePoolStats> {
533 let pools = self.pools.read().await;
534 pools.get(node_id).map(|pool| NodePoolStats {
535 idle_connections: pool
536 .connections
537 .iter()
538 .filter(|c| c.state == ConnectionState::Idle)
539 .count(),
540 total_created: pool.total_created,
541 total_closed: pool.total_closed,
542 })
543 }
544}
545
546#[derive(Debug, Clone)]
548pub struct NodePoolStats {
549 pub idle_connections: usize,
551 pub total_created: u64,
553 pub total_closed: u64,
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560
561 #[test]
562 fn test_pool_config_default() {
563 let config = PoolConfig::default();
564 assert_eq!(config.min_connections, 2);
565 assert_eq!(config.max_connections, 10);
566 assert!(config.test_on_acquire);
567 }
568
569 #[tokio::test]
570 async fn test_add_remove_node() {
571 let pool = ConnectionPool::new(PoolConfig::default());
572 let node_id = NodeId::new();
573
574 pool.add_node(node_id).await;
575 assert!(pool.node_stats(&node_id).await.is_some());
576
577 pool.remove_node(&node_id).await;
578 assert!(pool.node_stats(&node_id).await.is_none());
579 }
580
581 #[tokio::test]
582 async fn test_get_return_connection() {
583 let pool = ConnectionPool::new(PoolConfig::default());
584 let node_id = NodeId::new();
585
586 pool.add_node(node_id).await;
587
588 let conn = pool.get_connection(&node_id).await.expect("get failed");
590 assert_eq!(conn.node_id, node_id);
591 assert_eq!(conn.state, ConnectionState::InUse);
592 assert_eq!(pool.active_connections().await, 1);
593
594 pool.return_connection(conn).await;
596 assert_eq!(pool.active_connections().await, 0);
597 }
598
599 #[tokio::test]
600 async fn test_metrics() {
601 let pool = ConnectionPool::new(PoolConfig::default());
602 let node_id = NodeId::new();
603
604 pool.add_node(node_id).await;
605
606 let conn = pool.get_connection(&node_id).await.expect("get failed");
607 pool.return_connection(conn).await;
608
609 let metrics = pool.metrics().await;
610 assert_eq!(metrics.acquires, 1);
611 assert_eq!(metrics.connections_created, 1);
612 }
613
614 #[tokio::test]
619 async fn test_max_connections_enforced_while_in_use() {
620 let pool = ConnectionPool::new(PoolConfig {
621 min_connections: 0,
622 max_connections: 2,
623 acquire_timeout: Duration::from_millis(50),
624 ..Default::default()
625 });
626 let node_id = NodeId::new();
627 pool.add_node(node_id).await;
628
629 let c1 = pool.get_connection(&node_id).await.expect("first acquire");
630 let c2 = pool.get_connection(&node_id).await.expect("second acquire");
631
632 let err = pool
635 .get_connection(&node_id)
636 .await
637 .expect_err("third acquire should time out while c1/c2 held");
638 assert!(
639 matches!(err, ProxyError::Timeout(_)),
640 "expected Timeout, got {err:?}"
641 );
642
643 drop(c1);
645 let _c3 = pool
646 .get_connection(&node_id)
647 .await
648 .expect("acquire should succeed after c1 dropped");
649
650 drop(c2);
652 }
653
654 #[tokio::test]
659 async fn test_backend_template_with_unreachable_endpoint_errors() {
660 use crate::backend::{tls::default_client_config, TlsMode};
661
662 let template = BackendConfig {
663 host: "placeholder".into(),
664 port: 0,
665 user: "postgres".into(),
666 password: None,
667 database: None,
668 application_name: Some("helios-pool".into()),
669 tls_mode: TlsMode::Disable,
670 connect_timeout: Duration::from_millis(200),
671 query_timeout: Duration::from_millis(200),
672 tls_config: default_client_config(),
673 };
674
675 let pool = ConnectionPool::new(PoolConfig {
676 max_connections: 2,
677 acquire_timeout: Duration::from_millis(300),
678 ..Default::default()
679 })
680 .with_backend_template(template);
681
682 let node_id = NodeId::new();
683 pool.add_node_with_endpoint(node_id, "127.0.0.1", 1).await;
685
686 let err = pool
687 .get_connection(&node_id)
688 .await
689 .expect_err("acquire must fail when backend is unreachable");
690 match err {
691 ProxyError::Connection(msg) => {
692 assert!(
693 msg.contains("backend connect"),
694 "expected backend-connect error, got {}",
695 msg
696 );
697 }
698 other => panic!("expected Connection error, got {:?}", other),
699 }
700 }
701
702 #[tokio::test]
707 async fn test_add_node_with_endpoint_but_no_template_returns_skeleton_client() {
708 let pool = ConnectionPool::new(PoolConfig::default());
709 let node_id = NodeId::new();
710 pool.add_node_with_endpoint(node_id, "127.0.0.1", 5432)
711 .await;
712
713 let conn = pool.get_connection(&node_id).await.expect("acquire");
714 assert!(conn.client.is_none(), "no template → no live client");
715 }
716
717 #[tokio::test]
720 async fn test_return_then_reacquire_reuses_permit() {
721 let pool = ConnectionPool::new(PoolConfig {
722 min_connections: 0,
723 max_connections: 1,
724 acquire_timeout: Duration::from_millis(50),
725 ..Default::default()
726 });
727 let node_id = NodeId::new();
728 pool.add_node(node_id).await;
729
730 let c1 = pool.get_connection(&node_id).await.expect("first acquire");
731 pool.return_connection(c1).await;
732
733 let c2 = pool.get_connection(&node_id).await.expect("reacquire");
736 assert!(
737 c2.permit.is_some(),
738 "reused connection must carry its permit"
739 );
740
741 let metrics = pool.metrics().await;
742 assert_eq!(
743 metrics.connections_created, 1,
744 "reuse must not create a second connection"
745 );
746 }
747}