conduit_middleware/
lib.rs1#![warn(rust_2018_idioms)]
2use conduit::{BoxError, Handler, RequestExt};
3
4pub type BeforeResult = Result<(), BoxError>;
5pub type AfterResult = conduit::HandlerResult;
6
7pub trait Middleware: Send + Sync + 'static {
8 fn before(&self, _: &mut dyn RequestExt) -> BeforeResult {
9 Ok(())
10 }
11
12 fn after(&self, _: &mut dyn RequestExt, res: AfterResult) -> AfterResult {
13 res
14 }
15}
16
17pub trait AroundMiddleware: Handler {
18 fn with_handler(&mut self, handler: Box<dyn Handler>);
19}
20
21pub struct MiddlewareBuilder {
22 middlewares: Vec<Box<dyn Middleware>>,
23 handler: Option<Box<dyn Handler>>,
24}
25
26impl MiddlewareBuilder {
27 pub fn new<H: Handler>(handler: H) -> MiddlewareBuilder {
28 MiddlewareBuilder {
29 middlewares: vec![],
30 handler: Some(Box::new(handler) as Box<dyn Handler>),
31 }
32 }
33
34 pub fn add<M: Middleware>(&mut self, middleware: M) {
35 self.middlewares
36 .push(Box::new(middleware) as Box<dyn Middleware>);
37 }
38
39 pub fn around<M: AroundMiddleware>(&mut self, mut middleware: M) {
40 let handler = self.handler.take().unwrap();
41 middleware.with_handler(handler);
42 self.handler = Some(Box::new(middleware) as Box<dyn Handler>);
43 }
44}
45
46impl Handler for MiddlewareBuilder {
47 fn call(&self, req: &mut dyn RequestExt) -> AfterResult {
48 let mut error = None;
49
50 for (i, middleware) in self.middlewares.iter().enumerate() {
51 match middleware.before(req) {
52 Ok(_) => (),
53 Err(err) => {
54 error = Some((err, i));
55 break;
56 }
57 }
58 }
59
60 match error {
61 Some((err, i)) => {
62 let middlewares = &self.middlewares[..i];
63 run_afters(middlewares, req, Err(err))
64 }
65 None => {
66 let res = { self.handler.as_ref().unwrap().call(req) };
67 let middlewares = &self.middlewares;
68
69 run_afters(middlewares, req, res)
70 }
71 }
72 }
73}
74
75fn run_afters(
76 middleware: &[Box<dyn Middleware>],
77 req: &mut dyn RequestExt,
78 res: AfterResult,
79) -> AfterResult {
80 middleware
81 .iter()
82 .rev()
83 .fold(res, |res, m| m.after(req, res))
84}
85
86#[cfg(test)]
87mod tests {
88 use super::{AfterResult, AroundMiddleware, BeforeResult, Middleware, MiddlewareBuilder};
89
90 use std::any::Any;
91 use std::io;
92 use std::io::prelude::*;
93 use std::net::SocketAddr;
94
95 use conduit_test::ResponseExt;
96
97 use conduit::{
98 box_error, Body, Extensions, Handler, HeaderMap, Host, Method, RequestExt, Response,
99 Scheme, StatusCode, Version,
100 };
101
102 struct RequestSentinel {
103 path: String,
104 extensions: Extensions,
105 method: Method,
106 }
107
108 impl RequestSentinel {
109 fn new(method: Method, path: &'static str) -> RequestSentinel {
110 RequestSentinel {
111 path: path.to_string(),
112 extensions: Extensions::new(),
113 method,
114 }
115 }
116 }
117
118 impl conduit::RequestExt for RequestSentinel {
119 fn http_version(&self) -> Version {
120 unimplemented!()
121 }
122 fn method(&self) -> &Method {
123 &self.method
124 }
125 fn scheme(&self) -> Scheme {
126 unimplemented!()
127 }
128 fn host(&self) -> Host<'_> {
129 unimplemented!()
130 }
131 fn virtual_root(&self) -> Option<&str> {
132 unimplemented!()
133 }
134 fn path(&self) -> &str {
135 &self.path
136 }
137 fn path_mut(&mut self) -> &mut String {
138 &mut self.path
139 }
140 fn query_string(&self) -> Option<&str> {
141 unimplemented!()
142 }
143 fn remote_addr(&self) -> SocketAddr {
144 unimplemented!()
145 }
146 fn content_length(&self) -> Option<u64> {
147 unimplemented!()
148 }
149 fn headers(&self) -> &HeaderMap {
150 unimplemented!()
151 }
152 fn body(&mut self) -> &mut dyn Read {
153 unimplemented!()
154 }
155 fn extensions(&self) -> &Extensions {
156 &self.extensions
157 }
158 fn mut_extensions(&mut self) -> &mut Extensions {
159 &mut self.extensions
160 }
161 }
162
163 struct MyMiddleware;
164
165 impl Middleware for MyMiddleware {
166 fn before<'a>(&self, req: &'a mut dyn RequestExt) -> BeforeResult {
167 req.mut_extensions().insert("hello".to_string());
168 Ok(())
169 }
170 }
171
172 struct ErrorRecovery;
173
174 impl Middleware for ErrorRecovery {
175 fn after(&self, _: &mut dyn RequestExt, res: AfterResult) -> AfterResult {
176 res.or_else(|e| {
177 let e = e.to_string().into_bytes();
178 Response::builder()
179 .status(StatusCode::INTERNAL_SERVER_ERROR)
180 .body(Body::from_vec(e))
181 .map_err(box_error)
182 })
183 }
184 }
185
186 struct ProducesError;
187
188 impl Middleware for ProducesError {
189 fn before(&self, _: &mut dyn RequestExt) -> BeforeResult {
190 Err(Box::new(io::Error::new(io::ErrorKind::Other, "")))
191 }
192 }
193
194 struct NotReached;
195
196 impl Middleware for NotReached {
197 fn after(&self, _: &mut dyn RequestExt, _: AfterResult) -> AfterResult {
198 Response::builder().body(Body::empty()).map_err(box_error)
199 }
200 }
201
202 struct MyAroundMiddleware {
203 handler: Option<Box<dyn Handler>>,
204 }
205
206 impl MyAroundMiddleware {
207 fn new() -> MyAroundMiddleware {
208 MyAroundMiddleware { handler: None }
209 }
210 }
211
212 impl Middleware for MyAroundMiddleware {}
213
214 impl AroundMiddleware for MyAroundMiddleware {
215 fn with_handler(&mut self, handler: Box<dyn Handler>) {
216 self.handler = Some(handler)
217 }
218 }
219
220 impl Handler for MyAroundMiddleware {
221 fn call(&self, req: &mut dyn RequestExt) -> AfterResult {
222 req.mut_extensions().insert("hello".to_string());
223 self.handler.as_ref().unwrap().call(req)
224 }
225 }
226
227 fn get_extension<T: Any + Send + Sync>(req: &dyn RequestExt) -> &T {
228 req.extensions().get::<T>().unwrap()
229 }
230
231 fn response(string: String) -> Response<Body> {
232 Response::builder()
233 .body(Body::from_vec(string.into_bytes()))
234 .unwrap()
235 }
236
237 fn handler(req: &mut dyn RequestExt) -> io::Result<Response<Body>> {
238 let hello = get_extension::<String>(req);
239 Ok(response(hello.clone()))
240 }
241
242 fn error_handler(_: &mut dyn RequestExt) -> io::Result<Response<Body>> {
243 Err(io::Error::new(io::ErrorKind::Other, "Error in handler"))
244 }
245
246 fn middle_handler(req: &mut dyn RequestExt) -> io::Result<Response<Body>> {
247 let hello = get_extension::<String>(req);
248 let middle = get_extension::<String>(req);
249
250 Ok(response(format!("{} {}", hello, middle)))
251 }
252
253 #[test]
254 fn test_simple_middleware() {
255 let mut builder = MiddlewareBuilder::new(handler);
256 builder.add(MyMiddleware);
257
258 let mut req = RequestSentinel::new(Method::GET, "/");
259 let res = builder.call(&mut req).expect("No response");
260
261 assert_eq!(*res.into_cow(), b"hello"[..]);
262 }
263
264 #[test]
265 fn test_error_recovery() {
266 let mut builder = MiddlewareBuilder::new(handler);
267 builder.add(ErrorRecovery);
268 builder.add(ProducesError);
269 builder.add(NotReached);
271
272 let mut req = RequestSentinel::new(Method::GET, "/");
273 let res = builder.call(&mut req).expect("Error not handled");
274
275 assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR);
276 }
277
278 #[test]
279 fn test_error_recovery_in_handlers() {
280 let mut builder = MiddlewareBuilder::new(error_handler);
281 builder.add(ErrorRecovery);
282
283 let mut req = RequestSentinel::new(Method::GET, "/");
284 let res = builder.call(&mut req).expect("Error not handled");
285
286 assert_eq!(*res.into_cow(), b"Error in handler"[..]);
287 }
288
289 #[test]
290 fn test_around_middleware() {
291 let mut builder = MiddlewareBuilder::new(middle_handler);
292 builder.add(MyMiddleware);
293 builder.around(MyAroundMiddleware::new());
294
295 let mut req = RequestSentinel::new(Method::GET, "/");
296 let res = builder.call(&mut req).expect("No response");
297
298 assert_eq!(*res.into_cow(), b"hello hello"[..]);
299 }
300}