use std::{
fmt,
net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6},
str::FromStr,
};
use async_std::io::{self, prelude::WriteExt, ReadExt};
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub struct DomainName {
bytes: Vec<u8>,
}
impl AsRef<[u8]> for DomainName {
fn as_ref(&self) -> &[u8] {
self.bytes.as_ref()
}
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum ParseDomainNameError {
TooLong { max: usize, len: usize },
}
impl fmt::Display for ParseDomainNameError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TooLong { max, len } => write!(f, "length must be <={}, but {}", max, len),
}
}
}
impl std::error::Error for ParseDomainNameError {}
impl DomainName {
const MAX_LEN: usize = 0xff;
pub fn new(bytes: Vec<u8>) -> std::result::Result<Self, ParseDomainNameError> {
if bytes.len() > Self::MAX_LEN {
return Err(ParseDomainNameError::TooLong {
max: Self::MAX_LEN,
len: bytes.len(),
});
}
Ok(Self { bytes })
}
}
impl fmt::Display for DomainName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
String::from_utf8_lossy(self.bytes.as_ref()).fmt(f)
}
}
impl FromStr for DomainName {
type Err = ParseDomainNameError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
Ok(Self::new(s.as_bytes().to_vec())?)
}
}
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub enum Addr {
Ipv4(Ipv4Addr),
DomainName(DomainName),
Ipv6(Ipv6Addr),
}
impl fmt::Display for Addr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Addr::Ipv4(addr) => addr.fmt(f),
Addr::DomainName(addr) => addr.fmt(f),
Addr::Ipv6(addr) => addr.fmt(f),
}
}
}
type Port = u16;
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct SocketDomainName {
domain_name: DomainName,
port: Port,
}
impl SocketDomainName {
pub fn new(domain_name: DomainName, port: Port) -> Self {
Self { domain_name, port }
}
pub fn domain_name(&self) -> &DomainName {
&self.domain_name
}
pub fn port(&self) -> Port {
self.port
}
}
impl fmt::Display for SocketDomainName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}:{}", self.domain_name, self.port)
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum SocketAddr {
Ipv4(SocketAddrV4),
DomainName(SocketDomainName),
Ipv6(SocketAddrV6),
}
impl SocketAddr {
pub fn new(addr: Addr, port: Port) -> Self {
match addr {
Addr::Ipv4(addr) => Self::Ipv4(SocketAddrV4::new(addr, port)),
Addr::DomainName(addr) => Self::DomainName(SocketDomainName::new(addr, port)),
Addr::Ipv6(addr) => Self::Ipv6(SocketAddrV6::new(addr, port, 0, 0)),
}
}
}
impl fmt::Display for SocketAddr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SocketAddr::Ipv4(addr) => addr.fmt(f),
SocketAddr::DomainName(addr) => addr.fmt(f),
SocketAddr::Ipv6(addr) => addr.fmt(f),
}
}
}
const SOCKS_VERSION_5: u8 = 0x05;
const SOCKS_NO_AUTHENTICATION_REQUIRED: u8 = 0x00;
const SOCKS_COMMAND_CONNECT: u8 = 0x01;
const SOCKS_RESERVED: u8 = 0x00;
const SOCKS_ADDRESS_IPV4: u8 = 0x01;
const SOCKS_ADDRESS_DOMAIN_NAME: u8 = 0x03;
const SOCKS_ADDRESS_IPV6: u8 = 0x04;
const SOCKS_REPLY_SUCCEEDED: u8 = 0x00;
const SOCKS_REPLY_GENERAL_SOCKS_SERVER_FAILURE: u8 = 0x01;
const SOCKS_REPLY_HOST_UNREACHABLE: u8 = 0x04;
const SOCKS_REPLY_CONNECTION_REFUSED: u8 = 0x05;
const SOCKS_REPLY_TTL_EXPIRED: u8 = 0x06;
const SOCKS_REPLY_COMMAND_NOT_SUPPORTED: u8 = 0x07;
const SOCKS_REPLY_ADDRESS_TYPE_NOT_SUPPORTED: u8 = 0x08;
#[derive(Debug)]
pub enum ConnectError {
UnsupportedVersion(u8),
UnsupportedMethod(u8),
GeneralServerFailure,
HostUnreachable,
ConnectionRefused,
TtlExpired,
CommandNotSupported,
AddressTypeNotSupported,
UnknownFailure(u8),
UnsupportedAddressType(u8),
IoError(io::Error),
}
impl fmt::Display for ConnectError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::UnsupportedVersion(ver) => write!(f, "Unsupported SOCKS version: {:#02x}", ver),
Self::UnsupportedMethod(method) => {
write!(f, "Unsupported SOCKS method: {:#02x}", method)
}
Self::GeneralServerFailure => "General SOCKS server failure".fmt(f),
Self::HostUnreachable => "Host unreachable".fmt(f),
Self::ConnectionRefused => "Connection refused".fmt(f),
Self::TtlExpired => "TTL expired".fmt(f),
Self::CommandNotSupported => "Command not supported".fmt(f),
Self::AddressTypeNotSupported => "Address type not supported".fmt(f),
Self::UnknownFailure(rep) => write!(f, "Unknown SOCKS failure: {:#02x}", rep),
Self::UnsupportedAddressType(atyp) => {
write!(f, "Unsupported address type: {:#02x}", atyp)
}
Self::IoError(err) => err.fmt(f),
}
}
}
impl std::error::Error for ConnectError {}
pub type Result<T> = std::result::Result<T, ConnectError>;
async fn write_u8<W>(s: &mut W, v: u8) -> Result<()>
where
W: WriteExt + Unpin,
{
let bytes = [v];
if let Err(err) = s.write_all(&bytes).await {
return Err(ConnectError::IoError(err));
}
Ok(())
}
async fn read_u8<R>(s: &mut R) -> Result<u8>
where
R: ReadExt + Unpin,
{
let mut bytes = [0; 1];
if let Err(err) = s.read_exact(&mut bytes).await {
return Err(ConnectError::IoError(err));
}
Ok(bytes[0])
}
async fn write_u16<W>(s: &mut W, v: u16) -> Result<()>
where
W: WriteExt + Unpin,
{
if let Err(err) = s.write_all(&v.to_be_bytes()).await {
return Err(ConnectError::IoError(err));
}
Ok(())
}
async fn read_u16<R>(s: &mut R) -> Result<u16>
where
R: ReadExt + Unpin,
{
let mut bytes = [0; 2];
if let Err(err) = s.read_exact(&mut bytes).await {
return Err(ConnectError::IoError(err));
}
Ok(u16::from_be_bytes(bytes))
}
async fn write_all<W>(s: &mut W, v: &[u8]) -> Result<()>
where
W: WriteExt + Unpin,
{
if let Err(err) = s.write_all(v).await {
return Err(ConnectError::IoError(err));
}
Ok(())
}
async fn read_exact<R>(s: &mut R, v: &mut [u8]) -> Result<()>
where
R: ReadExt + Unpin,
{
if let Err(err) = s.read_exact(v).await {
return Err(ConnectError::IoError(err));
}
Ok(())
}
async fn read_to_end<R>(s: &mut R, v: &mut Vec<u8>) -> Result<()>
where
R: ReadExt + Unpin,
{
if let Err(err) = s.read_to_end(v).await {
return Err(ConnectError::IoError(err));
}
Ok(())
}
async fn flush<W>(s: &mut W) -> Result<()>
where
W: WriteExt + Unpin,
{
if let Err(err) = s.flush().await {
return Err(ConnectError::IoError(err));
}
Ok(())
}
pub async fn connect<S>(server: &mut S, destination: SocketAddr) -> Result<SocketAddr>
where
S: ReadExt + WriteExt + Unpin,
{
write_u8(server, SOCKS_VERSION_5).await?;
write_u8(server, 1).await?;
write_u8(server, SOCKS_NO_AUTHENTICATION_REQUIRED).await?;
flush(server).await?;
let ver = read_u8(server).await?;
let method = read_u8(server).await?;
if ver != SOCKS_VERSION_5 {
return Err(ConnectError::UnsupportedVersion(ver));
}
if method != SOCKS_NO_AUTHENTICATION_REQUIRED {
return Err(ConnectError::UnsupportedMethod(method));
}
write_u8(server, SOCKS_VERSION_5).await?;
write_u8(server, SOCKS_COMMAND_CONNECT).await?;
write_u8(server, SOCKS_RESERVED).await?;
match destination {
SocketAddr::Ipv4(addr) => {
write_u8(server, SOCKS_ADDRESS_IPV4).await?;
write_all(server, addr.ip().octets().as_ref()).await?;
write_u16(server, addr.port()).await?;
}
SocketAddr::DomainName(addr) => {
write_u8(server, SOCKS_ADDRESS_DOMAIN_NAME).await?;
write_u8(server, addr.domain_name().as_ref().len() as u8).await?;
write_all(server, addr.domain_name().as_ref()).await?;
write_u16(server, addr.port()).await?;
}
SocketAddr::Ipv6(addr) => {
write_u8(server, SOCKS_ADDRESS_IPV6).await?;
write_all(server, addr.ip().octets().as_ref()).await?;
write_u16(server, addr.port()).await?;
}
}
flush(server).await?;
let ver = read_u8(server).await?;
let rep = read_u8(server).await?;
if ver != SOCKS_VERSION_5 {
return Err(ConnectError::UnsupportedVersion(ver));
}
match rep {
SOCKS_REPLY_SUCCEEDED => {}
SOCKS_REPLY_GENERAL_SOCKS_SERVER_FAILURE => return Err(ConnectError::GeneralServerFailure),
SOCKS_REPLY_HOST_UNREACHABLE => return Err(ConnectError::HostUnreachable),
SOCKS_REPLY_CONNECTION_REFUSED => return Err(ConnectError::ConnectionRefused),
SOCKS_REPLY_TTL_EXPIRED => return Err(ConnectError::TtlExpired),
SOCKS_REPLY_COMMAND_NOT_SUPPORTED => return Err(ConnectError::CommandNotSupported),
SOCKS_REPLY_ADDRESS_TYPE_NOT_SUPPORTED => {
return Err(ConnectError::AddressTypeNotSupported)
}
_ => return Err(ConnectError::UnknownFailure(rep)),
}
let _rsv = read_u8(server).await?;
let atyp = read_u8(server).await?;
match atyp {
SOCKS_ADDRESS_IPV4 => {
let mut bytes = [0; 4];
read_exact(server, &mut bytes).await?;
let port = read_u16(server).await?;
Ok(SocketAddr::Ipv4(SocketAddrV4::new(bytes.into(), port)))
}
SOCKS_ADDRESS_DOMAIN_NAME => {
let len = read_u8(server).await?;
let mut r = server.take(len as u64);
let mut bytes = Vec::with_capacity(len as usize);
read_to_end(&mut r, &mut bytes).await?;
let domain_name = DomainName::new(bytes).unwrap();
let port = read_u16(server).await?;
Ok(SocketAddr::DomainName(SocketDomainName::new(
domain_name,
port,
)))
}
SOCKS_ADDRESS_IPV6 => {
let mut bytes = [0; 16];
read_exact(server, &mut bytes).await?;
let port = read_u16(server).await?;
Ok(SocketAddr::Ipv6(SocketAddrV6::new(
bytes.into(),
port,
0,
0,
)))
}
_ => Err(ConnectError::UnsupportedAddressType(atyp)),
}
}