1use std::io::{self, Read, Write};
6use std::net::{TcpStream, ToSocketAddrs};
7use std::os::unix::io::AsRawFd;
8use std::thread;
9use std::time::Duration;
10
11use rns_core::transport::types::InterfaceId;
12
13use crate::event::{Event, EventSender};
14use crate::hdlc;
15use crate::interface::Writer;
16
17#[derive(Debug, Clone)]
19pub struct TcpClientConfig {
20 pub name: String,
21 pub target_host: String,
22 pub target_port: u16,
23 pub interface_id: InterfaceId,
24 pub reconnect_wait: Duration,
25 pub max_reconnect_tries: Option<u32>,
26 pub connect_timeout: Duration,
27 pub device: Option<String>,
29}
30
31impl Default for TcpClientConfig {
32 fn default() -> Self {
33 TcpClientConfig {
34 name: String::new(),
35 target_host: "127.0.0.1".into(),
36 target_port: 4242,
37 interface_id: InterfaceId(0),
38 reconnect_wait: Duration::from_secs(5),
39 max_reconnect_tries: None,
40 connect_timeout: Duration::from_secs(5),
41 device: None,
42 }
43 }
44}
45
46struct TcpWriter {
48 stream: TcpStream,
49}
50
51impl Writer for TcpWriter {
52 fn send_frame(&mut self, data: &[u8]) -> io::Result<()> {
53 self.stream.write_all(&hdlc::frame(data))
54 }
55}
56
57fn set_socket_options(stream: &TcpStream) -> io::Result<()> {
59 let fd = stream.as_raw_fd();
60 unsafe {
61 let val: libc::c_int = 1;
63 if libc::setsockopt(
64 fd,
65 libc::IPPROTO_TCP,
66 libc::TCP_NODELAY,
67 &val as *const _ as *const libc::c_void,
68 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
69 ) != 0
70 {
71 return Err(io::Error::last_os_error());
72 }
73
74 if libc::setsockopt(
76 fd,
77 libc::SOL_SOCKET,
78 libc::SO_KEEPALIVE,
79 &val as *const _ as *const libc::c_void,
80 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
81 ) != 0
82 {
83 return Err(io::Error::last_os_error());
84 }
85
86 #[cfg(target_os = "linux")]
88 {
89 let idle: libc::c_int = 5;
91 if libc::setsockopt(
92 fd,
93 libc::IPPROTO_TCP,
94 libc::TCP_KEEPIDLE,
95 &idle as *const _ as *const libc::c_void,
96 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
97 ) != 0
98 {
99 return Err(io::Error::last_os_error());
100 }
101
102 let intvl: libc::c_int = 2;
104 if libc::setsockopt(
105 fd,
106 libc::IPPROTO_TCP,
107 libc::TCP_KEEPINTVL,
108 &intvl as *const _ as *const libc::c_void,
109 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
110 ) != 0
111 {
112 return Err(io::Error::last_os_error());
113 }
114
115 let cnt: libc::c_int = 12;
117 if libc::setsockopt(
118 fd,
119 libc::IPPROTO_TCP,
120 libc::TCP_KEEPCNT,
121 &cnt as *const _ as *const libc::c_void,
122 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
123 ) != 0
124 {
125 return Err(io::Error::last_os_error());
126 }
127
128 let timeout: libc::c_int = 24_000;
130 if libc::setsockopt(
131 fd,
132 libc::IPPROTO_TCP,
133 libc::TCP_USER_TIMEOUT,
134 &timeout as *const _ as *const libc::c_void,
135 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
136 ) != 0
137 {
138 return Err(io::Error::last_os_error());
139 }
140 }
141 }
142 Ok(())
143}
144
145fn try_connect(config: &TcpClientConfig) -> io::Result<TcpStream> {
147 let addr_str = format!("{}:{}", config.target_host, config.target_port);
148 let addr = addr_str
149 .to_socket_addrs()?
150 .next()
151 .ok_or_else(|| io::Error::new(io::ErrorKind::AddrNotAvailable, "no addresses resolved"))?;
152
153 #[cfg(target_os = "linux")]
154 let stream = if let Some(ref device) = config.device {
155 connect_with_device(&addr, device, config.connect_timeout)?
156 } else {
157 TcpStream::connect_timeout(&addr, config.connect_timeout)?
158 };
159 #[cfg(not(target_os = "linux"))]
160 let stream = TcpStream::connect_timeout(&addr, config.connect_timeout)?;
161 set_socket_options(&stream)?;
162 Ok(stream)
163}
164
165#[cfg(target_os = "linux")]
167fn connect_with_device(
168 addr: &std::net::SocketAddr,
169 device: &str,
170 timeout: Duration,
171) -> io::Result<TcpStream> {
172 use std::os::unix::io::{FromRawFd, RawFd};
173
174 let domain = if addr.is_ipv4() { libc::AF_INET } else { libc::AF_INET6 };
175 let fd: RawFd = unsafe { libc::socket(domain, libc::SOCK_STREAM, 0) };
176 if fd < 0 {
177 return Err(io::Error::last_os_error());
178 }
179
180 let stream = unsafe { TcpStream::from_raw_fd(fd) };
182
183 super::bind_to_device(stream.as_raw_fd(), device)?;
184
185 stream.set_nonblocking(true)?;
187
188 let (sockaddr, socklen) = socket_addr_to_raw(addr);
189 let ret = unsafe {
190 libc::connect(
191 stream.as_raw_fd(),
192 &sockaddr as *const libc::sockaddr_storage as *const libc::sockaddr,
193 socklen,
194 )
195 };
196
197 if ret != 0 {
198 let err = io::Error::last_os_error();
199 if err.raw_os_error() != Some(libc::EINPROGRESS) {
200 return Err(err);
201 }
202 }
203
204 let mut pollfd = libc::pollfd {
206 fd: stream.as_raw_fd(),
207 events: libc::POLLOUT,
208 revents: 0,
209 };
210 let timeout_ms = timeout.as_millis() as libc::c_int;
211 let poll_ret = unsafe { libc::poll(&mut pollfd, 1, timeout_ms) };
212
213 if poll_ret == 0 {
214 return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
215 }
216 if poll_ret < 0 {
217 return Err(io::Error::last_os_error());
218 }
219
220 let mut err_val: libc::c_int = 0;
222 let mut err_len: libc::socklen_t = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
223 let ret = unsafe {
224 libc::getsockopt(
225 stream.as_raw_fd(),
226 libc::SOL_SOCKET,
227 libc::SO_ERROR,
228 &mut err_val as *mut _ as *mut libc::c_void,
229 &mut err_len,
230 )
231 };
232 if ret != 0 {
233 return Err(io::Error::last_os_error());
234 }
235 if err_val != 0 {
236 return Err(io::Error::from_raw_os_error(err_val));
237 }
238
239 stream.set_nonblocking(false)?;
241
242 Ok(stream)
243}
244
245#[cfg(target_os = "linux")]
247fn socket_addr_to_raw(addr: &std::net::SocketAddr) -> (libc::sockaddr_storage, libc::socklen_t) {
248 let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
249 match addr {
250 std::net::SocketAddr::V4(v4) => {
251 let sin: &mut libc::sockaddr_in = unsafe {
252 &mut *(&mut storage as *mut libc::sockaddr_storage as *mut libc::sockaddr_in)
253 };
254 sin.sin_family = libc::AF_INET as libc::sa_family_t;
255 sin.sin_port = v4.port().to_be();
256 sin.sin_addr = libc::in_addr {
257 s_addr: u32::from_ne_bytes(v4.ip().octets()),
258 };
259 (storage, std::mem::size_of::<libc::sockaddr_in>() as libc::socklen_t)
260 }
261 std::net::SocketAddr::V6(v6) => {
262 let sin6: &mut libc::sockaddr_in6 = unsafe {
263 &mut *(&mut storage as *mut libc::sockaddr_storage as *mut libc::sockaddr_in6)
264 };
265 sin6.sin6_family = libc::AF_INET6 as libc::sa_family_t;
266 sin6.sin6_port = v6.port().to_be();
267 sin6.sin6_addr = libc::in6_addr {
268 s6_addr: v6.ip().octets(),
269 };
270 sin6.sin6_flowinfo = v6.flowinfo();
271 sin6.sin6_scope_id = v6.scope_id();
272 (storage, std::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t)
273 }
274 }
275}
276
277pub fn start(config: TcpClientConfig, tx: EventSender) -> io::Result<Box<dyn Writer>> {
279 let stream = try_connect(&config)?;
280 let reader_stream = stream.try_clone()?;
281 let writer_stream = stream.try_clone()?;
282
283 let id = config.interface_id;
284 let _ = tx.send(Event::InterfaceUp(id, None, None));
286
287 let reader_config = config;
289 let reader_tx = tx;
290 thread::Builder::new()
291 .name(format!("tcp-reader-{}", id.0))
292 .spawn(move || {
293 reader_loop(reader_stream, reader_config, reader_tx);
294 })?;
295
296 Ok(Box::new(TcpWriter { stream: writer_stream }))
297}
298
299fn reader_loop(mut stream: TcpStream, config: TcpClientConfig, tx: EventSender) {
302 let id = config.interface_id;
303 let mut decoder = hdlc::Decoder::new();
304 let mut buf = [0u8; 4096];
305
306 loop {
307 match stream.read(&mut buf) {
308 Ok(0) => {
309 log::warn!("[{}] connection closed", config.name);
311 let _ = tx.send(Event::InterfaceDown(id));
312 match reconnect(&config, &tx) {
313 Some(new_stream) => {
314 stream = new_stream;
315 decoder = hdlc::Decoder::new();
316 continue;
317 }
318 None => {
319 log::error!("[{}] reconnection failed, giving up", config.name);
320 return;
321 }
322 }
323 }
324 Ok(n) => {
325 for frame in decoder.feed(&buf[..n]) {
326 if tx.send(Event::Frame { interface_id: id, data: frame }).is_err() {
327 return;
329 }
330 }
331 }
332 Err(e) => {
333 log::warn!("[{}] read error: {}", config.name, e);
334 let _ = tx.send(Event::InterfaceDown(id));
335 match reconnect(&config, &tx) {
336 Some(new_stream) => {
337 stream = new_stream;
338 decoder = hdlc::Decoder::new();
339 continue;
340 }
341 None => {
342 log::error!("[{}] reconnection failed, giving up", config.name);
343 return;
344 }
345 }
346 }
347 }
348 }
349}
350
351fn reconnect(config: &TcpClientConfig, tx: &EventSender) -> Option<TcpStream> {
354 let mut attempts = 0u32;
355 loop {
356 thread::sleep(config.reconnect_wait);
357 attempts += 1;
358
359 if let Some(max) = config.max_reconnect_tries {
360 if attempts > max {
361 let _ = tx.send(Event::InterfaceDown(config.interface_id));
362 return None;
363 }
364 }
365
366 log::info!(
367 "[{}] reconnect attempt {} ...",
368 config.name,
369 attempts
370 );
371
372 match try_connect(config) {
373 Ok(new_stream) => {
374 let writer_stream = match new_stream.try_clone() {
376 Ok(s) => s,
377 Err(e) => {
378 log::warn!("[{}] failed to clone stream: {}", config.name, e);
379 continue;
380 }
381 };
382 log::info!("[{}] reconnected", config.name);
383 let new_writer: Box<dyn Writer> = Box::new(TcpWriter { stream: writer_stream });
385 let _ = tx.send(Event::InterfaceUp(config.interface_id, Some(new_writer), None));
386 return Some(new_stream);
387 }
388 Err(e) => {
389 log::warn!("[{}] reconnect failed: {}", config.name, e);
390 }
391 }
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398 use std::net::TcpListener;
399 use std::sync::mpsc;
400 use std::time::Duration;
401
402 fn find_free_port() -> u16 {
403 TcpListener::bind("127.0.0.1:0")
404 .unwrap()
405 .local_addr()
406 .unwrap()
407 .port()
408 }
409
410 fn make_config(port: u16) -> TcpClientConfig {
411 TcpClientConfig {
412 name: format!("test-tcp-{}", port),
413 target_host: "127.0.0.1".into(),
414 target_port: port,
415 interface_id: InterfaceId(1),
416 reconnect_wait: Duration::from_millis(100),
417 max_reconnect_tries: Some(2),
418 connect_timeout: Duration::from_secs(2),
419 device: None,
420 }
421 }
422
423 #[test]
424 fn connect_to_listener() {
425 let port = find_free_port();
426 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
427 let (tx, rx) = mpsc::channel();
428
429 let config = make_config(port);
430 let _writer = start(config, tx).unwrap();
431
432 let _server_stream = listener.accept().unwrap();
434
435 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
437 assert!(matches!(event, Event::InterfaceUp(InterfaceId(1), _, _)));
438 }
439
440 #[test]
441 fn receive_frame() {
442 let port = find_free_port();
443 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
444 let (tx, rx) = mpsc::channel();
445
446 let config = make_config(port);
447 let _writer = start(config, tx).unwrap();
448
449 let (mut server_stream, _) = listener.accept().unwrap();
450
451 let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
453
454 let payload: Vec<u8> = (0..32).collect();
456 let framed = hdlc::frame(&payload);
457 server_stream.write_all(&framed).unwrap();
458
459 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
461 match event {
462 Event::Frame { interface_id, data } => {
463 assert_eq!(interface_id, InterfaceId(1));
464 assert_eq!(data, payload);
465 }
466 other => panic!("expected Frame, got {:?}", other),
467 }
468 }
469
470 #[test]
471 fn send_frame() {
472 let port = find_free_port();
473 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
474 let (tx, _rx) = mpsc::channel();
475
476 let config = make_config(port);
477 let mut writer = start(config, tx).unwrap();
478
479 let (mut server_stream, _) = listener.accept().unwrap();
480 server_stream
481 .set_read_timeout(Some(Duration::from_secs(2)))
482 .unwrap();
483
484 let payload: Vec<u8> = (0..24).collect();
486 writer.send_frame(&payload).unwrap();
487
488 let mut buf = [0u8; 256];
490 let n = server_stream.read(&mut buf).unwrap();
491 let expected = hdlc::frame(&payload);
492 assert_eq!(&buf[..n], &expected[..]);
493 }
494
495 #[test]
496 fn multiple_frames() {
497 let port = find_free_port();
498 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
499 let (tx, rx) = mpsc::channel();
500
501 let config = make_config(port);
502 let _writer = start(config, tx).unwrap();
503
504 let (mut server_stream, _) = listener.accept().unwrap();
505
506 let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
508
509 let payloads: Vec<Vec<u8>> = (0..3).map(|i| (0..24).map(|j| j + i * 50).collect()).collect();
511 for p in &payloads {
512 server_stream.write_all(&hdlc::frame(p)).unwrap();
513 }
514
515 for expected in &payloads {
517 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
518 match event {
519 Event::Frame { data, .. } => assert_eq!(&data, expected),
520 other => panic!("expected Frame, got {:?}", other),
521 }
522 }
523 }
524
525 #[test]
526 fn split_across_reads() {
527 let port = find_free_port();
528 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
529 let (tx, rx) = mpsc::channel();
530
531 let config = make_config(port);
532 let _writer = start(config, tx).unwrap();
533
534 let (mut server_stream, _) = listener.accept().unwrap();
535
536 let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
538
539 let payload: Vec<u8> = (0..32).collect();
541 let framed = hdlc::frame(&payload);
542 let mid = framed.len() / 2;
543
544 server_stream.write_all(&framed[..mid]).unwrap();
545 server_stream.flush().unwrap();
546 thread::sleep(Duration::from_millis(50));
547 server_stream.write_all(&framed[mid..]).unwrap();
548 server_stream.flush().unwrap();
549
550 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
551 match event {
552 Event::Frame { data, .. } => assert_eq!(data, payload),
553 other => panic!("expected Frame, got {:?}", other),
554 }
555 }
556
557 #[test]
558 fn reconnect_on_close() {
559 let port = find_free_port();
560 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
561 listener.set_nonblocking(false).unwrap();
562 let (tx, rx) = mpsc::channel();
563
564 let config = make_config(port);
565 let _writer = start(config, tx).unwrap();
566
567 let (server_stream, _) = listener.accept().unwrap();
569
570 let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
572
573 drop(server_stream);
574
575 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
577 assert!(matches!(event, Event::InterfaceDown(InterfaceId(1))));
578
579 let _server_stream2 = listener.accept().unwrap();
581
582 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
584 assert!(matches!(event, Event::InterfaceUp(InterfaceId(1), _, _)));
585 }
586
587 #[test]
588 fn socket_options() {
589 let port = find_free_port();
590 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
591
592 let stream = try_connect(&make_config(port)).unwrap();
593 let _server = listener.accept().unwrap();
594
595 let fd = stream.as_raw_fd();
597 let mut val: libc::c_int = 0;
598 let mut len: libc::socklen_t = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
599 unsafe {
600 libc::getsockopt(
601 fd,
602 libc::IPPROTO_TCP,
603 libc::TCP_NODELAY,
604 &mut val as *mut _ as *mut libc::c_void,
605 &mut len,
606 );
607 }
608 assert_eq!(val, 1, "TCP_NODELAY should be 1");
609 }
610
611 #[test]
612 fn connect_timeout() {
613 let config = TcpClientConfig {
615 name: "timeout-test".into(),
616 target_host: "192.0.2.1".into(), target_port: 12345,
618 interface_id: InterfaceId(99),
619 reconnect_wait: Duration::from_millis(100),
620 max_reconnect_tries: Some(0),
621 connect_timeout: Duration::from_millis(500),
622 device: None,
623 };
624
625 let start_time = std::time::Instant::now();
626 let result = try_connect(&config);
627 let elapsed = start_time.elapsed();
628
629 assert!(result.is_err());
630 assert!(elapsed < Duration::from_secs(5));
632 }
633}