Skip to main content

rns_net/interface/
tcp.rs

1//! TCP client interface with HDLC framing.
2//!
3//! Matches Python `TCPClientInterface` from `TCPInterface.py`.
4
5use std::io::{self, Read, Write};
6use std::net::{TcpStream, ToSocketAddrs};
7use std::os::unix::io::AsRawFd;
8use std::thread;
9use std::time::Duration;
10
11use rns_core::transport::types::InterfaceId;
12
13use crate::event::{Event, EventSender};
14use crate::hdlc;
15use crate::interface::Writer;
16
17/// Configuration for a TCP client interface.
18#[derive(Debug, Clone)]
19pub struct TcpClientConfig {
20    pub name: String,
21    pub target_host: String,
22    pub target_port: u16,
23    pub interface_id: InterfaceId,
24    pub reconnect_wait: Duration,
25    pub max_reconnect_tries: Option<u32>,
26    pub connect_timeout: Duration,
27    /// Linux network interface to bind the socket to (e.g. "usb0").
28    pub device: Option<String>,
29}
30
31impl Default for TcpClientConfig {
32    fn default() -> Self {
33        TcpClientConfig {
34            name: String::new(),
35            target_host: "127.0.0.1".into(),
36            target_port: 4242,
37            interface_id: InterfaceId(0),
38            reconnect_wait: Duration::from_secs(5),
39            max_reconnect_tries: None,
40            connect_timeout: Duration::from_secs(5),
41            device: None,
42        }
43    }
44}
45
46/// Writer that sends HDLC-framed data over a TCP stream.
47struct TcpWriter {
48    stream: TcpStream,
49}
50
51impl Writer for TcpWriter {
52    fn send_frame(&mut self, data: &[u8]) -> io::Result<()> {
53        self.stream.write_all(&hdlc::frame(data))
54    }
55}
56
57/// Set TCP keepalive and timeout socket options (Linux).
58fn set_socket_options(stream: &TcpStream) -> io::Result<()> {
59    let fd = stream.as_raw_fd();
60    unsafe {
61        // TCP_NODELAY = 1
62        let val: libc::c_int = 1;
63        if libc::setsockopt(
64            fd,
65            libc::IPPROTO_TCP,
66            libc::TCP_NODELAY,
67            &val as *const _ as *const libc::c_void,
68            std::mem::size_of::<libc::c_int>() as libc::socklen_t,
69        ) != 0
70        {
71            return Err(io::Error::last_os_error());
72        }
73
74        // SO_KEEPALIVE = 1
75        if libc::setsockopt(
76            fd,
77            libc::SOL_SOCKET,
78            libc::SO_KEEPALIVE,
79            &val as *const _ as *const libc::c_void,
80            std::mem::size_of::<libc::c_int>() as libc::socklen_t,
81        ) != 0
82        {
83            return Err(io::Error::last_os_error());
84        }
85
86        // Linux-specific keepalive tuning and user timeout
87        #[cfg(target_os = "linux")]
88        {
89            // TCP_KEEPIDLE = 5
90            let idle: libc::c_int = 5;
91            if libc::setsockopt(
92                fd,
93                libc::IPPROTO_TCP,
94                libc::TCP_KEEPIDLE,
95                &idle as *const _ as *const libc::c_void,
96                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
97            ) != 0
98            {
99                return Err(io::Error::last_os_error());
100            }
101
102            // TCP_KEEPINTVL = 2
103            let intvl: libc::c_int = 2;
104            if libc::setsockopt(
105                fd,
106                libc::IPPROTO_TCP,
107                libc::TCP_KEEPINTVL,
108                &intvl as *const _ as *const libc::c_void,
109                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
110            ) != 0
111            {
112                return Err(io::Error::last_os_error());
113            }
114
115            // TCP_KEEPCNT = 12
116            let cnt: libc::c_int = 12;
117            if libc::setsockopt(
118                fd,
119                libc::IPPROTO_TCP,
120                libc::TCP_KEEPCNT,
121                &cnt as *const _ as *const libc::c_void,
122                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
123            ) != 0
124            {
125                return Err(io::Error::last_os_error());
126            }
127
128            // TCP_USER_TIMEOUT = 24000 ms
129            let timeout: libc::c_int = 24_000;
130            if libc::setsockopt(
131                fd,
132                libc::IPPROTO_TCP,
133                libc::TCP_USER_TIMEOUT,
134                &timeout as *const _ as *const libc::c_void,
135                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
136            ) != 0
137            {
138                return Err(io::Error::last_os_error());
139            }
140        }
141    }
142    Ok(())
143}
144
145/// Try to connect to the target host:port with timeout.
146fn try_connect(config: &TcpClientConfig) -> io::Result<TcpStream> {
147    let addr_str = format!("{}:{}", config.target_host, config.target_port);
148    let addr = addr_str
149        .to_socket_addrs()?
150        .next()
151        .ok_or_else(|| io::Error::new(io::ErrorKind::AddrNotAvailable, "no addresses resolved"))?;
152
153    #[cfg(target_os = "linux")]
154    let stream = if let Some(ref device) = config.device {
155        connect_with_device(&addr, device, config.connect_timeout)?
156    } else {
157        TcpStream::connect_timeout(&addr, config.connect_timeout)?
158    };
159    #[cfg(not(target_os = "linux"))]
160    let stream = TcpStream::connect_timeout(&addr, config.connect_timeout)?;
161    set_socket_options(&stream)?;
162    Ok(stream)
163}
164
165/// Create a TCP socket, bind it to a network device, then connect with timeout.
166#[cfg(target_os = "linux")]
167fn connect_with_device(
168    addr: &std::net::SocketAddr,
169    device: &str,
170    timeout: Duration,
171) -> io::Result<TcpStream> {
172    use std::os::unix::io::{FromRawFd, RawFd};
173
174    let domain = if addr.is_ipv4() { libc::AF_INET } else { libc::AF_INET6 };
175    let fd: RawFd = unsafe { libc::socket(domain, libc::SOCK_STREAM, 0) };
176    if fd < 0 {
177        return Err(io::Error::last_os_error());
178    }
179
180    // Ensure the fd is closed on error paths
181    let stream = unsafe { TcpStream::from_raw_fd(fd) };
182
183    super::bind_to_device(stream.as_raw_fd(), device)?;
184
185    // Set non-blocking for connect-with-timeout
186    stream.set_nonblocking(true)?;
187
188    let (sockaddr, socklen) = socket_addr_to_raw(addr);
189    let ret = unsafe {
190        libc::connect(
191            stream.as_raw_fd(),
192            &sockaddr as *const libc::sockaddr_storage as *const libc::sockaddr,
193            socklen,
194        )
195    };
196
197    if ret != 0 {
198        let err = io::Error::last_os_error();
199        if err.raw_os_error() != Some(libc::EINPROGRESS) {
200            return Err(err);
201        }
202    }
203
204    // Poll for connect completion
205    let mut pollfd = libc::pollfd {
206        fd: stream.as_raw_fd(),
207        events: libc::POLLOUT,
208        revents: 0,
209    };
210    let timeout_ms = timeout.as_millis() as libc::c_int;
211    let poll_ret = unsafe { libc::poll(&mut pollfd, 1, timeout_ms) };
212
213    if poll_ret == 0 {
214        return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
215    }
216    if poll_ret < 0 {
217        return Err(io::Error::last_os_error());
218    }
219
220    // Check SO_ERROR
221    let mut err_val: libc::c_int = 0;
222    let mut err_len: libc::socklen_t = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
223    let ret = unsafe {
224        libc::getsockopt(
225            stream.as_raw_fd(),
226            libc::SOL_SOCKET,
227            libc::SO_ERROR,
228            &mut err_val as *mut _ as *mut libc::c_void,
229            &mut err_len,
230        )
231    };
232    if ret != 0 {
233        return Err(io::Error::last_os_error());
234    }
235    if err_val != 0 {
236        return Err(io::Error::from_raw_os_error(err_val));
237    }
238
239    // Set back to blocking
240    stream.set_nonblocking(false)?;
241
242    Ok(stream)
243}
244
245/// Convert a `SocketAddr` to a raw `sockaddr_storage` for `libc::connect`.
246#[cfg(target_os = "linux")]
247fn socket_addr_to_raw(addr: &std::net::SocketAddr) -> (libc::sockaddr_storage, libc::socklen_t) {
248    let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
249    match addr {
250        std::net::SocketAddr::V4(v4) => {
251            let sin: &mut libc::sockaddr_in = unsafe {
252                &mut *(&mut storage as *mut libc::sockaddr_storage as *mut libc::sockaddr_in)
253            };
254            sin.sin_family = libc::AF_INET as libc::sa_family_t;
255            sin.sin_port = v4.port().to_be();
256            sin.sin_addr = libc::in_addr {
257                s_addr: u32::from_ne_bytes(v4.ip().octets()),
258            };
259            (storage, std::mem::size_of::<libc::sockaddr_in>() as libc::socklen_t)
260        }
261        std::net::SocketAddr::V6(v6) => {
262            let sin6: &mut libc::sockaddr_in6 = unsafe {
263                &mut *(&mut storage as *mut libc::sockaddr_storage as *mut libc::sockaddr_in6)
264            };
265            sin6.sin6_family = libc::AF_INET6 as libc::sa_family_t;
266            sin6.sin6_port = v6.port().to_be();
267            sin6.sin6_addr = libc::in6_addr {
268                s6_addr: v6.ip().octets(),
269            };
270            sin6.sin6_flowinfo = v6.flowinfo();
271            sin6.sin6_scope_id = v6.scope_id();
272            (storage, std::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t)
273        }
274    }
275}
276
277/// Connect and start the reader thread. Returns the writer for the driver.
278pub fn start(config: TcpClientConfig, tx: EventSender) -> io::Result<Box<dyn Writer>> {
279    let stream = try_connect(&config)?;
280    let reader_stream = stream.try_clone()?;
281    let writer_stream = stream.try_clone()?;
282
283    let id = config.interface_id;
284    // Initial connect: writer is None because it's returned directly to the caller
285    let _ = tx.send(Event::InterfaceUp(id, None, None));
286
287    // Spawn reader thread
288    let reader_config = config;
289    let reader_tx = tx;
290    thread::Builder::new()
291        .name(format!("tcp-reader-{}", id.0))
292        .spawn(move || {
293            reader_loop(reader_stream, reader_config, reader_tx);
294        })?;
295
296    Ok(Box::new(TcpWriter { stream: writer_stream }))
297}
298
299/// Reader thread: reads from socket, HDLC-decodes, sends frames to driver.
300/// On disconnect, attempts reconnection.
301fn reader_loop(mut stream: TcpStream, config: TcpClientConfig, tx: EventSender) {
302    let id = config.interface_id;
303    let mut decoder = hdlc::Decoder::new();
304    let mut buf = [0u8; 4096];
305
306    loop {
307        match stream.read(&mut buf) {
308            Ok(0) => {
309                // Connection closed by peer
310                log::warn!("[{}] connection closed", config.name);
311                let _ = tx.send(Event::InterfaceDown(id));
312                match reconnect(&config, &tx) {
313                    Some(new_stream) => {
314                        stream = new_stream;
315                        decoder = hdlc::Decoder::new();
316                        continue;
317                    }
318                    None => {
319                        log::error!("[{}] reconnection failed, giving up", config.name);
320                        return;
321                    }
322                }
323            }
324            Ok(n) => {
325                for frame in decoder.feed(&buf[..n]) {
326                    if tx.send(Event::Frame { interface_id: id, data: frame }).is_err() {
327                        // Driver shut down
328                        return;
329                    }
330                }
331            }
332            Err(e) => {
333                log::warn!("[{}] read error: {}", config.name, e);
334                let _ = tx.send(Event::InterfaceDown(id));
335                match reconnect(&config, &tx) {
336                    Some(new_stream) => {
337                        stream = new_stream;
338                        decoder = hdlc::Decoder::new();
339                        continue;
340                    }
341                    None => {
342                        log::error!("[{}] reconnection failed, giving up", config.name);
343                        return;
344                    }
345                }
346            }
347        }
348    }
349}
350
351/// Attempt to reconnect with retry logic. Returns the new reader stream on success.
352/// Sends the new writer to the driver via InterfaceUp event.
353fn reconnect(config: &TcpClientConfig, tx: &EventSender) -> Option<TcpStream> {
354    let mut attempts = 0u32;
355    loop {
356        thread::sleep(config.reconnect_wait);
357        attempts += 1;
358
359        if let Some(max) = config.max_reconnect_tries {
360            if attempts > max {
361                let _ = tx.send(Event::InterfaceDown(config.interface_id));
362                return None;
363            }
364        }
365
366        log::info!(
367            "[{}] reconnect attempt {} ...",
368            config.name,
369            attempts
370        );
371
372        match try_connect(config) {
373            Ok(new_stream) => {
374                // Clone the stream: one for the reader, one for the writer
375                let writer_stream = match new_stream.try_clone() {
376                    Ok(s) => s,
377                    Err(e) => {
378                        log::warn!("[{}] failed to clone stream: {}", config.name, e);
379                        continue;
380                    }
381                };
382                log::info!("[{}] reconnected", config.name);
383                // Send new writer to the driver so it can replace the stale one
384                let new_writer: Box<dyn Writer> = Box::new(TcpWriter { stream: writer_stream });
385                let _ = tx.send(Event::InterfaceUp(config.interface_id, Some(new_writer), None));
386                return Some(new_stream);
387            }
388            Err(e) => {
389                log::warn!("[{}] reconnect failed: {}", config.name, e);
390            }
391        }
392    }
393}
394
395// --- Factory implementation ---
396
397use std::collections::HashMap;
398use rns_core::transport::types::InterfaceInfo;
399use super::{InterfaceFactory, InterfaceConfigData, StartContext, StartResult};
400
401/// Factory for `TCPClientInterface`.
402pub struct TcpClientFactory;
403
404impl InterfaceFactory for TcpClientFactory {
405    fn type_name(&self) -> &str { "TCPClientInterface" }
406
407    fn parse_config(
408        &self,
409        name: &str,
410        id: InterfaceId,
411        params: &HashMap<String, String>,
412    ) -> Result<Box<dyn InterfaceConfigData>, String> {
413        let target_host = params.get("target_host")
414            .cloned()
415            .unwrap_or_else(|| "127.0.0.1".into());
416        let target_port = params.get("target_port")
417            .and_then(|v| v.parse().ok())
418            .unwrap_or(4242);
419
420        Ok(Box::new(TcpClientConfig {
421            name: name.to_string(),
422            target_host,
423            target_port,
424            interface_id: id,
425            device: params.get("device").cloned(),
426            ..TcpClientConfig::default()
427        }))
428    }
429
430    fn start(
431        &self,
432        config: Box<dyn InterfaceConfigData>,
433        ctx: StartContext,
434    ) -> io::Result<StartResult> {
435        let tcp_config = *config.into_any().downcast::<TcpClientConfig>()
436            .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "wrong config type"))?;
437
438        let id = tcp_config.interface_id;
439        let name = tcp_config.name.clone();
440        let info = InterfaceInfo {
441            id,
442            name,
443            mode: ctx.mode,
444            out_capable: true,
445            in_capable: true,
446            bitrate: None,
447            announce_rate_target: None,
448            announce_rate_grace: 0,
449            announce_rate_penalty: 0.0,
450            announce_cap: rns_core::constants::ANNOUNCE_CAP,
451            is_local_client: false,
452            wants_tunnel: false,
453            tunnel_id: None,
454            mtu: 65535,
455            ingress_control: true,
456            ia_freq: 0.0,
457            started: crate::time::now(),
458        };
459
460        let writer = start(tcp_config, ctx.tx)?;
461
462        Ok(StartResult::Simple {
463            id,
464            info,
465            writer,
466            interface_type_name: "TCPClientInterface".to_string(),
467        })
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474    use std::net::TcpListener;
475    use std::sync::mpsc;
476    use std::time::Duration;
477
478    fn find_free_port() -> u16 {
479        TcpListener::bind("127.0.0.1:0")
480            .unwrap()
481            .local_addr()
482            .unwrap()
483            .port()
484    }
485
486    fn make_config(port: u16) -> TcpClientConfig {
487        TcpClientConfig {
488            name: format!("test-tcp-{}", port),
489            target_host: "127.0.0.1".into(),
490            target_port: port,
491            interface_id: InterfaceId(1),
492            reconnect_wait: Duration::from_millis(100),
493            max_reconnect_tries: Some(2),
494            connect_timeout: Duration::from_secs(2),
495            device: None,
496        }
497    }
498
499    #[test]
500    fn connect_to_listener() {
501        let port = find_free_port();
502        let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
503        let (tx, rx) = mpsc::channel();
504
505        let config = make_config(port);
506        let _writer = start(config, tx).unwrap();
507
508        // Accept the connection
509        let _server_stream = listener.accept().unwrap();
510
511        // Should receive InterfaceUp event
512        let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
513        assert!(matches!(event, Event::InterfaceUp(InterfaceId(1), _, _)));
514    }
515
516    #[test]
517    fn receive_frame() {
518        let port = find_free_port();
519        let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
520        let (tx, rx) = mpsc::channel();
521
522        let config = make_config(port);
523        let _writer = start(config, tx).unwrap();
524
525        let (mut server_stream, _) = listener.accept().unwrap();
526
527        // Drain the InterfaceUp event
528        let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
529
530        // Send an HDLC frame from server (>= 19 bytes payload)
531        let payload: Vec<u8> = (0..32).collect();
532        let framed = hdlc::frame(&payload);
533        server_stream.write_all(&framed).unwrap();
534
535        // Should receive Frame event
536        let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
537        match event {
538            Event::Frame { interface_id, data } => {
539                assert_eq!(interface_id, InterfaceId(1));
540                assert_eq!(data, payload);
541            }
542            other => panic!("expected Frame, got {:?}", other),
543        }
544    }
545
546    #[test]
547    fn send_frame() {
548        let port = find_free_port();
549        let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
550        let (tx, _rx) = mpsc::channel();
551
552        let config = make_config(port);
553        let mut writer = start(config, tx).unwrap();
554
555        let (mut server_stream, _) = listener.accept().unwrap();
556        server_stream
557            .set_read_timeout(Some(Duration::from_secs(2)))
558            .unwrap();
559
560        // Send a frame via writer
561        let payload: Vec<u8> = (0..24).collect();
562        writer.send_frame(&payload).unwrap();
563
564        // Read from server side
565        let mut buf = [0u8; 256];
566        let n = server_stream.read(&mut buf).unwrap();
567        let expected = hdlc::frame(&payload);
568        assert_eq!(&buf[..n], &expected[..]);
569    }
570
571    #[test]
572    fn multiple_frames() {
573        let port = find_free_port();
574        let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
575        let (tx, rx) = mpsc::channel();
576
577        let config = make_config(port);
578        let _writer = start(config, tx).unwrap();
579
580        let (mut server_stream, _) = listener.accept().unwrap();
581
582        // Drain InterfaceUp
583        let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
584
585        // Send multiple frames
586        let payloads: Vec<Vec<u8>> = (0..3).map(|i| (0..24).map(|j| j + i * 50).collect()).collect();
587        for p in &payloads {
588            server_stream.write_all(&hdlc::frame(p)).unwrap();
589        }
590
591        // Should receive all frames
592        for expected in &payloads {
593            let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
594            match event {
595                Event::Frame { data, .. } => assert_eq!(&data, expected),
596                other => panic!("expected Frame, got {:?}", other),
597            }
598        }
599    }
600
601    #[test]
602    fn split_across_reads() {
603        let port = find_free_port();
604        let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
605        let (tx, rx) = mpsc::channel();
606
607        let config = make_config(port);
608        let _writer = start(config, tx).unwrap();
609
610        let (mut server_stream, _) = listener.accept().unwrap();
611
612        // Drain InterfaceUp
613        let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
614
615        // Send frame in two parts
616        let payload: Vec<u8> = (0..32).collect();
617        let framed = hdlc::frame(&payload);
618        let mid = framed.len() / 2;
619
620        server_stream.write_all(&framed[..mid]).unwrap();
621        server_stream.flush().unwrap();
622        thread::sleep(Duration::from_millis(50));
623        server_stream.write_all(&framed[mid..]).unwrap();
624        server_stream.flush().unwrap();
625
626        let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
627        match event {
628            Event::Frame { data, .. } => assert_eq!(data, payload),
629            other => panic!("expected Frame, got {:?}", other),
630        }
631    }
632
633    #[test]
634    fn reconnect_on_close() {
635        let port = find_free_port();
636        let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
637        listener.set_nonblocking(false).unwrap();
638        let (tx, rx) = mpsc::channel();
639
640        let config = make_config(port);
641        let _writer = start(config, tx).unwrap();
642
643        // Accept first connection and immediately close it
644        let (server_stream, _) = listener.accept().unwrap();
645
646        // Drain InterfaceUp
647        let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
648
649        drop(server_stream);
650
651        // Should get InterfaceDown
652        let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
653        assert!(matches!(event, Event::InterfaceDown(InterfaceId(1))));
654
655        // Accept the reconnection
656        let _server_stream2 = listener.accept().unwrap();
657
658        // Should get InterfaceUp again
659        let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
660        assert!(matches!(event, Event::InterfaceUp(InterfaceId(1), _, _)));
661    }
662
663    #[test]
664    fn socket_options() {
665        let port = find_free_port();
666        let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
667
668        let stream = try_connect(&make_config(port)).unwrap();
669        let _server = listener.accept().unwrap();
670
671        // Verify TCP_NODELAY is set
672        let fd = stream.as_raw_fd();
673        let mut val: libc::c_int = 0;
674        let mut len: libc::socklen_t = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
675        unsafe {
676            libc::getsockopt(
677                fd,
678                libc::IPPROTO_TCP,
679                libc::TCP_NODELAY,
680                &mut val as *mut _ as *mut libc::c_void,
681                &mut len,
682            );
683        }
684        assert_eq!(val, 1, "TCP_NODELAY should be 1");
685    }
686
687    #[test]
688    fn connect_timeout() {
689        // Use a non-routable address to trigger timeout
690        let config = TcpClientConfig {
691            name: "timeout-test".into(),
692            target_host: "192.0.2.1".into(), // TEST-NET, non-routable
693            target_port: 12345,
694            interface_id: InterfaceId(99),
695            reconnect_wait: Duration::from_millis(100),
696            max_reconnect_tries: Some(0),
697            connect_timeout: Duration::from_millis(500),
698            device: None,
699        };
700
701        let start_time = std::time::Instant::now();
702        let result = try_connect(&config);
703        let elapsed = start_time.elapsed();
704
705        assert!(result.is_err());
706        // Should timeout roughly around 500ms, definitely under 5s
707        assert!(elapsed < Duration::from_secs(5));
708    }
709}