oxify_storage/
connection_leak_detector.rs

1//! Connection leak detection and monitoring
2//!
3//! Provides utilities to detect and monitor potential connection leaks in the database pool.
4//! This is critical for production debugging when connections are not being properly returned.
5//!
6//! # Features
7//!
8//! - Track connection acquisition and release
9//! - Detect long-lived connections that may indicate leaks
10//! - Report connection usage statistics
11//! - Integration with tracing for observability
12//!
13//! # Example
14//!
15//! ```ignore
16//! use oxify_storage::connection_leak_detector::{LeakDetector, LeakDetectorConfig};
17//!
18//! let config = LeakDetectorConfig {
19//!     leak_threshold_seconds: 300, // 5 minutes
20//!     check_interval_seconds: 60,  // Check every minute
21//!     max_tracked_connections: 1000,
22//! };
23//!
24//! let detector = LeakDetector::new(config);
25//!
26//! // Track connection acquisition
27//! let token = detector.track_acquisition("user_query");
28//!
29//! // Perform database operations...
30//!
31//! // Release tracking when done
32//! detector.track_release(token);
33//!
34//! // Get leak report
35//! let report = detector.get_leak_report();
36//! if !report.suspected_leaks.is_empty() {
37//!     warn!("Detected {} potential connection leaks", report.suspected_leaks.len());
38//! }
39//! ```
40
41use std::collections::HashMap;
42use std::sync::{Arc, RwLock};
43use std::time::{Duration, Instant};
44
45/// Configuration for leak detection
46#[derive(Debug, Clone)]
47pub struct LeakDetectorConfig {
48    /// Threshold in seconds after which a connection is considered potentially leaked
49    pub leak_threshold_seconds: u64,
50    /// Interval in seconds between leak checks
51    pub check_interval_seconds: u64,
52    /// Maximum number of connections to track (prevents memory bloat)
53    pub max_tracked_connections: usize,
54}
55
56impl Default for LeakDetectorConfig {
57    fn default() -> Self {
58        Self {
59            leak_threshold_seconds: 300, // 5 minutes
60            check_interval_seconds: 60,  // 1 minute
61            max_tracked_connections: 1000,
62        }
63    }
64}
65
66/// Token for tracking a connection acquisition
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
68pub struct ConnectionToken(u64);
69
70/// Information about a tracked connection
71#[derive(Debug, Clone)]
72struct ConnectionInfo {
73    /// When the connection was acquired
74    acquired_at: Instant,
75    /// Where the connection was acquired (e.g., function name, query type)
76    context: String,
77    /// Whether this connection has been flagged as a potential leak
78    flagged_as_leak: bool,
79}
80
81/// Connection leak detector
82#[derive(Clone)]
83pub struct LeakDetector {
84    config: LeakDetectorConfig,
85    state: Arc<RwLock<DetectorState>>,
86}
87
88struct DetectorState {
89    /// Currently tracked connections
90    tracked: HashMap<ConnectionToken, ConnectionInfo>,
91    /// Next token ID to assign
92    next_token: u64,
93    /// Statistics
94    stats: LeakStats,
95    /// Last time a leak check was performed
96    last_check: Instant,
97}
98
99/// Statistics about connection usage and leaks
100#[derive(Debug, Clone, Default)]
101pub struct LeakStats {
102    /// Total connections tracked since detector creation
103    pub total_tracked: u64,
104    /// Total connections released
105    pub total_released: u64,
106    /// Total suspected leaks detected
107    pub total_suspected_leaks: u64,
108    /// Current number of active connections being tracked
109    pub active_connections: usize,
110    /// Longest connection duration ever observed (in seconds)
111    pub longest_connection_duration_secs: u64,
112}
113
114/// Information about a suspected connection leak
115#[derive(Debug, Clone)]
116pub struct SuspectedLeak {
117    /// Token identifying the connection
118    pub token: ConnectionToken,
119    /// How long the connection has been held (in seconds)
120    pub duration_secs: u64,
121    /// Context where the connection was acquired
122    pub context: String,
123    /// When the connection was acquired
124    pub acquired_at: Instant,
125}
126
127/// Report of detected leaks and statistics
128#[derive(Debug, Clone)]
129pub struct LeakReport {
130    /// List of suspected leaks
131    pub suspected_leaks: Vec<SuspectedLeak>,
132    /// Overall statistics
133    pub stats: LeakStats,
134    /// When the report was generated
135    pub generated_at: Instant,
136}
137
138impl LeakDetector {
139    /// Create a new leak detector with the given configuration
140    pub fn new(config: LeakDetectorConfig) -> Self {
141        Self {
142            config,
143            state: Arc::new(RwLock::new(DetectorState {
144                tracked: HashMap::new(),
145                next_token: 1,
146                stats: LeakStats::default(),
147                last_check: Instant::now(),
148            })),
149        }
150    }
151
152    /// Create a new leak detector with default configuration
153    pub fn with_defaults() -> Self {
154        Self::new(LeakDetectorConfig::default())
155    }
156
157    /// Track a connection acquisition
158    ///
159    /// Returns a token that should be passed to `track_release()` when the connection is returned.
160    ///
161    /// # Arguments
162    ///
163    /// * `context` - Description of where/why the connection was acquired (e.g., "user_query", "workflow_update")
164    pub fn track_acquisition(&self, context: impl Into<String>) -> ConnectionToken {
165        let mut state = self
166            .state
167            .write()
168            .expect("Leak detector state lock poisoned");
169
170        // Check if we've exceeded max tracked connections
171        if state.tracked.len() >= self.config.max_tracked_connections {
172            tracing::warn!(
173                "Leak detector at maximum capacity ({}), not tracking new connection",
174                self.config.max_tracked_connections
175            );
176            return ConnectionToken(0); // Invalid token
177        }
178
179        let token = ConnectionToken(state.next_token);
180        state.next_token += 1;
181
182        let info = ConnectionInfo {
183            acquired_at: Instant::now(),
184            context: context.into(),
185            flagged_as_leak: false,
186        };
187
188        state.tracked.insert(token, info);
189        state.stats.total_tracked += 1;
190        state.stats.active_connections = state.tracked.len();
191
192        token
193    }
194
195    /// Track a connection release
196    ///
197    /// Should be called when a connection is returned to the pool.
198    ///
199    /// # Arguments
200    ///
201    /// * `token` - The token returned from `track_acquisition()`
202    pub fn track_release(&self, token: ConnectionToken) {
203        if token.0 == 0 {
204            return; // Invalid token
205        }
206
207        let mut state = self
208            .state
209            .write()
210            .expect("Leak detector state lock poisoned");
211
212        if let Some(info) = state.tracked.remove(&token) {
213            let duration = info.acquired_at.elapsed();
214            let duration_secs = duration.as_secs();
215
216            state.stats.total_released += 1;
217            state.stats.active_connections = state.tracked.len();
218
219            if duration_secs > state.stats.longest_connection_duration_secs {
220                state.stats.longest_connection_duration_secs = duration_secs;
221            }
222
223            if duration_secs > self.config.leak_threshold_seconds {
224                tracing::info!(
225                    context = %info.context,
226                    duration_secs,
227                    "Long-lived connection released (exceeded threshold)"
228                );
229            }
230        } else {
231            tracing::warn!(
232                token = token.0,
233                "Attempted to release connection that was not tracked"
234            );
235        }
236    }
237
238    /// Perform a leak check and return a report
239    ///
240    /// This scans all tracked connections and identifies those that have been held
241    /// longer than the configured threshold.
242    pub fn get_leak_report(&self) -> LeakReport {
243        let mut state = self
244            .state
245            .write()
246            .expect("Leak detector state lock poisoned");
247
248        let now = Instant::now();
249        let threshold = Duration::from_secs(self.config.leak_threshold_seconds);
250
251        let mut suspected_leaks = Vec::new();
252        let mut new_leaks_count = 0;
253
254        for (token, info) in &mut state.tracked {
255            let duration = now.duration_since(info.acquired_at);
256
257            if duration > threshold {
258                if !info.flagged_as_leak {
259                    info.flagged_as_leak = true;
260                    new_leaks_count += 1;
261
262                    tracing::warn!(
263                        token = token.0,
264                        context = %info.context,
265                        duration_secs = duration.as_secs(),
266                        "Suspected connection leak detected"
267                    );
268                }
269
270                suspected_leaks.push(SuspectedLeak {
271                    token: *token,
272                    duration_secs: duration.as_secs(),
273                    context: info.context.clone(),
274                    acquired_at: info.acquired_at,
275                });
276            }
277        }
278
279        state.stats.total_suspected_leaks += new_leaks_count;
280        state.last_check = now;
281
282        // Sort by duration (longest first)
283        suspected_leaks.sort_by(|a, b| b.duration_secs.cmp(&a.duration_secs));
284
285        LeakReport {
286            suspected_leaks,
287            stats: state.stats.clone(),
288            generated_at: now,
289        }
290    }
291
292    /// Check for leaks and log warnings if any are found
293    ///
294    /// This is a convenience method that gets a leak report and logs warnings
295    /// for each suspected leak.
296    pub fn check_and_log_leaks(&self) {
297        let report = self.get_leak_report();
298
299        if !report.suspected_leaks.is_empty() {
300            tracing::warn!(
301                "Detected {} suspected connection leaks",
302                report.suspected_leaks.len()
303            );
304
305            for leak in &report.suspected_leaks {
306                tracing::warn!(
307                    token = leak.token.0,
308                    context = %leak.context,
309                    duration_secs = leak.duration_secs,
310                    "Connection held for {} seconds",
311                    leak.duration_secs
312                );
313            }
314        }
315    }
316
317    /// Get current statistics
318    pub fn get_stats(&self) -> LeakStats {
319        let state = self
320            .state
321            .read()
322            .expect("Leak detector state lock poisoned");
323        state.stats.clone()
324    }
325
326    /// Clear all tracked connections
327    ///
328    /// This is useful for testing or resetting the detector state.
329    /// Use with caution in production.
330    pub fn clear(&self) {
331        let mut state = self
332            .state
333            .write()
334            .expect("Leak detector state lock poisoned");
335        state.tracked.clear();
336        state.stats.active_connections = 0;
337    }
338
339    /// Get the number of currently tracked connections
340    pub fn active_count(&self) -> usize {
341        let state = self
342            .state
343            .read()
344            .expect("Leak detector state lock poisoned");
345        state.tracked.len()
346    }
347}
348
349impl LeakReport {
350    /// Check if there are any suspected leaks
351    pub fn has_leaks(&self) -> bool {
352        !self.suspected_leaks.is_empty()
353    }
354
355    /// Get the number of suspected leaks
356    pub fn leak_count(&self) -> usize {
357        self.suspected_leaks.len()
358    }
359
360    /// Get contexts where leaks occurred (unique)
361    pub fn leak_contexts(&self) -> Vec<String> {
362        let mut contexts: Vec<String> = self
363            .suspected_leaks
364            .iter()
365            .map(|leak| leak.context.clone())
366            .collect();
367        contexts.sort();
368        contexts.dedup();
369        contexts
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use std::thread;
377
378    #[test]
379    fn test_track_acquisition_and_release() {
380        let detector = LeakDetector::with_defaults();
381
382        let token = detector.track_acquisition("test_query");
383        assert_eq!(detector.active_count(), 1);
384
385        detector.track_release(token);
386        assert_eq!(detector.active_count(), 0);
387
388        let stats = detector.get_stats();
389        assert_eq!(stats.total_tracked, 1);
390        assert_eq!(stats.total_released, 1);
391    }
392
393    #[test]
394    fn test_multiple_connections() {
395        let detector = LeakDetector::with_defaults();
396
397        let token1 = detector.track_acquisition("query1");
398        let token2 = detector.track_acquisition("query2");
399        let token3 = detector.track_acquisition("query3");
400
401        assert_eq!(detector.active_count(), 3);
402
403        detector.track_release(token2);
404        assert_eq!(detector.active_count(), 2);
405
406        detector.track_release(token1);
407        detector.track_release(token3);
408        assert_eq!(detector.active_count(), 0);
409    }
410
411    #[test]
412    fn test_leak_detection() {
413        let config = LeakDetectorConfig {
414            leak_threshold_seconds: 0, // Immediate leak detection for testing
415            check_interval_seconds: 1,
416            max_tracked_connections: 100,
417        };
418
419        let detector = LeakDetector::new(config);
420
421        let _token = detector.track_acquisition("slow_query");
422
423        // Wait a bit to exceed threshold
424        thread::sleep(Duration::from_millis(100));
425
426        let report = detector.get_leak_report();
427        assert!(report.has_leaks());
428        assert_eq!(report.leak_count(), 1);
429        assert_eq!(report.leak_contexts(), vec!["slow_query".to_string()]);
430    }
431
432    #[test]
433    fn test_max_tracked_connections() {
434        let config = LeakDetectorConfig {
435            leak_threshold_seconds: 300,
436            check_interval_seconds: 60,
437            max_tracked_connections: 3,
438        };
439
440        let detector = LeakDetector::new(config);
441
442        let token1 = detector.track_acquisition("query1");
443        let token2 = detector.track_acquisition("query2");
444        let token3 = detector.track_acquisition("query3");
445        let token4 = detector.track_acquisition("query4"); // Should not be tracked
446
447        assert_eq!(detector.active_count(), 3);
448        assert_eq!(token4.0, 0); // Invalid token
449
450        detector.track_release(token1);
451        detector.track_release(token2);
452        detector.track_release(token3);
453        detector.track_release(token4); // Should be a no-op
454
455        assert_eq!(detector.active_count(), 0);
456    }
457
458    #[test]
459    fn test_clear() {
460        let detector = LeakDetector::with_defaults();
461
462        detector.track_acquisition("query1");
463        detector.track_acquisition("query2");
464        assert_eq!(detector.active_count(), 2);
465
466        detector.clear();
467        assert_eq!(detector.active_count(), 0);
468    }
469
470    #[test]
471    fn test_release_untracked_connection() {
472        let detector = LeakDetector::with_defaults();
473
474        // Try to release a connection that was never tracked
475        detector.track_release(ConnectionToken(9999));
476
477        // Should not panic, just log a warning
478        assert_eq!(detector.active_count(), 0);
479    }
480
481    #[test]
482    fn test_stats_tracking() {
483        let detector = LeakDetector::with_defaults();
484
485        let token1 = detector.track_acquisition("query1");
486        let token2 = detector.track_acquisition("query2");
487
488        let stats = detector.get_stats();
489        assert_eq!(stats.total_tracked, 2);
490        assert_eq!(stats.active_connections, 2);
491        assert_eq!(stats.total_released, 0);
492
493        detector.track_release(token1);
494
495        let stats = detector.get_stats();
496        assert_eq!(stats.total_released, 1);
497        assert_eq!(stats.active_connections, 1);
498
499        detector.track_release(token2);
500
501        let stats = detector.get_stats();
502        assert_eq!(stats.total_released, 2);
503        assert_eq!(stats.active_connections, 0);
504    }
505}