1use std::io::{self, Read, Write};
2use std::net::{TcpStream, ToSocketAddrs};
3use std::time::Duration;
4
5use heapless::Vec;
6use mbus_core::data_unit::common::MAX_ADU_FRAME_LEN;
7use mbus_core::transport::{ModbusConfig, Transport, TransportError, TransportType};
8
9#[cfg(feature = "logging")]
10macro_rules! transport_log_error {
11 ($($arg:tt)*) => {
12 log::error!($($arg)*)
13 };
14}
15
16#[cfg(not(feature = "logging"))]
17macro_rules! transport_log_error {
18 ($($arg:tt)*) => {{
19 let _ = core::format_args!($($arg)*);
20 }};
21}
22
23#[cfg(feature = "logging")]
24macro_rules! transport_log_warn {
25 ($($arg:tt)*) => {
26 log::warn!($($arg)*)
27 };
28}
29
30#[cfg(not(feature = "logging"))]
31macro_rules! transport_log_warn {
32 ($($arg:tt)*) => {{
33 let _ = core::format_args!($($arg)*);
34 }};
35}
36
37#[cfg(feature = "logging")]
38macro_rules! transport_log_debug {
39 ($($arg:tt)*) => {
40 log::debug!($($arg)*)
41 };
42}
43
44#[cfg(not(feature = "logging"))]
45macro_rules! transport_log_debug {
46 ($($arg:tt)*) => {{
47 let _ = core::format_args!($($arg)*);
48 }};
49}
50
51#[derive(Debug, Default)]
55pub struct StdTcpTransport {
56 stream: Option<TcpStream>,
57}
58
59impl StdTcpTransport {
60 pub fn new() -> Self {
70 Self { stream: None }
71 }
72
73 fn map_io_error(err: io::Error) -> TransportError {
77 match err.kind() {
78 io::ErrorKind::ConnectionRefused | io::ErrorKind::NotFound => {
79 TransportError::ConnectionFailed
80 }
81 io::ErrorKind::BrokenPipe
82 | io::ErrorKind::ConnectionReset
83 | io::ErrorKind::UnexpectedEof => TransportError::ConnectionClosed,
84 io::ErrorKind::WouldBlock | io::ErrorKind::TimedOut => TransportError::Timeout,
85 _ => TransportError::IoError,
86 }
87 }
88}
89
90impl Transport for StdTcpTransport {
91 type Error = TransportError;
92
93 fn connect(&mut self, config: &ModbusConfig) -> Result<(), Self::Error> {
102 let config = match config {
103 ModbusConfig::Tcp(c) => c,
104 _ => return Err(TransportError::Unexpected),
105 };
106
107 let connection_timeout = Duration::from_millis(config.connection_timeout_ms as u64);
108 let response_timeout = Duration::from_millis(config.response_timeout_ms as u64);
109
110 let mut addrs_iter = (config.host.as_str(), config.port)
113 .to_socket_addrs()
114 .map_err(|e| {
115 transport_log_error!("DNS resolution failed: {:?}", e);
116 TransportError::ConnectionFailed
117 })?;
118
119 let addr = addrs_iter.next().ok_or_else(|| {
121 transport_log_error!("No valid address found for host:port combination.");
122 TransportError::ConnectionFailed
123 })?;
124
125 transport_log_debug!("Trying address: {:?}", addr);
126
127 match TcpStream::connect_timeout(&addr, connection_timeout) {
128 Ok(stream) => {
129 stream
132 .set_read_timeout(Some(response_timeout))
133 .unwrap_or_else(|e| transport_log_warn!("Failed to set read timeout: {:?}", e));
134 stream
135 .set_write_timeout(Some(response_timeout))
136 .unwrap_or_else(|e| {
137 transport_log_warn!("Failed to set write timeout: {:?}", e)
138 });
139 stream
140 .set_nodelay(true)
141 .unwrap_or_else(|e| transport_log_warn!("Failed to set no-delay: {:?}", e));
142
143 self.stream = Some(stream); Ok(()) }
146 Err(e) => {
147 transport_log_error!("Connect failed: {:?}", e);
148 Err(TransportError::ConnectionFailed) }
150 }
151 }
152
153 fn disconnect(&mut self) -> Result<(), Self::Error> {
157 if let Some(stream) = self.stream.take() {
160 drop(stream);
161 }
162 Ok(())
163 }
164
165 fn send(&mut self, adu: &[u8]) -> Result<(), Self::Error> {
173 let stream = self
174 .stream
175 .as_mut()
176 .ok_or(TransportError::ConnectionClosed)?;
177
178 let result = stream.write_all(adu).and_then(|()| stream.flush());
179
180 if let Err(err) = result {
181 let transport_error = Self::map_io_error(err);
182 if transport_error == TransportError::ConnectionClosed {
183 self.stream = None;
184 }
185 return Err(transport_error);
186 }
187
188 Ok(())
189 }
190
191 fn recv(&mut self) -> Result<Vec<u8, MAX_ADU_FRAME_LEN>, Self::Error> {
201 let stream = self
202 .stream
203 .as_mut()
204 .ok_or(TransportError::ConnectionClosed)?;
205
206 let _ = stream.set_nonblocking(true);
209
210 let mut temp_buf = [0u8; MAX_ADU_FRAME_LEN];
211 let read_result = stream.read(&mut temp_buf);
212
213 let _ = stream.set_nonblocking(false);
216
217 match read_result {
218 Ok(0) => {
219 self.stream = None;
221 Err(TransportError::ConnectionClosed)
222 }
223 Ok(n) => {
224 let mut buffer = Vec::new();
225 if buffer.extend_from_slice(&temp_buf[..n]).is_err() {
227 return Err(TransportError::BufferTooSmall);
228 }
229 Ok(buffer)
230 }
231 Err(e) => {
232 let err = Self::map_io_error(e);
233 if err == TransportError::ConnectionClosed {
234 self.stream = None;
235 }
236 Err(err)
238 }
239 }
240 }
241
242 fn is_connected(&self) -> bool {
246 self.stream.is_some()
247 }
248
249 fn transport_type(&self) -> TransportType {
251 TransportType::StdTcp
252 }
253}
254
255#[cfg(test)]
256impl StdTcpTransport {
257 pub fn stream_mut(&mut self) -> Option<&mut TcpStream> {
258 self.stream.as_mut()
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use super::super::std_transport::StdTcpTransport;
265 use mbus_core::transport::{ModbusConfig, ModbusTcpConfig, Transport, TransportError};
266 use std::io::{self, Read, Write};
267 use std::net::TcpListener;
268 use std::sync::mpsc;
269 use std::thread;
270 use std::time::Duration;
271
272 fn create_test_listener() -> TcpListener {
275 TcpListener::bind("127.0.0.1:0").expect("Failed to bind to an available port")
276 }
277
278 fn get_host_port(addr: std::net::SocketAddr) -> u16 {
280 addr.port()
281 }
282
283 #[test]
285 fn test_new_std_tcp_transport() {
286 let transport = StdTcpTransport::new();
287 assert!(!transport.is_connected());
288 }
289
290 #[test]
294 fn test_connect_success() {
295 let listener = create_test_listener();
296 let addr = listener.local_addr().unwrap();
297 let (tx, rx) = mpsc::channel();
298
299 let server_handle = thread::spawn(move || {
300 tx.send(()).expect("Failed to send server ready signal"); let _ = listener.accept().unwrap();
303 });
304
305 rx.recv().expect("Failed to receive server ready signal"); let mut transport = StdTcpTransport::new();
308 let port = get_host_port(addr);
309 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
310 let result = transport.connect(&config);
311 assert!(result.is_ok());
312 assert!(transport.is_connected());
313
314 server_handle.join().unwrap();
315 }
316
317 #[test]
319 fn test_connect_failure_invalid_addr() {
320 let mut transport = StdTcpTransport::new();
321 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("invalid-address", 502).unwrap()); let result = transport.connect(&config);
323 assert!(result.is_err());
324 assert_eq!(result.unwrap_err(), TransportError::ConnectionFailed);
325 assert!(!transport.is_connected());
326 }
327
328 #[test]
332 fn test_connect_failure_connection_refused() {
333 let listener = create_test_listener(); let port = listener.local_addr().unwrap().port();
336 drop(listener); let mut transport = StdTcpTransport::new();
338 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
339 let result = transport.connect(&config);
340 assert!(result.is_err());
341 assert_eq!(result.unwrap_err(), TransportError::ConnectionFailed);
342 assert!(!transport.is_connected());
343 }
344
345 #[test]
347 fn test_disconnect() {
348 let listener = create_test_listener();
349 let addr = listener.local_addr().unwrap();
350 let (tx, rx) = mpsc::channel();
351
352 let server_handle = thread::spawn(move || {
353 tx.send(()).expect("Failed to send server ready signal");
354 let _ = listener.accept().unwrap(); });
356
357 rx.recv().expect("Failed to receive server ready signal");
358
359 let mut transport = StdTcpTransport::new();
360 let port = get_host_port(addr);
361 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
362 transport.connect(&config).unwrap();
363 assert!(transport.is_connected());
364
365 let result = transport.disconnect();
366 assert!(result.is_ok());
367 assert!(!transport.is_connected());
368
369 server_handle.join().unwrap();
370 }
371
372 #[test]
376 fn test_send_success() {
377 let listener = create_test_listener();
378 let addr = listener.local_addr().unwrap();
379 let (tx, rx) = mpsc::channel();
380 let test_data = [0x01, 0x02, 0x03, 0x04];
381
382 let server_handle = thread::spawn(move || {
383 tx.send(()).expect("Failed to send server ready signal");
384 let (mut stream, _) = listener.accept().unwrap();
385 let mut buf = [0; 4];
386 stream.read_exact(&mut buf).unwrap();
387 assert_eq!(buf, test_data);
388 });
389
390 rx.recv().expect("Failed to receive server ready signal");
391
392 let mut transport = StdTcpTransport::new();
393 let port = get_host_port(addr);
394 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
395 transport.connect(&config).unwrap();
396
397 let result = transport.send(&test_data);
398 assert!(result.is_ok());
399
400 server_handle.join().unwrap();
401 }
402
403 #[test]
405 fn test_send_failure_not_connected() {
406 let mut transport = StdTcpTransport::new();
407 let test_data = [0x01, 0x02];
408 let result = transport.send(&test_data);
409 assert!(result.is_err());
410 assert_eq!(result.unwrap_err(), TransportError::ConnectionClosed);
411 }
412
413 #[test]
417 fn test_recv_success_full_adu() {
418 let listener = create_test_listener();
419 let addr = listener.local_addr().unwrap();
420 let (tx, rx) = mpsc::channel();
421 let adu_to_send = [0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x01, 0x03, 0x00];
423
424 let server_handle = thread::spawn(move || {
425 tx.send(()).expect("Failed to send server ready signal");
426 let (mut stream, _) = listener.accept().unwrap();
427 stream.write_all(&adu_to_send).unwrap();
428 });
429
430 rx.recv().expect("Failed to receive server ready signal");
431
432 let mut transport = StdTcpTransport::new();
433 let port = get_host_port(addr);
434 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
435
436 transport.connect(&config).unwrap();
437
438 let mut combined_adu = std::vec::Vec::new();
441 for _ in 0..50 {
442 match transport.recv() {
443 Ok(bytes) => {
444 combined_adu.extend_from_slice(&bytes);
445 if combined_adu.len() == adu_to_send.len() {
446 break;
447 }
448 }
449 Err(TransportError::Timeout) => {
450 std::thread::sleep(Duration::from_millis(10));
451 }
452 Err(e) => panic!("Unexpected error: {:?}", e),
453 }
454 }
455 assert_eq!(combined_adu.as_slice(), adu_to_send);
456
457 server_handle.join().unwrap();
458 }
459
460 #[test]
462 fn test_recv_failure_not_connected() {
463 let mut transport = StdTcpTransport::new();
464 let result = transport.recv();
465 assert!(result.is_err());
466 assert_eq!(result.unwrap_err(), TransportError::ConnectionClosed);
467 }
468
469 #[test]
471 fn test_recv_failure_connection_closed_prematurely_header() {
472 let listener = create_test_listener();
473 let addr = listener.local_addr().unwrap();
474 let (tx, rx) = mpsc::channel();
475 let partial_adu = [0x00, 0x01, 0x00];
477
478 let server_handle = thread::spawn(move || {
479 tx.send(()).expect("Failed to send server ready signal");
480 let (mut stream, _) = listener.accept().unwrap();
481 stream.write_all(&partial_adu).unwrap();
482 });
484
485 rx.recv().expect("Failed to receive server ready signal");
486
487 let mut transport = StdTcpTransport::new();
488 let port = get_host_port(addr);
489 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
490 transport.connect(&config).unwrap();
491
492 let mut result = transport.recv();
493 for _ in 0..50 {
494 if let Err(TransportError::Timeout) = result {
495 std::thread::sleep(Duration::from_millis(10));
496 result = transport.recv();
497 } else if let Ok(_) = result {
498 result = transport.recv();
499 } else {
500 break;
501 }
502 }
503 assert!(result.is_err());
504 assert_eq!(result.unwrap_err(), TransportError::ConnectionClosed);
505
506 server_handle.join().unwrap();
507 }
508
509 #[test]
511 fn test_recv_failure_connection_closed_prematurely_pdu() {
512 let listener = create_test_listener();
513 let addr = listener.local_addr().unwrap();
514 let (tx, rx) = mpsc::channel();
515 let partial_adu = [0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x01, 0x03]; let server_handle = thread::spawn(move || {
520 tx.send(()).expect("Failed to send server ready signal");
521 let (mut stream, _) = listener.accept().unwrap();
522 stream.write_all(&partial_adu).unwrap();
523 });
525
526 rx.recv().expect("Failed to receive server ready signal");
527
528 let mut transport = StdTcpTransport::new();
529 let port = get_host_port(addr);
530 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
531 transport.connect(&config).unwrap();
532
533 let mut result = transport.recv();
534 for _ in 0..50 {
535 if let Err(TransportError::Timeout) = result {
536 std::thread::sleep(Duration::from_millis(10));
537 result = transport.recv();
538 } else if let Ok(_) = result {
539 result = transport.recv();
540 } else {
541 break;
542 }
543 }
544 assert!(result.is_err());
545 assert_eq!(result.unwrap_err(), TransportError::ConnectionClosed);
546
547 server_handle.join().unwrap();
548 }
549
550 #[test]
552 fn test_recv_timeout() {
553 let listener = create_test_listener();
554 let addr = listener.local_addr().unwrap();
555 let (tx, rx) = mpsc::channel();
556
557 let server_handle = thread::spawn(move || {
558 tx.send(()).expect("Failed to send server ready signal");
559 let (_stream, _) = listener.accept().unwrap();
560 thread::sleep(Duration::from_secs(5)); });
563
564 rx.recv().expect("Failed to receive server ready signal");
565
566 let mut transport = StdTcpTransport::new(); let port = get_host_port(addr);
568 let mut tcp_config = ModbusTcpConfig::new("127.0.0.1", port).unwrap();
569 tcp_config.response_timeout_ms = 100; let config = ModbusConfig::Tcp(tcp_config);
571 transport.connect(&config).unwrap();
572
573 let result = transport.recv();
574 assert!(result.is_err());
575 assert_eq!(result.unwrap_err(), TransportError::Timeout);
576
577 server_handle.join().unwrap();
578 }
579
580 #[test]
582 fn test_is_connected() {
583 let listener = create_test_listener();
584 let addr = listener.local_addr().unwrap();
585 let (tx, rx) = mpsc::channel();
586
587 let server_handle = thread::spawn(move || {
588 tx.send(()).expect("Failed to send server ready signal");
589 let (_stream, _) = listener.accept().unwrap();
590 thread::sleep(Duration::from_millis(500)); });
592
593 rx.recv().expect("Failed to receive server ready signal");
594
595 let mut transport = StdTcpTransport::new();
596 let port = get_host_port(addr);
597 assert!(!transport.is_connected());
598
599 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
600 transport.connect(&config).unwrap();
601
602 assert!(transport.is_connected());
603
604 transport.disconnect().unwrap();
605 assert!(!transport.is_connected());
606
607 server_handle.join().unwrap();
608 }
609
610 #[test]
612 fn test_map_io_error() {
613 let err = io::Error::new(io::ErrorKind::ConnectionRefused, "test");
615 assert_eq!(
616 StdTcpTransport::map_io_error(err),
617 TransportError::ConnectionFailed
618 );
619
620 let err = io::Error::new(io::ErrorKind::NotFound, "test");
622 assert_eq!(
623 StdTcpTransport::map_io_error(err),
624 TransportError::ConnectionFailed
625 );
626
627 let err = io::Error::new(io::ErrorKind::BrokenPipe, "test");
629 assert_eq!(
630 StdTcpTransport::map_io_error(err),
631 TransportError::ConnectionClosed
632 );
633
634 let err = io::Error::new(io::ErrorKind::ConnectionReset, "test");
636 assert_eq!(
637 StdTcpTransport::map_io_error(err),
638 TransportError::ConnectionClosed
639 );
640
641 let err = io::Error::new(io::ErrorKind::UnexpectedEof, "test");
643 assert_eq!(
644 StdTcpTransport::map_io_error(err),
645 TransportError::ConnectionClosed
646 );
647
648 let err = io::Error::new(io::ErrorKind::WouldBlock, "test");
650 assert_eq!(StdTcpTransport::map_io_error(err), TransportError::Timeout);
651
652 let err = io::Error::new(io::ErrorKind::TimedOut, "test");
654 assert_eq!(StdTcpTransport::map_io_error(err), TransportError::Timeout);
655
656 let err = io::Error::new(io::ErrorKind::PermissionDenied, "test");
658 assert_eq!(StdTcpTransport::map_io_error(err), TransportError::IoError);
659 }
660
661 #[test]
663 fn test_connect_with_custom_timeout() {
664 let listener = create_test_listener();
665 let addr = listener.local_addr().unwrap();
666 let (tx, rx) = mpsc::channel();
667
668 let server_handle = thread::spawn(move || {
669 tx.send(()).expect("Failed to send server ready signal");
670 let _ = listener.accept().unwrap();
671 });
672
673 rx.recv().expect("Failed to receive server ready signal");
674
675 let mut transport = StdTcpTransport::new(); let port = get_host_port(addr);
677 let mut tcp_config = ModbusTcpConfig::new("127.0.0.1", port).unwrap();
678 tcp_config.connection_timeout_ms = 500; let config = ModbusConfig::Tcp(tcp_config);
680 let result = transport.connect(&config);
681 assert!(result.is_ok());
682 assert!(transport.is_connected());
683
684 server_handle.join().unwrap();
685 }
686
687 #[test]
689 fn test_connect_with_no_timeout() {
690 let listener = create_test_listener();
691 let addr = listener.local_addr().unwrap();
692 let (tx, rx) = mpsc::channel();
693
694 let server_handle = thread::spawn(move || {
695 tx.send(()).expect("Failed to send server ready signal");
696 let _ = listener.accept().unwrap();
697 });
698
699 rx.recv().expect("Failed to receive server ready signal");
700
701 let mut transport = StdTcpTransport::new(); let port = get_host_port(addr);
703 let mut tcp_config = ModbusTcpConfig::new("127.0.0.1", port).unwrap();
704 tcp_config.connection_timeout_ms = 500; let config = ModbusConfig::Tcp(tcp_config);
706 let result = transport.connect(&config);
707 assert!(result.is_ok());
708 assert!(transport.is_connected());
709
710 server_handle.join().unwrap();
711 }
712
713 #[test]
715 fn test_send_failure_connection_reset() {
716 let listener = create_test_listener();
717 let addr = listener.local_addr().unwrap();
718 let (tx, rx) = mpsc::channel();
719 let test_data = [0x01, 0x02, 0x03, 0x04];
720
721 let server_handle = thread::spawn(move || {
722 tx.send(()).expect("Failed to send server ready signal");
723 let (stream, _) = listener.accept().unwrap();
724 drop(stream); });
726
727 rx.recv().expect("Failed to receive server ready signal");
728
729 let mut transport = StdTcpTransport::new();
730 let port = get_host_port(addr);
731 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
732
733 transport.connect(&config).unwrap();
734
735 assert!(transport.is_connected());
736
737 let mut recv_result = transport.recv();
740 for _ in 0..50 {
741 if let Err(TransportError::Timeout) = recv_result {
742 std::thread::sleep(Duration::from_millis(10));
743 recv_result = transport.recv();
744 } else {
745 break;
746 }
747 }
748 assert!(recv_result.is_err());
749 assert_eq!(recv_result.unwrap_err(), TransportError::ConnectionClosed);
750 assert!(!transport.is_connected());
752
753 let result = transport.send(&test_data);
755 assert!(result.is_err());
756 assert_eq!(result.unwrap_err(), TransportError::ConnectionClosed);
757
758 server_handle.join().unwrap();
759 }
760
761 #[test]
763 fn test_connect_success_single_addr() {
764 let listener = create_test_listener();
765 let addr = listener.local_addr().unwrap();
766 let (tx, rx) = mpsc::channel();
767
768 let server_handle = thread::spawn(move || {
770 tx.send(()).expect("Failed to send server ready signal");
771 let _ = listener.accept().unwrap(); });
773
774 rx.recv().expect("Failed to receive server ready signal");
775
776 let mut transport = StdTcpTransport::new();
777 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", addr.port()).unwrap());
778
779 let result = transport.connect(&config);
780 assert!(
781 result.is_ok(),
782 "Connection should succeed with a single address"
783 );
784 assert!(transport.is_connected());
785
786 server_handle.join().unwrap();
787 }
788}