#![deny(unsafe_code)]
#![warn(missing_docs)]
use std::{
fmt,
net::{SocketAddrV4, SocketAddrV6},
};
use futures::{
io::{self, AsyncRead, AsyncWrite},
prelude::*,
};
use koibumi_net::{
domain::{Domain, SocketDomain},
socks::SocketAddr,
};
const SOCKS_VERSION_5: u8 = 0x05;
const SOCKS_NO_AUTHENTICATION_REQUIRED: u8 = 0x00;
const SOCKS_USERNAME_AND_PASSWORD: u8 = 0x02;
const SOCKS_COMMAND_CONNECT: u8 = 0x01;
const SOCKS_RESERVED: u8 = 0x00;
const SOCKS_ADDRESS_IPV4: u8 = 0x01;
const SOCKS_ADDRESS_DOMAIN_NAME: u8 = 0x03;
const SOCKS_ADDRESS_IPV6: u8 = 0x04;
const SOCKS_REPLY_SUCCEEDED: u8 = 0x00;
const SOCKS_REPLY_GENERAL_SOCKS_SERVER_FAILURE: u8 = 0x01;
const SOCKS_REPLY_HOST_UNREACHABLE: u8 = 0x04;
const SOCKS_REPLY_CONNECTION_REFUSED: u8 = 0x05;
const SOCKS_REPLY_TTL_EXPIRED: u8 = 0x06;
const SOCKS_REPLY_COMMAND_NOT_SUPPORTED: u8 = 0x07;
const SOCKS_REPLY_ADDRESS_TYPE_NOT_SUPPORTED: u8 = 0x08;
const SOCKS_SUBNEGOTIATION_VERSION: u8 = 0x01;
const SOCKS_SUBNEGOTIATION_REPLY_SUCCEEDED: u8 = 0x00;
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
enum Auth {
None,
Password {
username: Vec<u8>,
password: Vec<u8>,
},
}
#[derive(Debug)]
pub enum ConnectError {
UnsupportedVersion(u8),
UnsupportedMethod(u8),
GeneralServerFailure,
HostUnreachable,
ConnectionRefused,
TtlExpired,
CommandNotSupported,
AddressTypeNotSupported,
UnknownFailure(u8),
UnsupportedAddressType(u8),
IoError(io::Error),
InvalidUsernameLength(usize),
InvalidPasswordLength(usize),
UnsupportedSubnegotiationVersion(u8),
AuthenticationFailure,
}
impl fmt::Display for ConnectError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::UnsupportedVersion(ver) => write!(f, "Unsupported SOCKS version: {:#02x}", ver),
Self::UnsupportedMethod(method) => {
write!(f, "Unsupported SOCKS method: {:#02x}", method)
}
Self::GeneralServerFailure => "General SOCKS server failure".fmt(f),
Self::HostUnreachable => "Host unreachable".fmt(f),
Self::ConnectionRefused => "Connection refused".fmt(f),
Self::TtlExpired => "TTL expired".fmt(f),
Self::CommandNotSupported => "Command not supported".fmt(f),
Self::AddressTypeNotSupported => "Address type not supported".fmt(f),
Self::UnknownFailure(rep) => write!(f, "Unknown SOCKS failure: {:#02x}", rep),
Self::UnsupportedAddressType(atyp) => {
write!(f, "Unsupported address type: {:#02x}", atyp)
}
Self::IoError(err) => err.fmt(f),
Self::InvalidUsernameLength(len) => {
write!(f, "username length must be 1..255, but {}", len)
}
Self::InvalidPasswordLength(len) => {
write!(f, "password length must be 1..255, but {}", len)
}
Self::UnsupportedSubnegotiationVersion(ver) => {
write!(f, "Unsupported SOCKS subnegotiation version: {:#02x}", ver)
}
Self::AuthenticationFailure => "authentication failure".fmt(f),
}
}
}
impl std::error::Error for ConnectError {}
impl From<io::Error> for ConnectError {
fn from(err: io::Error) -> Self {
ConnectError::IoError(err)
}
}
pub type Result<T> = std::result::Result<T, ConnectError>;
async fn read_u8<R>(s: &mut R) -> Result<u8>
where
R: AsyncRead + Unpin,
{
let mut bytes = [0; 1];
s.read_exact(&mut bytes).await?;
Ok(bytes[0])
}
async fn read_u16<R>(s: &mut R) -> Result<u16>
where
R: AsyncRead + Unpin,
{
let mut bytes = [0; 2];
s.read_exact(&mut bytes).await?;
Ok(u16::from_be_bytes(bytes))
}
pub async fn connect<S>(server: &mut S, destination: SocketAddr) -> Result<SocketAddr>
where
S: AsyncRead + AsyncWrite + Unpin,
{
connect_with_auth(server, Auth::None, destination).await
}
pub async fn connect_with_password<S>(
server: &mut S,
username: impl AsRef<[u8]>,
password: impl AsRef<[u8]>,
destination: SocketAddr,
) -> Result<SocketAddr>
where
S: AsyncRead + AsyncWrite + Unpin,
{
connect_with_auth(
server,
Auth::Password {
username: username.as_ref().to_vec(),
password: password.as_ref().to_vec(),
},
destination,
)
.await
}
#[allow(clippy::len_zero)]
async fn connect_with_auth<S>(
server: &mut S,
auth: Auth,
destination: SocketAddr,
) -> Result<SocketAddr>
where
S: AsyncRead + AsyncWrite + Unpin,
{
if let Auth::Password { username, password } = &auth {
if username.len() < 1 || username.len() > 255 {
return Err(ConnectError::InvalidUsernameLength(username.len()));
}
if password.len() < 1 || password.len() > 255 {
return Err(ConnectError::InvalidPasswordLength(password.len()));
}
}
let mut packet: Vec<u8> = Vec::with_capacity(3);
packet.push(SOCKS_VERSION_5);
packet.push(1);
let requested_method = match auth {
Auth::None => SOCKS_NO_AUTHENTICATION_REQUIRED,
Auth::Password { .. } => SOCKS_USERNAME_AND_PASSWORD,
};
packet.push(requested_method);
server.write_all(&packet).await?;
server.flush().await?;
let ver = read_u8(server).await?;
let method = read_u8(server).await?;
if ver != SOCKS_VERSION_5 {
return Err(ConnectError::UnsupportedVersion(ver));
}
if method != requested_method {
return Err(ConnectError::UnsupportedMethod(method));
}
if let Auth::Password { username, password } = auth {
let mut packet = Vec::new();
packet.push(SOCKS_SUBNEGOTIATION_VERSION);
packet.push(username.len() as u8);
packet.extend_from_slice(&username);
packet.push(password.len() as u8);
packet.extend_from_slice(&password);
server.write_all(&packet).await?;
server.flush().await?;
let ver = read_u8(server).await?;
let status = read_u8(server).await?;
if ver != SOCKS_SUBNEGOTIATION_VERSION {
return Err(ConnectError::UnsupportedSubnegotiationVersion(ver));
}
if status != SOCKS_SUBNEGOTIATION_REPLY_SUCCEEDED {
return Err(ConnectError::AuthenticationFailure);
}
}
let mut packet: Vec<u8> = Vec::new();
packet.push(SOCKS_VERSION_5);
packet.push(SOCKS_COMMAND_CONNECT);
packet.push(SOCKS_RESERVED);
match destination {
SocketAddr::Ipv4(addr) => {
packet.push(SOCKS_ADDRESS_IPV4);
packet.extend_from_slice(&addr.ip().octets());
packet.extend_from_slice(&addr.port().to_be_bytes());
}
SocketAddr::Domain(domain) => {
packet.push(SOCKS_ADDRESS_DOMAIN_NAME);
packet.push(domain.domain().as_ref().len() as u8);
packet.extend_from_slice(domain.domain().as_bytes());
packet.extend_from_slice(&domain.port().as_u16().to_be_bytes());
}
SocketAddr::Ipv6(addr) => {
packet.push(SOCKS_ADDRESS_IPV6);
packet.extend_from_slice(&addr.ip().octets());
packet.extend_from_slice(&addr.port().to_be_bytes());
}
}
server.write_all(&packet).await?;
server.flush().await?;
let ver = read_u8(server).await?;
let rep = read_u8(server).await?;
if ver != SOCKS_VERSION_5 {
return Err(ConnectError::UnsupportedVersion(ver));
}
match rep {
SOCKS_REPLY_SUCCEEDED => {}
SOCKS_REPLY_GENERAL_SOCKS_SERVER_FAILURE => return Err(ConnectError::GeneralServerFailure),
SOCKS_REPLY_HOST_UNREACHABLE => return Err(ConnectError::HostUnreachable),
SOCKS_REPLY_CONNECTION_REFUSED => return Err(ConnectError::ConnectionRefused),
SOCKS_REPLY_TTL_EXPIRED => return Err(ConnectError::TtlExpired),
SOCKS_REPLY_COMMAND_NOT_SUPPORTED => return Err(ConnectError::CommandNotSupported),
SOCKS_REPLY_ADDRESS_TYPE_NOT_SUPPORTED => {
return Err(ConnectError::AddressTypeNotSupported)
}
_ => return Err(ConnectError::UnknownFailure(rep)),
}
let _rsv = read_u8(server).await?;
let atyp = read_u8(server).await?;
match atyp {
SOCKS_ADDRESS_IPV4 => {
let mut bytes = [0; 4];
server.read_exact(&mut bytes).await?;
let port = read_u16(server).await?;
Ok(SocketAddr::Ipv4(SocketAddrV4::new(bytes.into(), port)))
}
SOCKS_ADDRESS_DOMAIN_NAME => {
let len = read_u8(server).await?;
let mut r = server.take(len as u64);
let mut bytes = Vec::with_capacity(len as usize);
r.read_to_end(&mut bytes).await?;
let domain = Domain::from_bytes(&bytes).unwrap();
let port = read_u16(server).await?;
Ok(SocketAddr::Domain(SocketDomain::new(domain, port.into())))
}
SOCKS_ADDRESS_IPV6 => {
let mut bytes = [0; 16];
server.read_exact(&mut bytes).await?;
let port = read_u16(server).await?;
Ok(SocketAddr::Ipv6(SocketAddrV6::new(
bytes.into(),
port,
0,
0,
)))
}
_ => Err(ConnectError::UnsupportedAddressType(atyp)),
}
}