Skip to main content

rusty_modbus_frame/
mbap.rs

1//! MBAP codec for Modbus/TCP framing.
2//!
3//! Implements [`tokio_util::codec::Decoder`] and [`tokio_util::codec::Encoder`]
4//! over the 7-byte MBAP header defined in the Modbus TCP/IP Implementation Guide.
5//! Decoding is zero-copy: the PDU is returned as a frozen [`bytes::Bytes`] handle
6//! sliced directly from the read buffer.
7
8use bytes::{BufMut, Bytes, BytesMut};
9use rusty_modbus_types::{MAX_PDU_SIZE, MBAP_HEADER_LEN, MODBUS_PROTOCOL_ID, MbapHeader};
10use tokio_util::codec::{Decoder, Encoder};
11use zerocopy::{FromBytes, IntoBytes};
12
13use crate::error::FrameError;
14use crate::frame::{Frame, FrameHeader};
15
16/// Maximum value of the MBAP length field: `MAX_PDU_SIZE` (253) + 1 byte for unit ID.
17#[allow(clippy::cast_possible_truncation)]
18const MAX_MBAP_LENGTH: u16 = MAX_PDU_SIZE as u16 + 1;
19/// Minimum MBAP length field: Unit Identifier (1) + function code (1).
20const MIN_MBAP_LENGTH: u16 = 2;
21
22/// MBAP codec for Modbus/TCP framing.
23///
24/// Handles the 7-byte MBAP header (transaction ID, protocol ID, length, unit ID)
25/// followed by the PDU (function code + data). The decoder validates the protocol
26/// identifier and length field before yielding a [`Frame`].
27#[derive(Debug, Default)]
28pub struct MbapCodec;
29
30impl Decoder for MbapCodec {
31    type Item = Frame;
32    type Error = FrameError;
33
34    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
35        // Step 1: need at least the MBAP header to proceed.
36        if src.len() < MBAP_HEADER_LEN {
37            return Ok(None);
38        }
39
40        // Step 2: peek at the header fields via zero-copy overlay.
41        let header = *MbapHeader::ref_from_bytes(&src[..MBAP_HEADER_LEN])
42            .map_err(|_| FrameError::Truncated)?;
43
44        // Step 3: validate protocol identifier.
45        let proto = header.protocol_id.get();
46        if proto != MODBUS_PROTOCOL_ID {
47            return Err(FrameError::InvalidProtocolId(proto));
48        }
49
50        // Step 4: validate length field bounds.
51        // The length field includes the unit ID byte, so the PDU portion is length - 1.
52        // Minimum: 2 (unit ID + one-byte function code). A value of 0 or 1 would
53        // leave no complete MODBUS PDU, which always starts with a function code.
54        // Maximum allowed: MAX_PDU_SIZE (253) bytes of PDU + 1 byte unit ID = 254.
55        let length = header.length.get();
56        if length < MIN_MBAP_LENGTH {
57            return Err(FrameError::InvalidLength {
58                declared: length,
59                minimum: MIN_MBAP_LENGTH,
60            });
61        }
62        if length > MAX_MBAP_LENGTH {
63            return Err(FrameError::LengthOverflow(length));
64        }
65
66        // Step 5: compute total ADU size.
67        // The length field counts bytes after itself: unit_id(1) + PDU.
68        // Bytes before the length-field payload: txn_id(2) + proto_id(2) + length(2) = 6.
69        let total = (MBAP_HEADER_LEN - 1) + length as usize;
70
71        // Step 6: wait for the complete frame.
72        if src.len() < total {
73            src.reserve(total - src.len());
74            return Ok(None);
75        }
76
77        // Step 7: split the complete ADU from the buffer — O(1), no copy.
78        let adu = src.split_to(total).freeze();
79
80        // Step 8: slice the PDU (function code + data) — zero-copy via Bytes::slice.
81        let pdu = adu.slice(MBAP_HEADER_LEN..);
82
83        // Step 9: return the decoded frame.
84        Ok(Some(Frame {
85            header: FrameHeader::Mbap(header),
86            pdu,
87        }))
88    }
89}
90
91impl Encoder<Frame> for MbapCodec {
92    type Error = FrameError;
93
94    fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
95        let header = match frame.header {
96            FrameHeader::Mbap(h) => h,
97            FrameHeader::Rtu { .. } => {
98                return Err(FrameError::InvalidProtocolId(0xFFFF));
99            }
100        };
101        validate_outgoing_header(header, frame.pdu.len())?;
102
103        dst.reserve(MBAP_HEADER_LEN + frame.pdu.len());
104        dst.put_slice(header.as_bytes());
105        dst.put_slice(&frame.pdu);
106
107        Ok(())
108    }
109}
110
111/// Convenience [`Encoder`] implementation for `(MbapHeader, Bytes)` tuples.
112///
113/// This allows callers to encode a header and PDU without constructing a full [`Frame`].
114impl Encoder<(MbapHeader, Bytes)> for MbapCodec {
115    type Error = FrameError;
116
117    fn encode(&mut self, item: (MbapHeader, Bytes), dst: &mut BytesMut) -> Result<(), Self::Error> {
118        let (header, pdu) = item;
119        validate_outgoing_header(header, pdu.len())?;
120
121        dst.reserve(MBAP_HEADER_LEN + pdu.len());
122        dst.put_slice(header.as_bytes());
123        dst.put_slice(&pdu);
124
125        Ok(())
126    }
127}
128
129fn validate_outgoing_header(header: MbapHeader, pdu_len: usize) -> Result<(), FrameError> {
130    let proto = header.protocol_id.get();
131    if proto != MODBUS_PROTOCOL_ID {
132        return Err(FrameError::InvalidProtocolId(proto));
133    }
134    if pdu_len > MAX_PDU_SIZE {
135        return Err(FrameError::LengthOverflow(
136            u16::try_from(pdu_len).unwrap_or(u16::MAX),
137        ));
138    }
139
140    let actual = u16::try_from(pdu_len + 1).expect("MAX_PDU_SIZE guarantees u16 length");
141    if actual < MIN_MBAP_LENGTH {
142        return Err(FrameError::InvalidLength {
143            declared: actual,
144            minimum: MIN_MBAP_LENGTH,
145        });
146    }
147
148    let declared = header.length.get();
149    if declared < MIN_MBAP_LENGTH {
150        return Err(FrameError::InvalidLength {
151            declared,
152            minimum: MIN_MBAP_LENGTH,
153        });
154    }
155    if declared > MAX_MBAP_LENGTH {
156        return Err(FrameError::LengthOverflow(declared));
157    }
158    if declared != actual {
159        return Err(FrameError::LengthMismatch { declared, actual });
160    }
161
162    Ok(())
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use bytes::BytesMut;
169    use zerocopy::network_endian::U16;
170
171    /// Build a valid MBAP frame: 7-byte header + PDU bytes.
172    fn build_frame(txn_id: u16, unit_id: u8, pdu: &[u8]) -> Vec<u8> {
173        let header = MbapHeader::new(txn_id, unit_id, u16::try_from(pdu.len()).unwrap());
174        let mut buf = Vec::with_capacity(MBAP_HEADER_LEN + pdu.len());
175        buf.extend_from_slice(header.as_bytes());
176        buf.extend_from_slice(pdu);
177        buf
178    }
179
180    #[test]
181    fn decode_valid_frame() {
182        let pdu = [0x03, 0x00, 0x00, 0x00, 0x0A]; // FC 0x03, start 0, qty 10
183        let raw = build_frame(1, 0xFF, &pdu);
184
185        let mut buf = BytesMut::from(&raw[..]);
186        let mut codec = MbapCodec;
187
188        let frame = codec
189            .decode(&mut buf)
190            .unwrap()
191            .expect("should decode a frame");
192        assert_eq!(frame.unit_id(), 0xFF);
193        assert_eq!(frame.pdu.as_ref(), &pdu);
194        assert!(buf.is_empty(), "buffer should be fully consumed");
195    }
196
197    #[test]
198    fn decode_returns_none_on_partial_header() {
199        let mut buf = BytesMut::from(&[0x00, 0x01, 0x00, 0x00][..]);
200        let mut codec = MbapCodec;
201
202        assert!(codec.decode(&mut buf).unwrap().is_none());
203    }
204
205    #[test]
206    fn decode_returns_none_on_partial_body() {
207        let pdu = [0x03, 0x00, 0x00, 0x00, 0x0A];
208        let raw = build_frame(1, 1, &pdu);
209
210        // Truncate 2 bytes from the end.
211        let mut buf = BytesMut::from(&raw[..raw.len() - 2]);
212        let mut codec = MbapCodec;
213
214        assert!(codec.decode(&mut buf).unwrap().is_none());
215    }
216
217    #[test]
218    fn decode_invalid_protocol_id() {
219        let mut raw = build_frame(1, 1, &[0x03]);
220        // Corrupt protocol ID to 0x0001.
221        raw[2] = 0x00;
222        raw[3] = 0x01;
223
224        let mut buf = BytesMut::from(&raw[..]);
225        let mut codec = MbapCodec;
226
227        let err = codec.decode(&mut buf).unwrap_err();
228        assert!(matches!(err, FrameError::InvalidProtocolId(1)));
229    }
230
231    #[test]
232    fn decode_length_overflow() {
233        let mut raw = build_frame(1, 1, &[0x03]);
234        // Set length field to MAX_PDU_SIZE + 2 = 255 (exceeds limit of 254).
235        let overflow_len = u16::try_from(MAX_PDU_SIZE).unwrap() + 2;
236        raw[4] = (overflow_len >> 8) as u8;
237        raw[5] = (overflow_len & 0xFF) as u8;
238
239        let mut buf = BytesMut::from(&raw[..]);
240        let mut codec = MbapCodec;
241
242        let err = codec.decode(&mut buf).unwrap_err();
243        assert!(matches!(err, FrameError::LengthOverflow(_)));
244    }
245
246    #[test]
247    fn decode_multiple_frames() {
248        let pdu1 = [0x03, 0x01];
249        let pdu2 = [0x06, 0x02, 0x03];
250        let mut raw = build_frame(1, 1, &pdu1);
251        raw.extend_from_slice(&build_frame(2, 2, &pdu2));
252
253        let mut buf = BytesMut::from(&raw[..]);
254        let mut codec = MbapCodec;
255
256        let f1 = codec.decode(&mut buf).unwrap().expect("frame 1");
257        assert_eq!(f1.unit_id(), 1);
258        assert_eq!(f1.pdu.as_ref(), &pdu1);
259
260        let f2 = codec.decode(&mut buf).unwrap().expect("frame 2");
261        assert_eq!(f2.unit_id(), 2);
262        assert_eq!(f2.pdu.as_ref(), &pdu2);
263
264        assert!(buf.is_empty());
265    }
266
267    #[test]
268    fn encode_roundtrip() {
269        let pdu = Bytes::from_static(&[0x03, 0x00, 0x00, 0x00, 0x0A]);
270        let header = MbapHeader::new(42, 0xFF, u16::try_from(pdu.len()).unwrap());
271        let frame = Frame {
272            header: FrameHeader::Mbap(header),
273            pdu: pdu.clone(),
274        };
275
276        let mut buf = BytesMut::new();
277        let mut codec = MbapCodec;
278        codec.encode(frame, &mut buf).unwrap();
279
280        // Decode it back.
281        let decoded = codec.decode(&mut buf).unwrap().expect("should decode");
282        assert_eq!(decoded.unit_id(), 0xFF);
283        assert_eq!(decoded.pdu.as_ref(), &pdu[..]);
284
285        match decoded.header {
286            FrameHeader::Mbap(h) => {
287                assert_eq!(h.transaction_id.get(), 42);
288            }
289            FrameHeader::Rtu { .. } => panic!("expected MBAP header"),
290        }
291    }
292
293    #[test]
294    fn encode_tuple_form() {
295        let pdu = Bytes::from_static(&[0x01, 0x00, 0x0A, 0x00, 0x0D]);
296        let header = MbapHeader::new(7, 1, u16::try_from(pdu.len()).unwrap());
297
298        let mut buf = BytesMut::new();
299        let mut codec = MbapCodec;
300        codec.encode((header, pdu.clone()), &mut buf).unwrap();
301
302        // Decode it back.
303        let decoded = codec.decode(&mut buf).unwrap().expect("should decode");
304        assert_eq!(decoded.unit_id(), 1);
305        assert_eq!(decoded.pdu.as_ref(), &pdu[..]);
306    }
307
308    #[test]
309    fn encode_rtu_frame_errors() {
310        let frame = Frame {
311            header: FrameHeader::Rtu { unit_id: 1 },
312            pdu: Bytes::from_static(&[0x03]),
313        };
314
315        let mut buf = BytesMut::new();
316        let mut codec = MbapCodec;
317        assert!(codec.encode(frame, &mut buf).is_err());
318    }
319
320    #[test]
321    fn decode_rejects_zero_length_pdu() {
322        // length = 1 means unit_id only, leaving no function-code byte.
323        let header = MbapHeader {
324            transaction_id: U16::new(0),
325            protocol_id: U16::new(0),
326            length: U16::new(1),
327            unit_id: 1,
328        };
329        let mut buf = BytesMut::from(header.as_bytes());
330        let mut codec = MbapCodec;
331
332        let err = codec.decode(&mut buf).unwrap_err();
333        assert!(matches!(err, FrameError::InvalidLength { .. }));
334    }
335
336    #[test]
337    fn decode_length_zero_returns_error() {
338        // length = 0 is invalid: even the unit_id byte wouldn't fit.
339        let header = MbapHeader {
340            transaction_id: U16::new(0),
341            protocol_id: U16::new(0),
342            length: U16::new(0),
343            unit_id: 1,
344        };
345        let mut buf = BytesMut::from(header.as_bytes());
346        let mut codec = MbapCodec;
347
348        let err = codec.decode(&mut buf).unwrap_err();
349        assert!(matches!(err, FrameError::InvalidLength { .. }));
350    }
351
352    #[test]
353    fn encode_rejects_zero_length_pdu() {
354        let frame = Frame {
355            header: FrameHeader::Mbap(MbapHeader::new(1, 1, 0)),
356            pdu: Bytes::new(),
357        };
358
359        let mut buf = BytesMut::new();
360        let mut codec = MbapCodec;
361
362        let err = codec.encode(frame, &mut buf).unwrap_err();
363        assert!(matches!(err, FrameError::InvalidLength { .. }));
364    }
365
366    #[test]
367    fn encode_rejects_length_mismatch() {
368        let pdu = Bytes::from_static(&[0x03, 0x00, 0x00, 0x00, 0x01]);
369        let header = MbapHeader::new(1, 1, 3);
370        let frame = Frame {
371            header: FrameHeader::Mbap(header),
372            pdu,
373        };
374
375        let mut buf = BytesMut::new();
376        let mut codec = MbapCodec;
377
378        let err = codec.encode(frame, &mut buf).unwrap_err();
379        assert!(matches!(
380            err,
381            FrameError::LengthMismatch {
382                declared: 4,
383                actual: 6
384            }
385        ));
386    }
387}