use std::{
collections::HashMap,
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
sync::{
Arc,
atomic::{AtomicBool, AtomicU64, Ordering},
},
};
use iroh::{
Endpoint, PublicKey, RelayMode,
endpoint::{Connection, ReadError, VarInt, WriteError},
};
use thiserror::Error;
use tokio::{
net::{TcpListener, UdpSocket},
sync::{RwLock, watch},
task::JoinHandle,
};
use crate::{
ALPN, TunnelProtocol,
common::{Client2HostControlMsg, LocalClientConnection, TunnelCommon, send_packet},
net::{tcp::bi_tcp_stream_client, udp::bi_udp_client},
};
#[derive(Error, Debug)]
pub enum ClientError {
#[error("Failed to create endpoint: {0}")]
FailedToCreateEndpoint(String),
#[error("IO error: {0}")]
IoError(#[from] std::io::Error),
#[error("Write error: {0}")]
WriteError(#[from] WriteError),
#[error("Read error: {0}")]
ReadError(#[from] ReadError),
#[error("Tunnel with local address {0}/{1} already exists")]
TunnelAlreadyExists(SocketAddr, TunnelProtocol),
#[error("Failed to connect: {0}")]
FailedToConnect(String),
#[error("Failed to accept stream: {0}")]
FailedToAcceptStream(String),
#[error("Connection error: {0}")]
ConnectionError(#[from] iroh::endpoint::ConnectionError),
#[error("Already connected to tunnel")]
AlreadyConnected,
#[error("Not connected to tunnel")]
NotConnected,
}
pub struct ClientState {
tunnels: HashMap<(SocketAddr, TunnelProtocol), ClientTunnel>,
}
impl Default for ClientState {
fn default() -> Self {
Self::new()
}
}
impl ClientState {
pub fn new() -> Self {
Self {
tunnels: HashMap::new(),
}
}
pub fn tunnels(&self) -> &HashMap<(SocketAddr, TunnelProtocol), ClientTunnel> {
&self.tunnels
}
pub fn tunnels_mut(&mut self) -> &mut HashMap<(SocketAddr, TunnelProtocol), ClientTunnel> {
&mut self.tunnels
}
pub async fn create_tunnel(
&mut self,
local_addr: SocketAddr,
name: String,
secret: [u8; 32],
protocol: TunnelProtocol,
) -> Result<&mut ClientTunnel, ClientError> {
if self.tunnels.contains_key(&(local_addr, protocol)) {
return Err(ClientError::TunnelAlreadyExists(local_addr, protocol));
}
let tunnel = ClientTunnel {
name,
secret: PublicKey::from_bytes(&secret).unwrap(),
tunnel_addr: local_addr,
protocol,
is_connected: Arc::new(AtomicBool::new(false)),
running_props: None,
};
self.tunnels.insert((local_addr, protocol), tunnel);
Ok(self.tunnels.get_mut(&(local_addr, protocol)).unwrap())
}
pub fn get_tunnel(
&self,
local_addr: SocketAddr,
protocol: TunnelProtocol,
) -> Option<&ClientTunnel> {
self.tunnels.get(&(local_addr, protocol))
}
pub fn get_tunnel_mut(
&mut self,
local_addr: SocketAddr,
protocol: TunnelProtocol,
) -> Option<&mut ClientTunnel> {
self.tunnels.get_mut(&(local_addr, protocol))
}
}
pub struct ClientTunnel {
pub name: String,
secret: PublicKey,
tunnel_addr: SocketAddr,
protocol: TunnelProtocol,
is_connected: Arc<AtomicBool>,
running_props: Option<RunningClientTunnelProps>,
}
struct RunningClientTunnelProps {
local_connections: Arc<RwLock<HashMap<SocketAddr, Arc<RwLock<LocalClientConnection>>>>>,
endpoint: Endpoint,
connection: Connection,
stop_tx: watch::Sender<bool>,
}
impl ClientTunnel {
pub fn client_addr(&self) -> SocketAddr {
self.tunnel_addr
}
pub async fn connect(&mut self) -> Result<JoinHandle<Result<(), ClientError>>, ClientError> {
if self.is_running() {
tracing::warn!(
"Client tunnel \"{}\" is already connected, not starting again",
self.name
);
return Err(ClientError::AlreadyConnected);
}
let endpoint = Endpoint::builder()
.relay_mode(RelayMode::Default)
.discovery_n0()
.bind()
.await
.map_err(|e| ClientError::FailedToCreateEndpoint(e.to_string()))?;
let conn = endpoint
.connect(self.secret, ALPN)
.await
.map_err(|e| ClientError::FailedToConnect(e.to_string()))?;
let local_connections = Arc::new(RwLock::new(HashMap::new()));
let (stop_tx, mut stop_rx) = watch::channel(false);
self.running_props = Some(RunningClientTunnelProps {
local_connections: local_connections.clone(),
connection: conn.clone(),
stop_tx,
endpoint,
});
self.is_connected.store(true, Ordering::SeqCst);
tracing::debug!("Client tunnel \"{}\" connected to host", self.name);
let protocol = self.protocol;
let tunnel_addr = self.tunnel_addr;
let connection_listener = tokio::spawn(async move {
match protocol {
TunnelProtocol::Tcp => {
let sock = TcpListener::bind(tunnel_addr).await?;
tracing::debug!("Client TCP socket listening on {}", tunnel_addr);
loop {
tokio::select! {
biased;
_ = stop_rx.changed() => {
if *stop_rx.borrow() {
tracing::debug!("Stopping TCP stream for client tunnel");
break;
}
}
Ok((stream, client_local_addr)) = sock.accept() => {
let (tunnel_send, tunnel_recv) = conn.open_bi().await?;
send_packet(&conn, Client2HostControlMsg::ConnReq { local_addr: client_local_addr }).await?;
let local_client_connection = LocalClientConnection {
client_local_addr,
client_virtual_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
last_active: Arc::new(AtomicU64::new(0)),
};
let local_client_connection = Arc::new(RwLock::new(local_client_connection));
{
local_connections.write().await.insert(client_local_addr, local_client_connection.clone());
}
bi_tcp_stream_client(stream, local_connections.clone(), local_client_connection.clone(), tunnel_send, tunnel_recv, stop_rx.clone()).await?;
}
}
}
}
TunnelProtocol::Udp => {
let socket = UdpSocket::bind(tunnel_addr).await?;
let socket = Arc::new(socket);
tracing::debug!("Client UDP socket listening on {}", tunnel_addr);
loop {
let mut buf = vec![0; 1024];
let socket = socket.clone();
tokio::select! {
biased;
_ = stop_rx.changed() => {
if *stop_rx.borrow() {
tracing::debug!("Stopping UDP stream for client tunnel");
break;
}
}
Ok((_, client_local_addr)) = socket.peek_from(&mut buf) => {
println!("Received UDP packet from {}", client_local_addr);
{
let local_connections = local_connections.read().await;
if local_connections.contains_key(&client_local_addr) {
continue;
}
}
let (tunnel_send, tunnel_recv) = conn.open_bi().await?;
send_packet(&conn, Client2HostControlMsg::ConnReq { local_addr: client_local_addr }).await?;
let local_client_connection = LocalClientConnection {
client_local_addr,
client_virtual_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)),
last_active: Arc::new(AtomicU64::new(0)),
};
let local_client_connection = Arc::new(RwLock::new(local_client_connection));
{
local_connections.write().await.insert(client_local_addr, local_client_connection.clone());
}
bi_udp_client(socket, local_connections.clone(), local_client_connection.clone(), tunnel_send, tunnel_recv, stop_rx.clone()).await?;
}
}
}
}
};
Ok::<(), ClientError>(())
});
tracing::info!("Client tunnel \"{}\" is connected", self.name);
Ok(connection_listener)
}
pub async fn disconnect_local_connection(
local_connections: Arc<RwLock<HashMap<SocketAddr, Arc<RwLock<LocalClientConnection>>>>>,
connection_to_disconnect: Arc<RwLock<LocalClientConnection>>,
) -> bool {
let mut local_connections = local_connections.write().await;
let connection_to_disconnect = connection_to_disconnect.read().await;
let client_local_addr = connection_to_disconnect.client_local_addr;
if local_connections.remove(&client_local_addr).is_some() {
tracing::debug!(
"Disconnected local connection {}. Connections left: {}",
client_local_addr,
local_connections.len()
);
true
} else {
tracing::warn!("Tried disconnecting a local connection that does not exist");
false
}
}
pub async fn disconnect(&mut self) -> Result<(), ClientError> {
if !self.is_running() {
tracing::warn!(
"Client tunnel \"{}\" is not connected, not stopping",
self.name
);
return Err(ClientError::NotConnected);
}
if let Some(running_props) = self.running_props.take() {
running_props.stop_tx.send(true).unwrap();
running_props.connection.close(VarInt::from_u32(0), &[]);
running_props.endpoint.close().await;
}
self.is_connected.store(false, Ordering::SeqCst);
tracing::info!("Client tunnel \"{}\" disconnected", self.name);
Ok(())
}
}
impl TunnelCommon for ClientTunnel {
fn secret(&self) -> [u8; 32] {
*self.secret.as_bytes()
}
fn protocol(&self) -> TunnelProtocol {
self.protocol
}
fn is_running(&self) -> bool {
self.is_connected.load(Ordering::SeqCst)
}
fn name(&self) -> String {
self.name.clone()
}
async fn num_active_connections(&self) -> usize {
if let Some(running_props) = &self.running_props {
let local_connections = running_props.local_connections.read().await;
local_connections.len()
} else {
0
}
}
fn addr(&self) -> SocketAddr {
self.tunnel_addr
}
}