use bytes::{BufMut, BytesMut};
use socks5_impl::protocol::{Address, AsyncStreamOperation, StreamOperation};
use tokio::io::{AsyncRead, AsyncReadExt};
pub const V2_MAGIC_ADDRESS: &str = "sp.v2.udp-over-tcp.arpa";
#[repr(u8)]
#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
pub enum UotMode {
#[default]
Datagram = 0,
Connected = 1,
}
impl TryFrom<u8> for UotMode {
type Error = std::io::Error;
fn try_from(value: u8) -> Result<Self, Self::Error> {
use std::io::{Error, ErrorKind::InvalidData};
match value {
0 => Ok(UotMode::Datagram),
1 => Ok(UotMode::Connected),
other => Err(Error::new(InvalidData, format!("invalid UOT mode: {other}"))),
}
}
}
impl From<UotMode> for u8 {
fn from(mode: UotMode) -> Self {
mode as u8
}
}
#[derive(Clone, Debug)]
pub struct UotRequest {
pub mode: UotMode,
pub destination: Address,
}
impl From<UotRequest> for Vec<u8> {
fn from(request: UotRequest) -> Self {
let mut buf = BytesMut::with_capacity(1 + request.destination.len());
buf.put_u8(request.mode.into());
request.destination.write_to_buf(&mut buf);
buf.to_vec()
}
}
impl UotRequest {
pub fn new(mode: UotMode, destination: Address) -> Self {
Self { mode, destination }
}
}
pub fn uot_sentinel_destination() -> Address {
Address::DomainAddress(V2_MAGIC_ADDRESS.into(), 0)
}
pub fn uot_is_sentinel_destination(address: &Address) -> bool {
matches!(address, Address::DomainAddress(domain, _) if &**domain == V2_MAGIC_ADDRESS)
}
pub async fn uot_get_request_from_stream<R>(reader: &mut R) -> std::io::Result<UotRequest>
where
R: AsyncRead + Unpin + Send + ?Sized,
{
let mode = UotMode::try_from(reader.read_u8().await?)?;
let destination = Address::retrieve_from_async_stream(reader).await?;
Ok(UotRequest::new(mode, destination))
}
pub fn uot_encode_packet(mode: UotMode, destination: Option<&Address>, payload: &[u8]) -> std::io::Result<Vec<u8>> {
use std::io::{Error, ErrorKind::InvalidInput};
if payload.len() > u16::MAX as usize {
return Err(Error::new(InvalidInput, "UOT packet too large"));
}
match mode {
UotMode::Datagram => {
let destination = destination.ok_or_else(|| Error::new(InvalidInput, "Datagram mode requires a destination"))?;
let mut buf = BytesMut::with_capacity(destination.len() + 2 + payload.len());
destination.write_to_buf(&mut buf);
buf.put_u16(payload.len() as u16);
buf.extend_from_slice(payload);
Ok(buf.to_vec())
}
UotMode::Connected => {
if destination.is_some() {
return Err(Error::new(InvalidInput, "Connected mode does not allow a destination"));
}
let mut buf = BytesMut::with_capacity(2 + payload.len());
buf.put_u16(payload.len() as u16);
buf.extend_from_slice(payload);
Ok(buf.to_vec())
}
}
}
pub async fn uot_get_packet_from_stream<R>(mode: UotMode, reader: &mut R) -> std::io::Result<(Option<Address>, Vec<u8>)>
where
R: AsyncRead + Unpin + Send + ?Sized,
{
match mode {
UotMode::Datagram => {
let destination = Address::retrieve_from_async_stream(reader).await?;
let payload_len = reader.read_u16().await? as usize;
let mut payload = vec![0u8; payload_len];
reader.read_exact(&mut payload).await?;
Ok((Some(destination), payload))
}
UotMode::Connected => {
let payload_len = reader.read_u16().await? as usize;
let mut payload = vec![0u8; payload_len];
reader.read_exact(&mut payload).await?;
Ok((None, payload))
}
}
}