use std::{
future::Future,
net::{Ipv4Addr, SocketAddr},
pin::Pin,
sync::Arc,
time::Duration,
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
sync::Semaphore,
task::AbortHandle,
};
use ts_netstack_smoltcp::{CreateSocket, netcore::Channel};
use crate::{Error, InternalErrorKind};
pub(crate) type Resolver = Arc<
dyn Fn(String) -> Pin<Box<dyn Future<Output = Result<Option<Ipv4Addr>, Error>> + Send>>
+ Send
+ Sync,
>;
const SOCKS5_VER: u8 = 0x05;
const METHOD_USER_PASS: u8 = 0x02;
const METHOD_NONE: u8 = 0xFF;
const AUTH_VER: u8 = 0x01;
const CMD_CONNECT: u8 = 0x01;
const ATYP_IPV4: u8 = 0x01;
const ATYP_DOMAIN: u8 = 0x03;
const ATYP_IPV6: u8 = 0x04;
const REP_CMD_NOT_SUPPORTED: u8 = 0x07;
const REP_ATYP_NOT_SUPPORTED: u8 = 0x08;
const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(30);
const PROXY_USERNAME: &str = "tsnet";
#[derive(Debug, Clone, PartialEq, Eq)]
enum Target {
Ipv4(Ipv4Addr, u16),
Domain(String, u16),
}
fn parse_request(buf: &[u8]) -> Result<Target, u8> {
if buf.len() < 4 || buf[0] != SOCKS5_VER {
return Err(REP_CMD_NOT_SUPPORTED);
}
if buf[1] != CMD_CONNECT {
return Err(REP_CMD_NOT_SUPPORTED);
}
let atyp = buf[3];
match atyp {
ATYP_IPV4 => {
if buf.len() < 4 + 4 + 2 {
return Err(REP_CMD_NOT_SUPPORTED);
}
let ip = Ipv4Addr::new(buf[4], buf[5], buf[6], buf[7]);
let port = u16::from_be_bytes([buf[8], buf[9]]);
Ok(Target::Ipv4(ip, port))
}
ATYP_DOMAIN => {
if buf.len() < 5 {
return Err(REP_CMD_NOT_SUPPORTED);
}
let len = buf[4] as usize;
if buf.len() < 5 + len + 2 {
return Err(REP_CMD_NOT_SUPPORTED);
}
let host = match std::str::from_utf8(&buf[5..5 + len]) {
Ok(h) => h.to_owned(),
Err(_) => return Err(REP_CMD_NOT_SUPPORTED),
};
let port = u16::from_be_bytes([buf[5 + len], buf[6 + len]]);
Ok(Target::Domain(host, port))
}
ATYP_IPV6 => Err(REP_ATYP_NOT_SUPPORTED),
_ => Err(REP_ATYP_NOT_SUPPORTED),
}
}
#[derive(Clone)]
pub(crate) struct OverlayDialer {
channel: Channel,
self_ipv4: Ipv4Addr,
resolve: Resolver,
}
impl OverlayDialer {
async fn dial_ipv4(
&self,
addr: Ipv4Addr,
port: u16,
) -> Result<crate::netstack::TcpStream, Error> {
let ephemeral_port = rand::random_range(49152..=u16::MAX);
self.channel
.tcp_connect((self.self_ipv4, ephemeral_port).into(), (addr, port).into())
.await
.map_err(Into::into)
}
async fn dial_name(&self, name: &str, port: u16) -> Result<crate::netstack::TcpStream, Error> {
let addr = (self.resolve)(name.to_string())
.await?
.ok_or(Error::Internal(InternalErrorKind::BadRequest))?;
self.dial_ipv4(addr, port).await
}
async fn dial(&self, target: &Target) -> Result<crate::netstack::TcpStream, Error> {
match target {
Target::Ipv4(addr, port) => self.dial_ipv4(*addr, *port).await,
Target::Domain(host, port) => self.dial_name(host, *port).await,
}
}
}
#[must_use = "dropping the handle stops the loopback SOCKS5 proxy"]
pub struct LoopbackHandle {
accept_task: AbortHandle,
}
impl LoopbackHandle {
pub fn shutdown(self) {
}
}
impl Drop for LoopbackHandle {
fn drop(&mut self) {
self.accept_task.abort();
}
}
impl OverlayDialer {
pub(crate) fn new(channel: Channel, self_ipv4: Ipv4Addr, resolve: Resolver) -> Self {
Self {
channel,
self_ipv4,
resolve,
}
}
}
pub(crate) async fn start(
dialer: OverlayDialer,
) -> Result<(SocketAddr, String, LoopbackHandle), Error> {
let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 0))
.await
.map_err(|_| Error::Internal(InternalErrorKind::Io))?;
let local_addr = listener
.local_addr()
.map_err(|_| Error::Internal(InternalErrorKind::Io))?;
let cred = gen_cred();
let accept_cred = cred.clone();
let task = tokio::spawn(async move {
accept_loop(listener, dialer, accept_cred).await;
});
Ok((
local_addr,
cred,
LoopbackHandle {
accept_task: task.abort_handle(),
},
))
}
fn gen_cred() -> String {
let b: [u8; 16] = rand::random();
b.iter().map(|x| format!("{x:02x}")).collect()
}
const MAX_CONCURRENT_CONNS: usize = 256;
async fn accept_loop(listener: TcpListener, dialer: OverlayDialer, cred: String) {
let sem = Arc::new(Semaphore::new(MAX_CONCURRENT_CONNS));
loop {
let permit = match sem.clone().acquire_owned().await {
Ok(permit) => permit,
Err(_) => return,
};
let (sock, _peer) = match listener.accept().await {
Ok(pair) => pair,
Err(e) => {
tracing::warn!(error = %e, "loopback SOCKS5 accept failed; stopping accept loop");
return;
}
};
let dialer = dialer.clone();
let cred = cred.clone();
tokio::spawn(async move {
let _permit = permit;
if let Err(e) = handle_conn(sock, dialer, cred).await {
tracing::debug!(error = %e, "loopback SOCKS5 connection ended");
}
});
}
}
async fn handle_conn(sock: TcpStream, dialer: OverlayDialer, cred: String) -> std::io::Result<()> {
let negotiated =
match tokio::time::timeout(HANDSHAKE_TIMEOUT, negotiate(sock, dialer, cred)).await {
Ok(res) => res?,
Err(_elapsed) => {
tracing::debug!("loopback SOCKS5 handshake timed out");
return Ok(());
}
};
let Some((mut sock, mut overlay)) = negotiated else {
return Ok(());
};
match tokio::io::copy_bidirectional(&mut sock, &mut overlay).await {
Ok((to_overlay, to_host)) => {
tracing::debug!(to_overlay, to_host, "loopback SOCKS5 splice finished");
}
Err(e) => {
tracing::debug!(error = %e, "loopback SOCKS5 splice ended");
}
}
Ok(())
}
async fn negotiate(
mut sock: TcpStream,
dialer: OverlayDialer,
cred: String,
) -> std::io::Result<Option<(TcpStream, crate::netstack::TcpStream)>> {
let mut head = [0u8; 2];
sock.read_exact(&mut head).await?;
if head[0] != SOCKS5_VER {
return Ok(None);
}
let nmethods = head[1] as usize;
let mut methods = vec![0u8; nmethods];
sock.read_exact(&mut methods).await?;
if !methods.contains(&METHOD_USER_PASS) {
sock.write_all(&[SOCKS5_VER, METHOD_NONE]).await?;
return Ok(None);
}
sock.write_all(&[SOCKS5_VER, METHOD_USER_PASS]).await?;
let mut avh = [0u8; 2];
sock.read_exact(&mut avh).await?;
if avh[0] != AUTH_VER {
return Ok(None);
}
let ulen = avh[1] as usize;
let mut uname = vec![0u8; ulen];
sock.read_exact(&mut uname).await?;
let mut plh = [0u8; 1];
sock.read_exact(&mut plh).await?;
let plen = plh[0] as usize;
let mut passwd = vec![0u8; plen];
sock.read_exact(&mut passwd).await?;
let ok = uname.as_slice() == PROXY_USERNAME.as_bytes() && passwd.as_slice() == cred.as_bytes();
if !ok {
sock.write_all(&[AUTH_VER, 0x01]).await?; return Ok(None);
}
sock.write_all(&[AUTH_VER, 0x00]).await?;
let mut rh = [0u8; 4];
sock.read_exact(&mut rh).await?;
let mut req = rh.to_vec();
match rh[3] {
ATYP_IPV4 => {
let mut rest = [0u8; 4 + 2];
sock.read_exact(&mut rest).await?;
req.extend_from_slice(&rest);
}
ATYP_DOMAIN => {
let mut lb = [0u8; 1];
sock.read_exact(&mut lb).await?;
let len = lb[0] as usize;
let mut rest = vec![0u8; len + 2];
sock.read_exact(&mut rest).await?;
req.push(lb[0]);
req.extend_from_slice(&rest);
}
ATYP_IPV6 => {
let mut rest = [0u8; 16 + 2];
drop(sock.read_exact(&mut rest).await);
reply_failure(&mut sock, REP_ATYP_NOT_SUPPORTED).await?;
return Ok(None);
}
_ => {
reply_failure(&mut sock, REP_ATYP_NOT_SUPPORTED).await?;
return Ok(None);
}
}
let target = match parse_request(&req) {
Ok(t) => t,
Err(rep) => {
reply_failure(&mut sock, rep).await?;
return Ok(None);
}
};
let overlay = match dialer.dial(&target).await {
Ok(s) => s,
Err(e) => {
tracing::debug!(?target, error = ?e, "loopback SOCKS5 overlay dial failed");
reply_failure(&mut sock, 0x05).await?; return Ok(None);
}
};
sock.write_all(&[SOCKS5_VER, 0x00, 0x00, ATYP_IPV4, 0, 0, 0, 0, 0, 0])
.await?;
Ok(Some((sock, overlay)))
}
async fn reply_failure(sock: &mut TcpStream, rep: u8) -> std::io::Result<()> {
sock.write_all(&[SOCKS5_VER, rep, 0x00, ATYP_IPV4, 0, 0, 0, 0, 0, 0])
.await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_request_ipv4() {
let buf = [0x05, 0x01, 0x00, 0x01, 100, 64, 0, 5, 0x1f, 0x90];
let t = parse_request(&buf).expect("ipv4 target");
assert_eq!(t, Target::Ipv4(Ipv4Addr::new(100, 64, 0, 5), 8080));
}
#[test]
fn parse_request_domain() {
let mut buf = vec![0x05, 0x01, 0x00, 0x03, 0x09];
buf.extend_from_slice(b"peer.host");
buf.extend_from_slice(&443u16.to_be_bytes());
let t = parse_request(&buf).expect("domain target");
assert_eq!(t, Target::Domain("peer.host".to_string(), 443));
}
#[test]
fn parse_request_ipv6_refused() {
let mut buf = vec![0x05, 0x01, 0x00, 0x04];
buf.extend_from_slice(&[0u8; 16]); buf.extend_from_slice(&443u16.to_be_bytes());
let rep = parse_request(&buf).expect_err("ipv6 refused");
assert_eq!(rep, REP_ATYP_NOT_SUPPORTED);
}
#[test]
fn parse_request_bad_cmd() {
let buf = [0x05, 0x03, 0x00, 0x01, 100, 64, 0, 5, 0x1f, 0x90];
let rep = parse_request(&buf).expect_err("bad cmd refused");
assert_eq!(rep, REP_CMD_NOT_SUPPORTED);
}
#[test]
fn hex_cred_len() {
let cred = gen_cred();
assert_eq!(cred.len(), 32);
assert!(
cred.chars()
.all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase())
);
}
}