1use bytes::{BufMut, Bytes, BytesMut};
9use rusty_modbus_types::{MAX_PDU_SIZE, MBAP_HEADER_LEN, MODBUS_PROTOCOL_ID, MbapHeader};
10use tokio_util::codec::{Decoder, Encoder};
11use zerocopy::{FromBytes, IntoBytes};
12
13use crate::error::FrameError;
14use crate::frame::{Frame, FrameHeader};
15
16#[allow(clippy::cast_possible_truncation)]
18const MAX_MBAP_LENGTH: u16 = MAX_PDU_SIZE as u16 + 1;
19const MIN_MBAP_LENGTH: u16 = 2;
21
22#[derive(Debug, Default)]
28pub struct MbapCodec;
29
30impl Decoder for MbapCodec {
31 type Item = Frame;
32 type Error = FrameError;
33
34 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
35 if src.len() < MBAP_HEADER_LEN {
37 return Ok(None);
38 }
39
40 let header = *MbapHeader::ref_from_bytes(&src[..MBAP_HEADER_LEN])
42 .map_err(|_| FrameError::Truncated)?;
43
44 let proto = header.protocol_id.get();
46 if proto != MODBUS_PROTOCOL_ID {
47 return Err(FrameError::InvalidProtocolId(proto));
48 }
49
50 let length = header.length.get();
56 if length < MIN_MBAP_LENGTH {
57 return Err(FrameError::InvalidLength {
58 declared: length,
59 minimum: MIN_MBAP_LENGTH,
60 });
61 }
62 if length > MAX_MBAP_LENGTH {
63 return Err(FrameError::LengthOverflow(length));
64 }
65
66 let total = (MBAP_HEADER_LEN - 1) + length as usize;
70
71 if src.len() < total {
73 src.reserve(total - src.len());
74 return Ok(None);
75 }
76
77 let adu = src.split_to(total).freeze();
79
80 let pdu = adu.slice(MBAP_HEADER_LEN..);
82
83 Ok(Some(Frame {
85 header: FrameHeader::Mbap(header),
86 pdu,
87 }))
88 }
89}
90
91impl Encoder<Frame> for MbapCodec {
92 type Error = FrameError;
93
94 fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
95 let header = match frame.header {
96 FrameHeader::Mbap(h) => h,
97 FrameHeader::Rtu { .. } => {
98 return Err(FrameError::InvalidProtocolId(0xFFFF));
99 }
100 };
101 validate_outgoing_header(header, frame.pdu.len())?;
102
103 dst.reserve(MBAP_HEADER_LEN + frame.pdu.len());
104 dst.put_slice(header.as_bytes());
105 dst.put_slice(&frame.pdu);
106
107 Ok(())
108 }
109}
110
111impl Encoder<(MbapHeader, Bytes)> for MbapCodec {
115 type Error = FrameError;
116
117 fn encode(&mut self, item: (MbapHeader, Bytes), dst: &mut BytesMut) -> Result<(), Self::Error> {
118 let (header, pdu) = item;
119 validate_outgoing_header(header, pdu.len())?;
120
121 dst.reserve(MBAP_HEADER_LEN + pdu.len());
122 dst.put_slice(header.as_bytes());
123 dst.put_slice(&pdu);
124
125 Ok(())
126 }
127}
128
129fn validate_outgoing_header(header: MbapHeader, pdu_len: usize) -> Result<(), FrameError> {
130 let proto = header.protocol_id.get();
131 if proto != MODBUS_PROTOCOL_ID {
132 return Err(FrameError::InvalidProtocolId(proto));
133 }
134 if pdu_len > MAX_PDU_SIZE {
135 return Err(FrameError::LengthOverflow(
136 u16::try_from(pdu_len).unwrap_or(u16::MAX),
137 ));
138 }
139
140 let actual = u16::try_from(pdu_len + 1).expect("MAX_PDU_SIZE guarantees u16 length");
141 if actual < MIN_MBAP_LENGTH {
142 return Err(FrameError::InvalidLength {
143 declared: actual,
144 minimum: MIN_MBAP_LENGTH,
145 });
146 }
147
148 let declared = header.length.get();
149 if declared < MIN_MBAP_LENGTH {
150 return Err(FrameError::InvalidLength {
151 declared,
152 minimum: MIN_MBAP_LENGTH,
153 });
154 }
155 if declared > MAX_MBAP_LENGTH {
156 return Err(FrameError::LengthOverflow(declared));
157 }
158 if declared != actual {
159 return Err(FrameError::LengthMismatch { declared, actual });
160 }
161
162 Ok(())
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use bytes::BytesMut;
169 use zerocopy::network_endian::U16;
170
171 fn build_frame(txn_id: u16, unit_id: u8, pdu: &[u8]) -> Vec<u8> {
173 let header = MbapHeader::new(txn_id, unit_id, u16::try_from(pdu.len()).unwrap());
174 let mut buf = Vec::with_capacity(MBAP_HEADER_LEN + pdu.len());
175 buf.extend_from_slice(header.as_bytes());
176 buf.extend_from_slice(pdu);
177 buf
178 }
179
180 #[test]
181 fn decode_valid_frame() {
182 let pdu = [0x03, 0x00, 0x00, 0x00, 0x0A]; let raw = build_frame(1, 0xFF, &pdu);
184
185 let mut buf = BytesMut::from(&raw[..]);
186 let mut codec = MbapCodec;
187
188 let frame = codec
189 .decode(&mut buf)
190 .unwrap()
191 .expect("should decode a frame");
192 assert_eq!(frame.unit_id(), 0xFF);
193 assert_eq!(frame.pdu.as_ref(), &pdu);
194 assert!(buf.is_empty(), "buffer should be fully consumed");
195 }
196
197 #[test]
198 fn decode_returns_none_on_partial_header() {
199 let mut buf = BytesMut::from(&[0x00, 0x01, 0x00, 0x00][..]);
200 let mut codec = MbapCodec;
201
202 assert!(codec.decode(&mut buf).unwrap().is_none());
203 }
204
205 #[test]
206 fn decode_returns_none_on_partial_body() {
207 let pdu = [0x03, 0x00, 0x00, 0x00, 0x0A];
208 let raw = build_frame(1, 1, &pdu);
209
210 let mut buf = BytesMut::from(&raw[..raw.len() - 2]);
212 let mut codec = MbapCodec;
213
214 assert!(codec.decode(&mut buf).unwrap().is_none());
215 }
216
217 #[test]
218 fn decode_invalid_protocol_id() {
219 let mut raw = build_frame(1, 1, &[0x03]);
220 raw[2] = 0x00;
222 raw[3] = 0x01;
223
224 let mut buf = BytesMut::from(&raw[..]);
225 let mut codec = MbapCodec;
226
227 let err = codec.decode(&mut buf).unwrap_err();
228 assert!(matches!(err, FrameError::InvalidProtocolId(1)));
229 }
230
231 #[test]
232 fn decode_length_overflow() {
233 let mut raw = build_frame(1, 1, &[0x03]);
234 let overflow_len = u16::try_from(MAX_PDU_SIZE).unwrap() + 2;
236 raw[4] = (overflow_len >> 8) as u8;
237 raw[5] = (overflow_len & 0xFF) as u8;
238
239 let mut buf = BytesMut::from(&raw[..]);
240 let mut codec = MbapCodec;
241
242 let err = codec.decode(&mut buf).unwrap_err();
243 assert!(matches!(err, FrameError::LengthOverflow(_)));
244 }
245
246 #[test]
247 fn decode_multiple_frames() {
248 let pdu1 = [0x03, 0x01];
249 let pdu2 = [0x06, 0x02, 0x03];
250 let mut raw = build_frame(1, 1, &pdu1);
251 raw.extend_from_slice(&build_frame(2, 2, &pdu2));
252
253 let mut buf = BytesMut::from(&raw[..]);
254 let mut codec = MbapCodec;
255
256 let f1 = codec.decode(&mut buf).unwrap().expect("frame 1");
257 assert_eq!(f1.unit_id(), 1);
258 assert_eq!(f1.pdu.as_ref(), &pdu1);
259
260 let f2 = codec.decode(&mut buf).unwrap().expect("frame 2");
261 assert_eq!(f2.unit_id(), 2);
262 assert_eq!(f2.pdu.as_ref(), &pdu2);
263
264 assert!(buf.is_empty());
265 }
266
267 #[test]
268 fn encode_roundtrip() {
269 let pdu = Bytes::from_static(&[0x03, 0x00, 0x00, 0x00, 0x0A]);
270 let header = MbapHeader::new(42, 0xFF, u16::try_from(pdu.len()).unwrap());
271 let frame = Frame {
272 header: FrameHeader::Mbap(header),
273 pdu: pdu.clone(),
274 };
275
276 let mut buf = BytesMut::new();
277 let mut codec = MbapCodec;
278 codec.encode(frame, &mut buf).unwrap();
279
280 let decoded = codec.decode(&mut buf).unwrap().expect("should decode");
282 assert_eq!(decoded.unit_id(), 0xFF);
283 assert_eq!(decoded.pdu.as_ref(), &pdu[..]);
284
285 match decoded.header {
286 FrameHeader::Mbap(h) => {
287 assert_eq!(h.transaction_id.get(), 42);
288 }
289 FrameHeader::Rtu { .. } => panic!("expected MBAP header"),
290 }
291 }
292
293 #[test]
294 fn encode_tuple_form() {
295 let pdu = Bytes::from_static(&[0x01, 0x00, 0x0A, 0x00, 0x0D]);
296 let header = MbapHeader::new(7, 1, u16::try_from(pdu.len()).unwrap());
297
298 let mut buf = BytesMut::new();
299 let mut codec = MbapCodec;
300 codec.encode((header, pdu.clone()), &mut buf).unwrap();
301
302 let decoded = codec.decode(&mut buf).unwrap().expect("should decode");
304 assert_eq!(decoded.unit_id(), 1);
305 assert_eq!(decoded.pdu.as_ref(), &pdu[..]);
306 }
307
308 #[test]
309 fn encode_rtu_frame_errors() {
310 let frame = Frame {
311 header: FrameHeader::Rtu { unit_id: 1 },
312 pdu: Bytes::from_static(&[0x03]),
313 };
314
315 let mut buf = BytesMut::new();
316 let mut codec = MbapCodec;
317 assert!(codec.encode(frame, &mut buf).is_err());
318 }
319
320 #[test]
321 fn decode_rejects_zero_length_pdu() {
322 let header = MbapHeader {
324 transaction_id: U16::new(0),
325 protocol_id: U16::new(0),
326 length: U16::new(1),
327 unit_id: 1,
328 };
329 let mut buf = BytesMut::from(header.as_bytes());
330 let mut codec = MbapCodec;
331
332 let err = codec.decode(&mut buf).unwrap_err();
333 assert!(matches!(err, FrameError::InvalidLength { .. }));
334 }
335
336 #[test]
337 fn decode_length_zero_returns_error() {
338 let header = MbapHeader {
340 transaction_id: U16::new(0),
341 protocol_id: U16::new(0),
342 length: U16::new(0),
343 unit_id: 1,
344 };
345 let mut buf = BytesMut::from(header.as_bytes());
346 let mut codec = MbapCodec;
347
348 let err = codec.decode(&mut buf).unwrap_err();
349 assert!(matches!(err, FrameError::InvalidLength { .. }));
350 }
351
352 #[test]
353 fn encode_rejects_zero_length_pdu() {
354 let frame = Frame {
355 header: FrameHeader::Mbap(MbapHeader::new(1, 1, 0)),
356 pdu: Bytes::new(),
357 };
358
359 let mut buf = BytesMut::new();
360 let mut codec = MbapCodec;
361
362 let err = codec.encode(frame, &mut buf).unwrap_err();
363 assert!(matches!(err, FrameError::InvalidLength { .. }));
364 }
365
366 #[test]
367 fn encode_rejects_length_mismatch() {
368 let pdu = Bytes::from_static(&[0x03, 0x00, 0x00, 0x00, 0x01]);
369 let header = MbapHeader::new(1, 1, 3);
370 let frame = Frame {
371 header: FrameHeader::Mbap(header),
372 pdu,
373 };
374
375 let mut buf = BytesMut::new();
376 let mut codec = MbapCodec;
377
378 let err = codec.encode(frame, &mut buf).unwrap_err();
379 assert!(matches!(
380 err,
381 FrameError::LengthMismatch {
382 declared: 4,
383 actual: 6
384 }
385 ));
386 }
387}