Skip to main content

rustmod_datalink/
server.rs

1use crate::DataLinkError;
2use rustmod_core::encoding::{Reader, Writer};
3use rustmod_core::frame::{rtu as rtu_frame, tcp};
4use rustmod_core::pdu::{DecodedRequest, ExceptionCode, ExceptionResponse};
5use rustmod_core::{DecodeError, UnitId};
6use std::future::Future;
7use std::sync::Arc;
8use thiserror::Error;
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
11use tokio::sync::Semaphore;
12use tracing::{debug, warn};
13
14#[cfg(feature = "metrics")]
15use std::sync::atomic::{AtomicU64, Ordering};
16
17const DEFAULT_MAX_PDU_LEN: usize = 253;
18const DEFAULT_MAX_RTU_FRAME_LEN: usize = 256;
19
20/// Errors returned by a [`ModbusService`] handler.
21///
22/// The server maps these to Modbus exception responses on the wire.
23#[derive(Debug, Error)]
24#[non_exhaustive]
25pub enum ServiceError {
26    /// A standard Modbus exception (e.g. illegal address, illegal function).
27    #[error("modbus exception: {0:?}")]
28    Exception(ExceptionCode),
29    /// The request was malformed or contained invalid parameters.
30    #[error("invalid request: {0}")]
31    InvalidRequest(&'static str),
32    /// An internal error (maps to Server Device Failure exception on the wire).
33    #[error("internal error: {0}")]
34    Internal(&'static str),
35}
36
37/// Application-level request handler for Modbus servers.
38///
39/// Implement this trait to define how your device responds to Modbus requests.
40/// See [`InMemoryModbusService`](crate::InMemoryModbusService) for a ready-made
41/// in-memory simulator implementation.
42pub trait ModbusService: Send + Sync + 'static {
43    /// Handle a decoded request and write a response PDU into `response_pdu`.
44    ///
45    /// Return the number of bytes written. The response must include function
46    /// code and payload, but not MBAP header bytes.
47    fn handle(
48        &self,
49        unit_id: UnitId,
50        request: DecodedRequest<'_>,
51        response_pdu: &mut [u8],
52    ) -> Result<usize, ServiceError>;
53}
54
55impl<T> ModbusService for Arc<T>
56where
57    T: ModbusService + ?Sized,
58{
59    fn handle(
60        &self,
61        unit_id: UnitId,
62        request: DecodedRequest<'_>,
63        response_pdu: &mut [u8],
64    ) -> Result<usize, ServiceError> {
65        (**self).handle(unit_id, request, response_pdu)
66    }
67}
68
69/// Atomic counters tracking server activity (available with the `metrics` feature).
70#[cfg(feature = "metrics")]
71#[derive(Debug, Default)]
72pub struct ServerMetrics {
73    requests_total: AtomicU64,
74    responses_ok: AtomicU64,
75    exceptions_sent: AtomicU64,
76    decode_errors: AtomicU64,
77    internal_errors: AtomicU64,
78}
79
80/// A point-in-time snapshot of [`ServerMetrics`] counters.
81#[cfg(feature = "metrics")]
82#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
83pub struct ServerMetricsSnapshot {
84    pub requests_total: u64,
85    pub responses_ok: u64,
86    pub exceptions_sent: u64,
87    pub decode_errors: u64,
88    pub internal_errors: u64,
89}
90
91#[cfg(feature = "metrics")]
92impl ServerMetrics {
93    fn snapshot(&self) -> ServerMetricsSnapshot {
94        ServerMetricsSnapshot {
95            requests_total: self.requests_total.load(Ordering::Relaxed),
96            responses_ok: self.responses_ok.load(Ordering::Relaxed),
97            exceptions_sent: self.exceptions_sent.load(Ordering::Relaxed),
98            decode_errors: self.decode_errors.load(Ordering::Relaxed),
99            internal_errors: self.internal_errors.load(Ordering::Relaxed),
100        }
101    }
102}
103
104const DEFAULT_MAX_CONNECTIONS: usize = 256;
105
106/// Modbus TCP server that accepts connections and dispatches requests to a [`ModbusService`].
107///
108/// Supports configurable connection limits, PDU size limits, and optional metrics.
109/// Use [`run`](Self::run) to accept connections indefinitely, or
110/// [`run_until`](Self::run_until) for graceful shutdown.
111pub struct ModbusTcpServer<S> {
112    listener: TcpListener,
113    service: Arc<S>,
114    max_pdu_len: usize,
115    max_connections: usize,
116    #[cfg(feature = "metrics")]
117    metrics: Arc<ServerMetrics>,
118}
119
120impl<S: ModbusService> ModbusTcpServer<S> {
121    /// Bind to a TCP address and create a new server.
122    pub async fn bind<A: ToSocketAddrs>(addr: A, service: S) -> Result<Self, DataLinkError> {
123        let listener = TcpListener::bind(addr).await?;
124        Ok(Self::from_listener(listener, service))
125    }
126
127    /// Create a server from an existing [`TcpListener`].
128    #[must_use]
129    pub fn from_listener(listener: TcpListener, service: S) -> Self {
130        Self {
131            listener,
132            service: Arc::new(service),
133            max_pdu_len: DEFAULT_MAX_PDU_LEN,
134            max_connections: DEFAULT_MAX_CONNECTIONS,
135            #[cfg(feature = "metrics")]
136            metrics: Arc::new(ServerMetrics::default()),
137        }
138    }
139
140    /// Return the local address the server is bound to.
141    pub fn local_addr(&self) -> Result<std::net::SocketAddr, DataLinkError> {
142        Ok(self.listener.local_addr()?)
143    }
144
145    /// Set the maximum PDU length the server will accept (default: 253).
146    #[must_use]
147    pub fn with_max_pdu_len(mut self, max_pdu_len: usize) -> Self {
148        self.max_pdu_len = max_pdu_len;
149        self
150    }
151
152    /// Set the maximum number of concurrent client connections (default: 256).
153    #[must_use]
154    pub fn with_max_connections(mut self, max_connections: usize) -> Self {
155        self.max_connections = max_connections;
156        self
157    }
158
159    /// Get a cloneable handle to the server metrics.
160    #[cfg(feature = "metrics")]
161    pub fn metrics_handle(&self) -> Arc<ServerMetrics> {
162        Arc::clone(&self.metrics)
163    }
164
165    /// Take a snapshot of the current metrics counters.
166    #[cfg(feature = "metrics")]
167    pub fn metrics_snapshot(&self) -> ServerMetricsSnapshot {
168        self.metrics.snapshot()
169    }
170
171    /// Accept connections and serve requests indefinitely.
172    pub async fn run(self) -> Result<(), DataLinkError> {
173        let semaphore = Arc::new(Semaphore::new(self.max_connections));
174        loop {
175            let (socket, peer) = self.listener.accept().await?;
176            let service = Arc::clone(&self.service);
177            let max_pdu_len = self.max_pdu_len;
178            let permit = Arc::clone(&semaphore);
179            #[cfg(feature = "metrics")]
180            let metrics = Arc::clone(&self.metrics);
181
182            tokio::spawn(async move {
183                let _permit = permit.acquire().await;
184                if let Err(err) = handle_connection(
185                    socket,
186                    service,
187                    max_pdu_len,
188                    #[cfg(feature = "metrics")]
189                    metrics,
190                )
191                .await
192                {
193                    warn!(%peer, error = %err, "modbus tcp server connection ended with error");
194                }
195            });
196        }
197    }
198
199    /// Run the server until the given shutdown future completes.
200    pub async fn run_until(self, shutdown: impl Future<Output = ()> + Send) -> Result<(), DataLinkError> {
201        let semaphore = Arc::new(Semaphore::new(self.max_connections));
202        tokio::pin!(shutdown);
203        loop {
204            tokio::select! {
205                result = self.listener.accept() => {
206                    let (socket, peer) = result?;
207                    let service = Arc::clone(&self.service);
208                    let max_pdu_len = self.max_pdu_len;
209                    let permit = Arc::clone(&semaphore);
210                    #[cfg(feature = "metrics")]
211                    let metrics = Arc::clone(&self.metrics);
212
213                    tokio::spawn(async move {
214                        let _permit = permit.acquire().await;
215                        if let Err(err) = handle_connection(
216                            socket,
217                            service,
218                            max_pdu_len,
219                            #[cfg(feature = "metrics")]
220                            metrics,
221                        )
222                        .await
223                        {
224                            warn!(%peer, error = %err, "modbus tcp server connection ended with error");
225                        }
226                    });
227                }
228                () = &mut shutdown => {
229                    return Ok(());
230                }
231            }
232        }
233    }
234}
235
236/// Modbus RTU-over-TCP server that accepts TCP connections carrying RTU-framed requests.
237///
238/// This is useful for RTU gateways or testing environments where RTU framing
239/// is tunnelled over a TCP socket.
240pub struct ModbusRtuOverTcpServer<S> {
241    listener: TcpListener,
242    service: Arc<S>,
243    max_pdu_len: usize,
244    max_frame_len: usize,
245    max_connections: usize,
246    #[cfg(feature = "metrics")]
247    metrics: Arc<ServerMetrics>,
248}
249
250impl<S: ModbusService> ModbusRtuOverTcpServer<S> {
251    /// Bind to a TCP address and create a new RTU-over-TCP server.
252    pub async fn bind<A: ToSocketAddrs>(addr: A, service: S) -> Result<Self, DataLinkError> {
253        let listener = TcpListener::bind(addr).await?;
254        Ok(Self::from_listener(listener, service))
255    }
256
257    /// Create a server from an existing [`TcpListener`].
258    #[must_use]
259    pub fn from_listener(listener: TcpListener, service: S) -> Self {
260        Self {
261            listener,
262            service: Arc::new(service),
263            max_pdu_len: DEFAULT_MAX_PDU_LEN,
264            max_frame_len: DEFAULT_MAX_RTU_FRAME_LEN,
265            max_connections: DEFAULT_MAX_CONNECTIONS,
266            #[cfg(feature = "metrics")]
267            metrics: Arc::new(ServerMetrics::default()),
268        }
269    }
270
271    /// Return the local address the server is bound to.
272    pub fn local_addr(&self) -> Result<std::net::SocketAddr, DataLinkError> {
273        Ok(self.listener.local_addr()?)
274    }
275
276    /// Set the maximum PDU length (default: 253).
277    #[must_use]
278    pub fn with_max_pdu_len(mut self, max_pdu_len: usize) -> Self {
279        self.max_pdu_len = max_pdu_len;
280        self
281    }
282
283    /// Set the maximum RTU frame length including address + CRC (default: 256).
284    #[must_use]
285    pub fn with_max_frame_len(mut self, max_frame_len: usize) -> Self {
286        self.max_frame_len = max_frame_len;
287        self
288    }
289
290    /// Set the maximum number of concurrent client connections (default: 256).
291    #[must_use]
292    pub fn with_max_connections(mut self, max_connections: usize) -> Self {
293        self.max_connections = max_connections;
294        self
295    }
296
297    /// Get a cloneable handle to the server metrics.
298    #[cfg(feature = "metrics")]
299    pub fn metrics_handle(&self) -> Arc<ServerMetrics> {
300        Arc::clone(&self.metrics)
301    }
302
303    /// Take a snapshot of the current metrics counters.
304    #[cfg(feature = "metrics")]
305    pub fn metrics_snapshot(&self) -> ServerMetricsSnapshot {
306        self.metrics.snapshot()
307    }
308
309    /// Accept connections and serve requests indefinitely.
310    pub async fn run(self) -> Result<(), DataLinkError> {
311        let semaphore = Arc::new(Semaphore::new(self.max_connections));
312        loop {
313            let (socket, peer) = self.listener.accept().await?;
314            let service = Arc::clone(&self.service);
315            let max_pdu_len = self.max_pdu_len;
316            let max_frame_len = self.max_frame_len;
317            let permit = Arc::clone(&semaphore);
318            #[cfg(feature = "metrics")]
319            let metrics = Arc::clone(&self.metrics);
320
321            tokio::spawn(async move {
322                let _permit = permit.acquire().await;
323                if let Err(err) = handle_rtu_over_tcp_connection(
324                    socket,
325                    service,
326                    max_pdu_len,
327                    max_frame_len,
328                    #[cfg(feature = "metrics")]
329                    metrics,
330                )
331                .await
332                {
333                    warn!(
334                        %peer,
335                        error = %err,
336                        "modbus rtu-over-tcp server connection ended with error"
337                    );
338                }
339            });
340        }
341    }
342
343    /// Run the server until the given shutdown future completes.
344    pub async fn run_until(self, shutdown: impl Future<Output = ()> + Send) -> Result<(), DataLinkError> {
345        let semaphore = Arc::new(Semaphore::new(self.max_connections));
346        tokio::pin!(shutdown);
347        loop {
348            tokio::select! {
349                result = self.listener.accept() => {
350                    let (socket, peer) = result?;
351                    let service = Arc::clone(&self.service);
352                    let max_pdu_len = self.max_pdu_len;
353                    let max_frame_len = self.max_frame_len;
354                    let permit = Arc::clone(&semaphore);
355                    #[cfg(feature = "metrics")]
356                    let metrics = Arc::clone(&self.metrics);
357
358                    tokio::spawn(async move {
359                        let _permit = permit.acquire().await;
360                        if let Err(err) = handle_rtu_over_tcp_connection(
361                            socket,
362                            service,
363                            max_pdu_len,
364                            max_frame_len,
365                            #[cfg(feature = "metrics")]
366                            metrics,
367                        )
368                        .await
369                        {
370                            warn!(
371                                %peer,
372                                error = %err,
373                                "modbus rtu-over-tcp server connection ended with error"
374                            );
375                        }
376                    });
377                }
378                () = &mut shutdown => {
379                    return Ok(());
380                }
381            }
382        }
383    }
384}
385
386fn is_write_request(request: &DecodedRequest<'_>) -> bool {
387    matches!(
388        request,
389        DecodedRequest::WriteSingleCoil(_)
390            | DecodedRequest::WriteSingleRegister(_)
391            | DecodedRequest::WriteMultipleCoils(_)
392            | DecodedRequest::WriteMultipleRegisters(_)
393            | DecodedRequest::MaskWriteRegister(_)
394            | DecodedRequest::ReadWriteMultipleRegisters(_)
395    )
396}
397
398async fn handle_connection<S: ModbusService>(
399    mut socket: TcpStream,
400    service: Arc<S>,
401    max_pdu_len: usize,
402    #[cfg(feature = "metrics")] metrics: Arc<ServerMetrics>,
403) -> Result<(), DataLinkError> {
404    let mut request_pdu_buf = [0u8; 253];
405    let mut response_pdu = vec![0u8; max_pdu_len];
406
407    loop {
408        let mut mbap = [0u8; tcp::MBAP_HEADER_LEN];
409        if let Err(err) = socket.read_exact(&mut mbap).await {
410            if err.kind() == std::io::ErrorKind::UnexpectedEof {
411                return Ok(());
412            }
413            return Err(DataLinkError::Io(err));
414        }
415
416        let mut mbap_reader = Reader::new(&mbap);
417        let header = tcp::MbapHeader::decode(&mut mbap_reader)?;
418        let pdu_len = usize::from(header.length)
419            .checked_sub(1)
420            .ok_or(DataLinkError::InvalidResponse("invalid mbap length"))?;
421
422        if pdu_len == 0 || pdu_len > max_pdu_len {
423            return Err(DataLinkError::InvalidResponse("invalid request pdu length"));
424        }
425
426        socket.read_exact(&mut request_pdu_buf[..pdu_len]).await?;
427        let request_pdu = &request_pdu_buf[..pdu_len];
428
429        #[cfg(feature = "metrics")]
430        metrics.requests_total.fetch_add(1, Ordering::Relaxed);
431
432        let mut request_reader = Reader::new(request_pdu);
433        let decoded = match DecodedRequest::decode(&mut request_reader) {
434            Ok(req) if request_reader.is_empty() => req,
435            Ok(_) => {
436                #[cfg(feature = "metrics")]
437                {
438                    metrics.decode_errors.fetch_add(1, Ordering::Relaxed);
439                    metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
440                }
441                let function = request_pdu[0] & 0x7F;
442                send_exception(
443                    &mut socket,
444                    header.transaction_id,
445                    header.unit_id,
446                    function,
447                    ExceptionCode::IllegalDataValue,
448                )
449                .await?;
450                continue;
451            }
452            Err(err) => {
453                #[cfg(feature = "metrics")]
454                {
455                    metrics.decode_errors.fetch_add(1, Ordering::Relaxed);
456                    metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
457                }
458                let function = request_pdu.first().copied().unwrap_or(0) & 0x7F;
459                send_exception(
460                    &mut socket,
461                    header.transaction_id,
462                    header.unit_id,
463                    function,
464                    map_decode_error_to_exception(err),
465                )
466                .await?;
467                continue;
468            }
469        };
470
471        debug!(
472            correlation_id = header.transaction_id,
473            unit_id = header.unit_id.as_u8(),
474            function = decoded.function_code().as_u8(),
475            pdu_len,
476            "received modbus tcp request"
477        );
478
479        // Broadcast handling (unit_id == 0)
480        if header.unit_id == UnitId::BROADCAST {
481            if is_write_request(&decoded) {
482                // Process write but don't send response
483                let _ = service.handle(header.unit_id, decoded, &mut response_pdu);
484                continue;
485            } else {
486                // Read on broadcast: send IllegalFunction exception
487                send_exception(
488                    &mut socket,
489                    header.transaction_id,
490                    header.unit_id,
491                    decoded.function_code().as_u8(),
492                    ExceptionCode::IllegalFunction,
493                )
494                .await?;
495                continue;
496            }
497        }
498
499        match service.handle(header.unit_id, decoded, &mut response_pdu) {
500            Ok(response_len) => {
501                if response_len == 0 || response_len > max_pdu_len {
502                    #[cfg(feature = "metrics")]
503                    {
504                        metrics.internal_errors.fetch_add(1, Ordering::Relaxed);
505                        metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
506                    }
507                    send_exception(
508                        &mut socket,
509                        header.transaction_id,
510                        header.unit_id,
511                        decoded.function_code().as_u8(),
512                        ExceptionCode::ServerDeviceFailure,
513                    )
514                    .await?;
515                    continue;
516                }
517
518                #[cfg(feature = "metrics")]
519                metrics.responses_ok.fetch_add(1, Ordering::Relaxed);
520
521                send_pdu(
522                    &mut socket,
523                    header.transaction_id,
524                    header.unit_id,
525                    &response_pdu[..response_len],
526                )
527                .await?;
528            }
529            Err(ServiceError::Exception(code)) => {
530                #[cfg(feature = "metrics")]
531                metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
532
533                send_exception(
534                    &mut socket,
535                    header.transaction_id,
536                    header.unit_id,
537                    decoded.function_code().as_u8(),
538                    code,
539                )
540                .await?;
541            }
542            Err(ServiceError::InvalidRequest(_)) => {
543                #[cfg(feature = "metrics")]
544                metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
545
546                send_exception(
547                    &mut socket,
548                    header.transaction_id,
549                    header.unit_id,
550                    decoded.function_code().as_u8(),
551                    ExceptionCode::IllegalDataValue,
552                )
553                .await?;
554            }
555            Err(_) => {
556                #[cfg(feature = "metrics")]
557                {
558                    metrics.internal_errors.fetch_add(1, Ordering::Relaxed);
559                    metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
560                }
561
562                send_exception(
563                    &mut socket,
564                    header.transaction_id,
565                    header.unit_id,
566                    decoded.function_code().as_u8(),
567                    ExceptionCode::ServerDeviceFailure,
568                )
569                .await?;
570            }
571        }
572    }
573}
574
575fn decode_rtu_suffix_frame(buffer: &[u8]) -> Option<(usize, UnitId, &[u8])> {
576    if buffer.len() < 4 {
577        return None;
578    }
579    for start in 0..=buffer.len() - 4 {
580        if let Ok((unit_id, pdu)) = rtu_frame::decode_frame(&buffer[start..]) {
581            return Some((start, unit_id, pdu));
582        }
583    }
584    None
585}
586
587async fn handle_rtu_over_tcp_connection<S: ModbusService>(
588    mut socket: TcpStream,
589    service: Arc<S>,
590    max_pdu_len: usize,
591    max_frame_len: usize,
592    #[cfg(feature = "metrics")] metrics: Arc<ServerMetrics>,
593) -> Result<(), DataLinkError> {
594    if max_frame_len < 4 {
595        return Err(DataLinkError::InvalidResponse(
596            "rtu frame length must be at least 4 bytes",
597        ));
598    }
599
600    let mut frame = vec![0u8; max_frame_len];
601    let mut len = 0usize;
602    let mut response_pdu = vec![0u8; max_pdu_len];
603
604    loop {
605        if len == max_frame_len {
606            // Drop oldest byte so we can continue scanning for a valid frame boundary.
607            frame.copy_within(1..max_frame_len, 0);
608            len -= 1;
609        }
610
611        let n = socket.read(&mut frame[len..len + 1]).await?;
612        if n == 0 {
613            return Ok(());
614        }
615        len += n;
616
617        let Some((_, unit_id, request_pdu)) = decode_rtu_suffix_frame(&frame[..len]) else {
618            continue;
619        };
620        len = 0;
621
622        #[cfg(feature = "metrics")]
623        metrics.requests_total.fetch_add(1, Ordering::Relaxed);
624
625        if request_pdu.is_empty() || request_pdu.len() > max_pdu_len {
626            #[cfg(feature = "metrics")]
627            {
628                metrics.decode_errors.fetch_add(1, Ordering::Relaxed);
629                metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
630            }
631            send_rtu_exception(&mut socket, unit_id, 0, ExceptionCode::IllegalDataValue).await?;
632            continue;
633        }
634
635        let mut request_reader = Reader::new(request_pdu);
636        let decoded = match DecodedRequest::decode(&mut request_reader) {
637            Ok(req) if request_reader.is_empty() => req,
638            Ok(_) => {
639                #[cfg(feature = "metrics")]
640                {
641                    metrics.decode_errors.fetch_add(1, Ordering::Relaxed);
642                    metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
643                }
644                let function = request_pdu[0] & 0x7F;
645                send_rtu_exception(
646                    &mut socket,
647                    unit_id,
648                    function,
649                    ExceptionCode::IllegalDataValue,
650                )
651                .await?;
652                continue;
653            }
654            Err(err) => {
655                #[cfg(feature = "metrics")]
656                {
657                    metrics.decode_errors.fetch_add(1, Ordering::Relaxed);
658                    metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
659                }
660                let function = request_pdu.first().copied().unwrap_or(0) & 0x7F;
661                send_rtu_exception(
662                    &mut socket,
663                    unit_id,
664                    function,
665                    map_decode_error_to_exception(err),
666                )
667                .await?;
668                continue;
669            }
670        };
671
672        debug!(
673            unit_id = unit_id.as_u8(),
674            function = decoded.function_code().as_u8(),
675            pdu_len = request_pdu.len(),
676            "received modbus rtu-over-tcp request"
677        );
678
679        match service.handle(unit_id, decoded, &mut response_pdu) {
680            Ok(response_len) => {
681                if response_len == 0 || response_len > max_pdu_len {
682                    #[cfg(feature = "metrics")]
683                    {
684                        metrics.internal_errors.fetch_add(1, Ordering::Relaxed);
685                        metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
686                    }
687                    send_rtu_exception(
688                        &mut socket,
689                        unit_id,
690                        decoded.function_code().as_u8(),
691                        ExceptionCode::ServerDeviceFailure,
692                    )
693                    .await?;
694                    continue;
695                }
696
697                #[cfg(feature = "metrics")]
698                metrics.responses_ok.fetch_add(1, Ordering::Relaxed);
699
700                send_rtu_pdu(&mut socket, unit_id, &response_pdu[..response_len]).await?;
701            }
702            Err(ServiceError::Exception(code)) => {
703                #[cfg(feature = "metrics")]
704                metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
705
706                send_rtu_exception(&mut socket, unit_id, decoded.function_code().as_u8(), code)
707                    .await?;
708            }
709            Err(ServiceError::InvalidRequest(_)) => {
710                #[cfg(feature = "metrics")]
711                metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
712
713                send_rtu_exception(
714                    &mut socket,
715                    unit_id,
716                    decoded.function_code().as_u8(),
717                    ExceptionCode::IllegalDataValue,
718                )
719                .await?;
720            }
721            Err(ServiceError::Internal(_)) => {
722                #[cfg(feature = "metrics")]
723                {
724                    metrics.internal_errors.fetch_add(1, Ordering::Relaxed);
725                    metrics.exceptions_sent.fetch_add(1, Ordering::Relaxed);
726                }
727
728                send_rtu_exception(
729                    &mut socket,
730                    unit_id,
731                    decoded.function_code().as_u8(),
732                    ExceptionCode::ServerDeviceFailure,
733                )
734                .await?;
735            }
736        }
737    }
738}
739
740fn map_decode_error_to_exception(err: DecodeError) -> ExceptionCode {
741    match err {
742        DecodeError::InvalidFunctionCode => ExceptionCode::IllegalFunction,
743        DecodeError::InvalidLength | DecodeError::InvalidValue | DecodeError::UnexpectedEof => {
744            ExceptionCode::IllegalDataValue
745        }
746        DecodeError::InvalidCrc | DecodeError::Unsupported | DecodeError::Message(_) => {
747            ExceptionCode::ServerDeviceFailure
748        }
749        _ => ExceptionCode::ServerDeviceFailure,
750    }
751}
752
753async fn send_exception(
754    socket: &mut TcpStream,
755    transaction_id: u16,
756    unit_id: UnitId,
757    function_code: u8,
758    exception_code: ExceptionCode,
759) -> Result<(), DataLinkError> {
760    let mut pdu = [0u8; 2];
761    let mut pdu_writer = Writer::new(&mut pdu);
762    ExceptionResponse {
763        function_code,
764        exception_code,
765    }
766    .encode(&mut pdu_writer)
767    .map_err(DataLinkError::Encode)?;
768
769    send_pdu(socket, transaction_id, unit_id, pdu_writer.as_written()).await
770}
771
772async fn send_pdu(
773    socket: &mut TcpStream,
774    transaction_id: u16,
775    unit_id: UnitId,
776    pdu: &[u8],
777) -> Result<(), DataLinkError> {
778    let mut frame = vec![0u8; tcp::MBAP_HEADER_LEN + pdu.len()];
779    let mut frame_writer = Writer::new(&mut frame);
780    tcp::encode_frame(&mut frame_writer, transaction_id, unit_id, pdu)?;
781
782    debug!(
783        correlation_id = transaction_id,
784        unit_id = unit_id.as_u8(),
785        pdu_len = pdu.len(),
786        "sending modbus tcp server response"
787    );
788    socket.write_all(frame_writer.as_written()).await?;
789    Ok(())
790}
791
792async fn send_rtu_exception(
793    socket: &mut TcpStream,
794    unit_id: UnitId,
795    function_code: u8,
796    exception_code: ExceptionCode,
797) -> Result<(), DataLinkError> {
798    let mut pdu = [0u8; 2];
799    let mut pdu_writer = Writer::new(&mut pdu);
800    ExceptionResponse {
801        function_code,
802        exception_code,
803    }
804    .encode(&mut pdu_writer)
805    .map_err(DataLinkError::Encode)?;
806
807    send_rtu_pdu(socket, unit_id, pdu_writer.as_written()).await
808}
809
810async fn send_rtu_pdu(socket: &mut TcpStream, unit_id: UnitId, pdu: &[u8]) -> Result<(), DataLinkError> {
811    let mut frame = vec![0u8; pdu.len() + 3];
812    let mut writer = Writer::new(&mut frame);
813    rtu_frame::encode_frame(&mut writer, unit_id, pdu)?;
814    socket.write_all(writer.as_written()).await?;
815    Ok(())
816}
817
818#[cfg(test)]
819mod tests {
820    use super::{ModbusRtuOverTcpServer, ModbusService, ModbusTcpServer, ServiceError};
821    use crate::{DataLink, ModbusTcpTransport};
822    use rustmod_core::encoding::Writer;
823    use rustmod_core::frame::rtu as rtu_frame;
824    use rustmod_core::pdu::{DecodedRequest, ExceptionCode};
825    use rustmod_core::UnitId;
826    use tokio::io::{AsyncReadExt, AsyncWriteExt};
827    use tokio::net::TcpStream;
828
829    struct EchoReadService;
830
831    impl ModbusService for EchoReadService {
832        fn handle(
833            &self,
834            _unit_id: UnitId,
835            request: DecodedRequest<'_>,
836            response_pdu: &mut [u8],
837        ) -> Result<usize, ServiceError> {
838            match request {
839                DecodedRequest::ReadHoldingRegisters(_) => {
840                    let bytes = [0x03u8, 0x02, 0x00, 0x2A];
841                    response_pdu[..bytes.len()].copy_from_slice(&bytes);
842                    Ok(bytes.len())
843                }
844                _ => Err(ServiceError::Exception(ExceptionCode::IllegalFunction)),
845            }
846        }
847    }
848
849    struct AlwaysExceptionService;
850
851    impl ModbusService for AlwaysExceptionService {
852        fn handle(
853            &self,
854            _unit_id: UnitId,
855            _request: DecodedRequest<'_>,
856            _response_pdu: &mut [u8],
857        ) -> Result<usize, ServiceError> {
858            Err(ServiceError::Exception(ExceptionCode::IllegalDataAddress))
859        }
860    }
861
862    #[tokio::test]
863    async fn tcp_server_handles_basic_read_request() {
864        let server = ModbusTcpServer::bind("127.0.0.1:0", EchoReadService)
865            .await
866            .unwrap();
867        let addr = server.local_addr().unwrap();
868        let task = tokio::spawn(server.run());
869
870        let transport = ModbusTcpTransport::connect(addr).await.unwrap();
871        let mut response = [0u8; 32];
872        let len = transport
873            .exchange(UnitId::new(1), &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
874            .await
875            .unwrap();
876        assert_eq!(&response[..len], &[0x03, 0x02, 0x00, 0x2A]);
877
878        task.abort();
879        let _ = task.await;
880    }
881
882    #[tokio::test]
883    async fn tcp_server_sends_exception_response() {
884        let server = ModbusTcpServer::bind("127.0.0.1:0", AlwaysExceptionService)
885            .await
886            .unwrap();
887        let addr = server.local_addr().unwrap();
888        let task = tokio::spawn(server.run());
889
890        let transport = ModbusTcpTransport::connect(addr).await.unwrap();
891        let mut response = [0u8; 32];
892        let len = transport
893            .exchange(UnitId::new(1), &[0x03, 0x00, 0x00, 0x00, 0x01], &mut response)
894            .await
895            .unwrap();
896        assert_eq!(&response[..len], &[0x83, 0x02]);
897
898        task.abort();
899        let _ = task.await;
900    }
901
902    #[tokio::test]
903    async fn tcp_server_maps_decode_error_to_exception() {
904        let server = ModbusTcpServer::bind("127.0.0.1:0", EchoReadService)
905            .await
906            .unwrap();
907        let addr = server.local_addr().unwrap();
908        let task = tokio::spawn(server.run());
909
910        let transport = ModbusTcpTransport::connect(addr).await.unwrap();
911        let mut response = [0u8; 32];
912        let len = transport
913            .exchange(
914                UnitId::new(1),
915                &[0x10, 0x00, 0x00, 0x00, 0x02, 0x03, 0x12, 0x34, 0x56],
916                &mut response,
917            )
918            .await
919            .unwrap();
920        assert_eq!(&response[..len], &[0x90, 0x03]);
921
922        task.abort();
923        let _ = task.await;
924    }
925
926    #[tokio::test]
927    async fn rtu_over_tcp_server_handles_basic_read_request() {
928        let server = ModbusRtuOverTcpServer::bind("127.0.0.1:0", EchoReadService)
929            .await
930            .unwrap();
931        let addr = server.local_addr().unwrap();
932        let task = tokio::spawn(server.run());
933
934        let mut stream = TcpStream::connect(addr).await.unwrap();
935        let mut request = [0u8; 16];
936        let mut writer = Writer::new(&mut request);
937        rtu_frame::encode_frame(&mut writer, UnitId::new(1), &[0x03, 0x00, 0x00, 0x00, 0x01]).unwrap();
938        stream.write_all(writer.as_written()).await.unwrap();
939
940        let mut response = [0u8; 7];
941        stream.read_exact(&mut response).await.unwrap();
942        let (unit_id, pdu) = rtu_frame::decode_frame(&response).unwrap();
943        assert_eq!(unit_id, UnitId::new(1));
944        assert_eq!(pdu, &[0x03, 0x02, 0x00, 0x2A]);
945
946        task.abort();
947        let _ = task.await;
948    }
949
950    #[tokio::test]
951    async fn rtu_over_tcp_server_maps_decode_error_to_exception() {
952        let server = ModbusRtuOverTcpServer::bind("127.0.0.1:0", EchoReadService)
953            .await
954            .unwrap();
955        let addr = server.local_addr().unwrap();
956        let task = tokio::spawn(server.run());
957
958        let mut stream = TcpStream::connect(addr).await.unwrap();
959        let mut request = [0u8; 32];
960        let mut writer = Writer::new(&mut request);
961        rtu_frame::encode_frame(
962            &mut writer,
963            UnitId::new(1),
964            &[0x10, 0x00, 0x00, 0x00, 0x02, 0x03, 0x12, 0x34, 0x56],
965        )
966        .unwrap();
967        stream.write_all(writer.as_written()).await.unwrap();
968
969        let mut response = [0u8; 5];
970        stream.read_exact(&mut response).await.unwrap();
971        let (unit_id, pdu) = rtu_frame::decode_frame(&response).unwrap();
972        assert_eq!(unit_id, UnitId::new(1));
973        assert_eq!(pdu, &[0x90, 0x03]);
974
975        task.abort();
976        let _ = task.await;
977    }
978}