ant_quic/relay/
statistics.rs

1//! Comprehensive relay statistics collection and aggregation.
2
3use super::{
4    AuthenticationStatistics, ConnectionStatistics, ErrorStatistics, RateLimitingStatistics,
5    RelayConnection, RelayStatistics, SessionManager, SessionStatistics,
6};
7use crate::endpoint::RelayStats;
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11
12/// Comprehensive relay statistics collector that aggregates stats from all relay components
13#[derive(Debug, Clone)]
14pub struct RelayStatisticsCollector {
15    /// Basic relay queue statistics
16    queue_stats: Arc<Mutex<RelayStats>>,
17
18    /// Session managers being tracked
19    session_managers: Arc<Mutex<Vec<Arc<SessionManager>>>>,
20
21    /// Connection tracking
22    connections: Arc<Mutex<HashMap<u32, Arc<RelayConnection>>>>,
23
24    /// Error tracking
25    error_counts: Arc<Mutex<HashMap<String, u64>>>,
26
27    /// Authentication tracking
28    auth_stats: Arc<Mutex<AuthenticationStatistics>>,
29
30    /// Rate limiting tracking
31    rate_limit_stats: Arc<Mutex<RateLimitingStatistics>>,
32
33    /// Collection start time for rate calculations
34    start_time: Instant,
35
36    /// Last statistics snapshot
37    last_snapshot: Arc<Mutex<RelayStatistics>>,
38}
39
40impl RelayStatisticsCollector {
41    /// Create a new statistics collector
42    pub fn new() -> Self {
43        Self {
44            queue_stats: Arc::new(Mutex::new(RelayStats::default())),
45            session_managers: Arc::new(Mutex::new(Vec::new())),
46            connections: Arc::new(Mutex::new(HashMap::new())),
47            error_counts: Arc::new(Mutex::new(HashMap::new())),
48            auth_stats: Arc::new(Mutex::new(AuthenticationStatistics::default())),
49            rate_limit_stats: Arc::new(Mutex::new(RateLimitingStatistics::default())),
50            start_time: Instant::now(),
51            last_snapshot: Arc::new(Mutex::new(RelayStatistics::default())),
52        }
53    }
54
55    /// Register a session manager for statistics collection
56    pub fn register_session_manager(&self, session_manager: Arc<SessionManager>) {
57        let mut managers = self.session_managers.lock().unwrap();
58        managers.push(session_manager);
59    }
60
61    /// Register a relay connection for statistics collection  
62    pub fn register_connection(&self, session_id: u32, connection: Arc<RelayConnection>) {
63        let mut connections = self.connections.lock().unwrap();
64        connections.insert(session_id, connection);
65    }
66
67    /// Unregister a relay connection
68    pub fn unregister_connection(&self, session_id: u32) {
69        let mut connections = self.connections.lock().unwrap();
70        connections.remove(&session_id);
71    }
72
73    /// Update queue statistics (called from endpoint)
74    pub fn update_queue_stats(&self, stats: &RelayStats) {
75        let mut queue_stats = self.queue_stats.lock().unwrap();
76        *queue_stats = stats.clone();
77    }
78
79    /// Record an authentication attempt
80    pub fn record_auth_attempt(&self, success: bool, error: Option<&str>) {
81        let mut auth_stats = self.auth_stats.lock().unwrap();
82        auth_stats.total_auth_attempts += 1;
83
84        if success {
85            auth_stats.successful_auths += 1;
86        } else {
87            auth_stats.failed_auths += 1;
88
89            if let Some(error_msg) = error {
90                if error_msg.contains("replay") {
91                    auth_stats.replay_attacks_blocked += 1;
92                } else if error_msg.contains("signature") {
93                    auth_stats.invalid_signatures += 1;
94                } else if error_msg.contains("unknown") || error_msg.contains("trusted") {
95                    auth_stats.unknown_peer_keys += 1;
96                }
97            }
98        }
99
100        // Update auth rate (auth attempts per second)
101        let elapsed = self.start_time.elapsed().as_secs_f64();
102        if elapsed > 0.0 {
103            auth_stats.auth_rate = auth_stats.total_auth_attempts as f64 / elapsed;
104        }
105    }
106
107    /// Record a rate limiting decision
108    pub fn record_rate_limit(&self, allowed: bool) {
109        let mut rate_stats = self.rate_limit_stats.lock().unwrap();
110        rate_stats.total_requests += 1;
111
112        if allowed {
113            rate_stats.requests_allowed += 1;
114        } else {
115            rate_stats.requests_blocked += 1;
116        }
117
118        // Update efficiency percentage
119        if rate_stats.total_requests > 0 {
120            rate_stats.efficiency_percentage =
121                (rate_stats.requests_allowed as f64 / rate_stats.total_requests as f64) * 100.0;
122        }
123    }
124
125    /// Record an error occurrence
126    pub fn record_error(&self, error_type: &str) {
127        let mut error_counts = self.error_counts.lock().unwrap();
128        *error_counts.entry(error_type.to_string()).or_insert(0) += 1;
129    }
130
131    /// Collect comprehensive statistics from all sources
132    pub fn collect_statistics(&self) -> RelayStatistics {
133        let session_stats = self.collect_session_statistics();
134        let connection_stats = self.collect_connection_statistics();
135        let auth_stats = self.auth_stats.lock().unwrap().clone();
136        let rate_limit_stats = self.rate_limit_stats.lock().unwrap().clone();
137        let error_stats = self.collect_error_statistics();
138
139        let stats = RelayStatistics {
140            session_stats,
141            connection_stats,
142            auth_stats,
143            rate_limit_stats,
144            error_stats,
145        };
146
147        // Update last snapshot
148        {
149            let mut last_snapshot = self.last_snapshot.lock().unwrap();
150            *last_snapshot = stats.clone();
151        }
152
153        stats
154    }
155
156    /// Get the last collected statistics snapshot
157    pub fn get_last_snapshot(&self) -> RelayStatistics {
158        self.last_snapshot.lock().unwrap().clone()
159    }
160
161    /// Collect session statistics from all registered session managers
162    fn collect_session_statistics(&self) -> SessionStatistics {
163        let managers = self.session_managers.lock().unwrap();
164        let mut total_stats = SessionStatistics::default();
165
166        for manager in managers.iter() {
167            let mgr_stats = manager.get_statistics();
168
169            // Aggregate session counts
170            total_stats.active_sessions += mgr_stats.active_sessions as u32;
171            total_stats.pending_sessions += mgr_stats.pending_sessions as u32;
172            total_stats.total_bytes_forwarded +=
173                mgr_stats.total_bytes_sent + mgr_stats.total_bytes_received;
174
175            // For derived stats, we take the maximum or average as appropriate
176            if mgr_stats.total_sessions > 0 {
177                total_stats.total_sessions_created += mgr_stats.total_sessions as u64;
178            }
179        }
180
181        // Calculate average session duration if we have historical data
182        // This would need to be tracked over time in a real implementation
183        let elapsed = self.start_time.elapsed().as_secs_f64();
184        if total_stats.total_sessions_created > 0 && elapsed > 0.0 {
185            total_stats.avg_session_duration = elapsed / total_stats.total_sessions_created as f64;
186        }
187
188        total_stats
189    }
190
191    /// Collect connection statistics from all registered connections
192    fn collect_connection_statistics(&self) -> ConnectionStatistics {
193        let connections = self.connections.lock().unwrap();
194        let mut total_stats = ConnectionStatistics::default();
195
196        total_stats.total_connections = connections.len() as u64;
197
198        for connection in connections.values() {
199            let conn_stats = connection.get_stats();
200
201            if conn_stats.is_active {
202                total_stats.active_connections += 1;
203            }
204
205            total_stats.total_bytes_sent += conn_stats.bytes_sent;
206            total_stats.total_bytes_received += conn_stats.bytes_received;
207        }
208
209        // Calculate average bandwidth usage
210        let elapsed = self.start_time.elapsed().as_secs_f64();
211        if elapsed > 0.0 {
212            let total_bytes = total_stats.total_bytes_sent + total_stats.total_bytes_received;
213            total_stats.avg_bandwidth_usage = total_bytes as f64 / elapsed;
214        }
215
216        // Peak concurrent connections would need to be tracked over time
217        total_stats.peak_concurrent_connections = total_stats.active_connections;
218
219        total_stats
220    }
221
222    /// Collect error statistics
223    fn collect_error_statistics(&self) -> ErrorStatistics {
224        let error_counts = self.error_counts.lock().unwrap();
225        let queue_stats = self.queue_stats.lock().unwrap();
226
227        let mut error_stats = ErrorStatistics::default();
228        error_stats.error_breakdown = error_counts.clone();
229
230        // Categorize errors
231        for (error_type, count) in error_counts.iter() {
232            if error_type.contains("protocol") || error_type.contains("frame") {
233                error_stats.protocol_errors += count;
234            } else if error_type.contains("resource") || error_type.contains("exhausted") {
235                error_stats.resource_exhausted += count;
236            } else if error_type.contains("session") {
237                error_stats.session_errors += count;
238            } else if error_type.contains("auth") {
239                error_stats.auth_failures += count;
240            } else if error_type.contains("network") || error_type.contains("connection") {
241                error_stats.network_errors += count;
242            } else {
243                error_stats.internal_errors += count;
244            }
245        }
246
247        // Add queue-related failures
248        error_stats.resource_exhausted += queue_stats.requests_dropped;
249        error_stats.protocol_errors += queue_stats.requests_failed;
250
251        // Calculate error rate
252        let total_errors = error_stats.protocol_errors
253            + error_stats.resource_exhausted
254            + error_stats.session_errors
255            + error_stats.auth_failures
256            + error_stats.network_errors
257            + error_stats.internal_errors;
258
259        let elapsed = self.start_time.elapsed().as_secs_f64();
260        if elapsed > 0.0 {
261            error_stats.error_rate = total_errors as f64 / elapsed;
262        }
263
264        error_stats
265    }
266
267    /// Reset all statistics (useful for testing)
268    pub fn reset(&self) {
269        {
270            let mut queue_stats = self.queue_stats.lock().unwrap();
271            *queue_stats = RelayStats::default();
272        }
273        {
274            let mut error_counts = self.error_counts.lock().unwrap();
275            error_counts.clear();
276        }
277        {
278            let mut auth_stats = self.auth_stats.lock().unwrap();
279            *auth_stats = AuthenticationStatistics::default();
280        }
281        {
282            let mut rate_limit_stats = self.rate_limit_stats.lock().unwrap();
283            *rate_limit_stats = RateLimitingStatistics::default();
284        }
285    }
286}
287
288impl Default for RelayStatisticsCollector {
289    fn default() -> Self {
290        Self::new()
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_statistics_collector_creation() {
300        let collector = RelayStatisticsCollector::new();
301        let stats = collector.collect_statistics();
302
303        // Should start with empty statistics
304        assert_eq!(stats.session_stats.active_sessions, 0);
305        assert_eq!(stats.connection_stats.total_connections, 0);
306        assert_eq!(stats.auth_stats.total_auth_attempts, 0);
307        assert!(stats.is_healthy());
308    }
309
310    #[test]
311    fn test_auth_tracking() {
312        let collector = RelayStatisticsCollector::new();
313
314        // Record some authentication attempts
315        collector.record_auth_attempt(true, None);
316        collector.record_auth_attempt(false, Some("signature verification failed"));
317        collector.record_auth_attempt(false, Some("replay attack detected"));
318
319        let stats = collector.collect_statistics();
320        assert_eq!(stats.auth_stats.total_auth_attempts, 3);
321        assert_eq!(stats.auth_stats.successful_auths, 1);
322        assert_eq!(stats.auth_stats.failed_auths, 2);
323        assert_eq!(stats.auth_stats.invalid_signatures, 1);
324        assert_eq!(stats.auth_stats.replay_attacks_blocked, 1);
325    }
326
327    #[test]
328    fn test_rate_limiting_tracking() {
329        let collector = RelayStatisticsCollector::new();
330
331        // Record some rate limiting decisions
332        collector.record_rate_limit(true);
333        collector.record_rate_limit(true);
334        collector.record_rate_limit(false);
335        collector.record_rate_limit(true);
336
337        let stats = collector.collect_statistics();
338        assert_eq!(stats.rate_limit_stats.total_requests, 4);
339        assert_eq!(stats.rate_limit_stats.requests_allowed, 3);
340        assert_eq!(stats.rate_limit_stats.requests_blocked, 1);
341        assert_eq!(stats.rate_limit_stats.efficiency_percentage, 75.0);
342    }
343
344    #[test]
345    fn test_error_tracking() {
346        let collector = RelayStatisticsCollector::new();
347
348        // Record various errors
349        collector.record_error("protocol_error");
350        collector.record_error("resource_exhausted");
351        collector.record_error("session_timeout");
352        collector.record_error("auth_failed");
353
354        let stats = collector.collect_statistics();
355        assert_eq!(stats.error_stats.protocol_errors, 1);
356        assert_eq!(stats.error_stats.resource_exhausted, 1);
357        assert_eq!(stats.error_stats.session_errors, 1);
358        assert_eq!(stats.error_stats.auth_failures, 1);
359        assert_eq!(stats.error_stats.error_breakdown.len(), 4);
360    }
361
362    #[test]
363    fn test_success_rate_calculation() {
364        let collector = RelayStatisticsCollector::new();
365
366        // Record some successful operations
367        collector.record_auth_attempt(true, None);
368        collector.record_auth_attempt(true, None);
369        collector.record_rate_limit(true);
370        collector.record_rate_limit(true);
371
372        // Record some failures
373        collector.record_auth_attempt(false, None);
374        collector.record_error("protocol_error");
375
376        let stats = collector.collect_statistics();
377
378        // Should have a good success rate but not perfect due to failures
379        let success_rate = stats.success_rate();
380        assert!(success_rate > 0.5);
381        assert!(success_rate < 1.0);
382    }
383
384    #[test]
385    fn test_reset_functionality() {
386        let collector = RelayStatisticsCollector::new();
387
388        // Add some data
389        collector.record_auth_attempt(true, None);
390        collector.record_error("test_error");
391        collector.record_rate_limit(false);
392
393        // Verify data exists
394        let stats_before = collector.collect_statistics();
395        assert!(stats_before.auth_stats.total_auth_attempts > 0);
396
397        // Reset and verify clean state
398        collector.reset();
399        let stats_after = collector.collect_statistics();
400        assert_eq!(stats_after.auth_stats.total_auth_attempts, 0);
401        assert_eq!(stats_after.rate_limit_stats.total_requests, 0);
402    }
403}