use crate::{Error, Result};
use caret::caret_int;
use std::fmt;
use std::net::IpAddr;
#[cfg(feature = "arbitrary")]
use std::net::Ipv6Addr;
use tor_error::bad_api_usage;
#[cfg(feature = "arbitrary")]
use arbitrary::{Arbitrary, Result as ArbitraryResult, Unstructured};
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
#[non_exhaustive]
pub enum SocksVersion {
V4,
V5,
}
impl TryFrom<u8> for SocksVersion {
type Error = Error;
fn try_from(v: u8) -> Result<SocksVersion> {
match v {
4 => Ok(SocksVersion::V4),
5 => Ok(SocksVersion::V5),
_ => Err(Error::BadProtocol(v)),
}
}
}
#[derive(Clone, Debug)]
#[cfg_attr(test, derive(PartialEq, Eq))]
pub struct SocksRequest {
version: SocksVersion,
cmd: SocksCmd,
addr: SocksAddr,
port: u16,
auth: SocksAuth,
}
#[cfg(feature = "arbitrary")]
impl<'a> Arbitrary<'a> for SocksRequest {
fn arbitrary(u: &mut Unstructured<'a>) -> ArbitraryResult<Self> {
let version = SocksVersion::arbitrary(u)?;
let cmd = SocksCmd::arbitrary(u)?;
let addr = SocksAddr::arbitrary(u)?;
let port = u16::arbitrary(u)?;
let auth = SocksAuth::arbitrary(u)?;
SocksRequest::new(version, cmd, addr, port, auth)
.map_err(|_| arbitrary::Error::IncorrectFormat)
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[allow(clippy::exhaustive_enums)]
pub enum SocksAddr {
Hostname(SocksHostname),
Ip(IpAddr),
}
#[cfg(feature = "arbitrary")]
impl<'a> Arbitrary<'a> for SocksAddr {
fn arbitrary(u: &mut Unstructured<'a>) -> ArbitraryResult<Self> {
use std::net::Ipv4Addr;
let b = u8::arbitrary(u)?;
Ok(match b % 3 {
0 => SocksAddr::Hostname(SocksHostname::arbitrary(u)?),
1 => SocksAddr::Ip(IpAddr::V4(Ipv4Addr::arbitrary(u)?)),
_ => SocksAddr::Ip(IpAddr::V6(Ipv6Addr::arbitrary(u)?)),
})
}
fn size_hint(_depth: usize) -> (usize, Option<usize>) {
(1, Some(256))
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct SocksHostname(String);
#[cfg(feature = "arbitrary")]
impl<'a> Arbitrary<'a> for SocksHostname {
fn arbitrary(u: &mut Unstructured<'a>) -> ArbitraryResult<Self> {
String::arbitrary(u)?
.try_into()
.map_err(|_| arbitrary::Error::IncorrectFormat)
}
fn size_hint(_depth: usize) -> (usize, Option<usize>) {
(0, Some(255))
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
#[non_exhaustive]
pub enum SocksAuth {
NoAuth,
Socks4(Vec<u8>),
Username(Vec<u8>, Vec<u8>),
}
caret_int! {
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
pub struct SocksCmd(u8) {
CONNECT = 1,
BIND = 2,
UDP_ASSOCIATE = 3,
RESOLVE = 0xF0,
RESOLVE_PTR = 0xF1,
}
}
caret_int! {
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
pub struct SocksStatus(u8) {
SUCCEEDED = 0x00,
GENERAL_FAILURE = 0x01,
NOT_ALLOWED = 0x02,
NETWORK_UNREACHABLE = 0x03,
HOST_UNREACHABLE = 0x04,
CONNECTION_REFUSED = 0x05,
TTL_EXPIRED = 0x06,
COMMAND_NOT_SUPPORTED = 0x07,
ADDRTYPE_NOT_SUPPORTED = 0x08,
HS_DESC_NOT_FOUND = 0xF0,
HS_DESC_INVALID = 0xF1,
HS_INTRO_FAILED = 0xF2,
HS_REND_FAILED = 0xF3,
HS_MISSING_CLIENT_AUTH = 0xF4,
HS_WRONG_CLIENT_AUTH = 0xF5,
HS_BAD_ADDRESS = 0xF6,
HS_INTRO_TIMEOUT = 0xF7
}
}
impl SocksCmd {
fn recognized(self) -> bool {
matches!(
self,
SocksCmd::CONNECT | SocksCmd::RESOLVE | SocksCmd::RESOLVE_PTR
)
}
fn requires_port(self) -> bool {
matches!(
self,
SocksCmd::CONNECT | SocksCmd::BIND | SocksCmd::UDP_ASSOCIATE
)
}
}
impl SocksStatus {
#[cfg(feature = "proxy-handshake")]
pub(crate) fn into_socks4_status(self) -> u8 {
match self {
SocksStatus::SUCCEEDED => 0x5A,
_ => 0x5B,
}
}
#[cfg(feature = "client-handshake")]
pub(crate) fn from_socks4_status(status: u8) -> Self {
match status {
0x5A => SocksStatus::SUCCEEDED,
0x5B => SocksStatus::GENERAL_FAILURE,
0x5C | 0x5D => SocksStatus::NOT_ALLOWED,
_ => SocksStatus::GENERAL_FAILURE,
}
}
}
impl TryFrom<String> for SocksHostname {
type Error = Error;
fn try_from(s: String) -> Result<SocksHostname> {
if s.len() > 255 {
Err(bad_api_usage!("hostname too long").into())
} else if contains_zeros(s.as_bytes()) {
Err(Error::Syntax)
} else {
Ok(SocksHostname(s))
}
}
}
impl AsRef<str> for SocksHostname {
fn as_ref(&self) -> &str {
self.0.as_ref()
}
}
impl SocksAuth {
fn validate(&self, version: SocksVersion) -> Result<()> {
match self {
SocksAuth::NoAuth => {}
SocksAuth::Socks4(data) => {
if version != SocksVersion::V4 || contains_zeros(data) {
return Err(Error::Syntax);
}
}
SocksAuth::Username(user, pass) => {
if version != SocksVersion::V5
|| user.len() > u8::MAX as usize
|| pass.len() > u8::MAX as usize
{
return Err(Error::Syntax);
}
}
}
Ok(())
}
}
fn contains_zeros(b: &[u8]) -> bool {
use subtle::{Choice, ConstantTimeEq};
let c: Choice = b
.iter()
.fold(Choice::from(0), |seen_any, byte| seen_any | byte.ct_eq(&0));
c.unwrap_u8() != 0
}
impl SocksRequest {
pub fn new(
version: SocksVersion,
cmd: SocksCmd,
addr: SocksAddr,
port: u16,
auth: SocksAuth,
) -> Result<Self> {
if !cmd.recognized() {
return Err(Error::NotImplemented(
format!("SOCKS command {}", cmd).into(),
));
}
if port == 0 && cmd.requires_port() {
return Err(Error::Syntax);
}
auth.validate(version)?;
Ok(SocksRequest {
version,
cmd,
addr,
port,
auth,
})
}
pub fn version(&self) -> SocksVersion {
self.version
}
pub fn command(&self) -> SocksCmd {
self.cmd
}
pub fn auth(&self) -> &SocksAuth {
&self.auth
}
pub fn port(&self) -> u16 {
self.port
}
pub fn addr(&self) -> &SocksAddr {
&self.addr
}
}
impl fmt::Display for SocksAddr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SocksAddr::Ip(a) => write!(f, "{}", a),
SocksAddr::Hostname(h) => write!(f, "{}", h.0),
}
}
}
#[derive(Debug, Clone)]
pub struct SocksReply {
status: SocksStatus,
addr: SocksAddr,
port: u16,
}
impl SocksReply {
#[cfg(feature = "client-handshake")]
pub(crate) fn new(status: SocksStatus, addr: SocksAddr, port: u16) -> Self {
Self { status, addr, port }
}
pub fn status(&self) -> SocksStatus {
self.status
}
pub fn addr(&self) -> &SocksAddr {
&self.addr
}
pub fn port(&self) -> u16 {
self.port
}
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_duration_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use super::*;
#[test]
fn display_sa() {
let a = SocksAddr::Ip(IpAddr::V4("127.0.0.1".parse().unwrap()));
assert_eq!(a.to_string(), "127.0.0.1");
let a = SocksAddr::Ip(IpAddr::V6("f00::9999".parse().unwrap()));
assert_eq!(a.to_string(), "f00::9999");
let a = SocksAddr::Hostname("www.torproject.org".to_string().try_into().unwrap());
assert_eq!(a.to_string(), "www.torproject.org");
}
#[test]
fn ok_request() {
let localhost_v4 = SocksAddr::Ip(IpAddr::V4("127.0.0.1".parse().unwrap()));
let r = SocksRequest::new(
SocksVersion::V4,
SocksCmd::CONNECT,
localhost_v4.clone(),
1024,
SocksAuth::NoAuth,
)
.unwrap();
assert_eq!(r.version(), SocksVersion::V4);
assert_eq!(r.command(), SocksCmd::CONNECT);
assert_eq!(r.addr(), &localhost_v4);
assert_eq!(r.auth(), &SocksAuth::NoAuth);
}
#[test]
fn bad_request() {
let localhost_v4 = SocksAddr::Ip(IpAddr::V4("127.0.0.1".parse().unwrap()));
let e = SocksRequest::new(
SocksVersion::V4,
SocksCmd::BIND,
localhost_v4.clone(),
1024,
SocksAuth::NoAuth,
);
assert!(matches!(e, Err(Error::NotImplemented(_))));
let e = SocksRequest::new(
SocksVersion::V4,
SocksCmd::CONNECT,
localhost_v4,
0,
SocksAuth::NoAuth,
);
assert!(matches!(e, Err(Error::Syntax)));
}
#[test]
fn test_contains_zeros() {
assert!(contains_zeros(b"Hello\0world"));
assert!(!contains_zeros(b"Hello world"));
}
}