koibumi-socks 0.0.0

A minimal SOCKS5 client library
Documentation
//! This crate is a minimal SOCKS5 client library.
//! Supports [`async-std`](https://crates.io/crates/async-std) only.
//!
//! Intended to use with a local Tor SOCKS5 proxy.
//! No authentication is supported.
//!
//! # Examples
//!
//! Connect to the web server at example.net:80
//! via a local Tor SOCKS5 proxy at 127.0.0.1:9050,
//! issue a GET command,
//! read and print the response:
//!
//! ```no_run
//! # type Error = Box<dyn std::error::Error + Send + Sync>;
//! # type Result<T> = std::result::Result<T, Error>;
//! # async fn test_connect() -> Result<()> {
//! #
//! use async_std::{
//!     io::{prelude::WriteExt, ReadExt},
//!     net::TcpStream,
//! };
//! use koibumi_socks::{self as socks, DomainName, SocketDomainName};
//!
//! let mut stream = TcpStream::connect("127.0.0.1:9050").await?;
//!
//! let destination = socks::SocketAddr::DomainName(
//!     SocketDomainName::new(
//!         DomainName::new(b"example.net".to_vec()).unwrap(), 80));
//!
//! let _dest = socks::connect(&mut stream, destination).await?;
//!
//! stream.write_all(b"GET /\n").await?;
//!
//! let mut bytes = Vec::new();
//! stream.read_to_end(&mut bytes).await?;
//! print!("{}", String::from_utf8_lossy(&bytes));
//! #
//! # Ok(())
//! # }
//! # async_std::task::block_on(test_connect());
//! ```

// See RFC 1928 SOCKS Protocol Version 5

use std::{
    fmt,
    net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6},
    str::FromStr,
};

use async_std::io::{self, prelude::WriteExt, ReadExt};

/// This type represents a domain name used by SOCKS5.
///
/// The maximum length is 255 bytes.
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub struct DomainName {
    bytes: Vec<u8>,
}

impl AsRef<[u8]> for DomainName {
    fn as_ref(&self) -> &[u8] {
        self.bytes.as_ref()
    }
}

/// An error which can be returned when parsing a domain name.
///
/// This error is used as the error type for the [`DomainName::new()`] method
/// and the `FromStr` implementation for [`DomainName`].
///
/// [`DomainName::new()`]: struct.DomainName.html#method.new
/// [`DomainName`]: struct.DomainName.html
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum ParseDomainNameError {
    /// The input was too long to construct a domain name for SOCKS5.
    /// The maximum length allowed and the actual length of the input
    /// are returned as payloads of this variant.
    TooLong { max: usize, len: usize },
}

impl fmt::Display for ParseDomainNameError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::TooLong { max, len } => write!(f, "length must be <={}, but {}", max, len),
        }
    }
}

impl std::error::Error for ParseDomainNameError {}

impl DomainName {
    const MAX_LEN: usize = 0xff;

    /// Constructs a domain name from a byte string.
    ///
    /// The byte length is checked.
    pub fn new(bytes: Vec<u8>) -> std::result::Result<Self, ParseDomainNameError> {
        if bytes.len() > Self::MAX_LEN {
            return Err(ParseDomainNameError::TooLong {
                max: Self::MAX_LEN,
                len: bytes.len(),
            });
        }
        Ok(Self { bytes })
    }
}

impl fmt::Display for DomainName {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        String::from_utf8_lossy(self.bytes.as_ref()).fmt(f)
    }
}

impl FromStr for DomainName {
    type Err = ParseDomainNameError;

    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
        Ok(Self::new(s.as_bytes().to_vec())?)
    }
}

/// This type represents IPv4, IPv6 address or a domain name used by SOCKS5.
///
/// The inner IP address types are defined in the `std::net` module.
/// On the other hand, the inner domain name type is defined in this own crate.
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub enum Addr {
    /// An IPv4 address.
    Ipv4(Ipv4Addr),

    /// A domain name.
    DomainName(DomainName),

    /// An IPv6 address.
    Ipv6(Ipv6Addr),
}

impl fmt::Display for Addr {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Addr::Ipv4(addr) => addr.fmt(f),
            Addr::DomainName(addr) => addr.fmt(f),
            Addr::Ipv6(addr) => addr.fmt(f),
        }
    }
}

type Port = u16;

/// This type represents a socket address which uses domain name,
/// that is, a domain name with a port.
///
/// For restrictions on a domain name, see documents for
/// [`DomainName`](struct.DomainName.html).
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct SocketDomainName {
    domain_name: DomainName,
    port: Port,
}

impl SocketDomainName {
    /// Constructs a socket address from a domain name with a port.
    pub fn new(domain_name: DomainName, port: Port) -> Self {
        Self { domain_name, port }
    }

    /// Returns the domain name part of the socket address.
    pub fn domain_name(&self) -> &DomainName {
        &self.domain_name
    }

    /// Returns the port part of the socket address.
    pub fn port(&self) -> Port {
        self.port
    }
}

impl fmt::Display for SocketDomainName {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}:{}", self.domain_name, self.port)
    }
}

/// This type represents a socket address used by SOCKS5.
///
/// The inner IP socket address types are defined in the `std::net` module.
/// On the other hand, the inner domain name socket address type is
/// defined in this own crate.
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub enum SocketAddr {
    /// A socket address using an IPv4 address.
    Ipv4(SocketAddrV4),

    /// A socket address using a domain name.
    DomainName(SocketDomainName),

    /// A socket address using an IPv6 address.
    Ipv6(SocketAddrV6),
}

impl SocketAddr {
    pub fn new(addr: Addr, port: Port) -> Self {
        match addr {
            Addr::Ipv4(addr) => Self::Ipv4(SocketAddrV4::new(addr, port)),
            Addr::DomainName(addr) => Self::DomainName(SocketDomainName::new(addr, port)),
            Addr::Ipv6(addr) => Self::Ipv6(SocketAddrV6::new(addr, port, 0, 0)),
        }
    }
}

impl fmt::Display for SocketAddr {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            SocketAddr::Ipv4(addr) => addr.fmt(f),
            SocketAddr::DomainName(addr) => addr.fmt(f),
            SocketAddr::Ipv6(addr) => addr.fmt(f),
        }
    }
}

const SOCKS_VERSION_5: u8 = 0x05;
const SOCKS_NO_AUTHENTICATION_REQUIRED: u8 = 0x00;
const SOCKS_COMMAND_CONNECT: u8 = 0x01;
const SOCKS_RESERVED: u8 = 0x00;
const SOCKS_ADDRESS_IPV4: u8 = 0x01;
const SOCKS_ADDRESS_DOMAIN_NAME: u8 = 0x03;
const SOCKS_ADDRESS_IPV6: u8 = 0x04;
const SOCKS_REPLY_SUCCEEDED: u8 = 0x00;
const SOCKS_REPLY_GENERAL_SOCKS_SERVER_FAILURE: u8 = 0x01;
const SOCKS_REPLY_HOST_UNREACHABLE: u8 = 0x04;
const SOCKS_REPLY_CONNECTION_REFUSED: u8 = 0x05;
const SOCKS_REPLY_TTL_EXPIRED: u8 = 0x06;
const SOCKS_REPLY_COMMAND_NOT_SUPPORTED: u8 = 0x07;
const SOCKS_REPLY_ADDRESS_TYPE_NOT_SUPPORTED: u8 = 0x08;

/// An error which can be returned when connecting to a destination host
/// via SOCKS5 proxy server.
///
/// This error is used as the error type for the [`connect`] function.
///
/// [`connect`]: fn.connect.html
#[derive(Debug)]
pub enum ConnectError {
    /// The server returned a version number that is not supported by this client.
    /// The actual version number received is returned as a payload of this variant.
    UnsupportedVersion(u8),

    /// The server selected a method that is not supported by this client.
    /// The actual method selected is returned as a payload of this variant.
    /// `0xff` means that the server said none of the methods listed by the client
    /// were acceptable.
    UnsupportedMethod(u8),

    /// General SOCKS server failure.
    GeneralServerFailure,

    /// Host unreachable.
    HostUnreachable,

    /// Connection refused.
    ConnectionRefused,

    /// TTL expired.
    TtlExpired,

    /// Command not supported.
    CommandNotSupported,

    /// Address type not supported.
    AddressTypeNotSupported,

    /// Unknown failure.
    /// The actual reply value received is returned as a payload of this variant.
    UnknownFailure(u8),

    /// The server returned a address type that is not supported by this client.
    /// The actual address type received is returned as a payload of this variant.
    UnsupportedAddressType(u8),

    /// A standard I/O error was caught during communication with the server.
    /// The actual error caught is returned as a payload of this variant.
    IoError(io::Error),
}

impl fmt::Display for ConnectError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::UnsupportedVersion(ver) => write!(f, "Unsupported SOCKS version: {:#02x}", ver),
            Self::UnsupportedMethod(method) => {
                write!(f, "Unsupported SOCKS method: {:#02x}", method)
            }
            Self::GeneralServerFailure => "General SOCKS server failure".fmt(f),
            Self::HostUnreachable => "Host unreachable".fmt(f),
            Self::ConnectionRefused => "Connection refused".fmt(f),
            Self::TtlExpired => "TTL expired".fmt(f),
            Self::CommandNotSupported => "Command not supported".fmt(f),
            Self::AddressTypeNotSupported => "Address type not supported".fmt(f),
            Self::UnknownFailure(rep) => write!(f, "Unknown SOCKS failure: {:#02x}", rep),
            Self::UnsupportedAddressType(atyp) => {
                write!(f, "Unsupported address type: {:#02x}", atyp)
            }

            Self::IoError(err) => err.fmt(f),
        }
    }
}

impl std::error::Error for ConnectError {}

/// A specialized Result type for SOCKS5 operations.
pub type Result<T> = std::result::Result<T, ConnectError>;

async fn write_u8<W>(s: &mut W, v: u8) -> Result<()>
where
    W: WriteExt + Unpin,
{
    let bytes = [v];
    if let Err(err) = s.write_all(&bytes).await {
        return Err(ConnectError::IoError(err));
    }
    Ok(())
}

async fn read_u8<R>(s: &mut R) -> Result<u8>
where
    R: ReadExt + Unpin,
{
    let mut bytes = [0; 1];
    if let Err(err) = s.read_exact(&mut bytes).await {
        return Err(ConnectError::IoError(err));
    }
    Ok(bytes[0])
}

async fn write_u16<W>(s: &mut W, v: u16) -> Result<()>
where
    W: WriteExt + Unpin,
{
    if let Err(err) = s.write_all(&v.to_be_bytes()).await {
        return Err(ConnectError::IoError(err));
    }
    Ok(())
}

async fn read_u16<R>(s: &mut R) -> Result<u16>
where
    R: ReadExt + Unpin,
{
    let mut bytes = [0; 2];
    if let Err(err) = s.read_exact(&mut bytes).await {
        return Err(ConnectError::IoError(err));
    }
    Ok(u16::from_be_bytes(bytes))
}

async fn write_all<W>(s: &mut W, v: &[u8]) -> Result<()>
where
    W: WriteExt + Unpin,
{
    if let Err(err) = s.write_all(v).await {
        return Err(ConnectError::IoError(err));
    }
    Ok(())
}

async fn read_exact<R>(s: &mut R, v: &mut [u8]) -> Result<()>
where
    R: ReadExt + Unpin,
{
    if let Err(err) = s.read_exact(v).await {
        return Err(ConnectError::IoError(err));
    }
    Ok(())
}

async fn read_to_end<R>(s: &mut R, v: &mut Vec<u8>) -> Result<()>
where
    R: ReadExt + Unpin,
{
    if let Err(err) = s.read_to_end(v).await {
        return Err(ConnectError::IoError(err));
    }
    Ok(())
}

async fn flush<W>(s: &mut W) -> Result<()>
where
    W: WriteExt + Unpin,
{
    if let Err(err) = s.flush().await {
        return Err(ConnectError::IoError(err));
    }
    Ok(())
}

/// Connects to an arbitrary network destination via a SOCKS5 server.
///
/// The SOCKS5 server is specified by a TCP socket
/// which is already connected to the SOCKS5 server.
///
/// # Examples
///
/// Connect to the web server at example.net:80
/// via a local Tor SOCKS5 proxy at 127.0.0.1:9050,
/// issue a GET command,
/// read and print the response:
///
/// ```no_run
/// # type Error = Box<dyn std::error::Error + Send + Sync>;
/// # type Result<T> = std::result::Result<T, Error>;
/// # async fn test_connect() -> Result<()> {
/// #
/// use async_std::{
///     io::{prelude::WriteExt, ReadExt},
///     net::TcpStream,
/// };
/// use koibumi_socks::{self as socks, DomainName, SocketDomainName};
///
/// let mut stream = TcpStream::connect("127.0.0.1:9050").await?;
///
/// let destination = socks::SocketAddr::DomainName(
///     SocketDomainName::new(
///         DomainName::new(b"example.net".to_vec()).unwrap(), 80));
///
/// let _dest = socks::connect(&mut stream, destination).await?;
///
/// stream.write_all(b"GET /\n").await?;
///
/// let mut bytes = Vec::new();
/// stream.read_to_end(&mut bytes).await?;
/// print!("{}", String::from_utf8_lossy(&bytes));
/// #
/// # Ok(())
/// # }
/// # async_std::task::block_on(test_connect());
/// ```
pub async fn connect<S>(server: &mut S, destination: SocketAddr) -> Result<SocketAddr>
where
    S: ReadExt + WriteExt + Unpin,
{
    // Send a version identifier/method selection message

    // VER
    write_u8(server, SOCKS_VERSION_5).await?;
    // NMETHODS
    write_u8(server, 1).await?;
    // METHODS
    write_u8(server, SOCKS_NO_AUTHENTICATION_REQUIRED).await?;

    flush(server).await?;

    // Receive response

    let ver = read_u8(server).await?;
    let method = read_u8(server).await?;
    if ver != SOCKS_VERSION_5 {
        return Err(ConnectError::UnsupportedVersion(ver));
    }
    if method != SOCKS_NO_AUTHENTICATION_REQUIRED {
        return Err(ConnectError::UnsupportedMethod(method));
    }

    // Send SOCKS request

    write_u8(server, SOCKS_VERSION_5).await?;
    write_u8(server, SOCKS_COMMAND_CONNECT).await?;
    write_u8(server, SOCKS_RESERVED).await?;
    match destination {
        SocketAddr::Ipv4(addr) => {
            write_u8(server, SOCKS_ADDRESS_IPV4).await?;
            write_all(server, addr.ip().octets().as_ref()).await?;
            write_u16(server, addr.port()).await?;
        }
        SocketAddr::DomainName(addr) => {
            write_u8(server, SOCKS_ADDRESS_DOMAIN_NAME).await?;
            write_u8(server, addr.domain_name().as_ref().len() as u8).await?;
            write_all(server, addr.domain_name().as_ref()).await?;
            write_u16(server, addr.port()).await?;
        }
        SocketAddr::Ipv6(addr) => {
            write_u8(server, SOCKS_ADDRESS_IPV6).await?;
            write_all(server, addr.ip().octets().as_ref()).await?;
            write_u16(server, addr.port()).await?;
        }
    }

    flush(server).await?;

    // Receive response

    let ver = read_u8(server).await?;
    let rep = read_u8(server).await?;
    if ver != SOCKS_VERSION_5 {
        return Err(ConnectError::UnsupportedVersion(ver));
    }
    match rep {
        SOCKS_REPLY_SUCCEEDED => {}
        SOCKS_REPLY_GENERAL_SOCKS_SERVER_FAILURE => return Err(ConnectError::GeneralServerFailure),
        SOCKS_REPLY_HOST_UNREACHABLE => return Err(ConnectError::HostUnreachable),
        SOCKS_REPLY_CONNECTION_REFUSED => return Err(ConnectError::ConnectionRefused),
        SOCKS_REPLY_TTL_EXPIRED => return Err(ConnectError::TtlExpired),
        SOCKS_REPLY_COMMAND_NOT_SUPPORTED => return Err(ConnectError::CommandNotSupported),
        SOCKS_REPLY_ADDRESS_TYPE_NOT_SUPPORTED => {
            return Err(ConnectError::AddressTypeNotSupported)
        }
        _ => return Err(ConnectError::UnknownFailure(rep)),
    }
    let _rsv = read_u8(server).await?;
    let atyp = read_u8(server).await?;
    match atyp {
        SOCKS_ADDRESS_IPV4 => {
            let mut bytes = [0; 4];
            read_exact(server, &mut bytes).await?;
            let port = read_u16(server).await?;
            Ok(SocketAddr::Ipv4(SocketAddrV4::new(bytes.into(), port)))
        }
        SOCKS_ADDRESS_DOMAIN_NAME => {
            let len = read_u8(server).await?;
            let mut r = server.take(len as u64);
            let mut bytes = Vec::with_capacity(len as usize);
            read_to_end(&mut r, &mut bytes).await?;
            let domain_name = DomainName::new(bytes).unwrap();
            let port = read_u16(server).await?;
            Ok(SocketAddr::DomainName(SocketDomainName::new(
                domain_name,
                port,
            )))
        }
        SOCKS_ADDRESS_IPV6 => {
            let mut bytes = [0; 16];
            read_exact(server, &mut bytes).await?;
            let port = read_u16(server).await?;
            Ok(SocketAddr::Ipv6(SocketAddrV6::new(
                bytes.into(),
                port,
                0,
                0,
            )))
        }
        _ => Err(ConnectError::UnsupportedAddressType(atyp)),
    }
}