modbus_relay/
stats_manager.rs

1use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::SystemTime};
2
3use tokio::sync::{Mutex, mpsc};
4use tracing::{debug, info, warn};
5
6use crate::{ClientStats, ConnectionStats, config::StatsConfig, connection::StatEvent};
7
8pub struct StatsManager {
9    stats: Arc<Mutex<HashMap<SocketAddr, ClientStats>>>,
10    event_rx: mpsc::Receiver<StatEvent>,
11    config: StatsConfig,
12    total_connections: u64,
13}
14
15impl StatsManager {
16    pub fn new(config: StatsConfig) -> (Self, mpsc::Sender<StatEvent>) {
17        let (tx, rx) = mpsc::channel(config.max_events_per_second as usize);
18
19        let manager = Self {
20            stats: Arc::new(Mutex::new(HashMap::new())),
21            event_rx: rx,
22            config,
23            total_connections: 0,
24        };
25
26        (manager, tx)
27    }
28
29    pub async fn run(&mut self, mut shutdown_rx: tokio::sync::watch::Receiver<bool>) {
30        let mut cleanup_interval = tokio::time::interval(self.config.cleanup_interval);
31
32        loop {
33            tokio::select! {
34                shutdown = shutdown_rx.changed() => {
35                    match shutdown {
36                        Ok(_) => {
37                            info!("Stats manager shutting down");
38                            // Ensure all events are processed before shutting down
39                            while let Ok(event) = self.event_rx.try_recv() {
40                                self.handle_event(event).await;
41                            }
42                            break;
43                        }
44                        Err(e) => {
45                            warn!("Shutdown channel closed: {}", e);
46                            break;
47                        }
48                    }
49                }
50
51                Some(event) = self.event_rx.recv() => {
52                    self.handle_event(event).await;
53                }
54
55                _ = cleanup_interval.tick() => {
56                    self.cleanup_idle_stats().await;
57                }
58            }
59        }
60
61        info!("Stats manager shutdown complete");
62    }
63
64    async fn handle_event(&mut self, event: StatEvent) {
65        let mut stats = self.stats.lock().await;
66
67        match event {
68            StatEvent::ClientConnected(addr) => {
69                let client_stats = stats.entry(addr).or_default();
70                client_stats.active_connections = client_stats.active_connections.saturating_add(1);
71                client_stats.last_active = SystemTime::now();
72                self.total_connections = self.total_connections.saturating_add(1);
73                debug!("Client connected from {}", addr);
74            }
75
76            StatEvent::ClientDisconnected(addr) => {
77                if let Some(client_stats) = stats.get_mut(&addr) {
78                    client_stats.active_connections =
79                        client_stats.active_connections.saturating_sub(1);
80                    client_stats.last_active = SystemTime::now();
81                    debug!("Client disconnected from {}", addr);
82                }
83            }
84
85            StatEvent::RequestProcessed {
86                addr,
87                success,
88                duration_ms,
89            } => {
90                let client_stats = stats.entry(addr).or_default();
91                client_stats.total_requests = client_stats.total_requests.saturating_add(1);
92
93                if !success {
94                    client_stats.total_errors = client_stats.total_errors.saturating_add(1);
95                    client_stats.last_error = Some(SystemTime::now());
96                }
97
98                // Update average response time using exponential moving average
99                const ALPHA: f64 = 0.1; // Smoothing factor
100
101                if client_stats.avg_response_time_ms == 0 {
102                    client_stats.avg_response_time_ms = duration_ms;
103                } else {
104                    let current_avg = client_stats.avg_response_time_ms as f64;
105                    client_stats.avg_response_time_ms =
106                        (current_avg + ALPHA * (duration_ms as f64 - current_avg)) as u64;
107                }
108
109                client_stats.last_active = SystemTime::now();
110            }
111
112            StatEvent::QueryStats { addr, response_tx } => {
113                if let Some(stats) = stats.get(&addr)
114                    && response_tx.send(stats.clone()).is_err()
115                {
116                    warn!("Failed to send stats for {}", addr);
117                }
118            }
119
120            StatEvent::QueryConnectionStats { response_tx } => {
121                let conn_stats = ConnectionStats::from_client_stats(&stats);
122                if response_tx.send(conn_stats).is_err() {
123                    warn!("Failed to send connection stats");
124                }
125            }
126        }
127    }
128
129    async fn cleanup_idle_stats(&self) {
130        let mut stats = self.stats.lock().await;
131        let now = SystemTime::now();
132
133        stats.retain(|addr, client_stats| {
134            // Check if client has been idle for too long
135            let is_idle = now
136                .duration_since(client_stats.last_active)
137                .map(|idle_time| idle_time <= self.config.idle_timeout)
138                .unwrap_or(true);
139
140            // Check if there was an error that's old enough to clean up
141            let has_recent_error = client_stats
142                .last_error
143                .and_then(|last_error| now.duration_since(last_error).ok())
144                .map(|error_time| error_time <= self.config.error_timeout)
145                .unwrap_or(false);
146
147            let should_retain = is_idle || has_recent_error;
148
149            if !should_retain {
150                debug!(
151                    "Cleaning up stats for {}: {} connections, {} requests, {} errors",
152                    addr,
153                    client_stats.active_connections,
154                    client_stats.total_requests,
155                    client_stats.total_errors
156                );
157            }
158
159            should_retain
160        });
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use std::time::Duration;
167
168    use super::*;
169    use tokio::{sync::oneshot, time::sleep};
170
171    #[tokio::test]
172    async fn test_client_lifecycle() {
173        let config = StatsConfig::default();
174        let (mut manager, tx) = StatsManager::new(config);
175        let addr = "127.0.0.1:8080".parse().unwrap();
176
177        let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
178        let manager_handle = tokio::spawn(async move {
179            manager.run(shutdown_rx).await;
180        });
181
182        // Test connection
183        tx.send(StatEvent::ClientConnected(addr)).await.unwrap();
184
185        // Test successful request
186        tx.send(StatEvent::RequestProcessed {
187            addr,
188            success: true,
189            duration_ms: Duration::from_millis(100).as_millis() as u64,
190        })
191        .await
192        .unwrap();
193
194        // Test failed request
195        tx.send(StatEvent::RequestProcessed {
196            addr,
197            success: false,
198            duration_ms: Duration::from_millis(150).as_millis() as u64,
199        })
200        .await
201        .unwrap();
202
203        sleep(Duration::from_millis(100)).await;
204
205        // Query per-client stats
206        let (response_tx, response_rx) = oneshot::channel();
207        tx.send(StatEvent::QueryStats { addr, response_tx })
208            .await
209            .unwrap();
210
211        let stats = response_rx.await.unwrap();
212        assert_eq!(stats.active_connections, 1);
213        assert_eq!(stats.total_requests, 2);
214        assert_eq!(stats.total_errors, 1);
215
216        // Query global stats
217        let (response_tx, response_rx) = oneshot::channel();
218        tx.send(StatEvent::QueryConnectionStats { response_tx })
219            .await
220            .unwrap();
221
222        let conn_stats = response_rx.await.unwrap();
223        assert_eq!(conn_stats.total_requests, 2);
224        assert_eq!(conn_stats.total_errors, 1);
225
226        // Cleanup
227        shutdown_tx.send(true).unwrap();
228        manager_handle.await.unwrap();
229    }
230
231    #[tokio::test]
232    async fn test_cleanup_idle_stats() {
233        let mut config = StatsConfig::default();
234        config.idle_timeout = Duration::from_millis(100);
235        let (mut manager, tx) = StatsManager::new(config);
236        let addr = "127.0.0.1:8080".parse().unwrap();
237
238        let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
239        let manager_handle = tokio::spawn(async move {
240            manager.run(shutdown_rx).await;
241        });
242
243        // Add client and disconnect
244        tx.send(StatEvent::ClientConnected(addr)).await.unwrap();
245        tx.send(StatEvent::ClientDisconnected(addr)).await.unwrap();
246
247        // Wait for idle timeout
248        sleep(Duration::from_millis(200)).await;
249
250        // Query stats - should be cleaned up
251        let (response_tx, response_rx) = oneshot::channel();
252        tx.send(StatEvent::QueryConnectionStats { response_tx })
253            .await
254            .unwrap();
255
256        let conn_stats = response_rx.await.unwrap();
257        assert_eq!(conn_stats.active_connections, 0);
258
259        shutdown_tx.send(true).unwrap();
260        manager_handle.await.unwrap();
261    }
262}