1use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10
11use hyper::Method;
12
13use crate::error::{Error, Result};
14use crate::http::{response_from_error, Request, Response};
15
16pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
17
18pub type HandlerFn =
19 Arc<dyn Fn(Request) -> BoxFuture<'static, Result<Response>> + Send + Sync + 'static>;
20
21pub type MiddlewareFn =
22 Arc<dyn Fn(Request, Next) -> BoxFuture<'static, Result<Response>> + Send + Sync + 'static>;
23
24pub struct Next {
25 chain: Vec<MiddlewareFn>,
26 handler: HandlerFn,
27 index: usize,
28}
29
30impl Next {
31 pub fn run(mut self, req: Request) -> BoxFuture<'static, Result<Response>> {
32 Box::pin(async move {
33 if self.index < self.chain.len() {
34 let mw = self.chain[self.index].clone();
35 self.index += 1;
36 mw(req, self).await
37 } else {
38 (self.handler)(req).await
39 }
40 })
41 }
42}
43
44struct Route {
45 method: Method,
46 segments: Vec<Segment>,
47 handler: HandlerFn,
48}
49
50enum Segment {
51 Static(String),
52 Param(String),
53}
54
55pub struct Router {
56 routes: Vec<Route>,
57 middleware: Vec<MiddlewareFn>,
58}
59
60impl Default for Router {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66impl Router {
67 pub fn new() -> Self {
68 Self {
69 routes: Vec::new(),
70 middleware: Vec::new(),
71 }
72 }
73
74 pub fn middleware<F, Fut>(mut self, mw: F) -> Self
75 where
76 F: Fn(Request, Next) -> Fut + Send + Sync + 'static,
77 Fut: Future<Output = Result<Response>> + Send + 'static,
78 {
79 let wrapped: MiddlewareFn = Arc::new(move |req, next| Box::pin(mw(req, next)));
80 self.middleware.push(wrapped);
81 self
82 }
83
84 pub fn get<F, Fut>(self, path: &str, handler: F) -> Self
85 where
86 F: Fn(Request) -> Fut + Send + Sync + 'static,
87 Fut: Future<Output = Result<Response>> + Send + 'static,
88 {
89 self.route(Method::GET, path, handler)
90 }
91
92 pub fn post<F, Fut>(self, path: &str, handler: F) -> Self
93 where
94 F: Fn(Request) -> Fut + Send + Sync + 'static,
95 Fut: Future<Output = Result<Response>> + Send + 'static,
96 {
97 self.route(Method::POST, path, handler)
98 }
99
100 pub fn route<F, Fut>(mut self, method: Method, path: &str, handler: F) -> Self
101 where
102 F: Fn(Request) -> Fut + Send + Sync + 'static,
103 Fut: Future<Output = Result<Response>> + Send + 'static,
104 {
105 let segments = parse_path(path);
106 let handler: HandlerFn = Arc::new(move |req| Box::pin(handler(req)));
107 self.routes.push(Route {
108 method,
109 segments,
110 handler,
111 });
112 self
113 }
114
115 fn find(&self, method: &Method, path: &str) -> MatchResult {
117 let mut path_segs: Vec<&str> = path.trim_start_matches('/').split('/').collect();
118 if path_segs.len() > 1 && path_segs.last() == Some(&"") {
122 path_segs.pop();
123 }
124 let mut path_matched = false;
125
126 for route in &self.routes {
127 if !segments_match(&route.segments, &path_segs) {
128 continue;
129 }
130 path_matched = true;
131 if route.method == *method {
132 let params = extract_params(&route.segments, &path_segs);
133 return MatchResult::Ok {
134 handler: route.handler.clone(),
135 params,
136 };
137 }
138 }
139
140 if path_matched {
141 MatchResult::MethodNotAllowed
142 } else {
143 MatchResult::NotFound
144 }
145 }
146
147 pub async fn dispatch(&self, mut req: Request) -> Response {
148 let matched = self.find(req.method(), req.path());
149
150 let outcome = match matched {
151 MatchResult::Ok { handler, params } => {
152 req.set_params(params);
153 let next = Next {
154 chain: self.middleware.clone(),
155 handler,
156 index: 0,
157 };
158 next.run(req).await
159 }
160 MatchResult::NotFound => Err(Error::NotFound(format!("no route for {}", req.path()))),
161 MatchResult::MethodNotAllowed => Err(Error::MethodNotAllowed(format!(
162 "{} not allowed",
163 req.method()
164 ))),
165 };
166
167 match outcome {
168 Ok(resp) => resp,
169 Err(err) => response_from_error(&err),
170 }
171 }
172}
173
174enum MatchResult {
175 Ok {
176 handler: HandlerFn,
177 params: std::collections::HashMap<String, String>,
178 },
179 NotFound,
180 MethodNotAllowed,
181}
182
183fn parse_path(path: &str) -> Vec<Segment> {
184 path.trim_start_matches('/')
185 .split('/')
186 .map(|seg| {
187 if let Some(name) = seg.strip_prefix(':') {
188 Segment::Param(name.to_string())
189 } else {
190 Segment::Static(seg.to_string())
191 }
192 })
193 .collect()
194}
195
196fn segments_match(route: &[Segment], path: &[&str]) -> bool {
197 if route.len() != path.len() {
198 return false;
199 }
200 for (r, p) in route.iter().zip(path.iter()) {
201 match r {
202 Segment::Static(s) if s != p => return false,
203 _ => {}
204 }
205 }
206 true
207}
208
209fn extract_params(route: &[Segment], path: &[&str]) -> std::collections::HashMap<String, String> {
210 let mut out = std::collections::HashMap::new();
211 for (r, p) in route.iter().zip(path.iter()) {
212 if let Segment::Param(name) = r {
213 out.insert(name.clone(), (*p).to_string());
214 }
215 }
216 out
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 #[tokio::test]
224 async fn matches_static_path() {
225 let router = Router::new().get("/hello", |_req| async { Ok(Response::text("hi")) });
226 let req = Request::new(
227 Method::GET,
228 "/hello".into(),
229 String::new(),
230 Default::default(),
231 bytes::Bytes::new(),
232 );
233 let resp = router.dispatch(req).await;
234 assert_eq!(resp.status.as_u16(), 200);
235 }
236
237 #[tokio::test]
238 async fn captures_param() {
239 let router = Router::new().get("/users/:id", |req| async move {
240 let id = req.param("id").unwrap_or("").to_string();
241 Ok(Response::text(id))
242 });
243 let req = Request::new(
244 Method::GET,
245 "/users/42".into(),
246 String::new(),
247 Default::default(),
248 bytes::Bytes::new(),
249 );
250 let resp = router.dispatch(req).await;
251 assert_eq!(resp.status.as_u16(), 200);
252 assert_eq!(&resp.body[..], b"42");
253 }
254
255 #[tokio::test]
256 async fn distinguishes_404_from_405() {
257 let router = Router::new().get("/things", |_| async { Ok(Response::text("ok")) });
258
259 let post = Request::new(
260 Method::POST,
261 "/things".into(),
262 String::new(),
263 Default::default(),
264 bytes::Bytes::new(),
265 );
266 assert_eq!(router.dispatch(post).await.status.as_u16(), 405);
267
268 let missing = Request::new(
269 Method::GET,
270 "/nope".into(),
271 String::new(),
272 Default::default(),
273 bytes::Bytes::new(),
274 );
275 assert_eq!(router.dispatch(missing).await.status.as_u16(), 404);
276 }
277
278 #[tokio::test]
279 async fn trailing_slash_is_normalised_for_static_and_param_routes() {
280 let router = Router::new()
281 .get("/admin/:name", |req| async move {
282 let name = req.param("name").unwrap_or("").to_string();
283 Ok(Response::text(name))
284 })
285 .get("/admin/:name/:id/edit", |req| async move {
286 Ok(Response::text(format!(
287 "{}/{}",
288 req.param("name").unwrap_or(""),
289 req.param("id").unwrap_or(""),
290 )))
291 });
292
293 for path in ["/admin/posts", "/admin/posts/"] {
294 let req = Request::new(
295 Method::GET,
296 path.into(),
297 String::new(),
298 Default::default(),
299 bytes::Bytes::new(),
300 );
301 let resp = router.dispatch(req).await;
302 assert_eq!(resp.status.as_u16(), 200, "GET {path} should be 200");
303 assert_eq!(&resp.body[..], b"posts", "GET {path} body");
304 }
305
306 for path in ["/admin/posts/1/edit", "/admin/posts/1/edit/"] {
307 let req = Request::new(
308 Method::GET,
309 path.into(),
310 String::new(),
311 Default::default(),
312 bytes::Bytes::new(),
313 );
314 let resp = router.dispatch(req).await;
315 assert_eq!(resp.status.as_u16(), 200, "GET {path} should be 200");
316 assert_eq!(&resp.body[..], b"posts/1", "GET {path} body");
317 }
318 }
319
320 #[tokio::test]
321 async fn root_path_still_matches_after_trailing_slash_normalisation() {
322 let router = Router::new().get("/", |_| async { Ok(Response::text("home")) });
325 let req = Request::new(
326 Method::GET,
327 "/".into(),
328 String::new(),
329 Default::default(),
330 bytes::Bytes::new(),
331 );
332 let resp = router.dispatch(req).await;
333 assert_eq!(resp.status.as_u16(), 200);
334 assert_eq!(&resp.body[..], b"home");
335 }
336}