use crate::general::ConnectionTimeouts;
use crate::proxy::ProxyConstructor;
use crate::clients::socks4::{ErrorKind, Command};
use byteorder::{ByteOrder, BigEndian};
use tokio::net::TcpStream;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::time::timeout;
use std::pin::Pin;
use core::task::{Poll, Context};
use std::net::SocketAddrV4;
use std::str::FromStr;
use std::borrow::Cow;
use std::io;
pub struct Socks4General {
dest_addr: SocketAddrV4,
ident: Cow<'static, str>,
timeouts: ConnectionTimeouts
}
#[derive(Debug)]
pub enum StrParsingError {
SyntaxError,
InvalidAddr,
InvalidTimeouts,
}
pub struct S4GeneralStream {
wrapped_stream: TcpStream
}
impl Socks4General {
pub fn new(dest_addr: SocketAddrV4, ident: Cow<'static, str>,
timeouts: ConnectionTimeouts)
-> Socks4General
{
Socks4General { dest_addr, ident, timeouts }
}
}
impl FromStr for Socks4General {
type Err = StrParsingError;
fn from_str(s: &str) -> Result<Socks4General, Self::Err> {
let mut s = s.split(" ");
let (address, ident, timeouts) = (s.next()
.ok_or(StrParsingError::SyntaxError)?
.parse::<SocketAddrV4>()
.map_err(|_| StrParsingError::InvalidAddr)?,
s.next()
.ok_or(StrParsingError::SyntaxError)?,
s.next()
.ok_or(StrParsingError::SyntaxError)?
.parse::<ConnectionTimeouts>()
.map_err(|_| StrParsingError::InvalidTimeouts)?);
Ok(Socks4General::new(address, Cow::Owned(ident.to_owned()), timeouts))
}
}
#[async_trait::async_trait]
impl ProxyConstructor for Socks4General {
type ProxyStream = S4GeneralStream;
type Stream = TcpStream;
type ErrorKind = ErrorKind;
async fn connect(&mut self, mut stream: Self::Stream)
-> Result<Self::ProxyStream, Self::ErrorKind>
{
let buf_len = 1 + 1 + 2 + 4 + self.ident.len() + 1;
let mut buf = Vec::with_capacity(buf_len);
buf.push(4);
buf.push(Command::TcpConnectionEstablishment as u8);
buf.push(0);
buf.push(0);
BigEndian::write_u16(&mut buf[2..4], self.dest_addr.port());
buf.push(0);
buf.push(0);
buf.push(0);
buf.push(0);
BigEndian::write_u32(&mut buf[4..8], (*self.dest_addr.ip()).into());
buf.push(0);
let future = stream.write_all(&buf);
let future = timeout(self.timeouts.write_timeout, future);
let _ = future.await.map_err(|_| ErrorKind::OperationTimeoutReached)?
.map_err(|e| ErrorKind::IOError(e))?;
let future = stream.read(&mut buf);
let future = timeout(self.timeouts.read_timeout, future);
let read_bytes = future.await.map_err(|_| ErrorKind::OperationTimeoutReached)?
.map_err(|e| ErrorKind::IOError(e))?;
if read_bytes != 8 {
return Err(ErrorKind::BadBuffer)
}
match buf[1] {
0x5a => Ok(S4GeneralStream { wrapped_stream: stream }),
0x5b => Err(ErrorKind::RequestDenied),
0x5c => Err(ErrorKind::IdentIsUnavailable),
0x5d => Err(ErrorKind::BadIdent),
_ => Err(ErrorKind::BadBuffer)
}
}
}
impl AsyncRead for S4GeneralStream {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8])
-> Poll<io::Result<usize>>
{
let pinned = &mut Pin::into_inner(self).wrapped_stream;
Pin::new(pinned).poll_read(cx, buf)
}
}
impl AsyncWrite for S4GeneralStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8])
-> Poll<Result<usize, io::Error>>
{
let stream = &mut Pin::into_inner(self).wrapped_stream;
Pin::new(stream).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>)
-> Poll<Result<(), io::Error>>
{
let stream = &mut Pin::into_inner(self).wrapped_stream;
Pin::new(stream).poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>)
-> Poll<Result<(), io::Error>>
{
let stream = &mut Pin::into_inner(self).wrapped_stream;
Pin::new(stream).poll_shutdown(cx)
}
}
impl Into<TcpStream> for S4GeneralStream {
fn into(self) -> TcpStream {
self.wrapped_stream
}
}