1use super::config::PoolModeConfig;
6use super::lease::{ClientId, ConnectionLease, LeaseAction};
7use super::metrics::PoolModeMetrics;
8use super::mode::PoolingMode;
9use crate::connection_pool::{ConnectionPool, PoolConfig};
10use crate::{NodeEndpoint, NodeId, ProxyError, Result};
11use dashmap::DashMap;
12use std::sync::Arc;
13use std::time::Instant;
14
15pub struct ConnectionPoolManager {
20 config: PoolModeConfig,
22 pools: DashMap<NodeId, ConnectionPool>,
24 active_leases: DashMap<ClientId, LeaseInfo>,
26 metrics: Arc<PoolModeMetrics>,
28}
29
30struct LeaseInfo {
32 node_id: NodeId,
34 statements: u64,
36}
37
38#[derive(Debug, Clone)]
40pub struct PoolStats {
41 pub total_connections: usize,
43 pub active_connections: usize,
45 pub idle_connections: usize,
47 pub node_count: usize,
49 pub node_stats: Vec<NodePoolStats>,
51}
52
53#[derive(Debug, Clone)]
55pub struct NodePoolStats {
56 pub node_id: NodeId,
58 pub total: usize,
60 pub active: usize,
62 pub idle: usize,
64}
65
66impl ConnectionPoolManager {
67 pub fn new(config: PoolModeConfig) -> Self {
69 Self {
70 config,
71 pools: DashMap::new(),
72 active_leases: DashMap::new(),
73 metrics: Arc::new(PoolModeMetrics::new()),
74 }
75 }
76
77 pub async fn add_node(&self, node: &NodeEndpoint) {
79 let pool_config = PoolConfig {
80 min_connections: self.config.min_idle as usize,
81 max_connections: self.config.max_pool_size as usize,
82 idle_timeout: self.config.idle_timeout(),
83 max_lifetime: self.config.max_lifetime(),
84 acquire_timeout: self.config.acquire_timeout(),
85 test_on_acquire: self.config.test_on_acquire,
86 };
87
88 let pool = ConnectionPool::new(pool_config);
89 pool.add_node(node.id).await;
90 self.pools.insert(node.id, pool);
91
92 tracing::debug!("Added node {:?} to pool manager", node.id);
93 }
94
95 pub async fn remove_node(&self, node_id: &NodeId) {
97 if let Some((_, pool)) = self.pools.remove(node_id) {
98 let _ = pool.close_all().await;
99 }
100 tracing::debug!("Removed node {:?} from pool manager", node_id);
101 }
102
103 pub async fn acquire(&self, client_id: ClientId, node_id: &NodeId) -> Result<ConnectionLease> {
112 self.acquire_with_mode(client_id, node_id, self.config.default_mode)
113 .await
114 }
115
116 pub async fn acquire_with_mode(
118 &self,
119 client_id: ClientId,
120 node_id: &NodeId,
121 mode: PoolingMode,
122 ) -> Result<ConnectionLease> {
123 if let Some(existing) = self.active_leases.get(&client_id) {
125 if existing.node_id == *node_id {
126 tracing::warn!(
129 "Client {:?} already has active lease for node {:?}",
130 client_id,
131 node_id
132 );
133 }
134 }
135
136 let pool = self
138 .pools
139 .get(node_id)
140 .ok_or_else(|| ProxyError::PoolExhausted(format!("Node {:?} not in pool", node_id)))?;
141
142 let acquire_start = Instant::now();
144 let connection =
145 match tokio::time::timeout(self.config.acquire_timeout(), pool.get_connection(node_id))
146 .await
147 {
148 Ok(Ok(conn)) => conn,
149 Ok(Err(e)) => {
150 self.metrics.record_acquire_failure();
151 return Err(e);
152 }
153 Err(_) => {
154 self.metrics.record_acquire_timeout();
155 return Err(ProxyError::Timeout(format!(
156 "Timeout acquiring connection for node {:?}",
157 node_id
158 )));
159 }
160 };
161
162 let _acquire_duration = acquire_start.elapsed();
163
164 let lease = ConnectionLease::new(connection, mode, client_id);
166
167 self.active_leases.insert(
169 client_id,
170 LeaseInfo {
171 node_id: *node_id,
172 statements: 0,
173 },
174 );
175
176 self.metrics.record_acquire(mode);
178
179 tracing::trace!(
180 "Acquired {:?} lease for client {:?} on node {:?}",
181 mode,
182 client_id,
183 node_id
184 );
185
186 Ok(lease)
187 }
188
189 pub async fn release(&self, lease: ConnectionLease) {
193 let client_id = lease.client_id();
194 let mode = lease.mode();
195 let statements = lease.statements_executed();
196 let duration_ms = lease.lease_duration().as_millis() as u64;
197
198 if let Some((_, info)) = self.active_leases.remove(&client_id) {
200 if let Some(pool) = self.pools.get(&info.node_id) {
202 let mut connection = lease.into_connection();
203
204 if mode != PoolingMode::Session {
210 let reset_query = self.config.reset_query.as_str();
211 match pool.run_reset_query(&mut connection, reset_query).await {
212 Ok(()) => {
213 tracing::trace!(query = reset_query, "reset query executed on release");
214 self.metrics.record_reset(true);
215 }
216 Err(e) => {
217 tracing::warn!(
218 error = %e,
219 "reset query failed; connection will not be returned to pool"
220 );
221 self.metrics.record_reset(false);
222 pool.close_connection(connection).await;
223 return;
224 }
225 }
226 }
227
228 pool.return_connection(connection).await;
230 }
231 }
232
233 self.metrics.record_release(mode, duration_ms, statements);
235
236 tracing::trace!(
237 "Released {:?} lease for client {:?} after {} statements",
238 mode,
239 client_id,
240 statements
241 );
242 }
243
244 pub async fn release_and_close(&self, lease: ConnectionLease) {
246 let client_id = lease.client_id();
247 let mode = lease.mode();
248 let statements = lease.statements_executed();
249 let duration_ms = lease.lease_duration().as_millis() as u64;
250
251 if let Some((_, info)) = self.active_leases.remove(&client_id) {
253 if let Some(pool) = self.pools.get(&info.node_id) {
255 let connection = lease.into_connection();
256 pool.close_connection(connection).await;
257 self.metrics.record_connection_closed();
258 }
259 }
260
261 self.metrics.record_release(mode, duration_ms, statements);
263 }
264
265 pub fn on_statement_complete(&self, lease: &mut ConnectionLease, sql: &str) -> LeaseAction {
274 let action = lease.on_statement_complete(sql);
275
276 if let Some(mut info) = self.active_leases.get_mut(&lease.client_id()) {
278 info.statements += 1;
279 }
280
281 if action == LeaseAction::Reset {
283 self.metrics.record_transaction_complete();
284 }
285
286 action
287 }
288
289 pub async fn get_stats(&self) -> PoolStats {
291 let mut total = 0;
292 let mut active = 0;
293 let mut node_stats = Vec::new();
294
295 for entry in self.pools.iter() {
296 let node_id = *entry.key();
297 let pool = entry.value();
298
299 let pool_total = pool.total_connections().await;
300 let pool_active = pool.active_connections().await;
301 let pool_idle = pool_total.saturating_sub(pool_active);
302
303 total += pool_total;
304 active += pool_active;
305
306 node_stats.push(NodePoolStats {
307 node_id,
308 total: pool_total,
309 active: pool_active,
310 idle: pool_idle,
311 });
312 }
313
314 PoolStats {
315 total_connections: total,
316 active_connections: active,
317 idle_connections: total.saturating_sub(active),
318 node_count: self.pools.len(),
319 node_stats,
320 }
321 }
322
323 pub fn metrics(&self) -> &PoolModeMetrics {
325 &self.metrics
326 }
327
328 pub fn config(&self) -> &PoolModeConfig {
330 &self.config
331 }
332
333 pub fn default_mode(&self) -> PoolingMode {
335 self.config.default_mode
336 }
337
338 pub fn has_active_lease(&self, client_id: &ClientId) -> bool {
340 self.active_leases.contains_key(client_id)
341 }
342
343 pub fn active_lease_count(&self) -> usize {
345 self.active_leases.len()
346 }
347
348 pub async fn close_all(&self) {
350 for entry in self.pools.iter() {
351 let _ = entry.value().close_all().await;
352 }
353 self.active_leases.clear();
354 tracing::info!("Closed all connections in pool manager");
355 }
356
357 pub async fn evict_idle(&self) {
359 for entry in self.pools.iter() {
360 entry.value().evict_idle().await;
361 }
362 }
363}
364
365impl std::fmt::Debug for ConnectionPoolManager {
366 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367 f.debug_struct("ConnectionPoolManager")
368 .field("default_mode", &self.config.default_mode)
369 .field("max_pool_size", &self.config.max_pool_size)
370 .field("active_leases", &self.active_leases.len())
371 .field("nodes", &self.pools.len())
372 .finish()
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[tokio::test]
381 async fn test_manager_creation() {
382 let config = PoolModeConfig::default();
383 let manager = ConnectionPoolManager::new(config);
384
385 assert_eq!(manager.default_mode(), PoolingMode::Session);
386 assert_eq!(manager.active_lease_count(), 0);
387 }
388
389 #[tokio::test]
390 async fn test_add_remove_node() {
391 let config = PoolModeConfig::default();
392 let manager = ConnectionPoolManager::new(config);
393
394 let node = NodeEndpoint::new("localhost", 5432);
395 manager.add_node(&node).await;
396
397 let stats = manager.get_stats().await;
398 assert_eq!(stats.node_count, 1);
399
400 manager.remove_node(&node.id).await;
401
402 let stats = manager.get_stats().await;
403 assert_eq!(stats.node_count, 0);
404 }
405
406 #[tokio::test]
407 async fn test_acquire_release() {
408 let config = PoolModeConfig::transaction_mode();
409 let manager = ConnectionPoolManager::new(config);
410
411 let node = NodeEndpoint::new("localhost", 5432);
412 manager.add_node(&node).await;
413
414 let client_id = ClientId::new();
415 let lease = manager.acquire(client_id, &node.id).await.unwrap();
416
417 assert!(manager.has_active_lease(&client_id));
418 assert_eq!(manager.active_lease_count(), 1);
419
420 manager.release(lease).await;
421
422 assert!(!manager.has_active_lease(&client_id));
423 assert_eq!(manager.active_lease_count(), 0);
424 }
425
426 #[tokio::test]
427 async fn test_metrics_recording() {
428 let config = PoolModeConfig::transaction_mode();
429 let manager = ConnectionPoolManager::new(config);
430
431 let node = NodeEndpoint::new("localhost", 5432);
432 manager.add_node(&node).await;
433
434 let client_id = ClientId::new();
435 let lease = manager.acquire(client_id, &node.id).await.unwrap();
436
437 let snapshot = manager.metrics().snapshot();
438 assert_eq!(snapshot.acquires, 1);
439
440 manager.release(lease).await;
441
442 let snapshot = manager.metrics().snapshot();
443 assert_eq!(snapshot.releases, 1);
444 }
445}