ant_quic/relay/
statistics.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8//! Comprehensive relay statistics collection and aggregation.
9
10use super::{
11    AuthenticationStatistics, ConnectionStatistics, ErrorStatistics, RateLimitingStatistics,
12    RelayConnection, RelayStatistics, SessionManager, SessionStatistics,
13};
14use crate::endpoint::RelayStats;
15use std::collections::HashMap;
16use std::sync::{Arc, Mutex};
17use std::time::{Duration, Instant};
18
19/// Comprehensive relay statistics collector that aggregates stats from all relay components
20#[derive(Debug, Clone)]
21pub struct RelayStatisticsCollector {
22    /// Basic relay queue statistics
23    queue_stats: Arc<Mutex<RelayStats>>,
24
25    /// Session managers being tracked
26    session_managers: Arc<Mutex<Vec<Arc<SessionManager>>>>,
27
28    /// Connection tracking
29    connections: Arc<Mutex<HashMap<u32, Arc<RelayConnection>>>>,
30
31    /// Error tracking
32    error_counts: Arc<Mutex<HashMap<String, u64>>>,
33
34    /// Authentication tracking
35    auth_stats: Arc<Mutex<AuthenticationStatistics>>,
36
37    /// Rate limiting tracking
38    rate_limit_stats: Arc<Mutex<RateLimitingStatistics>>,
39
40    /// Collection start time for rate calculations
41    start_time: Instant,
42
43    /// Last statistics snapshot
44    last_snapshot: Arc<Mutex<RelayStatistics>>,
45}
46
47impl RelayStatisticsCollector {
48    /// Create a new statistics collector
49    pub fn new() -> Self {
50        Self {
51            queue_stats: Arc::new(Mutex::new(RelayStats::default())),
52            session_managers: Arc::new(Mutex::new(Vec::new())),
53            connections: Arc::new(Mutex::new(HashMap::new())),
54            error_counts: Arc::new(Mutex::new(HashMap::new())),
55            auth_stats: Arc::new(Mutex::new(AuthenticationStatistics::default())),
56            rate_limit_stats: Arc::new(Mutex::new(RateLimitingStatistics::default())),
57            start_time: Instant::now(),
58            last_snapshot: Arc::new(Mutex::new(RelayStatistics::default())),
59        }
60    }
61
62    /// Register a session manager for statistics collection
63    #[allow(clippy::unwrap_used)]
64    pub fn register_session_manager(&self, session_manager: Arc<SessionManager>) {
65        let mut managers = self.session_managers.lock().unwrap();
66        managers.push(session_manager);
67    }
68
69    /// Register a relay connection for statistics collection
70    #[allow(clippy::unwrap_used)]
71    pub fn register_connection(&self, session_id: u32, connection: Arc<RelayConnection>) {
72        let mut connections = self.connections.lock().unwrap();
73        connections.insert(session_id, connection);
74    }
75
76    /// Unregister a relay connection
77    #[allow(clippy::unwrap_used)]
78    pub fn unregister_connection(&self, session_id: u32) {
79        let mut connections = self.connections.lock().unwrap();
80        connections.remove(&session_id);
81    }
82
83    /// Update queue statistics (called from endpoint)
84    #[allow(clippy::unwrap_used)]
85    pub fn update_queue_stats(&self, stats: &RelayStats) {
86        let mut queue_stats = self.queue_stats.lock().unwrap();
87        *queue_stats = stats.clone();
88    }
89
90    /// Record an authentication attempt
91    #[allow(clippy::unwrap_used)]
92    pub fn record_auth_attempt(&self, success: bool, error: Option<&str>) {
93        let mut auth_stats = self.auth_stats.lock().unwrap();
94        auth_stats.total_auth_attempts += 1;
95
96        if success {
97            auth_stats.successful_auths += 1;
98        } else {
99            auth_stats.failed_auths += 1;
100
101            if let Some(error_msg) = error {
102                if error_msg.contains("replay") {
103                    auth_stats.replay_attacks_blocked += 1;
104                } else if error_msg.contains("signature") {
105                    auth_stats.invalid_signatures += 1;
106                } else if error_msg.contains("unknown") || error_msg.contains("trusted") {
107                    auth_stats.unknown_peer_keys += 1;
108                }
109            }
110        }
111
112        // Update auth rate (auth attempts per second)
113        let elapsed = self.start_time.elapsed().as_secs_f64();
114        if elapsed > 0.0 {
115            auth_stats.auth_rate = auth_stats.total_auth_attempts as f64 / elapsed;
116        }
117    }
118
119    /// Record a rate limiting decision
120    #[allow(clippy::unwrap_used)]
121    pub fn record_rate_limit(&self, allowed: bool) {
122        let mut rate_stats = self.rate_limit_stats.lock().unwrap();
123        rate_stats.total_requests += 1;
124
125        if allowed {
126            rate_stats.requests_allowed += 1;
127        } else {
128            rate_stats.requests_blocked += 1;
129        }
130
131        // Update efficiency percentage
132        if rate_stats.total_requests > 0 {
133            rate_stats.efficiency_percentage =
134                (rate_stats.requests_allowed as f64 / rate_stats.total_requests as f64) * 100.0;
135        }
136    }
137
138    /// Record an error occurrence
139    #[allow(clippy::unwrap_used)]
140    pub fn record_error(&self, error_type: &str) {
141        let mut error_counts = self.error_counts.lock().unwrap();
142        *error_counts.entry(error_type.to_string()).or_insert(0) += 1;
143    }
144
145    /// Collect comprehensive statistics from all sources
146    #[allow(clippy::unwrap_used)]
147    pub fn collect_statistics(&self) -> RelayStatistics {
148        let session_stats = self.collect_session_statistics();
149        let connection_stats = self.collect_connection_statistics();
150        let auth_stats = self.auth_stats.lock().unwrap().clone();
151        let rate_limit_stats = self.rate_limit_stats.lock().unwrap().clone();
152        let error_stats = self.collect_error_statistics();
153
154        let stats = RelayStatistics {
155            session_stats,
156            connection_stats,
157            auth_stats,
158            rate_limit_stats,
159            error_stats,
160        };
161
162        // Update last snapshot
163        {
164            let mut last_snapshot = self.last_snapshot.lock().unwrap();
165            *last_snapshot = stats.clone();
166        }
167
168        stats
169    }
170
171    /// Get the last collected statistics snapshot
172    #[allow(clippy::unwrap_used)]
173    pub fn get_last_snapshot(&self) -> RelayStatistics {
174        self.last_snapshot.lock().unwrap().clone()
175    }
176
177    /// Collect session statistics from all registered session managers
178    #[allow(clippy::unwrap_used)]
179    fn collect_session_statistics(&self) -> SessionStatistics {
180        let managers = self.session_managers.lock().unwrap();
181        let mut total_stats = SessionStatistics::default();
182
183        for manager in managers.iter() {
184            let mgr_stats = manager.get_statistics();
185
186            // Aggregate session counts
187            total_stats.active_sessions += mgr_stats.active_sessions as u32;
188            total_stats.pending_sessions += mgr_stats.pending_sessions as u32;
189            total_stats.total_bytes_forwarded +=
190                mgr_stats.total_bytes_sent + mgr_stats.total_bytes_received;
191
192            // For derived stats, we take the maximum or average as appropriate
193            if mgr_stats.total_sessions > 0 {
194                total_stats.total_sessions_created += mgr_stats.total_sessions as u64;
195            }
196        }
197
198        // Calculate average session duration if we have historical data
199        // This would need to be tracked over time in a real implementation
200        let elapsed = self.start_time.elapsed().as_secs_f64();
201        if total_stats.total_sessions_created > 0 && elapsed > 0.0 {
202            total_stats.avg_session_duration = elapsed / total_stats.total_sessions_created as f64;
203        }
204
205        total_stats
206    }
207
208    /// Collect connection statistics from all registered connections
209    #[allow(clippy::unwrap_used)]
210    fn collect_connection_statistics(&self) -> ConnectionStatistics {
211        let connections = self.connections.lock().unwrap();
212        let mut total_stats = ConnectionStatistics::default();
213
214        total_stats.total_connections = connections.len() as u64;
215
216        for connection in connections.values() {
217            let conn_stats = connection.get_stats();
218
219            if conn_stats.is_active {
220                total_stats.active_connections += 1;
221            }
222
223            total_stats.total_bytes_sent += conn_stats.bytes_sent;
224            total_stats.total_bytes_received += conn_stats.bytes_received;
225        }
226
227        // Calculate average bandwidth usage
228        let elapsed = self.start_time.elapsed().as_secs_f64();
229        if elapsed > 0.0 {
230            let total_bytes = total_stats.total_bytes_sent + total_stats.total_bytes_received;
231            total_stats.avg_bandwidth_usage = total_bytes as f64 / elapsed;
232        }
233
234        // Peak concurrent connections would need to be tracked over time
235        total_stats.peak_concurrent_connections = total_stats.active_connections;
236
237        total_stats
238    }
239
240    /// Collect error statistics
241    #[allow(clippy::unwrap_used)]
242    fn collect_error_statistics(&self) -> ErrorStatistics {
243        let error_counts = self.error_counts.lock().unwrap();
244        let queue_stats = self.queue_stats.lock().unwrap();
245
246        let mut error_stats = ErrorStatistics::default();
247        error_stats.error_breakdown = error_counts.clone();
248
249        // Categorize errors
250        for (error_type, count) in error_counts.iter() {
251            if error_type.contains("protocol") || error_type.contains("frame") {
252                error_stats.protocol_errors += count;
253            } else if error_type.contains("resource") || error_type.contains("exhausted") {
254                error_stats.resource_exhausted += count;
255            } else if error_type.contains("session") {
256                error_stats.session_errors += count;
257            } else if error_type.contains("auth") {
258                error_stats.auth_failures += count;
259            } else if error_type.contains("network") || error_type.contains("connection") {
260                error_stats.network_errors += count;
261            } else {
262                error_stats.internal_errors += count;
263            }
264        }
265
266        // Add queue-related failures
267        error_stats.resource_exhausted += queue_stats.requests_dropped;
268        error_stats.protocol_errors += queue_stats.requests_failed;
269
270        // Calculate error rate
271        let total_errors = error_stats.protocol_errors
272            + error_stats.resource_exhausted
273            + error_stats.session_errors
274            + error_stats.auth_failures
275            + error_stats.network_errors
276            + error_stats.internal_errors;
277
278        let elapsed = self.start_time.elapsed().as_secs_f64();
279        if elapsed > 0.0 {
280            error_stats.error_rate = total_errors as f64 / elapsed;
281        }
282
283        error_stats
284    }
285
286    /// Reset all statistics (useful for testing)
287    #[allow(clippy::unwrap_used)]
288    pub fn reset(&self) {
289        {
290            let mut queue_stats = self.queue_stats.lock().unwrap();
291            *queue_stats = RelayStats::default();
292        }
293        {
294            let mut error_counts = self.error_counts.lock().unwrap();
295            error_counts.clear();
296        }
297        {
298            let mut auth_stats = self.auth_stats.lock().unwrap();
299            *auth_stats = AuthenticationStatistics::default();
300        }
301        {
302            let mut rate_limit_stats = self.rate_limit_stats.lock().unwrap();
303            *rate_limit_stats = RateLimitingStatistics::default();
304        }
305    }
306}
307
308impl Default for RelayStatisticsCollector {
309    fn default() -> Self {
310        Self::new()
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    #[test]
319    fn test_statistics_collector_creation() {
320        let collector = RelayStatisticsCollector::new();
321        let stats = collector.collect_statistics();
322
323        // Should start with empty statistics
324        assert_eq!(stats.session_stats.active_sessions, 0);
325        assert_eq!(stats.connection_stats.total_connections, 0);
326        assert_eq!(stats.auth_stats.total_auth_attempts, 0);
327        assert!(stats.is_healthy());
328    }
329
330    #[test]
331    fn test_auth_tracking() {
332        let collector = RelayStatisticsCollector::new();
333
334        // Record some authentication attempts
335        collector.record_auth_attempt(true, None);
336        collector.record_auth_attempt(false, Some("signature verification failed"));
337        collector.record_auth_attempt(false, Some("replay attack detected"));
338
339        let stats = collector.collect_statistics();
340        assert_eq!(stats.auth_stats.total_auth_attempts, 3);
341        assert_eq!(stats.auth_stats.successful_auths, 1);
342        assert_eq!(stats.auth_stats.failed_auths, 2);
343        assert_eq!(stats.auth_stats.invalid_signatures, 1);
344        assert_eq!(stats.auth_stats.replay_attacks_blocked, 1);
345    }
346
347    #[test]
348    fn test_rate_limiting_tracking() {
349        let collector = RelayStatisticsCollector::new();
350
351        // Record some rate limiting decisions
352        collector.record_rate_limit(true);
353        collector.record_rate_limit(true);
354        collector.record_rate_limit(false);
355        collector.record_rate_limit(true);
356
357        let stats = collector.collect_statistics();
358        assert_eq!(stats.rate_limit_stats.total_requests, 4);
359        assert_eq!(stats.rate_limit_stats.requests_allowed, 3);
360        assert_eq!(stats.rate_limit_stats.requests_blocked, 1);
361        assert_eq!(stats.rate_limit_stats.efficiency_percentage, 75.0);
362    }
363
364    #[test]
365    fn test_error_tracking() {
366        let collector = RelayStatisticsCollector::new();
367
368        // Record various errors
369        collector.record_error("protocol_error");
370        collector.record_error("resource_exhausted");
371        collector.record_error("session_timeout");
372        collector.record_error("auth_failed");
373
374        let stats = collector.collect_statistics();
375        assert_eq!(stats.error_stats.protocol_errors, 1);
376        assert_eq!(stats.error_stats.resource_exhausted, 1);
377        assert_eq!(stats.error_stats.session_errors, 1);
378        assert_eq!(stats.error_stats.auth_failures, 1);
379        assert_eq!(stats.error_stats.error_breakdown.len(), 4);
380    }
381
382    #[test]
383    fn test_success_rate_calculation() {
384        let collector = RelayStatisticsCollector::new();
385
386        // Record more successful operations to ensure > 50% success rate
387        collector.record_auth_attempt(true, None);
388        collector.record_auth_attempt(true, None);
389        collector.record_auth_attempt(true, None);
390        collector.record_auth_attempt(true, None);
391
392        // Note: record_rate_limit doesn't affect the success_rate calculation
393        // as it's not counted in total_ops
394        collector.record_rate_limit(true);
395        collector.record_rate_limit(true);
396
397        // Record some failures (but less than successes)
398        collector.record_auth_attempt(false, None);
399        collector.record_error("protocol_error");
400
401        let stats = collector.collect_statistics();
402
403        // Should have a good success rate but not perfect due to failures
404        let success_rate = stats.success_rate();
405        assert!(success_rate > 0.5);
406        assert!(success_rate < 1.0);
407    }
408
409    #[test]
410    fn test_reset_functionality() {
411        let collector = RelayStatisticsCollector::new();
412
413        // Add some data
414        collector.record_auth_attempt(true, None);
415        collector.record_error("test_error");
416        collector.record_rate_limit(false);
417
418        // Verify data exists
419        let stats_before = collector.collect_statistics();
420        assert!(stats_before.auth_stats.total_auth_attempts > 0);
421
422        // Reset and verify clean state
423        collector.reset();
424        let stats_after = collector.collect_statistics();
425        assert_eq!(stats_after.auth_stats.total_auth_attempts, 0);
426        assert_eq!(stats_after.rate_limit_stats.total_requests, 0);
427    }
428}