koibumi-socks 0.0.5

A minimal SOCKS5 client library
Documentation
//! This crate is a minimal SOCKS5 client library.
//!
//! The library is usable in async context.
//! The library is intended to be used with a local Tor SOCKS5 proxy.
//!
//! # Examples
//!
//! Connect to the web server at example.org:80
//! via a local Tor SOCKS5 proxy at 127.0.0.1:9050,
//! issue a GET command,
//! read and print the response:
//!
//! ```no_run
//! # type Error = Box<dyn std::error::Error + Send + Sync>;
//! # type Result<T> = std::result::Result<T, Error>;
//! # async fn test_connect() -> Result<()> {
//! #
//! use async_std::{
//!     io::{prelude::WriteExt, ReadExt},
//!     net::TcpStream,
//! };
//! use koibumi_net::{
//!     domain::{Domain, SocketDomain},
//!     socks::SocketAddr as SocksSocketAddr,
//! };
//! use koibumi_socks as socks;
//!
//! let mut stream = TcpStream::connect("127.0.0.1:9050").await?;
//!
//! let destination = SocksSocketAddr::Domain(
//!     SocketDomain::new(
//!         Domain::new("example.org").unwrap(), 80.into()));
//!
//! let _dest = socks::connect(&mut stream, destination).await?;
//!
//! stream.write_all(b"GET /\n").await?;
//!
//! let mut bytes = Vec::new();
//! stream.read_to_end(&mut bytes).await?;
//! print!("{}", String::from_utf8_lossy(&bytes));
//! #
//! # Ok(())
//! # }
//! # async_std::task::block_on(test_connect());
//! ```

// See RFC 1928 SOCKS Protocol Version 5

#![deny(unsafe_code)]
#![warn(missing_docs)]

use std::{
    fmt,
    net::{SocketAddrV4, SocketAddrV6},
};

use futures::{
    io::{self, AsyncRead, AsyncWrite},
    prelude::*,
};

use koibumi_net::{
    domain::{Domain, SocketDomain},
    socks::SocketAddr,
};

const SOCKS_VERSION_5: u8 = 0x05;
const SOCKS_NO_AUTHENTICATION_REQUIRED: u8 = 0x00;
const SOCKS_USERNAME_AND_PASSWORD: u8 = 0x02;
const SOCKS_COMMAND_CONNECT: u8 = 0x01;
const SOCKS_RESERVED: u8 = 0x00;
const SOCKS_ADDRESS_IPV4: u8 = 0x01;
const SOCKS_ADDRESS_DOMAIN_NAME: u8 = 0x03;
const SOCKS_ADDRESS_IPV6: u8 = 0x04;
const SOCKS_REPLY_SUCCEEDED: u8 = 0x00;
const SOCKS_REPLY_GENERAL_SOCKS_SERVER_FAILURE: u8 = 0x01;
const SOCKS_REPLY_HOST_UNREACHABLE: u8 = 0x04;
const SOCKS_REPLY_CONNECTION_REFUSED: u8 = 0x05;
const SOCKS_REPLY_TTL_EXPIRED: u8 = 0x06;
const SOCKS_REPLY_COMMAND_NOT_SUPPORTED: u8 = 0x07;
const SOCKS_REPLY_ADDRESS_TYPE_NOT_SUPPORTED: u8 = 0x08;

const SOCKS_SUBNEGOTIATION_VERSION: u8 = 0x01;
const SOCKS_SUBNEGOTIATION_REPLY_SUCCEEDED: u8 = 0x00;

/// An authentication method.
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
enum Auth {
    /// No authentication required.
    None,
    /// Username/password.
    Password {
        /// Username.
        username: Vec<u8>,
        /// Password.
        password: Vec<u8>,
    },
}

/// An error which can be returned when connecting to a destination host
/// via SOCKS5 proxy server.
///
/// This error is used as the error type for the [`connect`] function.
///
/// [`connect`]: fn.connect.html
#[derive(Debug)]
pub enum ConnectError {
    /// The server returned a version number that is not supported by this client.
    /// The actual version number received is returned as a payload of this variant.
    UnsupportedVersion(u8),

    /// The server selected a method that is not supported by this client.
    /// The actual method selected is returned as a payload of this variant.
    /// `0xff` means that the server said none of the methods listed by the client
    /// were acceptable.
    UnsupportedMethod(u8),

    /// General SOCKS server failure.
    GeneralServerFailure,

    /// Host unreachable.
    HostUnreachable,

    /// Connection refused.
    ConnectionRefused,

    /// TTL expired.
    TtlExpired,

    /// Command not supported.
    CommandNotSupported,

    /// Address type not supported.
    AddressTypeNotSupported,

    /// Unknown failure.
    /// The actual reply value received is returned as a payload of this variant.
    UnknownFailure(u8),

    /// The server returned a address type that is not supported by this client.
    /// The actual address type received is returned as a payload of this variant.
    UnsupportedAddressType(u8),

    /// A standard I/O error was caught during communication with the server.
    /// The actual error caught is returned as a payload of this variant.
    IoError(io::Error),

    /// The length of the username was invalid.
    /// The actual length supplied is returned as a payload of this variant.
    InvalidUsernameLength(usize),
    /// The length of the password was invalid.
    /// The actual length supplied is returned as a payload of this variant.
    InvalidPasswordLength(usize),
    /// The server returned a subnegotiation version number that is not supported by this client.
    /// The actual version number received is returned as a payload of this variant.
    UnsupportedSubnegotiationVersion(u8),
    /// Authentication failure.
    AuthenticationFailure,
}

impl fmt::Display for ConnectError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::UnsupportedVersion(ver) => write!(f, "Unsupported SOCKS version: {:#02x}", ver),
            Self::UnsupportedMethod(method) => {
                write!(f, "Unsupported SOCKS method: {:#02x}", method)
            }
            Self::GeneralServerFailure => "General SOCKS server failure".fmt(f),
            Self::HostUnreachable => "Host unreachable".fmt(f),
            Self::ConnectionRefused => "Connection refused".fmt(f),
            Self::TtlExpired => "TTL expired".fmt(f),
            Self::CommandNotSupported => "Command not supported".fmt(f),
            Self::AddressTypeNotSupported => "Address type not supported".fmt(f),
            Self::UnknownFailure(rep) => write!(f, "Unknown SOCKS failure: {:#02x}", rep),
            Self::UnsupportedAddressType(atyp) => {
                write!(f, "Unsupported address type: {:#02x}", atyp)
            }

            Self::IoError(err) => err.fmt(f),

            Self::InvalidUsernameLength(len) => {
                write!(f, "username length must be 1..255, but {}", len)
            }
            Self::InvalidPasswordLength(len) => {
                write!(f, "password length must be 1..255, but {}", len)
            }
            Self::UnsupportedSubnegotiationVersion(ver) => {
                write!(f, "Unsupported SOCKS subnegotiation version: {:#02x}", ver)
            }
            Self::AuthenticationFailure => "authentication failure".fmt(f),
        }
    }
}

impl std::error::Error for ConnectError {}

impl From<io::Error> for ConnectError {
    fn from(err: io::Error) -> Self {
        ConnectError::IoError(err)
    }
}

/// A specialized Result type for SOCKS5 operations.
pub type Result<T> = std::result::Result<T, ConnectError>;

async fn read_u8<R>(s: &mut R) -> Result<u8>
where
    R: AsyncRead + Unpin,
{
    let mut bytes = [0; 1];
    s.read_exact(&mut bytes).await?;
    Ok(bytes[0])
}

async fn read_u16<R>(s: &mut R) -> Result<u16>
where
    R: AsyncRead + Unpin,
{
    let mut bytes = [0; 2];
    s.read_exact(&mut bytes).await?;
    Ok(u16::from_be_bytes(bytes))
}

/// Connects to an arbitrary network destination via a SOCKS5 server.
///
/// The SOCKS5 server is specified by a TCP socket
/// which is already connected to the SOCKS5 server.
///
/// # Examples
///
/// Connect to the web server at example.org:80
/// via a local Tor SOCKS5 proxy at 127.0.0.1:9050,
/// issue a GET command,
/// read and print the response:
///
/// ```no_run
/// # type Error = Box<dyn std::error::Error + Send + Sync>;
/// # type Result<T> = std::result::Result<T, Error>;
/// # async fn test_connect() -> Result<()> {
/// #
/// use async_std::{
///     io::{prelude::WriteExt, ReadExt},
///     net::TcpStream,
/// };
/// use koibumi_net::{
///     domain::{Domain, SocketDomain},
///     socks::SocketAddr as SocksSocketAddr,
/// };
/// use koibumi_socks as socks;
///
/// let mut stream = TcpStream::connect("127.0.0.1:9050").await?;
///
/// let destination = SocksSocketAddr::Domain(
///     SocketDomain::new(
///         Domain::new("example.org").unwrap(), 80.into()));
///
/// let _dest = socks::connect(&mut stream, destination).await?;
///
/// stream.write_all(b"GET /\n").await?;
///
/// let mut bytes = Vec::new();
/// stream.read_to_end(&mut bytes).await?;
/// print!("{}", String::from_utf8_lossy(&bytes));
/// #
/// # Ok(())
/// # }
/// # async_std::task::block_on(test_connect());
/// ```
pub async fn connect<S>(server: &mut S, destination: SocketAddr) -> Result<SocketAddr>
where
    S: AsyncRead + AsyncWrite + Unpin,
{
    connect_with_auth(server, Auth::None, destination).await
}

/// Connects to an arbitrary network destination via a SOCKS5 server.
///
/// The SOCKS5 server is specified by a TCP socket
/// which is already connected to the SOCKS5 server.
///
/// The username/password authentication method is used.
/// A username and a password are specified by arguments.
pub async fn connect_with_password<S>(
    server: &mut S,
    username: impl AsRef<[u8]>,
    password: impl AsRef<[u8]>,
    destination: SocketAddr,
) -> Result<SocketAddr>
where
    S: AsyncRead + AsyncWrite + Unpin,
{
    connect_with_auth(
        server,
        Auth::Password {
            username: username.as_ref().to_vec(),
            password: password.as_ref().to_vec(),
        },
        destination,
    )
    .await
}

/// Connects to an arbitrary network destination via a SOCKS5 server.
///
/// The SOCKS5 server is specified by a TCP socket
/// which is already connected to the SOCKS5 server.
///
/// Authentication method is specified by `Auth` object.
#[allow(clippy::len_zero)]
async fn connect_with_auth<S>(
    server: &mut S,
    auth: Auth,
    destination: SocketAddr,
) -> Result<SocketAddr>
where
    S: AsyncRead + AsyncWrite + Unpin,
{
    // Check parameters

    if let Auth::Password { username, password } = &auth {
        if username.len() < 1 || username.len() > 255 {
            return Err(ConnectError::InvalidUsernameLength(username.len()));
        }
        if password.len() < 1 || password.len() > 255 {
            return Err(ConnectError::InvalidPasswordLength(password.len()));
        }
    }

    // Send a version identifier/method selection message

    let mut packet: Vec<u8> = Vec::with_capacity(3);
    // VER
    packet.push(SOCKS_VERSION_5);
    // NMETHODS
    packet.push(1);
    // METHODS
    let requested_method = match auth {
        Auth::None => SOCKS_NO_AUTHENTICATION_REQUIRED,
        Auth::Password { .. } => SOCKS_USERNAME_AND_PASSWORD,
    };
    packet.push(requested_method);

    server.write_all(&packet).await?;
    server.flush().await?;

    // Receive response

    let ver = read_u8(server).await?;
    let method = read_u8(server).await?;
    if ver != SOCKS_VERSION_5 {
        return Err(ConnectError::UnsupportedVersion(ver));
    }
    if method != requested_method {
        return Err(ConnectError::UnsupportedMethod(method));
    }

    // Username/password authentication

    if let Auth::Password { username, password } = auth {
        let mut packet = Vec::new();
        packet.push(SOCKS_SUBNEGOTIATION_VERSION);
        packet.push(username.len() as u8);
        packet.extend_from_slice(&username);
        packet.push(password.len() as u8);
        packet.extend_from_slice(&password);

        server.write_all(&packet).await?;
        server.flush().await?;

        let ver = read_u8(server).await?;
        let status = read_u8(server).await?;
        if ver != SOCKS_SUBNEGOTIATION_VERSION {
            return Err(ConnectError::UnsupportedSubnegotiationVersion(ver));
        }
        if status != SOCKS_SUBNEGOTIATION_REPLY_SUCCEEDED {
            return Err(ConnectError::AuthenticationFailure);
        }
    }

    // Send SOCKS request

    let mut packet: Vec<u8> = Vec::new();
    packet.push(SOCKS_VERSION_5);
    packet.push(SOCKS_COMMAND_CONNECT);
    packet.push(SOCKS_RESERVED);
    match destination {
        SocketAddr::Ipv4(addr) => {
            packet.push(SOCKS_ADDRESS_IPV4);
            packet.extend_from_slice(&addr.ip().octets());
            packet.extend_from_slice(&addr.port().to_be_bytes());
        }
        SocketAddr::Domain(domain) => {
            packet.push(SOCKS_ADDRESS_DOMAIN_NAME);
            packet.push(domain.domain().as_ref().len() as u8);
            packet.extend_from_slice(domain.domain().as_bytes());
            packet.extend_from_slice(&domain.port().as_u16().to_be_bytes());
        }
        SocketAddr::Ipv6(addr) => {
            packet.push(SOCKS_ADDRESS_IPV6);
            packet.extend_from_slice(&addr.ip().octets());
            packet.extend_from_slice(&addr.port().to_be_bytes());
        }
    }

    server.write_all(&packet).await?;
    server.flush().await?;

    // Receive response

    let ver = read_u8(server).await?;
    let rep = read_u8(server).await?;
    if ver != SOCKS_VERSION_5 {
        return Err(ConnectError::UnsupportedVersion(ver));
    }
    match rep {
        SOCKS_REPLY_SUCCEEDED => {}
        SOCKS_REPLY_GENERAL_SOCKS_SERVER_FAILURE => return Err(ConnectError::GeneralServerFailure),
        SOCKS_REPLY_HOST_UNREACHABLE => return Err(ConnectError::HostUnreachable),
        SOCKS_REPLY_CONNECTION_REFUSED => return Err(ConnectError::ConnectionRefused),
        SOCKS_REPLY_TTL_EXPIRED => return Err(ConnectError::TtlExpired),
        SOCKS_REPLY_COMMAND_NOT_SUPPORTED => return Err(ConnectError::CommandNotSupported),
        SOCKS_REPLY_ADDRESS_TYPE_NOT_SUPPORTED => {
            return Err(ConnectError::AddressTypeNotSupported)
        }
        _ => return Err(ConnectError::UnknownFailure(rep)),
    }
    let _rsv = read_u8(server).await?;
    let atyp = read_u8(server).await?;
    match atyp {
        SOCKS_ADDRESS_IPV4 => {
            let mut bytes = [0; 4];
            server.read_exact(&mut bytes).await?;
            let port = read_u16(server).await?;
            Ok(SocketAddr::Ipv4(SocketAddrV4::new(bytes.into(), port)))
        }
        SOCKS_ADDRESS_DOMAIN_NAME => {
            let len = read_u8(server).await?;
            let mut r = server.take(len as u64);
            let mut bytes = Vec::with_capacity(len as usize);
            r.read_to_end(&mut bytes).await?;
            let domain = Domain::from_bytes(&bytes).unwrap();
            let port = read_u16(server).await?;
            Ok(SocketAddr::Domain(SocketDomain::new(domain, port.into())))
        }
        SOCKS_ADDRESS_IPV6 => {
            let mut bytes = [0; 16];
            server.read_exact(&mut bytes).await?;
            let port = read_u16(server).await?;
            Ok(SocketAddr::Ipv6(SocketAddrV6::new(
                bytes.into(),
                port,
                0,
                0,
            )))
        }
        _ => Err(ConnectError::UnsupportedAddressType(atyp)),
    }
}