1pub mod auth;
2pub mod gzip;
3pub mod limit;
4pub mod rate;
5pub mod timeout;
6
7pub use crate::http::types::{BoxResponseFuture, HandlerFunc};
8use crate::router::Route;
9pub use gzip::gzip;
10use http::{Request, Response};
11use hyper::Body;
12pub use limit::max_bytes;
13pub use rate::{concurrency_limit, rate_limit};
14use std::future::Future;
15use std::slice;
16use std::sync::Arc;
17pub use timeout::timeout;
18
19pub type Middleware = Arc<dyn Fn(Request<Body>, HandlerFunc) -> BoxResponseFuture + Send + Sync>;
21
22pub trait IntoHandler {
24 fn into_handler(self) -> HandlerFunc;
25}
26
27impl<F, Fut> IntoHandler for F
28where
29 F: Fn(Request<Body>) -> Fut + Send + Sync + 'static,
30 Fut: Future<Output = Response<Body>> + Send + 'static,
31{
32 fn into_handler(self) -> HandlerFunc {
33 HandlerFunc::new(self)
34 }
35}
36
37impl IntoHandler for HandlerFunc {
38 fn into_handler(self) -> HandlerFunc {
39 self
40 }
41}
42
43pub fn handler<F, Fut>(f: F) -> HandlerFunc
45where
46 F: Fn(Request<Body>) -> Fut + Send + Sync + 'static,
47 Fut: Future<Output = Response<Body>> + Send + 'static,
48{
49 HandlerFunc::new(f)
50}
51
52pub fn middleware<F, Fut>(f: F) -> Middleware
54where
55 F: Fn(Request<Body>, HandlerFunc) -> Fut + Send + Sync + 'static,
56 Fut: Future<Output = Response<Body>> + Send + 'static,
57{
58 let f = Arc::new(f);
59 Arc::new(move |req, next| {
60 let f = f.clone();
61 let next = next.clone();
62 Box::pin(async move { f(req, next).await })
63 })
64}
65
66pub fn mw<F, Fut>(f: F) -> Middleware
68where
69 F: Fn(Request<Body>, HandlerFunc) -> Fut + Send + Sync + 'static,
70 Fut: Future<Output = Response<Body>> + Send + 'static,
71{
72 middleware(f)
73}
74
75pub fn apply_middlewares(handler: HandlerFunc, middlewares: &[Middleware]) -> HandlerFunc {
77 middlewares.iter().rev().fold(handler, |next, mw| {
78 let mw = mw.clone();
79 HandlerFunc::new(move |req| {
80 let mw = mw.clone();
81 let next = next.clone();
82 Box::pin(async move { mw(req, next).await })
83 })
84 })
85}
86
87pub fn with_middleware<R>(middleware: Middleware, routes: R) -> Vec<Route>
89where
90 R: IntoIterator<Item = Route>,
91{
92 routes
93 .into_iter()
94 .map(|route| {
95 let handler = apply_middlewares(route.handler.clone(), slice::from_ref(&middleware));
96 Route { handler, ..route }
97 })
98 .collect()
99}
100
101pub fn with_middlewares<R, I>(middlewares: I, routes: R) -> Vec<Route>
103where
104 I: IntoIterator<Item = Middleware>,
105 R: IntoIterator<Item = Route>,
106{
107 let collected: Vec<Middleware> = middlewares.into_iter().collect();
108 routes
109 .into_iter()
110 .map(|route| {
111 let handler = apply_middlewares(route.handler.clone(), &collected);
112 Route { handler, ..route }
113 })
114 .collect()
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use http::{Method, StatusCode};
121 use tokio::runtime::Runtime;
122
123 fn runtime() -> Runtime {
124 Runtime::new().unwrap()
125 }
126
127 fn ok_handler() -> impl IntoHandler {
128 |_req: Request<Body>| async {
129 Response::builder()
130 .status(StatusCode::OK)
131 .body(Body::empty())
132 .unwrap()
133 }
134 }
135
136 #[test]
137 fn apply_middlewares_should_follow_order() {
138 let h = IntoHandler::into_handler(ok_handler());
139 let m1 = middleware(|req, next| async move {
140 let mut resp = next.call(req).await;
141 resp.headers_mut().append("X-Order", "m1".parse().unwrap());
142 resp
143 });
144 let m2 = middleware(|req, next| async move {
145 let mut resp = next.call(req).await;
146 resp.headers_mut().append("X-Order", "m2".parse().unwrap());
147 resp
148 });
149
150 let wrapped = apply_middlewares(h, &[m1, m2]);
151 let resp = runtime().block_on(
152 wrapped.call(
153 Request::builder()
154 .method(Method::GET)
155 .uri("/")
156 .body(Body::empty())
157 .unwrap(),
158 ),
159 );
160
161 let header = resp.headers().get_all("X-Order");
162 let mut vals = header.iter().map(|v| v.to_str().unwrap().to_string());
163 assert_eq!(vals.next().unwrap(), "m2");
164 assert_eq!(vals.next().unwrap(), "m1");
165 }
166
167 #[test]
168 fn with_middleware_should_wrap_handlers() {
169 let routes = vec![Route::new(Method::GET, "/", ok_handler())];
170 let wrapped = with_middleware(
171 middleware(|req, next| async move {
172 let mut resp = next.call(req).await;
173 resp.headers_mut().insert("X-Test", "1".parse().unwrap());
174 resp
175 }),
176 routes,
177 );
178
179 assert_eq!(wrapped.len(), 1);
180 let handler = wrapped[0].handler.clone();
181 let resp = runtime().block_on(
182 handler.call(
183 Request::builder()
184 .method(Method::GET)
185 .uri("/")
186 .body(Body::empty())
187 .unwrap(),
188 ),
189 );
190 assert_eq!(resp.headers().get("X-Test").unwrap().to_str().unwrap(), "1");
191 }
192
193 #[test]
194 fn middleware_fn_should_allow_short_circuit() {
195 let routes = vec![Route::new(Method::GET, "/", ok_handler())];
196 let wrapped = with_middleware(
197 middleware(|_req, _next| async {
198 Response::builder()
199 .status(StatusCode::FORBIDDEN)
200 .body(Body::empty())
201 .unwrap()
202 }),
203 routes,
204 );
205
206 let handler = wrapped[0].handler.clone();
207 let resp = runtime().block_on(
208 handler.call(
209 Request::builder()
210 .method(Method::GET)
211 .uri("/")
212 .body(Body::empty())
213 .unwrap(),
214 ),
215 );
216 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
217 }
218
219 #[test]
220 fn with_middlewares_should_follow_go_zero_order() {
221 let routes = vec![Route::new(Method::POST, "/test", ok_handler())];
222 let wrapped = with_middlewares(
223 vec![
224 middleware(|req, next| async move {
225 let mut resp = next.call(req).await;
226 resp.headers_mut().append("X-Seq", "a".parse().unwrap());
227 resp
228 }),
229 middleware(|req, next| async move {
230 let mut resp = next.call(req).await;
231 resp.headers_mut().append("X-Seq", "b".parse().unwrap());
232 resp
233 }),
234 ],
235 routes,
236 );
237
238 let handler = wrapped[0].handler.clone();
239 let resp = runtime().block_on(
240 handler.call(
241 Request::builder()
242 .method(Method::POST)
243 .uri("/test")
244 .body(Body::empty())
245 .unwrap(),
246 ),
247 );
248 let mut vals = resp
249 .headers()
250 .get_all("X-Seq")
251 .iter()
252 .map(|v| v.to_str().unwrap().to_string());
253 assert_eq!(vals.next().unwrap(), "b");
254 assert_eq!(vals.next().unwrap(), "a");
255 }
256}