conduit_middleware/
lib.rs

1#![warn(rust_2018_idioms)]
2use conduit::{BoxError, Handler, RequestExt};
3
4pub type BeforeResult = Result<(), BoxError>;
5pub type AfterResult = conduit::HandlerResult;
6
7pub trait Middleware: Send + Sync + 'static {
8    fn before(&self, _: &mut dyn RequestExt) -> BeforeResult {
9        Ok(())
10    }
11
12    fn after(&self, _: &mut dyn RequestExt, res: AfterResult) -> AfterResult {
13        res
14    }
15}
16
17pub trait AroundMiddleware: Handler {
18    fn with_handler(&mut self, handler: Box<dyn Handler>);
19}
20
21pub struct MiddlewareBuilder {
22    middlewares: Vec<Box<dyn Middleware>>,
23    handler: Option<Box<dyn Handler>>,
24}
25
26impl MiddlewareBuilder {
27    pub fn new<H: Handler>(handler: H) -> MiddlewareBuilder {
28        MiddlewareBuilder {
29            middlewares: vec![],
30            handler: Some(Box::new(handler) as Box<dyn Handler>),
31        }
32    }
33
34    pub fn add<M: Middleware>(&mut self, middleware: M) {
35        self.middlewares
36            .push(Box::new(middleware) as Box<dyn Middleware>);
37    }
38
39    pub fn around<M: AroundMiddleware>(&mut self, mut middleware: M) {
40        let handler = self.handler.take().unwrap();
41        middleware.with_handler(handler);
42        self.handler = Some(Box::new(middleware) as Box<dyn Handler>);
43    }
44}
45
46impl Handler for MiddlewareBuilder {
47    fn call(&self, req: &mut dyn RequestExt) -> AfterResult {
48        let mut error = None;
49
50        for (i, middleware) in self.middlewares.iter().enumerate() {
51            match middleware.before(req) {
52                Ok(_) => (),
53                Err(err) => {
54                    error = Some((err, i));
55                    break;
56                }
57            }
58        }
59
60        match error {
61            Some((err, i)) => {
62                let middlewares = &self.middlewares[..i];
63                run_afters(middlewares, req, Err(err))
64            }
65            None => {
66                let res = { self.handler.as_ref().unwrap().call(req) };
67                let middlewares = &self.middlewares;
68
69                run_afters(middlewares, req, res)
70            }
71        }
72    }
73}
74
75fn run_afters(
76    middleware: &[Box<dyn Middleware>],
77    req: &mut dyn RequestExt,
78    res: AfterResult,
79) -> AfterResult {
80    middleware
81        .iter()
82        .rev()
83        .fold(res, |res, m| m.after(req, res))
84}
85
86#[cfg(test)]
87mod tests {
88    use super::{AfterResult, AroundMiddleware, BeforeResult, Middleware, MiddlewareBuilder};
89
90    use std::any::Any;
91    use std::io;
92    use std::io::prelude::*;
93    use std::net::SocketAddr;
94
95    use conduit_test::ResponseExt;
96
97    use conduit::{
98        box_error, Body, Extensions, Handler, HeaderMap, Host, Method, RequestExt, Response,
99        Scheme, StatusCode, Version,
100    };
101
102    struct RequestSentinel {
103        path: String,
104        extensions: Extensions,
105        method: Method,
106    }
107
108    impl RequestSentinel {
109        fn new(method: Method, path: &'static str) -> RequestSentinel {
110            RequestSentinel {
111                path: path.to_string(),
112                extensions: Extensions::new(),
113                method,
114            }
115        }
116    }
117
118    impl conduit::RequestExt for RequestSentinel {
119        fn http_version(&self) -> Version {
120            unimplemented!()
121        }
122        fn method(&self) -> &Method {
123            &self.method
124        }
125        fn scheme(&self) -> Scheme {
126            unimplemented!()
127        }
128        fn host(&self) -> Host<'_> {
129            unimplemented!()
130        }
131        fn virtual_root(&self) -> Option<&str> {
132            unimplemented!()
133        }
134        fn path(&self) -> &str {
135            &self.path
136        }
137        fn path_mut(&mut self) -> &mut String {
138            &mut self.path
139        }
140        fn query_string(&self) -> Option<&str> {
141            unimplemented!()
142        }
143        fn remote_addr(&self) -> SocketAddr {
144            unimplemented!()
145        }
146        fn content_length(&self) -> Option<u64> {
147            unimplemented!()
148        }
149        fn headers(&self) -> &HeaderMap {
150            unimplemented!()
151        }
152        fn body(&mut self) -> &mut dyn Read {
153            unimplemented!()
154        }
155        fn extensions(&self) -> &Extensions {
156            &self.extensions
157        }
158        fn mut_extensions(&mut self) -> &mut Extensions {
159            &mut self.extensions
160        }
161    }
162
163    struct MyMiddleware;
164
165    impl Middleware for MyMiddleware {
166        fn before<'a>(&self, req: &'a mut dyn RequestExt) -> BeforeResult {
167            req.mut_extensions().insert("hello".to_string());
168            Ok(())
169        }
170    }
171
172    struct ErrorRecovery;
173
174    impl Middleware for ErrorRecovery {
175        fn after(&self, _: &mut dyn RequestExt, res: AfterResult) -> AfterResult {
176            res.or_else(|e| {
177                let e = e.to_string().into_bytes();
178                Response::builder()
179                    .status(StatusCode::INTERNAL_SERVER_ERROR)
180                    .body(Body::from_vec(e))
181                    .map_err(box_error)
182            })
183        }
184    }
185
186    struct ProducesError;
187
188    impl Middleware for ProducesError {
189        fn before(&self, _: &mut dyn RequestExt) -> BeforeResult {
190            Err(Box::new(io::Error::new(io::ErrorKind::Other, "")))
191        }
192    }
193
194    struct NotReached;
195
196    impl Middleware for NotReached {
197        fn after(&self, _: &mut dyn RequestExt, _: AfterResult) -> AfterResult {
198            Response::builder().body(Body::empty()).map_err(box_error)
199        }
200    }
201
202    struct MyAroundMiddleware {
203        handler: Option<Box<dyn Handler>>,
204    }
205
206    impl MyAroundMiddleware {
207        fn new() -> MyAroundMiddleware {
208            MyAroundMiddleware { handler: None }
209        }
210    }
211
212    impl Middleware for MyAroundMiddleware {}
213
214    impl AroundMiddleware for MyAroundMiddleware {
215        fn with_handler(&mut self, handler: Box<dyn Handler>) {
216            self.handler = Some(handler)
217        }
218    }
219
220    impl Handler for MyAroundMiddleware {
221        fn call(&self, req: &mut dyn RequestExt) -> AfterResult {
222            req.mut_extensions().insert("hello".to_string());
223            self.handler.as_ref().unwrap().call(req)
224        }
225    }
226
227    fn get_extension<T: Any + Send + Sync>(req: &dyn RequestExt) -> &T {
228        req.extensions().get::<T>().unwrap()
229    }
230
231    fn response(string: String) -> Response<Body> {
232        Response::builder()
233            .body(Body::from_vec(string.into_bytes()))
234            .unwrap()
235    }
236
237    fn handler(req: &mut dyn RequestExt) -> io::Result<Response<Body>> {
238        let hello = get_extension::<String>(req);
239        Ok(response(hello.clone()))
240    }
241
242    fn error_handler(_: &mut dyn RequestExt) -> io::Result<Response<Body>> {
243        Err(io::Error::new(io::ErrorKind::Other, "Error in handler"))
244    }
245
246    fn middle_handler(req: &mut dyn RequestExt) -> io::Result<Response<Body>> {
247        let hello = get_extension::<String>(req);
248        let middle = get_extension::<String>(req);
249
250        Ok(response(format!("{} {}", hello, middle)))
251    }
252
253    #[test]
254    fn test_simple_middleware() {
255        let mut builder = MiddlewareBuilder::new(handler);
256        builder.add(MyMiddleware);
257
258        let mut req = RequestSentinel::new(Method::GET, "/");
259        let res = builder.call(&mut req).expect("No response");
260
261        assert_eq!(*res.into_cow(), b"hello"[..]);
262    }
263
264    #[test]
265    fn test_error_recovery() {
266        let mut builder = MiddlewareBuilder::new(handler);
267        builder.add(ErrorRecovery);
268        builder.add(ProducesError);
269        // the error bubbles up from ProducesError and shouldn't reach here
270        builder.add(NotReached);
271
272        let mut req = RequestSentinel::new(Method::GET, "/");
273        let res = builder.call(&mut req).expect("Error not handled");
274
275        assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
276    }
277
278    #[test]
279    fn test_error_recovery_in_handlers() {
280        let mut builder = MiddlewareBuilder::new(error_handler);
281        builder.add(ErrorRecovery);
282
283        let mut req = RequestSentinel::new(Method::GET, "/");
284        let res = builder.call(&mut req).expect("Error not handled");
285
286        assert_eq!(*res.into_cow(), b"Error in handler"[..]);
287    }
288
289    #[test]
290    fn test_around_middleware() {
291        let mut builder = MiddlewareBuilder::new(middle_handler);
292        builder.add(MyMiddleware);
293        builder.around(MyAroundMiddleware::new());
294
295        let mut req = RequestSentinel::new(Method::GET, "/");
296        let res = builder.call(&mut req).expect("No response");
297
298        assert_eq!(*res.into_cow(), b"hello hello"[..]);
299    }
300}