use std::io;
use std::net::{IpAddr, SocketAddr, TcpStream, UdpSocket};
use std::time::Duration;
use crate::error::{Error, Result};
use crate::net::socks;
const MAX_UDP_HEADER: usize = 22;
pub(crate) trait UdpTransport: Send {
fn send_to(&self, buf: &[u8], peer: SocketAddr) -> io::Result<usize>;
fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)>;
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()>;
fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()>;
#[allow(dead_code)]
fn local_addr(&self) -> io::Result<SocketAddr>;
}
pub enum UdpProxy {
Direct,
Socks5 {
host: String,
port: u16,
auth: Option<(String, String)>,
},
Unsupported,
}
pub(crate) fn open_udp_transport(
proxy: UdpProxy,
peer: SocketAddr,
) -> Result<Box<dyn UdpTransport>> {
match proxy {
UdpProxy::Direct => Ok(Box::new(DirectUdp::bind_for(peer)?)),
UdpProxy::Socks5 { host, port, auth } => {
let auth_ref = auth.as_ref().map(|(u, p)| (u.as_str(), p.as_str()));
Ok(Box::new(Socks5UdpTransport::connect(
&host, port, auth_ref,
)?))
}
UdpProxy::Unsupported => Err(Error::UnsupportedScheme(
"this proxy cannot tunnel UDP; HTTP/3 and TFTP need a direct \
connection or a SOCKS5 proxy"
.into(),
)),
}
}
pub(crate) struct DirectUdp {
sock: UdpSocket,
}
impl DirectUdp {
pub(crate) fn bind_for(peer: SocketAddr) -> io::Result<Self> {
let bind = if peer.is_ipv4() {
"0.0.0.0:0"
} else {
"[::]:0"
};
Ok(DirectUdp {
sock: UdpSocket::bind(bind)?,
})
}
}
impl UdpTransport for DirectUdp {
fn send_to(&self, buf: &[u8], peer: SocketAddr) -> io::Result<usize> {
self.sock.send_to(buf, peer)
}
fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.sock.recv_from(buf)
}
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.sock.set_read_timeout(dur)
}
fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.sock.set_write_timeout(dur)
}
fn local_addr(&self) -> io::Result<SocketAddr> {
self.sock.local_addr()
}
}
pub(crate) struct Socks5UdpTransport {
_control: TcpStream,
relay: UdpSocket,
}
impl Socks5UdpTransport {
pub(crate) fn connect(host: &str, port: u16, auth: Option<(&str, &str)>) -> Result<Self> {
let mut control = TcpStream::connect((host, port))?;
control.set_read_timeout(Some(Duration::from_secs(30)))?;
control.set_write_timeout(Some(Duration::from_secs(30)))?;
socks::socks5_negotiate(&mut control, auth)?;
let mut relay_addr = socks::socks5_request(&mut control, 0x03, "0.0.0.0", 0, false)?;
if relay_addr.ip().is_unspecified() {
let proxy_ip = control.peer_addr()?.ip();
relay_addr = SocketAddr::new(proxy_ip, relay_addr.port());
}
let bind = if relay_addr.is_ipv4() {
"0.0.0.0:0"
} else {
"[::]:0"
};
let relay = UdpSocket::bind(bind)?;
relay.connect(relay_addr)?;
control.set_read_timeout(None)?;
control.set_write_timeout(None)?;
Ok(Socks5UdpTransport {
_control: control,
relay,
})
}
}
impl UdpTransport for Socks5UdpTransport {
fn send_to(&self, buf: &[u8], peer: SocketAddr) -> io::Result<usize> {
let dgram = encode_udp_header(peer, buf);
self.relay.send(&dgram)?;
Ok(buf.len())
}
fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
let mut scratch = vec![0u8; buf.len().saturating_add(MAX_UDP_HEADER)];
let n = self.relay.recv(&mut scratch)?;
let (src, data) = decode_udp_header(&scratch[..n])
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let m = data.len().min(buf.len());
buf[..m].copy_from_slice(&data[..m]);
Ok((m, src))
}
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.relay.set_read_timeout(dur)
}
fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.relay.set_write_timeout(dur)
}
fn local_addr(&self) -> io::Result<SocketAddr> {
self.relay.local_addr()
}
}
fn encode_udp_header(dst: SocketAddr, data: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(MAX_UDP_HEADER + data.len());
out.extend_from_slice(&[0x00, 0x00, 0x00]); match dst.ip() {
IpAddr::V4(v4) => {
out.push(0x01);
out.extend_from_slice(&v4.octets());
}
IpAddr::V6(v6) => {
out.push(0x04);
out.extend_from_slice(&v6.octets());
}
}
out.extend_from_slice(&dst.port().to_be_bytes());
out.extend_from_slice(data);
out
}
fn decode_udp_header(buf: &[u8]) -> std::result::Result<(SocketAddr, &[u8]), String> {
if buf.len() < 4 {
return Err("socks5 udp: datagram too short".into());
}
if buf[2] != 0x00 {
return Err("socks5 udp: fragmentation not supported".into());
}
let (ip, rest): (IpAddr, &[u8]) = match buf[3] {
0x01 => {
if buf.len() < 4 + 4 + 2 {
return Err("socks5 udp: truncated IPv4 header".into());
}
let octets: [u8; 4] = buf[4..8].try_into().unwrap();
(IpAddr::V4(octets.into()), &buf[8..])
}
0x04 => {
if buf.len() < 4 + 16 + 2 {
return Err("socks5 udp: truncated IPv6 header".into());
}
let octets: [u8; 16] = buf[4..20].try_into().unwrap();
(IpAddr::V6(octets.into()), &buf[20..])
}
0x03 => return Err("socks5 udp: domain address not allowed in reply".into()),
other => return Err(format!("socks5 udp: unknown ATYP {other:#04x}")),
};
let port = u16::from_be_bytes([rest[0], rest[1]]);
Ok((SocketAddr::new(ip, port), &rest[2..]))
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, Ipv6Addr};
#[test]
fn header_roundtrip_ipv4() {
let dst = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34)), 443);
let dgram = encode_udp_header(dst, b"quic-bytes");
assert_eq!(&dgram[0..4], &[0x00, 0x00, 0x00, 0x01]);
let (src, data) = decode_udp_header(&dgram).unwrap();
assert_eq!(src, dst);
assert_eq!(data, b"quic-bytes");
}
#[test]
fn header_roundtrip_ipv6() {
let dst = SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 8080);
let dgram = encode_udp_header(dst, b"x");
assert_eq!(dgram[3], 0x04);
let (src, data) = decode_udp_header(&dgram).unwrap();
assert_eq!(src, dst);
assert_eq!(data, b"x");
}
#[test]
fn decode_rejects_fragmentation() {
let mut dgram =
encode_udp_header(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 1), b"d");
dgram[2] = 0x01; assert!(decode_udp_header(&dgram).is_err());
}
#[test]
fn decode_rejects_domain_atyp() {
let dgram = [0x00, 0x00, 0x00, 0x03, 0x01, b'x', 0x00, 0x50];
assert!(decode_udp_header(&dgram).is_err());
}
#[test]
fn decode_rejects_truncated() {
assert!(decode_udp_header(&[0x00, 0x00]).is_err());
assert!(decode_udp_header(&[0x00, 0x00, 0x00, 0x01, 1, 2]).is_err());
}
#[test]
fn socks5_udp_associate_roundtrip() {
use std::io::{Read, Write};
use std::net::{Ipv4Addr, TcpListener};
use std::thread;
let control = TcpListener::bind("127.0.0.1:0").unwrap();
let control_addr = control.local_addr().unwrap();
let relay = UdpSocket::bind("127.0.0.1:0").unwrap();
let relay_addr = relay.local_addr().unwrap();
let server = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 9)), 4433);
let ctrl = thread::spawn(move || {
let (mut s, _) = control.accept().unwrap();
let mut greet = [0u8; 3];
s.read_exact(&mut greet).unwrap();
s.write_all(&[0x05, 0x00]).unwrap();
let mut req = [0u8; 10];
s.read_exact(&mut req).unwrap();
assert_eq!(req[1], 0x03, "expected UDP ASSOCIATE");
let mut reply = vec![0x05, 0x00, 0x00, 0x01];
match relay_addr.ip() {
IpAddr::V4(v4) => reply.extend_from_slice(&v4.octets()),
IpAddr::V6(_) => unreachable!(),
}
reply.extend_from_slice(&relay_addr.port().to_be_bytes());
s.write_all(&reply).unwrap();
let mut sink = [0u8; 1];
let _ = s.read(&mut sink);
});
let relay_thread = thread::spawn(move || {
let mut buf = [0u8; 2048];
let (n, client) = relay.recv_from(&mut buf).unwrap();
let (dst, data) = decode_udp_header(&buf[..n]).unwrap();
assert_eq!(dst, server);
assert_eq!(data, b"ping");
let out = encode_udp_header(server, b"pong");
relay.send_to(&out, client).unwrap();
});
let t = Socks5UdpTransport::connect("127.0.0.1", control_addr.port(), None).unwrap();
t.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
t.send_to(b"ping", server).unwrap();
let mut buf = [0u8; 1024];
let (n, src) = t.recv_from(&mut buf).unwrap();
assert_eq!(&buf[..n], b"pong");
assert_eq!(src, server, "recv_from must report the decapsulated source");
drop(t); relay_thread.join().unwrap();
ctrl.join().unwrap();
}
#[test]
fn socks5_udp_rejects_foreign_source() {
use std::io::{Read, Write};
use std::net::{Ipv4Addr, TcpListener};
use std::thread;
let control = TcpListener::bind("127.0.0.1:0").unwrap();
let control_addr = control.local_addr().unwrap();
let relay = UdpSocket::bind("127.0.0.1:0").unwrap();
let relay_addr = relay.local_addr().unwrap();
let ctrl = thread::spawn(move || {
let (mut s, _) = control.accept().unwrap();
let mut greet = [0u8; 3];
s.read_exact(&mut greet).unwrap();
s.write_all(&[0x05, 0x00]).unwrap();
let mut req = [0u8; 10];
s.read_exact(&mut req).unwrap();
assert_eq!(req[1], 0x03, "expected UDP ASSOCIATE");
let mut reply = vec![0x05, 0x00, 0x00, 0x01];
match relay_addr.ip() {
IpAddr::V4(v4) => reply.extend_from_slice(&v4.octets()),
IpAddr::V6(_) => unreachable!(),
}
reply.extend_from_slice(&relay_addr.port().to_be_bytes());
s.write_all(&reply).unwrap();
let mut sink = [0u8; 1];
let _ = s.read(&mut sink);
});
let t = Socks5UdpTransport::connect("127.0.0.1", control_addr.port(), None).unwrap();
let victim = t.local_addr().unwrap();
t.set_read_timeout(Some(Duration::from_millis(300)))
.unwrap();
let server = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 9)), 4433);
let attacker = UdpSocket::bind("127.0.0.1:0").unwrap();
let forged = encode_udp_header(server, b"evil");
attacker
.send_to(
&forged,
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), victim.port()),
)
.unwrap();
let mut buf = [0u8; 1024];
match t.recv_from(&mut buf) {
Err(e) => assert!(
matches!(
e.kind(),
io::ErrorKind::WouldBlock | io::ErrorKind::TimedOut
),
"expected a timeout, got {e:?}"
),
Ok((n, src)) => panic!("forged datagram accepted: {n} bytes from {src}"),
}
drop(t);
ctrl.join().unwrap();
}
}