#[forbid(unsafe_code)]
use crate::read_exact;
use crate::socks4::{consts, ReplyError, Socks4Command};
use crate::util::target_addr::{TargetAddr, ToTargetAddr};
use crate::{Result, SocksError, SocksError::ReplySocks4Error};
use anyhow::Context;
use std::io;
use std::net::SocketAddr;
use std::net::ToSocketAddrs;
use std::pin::Pin;
use std::task::Poll;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
const MAX_ADDR_LEN: usize = 260;
#[derive(Debug)]
pub struct Socks4Stream<S: AsyncRead + AsyncWrite + Unpin> {
socket: S,
target_addr: Option<TargetAddr>,
}
impl<S> Socks4Stream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
pub fn use_stream(socket: S) -> Result<Self> {
let stream = Socks4Stream {
socket,
target_addr: None,
};
Ok(stream)
}
pub async fn request(
&mut self,
cmd: Socks4Command,
target_addr: TargetAddr,
resolve_locally: bool,
) -> Result<()> {
let resolved = if target_addr.is_domain() && resolve_locally {
target_addr.resolve_dns().await?
} else {
target_addr
};
self.target_addr = Some(resolved);
self.send_command_request(&cmd).await?;
self.read_command_request().await?;
Ok(())
}
async fn send_command_request(&mut self, cmd: &Socks4Command) -> Result<()> {
let mut packet = [0u8; MAX_ADDR_LEN];
packet[0] = consts::SOCKS4_VERSION;
packet[1] = cmd.as_u8();
match &self.target_addr {
Some(TargetAddr::Ip(SocketAddr::V4(addr))) => {
packet[2] = (addr.port() >> 8) as u8;
packet[3] = addr.port() as u8;
packet[4..8].copy_from_slice(&(addr.ip()).octets());
Ok(())
}
Some(TargetAddr::Ip(SocketAddr::V6(addr))) => {
error!("IPv6 are not supported: {:?}", addr);
Err(ReplySocks4Error(ReplyError::AddressTypeNotSupported))
}
Some(TargetAddr::Domain(domain, port)) => {
packet[2] = (port >> 8) as u8;
packet[3] = *port as u8;
packet[4..8].copy_from_slice(&[0, 0, 0, 1]);
let domain_bytes = domain.as_bytes();
let offset = 8 + domain_bytes.len();
packet[8..offset].copy_from_slice(domain_bytes);
Ok(())
}
_ => {
panic!("Unreachable case");
}
}?;
self.socket.write_all(&packet).await?;
Ok(())
}
#[rustfmt::skip]
async fn read_command_request(&mut self) -> Result<()> {
let [_, cd] = read_exact!(self.socket, [0u8; 2])?;
let reply = ReplyError::from_u8(cd);
match reply {
ReplyError::Succeeded => Ok(()),
_ => Err(SocksError::ReplySocks4Error(reply))
}
}
pub fn get_socket(self) -> S {
self.socket
}
pub fn get_socket_ref(&self) -> &S {
&self.socket
}
pub fn get_socket_mut(&mut self) -> &mut S {
&mut self.socket
}
}
impl Socks4Stream<TcpStream> {
pub async fn connect<T>(
socks_server: T,
target_addr: String,
target_port: u16,
resolve_locally: bool,
) -> Result<Self>
where
T: ToSocketAddrs,
{
Self::connect_raw(
Socks4Command::Connect,
socks_server,
target_addr,
target_port,
resolve_locally,
)
.await
}
pub async fn connect_raw<T>(
cmd: Socks4Command,
socks_server: T,
target_addr: String,
target_port: u16,
resolve_locally: bool,
) -> Result<Self>
where
T: ToSocketAddrs,
{
let socket = TcpStream::connect(
socks_server
.to_socket_addrs()?
.next()
.context("unreachable")?,
)
.await?;
info!("Connected @ {}", &socket.peer_addr()?);
let target_addr = (target_addr.as_str(), target_port)
.to_target_addr()
.context("Can't convert address to TargetAddr format")?;
let mut socks_stream = Self::use_stream(socket)?;
socks_stream
.request(cmd, target_addr, resolve_locally)
.await?;
Ok(socks_stream)
}
}
impl<S> AsyncRead for Socks4Stream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.socket).poll_read(context, buf)
}
}
impl<S> AsyncWrite for Socks4Stream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.socket).poll_write(context, buf)
}
fn poll_flush(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.socket).poll_flush(context)
}
fn poll_shutdown(
mut self: Pin<&mut Self>,
context: &mut std::task::Context,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.socket).poll_shutdown(context)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn get_domain() -> String {
"www.google.com".to_string()
}
async fn get_humans_txt(socks: &mut Socks4Stream<TcpStream>) -> Option<String> {
let headers = format!(
"GET /humans.txt HTTP/1.1\r\n\
Host: {}\r\n\
User-Agent: fast-socks5/0.1.0\r\n\
Accept: */*\r\n\r\n",
get_domain()
);
socks
.write_all(headers.as_bytes())
.await
.expect("should successfully write");
let response = &mut [0u8; 2048];
socks
.read(response)
.await
.expect("should successfully read");
if response[0] == 0 {
response.copy_from_slice(&[0u8; 2048]);
socks
.read(response)
.await
.expect("should successfully read");
}
let response_str = String::from_utf8_lossy(response);
let response_body = response_str
.split("\n")
.into_iter()
.filter(|x| x.starts_with("Google"))
.last()
.map(|x| x.to_string());
response_body
}
fn assert_response_body(response_body: &String) {
let expected =
"Google is built by a large team of engineers, designers, researchers, robots, \
and others in many different sites across the globe. It is updated continuously, \
and built with more tools and technologies than we can shake a stick at. If you'd \
like to help us out, see careers.google.com.";
assert_eq!(expected, response_body);
}
#[tokio::test]
pub async fn test_use_stream() {
let tcp = TcpStream::connect("217.17.56.160:4145")
.await
.expect("should connect to remote");
let mut socks = Socks4Stream::use_stream(tcp).expect("should wrap to socks stream");
socks
.request(
Socks4Command::Connect,
TargetAddr::Domain(get_domain(), 80),
true,
)
.await
.expect("should send connect successfully");
let response_body = get_humans_txt(&mut socks)
.await
.expect("should have response_body");
assert_response_body(&response_body);
}
#[tokio::test]
pub async fn test_use_stream_local_resolve() {
let mut socks = Socks4Stream::connect("217.17.56.160:4145", get_domain(), 80, true)
.await
.expect("should connect successfully to socks4 server");
let response_body = get_humans_txt(&mut socks)
.await
.expect("should have response_body");
assert_response_body(&response_body);
}
}