Skip to main content

rustio_admin/
router.rs

1//! A small, opinionated router.
2//!
3//! - Path segments starting with `:` are captured into `req.param(name)`.
4//! - Middleware is a chain of `async fn(Request, Next) -> Result<Response>`.
5//! - 404 vs 405 is distinguished (path matched but method didn't → 405).
6
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10
11use hyper::Method;
12
13use crate::error::{Error, Result};
14use crate::http::{response_from_error, Request, Response};
15
16// public:
17pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
18
19// public:
20pub type HandlerFn =
21    Arc<dyn Fn(Request) -> BoxFuture<'static, Result<Response>> + Send + Sync + 'static>;
22
23// public:
24pub type MiddlewareFn =
25    Arc<dyn Fn(Request, Next) -> BoxFuture<'static, Result<Response>> + Send + Sync + 'static>;
26
27// public:
28pub struct Next {
29    chain: Vec<MiddlewareFn>,
30    handler: HandlerFn,
31    index: usize,
32}
33
34impl Next {
35    // public:
36    pub fn run(mut self, req: Request) -> BoxFuture<'static, Result<Response>> {
37        Box::pin(async move {
38            if self.index < self.chain.len() {
39                let mw = self.chain[self.index].clone();
40                self.index += 1;
41                mw(req, self).await
42            } else {
43                (self.handler)(req).await
44            }
45        })
46    }
47}
48
49struct Route {
50    method: Method,
51    segments: Vec<Segment>,
52    handler: HandlerFn,
53}
54
55enum Segment {
56    Static(String),
57    Param(String),
58}
59
60// public:
61pub struct Router {
62    routes: Vec<Route>,
63    middleware: Vec<MiddlewareFn>,
64}
65
66impl Default for Router {
67    fn default() -> Self {
68        Self::new()
69    }
70}
71
72impl Router {
73    // public:
74    pub fn new() -> Self {
75        Self {
76            routes: Vec::new(),
77            middleware: Vec::new(),
78        }
79    }
80
81    // public:
82    pub fn middleware<F, Fut>(mut self, mw: F) -> Self
83    where
84        F: Fn(Request, Next) -> Fut + Send + Sync + 'static,
85        Fut: Future<Output = Result<Response>> + Send + 'static,
86    {
87        let wrapped: MiddlewareFn = Arc::new(move |req, next| Box::pin(mw(req, next)));
88        self.middleware.push(wrapped);
89        self
90    }
91
92    // public:
93    pub fn get<F, Fut>(self, path: &str, handler: F) -> Self
94    where
95        F: Fn(Request) -> Fut + Send + Sync + 'static,
96        Fut: Future<Output = Result<Response>> + Send + 'static,
97    {
98        self.route(Method::GET, path, handler)
99    }
100
101    // public:
102    pub fn post<F, Fut>(self, path: &str, handler: F) -> Self
103    where
104        F: Fn(Request) -> Fut + Send + Sync + 'static,
105        Fut: Future<Output = Result<Response>> + Send + 'static,
106    {
107        self.route(Method::POST, path, handler)
108    }
109
110    // public:
111    pub fn route<F, Fut>(mut self, method: Method, path: &str, handler: F) -> Self
112    where
113        F: Fn(Request) -> Fut + Send + Sync + 'static,
114        Fut: Future<Output = Result<Response>> + Send + 'static,
115    {
116        let segments = parse_path(path);
117        let handler: HandlerFn = Arc::new(move |req| Box::pin(handler(req)));
118        self.routes.push(Route {
119            method,
120            segments,
121            handler,
122        });
123        self
124    }
125
126    /// Look up a handler for the given method+path.
127    fn find(&self, method: &Method, path: &str) -> MatchResult {
128        let mut path_segs: Vec<&str> = path.trim_start_matches('/').split('/').collect();
129        // Normalise trailing slash: `/admin/posts/` and `/admin/posts`
130        // address the same handler. Skip when the only segment is empty
131        // (the root path `/`), so root-route lookups still work.
132        if path_segs.len() > 1 && path_segs.last() == Some(&"") {
133            path_segs.pop();
134        }
135        let mut path_matched = false;
136
137        for route in &self.routes {
138            if !segments_match(&route.segments, &path_segs) {
139                continue;
140            }
141            path_matched = true;
142            if route.method == *method {
143                let params = extract_params(&route.segments, &path_segs);
144                return MatchResult::Ok {
145                    handler: route.handler.clone(),
146                    params,
147                };
148            }
149        }
150
151        if path_matched {
152            MatchResult::MethodNotAllowed
153        } else {
154            MatchResult::NotFound
155        }
156    }
157
158    // public:
159    pub async fn dispatch(&self, mut req: Request) -> Response {
160        let matched = self.find(req.method(), req.path());
161
162        let outcome = match matched {
163            MatchResult::Ok { handler, params } => {
164                req.set_params(params);
165                let next = Next {
166                    chain: self.middleware.clone(),
167                    handler,
168                    index: 0,
169                };
170                next.run(req).await
171            }
172            MatchResult::NotFound => Err(Error::NotFound(format!("no route for {}", req.path()))),
173            MatchResult::MethodNotAllowed => Err(Error::MethodNotAllowed(format!(
174                "{} not allowed",
175                req.method()
176            ))),
177        };
178
179        match outcome {
180            Ok(resp) => resp,
181            Err(err) => response_from_error(&err),
182        }
183    }
184}
185
186enum MatchResult {
187    Ok {
188        handler: HandlerFn,
189        params: std::collections::HashMap<String, String>,
190    },
191    NotFound,
192    MethodNotAllowed,
193}
194
195fn parse_path(path: &str) -> Vec<Segment> {
196    path.trim_start_matches('/')
197        .split('/')
198        .map(|seg| {
199            if let Some(name) = seg.strip_prefix(':') {
200                Segment::Param(name.to_string())
201            } else {
202                Segment::Static(seg.to_string())
203            }
204        })
205        .collect()
206}
207
208fn segments_match(route: &[Segment], path: &[&str]) -> bool {
209    if route.len() != path.len() {
210        return false;
211    }
212    for (r, p) in route.iter().zip(path.iter()) {
213        match r {
214            Segment::Static(s) if s != p => return false,
215            _ => {}
216        }
217    }
218    true
219}
220
221fn extract_params(route: &[Segment], path: &[&str]) -> std::collections::HashMap<String, String> {
222    let mut out = std::collections::HashMap::new();
223    for (r, p) in route.iter().zip(path.iter()) {
224        if let Segment::Param(name) = r {
225            out.insert(name.clone(), (*p).to_string());
226        }
227    }
228    out
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    #[tokio::test]
236    async fn matches_static_path() {
237        let router = Router::new().get("/hello", |_req| async { Ok(Response::text("hi")) });
238        let req = Request::new(
239            Method::GET,
240            "/hello".into(),
241            String::new(),
242            Default::default(),
243            bytes::Bytes::new(),
244        );
245        let resp = router.dispatch(req).await;
246        assert_eq!(resp.status.as_u16(), 200);
247    }
248
249    #[tokio::test]
250    async fn captures_param() {
251        let router = Router::new().get("/users/:id", |req| async move {
252            let id = req.param("id").unwrap_or("").to_string();
253            Ok(Response::text(id))
254        });
255        let req = Request::new(
256            Method::GET,
257            "/users/42".into(),
258            String::new(),
259            Default::default(),
260            bytes::Bytes::new(),
261        );
262        let resp = router.dispatch(req).await;
263        assert_eq!(resp.status.as_u16(), 200);
264        assert_eq!(&resp.body[..], b"42");
265    }
266
267    #[tokio::test]
268    async fn distinguishes_404_from_405() {
269        let router = Router::new().get("/things", |_| async { Ok(Response::text("ok")) });
270
271        let post = Request::new(
272            Method::POST,
273            "/things".into(),
274            String::new(),
275            Default::default(),
276            bytes::Bytes::new(),
277        );
278        assert_eq!(router.dispatch(post).await.status.as_u16(), 405);
279
280        let missing = Request::new(
281            Method::GET,
282            "/nope".into(),
283            String::new(),
284            Default::default(),
285            bytes::Bytes::new(),
286        );
287        assert_eq!(router.dispatch(missing).await.status.as_u16(), 404);
288    }
289
290    #[tokio::test]
291    async fn trailing_slash_is_normalised_for_static_and_param_routes() {
292        let router = Router::new()
293            .get("/admin/:name", |req| async move {
294                let name = req.param("name").unwrap_or("").to_string();
295                Ok(Response::text(name))
296            })
297            .get("/admin/:name/:id/edit", |req| async move {
298                Ok(Response::text(format!(
299                    "{}/{}",
300                    req.param("name").unwrap_or(""),
301                    req.param("id").unwrap_or(""),
302                )))
303            });
304
305        for path in ["/admin/posts", "/admin/posts/"] {
306            let req = Request::new(
307                Method::GET,
308                path.into(),
309                String::new(),
310                Default::default(),
311                bytes::Bytes::new(),
312            );
313            let resp = router.dispatch(req).await;
314            assert_eq!(resp.status.as_u16(), 200, "GET {path} should be 200");
315            assert_eq!(&resp.body[..], b"posts", "GET {path} body");
316        }
317
318        for path in ["/admin/posts/1/edit", "/admin/posts/1/edit/"] {
319            let req = Request::new(
320                Method::GET,
321                path.into(),
322                String::new(),
323                Default::default(),
324                bytes::Bytes::new(),
325            );
326            let resp = router.dispatch(req).await;
327            assert_eq!(resp.status.as_u16(), 200, "GET {path} should be 200");
328            assert_eq!(&resp.body[..], b"posts/1", "GET {path} body");
329        }
330    }
331
332    #[tokio::test]
333    async fn root_path_still_matches_after_trailing_slash_normalisation() {
334        // Regression check: the trailing-slash strip must NOT collapse
335        // the single empty segment that represents the root path.
336        let router = Router::new().get("/", |_| async { Ok(Response::text("home")) });
337        let req = Request::new(
338            Method::GET,
339            "/".into(),
340            String::new(),
341            Default::default(),
342            bytes::Bytes::new(),
343        );
344        let resp = router.dispatch(req).await;
345        assert_eq!(resp.status.as_u16(), 200);
346        assert_eq!(&resp.body[..], b"home");
347    }
348}