use std::io::{self, Read, Write};
use std::net::{TcpStream, ToSocketAddrs};
use std::time::Duration;
use heapless::Vec;
use mbus_core::data_unit::common::MAX_ADU_FRAME_LEN;
use mbus_core::transport::{ModbusConfig, Transport, TransportError, TransportType};
#[cfg(feature = "logging")]
macro_rules! transport_log_error {
($($arg:tt)*) => {
log::error!($($arg)*)
};
}
#[cfg(not(feature = "logging"))]
macro_rules! transport_log_error {
($($arg:tt)*) => {{
let _ = core::format_args!($($arg)*);
}};
}
#[cfg(feature = "logging")]
macro_rules! transport_log_warn {
($($arg:tt)*) => {
log::warn!($($arg)*)
};
}
#[cfg(not(feature = "logging"))]
macro_rules! transport_log_warn {
($($arg:tt)*) => {{
let _ = core::format_args!($($arg)*);
}};
}
#[cfg(feature = "logging")]
macro_rules! transport_log_debug {
($($arg:tt)*) => {
log::debug!($($arg)*)
};
}
#[cfg(not(feature = "logging"))]
macro_rules! transport_log_debug {
($($arg:tt)*) => {{
let _ = core::format_args!($($arg)*);
}};
}
#[derive(Debug, Default)]
pub struct StdTcpTransport {
stream: Option<TcpStream>,
}
impl StdTcpTransport {
pub fn new() -> Self {
Self { stream: None }
}
fn map_io_error(err: io::Error) -> TransportError {
match err.kind() {
io::ErrorKind::ConnectionRefused | io::ErrorKind::NotFound => {
TransportError::ConnectionFailed
}
io::ErrorKind::BrokenPipe
| io::ErrorKind::ConnectionReset
| io::ErrorKind::UnexpectedEof => TransportError::ConnectionClosed,
io::ErrorKind::WouldBlock | io::ErrorKind::TimedOut => TransportError::Timeout,
_ => TransportError::IoError,
}
}
}
impl Transport for StdTcpTransport {
type Error = TransportError;
fn connect(&mut self, config: &ModbusConfig) -> Result<(), Self::Error> {
let config = match config {
ModbusConfig::Tcp(c) => c,
_ => return Err(TransportError::Unexpected),
};
let connection_timeout = Duration::from_millis(config.connection_timeout_ms as u64);
let response_timeout = Duration::from_millis(config.response_timeout_ms as u64);
let mut addrs_iter = (config.host.as_str(), config.port)
.to_socket_addrs()
.map_err(|e| {
transport_log_error!("DNS resolution failed: {:?}", e);
TransportError::ConnectionFailed
})?;
let addr = addrs_iter.next().ok_or_else(|| {
transport_log_error!("No valid address found for host:port combination.");
TransportError::ConnectionFailed
})?;
transport_log_debug!("Trying address: {:?}", addr);
match TcpStream::connect_timeout(&addr, connection_timeout) {
Ok(stream) => {
stream
.set_read_timeout(Some(response_timeout))
.unwrap_or_else(|e| transport_log_warn!("Failed to set read timeout: {:?}", e));
stream
.set_write_timeout(Some(response_timeout))
.unwrap_or_else(|e| {
transport_log_warn!("Failed to set write timeout: {:?}", e)
});
stream
.set_nodelay(true)
.unwrap_or_else(|e| transport_log_warn!("Failed to set no-delay: {:?}", e));
self.stream = Some(stream); Ok(()) }
Err(e) => {
transport_log_error!("Connect failed: {:?}", e);
Err(TransportError::ConnectionFailed) }
}
}
fn disconnect(&mut self) -> Result<(), Self::Error> {
if let Some(stream) = self.stream.take() {
drop(stream);
}
Ok(())
}
fn send(&mut self, adu: &[u8]) -> Result<(), Self::Error> {
let stream = self
.stream
.as_mut()
.ok_or(TransportError::ConnectionClosed)?;
let result = stream.write_all(adu).and_then(|()| stream.flush());
if let Err(err) = result {
let transport_error = Self::map_io_error(err);
if transport_error == TransportError::ConnectionClosed {
self.stream = None;
}
return Err(transport_error);
}
Ok(())
}
fn recv(&mut self) -> Result<Vec<u8, MAX_ADU_FRAME_LEN>, Self::Error> {
let stream = self
.stream
.as_mut()
.ok_or(TransportError::ConnectionClosed)?;
let _ = stream.set_nonblocking(true);
let mut temp_buf = [0u8; MAX_ADU_FRAME_LEN];
let read_result = stream.read(&mut temp_buf);
let _ = stream.set_nonblocking(false);
match read_result {
Ok(0) => {
self.stream = None;
Err(TransportError::ConnectionClosed)
}
Ok(n) => {
let mut buffer = Vec::new();
if buffer.extend_from_slice(&temp_buf[..n]).is_err() {
return Err(TransportError::BufferTooSmall);
}
Ok(buffer)
}
Err(e) => {
let err = Self::map_io_error(e);
if err == TransportError::ConnectionClosed {
self.stream = None;
}
Err(err)
}
}
}
fn is_connected(&self) -> bool {
self.stream.is_some()
}
fn transport_type(&self) -> TransportType {
TransportType::StdTcp
}
}
#[cfg(test)]
impl StdTcpTransport {
pub fn stream_mut(&mut self) -> Option<&mut TcpStream> {
self.stream.as_mut()
}
}
#[cfg(test)]
mod tests {
use super::super::std_transport::StdTcpTransport;
use mbus_core::transport::{ModbusConfig, ModbusTcpConfig, Transport, TransportError};
use std::io::{self, Read, Write};
use std::net::TcpListener;
use std::sync::mpsc;
use std::thread;
use std::time::Duration;
fn create_test_listener() -> TcpListener {
TcpListener::bind("127.0.0.1:0").expect("Failed to bind to an available port")
}
fn get_host_port(addr: std::net::SocketAddr) -> u16 {
addr.port()
}
#[test]
fn test_new_std_tcp_transport() {
let transport = StdTcpTransport::new();
assert!(!transport.is_connected());
}
#[test]
fn test_connect_success() {
let listener = create_test_listener();
let addr = listener.local_addr().unwrap();
let (tx, rx) = mpsc::channel();
let server_handle = thread::spawn(move || {
tx.send(()).expect("Failed to send server ready signal"); let _ = listener.accept().unwrap();
});
rx.recv().expect("Failed to receive server ready signal");
let mut transport = StdTcpTransport::new();
let port = get_host_port(addr);
let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
let result = transport.connect(&config);
assert!(result.is_ok());
assert!(transport.is_connected());
server_handle.join().unwrap();
}
#[test]
fn test_connect_failure_invalid_addr() {
let mut transport = StdTcpTransport::new();
let config = ModbusConfig::Tcp(ModbusTcpConfig::new("invalid-address", 502).unwrap()); let result = transport.connect(&config);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), TransportError::ConnectionFailed);
assert!(!transport.is_connected());
}
#[test]
fn test_connect_failure_connection_refused() {
let listener = create_test_listener(); let port = listener.local_addr().unwrap().port();
drop(listener); let mut transport = StdTcpTransport::new();
let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
let result = transport.connect(&config);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), TransportError::ConnectionFailed);
assert!(!transport.is_connected());
}
#[test]
fn test_disconnect() {
let listener = create_test_listener();
let addr = listener.local_addr().unwrap();
let (tx, rx) = mpsc::channel();
let server_handle = thread::spawn(move || {
tx.send(()).expect("Failed to send server ready signal");
let _ = listener.accept().unwrap(); });
rx.recv().expect("Failed to receive server ready signal");
let mut transport = StdTcpTransport::new();
let port = get_host_port(addr);
let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
transport.connect(&config).unwrap();
assert!(transport.is_connected());
let result = transport.disconnect();
assert!(result.is_ok());
assert!(!transport.is_connected());
server_handle.join().unwrap();
}
#[test]
fn test_send_success() {
let listener = create_test_listener();
let addr = listener.local_addr().unwrap();
let (tx, rx) = mpsc::channel();
let test_data = [0x01, 0x02, 0x03, 0x04];
let server_handle = thread::spawn(move || {
tx.send(()).expect("Failed to send server ready signal");
let (mut stream, _) = listener.accept().unwrap();
let mut buf = [0; 4];
stream.read_exact(&mut buf).unwrap();
assert_eq!(buf, test_data);
});
rx.recv().expect("Failed to receive server ready signal");
let mut transport = StdTcpTransport::new();
let port = get_host_port(addr);
let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
transport.connect(&config).unwrap();
let result = transport.send(&test_data);
assert!(result.is_ok());
server_handle.join().unwrap();
}
#[test]
fn test_send_failure_not_connected() {
let mut transport = StdTcpTransport::new();
let test_data = [0x01, 0x02];
let result = transport.send(&test_data);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), TransportError::ConnectionClosed);
}
#[test]
fn test_recv_success_full_adu() {
let listener = create_test_listener();
let addr = listener.local_addr().unwrap();
let (tx, rx) = mpsc::channel();
let adu_to_send = [0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x01, 0x03, 0x00];
let server_handle = thread::spawn(move || {
tx.send(()).expect("Failed to send server ready signal");
let (mut stream, _) = listener.accept().unwrap();
stream.write_all(&adu_to_send).unwrap();
});
rx.recv().expect("Failed to receive server ready signal");
let mut transport = StdTcpTransport::new();
let port = get_host_port(addr);
let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
transport.connect(&config).unwrap();
let mut combined_adu = std::vec::Vec::new();
for _ in 0..50 {
match transport.recv() {
Ok(bytes) => {
combined_adu.extend_from_slice(&bytes);
if combined_adu.len() == adu_to_send.len() {
break;
}
}
Err(TransportError::Timeout) => {
std::thread::sleep(Duration::from_millis(10));
}
Err(e) => panic!("Unexpected error: {:?}", e),
}
}
assert_eq!(combined_adu.as_slice(), adu_to_send);
server_handle.join().unwrap();
}
#[test]
fn test_recv_failure_not_connected() {
let mut transport = StdTcpTransport::new();
let result = transport.recv();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), TransportError::ConnectionClosed);
}
#[test]
fn test_recv_failure_connection_closed_prematurely_header() {
let listener = create_test_listener();
let addr = listener.local_addr().unwrap();
let (tx, rx) = mpsc::channel();
let partial_adu = [0x00, 0x01, 0x00];
let server_handle = thread::spawn(move || {
tx.send(()).expect("Failed to send server ready signal");
let (mut stream, _) = listener.accept().unwrap();
stream.write_all(&partial_adu).unwrap();
});
rx.recv().expect("Failed to receive server ready signal");
let mut transport = StdTcpTransport::new();
let port = get_host_port(addr);
let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
transport.connect(&config).unwrap();
let mut result = transport.recv();
for _ in 0..50 {
if let Err(TransportError::Timeout) = result {
std::thread::sleep(Duration::from_millis(10));
result = transport.recv();
} else if let Ok(_) = result {
result = transport.recv();
} else {
break;
}
}
assert!(result.is_err());
assert_eq!(result.unwrap_err(), TransportError::ConnectionClosed);
server_handle.join().unwrap();
}
#[test]
fn test_recv_failure_connection_closed_prematurely_pdu() {
let listener = create_test_listener();
let addr = listener.local_addr().unwrap();
let (tx, rx) = mpsc::channel();
let partial_adu = [0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x01, 0x03];
let server_handle = thread::spawn(move || {
tx.send(()).expect("Failed to send server ready signal");
let (mut stream, _) = listener.accept().unwrap();
stream.write_all(&partial_adu).unwrap();
});
rx.recv().expect("Failed to receive server ready signal");
let mut transport = StdTcpTransport::new();
let port = get_host_port(addr);
let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
transport.connect(&config).unwrap();
let mut result = transport.recv();
for _ in 0..50 {
if let Err(TransportError::Timeout) = result {
std::thread::sleep(Duration::from_millis(10));
result = transport.recv();
} else if let Ok(_) = result {
result = transport.recv();
} else {
break;
}
}
assert!(result.is_err());
assert_eq!(result.unwrap_err(), TransportError::ConnectionClosed);
server_handle.join().unwrap();
}
#[test]
fn test_recv_timeout() {
let listener = create_test_listener();
let addr = listener.local_addr().unwrap();
let (tx, rx) = mpsc::channel();
let server_handle = thread::spawn(move || {
tx.send(()).expect("Failed to send server ready signal");
let (_stream, _) = listener.accept().unwrap();
thread::sleep(Duration::from_secs(5)); });
rx.recv().expect("Failed to receive server ready signal");
let mut transport = StdTcpTransport::new(); let port = get_host_port(addr);
let mut tcp_config = ModbusTcpConfig::new("127.0.0.1", port).unwrap();
tcp_config.response_timeout_ms = 100; let config = ModbusConfig::Tcp(tcp_config);
transport.connect(&config).unwrap();
let result = transport.recv();
assert!(result.is_err());
assert_eq!(result.unwrap_err(), TransportError::Timeout);
server_handle.join().unwrap();
}
#[test]
fn test_is_connected() {
let listener = create_test_listener();
let addr = listener.local_addr().unwrap();
let (tx, rx) = mpsc::channel();
let server_handle = thread::spawn(move || {
tx.send(()).expect("Failed to send server ready signal");
let (_stream, _) = listener.accept().unwrap();
thread::sleep(Duration::from_millis(500)); });
rx.recv().expect("Failed to receive server ready signal");
let mut transport = StdTcpTransport::new();
let port = get_host_port(addr);
assert!(!transport.is_connected());
let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
transport.connect(&config).unwrap();
assert!(transport.is_connected());
transport.disconnect().unwrap();
assert!(!transport.is_connected());
server_handle.join().unwrap();
}
#[test]
fn test_map_io_error() {
let err = io::Error::new(io::ErrorKind::ConnectionRefused, "test");
assert_eq!(
StdTcpTransport::map_io_error(err),
TransportError::ConnectionFailed
);
let err = io::Error::new(io::ErrorKind::NotFound, "test");
assert_eq!(
StdTcpTransport::map_io_error(err),
TransportError::ConnectionFailed
);
let err = io::Error::new(io::ErrorKind::BrokenPipe, "test");
assert_eq!(
StdTcpTransport::map_io_error(err),
TransportError::ConnectionClosed
);
let err = io::Error::new(io::ErrorKind::ConnectionReset, "test");
assert_eq!(
StdTcpTransport::map_io_error(err),
TransportError::ConnectionClosed
);
let err = io::Error::new(io::ErrorKind::UnexpectedEof, "test");
assert_eq!(
StdTcpTransport::map_io_error(err),
TransportError::ConnectionClosed
);
let err = io::Error::new(io::ErrorKind::WouldBlock, "test");
assert_eq!(StdTcpTransport::map_io_error(err), TransportError::Timeout);
let err = io::Error::new(io::ErrorKind::TimedOut, "test");
assert_eq!(StdTcpTransport::map_io_error(err), TransportError::Timeout);
let err = io::Error::new(io::ErrorKind::PermissionDenied, "test");
assert_eq!(StdTcpTransport::map_io_error(err), TransportError::IoError);
}
#[test]
fn test_connect_with_custom_timeout() {
let listener = create_test_listener();
let addr = listener.local_addr().unwrap();
let (tx, rx) = mpsc::channel();
let server_handle = thread::spawn(move || {
tx.send(()).expect("Failed to send server ready signal");
let _ = listener.accept().unwrap();
});
rx.recv().expect("Failed to receive server ready signal");
let mut transport = StdTcpTransport::new(); let port = get_host_port(addr);
let mut tcp_config = ModbusTcpConfig::new("127.0.0.1", port).unwrap();
tcp_config.connection_timeout_ms = 500; let config = ModbusConfig::Tcp(tcp_config);
let result = transport.connect(&config);
assert!(result.is_ok());
assert!(transport.is_connected());
server_handle.join().unwrap();
}
#[test]
fn test_connect_with_no_timeout() {
let listener = create_test_listener();
let addr = listener.local_addr().unwrap();
let (tx, rx) = mpsc::channel();
let server_handle = thread::spawn(move || {
tx.send(()).expect("Failed to send server ready signal");
let _ = listener.accept().unwrap();
});
rx.recv().expect("Failed to receive server ready signal");
let mut transport = StdTcpTransport::new(); let port = get_host_port(addr);
let mut tcp_config = ModbusTcpConfig::new("127.0.0.1", port).unwrap();
tcp_config.connection_timeout_ms = 500; let config = ModbusConfig::Tcp(tcp_config);
let result = transport.connect(&config);
assert!(result.is_ok());
assert!(transport.is_connected());
server_handle.join().unwrap();
}
#[test]
fn test_send_failure_connection_reset() {
let listener = create_test_listener();
let addr = listener.local_addr().unwrap();
let (tx, rx) = mpsc::channel();
let test_data = [0x01, 0x02, 0x03, 0x04];
let server_handle = thread::spawn(move || {
tx.send(()).expect("Failed to send server ready signal");
let (stream, _) = listener.accept().unwrap();
drop(stream); });
rx.recv().expect("Failed to receive server ready signal");
let mut transport = StdTcpTransport::new();
let port = get_host_port(addr);
let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
transport.connect(&config).unwrap();
assert!(transport.is_connected());
let mut recv_result = transport.recv();
for _ in 0..50 {
if let Err(TransportError::Timeout) = recv_result {
std::thread::sleep(Duration::from_millis(10));
recv_result = transport.recv();
} else {
break;
}
}
assert!(recv_result.is_err());
assert_eq!(recv_result.unwrap_err(), TransportError::ConnectionClosed);
assert!(!transport.is_connected());
let result = transport.send(&test_data);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), TransportError::ConnectionClosed);
server_handle.join().unwrap();
}
#[test]
fn test_connect_success_single_addr() {
let listener = create_test_listener();
let addr = listener.local_addr().unwrap();
let (tx, rx) = mpsc::channel();
let server_handle = thread::spawn(move || {
tx.send(()).expect("Failed to send server ready signal");
let _ = listener.accept().unwrap(); });
rx.recv().expect("Failed to receive server ready signal");
let mut transport = StdTcpTransport::new();
let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", addr.port()).unwrap());
let result = transport.connect(&config);
assert!(
result.is_ok(),
"Connection should succeed with a single address"
);
assert!(transport.is_connected());
server_handle.join().unwrap();
}
}