1use super::config::PoolModeConfig;
6use super::lease::{ClientId, ConnectionLease, LeaseAction};
7use super::metrics::PoolModeMetrics;
8use super::mode::PoolingMode;
9use crate::connection_pool::{ConnectionPool, PoolConfig};
10use crate::{NodeEndpoint, NodeId, ProxyError, Result};
11use dashmap::DashMap;
12use std::sync::Arc;
13use std::time::Instant;
14
15pub struct ConnectionPoolManager {
20 config: PoolModeConfig,
22 pools: DashMap<NodeId, ConnectionPool>,
24 active_leases: DashMap<ClientId, LeaseInfo>,
26 metrics: Arc<PoolModeMetrics>,
28}
29
30struct LeaseInfo {
32 node_id: NodeId,
34 statements: u64,
36}
37
38#[derive(Debug, Clone)]
40pub struct PoolStats {
41 pub total_connections: usize,
43 pub active_connections: usize,
45 pub idle_connections: usize,
47 pub node_count: usize,
49 pub node_stats: Vec<NodePoolStats>,
51}
52
53#[derive(Debug, Clone)]
55pub struct NodePoolStats {
56 pub node_id: NodeId,
58 pub total: usize,
60 pub active: usize,
62 pub idle: usize,
64}
65
66impl ConnectionPoolManager {
67 pub fn new(config: PoolModeConfig) -> Self {
69 Self {
70 config,
71 pools: DashMap::new(),
72 active_leases: DashMap::new(),
73 metrics: Arc::new(PoolModeMetrics::new()),
74 }
75 }
76
77 pub async fn add_node(&self, node: &NodeEndpoint) {
79 let pool_config = PoolConfig {
80 min_connections: self.config.min_idle as usize,
81 max_connections: self.config.max_pool_size as usize,
82 idle_timeout: self.config.idle_timeout(),
83 max_lifetime: self.config.max_lifetime(),
84 acquire_timeout: self.config.acquire_timeout(),
85 test_on_acquire: self.config.test_on_acquire,
86 };
87
88 let pool = ConnectionPool::new(pool_config);
89 pool.add_node(node.id).await;
90 self.pools.insert(node.id, pool);
91
92 tracing::debug!("Added node {:?} to pool manager", node.id);
93 }
94
95 pub async fn remove_node(&self, node_id: &NodeId) {
97 if let Some((_, pool)) = self.pools.remove(node_id) {
98 let _ = pool.close_all().await;
99 }
100 tracing::debug!("Removed node {:?} from pool manager", node_id);
101 }
102
103 pub async fn acquire(&self, client_id: ClientId, node_id: &NodeId) -> Result<ConnectionLease> {
112 self.acquire_with_mode(client_id, node_id, self.config.default_mode)
113 .await
114 }
115
116 pub async fn acquire_with_mode(
118 &self,
119 client_id: ClientId,
120 node_id: &NodeId,
121 mode: PoolingMode,
122 ) -> Result<ConnectionLease> {
123 if let Some(existing) = self.active_leases.get(&client_id) {
125 if existing.node_id == *node_id {
126 tracing::warn!(
129 "Client {:?} already has active lease for node {:?}",
130 client_id,
131 node_id
132 );
133 }
134 }
135
136 let pool = self
138 .pools
139 .get(node_id)
140 .ok_or_else(|| ProxyError::PoolExhausted(format!("Node {:?} not in pool", node_id)))?;
141
142 let acquire_start = Instant::now();
144 let connection = match tokio::time::timeout(
145 self.config.acquire_timeout(),
146 pool.get_connection(node_id),
147 )
148 .await
149 {
150 Ok(Ok(conn)) => conn,
151 Ok(Err(e)) => {
152 self.metrics.record_acquire_failure();
153 return Err(e);
154 }
155 Err(_) => {
156 self.metrics.record_acquire_timeout();
157 return Err(ProxyError::Timeout(format!(
158 "Timeout acquiring connection for node {:?}",
159 node_id
160 )));
161 }
162 };
163
164 let _acquire_duration = acquire_start.elapsed();
165
166 let lease = ConnectionLease::new(connection, mode, client_id);
168
169 self.active_leases.insert(
171 client_id,
172 LeaseInfo {
173 node_id: *node_id,
174 statements: 0,
175 },
176 );
177
178 self.metrics.record_acquire(mode);
180
181 tracing::trace!(
182 "Acquired {:?} lease for client {:?} on node {:?}",
183 mode,
184 client_id,
185 node_id
186 );
187
188 Ok(lease)
189 }
190
191 pub async fn release(&self, lease: ConnectionLease) {
195 let client_id = lease.client_id();
196 let mode = lease.mode();
197 let statements = lease.statements_executed();
198 let duration_ms = lease.lease_duration().as_millis() as u64;
199
200 if let Some((_, info)) = self.active_leases.remove(&client_id) {
202 if let Some(pool) = self.pools.get(&info.node_id) {
204 let mut connection = lease.into_connection();
205
206 if mode != PoolingMode::Session {
212 let reset_query = self.config.reset_query.as_str();
213 match pool.run_reset_query(&mut connection, reset_query).await {
214 Ok(()) => {
215 tracing::trace!(
216 query = reset_query,
217 "reset query executed on release"
218 );
219 self.metrics.record_reset(true);
220 }
221 Err(e) => {
222 tracing::warn!(
223 error = %e,
224 "reset query failed; connection will not be returned to pool"
225 );
226 self.metrics.record_reset(false);
227 pool.close_connection(connection).await;
228 return;
229 }
230 }
231 }
232
233 pool.return_connection(connection).await;
235 }
236 }
237
238 self.metrics.record_release(mode, duration_ms, statements);
240
241 tracing::trace!(
242 "Released {:?} lease for client {:?} after {} statements",
243 mode,
244 client_id,
245 statements
246 );
247 }
248
249 pub async fn release_and_close(&self, lease: ConnectionLease) {
251 let client_id = lease.client_id();
252 let mode = lease.mode();
253 let statements = lease.statements_executed();
254 let duration_ms = lease.lease_duration().as_millis() as u64;
255
256 if let Some((_, info)) = self.active_leases.remove(&client_id) {
258 if let Some(pool) = self.pools.get(&info.node_id) {
260 let connection = lease.into_connection();
261 pool.close_connection(connection).await;
262 self.metrics.record_connection_closed();
263 }
264 }
265
266 self.metrics.record_release(mode, duration_ms, statements);
268 }
269
270 pub fn on_statement_complete(&self, lease: &mut ConnectionLease, sql: &str) -> LeaseAction {
279 let action = lease.on_statement_complete(sql);
280
281 if let Some(mut info) = self.active_leases.get_mut(&lease.client_id()) {
283 info.statements += 1;
284 }
285
286 if action == LeaseAction::Reset {
288 self.metrics.record_transaction_complete();
289 }
290
291 action
292 }
293
294 pub async fn get_stats(&self) -> PoolStats {
296 let mut total = 0;
297 let mut active = 0;
298 let mut node_stats = Vec::new();
299
300 for entry in self.pools.iter() {
301 let node_id = *entry.key();
302 let pool = entry.value();
303
304 let pool_total = pool.total_connections().await;
305 let pool_active = pool.active_connections().await;
306 let pool_idle = pool_total.saturating_sub(pool_active);
307
308 total += pool_total;
309 active += pool_active;
310
311 node_stats.push(NodePoolStats {
312 node_id,
313 total: pool_total,
314 active: pool_active,
315 idle: pool_idle,
316 });
317 }
318
319 PoolStats {
320 total_connections: total,
321 active_connections: active,
322 idle_connections: total.saturating_sub(active),
323 node_count: self.pools.len(),
324 node_stats,
325 }
326 }
327
328 pub fn metrics(&self) -> &PoolModeMetrics {
330 &self.metrics
331 }
332
333 pub fn config(&self) -> &PoolModeConfig {
335 &self.config
336 }
337
338 pub fn default_mode(&self) -> PoolingMode {
340 self.config.default_mode
341 }
342
343 pub fn has_active_lease(&self, client_id: &ClientId) -> bool {
345 self.active_leases.contains_key(client_id)
346 }
347
348 pub fn active_lease_count(&self) -> usize {
350 self.active_leases.len()
351 }
352
353 pub async fn close_all(&self) {
355 for entry in self.pools.iter() {
356 let _ = entry.value().close_all().await;
357 }
358 self.active_leases.clear();
359 tracing::info!("Closed all connections in pool manager");
360 }
361
362 pub async fn evict_idle(&self) {
364 for entry in self.pools.iter() {
365 entry.value().evict_idle().await;
366 }
367 }
368}
369
370impl std::fmt::Debug for ConnectionPoolManager {
371 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
372 f.debug_struct("ConnectionPoolManager")
373 .field("default_mode", &self.config.default_mode)
374 .field("max_pool_size", &self.config.max_pool_size)
375 .field("active_leases", &self.active_leases.len())
376 .field("nodes", &self.pools.len())
377 .finish()
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384
385 #[tokio::test]
386 async fn test_manager_creation() {
387 let config = PoolModeConfig::default();
388 let manager = ConnectionPoolManager::new(config);
389
390 assert_eq!(manager.default_mode(), PoolingMode::Session);
391 assert_eq!(manager.active_lease_count(), 0);
392 }
393
394 #[tokio::test]
395 async fn test_add_remove_node() {
396 let config = PoolModeConfig::default();
397 let manager = ConnectionPoolManager::new(config);
398
399 let node = NodeEndpoint::new("localhost", 5432);
400 manager.add_node(&node).await;
401
402 let stats = manager.get_stats().await;
403 assert_eq!(stats.node_count, 1);
404
405 manager.remove_node(&node.id).await;
406
407 let stats = manager.get_stats().await;
408 assert_eq!(stats.node_count, 0);
409 }
410
411 #[tokio::test]
412 async fn test_acquire_release() {
413 let config = PoolModeConfig::transaction_mode();
414 let manager = ConnectionPoolManager::new(config);
415
416 let node = NodeEndpoint::new("localhost", 5432);
417 manager.add_node(&node).await;
418
419 let client_id = ClientId::new();
420 let lease = manager.acquire(client_id, &node.id).await.unwrap();
421
422 assert!(manager.has_active_lease(&client_id));
423 assert_eq!(manager.active_lease_count(), 1);
424
425 manager.release(lease).await;
426
427 assert!(!manager.has_active_lease(&client_id));
428 assert_eq!(manager.active_lease_count(), 0);
429 }
430
431 #[tokio::test]
432 async fn test_metrics_recording() {
433 let config = PoolModeConfig::transaction_mode();
434 let manager = ConnectionPoolManager::new(config);
435
436 let node = NodeEndpoint::new("localhost", 5432);
437 manager.add_node(&node).await;
438
439 let client_id = ClientId::new();
440 let lease = manager.acquire(client_id, &node.id).await.unwrap();
441
442 let snapshot = manager.metrics().snapshot();
443 assert_eq!(snapshot.acquires, 1);
444
445 manager.release(lease).await;
446
447 let snapshot = manager.metrics().snapshot();
448 assert_eq!(snapshot.releases, 1);
449 }
450}