1use super::lease::ClientId;
10use super::mode::PoolingMode;
11use crate::{ProxyError, Result};
12use std::collections::HashMap;
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::time::{Duration, Instant};
15use parking_lot::RwLock;
16use tracing::{warn, info, debug};
17
18#[derive(Debug)]
23pub struct TransactionLeakDetector {
24 active_transactions: RwLock<HashMap<ClientId, TransactionInfo>>,
26 warning_threshold: Duration,
28 critical_threshold: Duration,
30 warnings_issued: AtomicU64,
32 force_closed: AtomicU64,
34}
35
36#[derive(Debug, Clone)]
38struct TransactionInfo {
39 started_at: Instant,
41 mode: PoolingMode,
43 first_statement: String,
45 warning_issued: bool,
47}
48
49impl Default for TransactionLeakDetector {
50 fn default() -> Self {
51 Self::new(Duration::from_secs(60), Duration::from_secs(300))
52 }
53}
54
55impl TransactionLeakDetector {
56 pub fn new(warning_threshold: Duration, critical_threshold: Duration) -> Self {
62 Self {
63 active_transactions: RwLock::new(HashMap::new()),
64 warning_threshold,
65 critical_threshold,
66 warnings_issued: AtomicU64::new(0),
67 force_closed: AtomicU64::new(0),
68 }
69 }
70
71 pub fn transaction_started(&self, client_id: ClientId, mode: PoolingMode, first_sql: &str) {
73 let info = TransactionInfo {
74 started_at: Instant::now(),
75 mode,
76 first_statement: truncate_sql(first_sql, 100),
77 warning_issued: false,
78 };
79 self.active_transactions.write().insert(client_id, info);
80 }
81
82 pub fn transaction_ended(&self, client_id: &ClientId) {
84 self.active_transactions.write().remove(client_id);
85 }
86
87 pub fn check_for_leaks(&self) -> Vec<ClientId> {
91 let now = Instant::now();
92 let mut leaked = Vec::new();
93 let mut txns = self.active_transactions.write();
94
95 for (client_id, info) in txns.iter_mut() {
96 let duration = now.duration_since(info.started_at);
97
98 if duration >= self.critical_threshold {
100 leaked.push(*client_id);
101 warn!(
102 "CRITICAL: Transaction leak detected for client {:?}, running for {:?}, mode: {:?}, sql: {}",
103 client_id, duration, info.mode, info.first_statement
104 );
105 self.force_closed.fetch_add(1, Ordering::Relaxed);
106 }
107 else if duration >= self.warning_threshold && !info.warning_issued {
109 warn!(
110 "Long-running transaction for client {:?}, running for {:?}, mode: {:?}, sql: {}",
111 client_id, duration, info.mode, info.first_statement
112 );
113 info.warning_issued = true;
114 self.warnings_issued.fetch_add(1, Ordering::Relaxed);
115 }
116 }
117
118 for client_id in &leaked {
120 txns.remove(client_id);
121 }
122
123 leaked
124 }
125
126 pub fn stats(&self) -> TransactionLeakStats {
128 let txns = self.active_transactions.read();
129 TransactionLeakStats {
130 active_transactions: txns.len(),
131 warnings_issued: self.warnings_issued.load(Ordering::Relaxed),
132 force_closed: self.force_closed.load(Ordering::Relaxed),
133 warning_threshold_secs: self.warning_threshold.as_secs(),
134 critical_threshold_secs: self.critical_threshold.as_secs(),
135 }
136 }
137}
138
139#[derive(Debug, Clone)]
141pub struct TransactionLeakStats {
142 pub active_transactions: usize,
144 pub warnings_issued: u64,
146 pub force_closed: u64,
148 pub warning_threshold_secs: u64,
150 pub critical_threshold_secs: u64,
152}
153
154#[derive(Debug)]
158pub struct ConnectionHealthValidator {
159 validation_query: String,
161 timeout: Duration,
163 validations: AtomicU64,
165 failures: AtomicU64,
167}
168
169impl Default for ConnectionHealthValidator {
170 fn default() -> Self {
171 Self::new("SELECT 1", Duration::from_secs(5))
172 }
173}
174
175impl ConnectionHealthValidator {
176 pub fn new(validation_query: impl Into<String>, timeout: Duration) -> Self {
178 Self {
179 validation_query: validation_query.into(),
180 timeout,
181 validations: AtomicU64::new(0),
182 failures: AtomicU64::new(0),
183 }
184 }
185
186 pub fn validation_query(&self) -> &str {
188 &self.validation_query
189 }
190
191 pub fn timeout(&self) -> Duration {
193 self.timeout
194 }
195
196 pub fn record_validation(&self, success: bool) {
198 self.validations.fetch_add(1, Ordering::Relaxed);
199 if !success {
200 self.failures.fetch_add(1, Ordering::Relaxed);
201 }
202 }
203
204 pub fn stats(&self) -> ValidationStats {
206 ValidationStats {
207 validations: self.validations.load(Ordering::Relaxed),
208 failures: self.failures.load(Ordering::Relaxed),
209 }
210 }
211
212 pub fn success_rate(&self) -> f64 {
214 let total = self.validations.load(Ordering::Relaxed);
215 let failures = self.failures.load(Ordering::Relaxed);
216 if total == 0 {
217 1.0
218 } else {
219 (total - failures) as f64 / total as f64
220 }
221 }
222}
223
224#[derive(Debug, Clone)]
226pub struct ValidationStats {
227 pub validations: u64,
229 pub failures: u64,
231}
232
233#[derive(Debug)]
237pub struct StaleLeaseCleaner {
238 max_idle_time: Duration,
240 lease_activity: RwLock<HashMap<ClientId, Instant>>,
242 cleaned_count: AtomicU64,
244}
245
246impl Default for StaleLeaseCleaner {
247 fn default() -> Self {
248 Self::new(Duration::from_secs(1800)) }
250}
251
252impl StaleLeaseCleaner {
253 pub fn new(max_idle_time: Duration) -> Self {
255 Self {
256 max_idle_time,
257 lease_activity: RwLock::new(HashMap::new()),
258 cleaned_count: AtomicU64::new(0),
259 }
260 }
261
262 pub fn record_activity(&self, client_id: ClientId) {
264 self.lease_activity.write().insert(client_id, Instant::now());
265 }
266
267 pub fn lease_released(&self, client_id: &ClientId) {
269 self.lease_activity.write().remove(client_id);
270 }
271
272 pub fn find_stale_leases(&self) -> Vec<ClientId> {
274 let now = Instant::now();
275 let activity = self.lease_activity.read();
276
277 activity
278 .iter()
279 .filter(|(_, last_activity)| now.duration_since(**last_activity) > self.max_idle_time)
280 .map(|(client_id, _)| *client_id)
281 .collect()
282 }
283
284 pub fn clean_stale(&self) -> Vec<ClientId> {
286 let stale = self.find_stale_leases();
287 let count = stale.len();
288
289 if count > 0 {
290 let mut activity = self.lease_activity.write();
291 for client_id in &stale {
292 activity.remove(client_id);
293 }
294 self.cleaned_count.fetch_add(count as u64, Ordering::Relaxed);
295
296 info!(
297 "Cleaned {} stale leases (idle > {:?})",
298 count, self.max_idle_time
299 );
300 }
301
302 stale
303 }
304
305 pub fn cleaned_count(&self) -> u64 {
307 self.cleaned_count.load(Ordering::Relaxed)
308 }
309}
310
311#[derive(Debug)]
315pub struct PoolExhaustionMonitor {
316 max_queue_size: usize,
318 current_queue: AtomicU64,
320 exhaustion_events: AtomicU64,
322 rejected_requests: AtomicU64,
324 enable_backpressure: bool,
326}
327
328impl Default for PoolExhaustionMonitor {
329 fn default() -> Self {
330 Self::new(1000, true)
331 }
332}
333
334impl PoolExhaustionMonitor {
335 pub fn new(max_queue_size: usize, enable_backpressure: bool) -> Self {
337 Self {
338 max_queue_size,
339 current_queue: AtomicU64::new(0),
340 exhaustion_events: AtomicU64::new(0),
341 rejected_requests: AtomicU64::new(0),
342 enable_backpressure,
343 }
344 }
345
346 pub fn check_capacity(&self) -> Result<()> {
350 let queue_size = self.current_queue.load(Ordering::Relaxed);
351
352 if self.enable_backpressure && queue_size >= self.max_queue_size as u64 {
353 self.rejected_requests.fetch_add(1, Ordering::Relaxed);
354 return Err(ProxyError::PoolExhausted(format!(
355 "Pool queue full ({} waiting), request rejected",
356 queue_size
357 )));
358 }
359
360 Ok(())
361 }
362
363 pub fn enter_queue(&self) {
365 let prev = self.current_queue.fetch_add(1, Ordering::Relaxed);
366 if prev == 0 {
367 self.exhaustion_events.fetch_add(1, Ordering::Relaxed);
369 debug!("Pool exhaustion event - requests now queuing");
370 }
371 }
372
373 pub fn leave_queue(&self) {
375 self.current_queue.fetch_sub(1, Ordering::Relaxed);
376 }
377
378 pub fn queue_size(&self) -> u64 {
380 self.current_queue.load(Ordering::Relaxed)
381 }
382
383 pub fn stats(&self) -> ExhaustionStats {
385 ExhaustionStats {
386 current_queue: self.current_queue.load(Ordering::Relaxed),
387 max_queue_size: self.max_queue_size as u64,
388 exhaustion_events: self.exhaustion_events.load(Ordering::Relaxed),
389 rejected_requests: self.rejected_requests.load(Ordering::Relaxed),
390 backpressure_enabled: self.enable_backpressure,
391 }
392 }
393}
394
395#[derive(Debug, Clone)]
397pub struct ExhaustionStats {
398 pub current_queue: u64,
400 pub max_queue_size: u64,
402 pub exhaustion_events: u64,
404 pub rejected_requests: u64,
406 pub backpressure_enabled: bool,
408}
409
410#[derive(Debug)]
412pub struct PoolHardening {
413 pub leak_detector: TransactionLeakDetector,
415 pub health_validator: ConnectionHealthValidator,
417 pub stale_cleaner: StaleLeaseCleaner,
419 pub exhaustion_monitor: PoolExhaustionMonitor,
421}
422
423impl Default for PoolHardening {
424 fn default() -> Self {
425 Self {
426 leak_detector: TransactionLeakDetector::default(),
427 health_validator: ConnectionHealthValidator::default(),
428 stale_cleaner: StaleLeaseCleaner::default(),
429 exhaustion_monitor: PoolExhaustionMonitor::default(),
430 }
431 }
432}
433
434impl PoolHardening {
435 pub fn new(
437 tx_warning_threshold: Duration,
438 tx_critical_threshold: Duration,
439 validation_query: &str,
440 validation_timeout: Duration,
441 max_lease_idle: Duration,
442 max_queue_size: usize,
443 enable_backpressure: bool,
444 ) -> Self {
445 Self {
446 leak_detector: TransactionLeakDetector::new(tx_warning_threshold, tx_critical_threshold),
447 health_validator: ConnectionHealthValidator::new(validation_query, validation_timeout),
448 stale_cleaner: StaleLeaseCleaner::new(max_lease_idle),
449 exhaustion_monitor: PoolExhaustionMonitor::new(max_queue_size, enable_backpressure),
450 }
451 }
452
453 pub fn run_maintenance(&self) -> (Vec<ClientId>, Vec<ClientId>) {
457 let leaked = self.leak_detector.check_for_leaks();
458 let stale = self.stale_cleaner.clean_stale();
459 (leaked, stale)
460 }
461
462 pub fn stats(&self) -> HardeningStats {
464 HardeningStats {
465 leak_stats: self.leak_detector.stats(),
466 validation_stats: self.health_validator.stats(),
467 exhaustion_stats: self.exhaustion_monitor.stats(),
468 stale_cleaned: self.stale_cleaner.cleaned_count(),
469 }
470 }
471}
472
473#[derive(Debug, Clone)]
475pub struct HardeningStats {
476 pub leak_stats: TransactionLeakStats,
478 pub validation_stats: ValidationStats,
480 pub exhaustion_stats: ExhaustionStats,
482 pub stale_cleaned: u64,
484}
485
486fn truncate_sql(sql: &str, max_len: usize) -> String {
488 if sql.len() <= max_len {
489 sql.to_string()
490 } else {
491 format!("{}...", &sql[..max_len])
492 }
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498
499 #[test]
500 fn test_transaction_leak_detector() {
501 let detector = TransactionLeakDetector::new(
502 Duration::from_millis(10),
503 Duration::from_millis(50),
504 );
505
506 let client1 = ClientId::new();
507 let client2 = ClientId::new();
508
509 detector.transaction_started(client1, PoolingMode::Transaction, "BEGIN; SELECT * FROM users");
511 detector.transaction_started(client2, PoolingMode::Statement, "SELECT 1");
512
513 assert!(detector.check_for_leaks().is_empty());
515
516 detector.transaction_ended(&client2);
518
519 std::thread::sleep(Duration::from_millis(15));
521 let leaked = detector.check_for_leaks();
522 assert!(leaked.is_empty()); std::thread::sleep(Duration::from_millis(40));
526 let leaked = detector.check_for_leaks();
527 assert_eq!(leaked.len(), 1);
528 assert_eq!(leaked[0], client1);
529 }
530
531 #[test]
532 fn test_connection_health_validator() {
533 let validator = ConnectionHealthValidator::default();
534
535 validator.record_validation(true);
536 validator.record_validation(true);
537 validator.record_validation(false);
538
539 assert_eq!(validator.stats().validations, 3);
540 assert_eq!(validator.stats().failures, 1);
541 assert!((validator.success_rate() - 0.666).abs() < 0.01);
542 }
543
544 #[test]
545 fn test_stale_lease_cleaner() {
546 let cleaner = StaleLeaseCleaner::new(Duration::from_millis(20));
547
548 let client1 = ClientId::new();
549 let client2 = ClientId::new();
550
551 cleaner.record_activity(client1);
552 cleaner.record_activity(client2);
553
554 assert!(cleaner.find_stale_leases().is_empty());
556
557 std::thread::sleep(Duration::from_millis(25));
559 cleaner.record_activity(client1);
560
561 let stale = cleaner.clean_stale();
563 assert_eq!(stale.len(), 1);
564 assert_eq!(stale[0], client2);
565 assert_eq!(cleaner.cleaned_count(), 1);
566 }
567
568 #[test]
569 fn test_pool_exhaustion_monitor() {
570 let monitor = PoolExhaustionMonitor::new(2, true);
571
572 assert!(monitor.check_capacity().is_ok());
574 monitor.enter_queue();
575 assert!(monitor.check_capacity().is_ok());
576 monitor.enter_queue();
577
578 assert!(monitor.check_capacity().is_err());
580 assert_eq!(monitor.stats().rejected_requests, 1);
581
582 monitor.leave_queue();
584 assert!(monitor.check_capacity().is_ok());
585 }
586
587 #[test]
588 fn test_pool_hardening_combined() {
589 let hardening = PoolHardening::default();
590
591 let (leaked, stale) = hardening.run_maintenance();
593 assert!(leaked.is_empty());
594 assert!(stale.is_empty());
595
596 let stats = hardening.stats();
598 assert_eq!(stats.leak_stats.active_transactions, 0);
599 assert_eq!(stats.stale_cleaned, 0);
600 }
601}