use super::Filter;
use anyhow::{Context, Result, bail};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tracing::{debug, warn};
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10);
const MAX_DOMAIN_LEN: usize = 255;
const REP_SUCCEEDED: u8 = 0x00;
const REP_GENERAL_FAILURE: u8 = 0x01;
const REP_NOT_ALLOWED: u8 = 0x02;
const REP_NETWORK_UNREACHABLE: u8 = 0x03;
const REP_HOST_UNREACHABLE: u8 = 0x04;
const REP_CONNECTION_REFUSED: u8 = 0x05;
const REP_COMMAND_NOT_SUPPORTED: u8 = 0x07;
const REP_ADDRESS_TYPE_NOT_SUPPORTED: u8 = 0x08;
#[derive(Debug)]
pub struct Socks5Server {
listener: TcpListener,
filter: Filter,
port: u16,
upstream: crate::proxy::upstream::UpstreamConfig,
}
impl Socks5Server {
pub async fn bind(port: Option<u16>, filter: Filter) -> Result<Self> {
let bind_port = port.unwrap_or(0);
let listener = TcpListener::bind(("127.0.0.1", bind_port))
.await
.with_context(|| format!("bind socks5 listener on 127.0.0.1:{bind_port}"))?;
let port = listener.local_addr()?.port();
let upstream = crate::proxy::upstream::UpstreamConfig::from_env();
Ok(Self {
listener,
filter,
port,
upstream,
})
}
pub fn port(&self) -> u16 {
self.port
}
pub fn with_upstream(mut self, upstream: crate::proxy::upstream::UpstreamConfig) -> Self {
self.upstream = upstream;
self
}
pub async fn serve(self) {
let filter = self.filter;
let upstream = std::sync::Arc::new(self.upstream);
loop {
let (sock, peer) = match self.listener.accept().await {
Ok(t) => t,
Err(e) => {
warn!("socks5 accept failed: {e}");
return;
}
};
let f = filter.clone();
let up = std::sync::Arc::clone(&upstream);
tokio::spawn(async move {
if let Err(e) = handle_one(sock, &f, &up).await {
debug!("socks5 connection from {peer} ended: {e:#}");
}
});
}
}
}
async fn handle_one(
mut client: TcpStream,
filter: &Filter,
upstream: &crate::proxy::upstream::UpstreamConfig,
) -> Result<()> {
tokio::time::timeout(HANDSHAKE_TIMEOUT, async {
greet(&mut client).await?;
let target = read_request(&mut client).await?;
let target = match target {
RequestOutcome::Reject { rep, atyp } => {
send_reply(&mut client, rep, atyp).await?;
return Ok(());
}
RequestOutcome::Connect(t) => t,
};
if !filter.allows(&target.host) {
debug!("socks5: blocked CONNECT {} (not in allowlist)", target.host);
send_reply(&mut client, REP_NOT_ALLOWED, 0x03).await?;
return Ok(());
}
let upstream_addr = format!("{}:{}", target.host, target.port);
let dial = crate::proxy::upstream::connect_upstream(&upstream_addr, upstream).await;
let upstream_sock = match dial {
Ok(s) => s,
Err(e) => {
let rep = e
.downcast_ref::<std::io::Error>()
.map(io_error_to_rep)
.unwrap_or(REP_GENERAL_FAILURE);
warn!("socks5: upstream connect to {upstream_addr} failed: {e:#}");
send_reply(&mut client, rep, 0x03).await?;
return Ok(());
}
};
send_reply(&mut client, REP_SUCCEEDED, 0x03).await?;
let _ = super::relay::relay_with_timeouts(
client,
upstream_sock,
super::relay::DEFAULT_IDLE_TIMEOUT,
super::relay::DEFAULT_TOTAL_TIMEOUT,
)
.await;
Ok::<_, anyhow::Error>(())
})
.await
.context("socks5 handshake timed out")??;
Ok(())
}
async fn greet(client: &mut TcpStream) -> Result<()> {
let mut header = [0u8; 2];
client.read_exact(&mut header).await?;
if header[0] != 0x05 {
bail!("not a socks5 client (VER={:#x})", header[0]);
}
let nmethods = header[1] as usize;
let mut methods = vec![0u8; nmethods];
if nmethods > 0 {
client.read_exact(&mut methods).await?;
}
if !methods.contains(&0x00) {
client.write_all(&[0x05, 0xFF]).await?;
bail!("client offered no acceptable auth methods");
}
client.write_all(&[0x05, 0x00]).await?;
Ok(())
}
#[derive(Debug)]
struct ConnectTarget {
host: String,
port: u16,
}
#[derive(Debug)]
enum RequestOutcome {
Connect(ConnectTarget),
Reject {
rep: u8,
atyp: u8,
},
}
async fn read_request(client: &mut TcpStream) -> Result<RequestOutcome> {
let mut header = [0u8; 4];
client.read_exact(&mut header).await?;
let [ver, cmd, _rsv, atyp] = header;
if ver != 0x05 {
bail!("bad request version {ver:#x}");
}
if cmd != 0x01 {
drain_addr_port(client, atyp).await?;
return Ok(RequestOutcome::Reject {
rep: REP_COMMAND_NOT_SUPPORTED,
atyp,
});
}
let host = match atyp {
0x03 => {
let mut len_buf = [0u8; 1];
client.read_exact(&mut len_buf).await?;
let len = len_buf[0] as usize;
if len == 0 || len > MAX_DOMAIN_LEN {
drain_port(client).await?;
return Ok(RequestOutcome::Reject {
rep: REP_GENERAL_FAILURE,
atyp,
});
}
let mut buf = vec![0u8; len];
client.read_exact(&mut buf).await?;
match String::from_utf8(buf) {
Ok(s) => s,
Err(_) => {
drain_port(client).await?;
return Ok(RequestOutcome::Reject {
rep: REP_GENERAL_FAILURE,
atyp,
});
}
}
}
0x01 | 0x04 => {
drain_addr_port(client, atyp).await?;
return Ok(RequestOutcome::Reject {
rep: REP_ADDRESS_TYPE_NOT_SUPPORTED,
atyp,
});
}
other => {
bail!("unknown ATYP {other:#x}");
}
};
let mut port_buf = [0u8; 2];
client.read_exact(&mut port_buf).await?;
let port = u16::from_be_bytes(port_buf);
Ok(RequestOutcome::Connect(ConnectTarget { host, port }))
}
async fn drain_addr_port(client: &mut TcpStream, atyp: u8) -> Result<()> {
match atyp {
0x01 => {
let mut buf = [0u8; 4 + 2];
client.read_exact(&mut buf).await?;
}
0x03 => {
let mut len = [0u8; 1];
client.read_exact(&mut len).await?;
let mut buf = vec![0u8; len[0] as usize + 2];
client.read_exact(&mut buf).await?;
}
0x04 => {
let mut buf = [0u8; 16 + 2];
client.read_exact(&mut buf).await?;
}
_ => {}
}
Ok(())
}
async fn drain_port(client: &mut TcpStream) -> Result<()> {
let mut port_buf = [0u8; 2];
client.read_exact(&mut port_buf).await?;
Ok(())
}
async fn send_reply(client: &mut TcpStream, rep: u8, atyp: u8) -> Result<()> {
let addr_len = match atyp {
0x01 => 4,
0x04 => 16,
_ => 1,
};
let mut reply = vec![0u8; 4 + addr_len + 2];
reply[0] = 0x05;
reply[1] = rep;
reply[2] = 0x00;
reply[3] = if matches!(atyp, 0x01 | 0x04) {
atyp
} else {
0x03
};
client.write_all(&reply).await?;
Ok(())
}
fn io_error_to_rep(e: &std::io::Error) -> u8 {
use std::io::ErrorKind::*;
match e.kind() {
ConnectionRefused => REP_CONNECTION_REFUSED,
_ => {
let s = e.to_string().to_ascii_lowercase();
if s.contains("network is unreachable") {
REP_NETWORK_UNREACHABLE
} else if s.contains("no route to host") || s.contains("host is unreachable") {
REP_HOST_UNREACHABLE
} else {
REP_GENERAL_FAILURE
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
async fn spawn(filter: Filter) -> (u16, tokio::task::JoinHandle<()>) {
let server = Socks5Server::bind(None, filter)
.await
.unwrap()
.with_upstream(crate::proxy::upstream::UpstreamConfig::Direct);
let port = server.port();
let task = tokio::spawn(server.serve());
(port, task)
}
async fn greet_noauth(sock: &mut TcpStream) -> [u8; 2] {
sock.write_all(&[0x05, 0x01, 0x00]).await.unwrap();
let mut reply = [0u8; 2];
sock.read_exact(&mut reply).await.unwrap();
reply
}
async fn connect_domain(sock: &mut TcpStream, host: &str, port: u16) -> [u8; 4] {
let mut req = vec![0x05, 0x01, 0x00, 0x03, host.len() as u8];
req.extend_from_slice(host.as_bytes());
req.extend_from_slice(&port.to_be_bytes());
sock.write_all(&req).await.unwrap();
let mut reply = [0u8; 4];
sock.read_exact(&mut reply).await.unwrap();
reply
}
#[tokio::test]
async fn greeting_accepts_noauth() {
let (port, _task) = spawn(Filter::default()).await;
let mut sock = TcpStream::connect(("127.0.0.1", port)).await.unwrap();
assert_eq!(greet_noauth(&mut sock).await, [0x05, 0x00]);
}
#[tokio::test]
async fn greeting_rejects_no_acceptable_auth() {
let (port, _task) = spawn(Filter::default()).await;
let mut sock = TcpStream::connect(("127.0.0.1", port)).await.unwrap();
sock.write_all(&[0x05, 0x01, 0x02]).await.unwrap();
let mut reply = [0u8; 2];
sock.read_exact(&mut reply).await.unwrap();
assert_eq!(reply, [0x05, 0xFF]);
}
#[tokio::test]
async fn connect_domain_blocked_by_filter() {
let (port, _task) = spawn(Filter::new(["github.com"]).unwrap()).await;
let mut sock = TcpStream::connect(("127.0.0.1", port)).await.unwrap();
let _ = greet_noauth(&mut sock).await;
let reply = connect_domain(&mut sock, "evil.example.com", 443).await;
assert_eq!(reply[0], 0x05);
assert_eq!(reply[1], REP_NOT_ALLOWED);
}
#[tokio::test]
async fn ipv4_literal_rejected_with_atyp_unsupported() {
let (port, _task) = spawn(Filter::new(["github.com"]).unwrap()).await;
let mut sock = TcpStream::connect(("127.0.0.1", port)).await.unwrap();
let _ = greet_noauth(&mut sock).await;
sock.write_all(&[0x05, 0x01, 0x00, 0x01, 1, 2, 3, 4, 0x01, 0xBB])
.await
.unwrap();
let mut reply = [0u8; 4];
sock.read_exact(&mut reply).await.unwrap();
assert_eq!(reply[0], 0x05);
assert_eq!(reply[1], REP_ADDRESS_TYPE_NOT_SUPPORTED);
}
#[tokio::test]
async fn ipv6_literal_rejected_with_atyp_unsupported() {
let (port, _task) = spawn(Filter::new(["github.com"]).unwrap()).await;
let mut sock = TcpStream::connect(("127.0.0.1", port)).await.unwrap();
let _ = greet_noauth(&mut sock).await;
let mut req = vec![0x05, 0x01, 0x00, 0x04];
req.extend_from_slice(&[0u8; 16]);
req.extend_from_slice(&443u16.to_be_bytes());
sock.write_all(&req).await.unwrap();
let mut reply = [0u8; 4];
sock.read_exact(&mut reply).await.unwrap();
assert_eq!(reply[1], REP_ADDRESS_TYPE_NOT_SUPPORTED);
}
#[tokio::test]
async fn bind_command_refused() {
let (port, _task) = spawn(Filter::default()).await;
let mut sock = TcpStream::connect(("127.0.0.1", port)).await.unwrap();
let _ = greet_noauth(&mut sock).await;
sock.write_all(&[0x05, 0x02, 0x00, 0x03, 0x01, b'x', 0x00, 0x50])
.await
.unwrap();
let mut reply = [0u8; 4];
sock.read_exact(&mut reply).await.unwrap();
assert_eq!(reply[1], REP_COMMAND_NOT_SUPPORTED);
}
#[tokio::test]
async fn allowed_domain_succeeds_then_relays() {
let upstream = TcpListener::bind("127.0.0.1:0").await.unwrap();
let upstream_port = upstream.local_addr().unwrap().port();
tokio::spawn(async move {
let (mut s, _) = upstream.accept().await.unwrap();
let mut byte = [0u8; 1];
s.read_exact(&mut byte).await.unwrap();
s.write_all(&byte).await.unwrap();
});
let (port, _task) = spawn(Filter::new(["localhost"]).unwrap()).await;
let mut sock = TcpStream::connect(("127.0.0.1", port)).await.unwrap();
let _ = greet_noauth(&mut sock).await;
let reply = connect_domain(&mut sock, "localhost", upstream_port).await;
assert_eq!(reply[1], REP_SUCCEEDED);
let mut tail = [0u8; 3];
sock.read_exact(&mut tail).await.unwrap();
sock.write_all(&[0x42]).await.unwrap();
let mut echo = [0u8; 1];
sock.read_exact(&mut echo).await.unwrap();
assert_eq!(echo[0], 0x42);
}
#[test]
fn io_error_mapping_picks_sensible_codes() {
let refused = std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "refused");
assert_eq!(io_error_to_rep(&refused), REP_CONNECTION_REFUSED);
let other = std::io::Error::other("network is unreachable");
assert_eq!(io_error_to_rep(&other), REP_NETWORK_UNREACHABLE);
let mystery = std::io::Error::other("kaboom");
assert_eq!(io_error_to_rep(&mystery), REP_GENERAL_FAILURE);
}
}