use std::{
convert::From,
fmt::{self, Debug, Display, Formatter},
io::{self, ErrorKind},
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
slice,
str::FromStr,
vec,
};
use bytes::{Buf, BufMut, BytesMut};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
pub use self::consts::{
SOCKS5_AUTH_METHOD_GSSAPI,
SOCKS5_AUTH_METHOD_NONE,
SOCKS5_AUTH_METHOD_NOT_ACCEPTABLE,
SOCKS5_AUTH_METHOD_PASSWORD,
};
#[rustfmt::skip]
mod consts {
pub const SOCKS5_VERSION: u8 = 0x05;
pub const SOCKS5_AUTH_METHOD_NONE: u8 = 0x00;
pub const SOCKS5_AUTH_METHOD_GSSAPI: u8 = 0x01;
pub const SOCKS5_AUTH_METHOD_PASSWORD: u8 = 0x02;
pub const SOCKS5_AUTH_METHOD_NOT_ACCEPTABLE: u8 = 0xff;
pub const SOCKS5_CMD_TCP_CONNECT: u8 = 0x01;
pub const SOCKS5_CMD_TCP_BIND: u8 = 0x02;
pub const SOCKS5_CMD_UDP_ASSOCIATE: u8 = 0x03;
pub const SOCKS5_ADDR_TYPE_IPV4: u8 = 0x01;
pub const SOCKS5_ADDR_TYPE_DOMAIN_NAME: u8 = 0x03;
pub const SOCKS5_ADDR_TYPE_IPV6: u8 = 0x04;
pub const SOCKS5_REPLY_SUCCEEDED: u8 = 0x00;
pub const SOCKS5_REPLY_GENERAL_FAILURE: u8 = 0x01;
pub const SOCKS5_REPLY_CONNECTION_NOT_ALLOWED: u8 = 0x02;
pub const SOCKS5_REPLY_NETWORK_UNREACHABLE: u8 = 0x03;
pub const SOCKS5_REPLY_HOST_UNREACHABLE: u8 = 0x04;
pub const SOCKS5_REPLY_CONNECTION_REFUSED: u8 = 0x05;
pub const SOCKS5_REPLY_TTL_EXPIRED: u8 = 0x06;
pub const SOCKS5_REPLY_COMMAND_NOT_SUPPORTED: u8 = 0x07;
pub const SOCKS5_REPLY_ADDRESS_TYPE_NOT_SUPPORTED: u8 = 0x08;
}
#[derive(Clone, Debug, Copy)]
pub enum Command {
TcpConnect,
TcpBind,
UdpAssociate,
}
impl Command {
#[inline]
#[rustfmt::skip]
fn as_u8(self) -> u8 {
match self {
Command::TcpConnect => consts::SOCKS5_CMD_TCP_CONNECT,
Command::TcpBind => consts::SOCKS5_CMD_TCP_BIND,
Command::UdpAssociate => consts::SOCKS5_CMD_UDP_ASSOCIATE,
}
}
#[inline]
#[rustfmt::skip]
fn from_u8(code: u8) -> Option<Command> {
match code {
consts::SOCKS5_CMD_TCP_CONNECT => Some(Command::TcpConnect),
consts::SOCKS5_CMD_TCP_BIND => Some(Command::TcpBind),
consts::SOCKS5_CMD_UDP_ASSOCIATE => Some(Command::UdpAssociate),
_ => None,
}
}
}
#[derive(Clone, Debug, Copy)]
pub enum Reply {
Succeeded,
GeneralFailure,
ConnectionNotAllowed,
NetworkUnreachable,
HostUnreachable,
ConnectionRefused,
TtlExpired,
CommandNotSupported,
AddressTypeNotSupported,
OtherReply(u8),
}
impl Reply {
#[inline]
#[rustfmt::skip]
pub fn as_u8(self) -> u8 {
match self {
Reply::Succeeded => consts::SOCKS5_REPLY_SUCCEEDED,
Reply::GeneralFailure => consts::SOCKS5_REPLY_GENERAL_FAILURE,
Reply::ConnectionNotAllowed => consts::SOCKS5_REPLY_CONNECTION_NOT_ALLOWED,
Reply::NetworkUnreachable => consts::SOCKS5_REPLY_NETWORK_UNREACHABLE,
Reply::HostUnreachable => consts::SOCKS5_REPLY_HOST_UNREACHABLE,
Reply::ConnectionRefused => consts::SOCKS5_REPLY_CONNECTION_REFUSED,
Reply::TtlExpired => consts::SOCKS5_REPLY_TTL_EXPIRED,
Reply::CommandNotSupported => consts::SOCKS5_REPLY_COMMAND_NOT_SUPPORTED,
Reply::AddressTypeNotSupported => consts::SOCKS5_REPLY_ADDRESS_TYPE_NOT_SUPPORTED,
Reply::OtherReply(c) => c,
}
}
#[inline]
#[rustfmt::skip]
pub fn from_u8(code: u8) -> Reply {
match code {
consts::SOCKS5_REPLY_SUCCEEDED => Reply::Succeeded,
consts::SOCKS5_REPLY_GENERAL_FAILURE => Reply::GeneralFailure,
consts::SOCKS5_REPLY_CONNECTION_NOT_ALLOWED => Reply::ConnectionNotAllowed,
consts::SOCKS5_REPLY_NETWORK_UNREACHABLE => Reply::NetworkUnreachable,
consts::SOCKS5_REPLY_HOST_UNREACHABLE => Reply::HostUnreachable,
consts::SOCKS5_REPLY_CONNECTION_REFUSED => Reply::ConnectionRefused,
consts::SOCKS5_REPLY_TTL_EXPIRED => Reply::TtlExpired,
consts::SOCKS5_REPLY_COMMAND_NOT_SUPPORTED => Reply::CommandNotSupported,
consts::SOCKS5_REPLY_ADDRESS_TYPE_NOT_SUPPORTED => Reply::AddressTypeNotSupported,
_ => Reply::OtherReply(code),
}
}
}
impl fmt::Display for Reply {
#[rustfmt::skip]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Reply::Succeeded => write!(f, "Succeeded"),
Reply::AddressTypeNotSupported => write!(f, "Address type not supported"),
Reply::CommandNotSupported => write!(f, "Command not supported"),
Reply::ConnectionNotAllowed => write!(f, "Connection not allowed"),
Reply::ConnectionRefused => write!(f, "Connection refused"),
Reply::GeneralFailure => write!(f, "General failure"),
Reply::HostUnreachable => write!(f, "Host unreachable"),
Reply::NetworkUnreachable => write!(f, "Network unreachable"),
Reply::OtherReply(u) => write!(f, "Other reply ({u})"),
Reply::TtlExpired => write!(f, "TTL expired"),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("{0}")]
IoError(#[from] io::Error),
#[error("address type {0:#x} not supported")]
AddressTypeNotSupported(u8),
#[error("address domain name must be UTF-8 encoding")]
AddressDomainInvalidEncoding,
#[error("unsupported socks version {0:#x}")]
UnsupportedSocksVersion(u8),
#[error("unsupported command {0:#x}")]
UnsupportedCommand(u8),
#[error("unsupported username/password authentication version {0:#x}")]
UnsupportedPasswdAuthVersion(u8),
#[error("username/password authentication invalid request")]
PasswdAuthInvalidRequest,
#[error("{0}")]
Reply(Reply),
}
impl From<Error> for io::Error {
fn from(err: Error) -> io::Error {
match err {
Error::IoError(err) => err,
e => io::Error::new(ErrorKind::Other, e),
}
}
}
impl Error {
pub fn as_reply(&self) -> Reply {
match *self {
Error::IoError(ref err) => match err.kind() {
ErrorKind::ConnectionRefused => Reply::ConnectionRefused,
_ => Reply::GeneralFailure,
},
Error::AddressTypeNotSupported(..) => Reply::AddressTypeNotSupported,
Error::AddressDomainInvalidEncoding => Reply::GeneralFailure,
Error::UnsupportedSocksVersion(..) => Reply::GeneralFailure,
Error::UnsupportedCommand(..) => Reply::CommandNotSupported,
Error::UnsupportedPasswdAuthVersion(..) => Reply::GeneralFailure,
Error::PasswdAuthInvalidRequest => Reply::GeneralFailure,
Error::Reply(r) => r,
}
}
}
#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum Address {
SocketAddress(SocketAddr),
DomainNameAddress(String, u16),
}
impl Address {
pub fn read_cursor<T: AsRef<[u8]>>(cur: &mut io::Cursor<T>) -> Result<Address, Error> {
if cur.remaining() < 2 {
return Err(io::Error::new(io::ErrorKind::Other, "invalid buf").into());
}
let atyp = cur.get_u8();
match atyp {
consts::SOCKS5_ADDR_TYPE_IPV4 => {
if cur.remaining() < 4 + 2 {
return Err(io::Error::new(io::ErrorKind::Other, "invalid buf").into());
}
let addr = Ipv4Addr::from(cur.get_u32());
let port = cur.get_u16();
Ok(Address::SocketAddress(SocketAddr::V4(SocketAddrV4::new(addr, port))))
}
consts::SOCKS5_ADDR_TYPE_IPV6 => {
if cur.remaining() < 16 + 2 {
return Err(io::Error::new(io::ErrorKind::Other, "invalid buf").into());
}
let addr = Ipv6Addr::from(cur.get_u128());
let port = cur.get_u16();
Ok(Address::SocketAddress(SocketAddr::V6(SocketAddrV6::new(
addr, port, 0, 0,
))))
}
consts::SOCKS5_ADDR_TYPE_DOMAIN_NAME => {
let domain_len = cur.get_u8() as usize;
if cur.remaining() < domain_len {
return Err(Error::AddressDomainInvalidEncoding);
}
let mut buf = vec![0u8; domain_len];
cur.copy_to_slice(&mut buf);
let port = cur.get_u16();
let addr = String::from_utf8(buf).map_err(|_| Error::AddressDomainInvalidEncoding)?;
Ok(Address::DomainNameAddress(addr, port))
}
_ => Err(Error::AddressTypeNotSupported(atyp)),
}
}
pub async fn read_from<R>(stream: &mut R) -> Result<Address, Error>
where
R: AsyncRead + Unpin,
{
let mut addr_type_buf = [0u8; 1];
let _ = stream.read_exact(&mut addr_type_buf).await?;
let addr_type = addr_type_buf[0];
match addr_type {
consts::SOCKS5_ADDR_TYPE_IPV4 => {
let mut buf = [0u8; 6];
let _ = stream.read_exact(&mut buf).await?;
let v4addr = Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]);
let port = u16::from_be_bytes([buf[4], buf[5]]);
Ok(Address::SocketAddress(SocketAddr::V4(SocketAddrV4::new(v4addr, port))))
}
consts::SOCKS5_ADDR_TYPE_IPV6 => {
let mut buf = [0u16; 9];
let bytes_buf = unsafe { slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut _, 18) };
let _ = stream.read_exact(bytes_buf).await?;
let v6addr = Ipv6Addr::new(
u16::from_be(buf[0]),
u16::from_be(buf[1]),
u16::from_be(buf[2]),
u16::from_be(buf[3]),
u16::from_be(buf[4]),
u16::from_be(buf[5]),
u16::from_be(buf[6]),
u16::from_be(buf[7]),
);
let port = u16::from_be(buf[8]);
Ok(Address::SocketAddress(SocketAddr::V6(SocketAddrV6::new(
v6addr, port, 0, 0,
))))
}
consts::SOCKS5_ADDR_TYPE_DOMAIN_NAME => {
let mut length_buf = [0u8; 1];
let _ = stream.read_exact(&mut length_buf).await?;
let length = length_buf[0] as usize;
let buf_length = length + 2;
let mut raw_addr = vec![0u8; buf_length];
let _ = stream.read_exact(&mut raw_addr).await?;
let raw_port = &raw_addr[length..];
let port = u16::from_be_bytes([raw_port[0], raw_port[1]]);
raw_addr.truncate(length);
let addr = match String::from_utf8(raw_addr) {
Ok(addr) => addr,
Err(..) => return Err(Error::AddressDomainInvalidEncoding),
};
Ok(Address::DomainNameAddress(addr, port))
}
_ => {
Err(Error::AddressTypeNotSupported(addr_type))
}
}
}
#[inline]
pub async fn write_to<W>(&self, writer: &mut W) -> io::Result<()>
where
W: AsyncWrite + Unpin,
{
let mut buf = BytesMut::with_capacity(self.serialized_len());
self.write_to_buf(&mut buf);
writer.write_all(&buf).await
}
#[inline]
pub fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
write_address(self, buf)
}
#[inline]
pub fn serialized_len(&self) -> usize {
get_addr_len(self)
}
#[inline]
pub fn max_serialized_len() -> usize {
1 + 1 + u8::MAX as usize + 2 }
pub fn port(&self) -> u16 {
match *self {
Address::SocketAddress(addr) => addr.port(),
Address::DomainNameAddress(.., port) => port,
}
}
pub fn host(&self) -> String {
match *self {
Address::SocketAddress(ref addr) => addr.ip().to_string(),
Address::DomainNameAddress(ref domain, ..) => domain.to_owned(),
}
}
}
impl Debug for Address {
#[inline]
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match *self {
Address::SocketAddress(ref addr) => write!(f, "{addr}"),
Address::DomainNameAddress(ref addr, ref port) => write!(f, "{addr}:{port}"),
}
}
}
impl fmt::Display for Address {
#[inline]
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match *self {
Address::SocketAddress(ref addr) => write!(f, "{addr}"),
Address::DomainNameAddress(ref addr, ref port) => write!(f, "{addr}:{port}"),
}
}
}
impl ToSocketAddrs for Address {
type Iter = vec::IntoIter<SocketAddr>;
fn to_socket_addrs(&self) -> io::Result<vec::IntoIter<SocketAddr>> {
match self.clone() {
Address::SocketAddress(addr) => Ok(vec![addr].into_iter()),
Address::DomainNameAddress(addr, port) => (&addr[..], port).to_socket_addrs(),
}
}
}
impl From<SocketAddr> for Address {
fn from(s: SocketAddr) -> Address {
Address::SocketAddress(s)
}
}
impl From<(String, u16)> for Address {
fn from((dn, port): (String, u16)) -> Address {
Address::DomainNameAddress(dn, port)
}
}
impl From<&Address> for Address {
fn from(addr: &Address) -> Address {
addr.clone()
}
}
#[derive(Debug)]
pub struct AddressError;
impl Display for AddressError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str("invalid Address")
}
}
impl std::error::Error for AddressError {}
impl FromStr for Address {
type Err = AddressError;
fn from_str(s: &str) -> Result<Address, AddressError> {
match s.parse::<SocketAddr>() {
Ok(addr) => Ok(Address::SocketAddress(addr)),
Err(..) => {
let mut sp = s.split(':');
match (sp.next(), sp.next()) {
(Some(dn), Some(port)) => match port.parse::<u16>() {
Ok(port) => Ok(Address::DomainNameAddress(dn.to_owned(), port)),
Err(..) => Err(AddressError),
},
(Some(dn), None) => {
Ok(Address::DomainNameAddress(dn.to_owned(), 80))
}
_ => Err(AddressError),
}
}
}
}
}
fn write_ipv4_address<B: BufMut>(addr: &SocketAddrV4, buf: &mut B) {
buf.put_u8(consts::SOCKS5_ADDR_TYPE_IPV4); buf.put_slice(&addr.ip().octets()); buf.put_u16(addr.port()); }
fn write_ipv6_address<B: BufMut>(addr: &SocketAddrV6, buf: &mut B) {
buf.put_u8(consts::SOCKS5_ADDR_TYPE_IPV6); for seg in &addr.ip().segments() {
buf.put_u16(*seg); }
buf.put_u16(addr.port()); }
fn write_domain_name_address<B: BufMut>(dnaddr: &str, port: u16, buf: &mut B) {
assert!(dnaddr.len() <= u8::MAX as usize);
buf.put_u8(consts::SOCKS5_ADDR_TYPE_DOMAIN_NAME);
assert!(
dnaddr.len() <= u8::MAX as usize,
"domain name length must be smaller than 256"
);
buf.put_u8(dnaddr.len() as u8);
buf.put_slice(dnaddr[..].as_bytes());
buf.put_u16(port);
}
fn write_socket_address<B: BufMut>(addr: &SocketAddr, buf: &mut B) {
match *addr {
SocketAddr::V4(ref addr) => write_ipv4_address(addr, buf),
SocketAddr::V6(ref addr) => write_ipv6_address(addr, buf),
}
}
fn write_address<B: BufMut>(addr: &Address, buf: &mut B) {
match *addr {
Address::SocketAddress(ref addr) => write_socket_address(addr, buf),
Address::DomainNameAddress(ref dnaddr, ref port) => write_domain_name_address(dnaddr, *port, buf),
}
}
#[inline]
fn get_addr_len(atyp: &Address) -> usize {
match *atyp {
Address::SocketAddress(SocketAddr::V4(..)) => 1 + 4 + 2,
Address::SocketAddress(SocketAddr::V6(..)) => 1 + 8 * 2 + 2,
Address::DomainNameAddress(ref dmname, _) => 1 + 1 + dmname.len() + 2,
}
}
#[derive(Clone, Debug)]
pub struct TcpRequestHeader {
pub command: Command,
pub address: Address,
}
impl TcpRequestHeader {
pub fn new(cmd: Command, addr: Address) -> TcpRequestHeader {
TcpRequestHeader {
command: cmd,
address: addr,
}
}
pub async fn read_from<R>(r: &mut R) -> Result<TcpRequestHeader, Error>
where
R: AsyncRead + Unpin,
{
let mut buf = [0u8; 3];
let _ = r.read_exact(&mut buf).await?;
let ver = buf[0];
if ver != consts::SOCKS5_VERSION {
return Err(Error::UnsupportedSocksVersion(ver));
}
let cmd = buf[1];
let command = match Command::from_u8(cmd) {
Some(c) => c,
None => return Err(Error::UnsupportedCommand(cmd)),
};
let address = Address::read_from(r).await?;
Ok(TcpRequestHeader { command, address })
}
pub async fn write_to<W>(&self, w: &mut W) -> io::Result<()>
where
W: AsyncWrite + Unpin,
{
let mut buf = BytesMut::with_capacity(self.serialized_len());
self.write_to_buf(&mut buf);
w.write_all(&buf).await
}
pub fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
let TcpRequestHeader {
ref address,
ref command,
} = *self;
buf.put_slice(&[consts::SOCKS5_VERSION, command.as_u8(), 0x00]);
address.write_to_buf(buf);
}
#[inline]
pub fn serialized_len(&self) -> usize {
self.address.serialized_len() + 3
}
}
#[derive(Clone, Debug)]
pub struct TcpResponseHeader {
pub reply: Reply,
pub address: Address,
}
impl TcpResponseHeader {
pub fn new(reply: Reply, address: Address) -> TcpResponseHeader {
TcpResponseHeader { reply, address }
}
pub async fn read_from<R>(r: &mut R) -> Result<TcpResponseHeader, Error>
where
R: AsyncRead + Unpin,
{
let mut buf = [0u8; 3];
let _ = r.read_exact(&mut buf).await?;
let ver = buf[0];
let reply_code = buf[1];
if ver != consts::SOCKS5_VERSION {
return Err(Error::UnsupportedSocksVersion(ver));
}
let address = Address::read_from(r).await?;
Ok(TcpResponseHeader {
reply: Reply::from_u8(reply_code),
address,
})
}
pub async fn write_to<W>(&self, w: &mut W) -> io::Result<()>
where
W: AsyncWrite + Unpin,
{
let mut buf = BytesMut::with_capacity(self.serialized_len());
self.write_to_buf(&mut buf);
w.write_all(&buf).await
}
pub fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
let TcpResponseHeader { ref reply, ref address } = *self;
buf.put_slice(&[consts::SOCKS5_VERSION, reply.as_u8(), 0x00]);
address.write_to_buf(buf);
}
#[inline]
pub fn serialized_len(&self) -> usize {
self.address.serialized_len() + 3
}
}
#[derive(Clone, Debug)]
pub struct HandshakeRequest {
pub methods: Vec<u8>,
}
impl HandshakeRequest {
pub fn new(methods: Vec<u8>) -> HandshakeRequest {
HandshakeRequest { methods }
}
pub async fn read_from<R>(r: &mut R) -> Result<HandshakeRequest, Error>
where
R: AsyncRead + Unpin,
{
let mut buf = [0u8; 2];
let _ = r.read_exact(&mut buf).await?;
let ver = buf[0];
let nmet = buf[1];
if ver != consts::SOCKS5_VERSION {
return Err(Error::UnsupportedSocksVersion(ver));
}
let mut methods = vec![0u8; nmet as usize];
let _ = r.read_exact(&mut methods).await?;
Ok(HandshakeRequest { methods })
}
pub async fn write_to<W>(&self, w: &mut W) -> io::Result<()>
where
W: AsyncWrite + Unpin,
{
let mut buf = BytesMut::with_capacity(self.serialized_len());
self.write_to_buf(&mut buf);
w.write_all(&buf).await
}
pub fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
let HandshakeRequest { ref methods } = *self;
buf.put_slice(&[consts::SOCKS5_VERSION, methods.len() as u8]);
buf.put_slice(methods);
}
pub fn serialized_len(&self) -> usize {
2 + self.methods.len()
}
}
#[derive(Clone, Debug, Copy)]
pub struct HandshakeResponse {
pub chosen_method: u8,
}
impl HandshakeResponse {
pub fn new(cm: u8) -> HandshakeResponse {
HandshakeResponse { chosen_method: cm }
}
pub async fn read_from<R>(r: &mut R) -> Result<HandshakeResponse, Error>
where
R: AsyncRead + Unpin,
{
let mut buf = [0u8; 2];
let _ = r.read_exact(&mut buf).await?;
let ver = buf[0];
let met = buf[1];
if ver != consts::SOCKS5_VERSION {
Err(Error::UnsupportedSocksVersion(ver))
} else {
Ok(HandshakeResponse { chosen_method: met })
}
}
pub async fn write_to<W>(self, w: &mut W) -> io::Result<()>
where
W: AsyncWrite + Unpin,
{
let mut buf = BytesMut::with_capacity(self.serialized_len());
self.write_to_buf(&mut buf);
w.write_all(&buf).await
}
pub fn write_to_buf<B: BufMut>(self, buf: &mut B) {
buf.put_slice(&[consts::SOCKS5_VERSION, self.chosen_method]);
}
pub fn serialized_len(self) -> usize {
2
}
}
#[derive(Clone, Debug)]
pub struct UdpAssociateHeader {
pub frag: u8,
pub address: Address,
}
impl UdpAssociateHeader {
pub fn new(frag: u8, address: Address) -> UdpAssociateHeader {
UdpAssociateHeader { frag, address }
}
pub async fn read_from<R>(r: &mut R) -> Result<UdpAssociateHeader, Error>
where
R: AsyncRead + Unpin,
{
let mut buf = [0u8; 3];
let _ = r.read_exact(&mut buf).await?;
let frag = buf[2];
let address = Address::read_from(r).await?;
Ok(UdpAssociateHeader::new(frag, address))
}
pub async fn write_to<W>(&self, w: &mut W) -> io::Result<()>
where
W: AsyncWrite + Unpin,
{
let mut buf = BytesMut::with_capacity(self.serialized_len());
self.write_to_buf(&mut buf);
w.write_all(&buf).await
}
pub fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
let UdpAssociateHeader { ref frag, ref address } = *self;
buf.put_slice(&[0x00, 0x00, *frag]);
address.write_to_buf(buf);
}
#[inline]
pub fn serialized_len(&self) -> usize {
3 + self.address.serialized_len()
}
}
pub struct PasswdAuthRequest {
pub uname: Vec<u8>,
pub passwd: Vec<u8>,
}
impl PasswdAuthRequest {
pub fn new<U, P>(uname: U, passwd: P) -> PasswdAuthRequest
where
U: Into<Vec<u8>>,
P: Into<Vec<u8>>,
{
let uname = uname.into();
let passwd = passwd.into();
assert!(
!uname.is_empty()
&& uname.len() <= u8::MAX as usize
&& !passwd.is_empty()
&& passwd.len() <= u8::MAX as usize
);
PasswdAuthRequest { uname, passwd }
}
pub async fn read_from<R>(r: &mut R) -> Result<PasswdAuthRequest, Error>
where
R: AsyncRead + Unpin,
{
let mut ver_buf = [0u8; 1];
let _ = r.read_exact(&mut ver_buf).await?;
if ver_buf[0] != 0x01 {
return Err(Error::UnsupportedPasswdAuthVersion(ver_buf[0]));
}
let mut ulen_buf = [0u8; 1];
let _ = r.read_exact(&mut ulen_buf).await?;
let ulen = ulen_buf[0] as usize;
if ulen == 0 {
return Err(Error::PasswdAuthInvalidRequest);
}
let mut uname = vec![0u8; ulen];
if ulen > 0 {
let _ = r.read_exact(&mut uname).await?;
}
let mut plen_buf = [0u8; 1];
let _ = r.read_exact(&mut plen_buf).await?;
let plen = plen_buf[0] as usize;
if plen == 0 {
return Err(Error::PasswdAuthInvalidRequest);
}
let mut passwd = vec![0u8; plen];
if plen > 0 {
let _ = r.read_exact(&mut passwd).await?;
}
Ok(PasswdAuthRequest { uname, passwd })
}
pub async fn write_to<W>(&self, w: &mut W) -> io::Result<()>
where
W: AsyncWrite + Unpin,
{
let mut buf = BytesMut::with_capacity(self.serialized_len());
self.write_to_buf(&mut buf);
w.write_all(&buf).await
}
fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
buf.put_u8(0x01);
buf.put_u8(self.uname.len() as u8);
buf.put_slice(&self.uname);
buf.put_u8(self.passwd.len() as u8);
buf.put_slice(&self.passwd);
}
#[inline]
pub fn serialized_len(&self) -> usize {
1 + 1 + self.uname.len() + 1 + self.passwd.len()
}
}
pub struct PasswdAuthResponse {
pub status: u8,
}
impl PasswdAuthResponse {
pub fn new(status: u8) -> PasswdAuthResponse {
PasswdAuthResponse { status }
}
pub async fn read_from<R>(r: &mut R) -> Result<PasswdAuthResponse, Error>
where
R: AsyncRead + Unpin,
{
let mut buf = [0u8; 2];
let _ = r.read_exact(&mut buf).await;
if buf[0] != 0x01 {
return Err(Error::UnsupportedPasswdAuthVersion(buf[0]));
}
Ok(PasswdAuthResponse { status: buf[1] })
}
pub async fn write_to<W>(&self, w: &mut W) -> io::Result<()>
where
W: AsyncWrite + Unpin,
{
let mut buf = BytesMut::with_capacity(self.serialized_len());
self.write_to_buf(&mut buf);
w.write_all(&buf).await
}
fn write_to_buf<B: BufMut>(&self, buf: &mut B) {
buf.put_u8(0x01);
buf.put_u8(self.status);
}
#[inline]
pub fn serialized_len(&self) -> usize {
2
}
}