use std::io::{self, Read, Write};
use std::net::{TcpStream, ToSocketAddrs};
use std::os::unix::io::AsRawFd;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
use rns_core::transport::types::InterfaceId;
use crate::event::{Event, EventSender};
use crate::hdlc;
use crate::interface::Writer;
#[derive(Debug, Clone)]
pub struct TcpClientConfig {
pub name: String,
pub target_host: String,
pub target_port: u16,
pub interface_id: InterfaceId,
pub reconnect_wait: Duration,
pub max_reconnect_tries: Option<u32>,
pub connect_timeout: Duration,
pub device: Option<String>,
pub runtime: Arc<Mutex<TcpClientRuntime>>,
}
#[derive(Debug, Clone)]
pub struct TcpClientRuntime {
pub target_host: String,
pub target_port: u16,
pub reconnect_wait: Duration,
pub max_reconnect_tries: Option<u32>,
pub connect_timeout: Duration,
}
impl TcpClientRuntime {
pub fn from_config(config: &TcpClientConfig) -> Self {
Self {
target_host: config.target_host.clone(),
target_port: config.target_port,
reconnect_wait: config.reconnect_wait,
max_reconnect_tries: config.max_reconnect_tries,
connect_timeout: config.connect_timeout,
}
}
}
#[derive(Debug, Clone)]
pub struct TcpClientRuntimeConfigHandle {
pub interface_name: String,
pub runtime: Arc<Mutex<TcpClientRuntime>>,
pub startup: TcpClientRuntime,
}
impl Default for TcpClientConfig {
fn default() -> Self {
let mut config = TcpClientConfig {
name: String::new(),
target_host: "127.0.0.1".into(),
target_port: 4242,
interface_id: InterfaceId(0),
reconnect_wait: Duration::from_secs(5),
max_reconnect_tries: None,
connect_timeout: Duration::from_secs(5),
device: None,
runtime: Arc::new(Mutex::new(TcpClientRuntime {
target_host: "127.0.0.1".into(),
target_port: 4242,
reconnect_wait: Duration::from_secs(5),
max_reconnect_tries: None,
connect_timeout: Duration::from_secs(5),
})),
};
let startup = TcpClientRuntime::from_config(&config);
config.runtime = Arc::new(Mutex::new(startup));
config
}
}
struct TcpWriter {
stream: TcpStream,
}
impl Writer for TcpWriter {
fn send_frame(&mut self, data: &[u8]) -> io::Result<()> {
self.stream.write_all(&hdlc::frame(data))
}
}
fn set_socket_options(stream: &TcpStream) -> io::Result<()> {
let fd = stream.as_raw_fd();
unsafe {
let val: libc::c_int = 1;
if libc::setsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_NODELAY,
&val as *const _ as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
) != 0
{
return Err(io::Error::last_os_error());
}
if libc::setsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_KEEPALIVE,
&val as *const _ as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
) != 0
{
return Err(io::Error::last_os_error());
}
#[cfg(target_os = "linux")]
{
let idle: libc::c_int = 5;
if libc::setsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_KEEPIDLE,
&idle as *const _ as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
) != 0
{
return Err(io::Error::last_os_error());
}
let intvl: libc::c_int = 2;
if libc::setsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_KEEPINTVL,
&intvl as *const _ as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
) != 0
{
return Err(io::Error::last_os_error());
}
let cnt: libc::c_int = 12;
if libc::setsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_KEEPCNT,
&cnt as *const _ as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
) != 0
{
return Err(io::Error::last_os_error());
}
let timeout: libc::c_int = 24_000;
if libc::setsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_USER_TIMEOUT,
&timeout as *const _ as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
) != 0
{
return Err(io::Error::last_os_error());
}
}
}
Ok(())
}
fn try_connect(config: &TcpClientConfig) -> io::Result<TcpStream> {
let runtime = config.runtime.lock().unwrap().clone();
let addr_str = format!("{}:{}", config.target_host, config.target_port);
let addr = addr_str
.to_socket_addrs()?
.next()
.ok_or_else(|| io::Error::new(io::ErrorKind::AddrNotAvailable, "no addresses resolved"))?;
#[cfg(target_os = "linux")]
let stream = if let Some(ref device) = config.device {
connect_with_device(&addr, device, runtime.connect_timeout)?
} else {
TcpStream::connect_timeout(&addr, runtime.connect_timeout)?
};
#[cfg(not(target_os = "linux"))]
let stream = TcpStream::connect_timeout(&addr, runtime.connect_timeout)?;
set_socket_options(&stream)?;
Ok(stream)
}
#[cfg(target_os = "linux")]
fn connect_with_device(
addr: &std::net::SocketAddr,
device: &str,
timeout: Duration,
) -> io::Result<TcpStream> {
use std::os::unix::io::{FromRawFd, RawFd};
let domain = if addr.is_ipv4() {
libc::AF_INET
} else {
libc::AF_INET6
};
let fd: RawFd = unsafe { libc::socket(domain, libc::SOCK_STREAM, 0) };
if fd < 0 {
return Err(io::Error::last_os_error());
}
let stream = unsafe { TcpStream::from_raw_fd(fd) };
super::bind_to_device(stream.as_raw_fd(), device)?;
stream.set_nonblocking(true)?;
let (sockaddr, socklen) = socket_addr_to_raw(addr);
let ret = unsafe {
libc::connect(
stream.as_raw_fd(),
&sockaddr as *const libc::sockaddr_storage as *const libc::sockaddr,
socklen,
)
};
if ret != 0 {
let err = io::Error::last_os_error();
if err.raw_os_error() != Some(libc::EINPROGRESS) {
return Err(err);
}
}
let mut pollfd = libc::pollfd {
fd: stream.as_raw_fd(),
events: libc::POLLOUT,
revents: 0,
};
let timeout_ms = timeout.as_millis() as libc::c_int;
let poll_ret = unsafe { libc::poll(&mut pollfd, 1, timeout_ms) };
if poll_ret == 0 {
return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
}
if poll_ret < 0 {
return Err(io::Error::last_os_error());
}
let mut err_val: libc::c_int = 0;
let mut err_len: libc::socklen_t = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
let ret = unsafe {
libc::getsockopt(
stream.as_raw_fd(),
libc::SOL_SOCKET,
libc::SO_ERROR,
&mut err_val as *mut _ as *mut libc::c_void,
&mut err_len,
)
};
if ret != 0 {
return Err(io::Error::last_os_error());
}
if err_val != 0 {
return Err(io::Error::from_raw_os_error(err_val));
}
stream.set_nonblocking(false)?;
Ok(stream)
}
#[cfg(target_os = "linux")]
fn socket_addr_to_raw(addr: &std::net::SocketAddr) -> (libc::sockaddr_storage, libc::socklen_t) {
let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
match addr {
std::net::SocketAddr::V4(v4) => {
let sin: &mut libc::sockaddr_in = unsafe {
&mut *(&mut storage as *mut libc::sockaddr_storage as *mut libc::sockaddr_in)
};
sin.sin_family = libc::AF_INET as libc::sa_family_t;
sin.sin_port = v4.port().to_be();
sin.sin_addr = libc::in_addr {
s_addr: u32::from_ne_bytes(v4.ip().octets()),
};
(
storage,
std::mem::size_of::<libc::sockaddr_in>() as libc::socklen_t,
)
}
std::net::SocketAddr::V6(v6) => {
let sin6: &mut libc::sockaddr_in6 = unsafe {
&mut *(&mut storage as *mut libc::sockaddr_storage as *mut libc::sockaddr_in6)
};
sin6.sin6_family = libc::AF_INET6 as libc::sa_family_t;
sin6.sin6_port = v6.port().to_be();
sin6.sin6_addr = libc::in6_addr {
s6_addr: v6.ip().octets(),
};
sin6.sin6_flowinfo = v6.flowinfo();
sin6.sin6_scope_id = v6.scope_id();
(
storage,
std::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t,
)
}
}
}
pub fn start(config: TcpClientConfig, tx: EventSender) -> io::Result<Box<dyn Writer>> {
let stream = try_connect(&config)?;
let reader_stream = stream.try_clone()?;
let writer_stream = stream.try_clone()?;
let id = config.interface_id;
let _ = tx.send(Event::InterfaceUp(id, None, None));
let reader_config = config;
let reader_tx = tx;
thread::Builder::new()
.name(format!("tcp-reader-{}", id.0))
.spawn(move || {
reader_loop(reader_stream, reader_config, reader_tx);
})?;
Ok(Box::new(TcpWriter {
stream: writer_stream,
}))
}
fn reader_loop(mut stream: TcpStream, config: TcpClientConfig, tx: EventSender) {
let id = config.interface_id;
let mut decoder = hdlc::Decoder::new();
let mut buf = [0u8; 4096];
loop {
match stream.read(&mut buf) {
Ok(0) => {
log::warn!("[{}] connection closed", config.name);
let _ = tx.send(Event::InterfaceDown(id));
match reconnect(&config, &tx) {
Some(new_stream) => {
stream = new_stream;
decoder = hdlc::Decoder::new();
continue;
}
None => {
log::error!("[{}] reconnection failed, giving up", config.name);
return;
}
}
}
Ok(n) => {
for frame in decoder.feed(&buf[..n]) {
if tx
.send(Event::Frame {
interface_id: id,
data: frame,
})
.is_err()
{
return;
}
}
}
Err(e) => {
log::warn!("[{}] read error: {}", config.name, e);
let _ = tx.send(Event::InterfaceDown(id));
match reconnect(&config, &tx) {
Some(new_stream) => {
stream = new_stream;
decoder = hdlc::Decoder::new();
continue;
}
None => {
log::error!("[{}] reconnection failed, giving up", config.name);
return;
}
}
}
}
}
}
fn reconnect(config: &TcpClientConfig, tx: &EventSender) -> Option<TcpStream> {
let mut attempts = 0u32;
loop {
let runtime = config.runtime.lock().unwrap().clone();
thread::sleep(runtime.reconnect_wait);
attempts += 1;
if let Some(max) = runtime.max_reconnect_tries {
if attempts > max {
let _ = tx.send(Event::InterfaceDown(config.interface_id));
return None;
}
}
log::info!("[{}] reconnect attempt {} ...", config.name, attempts);
match try_connect(config) {
Ok(new_stream) => {
let writer_stream = match new_stream.try_clone() {
Ok(s) => s,
Err(e) => {
log::warn!("[{}] failed to clone stream: {}", config.name, e);
continue;
}
};
log::info!("[{}] reconnected", config.name);
let new_writer: Box<dyn Writer> = Box::new(TcpWriter {
stream: writer_stream,
});
let _ = tx.send(Event::InterfaceUp(
config.interface_id,
Some(new_writer),
None,
));
return Some(new_stream);
}
Err(e) => {
log::warn!("[{}] reconnect failed: {}", config.name, e);
}
}
}
}
use super::{InterfaceConfigData, InterfaceFactory, StartContext, StartResult};
use rns_core::transport::types::InterfaceInfo;
use std::collections::HashMap;
pub struct TcpClientFactory;
impl InterfaceFactory for TcpClientFactory {
fn type_name(&self) -> &str {
"TCPClientInterface"
}
fn parse_config(
&self,
name: &str,
id: InterfaceId,
params: &HashMap<String, String>,
) -> Result<Box<dyn InterfaceConfigData>, String> {
let target_host = params
.get("target_host")
.cloned()
.unwrap_or_else(|| "127.0.0.1".into());
let target_port = params
.get("target_port")
.and_then(|v| v.parse().ok())
.unwrap_or(4242);
Ok(Box::new(TcpClientConfig {
name: name.to_string(),
target_host,
target_port,
interface_id: id,
device: params.get("device").cloned(),
..TcpClientConfig::default()
}))
}
fn start(
&self,
config: Box<dyn InterfaceConfigData>,
ctx: StartContext,
) -> io::Result<StartResult> {
let tcp_config = *config
.into_any()
.downcast::<TcpClientConfig>()
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "wrong config type"))?;
let id = tcp_config.interface_id;
let name = tcp_config.name.clone();
let info = InterfaceInfo {
id,
name,
mode: ctx.mode,
out_capable: true,
in_capable: true,
bitrate: None,
announce_rate_target: None,
announce_rate_grace: 0,
announce_rate_penalty: 0.0,
announce_cap: rns_core::constants::ANNOUNCE_CAP,
is_local_client: false,
wants_tunnel: false,
tunnel_id: None,
mtu: 65535,
ingress_control: rns_core::transport::types::IngressControlConfig::enabled(),
ia_freq: 0.0,
started: crate::time::now(),
};
let writer = start(tcp_config, ctx.tx)?;
Ok(StartResult::Simple {
id,
info,
writer,
interface_type_name: "TCPClientInterface".to_string(),
})
}
}
pub(crate) fn tcp_client_runtime_handle_from_config(
config: &TcpClientConfig,
) -> TcpClientRuntimeConfigHandle {
TcpClientRuntimeConfigHandle {
interface_name: config.name.clone(),
runtime: Arc::clone(&config.runtime),
startup: TcpClientRuntime::from_config(config),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::TcpListener;
use std::time::Duration;
fn find_free_port() -> u16 {
TcpListener::bind("127.0.0.1:0")
.unwrap()
.local_addr()
.unwrap()
.port()
}
fn make_config(port: u16) -> TcpClientConfig {
TcpClientConfig {
name: format!("test-tcp-{}", port),
target_host: "127.0.0.1".into(),
target_port: port,
interface_id: InterfaceId(1),
reconnect_wait: Duration::from_millis(100),
max_reconnect_tries: Some(2),
connect_timeout: Duration::from_secs(2),
runtime: Arc::new(Mutex::new(TcpClientRuntime {
target_host: "127.0.0.1".into(),
target_port: port,
reconnect_wait: Duration::from_millis(100),
max_reconnect_tries: Some(2),
connect_timeout: Duration::from_secs(2),
})),
device: None,
}
}
#[test]
fn connect_to_listener() {
let port = find_free_port();
let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
let (tx, rx) = crate::event::channel();
let config = make_config(port);
let _writer = start(config, tx).unwrap();
let _server_stream = listener.accept().unwrap();
let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
assert!(matches!(event, Event::InterfaceUp(InterfaceId(1), _, _)));
}
#[test]
fn receive_frame() {
let port = find_free_port();
let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
let (tx, rx) = crate::event::channel();
let config = make_config(port);
let _writer = start(config, tx).unwrap();
let (mut server_stream, _) = listener.accept().unwrap();
let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
let payload: Vec<u8> = (0..32).collect();
let framed = hdlc::frame(&payload);
server_stream.write_all(&framed).unwrap();
let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
match event {
Event::Frame { interface_id, data } => {
assert_eq!(interface_id, InterfaceId(1));
assert_eq!(data, payload);
}
other => panic!("expected Frame, got {:?}", other),
}
}
#[test]
fn send_frame() {
let port = find_free_port();
let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
let (tx, _rx) = crate::event::channel();
let config = make_config(port);
let mut writer = start(config, tx).unwrap();
let (mut server_stream, _) = listener.accept().unwrap();
server_stream
.set_read_timeout(Some(Duration::from_secs(2)))
.unwrap();
let payload: Vec<u8> = (0..24).collect();
writer.send_frame(&payload).unwrap();
let mut buf = [0u8; 256];
let n = server_stream.read(&mut buf).unwrap();
let expected = hdlc::frame(&payload);
assert_eq!(&buf[..n], &expected[..]);
}
#[test]
fn multiple_frames() {
let port = find_free_port();
let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
let (tx, rx) = crate::event::channel();
let config = make_config(port);
let _writer = start(config, tx).unwrap();
let (mut server_stream, _) = listener.accept().unwrap();
let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
let payloads: Vec<Vec<u8>> = (0..3)
.map(|i| (0..24).map(|j| j + i * 50).collect())
.collect();
for p in &payloads {
server_stream.write_all(&hdlc::frame(p)).unwrap();
}
for expected in &payloads {
let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
match event {
Event::Frame { data, .. } => assert_eq!(&data, expected),
other => panic!("expected Frame, got {:?}", other),
}
}
}
#[test]
fn split_across_reads() {
let port = find_free_port();
let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
let (tx, rx) = crate::event::channel();
let config = make_config(port);
let _writer = start(config, tx).unwrap();
let (mut server_stream, _) = listener.accept().unwrap();
let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
let payload: Vec<u8> = (0..32).collect();
let framed = hdlc::frame(&payload);
let mid = framed.len() / 2;
server_stream.write_all(&framed[..mid]).unwrap();
server_stream.flush().unwrap();
thread::sleep(Duration::from_millis(50));
server_stream.write_all(&framed[mid..]).unwrap();
server_stream.flush().unwrap();
let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
match event {
Event::Frame { data, .. } => assert_eq!(data, payload),
other => panic!("expected Frame, got {:?}", other),
}
}
#[test]
fn reconnect_on_close() {
let port = find_free_port();
let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
listener.set_nonblocking(false).unwrap();
let (tx, rx) = crate::event::channel();
let config = make_config(port);
let _writer = start(config, tx).unwrap();
let (server_stream, _) = listener.accept().unwrap();
let _ = rx.recv_timeout(Duration::from_secs(1)).unwrap();
drop(server_stream);
let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
assert!(matches!(event, Event::InterfaceDown(InterfaceId(1))));
let _server_stream2 = listener.accept().unwrap();
let event = rx.recv_timeout(Duration::from_secs(2)).unwrap();
assert!(matches!(event, Event::InterfaceUp(InterfaceId(1), _, _)));
}
#[test]
fn socket_options() {
let port = find_free_port();
let listener = TcpListener::bind(format!("127.0.0.1:{}", port)).unwrap();
let stream = try_connect(&make_config(port)).unwrap();
let _server = listener.accept().unwrap();
let fd = stream.as_raw_fd();
let mut val: libc::c_int = 0;
let mut len: libc::socklen_t = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
unsafe {
libc::getsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_NODELAY,
&mut val as *mut _ as *mut libc::c_void,
&mut len,
);
}
assert_eq!(val, 1, "TCP_NODELAY should be 1");
}
#[test]
fn connect_timeout() {
let config = TcpClientConfig {
name: "timeout-test".into(),
target_host: "192.0.2.1".into(), target_port: 12345,
interface_id: InterfaceId(99),
reconnect_wait: Duration::from_millis(100),
max_reconnect_tries: Some(0),
connect_timeout: Duration::from_millis(500),
device: None,
runtime: Arc::new(Mutex::new(TcpClientRuntime {
target_host: "192.0.2.1".into(),
target_port: 12345,
reconnect_wait: Duration::from_millis(100),
max_reconnect_tries: Some(0),
connect_timeout: Duration::from_millis(500),
})),
..TcpClientConfig::default()
};
let start_time = std::time::Instant::now();
let result = try_connect(&config);
let elapsed = start_time.elapsed();
assert!(result.is_err());
assert!(elapsed < Duration::from_secs(5));
}
}