1use std::io::{self, Read, Write};
2use std::net::{TcpStream, ToSocketAddrs};
3use std::time::Duration;
4
5use heapless::Vec;
6use mbus_core::data_unit::common::MAX_ADU_FRAME_LEN;
7use mbus_core::transport::{ModbusConfig, Transport, TransportError, TransportType};
8
9#[cfg(feature = "logging")]
10macro_rules! transport_log_error {
11 ($($arg:tt)*) => {
12 log::error!($($arg)*)
13 };
14}
15
16#[cfg(not(feature = "logging"))]
17macro_rules! transport_log_error {
18 ($($arg:tt)*) => {{
19 let _ = core::format_args!($($arg)*);
20 }};
21}
22
23#[cfg(feature = "logging")]
24macro_rules! transport_log_warn {
25 ($($arg:tt)*) => {
26 log::warn!($($arg)*)
27 };
28}
29
30#[cfg(not(feature = "logging"))]
31macro_rules! transport_log_warn {
32 ($($arg:tt)*) => {{
33 let _ = core::format_args!($($arg)*);
34 }};
35}
36
37#[cfg(feature = "logging")]
38macro_rules! transport_log_debug {
39 ($($arg:tt)*) => {
40 log::debug!($($arg)*)
41 };
42}
43
44#[cfg(not(feature = "logging"))]
45macro_rules! transport_log_debug {
46 ($($arg:tt)*) => {{
47 let _ = core::format_args!($($arg)*);
48 }};
49}
50
51#[derive(Debug, Default)]
55pub struct StdTcpTransport {
56 stream: Option<TcpStream>,
57}
58
59impl StdTcpTransport {
60 pub fn new() -> Self {
70 Self { stream: None }
71 }
72
73 fn map_io_error(err: io::Error) -> TransportError {
77 match err.kind() {
78 io::ErrorKind::ConnectionRefused | io::ErrorKind::NotFound => {
79 TransportError::ConnectionFailed
80 }
81 io::ErrorKind::BrokenPipe
82 | io::ErrorKind::ConnectionReset
83 | io::ErrorKind::UnexpectedEof => TransportError::ConnectionClosed,
84 io::ErrorKind::WouldBlock | io::ErrorKind::TimedOut => TransportError::Timeout,
85 _ => TransportError::IoError,
86 }
87 }
88}
89
90impl Transport for StdTcpTransport {
91 type Error = TransportError;
92 const TRANSPORT_TYPE: TransportType = TransportType::StdTcp;
93
94 fn connect(&mut self, config: &ModbusConfig) -> Result<(), Self::Error> {
103 let config = match config {
104 ModbusConfig::Tcp(c) => c,
105 _ => return Err(TransportError::Unexpected),
106 };
107
108 let connection_timeout = Duration::from_millis(config.connection_timeout_ms as u64);
109 let response_timeout = Duration::from_millis(config.response_timeout_ms as u64);
110
111 let mut addrs_iter = (config.host.as_str(), config.port)
114 .to_socket_addrs()
115 .map_err(|e| {
116 transport_log_error!("DNS resolution failed: {:?}", e);
117 TransportError::ConnectionFailed
118 })?;
119
120 let addr = addrs_iter.next().ok_or_else(|| {
122 transport_log_error!("No valid address found for host:port combination.");
123 TransportError::ConnectionFailed
124 })?;
125
126 transport_log_debug!("Trying address: {:?}", addr);
127
128 match TcpStream::connect_timeout(&addr, connection_timeout) {
129 Ok(stream) => {
130 stream
133 .set_read_timeout(Some(response_timeout))
134 .unwrap_or_else(|e| transport_log_warn!("Failed to set read timeout: {:?}", e));
135 stream
136 .set_write_timeout(Some(response_timeout))
137 .unwrap_or_else(|e| {
138 transport_log_warn!("Failed to set write timeout: {:?}", e)
139 });
140 stream
141 .set_nodelay(true)
142 .unwrap_or_else(|e| transport_log_warn!("Failed to set no-delay: {:?}", e));
143
144 self.stream = Some(stream); Ok(()) }
147 Err(e) => {
148 transport_log_error!("Connect failed: {:?}", e);
149 Err(TransportError::ConnectionFailed) }
151 }
152 }
153
154 fn disconnect(&mut self) -> Result<(), Self::Error> {
158 if let Some(stream) = self.stream.take() {
161 drop(stream);
162 }
163 Ok(())
164 }
165
166 fn send(&mut self, adu: &[u8]) -> Result<(), Self::Error> {
174 let stream = self
175 .stream
176 .as_mut()
177 .ok_or(TransportError::ConnectionClosed)?;
178
179 let result = stream.write_all(adu).and_then(|()| stream.flush());
180
181 if let Err(err) = result {
182 let transport_error = Self::map_io_error(err);
183 if transport_error == TransportError::ConnectionClosed {
184 self.stream = None;
185 }
186 return Err(transport_error);
187 }
188
189 Ok(())
190 }
191
192 fn recv(&mut self) -> Result<Vec<u8, MAX_ADU_FRAME_LEN>, Self::Error> {
202 let stream = self
203 .stream
204 .as_mut()
205 .ok_or(TransportError::ConnectionClosed)?;
206
207 let _ = stream.set_nonblocking(true);
210
211 let mut temp_buf = [0u8; MAX_ADU_FRAME_LEN];
212 let read_result = stream.read(&mut temp_buf);
213
214 let _ = stream.set_nonblocking(false);
217
218 match read_result {
219 Ok(0) => {
220 self.stream = None;
222 Err(TransportError::ConnectionClosed)
223 }
224 Ok(n) => {
225 let mut buffer = Vec::new();
226 if buffer.extend_from_slice(&temp_buf[..n]).is_err() {
228 return Err(TransportError::BufferTooSmall);
229 }
230 Ok(buffer)
231 }
232 Err(e) => {
233 let err = Self::map_io_error(e);
234 if err == TransportError::ConnectionClosed {
235 self.stream = None;
236 }
237 Err(err)
239 }
240 }
241 }
242
243 fn is_connected(&self) -> bool {
247 self.stream.is_some()
248 }
249}
250
251#[cfg(test)]
252impl StdTcpTransport {
253 pub fn stream_mut(&mut self) -> Option<&mut TcpStream> {
254 self.stream.as_mut()
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::super::std_transport::StdTcpTransport;
261 use mbus_core::transport::{ModbusConfig, ModbusTcpConfig, Transport, TransportError};
262 use std::io::{self, Read, Write};
263 use std::net::TcpListener;
264 use std::sync::mpsc;
265 use std::thread;
266 use std::time::Duration;
267
268 fn create_test_listener() -> TcpListener {
271 TcpListener::bind("127.0.0.1:0").expect("Failed to bind to an available port")
272 }
273
274 fn get_host_port(addr: std::net::SocketAddr) -> u16 {
276 addr.port()
277 }
278
279 #[test]
281 fn test_new_std_tcp_transport() {
282 let transport = StdTcpTransport::new();
283 assert!(!transport.is_connected());
284 }
285
286 #[test]
290 fn test_connect_success() {
291 let listener = create_test_listener();
292 let addr = listener.local_addr().unwrap();
293 let (tx, rx) = mpsc::channel();
294
295 let server_handle = thread::spawn(move || {
296 tx.send(()).expect("Failed to send server ready signal"); let _ = listener.accept().unwrap();
299 });
300
301 rx.recv().expect("Failed to receive server ready signal"); let mut transport = StdTcpTransport::new();
304 let port = get_host_port(addr);
305 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
306 let result = transport.connect(&config);
307 assert!(result.is_ok());
308 assert!(transport.is_connected());
309
310 server_handle.join().unwrap();
311 }
312
313 #[test]
315 fn test_connect_failure_invalid_addr() {
316 let mut transport = StdTcpTransport::new();
317 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("invalid-address", 502).unwrap()); let result = transport.connect(&config);
319 assert!(result.is_err());
320 assert_eq!(result.unwrap_err(), TransportError::ConnectionFailed);
321 assert!(!transport.is_connected());
322 }
323
324 #[test]
328 fn test_connect_failure_connection_refused() {
329 let listener = create_test_listener(); let port = listener.local_addr().unwrap().port();
332 drop(listener); let mut transport = StdTcpTransport::new();
334 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
335 let result = transport.connect(&config);
336 assert!(result.is_err());
337 assert_eq!(result.unwrap_err(), TransportError::ConnectionFailed);
338 assert!(!transport.is_connected());
339 }
340
341 #[test]
343 fn test_disconnect() {
344 let listener = create_test_listener();
345 let addr = listener.local_addr().unwrap();
346 let (tx, rx) = mpsc::channel();
347
348 let server_handle = thread::spawn(move || {
349 tx.send(()).expect("Failed to send server ready signal");
350 let _ = listener.accept().unwrap(); });
352
353 rx.recv().expect("Failed to receive server ready signal");
354
355 let mut transport = StdTcpTransport::new();
356 let port = get_host_port(addr);
357 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
358 transport.connect(&config).unwrap();
359 assert!(transport.is_connected());
360
361 let result = transport.disconnect();
362 assert!(result.is_ok());
363 assert!(!transport.is_connected());
364
365 server_handle.join().unwrap();
366 }
367
368 #[test]
372 fn test_send_success() {
373 let listener = create_test_listener();
374 let addr = listener.local_addr().unwrap();
375 let (tx, rx) = mpsc::channel();
376 let test_data = [0x01, 0x02, 0x03, 0x04];
377
378 let server_handle = thread::spawn(move || {
379 tx.send(()).expect("Failed to send server ready signal");
380 let (mut stream, _) = listener.accept().unwrap();
381 let mut buf = [0; 4];
382 stream.read_exact(&mut buf).unwrap();
383 assert_eq!(buf, test_data);
384 });
385
386 rx.recv().expect("Failed to receive server ready signal");
387
388 let mut transport = StdTcpTransport::new();
389 let port = get_host_port(addr);
390 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
391 transport.connect(&config).unwrap();
392
393 let result = transport.send(&test_data);
394 assert!(result.is_ok());
395
396 server_handle.join().unwrap();
397 }
398
399 #[test]
401 fn test_send_failure_not_connected() {
402 let mut transport = StdTcpTransport::new();
403 let test_data = [0x01, 0x02];
404 let result = transport.send(&test_data);
405 assert!(result.is_err());
406 assert_eq!(result.unwrap_err(), TransportError::ConnectionClosed);
407 }
408
409 #[test]
413 fn test_recv_success_full_adu() {
414 let listener = create_test_listener();
415 let addr = listener.local_addr().unwrap();
416 let (tx, rx) = mpsc::channel();
417 let adu_to_send = [0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x01, 0x03, 0x00];
419
420 let server_handle = thread::spawn(move || {
421 tx.send(()).expect("Failed to send server ready signal");
422 let (mut stream, _) = listener.accept().unwrap();
423 stream.write_all(&adu_to_send).unwrap();
424 });
425
426 rx.recv().expect("Failed to receive server ready signal");
427
428 let mut transport = StdTcpTransport::new();
429 let port = get_host_port(addr);
430 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
431
432 transport.connect(&config).unwrap();
433
434 let mut combined_adu = std::vec::Vec::new();
437 for _ in 0..50 {
438 match transport.recv() {
439 Ok(bytes) => {
440 combined_adu.extend_from_slice(&bytes);
441 if combined_adu.len() == adu_to_send.len() {
442 break;
443 }
444 }
445 Err(TransportError::Timeout) => {
446 std::thread::sleep(Duration::from_millis(10));
447 }
448 Err(e) => panic!("Unexpected error: {:?}", e),
449 }
450 }
451 assert_eq!(combined_adu.as_slice(), adu_to_send);
452
453 server_handle.join().unwrap();
454 }
455
456 #[test]
458 fn test_recv_failure_not_connected() {
459 let mut transport = StdTcpTransport::new();
460 let result = transport.recv();
461 assert!(result.is_err());
462 assert_eq!(result.unwrap_err(), TransportError::ConnectionClosed);
463 }
464
465 #[test]
467 fn test_recv_failure_connection_closed_prematurely_header() {
468 let listener = create_test_listener();
469 let addr = listener.local_addr().unwrap();
470 let (tx, rx) = mpsc::channel();
471 let partial_adu = [0x00, 0x01, 0x00];
473
474 let server_handle = thread::spawn(move || {
475 tx.send(()).expect("Failed to send server ready signal");
476 let (mut stream, _) = listener.accept().unwrap();
477 stream.write_all(&partial_adu).unwrap();
478 });
480
481 rx.recv().expect("Failed to receive server ready signal");
482
483 let mut transport = StdTcpTransport::new();
484 let port = get_host_port(addr);
485 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
486 transport.connect(&config).unwrap();
487
488 let mut result = transport.recv();
489 for _ in 0..50 {
490 if let Err(TransportError::Timeout) = result {
491 std::thread::sleep(Duration::from_millis(10));
492 result = transport.recv();
493 } else if let Ok(_) = result {
494 result = transport.recv();
495 } else {
496 break;
497 }
498 }
499 assert!(result.is_err());
500 assert_eq!(result.unwrap_err(), TransportError::ConnectionClosed);
501
502 server_handle.join().unwrap();
503 }
504
505 #[test]
507 fn test_recv_failure_connection_closed_prematurely_pdu() {
508 let listener = create_test_listener();
509 let addr = listener.local_addr().unwrap();
510 let (tx, rx) = mpsc::channel();
511 let partial_adu = [0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x01, 0x03]; let server_handle = thread::spawn(move || {
516 tx.send(()).expect("Failed to send server ready signal");
517 let (mut stream, _) = listener.accept().unwrap();
518 stream.write_all(&partial_adu).unwrap();
519 });
521
522 rx.recv().expect("Failed to receive server ready signal");
523
524 let mut transport = StdTcpTransport::new();
525 let port = get_host_port(addr);
526 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
527 transport.connect(&config).unwrap();
528
529 let mut result = transport.recv();
530 for _ in 0..50 {
531 if let Err(TransportError::Timeout) = result {
532 std::thread::sleep(Duration::from_millis(10));
533 result = transport.recv();
534 } else if let Ok(_) = result {
535 result = transport.recv();
536 } else {
537 break;
538 }
539 }
540 assert!(result.is_err());
541 assert_eq!(result.unwrap_err(), TransportError::ConnectionClosed);
542
543 server_handle.join().unwrap();
544 }
545
546 #[test]
548 fn test_recv_timeout() {
549 let listener = create_test_listener();
550 let addr = listener.local_addr().unwrap();
551 let (tx, rx) = mpsc::channel();
552
553 let server_handle = thread::spawn(move || {
554 tx.send(()).expect("Failed to send server ready signal");
555 let (_stream, _) = listener.accept().unwrap();
556 thread::sleep(Duration::from_secs(5)); });
559
560 rx.recv().expect("Failed to receive server ready signal");
561
562 let mut transport = StdTcpTransport::new(); let port = get_host_port(addr);
564 let mut tcp_config = ModbusTcpConfig::new("127.0.0.1", port).unwrap();
565 tcp_config.response_timeout_ms = 100; let config = ModbusConfig::Tcp(tcp_config);
567 transport.connect(&config).unwrap();
568
569 let result = transport.recv();
570 assert!(result.is_err());
571 assert_eq!(result.unwrap_err(), TransportError::Timeout);
572
573 server_handle.join().unwrap();
574 }
575
576 #[test]
578 fn test_is_connected() {
579 let listener = create_test_listener();
580 let addr = listener.local_addr().unwrap();
581 let (tx, rx) = mpsc::channel();
582
583 let server_handle = thread::spawn(move || {
584 tx.send(()).expect("Failed to send server ready signal");
585 let (_stream, _) = listener.accept().unwrap();
586 thread::sleep(Duration::from_millis(500)); });
588
589 rx.recv().expect("Failed to receive server ready signal");
590
591 let mut transport = StdTcpTransport::new();
592 let port = get_host_port(addr);
593 assert!(!transport.is_connected());
594
595 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
596 transport.connect(&config).unwrap();
597
598 assert!(transport.is_connected());
599
600 transport.disconnect().unwrap();
601 assert!(!transport.is_connected());
602
603 server_handle.join().unwrap();
604 }
605
606 #[test]
608 fn test_map_io_error() {
609 let err = io::Error::new(io::ErrorKind::ConnectionRefused, "test");
611 assert_eq!(
612 StdTcpTransport::map_io_error(err),
613 TransportError::ConnectionFailed
614 );
615
616 let err = io::Error::new(io::ErrorKind::NotFound, "test");
618 assert_eq!(
619 StdTcpTransport::map_io_error(err),
620 TransportError::ConnectionFailed
621 );
622
623 let err = io::Error::new(io::ErrorKind::BrokenPipe, "test");
625 assert_eq!(
626 StdTcpTransport::map_io_error(err),
627 TransportError::ConnectionClosed
628 );
629
630 let err = io::Error::new(io::ErrorKind::ConnectionReset, "test");
632 assert_eq!(
633 StdTcpTransport::map_io_error(err),
634 TransportError::ConnectionClosed
635 );
636
637 let err = io::Error::new(io::ErrorKind::UnexpectedEof, "test");
639 assert_eq!(
640 StdTcpTransport::map_io_error(err),
641 TransportError::ConnectionClosed
642 );
643
644 let err = io::Error::new(io::ErrorKind::WouldBlock, "test");
646 assert_eq!(StdTcpTransport::map_io_error(err), TransportError::Timeout);
647
648 let err = io::Error::new(io::ErrorKind::TimedOut, "test");
650 assert_eq!(StdTcpTransport::map_io_error(err), TransportError::Timeout);
651
652 let err = io::Error::new(io::ErrorKind::PermissionDenied, "test");
654 assert_eq!(StdTcpTransport::map_io_error(err), TransportError::IoError);
655 }
656
657 #[test]
659 fn test_connect_with_custom_timeout() {
660 let listener = create_test_listener();
661 let addr = listener.local_addr().unwrap();
662 let (tx, rx) = mpsc::channel();
663
664 let server_handle = thread::spawn(move || {
665 tx.send(()).expect("Failed to send server ready signal");
666 let _ = listener.accept().unwrap();
667 });
668
669 rx.recv().expect("Failed to receive server ready signal");
670
671 let mut transport = StdTcpTransport::new(); let port = get_host_port(addr);
673 let mut tcp_config = ModbusTcpConfig::new("127.0.0.1", port).unwrap();
674 tcp_config.connection_timeout_ms = 500; let config = ModbusConfig::Tcp(tcp_config);
676 let result = transport.connect(&config);
677 assert!(result.is_ok());
678 assert!(transport.is_connected());
679
680 server_handle.join().unwrap();
681 }
682
683 #[test]
685 fn test_connect_with_no_timeout() {
686 let listener = create_test_listener();
687 let addr = listener.local_addr().unwrap();
688 let (tx, rx) = mpsc::channel();
689
690 let server_handle = thread::spawn(move || {
691 tx.send(()).expect("Failed to send server ready signal");
692 let _ = listener.accept().unwrap();
693 });
694
695 rx.recv().expect("Failed to receive server ready signal");
696
697 let mut transport = StdTcpTransport::new(); let port = get_host_port(addr);
699 let mut tcp_config = ModbusTcpConfig::new("127.0.0.1", port).unwrap();
700 tcp_config.connection_timeout_ms = 500; let config = ModbusConfig::Tcp(tcp_config);
702 let result = transport.connect(&config);
703 assert!(result.is_ok());
704 assert!(transport.is_connected());
705
706 server_handle.join().unwrap();
707 }
708
709 #[test]
711 fn test_send_failure_connection_reset() {
712 let listener = create_test_listener();
713 let addr = listener.local_addr().unwrap();
714 let (tx, rx) = mpsc::channel();
715 let test_data = [0x01, 0x02, 0x03, 0x04];
716
717 let server_handle = thread::spawn(move || {
718 tx.send(()).expect("Failed to send server ready signal");
719 let (stream, _) = listener.accept().unwrap();
720 drop(stream); });
722
723 rx.recv().expect("Failed to receive server ready signal");
724
725 let mut transport = StdTcpTransport::new();
726 let port = get_host_port(addr);
727 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", port).unwrap());
728
729 transport.connect(&config).unwrap();
730
731 assert!(transport.is_connected());
732
733 let mut recv_result = transport.recv();
736 for _ in 0..50 {
737 if let Err(TransportError::Timeout) = recv_result {
738 std::thread::sleep(Duration::from_millis(10));
739 recv_result = transport.recv();
740 } else {
741 break;
742 }
743 }
744 assert!(recv_result.is_err());
745 assert_eq!(recv_result.unwrap_err(), TransportError::ConnectionClosed);
746 assert!(!transport.is_connected());
748
749 let result = transport.send(&test_data);
751 assert!(result.is_err());
752 assert_eq!(result.unwrap_err(), TransportError::ConnectionClosed);
753
754 server_handle.join().unwrap();
755 }
756
757 #[test]
759 fn test_connect_success_single_addr() {
760 let listener = create_test_listener();
761 let addr = listener.local_addr().unwrap();
762 let (tx, rx) = mpsc::channel();
763
764 let server_handle = thread::spawn(move || {
766 tx.send(()).expect("Failed to send server ready signal");
767 let _ = listener.accept().unwrap(); });
769
770 rx.recv().expect("Failed to receive server ready signal");
771
772 let mut transport = StdTcpTransport::new();
773 let config = ModbusConfig::Tcp(ModbusTcpConfig::new("127.0.0.1", addr.port()).unwrap());
774
775 let result = transport.connect(&config);
776 assert!(
777 result.is_ok(),
778 "Connection should succeed with a single address"
779 );
780 assert!(transport.is_connected());
781
782 server_handle.join().unwrap();
783 }
784}