use super::Error;
use bytes::Bytes;
use std::net::Ipv4Addr;
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tracing::trace;
#[inline]
pub async fn read_request<R>(reader: &mut R) -> Result<(u8, Bytes, u16), Error>
where
R: AsyncBufRead + Unpin,
{
let command = reader
.read_u8()
.await
.map_err(|e| Error::ProcessSocksRequest("read command", e))?;
let rport = reader
.read_u16()
.await
.map_err(|e| Error::ProcessSocksRequest("read port", e))?;
let ip = reader
.read_u32()
.await
.map_err(|e| Error::ProcessSocksRequest("read ip", e))?;
let mut user_id = Vec::new();
reader
.read_until(0, &mut user_id)
.await
.map_err(|e| Error::ProcessSocksRequest("read user id", e))?;
user_id.pop();
trace!("User ID: {user_id:?}");
let rhost = if ip >> 24 == 0 {
let mut domain = Vec::new();
reader
.read_until(0, &mut domain)
.await
.map_err(|e| Error::ProcessSocksRequest("read domain", e))?;
domain.pop();
Bytes::from(domain)
} else {
Ipv4Addr::from(ip).to_string().into()
};
Ok((command, rhost, rport))
}
#[inline]
pub async fn write_response<W>(writer: &mut W, response: u8) -> Result<(), Error>
where
W: AsyncWrite + Unpin,
{
let buf = [0x00, response, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
writer
.write_all(&buf)
.await
.map_err(|e| Error::ProcessSocksRequest("write response", e))?;
writer
.flush()
.await
.map_err(|e| Error::ProcessSocksRequest("flush", e))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[tokio::test]
async fn test_read_request_ip() {
crate::tests::setup_logging();
let mut reader = Cursor::new([0x01, 0x00, 0x50, 0x7f, 0x00, 0x00, 0x01, 0x61, 0x00]);
let (command, rhost, rport) = read_request(&mut reader).await.unwrap();
assert_eq!(command, 0x01);
assert_eq!(rhost, "127.0.0.1");
assert_eq!(rport, 0x50);
}
#[tokio::test]
async fn test_read_request_domain() {
crate::tests::setup_logging();
let mut reader = Cursor::new([
0x01, 0x00, 0x50, 0x00, 0x00, 0x00, 0x01, 0x61, 0x00, 0x77, 0x77, 0x77, 0x2e, 0x65,
0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x63, 0x6f, 0x6d, 0x00,
]);
let (command, rhost, rport) = read_request(&mut reader).await.unwrap();
assert_eq!(command, 0x01);
assert_eq!(rhost, "www.example.com");
assert_eq!(rport, 0x50);
}
#[tokio::test]
async fn test_write_response() {
crate::tests::setup_logging();
let mut writer = Cursor::new(Vec::new());
write_response(&mut writer, 0x5a).await.unwrap();
assert_eq!(
writer.get_ref(),
&[0x00, 0x5a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
);
}
}