use anyhow::{Context, Result};
use futures::FutureExt;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::io::AsyncWriteExt;
use tokio::net::{TcpStream, UdpSocket};
const UDP_SESSION_IDLE: Duration = Duration::from_secs(60);
const UDP_BUF_BYTES: usize = 65_536;
pub async fn transfer(mut inbound: TcpStream, proxy_addr: SocketAddr) -> Result<()> {
let mut outbound = TcpStream::connect(proxy_addr).await?;
let (mut ri, mut wi) = inbound.split();
let (mut ro, mut wo) = outbound.split();
let client_to_server = async {
tokio::io::copy(&mut ri, &mut wo).await?;
wo.shutdown().await
};
let server_to_client = async {
tokio::io::copy(&mut ro, &mut wi).await?;
wi.shutdown().await
};
tokio::try_join!(client_to_server, server_to_client)?;
Ok(())
}
pub async fn proxy_handler(listen_addr: SocketAddr, dest_addr: SocketAddr) -> Result<()> {
tracing::debug!("Proxying TCP traffic: {} -> {}", listen_addr, dest_addr);
let listener = tokio::net::TcpListener::bind(&listen_addr).await?;
while let Ok((inbound, _)) = listener.accept().await {
let transfer = transfer(inbound, dest_addr).map(|r| {
if let Err(e) = r {
tracing::warn!("Proxy connection dropped, creating new handler: {}", e);
}
});
tokio::spawn(transfer);
}
Ok(())
}
type UdpSessions = Arc<Mutex<HashMap<SocketAddr, Arc<UdpSocket>>>>;
pub async fn udp_handler(listen_addr: SocketAddr, dest_addr: SocketAddr) -> Result<()> {
tracing::debug!("Proxying UDP traffic: {} -> {}", listen_addr, dest_addr);
let listener = Arc::new(
UdpSocket::bind(listen_addr)
.await
.with_context(|| format!("binding UDP listener {listen_addr}"))?,
);
let sessions: UdpSessions = Arc::new(Mutex::new(HashMap::new()));
let mut buf = vec![0u8; UDP_BUF_BYTES];
loop {
let (len, src) = match listener.recv_from(&mut buf).await {
Ok(t) => t,
Err(e) => {
tracing::warn!("UDP listener recv on {listen_addr}: {e}");
continue;
}
};
let cached = sessions.lock().unwrap().get(&src).cloned();
let outbound = match cached {
Some(s) => s,
None => {
let s = UdpSocket::bind("0.0.0.0:0")
.await
.context("binding ephemeral UDP outbound")?;
s.connect(dest_addr)
.await
.with_context(|| format!("connecting UDP outbound to {dest_addr}"))?;
let s = Arc::new(s);
sessions.lock().unwrap().insert(src, s.clone());
spawn_udp_reply_pump(s.clone(), listener.clone(), src, sessions.clone());
s
}
};
if let Err(e) = outbound.send(&buf[..len]).await {
tracing::warn!("UDP forward {src} -> {dest_addr}: {e}");
let mut guard = sessions.lock().unwrap();
if let Some(current) = guard.get(&src) {
if Arc::ptr_eq(current, &outbound) {
guard.remove(&src);
}
}
}
}
}
fn spawn_udp_reply_pump(
outbound: Arc<UdpSocket>,
listener: Arc<UdpSocket>,
client_src: SocketAddr,
sessions: UdpSessions,
) {
tokio::spawn(async move {
let mut buf = vec![0u8; UDP_BUF_BYTES];
loop {
match tokio::time::timeout(UDP_SESSION_IDLE, outbound.recv(&mut buf)).await {
Ok(Ok(n)) => {
if let Err(e) = listener.send_to(&buf[..n], client_src).await {
tracing::warn!("UDP reply to {client_src}: {e}");
break;
}
}
Ok(Err(e)) => {
tracing::warn!("UDP outbound recv from session {client_src}: {e}");
break;
}
Err(_) => {
tracing::trace!("UDP session {client_src} idle, reaping");
break;
}
}
}
let mut guard = sessions.lock().unwrap();
if let Some(current) = guard.get(&client_src) {
if Arc::ptr_eq(current, &outbound) {
guard.remove(&client_src);
}
}
});
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::timeout;
#[tokio::test]
async fn udp_handler_round_trips_a_datagram() {
let server = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let server_addr = server.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 1500];
let (n, src) = server.recv_from(&mut buf).await.unwrap();
server.send_to(&buf[..n], src).await.unwrap();
});
let probe = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let listen_addr = probe.local_addr().unwrap();
drop(probe);
let proxy_task = tokio::spawn(udp_handler(listen_addr, server_addr));
tokio::time::sleep(Duration::from_millis(50)).await;
let client = UdpSocket::bind("127.0.0.1:0").await.unwrap();
client.send_to(b"hello-udp", listen_addr).await.unwrap();
let mut buf = vec![0u8; 1500];
let (n, _) = timeout(Duration::from_secs(2), client.recv_from(&mut buf))
.await
.expect("timed out waiting for UDP echo")
.unwrap();
assert_eq!(&buf[..n], b"hello-udp");
proxy_task.abort();
server_task.await.unwrap();
}
#[tokio::test]
async fn udp_handler_isolates_sessions_per_client() {
let server = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let server_addr = server.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let mut buf = vec![0u8; 1500];
for _ in 0..2 {
let (n, src) = server.recv_from(&mut buf).await.unwrap();
let mut reply = format!("from-{}: ", src.port()).into_bytes();
reply.extend_from_slice(&buf[..n]);
server.send_to(&reply, src).await.unwrap();
}
});
let probe = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let listen_addr = probe.local_addr().unwrap();
drop(probe);
let proxy_task = tokio::spawn(udp_handler(listen_addr, server_addr));
tokio::time::sleep(Duration::from_millis(50)).await;
let client_a = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let client_b = UdpSocket::bind("127.0.0.1:0").await.unwrap();
client_a.send_to(b"A", listen_addr).await.unwrap();
client_b.send_to(b"B", listen_addr).await.unwrap();
let mut a_buf = vec![0u8; 1500];
let mut b_buf = vec![0u8; 1500];
let (an, _) = timeout(Duration::from_secs(2), client_a.recv_from(&mut a_buf))
.await
.expect("client A timed out")
.unwrap();
let (bn, _) = timeout(Duration::from_secs(2), client_b.recv_from(&mut b_buf))
.await
.expect("client B timed out")
.unwrap();
assert!(a_buf[..an].ends_with(b"A"), "A got: {:?}", &a_buf[..an]);
assert!(b_buf[..bn].ends_with(b"B"), "B got: {:?}", &b_buf[..bn]);
proxy_task.abort();
server_task.await.unwrap();
}
}