tower-livereload 0.10.3

A LiveReload middleware built on top of tower.
Documentation
use std::{convert::Infallible, future::Future, sync::Arc, task::Poll};

use bytes::Buf;
use http::{request::Parts, Request, Response};
use http_body::{Body, Frame};
use tower::Service;

pub struct OverlayService<B, E, S> {
    alternative: Arc<dyn Fn(&Parts) -> Option<Result<Response<B>, E>> + Send + Sync>,
    service: S,
}

impl<B, E, S> OverlayService<B, E, S> {
    pub fn new(
        service: S,
        alternative_fn: impl Fn(&Parts) -> Option<Result<Response<B>, E>> + Send + Sync + 'static,
    ) -> Self {
        Self {
            alternative: Arc::new(alternative_fn),
            service,
        }
    }
}

impl<B, E, S: Clone> Clone for OverlayService<B, E, S> {
    fn clone(&self) -> Self {
        OverlayService {
            alternative: self.alternative.clone(),
            service: self.service.clone(),
        }
    }
}

impl<B, E, S: std::fmt::Debug> std::fmt::Debug for OverlayService<B, E, S> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("OverlayService")
            .field("alternative", &"...")
            .field("service", &self.service)
            .finish()
    }
}

impl<S, E, ReqBody, ResBody, ResBodyNew> Service<Request<ReqBody>>
    for OverlayService<ResBodyNew, E, S>
where
    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
{
    type Response = Response<OverlayBody<ResBodyNew, ResBody>>;
    type Error = OverlayError<E, S::Error>;
    type Future = OverlayFuture<ResBodyNew, E, S::Future>;

    fn poll_ready(
        &mut self,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        self.service.poll_ready(cx).map_err(OverlayError::Right)
    }

    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
        let (parts, body) = req.into_parts();
        if let Some(result) = self.alternative.clone()(&parts) {
            OverlayFuture::Alternative {
                alternative: Some(result),
            }
        } else {
            OverlayFuture::Inner {
                inner: self.service.call(Request::from_parts(parts, body)),
            }
        }
    }
}

pin_project_lite::pin_project! {
    #[project = OverlayFutureProj]
    pub enum OverlayFuture<B, E, F> {
        Inner {
            #[pin]
            inner: F
        },
        Alternative {
            alternative: Option<Result<Response<B>, E>>
        },
    }
}

impl<B, E, PB, PE, F> Future for OverlayFuture<B, E, F>
where
    F: Future<Output = Result<Response<PB>, PE>>,
{
    type Output = Result<Response<OverlayBody<B, PB>>, OverlayError<E, PE>>;

    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
        let this = self.project();

        match this {
            OverlayFutureProj::Inner { inner } => inner
                .poll(cx)
                .map_ok(|resp| resp.map(|right| OverlayBody::Right { right }))
                .map_err(OverlayError::Right),
            OverlayFutureProj::Alternative { alternative } => Poll::Ready(
                alternative
                    .take()
                    .map(|some| {
                        some.map(|ok| ok.map(|left| OverlayBody::Left { left }))
                            .map_err(OverlayError::Left)
                    })
                    .unwrap_or_else(|| unreachable!()),
            ),
        }
    }
}

pin_project_lite::pin_project! {
    #[project = OverlayBodyProj]
    pub enum OverlayBody<L, R> {
        Left {
            #[pin]
            left: L
        },
        Right{
            #[pin]
            right: R
        },
    }
}

impl<Data: Buf, A: Body<Data = Data>, B: Body<Data = Data>> Body for OverlayBody<A, B> {
    type Data = Data;
    type Error = OverlayError<A::Error, B::Error>;

    fn poll_frame(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
        match self.project() {
            OverlayBodyProj::Left { left } => left.poll_frame(cx).map_err(OverlayError::Left),
            OverlayBodyProj::Right { right } => right.poll_frame(cx).map_err(OverlayError::Right),
        }
    }
}

#[derive(Debug, Clone)]
pub enum OverlayError<L, R> {
    Left(L),
    Right(R),
}

impl<L: std::error::Error, R: std::error::Error> std::error::Error for OverlayError<L, R> {
    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
        match self {
            OverlayError::Left(left) => left.source(),
            OverlayError::Right(right) => right.source(),
        }
    }
}

impl<L: std::fmt::Display, R: std::fmt::Display> std::fmt::Display for OverlayError<L, R> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            OverlayError::Left(left) => left.fmt(f),
            OverlayError::Right(right) => right.fmt(f),
        }
    }
}

impl<L, R> From<OverlayError<L, R>> for Infallible
where
    L: Into<Infallible>,
    R: Into<Infallible>,
{
    fn from(value: OverlayError<L, R>) -> Self {
        match value {
            #[allow(unreachable_code)]
            OverlayError::Left(left) => left.into(),
            #[allow(unreachable_code)]
            OverlayError::Right(right) => right.into(),
        }
    }
}