use super::lease::ClientId;
use super::mode::PoolingMode;
use crate::{ProxyError, Result};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use parking_lot::RwLock;
use tracing::{warn, info, debug};
#[derive(Debug)]
pub struct TransactionLeakDetector {
active_transactions: RwLock<HashMap<ClientId, TransactionInfo>>,
warning_threshold: Duration,
critical_threshold: Duration,
warnings_issued: AtomicU64,
force_closed: AtomicU64,
}
#[derive(Debug, Clone)]
struct TransactionInfo {
started_at: Instant,
mode: PoolingMode,
first_statement: String,
warning_issued: bool,
}
impl Default for TransactionLeakDetector {
fn default() -> Self {
Self::new(Duration::from_secs(60), Duration::from_secs(300))
}
}
impl TransactionLeakDetector {
pub fn new(warning_threshold: Duration, critical_threshold: Duration) -> Self {
Self {
active_transactions: RwLock::new(HashMap::new()),
warning_threshold,
critical_threshold,
warnings_issued: AtomicU64::new(0),
force_closed: AtomicU64::new(0),
}
}
pub fn transaction_started(&self, client_id: ClientId, mode: PoolingMode, first_sql: &str) {
let info = TransactionInfo {
started_at: Instant::now(),
mode,
first_statement: truncate_sql(first_sql, 100),
warning_issued: false,
};
self.active_transactions.write().insert(client_id, info);
}
pub fn transaction_ended(&self, client_id: &ClientId) {
self.active_transactions.write().remove(client_id);
}
pub fn check_for_leaks(&self) -> Vec<ClientId> {
let now = Instant::now();
let mut leaked = Vec::new();
let mut txns = self.active_transactions.write();
for (client_id, info) in txns.iter_mut() {
let duration = now.duration_since(info.started_at);
if duration >= self.critical_threshold {
leaked.push(*client_id);
warn!(
"CRITICAL: Transaction leak detected for client {:?}, running for {:?}, mode: {:?}, sql: {}",
client_id, duration, info.mode, info.first_statement
);
self.force_closed.fetch_add(1, Ordering::Relaxed);
}
else if duration >= self.warning_threshold && !info.warning_issued {
warn!(
"Long-running transaction for client {:?}, running for {:?}, mode: {:?}, sql: {}",
client_id, duration, info.mode, info.first_statement
);
info.warning_issued = true;
self.warnings_issued.fetch_add(1, Ordering::Relaxed);
}
}
for client_id in &leaked {
txns.remove(client_id);
}
leaked
}
pub fn stats(&self) -> TransactionLeakStats {
let txns = self.active_transactions.read();
TransactionLeakStats {
active_transactions: txns.len(),
warnings_issued: self.warnings_issued.load(Ordering::Relaxed),
force_closed: self.force_closed.load(Ordering::Relaxed),
warning_threshold_secs: self.warning_threshold.as_secs(),
critical_threshold_secs: self.critical_threshold.as_secs(),
}
}
}
#[derive(Debug, Clone)]
pub struct TransactionLeakStats {
pub active_transactions: usize,
pub warnings_issued: u64,
pub force_closed: u64,
pub warning_threshold_secs: u64,
pub critical_threshold_secs: u64,
}
#[derive(Debug)]
pub struct ConnectionHealthValidator {
validation_query: String,
timeout: Duration,
validations: AtomicU64,
failures: AtomicU64,
}
impl Default for ConnectionHealthValidator {
fn default() -> Self {
Self::new("SELECT 1", Duration::from_secs(5))
}
}
impl ConnectionHealthValidator {
pub fn new(validation_query: impl Into<String>, timeout: Duration) -> Self {
Self {
validation_query: validation_query.into(),
timeout,
validations: AtomicU64::new(0),
failures: AtomicU64::new(0),
}
}
pub fn validation_query(&self) -> &str {
&self.validation_query
}
pub fn timeout(&self) -> Duration {
self.timeout
}
pub fn record_validation(&self, success: bool) {
self.validations.fetch_add(1, Ordering::Relaxed);
if !success {
self.failures.fetch_add(1, Ordering::Relaxed);
}
}
pub fn stats(&self) -> ValidationStats {
ValidationStats {
validations: self.validations.load(Ordering::Relaxed),
failures: self.failures.load(Ordering::Relaxed),
}
}
pub fn success_rate(&self) -> f64 {
let total = self.validations.load(Ordering::Relaxed);
let failures = self.failures.load(Ordering::Relaxed);
if total == 0 {
1.0
} else {
(total - failures) as f64 / total as f64
}
}
}
#[derive(Debug, Clone)]
pub struct ValidationStats {
pub validations: u64,
pub failures: u64,
}
#[derive(Debug)]
pub struct StaleLeaseCleaner {
max_idle_time: Duration,
lease_activity: RwLock<HashMap<ClientId, Instant>>,
cleaned_count: AtomicU64,
}
impl Default for StaleLeaseCleaner {
fn default() -> Self {
Self::new(Duration::from_secs(1800)) }
}
impl StaleLeaseCleaner {
pub fn new(max_idle_time: Duration) -> Self {
Self {
max_idle_time,
lease_activity: RwLock::new(HashMap::new()),
cleaned_count: AtomicU64::new(0),
}
}
pub fn record_activity(&self, client_id: ClientId) {
self.lease_activity.write().insert(client_id, Instant::now());
}
pub fn lease_released(&self, client_id: &ClientId) {
self.lease_activity.write().remove(client_id);
}
pub fn find_stale_leases(&self) -> Vec<ClientId> {
let now = Instant::now();
let activity = self.lease_activity.read();
activity
.iter()
.filter(|(_, last_activity)| now.duration_since(**last_activity) > self.max_idle_time)
.map(|(client_id, _)| *client_id)
.collect()
}
pub fn clean_stale(&self) -> Vec<ClientId> {
let stale = self.find_stale_leases();
let count = stale.len();
if count > 0 {
let mut activity = self.lease_activity.write();
for client_id in &stale {
activity.remove(client_id);
}
self.cleaned_count.fetch_add(count as u64, Ordering::Relaxed);
info!(
"Cleaned {} stale leases (idle > {:?})",
count, self.max_idle_time
);
}
stale
}
pub fn cleaned_count(&self) -> u64 {
self.cleaned_count.load(Ordering::Relaxed)
}
}
#[derive(Debug)]
pub struct PoolExhaustionMonitor {
max_queue_size: usize,
current_queue: AtomicU64,
exhaustion_events: AtomicU64,
rejected_requests: AtomicU64,
enable_backpressure: bool,
}
impl Default for PoolExhaustionMonitor {
fn default() -> Self {
Self::new(1000, true)
}
}
impl PoolExhaustionMonitor {
pub fn new(max_queue_size: usize, enable_backpressure: bool) -> Self {
Self {
max_queue_size,
current_queue: AtomicU64::new(0),
exhaustion_events: AtomicU64::new(0),
rejected_requests: AtomicU64::new(0),
enable_backpressure,
}
}
pub fn check_capacity(&self) -> Result<()> {
let queue_size = self.current_queue.load(Ordering::Relaxed);
if self.enable_backpressure && queue_size >= self.max_queue_size as u64 {
self.rejected_requests.fetch_add(1, Ordering::Relaxed);
return Err(ProxyError::PoolExhausted(format!(
"Pool queue full ({} waiting), request rejected",
queue_size
)));
}
Ok(())
}
pub fn enter_queue(&self) {
let prev = self.current_queue.fetch_add(1, Ordering::Relaxed);
if prev == 0 {
self.exhaustion_events.fetch_add(1, Ordering::Relaxed);
debug!("Pool exhaustion event - requests now queuing");
}
}
pub fn leave_queue(&self) {
self.current_queue.fetch_sub(1, Ordering::Relaxed);
}
pub fn queue_size(&self) -> u64 {
self.current_queue.load(Ordering::Relaxed)
}
pub fn stats(&self) -> ExhaustionStats {
ExhaustionStats {
current_queue: self.current_queue.load(Ordering::Relaxed),
max_queue_size: self.max_queue_size as u64,
exhaustion_events: self.exhaustion_events.load(Ordering::Relaxed),
rejected_requests: self.rejected_requests.load(Ordering::Relaxed),
backpressure_enabled: self.enable_backpressure,
}
}
}
#[derive(Debug, Clone)]
pub struct ExhaustionStats {
pub current_queue: u64,
pub max_queue_size: u64,
pub exhaustion_events: u64,
pub rejected_requests: u64,
pub backpressure_enabled: bool,
}
#[derive(Debug)]
pub struct PoolHardening {
pub leak_detector: TransactionLeakDetector,
pub health_validator: ConnectionHealthValidator,
pub stale_cleaner: StaleLeaseCleaner,
pub exhaustion_monitor: PoolExhaustionMonitor,
}
impl Default for PoolHardening {
fn default() -> Self {
Self {
leak_detector: TransactionLeakDetector::default(),
health_validator: ConnectionHealthValidator::default(),
stale_cleaner: StaleLeaseCleaner::default(),
exhaustion_monitor: PoolExhaustionMonitor::default(),
}
}
}
impl PoolHardening {
pub fn new(
tx_warning_threshold: Duration,
tx_critical_threshold: Duration,
validation_query: &str,
validation_timeout: Duration,
max_lease_idle: Duration,
max_queue_size: usize,
enable_backpressure: bool,
) -> Self {
Self {
leak_detector: TransactionLeakDetector::new(tx_warning_threshold, tx_critical_threshold),
health_validator: ConnectionHealthValidator::new(validation_query, validation_timeout),
stale_cleaner: StaleLeaseCleaner::new(max_lease_idle),
exhaustion_monitor: PoolExhaustionMonitor::new(max_queue_size, enable_backpressure),
}
}
pub fn run_maintenance(&self) -> (Vec<ClientId>, Vec<ClientId>) {
let leaked = self.leak_detector.check_for_leaks();
let stale = self.stale_cleaner.clean_stale();
(leaked, stale)
}
pub fn stats(&self) -> HardeningStats {
HardeningStats {
leak_stats: self.leak_detector.stats(),
validation_stats: self.health_validator.stats(),
exhaustion_stats: self.exhaustion_monitor.stats(),
stale_cleaned: self.stale_cleaner.cleaned_count(),
}
}
}
#[derive(Debug, Clone)]
pub struct HardeningStats {
pub leak_stats: TransactionLeakStats,
pub validation_stats: ValidationStats,
pub exhaustion_stats: ExhaustionStats,
pub stale_cleaned: u64,
}
fn truncate_sql(sql: &str, max_len: usize) -> String {
if sql.len() <= max_len {
sql.to_string()
} else {
format!("{}...", &sql[..max_len])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transaction_leak_detector() {
let detector = TransactionLeakDetector::new(
Duration::from_millis(10),
Duration::from_millis(50),
);
let client1 = ClientId::new();
let client2 = ClientId::new();
detector.transaction_started(client1, PoolingMode::Transaction, "BEGIN; SELECT * FROM users");
detector.transaction_started(client2, PoolingMode::Statement, "SELECT 1");
assert!(detector.check_for_leaks().is_empty());
detector.transaction_ended(&client2);
std::thread::sleep(Duration::from_millis(15));
let leaked = detector.check_for_leaks();
assert!(leaked.is_empty());
std::thread::sleep(Duration::from_millis(40));
let leaked = detector.check_for_leaks();
assert_eq!(leaked.len(), 1);
assert_eq!(leaked[0], client1);
}
#[test]
fn test_connection_health_validator() {
let validator = ConnectionHealthValidator::default();
validator.record_validation(true);
validator.record_validation(true);
validator.record_validation(false);
assert_eq!(validator.stats().validations, 3);
assert_eq!(validator.stats().failures, 1);
assert!((validator.success_rate() - 0.666).abs() < 0.01);
}
#[test]
fn test_stale_lease_cleaner() {
let cleaner = StaleLeaseCleaner::new(Duration::from_millis(20));
let client1 = ClientId::new();
let client2 = ClientId::new();
cleaner.record_activity(client1);
cleaner.record_activity(client2);
assert!(cleaner.find_stale_leases().is_empty());
std::thread::sleep(Duration::from_millis(25));
cleaner.record_activity(client1);
let stale = cleaner.clean_stale();
assert_eq!(stale.len(), 1);
assert_eq!(stale[0], client2);
assert_eq!(cleaner.cleaned_count(), 1);
}
#[test]
fn test_pool_exhaustion_monitor() {
let monitor = PoolExhaustionMonitor::new(2, true);
assert!(monitor.check_capacity().is_ok());
monitor.enter_queue();
assert!(monitor.check_capacity().is_ok());
monitor.enter_queue();
assert!(monitor.check_capacity().is_err());
assert_eq!(monitor.stats().rejected_requests, 1);
monitor.leave_queue();
assert!(monitor.check_capacity().is_ok());
}
#[test]
fn test_pool_hardening_combined() {
let hardening = PoolHardening::default();
let (leaked, stale) = hardening.run_maintenance();
assert!(leaked.is_empty());
assert!(stale.is_empty());
let stats = hardening.stats();
assert_eq!(stats.leak_stats.active_transactions, 0);
assert_eq!(stats.stale_cleaned, 0);
}
}