mod handle_remote;
mod maybe_retryable;
mod ws_connect;
use self::handle_remote::handle_remote;
use self::maybe_retryable::MaybeRetryableError;
use crate::arg::ClientArgs;
use crate::config;
use crate::tls::MaybeTlsStream;
use bytes::Bytes;
use futures_util::TryFutureExt;
use parking_lot::Mutex;
use penguin_mux::timing::{Backoff, OptionalDuration};
use penguin_mux::{Datagram, Dupe, IntKey, Multiplexor, MuxStream};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, OnceLock};
use std::time::Duration;
use thiserror::Error;
use tokio::io::AsyncWriteExt;
use tokio::net::{TcpStream, UdpSocket};
use tokio::sync::{mpsc, oneshot};
use tokio::task::JoinSet;
use tokio::time;
use tracing::{debug, error, info, trace, warn};
#[cfg(feature = "nohash")]
use nohash_hasher::IntMap;
#[cfg(not(feature = "nohash"))]
use std::collections::HashMap as IntMap;
#[derive(Debug, Error)]
pub enum Error {
#[error("maximum retry count reached (last error: {0})")]
MaxRetryCountReached(Box<Self>),
#[error("failed to parse remote: {0}")]
ParseRemote(#[from] crate::parse_remote::Error),
#[error("remote handler exited: {0}")]
RemoteHandlerExited(#[from] handle_remote::FatalError),
#[error("given domain name is not encodable in UTF-8")]
InvalidDomainName(http::header::ToStrError),
#[error(transparent)]
Tungstenite(#[from] tokio_tungstenite::tungstenite::Error),
#[error("error making a TCP connection: {0}")]
TcpConnect(std::io::Error),
#[error(transparent)]
Tls(#[from] crate::tls::Error),
#[error(transparent)]
Mux(#[from] penguin_mux::Error),
#[error("initial WebSocket handshake timed out")]
HandshakeTimeout,
#[error("user cancelled initial WebSocket handshake")]
HandshakeCancelled,
#[error("stream request timed out")]
StreamRequestTimeout,
#[error("server disconnected normally")]
ServerDisconnected,
}
#[derive(Debug)]
pub struct StreamCommand {
tx: oneshot::Sender<MuxStream>,
host: Bytes,
port: u16,
}
#[derive(Clone, Debug)]
pub struct HandlerResources {
stream_command_tx: mpsc::Sender<StreamCommand>,
datagram_tx: mpsc::Sender<Datagram>,
udp_client_map: Arc<Mutex<ClientIdMaps>>,
}
impl HandlerResources {
#[must_use]
pub fn create() -> (
Self,
mpsc::Receiver<StreamCommand>,
mpsc::Receiver<Datagram>,
) {
let (stream_command_tx, stream_command_rx) =
mpsc::channel(config::STREAM_REQUEST_COMMAND_SIZE);
let (datagram_tx, datagram_rx) = mpsc::channel(config::INCOMING_DATAGRAM_BUFFER_SIZE);
let udp_client_map = Arc::new(Mutex::new(ClientIdMaps::new()));
(
Self {
stream_command_tx,
datagram_tx,
udp_client_map: udp_client_map.dupe(),
},
stream_command_rx,
datagram_rx,
)
}
#[must_use = "This function returns the new client ID, which should be used to mark the datagram"]
pub fn add_udp_client(&self, addr: SocketAddr, socket: Arc<UdpSocket>, socks5: bool) -> u32 {
let our_addr = socket
.local_addr()
.expect("Failed to get local address of UDP socket (this is a bug)");
let ClientIdMaps {
client_id_map,
client_addr_map,
} = &mut *self.udp_client_map.lock();
if let Some(client_id) = client_addr_map.get(&(addr, our_addr)) {
client_id_map
.get_mut(client_id)
.expect("`client_id_map` and `client_addr_map` are inconsistent (this is a bug)")
.refresh();
*client_id
} else {
let client_id = u32::next_available_key(client_id_map);
client_id_map.insert(
client_id,
ClientIdMapEntry::new(addr, our_addr, socket, socks5),
);
client_addr_map.insert((addr, our_addr), client_id);
client_id
}
}
fn prune_udp_clients(&self) {
let ClientIdMaps {
client_id_map,
client_addr_map,
} = &mut *self.udp_client_map.lock();
let now = time::Instant::now();
client_id_map.retain(|_, entry| {
if entry.expires > now {
true
} else {
let client_id = client_addr_map
.remove(&(entry.peer_addr, entry.our_addr))
.expect(
"`client_id_map` and `client_addr_map` are inconsistent (this is a bug)",
);
debug!(
"pruned inactive UDP client {client_id:08x} for address {}",
entry.peer_addr
);
false
}
});
}
}
#[derive(Clone, Debug)]
#[expect(clippy::module_name_repetitions)]
pub struct ClientIdMaps {
client_id_map: IntMap<u32, ClientIdMapEntry>,
client_addr_map: HashMap<(SocketAddr, SocketAddr), u32>,
}
impl ClientIdMaps {
#[must_use]
fn new() -> Self {
Self {
client_id_map: IntMap::default(),
client_addr_map: HashMap::new(),
}
}
async fn send_datagram_reply(
lock_self: &Mutex<Self>,
client_id: u32,
data: &[u8],
) -> Option<std::io::Result<()>> {
if client_id == 0 {
return Some(tokio::io::stdout().write_all(data).await);
}
let (socket, peer_addr, socks5) = {
let Self {
client_id_map,
client_addr_map: _,
} = &mut *lock_self.lock();
let entry = client_id_map.get_mut(&client_id)?;
entry.refresh();
(entry.socket.dupe(), entry.peer_addr, entry.socks5)
};
let send_result = if socks5 {
handle_remote::socks::send_udp_relay_response(&socket, peer_addr, data).await
} else {
socket.send_to(data, peer_addr).await
}
.map(|_| ());
Some(send_result)
}
}
#[derive(Clone, Debug)]
#[expect(clippy::module_name_repetitions)]
pub struct ClientIdMapEntry {
pub peer_addr: SocketAddr,
pub our_addr: SocketAddr,
pub socket: Arc<UdpSocket>,
pub socks5: bool,
pub expires: time::Instant,
}
impl Dupe for ClientIdMapEntry {
fn dupe(&self) -> Self {
Self {
peer_addr: self.peer_addr,
our_addr: self.our_addr,
socket: self.socket.dupe(),
socks5: self.socks5,
expires: self.expires,
}
}
}
impl ClientIdMapEntry {
#[must_use]
fn new(
peer_addr: SocketAddr,
our_addr: SocketAddr,
socket: Arc<UdpSocket>,
socks5: bool,
) -> Self {
Self {
peer_addr,
our_addr,
socket,
socks5,
expires: time::Instant::now() + config::UDP_PRUNE_TIMEOUT,
}
}
fn refresh(&mut self) {
self.expires = time::Instant::now() + config::UDP_PRUNE_TIMEOUT;
}
}
#[tracing::instrument(level = "trace")]
pub async fn client_main(args: &'static ClientArgs) -> Result<(), Error> {
static HANDLER_RESOURCES: OnceLock<HandlerResources> = OnceLock::new();
let (handler_resources, stream_command_rx, datagram_rx) = HandlerResources::create();
HANDLER_RESOURCES
.set(handler_resources)
.expect("HandlerResources should only be set once (this is a bug)");
client_main_inner(
args,
HANDLER_RESOURCES
.get()
.expect("HandlerResources should be set (this is a bug)"),
stream_command_rx,
datagram_rx,
)
.await
}
pub async fn client_main_inner(
args: &'static ClientArgs,
handler_resources: &'static HandlerResources,
mut stream_command_rx: mpsc::Receiver<StreamCommand>,
mut datagram_rx: mpsc::Receiver<Datagram>,
) -> Result<(), Error> {
if args.proxy.is_some() {
warn!("Proxy not implemented yet");
}
let mut jobs = JoinSet::new();
for remote in &args.remote {
jobs.spawn(handle_remote(remote, handler_resources));
}
let check_listeners_future = async move {
while let Some(result) = jobs.join_next().await {
result.expect("JoinSet panicked (this is a bug)")?;
}
Ok::<(), Error>(())
};
let main_future = async move {
let mut backoff = Backoff::new(
Duration::from_millis(200),
Duration::from_millis(args.max_retry_interval),
2,
args.max_retry_count,
);
let mut failed_stream_request: Option<StreamCommand> = None;
loop {
let r = ws_connect::handshake(args)
.and_then(|ws_stream| {
on_connected(
args,
ws_stream,
&mut stream_command_rx,
&mut failed_stream_request,
&mut datagram_rx,
&handler_resources.udp_client_map,
)
.inspect_err(|_| backoff.reset())
})
.await;
match r {
Ok(()) => return Ok(()),
Err(ref e) if !e.retryable() => return r,
Err(e) => {
warn!("Connection failed: {e}");
let Some(current_retry_interval) = backoff.advance() else {
warn!("Max retry count reached, giving up");
return Err(Error::MaxRetryCountReached(Box::new(e)));
};
warn!("Reconnecting in {current_retry_interval:?}");
if time::timeout(current_retry_interval, tokio::signal::ctrl_c())
.await
.is_ok()
{
return Err(Error::HandshakeCancelled);
}
}
}
}
};
tokio::select! {
biased;
result = check_listeners_future => result,
() = prune_client_id_map_task(handler_resources) => unreachable!("prune_client_id_map_task should never return"),
result = main_future => result,
}
}
#[tracing::instrument(skip_all, level = "debug")]
async fn on_connected(
args: &ClientArgs,
ws_stream: tokio_tungstenite::WebSocketStream<MaybeTlsStream<TcpStream>>,
stream_command_rx: &mut mpsc::Receiver<StreamCommand>,
failed_stream_request: &mut Option<StreamCommand>,
datagram_rx: &mut mpsc::Receiver<Datagram>,
udp_client_map: &Mutex<ClientIdMaps>,
) -> Result<(), Error> {
let mut mux_task_joinset = JoinSet::new();
let options = penguin_mux::config::Options::new()
.keepalive_interval(args.keepalive)
.keepalive_timeout(args.keepalive_timeout);
let mux = Multiplexor::new(ws_stream, Some(options), Some(&mut mux_task_joinset));
info!("Connected to server");
if let Some(sender) = failed_stream_request.take() {
get_send_stream_chan(&mux, sender, failed_stream_request, args.channel_timeout).await?;
}
loop {
tokio::select! {
Some(mux_task_joinset_result) = mux_task_joinset.join_next() => {
mux_task_joinset_result.expect("Task panicked (this is a bug)")?;
}
Some(sender) = stream_command_rx.recv() => {
get_send_stream_chan(&mux, sender, failed_stream_request, args.channel_timeout).await?;
}
Some(datagram) = datagram_rx.recv() => {
if let Err(e) = mux.send_datagram(datagram).await {
error!("{e}");
}
}
Ok(dgram_frame) = mux.get_datagram() => {
let client_id = dgram_frame.flow_id;
let data = dgram_frame.data;
match ClientIdMaps::send_datagram_reply(udp_client_map, client_id, data.as_ref()).await {
Some(Ok(())) => trace!("sent datagram to client {client_id:08x}"),
Some(Err(e)) => warn!("Failed to send datagram to client {client_id:08x}: {e}"),
None => info!("Received datagram for unknown client ID: {client_id:08x}"),
}
}
Ok(()) = tokio::signal::ctrl_c() => {
info!("Received Ctrl-C, exiting once all streams are closed");
drop(mux);
while let Some(result) = mux_task_joinset.join_next().await {
result.expect("Task panicked (this is a bug)")?;
}
return Ok(());
}
else => return Err(Error::ServerDisconnected),
}
}
}
#[tracing::instrument(skip_all, level = "trace")]
async fn get_send_stream_chan(
mux: &Multiplexor,
stream_command: StreamCommand,
failed_stream_request: &mut Option<StreamCommand>,
channel_timeout: OptionalDuration,
) -> Result<(), Error> {
trace!("requesting a new TCP channel");
match channel_timeout
.timeout(mux.new_stream_channel(&stream_command.host, stream_command.port))
.await
{
Ok(Ok(stream)) => {
trace!("got a new channel");
stream_command.tx.send(stream).ok();
trace!("sent stream to handler (or handler died)");
Ok(())
}
Ok(Err(e)) => {
failed_stream_request.replace(stream_command);
Err(e.into())
}
Err(_) => {
failed_stream_request.replace(stream_command);
Err(Error::StreamRequestTimeout)
}
}
}
#[tracing::instrument(skip_all, level = "trace")]
async fn prune_client_id_map_task(handler_resources: &HandlerResources) {
let mut interval = time::interval(config::UDP_PRUNE_TIMEOUT);
interval.set_missed_tick_behavior(time::MissedTickBehavior::Delay);
loop {
interval.tick().await;
handler_resources.prune_udp_clients();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::IpAddr;
#[tokio::test]
async fn test_client_map_add_client() {
crate::tests::setup_logging();
let (stub_stream_tx, _stub_stream_rx) = mpsc::channel(1);
let (stub_datagram_tx, _stub_datagram_rx) = mpsc::channel(1);
let handler_resources = HandlerResources {
stream_command_tx: stub_stream_tx,
datagram_tx: stub_datagram_tx,
udp_client_map: Arc::new(Mutex::new(ClientIdMaps::new())),
};
let stub_socket = Arc::new(UdpSocket::bind(("127.0.0.1", 0)).await.unwrap());
let client_id = handler_resources.add_udp_client(
(IpAddr::from([127, 0, 0, 1]), 1234).into(),
stub_socket.dupe(),
false,
);
let client_id2 = handler_resources.add_udp_client(
(IpAddr::from([127, 0, 0, 1]), 1234).into(),
stub_socket.dupe(),
false,
);
assert_eq!(client_id, client_id2);
let stub_socket_2 = Arc::new(UdpSocket::bind(("127.0.0.1", 0)).await.unwrap());
let client_id2 = handler_resources.add_udp_client(
(IpAddr::from([127, 0, 0, 1]), 1234).into(),
stub_socket_2,
false,
);
assert_ne!(client_id, client_id2);
let client_id2 = handler_resources.add_udp_client(
(IpAddr::from([127, 0, 0, 1]), 1235).into(),
stub_socket.dupe(),
false,
);
assert_ne!(client_id, client_id2);
}
#[tokio::test]
async fn test_client_map_remove_client() {
crate::tests::setup_logging();
let (stub_stream_tx, _stub_stream_rx) = mpsc::channel(1);
let (stub_datagram_tx, _stub_datagram_rx) = mpsc::channel(1);
let handler_resources = HandlerResources {
stream_command_tx: stub_stream_tx,
datagram_tx: stub_datagram_tx,
udp_client_map: Arc::new(Mutex::new(ClientIdMaps::new())),
};
let stub_socket = Arc::new(UdpSocket::bind(("127.0.0.1", 0)).await.unwrap());
let _ = handler_resources.add_udp_client(
(IpAddr::from([127, 0, 0, 1]), 1234).into(),
stub_socket.dupe(),
false,
);
tokio::time::sleep(config::UDP_PRUNE_TIMEOUT).await;
handler_resources.prune_udp_clients();
assert!(
handler_resources
.udp_client_map
.lock()
.client_id_map
.is_empty()
);
}
}