Skip to main content

heliosdb_proxy/pool/
hardening.rs

1//! Connection Pool Hardening
2//!
3//! Additional safety and reliability features for connection pooling:
4//! - Transaction leak detection
5//! - Connection health validation
6//! - Stale lease cleanup
7//! - Pool exhaustion monitoring
8
9use super::lease::ClientId;
10use super::mode::PoolingMode;
11use crate::{ProxyError, Result};
12use std::collections::HashMap;
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::time::{Duration, Instant};
15use parking_lot::RwLock;
16use tracing::{warn, info, debug};
17
18/// Transaction leak detector
19///
20/// Tracks active transactions and warns when they exceed expected lifetimes.
21/// Helps identify abandoned transactions that could block connections.
22#[derive(Debug)]
23pub struct TransactionLeakDetector {
24    /// Active transactions: client_id -> (start_time, mode)
25    active_transactions: RwLock<HashMap<ClientId, TransactionInfo>>,
26    /// Warning threshold for transaction duration
27    warning_threshold: Duration,
28    /// Critical threshold - transaction is considered leaked
29    critical_threshold: Duration,
30    /// Number of leak warnings issued
31    warnings_issued: AtomicU64,
32    /// Number of transactions forced-closed
33    force_closed: AtomicU64,
34}
35
36/// Information about an active transaction
37#[derive(Debug, Clone)]
38struct TransactionInfo {
39    /// When the transaction started
40    started_at: Instant,
41    /// Pooling mode of the connection
42    mode: PoolingMode,
43    /// First SQL statement (truncated)
44    first_statement: String,
45    /// Whether warning has been issued
46    warning_issued: bool,
47}
48
49impl Default for TransactionLeakDetector {
50    fn default() -> Self {
51        Self::new(Duration::from_secs(60), Duration::from_secs(300))
52    }
53}
54
55impl TransactionLeakDetector {
56    /// Create a new transaction leak detector
57    ///
58    /// # Arguments
59    /// * `warning_threshold` - Duration after which to issue a warning
60    /// * `critical_threshold` - Duration after which transaction is considered leaked
61    pub fn new(warning_threshold: Duration, critical_threshold: Duration) -> Self {
62        Self {
63            active_transactions: RwLock::new(HashMap::new()),
64            warning_threshold,
65            critical_threshold,
66            warnings_issued: AtomicU64::new(0),
67            force_closed: AtomicU64::new(0),
68        }
69    }
70
71    /// Track the start of a transaction
72    pub fn transaction_started(&self, client_id: ClientId, mode: PoolingMode, first_sql: &str) {
73        let info = TransactionInfo {
74            started_at: Instant::now(),
75            mode,
76            first_statement: truncate_sql(first_sql, 100),
77            warning_issued: false,
78        };
79        self.active_transactions.write().insert(client_id, info);
80    }
81
82    /// Track the end of a transaction
83    pub fn transaction_ended(&self, client_id: &ClientId) {
84        self.active_transactions.write().remove(client_id);
85    }
86
87    /// Check for leaked transactions and issue warnings
88    ///
89    /// Returns list of client IDs that have exceeded the critical threshold
90    pub fn check_for_leaks(&self) -> Vec<ClientId> {
91        let now = Instant::now();
92        let mut leaked = Vec::new();
93        let mut txns = self.active_transactions.write();
94
95        for (client_id, info) in txns.iter_mut() {
96            let duration = now.duration_since(info.started_at);
97
98            // Check critical threshold first
99            if duration >= self.critical_threshold {
100                leaked.push(*client_id);
101                warn!(
102                    "CRITICAL: Transaction leak detected for client {:?}, running for {:?}, mode: {:?}, sql: {}",
103                    client_id, duration, info.mode, info.first_statement
104                );
105                self.force_closed.fetch_add(1, Ordering::Relaxed);
106            }
107            // Then check warning threshold
108            else if duration >= self.warning_threshold && !info.warning_issued {
109                warn!(
110                    "Long-running transaction for client {:?}, running for {:?}, mode: {:?}, sql: {}",
111                    client_id, duration, info.mode, info.first_statement
112                );
113                info.warning_issued = true;
114                self.warnings_issued.fetch_add(1, Ordering::Relaxed);
115            }
116        }
117
118        // Remove leaked transactions from tracking (they'll be force-closed)
119        for client_id in &leaked {
120            txns.remove(client_id);
121        }
122
123        leaked
124    }
125
126    /// Get statistics about transaction tracking
127    pub fn stats(&self) -> TransactionLeakStats {
128        let txns = self.active_transactions.read();
129        TransactionLeakStats {
130            active_transactions: txns.len(),
131            warnings_issued: self.warnings_issued.load(Ordering::Relaxed),
132            force_closed: self.force_closed.load(Ordering::Relaxed),
133            warning_threshold_secs: self.warning_threshold.as_secs(),
134            critical_threshold_secs: self.critical_threshold.as_secs(),
135        }
136    }
137}
138
139/// Transaction leak statistics
140#[derive(Debug, Clone)]
141pub struct TransactionLeakStats {
142    /// Currently active transactions being tracked
143    pub active_transactions: usize,
144    /// Total warnings issued
145    pub warnings_issued: u64,
146    /// Total transactions force-closed
147    pub force_closed: u64,
148    /// Warning threshold in seconds
149    pub warning_threshold_secs: u64,
150    /// Critical threshold in seconds
151    pub critical_threshold_secs: u64,
152}
153
154/// Connection health validator
155///
156/// Validates that connections are healthy before returning them from the pool.
157#[derive(Debug)]
158pub struct ConnectionHealthValidator {
159    /// Query to execute for validation
160    validation_query: String,
161    /// Validation timeout
162    timeout: Duration,
163    /// Total validations performed
164    validations: AtomicU64,
165    /// Validation failures
166    failures: AtomicU64,
167}
168
169impl Default for ConnectionHealthValidator {
170    fn default() -> Self {
171        Self::new("SELECT 1", Duration::from_secs(5))
172    }
173}
174
175impl ConnectionHealthValidator {
176    /// Create a new health validator
177    pub fn new(validation_query: impl Into<String>, timeout: Duration) -> Self {
178        Self {
179            validation_query: validation_query.into(),
180            timeout,
181            validations: AtomicU64::new(0),
182            failures: AtomicU64::new(0),
183        }
184    }
185
186    /// Get the validation query
187    pub fn validation_query(&self) -> &str {
188        &self.validation_query
189    }
190
191    /// Get the timeout
192    pub fn timeout(&self) -> Duration {
193        self.timeout
194    }
195
196    /// Record a validation attempt
197    pub fn record_validation(&self, success: bool) {
198        self.validations.fetch_add(1, Ordering::Relaxed);
199        if !success {
200            self.failures.fetch_add(1, Ordering::Relaxed);
201        }
202    }
203
204    /// Get validation statistics
205    pub fn stats(&self) -> ValidationStats {
206        ValidationStats {
207            validations: self.validations.load(Ordering::Relaxed),
208            failures: self.failures.load(Ordering::Relaxed),
209        }
210    }
211
212    /// Calculate success rate
213    pub fn success_rate(&self) -> f64 {
214        let total = self.validations.load(Ordering::Relaxed);
215        let failures = self.failures.load(Ordering::Relaxed);
216        if total == 0 {
217            1.0
218        } else {
219            (total - failures) as f64 / total as f64
220        }
221    }
222}
223
224/// Validation statistics
225#[derive(Debug, Clone)]
226pub struct ValidationStats {
227    /// Total validations
228    pub validations: u64,
229    /// Failed validations
230    pub failures: u64,
231}
232
233/// Stale lease cleaner
234///
235/// Identifies and cleans up leases that have been held too long without activity.
236#[derive(Debug)]
237pub struct StaleLeaseCleaner {
238    /// Maximum idle time before a lease is considered stale
239    max_idle_time: Duration,
240    /// Tracked lease activity: client_id -> last_activity
241    lease_activity: RwLock<HashMap<ClientId, Instant>>,
242    /// Leases cleaned up
243    cleaned_count: AtomicU64,
244}
245
246impl Default for StaleLeaseCleaner {
247    fn default() -> Self {
248        Self::new(Duration::from_secs(1800)) // 30 minutes default
249    }
250}
251
252impl StaleLeaseCleaner {
253    /// Create a new stale lease cleaner
254    pub fn new(max_idle_time: Duration) -> Self {
255        Self {
256            max_idle_time,
257            lease_activity: RwLock::new(HashMap::new()),
258            cleaned_count: AtomicU64::new(0),
259        }
260    }
261
262    /// Record activity for a lease
263    pub fn record_activity(&self, client_id: ClientId) {
264        self.lease_activity.write().insert(client_id, Instant::now());
265    }
266
267    /// Remove tracking for a lease
268    pub fn lease_released(&self, client_id: &ClientId) {
269        self.lease_activity.write().remove(client_id);
270    }
271
272    /// Find stale leases that should be cleaned up
273    pub fn find_stale_leases(&self) -> Vec<ClientId> {
274        let now = Instant::now();
275        let activity = self.lease_activity.read();
276
277        activity
278            .iter()
279            .filter(|(_, last_activity)| now.duration_since(**last_activity) > self.max_idle_time)
280            .map(|(client_id, _)| *client_id)
281            .collect()
282    }
283
284    /// Clean up stale leases and return their IDs
285    pub fn clean_stale(&self) -> Vec<ClientId> {
286        let stale = self.find_stale_leases();
287        let count = stale.len();
288
289        if count > 0 {
290            let mut activity = self.lease_activity.write();
291            for client_id in &stale {
292                activity.remove(client_id);
293            }
294            self.cleaned_count.fetch_add(count as u64, Ordering::Relaxed);
295
296            info!(
297                "Cleaned {} stale leases (idle > {:?})",
298                count, self.max_idle_time
299            );
300        }
301
302        stale
303    }
304
305    /// Get cleaned count
306    pub fn cleaned_count(&self) -> u64 {
307        self.cleaned_count.load(Ordering::Relaxed)
308    }
309}
310
311/// Pool exhaustion monitor
312///
313/// Tracks pool exhaustion events and can trigger alerts or backpressure.
314#[derive(Debug)]
315pub struct PoolExhaustionMonitor {
316    /// Maximum queue size before rejecting
317    max_queue_size: usize,
318    /// Current queue size
319    current_queue: AtomicU64,
320    /// Total exhaustion events
321    exhaustion_events: AtomicU64,
322    /// Total requests rejected due to queue full
323    rejected_requests: AtomicU64,
324    /// Whether to enable backpressure (reject requests when pool full)
325    enable_backpressure: bool,
326}
327
328impl Default for PoolExhaustionMonitor {
329    fn default() -> Self {
330        Self::new(1000, true)
331    }
332}
333
334impl PoolExhaustionMonitor {
335    /// Create a new exhaustion monitor
336    pub fn new(max_queue_size: usize, enable_backpressure: bool) -> Self {
337        Self {
338            max_queue_size,
339            current_queue: AtomicU64::new(0),
340            exhaustion_events: AtomicU64::new(0),
341            rejected_requests: AtomicU64::new(0),
342            enable_backpressure,
343        }
344    }
345
346    /// Check if a request should be queued or rejected
347    ///
348    /// Returns Ok(()) if request can proceed, Err if should be rejected
349    pub fn check_capacity(&self) -> Result<()> {
350        let queue_size = self.current_queue.load(Ordering::Relaxed);
351
352        if self.enable_backpressure && queue_size >= self.max_queue_size as u64 {
353            self.rejected_requests.fetch_add(1, Ordering::Relaxed);
354            return Err(ProxyError::PoolExhausted(format!(
355                "Pool queue full ({} waiting), request rejected",
356                queue_size
357            )));
358        }
359
360        Ok(())
361    }
362
363    /// Record entering the wait queue
364    pub fn enter_queue(&self) {
365        let prev = self.current_queue.fetch_add(1, Ordering::Relaxed);
366        if prev == 0 {
367            // First waiter means pool is exhausted
368            self.exhaustion_events.fetch_add(1, Ordering::Relaxed);
369            debug!("Pool exhaustion event - requests now queuing");
370        }
371    }
372
373    /// Record leaving the wait queue (got a connection)
374    pub fn leave_queue(&self) {
375        self.current_queue.fetch_sub(1, Ordering::Relaxed);
376    }
377
378    /// Get current queue size
379    pub fn queue_size(&self) -> u64 {
380        self.current_queue.load(Ordering::Relaxed)
381    }
382
383    /// Get exhaustion statistics
384    pub fn stats(&self) -> ExhaustionStats {
385        ExhaustionStats {
386            current_queue: self.current_queue.load(Ordering::Relaxed),
387            max_queue_size: self.max_queue_size as u64,
388            exhaustion_events: self.exhaustion_events.load(Ordering::Relaxed),
389            rejected_requests: self.rejected_requests.load(Ordering::Relaxed),
390            backpressure_enabled: self.enable_backpressure,
391        }
392    }
393}
394
395/// Pool exhaustion statistics
396#[derive(Debug, Clone)]
397pub struct ExhaustionStats {
398    /// Current requests waiting in queue
399    pub current_queue: u64,
400    /// Maximum queue size
401    pub max_queue_size: u64,
402    /// Total exhaustion events
403    pub exhaustion_events: u64,
404    /// Total rejected requests
405    pub rejected_requests: u64,
406    /// Whether backpressure is enabled
407    pub backpressure_enabled: bool,
408}
409
410/// Combined hardening features
411#[derive(Debug)]
412pub struct PoolHardening {
413    /// Transaction leak detector
414    pub leak_detector: TransactionLeakDetector,
415    /// Connection health validator
416    pub health_validator: ConnectionHealthValidator,
417    /// Stale lease cleaner
418    pub stale_cleaner: StaleLeaseCleaner,
419    /// Pool exhaustion monitor
420    pub exhaustion_monitor: PoolExhaustionMonitor,
421}
422
423impl Default for PoolHardening {
424    fn default() -> Self {
425        Self {
426            leak_detector: TransactionLeakDetector::default(),
427            health_validator: ConnectionHealthValidator::default(),
428            stale_cleaner: StaleLeaseCleaner::default(),
429            exhaustion_monitor: PoolExhaustionMonitor::default(),
430        }
431    }
432}
433
434impl PoolHardening {
435    /// Create with custom configuration
436    pub fn new(
437        tx_warning_threshold: Duration,
438        tx_critical_threshold: Duration,
439        validation_query: &str,
440        validation_timeout: Duration,
441        max_lease_idle: Duration,
442        max_queue_size: usize,
443        enable_backpressure: bool,
444    ) -> Self {
445        Self {
446            leak_detector: TransactionLeakDetector::new(tx_warning_threshold, tx_critical_threshold),
447            health_validator: ConnectionHealthValidator::new(validation_query, validation_timeout),
448            stale_cleaner: StaleLeaseCleaner::new(max_lease_idle),
449            exhaustion_monitor: PoolExhaustionMonitor::new(max_queue_size, enable_backpressure),
450        }
451    }
452
453    /// Run periodic maintenance
454    ///
455    /// Returns (leaked_txns, stale_leases) that need to be cleaned up
456    pub fn run_maintenance(&self) -> (Vec<ClientId>, Vec<ClientId>) {
457        let leaked = self.leak_detector.check_for_leaks();
458        let stale = self.stale_cleaner.clean_stale();
459        (leaked, stale)
460    }
461
462    /// Get combined statistics
463    pub fn stats(&self) -> HardeningStats {
464        HardeningStats {
465            leak_stats: self.leak_detector.stats(),
466            validation_stats: self.health_validator.stats(),
467            exhaustion_stats: self.exhaustion_monitor.stats(),
468            stale_cleaned: self.stale_cleaner.cleaned_count(),
469        }
470    }
471}
472
473/// Combined hardening statistics
474#[derive(Debug, Clone)]
475pub struct HardeningStats {
476    /// Transaction leak statistics
477    pub leak_stats: TransactionLeakStats,
478    /// Validation statistics
479    pub validation_stats: ValidationStats,
480    /// Exhaustion statistics
481    pub exhaustion_stats: ExhaustionStats,
482    /// Stale leases cleaned
483    pub stale_cleaned: u64,
484}
485
486/// Truncate SQL for logging
487fn truncate_sql(sql: &str, max_len: usize) -> String {
488    if sql.len() <= max_len {
489        sql.to_string()
490    } else {
491        format!("{}...", &sql[..max_len])
492    }
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498
499    #[test]
500    fn test_transaction_leak_detector() {
501        let detector = TransactionLeakDetector::new(
502            Duration::from_millis(10),
503            Duration::from_millis(50),
504        );
505
506        let client1 = ClientId::new();
507        let client2 = ClientId::new();
508
509        // Start two transactions
510        detector.transaction_started(client1, PoolingMode::Transaction, "BEGIN; SELECT * FROM users");
511        detector.transaction_started(client2, PoolingMode::Statement, "SELECT 1");
512
513        // No leaks immediately
514        assert!(detector.check_for_leaks().is_empty());
515
516        // End one transaction
517        detector.transaction_ended(&client2);
518
519        // Wait for warning threshold
520        std::thread::sleep(Duration::from_millis(15));
521        let leaked = detector.check_for_leaks();
522        assert!(leaked.is_empty()); // Warning issued but not critical yet
523
524        // Wait for critical threshold
525        std::thread::sleep(Duration::from_millis(40));
526        let leaked = detector.check_for_leaks();
527        assert_eq!(leaked.len(), 1);
528        assert_eq!(leaked[0], client1);
529    }
530
531    #[test]
532    fn test_connection_health_validator() {
533        let validator = ConnectionHealthValidator::default();
534
535        validator.record_validation(true);
536        validator.record_validation(true);
537        validator.record_validation(false);
538
539        assert_eq!(validator.stats().validations, 3);
540        assert_eq!(validator.stats().failures, 1);
541        assert!((validator.success_rate() - 0.666).abs() < 0.01);
542    }
543
544    #[test]
545    fn test_stale_lease_cleaner() {
546        let cleaner = StaleLeaseCleaner::new(Duration::from_millis(20));
547
548        let client1 = ClientId::new();
549        let client2 = ClientId::new();
550
551        cleaner.record_activity(client1);
552        cleaner.record_activity(client2);
553
554        // No stale leases immediately
555        assert!(cleaner.find_stale_leases().is_empty());
556
557        // Wait and only update client1
558        std::thread::sleep(Duration::from_millis(25));
559        cleaner.record_activity(client1);
560
561        // client2 should be stale
562        let stale = cleaner.clean_stale();
563        assert_eq!(stale.len(), 1);
564        assert_eq!(stale[0], client2);
565        assert_eq!(cleaner.cleaned_count(), 1);
566    }
567
568    #[test]
569    fn test_pool_exhaustion_monitor() {
570        let monitor = PoolExhaustionMonitor::new(2, true);
571
572        // First two requests OK
573        assert!(monitor.check_capacity().is_ok());
574        monitor.enter_queue();
575        assert!(monitor.check_capacity().is_ok());
576        monitor.enter_queue();
577
578        // Third should be rejected (backpressure)
579        assert!(monitor.check_capacity().is_err());
580        assert_eq!(monitor.stats().rejected_requests, 1);
581
582        // Leave queue
583        monitor.leave_queue();
584        assert!(monitor.check_capacity().is_ok());
585    }
586
587    #[test]
588    fn test_pool_hardening_combined() {
589        let hardening = PoolHardening::default();
590
591        // Run maintenance on empty state
592        let (leaked, stale) = hardening.run_maintenance();
593        assert!(leaked.is_empty());
594        assert!(stale.is_empty());
595
596        // Check stats
597        let stats = hardening.stats();
598        assert_eq!(stats.leak_stats.active_transactions, 0);
599        assert_eq!(stats.stale_cleaned, 0);
600    }
601}