use socks5_proto::{Address, Reply, Response};
use std::{
io::Error,
marker::PhantomData,
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
};
use tokio::{
io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf},
net::TcpStream,
};
pub mod state {
#[derive(Debug)]
pub struct NeedFirstReply;
#[derive(Debug)]
pub struct NeedSecondReply;
#[derive(Debug)]
pub struct Ready;
}
#[derive(Debug)]
pub struct Bind<S> {
stream: TcpStream,
_state: PhantomData<S>,
}
impl Bind<state::NeedFirstReply> {
pub async fn reply(
mut self,
reply: Reply,
addr: Address,
) -> Result<Bind<state::NeedSecondReply>, (Error, TcpStream)> {
let resp = Response::new(reply, addr);
if let Err(err) = resp.write_to(&mut self.stream).await {
return Err((err, self.stream));
}
Ok(Bind::new(self.stream))
}
}
impl Bind<state::NeedSecondReply> {
pub async fn reply(
mut self,
reply: Reply,
addr: Address,
) -> Result<Bind<state::Ready>, (Error, TcpStream)> {
let resp = Response::new(reply, addr);
if let Err(err) = resp.write_to(&mut self.stream).await {
return Err((err, self.stream));
}
Ok(Bind::new(self.stream))
}
}
impl<S> Bind<S> {
#[inline]
pub(super) fn new(stream: TcpStream) -> Self {
Self {
stream,
_state: PhantomData,
}
}
#[inline]
pub async fn close(&mut self) -> Result<(), Error> {
self.stream.shutdown().await
}
#[inline]
pub fn local_addr(&self) -> Result<SocketAddr, Error> {
self.stream.local_addr()
}
#[inline]
pub fn peer_addr(&self) -> Result<SocketAddr, Error> {
self.stream.peer_addr()
}
#[inline]
pub fn get_ref(&self) -> &TcpStream {
&self.stream
}
#[inline]
pub fn get_mut(&mut self) -> &mut TcpStream {
&mut self.stream
}
#[inline]
pub fn into_inner(self) -> TcpStream {
self.stream
}
}
impl AsyncRead for Bind<state::Ready> {
#[inline]
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<(), Error>> {
Pin::new(&mut self.stream).poll_read(cx, buf)
}
}
impl AsyncWrite for Bind<state::Ready> {
#[inline]
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
Pin::new(&mut self.stream).poll_write(cx, buf)
}
#[inline]
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Pin::new(&mut self.stream).poll_flush(cx)
}
#[inline]
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Pin::new(&mut self.stream).poll_shutdown(cx)
}
}