1use std::io::{self, Read, Write};
6use std::net::{TcpStream, ToSocketAddrs};
7use std::os::unix::io::AsRawFd;
8use std::thread;
9use std::time::Duration;
10
11use rns_core::transport::types::InterfaceId;
12
13use crate::event::{Event, EventSender};
14use crate::hdlc;
15use crate::interface::Writer;
16
17#[derive(Debug, Clone)]
19pub struct TcpClientConfig {
20 pub name: String,
21 pub target_host: String,
22 pub target_port: u16,
23 pub interface_id: InterfaceId,
24 pub reconnect_wait: Duration,
25 pub max_reconnect_tries: Option<u32>,
26 pub connect_timeout: Duration,
27 pub device: Option<String>,
29}
30
31impl Default for TcpClientConfig {
32 fn default() -> Self {
33 TcpClientConfig {
34 name: String::new(),
35 target_host: "127.0.0.1".into(),
36 target_port: 4242,
37 interface_id: InterfaceId(0),
38 reconnect_wait: Duration::from_secs(5),
39 max_reconnect_tries: None,
40 connect_timeout: Duration::from_secs(5),
41 device: None,
42 }
43 }
44}
45
46struct TcpWriter {
48 stream: TcpStream,
49}
50
51impl Writer for TcpWriter {
52 fn send_frame(&mut self, data: &[u8]) -> io::Result<()> {
53 self.stream.write_all(&hdlc::frame(data))
54 }
55}
56
57fn set_socket_options(stream: &TcpStream) -> io::Result<()> {
59 let fd = stream.as_raw_fd();
60 unsafe {
61 let val: libc::c_int = 1;
63 if libc::setsockopt(
64 fd,
65 libc::IPPROTO_TCP,
66 libc::TCP_NODELAY,
67 &val as *const _ as *const libc::c_void,
68 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
69 ) != 0
70 {
71 return Err(io::Error::last_os_error());
72 }
73
74 if libc::setsockopt(
76 fd,
77 libc::SOL_SOCKET,
78 libc::SO_KEEPALIVE,
79 &val as *const _ as *const libc::c_void,
80 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
81 ) != 0
82 {
83 return Err(io::Error::last_os_error());
84 }
85
86 #[cfg(target_os = "linux")]
88 {
89 let idle: libc::c_int = 5;
91 if libc::setsockopt(
92 fd,
93 libc::IPPROTO_TCP,
94 libc::TCP_KEEPIDLE,
95 &idle as *const _ as *const libc::c_void,
96 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
97 ) != 0
98 {
99 return Err(io::Error::last_os_error());
100 }
101
102 let intvl: libc::c_int = 2;
104 if libc::setsockopt(
105 fd,
106 libc::IPPROTO_TCP,
107 libc::TCP_KEEPINTVL,
108 &intvl as *const _ as *const libc::c_void,
109 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
110 ) != 0
111 {
112 return Err(io::Error::last_os_error());
113 }
114
115 let cnt: libc::c_int = 12;
117 if libc::setsockopt(
118 fd,
119 libc::IPPROTO_TCP,
120 libc::TCP_KEEPCNT,
121 &cnt as *const _ as *const libc::c_void,
122 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
123 ) != 0
124 {
125 return Err(io::Error::last_os_error());
126 }
127
128 let timeout: libc::c_int = 24_000;
130 if libc::setsockopt(
131 fd,
132 libc::IPPROTO_TCP,
133 libc::TCP_USER_TIMEOUT,
134 &timeout as *const _ as *const libc::c_void,
135 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
136 ) != 0
137 {
138 return Err(io::Error::last_os_error());
139 }
140 }
141 }
142 Ok(())
143}
144
145fn try_connect(config: &TcpClientConfig) -> io::Result<TcpStream> {
147 let addr_str = format!("{}:{}", config.target_host, config.target_port);
148 let addr = addr_str
149 .to_socket_addrs()?
150 .next()
151 .ok_or_else(|| io::Error::new(io::ErrorKind::AddrNotAvailable, "no addresses resolved"))?;
152
153 #[cfg(target_os = "linux")]
154 let stream = if let Some(ref device) = config.device {
155 connect_with_device(&addr, device, config.connect_timeout)?
156 } else {
157 TcpStream::connect_timeout(&addr, config.connect_timeout)?
158 };
159 #[cfg(not(target_os = "linux"))]
160 let stream = TcpStream::connect_timeout(&addr, config.connect_timeout)?;
161 set_socket_options(&stream)?;
162 Ok(stream)
163}
164
165#[cfg(target_os = "linux")]
167fn connect_with_device(
168 addr: &std::net::SocketAddr,
169 device: &str,
170 timeout: Duration,
171) -> io::Result<TcpStream> {
172 use std::os::unix::io::{FromRawFd, RawFd};
173
174 let domain = if addr.is_ipv4() {
175 libc::AF_INET
176 } else {
177 libc::AF_INET6
178 };
179 let fd: RawFd = unsafe { libc::socket(domain, libc::SOCK_STREAM, 0) };
180 if fd < 0 {
181 return Err(io::Error::last_os_error());
182 }
183
184 let stream = unsafe { TcpStream::from_raw_fd(fd) };
186
187 super::bind_to_device(stream.as_raw_fd(), device)?;
188
189 stream.set_nonblocking(true)?;
191
192 let (sockaddr, socklen) = socket_addr_to_raw(addr);
193 let ret = unsafe {
194 libc::connect(
195 stream.as_raw_fd(),
196 &sockaddr as *const libc::sockaddr_storage as *const libc::sockaddr,
197 socklen,
198 )
199 };
200
201 if ret != 0 {
202 let err = io::Error::last_os_error();
203 if err.raw_os_error() != Some(libc::EINPROGRESS) {
204 return Err(err);
205 }
206 }
207
208 let mut pollfd = libc::pollfd {
210 fd: stream.as_raw_fd(),
211 events: libc::POLLOUT,
212 revents: 0,
213 };
214 let timeout_ms = timeout.as_millis() as libc::c_int;
215 let poll_ret = unsafe { libc::poll(&mut pollfd, 1, timeout_ms) };
216
217 if poll_ret == 0 {
218 return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
219 }
220 if poll_ret < 0 {
221 return Err(io::Error::last_os_error());
222 }
223
224 let mut err_val: libc::c_int = 0;
226 let mut err_len: libc::socklen_t = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
227 let ret = unsafe {
228 libc::getsockopt(
229 stream.as_raw_fd(),
230 libc::SOL_SOCKET,
231 libc::SO_ERROR,
232 &mut err_val as *mut _ as *mut libc::c_void,
233 &mut err_len,
234 )
235 };
236 if ret != 0 {
237 return Err(io::Error::last_os_error());
238 }
239 if err_val != 0 {
240 return Err(io::Error::from_raw_os_error(err_val));
241 }
242
243 stream.set_nonblocking(false)?;
245
246 Ok(stream)
247}
248
249#[cfg(target_os = "linux")]
251fn socket_addr_to_raw(addr: &std::net::SocketAddr) -> (libc::sockaddr_storage, libc::socklen_t) {
252 let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
253 match addr {
254 std::net::SocketAddr::V4(v4) => {
255 let sin: &mut libc::sockaddr_in = unsafe {
256 &mut *(&mut storage as *mut libc::sockaddr_storage as *mut libc::sockaddr_in)
257 };
258 sin.sin_family = libc::AF_INET as libc::sa_family_t;
259 sin.sin_port = v4.port().to_be();
260 sin.sin_addr = libc::in_addr {
261 s_addr: u32::from_ne_bytes(v4.ip().octets()),
262 };
263 (
264 storage,
265 std::mem::size_of::<libc::sockaddr_in>() as libc::socklen_t,
266 )
267 }
268 std::net::SocketAddr::V6(v6) => {
269 let sin6: &mut libc::sockaddr_in6 = unsafe {
270 &mut *(&mut storage as *mut libc::sockaddr_storage as *mut libc::sockaddr_in6)
271 };
272 sin6.sin6_family = libc::AF_INET6 as libc::sa_family_t;
273 sin6.sin6_port = v6.port().to_be();
274 sin6.sin6_addr = libc::in6_addr {
275 s6_addr: v6.ip().octets(),
276 };
277 sin6.sin6_flowinfo = v6.flowinfo();
278 sin6.sin6_scope_id = v6.scope_id();
279 (
280 storage,
281 std::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t,
282 )
283 }
284 }
285}
286
287pub fn start(config: TcpClientConfig, tx: EventSender) -> io::Result<Box<dyn Writer>> {
289 let stream = try_connect(&config)?;
290 let reader_stream = stream.try_clone()?;
291 let writer_stream = stream.try_clone()?;
292
293 let id = config.interface_id;
294 let _ = tx.send(Event::InterfaceUp(id, None, None));
296
297 let reader_config = config;
299 let reader_tx = tx;
300 thread::Builder::new()
301 .name(format!("tcp-reader-{}", id.0))
302 .spawn(move || {
303 reader_loop(reader_stream, reader_config, reader_tx);
304 })?;
305
306 Ok(Box::new(TcpWriter {
307 stream: writer_stream,
308 }))
309}
310
311fn reader_loop(mut stream: TcpStream, config: TcpClientConfig, tx: EventSender) {
314 let id = config.interface_id;
315 let mut decoder = hdlc::Decoder::new();
316 let mut buf = [0u8; 4096];
317
318 loop {
319 match stream.read(&mut buf) {
320 Ok(0) => {
321 log::warn!("[{}] connection closed", config.name);
323 let _ = tx.send(Event::InterfaceDown(id));
324 match reconnect(&config, &tx) {
325 Some(new_stream) => {
326 stream = new_stream;
327 decoder = hdlc::Decoder::new();
328 continue;
329 }
330 None => {
331 log::error!("[{}] reconnection failed, giving up", config.name);
332 return;
333 }
334 }
335 }
336 Ok(n) => {
337 for frame in decoder.feed(&buf[..n]) {
338 if tx
339 .send(Event::Frame {
340 interface_id: id,
341 data: frame,
342 })
343 .is_err()
344 {
345 return;
347 }
348 }
349 }
350 Err(e) => {
351 log::warn!("[{}] read error: {}", config.name, e);
352 let _ = tx.send(Event::InterfaceDown(id));
353 match reconnect(&config, &tx) {
354 Some(new_stream) => {
355 stream = new_stream;
356 decoder = hdlc::Decoder::new();
357 continue;
358 }
359 None => {
360 log::error!("[{}] reconnection failed, giving up", config.name);
361 return;
362 }
363 }
364 }
365 }
366 }
367}
368
369fn reconnect(config: &TcpClientConfig, tx: &EventSender) -> Option<TcpStream> {
372 let mut attempts = 0u32;
373 loop {
374 thread::sleep(config.reconnect_wait);
375 attempts += 1;
376
377 if let Some(max) = config.max_reconnect_tries {
378 if attempts > max {
379 let _ = tx.send(Event::InterfaceDown(config.interface_id));
380 return None;
381 }
382 }
383
384 log::info!("[{}] reconnect attempt {} ...", config.name, attempts);
385
386 match try_connect(config) {
387 Ok(new_stream) => {
388 let writer_stream = match new_stream.try_clone() {
390 Ok(s) => s,
391 Err(e) => {
392 log::warn!("[{}] failed to clone stream: {}", config.name, e);
393 continue;
394 }
395 };
396 log::info!("[{}] reconnected", config.name);
397 let new_writer: Box<dyn Writer> = Box::new(TcpWriter {
399 stream: writer_stream,
400 });
401 let _ = tx.send(Event::InterfaceUp(
402 config.interface_id,
403 Some(new_writer),
404 None,
405 ));
406 return Some(new_stream);
407 }
408 Err(e) => {
409 log::warn!("[{}] reconnect failed: {}", config.name, e);
410 }
411 }
412 }
413}
414
415use super::{InterfaceConfigData, InterfaceFactory, StartContext, StartResult};
418use rns_core::transport::types::InterfaceInfo;
419use std::collections::HashMap;
420
421pub struct TcpClientFactory;
423
424impl InterfaceFactory for TcpClientFactory {
425 fn type_name(&self) -> &str {
426 "TCPClientInterface"
427 }
428
429 fn parse_config(
430 &self,
431 name: &str,
432 id: InterfaceId,
433 params: &HashMap<String, String>,
434 ) -> Result<Box<dyn InterfaceConfigData>, String> {
435 let target_host = params
436 .get("target_host")
437 .cloned()
438 .unwrap_or_else(|| "127.0.0.1".into());
439 let target_port = params
440 .get("target_port")
441 .and_then(|v| v.parse().ok())
442 .unwrap_or(4242);
443
444 Ok(Box::new(TcpClientConfig {
445 name: name.to_string(),
446 target_host,
447 target_port,
448 interface_id: id,
449 device: params.get("device").cloned(),
450 ..TcpClientConfig::default()
451 }))
452 }
453
454 fn start(
455 &self,
456 config: Box<dyn InterfaceConfigData>,
457 ctx: StartContext,
458 ) -> io::Result<StartResult> {
459 let tcp_config = *config
460 .into_any()
461 .downcast::<TcpClientConfig>()
462 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "wrong config type"))?;
463
464 let id = tcp_config.interface_id;
465 let name = tcp_config.name.clone();
466 let info = InterfaceInfo {
467 id,
468 name,
469 mode: ctx.mode,
470 out_capable: true,
471 in_capable: true,
472 bitrate: None,
473 announce_rate_target: None,
474 announce_rate_grace: 0,
475 announce_rate_penalty: 0.0,
476 announce_cap: rns_core::constants::ANNOUNCE_CAP,
477 is_local_client: false,
478 wants_tunnel: false,
479 tunnel_id: None,
480 mtu: 65535,
481 ingress_control: true,
482 ia_freq: 0.0,
483 started: crate::time::now(),
484 };
485
486 let writer = start(tcp_config, ctx.tx)?;
487
488 Ok(StartResult::Simple {
489 id,
490 info,
491 writer,
492 interface_type_name: "TCPClientInterface".to_string(),
493 })
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500 use std::net::TcpListener;
501 use std::sync::mpsc;
502 use std::time::Duration;
503
504 fn find_free_port() -> u16 {
505 TcpListener::bind("127.0.0.1:0")
506 .unwrap()
507 .local_addr()
508 .unwrap()
509 .port()
510 }
511
512 fn make_config(port: u16) -> TcpClientConfig {
513 TcpClientConfig {
514 name: format!("test-tcp-{}", port),
515 target_host: "127.0.0.1".into(),
516 target_port: port,
517 interface_id: InterfaceId(1),
518 reconnect_wait: Duration::from_millis(100),
519 max_reconnect_tries: Some(2),
520 connect_timeout: Duration::from_secs(2),
521 device: None,
522 }
523 }
524
525 #[test]
526 fn connect_to_listener() {
527 let port = find_free_port();
528 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
529 let (tx, rx) = mpsc::channel();
530
531 let config = make_config(port);
532 let _writer = start(config, tx).unwrap();
533
534 let _server_stream = listener.accept().unwrap();
536
537 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
539 assert!(matches!(event, Event::InterfaceUp(InterfaceId(1), _, _)));
540 }
541
542 #[test]
543 fn receive_frame() {
544 let port = find_free_port();
545 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
546 let (tx, rx) = mpsc::channel();
547
548 let config = make_config(port);
549 let _writer = start(config, tx).unwrap();
550
551 let (mut server_stream, _) = listener.accept().unwrap();
552
553 let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
555
556 let payload: Vec<u8> = (0..32).collect();
558 let framed = hdlc::frame(&payload);
559 server_stream.write_all(&framed).unwrap();
560
561 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
563 match event {
564 Event::Frame { interface_id, data } => {
565 assert_eq!(interface_id, InterfaceId(1));
566 assert_eq!(data, payload);
567 }
568 other => panic!("expected Frame, got {:?}", other),
569 }
570 }
571
572 #[test]
573 fn send_frame() {
574 let port = find_free_port();
575 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
576 let (tx, _rx) = mpsc::channel();
577
578 let config = make_config(port);
579 let mut writer = start(config, tx).unwrap();
580
581 let (mut server_stream, _) = listener.accept().unwrap();
582 server_stream
583 .set_read_timeout(Some(Duration::from_secs(2)))
584 .unwrap();
585
586 let payload: Vec<u8> = (0..24).collect();
588 writer.send_frame(&payload).unwrap();
589
590 let mut buf = [0u8; 256];
592 let n = server_stream.read(&mut buf).unwrap();
593 let expected = hdlc::frame(&payload);
594 assert_eq!(&buf[..n], &expected[..]);
595 }
596
597 #[test]
598 fn multiple_frames() {
599 let port = find_free_port();
600 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
601 let (tx, rx) = mpsc::channel();
602
603 let config = make_config(port);
604 let _writer = start(config, tx).unwrap();
605
606 let (mut server_stream, _) = listener.accept().unwrap();
607
608 let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
610
611 let payloads: Vec<Vec<u8>> = (0..3)
613 .map(|i| (0..24).map(|j| j + i * 50).collect())
614 .collect();
615 for p in &payloads {
616 server_stream.write_all(&hdlc::frame(p)).unwrap();
617 }
618
619 for expected in &payloads {
621 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
622 match event {
623 Event::Frame { data, .. } => assert_eq!(&data, expected),
624 other => panic!("expected Frame, got {:?}", other),
625 }
626 }
627 }
628
629 #[test]
630 fn split_across_reads() {
631 let port = find_free_port();
632 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
633 let (tx, rx) = mpsc::channel();
634
635 let config = make_config(port);
636 let _writer = start(config, tx).unwrap();
637
638 let (mut server_stream, _) = listener.accept().unwrap();
639
640 let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
642
643 let payload: Vec<u8> = (0..32).collect();
645 let framed = hdlc::frame(&payload);
646 let mid = framed.len() / 2;
647
648 server_stream.write_all(&framed[..mid]).unwrap();
649 server_stream.flush().unwrap();
650 thread::sleep(Duration::from_millis(50));
651 server_stream.write_all(&framed[mid..]).unwrap();
652 server_stream.flush().unwrap();
653
654 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
655 match event {
656 Event::Frame { data, .. } => assert_eq!(data, payload),
657 other => panic!("expected Frame, got {:?}", other),
658 }
659 }
660
661 #[test]
662 fn reconnect_on_close() {
663 let port = find_free_port();
664 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
665 listener.set_nonblocking(false).unwrap();
666 let (tx, rx) = mpsc::channel();
667
668 let config = make_config(port);
669 let _writer = start(config, tx).unwrap();
670
671 let (server_stream, _) = listener.accept().unwrap();
673
674 let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
676
677 drop(server_stream);
678
679 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
681 assert!(matches!(event, Event::InterfaceDown(InterfaceId(1))));
682
683 let _server_stream2 = listener.accept().unwrap();
685
686 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
688 assert!(matches!(event, Event::InterfaceUp(InterfaceId(1), _, _)));
689 }
690
691 #[test]
692 fn socket_options() {
693 let port = find_free_port();
694 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
695
696 let stream = try_connect(&make_config(port)).unwrap();
697 let _server = listener.accept().unwrap();
698
699 let fd = stream.as_raw_fd();
701 let mut val: libc::c_int = 0;
702 let mut len: libc::socklen_t = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
703 unsafe {
704 libc::getsockopt(
705 fd,
706 libc::IPPROTO_TCP,
707 libc::TCP_NODELAY,
708 &mut val as *mut _ as *mut libc::c_void,
709 &mut len,
710 );
711 }
712 assert_eq!(val, 1, "TCP_NODELAY should be 1");
713 }
714
715 #[test]
716 fn connect_timeout() {
717 let config = TcpClientConfig {
719 name: "timeout-test".into(),
720 target_host: "192.0.2.1".into(), target_port: 12345,
722 interface_id: InterfaceId(99),
723 reconnect_wait: Duration::from_millis(100),
724 max_reconnect_tries: Some(0),
725 connect_timeout: Duration::from_millis(500),
726 device: None,
727 };
728
729 let start_time = std::time::Instant::now();
730 let result = try_connect(&config);
731 let elapsed = start_time.elapsed();
732
733 assert!(result.is_err());
734 assert!(elapsed < Duration::from_secs(5));
736 }
737}