use std::io;
use bytes::Bytes;
use futures::SinkExt;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_stream::StreamExt;
use tokio_util::codec::Framed;
use crate::hooks::{ConnectionHook, Error, HookResult};
use msg_wire::auth;
#[derive(Debug, thiserror::Error)]
pub enum ServerHookError {
#[error("authentication rejected")]
Rejected,
#[error("connection closed")]
ConnectionClosed,
#[error("expected auth message")]
ExpectedAuthMessage,
}
#[derive(Debug, thiserror::Error)]
pub enum ClientHookError {
#[error("authentication denied")]
Denied,
#[error("connection closed")]
ConnectionClosed,
}
pub struct ServerHook<F> {
validator: F,
}
impl ServerHook<fn(&Bytes) -> bool> {
pub fn accept_all() -> Self {
Self { validator: |_| true }
}
}
impl<F> ServerHook<F>
where
F: Fn(&Bytes) -> bool + Send + Sync + 'static,
{
pub fn new(validator: F) -> Self {
Self { validator }
}
}
impl<Io, F> ConnectionHook<Io> for ServerHook<F>
where
Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
F: Fn(&Bytes) -> bool + Send + Sync + 'static,
{
type Error = ServerHookError;
async fn on_connection(&self, io: Io) -> HookResult<Io, Self::Error> {
let mut conn = Framed::new(io, auth::Codec::new_server());
let msg = conn
.next()
.await
.ok_or(Error::hook(ServerHookError::ConnectionClosed))?
.map_err(|e| io::Error::other(e.to_string()))?;
let auth::Message::Auth(token) = msg else {
return Err(Error::hook(ServerHookError::ExpectedAuthMessage));
};
if !(self.validator)(&token) {
conn.send(auth::Message::Reject).await?;
return Err(Error::hook(ServerHookError::Rejected));
}
conn.send(auth::Message::Ack).await?;
Ok(conn.into_inner())
}
}
pub struct ClientHook {
token: Bytes,
}
impl ClientHook {
pub fn new(token: Bytes) -> Self {
Self { token }
}
}
impl<Io> ConnectionHook<Io> for ClientHook
where
Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
type Error = ClientHookError;
async fn on_connection(&self, io: Io) -> HookResult<Io, Self::Error> {
let mut conn = Framed::new(io, auth::Codec::new_client());
conn.send(auth::Message::Auth(self.token.clone())).await?;
conn.flush().await?;
let ack = conn
.next()
.await
.ok_or(Error::hook(ClientHookError::ConnectionClosed))?
.map_err(|e| io::Error::other(e.to_string()))?;
if !matches!(ack, auth::Message::Ack) {
return Err(Error::hook(ClientHookError::Denied));
}
Ok(conn.into_inner())
}
}