use std::convert::Infallible;
use std::ops::ControlFlow;
use std::task::Context;
use std::task::Poll;
use axum::extract::Request;
use axum::response::IntoResponse;
use axum::response::Response;
use futures_lite::future::Boxed;
use tower::Layer;
use tower::Service;
pub mod catch_unwind;
pub trait SimpleGalvynMiddleware: Clone + Send + Sync + 'static {
fn pre_handler(
&mut self,
request: Request,
) -> impl Future<Output = ControlFlow<Response, Request>> + Send {
async move { ControlFlow::Continue(request) }
}
fn post_handler(&mut self, response: Response) -> impl Future<Output = Response> + Send {
async move { response }
}
}
pub trait GalvynMiddleware: Clone + Send + Sync + 'static {
fn call<S: AxumService>(
self,
inner: S,
request: Request,
) -> impl Future<Output = Result<Response, Infallible>> + Send + 'static;
fn into_layer(self) -> MiddlewareLayer<Self>
where
Self: Sized,
{
MiddlewareLayer(self)
}
}
impl<T: SimpleGalvynMiddleware> GalvynMiddleware for T {
async fn call<S: AxumService>(
mut self,
mut inner: S,
request: Request,
) -> Result<Response, Infallible> {
Ok(match self.pre_handler(request).await {
ControlFlow::Continue(request) => {
let response = inner.call(request).await.into_response();
self.post_handler(response).await
}
ControlFlow::Break(response) => response,
})
}
}
#[derive(Copy, Clone, Debug)]
pub struct MiddlewareLayer<M>(pub M);
impl<M, S> Layer<S> for MiddlewareLayer<M>
where
M: GalvynMiddleware,
{
type Service = MiddlewareService<M, S>;
fn layer(&self, inner: S) -> Self::Service {
MiddlewareService {
inner,
middleware: self.0.clone(),
}
}
}
#[derive(Copy, Clone, Debug)]
pub struct MiddlewareService<M, S> {
inner: S,
middleware: M,
}
impl<M, S> Service<Request> for MiddlewareService<M, S>
where
M: GalvynMiddleware,
S: AxumService,
{
type Response = Response;
type Error = Infallible;
type Future = Boxed<Result<Response, Infallible>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request) -> Self::Future {
let not_ready_inner = self.inner.clone();
let ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
let middleware = self.middleware.clone();
Box::pin(middleware.call(ready_inner, request))
}
}
pub trait AxumService:
Service<Request, Error = Infallible, Response: IntoResponse, Future: Send + 'static>
+ Clone
+ Send
+ 'static
{
}
impl<T> AxumService for T where
T: Service<Request, Error = Infallible, Response: IntoResponse, Future: Send + 'static>
+ Clone
+ Send
+ 'static
{
}