rest/
middleware.rs

1pub mod auth;
2pub mod gzip;
3pub mod limit;
4pub mod rate;
5pub mod timeout;
6
7pub use crate::http::types::{BoxResponseFuture, HandlerFunc};
8use crate::router::Route;
9pub use gzip::gzip;
10use http::{Request, Response};
11use hyper::Body;
12pub use limit::max_bytes;
13pub use rate::{concurrency_limit, rate_limit};
14use std::future::Future;
15use std::slice;
16use std::sync::Arc;
17pub use timeout::timeout;
18
19/// Middleware: `(req, next) -> Response`.
20pub type Middleware = Arc<dyn Fn(Request<Body>, HandlerFunc) -> BoxResponseFuture + Send + Sync>;
21
22/// Convert closures into `HandlerFunc` (one-time boxing).
23pub trait IntoHandler {
24    fn into_handler(self) -> HandlerFunc;
25}
26
27impl<F, Fut> IntoHandler for F
28where
29    F: Fn(Request<Body>) -> Fut + Send + Sync + 'static,
30    Fut: Future<Output = Response<Body>> + Send + 'static,
31{
32    fn into_handler(self) -> HandlerFunc {
33        HandlerFunc::new(self)
34    }
35}
36
37impl IntoHandler for HandlerFunc {
38    fn into_handler(self) -> HandlerFunc {
39        self
40    }
41}
42
43/// Convert `(req) -> async Response` into `HandlerFunc` (one-time boxing).
44pub fn handler<F, Fut>(f: F) -> HandlerFunc
45where
46    F: Fn(Request<Body>) -> Fut + Send + Sync + 'static,
47    Fut: Future<Output = Response<Body>> + Send + 'static,
48{
49    HandlerFunc::new(f)
50}
51
52/// Convert `(req, next)` into `Middleware`.
53pub fn middleware<F, Fut>(f: F) -> Middleware
54where
55    F: Fn(Request<Body>, HandlerFunc) -> Fut + Send + Sync + 'static,
56    Fut: Future<Output = Response<Body>> + Send + 'static,
57{
58    let f = Arc::new(f);
59    Arc::new(move |req, next| {
60        let f = f.clone();
61        let next = next.clone();
62        Box::pin(async move { f(req, next).await })
63    })
64}
65
66/// Alias for shorter writing.
67pub fn mw<F, Fut>(f: F) -> Middleware
68where
69    F: Fn(Request<Body>, HandlerFunc) -> Fut + Send + Sync + 'static,
70    Fut: Future<Output = Response<Body>> + Send + 'static,
71{
72    middleware(f)
73}
74
75/// Apply middlewares once at registration time, keeping user-facing order.
76pub fn apply_middlewares(handler: HandlerFunc, middlewares: &[Middleware]) -> HandlerFunc {
77    middlewares.iter().rev().fold(handler, |next, mw| {
78        let mw = mw.clone();
79        HandlerFunc::new(move |req| {
80            let mw = mw.clone();
81            let next = next.clone();
82            Box::pin(async move { mw(req, next).await })
83        })
84    })
85}
86
87/// Apply single middleware to routes.
88pub fn with_middleware<R>(middleware: Middleware, routes: R) -> Vec<Route>
89where
90    R: IntoIterator<Item = Route>,
91{
92    routes
93        .into_iter()
94        .map(|route| {
95            let handler = apply_middlewares(route.handler.clone(), slice::from_ref(&middleware));
96            Route { handler, ..route }
97        })
98        .collect()
99}
100
101/// Apply multiple middlewares to routes.
102pub fn with_middlewares<R, I>(middlewares: I, routes: R) -> Vec<Route>
103where
104    I: IntoIterator<Item = Middleware>,
105    R: IntoIterator<Item = Route>,
106{
107    let collected: Vec<Middleware> = middlewares.into_iter().collect();
108    routes
109        .into_iter()
110        .map(|route| {
111            let handler = apply_middlewares(route.handler.clone(), &collected);
112            Route { handler, ..route }
113        })
114        .collect()
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use http::{Method, StatusCode};
121    use tokio::runtime::Runtime;
122
123    fn runtime() -> Runtime {
124        Runtime::new().unwrap()
125    }
126
127    fn ok_handler() -> impl IntoHandler {
128        |_req: Request<Body>| async {
129            Response::builder()
130                .status(StatusCode::OK)
131                .body(Body::empty())
132                .unwrap()
133        }
134    }
135
136    #[test]
137    fn apply_middlewares_should_follow_order() {
138        let h = IntoHandler::into_handler(ok_handler());
139        let m1 = middleware(|req, next| async move {
140            let mut resp = next.call(req).await;
141            resp.headers_mut().append("X-Order", "m1".parse().unwrap());
142            resp
143        });
144        let m2 = middleware(|req, next| async move {
145            let mut resp = next.call(req).await;
146            resp.headers_mut().append("X-Order", "m2".parse().unwrap());
147            resp
148        });
149
150        let wrapped = apply_middlewares(h, &[m1, m2]);
151        let resp = runtime().block_on(
152            wrapped.call(
153                Request::builder()
154                    .method(Method::GET)
155                    .uri("/")
156                    .body(Body::empty())
157                    .unwrap(),
158            ),
159        );
160
161        let header = resp.headers().get_all("X-Order");
162        let mut vals = header.iter().map(|v| v.to_str().unwrap().to_string());
163        assert_eq!(vals.next().unwrap(), "m2");
164        assert_eq!(vals.next().unwrap(), "m1");
165    }
166
167    #[test]
168    fn with_middleware_should_wrap_handlers() {
169        let routes = vec![Route::new(Method::GET, "/", ok_handler())];
170        let wrapped = with_middleware(
171            middleware(|req, next| async move {
172                let mut resp = next.call(req).await;
173                resp.headers_mut().insert("X-Test", "1".parse().unwrap());
174                resp
175            }),
176            routes,
177        );
178
179        assert_eq!(wrapped.len(), 1);
180        let handler = wrapped[0].handler.clone();
181        let resp = runtime().block_on(
182            handler.call(
183                Request::builder()
184                    .method(Method::GET)
185                    .uri("/")
186                    .body(Body::empty())
187                    .unwrap(),
188            ),
189        );
190        assert_eq!(resp.headers().get("X-Test").unwrap().to_str().unwrap(), "1");
191    }
192
193    #[test]
194    fn middleware_fn_should_allow_short_circuit() {
195        let routes = vec![Route::new(Method::GET, "/", ok_handler())];
196        let wrapped = with_middleware(
197            middleware(|_req, _next| async {
198                Response::builder()
199                    .status(StatusCode::FORBIDDEN)
200                    .body(Body::empty())
201                    .unwrap()
202            }),
203            routes,
204        );
205
206        let handler = wrapped[0].handler.clone();
207        let resp = runtime().block_on(
208            handler.call(
209                Request::builder()
210                    .method(Method::GET)
211                    .uri("/")
212                    .body(Body::empty())
213                    .unwrap(),
214            ),
215        );
216        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
217    }
218
219    #[test]
220    fn with_middlewares_should_follow_go_zero_order() {
221        let routes = vec![Route::new(Method::POST, "/test", ok_handler())];
222        let wrapped = with_middlewares(
223            vec![
224                middleware(|req, next| async move {
225                    let mut resp = next.call(req).await;
226                    resp.headers_mut().append("X-Seq", "a".parse().unwrap());
227                    resp
228                }),
229                middleware(|req, next| async move {
230                    let mut resp = next.call(req).await;
231                    resp.headers_mut().append("X-Seq", "b".parse().unwrap());
232                    resp
233                }),
234            ],
235            routes,
236        );
237
238        let handler = wrapped[0].handler.clone();
239        let resp = runtime().block_on(
240            handler.call(
241                Request::builder()
242                    .method(Method::POST)
243                    .uri("/test")
244                    .body(Body::empty())
245                    .unwrap(),
246            ),
247        );
248        let mut vals = resp
249            .headers()
250            .get_all("X-Seq")
251            .iter()
252            .map(|v| v.to_str().unwrap().to_string());
253        assert_eq!(vals.next().unwrap(), "b");
254        assert_eq!(vals.next().unwrap(), "a");
255    }
256}