Skip to main content

rns_net/interface/
tcp.rs

1//! TCP client interface with HDLC framing.
2//!
3//! Matches Python `TCPClientInterface` from `TCPInterface.py`.
4
5use std::io::{self, Read, Write};
6use std::net::{TcpStream, ToSocketAddrs};
7use std::os::unix::io::AsRawFd;
8use std::thread;
9use std::time::Duration;
10
11use rns_core::transport::types::InterfaceId;
12
13use crate::event::{Event, EventSender};
14use crate::hdlc;
15use crate::interface::Writer;
16
17/// Configuration for a TCP client interface.
18#[derive(Debug, Clone)]
19pub struct TcpClientConfig {
20    pub name: String,
21    pub target_host: String,
22    pub target_port: u16,
23    pub interface_id: InterfaceId,
24    pub reconnect_wait: Duration,
25    pub max_reconnect_tries: Option<u32>,
26    pub connect_timeout: Duration,
27    /// Linux network interface to bind the socket to (e.g. "usb0").
28    pub device: Option<String>,
29}
30
31impl Default for TcpClientConfig {
32    fn default() -> Self {
33        TcpClientConfig {
34            name: String::new(),
35            target_host: "127.0.0.1".into(),
36            target_port: 4242,
37            interface_id: InterfaceId(0),
38            reconnect_wait: Duration::from_secs(5),
39            max_reconnect_tries: None,
40            connect_timeout: Duration::from_secs(5),
41            device: None,
42        }
43    }
44}
45
46/// Writer that sends HDLC-framed data over a TCP stream.
47struct TcpWriter {
48    stream: TcpStream,
49}
50
51impl Writer for TcpWriter {
52    fn send_frame(&mut self, data: &[u8]) -> io::Result<()> {
53        self.stream.write_all(&hdlc::frame(data))
54    }
55}
56
57/// Set TCP keepalive and timeout socket options (Linux).
58fn set_socket_options(stream: &TcpStream) -> io::Result<()> {
59    let fd = stream.as_raw_fd();
60    unsafe {
61        // TCP_NODELAY = 1
62        let val: libc::c_int = 1;
63        if libc::setsockopt(
64            fd,
65            libc::IPPROTO_TCP,
66            libc::TCP_NODELAY,
67            &val as *const _ as *const libc::c_void,
68            std::mem::size_of::<libc::c_int>() as libc::socklen_t,
69        ) != 0
70        {
71            return Err(io::Error::last_os_error());
72        }
73
74        // SO_KEEPALIVE = 1
75        if libc::setsockopt(
76            fd,
77            libc::SOL_SOCKET,
78            libc::SO_KEEPALIVE,
79            &val as *const _ as *const libc::c_void,
80            std::mem::size_of::<libc::c_int>() as libc::socklen_t,
81        ) != 0
82        {
83            return Err(io::Error::last_os_error());
84        }
85
86        // Linux-specific keepalive tuning and user timeout
87        #[cfg(target_os = "linux")]
88        {
89            // TCP_KEEPIDLE = 5
90            let idle: libc::c_int = 5;
91            if libc::setsockopt(
92                fd,
93                libc::IPPROTO_TCP,
94                libc::TCP_KEEPIDLE,
95                &idle as *const _ as *const libc::c_void,
96                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
97            ) != 0
98            {
99                return Err(io::Error::last_os_error());
100            }
101
102            // TCP_KEEPINTVL = 2
103            let intvl: libc::c_int = 2;
104            if libc::setsockopt(
105                fd,
106                libc::IPPROTO_TCP,
107                libc::TCP_KEEPINTVL,
108                &intvl as *const _ as *const libc::c_void,
109                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
110            ) != 0
111            {
112                return Err(io::Error::last_os_error());
113            }
114
115            // TCP_KEEPCNT = 12
116            let cnt: libc::c_int = 12;
117            if libc::setsockopt(
118                fd,
119                libc::IPPROTO_TCP,
120                libc::TCP_KEEPCNT,
121                &cnt as *const _ as *const libc::c_void,
122                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
123            ) != 0
124            {
125                return Err(io::Error::last_os_error());
126            }
127
128            // TCP_USER_TIMEOUT = 24000 ms
129            let timeout: libc::c_int = 24_000;
130            if libc::setsockopt(
131                fd,
132                libc::IPPROTO_TCP,
133                libc::TCP_USER_TIMEOUT,
134                &timeout as *const _ as *const libc::c_void,
135                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
136            ) != 0
137            {
138                return Err(io::Error::last_os_error());
139            }
140        }
141    }
142    Ok(())
143}
144
145/// Try to connect to the target host:port with timeout.
146fn try_connect(config: &TcpClientConfig) -> io::Result<TcpStream> {
147    let addr_str = format!("{}:{}", config.target_host, config.target_port);
148    let addr = addr_str
149        .to_socket_addrs()?
150        .next()
151        .ok_or_else(|| io::Error::new(io::ErrorKind::AddrNotAvailable, "no addresses resolved"))?;
152
153    #[cfg(target_os = "linux")]
154    let stream = if let Some(ref device) = config.device {
155        connect_with_device(&addr, device, config.connect_timeout)?
156    } else {
157        TcpStream::connect_timeout(&addr, config.connect_timeout)?
158    };
159    #[cfg(not(target_os = "linux"))]
160    let stream = TcpStream::connect_timeout(&addr, config.connect_timeout)?;
161    set_socket_options(&stream)?;
162    Ok(stream)
163}
164
165/// Create a TCP socket, bind it to a network device, then connect with timeout.
166#[cfg(target_os = "linux")]
167fn connect_with_device(
168    addr: &std::net::SocketAddr,
169    device: &str,
170    timeout: Duration,
171) -> io::Result<TcpStream> {
172    use std::os::unix::io::{FromRawFd, RawFd};
173
174    let domain = if addr.is_ipv4() { libc::AF_INET } else { libc::AF_INET6 };
175    let fd: RawFd = unsafe { libc::socket(domain, libc::SOCK_STREAM, 0) };
176    if fd < 0 {
177        return Err(io::Error::last_os_error());
178    }
179
180    // Ensure the fd is closed on error paths
181    let stream = unsafe { TcpStream::from_raw_fd(fd) };
182
183    super::bind_to_device(stream.as_raw_fd(), device)?;
184
185    // Set non-blocking for connect-with-timeout
186    stream.set_nonblocking(true)?;
187
188    let (sockaddr, socklen) = socket_addr_to_raw(addr);
189    let ret = unsafe {
190        libc::connect(
191            stream.as_raw_fd(),
192            &sockaddr as *const libc::sockaddr_storage as *const libc::sockaddr,
193            socklen,
194        )
195    };
196
197    if ret != 0 {
198        let err = io::Error::last_os_error();
199        if err.raw_os_error() != Some(libc::EINPROGRESS) {
200            return Err(err);
201        }
202    }
203
204    // Poll for connect completion
205    let mut pollfd = libc::pollfd {
206        fd: stream.as_raw_fd(),
207        events: libc::POLLOUT,
208        revents: 0,
209    };
210    let timeout_ms = timeout.as_millis() as libc::c_int;
211    let poll_ret = unsafe { libc::poll(&mut pollfd, 1, timeout_ms) };
212
213    if poll_ret == 0 {
214        return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
215    }
216    if poll_ret < 0 {
217        return Err(io::Error::last_os_error());
218    }
219
220    // Check SO_ERROR
221    let mut err_val: libc::c_int = 0;
222    let mut err_len: libc::socklen_t = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
223    let ret = unsafe {
224        libc::getsockopt(
225            stream.as_raw_fd(),
226            libc::SOL_SOCKET,
227            libc::SO_ERROR,
228            &mut err_val as *mut _ as *mut libc::c_void,
229            &mut err_len,
230        )
231    };
232    if ret != 0 {
233        return Err(io::Error::last_os_error());
234    }
235    if err_val != 0 {
236        return Err(io::Error::from_raw_os_error(err_val));
237    }
238
239    // Set back to blocking
240    stream.set_nonblocking(false)?;
241
242    Ok(stream)
243}
244
245/// Convert a `SocketAddr` to a raw `sockaddr_storage` for `libc::connect`.
246#[cfg(target_os = "linux")]
247fn socket_addr_to_raw(addr: &std::net::SocketAddr) -> (libc::sockaddr_storage, libc::socklen_t) {
248    let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
249    match addr {
250        std::net::SocketAddr::V4(v4) => {
251            let sin: &mut libc::sockaddr_in = unsafe {
252                &mut *(&mut storage as *mut libc::sockaddr_storage as *mut libc::sockaddr_in)
253            };
254            sin.sin_family = libc::AF_INET as libc::sa_family_t;
255            sin.sin_port = v4.port().to_be();
256            sin.sin_addr = libc::in_addr {
257                s_addr: u32::from_ne_bytes(v4.ip().octets()),
258            };
259            (storage, std::mem::size_of::<libc::sockaddr_in>() as libc::socklen_t)
260        }
261        std::net::SocketAddr::V6(v6) => {
262            let sin6: &mut libc::sockaddr_in6 = unsafe {
263                &mut *(&mut storage as *mut libc::sockaddr_storage as *mut libc::sockaddr_in6)
264            };
265            sin6.sin6_family = libc::AF_INET6 as libc::sa_family_t;
266            sin6.sin6_port = v6.port().to_be();
267            sin6.sin6_addr = libc::in6_addr {
268                s6_addr: v6.ip().octets(),
269            };
270            sin6.sin6_flowinfo = v6.flowinfo();
271            sin6.sin6_scope_id = v6.scope_id();
272            (storage, std::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t)
273        }
274    }
275}
276
277/// Connect and start the reader thread. Returns the writer for the driver.
278pub fn start(config: TcpClientConfig, tx: EventSender) -> io::Result<Box<dyn Writer>> {
279    let stream = try_connect(&config)?;
280    let reader_stream = stream.try_clone()?;
281    let writer_stream = stream.try_clone()?;
282
283    let id = config.interface_id;
284    // Initial connect: writer is None because it's returned directly to the caller
285    let _ = tx.send(Event::InterfaceUp(id, None, None));
286
287    // Spawn reader thread
288    let reader_config = config;
289    let reader_tx = tx;
290    thread::Builder::new()
291        .name(format!("tcp-reader-{}", id.0))
292        .spawn(move || {
293            reader_loop(reader_stream, reader_config, reader_tx);
294        })?;
295
296    Ok(Box::new(TcpWriter { stream: writer_stream }))
297}
298
299/// Reader thread: reads from socket, HDLC-decodes, sends frames to driver.
300/// On disconnect, attempts reconnection.
301fn reader_loop(mut stream: TcpStream, config: TcpClientConfig, tx: EventSender) {
302    let id = config.interface_id;
303    let mut decoder = hdlc::Decoder::new();
304    let mut buf = [0u8; 4096];
305
306    loop {
307        match stream.read(&mut buf) {
308            Ok(0) => {
309                // Connection closed by peer
310                log::warn!("[{}] connection closed", config.name);
311                let _ = tx.send(Event::InterfaceDown(id));
312                match reconnect(&config, &tx) {
313                    Some(new_stream) => {
314                        stream = new_stream;
315                        decoder = hdlc::Decoder::new();
316                        continue;
317                    }
318                    None => {
319                        log::error!("[{}] reconnection failed, giving up", config.name);
320                        return;
321                    }
322                }
323            }
324            Ok(n) => {
325                for frame in decoder.feed(&buf[..n]) {
326                    if tx.send(Event::Frame { interface_id: id, data: frame }).is_err() {
327                        // Driver shut down
328                        return;
329                    }
330                }
331            }
332            Err(e) => {
333                log::warn!("[{}] read error: {}", config.name, e);
334                let _ = tx.send(Event::InterfaceDown(id));
335                match reconnect(&config, &tx) {
336                    Some(new_stream) => {
337                        stream = new_stream;
338                        decoder = hdlc::Decoder::new();
339                        continue;
340                    }
341                    None => {
342                        log::error!("[{}] reconnection failed, giving up", config.name);
343                        return;
344                    }
345                }
346            }
347        }
348    }
349}
350
351/// Attempt to reconnect with retry logic. Returns the new reader stream on success.
352/// Sends the new writer to the driver via InterfaceUp event.
353fn reconnect(config: &TcpClientConfig, tx: &EventSender) -> Option<TcpStream> {
354    let mut attempts = 0u32;
355    loop {
356        thread::sleep(config.reconnect_wait);
357        attempts += 1;
358
359        if let Some(max) = config.max_reconnect_tries {
360            if attempts > max {
361                let _ = tx.send(Event::InterfaceDown(config.interface_id));
362                return None;
363            }
364        }
365
366        log::info!(
367            "[{}] reconnect attempt {} ...",
368            config.name,
369            attempts
370        );
371
372        match try_connect(config) {
373            Ok(new_stream) => {
374                // Clone the stream: one for the reader, one for the writer
375                let writer_stream = match new_stream.try_clone() {
376                    Ok(s) => s,
377                    Err(e) => {
378                        log::warn!("[{}] failed to clone stream: {}", config.name, e);
379                        continue;
380                    }
381                };
382                log::info!("[{}] reconnected", config.name);
383                // Send new writer to the driver so it can replace the stale one
384                let new_writer: Box<dyn Writer> = Box::new(TcpWriter { stream: writer_stream });
385                let _ = tx.send(Event::InterfaceUp(config.interface_id, Some(new_writer), None));
386                return Some(new_stream);
387            }
388            Err(e) => {
389                log::warn!("[{}] reconnect failed: {}", config.name, e);
390            }
391        }
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398    use std::net::TcpListener;
399    use std::sync::mpsc;
400    use std::time::Duration;
401
402    fn find_free_port() -> u16 {
403        TcpListener::bind("127.0.0.1:0")
404            .unwrap()
405            .local_addr()
406            .unwrap()
407            .port()
408    }
409
410    fn make_config(port: u16) -> TcpClientConfig {
411        TcpClientConfig {
412            name: format!("test-tcp-{}", port),
413            target_host: "127.0.0.1".into(),
414            target_port: port,
415            interface_id: InterfaceId(1),
416            reconnect_wait: Duration::from_millis(100),
417            max_reconnect_tries: Some(2),
418            connect_timeout: Duration::from_secs(2),
419            device: None,
420        }
421    }
422
423    #[test]
424    fn connect_to_listener() {
425        let port = find_free_port();
426        let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
427        let (tx, rx) = mpsc::channel();
428
429        let config = make_config(port);
430        let _writer = start(config, tx).unwrap();
431
432        // Accept the connection
433        let _server_stream = listener.accept().unwrap();
434
435        // Should receive InterfaceUp event
436        let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
437        assert!(matches!(event, Event::InterfaceUp(InterfaceId(1), _, _)));
438    }
439
440    #[test]
441    fn receive_frame() {
442        let port = find_free_port();
443        let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
444        let (tx, rx) = mpsc::channel();
445
446        let config = make_config(port);
447        let _writer = start(config, tx).unwrap();
448
449        let (mut server_stream, _) = listener.accept().unwrap();
450
451        // Drain the InterfaceUp event
452        let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
453
454        // Send an HDLC frame from server (>= 19 bytes payload)
455        let payload: Vec<u8> = (0..32).collect();
456        let framed = hdlc::frame(&payload);
457        server_stream.write_all(&framed).unwrap();
458
459        // Should receive Frame event
460        let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
461        match event {
462            Event::Frame { interface_id, data } => {
463                assert_eq!(interface_id, InterfaceId(1));
464                assert_eq!(data, payload);
465            }
466            other => panic!("expected Frame, got {:?}", other),
467        }
468    }
469
470    #[test]
471    fn send_frame() {
472        let port = find_free_port();
473        let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
474        let (tx, _rx) = mpsc::channel();
475
476        let config = make_config(port);
477        let mut writer = start(config, tx).unwrap();
478
479        let (mut server_stream, _) = listener.accept().unwrap();
480        server_stream
481            .set_read_timeout(Some(Duration::from_secs(2)))
482            .unwrap();
483
484        // Send a frame via writer
485        let payload: Vec<u8> = (0..24).collect();
486        writer.send_frame(&payload).unwrap();
487
488        // Read from server side
489        let mut buf = [0u8; 256];
490        let n = server_stream.read(&mut buf).unwrap();
491        let expected = hdlc::frame(&payload);
492        assert_eq!(&buf[..n], &expected[..]);
493    }
494
495    #[test]
496    fn multiple_frames() {
497        let port = find_free_port();
498        let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
499        let (tx, rx) = mpsc::channel();
500
501        let config = make_config(port);
502        let _writer = start(config, tx).unwrap();
503
504        let (mut server_stream, _) = listener.accept().unwrap();
505
506        // Drain InterfaceUp
507        let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
508
509        // Send multiple frames
510        let payloads: Vec<Vec<u8>> = (0..3).map(|i| (0..24).map(|j| j + i * 50).collect()).collect();
511        for p in &payloads {
512            server_stream.write_all(&hdlc::frame(p)).unwrap();
513        }
514
515        // Should receive all frames
516        for expected in &payloads {
517            let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
518            match event {
519                Event::Frame { data, .. } => assert_eq!(&data, expected),
520                other => panic!("expected Frame, got {:?}", other),
521            }
522        }
523    }
524
525    #[test]
526    fn split_across_reads() {
527        let port = find_free_port();
528        let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
529        let (tx, rx) = mpsc::channel();
530
531        let config = make_config(port);
532        let _writer = start(config, tx).unwrap();
533
534        let (mut server_stream, _) = listener.accept().unwrap();
535
536        // Drain InterfaceUp
537        let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
538
539        // Send frame in two parts
540        let payload: Vec<u8> = (0..32).collect();
541        let framed = hdlc::frame(&payload);
542        let mid = framed.len() / 2;
543
544        server_stream.write_all(&framed[..mid]).unwrap();
545        server_stream.flush().unwrap();
546        thread::sleep(Duration::from_millis(50));
547        server_stream.write_all(&framed[mid..]).unwrap();
548        server_stream.flush().unwrap();
549
550        let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
551        match event {
552            Event::Frame { data, .. } => assert_eq!(data, payload),
553            other => panic!("expected Frame, got {:?}", other),
554        }
555    }
556
557    #[test]
558    fn reconnect_on_close() {
559        let port = find_free_port();
560        let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
561        listener.set_nonblocking(false).unwrap();
562        let (tx, rx) = mpsc::channel();
563
564        let config = make_config(port);
565        let _writer = start(config, tx).unwrap();
566
567        // Accept first connection and immediately close it
568        let (server_stream, _) = listener.accept().unwrap();
569
570        // Drain InterfaceUp
571        let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
572
573        drop(server_stream);
574
575        // Should get InterfaceDown
576        let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
577        assert!(matches!(event, Event::InterfaceDown(InterfaceId(1))));
578
579        // Accept the reconnection
580        let _server_stream2 = listener.accept().unwrap();
581
582        // Should get InterfaceUp again
583        let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
584        assert!(matches!(event, Event::InterfaceUp(InterfaceId(1), _, _)));
585    }
586
587    #[test]
588    fn socket_options() {
589        let port = find_free_port();
590        let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
591
592        let stream = try_connect(&make_config(port)).unwrap();
593        let _server = listener.accept().unwrap();
594
595        // Verify TCP_NODELAY is set
596        let fd = stream.as_raw_fd();
597        let mut val: libc::c_int = 0;
598        let mut len: libc::socklen_t = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
599        unsafe {
600            libc::getsockopt(
601                fd,
602                libc::IPPROTO_TCP,
603                libc::TCP_NODELAY,
604                &mut val as *mut _ as *mut libc::c_void,
605                &mut len,
606            );
607        }
608        assert_eq!(val, 1, "TCP_NODELAY should be 1");
609    }
610
611    #[test]
612    fn connect_timeout() {
613        // Use a non-routable address to trigger timeout
614        let config = TcpClientConfig {
615            name: "timeout-test".into(),
616            target_host: "192.0.2.1".into(), // TEST-NET, non-routable
617            target_port: 12345,
618            interface_id: InterfaceId(99),
619            reconnect_wait: Duration::from_millis(100),
620            max_reconnect_tries: Some(0),
621            connect_timeout: Duration::from_millis(500),
622            device: None,
623        };
624
625        let start_time = std::time::Instant::now();
626        let result = try_connect(&config);
627        let elapsed = start_time.elapsed();
628
629        assert!(result.is_err());
630        // Should timeout roughly around 500ms, definitely under 5s
631        assert!(elapsed < Duration::from_secs(5));
632    }
633}