use std::{
borrow::Cow,
convert::TryFrom,
fmt,
fmt::Formatter,
net::{Ipv4Addr, Ipv6Addr},
};
use data_encoding::BASE32;
use multiaddr::{Multiaddr, Protocol};
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use super::error::SocksError;
pub type Result<T> = std::result::Result<T, SocksError>;
#[derive(Default, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum Authentication {
#[default]
None,
Password { username: String, password: String },
}
impl Authentication {
fn id(&self) -> u8 {
match self {
Authentication::Password { .. } => 0x02,
Authentication::None => 0x00,
}
}
}
impl fmt::Debug for Authentication {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
use Authentication::{None, Password};
match self {
None => write!(f, "None"),
Password { username, .. } => write!(f, "Password({username}, ...)"),
}
}
}
#[repr(u8)]
#[derive(Clone, Debug, Copy)]
#[allow(dead_code)]
enum Command {
Connect = 0x01,
Bind = 0x02,
TorResolve = 0xF0,
TorResolvePtr = 0xF1,
}
pub struct Socks5Client<TSocket> {
protocol: SocksProtocol<TSocket>,
is_authenticated: bool,
}
impl<TSocket> Socks5Client<TSocket>
where TSocket: AsyncRead + AsyncWrite + Unpin
{
pub fn new(socket: TSocket) -> Self {
Self {
protocol: SocksProtocol::new(socket),
is_authenticated: false,
}
}
pub fn with_authentication(&mut self, auth: Authentication) -> Result<&mut Self> {
Self::validate_auth(&auth)?;
self.protocol.set_authentication(auth);
Ok(self)
}
pub async fn connect(mut self, address: &Multiaddr) -> Result<(TSocket, Multiaddr)> {
let address = self.execute_command(Command::Connect, address).await?;
Ok((self.protocol.socket, address))
}
pub async fn tor_resolve(&mut self, address: &Multiaddr) -> Result<Multiaddr> {
let (dns, rest) = multiaddr_split_first(address);
let mut resolved = self.execute_command(Command::TorResolve, &dns.into()).await?;
resolved.pop();
for r in rest {
resolved.push(r);
}
Ok(resolved)
}
pub async fn tor_resolve_ptr(&mut self, address: &Multiaddr) -> Result<Multiaddr> {
self.execute_command(Command::TorResolvePtr, address).await
}
async fn execute_command(&mut self, command: Command, address: &Multiaddr) -> Result<Multiaddr> {
if !self.is_authenticated {
self.protocol.authenticate().await?;
self.is_authenticated = true;
}
let address = self.protocol.send_command(command, address).await?;
Ok(address)
}
fn validate_auth(auth: &Authentication) -> Result<()> {
match auth {
Authentication::None => {},
Authentication::Password { username, password } => {
let username_len = username.len();
if !(1..=255).contains(&username_len) {
return Err(SocksError::InvalidAuthValues(
"username length should between 1 to 255".to_string(),
));
}
let password_len = password.len();
if !(1..=255).contains(&password_len) {
return Err(SocksError::InvalidAuthValues(
"password length should between 1 to 255".to_string(),
));
}
},
}
Ok(())
}
}
fn multiaddr_split_first(addr: &Multiaddr) -> (Protocol<'_>, Vec<Protocol<'_>>) {
let mut iter = addr.iter();
let proto = iter
.next()
.expect("prepare_multiaddr_for_tor_resolve: received empty `Multiaddr`");
let rest = iter.collect();
(proto, rest)
}
const SOCKS_BUFFER_LENGTH: usize = 513;
struct SocksProtocol<TSocket> {
socket: TSocket,
authentication: Authentication,
buf: Box<[u8; SOCKS_BUFFER_LENGTH]>,
ptr: usize,
len: usize,
}
impl<TSocket> SocksProtocol<TSocket>
where TSocket: AsyncRead + AsyncWrite + Unpin
{
fn new(socket: TSocket) -> Self {
SocksProtocol {
socket,
authentication: Default::default(),
buf: Box::new([0; 513]),
ptr: 0,
len: 0,
}
}
pub async fn authenticate(&mut self) -> Result<()> {
self.prepare_send_auth_method_selection()?;
self.write().await?;
self.prepare_recv_auth_method_selection();
self.read().await?;
if *self.buf.first().ok_or(SocksError::InvalidAmountOfBytesRead)? != 0x05 {
return Err(SocksError::InvalidResponseVersion);
}
match *self.buf.get(1).ok_or(SocksError::InvalidAmountOfBytesRead)? {
0x00 => {
},
0x02 => {
self.password_authentication_protocol().await?;
},
0xff => {
return Err(SocksError::NoAcceptableAuthMethods);
},
m if m != self.authentication.id() => return Err(SocksError::UnknownAuthMethod),
_ => unimplemented!(),
}
Ok(())
}
pub fn set_authentication(&mut self, authentication: Authentication) {
self.authentication = authentication;
}
pub async fn send_command(&mut self, command: Command, address: &Multiaddr) -> Result<Multiaddr> {
self.prepare_send_request(command, address)?;
self.write().await?;
self.receive_reply().await
}
async fn password_authentication_protocol(&mut self) -> Result<()> {
self.prepare_send_password_auth()?;
self.write().await?;
self.prepare_recv_password_auth();
self.read().await?;
if *self.buf.first().ok_or(SocksError::InvalidAmountOfBytesRead)? != 0x01 {
return Err(SocksError::InvalidResponseVersion);
}
if *self.buf.get(1).ok_or(SocksError::InvalidAmountOfBytesRead)? != 0x00 {
return Err(SocksError::PasswordAuthFailure(
*self.buf.get(1).expect("Already checked"),
));
}
Ok(())
}
async fn receive_reply(&mut self) -> Result<Multiaddr> {
self.prepare_recv_reply();
self.ptr += self.read().await?;
if *self.buf.first().ok_or(SocksError::InvalidAmountOfBytesRead)? != 0x05 {
return Err(SocksError::InvalidResponseVersion);
}
if *self.buf.get(2).ok_or(SocksError::InvalidAmountOfBytesRead)? != 0x00 {
return Err(SocksError::InvalidReservedByte);
}
let auth_byte = *self.buf.get(1).ok_or(SocksError::InvalidAmountOfBytesRead)?;
if auth_byte != 0x00 {
return match auth_byte {
0x00 => unreachable!(),
0x01 => Err(SocksError::GeneralSocksServerFailure),
0x02 => Err(SocksError::ConnectionNotAllowedByRuleset),
0x03 => Err(SocksError::NetworkUnreachable),
0x04 => Err(SocksError::HostUnreachable),
0x05 => Err(SocksError::ConnectionRefused),
0x06 => Err(SocksError::TtlExpired),
0x07 => Err(SocksError::CommandNotSupported),
0x08 => Err(SocksError::AddressTypeNotSupported),
_ => Err(SocksError::UnknownAuthMethod),
};
}
match *self.buf.get(3).ok_or(SocksError::InvalidAmountOfBytesRead)? {
0x01 => {
self.len = 10;
},
0x04 => {
self.len = 22;
},
0x03 => {
self.len = 5;
self.ptr += self.read().await?;
self.len += *self.buf.get(4).ok_or(SocksError::InvalidAmountOfBytesRead)? as usize + 2;
},
_ => return Err(SocksError::UnknownAddressType),
}
self.ptr += self.read().await?;
let address = match *self.buf.get(3).ok_or(SocksError::InvalidAmountOfBytesRead)? {
0x01 => {
let mut ip = [0; 4];
ip[..].copy_from_slice(self.buf.get(4..8).ok_or(SocksError::InvalidAmountOfBytesRead)?);
let ip = Ipv4Addr::from(ip);
let port = u16::from_be_bytes([
*self.buf.get(8).ok_or(SocksError::InvalidAmountOfBytesRead)?,
*self.buf.get(9).ok_or(SocksError::InvalidAmountOfBytesRead)?,
]);
let mut addr: Multiaddr = Protocol::Ip4(ip).into();
addr.push(Protocol::Tcp(port));
addr
},
0x04 => {
let mut ip = [0; 16];
ip[..].copy_from_slice(self.buf.get(4..20).ok_or(SocksError::InvalidAmountOfBytesRead)?);
let ip = Ipv6Addr::from(ip);
let port = u16::from_be_bytes([
*self.buf.get(20).ok_or(SocksError::InvalidAmountOfBytesRead)?,
*self.buf.get(21).ok_or(SocksError::InvalidAmountOfBytesRead)?,
]);
let mut addr: Multiaddr = Protocol::Ip6(ip).into();
addr.push(Protocol::Tcp(port));
addr
},
0x03 => {
let domain_bytes = (self
.buf
.get(5..(self.len - 2))
.ok_or(SocksError::InvalidAmountOfBytesRead)?)
.to_vec();
let domain = String::from_utf8(domain_bytes)
.map_err(|_| SocksError::InvalidTargetAddress("domain bytes are not a valid UTF-8 string"))?;
let mut addr: Multiaddr = Protocol::Dns4(Cow::Owned(domain)).into();
let port = u16::from_be_bytes([
*self.buf.get(self.len - 2).ok_or(SocksError::InvalidAmountOfBytesRead)?,
*self.buf.get(self.len - 1).ok_or(SocksError::InvalidAmountOfBytesRead)?,
]);
addr.push(Protocol::Tcp(port));
addr
},
_ => unreachable!(),
};
Ok(address)
}
fn prepare_send_auth_method_selection(&mut self) -> Result<()> {
self.ptr = 0;
*self.buf.get_mut(0).ok_or(SocksError::InvalidAmountOfBytesRead)? = 0x05;
match self.authentication {
Authentication::None => {
self.buf
.get_mut(1..3)
.ok_or(SocksError::InvalidAmountOfBytesRead)?
.copy_from_slice(&[1, 0x00]);
self.len = 3;
},
Authentication::Password { .. } => {
self.buf
.get_mut(1..4)
.ok_or(SocksError::InvalidAmountOfBytesRead)?
.copy_from_slice(&[2, 0x00, 0x02]);
self.len = 4;
},
}
Ok(())
}
fn prepare_recv_auth_method_selection(&mut self) {
self.ptr = 0;
self.len = 2;
}
fn prepare_send_password_auth(&mut self) -> Result<()> {
match &self.authentication {
Authentication::Password { username, password } => {
self.ptr = 0;
*self.buf.get_mut(0).ok_or(SocksError::InvalidAmountOfBytesRead)? = 0x01;
let username_bytes = username.as_bytes();
let username_len = username_bytes.len();
*self.buf.get_mut(1).ok_or(SocksError::InvalidAmountOfBytesRead)? = u8::try_from(username_len).unwrap();
self.buf
.get_mut(2..(2 + username_len))
.ok_or(SocksError::InvalidAmountOfBytesRead)?
.copy_from_slice(username_bytes);
let password_bytes = password.as_bytes();
let password_len = password_bytes.len();
self.len = 3 + username_len + password_len;
*self
.buf
.get_mut(2 + username_len)
.ok_or(SocksError::InvalidAmountOfBytesRead)? = u8::try_from(password_len).unwrap();
self.buf
.get_mut((3 + username_len)..self.len)
.ok_or(SocksError::InvalidAmountOfBytesRead)?
.copy_from_slice(password_bytes);
},
Authentication::None => unreachable!(),
};
Ok(())
}
fn prepare_recv_password_auth(&mut self) {
self.ptr = 0;
self.len = 2;
}
fn prepare_send_request(&mut self, command: Command, address: &Multiaddr) -> Result<()> {
self.ptr = 0;
self.buf
.get_mut(..3)
.ok_or(SocksError::InvalidAmountOfBytesRead)?
.copy_from_slice(&[0x05, command as u8, 0x00]);
let mut addr_iter = address.iter();
let part1 = addr_iter
.next()
.ok_or(SocksError::InvalidTargetAddress("Address contained no components"))?;
let part2 = addr_iter.next();
match (part1, part2) {
(Protocol::Ip4(ip), Some(Protocol::Tcp(port))) => {
*self.buf.get_mut(3).ok_or(SocksError::InvalidAmountOfBytesRead)? = 0x01;
self.buf
.get_mut(4..8)
.ok_or(SocksError::InvalidAmountOfBytesRead)?
.copy_from_slice(&ip.octets());
self.buf
.get_mut(8..10)
.ok_or(SocksError::InvalidAmountOfBytesRead)?
.copy_from_slice(&port.to_be_bytes());
self.len = 10;
},
(Protocol::Ip6(ip), Some(Protocol::Tcp(port))) => {
*self.buf.get_mut(3).ok_or(SocksError::InvalidAmountOfBytesRead)? = 0x04;
self.buf
.get_mut(4..20)
.ok_or(SocksError::InvalidAmountOfBytesRead)?
.copy_from_slice(&ip.octets());
self.buf
.get_mut(20..22)
.ok_or(SocksError::InvalidAmountOfBytesRead)?
.copy_from_slice(&port.to_be_bytes());
self.len = 22;
},
(Protocol::Dns4(domain), Some(Protocol::Tcp(port))) => {
*self.buf.get_mut(3).ok_or(SocksError::InvalidAmountOfBytesRead)? = 0x03;
let domain = domain.as_bytes();
let len = domain.len();
*self.buf.get_mut(4).ok_or(SocksError::InvalidAmountOfBytesRead)? = u8::try_from(len).unwrap();
self.buf
.get_mut(5..5 + len)
.ok_or(SocksError::InvalidAmountOfBytesRead)?
.copy_from_slice(domain);
self.buf
.get_mut((5 + len)..(7 + len))
.ok_or(SocksError::InvalidAmountOfBytesRead)?
.copy_from_slice(&port.to_be_bytes());
self.len = 7 + len;
},
(Protocol::Dns4(domain), None) | (Protocol::Dns(domain), None) => {
*self.buf.get_mut(3).ok_or(SocksError::InvalidAmountOfBytesRead)? = 0x03;
let domain = domain.as_bytes();
let len = domain.len();
*self.buf.get_mut(4).ok_or(SocksError::InvalidAmountOfBytesRead)? = u8::try_from(len).unwrap();
self.buf
.get_mut(5..5 + len)
.ok_or(SocksError::InvalidAmountOfBytesRead)?
.copy_from_slice(domain);
*self.buf.get_mut(5 + len).ok_or(SocksError::InvalidAmountOfBytesRead)? = 0;
*self.buf.get_mut(6 + len).ok_or(SocksError::InvalidAmountOfBytesRead)? = 0;
self.len = 7 + len;
},
(p @ Protocol::Onion(_, _), None) => {
*self.buf.get_mut(3).ok_or(SocksError::InvalidAmountOfBytesRead)? = 0x03;
let (domain, port) = Self::extract_onion_address(&p)?;
let len = domain.len();
*self.buf.get_mut(4).ok_or(SocksError::InvalidAmountOfBytesRead)? = u8::try_from(len).unwrap();
self.buf
.get_mut(5..5 + len)
.ok_or(SocksError::InvalidAmountOfBytesRead)?
.copy_from_slice(domain.as_bytes());
self.buf
.get_mut((5 + len)..(7 + len))
.ok_or(SocksError::InvalidAmountOfBytesRead)?
.copy_from_slice(&port.to_be_bytes());
self.len = 7 + len;
},
(Protocol::Onion3(addr), None) => {
*self.buf.get_mut(3).ok_or(SocksError::InvalidAmountOfBytesRead)? = 0x03;
let port = addr.port();
let domain = format!("{}.onion", BASE32.encode(addr.hash()));
let len = domain.len();
*self.buf.get_mut(4).ok_or(SocksError::InvalidAmountOfBytesRead)? = u8::try_from(len).unwrap();
self.buf
.get_mut(5..5 + len)
.ok_or(SocksError::InvalidAmountOfBytesRead)?
.copy_from_slice(domain.as_bytes());
self.buf
.get_mut((5 + len)..(7 + len))
.ok_or(SocksError::InvalidAmountOfBytesRead)?
.copy_from_slice(&port.to_be_bytes());
self.len = 7 + len;
},
_ => return Err(SocksError::AddressTypeNotSupported),
}
Ok(())
}
fn extract_onion_address(p: &Protocol<'_>) -> Result<(String, u16)> {
let onion_addr = p.to_string();
let mut parts = onion_addr.split('/').nth(2).expect("already checked").split(':');
let domain = format!("{}.onion", parts.next().expect("already checked"),);
let port = parts
.next()
.expect("already checked")
.parse::<u16>()
.map_err(|_| SocksError::InvalidTargetAddress("Invalid onion address port"))?;
Ok((domain, port))
}
fn prepare_recv_reply(&mut self) {
self.ptr = 0;
self.len = 4;
}
async fn write(&mut self) -> Result<()> {
self.socket
.write_all(
self.buf
.get(self.ptr..self.len)
.ok_or(SocksError::InvalidAmountOfBytesRead)?,
)
.await
.map_err(Into::into)
}
async fn read(&mut self) -> Result<usize> {
self.socket
.read_exact(
self.buf
.get_mut(self.ptr..self.len)
.ok_or(SocksError::InvalidAmountOfBytesRead)?,
)
.await?;
Ok(self.len - self.ptr)
}
}