Skip to main content

heliosdb_proxy/pool/
hardening.rs

1//! Connection Pool Hardening
2//!
3//! Additional safety and reliability features for connection pooling:
4//! - Transaction leak detection
5//! - Connection health validation
6//! - Stale lease cleanup
7//! - Pool exhaustion monitoring
8
9use super::lease::ClientId;
10use super::mode::PoolingMode;
11use crate::{ProxyError, Result};
12use parking_lot::RwLock;
13use std::collections::HashMap;
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::time::{Duration, Instant};
16use tracing::{debug, info, warn};
17
18/// Transaction leak detector
19///
20/// Tracks active transactions and warns when they exceed expected lifetimes.
21/// Helps identify abandoned transactions that could block connections.
22#[derive(Debug)]
23pub struct TransactionLeakDetector {
24    /// Active transactions: client_id -> (start_time, mode)
25    active_transactions: RwLock<HashMap<ClientId, TransactionInfo>>,
26    /// Warning threshold for transaction duration
27    warning_threshold: Duration,
28    /// Critical threshold - transaction is considered leaked
29    critical_threshold: Duration,
30    /// Number of leak warnings issued
31    warnings_issued: AtomicU64,
32    /// Number of transactions forced-closed
33    force_closed: AtomicU64,
34}
35
36/// Information about an active transaction
37#[derive(Debug, Clone)]
38struct TransactionInfo {
39    /// When the transaction started
40    started_at: Instant,
41    /// Pooling mode of the connection
42    mode: PoolingMode,
43    /// First SQL statement (truncated)
44    first_statement: String,
45    /// Whether warning has been issued
46    warning_issued: bool,
47}
48
49impl Default for TransactionLeakDetector {
50    fn default() -> Self {
51        Self::new(Duration::from_secs(60), Duration::from_secs(300))
52    }
53}
54
55impl TransactionLeakDetector {
56    /// Create a new transaction leak detector
57    ///
58    /// # Arguments
59    /// * `warning_threshold` - Duration after which to issue a warning
60    /// * `critical_threshold` - Duration after which transaction is considered leaked
61    pub fn new(warning_threshold: Duration, critical_threshold: Duration) -> Self {
62        Self {
63            active_transactions: RwLock::new(HashMap::new()),
64            warning_threshold,
65            critical_threshold,
66            warnings_issued: AtomicU64::new(0),
67            force_closed: AtomicU64::new(0),
68        }
69    }
70
71    /// Track the start of a transaction
72    pub fn transaction_started(&self, client_id: ClientId, mode: PoolingMode, first_sql: &str) {
73        let info = TransactionInfo {
74            started_at: Instant::now(),
75            mode,
76            first_statement: truncate_sql(first_sql, 100),
77            warning_issued: false,
78        };
79        self.active_transactions.write().insert(client_id, info);
80    }
81
82    /// Track the end of a transaction
83    pub fn transaction_ended(&self, client_id: &ClientId) {
84        self.active_transactions.write().remove(client_id);
85    }
86
87    /// Check for leaked transactions and issue warnings
88    ///
89    /// Returns list of client IDs that have exceeded the critical threshold
90    pub fn check_for_leaks(&self) -> Vec<ClientId> {
91        let now = Instant::now();
92        let mut leaked = Vec::new();
93        let mut txns = self.active_transactions.write();
94
95        for (client_id, info) in txns.iter_mut() {
96            let duration = now.duration_since(info.started_at);
97
98            // Check critical threshold first
99            if duration >= self.critical_threshold {
100                leaked.push(*client_id);
101                warn!(
102                    "CRITICAL: Transaction leak detected for client {:?}, running for {:?}, mode: {:?}, sql: {}",
103                    client_id, duration, info.mode, info.first_statement
104                );
105                self.force_closed.fetch_add(1, Ordering::Relaxed);
106            }
107            // Then check warning threshold
108            else if duration >= self.warning_threshold && !info.warning_issued {
109                warn!(
110                    "Long-running transaction for client {:?}, running for {:?}, mode: {:?}, sql: {}",
111                    client_id, duration, info.mode, info.first_statement
112                );
113                info.warning_issued = true;
114                self.warnings_issued.fetch_add(1, Ordering::Relaxed);
115            }
116        }
117
118        // Remove leaked transactions from tracking (they'll be force-closed)
119        for client_id in &leaked {
120            txns.remove(client_id);
121        }
122
123        leaked
124    }
125
126    /// Get statistics about transaction tracking
127    pub fn stats(&self) -> TransactionLeakStats {
128        let txns = self.active_transactions.read();
129        TransactionLeakStats {
130            active_transactions: txns.len(),
131            warnings_issued: self.warnings_issued.load(Ordering::Relaxed),
132            force_closed: self.force_closed.load(Ordering::Relaxed),
133            warning_threshold_secs: self.warning_threshold.as_secs(),
134            critical_threshold_secs: self.critical_threshold.as_secs(),
135        }
136    }
137}
138
139/// Transaction leak statistics
140#[derive(Debug, Clone)]
141pub struct TransactionLeakStats {
142    /// Currently active transactions being tracked
143    pub active_transactions: usize,
144    /// Total warnings issued
145    pub warnings_issued: u64,
146    /// Total transactions force-closed
147    pub force_closed: u64,
148    /// Warning threshold in seconds
149    pub warning_threshold_secs: u64,
150    /// Critical threshold in seconds
151    pub critical_threshold_secs: u64,
152}
153
154/// Connection health validator
155///
156/// Validates that connections are healthy before returning them from the pool.
157#[derive(Debug)]
158pub struct ConnectionHealthValidator {
159    /// Query to execute for validation
160    validation_query: String,
161    /// Validation timeout
162    timeout: Duration,
163    /// Total validations performed
164    validations: AtomicU64,
165    /// Validation failures
166    failures: AtomicU64,
167}
168
169impl Default for ConnectionHealthValidator {
170    fn default() -> Self {
171        Self::new("SELECT 1", Duration::from_secs(5))
172    }
173}
174
175impl ConnectionHealthValidator {
176    /// Create a new health validator
177    pub fn new(validation_query: impl Into<String>, timeout: Duration) -> Self {
178        Self {
179            validation_query: validation_query.into(),
180            timeout,
181            validations: AtomicU64::new(0),
182            failures: AtomicU64::new(0),
183        }
184    }
185
186    /// Get the validation query
187    pub fn validation_query(&self) -> &str {
188        &self.validation_query
189    }
190
191    /// Get the timeout
192    pub fn timeout(&self) -> Duration {
193        self.timeout
194    }
195
196    /// Record a validation attempt
197    pub fn record_validation(&self, success: bool) {
198        self.validations.fetch_add(1, Ordering::Relaxed);
199        if !success {
200            self.failures.fetch_add(1, Ordering::Relaxed);
201        }
202    }
203
204    /// Get validation statistics
205    pub fn stats(&self) -> ValidationStats {
206        ValidationStats {
207            validations: self.validations.load(Ordering::Relaxed),
208            failures: self.failures.load(Ordering::Relaxed),
209        }
210    }
211
212    /// Calculate success rate
213    pub fn success_rate(&self) -> f64 {
214        let total = self.validations.load(Ordering::Relaxed);
215        let failures = self.failures.load(Ordering::Relaxed);
216        if total == 0 {
217            1.0
218        } else {
219            (total - failures) as f64 / total as f64
220        }
221    }
222}
223
224/// Validation statistics
225#[derive(Debug, Clone)]
226pub struct ValidationStats {
227    /// Total validations
228    pub validations: u64,
229    /// Failed validations
230    pub failures: u64,
231}
232
233/// Stale lease cleaner
234///
235/// Identifies and cleans up leases that have been held too long without activity.
236#[derive(Debug)]
237pub struct StaleLeaseCleaner {
238    /// Maximum idle time before a lease is considered stale
239    max_idle_time: Duration,
240    /// Tracked lease activity: client_id -> last_activity
241    lease_activity: RwLock<HashMap<ClientId, Instant>>,
242    /// Leases cleaned up
243    cleaned_count: AtomicU64,
244}
245
246impl Default for StaleLeaseCleaner {
247    fn default() -> Self {
248        Self::new(Duration::from_secs(1800)) // 30 minutes default
249    }
250}
251
252impl StaleLeaseCleaner {
253    /// Create a new stale lease cleaner
254    pub fn new(max_idle_time: Duration) -> Self {
255        Self {
256            max_idle_time,
257            lease_activity: RwLock::new(HashMap::new()),
258            cleaned_count: AtomicU64::new(0),
259        }
260    }
261
262    /// Record activity for a lease
263    pub fn record_activity(&self, client_id: ClientId) {
264        self.lease_activity
265            .write()
266            .insert(client_id, Instant::now());
267    }
268
269    /// Remove tracking for a lease
270    pub fn lease_released(&self, client_id: &ClientId) {
271        self.lease_activity.write().remove(client_id);
272    }
273
274    /// Find stale leases that should be cleaned up
275    pub fn find_stale_leases(&self) -> Vec<ClientId> {
276        let now = Instant::now();
277        let activity = self.lease_activity.read();
278
279        activity
280            .iter()
281            .filter(|(_, last_activity)| now.duration_since(**last_activity) > self.max_idle_time)
282            .map(|(client_id, _)| *client_id)
283            .collect()
284    }
285
286    /// Clean up stale leases and return their IDs
287    pub fn clean_stale(&self) -> Vec<ClientId> {
288        let stale = self.find_stale_leases();
289        let count = stale.len();
290
291        if count > 0 {
292            let mut activity = self.lease_activity.write();
293            for client_id in &stale {
294                activity.remove(client_id);
295            }
296            self.cleaned_count
297                .fetch_add(count as u64, Ordering::Relaxed);
298
299            info!(
300                "Cleaned {} stale leases (idle > {:?})",
301                count, self.max_idle_time
302            );
303        }
304
305        stale
306    }
307
308    /// Get cleaned count
309    pub fn cleaned_count(&self) -> u64 {
310        self.cleaned_count.load(Ordering::Relaxed)
311    }
312}
313
314/// Pool exhaustion monitor
315///
316/// Tracks pool exhaustion events and can trigger alerts or backpressure.
317#[derive(Debug)]
318pub struct PoolExhaustionMonitor {
319    /// Maximum queue size before rejecting
320    max_queue_size: usize,
321    /// Current queue size
322    current_queue: AtomicU64,
323    /// Total exhaustion events
324    exhaustion_events: AtomicU64,
325    /// Total requests rejected due to queue full
326    rejected_requests: AtomicU64,
327    /// Whether to enable backpressure (reject requests when pool full)
328    enable_backpressure: bool,
329}
330
331impl Default for PoolExhaustionMonitor {
332    fn default() -> Self {
333        Self::new(1000, true)
334    }
335}
336
337impl PoolExhaustionMonitor {
338    /// Create a new exhaustion monitor
339    pub fn new(max_queue_size: usize, enable_backpressure: bool) -> Self {
340        Self {
341            max_queue_size,
342            current_queue: AtomicU64::new(0),
343            exhaustion_events: AtomicU64::new(0),
344            rejected_requests: AtomicU64::new(0),
345            enable_backpressure,
346        }
347    }
348
349    /// Check if a request should be queued or rejected
350    ///
351    /// Returns Ok(()) if request can proceed, Err if should be rejected
352    pub fn check_capacity(&self) -> Result<()> {
353        let queue_size = self.current_queue.load(Ordering::Relaxed);
354
355        if self.enable_backpressure && queue_size >= self.max_queue_size as u64 {
356            self.rejected_requests.fetch_add(1, Ordering::Relaxed);
357            return Err(ProxyError::PoolExhausted(format!(
358                "Pool queue full ({} waiting), request rejected",
359                queue_size
360            )));
361        }
362
363        Ok(())
364    }
365
366    /// Record entering the wait queue
367    pub fn enter_queue(&self) {
368        let prev = self.current_queue.fetch_add(1, Ordering::Relaxed);
369        if prev == 0 {
370            // First waiter means pool is exhausted
371            self.exhaustion_events.fetch_add(1, Ordering::Relaxed);
372            debug!("Pool exhaustion event - requests now queuing");
373        }
374    }
375
376    /// Record leaving the wait queue (got a connection)
377    pub fn leave_queue(&self) {
378        self.current_queue.fetch_sub(1, Ordering::Relaxed);
379    }
380
381    /// Get current queue size
382    pub fn queue_size(&self) -> u64 {
383        self.current_queue.load(Ordering::Relaxed)
384    }
385
386    /// Get exhaustion statistics
387    pub fn stats(&self) -> ExhaustionStats {
388        ExhaustionStats {
389            current_queue: self.current_queue.load(Ordering::Relaxed),
390            max_queue_size: self.max_queue_size as u64,
391            exhaustion_events: self.exhaustion_events.load(Ordering::Relaxed),
392            rejected_requests: self.rejected_requests.load(Ordering::Relaxed),
393            backpressure_enabled: self.enable_backpressure,
394        }
395    }
396}
397
398/// Pool exhaustion statistics
399#[derive(Debug, Clone)]
400pub struct ExhaustionStats {
401    /// Current requests waiting in queue
402    pub current_queue: u64,
403    /// Maximum queue size
404    pub max_queue_size: u64,
405    /// Total exhaustion events
406    pub exhaustion_events: u64,
407    /// Total rejected requests
408    pub rejected_requests: u64,
409    /// Whether backpressure is enabled
410    pub backpressure_enabled: bool,
411}
412
413/// Combined hardening features
414#[derive(Debug, Default)]
415pub struct PoolHardening {
416    /// Transaction leak detector
417    pub leak_detector: TransactionLeakDetector,
418    /// Connection health validator
419    pub health_validator: ConnectionHealthValidator,
420    /// Stale lease cleaner
421    pub stale_cleaner: StaleLeaseCleaner,
422    /// Pool exhaustion monitor
423    pub exhaustion_monitor: PoolExhaustionMonitor,
424}
425
426impl PoolHardening {
427    /// Create with custom configuration
428    pub fn new(
429        tx_warning_threshold: Duration,
430        tx_critical_threshold: Duration,
431        validation_query: &str,
432        validation_timeout: Duration,
433        max_lease_idle: Duration,
434        max_queue_size: usize,
435        enable_backpressure: bool,
436    ) -> Self {
437        Self {
438            leak_detector: TransactionLeakDetector::new(
439                tx_warning_threshold,
440                tx_critical_threshold,
441            ),
442            health_validator: ConnectionHealthValidator::new(validation_query, validation_timeout),
443            stale_cleaner: StaleLeaseCleaner::new(max_lease_idle),
444            exhaustion_monitor: PoolExhaustionMonitor::new(max_queue_size, enable_backpressure),
445        }
446    }
447
448    /// Run periodic maintenance
449    ///
450    /// Returns (leaked_txns, stale_leases) that need to be cleaned up
451    pub fn run_maintenance(&self) -> (Vec<ClientId>, Vec<ClientId>) {
452        let leaked = self.leak_detector.check_for_leaks();
453        let stale = self.stale_cleaner.clean_stale();
454        (leaked, stale)
455    }
456
457    /// Get combined statistics
458    pub fn stats(&self) -> HardeningStats {
459        HardeningStats {
460            leak_stats: self.leak_detector.stats(),
461            validation_stats: self.health_validator.stats(),
462            exhaustion_stats: self.exhaustion_monitor.stats(),
463            stale_cleaned: self.stale_cleaner.cleaned_count(),
464        }
465    }
466}
467
468/// Combined hardening statistics
469#[derive(Debug, Clone)]
470pub struct HardeningStats {
471    /// Transaction leak statistics
472    pub leak_stats: TransactionLeakStats,
473    /// Validation statistics
474    pub validation_stats: ValidationStats,
475    /// Exhaustion statistics
476    pub exhaustion_stats: ExhaustionStats,
477    /// Stale leases cleaned
478    pub stale_cleaned: u64,
479}
480
481/// Truncate SQL for logging
482fn truncate_sql(sql: &str, max_len: usize) -> String {
483    if sql.len() <= max_len {
484        sql.to_string()
485    } else {
486        format!("{}...", &sql[..max_len])
487    }
488}
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493
494    #[test]
495    fn test_transaction_leak_detector() {
496        let detector =
497            TransactionLeakDetector::new(Duration::from_millis(10), Duration::from_millis(50));
498
499        let client1 = ClientId::new();
500        let client2 = ClientId::new();
501
502        // Start two transactions
503        detector.transaction_started(
504            client1,
505            PoolingMode::Transaction,
506            "BEGIN; SELECT * FROM users",
507        );
508        detector.transaction_started(client2, PoolingMode::Statement, "SELECT 1");
509
510        // No leaks immediately
511        assert!(detector.check_for_leaks().is_empty());
512
513        // End one transaction
514        detector.transaction_ended(&client2);
515
516        // Wait for warning threshold
517        std::thread::sleep(Duration::from_millis(15));
518        let leaked = detector.check_for_leaks();
519        assert!(leaked.is_empty()); // Warning issued but not critical yet
520
521        // Wait for critical threshold
522        std::thread::sleep(Duration::from_millis(40));
523        let leaked = detector.check_for_leaks();
524        assert_eq!(leaked.len(), 1);
525        assert_eq!(leaked[0], client1);
526    }
527
528    #[test]
529    fn test_connection_health_validator() {
530        let validator = ConnectionHealthValidator::default();
531
532        validator.record_validation(true);
533        validator.record_validation(true);
534        validator.record_validation(false);
535
536        assert_eq!(validator.stats().validations, 3);
537        assert_eq!(validator.stats().failures, 1);
538        assert!((validator.success_rate() - 0.666).abs() < 0.01);
539    }
540
541    #[test]
542    fn test_stale_lease_cleaner() {
543        let cleaner = StaleLeaseCleaner::new(Duration::from_millis(20));
544
545        let client1 = ClientId::new();
546        let client2 = ClientId::new();
547
548        cleaner.record_activity(client1);
549        cleaner.record_activity(client2);
550
551        // No stale leases immediately
552        assert!(cleaner.find_stale_leases().is_empty());
553
554        // Wait and only update client1
555        std::thread::sleep(Duration::from_millis(25));
556        cleaner.record_activity(client1);
557
558        // client2 should be stale
559        let stale = cleaner.clean_stale();
560        assert_eq!(stale.len(), 1);
561        assert_eq!(stale[0], client2);
562        assert_eq!(cleaner.cleaned_count(), 1);
563    }
564
565    #[test]
566    fn test_pool_exhaustion_monitor() {
567        let monitor = PoolExhaustionMonitor::new(2, true);
568
569        // First two requests OK
570        assert!(monitor.check_capacity().is_ok());
571        monitor.enter_queue();
572        assert!(monitor.check_capacity().is_ok());
573        monitor.enter_queue();
574
575        // Third should be rejected (backpressure)
576        assert!(monitor.check_capacity().is_err());
577        assert_eq!(monitor.stats().rejected_requests, 1);
578
579        // Leave queue
580        monitor.leave_queue();
581        assert!(monitor.check_capacity().is_ok());
582    }
583
584    #[test]
585    fn test_pool_hardening_combined() {
586        let hardening = PoolHardening::default();
587
588        // Run maintenance on empty state
589        let (leaked, stale) = hardening.run_maintenance();
590        assert!(leaked.is_empty());
591        assert!(stale.is_empty());
592
593        // Check stats
594        let stats = hardening.stats();
595        assert_eq!(stats.leak_stats.active_transactions, 0);
596        assert_eq!(stats.stale_cleaned, 0);
597    }
598}