use self::{associate::UdpAssociate, bind::Bind, connect::Connect};
use crate::{
protocol::{self, Address, AsyncStreamOperation, AuthMethod, Command, handshake},
server::AuthAdaptor,
};
use std::{net::SocketAddr, time::Duration};
use tokio::{io::AsyncWriteExt, net::TcpStream};
pub mod associate;
pub mod bind;
pub mod connect;
pub struct IncomingConnection<O> {
stream: TcpStream,
auth: AuthAdaptor<O>,
}
impl<O: 'static> IncomingConnection<O> {
#[inline]
pub(crate) fn new(stream: TcpStream, auth: AuthAdaptor<O>) -> Self {
IncomingConnection { stream, auth }
}
#[inline]
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.stream.local_addr()
}
#[inline]
pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
self.stream.peer_addr()
}
#[inline]
pub async fn shutdown(&mut self) -> std::io::Result<()> {
self.stream.shutdown().await
}
#[inline]
pub fn nodelay(&self) -> std::io::Result<bool> {
self.stream.nodelay()
}
pub fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
self.stream.set_nodelay(nodelay)
}
pub fn ttl(&self) -> std::io::Result<u32> {
self.stream.ttl()
}
pub fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
self.stream.set_ttl(ttl)
}
pub async fn authenticate_with_timeout(self, timeout: Duration) -> crate::Result<(Authenticated, O)> {
tokio::time::timeout(timeout, self.authenticate())
.await
.map_err(|_| crate::Error::String("handshake timeout".into()))?
}
pub async fn authenticate(mut self) -> crate::Result<(Authenticated, O)> {
let request = handshake::Request::retrieve_from_async_stream(&mut self.stream).await?;
if let Some(method) = self.evaluate_request(&request) {
let response = handshake::Response::new(method);
response.write_to_async_stream(&mut self.stream).await?;
let output = self.auth.execute(&mut self.stream).await;
Ok((Authenticated::new(self.stream), output))
} else {
let response = handshake::Response::new(AuthMethod::NoAcceptableMethods);
response.write_to_async_stream(&mut self.stream).await?;
let err = "No available handshake method provided by client";
Err(crate::Error::Io(std::io::Error::new(std::io::ErrorKind::Unsupported, err)))
}
}
fn evaluate_request(&self, req: &handshake::Request) -> Option<AuthMethod> {
let method = self.auth.auth_method();
if req.evaluate_method(method) { Some(method) } else { None }
}
}
impl<O> std::fmt::Debug for IncomingConnection<O> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IncomingConnection").field("stream", &self.stream).finish()
}
}
impl<O> From<IncomingConnection<O>> for TcpStream {
#[inline]
fn from(conn: IncomingConnection<O>) -> Self {
conn.stream
}
}
pub struct Authenticated(TcpStream);
impl Authenticated {
#[inline]
fn new(stream: TcpStream) -> Self {
Self(stream)
}
pub async fn wait_request(mut self) -> crate::Result<ClientConnection> {
let req = protocol::Request::retrieve_from_async_stream(&mut self.0).await?;
match req.command {
Command::UdpAssociate => Ok(ClientConnection::UdpAssociate(
UdpAssociate::<associate::NeedReply>::new(self.0),
req.address,
)),
Command::Bind => Ok(ClientConnection::Bind(Bind::<bind::NeedFirstReply>::new(self.0), req.address)),
Command::Connect => Ok(ClientConnection::Connect(Connect::<connect::NeedReply>::new(self.0), req.address)),
}
}
#[inline]
pub async fn shutdown(&mut self) -> std::io::Result<()> {
self.0.shutdown().await
}
#[inline]
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.0.local_addr()
}
#[inline]
pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
self.0.peer_addr()
}
#[inline]
pub fn nodelay(&self) -> std::io::Result<bool> {
self.0.nodelay()
}
pub fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
self.0.set_nodelay(nodelay)
}
pub fn ttl(&self) -> std::io::Result<u32> {
self.0.ttl()
}
pub fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
self.0.set_ttl(ttl)
}
}
impl From<Authenticated> for TcpStream {
#[inline]
fn from(conn: Authenticated) -> Self {
conn.0
}
}
#[derive(Debug)]
pub enum ClientConnection {
UdpAssociate(UdpAssociate<associate::NeedReply>, Address),
Bind(Bind<bind::NeedFirstReply>, Address),
Connect(Connect<connect::NeedReply>, Address),
}