pub mod associate;
pub mod bind;
pub mod connect;
use std::{net::SocketAddr, sync::Arc, time::Duration};
use tokio::{io::AsyncWriteExt, net::TcpStream};
use self::{associate::UdpAssociate, bind::Bind, connect::Connect};
use super::{
auth::{Auth, AuthAdaptor},
error::Error,
proto::{self, Address, AsyncStreamOperation, Command, Method, handshake},
};
pub struct IncomingConnection {
stream: TcpStream,
auth: Arc<AuthAdaptor>,
}
impl IncomingConnection {
#[inline]
pub(crate) fn new(stream: TcpStream, auth: Arc<AuthAdaptor>) -> Self {
IncomingConnection { stream, auth }
}
#[inline]
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.stream.local_addr()
}
#[inline]
pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
self.stream.peer_addr()
}
#[inline]
pub async fn shutdown(&mut self) -> std::io::Result<()> {
self.stream.shutdown().await
}
#[inline]
pub fn linger(&self) -> std::io::Result<Option<Duration>> {
self.stream.linger()
}
#[inline]
pub fn set_linger(&self, dur: Option<Duration>) -> std::io::Result<()> {
self.stream.set_linger(dur)
}
#[inline]
pub fn nodelay(&self) -> std::io::Result<bool> {
self.stream.nodelay()
}
pub fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
self.stream.set_nodelay(nodelay)
}
pub fn ttl(&self) -> std::io::Result<u32> {
self.stream.ttl()
}
pub fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
self.stream.set_ttl(ttl)
}
pub async fn authenticate(
mut self,
) -> std::io::Result<(AuthenticatedStream, <AuthAdaptor as Auth>::Output)> {
let request = handshake::Request::retrieve_from_async_stream(&mut self.stream).await?;
if let Some(method) = self.evaluate_request(&request) {
let response = handshake::Response::new(method);
response.write_to_async_stream(&mut self.stream).await?;
let output = self.auth.execute(&mut self.stream).await;
Ok((AuthenticatedStream::new(self.stream), output))
} else {
let response = handshake::Response::new(Method::NoAcceptableMethods);
response.write_to_async_stream(&mut self.stream).await?;
let err = "No available handshake method provided by client";
Err(std::io::Error::new(std::io::ErrorKind::Unsupported, err))
}
}
fn evaluate_request(&self, req: &handshake::Request) -> Option<Method> {
let method = self.auth.method();
if req.evaluate_method(method) {
Some(method)
} else {
None
}
}
}
impl std::fmt::Debug for IncomingConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IncomingConnection")
.field("stream", &self.stream)
.finish()
}
}
impl From<IncomingConnection> for TcpStream {
#[inline]
fn from(conn: IncomingConnection) -> Self {
conn.stream
}
}
pub struct AuthenticatedStream(TcpStream);
impl AuthenticatedStream {
#[inline]
fn new(stream: TcpStream) -> Self {
Self(stream)
}
pub async fn wait_request(mut self) -> Result<ClientConnection, Error> {
let req = proto::Request::retrieve_from_async_stream(&mut self.0).await?;
match req.command {
Command::UdpAssociate => Ok(ClientConnection::UdpAssociate(
UdpAssociate::<associate::NeedReply>::new(self.0),
req.address,
)),
Command::Bind => Ok(ClientConnection::Bind(
Bind::<bind::NeedFirstReply>::new(self.0),
req.address,
)),
Command::Connect => Ok(ClientConnection::Connect(
Connect::<connect::NeedReply>::new(self.0),
req.address,
)),
}
}
#[inline]
pub async fn shutdown(&mut self) -> std::io::Result<()> {
self.0.shutdown().await
}
#[inline]
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.0.local_addr()
}
#[inline]
pub fn peer_addr(&self) -> std::io::Result<SocketAddr> {
self.0.peer_addr()
}
#[inline]
pub fn linger(&self) -> std::io::Result<Option<Duration>> {
self.0.linger()
}
#[inline]
pub fn set_linger(&self, dur: Option<Duration>) -> std::io::Result<()> {
self.0.set_linger(dur)
}
#[inline]
pub fn nodelay(&self) -> std::io::Result<bool> {
self.0.nodelay()
}
#[inline]
pub fn set_nodelay(&self, nodelay: bool) -> std::io::Result<()> {
self.0.set_nodelay(nodelay)
}
#[inline]
pub fn ttl(&self) -> std::io::Result<u32> {
self.0.ttl()
}
#[inline]
pub fn set_ttl(&self, ttl: u32) -> std::io::Result<()> {
self.0.set_ttl(ttl)
}
}
impl From<AuthenticatedStream> for TcpStream {
#[inline]
fn from(conn: AuthenticatedStream) -> Self {
conn.0
}
}
#[derive(Debug)]
pub enum ClientConnection {
UdpAssociate(UdpAssociate<associate::NeedReply>, Address),
Bind(Bind<bind::NeedFirstReply>, Address),
Connect(Connect<connect::NeedReply>, Address),
}