Skip to main content

jerrycan_core/
router.rs

1//! Method routing + segment trie with `{param}` captures (spec §4.1).
2//! Conflicting routes are detected at build time — fail loud before serving.
3//! Path segments are percent-decoded after '/'-splitting; malformed encodings
4//! surface as `RouteMatch::Malformed` (a clean 400, never a panic).
5
6use crate::dep::DepEnv;
7use crate::error::{Error, Result};
8use crate::handler::{BoxHandlerFn, Handler};
9use crate::middleware::Middleware;
10use http::Method;
11use std::collections::HashMap;
12use std::sync::Arc;
13
14/// Per-path method table: `get(list).post(create)` (spec §4.1).
15pub struct MethodRouter {
16    pub(crate) handlers: Vec<(Method, BoxHandlerFn)>,
17    /// Per-route request-body cap in bytes. `None` defers to the app default
18    /// (1 MiB, spec §4.4). Applies to ALL methods on the route, not per-method.
19    pub(crate) body_limit: Option<usize>,
20    /// When true, the body is NOT buffered before dispatch — extractors read
21    /// the live stream lane. Applies to ALL methods on the route.
22    pub(crate) stream_body: bool,
23}
24
25pub fn get<H: Handler<A>, A>(h: H) -> MethodRouter {
26    MethodRouter::new().on(Method::GET, h)
27}
28pub fn post<H: Handler<A>, A>(h: H) -> MethodRouter {
29    MethodRouter::new().on(Method::POST, h)
30}
31pub fn put<H: Handler<A>, A>(h: H) -> MethodRouter {
32    MethodRouter::new().on(Method::PUT, h)
33}
34pub fn patch<H: Handler<A>, A>(h: H) -> MethodRouter {
35    MethodRouter::new().on(Method::PATCH, h)
36}
37pub fn delete<H: Handler<A>, A>(h: H) -> MethodRouter {
38    MethodRouter::new().on(Method::DELETE, h)
39}
40
41impl MethodRouter {
42    fn new() -> Self {
43        Self {
44            handlers: Vec::new(),
45            body_limit: None,
46            stream_body: false,
47        }
48    }
49
50    pub fn on<H: Handler<A>, A>(mut self, method: Method, h: H) -> Self {
51        self.handlers.push((method, h.into_handler_fn()));
52        self
53    }
54
55    /// Cap the request body for THIS route at `bytes`, overriding the app's
56    /// 1 MiB default (spec §4.4). The cap is per-route — it applies to every
57    /// method registered here, not per-method. Bodies over the cap are
58    /// rejected with 413 before the handler runs.
59    pub fn body_limit(mut self, bytes: usize) -> Self {
60        self.body_limit = Some(bytes);
61        self
62    }
63
64    /// Marks every method on this route as STREAMING: the body is not buffered
65    /// before dispatch; extractors read it incrementally (Multipart) or drain it
66    /// on demand (Json/RawBody). `body_limit` still caps cumulative bytes.
67    pub fn stream_body(mut self) -> Self {
68        self.stream_body = true;
69        self
70    }
71    pub fn get<H: Handler<A>, A>(self, h: H) -> Self {
72        self.on(Method::GET, h)
73    }
74    pub fn post<H: Handler<A>, A>(self, h: H) -> Self {
75        self.on(Method::POST, h)
76    }
77    pub fn put<H: Handler<A>, A>(self, h: H) -> Self {
78        self.on(Method::PUT, h)
79    }
80    pub fn patch<H: Handler<A>, A>(self, h: H) -> Self {
81        self.on(Method::PATCH, h)
82    }
83    pub fn delete<H: Handler<A>, A>(self, h: H) -> Self {
84        self.on(Method::DELETE, h)
85    }
86}
87
88/// A flattened route: method table + the effective dependency environment and
89/// middleware chain for this path (computed at build time, spec §4.2).
90pub(crate) struct Endpoint {
91    pub(crate) methods: HashMap<Method, BoxHandlerFn>,
92    pub(crate) env: Arc<DepEnv>,
93    pub(crate) middleware: Arc<[Arc<dyn Middleware>]>,
94    /// Per-route body cap (bytes); `None` = the app default. Read pre-dispatch
95    /// by `route_policy` to size the body read for this route (spec §4.4).
96    pub(crate) body_limit: Option<usize>,
97    /// When true, the body is streamed (not collected upfront) — `route_policy`
98    /// reports it so serve hands the live stream lane to dispatch (v2.1).
99    pub(crate) stream_body: bool,
100}
101
102#[derive(Default)]
103pub(crate) struct Trie {
104    root: Node,
105}
106
107#[derive(Default)]
108struct Node {
109    statics: HashMap<String, Node>,
110    param: Option<(String, Box<Node>)>,
111    endpoint: Option<Endpoint>,
112}
113
114pub(crate) enum RouteMatch<'a> {
115    Found {
116        endpoint: &'a Endpoint,
117        params: Vec<(String, String)>,
118    },
119    MethodMissing,
120    Malformed,
121    NotFound,
122}
123
124fn segments(path: &str) -> impl Iterator<Item = &str> {
125    path.split('/').filter(|s| !s.is_empty())
126}
127
128/// Decode %XX sequences in ONE path segment. `None` = malformed (bad hex,
129/// truncated escape, or non-UTF-8 result) — the caller answers 400.
130/// Runs after '/'-splitting, so an encoded slash cannot create segments.
131fn decode_segment(seg: &str) -> Option<String> {
132    if !seg.contains('%') {
133        return Some(seg.to_string());
134    }
135    fn hex(b: u8) -> Option<u8> {
136        match b {
137            b'0'..=b'9' => Some(b - b'0'),
138            b'a'..=b'f' => Some(b - b'a' + 10),
139            b'A'..=b'F' => Some(b - b'A' + 10),
140            _ => None,
141        }
142    }
143    let bytes = seg.as_bytes();
144    let mut out = Vec::with_capacity(bytes.len());
145    let mut i = 0;
146    while i < bytes.len() {
147        if bytes[i] == b'%' {
148            // `get` returns None on truncated escapes; `hex` on bad digits.
149            let high = hex(*bytes.get(i + 1)?)?;
150            let low = hex(*bytes.get(i + 2)?)?;
151            out.push(high * 16 + low);
152            i += 3;
153        } else {
154            out.push(bytes[i]);
155            i += 1;
156        }
157    }
158    String::from_utf8(out).ok()
159}
160
161impl Trie {
162    pub(crate) fn insert(&mut self, path: &str, endpoint: Endpoint) -> Result<()> {
163        let mut node = &mut self.root;
164        for seg in segments(path) {
165            if let Some(name) = seg.strip_prefix('{').and_then(|s| s.strip_suffix('}')) {
166                if node.param.is_none() {
167                    node.param = Some((name.to_string(), Box::default()));
168                }
169                let (existing, child) = node.param.as_mut().expect("just ensured");
170                if existing != name {
171                    return Err(Error::internal(format!(
172                        "conflicting path parameters `{{{existing}}}` vs `{{{name}}}` in `{path}`"
173                    )));
174                }
175                node = child;
176            } else {
177                node = node.statics.entry(seg.to_string()).or_default();
178            }
179        }
180        if node.endpoint.is_some() {
181            return Err(Error::internal(format!(
182                "duplicate route registration for `{path}`"
183            )));
184        }
185        node.endpoint = Some(endpoint);
186        Ok(())
187    }
188
189    pub(crate) fn find<'a>(&'a self, path: &str, method: &Method) -> RouteMatch<'a> {
190        if !path.contains('%') {
191            let segs: Vec<&str> = segments(path).collect();
192            return self.find_in(&segs, method);
193        }
194        let mut decoded: Vec<String> = Vec::new();
195        for raw in segments(path) {
196            match decode_segment(raw) {
197                Some(d) => decoded.push(d),
198                None => return RouteMatch::Malformed,
199            }
200        }
201        let segs: Vec<&str> = decoded.iter().map(String::as_str).collect();
202        self.find_in(&segs, method)
203    }
204
205    /// The HTTP methods registered for `path`, or `None` if the path is unknown.
206    /// Used by CORS preflight to reflect `Access-Control-Allow-Methods`. The walk
207    /// mirrors [`Trie::find`] (percent-decode then resolve via `find_node`) but
208    /// ignores the request method — preflight cares only whether the path exists.
209    /// Methods are sorted so the emitted header is deterministic (a framework
210    /// invariant) regardless of registration order.
211    pub(crate) fn methods_for(&self, path: &str) -> Option<Vec<Method>> {
212        let mut params: Vec<(String, String)> = Vec::new();
213        let node = if path.contains('%') {
214            let mut decoded: Vec<String> = Vec::new();
215            for raw in segments(path) {
216                decoded.push(decode_segment(raw)?);
217            }
218            let segs: Vec<&str> = decoded.iter().map(String::as_str).collect();
219            find_node(&self.root, &segs, &mut params)
220        } else {
221            let segs: Vec<&str> = segments(path).collect();
222            find_node(&self.root, &segs, &mut params)
223        }?;
224        let ep = node
225            .endpoint
226            .as_ref()
227            .expect("find_node only returns endpoint nodes");
228        let mut methods: Vec<Method> = ep.methods.keys().cloned().collect();
229        methods.sort_by(|a, b| a.as_str().cmp(b.as_str()));
230        Some(methods)
231    }
232
233    fn find_in<'a>(&'a self, segs: &[&str], method: &Method) -> RouteMatch<'a> {
234        let mut params: Vec<(String, String)> = Vec::new();
235        match find_node(&self.root, segs, &mut params) {
236            Some(node) => {
237                let ep = node
238                    .endpoint
239                    .as_ref()
240                    .expect("find_node only returns endpoint nodes");
241                if ep.methods.contains_key(method) {
242                    RouteMatch::Found {
243                        endpoint: ep,
244                        params,
245                    }
246                } else {
247                    RouteMatch::MethodMissing
248                }
249            }
250            None => RouteMatch::NotFound,
251        }
252    }
253}
254
255/// Depth-first with backtracking: static child first; if that subtree fails,
256/// retry via the param child (capturing the segment). Only nodes WITH an
257/// endpoint count as matches, so a static dead-end falls back to the param route.
258fn find_node<'a>(
259    node: &'a Node,
260    segs: &[&str],
261    params: &mut Vec<(String, String)>,
262) -> Option<&'a Node> {
263    let Some((head, rest)) = segs.split_first() else {
264        return node.endpoint.is_some().then_some(node);
265    };
266    if let Some(child) = node.statics.get(*head)
267        && let Some(found) = find_node(child, rest, params)
268    {
269        return Some(found);
270    }
271    if let Some((name, child)) = &node.param {
272        params.push((name.clone(), (*head).to_string()));
273        if let Some(found) = find_node(child, rest, params) {
274            return Some(found);
275        }
276        params.pop();
277    }
278    None
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use crate::response::IntoResponse;
285
286    fn dummy_handler() -> BoxHandlerFn {
287        Arc::new(move |_ctx: &mut crate::RequestCtx| Box::pin(async move { "ok".into_response() }))
288    }
289
290    fn endpoint(methods: &[Method]) -> Endpoint {
291        let mut map = HashMap::new();
292        for m in methods {
293            map.insert(m.clone(), dummy_handler());
294        }
295        Endpoint {
296            methods: map,
297            env: Arc::new(DepEnv::default()),
298            middleware: Arc::from(vec![]),
299            body_limit: None,
300            stream_body: false,
301        }
302    }
303
304    #[test]
305    fn static_and_param_segments_match() {
306        let mut t = Trie::default();
307        t.insert("/todos", endpoint(&[Method::GET])).unwrap();
308        t.insert("/todos/{id}", endpoint(&[Method::GET, Method::DELETE]))
309            .unwrap();
310        t.insert("/todos/{id}/comments", endpoint(&[Method::GET]))
311            .unwrap();
312
313        match t.find("/todos/42/comments", &Method::GET) {
314            RouteMatch::Found { params, .. } => {
315                assert_eq!(params, vec![("id".to_string(), "42".to_string())])
316            }
317            _ => panic!("expected match"),
318        }
319        assert!(matches!(
320            t.find("/todos/42", &Method::DELETE),
321            RouteMatch::Found { .. }
322        ));
323    }
324
325    #[test]
326    fn unknown_path_is_not_found_and_wrong_method_is_method_missing() {
327        let mut t = Trie::default();
328        t.insert("/todos", endpoint(&[Method::GET])).unwrap();
329        assert!(matches!(
330            t.find("/nope", &Method::GET),
331            RouteMatch::NotFound
332        ));
333        assert!(matches!(
334            t.find("/todos", &Method::POST),
335            RouteMatch::MethodMissing
336        ));
337    }
338
339    #[test]
340    fn duplicate_path_registration_is_a_build_error() {
341        let mut t = Trie::default();
342        t.insert("/todos", endpoint(&[Method::GET])).unwrap();
343        let err = t.insert("/todos", endpoint(&[Method::POST])).unwrap_err();
344        assert!(err.message().contains("/todos"));
345    }
346
347    #[test]
348    fn conflicting_param_names_are_a_build_error() {
349        let mut t = Trie::default();
350        t.insert("/todos/{id}", endpoint(&[Method::GET])).unwrap();
351        let err = t
352            .insert("/todos/{todo_id}", endpoint(&[Method::DELETE]))
353            .unwrap_err();
354        assert!(err.message().contains("id"));
355    }
356
357    #[test]
358    fn static_dead_end_backtracks_to_param_branch() {
359        let mut t = Trie::default();
360        t.insert("/a/b/c", endpoint(&[Method::GET])).unwrap();
361        t.insert("/a/{x}/d", endpoint(&[Method::GET])).unwrap();
362        match t.find("/a/b/d", &Method::GET) {
363            RouteMatch::Found { params, .. } => {
364                assert_eq!(params, vec![("x".to_string(), "b".to_string())]);
365            }
366            _ => panic!("expected /a/{{x}}/d to match /a/b/d via backtracking"),
367        }
368        assert!(matches!(
369            t.find("/a/b/c", &Method::GET),
370            RouteMatch::Found { .. }
371        ));
372    }
373
374    #[test]
375    fn static_wins_over_param_when_both_match() {
376        let mut t = Trie::default();
377        t.insert("/users/me", endpoint(&[Method::GET])).unwrap();
378        t.insert("/users/{id}", endpoint(&[Method::GET])).unwrap();
379        match t.find("/users/me", &Method::GET) {
380            RouteMatch::Found { params, .. } => {
381                assert!(params.is_empty(), "static match captures nothing")
382            }
383            _ => panic!("expected static /users/me"),
384        }
385        match t.find("/users/42", &Method::GET) {
386            RouteMatch::Found { params, .. } => {
387                assert_eq!(params, vec![("id".to_string(), "42".to_string())])
388            }
389            _ => panic!("expected param /users/{{id}}"),
390        }
391    }
392
393    #[test]
394    fn method_router_builder_collects_methods() {
395        let mr = get(|| async { "a" }).post(|| async { "b" });
396        let methods: Vec<_> = mr.handlers.iter().map(|(m, _)| m.clone()).collect();
397        assert_eq!(methods, vec![Method::GET, Method::POST]);
398    }
399
400    #[test]
401    fn percent_encoded_segments_decode_for_statics_and_params() {
402        let mut t = Trie::default();
403        t.insert("/caf\u{e9}/menu", endpoint(&[Method::GET]))
404            .unwrap();
405        t.insert("/todos/{id}", endpoint(&[Method::GET])).unwrap();
406
407        // %C3%A9 = é in a STATIC segment
408        assert!(matches!(
409            t.find("/caf%C3%A9/menu", &Method::GET),
410            RouteMatch::Found { .. }
411        ));
412
413        // %2F decodes INSIDE the param value without creating a new segment
414        match t.find("/todos/a%2Fb", &Method::GET) {
415            RouteMatch::Found { params, .. } => assert_eq!(params[0].1, "a/b"),
416            other => panic!(
417                "expected param capture, got no match ({})",
418                matches!(other, RouteMatch::NotFound)
419            ),
420        }
421
422        // %20 decodes to a space
423        match t.find("/todos/hello%20world", &Method::GET) {
424            RouteMatch::Found { params, .. } => assert_eq!(params[0].1, "hello world"),
425            _ => panic!("expected match"),
426        }
427    }
428
429    #[test]
430    fn malformed_percent_encodings_are_flagged_not_matched() {
431        let mut t = Trie::default();
432        t.insert("/todos/{id}", endpoint(&[Method::GET])).unwrap();
433        assert!(matches!(
434            t.find("/todos/%zz", &Method::GET),
435            RouteMatch::Malformed
436        ));
437        assert!(matches!(
438            t.find("/todos/%2", &Method::GET),
439            RouteMatch::Malformed
440        )); // truncated
441        assert!(matches!(
442            t.find("/todos/%FF", &Method::GET),
443            RouteMatch::Malformed
444        )); // invalid UTF-8
445    }
446}