use std::{
fmt,
io::{self, Error, ErrorKind},
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
};
use byteorder::{BigEndian, ByteOrder};
use bytes::{BufMut, BytesMut};
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use crate::relay::socks5;
#[rustfmt::skip]
mod consts {
pub const SOCKS4_VERSION: u8 = 4;
pub const SOCKS4_COMMAND_CONNECT: u8 = 1;
pub const SOCKS4_COMMAND_BIND: u8 = 2;
pub const SOCKS4_RESULT_REQUEST_GRANTED: u8 = 90;
pub const SOCKS4_RESULT_REQUEST_REJECTED_OR_FAILED: u8 = 91;
pub const SOCKS4_RESULT_REQUEST_REJECTED_CANNOT_CONNECT: u8 = 92;
pub const SOCKS4_RESULT_REQUEST_REJECTED_DFFERENT_USER_ID: u8 = 93;
}
#[derive(Clone, Debug, Copy)]
pub enum Command {
Connect,
Bind,
}
impl Command {
#[inline]
fn as_u8(self) -> u8 {
match self {
Command::Connect => consts::SOCKS4_COMMAND_CONNECT,
Command::Bind => consts::SOCKS4_COMMAND_BIND,
}
}
#[inline]
fn from_u8(code: u8) -> Option<Command> {
match code {
consts::SOCKS4_COMMAND_CONNECT => Some(Command::Connect),
consts::SOCKS4_COMMAND_BIND => Some(Command::Bind),
_ => None,
}
}
}
#[derive(Clone, Debug, Copy)]
pub enum ResultCode {
RequestGranted,
RequestRejectedOrFailed,
RequestRejectedCannotConnect,
RequestRejectedDifferentUserId,
Other(u8),
}
impl ResultCode {
#[inline]
fn as_u8(self) -> u8 {
match self {
ResultCode::RequestGranted => consts::SOCKS4_RESULT_REQUEST_GRANTED,
ResultCode::RequestRejectedOrFailed => consts::SOCKS4_RESULT_REQUEST_REJECTED_OR_FAILED,
ResultCode::RequestRejectedCannotConnect => consts::SOCKS4_RESULT_REQUEST_REJECTED_CANNOT_CONNECT,
ResultCode::RequestRejectedDifferentUserId => consts::SOCKS4_RESULT_REQUEST_REJECTED_DFFERENT_USER_ID,
ResultCode::Other(c) => c,
}
}
#[inline]
fn from_u8(code: u8) -> ResultCode {
match code {
consts::SOCKS4_RESULT_REQUEST_GRANTED => ResultCode::RequestGranted,
consts::SOCKS4_RESULT_REQUEST_REJECTED_OR_FAILED => ResultCode::RequestRejectedOrFailed,
consts::SOCKS4_RESULT_REQUEST_REJECTED_CANNOT_CONNECT => ResultCode::RequestRejectedCannotConnect,
consts::SOCKS4_RESULT_REQUEST_REJECTED_DFFERENT_USER_ID => ResultCode::RequestRejectedDifferentUserId,
code => ResultCode::Other(code),
}
}
}
impl fmt::Display for ResultCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
ResultCode::RequestGranted => f.write_str("request granted"),
ResultCode::RequestRejectedOrFailed => f.write_str("request rejected or failed"),
ResultCode::RequestRejectedCannotConnect => {
f.write_str("request rejected becasue SOCKS server cannot connect to identd on the client")
}
ResultCode::RequestRejectedDifferentUserId => {
f.write_str("request rejected because the client program and identd report different user-ids")
}
ResultCode::Other(code) => write!(f, "other result code {}", code),
}
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub enum Address {
SocketAddress(SocketAddrV4),
DomainNameAddress(String, u16),
}
impl fmt::Debug for Address {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Address::SocketAddress(ref addr) => write!(f, "{}", addr),
Address::DomainNameAddress(ref addr, ref port) => write!(f, "{}:{}", addr, port),
}
}
}
impl fmt::Display for Address {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Address::SocketAddress(ref addr) => write!(f, "{}", addr),
Address::DomainNameAddress(ref addr, ref port) => write!(f, "{}:{}", addr, port),
}
}
}
impl From<SocketAddrV4> for Address {
fn from(s: SocketAddrV4) -> Address {
Address::SocketAddress(s)
}
}
impl From<(String, u16)> for Address {
fn from((dn, port): (String, u16)) -> Address {
Address::DomainNameAddress(dn, port)
}
}
impl From<&Address> for Address {
fn from(addr: &Address) -> Address {
addr.clone()
}
}
impl From<Address> for socks5::Address {
fn from(addr: Address) -> socks5::Address {
match addr {
Address::SocketAddress(a) => socks5::Address::SocketAddress(SocketAddr::V4(a)),
Address::DomainNameAddress(d, p) => socks5::Address::DomainNameAddress(d, p),
}
}
}
#[derive(Debug, Clone)]
pub struct HandshakeRequest {
pub cd: Command,
pub dst: Address,
pub user_id: Vec<u8>,
}
impl HandshakeRequest {
pub async fn read_from<R>(r: &mut R) -> io::Result<HandshakeRequest>
where
R: AsyncBufRead + Unpin,
{
let mut buf = [0u8; 8];
let _ = r.read_exact(&mut buf).await?;
let vn = buf[0];
if vn != consts::SOCKS4_VERSION {
let err = Error::new(ErrorKind::InvalidData, format!("unsupported socks version {:#x}", vn));
return Err(err);
}
let cd = buf[1];
let command = match Command::from_u8(cd) {
Some(c) => c,
None => {
let err = Error::new(ErrorKind::InvalidData, format!("unsupported command {:#x}", cd));
return Err(err);
}
};
let port = BigEndian::read_u16(&buf[2..4]);
let mut user_id = Vec::new();
let _ = r.read_until(b'\0', &mut user_id).await?;
if user_id.is_empty() || user_id.last() != Some(&b'\0') {
return Err(ErrorKind::UnexpectedEof.into());
}
user_id.pop();
let dst = if buf[4] == 0x00 && buf[5] == 0x00 && buf[6] == 0x00 && buf[7] != 0x00 {
let mut host = Vec::new();
let _ = r.read_until(b'\0', &mut host).await?;
if host.is_empty() || host.last() != Some(&b'\0') {
return Err(ErrorKind::UnexpectedEof.into());
}
host.pop();
match String::from_utf8(host) {
Ok(host) => Address::DomainNameAddress(host, port),
Err(..) => {
let err = Error::new(ErrorKind::InvalidData, "invalid host encoding");
return Err(err);
}
}
} else {
let ip = Ipv4Addr::new(buf[4], buf[5], buf[6], buf[7]);
Address::SocketAddress(SocketAddrV4::new(ip, port))
};
Ok(HandshakeRequest {
cd: command,
dst,
user_id,
})
}
pub fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
debug_assert!(
!self.user_id.contains(&b'\0'),
"USERID shouldn't contain any NULL characters"
);
buf.put_u8(consts::SOCKS4_VERSION);
buf.put_u8(self.cd.as_u8());
match self.dst {
Address::SocketAddress(ref saddr) => {
let port = saddr.port();
buf.put_u16(port);
buf.put_slice(&saddr.ip().octets());
buf.put_slice(&self.user_id);
buf.put_u8(b'\0');
}
Address::DomainNameAddress(ref dname, port) => {
buf.put_u16(port);
const PLACEHOLDER: [u8; 4] = [0x00, 0x00, 0x00, 0xff];
buf.put_slice(&PLACEHOLDER);
buf.put_slice(&self.user_id);
buf.put_u8(b'\0');
buf.put_slice(dname.as_bytes());
buf.put_u8(b'\0');
}
}
}
#[inline]
pub fn serialized_len(&self) -> usize {
let mut s = 1 + 1 + 2 + 4 + self.user_id.len() + 1;
if let Address::DomainNameAddress(ref dname, _) = self.dst {
s += dname.len() + 1;
}
s
}
}
#[derive(Debug, Clone)]
pub struct HandshakeResponse {
pub cd: ResultCode,
}
impl HandshakeResponse {
pub fn new(code: ResultCode) -> HandshakeResponse {
HandshakeResponse { cd: code }
}
pub async fn read_from<R>(r: &mut R) -> io::Result<HandshakeResponse>
where
R: AsyncRead + Unpin,
{
let mut buf = [0u8; 8];
let _ = r.read_exact(&mut buf).await?;
let vn = buf[0];
if vn != consts::SOCKS4_VERSION {
let err = Error::new(ErrorKind::InvalidData, format!("unsupported socks version {:#x}", vn));
return Err(err);
}
let cd = buf[1];
let result_code = ResultCode::from_u8(cd);
Ok(HandshakeResponse { cd: result_code })
}
pub async fn write_to<W>(&self, w: &mut W) -> io::Result<()>
where
W: AsyncWrite + Unpin,
{
let mut buf = BytesMut::with_capacity(self.serialized_len());
self.write_to_buf(&mut buf);
w.write_all(&buf).await
}
pub fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
let HandshakeResponse { ref cd } = *self;
buf.put_slice(&[
0x00,
cd.as_u8(),
0x00,
0x00,
0x00,
0x00,
0x00,
0x00,
]);
}
#[inline]
pub fn serialized_len(&self) -> usize {
1 + 1 + 2 + 4
}
}