1use std::io::{self, Read, Write};
6use std::net::{TcpStream, ToSocketAddrs};
7use std::os::unix::io::AsRawFd;
8use std::thread;
9use std::time::Duration;
10
11use rns_core::transport::types::InterfaceId;
12
13use crate::event::{Event, EventSender};
14use crate::hdlc;
15use crate::interface::Writer;
16
17#[derive(Debug, Clone)]
19pub struct TcpClientConfig {
20 pub name: String,
21 pub target_host: String,
22 pub target_port: u16,
23 pub interface_id: InterfaceId,
24 pub reconnect_wait: Duration,
25 pub max_reconnect_tries: Option<u32>,
26 pub connect_timeout: Duration,
27}
28
29impl Default for TcpClientConfig {
30 fn default() -> Self {
31 TcpClientConfig {
32 name: String::new(),
33 target_host: "127.0.0.1".into(),
34 target_port: 4242,
35 interface_id: InterfaceId(0),
36 reconnect_wait: Duration::from_secs(5),
37 max_reconnect_tries: None,
38 connect_timeout: Duration::from_secs(5),
39 }
40 }
41}
42
43struct TcpWriter {
45 stream: TcpStream,
46}
47
48impl Writer for TcpWriter {
49 fn send_frame(&mut self, data: &[u8]) -> io::Result<()> {
50 self.stream.write_all(&hdlc::frame(data))
51 }
52}
53
54fn set_socket_options(stream: &TcpStream) -> io::Result<()> {
56 let fd = stream.as_raw_fd();
57 unsafe {
58 let val: libc::c_int = 1;
60 if libc::setsockopt(
61 fd,
62 libc::IPPROTO_TCP,
63 libc::TCP_NODELAY,
64 &val as *const _ as *const libc::c_void,
65 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
66 ) != 0
67 {
68 return Err(io::Error::last_os_error());
69 }
70
71 if libc::setsockopt(
73 fd,
74 libc::SOL_SOCKET,
75 libc::SO_KEEPALIVE,
76 &val as *const _ as *const libc::c_void,
77 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
78 ) != 0
79 {
80 return Err(io::Error::last_os_error());
81 }
82
83 #[cfg(target_os = "linux")]
85 {
86 let idle: libc::c_int = 5;
88 if libc::setsockopt(
89 fd,
90 libc::IPPROTO_TCP,
91 libc::TCP_KEEPIDLE,
92 &idle as *const _ as *const libc::c_void,
93 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
94 ) != 0
95 {
96 return Err(io::Error::last_os_error());
97 }
98
99 let intvl: libc::c_int = 2;
101 if libc::setsockopt(
102 fd,
103 libc::IPPROTO_TCP,
104 libc::TCP_KEEPINTVL,
105 &intvl as *const _ as *const libc::c_void,
106 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
107 ) != 0
108 {
109 return Err(io::Error::last_os_error());
110 }
111
112 let cnt: libc::c_int = 12;
114 if libc::setsockopt(
115 fd,
116 libc::IPPROTO_TCP,
117 libc::TCP_KEEPCNT,
118 &cnt as *const _ as *const libc::c_void,
119 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
120 ) != 0
121 {
122 return Err(io::Error::last_os_error());
123 }
124
125 let timeout: libc::c_int = 24_000;
127 if libc::setsockopt(
128 fd,
129 libc::IPPROTO_TCP,
130 libc::TCP_USER_TIMEOUT,
131 &timeout as *const _ as *const libc::c_void,
132 std::mem::size_of::<libc::c_int>() as libc::socklen_t,
133 ) != 0
134 {
135 return Err(io::Error::last_os_error());
136 }
137 }
138 }
139 Ok(())
140}
141
142fn try_connect(config: &TcpClientConfig) -> io::Result<TcpStream> {
144 let addr_str = format!("{}:{}", config.target_host, config.target_port);
145 let addr = addr_str
146 .to_socket_addrs()?
147 .next()
148 .ok_or_else(|| io::Error::new(io::ErrorKind::AddrNotAvailable, "no addresses resolved"))?;
149
150 let stream = TcpStream::connect_timeout(&addr, config.connect_timeout)?;
151 set_socket_options(&stream)?;
152 Ok(stream)
153}
154
155pub fn start(config: TcpClientConfig, tx: EventSender) -> io::Result<Box<dyn Writer>> {
157 let stream = try_connect(&config)?;
158 let reader_stream = stream.try_clone()?;
159 let writer_stream = stream.try_clone()?;
160
161 let id = config.interface_id;
162 let _ = tx.send(Event::InterfaceUp(id, None, None));
164
165 let reader_config = config;
167 let reader_tx = tx;
168 thread::Builder::new()
169 .name(format!("tcp-reader-{}", id.0))
170 .spawn(move || {
171 reader_loop(reader_stream, reader_config, reader_tx);
172 })?;
173
174 Ok(Box::new(TcpWriter { stream: writer_stream }))
175}
176
177fn reader_loop(mut stream: TcpStream, config: TcpClientConfig, tx: EventSender) {
180 let id = config.interface_id;
181 let mut decoder = hdlc::Decoder::new();
182 let mut buf = [0u8; 4096];
183
184 loop {
185 match stream.read(&mut buf) {
186 Ok(0) => {
187 log::warn!("[{}] connection closed", config.name);
189 let _ = tx.send(Event::InterfaceDown(id));
190 match reconnect(&config, &tx) {
191 Some(new_stream) => {
192 stream = new_stream;
193 decoder = hdlc::Decoder::new();
194 continue;
195 }
196 None => {
197 log::error!("[{}] reconnection failed, giving up", config.name);
198 return;
199 }
200 }
201 }
202 Ok(n) => {
203 for frame in decoder.feed(&buf[..n]) {
204 if tx.send(Event::Frame { interface_id: id, data: frame }).is_err() {
205 return;
207 }
208 }
209 }
210 Err(e) => {
211 log::warn!("[{}] read error: {}", config.name, e);
212 let _ = tx.send(Event::InterfaceDown(id));
213 match reconnect(&config, &tx) {
214 Some(new_stream) => {
215 stream = new_stream;
216 decoder = hdlc::Decoder::new();
217 continue;
218 }
219 None => {
220 log::error!("[{}] reconnection failed, giving up", config.name);
221 return;
222 }
223 }
224 }
225 }
226 }
227}
228
229fn reconnect(config: &TcpClientConfig, tx: &EventSender) -> Option<TcpStream> {
232 let mut attempts = 0u32;
233 loop {
234 thread::sleep(config.reconnect_wait);
235 attempts += 1;
236
237 if let Some(max) = config.max_reconnect_tries {
238 if attempts > max {
239 let _ = tx.send(Event::InterfaceDown(config.interface_id));
240 return None;
241 }
242 }
243
244 log::info!(
245 "[{}] reconnect attempt {} ...",
246 config.name,
247 attempts
248 );
249
250 match try_connect(config) {
251 Ok(new_stream) => {
252 let writer_stream = match new_stream.try_clone() {
254 Ok(s) => s,
255 Err(e) => {
256 log::warn!("[{}] failed to clone stream: {}", config.name, e);
257 continue;
258 }
259 };
260 log::info!("[{}] reconnected", config.name);
261 let new_writer: Box<dyn Writer> = Box::new(TcpWriter { stream: writer_stream });
263 let _ = tx.send(Event::InterfaceUp(config.interface_id, Some(new_writer), None));
264 return Some(new_stream);
265 }
266 Err(e) => {
267 log::warn!("[{}] reconnect failed: {}", config.name, e);
268 }
269 }
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use std::net::TcpListener;
277 use std::sync::mpsc;
278 use std::time::Duration;
279
280 fn find_free_port() -> u16 {
281 TcpListener::bind("127.0.0.1:0")
282 .unwrap()
283 .local_addr()
284 .unwrap()
285 .port()
286 }
287
288 fn make_config(port: u16) -> TcpClientConfig {
289 TcpClientConfig {
290 name: format!("test-tcp-{}", port),
291 target_host: "127.0.0.1".into(),
292 target_port: port,
293 interface_id: InterfaceId(1),
294 reconnect_wait: Duration::from_millis(100),
295 max_reconnect_tries: Some(2),
296 connect_timeout: Duration::from_secs(2),
297 }
298 }
299
300 #[test]
301 fn connect_to_listener() {
302 let port = find_free_port();
303 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
304 let (tx, rx) = mpsc::channel();
305
306 let config = make_config(port);
307 let _writer = start(config, tx).unwrap();
308
309 let _server_stream = listener.accept().unwrap();
311
312 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
314 assert!(matches!(event, Event::InterfaceUp(InterfaceId(1), _, _)));
315 }
316
317 #[test]
318 fn receive_frame() {
319 let port = find_free_port();
320 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
321 let (tx, rx) = mpsc::channel();
322
323 let config = make_config(port);
324 let _writer = start(config, tx).unwrap();
325
326 let (mut server_stream, _) = listener.accept().unwrap();
327
328 let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
330
331 let payload: Vec<u8> = (0..32).collect();
333 let framed = hdlc::frame(&payload);
334 server_stream.write_all(&framed).unwrap();
335
336 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
338 match event {
339 Event::Frame { interface_id, data } => {
340 assert_eq!(interface_id, InterfaceId(1));
341 assert_eq!(data, payload);
342 }
343 other => panic!("expected Frame, got {:?}", other),
344 }
345 }
346
347 #[test]
348 fn send_frame() {
349 let port = find_free_port();
350 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
351 let (tx, _rx) = mpsc::channel();
352
353 let config = make_config(port);
354 let mut writer = start(config, tx).unwrap();
355
356 let (mut server_stream, _) = listener.accept().unwrap();
357 server_stream
358 .set_read_timeout(Some(Duration::from_secs(2)))
359 .unwrap();
360
361 let payload: Vec<u8> = (0..24).collect();
363 writer.send_frame(&payload).unwrap();
364
365 let mut buf = [0u8; 256];
367 let n = server_stream.read(&mut buf).unwrap();
368 let expected = hdlc::frame(&payload);
369 assert_eq!(&buf[..n], &expected[..]);
370 }
371
372 #[test]
373 fn multiple_frames() {
374 let port = find_free_port();
375 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
376 let (tx, rx) = mpsc::channel();
377
378 let config = make_config(port);
379 let _writer = start(config, tx).unwrap();
380
381 let (mut server_stream, _) = listener.accept().unwrap();
382
383 let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
385
386 let payloads: Vec<Vec<u8>> = (0..3).map(|i| (0..24).map(|j| j + i * 50).collect()).collect();
388 for p in &payloads {
389 server_stream.write_all(&hdlc::frame(p)).unwrap();
390 }
391
392 for expected in &payloads {
394 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
395 match event {
396 Event::Frame { data, .. } => assert_eq!(&data, expected),
397 other => panic!("expected Frame, got {:?}", other),
398 }
399 }
400 }
401
402 #[test]
403 fn split_across_reads() {
404 let port = find_free_port();
405 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
406 let (tx, rx) = mpsc::channel();
407
408 let config = make_config(port);
409 let _writer = start(config, tx).unwrap();
410
411 let (mut server_stream, _) = listener.accept().unwrap();
412
413 let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
415
416 let payload: Vec<u8> = (0..32).collect();
418 let framed = hdlc::frame(&payload);
419 let mid = framed.len() / 2;
420
421 server_stream.write_all(&framed[..mid]).unwrap();
422 server_stream.flush().unwrap();
423 thread::sleep(Duration::from_millis(50));
424 server_stream.write_all(&framed[mid..]).unwrap();
425 server_stream.flush().unwrap();
426
427 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
428 match event {
429 Event::Frame { data, .. } => assert_eq!(data, payload),
430 other => panic!("expected Frame, got {:?}", other),
431 }
432 }
433
434 #[test]
435 fn reconnect_on_close() {
436 let port = find_free_port();
437 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
438 listener.set_nonblocking(false).unwrap();
439 let (tx, rx) = mpsc::channel();
440
441 let config = make_config(port);
442 let _writer = start(config, tx).unwrap();
443
444 let (server_stream, _) = listener.accept().unwrap();
446
447 let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
449
450 drop(server_stream);
451
452 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
454 assert!(matches!(event, Event::InterfaceDown(InterfaceId(1))));
455
456 let _server_stream2 = listener.accept().unwrap();
458
459 let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
461 assert!(matches!(event, Event::InterfaceUp(InterfaceId(1), _, _)));
462 }
463
464 #[test]
465 fn socket_options() {
466 let port = find_free_port();
467 let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
468
469 let stream = try_connect(&make_config(port)).unwrap();
470 let _server = listener.accept().unwrap();
471
472 let fd = stream.as_raw_fd();
474 let mut val: libc::c_int = 0;
475 let mut len: libc::socklen_t = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
476 unsafe {
477 libc::getsockopt(
478 fd,
479 libc::IPPROTO_TCP,
480 libc::TCP_NODELAY,
481 &mut val as *mut _ as *mut libc::c_void,
482 &mut len,
483 );
484 }
485 assert_eq!(val, 1, "TCP_NODELAY should be 1");
486 }
487
488 #[test]
489 fn connect_timeout() {
490 let config = TcpClientConfig {
492 name: "timeout-test".into(),
493 target_host: "192.0.2.1".into(), target_port: 12345,
495 interface_id: InterfaceId(99),
496 reconnect_wait: Duration::from_millis(100),
497 max_reconnect_tries: Some(0),
498 connect_timeout: Duration::from_millis(500),
499 };
500
501 let start_time = std::time::Instant::now();
502 let result = try_connect(&config);
503 let elapsed = start_time.elapsed();
504
505 assert!(result.is_err());
506 assert!(elapsed < Duration::from_secs(5));
508 }
509}