rest/
chain.rs

1use crate::http::HandlerFunc;
2use crate::middleware::{Middleware, apply_middlewares, handler};
3
4/// Chain builder for middlewares.
5#[derive(Clone, Default)]
6pub struct Chain {
7    mws: Vec<Middleware>,
8}
9
10impl Chain {
11    /// Create empty chain.
12    /// Create an empty chain.
13    pub fn new() -> Self {
14        Self { mws: Vec::new() }
15    }
16
17    /// Append middleware to chain.
18    /// Append a middleware.
19    pub fn append(mut self, mw: Middleware) -> Self {
20        self.mws.push(mw);
21        self
22    }
23
24    /// Append middleware mutably.
25    /// Append by mutable reference.
26    pub fn append_mut(&mut self, mw: Middleware) {
27        self.mws.push(mw);
28    }
29
30    /// Wrap handler with chain middlewares.
31    pub fn then(self, handler: HandlerFunc) -> HandlerFunc {
32        apply_middlewares(handler, &self.mws)
33    }
34
35    /// Wrap async fn -> HandlerFunc and return HandlerFunc.
36    pub fn then_func<F, Fut>(self, f: F) -> HandlerFunc
37    where
38        F: Fn(http::Request<hyper::Body>) -> Fut + Send + Sync + 'static,
39        Fut: std::future::Future<Output = http::Response<hyper::Body>> + Send + 'static,
40    {
41        self.then(handler(f))
42    }
43}
44
45#[cfg(test)]
46mod tests {
47    use super::*;
48    use http::{Method, StatusCode};
49    use hyper::{Body, Response};
50    use tokio::runtime::Runtime;
51
52    fn runtime() -> Runtime {
53        Runtime::new().unwrap()
54    }
55
56    fn ok_handler() -> HandlerFunc {
57        handler(|_req: http::Request<Body>| async {
58            Response::builder()
59                .status(StatusCode::OK)
60                .body(Body::empty())
61                .unwrap()
62        })
63    }
64
65    #[test]
66    fn chain_should_apply_middlewares_in_order() {
67        runtime().block_on(async {
68            let chain =
69                Chain::new().append(crate::middleware::middleware(|req, next| async move {
70                    let mut resp = next.call(req).await;
71                    resp.headers_mut().append("X-C", "a".parse().unwrap());
72                    resp
73                }));
74            let h = chain.then(ok_handler());
75            let resp = h
76                .call(
77                    http::Request::builder()
78                        .method(Method::GET)
79                        .uri("/")
80                        .body(Body::empty())
81                        .unwrap(),
82                )
83                .await;
84            assert_eq!(resp.headers().get("X-C").unwrap().to_str().unwrap(), "a");
85        });
86    }
87}