1use crate::dep::DepEnv;
7use crate::error::{Error, Result};
8use crate::handler::{BoxHandlerFn, Handler};
9use crate::middleware::Middleware;
10use http::Method;
11use std::collections::HashMap;
12use std::sync::Arc;
13
14pub struct MethodRouter {
16 pub(crate) handlers: Vec<(Method, BoxHandlerFn)>,
17 pub(crate) body_limit: Option<usize>,
20 pub(crate) stream_body: bool,
23}
24
25pub fn get<H: Handler<A>, A>(h: H) -> MethodRouter {
26 MethodRouter::new().on(Method::GET, h)
27}
28pub fn post<H: Handler<A>, A>(h: H) -> MethodRouter {
29 MethodRouter::new().on(Method::POST, h)
30}
31pub fn put<H: Handler<A>, A>(h: H) -> MethodRouter {
32 MethodRouter::new().on(Method::PUT, h)
33}
34pub fn patch<H: Handler<A>, A>(h: H) -> MethodRouter {
35 MethodRouter::new().on(Method::PATCH, h)
36}
37pub fn delete<H: Handler<A>, A>(h: H) -> MethodRouter {
38 MethodRouter::new().on(Method::DELETE, h)
39}
40
41impl MethodRouter {
42 fn new() -> Self {
43 Self {
44 handlers: Vec::new(),
45 body_limit: None,
46 stream_body: false,
47 }
48 }
49
50 pub fn on<H: Handler<A>, A>(mut self, method: Method, h: H) -> Self {
51 self.handlers.push((method, h.into_handler_fn()));
52 self
53 }
54
55 pub fn body_limit(mut self, bytes: usize) -> Self {
60 self.body_limit = Some(bytes);
61 self
62 }
63
64 pub fn stream_body(mut self) -> Self {
68 self.stream_body = true;
69 self
70 }
71 pub fn get<H: Handler<A>, A>(self, h: H) -> Self {
72 self.on(Method::GET, h)
73 }
74 pub fn post<H: Handler<A>, A>(self, h: H) -> Self {
75 self.on(Method::POST, h)
76 }
77 pub fn put<H: Handler<A>, A>(self, h: H) -> Self {
78 self.on(Method::PUT, h)
79 }
80 pub fn patch<H: Handler<A>, A>(self, h: H) -> Self {
81 self.on(Method::PATCH, h)
82 }
83 pub fn delete<H: Handler<A>, A>(self, h: H) -> Self {
84 self.on(Method::DELETE, h)
85 }
86}
87
88pub(crate) struct Endpoint {
91 pub(crate) methods: HashMap<Method, BoxHandlerFn>,
92 pub(crate) env: Arc<DepEnv>,
93 pub(crate) middleware: Arc<[Arc<dyn Middleware>]>,
94 pub(crate) body_limit: Option<usize>,
97 pub(crate) stream_body: bool,
100}
101
102#[derive(Default)]
103pub(crate) struct Trie {
104 root: Node,
105}
106
107#[derive(Default)]
108struct Node {
109 statics: HashMap<String, Node>,
110 param: Option<(String, Box<Node>)>,
111 endpoint: Option<Endpoint>,
112}
113
114pub(crate) enum RouteMatch<'a> {
115 Found {
116 endpoint: &'a Endpoint,
117 params: Vec<(String, String)>,
118 },
119 MethodMissing,
120 Malformed,
121 NotFound,
122}
123
124fn segments(path: &str) -> impl Iterator<Item = &str> {
125 path.split('/').filter(|s| !s.is_empty())
126}
127
128fn decode_segment(seg: &str) -> Option<String> {
132 if !seg.contains('%') {
133 return Some(seg.to_string());
134 }
135 fn hex(b: u8) -> Option<u8> {
136 match b {
137 b'0'..=b'9' => Some(b - b'0'),
138 b'a'..=b'f' => Some(b - b'a' + 10),
139 b'A'..=b'F' => Some(b - b'A' + 10),
140 _ => None,
141 }
142 }
143 let bytes = seg.as_bytes();
144 let mut out = Vec::with_capacity(bytes.len());
145 let mut i = 0;
146 while i < bytes.len() {
147 if bytes[i] == b'%' {
148 let high = hex(*bytes.get(i + 1)?)?;
150 let low = hex(*bytes.get(i + 2)?)?;
151 out.push(high * 16 + low);
152 i += 3;
153 } else {
154 out.push(bytes[i]);
155 i += 1;
156 }
157 }
158 String::from_utf8(out).ok()
159}
160
161impl Trie {
162 pub(crate) fn insert(&mut self, path: &str, endpoint: Endpoint) -> Result<()> {
163 let mut node = &mut self.root;
164 for seg in segments(path) {
165 if let Some(name) = seg.strip_prefix('{').and_then(|s| s.strip_suffix('}')) {
166 if node.param.is_none() {
167 node.param = Some((name.to_string(), Box::default()));
168 }
169 let (existing, child) = node.param.as_mut().expect("just ensured");
170 if existing != name {
171 return Err(Error::internal(format!(
172 "conflicting path parameters `{{{existing}}}` vs `{{{name}}}` in `{path}`"
173 )));
174 }
175 node = child;
176 } else {
177 node = node.statics.entry(seg.to_string()).or_default();
178 }
179 }
180 if node.endpoint.is_some() {
181 return Err(Error::internal(format!(
182 "duplicate route registration for `{path}`"
183 )));
184 }
185 node.endpoint = Some(endpoint);
186 Ok(())
187 }
188
189 pub(crate) fn find<'a>(&'a self, path: &str, method: &Method) -> RouteMatch<'a> {
190 if !path.contains('%') {
191 let segs: Vec<&str> = segments(path).collect();
192 return self.find_in(&segs, method);
193 }
194 let mut decoded: Vec<String> = Vec::new();
195 for raw in segments(path) {
196 match decode_segment(raw) {
197 Some(d) => decoded.push(d),
198 None => return RouteMatch::Malformed,
199 }
200 }
201 let segs: Vec<&str> = decoded.iter().map(String::as_str).collect();
202 self.find_in(&segs, method)
203 }
204
205 pub(crate) fn methods_for(&self, path: &str) -> Option<Vec<Method>> {
212 let mut params: Vec<(String, String)> = Vec::new();
213 let node = if path.contains('%') {
214 let mut decoded: Vec<String> = Vec::new();
215 for raw in segments(path) {
216 decoded.push(decode_segment(raw)?);
217 }
218 let segs: Vec<&str> = decoded.iter().map(String::as_str).collect();
219 find_node(&self.root, &segs, &mut params)
220 } else {
221 let segs: Vec<&str> = segments(path).collect();
222 find_node(&self.root, &segs, &mut params)
223 }?;
224 let ep = node
225 .endpoint
226 .as_ref()
227 .expect("find_node only returns endpoint nodes");
228 let mut methods: Vec<Method> = ep.methods.keys().cloned().collect();
229 methods.sort_by(|a, b| a.as_str().cmp(b.as_str()));
230 Some(methods)
231 }
232
233 fn find_in<'a>(&'a self, segs: &[&str], method: &Method) -> RouteMatch<'a> {
234 let mut params: Vec<(String, String)> = Vec::new();
235 match find_node(&self.root, segs, &mut params) {
236 Some(node) => {
237 let ep = node
238 .endpoint
239 .as_ref()
240 .expect("find_node only returns endpoint nodes");
241 if ep.methods.contains_key(method) {
242 RouteMatch::Found {
243 endpoint: ep,
244 params,
245 }
246 } else {
247 RouteMatch::MethodMissing
248 }
249 }
250 None => RouteMatch::NotFound,
251 }
252 }
253}
254
255fn find_node<'a>(
259 node: &'a Node,
260 segs: &[&str],
261 params: &mut Vec<(String, String)>,
262) -> Option<&'a Node> {
263 let Some((head, rest)) = segs.split_first() else {
264 return node.endpoint.is_some().then_some(node);
265 };
266 if let Some(child) = node.statics.get(*head)
267 && let Some(found) = find_node(child, rest, params)
268 {
269 return Some(found);
270 }
271 if let Some((name, child)) = &node.param {
272 params.push((name.clone(), (*head).to_string()));
273 if let Some(found) = find_node(child, rest, params) {
274 return Some(found);
275 }
276 params.pop();
277 }
278 None
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284 use crate::response::IntoResponse;
285
286 fn dummy_handler() -> BoxHandlerFn {
287 Arc::new(move |_ctx: &mut crate::RequestCtx| Box::pin(async move { "ok".into_response() }))
288 }
289
290 fn endpoint(methods: &[Method]) -> Endpoint {
291 let mut map = HashMap::new();
292 for m in methods {
293 map.insert(m.clone(), dummy_handler());
294 }
295 Endpoint {
296 methods: map,
297 env: Arc::new(DepEnv::default()),
298 middleware: Arc::from(vec![]),
299 body_limit: None,
300 stream_body: false,
301 }
302 }
303
304 #[test]
305 fn static_and_param_segments_match() {
306 let mut t = Trie::default();
307 t.insert("/todos", endpoint(&[Method::GET])).unwrap();
308 t.insert("/todos/{id}", endpoint(&[Method::GET, Method::DELETE]))
309 .unwrap();
310 t.insert("/todos/{id}/comments", endpoint(&[Method::GET]))
311 .unwrap();
312
313 match t.find("/todos/42/comments", &Method::GET) {
314 RouteMatch::Found { params, .. } => {
315 assert_eq!(params, vec![("id".to_string(), "42".to_string())])
316 }
317 _ => panic!("expected match"),
318 }
319 assert!(matches!(
320 t.find("/todos/42", &Method::DELETE),
321 RouteMatch::Found { .. }
322 ));
323 }
324
325 #[test]
326 fn unknown_path_is_not_found_and_wrong_method_is_method_missing() {
327 let mut t = Trie::default();
328 t.insert("/todos", endpoint(&[Method::GET])).unwrap();
329 assert!(matches!(
330 t.find("/nope", &Method::GET),
331 RouteMatch::NotFound
332 ));
333 assert!(matches!(
334 t.find("/todos", &Method::POST),
335 RouteMatch::MethodMissing
336 ));
337 }
338
339 #[test]
340 fn duplicate_path_registration_is_a_build_error() {
341 let mut t = Trie::default();
342 t.insert("/todos", endpoint(&[Method::GET])).unwrap();
343 let err = t.insert("/todos", endpoint(&[Method::POST])).unwrap_err();
344 assert!(err.message().contains("/todos"));
345 }
346
347 #[test]
348 fn conflicting_param_names_are_a_build_error() {
349 let mut t = Trie::default();
350 t.insert("/todos/{id}", endpoint(&[Method::GET])).unwrap();
351 let err = t
352 .insert("/todos/{todo_id}", endpoint(&[Method::DELETE]))
353 .unwrap_err();
354 assert!(err.message().contains("id"));
355 }
356
357 #[test]
358 fn static_dead_end_backtracks_to_param_branch() {
359 let mut t = Trie::default();
360 t.insert("/a/b/c", endpoint(&[Method::GET])).unwrap();
361 t.insert("/a/{x}/d", endpoint(&[Method::GET])).unwrap();
362 match t.find("/a/b/d", &Method::GET) {
363 RouteMatch::Found { params, .. } => {
364 assert_eq!(params, vec![("x".to_string(), "b".to_string())]);
365 }
366 _ => panic!("expected /a/{{x}}/d to match /a/b/d via backtracking"),
367 }
368 assert!(matches!(
369 t.find("/a/b/c", &Method::GET),
370 RouteMatch::Found { .. }
371 ));
372 }
373
374 #[test]
375 fn static_wins_over_param_when_both_match() {
376 let mut t = Trie::default();
377 t.insert("/users/me", endpoint(&[Method::GET])).unwrap();
378 t.insert("/users/{id}", endpoint(&[Method::GET])).unwrap();
379 match t.find("/users/me", &Method::GET) {
380 RouteMatch::Found { params, .. } => {
381 assert!(params.is_empty(), "static match captures nothing")
382 }
383 _ => panic!("expected static /users/me"),
384 }
385 match t.find("/users/42", &Method::GET) {
386 RouteMatch::Found { params, .. } => {
387 assert_eq!(params, vec![("id".to_string(), "42".to_string())])
388 }
389 _ => panic!("expected param /users/{{id}}"),
390 }
391 }
392
393 #[test]
394 fn method_router_builder_collects_methods() {
395 let mr = get(|| async { "a" }).post(|| async { "b" });
396 let methods: Vec<_> = mr.handlers.iter().map(|(m, _)| m.clone()).collect();
397 assert_eq!(methods, vec![Method::GET, Method::POST]);
398 }
399
400 #[test]
401 fn percent_encoded_segments_decode_for_statics_and_params() {
402 let mut t = Trie::default();
403 t.insert("/caf\u{e9}/menu", endpoint(&[Method::GET]))
404 .unwrap();
405 t.insert("/todos/{id}", endpoint(&[Method::GET])).unwrap();
406
407 assert!(matches!(
409 t.find("/caf%C3%A9/menu", &Method::GET),
410 RouteMatch::Found { .. }
411 ));
412
413 match t.find("/todos/a%2Fb", &Method::GET) {
415 RouteMatch::Found { params, .. } => assert_eq!(params[0].1, "a/b"),
416 other => panic!(
417 "expected param capture, got no match ({})",
418 matches!(other, RouteMatch::NotFound)
419 ),
420 }
421
422 match t.find("/todos/hello%20world", &Method::GET) {
424 RouteMatch::Found { params, .. } => assert_eq!(params[0].1, "hello world"),
425 _ => panic!("expected match"),
426 }
427 }
428
429 #[test]
430 fn malformed_percent_encodings_are_flagged_not_matched() {
431 let mut t = Trie::default();
432 t.insert("/todos/{id}", endpoint(&[Method::GET])).unwrap();
433 assert!(matches!(
434 t.find("/todos/%zz", &Method::GET),
435 RouteMatch::Malformed
436 ));
437 assert!(matches!(
438 t.find("/todos/%2", &Method::GET),
439 RouteMatch::Malformed
440 )); assert!(matches!(
442 t.find("/todos/%FF", &Method::GET),
443 RouteMatch::Malformed
444 )); }
446}