#![warn(missing_debug_implementations, missing_docs, unreachable_pub)]
#![cfg_attr(docsrs, feature(doc_cfg))]
use std::{
fmt::{Display, Formatter},
io::{Error, ErrorKind, Result},
net::{Ipv4Addr, Ipv6Addr},
result,
};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[cfg(feature = "http")]
mod http;
#[cfg(feature = "socks5")]
mod socks5;
#[cfg(test)]
#[doc(hidden)]
pub mod test_utils;
#[cfg(feature = "http")]
#[cfg_attr(docsrs, doc(cfg(feature = "http")))]
pub use http::{HttpError, HttpReply, http_accept, http_connect, http_finalize_accept};
#[cfg(feature = "socks5")]
#[cfg_attr(docsrs, doc(cfg(feature = "socks5")))]
pub use socks5::{
Socks5Command,
Socks5Error,
Socks5Reply,
socks5_accept,
socks5_connect,
socks5_finalize_accept,
socks5_read_udp_header,
socks5_write_udp_header,
};
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub enum Address {
IPv4((Ipv4Addr, u16)),
DomainName((String, u16)),
IPv6((Ipv6Addr, u16)),
}
impl Address {
pub async fn decode_from_reader<T>(reader: &mut T) -> Result<(Self, usize)>
where
T: AsyncRead + Unpin,
{
let addr_type = AddressType::try_from(reader.read_u8().await?)?;
match addr_type {
AddressType::IPv4 => {
let mut ip = [0u8; 4];
reader.read_exact(&mut ip).await?;
let port = reader.read_u16().await?;
Ok((Address::IPv4((Ipv4Addr::from(ip), port)), 1 + 4 + 2))
}
AddressType::DomainName => {
let len = reader.read_u8().await? as usize;
let mut domain = vec![0u8; len];
reader.read_exact(&mut domain).await?;
let domain_str =
String::from_utf8(domain).map_err(|_| AddrError::InvalidDomainNameEncoding)?;
let port = reader.read_u16().await?;
Ok((Address::DomainName((domain_str, port)), 1 + 1 + len + 2))
}
AddressType::IPv6 => {
let mut ip = [0u8; 16];
reader.read_exact(&mut ip).await?;
let port = reader.read_u16().await?;
Ok((Address::IPv6((Ipv6Addr::from(ip), port)), 1 + 16 + 2))
}
}
}
pub async fn encode_to_writer<T>(&self, writer: &mut T) -> Result<usize>
where
T: AsyncWrite + Unpin,
{
match self {
Address::IPv4((ip, port)) => {
writer.write_u8(AddressType::IPv4 as u8).await?;
writer.write_all(&ip.octets()).await?;
writer.write_u16(*port).await?;
Ok(1 + 4 + 2)
}
Address::DomainName((domain, port)) => {
let domain_bytes = domain.as_bytes();
if domain_bytes.len() > 255 {
return Err(AddrError::DomainNameTooLong.into());
}
writer.write_u8(AddressType::DomainName as u8).await?;
writer.write_u8(domain_bytes.len() as u8).await?;
writer.write_all(domain_bytes).await?;
writer.write_u16(*port).await?;
Ok(1 + 1 + domain_bytes.len() + 2)
}
Address::IPv6((ip, port)) => {
writer.write_u8(AddressType::IPv6 as u8).await?;
writer.write_all(&ip.octets()).await?;
writer.write_u16(*port).await?;
Ok(1 + 16 + 2)
}
}
}
pub fn decode_from_buf(buf: &[u8]) -> Result<(Self, usize)> {
let mut cursor = Cursor::new(buf);
let addr_type = AddressType::try_from(cursor.read_u8()?)?;
match addr_type {
AddressType::IPv4 => {
let mut ip = [0u8; 4];
cursor.read_slice(&mut ip)?;
let port = cursor.read_u16()?;
Ok((Address::IPv4((Ipv4Addr::from(ip), port)), 1 + 4 + 2))
}
AddressType::DomainName => {
let len = cursor.read_u8()? as usize;
let mut domain = vec![0u8; len];
cursor.read_slice(&mut domain)?;
let domain_str =
String::from_utf8(domain).map_err(|_| AddrError::InvalidDomainNameEncoding)?;
let port = cursor.read_u16()?;
Ok((Address::DomainName((domain_str, port)), 1 + 1 + len + 2))
}
AddressType::IPv6 => {
let mut ip = [0u8; 16];
cursor.read_slice(&mut ip)?;
let port = cursor.read_u16()?;
Ok((Address::IPv6((Ipv6Addr::from(ip), port)), 1 + 16 + 2))
}
}
}
pub fn encode_to_buf(&self, buf: &mut [u8]) -> Result<usize> {
let mut cursor = CursorMut::new(buf);
match self {
Address::IPv4((ip, port)) => {
cursor.write_u8(AddressType::IPv4 as u8)?;
cursor.write_slice(&ip.octets())?;
cursor.write_u16(*port)?;
Ok(1 + 4 + 2)
}
Address::DomainName((domain, port)) => {
let domain_bytes = domain.as_bytes();
if domain_bytes.len() > 255 {
return Err(AddrError::DomainNameTooLong.into());
}
cursor.write_u8(AddressType::DomainName as u8)?;
cursor.write_u8(domain_bytes.len() as u8)?;
cursor.write_slice(domain_bytes)?;
cursor.write_u16(*port)?;
Ok(1 + 1 + domain_bytes.len() + 2)
}
Address::IPv6((ip, port)) => {
cursor.write_u8(AddressType::IPv6 as u8)?;
cursor.write_slice(&ip.octets())?;
cursor.write_u16(*port)?;
Ok(1 + 16 + 2)
}
}
}
}
impl From<Address> for String {
fn from(value: Address) -> Self {
(&value).into()
}
}
impl From<&Address> for String {
fn from(address: &Address) -> Self {
match address {
Address::IPv4((ip, port)) => format!("{}:{}", ip, port),
Address::IPv6((ip, port)) => format!("[{}]:{}", ip, port),
Address::DomainName((domain, port)) => format!("{}:{}", domain, port),
}
}
}
impl TryFrom<String> for Address {
type Error = AddrError;
fn try_from(value: String) -> result::Result<Self, Self::Error> {
Address::try_from(value.as_str())
}
}
impl TryFrom<&str> for Address {
type Error = AddrError;
fn try_from(string: &str) -> result::Result<Self, Self::Error> {
if string.starts_with('[') {
let end_bracket_pos = string
.rfind(']')
.ok_or(AddrError::InvalidIPv6MissingClosingBracket)?;
if end_bracket_pos + 1 >= string.len()
|| &string[end_bracket_pos + 1..end_bracket_pos + 2] != ":"
{
return Err(AddrError::InvalidIPv6MissingPortSeparator);
}
let host = &string[1..end_bracket_pos]; let port_str = &string[end_bracket_pos + 2..];
let port = port_str
.parse::<u16>()
.map_err(|_| AddrError::InvalidPortNumber)?;
let ipv6 = host
.parse::<Ipv6Addr>()
.map_err(|_| AddrError::InvalidIPv6Address)?;
Ok(Address::IPv6((ipv6, port)))
} else {
let last_colon_pos = string
.rfind(':')
.ok_or(AddrError::InvalidTargetAddressMissingPortSeparator)?;
let host = &string[0..last_colon_pos];
let port_str = &string[last_colon_pos + 1..];
let port = port_str
.parse::<u16>()
.map_err(|_| AddrError::InvalidPortNumber)?;
if let Ok(ipv4) = host.parse::<Ipv4Addr>() {
Ok(Address::IPv4((ipv4, port)))
} else {
Ok(Address::DomainName((host.to_string(), port)))
}
}
}
}
#[derive(Clone, Debug, Default, Eq, PartialEq, Hash)]
pub enum AuthMethod {
#[default]
NoAuth,
UserPass {
username: String,
password: String,
},
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum AddressType {
IPv4 = 0x01,
DomainName = 0x03,
IPv6 = 0x04,
}
impl TryFrom<u8> for AddressType {
type Error = Error;
fn try_from(value: u8) -> Result<Self> {
match value {
0x01 => Ok(AddressType::IPv4),
0x03 => Ok(AddressType::DomainName),
0x04 => Ok(AddressType::IPv6),
_ => Err(AddrError::UnsupportedAddressType.into()),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum AddrError {
UnsupportedAddressType,
DomainNameTooLong,
InvalidDomainNameEncoding,
InvalidIPv6MissingClosingBracket,
InvalidIPv6MissingPortSeparator,
InvalidTargetAddressMissingPortSeparator,
InvalidPortNumber,
InvalidIPv6Address,
}
impl Display for AddrError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::UnsupportedAddressType => write!(f, "Unsupported address type"),
Self::DomainNameTooLong => write!(f, "Domain name too long"),
Self::InvalidDomainNameEncoding => write!(f, "Invalid domain name encoding"),
Self::InvalidIPv6MissingClosingBracket => {
write!(f, "Invalid IPv6 address format: missing closing bracket")
}
Self::InvalidIPv6MissingPortSeparator => {
write!(f, "Invalid IPv6 address format: missing port separator")
}
Self::InvalidTargetAddressMissingPortSeparator => {
write!(f, "Invalid target address format: missing port separator")
}
Self::InvalidPortNumber => write!(f, "Invalid port number"),
Self::InvalidIPv6Address => write!(f, "Invalid IPv6 address"),
}
}
}
impl std::error::Error for AddrError {}
impl From<AddrError> for Error {
fn from(e: AddrError) -> Self {
match e {
AddrError::UnsupportedAddressType => Error::new(ErrorKind::InvalidData, e),
AddrError::DomainNameTooLong => Error::new(ErrorKind::InvalidInput, e),
AddrError::InvalidDomainNameEncoding => Error::new(ErrorKind::InvalidData, e),
AddrError::InvalidIPv6MissingClosingBracket => Error::new(ErrorKind::InvalidData, e),
AddrError::InvalidIPv6MissingPortSeparator => Error::new(ErrorKind::InvalidData, e),
AddrError::InvalidTargetAddressMissingPortSeparator => {
Error::new(ErrorKind::InvalidData, e)
}
AddrError::InvalidPortNumber => Error::new(ErrorKind::InvalidData, e),
AddrError::InvalidIPv6Address => Error::new(ErrorKind::InvalidData, e),
}
}
}
struct Cursor<'a> {
buf: &'a [u8],
pos: usize,
}
impl<'a> Cursor<'a> {
fn new(buf: &'a [u8]) -> Self {
Self { buf, pos: 0 }
}
fn read_u8(&mut self) -> Result<u8> {
let p = self
.buf
.get(self.pos)
.ok_or(Error::new(ErrorKind::UnexpectedEof, "buffer underflow"))?;
self.pos += 1;
Ok(*p)
}
fn read_u16(&mut self) -> Result<u16> {
let p = self
.buf
.get(self.pos..self.pos + 2)
.ok_or(Error::new(ErrorKind::UnexpectedEof, "buffer underflow"))?;
self.pos += 2;
Ok(u16::from_be_bytes(p.try_into().unwrap()))
}
fn read_slice(&mut self, buf: &mut [u8]) -> Result<()> {
let p = self
.buf
.get(self.pos..self.pos + buf.len())
.ok_or(Error::new(ErrorKind::UnexpectedEof, "buffer underflow"))?;
self.pos += buf.len();
buf.copy_from_slice(p);
Ok(())
}
}
struct CursorMut<'a> {
buf: &'a mut [u8],
pos: usize,
}
impl<'a> CursorMut<'a> {
fn new(buf: &'a mut [u8]) -> Self {
Self { buf, pos: 0 }
}
fn write_u8(&mut self, value: u8) -> Result<()> {
let p = self
.buf
.get_mut(self.pos)
.ok_or(Error::new(ErrorKind::WriteZero, "buffer overflow"))?;
*p = value;
self.pos += 1;
Ok(())
}
fn write_u16(&mut self, value: u16) -> Result<()> {
let p = self
.buf
.get_mut(self.pos..self.pos + 2)
.ok_or(Error::new(ErrorKind::WriteZero, "buffer overflow"))?;
p.copy_from_slice(&value.to_be_bytes());
self.pos += 2;
Ok(())
}
fn write_slice(&mut self, value: &[u8]) -> Result<()> {
let p = self
.buf
.get_mut(self.pos..self.pos + value.len())
.ok_or(Error::new(ErrorKind::WriteZero, "buffer overflow"))?;
p.copy_from_slice(value);
self.pos += value.len();
Ok(())
}
}
#[cfg(test)]
mod test {
use std::net::Ipv4Addr;
use tokio::task;
use super::*;
use crate::test_utils::*;
#[tokio::test]
async fn test_http_connect_accept_finalize_no_auth() {
let target_addresses = [
Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
Address::DomainName(("example.com".to_string(), 443)),
Address::IPv6((
Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
8080,
)),
];
let auth_method = AuthMethod::NoAuth;
for target in target_addresses {
let (mut client_stream, mut server_stream) = create_mock_stream();
let target_s = target.clone();
let target_c = target.clone();
let auth_s = auth_method.clone();
let auth_c = auth_method.clone();
let server_task = task::spawn(async move {
let received_addr = http_accept(&mut server_stream, &auth_s).await?;
assert_eq!(received_addr, target_s);
http_finalize_accept(&mut server_stream, &HttpReply::Ok).await?;
Ok::<_, Error>(())
});
let client_task = task::spawn(async move {
http_connect(&mut client_stream, &target_c, &auth_c).await?;
Ok::<_, Error>(())
});
let (server_result, client_result) = tokio::join!(server_task, client_task);
server_result.unwrap().unwrap();
client_result.unwrap().unwrap();
}
}
#[tokio::test]
async fn test_http_connect_accept_finalize_userpass() {
let target_addresses = [
Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
Address::DomainName(("example.com".to_string(), 443)),
Address::IPv6((
Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
8080,
)),
];
let auth_method = AuthMethod::UserPass {
username: "user".to_string(),
password: "pass".to_string(),
};
for target in target_addresses {
let (mut client_stream, mut server_stream) = create_mock_stream();
let target_s = target.clone();
let target_c = target.clone();
let auth_s = auth_method.clone();
let auth_c = auth_method.clone();
let server_task = task::spawn(async move {
let received_addr = http_accept(&mut server_stream, &auth_s).await?;
assert_eq!(received_addr, target_s);
http_finalize_accept(&mut server_stream, &HttpReply::Ok).await?;
Ok::<_, Error>(())
});
let client_task = task::spawn(async move {
http_connect(&mut client_stream, &target_c, &auth_c).await?;
Ok::<_, Error>(())
});
let (server_result, client_result) = tokio::join!(server_task, client_task);
server_result.unwrap().unwrap();
client_result.unwrap().unwrap();
}
}
#[tokio::test]
async fn test_socks5_connect_accept_finalize_no_auth() {
let target_addresses = [
Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
Address::DomainName(("example.com".to_string(), 443)),
Address::IPv6((
Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
8080,
)),
];
let auth_method = AuthMethod::NoAuth;
let commands = [
Socks5Command::Connect,
Socks5Command::Bind,
Socks5Command::UdpAssociate,
];
for target in target_addresses {
for commmand in commands {
let (mut client_stream, mut server_stream) = create_mock_stream();
let target_s = target.clone();
let target_c = target.clone();
let auth_s = auth_method.clone();
let auth_c = auth_method.clone();
let server_task = task::spawn(async move {
let (cmd, received_addr) = socks5_accept(&mut server_stream, &auth_s).await?;
assert_eq!(cmd, commmand);
assert_eq!(received_addr, target_s);
socks5_finalize_accept(
&mut server_stream,
&Socks5Reply::Succeeded,
&received_addr,
)
.await?;
Ok::<_, Error>(())
});
let client_task = task::spawn(async move {
let received_addr =
socks5_connect(&mut client_stream, &commmand, &target_c, &[auth_c]).await?;
assert_eq!(received_addr, target_c);
Ok::<_, Error>(())
});
let (server_result, client_result) = tokio::join!(server_task, client_task);
server_result.unwrap().unwrap();
client_result.unwrap().unwrap();
}
}
}
#[tokio::test]
async fn test_socks5_connect_accept_finalize_userpass() {
let target_addresses = [
Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
Address::DomainName(("example.com".to_string(), 443)),
Address::IPv6((
Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
8080,
)),
];
let auth_method = AuthMethod::UserPass {
username: "user".to_string(),
password: "pass".to_string(),
};
let commands = [
Socks5Command::Connect,
Socks5Command::Bind,
Socks5Command::UdpAssociate,
];
for target in target_addresses {
for commmand in commands {
let (mut client_stream, mut server_stream) = create_mock_stream();
let target_s = target.clone();
let target_c = target.clone();
let auth_s = auth_method.clone();
let auth_c = auth_method.clone();
let server_task = task::spawn(async move {
let (cmd, received_addr) = socks5_accept(&mut server_stream, &auth_s).await?;
assert_eq!(cmd, commmand);
assert_eq!(received_addr, target_s);
socks5_finalize_accept(
&mut server_stream,
&Socks5Reply::Succeeded,
&received_addr,
)
.await?;
Ok::<_, Error>(())
});
let client_task = task::spawn(async move {
let received_addr =
socks5_connect(&mut client_stream, &commmand, &target_c, &[auth_c]).await?;
assert_eq!(received_addr, target_c);
Ok::<_, Error>(())
});
let (server_result, client_result) = tokio::join!(server_task, client_task);
server_result.unwrap().unwrap();
client_result.unwrap().unwrap();
}
}
}
#[test]
fn test_socks5_udp_encode_decode() {
let addresses = [
Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
Address::DomainName(("example.com".to_string(), 443)),
Address::IPv6((
Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
8080,
)),
];
for original_addr in addresses {
let mut buffer = vec![0u8; 300];
let write_len = socks5_write_udp_header(&original_addr, &mut buffer).unwrap();
assert_eq!(&buffer[0..3], &[0, 0, 0]);
let (decoded_addr, read_len) = socks5_read_udp_header(&buffer).unwrap();
assert_eq!(write_len, read_len);
assert_eq!(original_addr, decoded_addr);
}
}
#[tokio::test]
async fn test_encode_decode_with_stream() {
let addresses = [
Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
Address::DomainName(("example.com".to_string(), 443)),
Address::IPv6((
Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
8080,
)),
];
for original_addr in addresses {
let (mut stream1, mut stream2) = create_mock_stream();
let write_len = original_addr.encode_to_writer(&mut stream1).await.unwrap();
let (decoded_addr, read_len) = Address::decode_from_reader(&mut stream2).await.unwrap();
assert_eq!(write_len, read_len);
assert_eq!(original_addr, decoded_addr);
}
}
#[test]
fn test_encode_decode_with_buffer() {
let addresses = [
Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
Address::DomainName(("example.com".to_string(), 443)),
Address::IPv6((
Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
8080,
)),
];
for original_addr in addresses {
let mut buffer = vec![0u8; 300];
let write_len = original_addr.encode_to_buf(&mut buffer).unwrap();
let (decoded_addr, read_len) = Address::decode_from_buf(&buffer).unwrap();
assert_eq!(write_len, read_len);
assert_eq!(original_addr, decoded_addr);
}
}
#[test]
fn test_encode_decode_text() {
let address_pairs = [
(
Address::IPv4((Ipv4Addr::new(192, 168, 1, 1), 8080)),
"192.168.1.1:8080",
),
(
Address::DomainName(("example.com".to_string(), 443)),
"example.com:443",
),
(
Address::IPv6((
Ipv6Addr::new(0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x01),
8080,
)),
"[20:1:d:b8::1]:8080",
),
];
for (addr, expected_str) in address_pairs {
let addr_to_string = String::from(&addr);
assert_eq!(addr_to_string, expected_str);
let string_to_addr = Address::try_from(expected_str).unwrap();
assert_eq!(string_to_addr, addr);
let round_trip = Address::try_from(String::from(&addr)).unwrap();
assert_eq!(round_trip, addr);
}
}
}