use bytes::{BufMut, Bytes, BytesMut};
use rusty_modbus_types::{MAX_PDU_SIZE, MBAP_HEADER_LEN, MODBUS_PROTOCOL_ID, MbapHeader};
use tokio_util::codec::{Decoder, Encoder};
use zerocopy::{FromBytes, IntoBytes};
use crate::error::FrameError;
use crate::frame::{Frame, FrameHeader};
#[allow(clippy::cast_possible_truncation)]
const MAX_MBAP_LENGTH: u16 = MAX_PDU_SIZE as u16 + 1;
const MIN_MBAP_LENGTH: u16 = 2;
#[derive(Debug, Default)]
pub struct MbapCodec;
impl Decoder for MbapCodec {
type Item = Frame;
type Error = FrameError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if src.len() < MBAP_HEADER_LEN {
return Ok(None);
}
let header = *MbapHeader::ref_from_bytes(&src[..MBAP_HEADER_LEN])
.map_err(|_| FrameError::Truncated)?;
let proto = header.protocol_id.get();
if proto != MODBUS_PROTOCOL_ID {
return Err(FrameError::InvalidProtocolId(proto));
}
let length = header.length.get();
if length < MIN_MBAP_LENGTH {
return Err(FrameError::InvalidLength {
declared: length,
minimum: MIN_MBAP_LENGTH,
});
}
if length > MAX_MBAP_LENGTH {
return Err(FrameError::LengthOverflow(length));
}
let total = (MBAP_HEADER_LEN - 1) + length as usize;
if src.len() < total {
src.reserve(total - src.len());
return Ok(None);
}
let adu = src.split_to(total).freeze();
let pdu = adu.slice(MBAP_HEADER_LEN..);
Ok(Some(Frame {
header: FrameHeader::Mbap(header),
pdu,
}))
}
}
impl Encoder<Frame> for MbapCodec {
type Error = FrameError;
fn encode(&mut self, frame: Frame, dst: &mut BytesMut) -> Result<(), Self::Error> {
let header = match frame.header {
FrameHeader::Mbap(h) => h,
FrameHeader::Rtu { .. } => {
return Err(FrameError::InvalidProtocolId(0xFFFF));
}
};
validate_outgoing_header(header, frame.pdu.len())?;
dst.reserve(MBAP_HEADER_LEN + frame.pdu.len());
dst.put_slice(header.as_bytes());
dst.put_slice(&frame.pdu);
Ok(())
}
}
impl Encoder<(MbapHeader, Bytes)> for MbapCodec {
type Error = FrameError;
fn encode(&mut self, item: (MbapHeader, Bytes), dst: &mut BytesMut) -> Result<(), Self::Error> {
let (header, pdu) = item;
validate_outgoing_header(header, pdu.len())?;
dst.reserve(MBAP_HEADER_LEN + pdu.len());
dst.put_slice(header.as_bytes());
dst.put_slice(&pdu);
Ok(())
}
}
fn validate_outgoing_header(header: MbapHeader, pdu_len: usize) -> Result<(), FrameError> {
let proto = header.protocol_id.get();
if proto != MODBUS_PROTOCOL_ID {
return Err(FrameError::InvalidProtocolId(proto));
}
if pdu_len > MAX_PDU_SIZE {
return Err(FrameError::LengthOverflow(
u16::try_from(pdu_len).unwrap_or(u16::MAX),
));
}
let actual = u16::try_from(pdu_len + 1).expect("MAX_PDU_SIZE guarantees u16 length");
if actual < MIN_MBAP_LENGTH {
return Err(FrameError::InvalidLength {
declared: actual,
minimum: MIN_MBAP_LENGTH,
});
}
let declared = header.length.get();
if declared < MIN_MBAP_LENGTH {
return Err(FrameError::InvalidLength {
declared,
minimum: MIN_MBAP_LENGTH,
});
}
if declared > MAX_MBAP_LENGTH {
return Err(FrameError::LengthOverflow(declared));
}
if declared != actual {
return Err(FrameError::LengthMismatch { declared, actual });
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::BytesMut;
use zerocopy::network_endian::U16;
fn build_frame(txn_id: u16, unit_id: u8, pdu: &[u8]) -> Vec<u8> {
let header = MbapHeader::new(txn_id, unit_id, u16::try_from(pdu.len()).unwrap());
let mut buf = Vec::with_capacity(MBAP_HEADER_LEN + pdu.len());
buf.extend_from_slice(header.as_bytes());
buf.extend_from_slice(pdu);
buf
}
#[test]
fn decode_valid_frame() {
let pdu = [0x03, 0x00, 0x00, 0x00, 0x0A]; let raw = build_frame(1, 0xFF, &pdu);
let mut buf = BytesMut::from(&raw[..]);
let mut codec = MbapCodec;
let frame = codec
.decode(&mut buf)
.unwrap()
.expect("should decode a frame");
assert_eq!(frame.unit_id(), 0xFF);
assert_eq!(frame.pdu.as_ref(), &pdu);
assert!(buf.is_empty(), "buffer should be fully consumed");
}
#[test]
fn decode_returns_none_on_partial_header() {
let mut buf = BytesMut::from(&[0x00, 0x01, 0x00, 0x00][..]);
let mut codec = MbapCodec;
assert!(codec.decode(&mut buf).unwrap().is_none());
}
#[test]
fn decode_returns_none_on_partial_body() {
let pdu = [0x03, 0x00, 0x00, 0x00, 0x0A];
let raw = build_frame(1, 1, &pdu);
let mut buf = BytesMut::from(&raw[..raw.len() - 2]);
let mut codec = MbapCodec;
assert!(codec.decode(&mut buf).unwrap().is_none());
}
#[test]
fn decode_invalid_protocol_id() {
let mut raw = build_frame(1, 1, &[0x03]);
raw[2] = 0x00;
raw[3] = 0x01;
let mut buf = BytesMut::from(&raw[..]);
let mut codec = MbapCodec;
let err = codec.decode(&mut buf).unwrap_err();
assert!(matches!(err, FrameError::InvalidProtocolId(1)));
}
#[test]
fn decode_length_overflow() {
let mut raw = build_frame(1, 1, &[0x03]);
let overflow_len = u16::try_from(MAX_PDU_SIZE).unwrap() + 2;
raw[4] = (overflow_len >> 8) as u8;
raw[5] = (overflow_len & 0xFF) as u8;
let mut buf = BytesMut::from(&raw[..]);
let mut codec = MbapCodec;
let err = codec.decode(&mut buf).unwrap_err();
assert!(matches!(err, FrameError::LengthOverflow(_)));
}
#[test]
fn decode_multiple_frames() {
let pdu1 = [0x03, 0x01];
let pdu2 = [0x06, 0x02, 0x03];
let mut raw = build_frame(1, 1, &pdu1);
raw.extend_from_slice(&build_frame(2, 2, &pdu2));
let mut buf = BytesMut::from(&raw[..]);
let mut codec = MbapCodec;
let f1 = codec.decode(&mut buf).unwrap().expect("frame 1");
assert_eq!(f1.unit_id(), 1);
assert_eq!(f1.pdu.as_ref(), &pdu1);
let f2 = codec.decode(&mut buf).unwrap().expect("frame 2");
assert_eq!(f2.unit_id(), 2);
assert_eq!(f2.pdu.as_ref(), &pdu2);
assert!(buf.is_empty());
}
#[test]
fn encode_roundtrip() {
let pdu = Bytes::from_static(&[0x03, 0x00, 0x00, 0x00, 0x0A]);
let header = MbapHeader::new(42, 0xFF, u16::try_from(pdu.len()).unwrap());
let frame = Frame {
header: FrameHeader::Mbap(header),
pdu: pdu.clone(),
};
let mut buf = BytesMut::new();
let mut codec = MbapCodec;
codec.encode(frame, &mut buf).unwrap();
let decoded = codec.decode(&mut buf).unwrap().expect("should decode");
assert_eq!(decoded.unit_id(), 0xFF);
assert_eq!(decoded.pdu.as_ref(), &pdu[..]);
match decoded.header {
FrameHeader::Mbap(h) => {
assert_eq!(h.transaction_id.get(), 42);
}
FrameHeader::Rtu { .. } => panic!("expected MBAP header"),
}
}
#[test]
fn encode_tuple_form() {
let pdu = Bytes::from_static(&[0x01, 0x00, 0x0A, 0x00, 0x0D]);
let header = MbapHeader::new(7, 1, u16::try_from(pdu.len()).unwrap());
let mut buf = BytesMut::new();
let mut codec = MbapCodec;
codec.encode((header, pdu.clone()), &mut buf).unwrap();
let decoded = codec.decode(&mut buf).unwrap().expect("should decode");
assert_eq!(decoded.unit_id(), 1);
assert_eq!(decoded.pdu.as_ref(), &pdu[..]);
}
#[test]
fn encode_rtu_frame_errors() {
let frame = Frame {
header: FrameHeader::Rtu { unit_id: 1 },
pdu: Bytes::from_static(&[0x03]),
};
let mut buf = BytesMut::new();
let mut codec = MbapCodec;
assert!(codec.encode(frame, &mut buf).is_err());
}
#[test]
fn decode_rejects_zero_length_pdu() {
let header = MbapHeader {
transaction_id: U16::new(0),
protocol_id: U16::new(0),
length: U16::new(1),
unit_id: 1,
};
let mut buf = BytesMut::from(header.as_bytes());
let mut codec = MbapCodec;
let err = codec.decode(&mut buf).unwrap_err();
assert!(matches!(err, FrameError::InvalidLength { .. }));
}
#[test]
fn decode_length_zero_returns_error() {
let header = MbapHeader {
transaction_id: U16::new(0),
protocol_id: U16::new(0),
length: U16::new(0),
unit_id: 1,
};
let mut buf = BytesMut::from(header.as_bytes());
let mut codec = MbapCodec;
let err = codec.decode(&mut buf).unwrap_err();
assert!(matches!(err, FrameError::InvalidLength { .. }));
}
#[test]
fn encode_rejects_zero_length_pdu() {
let frame = Frame {
header: FrameHeader::Mbap(MbapHeader::new(1, 1, 0)),
pdu: Bytes::new(),
};
let mut buf = BytesMut::new();
let mut codec = MbapCodec;
let err = codec.encode(frame, &mut buf).unwrap_err();
assert!(matches!(err, FrameError::InvalidLength { .. }));
}
#[test]
fn encode_rejects_length_mismatch() {
let pdu = Bytes::from_static(&[0x03, 0x00, 0x00, 0x00, 0x01]);
let header = MbapHeader::new(1, 1, 3);
let frame = Frame {
header: FrameHeader::Mbap(header),
pdu,
};
let mut buf = BytesMut::new();
let mut codec = MbapCodec;
let err = codec.encode(frame, &mut buf).unwrap_err();
assert!(matches!(
err,
FrameError::LengthMismatch {
declared: 4,
actual: 6
}
));
}
}