greyhound 0.0.1

al3x's personal backend framework
Documentation
use crate::{
	endpoint::{BoxedEnpoint, Endpoint, Handler, HandlerOutput},
	http,
	request::Request,
	response::Response,
	server::state::State,
};
use async_std::sync::{Arc, RwLock};
use async_trait::async_trait;
use std::{future::Future, pin::Pin};

pub(crate) type MiddlewareLock<S, Args, O> = RwLock<Vec<Arc<dyn Middleware<S, Args, O>>>>;

pub(crate) type MiddlewareArc<S, Args, O> = Arc<MiddlewareLock<S, Args, O>>;

#[async_trait]
pub trait Middleware<S, Args, O>: Send + Sync + 'static
where
	S: State,
{
	async fn handle(&self, args: Args, next: Next<'_, S, Args, O>) -> O;

	fn name(&self) -> &str {
		std::any::type_name::<Self>()
	}
}

#[async_trait]
impl<S, FN> Middleware<S, Request<S>, HandlerOutput> for FN
where
	S: State,
	FN: for<'a> Fn(
			Request<S>,
			Next<'a, S, Request<S>, HandlerOutput>,
		) -> Pin<Box<dyn Future<Output = HandlerOutput> + 'a + Send>>
		+ Send
		+ Sync
		+ 'static,
{
	async fn handle(
		&self,
		args: Request<S>,
		next: Next<'_, S, Request<S>, HandlerOutput>,
	) -> HandlerOutput {
		log::debug!("inside middleware call named: {}", self.name());
		(self)(args, next).await
	}
}

pub struct Next<'a, S, Args, O>
where
	S: State,
{
	pub(crate) handler: &'a Endpoint<S>,
	pub(crate) next_middleware: &'a [Arc<dyn Middleware<S, Args, O>>],
}

impl<S: State> std::fmt::Debug for Next<'_, S, Request<S>, HandlerOutput> {
	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
		f.debug_struct("Next").finish()
	}
}

impl<S: State> Next<'_, S, Request<S>, HandlerOutput> {
	/// # Errors
	///
	/// fails when call to handler fails
	pub async fn run(mut self, req: Request<S>) -> HandlerOutput {
		if let Some((current, next)) = self.next_middleware.split_first() {
			self.next_middleware = next;
			current.handle(req, self).await
		} else {
			self.handler.call(req).await
		}
	}
}

#[derive(Clone)]
pub(crate) struct MiddlewareHandler<H, S: State> {
	handler: H,
	middleware: MiddlewareArc<S, Request<S>, HandlerOutput>,
}

impl<E, S: State> std::fmt::Debug for MiddlewareHandler<E, S> {
	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
		write!(f, "MiddlewareHandler: {:?}", self.middleware.read())
	}
}

impl<H, S> MiddlewareHandler<H, S>
where
	H: Handler<Request<S>, HandlerOutput>,
	S: State,
{
	pub(crate) async fn wrap_with_middleware(
		ep: H,
		middleware: MiddlewareArc<S, Request<S>, HandlerOutput>,
	) -> BoxedEnpoint<S> {
		if middleware.read().await.is_empty() {
			Box::new(ep)
		} else {
			Box::new(Self {
				handler: ep,
				middleware,
			})
		}
	}
}

#[async_trait]
impl<H, S> Handler<Request<S>, HandlerOutput> for MiddlewareHandler<H, S>
where
	S: State,
	H: Handler<Request<S>, HandlerOutput>,
{
	async fn call(&self, args: Request<S>) -> HandlerOutput {
		let next_middleware = self.middleware.read().await;

		let next = Next {
			handler: &self.handler,
			next_middleware: next_middleware.as_ref(),
		};

		next.run(args).await
	}
}

#[derive(Debug)]
pub struct Before<F>(pub F);

#[async_trait]
impl<S, FN, FT> Middleware<S, Request<S>, HandlerOutput> for Before<FN>
where
	S: State,
	FN: Fn(Request<S>) -> FT + Send + Sync + 'static,
	FT: Future<Output = Result<Request<S>, http::Error>> + Send + 'static,
{
	async fn handle(
		&self,
		args: Request<S>,
		next: Next<'_, S, Request<S>, HandlerOutput>,
	) -> HandlerOutput {
		match (self.0)(args).await {
			Ok(req) => next.run(req).await,
			Err(e) => Ok(Response::from(e).into()),
		}
	}
}

#[derive(Debug)]
pub struct After<F>(pub F);

#[async_trait]
impl<S, FN, FT> Middleware<S, Request<S>, HandlerOutput> for After<FN>
where
	S: State,
	FN: Fn(Response) -> FT + Send + Sync + 'static,
	FT: Future<Output = HandlerOutput> + Send + 'static,
{
	async fn handle(
		&self,
		args: Request<S>,
		next: Next<'_, S, Request<S>, HandlerOutput>,
	) -> HandlerOutput {
		match next.run(args).await {
			Ok(res) => (self.0)(res.into()).await,
			Err(e) => Ok(Response::from(e).into()),
		}
	}
}