use super::FatalError;
use crate::client::HandlerResources;
use crate::config;
use bytes::Bytes;
use penguin_mux::{Datagram, Dupe};
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::net::UdpSocket;
use tracing::{info, trace};
#[tracing::instrument(skip(handler_resources), level = "debug")]
pub(super) async fn handle_udp(
lhost: &'static str,
lport: u16,
rhost: &'static str,
rport: u16,
handler_resources: &HandlerResources,
) -> Result<(), FatalError> {
let socket = UdpSocket::bind((lhost, lport))
.await
.map_err(FatalError::ClientIo)?;
let socket = Arc::new(socket);
let local_addr = socket
.local_addr()
.expect("Failed to get local address of UDP socket (this is a bug)");
info!("Bound on {local_addr}");
loop {
let mut buf = vec![0; config::MAX_UDP_PACKET_SIZE];
let (len, addr) = socket
.recv_from(&mut buf)
.await
.map_err(FatalError::ClientIo)?;
buf.truncate(len);
trace!("received {len} bytes from {addr}");
let client_id = handler_resources.add_udp_client(addr, socket.dupe(), false);
let frame = Datagram {
target_host: Bytes::from(rhost),
target_port: rport,
flow_id: client_id,
data: Bytes::from(buf),
};
handler_resources
.datagram_tx
.send(frame)
.await
.or(Err(FatalError::SendDatagram))?;
}
}
#[tracing::instrument(skip(handler_resources), level = "debug")]
pub(super) async fn handle_udp_stdio(
rhost: &'static str,
rport: u16,
handler_resources: &HandlerResources,
) -> Result<(), FatalError> {
let mut stdin = BufReader::new(tokio::io::stdin());
loop {
let mut line = String::new();
stdin
.read_line(&mut line)
.await
.map_err(FatalError::ClientIo)?;
let frame = Datagram {
target_host: Bytes::from_static(rhost.as_bytes()),
target_port: rport,
flow_id: 0,
data: Bytes::from(line),
};
handler_resources
.datagram_tx
.send(frame)
.await
.or(Err(FatalError::SendDatagram))?;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::client::ClientIdMaps;
use parking_lot::Mutex;
#[tokio::test]
async fn test_handle_udp() {
static LHOST: &str = "127.0.0.1";
static RHOST: &str = "127.0.0.1";
crate::tests::setup_logging();
let (datagram_tx, mut datagram_rx) = tokio::sync::mpsc::channel(1);
let (stream_command_tx, _) = tokio::sync::mpsc::channel(1);
let udp_client_map = Arc::new(Mutex::new(ClientIdMaps::new()));
let handler_resources = HandlerResources {
datagram_tx,
stream_command_tx,
udp_client_map: udp_client_map.dupe(),
};
let forwarding_task =
tokio::spawn(
async move { handle_udp(LHOST, 14196, RHOST, 255, &handler_resources).await },
);
let socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let local_addr = socket.local_addr().unwrap();
socket.connect("127.0.0.1:14196").await.unwrap();
socket.send(b"hello").await.unwrap();
let frame = datagram_rx.recv().await.unwrap();
assert_eq!(frame.target_host, RHOST.as_bytes());
assert_eq!(frame.target_port, 255);
assert_eq!(*frame.data, *b"hello");
let client_id = *udp_client_map
.lock()
.client_addr_map
.get(&(local_addr, ([127, 0, 0, 1], 14196).into()))
.unwrap();
assert_eq!(frame.flow_id, client_id);
forwarding_task.abort();
}
}