use std::time::{Duration, Instant};
use bytes::BytesMut;
use tokio_util::codec::{Decoder, Encoder};
use crate::error::ModbusError;
use super::frame::{verify_crc, RtuFrame, RtuFrameError, RTU_MAX_FRAME_SIZE, RTU_MIN_FRAME_SIZE};
#[derive(Debug, Clone, Copy)]
pub struct RtuTiming {
pub char_time: Duration,
pub inter_char_timeout: Duration,
pub inter_frame_timeout: Duration,
}
impl RtuTiming {
pub fn from_baud_rate(baud_rate: u32) -> Self {
Self::from_baud_rate_with_bits(baud_rate, 11)
}
pub fn from_baud_rate_with_bits(baud_rate: u32, bits_per_char: u32) -> Self {
let char_time_us = (bits_per_char as u64 * 1_000_000) / baud_rate as u64;
let char_time = Duration::from_micros(char_time_us);
let (inter_char, inter_frame) = if baud_rate > 19200 {
(Duration::from_micros(750), Duration::from_micros(1750))
} else {
(char_time.mul_f32(1.5), char_time.mul_f32(3.5))
};
Self {
char_time,
inter_char_timeout: inter_char,
inter_frame_timeout: inter_frame,
}
}
pub fn transmission_time(&self, bytes: usize) -> Duration {
self.char_time * bytes as u32
}
}
impl Default for RtuTiming {
fn default() -> Self {
Self::from_baud_rate(9600)
}
}
#[derive(Debug, Clone)]
enum DecodeState {
Idle,
Receiving {
last_byte_time: Instant,
#[allow(dead_code)]
expected_length: Option<usize>,
},
#[allow(dead_code)]
Complete,
}
impl Default for DecodeState {
fn default() -> Self {
Self::Idle
}
}
#[derive(Debug)]
pub struct RtuCodec {
timing: RtuTiming,
state: DecodeState,
buffer: BytesMut,
strict_timing: bool,
unit_id_filter: Option<Vec<u8>>,
}
impl RtuCodec {
pub fn new() -> Self {
Self::with_timing(RtuTiming::default())
}
pub fn with_timing(timing: RtuTiming) -> Self {
Self {
timing,
state: DecodeState::Idle,
buffer: BytesMut::with_capacity(RTU_MAX_FRAME_SIZE),
strict_timing: false,
unit_id_filter: None,
}
}
pub fn with_baud_rate(baud_rate: u32) -> Self {
Self::with_timing(RtuTiming::from_baud_rate(baud_rate))
}
pub fn strict_timing(mut self, enabled: bool) -> Self {
self.strict_timing = enabled;
self
}
pub fn unit_id_filter(mut self, unit_ids: Vec<u8>) -> Self {
self.unit_id_filter = Some(unit_ids);
self
}
pub fn timing(&self) -> &RtuTiming {
&self.timing
}
pub fn reset(&mut self) {
self.state = DecodeState::Idle;
self.buffer.clear();
}
fn try_parse_frame(&mut self) -> Result<Option<RtuFrame>, ModbusError> {
if self.buffer.len() < RTU_MIN_FRAME_SIZE {
return Ok(None);
}
let expected_len = self.estimate_frame_length();
match expected_len {
Some(len) if self.buffer.len() >= len => {
let frame_data = self.buffer.split_to(len);
match RtuFrame::decode(&frame_data) {
Ok(frame) => {
if let Some(ref filter) = self.unit_id_filter {
if !filter.contains(&frame.unit_id) && frame.unit_id != 0 {
self.state = DecodeState::Idle;
return Ok(None);
}
}
self.state = DecodeState::Idle;
Ok(Some(frame))
}
Err(RtuFrameError::CrcMismatch { expected, received }) => {
self.state = DecodeState::Idle;
Err(ModbusError::InvalidData(format!(
"CRC mismatch: expected 0x{:04X}, got 0x{:04X}",
expected, received
)))
}
Err(e) => {
self.state = DecodeState::Idle;
Err(ModbusError::InvalidData(e.to_string()))
}
}
}
Some(_) => {
Ok(None)
}
None if self.buffer.len() >= RTU_MAX_FRAME_SIZE => {
self.buffer.clear();
self.state = DecodeState::Idle;
Err(ModbusError::InvalidData(
"Unable to determine frame length, buffer overflow".into(),
))
}
None => {
Ok(None)
}
}
}
fn estimate_frame_length(&self) -> Option<usize> {
if self.buffer.len() < 2 {
return None;
}
let function_code = self.buffer[1];
if function_code & 0x80 != 0 {
return Some(5);
}
match function_code {
0x01 | 0x02 | 0x03 | 0x04 | 0x05 | 0x06 => Some(8),
0x16 => Some(10),
0x0F | 0x10 => {
if self.buffer.len() >= 7 {
let byte_count = self.buffer[6] as usize;
Some(7 + byte_count + 2)
} else {
None
}
}
0x07 => Some(4),
0x08 => Some(8),
0x0B | 0x0C => Some(4),
0x11 => {
if self.buffer.len() >= 3 {
let byte_count = self.buffer[2] as usize;
Some(3 + byte_count + 2)
} else {
None
}
}
0x17 => {
if self.buffer.len() >= 11 {
let write_byte_count = self.buffer[10] as usize;
Some(11 + write_byte_count + 2)
} else {
None
}
}
_ => None,
}
}
fn check_frame_timeout(&mut self) -> bool {
if !self.strict_timing {
return false;
}
if let DecodeState::Receiving { last_byte_time, .. } = &self.state {
last_byte_time.elapsed() >= self.timing.inter_frame_timeout
} else {
false
}
}
}
impl Default for RtuCodec {
fn default() -> Self {
Self::new()
}
}
impl Decoder for RtuCodec {
type Item = RtuFrame;
type Error = ModbusError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if self.check_frame_timeout() && !self.buffer.is_empty() {
if self.buffer.len() >= RTU_MIN_FRAME_SIZE && verify_crc(&self.buffer) {
return self.try_parse_frame();
} else {
self.buffer.clear();
self.state = DecodeState::Idle;
}
}
if src.is_empty() {
return Ok(None);
}
self.buffer.extend_from_slice(src);
src.clear();
self.state = DecodeState::Receiving {
last_byte_time: Instant::now(),
expected_length: self.estimate_frame_length(),
};
self.try_parse_frame()
}
}
impl Encoder<RtuFrame> for RtuCodec {
type Error = ModbusError;
fn encode(&mut self, item: RtuFrame, dst: &mut BytesMut) -> Result<(), Self::Error> {
if item.pdu.is_empty() {
return Err(ModbusError::InvalidData("PDU cannot be empty".into()));
}
if item.pdu.len() > super::frame::RTU_MAX_PDU_SIZE {
return Err(ModbusError::InvalidData(format!(
"PDU too large: {} bytes (max {})",
item.pdu.len(),
super::frame::RTU_MAX_PDU_SIZE
)));
}
dst.reserve(item.frame_size());
item.encode_to(dst);
Ok(())
}
}
#[derive(Debug)]
pub struct StreamingRtuCodec {
inner: RtuCodec,
partial_frame: BytesMut,
last_byte_time: Option<Instant>,
}
impl StreamingRtuCodec {
pub fn new(timing: RtuTiming) -> Self {
Self {
inner: RtuCodec::with_timing(timing).strict_timing(true),
partial_frame: BytesMut::with_capacity(RTU_MAX_FRAME_SIZE),
last_byte_time: None,
}
}
pub fn process_byte(&mut self, byte: u8) -> Result<Option<RtuFrame>, ModbusError> {
let now = Instant::now();
if let Some(last_time) = self.last_byte_time {
if now.duration_since(last_time) >= self.inner.timing.inter_frame_timeout {
if !self.partial_frame.is_empty() {
if self.partial_frame.len() >= RTU_MIN_FRAME_SIZE
&& verify_crc(&self.partial_frame)
{
let frame_data = std::mem::replace(
&mut self.partial_frame,
BytesMut::with_capacity(RTU_MAX_FRAME_SIZE),
);
self.last_byte_time = Some(now);
self.partial_frame.extend_from_slice(&[byte]);
return RtuFrame::decode(&frame_data)
.map(Some)
.map_err(|e| ModbusError::InvalidData(e.to_string()));
} else {
self.partial_frame.clear();
}
}
}
}
self.last_byte_time = Some(now);
self.partial_frame.extend_from_slice(&[byte]);
if self.partial_frame.len() >= RTU_MIN_FRAME_SIZE {
if let Some(expected_len) = self.inner.estimate_frame_length() {
if self.partial_frame.len() >= expected_len
&& verify_crc(&self.partial_frame[..expected_len])
{
let frame_data = self.partial_frame.split_to(expected_len);
return RtuFrame::decode(&frame_data)
.map(Some)
.map_err(|e| ModbusError::InvalidData(e.to_string()));
}
}
}
if self.partial_frame.len() >= RTU_MAX_FRAME_SIZE {
self.partial_frame.clear();
return Err(ModbusError::InvalidData("Frame buffer overflow".into()));
}
Ok(None)
}
pub fn check_timeout(&mut self) -> Result<Option<RtuFrame>, ModbusError> {
if let Some(last_time) = self.last_byte_time {
if Instant::now().duration_since(last_time) >= self.inner.timing.inter_frame_timeout {
if self.partial_frame.len() >= RTU_MIN_FRAME_SIZE && verify_crc(&self.partial_frame)
{
let frame_data = std::mem::replace(
&mut self.partial_frame,
BytesMut::with_capacity(RTU_MAX_FRAME_SIZE),
);
self.last_byte_time = None;
return RtuFrame::decode(&frame_data)
.map(Some)
.map_err(|e| ModbusError::InvalidData(e.to_string()));
} else if !self.partial_frame.is_empty() {
self.partial_frame.clear();
self.last_byte_time = None;
}
}
}
Ok(None)
}
pub fn reset(&mut self) {
self.inner.reset();
self.partial_frame.clear();
self.last_byte_time = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_timing_calculation() {
let timing = RtuTiming::from_baud_rate(9600);
let char_time_us = timing.char_time.as_micros();
assert!(char_time_us > 1100 && char_time_us < 1200);
let inter_frame_us = timing.inter_frame_timeout.as_micros();
assert!(inter_frame_us > 3500 && inter_frame_us < 4500);
}
#[test]
fn test_high_baud_rate_minimums() {
let timing = RtuTiming::from_baud_rate(115200);
assert_eq!(timing.inter_char_timeout, Duration::from_micros(750));
assert_eq!(timing.inter_frame_timeout, Duration::from_micros(1750));
}
#[test]
fn test_codec_encode_decode() {
let mut codec = RtuCodec::new();
let frame = RtuFrame::new(1, vec![0x03, 0x00, 0x00, 0x00, 0x0A]);
let mut buf = BytesMut::new();
codec.encode(frame.clone(), &mut buf).unwrap();
assert_eq!(buf.len(), 8);
let mut codec = RtuCodec::new();
let decoded = codec.decode(&mut buf).unwrap().unwrap();
assert_eq!(decoded.unit_id, frame.unit_id);
assert_eq!(decoded.pdu, frame.pdu);
}
#[test]
fn test_codec_partial_frame() {
let mut codec = RtuCodec::new();
let frame = RtuFrame::new(1, vec![0x03, 0x00, 0x00, 0x00, 0x0A]);
let full = frame.encode();
let mut buf = BytesMut::from(&full[..3]);
let result = codec.decode(&mut buf).unwrap();
assert!(result.is_none());
let mut remaining = BytesMut::from(&full[3..]);
let result = codec.decode(&mut remaining).unwrap();
assert!(result.is_some());
}
#[test]
fn test_codec_exception_frame() {
let mut codec = RtuCodec::new();
let frame = RtuFrame::exception(1, 0x03, 0x02);
let mut buf = BytesMut::new();
codec.encode(frame.clone(), &mut buf).unwrap();
assert_eq!(buf.len(), 5);
let decoded = codec.decode(&mut buf).unwrap().unwrap();
assert!(decoded.is_exception());
}
#[test]
fn test_codec_unit_id_filter() {
let mut codec = RtuCodec::new().unit_id_filter(vec![1, 2]);
let frame1 = RtuFrame::new(1, vec![0x03, 0x00, 0x00, 0x00, 0x0A]);
let mut buf = frame1.encode().into();
let result = codec.decode(&mut buf).unwrap();
assert!(result.is_some());
codec.reset();
let frame5 = RtuFrame::new(5, vec![0x03, 0x00, 0x00, 0x00, 0x0A]);
let mut buf = frame5.encode().into();
let result = codec.decode(&mut buf).unwrap();
assert!(result.is_none()); }
#[test]
fn test_transmission_time() {
let timing = RtuTiming::from_baud_rate(9600);
let time = timing.transmission_time(8);
let time_ms = time.as_millis();
assert!(time_ms >= 8 && time_ms <= 10);
}
}