1use super::lease::ClientId;
10use super::mode::PoolingMode;
11use crate::{ProxyError, Result};
12use parking_lot::RwLock;
13use std::collections::HashMap;
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::time::{Duration, Instant};
16use tracing::{debug, info, warn};
17
18#[derive(Debug)]
23pub struct TransactionLeakDetector {
24 active_transactions: RwLock<HashMap<ClientId, TransactionInfo>>,
26 warning_threshold: Duration,
28 critical_threshold: Duration,
30 warnings_issued: AtomicU64,
32 force_closed: AtomicU64,
34}
35
36#[derive(Debug, Clone)]
38struct TransactionInfo {
39 started_at: Instant,
41 mode: PoolingMode,
43 first_statement: String,
45 warning_issued: bool,
47}
48
49impl Default for TransactionLeakDetector {
50 fn default() -> Self {
51 Self::new(Duration::from_secs(60), Duration::from_secs(300))
52 }
53}
54
55impl TransactionLeakDetector {
56 pub fn new(warning_threshold: Duration, critical_threshold: Duration) -> Self {
62 Self {
63 active_transactions: RwLock::new(HashMap::new()),
64 warning_threshold,
65 critical_threshold,
66 warnings_issued: AtomicU64::new(0),
67 force_closed: AtomicU64::new(0),
68 }
69 }
70
71 pub fn transaction_started(&self, client_id: ClientId, mode: PoolingMode, first_sql: &str) {
73 let info = TransactionInfo {
74 started_at: Instant::now(),
75 mode,
76 first_statement: truncate_sql(first_sql, 100),
77 warning_issued: false,
78 };
79 self.active_transactions.write().insert(client_id, info);
80 }
81
82 pub fn transaction_ended(&self, client_id: &ClientId) {
84 self.active_transactions.write().remove(client_id);
85 }
86
87 pub fn check_for_leaks(&self) -> Vec<ClientId> {
91 let now = Instant::now();
92 let mut leaked = Vec::new();
93 let mut txns = self.active_transactions.write();
94
95 for (client_id, info) in txns.iter_mut() {
96 let duration = now.duration_since(info.started_at);
97
98 if duration >= self.critical_threshold {
100 leaked.push(*client_id);
101 warn!(
102 "CRITICAL: Transaction leak detected for client {:?}, running for {:?}, mode: {:?}, sql: {}",
103 client_id, duration, info.mode, info.first_statement
104 );
105 self.force_closed.fetch_add(1, Ordering::Relaxed);
106 }
107 else if duration >= self.warning_threshold && !info.warning_issued {
109 warn!(
110 "Long-running transaction for client {:?}, running for {:?}, mode: {:?}, sql: {}",
111 client_id, duration, info.mode, info.first_statement
112 );
113 info.warning_issued = true;
114 self.warnings_issued.fetch_add(1, Ordering::Relaxed);
115 }
116 }
117
118 for client_id in &leaked {
120 txns.remove(client_id);
121 }
122
123 leaked
124 }
125
126 pub fn stats(&self) -> TransactionLeakStats {
128 let txns = self.active_transactions.read();
129 TransactionLeakStats {
130 active_transactions: txns.len(),
131 warnings_issued: self.warnings_issued.load(Ordering::Relaxed),
132 force_closed: self.force_closed.load(Ordering::Relaxed),
133 warning_threshold_secs: self.warning_threshold.as_secs(),
134 critical_threshold_secs: self.critical_threshold.as_secs(),
135 }
136 }
137}
138
139#[derive(Debug, Clone)]
141pub struct TransactionLeakStats {
142 pub active_transactions: usize,
144 pub warnings_issued: u64,
146 pub force_closed: u64,
148 pub warning_threshold_secs: u64,
150 pub critical_threshold_secs: u64,
152}
153
154#[derive(Debug)]
158pub struct ConnectionHealthValidator {
159 validation_query: String,
161 timeout: Duration,
163 validations: AtomicU64,
165 failures: AtomicU64,
167}
168
169impl Default for ConnectionHealthValidator {
170 fn default() -> Self {
171 Self::new("SELECT 1", Duration::from_secs(5))
172 }
173}
174
175impl ConnectionHealthValidator {
176 pub fn new(validation_query: impl Into<String>, timeout: Duration) -> Self {
178 Self {
179 validation_query: validation_query.into(),
180 timeout,
181 validations: AtomicU64::new(0),
182 failures: AtomicU64::new(0),
183 }
184 }
185
186 pub fn validation_query(&self) -> &str {
188 &self.validation_query
189 }
190
191 pub fn timeout(&self) -> Duration {
193 self.timeout
194 }
195
196 pub fn record_validation(&self, success: bool) {
198 self.validations.fetch_add(1, Ordering::Relaxed);
199 if !success {
200 self.failures.fetch_add(1, Ordering::Relaxed);
201 }
202 }
203
204 pub fn stats(&self) -> ValidationStats {
206 ValidationStats {
207 validations: self.validations.load(Ordering::Relaxed),
208 failures: self.failures.load(Ordering::Relaxed),
209 }
210 }
211
212 pub fn success_rate(&self) -> f64 {
214 let total = self.validations.load(Ordering::Relaxed);
215 let failures = self.failures.load(Ordering::Relaxed);
216 if total == 0 {
217 1.0
218 } else {
219 (total - failures) as f64 / total as f64
220 }
221 }
222}
223
224#[derive(Debug, Clone)]
226pub struct ValidationStats {
227 pub validations: u64,
229 pub failures: u64,
231}
232
233#[derive(Debug)]
237pub struct StaleLeaseCleaner {
238 max_idle_time: Duration,
240 lease_activity: RwLock<HashMap<ClientId, Instant>>,
242 cleaned_count: AtomicU64,
244}
245
246impl Default for StaleLeaseCleaner {
247 fn default() -> Self {
248 Self::new(Duration::from_secs(1800)) }
250}
251
252impl StaleLeaseCleaner {
253 pub fn new(max_idle_time: Duration) -> Self {
255 Self {
256 max_idle_time,
257 lease_activity: RwLock::new(HashMap::new()),
258 cleaned_count: AtomicU64::new(0),
259 }
260 }
261
262 pub fn record_activity(&self, client_id: ClientId) {
264 self.lease_activity
265 .write()
266 .insert(client_id, Instant::now());
267 }
268
269 pub fn lease_released(&self, client_id: &ClientId) {
271 self.lease_activity.write().remove(client_id);
272 }
273
274 pub fn find_stale_leases(&self) -> Vec<ClientId> {
276 let now = Instant::now();
277 let activity = self.lease_activity.read();
278
279 activity
280 .iter()
281 .filter(|(_, last_activity)| now.duration_since(**last_activity) > self.max_idle_time)
282 .map(|(client_id, _)| *client_id)
283 .collect()
284 }
285
286 pub fn clean_stale(&self) -> Vec<ClientId> {
288 let stale = self.find_stale_leases();
289 let count = stale.len();
290
291 if count > 0 {
292 let mut activity = self.lease_activity.write();
293 for client_id in &stale {
294 activity.remove(client_id);
295 }
296 self.cleaned_count
297 .fetch_add(count as u64, Ordering::Relaxed);
298
299 info!(
300 "Cleaned {} stale leases (idle > {:?})",
301 count, self.max_idle_time
302 );
303 }
304
305 stale
306 }
307
308 pub fn cleaned_count(&self) -> u64 {
310 self.cleaned_count.load(Ordering::Relaxed)
311 }
312}
313
314#[derive(Debug)]
318pub struct PoolExhaustionMonitor {
319 max_queue_size: usize,
321 current_queue: AtomicU64,
323 exhaustion_events: AtomicU64,
325 rejected_requests: AtomicU64,
327 enable_backpressure: bool,
329}
330
331impl Default for PoolExhaustionMonitor {
332 fn default() -> Self {
333 Self::new(1000, true)
334 }
335}
336
337impl PoolExhaustionMonitor {
338 pub fn new(max_queue_size: usize, enable_backpressure: bool) -> Self {
340 Self {
341 max_queue_size,
342 current_queue: AtomicU64::new(0),
343 exhaustion_events: AtomicU64::new(0),
344 rejected_requests: AtomicU64::new(0),
345 enable_backpressure,
346 }
347 }
348
349 pub fn check_capacity(&self) -> Result<()> {
353 let queue_size = self.current_queue.load(Ordering::Relaxed);
354
355 if self.enable_backpressure && queue_size >= self.max_queue_size as u64 {
356 self.rejected_requests.fetch_add(1, Ordering::Relaxed);
357 return Err(ProxyError::PoolExhausted(format!(
358 "Pool queue full ({} waiting), request rejected",
359 queue_size
360 )));
361 }
362
363 Ok(())
364 }
365
366 pub fn enter_queue(&self) {
368 let prev = self.current_queue.fetch_add(1, Ordering::Relaxed);
369 if prev == 0 {
370 self.exhaustion_events.fetch_add(1, Ordering::Relaxed);
372 debug!("Pool exhaustion event - requests now queuing");
373 }
374 }
375
376 pub fn leave_queue(&self) {
378 self.current_queue.fetch_sub(1, Ordering::Relaxed);
379 }
380
381 pub fn queue_size(&self) -> u64 {
383 self.current_queue.load(Ordering::Relaxed)
384 }
385
386 pub fn stats(&self) -> ExhaustionStats {
388 ExhaustionStats {
389 current_queue: self.current_queue.load(Ordering::Relaxed),
390 max_queue_size: self.max_queue_size as u64,
391 exhaustion_events: self.exhaustion_events.load(Ordering::Relaxed),
392 rejected_requests: self.rejected_requests.load(Ordering::Relaxed),
393 backpressure_enabled: self.enable_backpressure,
394 }
395 }
396}
397
398#[derive(Debug, Clone)]
400pub struct ExhaustionStats {
401 pub current_queue: u64,
403 pub max_queue_size: u64,
405 pub exhaustion_events: u64,
407 pub rejected_requests: u64,
409 pub backpressure_enabled: bool,
411}
412
413#[derive(Debug, Default)]
415pub struct PoolHardening {
416 pub leak_detector: TransactionLeakDetector,
418 pub health_validator: ConnectionHealthValidator,
420 pub stale_cleaner: StaleLeaseCleaner,
422 pub exhaustion_monitor: PoolExhaustionMonitor,
424}
425
426impl PoolHardening {
427 pub fn new(
429 tx_warning_threshold: Duration,
430 tx_critical_threshold: Duration,
431 validation_query: &str,
432 validation_timeout: Duration,
433 max_lease_idle: Duration,
434 max_queue_size: usize,
435 enable_backpressure: bool,
436 ) -> Self {
437 Self {
438 leak_detector: TransactionLeakDetector::new(
439 tx_warning_threshold,
440 tx_critical_threshold,
441 ),
442 health_validator: ConnectionHealthValidator::new(validation_query, validation_timeout),
443 stale_cleaner: StaleLeaseCleaner::new(max_lease_idle),
444 exhaustion_monitor: PoolExhaustionMonitor::new(max_queue_size, enable_backpressure),
445 }
446 }
447
448 pub fn run_maintenance(&self) -> (Vec<ClientId>, Vec<ClientId>) {
452 let leaked = self.leak_detector.check_for_leaks();
453 let stale = self.stale_cleaner.clean_stale();
454 (leaked, stale)
455 }
456
457 pub fn stats(&self) -> HardeningStats {
459 HardeningStats {
460 leak_stats: self.leak_detector.stats(),
461 validation_stats: self.health_validator.stats(),
462 exhaustion_stats: self.exhaustion_monitor.stats(),
463 stale_cleaned: self.stale_cleaner.cleaned_count(),
464 }
465 }
466}
467
468#[derive(Debug, Clone)]
470pub struct HardeningStats {
471 pub leak_stats: TransactionLeakStats,
473 pub validation_stats: ValidationStats,
475 pub exhaustion_stats: ExhaustionStats,
477 pub stale_cleaned: u64,
479}
480
481fn truncate_sql(sql: &str, max_len: usize) -> String {
483 if sql.len() <= max_len {
484 sql.to_string()
485 } else {
486 format!("{}...", &sql[..max_len])
487 }
488}
489
490#[cfg(test)]
491mod tests {
492 use super::*;
493
494 #[test]
495 fn test_transaction_leak_detector() {
496 let detector =
497 TransactionLeakDetector::new(Duration::from_millis(10), Duration::from_millis(50));
498
499 let client1 = ClientId::new();
500 let client2 = ClientId::new();
501
502 detector.transaction_started(
504 client1,
505 PoolingMode::Transaction,
506 "BEGIN; SELECT * FROM users",
507 );
508 detector.transaction_started(client2, PoolingMode::Statement, "SELECT 1");
509
510 assert!(detector.check_for_leaks().is_empty());
512
513 detector.transaction_ended(&client2);
515
516 std::thread::sleep(Duration::from_millis(15));
518 let leaked = detector.check_for_leaks();
519 assert!(leaked.is_empty()); std::thread::sleep(Duration::from_millis(40));
523 let leaked = detector.check_for_leaks();
524 assert_eq!(leaked.len(), 1);
525 assert_eq!(leaked[0], client1);
526 }
527
528 #[test]
529 fn test_connection_health_validator() {
530 let validator = ConnectionHealthValidator::default();
531
532 validator.record_validation(true);
533 validator.record_validation(true);
534 validator.record_validation(false);
535
536 assert_eq!(validator.stats().validations, 3);
537 assert_eq!(validator.stats().failures, 1);
538 assert!((validator.success_rate() - 0.666).abs() < 0.01);
539 }
540
541 #[test]
542 fn test_stale_lease_cleaner() {
543 let cleaner = StaleLeaseCleaner::new(Duration::from_millis(20));
544
545 let client1 = ClientId::new();
546 let client2 = ClientId::new();
547
548 cleaner.record_activity(client1);
549 cleaner.record_activity(client2);
550
551 assert!(cleaner.find_stale_leases().is_empty());
553
554 std::thread::sleep(Duration::from_millis(25));
556 cleaner.record_activity(client1);
557
558 let stale = cleaner.clean_stale();
560 assert_eq!(stale.len(), 1);
561 assert_eq!(stale[0], client2);
562 assert_eq!(cleaner.cleaned_count(), 1);
563 }
564
565 #[test]
566 fn test_pool_exhaustion_monitor() {
567 let monitor = PoolExhaustionMonitor::new(2, true);
568
569 assert!(monitor.check_capacity().is_ok());
571 monitor.enter_queue();
572 assert!(monitor.check_capacity().is_ok());
573 monitor.enter_queue();
574
575 assert!(monitor.check_capacity().is_err());
577 assert_eq!(monitor.stats().rejected_requests, 1);
578
579 monitor.leave_queue();
581 assert!(monitor.check_capacity().is_ok());
582 }
583
584 #[test]
585 fn test_pool_hardening_combined() {
586 let hardening = PoolHardening::default();
587
588 let (leaked, stale) = hardening.run_maintenance();
590 assert!(leaked.is_empty());
591 assert!(stale.is_empty());
592
593 let stats = hardening.stats();
595 assert_eq!(stats.leak_stats.active_transactions, 0);
596 assert_eq!(stats.stale_cleaned, 0);
597 }
598}