modbus_relay/connection/
manager.rs

1use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration};
2
3use tokio::sync::{Mutex, Semaphore, mpsc, oneshot};
4use tracing::error;
5
6use crate::{ConnectionError, RelayError, config::ConnectionConfig};
7
8use super::{ConnectionGuard, ConnectionStats, StatEvent};
9
10/// TCP connection management
11#[derive(Debug)]
12pub struct Manager {
13    /// Connection limit per IP
14    per_ip_semaphores: Arc<Mutex<HashMap<SocketAddr, Arc<Semaphore>>>>,
15    /// Global connection limit
16    global_semaphore: Arc<Semaphore>,
17    /// Active connections counter per IP
18    active_connections: Arc<Mutex<HashMap<SocketAddr, usize>>>,
19    /// Configuration
20    config: ConnectionConfig,
21    /// Stats event sender
22    stats_tx: mpsc::Sender<StatEvent>,
23}
24
25impl Manager {
26    pub fn new(config: ConnectionConfig, stats_tx: mpsc::Sender<StatEvent>) -> Self {
27        Self {
28            per_ip_semaphores: Arc::new(Mutex::new(HashMap::new())),
29            global_semaphore: Arc::new(Semaphore::new(config.max_connections as usize)),
30            active_connections: Arc::new(Mutex::new(HashMap::new())),
31            config,
32            stats_tx,
33        }
34    }
35
36    /// Attempt to establish a new connection
37    pub async fn accept_connection(
38        self: &Arc<Self>,
39        addr: SocketAddr,
40    ) -> Result<ConnectionGuard, RelayError> {
41        // Check per IP limit if enabled
42        let per_ip_permit = if let Some(per_ip_limit) = self.config.per_ip_limits {
43            let mut semaphores = self.per_ip_semaphores.lock().await;
44
45            let semaphore = semaphores
46                .entry(addr)
47                .or_insert_with(|| Arc::new(Semaphore::new(per_ip_limit as usize)));
48
49            Some(semaphore.clone().try_acquire_owned().map_err(|_| {
50                RelayError::Connection(ConnectionError::limit_exceeded(format!(
51                    "Per-IP limit ({}) reached for {}",
52                    per_ip_limit, addr
53                )))
54            })?)
55        } else {
56            None
57        };
58
59        // Check if the global limit has been exceeded
60        let global_permit = self
61            .global_semaphore
62            .clone()
63            .try_acquire_owned()
64            .map_err(|_| {
65                RelayError::Connection(ConnectionError::limit_exceeded(
66                    "Global connection limit reached",
67                ))
68            })?;
69
70        // Increment active connections counter
71        {
72            let mut active_conns = self.active_connections.lock().await;
73            let conn_count = active_conns.entry(addr).or_default();
74            *conn_count = conn_count.saturating_add(1);
75        }
76
77        // Notify stats manager about new connection
78        if let Err(e) = self.stats_tx.send(StatEvent::ClientConnected(addr)).await {
79            error!("Failed to send connection event to stats manager: {}", e);
80        }
81
82        Ok(ConnectionGuard {
83            manager: Arc::clone(self),
84            addr,
85            _global_permit: global_permit,
86            _per_ip_permit: per_ip_permit,
87        })
88    }
89
90    pub async fn get_connection_count(&self, addr: &SocketAddr) -> usize {
91        self.active_connections
92            .lock()
93            .await
94            .get(addr)
95            .copied()
96            .unwrap_or(0)
97    }
98
99    pub async fn get_total_connections(&self) -> usize {
100        self.active_connections.lock().await.values().sum()
101    }
102
103    /// Updates statistics for a given request
104    pub async fn record_request(&self, addr: SocketAddr, success: bool, duration: Duration) {
105        if let Err(e) = self
106            .stats_tx
107            .send(StatEvent::RequestProcessed {
108                addr,
109                success,
110                duration_ms: duration.as_millis() as u64,
111            })
112            .await
113        {
114            error!("Failed to send request stats: {}", e);
115        }
116    }
117
118    /// Gets complete connection statistics
119    pub async fn get_stats(&self) -> Result<ConnectionStats, RelayError> {
120        let (tx, rx) = oneshot::channel();
121
122        self.stats_tx
123            .send(StatEvent::QueryConnectionStats { response_tx: tx })
124            .await
125            .map_err(|_| {
126                RelayError::Connection(ConnectionError::invalid_state(
127                    "Failed to query connection stats",
128                ))
129            })?;
130
131        rx.await.map_err(|_| {
132            RelayError::Connection(ConnectionError::invalid_state(
133                "Failed to receive connection stats",
134            ))
135        })
136    }
137
138    /// Cleans up idle connections
139    pub(crate) async fn cleanup_idle_connections(&self) -> Result<(), RelayError> {
140        // Cleanup is now handled by StatsManager, we just need to sync our active connections
141        let (tx, rx) = oneshot::channel();
142
143        self.stats_tx
144            .send(StatEvent::QueryConnectionStats { response_tx: tx })
145            .await
146            .map_err(|_| {
147                RelayError::Connection(ConnectionError::invalid_state(
148                    "Failed to query stats for cleanup",
149                ))
150            })?;
151
152        let stats = rx.await.map_err(|_| {
153            RelayError::Connection(ConnectionError::invalid_state(
154                "Failed to receive stats for cleanup",
155            ))
156        })?;
157
158        let mut active_conns = self.active_connections.lock().await;
159        active_conns.retain(|addr, count| {
160            if let Some(ip_stats) = stats.per_ip_stats.get(addr) {
161                ip_stats.active_connections > 0
162            } else {
163                // If no stats exist, connection is considered inactive
164                *count == 0
165            }
166        });
167
168        Ok(())
169    }
170
171    pub(crate) fn decrease_connection_count(&self, addr: SocketAddr) {
172        let mut active_conns = self
173            .active_connections
174            .try_lock()
175            .expect("Failed to lock active_connections in guard drop");
176
177        if let Some(count) = active_conns.get_mut(&addr) {
178            *count = count.saturating_sub(1);
179            if *count == 0 {
180                active_conns.remove(&addr);
181            }
182        }
183    }
184
185    pub fn stats_tx(&self) -> mpsc::Sender<StatEvent> {
186        self.stats_tx.clone()
187    }
188}