use bytes::{Buf, BufMut, BytesMut};
use tokio_util::codec::{Decoder, Encoder};
use crate::error::{ModbusError, ModbusResult};
pub const MBAP_HEADER_SIZE: usize = 7;
pub const MAX_MBAP_LENGTH: u16 = 253;
#[allow(dead_code)]
pub const MIN_PDU_SIZE: usize = 1;
pub const MAX_PDU_SIZE: usize = 253;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MbapHeader {
pub transaction_id: u16,
pub protocol_id: u16,
pub length: u16,
pub unit_id: u8,
}
impl MbapHeader {
pub fn new(transaction_id: u16, unit_id: u8, pdu_length: usize) -> Self {
Self {
transaction_id,
protocol_id: 0,
length: (pdu_length + 1) as u16, unit_id,
}
}
pub fn parse(bytes: &[u8]) -> ModbusResult<Self> {
if bytes.len() < MBAP_HEADER_SIZE {
return Err(ModbusError::InvalidData(format!(
"MBAP header too short: {} bytes (expected {})",
bytes.len(),
MBAP_HEADER_SIZE
)));
}
let transaction_id = u16::from_be_bytes([bytes[0], bytes[1]]);
let protocol_id = u16::from_be_bytes([bytes[2], bytes[3]]);
let length = u16::from_be_bytes([bytes[4], bytes[5]]);
let unit_id = bytes[6];
if protocol_id != 0 {
return Err(ModbusError::InvalidData(format!(
"Invalid protocol ID: {} (expected 0)",
protocol_id
)));
}
if length == 0 || length > MAX_MBAP_LENGTH {
return Err(ModbusError::InvalidData(format!(
"Invalid MBAP length: {} (expected 1-{})",
length, MAX_MBAP_LENGTH
)));
}
Ok(Self {
transaction_id,
protocol_id,
length,
unit_id,
})
}
pub fn encode(&self, buf: &mut BytesMut) {
buf.put_u16(self.transaction_id);
buf.put_u16(self.protocol_id);
buf.put_u16(self.length);
buf.put_u8(self.unit_id);
}
pub fn pdu_length(&self) -> usize {
(self.length as usize).saturating_sub(1)
}
pub fn total_frame_length(&self) -> usize {
MBAP_HEADER_SIZE + self.pdu_length()
}
}
#[derive(Debug, Clone)]
pub struct MbapFrame {
pub header: MbapHeader,
pub pdu: Vec<u8>,
}
impl MbapFrame {
pub fn new(transaction_id: u16, unit_id: u8, pdu: Vec<u8>) -> Self {
let header = MbapHeader::new(transaction_id, unit_id, pdu.len());
Self { header, pdu }
}
pub fn response(request: &MbapFrame, response_pdu: Vec<u8>) -> Self {
Self::new(
request.header.transaction_id,
request.header.unit_id,
response_pdu,
)
}
pub fn function_code(&self) -> Option<u8> {
self.pdu.first().copied()
}
pub fn is_exception(&self) -> bool {
self.pdu.first().map(|fc| fc & 0x80 != 0).unwrap_or(false)
}
pub fn frame_size(&self) -> usize {
MBAP_HEADER_SIZE + self.pdu.len()
}
pub fn encode(&self, buf: &mut BytesMut) {
self.header.encode(buf);
buf.put_slice(&self.pdu);
}
pub fn to_bytes(&self) -> BytesMut {
let mut buf = BytesMut::with_capacity(self.frame_size());
self.encode(&mut buf);
buf
}
}
#[derive(Debug, Clone, Default)]
pub struct MbapCodec {
state: DecodeState,
}
#[derive(Debug, Clone, Default)]
enum DecodeState {
#[default]
WaitingForHeader,
WaitingForPdu(MbapHeader),
}
impl MbapCodec {
pub fn new() -> Self {
Self {
state: DecodeState::WaitingForHeader,
}
}
}
impl Decoder for MbapCodec {
type Item = MbapFrame;
type Error = ModbusError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
loop {
match &self.state {
DecodeState::WaitingForHeader => {
if src.len() < MBAP_HEADER_SIZE {
return Ok(None);
}
let header = MbapHeader::parse(&src[..MBAP_HEADER_SIZE])?;
let pdu_len = header.pdu_length();
if pdu_len == 0 {
return Err(ModbusError::InvalidData("PDU length cannot be zero".into()));
}
if pdu_len > MAX_PDU_SIZE {
return Err(ModbusError::InvalidData(format!(
"PDU length too large: {} (max {})",
pdu_len, MAX_PDU_SIZE
)));
}
self.state = DecodeState::WaitingForPdu(header);
}
DecodeState::WaitingForPdu(header) => {
let total_len = header.total_frame_length();
if src.len() < total_len {
return Ok(None);
}
src.advance(MBAP_HEADER_SIZE);
let pdu_len = header.pdu_length();
let pdu = src.split_to(pdu_len).to_vec();
let frame = MbapFrame {
header: *header,
pdu,
};
self.state = DecodeState::WaitingForHeader;
return Ok(Some(frame));
}
}
}
}
}
impl Encoder<MbapFrame> for MbapCodec {
type Error = ModbusError;
fn encode(&mut self, item: MbapFrame, dst: &mut BytesMut) -> Result<(), Self::Error> {
if item.pdu.is_empty() {
return Err(ModbusError::InvalidData("PDU cannot be empty".into()));
}
if item.pdu.len() > MAX_PDU_SIZE {
return Err(ModbusError::InvalidData(format!(
"PDU too large: {} (max {})",
item.pdu.len(),
MAX_PDU_SIZE
)));
}
dst.reserve(item.frame_size());
item.encode(dst);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mbap_header_parse() {
let bytes = [0x00, 0x01, 0x00, 0x00, 0x00, 0x06, 0x01];
let header = MbapHeader::parse(&bytes).unwrap();
assert_eq!(header.transaction_id, 1);
assert_eq!(header.protocol_id, 0);
assert_eq!(header.length, 6);
assert_eq!(header.unit_id, 1);
assert_eq!(header.pdu_length(), 5);
}
#[test]
fn test_mbap_header_parse_invalid_protocol() {
let bytes = [0x00, 0x01, 0x00, 0x01, 0x00, 0x06, 0x01];
let result = MbapHeader::parse(&bytes);
assert!(result.is_err());
}
#[test]
fn test_mbap_header_encode() {
let header = MbapHeader::new(42, 1, 5);
let mut buf = BytesMut::new();
header.encode(&mut buf);
assert_eq!(buf.len(), MBAP_HEADER_SIZE);
assert_eq!(&buf[..], &[0x00, 0x2A, 0x00, 0x00, 0x00, 0x06, 0x01]);
}
#[test]
fn test_mbap_frame_creation() {
let pdu = vec![0x03, 0x00, 0x00, 0x00, 0x0A]; let frame = MbapFrame::new(1, 1, pdu.clone());
assert_eq!(frame.header.transaction_id, 1);
assert_eq!(frame.header.unit_id, 1);
assert_eq!(frame.header.length, 6); assert_eq!(frame.pdu, pdu);
assert_eq!(frame.function_code(), Some(0x03));
assert!(!frame.is_exception());
}
#[test]
fn test_mbap_frame_exception() {
let pdu = vec![0x83, 0x02]; let frame = MbapFrame::new(1, 1, pdu);
assert!(frame.is_exception());
assert_eq!(frame.function_code(), Some(0x83));
}
#[test]
fn test_codec_decode() {
let mut codec = MbapCodec::new();
let mut buf = BytesMut::from(
&[
0x00, 0x01, 0x00, 0x00, 0x00, 0x06, 0x01, 0x03, 0x00, 0x00, 0x00, 0x0A, ][..],
);
let frame = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(frame.header.transaction_id, 1);
assert_eq!(frame.header.unit_id, 1);
assert_eq!(frame.pdu, vec![0x03, 0x00, 0x00, 0x00, 0x0A]);
assert!(buf.is_empty());
}
#[test]
fn test_codec_decode_partial() {
let mut codec = MbapCodec::new();
let mut buf = BytesMut::from(&[0x00, 0x01, 0x00, 0x00, 0x00, 0x06, 0x01][..]);
let result = codec.decode(&mut buf).unwrap();
assert!(result.is_none());
buf.extend_from_slice(&[0x03, 0x00, 0x00, 0x00, 0x0A]);
let frame = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(frame.pdu, vec![0x03, 0x00, 0x00, 0x00, 0x0A]);
}
#[test]
fn test_codec_decode_multiple() {
let mut codec = MbapCodec::new();
let mut buf = BytesMut::from(
&[
0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x01, 0x03, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03, 0x02, 0x04, 0x00,
][..],
);
let frame1 = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(frame1.header.transaction_id, 1);
assert_eq!(frame1.header.unit_id, 1);
let frame2 = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(frame2.header.transaction_id, 2);
assert_eq!(frame2.header.unit_id, 2);
assert!(buf.is_empty());
}
#[test]
fn test_codec_encode() {
let mut codec = MbapCodec::new();
let frame = MbapFrame::new(1, 1, vec![0x03, 0x00, 0x00, 0x00, 0x0A]);
let mut buf = BytesMut::new();
codec.encode(frame, &mut buf).unwrap();
assert_eq!(
buf,
&[
0x00, 0x01, 0x00, 0x00, 0x00, 0x06, 0x01, 0x03, 0x00, 0x00, 0x00, 0x0A, ][..]
);
}
#[test]
fn test_codec_round_trip() {
let mut codec = MbapCodec::new();
let original = MbapFrame::new(
42,
5,
vec![0x10, 0x00, 0x10, 0x00, 0x02, 0x04, 0x01, 0x02, 0x03, 0x04],
);
let mut buf = BytesMut::new();
codec.encode(original.clone(), &mut buf).unwrap();
let decoded = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(
decoded.header.transaction_id,
original.header.transaction_id
);
assert_eq!(decoded.header.unit_id, original.header.unit_id);
assert_eq!(decoded.pdu, original.pdu);
}
}