mod v4;
mod v5;
use super::HandlerResources;
use super::tcp::{open_tcp_listener, request_tcp_channel};
use crate::client::StreamCommand;
use crate::config;
use bytes::{Buf, Bytes};
use penguin_mux::{Datagram, Dupe};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::sync::Arc;
use thiserror::Error;
use tokio::io::{self as tio, AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, BufReader};
use tokio::net::UdpSocket;
use tokio::sync::mpsc;
use tokio::task::JoinSet;
use tracing::{debug, info, trace, warn};
#[derive(Debug, Error)]
pub enum Error {
#[error("data transfer error: {0}")]
DataTransfer(std::io::Error),
#[error("client with version={0} is not SOCKSv4 or SOCKSv5")]
SocksVersion(u8),
#[error("unsupported SOCKS command: {0}")]
InvalidCommand(u8),
#[error("invalid SOCKS address type: {0}")]
AddressType(u8),
#[error("cannot {0} in SOCKS request: {1}")]
ProcessSocksRequest(&'static str, std::io::Error),
#[error("cannot parse SOCKS associate datagram")]
ParseAssociate,
#[error("client does not support NOAUTH")]
OtherAuth,
#[error(transparent)]
Fatal(#[from] super::FatalError),
}
pub(super) async fn handle_socks(
lhost: &'static str,
lport: u16,
handler_resources: &'static HandlerResources,
) -> Result<(), super::FatalError> {
let listener = open_tcp_listener(lhost, lport)
.await
.map_err(super::FatalError::ClientIo)?;
let mut socks_jobs = JoinSet::new();
loop {
tokio::select! {
biased;
Some(finished) = socks_jobs.join_next() => {
if let Err(e) = finished.expect("SOCKS job panicked (this is a bug)") {
if let Error::Fatal(e) = e {
return Err(e);
}
info!("{e}");
}
}
result = listener.accept() => {
let (stream, _) = result.map_err(super::FatalError::ClientIo)?;
socks_jobs.spawn(on_socks_accept(stream, lhost, handler_resources));
}
}
}
}
pub(super) async fn handle_socks_stdio(
handler_resources: &'static HandlerResources,
) -> Result<(), super::FatalError> {
let stdio = tio::join(tio::stdin(), tio::stdout());
if let Err(e) = on_socks_accept(stdio, "localhost", handler_resources).await {
if let Error::Fatal(e) = e {
return Err(e);
}
info!("{e}");
}
Ok(())
}
#[tracing::instrument(skip(stream, handler_resources), level = "trace")]
pub(super) async fn on_socks_accept<RW>(
stream: RW,
local_addr: &str,
handler_resources: &'static HandlerResources,
) -> Result<(), Error>
where
RW: AsyncRead + AsyncWrite + Unpin,
{
let mut bufreader = BufReader::new(stream);
let version = bufreader
.read_u8()
.await
.map_err(|e| Error::ProcessSocksRequest("read version", e))?;
match version {
4 => socks4(&mut bufreader, handler_resources).await,
5 => socks5(&mut bufreader, local_addr, handler_resources).await,
version => Err(Error::SocksVersion(version)),
}
}
#[tracing::instrument(skip_all, fields(host, port, cmd))]
async fn socks4<RW>(stream: &mut RW, handler_resources: &HandlerResources) -> Result<(), Error>
where
RW: AsyncBufRead + AsyncWrite + Unpin,
{
let (command, rhost, rport) = v4::read_request(stream).await?;
tracing::Span::current().record("host", format_args!("{}", String::from_utf8_lossy(&rhost)));
tracing::Span::current().record("port", rport);
tracing::Span::current().record("cmd", command);
debug!("SOCKSv4 request");
if command == 0x01 {
let stream_command_tx_permit = handler_resources
.stream_command_tx
.reserve()
.await
.or(Err(super::FatalError::RequestStream))?;
handle_connect(stream, rhost, rport, stream_command_tx_permit, false).await
} else {
v4::write_response(stream, 0x5b).await?;
Err(Error::InvalidCommand(command))
}
}
#[tracing::instrument(skip_all, fields(host, port, cmd, local = %local_addr))]
async fn socks5<RW>(
stream: &mut RW,
local_addr: &str,
handler_resources: &'static HandlerResources,
) -> Result<(), Error>
where
RW: AsyncBufRead + AsyncWrite + Unpin,
{
let methods = v5::read_auth_methods(stream).await?;
if !methods.contains(&0x00) {
v5::write_auth_method(stream, 0xff).await?;
return Err(Error::OtherAuth);
}
v5::write_auth_method(stream, 0x00).await?;
let (command, rhost, rport) = v5::read_request(stream).await?;
tracing::Span::current().record("host", format_args!("{}", String::from_utf8_lossy(&rhost)));
tracing::Span::current().record("port", rport);
tracing::Span::current().record("cmd", command);
debug!("SOCKSv5 request");
match command {
0x01 => {
let stream_command_tx_permit = handler_resources
.stream_command_tx
.reserve()
.await
.or(Err(super::FatalError::RequestStream))?;
handle_connect(stream, rhost, rport, stream_command_tx_permit, true).await
}
0x03 => handle_associate(stream, local_addr, handler_resources).await,
_ => {
v5::write_response_unspecified(stream, 0x07).await?;
Err(Error::InvalidCommand(command))
}
}
}
#[tracing::instrument(skip_all, level = "trace")]
async fn handle_connect<RW>(
stream: &mut RW,
rhost: Bytes,
rport: u16,
stream_command_tx_permit: mpsc::Permit<'_, StreamCommand>,
version_is_5: bool,
) -> Result<(), Error>
where
RW: AsyncBufRead + AsyncWrite + Unpin,
{
let channel = request_tcp_channel(stream_command_tx_permit, rhost, rport)
.await
.or(Err(super::FatalError::MainLoopExitWithoutSendingStream))?;
if version_is_5 {
v5::write_response_unspecified(stream, 0x00).await?;
} else {
v4::write_response(stream, 0x5a).await?;
}
trace!("SOCKS starting copy");
channel
.into_copy_bidirectional_with_buf(stream)
.await
.map_err(Error::DataTransfer)?;
Ok(())
}
#[tracing::instrument(skip_all, level = "trace")]
async fn handle_associate<RW>(
stream: &mut RW,
local_addr: &str,
handler_resources: &'static HandlerResources,
) -> Result<(), Error>
where
RW: AsyncRead + AsyncWrite + Unpin,
{
let socket = match UdpSocket::bind((local_addr, 0)).await {
Ok(s) => s,
Err(e) => {
v5::write_response_unspecified(stream, 0x01).await?;
return Err(Error::ProcessSocksRequest("bind udp socket", e));
}
};
let sock_local_addr = match socket.local_addr() {
Ok(a) => a,
Err(e) => {
v5::write_response_unspecified(stream, 0x01).await?;
return Err(Error::ProcessSocksRequest("get udp socket local addr", e));
}
};
trace!("SOCKS relaying at {sock_local_addr}");
let relay_task = tokio::spawn(udp_relay(handler_resources, socket));
v5::write_response(stream, 0x00, sock_local_addr).await?;
stream.read_exact(&mut [0; 1]).await.ok();
relay_task.abort();
Ok(())
}
#[tracing::instrument(skip_all, level = "trace")]
async fn udp_relay(handler_resources: &HandlerResources, socket: UdpSocket) -> Result<(), Error> {
let socket = Arc::new(socket);
loop {
let Some((target_host, target_port, data, src, sport)) =
handle_udp_relay_header(&socket).await?
else {
continue;
};
let client_id = handler_resources.add_udp_client((src, sport).into(), socket.dupe(), true);
let datagram_frame = Datagram {
target_host,
target_port,
flow_id: client_id,
data,
};
handler_resources
.datagram_tx
.send(datagram_frame)
.await
.or(Err(super::FatalError::SendDatagram))?;
}
}
async fn handle_udp_relay_header(
socket: &UdpSocket,
) -> Result<Option<(Bytes, u16, Bytes, IpAddr, u16)>, Error> {
let mut buf = vec![0; config::MAX_UDP_PACKET_SIZE];
let (len, addr) = socket
.recv_from(&mut buf)
.await
.map_err(Error::DataTransfer)?;
trace!("received {len} bytes from {addr}");
buf.truncate(len);
let mut buf = Bytes::from(buf);
if buf.remaining() < 4 {
return Err(Error::ParseAssociate);
}
let _reserved = buf.get_u16();
let frag = buf.get_u8();
if frag != 0 {
warn!("Fragmented UDP packets are not implemented");
return Ok(None);
}
let atyp = buf.get_u8();
let (dst, port) = match atyp {
0x01 => {
if buf.remaining() < 6 {
return Err(Error::ParseAssociate);
}
let addr = buf.get_u32();
let dst = Ipv4Addr::from(addr).to_string();
let port = buf.get_u16();
(dst.into(), port)
}
0x03 => {
if buf.remaining() < 1 {
return Err(Error::ParseAssociate);
}
let len = usize::from(buf.get_u8());
if buf.remaining() < len + 2 {
return Err(Error::ParseAssociate);
}
let dst = buf.split_to(len);
let port = buf.get_u16();
(dst, port)
}
0x04 => {
if buf.remaining() < 18 {
return Err(Error::ParseAssociate);
}
let addr = buf.get_u128();
let dst = Ipv6Addr::from(addr).to_string();
let port = buf.get_u16();
(dst.into(), port)
}
_ => {
warn!("Dropping datagram with invalid address type {atyp}");
return Ok(None);
}
};
trace!("Parsed packet: dst {dst:?} port {port}");
Ok(Some((dst, port, buf, addr.ip(), addr.port())))
}
pub async fn send_udp_relay_response(
socket: &UdpSocket,
target: SocketAddr,
data: &[u8],
) -> std::io::Result<usize> {
let mut content = vec![0; 3];
match target.ip() {
IpAddr::V4(ip) => {
content.extend(ip.octets());
content.extend([0x01]);
}
IpAddr::V6(ip) => {
content.extend(ip.octets());
content.extend([0x04]);
}
}
content.extend(&target.port().to_be_bytes());
content.extend(data);
socket.send_to(&content, target).await
}