1use std::io::{self, Read, Write};
6use std::net::{TcpStream, ToSocketAddrs};
7use std::os::unix::io::AsRawFd;
8use std::sync::{Arc, Mutex};
9use std::thread;
10use std::time::Duration;
11
12use rns_core::transport::types::InterfaceId;
13
14use crate::event::{Event, EventSender};
15use crate::hdlc;
16use crate::interface::Writer;
17
18#[derive(Debug, Clone)]
20pub struct TcpClientConfig {
21 pub name: String,
22 pub target_host: String,
23 pub target_port: u16,
24 pub interface_id: InterfaceId,
25 pub reconnect_wait: Duration,
26 pub max_reconnect_tries: Option<u32>,
27 pub connect_timeout: Duration,
28 pub device: Option<String>,
30 pub runtime: Arc<Mutex<TcpClientRuntime>>,
31}
32
33#[derive(Debug, Clone)]
34pub struct TcpClientRuntime {
35 pub target_host: String,
36 pub target_port: u16,
37 pub reconnect_wait: Duration,
38 pub max_reconnect_tries: Option<u32>,
39 pub connect_timeout: Duration,
40}
41
42impl TcpClientRuntime {
43 pub fn from_config(config: &TcpClientConfig) -> Self {
44 Self {
45 target_host: config.target_host.clone(),
46 target_port: config.target_port,
47 reconnect_wait: config.reconnect_wait,
48 max_reconnect_tries: config.max_reconnect_tries,
49 connect_timeout: config.connect_timeout,
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
55pub struct TcpClientRuntimeConfigHandle {
56 pub interface_name: String,
57 pub runtime: Arc<Mutex<TcpClientRuntime>>,
58 pub startup: TcpClientRuntime,
59}
60
61impl Default for TcpClientConfig {
62 fn default() -> Self {
63 let mut config = TcpClientConfig {
64 name: String::new(),
65 target_host: "127.0.0.1".into(),
66 target_port: 4242,
67 interface_id: InterfaceId(0),
68 reconnect_wait: Duration::from_secs(5),
69 max_reconnect_tries: None,
70 connect_timeout: Duration::from_secs(5),
71 device: None,
72 runtime: Arc::new(Mutex::new(TcpClientRuntime {
73 target_host: "127.0.0.1".into(),
74 target_port: 4242,
75 reconnect_wait: Duration::from_secs(5),
76 max_reconnect_tries: None,
77 connect_timeout: Duration::from_secs(5),
78 })),
79 };
80 let startup = TcpClientRuntime::from_config(&config);
81 config.runtime = Arc::new(Mutex::new(startup));
82 config
83 }
84}
85
86struct TcpWriter {
88 stream: TcpStream,
89}
90
91impl Writer for TcpWriter {
92 fn send_frame(&mut self, data: &[u8]) -> io::Result<()> {
93 self.stream.write_all(&hdlc::frame(data))
94 }
95}
96
97fn set_socket_options(stream: &TcpStream) -> io::Result<()> {
99 let fd = stream.as_raw_fd();
100 unsafe {
101 let val: libc::c_int = 1;
103 if libc::setsockopt(
104 fd,
105 libc::IPPROTO_TCP,
106 libc::TCP_NODELAY,
107 &val as *const _ as *const libc::c_void,
108 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
109 ) != 0
110 {
111 return Err(io::Error::last_os_error());
112 }
113
114 if libc::setsockopt(
116 fd,
117 libc::SOL_SOCKET,
118 libc::SO_KEEPALIVE,
119 &val as *const _ as *const libc::c_void,
120 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
121 ) != 0
122 {
123 return Err(io::Error::last_os_error());
124 }
125
126 #[cfg(target_os = "linux")]
128 {
129 let idle: libc::c_int = 5;
131 if libc::setsockopt(
132 fd,
133 libc::IPPROTO_TCP,
134 libc::TCP_KEEPIDLE,
135 &idle as *const _ as *const libc::c_void,
136 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
137 ) != 0
138 {
139 return Err(io::Error::last_os_error());
140 }
141
142 let intvl: libc::c_int = 2;
144 if libc::setsockopt(
145 fd,
146 libc::IPPROTO_TCP,
147 libc::TCP_KEEPINTVL,
148 &intvl as *const _ as *const libc::c_void,
149 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
150 ) != 0
151 {
152 return Err(io::Error::last_os_error());
153 }
154
155 let cnt: libc::c_int = 12;
157 if libc::setsockopt(
158 fd,
159 libc::IPPROTO_TCP,
160 libc::TCP_KEEPCNT,
161 &cnt as *const _ as *const libc::c_void,
162 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
163 ) != 0
164 {
165 return Err(io::Error::last_os_error());
166 }
167
168 let timeout: libc::c_int = 24_000;
170 if libc::setsockopt(
171 fd,
172 libc::IPPROTO_TCP,
173 libc::TCP_USER_TIMEOUT,
174 &timeout as *const _ as *const libc::c_void,
175 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
176 ) != 0
177 {
178 return Err(io::Error::last_os_error());
179 }
180 }
181 }
182 Ok(())
183}
184
185fn try_connect(config: &TcpClientConfig) -> io::Result<TcpStream> {
187 let runtime = config.runtime.lock().unwrap().clone();
188 let addr_str = format!("{}:{}", config.target_host, config.target_port);
189 let addr = addr_str
190 .to_socket_addrs()?
191 .next()
192 .ok_or_else(|| io::Error::new(io::ErrorKind::AddrNotAvailable, "no addresses resolved"))?;
193
194 #[cfg(target_os = "linux")]
195 let stream = if let Some(ref device) = config.device {
196 connect_with_device(&addr, device, runtime.connect_timeout)?
197 } else {
198 TcpStream::connect_timeout(&addr, runtime.connect_timeout)?
199 };
200 #[cfg(not(target_os = "linux"))]
201 let stream = TcpStream::connect_timeout(&addr, runtime.connect_timeout)?;
202 set_socket_options(&stream)?;
203 Ok(stream)
204}
205
206#[cfg(target_os = "linux")]
208fn connect_with_device(
209 addr: &std::net::SocketAddr,
210 device: &str,
211 timeout: Duration,
212) -> io::Result<TcpStream> {
213 use std::os::unix::io::{FromRawFd, RawFd};
214
215 let domain = if addr.is_ipv4() {
216 libc::AF_INET
217 } else {
218 libc::AF_INET6
219 };
220 let fd: RawFd = unsafe { libc::socket(domain, libc::SOCK_STREAM, 0) };
221 if fd < 0 {
222 return Err(io::Error::last_os_error());
223 }
224
225 let stream = unsafe { TcpStream::from_raw_fd(fd) };
227
228 super::bind_to_device(stream.as_raw_fd(), device)?;
229
230 stream.set_nonblocking(true)?;
232
233 let (sockaddr, socklen) = socket_addr_to_raw(addr);
234 let ret = unsafe {
235 libc::connect(
236 stream.as_raw_fd(),
237 &sockaddr as *const libc::sockaddr_storage as *const libc::sockaddr,
238 socklen,
239 )
240 };
241
242 if ret != 0 {
243 let err = io::Error::last_os_error();
244 if err.raw_os_error() != Some(libc::EINPROGRESS) {
245 return Err(err);
246 }
247 }
248
249 let mut pollfd = libc::pollfd {
251 fd: stream.as_raw_fd(),
252 events: libc::POLLOUT,
253 revents: 0,
254 };
255 let timeout_ms = timeout.as_millis() as libc::c_int;
256 let poll_ret = unsafe { libc::poll(&mut pollfd, 1, timeout_ms) };
257
258 if poll_ret == 0 {
259 return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
260 }
261 if poll_ret < 0 {
262 return Err(io::Error::last_os_error());
263 }
264
265 let mut err_val: libc::c_int = 0;
267 let mut err_len: libc::socklen_t = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
268 let ret = unsafe {
269 libc::getsockopt(
270 stream.as_raw_fd(),
271 libc::SOL_SOCKET,
272 libc::SO_ERROR,
273 &mut err_val as *mut _ as *mut libc::c_void,
274 &mut err_len,
275 )
276 };
277 if ret != 0 {
278 return Err(io::Error::last_os_error());
279 }
280 if err_val != 0 {
281 return Err(io::Error::from_raw_os_error(err_val));
282 }
283
284 stream.set_nonblocking(false)?;
286
287 Ok(stream)
288}
289
290#[cfg(target_os = "linux")]
292fn socket_addr_to_raw(addr: &std::net::SocketAddr) -> (libc::sockaddr_storage, libc::socklen_t) {
293 let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
294 match addr {
295 std::net::SocketAddr::V4(v4) => {
296 let sin: &mut libc::sockaddr_in = unsafe {
297 &mut *(&mut storage as *mut libc::sockaddr_storage as *mut libc::sockaddr_in)
298 };
299 sin.sin_family = libc::AF_INET as libc::sa_family_t;
300 sin.sin_port = v4.port().to_be();
301 sin.sin_addr = libc::in_addr {
302 s_addr: u32::from_ne_bytes(v4.ip().octets()),
303 };
304 (
305 storage,
306 std::mem::size_of::<libc::sockaddr_in>() as libc::socklen_t,
307 )
308 }
309 std::net::SocketAddr::V6(v6) => {
310 let sin6: &mut libc::sockaddr_in6 = unsafe {
311 &mut *(&mut storage as *mut libc::sockaddr_storage as *mut libc::sockaddr_in6)
312 };
313 sin6.sin6_family = libc::AF_INET6 as libc::sa_family_t;
314 sin6.sin6_port = v6.port().to_be();
315 sin6.sin6_addr = libc::in6_addr {
316 s6_addr: v6.ip().octets(),
317 };
318 sin6.sin6_flowinfo = v6.flowinfo();
319 sin6.sin6_scope_id = v6.scope_id();
320 (
321 storage,
322 std::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t,
323 )
324 }
325 }
326}
327
328pub fn start(config: TcpClientConfig, tx: EventSender) -> io::Result<Box<dyn Writer>> {
330 let stream = try_connect(&config)?;
331 let reader_stream = stream.try_clone()?;
332 let writer_stream = stream.try_clone()?;
333
334 let id = config.interface_id;
335 let _ = tx.send(Event::InterfaceUp(id, None, None));
337
338 let reader_config = config;
340 let reader_tx = tx;
341 thread::Builder::new()
342 .name(format!("tcp-reader-{}", id.0))
343 .spawn(move || {
344 reader_loop(reader_stream, reader_config, reader_tx);
345 })?;
346
347 Ok(Box::new(TcpWriter {
348 stream: writer_stream,
349 }))
350}
351
352fn reader_loop(mut stream: TcpStream, config: TcpClientConfig, tx: EventSender) {
355 let id = config.interface_id;
356 let mut decoder = hdlc::Decoder::new();
357 let mut buf = [0u8; 4096];
358
359 loop {
360 match stream.read(&mut buf) {
361 Ok(0) => {
362 log::warn!("[{}] connection closed", config.name);
364 let _ = tx.send(Event::InterfaceDown(id));
365 match reconnect(&config, &tx) {
366 Some(new_stream) => {
367 stream = new_stream;
368 decoder = hdlc::Decoder::new();
369 continue;
370 }
371 None => {
372 log::error!("[{}] reconnection failed, giving up", config.name);
373 return;
374 }
375 }
376 }
377 Ok(n) => {
378 for frame in decoder.feed(&buf[..n]) {
379 if tx
380 .send(Event::Frame {
381 interface_id: id,
382 data: frame,
383 })
384 .is_err()
385 {
386 return;
388 }
389 }
390 }
391 Err(e) => {
392 log::warn!("[{}] read error: {}", config.name, e);
393 let _ = tx.send(Event::InterfaceDown(id));
394 match reconnect(&config, &tx) {
395 Some(new_stream) => {
396 stream = new_stream;
397 decoder = hdlc::Decoder::new();
398 continue;
399 }
400 None => {
401 log::error!("[{}] reconnection failed, giving up", config.name);
402 return;
403 }
404 }
405 }
406 }
407 }
408}
409
410fn reconnect(config: &TcpClientConfig, tx: &EventSender) -> Option<TcpStream> {
413 let mut attempts = 0u32;
414 loop {
415 let runtime = config.runtime.lock().unwrap().clone();
416 thread::sleep(runtime.reconnect_wait);
417 attempts += 1;
418
419 if let Some(max) = runtime.max_reconnect_tries {
420 if attempts > max {
421 let _ = tx.send(Event::InterfaceDown(config.interface_id));
422 return None;
423 }
424 }
425
426 log::info!("[{}] reconnect attempt {} ...", config.name, attempts);
427
428 match try_connect(config) {
429 Ok(new_stream) => {
430 let writer_stream = match new_stream.try_clone() {
432 Ok(s) => s,
433 Err(e) => {
434 log::warn!("[{}] failed to clone stream: {}", config.name, e);
435 continue;
436 }
437 };
438 log::info!("[{}] reconnected", config.name);
439 let new_writer: Box<dyn Writer> = Box::new(TcpWriter {
441 stream: writer_stream,
442 });
443 let _ = tx.send(Event::InterfaceUp(
444 config.interface_id,
445 Some(new_writer),
446 None,
447 ));
448 return Some(new_stream);
449 }
450 Err(e) => {
451 log::warn!("[{}] reconnect failed: {}", config.name, e);
452 }
453 }
454 }
455}
456
457use super::{InterfaceConfigData, InterfaceFactory, StartContext, StartResult};
460use rns_core::transport::types::InterfaceInfo;
461use std::collections::HashMap;
462
463pub struct TcpClientFactory;
465
466impl InterfaceFactory for TcpClientFactory {
467 fn type_name(&self) -> &str {
468 "TCPClientInterface"
469 }
470
471 fn parse_config(
472 &self,
473 name: &str,
474 id: InterfaceId,
475 params: &HashMap<String, String>,
476 ) -> Result<Box<dyn InterfaceConfigData>, String> {
477 let target_host = params
478 .get("target_host")
479 .cloned()
480 .unwrap_or_else(|| "127.0.0.1".into());
481 let target_port = params
482 .get("target_port")
483 .and_then(|v| v.parse().ok())
484 .unwrap_or(4242);
485
486 Ok(Box::new(TcpClientConfig {
487 name: name.to_string(),
488 target_host,
489 target_port,
490 interface_id: id,
491 device: params.get("device").cloned(),
492 ..TcpClientConfig::default()
493 }))
494 }
495
496 fn start(
497 &self,
498 config: Box<dyn InterfaceConfigData>,
499 ctx: StartContext,
500 ) -> io::Result<StartResult> {
501 let tcp_config = *config
502 .into_any()
503 .downcast::<TcpClientConfig>()
504 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "wrong config type"))?;
505
506 let id = tcp_config.interface_id;
507 let name = tcp_config.name.clone();
508 let info = InterfaceInfo {
509 id,
510 name,
511 mode: ctx.mode,
512 out_capable: true,
513 in_capable: true,
514 bitrate: None,
515 announce_rate_target: None,
516 announce_rate_grace: 0,
517 announce_rate_penalty: 0.0,
518 announce_cap: rns_core::constants::ANNOUNCE_CAP,
519 is_local_client: false,
520 wants_tunnel: false,
521 tunnel_id: None,
522 mtu: 65535,
523 ingress_control: true,
524 ia_freq: 0.0,
525 started: crate::time::now(),
526 };
527
528 let writer = start(tcp_config, ctx.tx)?;
529
530 Ok(StartResult::Simple {
531 id,
532 info,
533 writer,
534 interface_type_name: "TCPClientInterface".to_string(),
535 })
536 }
537}
538
539pub(crate) fn tcp_client_runtime_handle_from_config(
540 config: &TcpClientConfig,
541) -> TcpClientRuntimeConfigHandle {
542 TcpClientRuntimeConfigHandle {
543 interface_name: config.name.clone(),
544 runtime: Arc::clone(&config.runtime),
545 startup: TcpClientRuntime::from_config(config),
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552 use std::net::TcpListener;
553 use std::time::Duration;
554
555 fn find_free_port() -> u16 {
556 TcpListener::bind("127.0.0.1:0")
557 .unwrap()
558 .local_addr()
559 .unwrap()
560 .port()
561 }
562
563 fn make_config(port: u16) -> TcpClientConfig {
564 TcpClientConfig {
565 name: format!("test-tcp-{}", port),
566 target_host: "127.0.0.1".into(),
567 target_port: port,
568 interface_id: InterfaceId(1),
569 reconnect_wait: Duration::from_millis(100),
570 max_reconnect_tries: Some(2),
571 connect_timeout: Duration::from_secs(2),
572 runtime: Arc::new(Mutex::new(TcpClientRuntime {
573 target_host: "127.0.0.1".into(),
574 target_port: port,
575 reconnect_wait: Duration::from_millis(100),
576 max_reconnect_tries: Some(2),
577 connect_timeout: Duration::from_secs(2),
578 })),
579 device: None,
580 }
581 }
582
583 #[test]
584 fn connect_to_listener() {
585 let port = find_free_port();
586 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
587 let (tx, rx) = crate::event::channel();
588
589 let config = make_config(port);
590 let _writer = start(config, tx).unwrap();
591
592 let _server_stream = listener.accept().unwrap();
594
595 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
597 assert!(matches!(event, Event::InterfaceUp(InterfaceId(1), _, _)));
598 }
599
600 #[test]
601 fn receive_frame() {
602 let port = find_free_port();
603 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
604 let (tx, rx) = crate::event::channel();
605
606 let config = make_config(port);
607 let _writer = start(config, tx).unwrap();
608
609 let (mut server_stream, _) = listener.accept().unwrap();
610
611 let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
613
614 let payload: Vec<u8> = (0..32).collect();
616 let framed = hdlc::frame(&payload);
617 server_stream.write_all(&framed).unwrap();
618
619 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
621 match event {
622 Event::Frame { interface_id, data } => {
623 assert_eq!(interface_id, InterfaceId(1));
624 assert_eq!(data, payload);
625 }
626 other => panic!("expected Frame, got {:?}", other),
627 }
628 }
629
630 #[test]
631 fn send_frame() {
632 let port = find_free_port();
633 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
634 let (tx, _rx) = crate::event::channel();
635
636 let config = make_config(port);
637 let mut writer = start(config, tx).unwrap();
638
639 let (mut server_stream, _) = listener.accept().unwrap();
640 server_stream
641 .set_read_timeout(Some(Duration::from_secs(2)))
642 .unwrap();
643
644 let payload: Vec<u8> = (0..24).collect();
646 writer.send_frame(&payload).unwrap();
647
648 let mut buf = [0u8; 256];
650 let n = server_stream.read(&mut buf).unwrap();
651 let expected = hdlc::frame(&payload);
652 assert_eq!(&buf[..n], &expected[..]);
653 }
654
655 #[test]
656 fn multiple_frames() {
657 let port = find_free_port();
658 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
659 let (tx, rx) = crate::event::channel();
660
661 let config = make_config(port);
662 let _writer = start(config, tx).unwrap();
663
664 let (mut server_stream, _) = listener.accept().unwrap();
665
666 let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
668
669 let payloads: Vec<Vec<u8>> = (0..3)
671 .map(|i| (0..24).map(|j| j + i * 50).collect())
672 .collect();
673 for p in &payloads {
674 server_stream.write_all(&hdlc::frame(p)).unwrap();
675 }
676
677 for expected in &payloads {
679 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
680 match event {
681 Event::Frame { data, .. } => assert_eq!(&data, expected),
682 other => panic!("expected Frame, got {:?}", other),
683 }
684 }
685 }
686
687 #[test]
688 fn split_across_reads() {
689 let port = find_free_port();
690 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
691 let (tx, rx) = crate::event::channel();
692
693 let config = make_config(port);
694 let _writer = start(config, tx).unwrap();
695
696 let (mut server_stream, _) = listener.accept().unwrap();
697
698 let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
700
701 let payload: Vec<u8> = (0..32).collect();
703 let framed = hdlc::frame(&payload);
704 let mid = framed.len() / 2;
705
706 server_stream.write_all(&framed[..mid]).unwrap();
707 server_stream.flush().unwrap();
708 thread::sleep(Duration::from_millis(50));
709 server_stream.write_all(&framed[mid..]).unwrap();
710 server_stream.flush().unwrap();
711
712 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
713 match event {
714 Event::Frame { data, .. } => assert_eq!(data, payload),
715 other => panic!("expected Frame, got {:?}", other),
716 }
717 }
718
719 #[test]
720 fn reconnect_on_close() {
721 let port = find_free_port();
722 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
723 listener.set_nonblocking(false).unwrap();
724 let (tx, rx) = crate::event::channel();
725
726 let config = make_config(port);
727 let _writer = start(config, tx).unwrap();
728
729 let (server_stream, _) = listener.accept().unwrap();
731
732 let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
734
735 drop(server_stream);
736
737 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
739 assert!(matches!(event, Event::InterfaceDown(InterfaceId(1))));
740
741 let _server_stream2 = listener.accept().unwrap();
743
744 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
746 assert!(matches!(event, Event::InterfaceUp(InterfaceId(1), _, _)));
747 }
748
749 #[test]
750 fn socket_options() {
751 let port = find_free_port();
752 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
753
754 let stream = try_connect(&make_config(port)).unwrap();
755 let _server = listener.accept().unwrap();
756
757 let fd = stream.as_raw_fd();
759 let mut val: libc::c_int = 0;
760 let mut len: libc::socklen_t = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
761 unsafe {
762 libc::getsockopt(
763 fd,
764 libc::IPPROTO_TCP,
765 libc::TCP_NODELAY,
766 &mut val as *mut _ as *mut libc::c_void,
767 &mut len,
768 );
769 }
770 assert_eq!(val, 1, "TCP_NODELAY should be 1");
771 }
772
773 #[test]
774 fn connect_timeout() {
775 let config = TcpClientConfig {
777 name: "timeout-test".into(),
778 target_host: "192.0.2.1".into(), target_port: 12345,
780 interface_id: InterfaceId(99),
781 reconnect_wait: Duration::from_millis(100),
782 max_reconnect_tries: Some(0),
783 connect_timeout: Duration::from_millis(500),
784 device: None,
785 runtime: Arc::new(Mutex::new(TcpClientRuntime {
786 target_host: "192.0.2.1".into(),
787 target_port: 12345,
788 reconnect_wait: Duration::from_millis(100),
789 max_reconnect_tries: Some(0),
790 connect_timeout: Duration::from_millis(500),
791 })),
792 ..TcpClientConfig::default()
793 };
794
795 let start_time = std::time::Instant::now();
796 let result = try_connect(&config);
797 let elapsed = start_time.elapsed();
798
799 assert!(result.is_err());
800 assert!(elapsed < Duration::from_secs(5));
802 }
803}