use std::task::{Context, Poll};
use crate::ServiceBound;
use async_trait::async_trait;
use futures_util::future::BoxFuture;
use tonic::body::Body;
use tonic::codegen::http::Request;
use tonic::codegen::Service;
use tonic::server::NamedService;
use tonic::Status;
use tower::Layer;
#[async_trait]
pub trait RequestInterceptor {
async fn intercept(&self, req: Request<Body>) -> Result<Request<Body>, Status>;
}
#[derive(Clone)]
pub struct InterceptorFor<S, I>
where
I: RequestInterceptor,
{
pub inner: S,
pub interceptor: I,
}
impl<S, I> InterceptorFor<S, I>
where
I: RequestInterceptor,
{
pub fn new(inner: S, interceptor: I) -> Self {
InterceptorFor { inner, interceptor }
}
}
impl<S, I> Service<Request<Body>> for InterceptorFor<S, I>
where
S: ServiceBound,
S::Future: Send,
I: RequestInterceptor + Send + Clone + 'static + Sync,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let interceptor = self.interceptor.clone();
let mut inner = self.inner.clone();
Box::pin(async move {
match interceptor.intercept(req).await {
Ok(req) => inner.call(req).await,
Err(status) => {
let response = status.into_http();
Ok(response)
}
}
})
}
}
impl<S, I> NamedService for InterceptorFor<S, I>
where
S: NamedService,
I: RequestInterceptor,
{
const NAME: &'static str = S::NAME;
}
#[derive(Clone)]
pub struct RequestInterceptorLayer<I> {
interceptor: I,
}
impl<I> RequestInterceptorLayer<I> {
pub fn new(interceptor: I) -> Self {
RequestInterceptorLayer { interceptor }
}
}
impl<S, I> Layer<S> for RequestInterceptorLayer<I>
where
I: RequestInterceptor + Clone,
{
type Service = InterceptorFor<S, I>;
fn layer(&self, inner: S) -> Self::Service {
InterceptorFor::new(inner, self.interceptor.clone())
}
}