#[cfg(feature = "rtu")]
use crc::{Crc, CRC_16_MODBUS};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::timeout;
use tracing::{debug, info};
#[cfg(feature = "rtu")]
use tokio_serial;
use crate::error::{ModbusError, ModbusResult};
use crate::protocol::{ModbusFunction, ModbusRequest, ModbusResponse};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PacketDirection {
Send,
Receive,
}
impl PacketDirection {
pub fn as_str(&self) -> &'static str {
match self {
PacketDirection::Send => "SEND",
PacketDirection::Receive => "RECV",
}
}
}
pub type PacketCallback = Arc<dyn Fn(PacketDirection, &[u8]) + Send + Sync>;
#[allow(dead_code)]
const MAX_TCP_FRAME_SIZE: usize = 260;
const MBAP_HEADER_SIZE: usize = 6;
#[cfg(feature = "rtu")]
const MAX_RTU_FRAME_SIZE: usize = 256;
#[cfg(feature = "rtu")]
const CRC_MODBUS: Crc<u16> = Crc::<u16>::new(&CRC_16_MODBUS);
fn format_hex_packet(data: &[u8]) -> String {
use std::fmt::Write;
if data.is_empty() {
return String::new();
}
let mut result = String::with_capacity(data.len() * 3 - 1);
for (i, b) in data.iter().enumerate() {
if i > 0 {
result.push(' ');
}
let _ = write!(result, "{:02X}", b);
}
result
}
fn log_packet(direction: &str, data: &[u8], protocol: &str, slave_id: Option<u8>) {
let hex_string = format_hex_packet(data);
match slave_id {
Some(id) => info!(
"[MODBUS-{}] {} slave:{} {}",
protocol, direction, id, hex_string
),
None => info!("[MODBUS-{}] {} {}", protocol, direction, hex_string),
}
}
pub trait ModbusTransport: Send + Sync {
fn request(
&mut self,
request: &ModbusRequest,
) -> impl std::future::Future<Output = ModbusResult<ModbusResponse>> + Send;
fn is_connected(&self) -> bool;
fn close(&mut self) -> impl std::future::Future<Output = ModbusResult<()>> + Send;
fn get_stats(&self) -> TransportStats;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct TransportStats {
pub requests_sent: u64,
pub responses_received: u64,
pub errors: u64,
pub timeouts: u64,
pub bytes_sent: u64,
pub bytes_received: u64,
}
pub struct TcpTransport {
stream: Option<TcpStream>,
pub address: SocketAddr,
timeout: Duration,
transaction_id: u16,
stats: TransportStats,
packet_logging: bool,
packet_callback: Option<PacketCallback>,
}
impl TcpTransport {
pub async fn new(address: SocketAddr, timeout: Duration) -> ModbusResult<Self> {
let stream = TcpStream::connect(address).await.map_err(|e| {
ModbusError::connection(format!("Failed to connect to {}: {}", address, e))
})?;
Ok(Self {
stream: Some(stream),
address,
timeout,
transaction_id: 1,
stats: TransportStats::default(),
packet_logging: false,
packet_callback: None,
})
}
pub async fn with_packet_logging(
address: SocketAddr,
timeout: Duration,
enable_logging: bool,
) -> ModbusResult<Self> {
let stream = TcpStream::connect(address).await.map_err(|e| {
ModbusError::connection(format!("Failed to connect to {}: {}", address, e))
})?;
Ok(Self {
stream: Some(stream),
address,
timeout,
transaction_id: 1,
stats: TransportStats::default(),
packet_logging: enable_logging,
packet_callback: None,
})
}
pub fn set_packet_logging(&mut self, enabled: bool) {
self.packet_logging = enabled;
}
pub fn set_packet_callback(&mut self, callback: PacketCallback) {
self.packet_callback = Some(callback);
}
pub fn clear_packet_callback(&mut self) {
self.packet_callback = None;
}
async fn reconnect(&mut self) -> ModbusResult<()> {
self.stream = None;
let stream = TcpStream::connect(self.address).await.map_err(|e| {
ModbusError::connection(format!("Failed to reconnect to {}: {}", self.address, e))
})?;
self.stream = Some(stream);
Ok(())
}
fn next_transaction_id(&mut self) -> u16 {
self.transaction_id = self.transaction_id.wrapping_add(1);
if self.transaction_id == 0 {
self.transaction_id = 1;
}
self.transaction_id
}
fn encode_request(&mut self, request: &ModbusRequest) -> Vec<u8> {
let transaction_id = self.next_transaction_id();
let protocol_id = 0u16;
let pdu_length = 1
+ 1
+ match request.function {
ModbusFunction::ReadCoils
| ModbusFunction::ReadDiscreteInputs
| ModbusFunction::ReadHoldingRegisters
| ModbusFunction::ReadInputRegisters => 4,
ModbusFunction::WriteSingleCoil | ModbusFunction::WriteSingleRegister => 4,
ModbusFunction::WriteMultipleCoils | ModbusFunction::WriteMultipleRegisters => {
5 + request.data.len()
} };
let mut frame = Vec::with_capacity(MBAP_HEADER_SIZE + pdu_length);
frame.extend_from_slice(&transaction_id.to_be_bytes());
frame.extend_from_slice(&protocol_id.to_be_bytes());
frame.extend_from_slice(&(pdu_length as u16).to_be_bytes());
frame.push(request.slave_id);
frame.push(request.function.to_u8());
frame.extend_from_slice(&request.address.to_be_bytes());
match request.function {
ModbusFunction::ReadCoils
| ModbusFunction::ReadDiscreteInputs
| ModbusFunction::ReadHoldingRegisters
| ModbusFunction::ReadInputRegisters => {
frame.extend_from_slice(&request.quantity.to_be_bytes());
}
ModbusFunction::WriteSingleCoil => {
let value: u16 = if !request.data.is_empty() && request.data[0] != 0 {
0xFF00
} else {
0x0000
};
frame.extend_from_slice(&value.to_be_bytes());
}
ModbusFunction::WriteSingleRegister => {
if request.data.len() >= 2 {
frame.extend_from_slice(&request.data[0..2]);
} else {
frame.extend_from_slice(&[0, 0]);
}
}
ModbusFunction::WriteMultipleCoils | ModbusFunction::WriteMultipleRegisters => {
frame.extend_from_slice(&request.quantity.to_be_bytes());
frame.push(request.data.len() as u8);
frame.extend_from_slice(&request.data);
}
}
frame
}
fn decode_response(&self, frame: Vec<u8>) -> ModbusResult<ModbusResponse> {
if frame.len() < MBAP_HEADER_SIZE + 2 {
return Err(ModbusError::frame("Frame too short"));
}
let _protocol_id = u16::from_be_bytes([frame[2], frame[3]]);
let length = u16::from_be_bytes([frame[4], frame[5]]);
let slave_id = frame[6];
if frame.len() < MBAP_HEADER_SIZE + length as usize {
return Err(ModbusError::frame("Incomplete frame"));
}
let function_code = frame[7];
if function_code & 0x80 != 0 {
if frame.len() < MBAP_HEADER_SIZE + 3 {
return Err(ModbusError::frame("Invalid exception response"));
}
let original_function = function_code & 0x7F;
let exception_code = frame[8];
return Ok(ModbusResponse::new_exception(
slave_id,
ModbusFunction::from_u8(original_function)?,
exception_code,
));
}
let function = ModbusFunction::from_u8(function_code)?;
let data_start = MBAP_HEADER_SIZE + 2;
let data_len = (length as usize).saturating_sub(2);
Ok(ModbusResponse::new_from_frame(
frame, slave_id, function, data_start, data_len,
))
}
}
impl ModbusTransport for TcpTransport {
async fn request(&mut self, request: &ModbusRequest) -> ModbusResult<ModbusResponse> {
request.validate()?;
if self.stream.is_none() {
self.reconnect().await?;
}
let frame = self.encode_request(request);
let expected_transaction_id = self.transaction_id;
self.stats.requests_sent += 1;
self.stats.bytes_sent += frame.len() as u64;
if let Some(ref callback) = self.packet_callback {
callback(PacketDirection::Send, &frame);
}
if self.packet_logging {
log_packet("send", &frame, "TCP", Some(request.slave_id));
}
let stream = self.stream.as_mut().unwrap();
let send_result = timeout(self.timeout, stream.write_all(&frame)).await;
if send_result.is_err() || send_result.unwrap().is_err() {
self.stats.timeouts += 1;
self.stats.errors += 1;
self.stream = None; return Err(ModbusError::timeout(
"send request",
self.timeout.as_millis() as u64,
));
}
let response_buf = loop {
let mut header_buf = [0u8; MBAP_HEADER_SIZE + 1];
let read_result = timeout(self.timeout, stream.read_exact(&mut header_buf)).await;
if read_result.is_err() || read_result.unwrap().is_err() {
self.stats.timeouts += 1;
self.stats.errors += 1;
self.stream = None;
return Err(ModbusError::timeout(
"read response header",
self.timeout.as_millis() as u64,
));
}
let length = u16::from_be_bytes([header_buf[4], header_buf[5]]);
if !(2..=254).contains(&length) {
self.stats.errors += 1;
self.stream = None;
return Err(ModbusError::frame(format!(
"Invalid MBAP length: {} (must be 2-254)",
length
)));
}
let protocol_id = u16::from_be_bytes([header_buf[2], header_buf[3]]);
if protocol_id != 0 {
self.stats.errors += 1;
self.stream = None;
return Err(ModbusError::frame(format!(
"Invalid protocol ID: {:04X} (expected 0000)",
protocol_id
)));
}
let remaining_bytes = (length as usize).saturating_sub(1); let mut response_buf = vec![0u8; MBAP_HEADER_SIZE + 1 + remaining_bytes];
response_buf[..MBAP_HEADER_SIZE + 1].copy_from_slice(&header_buf);
if remaining_bytes > 0 {
let read_result = timeout(
self.timeout,
stream.read_exact(&mut response_buf[MBAP_HEADER_SIZE + 1..]),
)
.await;
if read_result.is_err() || read_result.unwrap().is_err() {
self.stats.timeouts += 1;
self.stats.errors += 1;
self.stream = None;
return Err(ModbusError::timeout(
"read response data",
self.timeout.as_millis() as u64,
));
}
}
self.stats.bytes_received += response_buf.len() as u64;
if let Some(ref callback) = self.packet_callback {
callback(PacketDirection::Receive, &response_buf);
}
if self.packet_logging {
log_packet("receive", &response_buf, "TCP", Some(request.slave_id));
}
let actual_tid = u16::from_be_bytes([response_buf[0], response_buf[1]]);
if actual_tid != expected_transaction_id {
debug!(
"Discarding mismatched response: TID={:04X}, expecting {:04X}",
actual_tid,
expected_transaction_id
);
continue;
}
let actual_unit_id = response_buf[6];
if actual_unit_id != request.slave_id {
debug!(
"Discarding mismatched response: Unit ID={}, expecting {}",
actual_unit_id,
request.slave_id
);
continue;
}
break response_buf;
};
self.stats.responses_received += 1;
let response = self.decode_response(response_buf)?;
if let Some(error) = response.get_exception() {
self.stats.errors += 1;
return Err(error);
}
Ok(response)
}
fn is_connected(&self) -> bool {
self.stream.is_some()
}
async fn close(&mut self) -> ModbusResult<()> {
if let Some(mut stream) = self.stream.take() {
let _ = stream.shutdown().await;
}
Ok(())
}
fn get_stats(&self) -> TransportStats {
self.stats
}
}
#[cfg(feature = "rtu")]
pub struct RtuTransport {
port: Option<tokio_serial::SerialStream>,
port_name: String,
baud_rate: u32,
data_bits: tokio_serial::DataBits,
stop_bits: tokio_serial::StopBits,
parity: tokio_serial::Parity,
timeout: Duration,
frame_gap: Duration,
stats: TransportStats,
packet_logging: bool,
packet_callback: Option<PacketCallback>,
}
#[cfg(feature = "rtu")]
impl RtuTransport {
pub fn new(port: &str, baud_rate: u32) -> ModbusResult<Self> {
Self::new_with_config(
port,
baud_rate,
tokio_serial::DataBits::Eight,
tokio_serial::StopBits::One,
tokio_serial::Parity::None,
Duration::from_millis(1000),
)
}
pub fn new_with_config(
port: &str,
baud_rate: u32,
data_bits: tokio_serial::DataBits,
stop_bits: tokio_serial::StopBits,
parity: tokio_serial::Parity,
timeout: Duration,
) -> ModbusResult<Self> {
let char_time_us = (11_000_000 / baud_rate) as u64; let frame_gap = Duration::from_micros(char_time_us * 35 / 10);
let mut transport = Self {
port: None,
port_name: port.to_string(),
baud_rate,
data_bits,
stop_bits,
parity,
timeout,
frame_gap,
stats: TransportStats::default(),
packet_logging: false,
packet_callback: None,
};
transport.connect()?;
Ok(transport)
}
pub fn new_with_packet_logging(
port: &str,
baud_rate: u32,
data_bits: tokio_serial::DataBits,
stop_bits: tokio_serial::StopBits,
parity: tokio_serial::Parity,
timeout: Duration,
enable_logging: bool,
) -> ModbusResult<Self> {
let char_time_us = (11_000_000 / baud_rate) as u64;
let frame_gap = Duration::from_micros(char_time_us * 35 / 10);
let mut transport = Self {
port: None,
port_name: port.to_string(),
baud_rate,
data_bits,
stop_bits,
parity,
timeout,
frame_gap,
stats: TransportStats::default(),
packet_logging: enable_logging,
packet_callback: None,
};
transport.connect()?;
Ok(transport)
}
pub fn set_packet_logging(&mut self, enabled: bool) {
self.packet_logging = enabled;
}
pub fn set_packet_callback(&mut self, callback: PacketCallback) {
self.packet_callback = Some(callback);
}
pub fn clear_packet_callback(&mut self) {
self.packet_callback = None;
}
fn connect(&mut self) -> ModbusResult<()> {
let builder = tokio_serial::new(&self.port_name, self.baud_rate)
.data_bits(self.data_bits)
.stop_bits(self.stop_bits)
.parity(self.parity)
.timeout(self.timeout);
let port = tokio_serial::SerialStream::open(&builder).map_err(|e| {
ModbusError::connection(format!(
"Failed to open serial port {}: {}",
self.port_name, e
))
})?;
self.port = Some(port);
Ok(())
}
fn calculate_crc(data: &[u8]) -> u16 {
CRC_MODBUS.checksum(data)
}
fn encode_request(&self, request: &ModbusRequest) -> ModbusResult<Vec<u8>> {
let mut frame = Vec::new();
frame.push(request.slave_id);
frame.push(request.function.to_u8());
match request.function {
ModbusFunction::ReadCoils
| ModbusFunction::ReadDiscreteInputs
| ModbusFunction::ReadHoldingRegisters
| ModbusFunction::ReadInputRegisters => {
frame.extend_from_slice(&request.address.to_be_bytes());
frame.extend_from_slice(&request.quantity.to_be_bytes());
}
ModbusFunction::WriteSingleCoil => {
frame.extend_from_slice(&request.address.to_be_bytes());
let value: u16 = if !request.data.is_empty() && request.data[0] != 0 {
0xFF00
} else {
0x0000
};
frame.extend_from_slice(&value.to_be_bytes());
}
ModbusFunction::WriteSingleRegister => {
frame.extend_from_slice(&request.address.to_be_bytes());
if request.data.len() >= 2 {
frame.extend_from_slice(&request.data[0..2]);
} else {
frame.extend_from_slice(&[0, 0]);
}
}
ModbusFunction::WriteMultipleCoils => {
frame.extend_from_slice(&request.address.to_be_bytes());
frame.extend_from_slice(&request.quantity.to_be_bytes());
frame.push(request.data.len() as u8);
frame.extend_from_slice(&request.data);
}
ModbusFunction::WriteMultipleRegisters => {
frame.extend_from_slice(&request.address.to_be_bytes());
frame.extend_from_slice(&request.quantity.to_be_bytes());
frame.push(request.data.len() as u8);
frame.extend_from_slice(&request.data);
}
}
let crc = Self::calculate_crc(&frame);
frame.extend_from_slice(&crc.to_le_bytes());
Ok(frame)
}
fn decode_response(&self, frame: Vec<u8>) -> ModbusResult<ModbusResponse> {
if frame.len() < 4 {
return Err(ModbusError::frame("RTU frame too short"));
}
let pdu_len = frame.len() - 2;
let received_crc = u16::from_le_bytes([frame[pdu_len], frame[pdu_len + 1]]);
let calculated_crc = Self::calculate_crc(&frame[..pdu_len]);
if received_crc != calculated_crc {
return Err(ModbusError::frame(format!(
"CRC mismatch: expected 0x{:04X}, got 0x{:04X}",
calculated_crc, received_crc
)));
}
let slave_id = frame[0];
let function_code = frame[1];
if function_code & 0x80 != 0 {
if frame.len() < 5 {
return Err(ModbusError::frame("Invalid exception response"));
}
let original_function = function_code & 0x7F;
let exception_code = frame[2];
return Ok(ModbusResponse::new_exception(
slave_id,
ModbusFunction::from_u8(original_function)?,
exception_code,
));
}
let function = ModbusFunction::from_u8(function_code)?;
let data_start = 2;
let data_len = pdu_len.saturating_sub(2);
Ok(ModbusResponse::new_from_frame(
frame, slave_id, function, data_start, data_len,
))
}
async fn wait_frame_gap(&self) {
tokio::time::sleep(self.frame_gap).await;
}
async fn read_frame(&mut self) -> ModbusResult<Vec<u8>> {
let port = self
.port
.as_mut()
.ok_or_else(|| ModbusError::connection("Serial port not connected"))?;
let mut frame = Vec::new();
let mut buffer = [0u8; 1];
loop {
match timeout(self.frame_gap, port.read_exact(&mut buffer)).await {
Ok(Ok(_)) => {
frame.push(buffer[0]);
if frame.len() > MAX_RTU_FRAME_SIZE {
return Err(ModbusError::frame("RTU frame too large"));
}
}
Ok(Err(e)) => {
return Err(ModbusError::io(format!("Serial read error: {}", e)));
}
Err(_) => {
if !frame.is_empty() {
break;
}
}
}
}
if frame.is_empty() {
return Err(ModbusError::timeout(
"No response received",
self.timeout.as_millis() as u64,
));
}
Ok(frame)
}
}
#[cfg(feature = "rtu")]
impl ModbusTransport for RtuTransport {
async fn request(&mut self, request: &ModbusRequest) -> ModbusResult<ModbusResponse> {
request.validate()?;
if self.port.is_none() {
self.connect()?;
}
self.wait_frame_gap().await;
let frame = self.encode_request(request)?;
self.stats.requests_sent += 1;
self.stats.bytes_sent += frame.len() as u64;
if let Some(ref callback) = self.packet_callback {
callback(PacketDirection::Send, &frame);
}
if self.packet_logging {
log_packet("send", &frame, "RTU", Some(request.slave_id));
}
let port = self
.port
.as_mut()
.ok_or_else(|| ModbusError::connection("Serial port not connected"))?;
let send_result = timeout(self.timeout, port.write_all(&frame)).await;
match send_result {
Ok(Ok(_)) => {
let _ = timeout(self.timeout, port.flush()).await;
}
Ok(Err(e)) => {
self.stats.errors += 1;
return Err(ModbusError::io(format!("Failed to send RTU frame: {}", e)));
}
Err(_) => {
self.stats.timeouts += 1;
self.stats.errors += 1;
return Err(ModbusError::timeout(
"send request",
self.timeout.as_millis() as u64,
));
}
}
let response_frame = match timeout(self.timeout, self.read_frame()).await {
Ok(Ok(frame)) => frame,
Ok(Err(e)) => {
self.stats.errors += 1;
return Err(e);
}
Err(_) => {
self.stats.timeouts += 1;
self.stats.errors += 1;
return Err(ModbusError::timeout(
"read response",
self.timeout.as_millis() as u64,
));
}
};
self.stats.responses_received += 1;
self.stats.bytes_received += response_frame.len() as u64;
if let Some(ref callback) = self.packet_callback {
callback(PacketDirection::Receive, &response_frame);
}
if self.packet_logging {
log_packet("receive", &response_frame, "RTU", Some(request.slave_id));
}
let response = self.decode_response(response_frame)?;
if response.slave_id != request.slave_id {
self.stats.errors += 1;
return Err(ModbusError::protocol(format!(
"Response slave ID mismatch: expected {}, got {}",
request.slave_id, response.slave_id
)));
}
if let Some(error) = response.get_exception() {
self.stats.errors += 1;
return Err(error);
}
Ok(response)
}
fn is_connected(&self) -> bool {
self.port.is_some()
}
async fn close(&mut self) -> ModbusResult<()> {
if let Some(_port) = self.port.take() {
}
Ok(())
}
fn get_stats(&self) -> TransportStats {
self.stats
}
}
#[cfg(feature = "rtu")]
pub struct AsciiTransport {
port: Option<tokio_serial::SerialStream>,
port_name: String,
baud_rate: u32,
data_bits: tokio_serial::DataBits,
stop_bits: tokio_serial::StopBits,
parity: tokio_serial::Parity,
timeout: Duration,
inter_char_timeout: Duration,
stats: TransportStats,
}
#[cfg(feature = "rtu")]
impl AsciiTransport {
pub fn new(port: &str, baud_rate: u32) -> ModbusResult<Self> {
Self::new_with_config(
port,
baud_rate,
tokio_serial::DataBits::Seven, tokio_serial::StopBits::One,
tokio_serial::Parity::Even, Duration::from_secs(1), Duration::from_millis(1000), )
}
pub fn new_with_config(
port: &str,
baud_rate: u32,
data_bits: tokio_serial::DataBits,
stop_bits: tokio_serial::StopBits,
parity: tokio_serial::Parity,
timeout: Duration,
inter_char_timeout: Duration,
) -> ModbusResult<Self> {
let mut transport = Self {
port: None,
port_name: port.to_string(),
baud_rate,
data_bits,
stop_bits,
parity,
timeout,
inter_char_timeout,
stats: TransportStats::default(),
};
transport.connect()?;
Ok(transport)
}
fn connect(&mut self) -> ModbusResult<()> {
let builder = tokio_serial::new(&self.port_name, self.baud_rate)
.data_bits(self.data_bits)
.stop_bits(self.stop_bits)
.parity(self.parity)
.timeout(self.timeout);
let port = tokio_serial::SerialStream::open(&builder).map_err(|e| {
ModbusError::connection(format!(
"Failed to open serial port {}: {}",
self.port_name, e
))
})?;
self.port = Some(port);
Ok(())
}
fn calculate_lrc(data: &[u8]) -> u8 {
let sum: u16 = data.iter().map(|&b| b as u16).sum();
(-(sum as i16)) as u8 }
fn byte_to_ascii_hex(byte: u8) -> [u8; 2] {
let high = (byte >> 4) & 0x0F;
let low = byte & 0x0F;
let high_char = if high < 10 {
b'0' + high
} else {
b'A' + (high - 10)
};
let low_char = if low < 10 {
b'0' + low
} else {
b'A' + (low - 10)
};
[high_char, low_char]
}
fn ascii_hex_to_byte(ascii: &[u8]) -> ModbusResult<u8> {
if ascii.len() != 2 {
return Err(ModbusError::frame("Invalid ASCII hex length"));
}
let high = Self::ascii_char_to_hex(ascii[0])?;
let low = Self::ascii_char_to_hex(ascii[1])?;
Ok((high << 4) | low)
}
fn ascii_char_to_hex(c: u8) -> ModbusResult<u8> {
match c {
b'0'..=b'9' => Ok(c - b'0'),
b'A'..=b'F' => Ok(c - b'A' + 10),
b'a'..=b'f' => Ok(c - b'a' + 10),
_ => Err(ModbusError::frame(format!(
"Invalid ASCII hex character: {}",
c as char
))),
}
}
fn encode_request(&self, request: &ModbusRequest) -> ModbusResult<Vec<u8>> {
let mut raw_data = Vec::new();
raw_data.push(request.slave_id);
raw_data.push(request.function.to_u8());
match request.function {
ModbusFunction::ReadCoils
| ModbusFunction::ReadDiscreteInputs
| ModbusFunction::ReadHoldingRegisters
| ModbusFunction::ReadInputRegisters => {
raw_data.extend_from_slice(&request.address.to_be_bytes());
raw_data.extend_from_slice(&request.quantity.to_be_bytes());
}
ModbusFunction::WriteSingleCoil => {
raw_data.extend_from_slice(&request.address.to_be_bytes());
let value: u16 = if !request.data.is_empty() && request.data[0] != 0 {
0xFF00
} else {
0x0000
};
raw_data.extend_from_slice(&value.to_be_bytes());
}
ModbusFunction::WriteSingleRegister => {
raw_data.extend_from_slice(&request.address.to_be_bytes());
if request.data.len() >= 2 {
raw_data.extend_from_slice(&request.data[0..2]);
} else {
raw_data.extend_from_slice(&[0, 0]);
}
}
ModbusFunction::WriteMultipleCoils | ModbusFunction::WriteMultipleRegisters => {
raw_data.extend_from_slice(&request.address.to_be_bytes());
raw_data.extend_from_slice(&request.quantity.to_be_bytes());
raw_data.push(request.data.len() as u8);
raw_data.extend_from_slice(&request.data);
}
}
let lrc = Self::calculate_lrc(&raw_data);
let mut frame = Vec::new();
frame.push(b':');
for &byte in &raw_data {
let ascii_hex = Self::byte_to_ascii_hex(byte);
frame.extend_from_slice(&ascii_hex);
}
let lrc_ascii = Self::byte_to_ascii_hex(lrc);
frame.extend_from_slice(&lrc_ascii);
frame.push(0x0D); frame.push(0x0A);
Ok(frame)
}
fn decode_response(&self, frame: Vec<u8>) -> ModbusResult<ModbusResponse> {
if frame.len() < 11 {
return Err(ModbusError::frame("ASCII frame too short"));
}
if frame[0] != b':' {
return Err(ModbusError::frame("Invalid ASCII frame start character"));
}
let len = frame.len();
if frame[len - 2] != 0x0D || frame[len - 1] != 0x0A {
return Err(ModbusError::frame("Invalid ASCII frame end characters"));
}
let ascii_data = &frame[1..len - 2];
if ascii_data.len() % 2 != 0 {
return Err(ModbusError::frame("Invalid ASCII frame length"));
}
let mut raw_data = Vec::with_capacity(ascii_data.len() / 2);
for chunk in ascii_data.chunks(2) {
let byte = Self::ascii_hex_to_byte(chunk)?;
raw_data.push(byte);
}
if raw_data.len() < 3 {
return Err(ModbusError::frame("ASCII frame too short after decoding"));
}
let received_lrc = raw_data.pop().unwrap();
let calculated_lrc = Self::calculate_lrc(&raw_data);
if received_lrc != calculated_lrc {
return Err(ModbusError::frame(format!(
"LRC mismatch: expected 0x{:02X}, got 0x{:02X}",
calculated_lrc, received_lrc
)));
}
let slave_id = raw_data[0];
let function_code = raw_data[1];
if function_code & 0x80 != 0 {
if raw_data.len() < 3 {
return Err(ModbusError::frame("Invalid exception response"));
}
let original_function = function_code & 0x7F;
let exception_code = raw_data[2];
return Ok(ModbusResponse::new_exception(
slave_id,
ModbusFunction::from_u8(original_function)?,
exception_code,
));
}
let function = ModbusFunction::from_u8(function_code)?;
let data_start = 2;
let data_len = raw_data.len().saturating_sub(2);
Ok(ModbusResponse::new_from_frame(
raw_data, slave_id, function, data_start, data_len,
))
}
async fn read_frame(&mut self) -> ModbusResult<Vec<u8>> {
let port = self
.port
.as_mut()
.ok_or_else(|| ModbusError::connection("Serial port not connected"))?;
let mut frame = Vec::new();
let mut buffer = [0u8; 1];
loop {
match timeout(self.inter_char_timeout, port.read_exact(&mut buffer)).await {
Ok(Ok(_)) => {
frame.push(buffer[0]);
if frame.len() >= 2
&& frame[frame.len() - 2] == 0x0D
&& frame[frame.len() - 1] == 0x0A
{
break;
}
if frame.len() > MAX_RTU_FRAME_SIZE * 2 {
return Err(ModbusError::frame("ASCII frame too large"));
}
}
Ok(Err(e)) => {
return Err(ModbusError::io(format!("Serial read error: {}", e)));
}
Err(_) => {
if frame.is_empty() {
continue;
} else {
return Err(ModbusError::timeout(
"Incomplete ASCII frame",
self.inter_char_timeout.as_millis() as u64,
));
}
}
}
}
if frame.is_empty() {
return Err(ModbusError::timeout(
"No response received",
self.timeout.as_millis() as u64,
));
}
Ok(frame)
}
}
#[cfg(feature = "rtu")]
impl ModbusTransport for AsciiTransport {
async fn request(&mut self, request: &ModbusRequest) -> ModbusResult<ModbusResponse> {
request.validate()?;
if self.port.is_none() {
self.connect()?;
}
let frame = self.encode_request(request)?;
self.stats.requests_sent += 1;
self.stats.bytes_sent += frame.len() as u64;
let port = self
.port
.as_mut()
.ok_or_else(|| ModbusError::connection("Serial port not connected"))?;
let send_result = timeout(self.timeout, port.write_all(&frame)).await;
match send_result {
Ok(Ok(_)) => {
let _ = timeout(self.timeout, port.flush()).await;
}
Ok(Err(e)) => {
self.stats.errors += 1;
return Err(ModbusError::io(format!(
"Failed to send ASCII frame: {}",
e
)));
}
Err(_) => {
self.stats.timeouts += 1;
self.stats.errors += 1;
return Err(ModbusError::timeout(
"send request",
self.timeout.as_millis() as u64,
));
}
}
let response_frame = match timeout(self.timeout, self.read_frame()).await {
Ok(Ok(frame)) => frame,
Ok(Err(e)) => {
self.stats.errors += 1;
return Err(e);
}
Err(_) => {
self.stats.timeouts += 1;
self.stats.errors += 1;
return Err(ModbusError::timeout(
"read response",
self.timeout.as_millis() as u64,
));
}
};
self.stats.responses_received += 1;
self.stats.bytes_received += response_frame.len() as u64;
let response = self.decode_response(response_frame)?;
if response.slave_id != request.slave_id {
self.stats.errors += 1;
return Err(ModbusError::protocol(format!(
"Response slave ID mismatch: expected {}, got {}",
request.slave_id, response.slave_id
)));
}
if let Some(error) = response.get_exception() {
self.stats.errors += 1;
return Err(error);
}
Ok(response)
}
fn is_connected(&self) -> bool {
self.port.is_some()
}
async fn close(&mut self) -> ModbusResult<()> {
if let Some(_port) = self.port.take() {
}
Ok(())
}
fn get_stats(&self) -> TransportStats {
self.stats
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_tcp_transport_creation() {
let addr = "127.0.0.1:502".parse().unwrap();
let timeout = Duration::from_secs(5);
let result = TcpTransport::new(addr, timeout).await;
println!("TCP transport creation result: {:?}", result.is_ok());
}
#[test]
fn test_transaction_id_mismatch_error() {
let error = ModbusError::transaction_id_mismatch(0x1234, 0x5678);
assert!(matches!(
error,
ModbusError::TransactionIdMismatch {
expected: 0x1234,
actual: 0x5678
}
));
assert!(error.is_recoverable());
assert!(error.is_protocol_error());
let error_msg = format!("{}", error);
assert!(error_msg.contains("1234"));
assert!(error_msg.contains("5678"));
assert!(error_msg.contains("Transaction ID mismatch"));
}
#[test]
fn test_tcp_transaction_id_generation() {
let mut transport = TcpTransport {
stream: None,
address: "127.0.0.1:502".parse().unwrap(),
timeout: Duration::from_secs(5),
transaction_id: 0,
stats: TransportStats::default(),
packet_logging: false,
packet_callback: None,
};
let id1 = transport.next_transaction_id();
assert_eq!(id1, 1);
let id2 = transport.next_transaction_id();
assert_eq!(id2, 2);
transport.transaction_id = u16::MAX;
let id_after_wrap = transport.next_transaction_id();
assert_eq!(id_after_wrap, 1); }
#[test]
fn test_tcp_encode_request_sets_transaction_id() {
use crate::protocol::{ModbusFunction, ModbusRequest};
let mut transport = TcpTransport {
stream: None,
address: "127.0.0.1:502".parse().unwrap(),
timeout: Duration::from_secs(5),
transaction_id: 0,
stats: TransportStats::default(),
packet_logging: false,
packet_callback: None,
};
let request = ModbusRequest::new_read(
1, ModbusFunction::ReadHoldingRegisters, 0, 10, );
let frame = transport.encode_request(&request);
let tid_in_frame = u16::from_be_bytes([frame[0], frame[1]]);
assert_eq!(tid_in_frame, transport.transaction_id);
assert_eq!(transport.transaction_id, 1);
let frame2 = transport.encode_request(&request);
let tid_in_frame2 = u16::from_be_bytes([frame2[0], frame2[1]]);
assert_eq!(tid_in_frame2, 2);
}
}
#[cfg(all(test, feature = "rtu"))]
mod rtu_tests {
use super::*;
use crate::protocol::ModbusFunction;
#[test]
fn test_crc_calculation() {
let data = [0x01, 0x03, 0x00, 0x00, 0x00, 0x02];
let crc = RtuTransport::calculate_crc(&data);
assert!(crc > 0);
}
#[test]
fn test_ascii_lrc_calculation() {
let data = [0x01, 0x03, 0x00, 0x00, 0x00, 0x02];
let lrc = AsciiTransport::calculate_lrc(&data);
let sum: u16 = data.iter().map(|&b| b as u16).sum();
let expected_lrc = (-(sum as i16)) as u8;
assert_eq!(lrc, expected_lrc);
}
#[test]
fn test_ascii_hex_conversion() {
let ascii_hex = AsciiTransport::byte_to_ascii_hex(0x1A);
assert_eq!(ascii_hex, [b'1', b'A']);
let ascii_hex = AsciiTransport::byte_to_ascii_hex(0x0F);
assert_eq!(ascii_hex, [b'0', b'F']);
let byte = AsciiTransport::ascii_hex_to_byte(&[b'1', b'A']).unwrap();
assert_eq!(byte, 0x1A);
let byte = AsciiTransport::ascii_hex_to_byte(&[b'0', b'F']).unwrap();
assert_eq!(byte, 0x0F);
let byte = AsciiTransport::ascii_hex_to_byte(&[b'a', b'f']).unwrap();
assert_eq!(byte, 0xAF);
}
#[test]
fn test_ascii_frame_encoding() {
let transport = create_mock_ascii_transport();
let request = ModbusRequest::new_read(
0x01, ModbusFunction::ReadHoldingRegisters, 0x0000, 0x0002, );
let frame = transport.encode_request(&request).unwrap();
let data = [0x01u8, 0x03, 0x00, 0x00, 0x00, 0x02];
let sum: u16 = data.iter().map(|&b| b as u16).sum(); let expected_lrc = (-(sum as i16)) as u8;
let expected = format!(":010300000002{:02X}\r\n", expected_lrc);
let expected_bytes = expected.as_bytes();
assert_eq!(frame, expected_bytes);
assert_eq!(frame[0], b':'); assert_eq!(frame[frame.len() - 2], 0x0D); assert_eq!(frame[frame.len() - 1], 0x0A); }
#[test]
fn test_ascii_frame_decoding() {
let transport = create_mock_ascii_transport();
let test_data = [0x01u8, 0x03, 0x04, 0x00, 0xAB, 0x00, 0xCD];
let sum: u16 = test_data.iter().map(|&b| b as u16).sum();
let lrc = (-(sum as i16)) as u8;
let frame = format!(":01030400AB00CD{:02X}\r\n", lrc);
let response = transport
.decode_response(frame.as_bytes().to_vec())
.unwrap();
assert_eq!(response.slave_id, 0x01);
assert_eq!(response.function, ModbusFunction::ReadHoldingRegisters);
assert!(!response.is_exception());
let exc_data = [0x01u8, 0x83, 0x02];
let exc_sum: u16 = exc_data.iter().map(|&b| b as u16).sum();
let exc_lrc = (-(exc_sum as i16)) as u8;
let exception_frame = format!(":018302{:02X}\r\n", exc_lrc);
let exception_response = transport
.decode_response(exception_frame.as_bytes().to_vec())
.unwrap();
assert_eq!(exception_response.slave_id, 0x01);
assert!(exception_response.is_exception());
}
#[test]
fn test_ascii_error_handling() {
let transport = create_mock_ascii_transport();
let invalid_start = b"X010300000002C5\r\n".to_vec();
assert!(transport.decode_response(invalid_start).is_err());
let invalid_end = b":010300000002C5\r\r".to_vec();
assert!(transport.decode_response(invalid_end).is_err());
let odd_length = b":01030000002C5\r\n".to_vec();
assert!(transport.decode_response(odd_length).is_err());
let wrong_lrc = b":010300000002FF\r\n".to_vec();
assert!(transport.decode_response(wrong_lrc).is_err());
}
fn create_mock_ascii_transport() -> AsciiTransport {
AsciiTransport {
port: None,
port_name: "mock".to_string(),
baud_rate: 9600,
data_bits: tokio_serial::DataBits::Seven,
stop_bits: tokio_serial::StopBits::One,
parity: tokio_serial::Parity::Even,
timeout: Duration::from_secs(1),
inter_char_timeout: Duration::from_millis(100),
stats: TransportStats::default(),
}
}
}