1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
use std::task::{Context, Poll};

use crate::ServiceBound;
use async_trait::async_trait;
use futures_util::future::BoxFuture;
use tonic::body::BoxBody;
use tonic::codegen::http::Request;
use tonic::codegen::http::Response;
use tonic::codegen::Service;
use tonic::server::NamedService;
use tower::Layer;

/// The `Middleware` trait defines a generic interface for middleware components
/// in a grpc service chain.
/// Implementors of this trait can modify, observe, or otherwise interact with requests and
/// responses in the service pipeline
///
/// If you need just intercept requests, pls can [RequestInterceptor]
///
/// # Type Parameters
///
/// * `S`: A service bound that defines the requirements for the service being wrapped by
/// the middleware.
///
/// See [examples on GitHub](https://github.com/teimuraz/tonic-middleware/tree/main/example)
#[async_trait]
pub trait Middleware<S>
where
    S: ServiceBound,
{
    /// Processes an incoming request and forwards it to the given service.
    ///
    /// Implementations may perform operations before or after forwarding the request,
    /// such as logging, metrics collection, or request modification.
    ///
    /// # Parameters
    ///
    /// * `req`: The incoming request to process.
    /// * `service`: The service to forward the processed request to.
    ///
    /// # Returns
    ///
    /// A `Result` containing the response from the service or an error if one occurred
    /// during processing.
    async fn call(&self, req: Request<BoxBody>, service: S) -> Result<Response<BoxBody>, S::Error>;
}

/// `MiddlewareFor` is a service wrapper that pairs a middleware with its target service.
///
/// # Type Parameters
///
/// * `S`: The service that this middleware is wrapping.
/// * `M`: The middleware that is being applied to the service.
#[derive(Clone)]
pub struct MiddlewareFor<S, M>
where
    S: ServiceBound,
    M: Middleware<S>,
{
    pub inner: S,
    pub middleware: M,
}

impl<S, M> MiddlewareFor<S, M>
where
    S: ServiceBound,
    M: Middleware<S>,
{
    /// Constructs a new `MiddlewareFor` with the given service and middleware.
    ///
    /// # Parameters
    ///
    /// * `inner`: The service that this middleware is wrapping.
    /// * `middleware`: The middleware that is being applied to the service.
    pub fn new(inner: S, middleware: M) -> Self {
        MiddlewareFor { inner, middleware }
    }
}

impl<S, M> Service<Request<BoxBody>> for MiddlewareFor<S, M>
where
    S: ServiceBound,
    S::Future: Send,
    M: Middleware<S> + Send + Clone + 'static + Sync,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: Request<BoxBody>) -> Self::Future {
        let middleware = self.middleware.clone();
        let inner = self.inner.clone();
        Box::pin(async move { middleware.call(req, inner).await })
    }
}

impl<S, M> NamedService for MiddlewareFor<S, M>
where
    S: NamedService + ServiceBound,
    M: Middleware<S>,
{
    const NAME: &'static str = S::NAME;
}

/// `MiddlewareLayer` provides a way to wrap services with a specific middleware using
/// the tower `Layer` trait
#[derive(Clone)]
pub struct MiddlewareLayer<M> {
    middleware: M,
}

impl<M> MiddlewareLayer<M> {
    /// Creates a new `MiddlewareLayer` with the given middleware.
    ///
    /// # Parameters
    ///
    /// * `middleware`: The middleware to apply to services.
    pub fn new(middleware: M) -> Self {
        MiddlewareLayer { middleware }
    }
}

impl<S, M> Layer<S> for MiddlewareLayer<M>
where
    S: ServiceBound,
    M: Middleware<S> + Clone,
{
    type Service = MiddlewareFor<S, M>;

    fn layer(&self, inner: S) -> Self::Service {
        MiddlewareFor::new(inner, self.middleware.clone())
    }
}