use super::config::PoolModeConfig;
use super::lease::{ClientId, ConnectionLease, LeaseAction};
use super::metrics::PoolModeMetrics;
use super::mode::PoolingMode;
use crate::connection_pool::{ConnectionPool, PoolConfig};
use crate::{NodeEndpoint, NodeId, ProxyError, Result};
use dashmap::DashMap;
use std::sync::Arc;
use std::time::Instant;
pub struct ConnectionPoolManager {
config: PoolModeConfig,
pools: DashMap<NodeId, ConnectionPool>,
active_leases: DashMap<ClientId, LeaseInfo>,
metrics: Arc<PoolModeMetrics>,
}
struct LeaseInfo {
node_id: NodeId,
statements: u64,
}
#[derive(Debug, Clone)]
pub struct PoolStats {
pub total_connections: usize,
pub active_connections: usize,
pub idle_connections: usize,
pub node_count: usize,
pub node_stats: Vec<NodePoolStats>,
}
#[derive(Debug, Clone)]
pub struct NodePoolStats {
pub node_id: NodeId,
pub total: usize,
pub active: usize,
pub idle: usize,
}
impl ConnectionPoolManager {
pub fn new(config: PoolModeConfig) -> Self {
Self {
config,
pools: DashMap::new(),
active_leases: DashMap::new(),
metrics: Arc::new(PoolModeMetrics::new()),
}
}
pub async fn add_node(&self, node: &NodeEndpoint) {
let pool_config = PoolConfig {
min_connections: self.config.min_idle as usize,
max_connections: self.config.max_pool_size as usize,
idle_timeout: self.config.idle_timeout(),
max_lifetime: self.config.max_lifetime(),
acquire_timeout: self.config.acquire_timeout(),
test_on_acquire: self.config.test_on_acquire,
};
let pool = ConnectionPool::new(pool_config);
pool.add_node(node.id).await;
self.pools.insert(node.id, pool);
tracing::debug!("Added node {:?} to pool manager", node.id);
}
pub async fn remove_node(&self, node_id: &NodeId) {
if let Some((_, pool)) = self.pools.remove(node_id) {
let _ = pool.close_all().await;
}
tracing::debug!("Removed node {:?} from pool manager", node_id);
}
pub async fn acquire(&self, client_id: ClientId, node_id: &NodeId) -> Result<ConnectionLease> {
self.acquire_with_mode(client_id, node_id, self.config.default_mode)
.await
}
pub async fn acquire_with_mode(
&self,
client_id: ClientId,
node_id: &NodeId,
mode: PoolingMode,
) -> Result<ConnectionLease> {
if let Some(existing) = self.active_leases.get(&client_id) {
if existing.node_id == *node_id {
tracing::warn!(
"Client {:?} already has active lease for node {:?}",
client_id,
node_id
);
}
}
let pool = self
.pools
.get(node_id)
.ok_or_else(|| ProxyError::PoolExhausted(format!("Node {:?} not in pool", node_id)))?;
let acquire_start = Instant::now();
let connection = match tokio::time::timeout(
self.config.acquire_timeout(),
pool.get_connection(node_id),
)
.await
{
Ok(Ok(conn)) => conn,
Ok(Err(e)) => {
self.metrics.record_acquire_failure();
return Err(e);
}
Err(_) => {
self.metrics.record_acquire_timeout();
return Err(ProxyError::Timeout(format!(
"Timeout acquiring connection for node {:?}",
node_id
)));
}
};
let _acquire_duration = acquire_start.elapsed();
let lease = ConnectionLease::new(connection, mode, client_id);
self.active_leases.insert(
client_id,
LeaseInfo {
node_id: *node_id,
statements: 0,
},
);
self.metrics.record_acquire(mode);
tracing::trace!(
"Acquired {:?} lease for client {:?} on node {:?}",
mode,
client_id,
node_id
);
Ok(lease)
}
pub async fn release(&self, lease: ConnectionLease) {
let client_id = lease.client_id();
let mode = lease.mode();
let statements = lease.statements_executed();
let duration_ms = lease.lease_duration().as_millis() as u64;
if let Some((_, info)) = self.active_leases.remove(&client_id) {
if let Some(pool) = self.pools.get(&info.node_id) {
let mut connection = lease.into_connection();
if mode != PoolingMode::Session {
let reset_query = self.config.reset_query.as_str();
match pool.run_reset_query(&mut connection, reset_query).await {
Ok(()) => {
tracing::trace!(
query = reset_query,
"reset query executed on release"
);
self.metrics.record_reset(true);
}
Err(e) => {
tracing::warn!(
error = %e,
"reset query failed; connection will not be returned to pool"
);
self.metrics.record_reset(false);
pool.close_connection(connection).await;
return;
}
}
}
pool.return_connection(connection).await;
}
}
self.metrics.record_release(mode, duration_ms, statements);
tracing::trace!(
"Released {:?} lease for client {:?} after {} statements",
mode,
client_id,
statements
);
}
pub async fn release_and_close(&self, lease: ConnectionLease) {
let client_id = lease.client_id();
let mode = lease.mode();
let statements = lease.statements_executed();
let duration_ms = lease.lease_duration().as_millis() as u64;
if let Some((_, info)) = self.active_leases.remove(&client_id) {
if let Some(pool) = self.pools.get(&info.node_id) {
let connection = lease.into_connection();
pool.close_connection(connection).await;
self.metrics.record_connection_closed();
}
}
self.metrics.record_release(mode, duration_ms, statements);
}
pub fn on_statement_complete(&self, lease: &mut ConnectionLease, sql: &str) -> LeaseAction {
let action = lease.on_statement_complete(sql);
if let Some(mut info) = self.active_leases.get_mut(&lease.client_id()) {
info.statements += 1;
}
if action == LeaseAction::Reset {
self.metrics.record_transaction_complete();
}
action
}
pub async fn get_stats(&self) -> PoolStats {
let mut total = 0;
let mut active = 0;
let mut node_stats = Vec::new();
for entry in self.pools.iter() {
let node_id = *entry.key();
let pool = entry.value();
let pool_total = pool.total_connections().await;
let pool_active = pool.active_connections().await;
let pool_idle = pool_total.saturating_sub(pool_active);
total += pool_total;
active += pool_active;
node_stats.push(NodePoolStats {
node_id,
total: pool_total,
active: pool_active,
idle: pool_idle,
});
}
PoolStats {
total_connections: total,
active_connections: active,
idle_connections: total.saturating_sub(active),
node_count: self.pools.len(),
node_stats,
}
}
pub fn metrics(&self) -> &PoolModeMetrics {
&self.metrics
}
pub fn config(&self) -> &PoolModeConfig {
&self.config
}
pub fn default_mode(&self) -> PoolingMode {
self.config.default_mode
}
pub fn has_active_lease(&self, client_id: &ClientId) -> bool {
self.active_leases.contains_key(client_id)
}
pub fn active_lease_count(&self) -> usize {
self.active_leases.len()
}
pub async fn close_all(&self) {
for entry in self.pools.iter() {
let _ = entry.value().close_all().await;
}
self.active_leases.clear();
tracing::info!("Closed all connections in pool manager");
}
pub async fn evict_idle(&self) {
for entry in self.pools.iter() {
entry.value().evict_idle().await;
}
}
}
impl std::fmt::Debug for ConnectionPoolManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectionPoolManager")
.field("default_mode", &self.config.default_mode)
.field("max_pool_size", &self.config.max_pool_size)
.field("active_leases", &self.active_leases.len())
.field("nodes", &self.pools.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_manager_creation() {
let config = PoolModeConfig::default();
let manager = ConnectionPoolManager::new(config);
assert_eq!(manager.default_mode(), PoolingMode::Session);
assert_eq!(manager.active_lease_count(), 0);
}
#[tokio::test]
async fn test_add_remove_node() {
let config = PoolModeConfig::default();
let manager = ConnectionPoolManager::new(config);
let node = NodeEndpoint::new("localhost", 5432);
manager.add_node(&node).await;
let stats = manager.get_stats().await;
assert_eq!(stats.node_count, 1);
manager.remove_node(&node.id).await;
let stats = manager.get_stats().await;
assert_eq!(stats.node_count, 0);
}
#[tokio::test]
async fn test_acquire_release() {
let config = PoolModeConfig::transaction_mode();
let manager = ConnectionPoolManager::new(config);
let node = NodeEndpoint::new("localhost", 5432);
manager.add_node(&node).await;
let client_id = ClientId::new();
let lease = manager.acquire(client_id, &node.id).await.unwrap();
assert!(manager.has_active_lease(&client_id));
assert_eq!(manager.active_lease_count(), 1);
manager.release(lease).await;
assert!(!manager.has_active_lease(&client_id));
assert_eq!(manager.active_lease_count(), 0);
}
#[tokio::test]
async fn test_metrics_recording() {
let config = PoolModeConfig::transaction_mode();
let manager = ConnectionPoolManager::new(config);
let node = NodeEndpoint::new("localhost", 5432);
manager.add_node(&node).await;
let client_id = ClientId::new();
let lease = manager.acquire(client_id, &node.id).await.unwrap();
let snapshot = manager.metrics().snapshot();
assert_eq!(snapshot.acquires, 1);
manager.release(lease).await;
let snapshot = manager.metrics().snapshot();
assert_eq!(snapshot.releases, 1);
}
}