use crate::io::{self, AsyncRead, AsyncReadExt};
use crate::v5::Address;
#[derive(Debug, Clone, PartialEq)]
pub enum Request {
Bind(Address),
Connect(Address),
Associate(Address),
}
#[rustfmt::skip]
impl Request {
const SOCKS5_CMD_CONNECT: u8 = 0x01;
const SOCKS5_CMD_BIND: u8 = 0x02;
const SOCKS5_CMD_ASSOCIATE: u8 = 0x03;
}
impl Request {
pub async fn from_async_read<R: AsyncRead + Unpin>(reader: &mut R) -> io::Result<Self> {
let mut buf = [0u8; 2];
reader.read_exact(&mut buf).await?;
let command = buf[0];
let request = match command {
Self::SOCKS5_CMD_BIND => Self::Bind(Address::from_async_read(reader).await?),
Self::SOCKS5_CMD_CONNECT => Self::Connect(Address::from_async_read(reader).await?),
Self::SOCKS5_CMD_ASSOCIATE => Self::Associate(Address::from_async_read(reader).await?),
command => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Invalid request command: {}", command),
));
}
};
Ok(request)
}
}
#[cfg(test)]
mod tests {
use crate::v5::{Address, Request};
use bytes::{BufMut, BytesMut};
use std::io::Cursor;
use tokio::io::BufReader;
#[tokio::test]
async fn test_request_from_async_read_connect_ipv4() {
let mut buffer = BytesMut::new();
buffer.put_u8(Request::SOCKS5_CMD_CONNECT);
buffer.put_u8(0x00);
buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV4);
buffer.put_slice(&[192, 168, 1, 1]); buffer.put_u16(80);
let bytes = buffer.freeze();
let mut cursor = Cursor::new(bytes);
let mut reader = BufReader::new(&mut cursor);
let request = Request::from_async_read(&mut reader).await.unwrap();
match request {
Request::Connect(addr) => match addr {
Address::IPv4(socket_addr) => {
assert_eq!(socket_addr.ip().octets(), [192, 168, 1, 1]);
assert_eq!(socket_addr.port(), 80);
}
_ => panic!("Should be IPv4 address"),
},
_ => panic!("Should be Connect request"),
}
}
#[tokio::test]
async fn test_request_from_async_read_bind_ipv6() {
let mut buffer = BytesMut::new();
buffer.put_u8(Request::SOCKS5_CMD_BIND);
buffer.put_u8(0x00);
buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_IPV6);
buffer.put_slice(&[0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]); buffer.put_u16(443);
let bytes = buffer.freeze();
let mut cursor = Cursor::new(bytes);
let mut reader = BufReader::new(&mut cursor);
let request = Request::from_async_read(&mut reader).await.unwrap();
match request {
Request::Bind(addr) => match addr {
Address::IPv6(socket_addr) => {
assert_eq!(
socket_addr.ip().octets(),
[0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
);
assert_eq!(socket_addr.port(), 443);
}
_ => panic!("Should be IPv6 address"),
},
_ => panic!("Should be Bind request"),
}
}
#[tokio::test]
async fn test_request_from_async_read_associate_domain() {
let mut buffer = BytesMut::new();
buffer.put_u8(Request::SOCKS5_CMD_ASSOCIATE);
buffer.put_u8(0x00);
buffer.put_u8(Address::SOCKS5_ADDRESS_TYPE_DOMAIN_NAME);
buffer.put_u8(11); buffer.put_slice(b"example.com"); buffer.put_u16(8080);
let bytes = buffer.freeze();
let mut cursor = Cursor::new(bytes);
let mut reader = BufReader::new(&mut cursor);
let request = Request::from_async_read(&mut reader).await.unwrap();
match request {
Request::Associate(addr) => match addr {
Address::Domain(domain, port) => {
assert_eq!(**domain.as_bytes(), *b"example.com");
assert_eq!(port, 8080);
}
_ => panic!("Should be domain address"),
},
_ => panic!("Should be Associate request"),
}
}
#[tokio::test]
async fn test_request_from_async_read_invalid_command() {
let mut buffer = BytesMut::new();
buffer.put_u8(0xFF); buffer.put_u8(0x00);
let bytes = buffer.freeze();
let mut cursor = Cursor::new(bytes);
let mut reader = BufReader::new(&mut cursor);
let result = Request::from_async_read(&mut reader).await;
assert!(result.is_err());
if let Err(e) = result {
assert_eq!(e.kind(), std::io::ErrorKind::InvalidData);
}
}
#[tokio::test]
async fn test_request_from_async_read_incomplete_data() {
let mut buffer = BytesMut::new();
buffer.put_u8(Request::SOCKS5_CMD_CONNECT);
let bytes = buffer.freeze();
let mut cursor = Cursor::new(bytes);
let mut reader = BufReader::new(&mut cursor);
let result = Request::from_async_read(&mut reader).await;
assert!(result.is_err());
}
}