Skip to main content

rns_net/interface/
tcp.rs

1//! TCP client interface with HDLC framing.
2//!
3//! Matches Python `TCPClientInterface` from `TCPInterface.py`.
4
5use std::io::{self, Read, Write};
6use std::net::{TcpStream, ToSocketAddrs};
7use std::os::unix::io::AsRawFd;
8use std::thread;
9use std::time::Duration;
10
11use rns_core::transport::types::InterfaceId;
12
13use crate::event::{Event, EventSender};
14use crate::hdlc;
15use crate::interface::Writer;
16
17/// Configuration for a TCP client interface.
18#[derive(Debug, Clone)]
19pub struct TcpClientConfig {
20    pub name: String,
21    pub target_host: String,
22    pub target_port: u16,
23    pub interface_id: InterfaceId,
24    pub reconnect_wait: Duration,
25    pub max_reconnect_tries: Option<u32>,
26    pub connect_timeout: Duration,
27}
28
29impl Default for TcpClientConfig {
30    fn default() -> Self {
31        TcpClientConfig {
32            name: String::new(),
33            target_host: "127.0.0.1".into(),
34            target_port: 4242,
35            interface_id: InterfaceId(0),
36            reconnect_wait: Duration::from_secs(5),
37            max_reconnect_tries: None,
38            connect_timeout: Duration::from_secs(5),
39        }
40    }
41}
42
43/// Writer that sends HDLC-framed data over a TCP stream.
44struct TcpWriter {
45    stream: TcpStream,
46}
47
48impl Writer for TcpWriter {
49    fn send_frame(&mut self, data: &[u8]) -> io::Result<()> {
50        self.stream.write_all(&hdlc::frame(data))
51    }
52}
53
54/// Set TCP keepalive and timeout socket options (Linux).
55fn set_socket_options(stream: &TcpStream) -> io::Result<()> {
56    let fd = stream.as_raw_fd();
57    unsafe {
58        // TCP_NODELAY = 1
59        let val: libc::c_int = 1;
60        if libc::setsockopt(
61            fd,
62            libc::IPPROTO_TCP,
63            libc::TCP_NODELAY,
64            &val as *const _ as *const libc::c_void,
65            std::mem::size_of::<libc::c_int>() as libc::socklen_t,
66        ) != 0
67        {
68            return Err(io::Error::last_os_error());
69        }
70
71        // SO_KEEPALIVE = 1
72        if libc::setsockopt(
73            fd,
74            libc::SOL_SOCKET,
75            libc::SO_KEEPALIVE,
76            &val as *const _ as *const libc::c_void,
77            std::mem::size_of::<libc::c_int>() as libc::socklen_t,
78        ) != 0
79        {
80            return Err(io::Error::last_os_error());
81        }
82
83        // Linux-specific keepalive tuning and user timeout
84        #[cfg(target_os = "linux")]
85        {
86            // TCP_KEEPIDLE = 5
87            let idle: libc::c_int = 5;
88            if libc::setsockopt(
89                fd,
90                libc::IPPROTO_TCP,
91                libc::TCP_KEEPIDLE,
92                &idle as *const _ as *const libc::c_void,
93                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
94            ) != 0
95            {
96                return Err(io::Error::last_os_error());
97            }
98
99            // TCP_KEEPINTVL = 2
100            let intvl: libc::c_int = 2;
101            if libc::setsockopt(
102                fd,
103                libc::IPPROTO_TCP,
104                libc::TCP_KEEPINTVL,
105                &intvl as *const _ as *const libc::c_void,
106                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
107            ) != 0
108            {
109                return Err(io::Error::last_os_error());
110            }
111
112            // TCP_KEEPCNT = 12
113            let cnt: libc::c_int = 12;
114            if libc::setsockopt(
115                fd,
116                libc::IPPROTO_TCP,
117                libc::TCP_KEEPCNT,
118                &cnt as *const _ as *const libc::c_void,
119                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
120            ) != 0
121            {
122                return Err(io::Error::last_os_error());
123            }
124
125            // TCP_USER_TIMEOUT = 24000 ms
126            let timeout: libc::c_int = 24_000;
127            if libc::setsockopt(
128                fd,
129                libc::IPPROTO_TCP,
130                libc::TCP_USER_TIMEOUT,
131                &timeout as *const _ as *const libc::c_void,
132                std::mem::size_of::<libc::c_int>() as libc::socklen_t,
133            ) != 0
134            {
135                return Err(io::Error::last_os_error());
136            }
137        }
138    }
139    Ok(())
140}
141
142/// Try to connect to the target host:port with timeout.
143fn try_connect(config: &TcpClientConfig) -> io::Result<TcpStream> {
144    let addr_str = format!("{}:{}", config.target_host, config.target_port);
145    let addr = addr_str
146        .to_socket_addrs()?
147        .next()
148        .ok_or_else(|| io::Error::new(io::ErrorKind::AddrNotAvailable, "no addresses resolved"))?;
149
150    let stream = TcpStream::connect_timeout(&addr, config.connect_timeout)?;
151    set_socket_options(&stream)?;
152    Ok(stream)
153}
154
155/// Connect and start the reader thread. Returns the writer for the driver.
156pub fn start(config: TcpClientConfig, tx: EventSender) -> io::Result<Box<dyn Writer>> {
157    let stream = try_connect(&config)?;
158    let reader_stream = stream.try_clone()?;
159    let writer_stream = stream.try_clone()?;
160
161    let id = config.interface_id;
162    // Initial connect: writer is None because it's returned directly to the caller
163    let _ = tx.send(Event::InterfaceUp(id, None, None));
164
165    // Spawn reader thread
166    let reader_config = config;
167    let reader_tx = tx;
168    thread::Builder::new()
169        .name(format!("tcp-reader-{}", id.0))
170        .spawn(move || {
171            reader_loop(reader_stream, reader_config, reader_tx);
172        })?;
173
174    Ok(Box::new(TcpWriter { stream: writer_stream }))
175}
176
177/// Reader thread: reads from socket, HDLC-decodes, sends frames to driver.
178/// On disconnect, attempts reconnection.
179fn reader_loop(mut stream: TcpStream, config: TcpClientConfig, tx: EventSender) {
180    let id = config.interface_id;
181    let mut decoder = hdlc::Decoder::new();
182    let mut buf = [0u8; 4096];
183
184    loop {
185        match stream.read(&mut buf) {
186            Ok(0) => {
187                // Connection closed by peer
188                log::warn!("[{}] connection closed", config.name);
189                let _ = tx.send(Event::InterfaceDown(id));
190                match reconnect(&config, &tx) {
191                    Some(new_stream) => {
192                        stream = new_stream;
193                        decoder = hdlc::Decoder::new();
194                        continue;
195                    }
196                    None => {
197                        log::error!("[{}] reconnection failed, giving up", config.name);
198                        return;
199                    }
200                }
201            }
202            Ok(n) => {
203                for frame in decoder.feed(&buf[..n]) {
204                    if tx.send(Event::Frame { interface_id: id, data: frame }).is_err() {
205                        // Driver shut down
206                        return;
207                    }
208                }
209            }
210            Err(e) => {
211                log::warn!("[{}] read error: {}", config.name, e);
212                let _ = tx.send(Event::InterfaceDown(id));
213                match reconnect(&config, &tx) {
214                    Some(new_stream) => {
215                        stream = new_stream;
216                        decoder = hdlc::Decoder::new();
217                        continue;
218                    }
219                    None => {
220                        log::error!("[{}] reconnection failed, giving up", config.name);
221                        return;
222                    }
223                }
224            }
225        }
226    }
227}
228
229/// Attempt to reconnect with retry logic. Returns the new reader stream on success.
230/// Sends the new writer to the driver via InterfaceUp event.
231fn reconnect(config: &TcpClientConfig, tx: &EventSender) -> Option<TcpStream> {
232    let mut attempts = 0u32;
233    loop {
234        thread::sleep(config.reconnect_wait);
235        attempts += 1;
236
237        if let Some(max) = config.max_reconnect_tries {
238            if attempts > max {
239                let _ = tx.send(Event::InterfaceDown(config.interface_id));
240                return None;
241            }
242        }
243
244        log::info!(
245            "[{}] reconnect attempt {} ...",
246            config.name,
247            attempts
248        );
249
250        match try_connect(config) {
251            Ok(new_stream) => {
252                // Clone the stream: one for the reader, one for the writer
253                let writer_stream = match new_stream.try_clone() {
254                    Ok(s) => s,
255                    Err(e) => {
256                        log::warn!("[{}] failed to clone stream: {}", config.name, e);
257                        continue;
258                    }
259                };
260                log::info!("[{}] reconnected", config.name);
261                // Send new writer to the driver so it can replace the stale one
262                let new_writer: Box<dyn Writer> = Box::new(TcpWriter { stream: writer_stream });
263                let _ = tx.send(Event::InterfaceUp(config.interface_id, Some(new_writer), None));
264                return Some(new_stream);
265            }
266            Err(e) => {
267                log::warn!("[{}] reconnect failed: {}", config.name, e);
268            }
269        }
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use std::net::TcpListener;
277    use std::sync::mpsc;
278    use std::time::Duration;
279
280    fn find_free_port() -> u16 {
281        TcpListener::bind("127.0.0.1:0")
282            .unwrap()
283            .local_addr()
284            .unwrap()
285            .port()
286    }
287
288    fn make_config(port: u16) -> TcpClientConfig {
289        TcpClientConfig {
290            name: format!("test-tcp-{}", port),
291            target_host: "127.0.0.1".into(),
292            target_port: port,
293            interface_id: InterfaceId(1),
294            reconnect_wait: Duration::from_millis(100),
295            max_reconnect_tries: Some(2),
296            connect_timeout: Duration::from_secs(2),
297        }
298    }
299
300    #[test]
301    fn connect_to_listener() {
302        let port = find_free_port();
303        let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
304        let (tx, rx) = mpsc::channel();
305
306        let config = make_config(port);
307        let _writer = start(config, tx).unwrap();
308
309        // Accept the connection
310        let _server_stream = listener.accept().unwrap();
311
312        // Should receive InterfaceUp event
313        let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
314        assert!(matches!(event, Event::InterfaceUp(InterfaceId(1), _, _)));
315    }
316
317    #[test]
318    fn receive_frame() {
319        let port = find_free_port();
320        let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
321        let (tx, rx) = mpsc::channel();
322
323        let config = make_config(port);
324        let _writer = start(config, tx).unwrap();
325
326        let (mut server_stream, _) = listener.accept().unwrap();
327
328        // Drain the InterfaceUp event
329        let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
330
331        // Send an HDLC frame from server (>= 19 bytes payload)
332        let payload: Vec<u8> = (0..32).collect();
333        let framed = hdlc::frame(&payload);
334        server_stream.write_all(&framed).unwrap();
335
336        // Should receive Frame event
337        let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
338        match event {
339            Event::Frame { interface_id, data } => {
340                assert_eq!(interface_id, InterfaceId(1));
341                assert_eq!(data, payload);
342            }
343            other => panic!("expected Frame, got {:?}", other),
344        }
345    }
346
347    #[test]
348    fn send_frame() {
349        let port = find_free_port();
350        let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
351        let (tx, _rx) = mpsc::channel();
352
353        let config = make_config(port);
354        let mut writer = start(config, tx).unwrap();
355
356        let (mut server_stream, _) = listener.accept().unwrap();
357        server_stream
358            .set_read_timeout(Some(Duration::from_secs(2)))
359            .unwrap();
360
361        // Send a frame via writer
362        let payload: Vec<u8> = (0..24).collect();
363        writer.send_frame(&payload).unwrap();
364
365        // Read from server side
366        let mut buf = [0u8; 256];
367        let n = server_stream.read(&mut buf).unwrap();
368        let expected = hdlc::frame(&payload);
369        assert_eq!(&buf[..n], &expected[..]);
370    }
371
372    #[test]
373    fn multiple_frames() {
374        let port = find_free_port();
375        let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
376        let (tx, rx) = mpsc::channel();
377
378        let config = make_config(port);
379        let _writer = start(config, tx).unwrap();
380
381        let (mut server_stream, _) = listener.accept().unwrap();
382
383        // Drain InterfaceUp
384        let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
385
386        // Send multiple frames
387        let payloads: Vec<Vec<u8>> = (0..3).map(|i| (0..24).map(|j| j + i * 50).collect()).collect();
388        for p in &payloads {
389            server_stream.write_all(&hdlc::frame(p)).unwrap();
390        }
391
392        // Should receive all frames
393        for expected in &payloads {
394            let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
395            match event {
396                Event::Frame { data, .. } => assert_eq!(&data, expected),
397                other => panic!("expected Frame, got {:?}", other),
398            }
399        }
400    }
401
402    #[test]
403    fn split_across_reads() {
404        let port = find_free_port();
405        let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
406        let (tx, rx) = mpsc::channel();
407
408        let config = make_config(port);
409        let _writer = start(config, tx).unwrap();
410
411        let (mut server_stream, _) = listener.accept().unwrap();
412
413        // Drain InterfaceUp
414        let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
415
416        // Send frame in two parts
417        let payload: Vec<u8> = (0..32).collect();
418        let framed = hdlc::frame(&payload);
419        let mid = framed.len() / 2;
420
421        server_stream.write_all(&framed[..mid]).unwrap();
422        server_stream.flush().unwrap();
423        thread::sleep(Duration::from_millis(50));
424        server_stream.write_all(&framed[mid..]).unwrap();
425        server_stream.flush().unwrap();
426
427        let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
428        match event {
429            Event::Frame { data, .. } => assert_eq!(data, payload),
430            other => panic!("expected Frame, got {:?}", other),
431        }
432    }
433
434    #[test]
435    fn reconnect_on_close() {
436        let port = find_free_port();
437        let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
438        listener.set_nonblocking(false).unwrap();
439        let (tx, rx) = mpsc::channel();
440
441        let config = make_config(port);
442        let _writer = start(config, tx).unwrap();
443
444        // Accept first connection and immediately close it
445        let (server_stream, _) = listener.accept().unwrap();
446
447        // Drain InterfaceUp
448        let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
449
450        drop(server_stream);
451
452        // Should get InterfaceDown
453        let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
454        assert!(matches!(event, Event::InterfaceDown(InterfaceId(1))));
455
456        // Accept the reconnection
457        let _server_stream2 = listener.accept().unwrap();
458
459        // Should get InterfaceUp again
460        let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
461        assert!(matches!(event, Event::InterfaceUp(InterfaceId(1), _, _)));
462    }
463
464    #[test]
465    fn socket_options() {
466        let port = find_free_port();
467        let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
468
469        let stream = try_connect(&make_config(port)).unwrap();
470        let _server = listener.accept().unwrap();
471
472        // Verify TCP_NODELAY is set
473        let fd = stream.as_raw_fd();
474        let mut val: libc::c_int = 0;
475        let mut len: libc::socklen_t = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
476        unsafe {
477            libc::getsockopt(
478                fd,
479                libc::IPPROTO_TCP,
480                libc::TCP_NODELAY,
481                &mut val as *mut _ as *mut libc::c_void,
482                &mut len,
483            );
484        }
485        assert_eq!(val, 1, "TCP_NODELAY should be 1");
486    }
487
488    #[test]
489    fn connect_timeout() {
490        // Use a non-routable address to trigger timeout
491        let config = TcpClientConfig {
492            name: "timeout-test".into(),
493            target_host: "192.0.2.1".into(), // TEST-NET, non-routable
494            target_port: 12345,
495            interface_id: InterfaceId(99),
496            reconnect_wait: Duration::from_millis(100),
497            max_reconnect_tries: Some(0),
498            connect_timeout: Duration::from_millis(500),
499        };
500
501        let start_time = std::time::Instant::now();
502        let result = try_connect(&config);
503        let elapsed = start_time.elapsed();
504
505        assert!(result.is_err());
506        // Should timeout roughly around 500ms, definitely under 5s
507        assert!(elapsed < Duration::from_secs(5));
508    }
509}