use std::{future::Future, io::Error};
use password::{Request, Response, Status::*};
use tokio::net::TcpStream;
use crate::{
ext::Extension,
server::socks::proto::{AsyncStreamOperation, Method, UsernamePassword, handshake::password},
};
pub trait Auth: Send {
type Output;
fn method(&self) -> Method;
fn execute(&self, stream: &mut TcpStream) -> impl Future<Output = Self::Output> + Send;
}
#[non_exhaustive]
pub enum AuthAdaptor {
NoAuth(NoAuth),
Password(Password),
}
impl AuthAdaptor {
#[inline]
pub fn no() -> Self {
Self::NoAuth(NoAuth)
}
#[inline]
pub fn password<S: Into<String>>(username: S, password: S) -> Self {
AuthAdaptor::Password(Password::new(username, password))
}
}
impl Auth for AuthAdaptor {
type Output = std::io::Result<Extension>;
#[inline]
fn method(&self) -> Method {
match self {
Self::NoAuth(auth) => auth.method(),
Self::Password(auth) => auth.method(),
}
}
#[inline]
async fn execute(&self, stream: &mut TcpStream) -> Self::Output {
match self {
Self::NoAuth(auth) => auth.execute(stream).await,
Self::Password(auth) => auth.execute(stream).await,
}
}
}
pub struct NoAuth;
impl Auth for NoAuth {
type Output = std::io::Result<Extension>;
#[inline]
fn method(&self) -> Method {
Method::NoAuth
}
#[inline]
async fn execute(&self, _stream: &mut TcpStream) -> Self::Output {
Ok(Extension::None)
}
}
pub struct Password {
inner: UsernamePassword,
}
impl Password {
pub fn new<S: Into<String>>(username: S, password: S) -> Self {
Self {
inner: UsernamePassword::new(username, password),
}
}
}
impl Auth for Password {
type Output = std::io::Result<Extension>;
#[inline]
fn method(&self) -> Method {
Method::Password
}
async fn execute(&self, stream: &mut TcpStream) -> Self::Output {
let req = Request::retrieve_from_async_stream(stream).await?;
let is_equal = req.user_pass.username.starts_with(&self.inner.username)
&& req.user_pass.password.eq(&self.inner.password);
let resp = Response::new(if is_equal { Succeeded } else { Failed });
resp.write_to_async_stream(stream).await?;
if is_equal {
let extension = Extension::try_from(&self.inner.username, req.user_pass.username)
.await
.map_err(|_| Error::other("failed to parse extension"))?;
Ok(extension)
} else {
Err(Error::other("username or password is incorrect"))
}
}
}