modbus_relay/connection/
manager.rs1use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration};
2
3use tokio::sync::{Mutex, Semaphore, mpsc, oneshot};
4use tracing::error;
5
6use crate::{ConnectionError, RelayError, config::ConnectionConfig};
7
8use super::{ConnectionGuard, ConnectionStats, StatEvent};
9
10#[derive(Debug)]
12pub struct Manager {
13 per_ip_semaphores: Arc<Mutex<HashMap<SocketAddr, Arc<Semaphore>>>>,
15 global_semaphore: Arc<Semaphore>,
17 active_connections: Arc<Mutex<HashMap<SocketAddr, usize>>>,
19 config: ConnectionConfig,
21 stats_tx: mpsc::Sender<StatEvent>,
23}
24
25impl Manager {
26 pub fn new(config: ConnectionConfig, stats_tx: mpsc::Sender<StatEvent>) -> Self {
27 Self {
28 per_ip_semaphores: Arc::new(Mutex::new(HashMap::new())),
29 global_semaphore: Arc::new(Semaphore::new(config.max_connections as usize)),
30 active_connections: Arc::new(Mutex::new(HashMap::new())),
31 config,
32 stats_tx,
33 }
34 }
35
36 pub async fn accept_connection(
38 self: &Arc<Self>,
39 addr: SocketAddr,
40 ) -> Result<ConnectionGuard, RelayError> {
41 let per_ip_permit = if let Some(per_ip_limit) = self.config.per_ip_limits {
43 let mut semaphores = self.per_ip_semaphores.lock().await;
44
45 let semaphore = semaphores
46 .entry(addr)
47 .or_insert_with(|| Arc::new(Semaphore::new(per_ip_limit as usize)));
48
49 Some(semaphore.clone().try_acquire_owned().map_err(|_| {
50 RelayError::Connection(ConnectionError::limit_exceeded(format!(
51 "Per-IP limit ({}) reached for {}",
52 per_ip_limit, addr
53 )))
54 })?)
55 } else {
56 None
57 };
58
59 let global_permit = self
61 .global_semaphore
62 .clone()
63 .try_acquire_owned()
64 .map_err(|_| {
65 RelayError::Connection(ConnectionError::limit_exceeded(
66 "Global connection limit reached",
67 ))
68 })?;
69
70 {
72 let mut active_conns = self.active_connections.lock().await;
73 let conn_count = active_conns.entry(addr).or_default();
74 *conn_count = conn_count.saturating_add(1);
75 }
76
77 if let Err(e) = self.stats_tx.send(StatEvent::ClientConnected(addr)).await {
79 error!("Failed to send connection event to stats manager: {}", e);
80 }
81
82 Ok(ConnectionGuard {
83 manager: Arc::clone(self),
84 addr,
85 _global_permit: global_permit,
86 _per_ip_permit: per_ip_permit,
87 })
88 }
89
90 pub async fn get_connection_count(&self, addr: &SocketAddr) -> usize {
91 self.active_connections
92 .lock()
93 .await
94 .get(addr)
95 .copied()
96 .unwrap_or(0)
97 }
98
99 pub async fn get_total_connections(&self) -> usize {
100 self.active_connections.lock().await.values().sum()
101 }
102
103 pub async fn record_request(&self, addr: SocketAddr, success: bool, duration: Duration) {
105 if let Err(e) = self
106 .stats_tx
107 .send(StatEvent::RequestProcessed {
108 addr,
109 success,
110 duration_ms: duration.as_millis() as u64,
111 })
112 .await
113 {
114 error!("Failed to send request stats: {}", e);
115 }
116 }
117
118 pub async fn get_stats(&self) -> Result<ConnectionStats, RelayError> {
120 let (tx, rx) = oneshot::channel();
121
122 self.stats_tx
123 .send(StatEvent::QueryConnectionStats { response_tx: tx })
124 .await
125 .map_err(|_| {
126 RelayError::Connection(ConnectionError::invalid_state(
127 "Failed to query connection stats",
128 ))
129 })?;
130
131 rx.await.map_err(|_| {
132 RelayError::Connection(ConnectionError::invalid_state(
133 "Failed to receive connection stats",
134 ))
135 })
136 }
137
138 pub(crate) async fn cleanup_idle_connections(&self) -> Result<(), RelayError> {
140 let (tx, rx) = oneshot::channel();
142
143 self.stats_tx
144 .send(StatEvent::QueryConnectionStats { response_tx: tx })
145 .await
146 .map_err(|_| {
147 RelayError::Connection(ConnectionError::invalid_state(
148 "Failed to query stats for cleanup",
149 ))
150 })?;
151
152 let stats = rx.await.map_err(|_| {
153 RelayError::Connection(ConnectionError::invalid_state(
154 "Failed to receive stats for cleanup",
155 ))
156 })?;
157
158 let mut active_conns = self.active_connections.lock().await;
159 active_conns.retain(|addr, count| {
160 if let Some(ip_stats) = stats.per_ip_stats.get(addr) {
161 ip_stats.active_connections > 0
162 } else {
163 *count == 0
165 }
166 });
167
168 Ok(())
169 }
170
171 pub(crate) fn decrease_connection_count(&self, addr: SocketAddr) {
172 let mut active_conns = self
173 .active_connections
174 .try_lock()
175 .expect("Failed to lock active_connections in guard drop");
176
177 if let Some(count) = active_conns.get_mut(&addr) {
178 *count = count.saturating_sub(1);
179 if *count == 0 {
180 active_conns.remove(&addr);
181 }
182 }
183 }
184
185 pub fn stats_tx(&self) -> mpsc::Sender<StatEvent> {
186 self.stats_tx.clone()
187 }
188}