1use std::net::SocketAddr;
4use std::sync::Arc;
5
6use bytes::Bytes;
7use rusty_modbus_frame::frame::{Frame, FrameHeader};
8use rusty_modbus_tcp::config::TcpServerConfig;
9use rusty_modbus_tcp::listener::TcpServerListener;
10use rusty_modbus_tcp::transport::{TransportSink, TransportStream};
11use rusty_modbus_types::{ExceptionCode, MAX_PDU_SIZE, MbapHeader, UnitId};
12use tokio::sync::watch;
13use tracing::{debug, info, trace, warn};
14
15use crate::config::{DeviceIdentification, ServerConfig};
16use crate::error::ServerError;
17use crate::handler;
18use crate::store::DataStore;
19
20pub struct ModbusServer<S: DataStore> {
22 config: ServerConfig,
23 store: Arc<S>,
24 local_addr: SocketAddr,
25 shutdown_tx: watch::Sender<bool>,
26 accept_handle: Option<tokio::task::JoinHandle<()>>,
27}
28
29impl<S: DataStore + 'static> ModbusServer<S> {
30 #[tracing::instrument(level = "debug", skip(config, store), fields(addr = %config.listen_addr, unit_id = config.unit_id.0))]
38 pub async fn start(config: ServerConfig, store: Arc<S>) -> Result<Self, ServerError> {
39 let tcp_config = TcpServerConfig {
40 max_connections: config.max_connections,
41 ..config.tcp_config.clone()
42 };
43
44 let listener = TcpServerListener::bind(config.listen_addr, tcp_config)
45 .await
46 .map_err(|e| match e {
47 rusty_modbus_tcp::TransportError::Io(io) => ServerError::Bind(io),
48 other => ServerError::Transport(other),
49 })?;
50
51 let local_addr = listener.local_addr().map_err(|e| match e {
52 rusty_modbus_tcp::TransportError::Io(io) => ServerError::Bind(io),
53 other => ServerError::Transport(other),
54 })?;
55 info!(addr = %local_addr, unit_id = config.unit_id.0, "Modbus server listening");
56
57 let (shutdown_tx, shutdown_rx) = watch::channel(false);
58
59 let server_unit_id = config.unit_id;
60 let server_store = Arc::clone(&store);
61 let server_device_id = config.device_id.clone();
62
63 let accept_handle = tokio::spawn(async move {
64 accept_loop(
65 listener,
66 server_unit_id,
67 server_store,
68 server_device_id,
69 shutdown_rx,
70 )
71 .await;
72 });
73
74 Ok(Self {
75 config,
76 store,
77 local_addr,
78 shutdown_tx,
79 accept_handle: Some(accept_handle),
80 })
81 }
82
83 pub async fn stop(&self) {
85 info!(addr = %self.local_addr, "stopping Modbus server");
86 let _ = self.shutdown_tx.send(true);
87 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
89 }
90
91 #[must_use]
93 pub fn store(&self) -> &S {
94 self.store.as_ref()
95 }
96
97 #[must_use]
99 pub fn local_addr(&self) -> SocketAddr {
100 self.local_addr
101 }
102}
103
104impl<S: DataStore> Drop for ModbusServer<S> {
105 fn drop(&mut self) {
106 let _ = self.shutdown_tx.send(true);
107 if let Some(h) = self.accept_handle.take() {
108 h.abort();
109 }
110 }
111}
112
113impl<S: DataStore> std::fmt::Debug for ModbusServer<S> {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 f.debug_struct("ModbusServer")
116 .field("addr", &self.local_addr)
117 .field("unit_id", &self.config.unit_id)
118 .finish_non_exhaustive()
119 }
120}
121
122async fn accept_loop<S: DataStore + 'static>(
123 listener: TcpServerListener,
124 unit_id: UnitId,
125 store: Arc<S>,
126 device_id: DeviceIdentification,
127 mut shutdown_rx: watch::Receiver<bool>,
128) {
129 loop {
130 tokio::select! {
131 result = listener.accept() => {
132 if let Ok((sink, stream, addr, guard)) = result {
133 debug!(peer_addr = %addr, "accepted Modbus server connection");
134 let conn_store = Arc::clone(&store);
135 let conn_device_id = device_id.clone();
136 tokio::spawn(async move {
137 handle_connection(sink, stream, addr, unit_id, conn_store, conn_device_id).await;
138 drop(guard);
139 });
140 } else if let Err(error) = result {
141 warn!(error = %error, "Modbus server accept failed");
142 }
143 }
145 _ = shutdown_rx.changed() => {
146 if *shutdown_rx.borrow() {
147 debug!("Modbus server accept loop received shutdown");
148 break;
149 }
150 }
151 }
152 }
153}
154
155async fn handle_connection<S: DataStore>(
156 mut sink: rusty_modbus_tcp::TcpSink,
157 mut stream: rusty_modbus_tcp::TcpRecvStream,
158 peer_addr: SocketAddr,
159 unit_id: UnitId,
160 store: Arc<S>,
161 device_id: DeviceIdentification,
162) {
163 while let Ok(frame) = stream.recv().await {
164 let request_unit_id = UnitId(frame.unit_id());
165 let pdu_len = frame.pdu.len();
166 trace!(
167 peer_addr = %peer_addr,
168 request_unit_id = request_unit_id.0,
169 pdu_len,
170 "received Modbus server request"
171 );
172
173 if request_unit_id.0 != unit_id.0
175 && !request_unit_id.is_broadcast()
176 && !request_unit_id.is_tcp_device()
177 {
178 debug!(
180 peer_addr = %peer_addr,
181 request_unit_id = request_unit_id.0,
182 server_unit_id = unit_id.0,
183 "discarding request for different unit id"
184 );
185 continue;
186 }
187
188 let txn_id = match frame.header {
189 FrameHeader::Mbap(h) => h.transaction_id.get(),
190 FrameHeader::Rtu { .. } => 0,
191 };
192
193 if let Some(response_pdu) =
195 handler::process_request(&frame.pdu, request_unit_id, store.as_ref(), &device_id).await
196 {
197 let Some(response_frame) = response_frame(txn_id, request_unit_id, response_pdu) else {
198 warn!(peer_addr = %peer_addr, txn_id, "dropping empty Modbus response PDU");
199 break;
200 };
201 if let Err(error) = sink.send(response_frame).await {
202 debug!(peer_addr = %peer_addr, txn_id, error = %error, "failed to send Modbus response");
203 break; }
205 trace!(peer_addr = %peer_addr, txn_id, "sent Modbus server response");
206 }
207 }
209 debug!(peer_addr = %peer_addr, "Modbus server connection closed");
210}
211
212fn response_frame(txn_id: u16, unit_id: UnitId, response_pdu: Vec<u8>) -> Option<Frame> {
213 let pdu = bounded_response_pdu(response_pdu)?;
214 let pdu_len = u16::try_from(pdu.len()).expect("MAX_PDU_SIZE fits in u16");
215 let header = MbapHeader::new(txn_id, unit_id.0, pdu_len);
216 Some(Frame {
217 header: FrameHeader::Mbap(header),
218 pdu: Bytes::from(pdu),
219 })
220}
221
222fn bounded_response_pdu(response_pdu: Vec<u8>) -> Option<Vec<u8>> {
223 let fc = response_pdu.first().copied()?;
224 if response_pdu.len() <= MAX_PDU_SIZE {
225 return Some(response_pdu);
226 }
227
228 warn!(
229 function_code = fc,
230 pdu_len = response_pdu.len(),
231 max_pdu_size = MAX_PDU_SIZE,
232 "server response exceeded Modbus PDU limit"
233 );
234 Some(vec![fc | 0x80, ExceptionCode::ServerDeviceFailure.code()])
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[test]
242 fn response_frame_preserves_valid_pdu() {
243 let frame = response_frame(0x1234, UnitId(7), vec![0x03, 0x02, 0xAA, 0xBB])
244 .expect("valid response should produce a frame");
245
246 match frame.header {
247 FrameHeader::Mbap(header) => {
248 assert_eq!(header.transaction_id.get(), 0x1234);
249 assert_eq!(header.unit_id, 7);
250 assert_eq!(header.pdu_length(), 4);
251 }
252 FrameHeader::Rtu { .. } => panic!("expected MBAP response"),
253 }
254 assert_eq!(frame.pdu.as_ref(), &[0x03, 0x02, 0xAA, 0xBB]);
255 }
256
257 #[test]
258 fn response_frame_turns_oversized_pdu_into_exception() {
259 let frame = response_frame(0xBEEF, UnitId(2), vec![0x03; MAX_PDU_SIZE + 1])
260 .expect("oversized response should become an exception frame");
261
262 match frame.header {
263 FrameHeader::Mbap(header) => {
264 assert_eq!(header.transaction_id.get(), 0xBEEF);
265 assert_eq!(header.unit_id, 2);
266 assert_eq!(header.pdu_length(), 2);
267 }
268 FrameHeader::Rtu { .. } => panic!("expected MBAP response"),
269 }
270 assert_eq!(
271 frame.pdu.as_ref(),
272 &[0x83, ExceptionCode::ServerDeviceFailure.code()]
273 );
274 }
275
276 #[test]
277 fn response_frame_drops_empty_pdu() {
278 assert!(response_frame(0, UnitId(1), Vec::new()).is_none());
279 }
280}