use std::{
collections::HashMap,
net::SocketAddr,
sync::{Arc, atomic::Ordering},
time::Duration,
};
use iroh::endpoint::{RecvStream, SendStream};
use tokio::{
net::UdpSocket,
sync::{RwLock, watch},
};
use crate::{
client::{ClientError, ClientTunnel},
common::{CONNECTION_TIMEOUT_SECONDS, LocalClientConnection, TUNNEL_FORWARDER_BUFFER_SIZE},
host::{ConnectedClient, HostError, HostTunnel},
utils::now_secs,
};
pub async fn bi_udp_host(
socket: UdpSocket,
connection: Arc<ConnectedClient>,
local_client_connection: Arc<RwLock<LocalClientConnection>>,
mut tunnel_send: SendStream,
mut tunnel_recv: RecvStream,
mut stop_rx: watch::Receiver<bool>,
) -> Result<(), HostError> {
if socket.peer_addr().is_err() {
panic!("Socket is not connected");
}
tracing::debug!("Starting bidirectional UDP forwarder for the host tunnel");
let socket_arc = Arc::new(socket);
let udp_read = socket_arc.clone();
let udp_write = socket_arc.clone();
let stop_rx_c = stop_rx.clone();
let last_active = local_client_connection.read().await.last_active.clone();
let last_active_c = last_active.clone();
last_active.store(now_secs(), Ordering::SeqCst);
let tunnel_to_udp_forwader = tokio::spawn(async move {
let mut buf = vec![0; TUNNEL_FORWARDER_BUFFER_SIZE];
let mut stop_rx = stop_rx_c;
let last_active = last_active_c;
loop {
let now = now_secs();
let last = last_active.load(Ordering::SeqCst);
let elapsed = now.saturating_sub(last);
let timeout = if elapsed >= CONNECTION_TIMEOUT_SECONDS {
Duration::from_secs(0)
} else {
Duration::from_secs(CONNECTION_TIMEOUT_SECONDS - elapsed)
};
tokio::select! {
biased;
_ = stop_rx.changed() => {
if *stop_rx.borrow() {
tracing::debug!("Stopping UDP stream for host tunnel");
break;
}
}
Ok(n) = tunnel_recv.read(&mut buf) => {
let Some(n) = n else {
tracing::debug!("Tunnel stream closed");
break;
};
if n == 0 {
tracing::debug!("Tunnel stream closed");
break;
}
last_active.store(now_secs(), Ordering::SeqCst);
udp_write.send(&buf[..n]).await.unwrap();
}
_ = tokio::time::sleep(timeout) => {
let now = now_secs();
let last = last_active.load(Ordering::SeqCst);
let elapsed = now.saturating_sub(last);
if elapsed >= CONNECTION_TIMEOUT_SECONDS {
tracing::debug!("UDP stream timed out");
break;
}
}
}
}
});
let udp_to_tunnel_forwader = tokio::spawn(async move {
let mut buf = vec![0; TUNNEL_FORWARDER_BUFFER_SIZE];
loop {
let now = now_secs();
let last = last_active.load(Ordering::SeqCst);
let elapsed = now.saturating_sub(last);
let timeout = if elapsed >= CONNECTION_TIMEOUT_SECONDS {
Duration::from_secs(0)
} else {
Duration::from_secs(CONNECTION_TIMEOUT_SECONDS - elapsed)
};
tokio::select! {
biased;
_ = stop_rx.changed() => {
if *stop_rx.borrow() {
tracing::debug!("Stopping UDP stream for host tunnel");
break;
}
}
Ok(n) = udp_read.recv(&mut buf) => {
if n == 0 {
tracing::debug!("Tunnel stream closed");
break;
}
last_active.store(now_secs(), Ordering::SeqCst);
tunnel_send.write_all(&buf[..n]).await.unwrap();
}
_ = tokio::time::sleep(timeout) => {
let now = now_secs();
let last = last_active.load(Ordering::SeqCst);
let elapsed = now.saturating_sub(last);
if elapsed >= CONNECTION_TIMEOUT_SECONDS {
tracing::debug!("UDP stream timed out");
break;
}
}
}
}
});
tracing::info!("Started tunnel UDP forwarder for host tunnel");
let _ = tokio::try_join!(tunnel_to_udp_forwader, udp_to_tunnel_forwader);
HostTunnel::disconnect_local_client_connection(connection, local_client_connection).await;
tracing::info!("Stopping tunnel UDP forwarder for host tunnel");
Ok(())
}
pub async fn bi_udp_client(
socket: Arc<UdpSocket>,
local_connections: Arc<RwLock<HashMap<SocketAddr, Arc<RwLock<LocalClientConnection>>>>>,
local_client_connection: Arc<RwLock<LocalClientConnection>>,
mut tunnel_send: SendStream,
mut tunnel_recv: RecvStream,
mut stop_rx: watch::Receiver<bool>,
) -> Result<(), ClientError> {
tracing::debug!("Starting bidirectional UDP forwarder for the client tunnel");
let client_local_addr = { local_client_connection.read().await.client_local_addr };
let udp_read = socket.clone();
let udp_write = socket;
let stop_rx_c = stop_rx.clone();
let last_active = local_client_connection.read().await.last_active.clone();
let last_active_c = last_active.clone();
last_active.store(now_secs(), Ordering::SeqCst);
let tunnel_to_udp_forwader = tokio::spawn(async move {
let mut buf = vec![0; TUNNEL_FORWARDER_BUFFER_SIZE];
let mut stop_rx = stop_rx_c;
let last_active = last_active_c;
loop {
let now = now_secs();
let last = last_active.load(Ordering::SeqCst);
let elapsed = now.saturating_sub(last);
let timeout = if elapsed >= CONNECTION_TIMEOUT_SECONDS {
Duration::from_secs(0)
} else {
Duration::from_secs(CONNECTION_TIMEOUT_SECONDS - elapsed)
};
tokio::select! {
biased;
_ = stop_rx.changed() => {
if *stop_rx.borrow() {
tracing::debug!("Stopping UDP stream for host tunnel");
break;
}
}
Ok(n) = tunnel_recv.read(&mut buf) => {
let Some(n) = n else {
tracing::debug!("Tunnel stream closed");
break;
};
if n == 0 {
tracing::debug!("Tunnel stream closed");
break;
}
last_active.store(now_secs(), Ordering::SeqCst);
udp_write.send_to(&buf[..n], client_local_addr).await.unwrap();
}
_ = tokio::time::sleep(timeout) => {
let now = now_secs();
let last = last_active.load(Ordering::SeqCst);
let elapsed = now.saturating_sub(last);
if elapsed >= CONNECTION_TIMEOUT_SECONDS {
tracing::debug!("UDP stream timed out");
break;
}
}
}
}
});
let udp_to_tunnel_forwader = tokio::spawn(async move {
let mut buf = vec![0; TUNNEL_FORWARDER_BUFFER_SIZE];
loop {
let now = now_secs();
let last = last_active.load(Ordering::SeqCst);
let elapsed = now.saturating_sub(last);
let timeout = if elapsed >= CONNECTION_TIMEOUT_SECONDS {
Duration::from_secs(0)
} else {
Duration::from_secs(CONNECTION_TIMEOUT_SECONDS - elapsed)
};
tokio::select! {
biased;
_ = stop_rx.changed() => {
if *stop_rx.borrow() {
tracing::debug!("Stopping UDP stream for host tunnel");
break;
}
}
Ok((n, addr)) = udp_read.peek_from(&mut buf) => {
if addr != client_local_addr {
continue
};
if udp_read.recv(&mut buf[..n]).await.is_err() {
continue;
};
if n == 0 {
tracing::debug!("Tunnel stream closed");
break;
}
last_active.store(now_secs(), Ordering::SeqCst);
tunnel_send.write_all(&buf[..n]).await.unwrap();
}
_ = tokio::time::sleep(timeout) => {
let now = now_secs();
let last = last_active.load(Ordering::SeqCst);
let elapsed = now.saturating_sub(last);
if elapsed >= CONNECTION_TIMEOUT_SECONDS {
tracing::debug!("UDP stream timed out");
break;
}
}
}
}
});
tracing::info!("Started tunnel UDP forwarder for host tunnel");
let _ = tokio::try_join!(tunnel_to_udp_forwader, udp_to_tunnel_forwader);
ClientTunnel::disconnect_local_connection(local_connections, local_client_connection).await;
Ok(())
}