use std::task::{Context, Poll};
use crate::ServiceBound;
use async_trait::async_trait;
use futures_util::future::BoxFuture;
use tonic::body::Body;
use tonic::codegen::http::Request;
use tonic::codegen::http::Response;
use tonic::codegen::Service;
use tonic::server::NamedService;
use tower::Layer;
#[async_trait]
pub trait Middleware<S>
where
S: ServiceBound,
{
async fn call(&self, req: Request<Body>, service: S) -> Result<Response<Body>, S::Error>;
}
#[derive(Clone)]
pub struct MiddlewareFor<S, M>
where
S: ServiceBound,
M: Middleware<S>,
{
pub inner: S,
pub middleware: M,
}
impl<S, M> MiddlewareFor<S, M>
where
S: ServiceBound,
M: Middleware<S>,
{
pub fn new(inner: S, middleware: M) -> Self {
MiddlewareFor { inner, middleware }
}
}
impl<S, M> Service<Request<Body>> for MiddlewareFor<S, M>
where
S: ServiceBound,
S::Future: Send,
M: Middleware<S> + Send + Clone + 'static + Sync,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let middleware = self.middleware.clone();
let inner = self.inner.clone();
Box::pin(async move { middleware.call(req, inner).await })
}
}
impl<S, M> NamedService for MiddlewareFor<S, M>
where
S: NamedService + ServiceBound,
M: Middleware<S>,
{
const NAME: &'static str = S::NAME;
}
#[derive(Clone)]
pub struct MiddlewareLayer<M> {
middleware: M,
}
impl<M> MiddlewareLayer<M> {
pub fn new(middleware: M) -> Self {
MiddlewareLayer { middleware }
}
}
impl<S, M> Layer<S> for MiddlewareLayer<M>
where
S: ServiceBound,
M: Middleware<S> + Clone,
{
type Service = MiddlewareFor<S, M>;
fn layer(&self, inner: S) -> Self::Service {
MiddlewareFor::new(inner, self.middleware.clone())
}
}