rd-std 0.1.0

std for rabbit-digger
Documentation
use std::{convert::TryInto, io};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

use super::common::Address;

#[derive(Debug, Error)]
pub enum Error {
    #[error("Invalid version: {0}")]
    InvalidVersion(u8),
    #[error("Too many methods")]
    TooManyMethods,
    #[error("Invalid handshake")]
    InvalidHandshake,
    #[error("Invalid command: {0}")]
    InvalidCommand(u8),
    #[error("Invalid command reply: {0}")]
    InvalidCommandReply(u8),
    #[error("Command reply with error: {0:?}")]
    CommandReply(CommandReply),
    #[error("IO error: {0:?}")]
    Io(#[from] io::Error),
}
pub type Result<T, E = Error> = ::std::result::Result<T, E>;

#[derive(Debug)]
pub enum Version {
    V5,
}
impl Version {
    pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<Version> {
        let version = &mut [0u8];
        reader.read_exact(version).await?;
        match version[0] {
            5 => Ok(Version::V5),
            other => Err(Error::InvalidVersion(other)),
        }
    }
    pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
        let v = match self {
            Version::V5 => 5u8,
        };
        writer.write_all(&[v]).await?;
        Ok(())
    }
}

#[derive(Debug, Eq, PartialEq, Clone, Copy)]
pub enum AuthMethod {
    Noauth,
    Gssapi,
    UsernamePassword,
    NoAcceptableMethod,
    Other(u8),
}

impl From<u8> for AuthMethod {
    fn from(n: u8) -> Self {
        match n {
            0x00 => AuthMethod::Noauth,
            0x01 => AuthMethod::Gssapi,
            0x02 => AuthMethod::UsernamePassword,
            0xff => AuthMethod::NoAcceptableMethod,
            other => AuthMethod::Other(other),
        }
    }
}

impl Into<u8> for AuthMethod {
    fn into(self) -> u8 {
        match self {
            AuthMethod::Noauth => 0x00,
            AuthMethod::Gssapi => 0x01,
            AuthMethod::UsernamePassword => 0x02,
            AuthMethod::NoAcceptableMethod => 0xff,
            AuthMethod::Other(other) => other,
        }
    }
}

#[derive(Debug, Eq, PartialEq, Clone)]
pub struct AuthRequest(Vec<AuthMethod>);

impl AuthRequest {
    pub fn new(methods: impl Into<Vec<AuthMethod>>) -> AuthRequest {
        AuthRequest(methods.into())
    }
    pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<AuthRequest> {
        let count = &mut [0u8];
        reader.read_exact(count).await?;
        let mut methods = vec![0u8; count[0] as usize];
        reader.read_exact(&mut methods).await?;

        Ok(AuthRequest(methods.into_iter().map(Into::into).collect()))
    }
    pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
        let count = self.0.len();
        if count > 255 {
            return Err(Error::TooManyMethods);
        }

        writer.write_all(&[count as u8]).await?;
        writer
            .write_all(
                &self
                    .0
                    .iter()
                    .map(|i| Into::<u8>::into(*i))
                    .collect::<Vec<_>>(),
            )
            .await?;

        Ok(())
    }
    pub fn select_from(&self, auth: &[AuthMethod]) -> AuthMethod {
        self.0
            .iter()
            .enumerate()
            .find(|(_, m)| auth.contains(*m))
            .map(|(v, _)| AuthMethod::from(v as u8))
            .unwrap_or(AuthMethod::NoAcceptableMethod)
    }
}

#[derive(Debug, Eq, PartialEq, Clone)]
pub struct AuthResponse(AuthMethod);

impl AuthResponse {
    pub fn new(method: AuthMethod) -> AuthResponse {
        AuthResponse(method)
    }
    pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<AuthResponse> {
        let method = &mut [0u8];
        reader.read_exact(method).await?;
        Ok(AuthResponse(method[0].into()))
    }
    pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
        writer.write_all(&[self.0.into()]).await?;
        Ok(())
    }
    pub fn method(&self) -> AuthMethod {
        self.0
    }
}

#[derive(Debug)]
pub enum Command {
    Connect,
    Bind,
    UdpAssociate,
}
#[derive(Debug)]
pub struct CommandRequest {
    pub command: Command,
    pub address: Address,
}

impl CommandRequest {
    pub fn connect(address: Address) -> CommandRequest {
        CommandRequest {
            command: Command::Connect,
            address,
        }
    }
    pub fn udp_associate(address: Address) -> CommandRequest {
        CommandRequest {
            command: Command::UdpAssociate,
            address,
        }
    }
    pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<CommandRequest> {
        let buf = &mut [0u8; 3];
        reader.read_exact(buf).await?;
        if buf[0] != 5 {
            return Err(Error::InvalidVersion(buf[0]));
        }
        if buf[2] != 0 {
            return Err(Error::InvalidHandshake);
        }
        let cmd = match buf[1] {
            1 => Command::Connect,
            2 => Command::Bind,
            3 => Command::UdpAssociate,
            _ => return Err(Error::InvalidCommand(buf[1])),
        };

        let address = Address::read(reader).await?;

        Ok(CommandRequest {
            command: cmd,
            address,
        })
    }
    pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
        let cmd = match self.command {
            Command::Connect => 1u8,
            Command::Bind => 2,
            Command::UdpAssociate => 3,
        };
        writer.write_all(&[0x05, cmd, 0x00]).await?;
        self.address.write(writer).await?;
        Ok(())
    }
}

#[derive(Debug, PartialEq, PartialOrd)]
pub enum CommandReply {
    Succeeded,
    GeneralSocksServerFailure,
    ConnectionNotAllowedByRuleset,
    NetworkUnreachable,
    HostUnreachable,
    ConnectionRefused,
    TtlExpired,
    CommandNotSupported,
    AddressTypeNotSupported,
}

impl CommandReply {
    pub fn from_u8(n: u8) -> Result<CommandReply> {
        Ok(match n {
            0 => CommandReply::Succeeded,
            1 => CommandReply::GeneralSocksServerFailure,
            2 => CommandReply::ConnectionNotAllowedByRuleset,
            3 => CommandReply::NetworkUnreachable,
            4 => CommandReply::HostUnreachable,
            5 => CommandReply::ConnectionRefused,
            6 => CommandReply::TtlExpired,
            7 => CommandReply::CommandNotSupported,
            8 => CommandReply::AddressTypeNotSupported,
            _ => return Err(Error::InvalidCommandReply(n)),
        })
    }
    pub fn to_u8(&self) -> u8 {
        match self {
            CommandReply::Succeeded => 0,
            CommandReply::GeneralSocksServerFailure => 1,
            CommandReply::ConnectionNotAllowedByRuleset => 2,
            CommandReply::NetworkUnreachable => 3,
            CommandReply::HostUnreachable => 4,
            CommandReply::ConnectionRefused => 5,
            CommandReply::TtlExpired => 6,
            CommandReply::CommandNotSupported => 7,
            CommandReply::AddressTypeNotSupported => 8,
        }
    }
}

#[derive(Debug)]
pub struct CommandResponse {
    pub reply: CommandReply,
    pub address: Address,
}

impl CommandResponse {
    pub fn success(address: Address) -> CommandResponse {
        CommandResponse {
            reply: CommandReply::Succeeded,
            address,
        }
    }
    pub fn reply_error(reply: CommandReply) -> CommandResponse {
        CommandResponse {
            reply,
            address: Default::default(),
        }
    }
    pub fn error(e: impl TryInto<io::Error>) -> CommandResponse {
        match e.try_into() {
            Ok(v) => {
                use io::ErrorKind;
                let reply = match v.kind() {
                    ErrorKind::ConnectionRefused => CommandReply::ConnectionRefused,
                    _ => CommandReply::GeneralSocksServerFailure,
                };
                CommandResponse {
                    reply,
                    address: Default::default(),
                }
            }
            Err(_) => CommandResponse {
                reply: CommandReply::GeneralSocksServerFailure,
                address: Default::default(),
            },
        }
    }
    pub async fn read(mut reader: impl AsyncRead + Unpin) -> Result<CommandResponse> {
        let buf = &mut [0u8; 3];
        reader.read_exact(buf).await?;
        if buf[0] != 5 {
            return Err(Error::InvalidVersion(buf[0]));
        }
        if buf[2] != 0 {
            return Err(Error::InvalidHandshake);
        }
        let reply = CommandReply::from_u8(buf[1])?;

        let address = Address::read(reader).await?;

        if reply != CommandReply::Succeeded {
            return Err(Error::CommandReply(reply));
        }

        Ok(CommandResponse { reply, address })
    }
    pub async fn write(&self, mut writer: impl AsyncWrite + Unpin) -> Result<()> {
        writer.write_all(&[0x05, self.reply.to_u8(), 0x00]).await?;
        self.address.write(writer).await?;
        Ok(())
    }
}