use http::Method;
use crate::extract::PathParams;
const NUM_STANDARD_METHODS: usize = 9;
fn method_index(m: &Method) -> Option<usize> {
match m.as_str() {
"GET" => Some(0),
"POST" => Some(1),
"PUT" => Some(2),
"DELETE" => Some(3),
"HEAD" => Some(4),
"OPTIONS" => Some(5),
"PATCH" => Some(6),
"CONNECT" => Some(7),
"TRACE" => Some(8),
_ => None,
}
}
struct Node {
prefix: Box<[u8]>,
indices: Vec<u8>,
children: Vec<usize>,
param_child: Option<usize>,
param_name: Option<&'static str>,
value: Option<usize>,
priority: u32,
}
impl Node {
fn new() -> Self {
Self {
prefix: Box::default(),
indices: Vec::new(),
children: Vec::new(),
param_child: None,
param_name: None,
value: None,
priority: 0,
}
}
fn with_prefix(prefix: Vec<u8>) -> Self {
Self {
prefix: prefix.into_boxed_slice(),
..Self::new()
}
}
}
enum Segment<'a> {
Static(&'a [u8]),
Param(&'a str),
}
fn split_pattern(pattern: &str) -> Vec<Segment<'_>> {
let mut segments = Vec::new();
let bytes = pattern.as_bytes();
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b':' {
let start = i + 1; let end = bytes[start..]
.iter()
.position(|&b| b == b'/')
.map(|p| start + p)
.unwrap_or(bytes.len());
assert!(
end == bytes.len() || bytes[end] == b'/',
"param `{}` in pattern `{}` does not consume a full segment — \
suffix patterns like `:name.txt` are not supported",
&pattern[start..end],
pattern,
);
segments.push(Segment::Param(&pattern[start..end]));
i = end;
} else {
let start = i;
let end = bytes[start..]
.iter()
.position(|&b| b == b':')
.map(|p| start + p)
.unwrap_or(bytes.len());
segments.push(Segment::Static(&bytes[start..end]));
i = end;
}
}
segments
}
struct RadixTrie {
arena: Vec<Node>,
}
impl RadixTrie {
fn new() -> Self {
Self {
arena: vec![Node::new()],
}
}
fn insert(&mut self, pattern: &str, route_index: usize) {
let segments = split_pattern(pattern);
let mut current = 0;
for seg in &segments {
match seg {
Segment::Static(bytes) => {
current = self.insert_static(current, bytes);
}
Segment::Param(name) => {
if let Some(child_id) = self.arena[current].param_child {
if let Some(existing) = self.arena[child_id].param_name {
assert!(
existing == *name,
"conflicting param names at the same position: \
`:{existing}` and `:{name}` in pattern `{pattern}` — \
all routes sharing this param position must use the same name",
);
}
} else {
let leaked: &'static str = Box::leak(name.to_string().into_boxed_str());
let id = self.alloc(Node {
param_name: Some(leaked),
..Node::new()
});
self.arena[current].param_child = Some(id);
}
current = self.arena[current].param_child.unwrap();
}
}
}
self.arena[current].value = Some(route_index);
}
fn insert_static(&mut self, start: usize, key: &[u8]) -> usize {
if key.is_empty() {
return start;
}
let mut current = start;
let mut remaining = key;
loop {
let prefix_len = common_prefix_len(&self.arena[current].prefix, remaining);
if prefix_len < self.arena[current].prefix.len() {
self.split_node(current, prefix_len);
}
remaining = &remaining[prefix_len..];
if remaining.is_empty() {
return current;
}
let next = remaining[0];
if let Some(pos) = self.arena[current].indices.iter().position(|&b| b == next) {
current = self.arena[current].children[pos];
continue;
}
let child = self.alloc(Node::with_prefix(remaining.to_vec()));
self.arena[current].indices.push(next);
self.arena[current].children.push(child);
return child;
}
}
fn split_node(&mut self, node: usize, at: usize) {
let suffix = self.arena[node].prefix[at..].into();
let first_byte = self.arena[node].prefix[at];
let child_node = Node {
prefix: suffix,
indices: std::mem::take(&mut self.arena[node].indices),
children: std::mem::take(&mut self.arena[node].children),
param_child: self.arena[node].param_child.take(),
param_name: self.arena[node].param_name.take(),
value: self.arena[node].value.take(),
priority: self.arena[node].priority,
};
let child = self.alloc(child_node);
self.arena[node].prefix = self.arena[node].prefix[..at].into();
self.arena[node].indices = vec![first_byte];
self.arena[node].children = vec![child];
}
fn compute_priorities(&mut self) {
self.compute_priorities_dfs(0);
}
fn compute_priorities_dfs(&mut self, node_id: usize) -> u32 {
let mut count = if self.arena[node_id].value.is_some() {
1
} else {
0
};
let num_children = self.arena[node_id].children.len();
for i in 0..num_children {
let child = self.arena[node_id].children[i];
count += self.compute_priorities_dfs(child);
}
if let Some(param) = self.arena[node_id].param_child {
count += self.compute_priorities_dfs(param);
}
self.arena[node_id].priority = count;
count
}
fn reorder_children(&mut self) {
for i in 0..self.arena.len() {
if self.arena[i].children.len() <= 1 {
continue;
}
let mut order: Vec<(u32, usize, u8)> = self.arena[i]
.children
.iter()
.zip(self.arena[i].indices.iter())
.map(|(&child, &byte)| (self.arena[child].priority, child, byte))
.collect();
order.sort_by_key(|t| std::cmp::Reverse(t.0));
self.arena[i].children = order.iter().map(|&(_, c, _)| c).collect();
self.arena[i].indices = order.iter().map(|&(_, _, b)| b).collect();
}
}
fn lookup(&self, path: &str, params: &mut PathParams) -> Option<usize> {
self.lookup_recursive(0, path, params)
}
fn lookup_recursive(
&self,
node_id: usize,
path: &str,
params: &mut PathParams,
) -> Option<usize> {
let node = &self.arena[node_id];
let path_bytes = path.as_bytes();
if !path_bytes.starts_with(&node.prefix) {
return None;
}
let consumed = node.prefix.len();
let remaining = &path[consumed..];
if remaining.is_empty() {
return node.value;
}
let remaining_bytes = remaining.as_bytes();
let next = remaining_bytes[0];
for (i, &byte) in node.indices.iter().enumerate() {
if byte == next {
let saved_len = params.len();
if let Some(result) = self.lookup_recursive(node.children[i], remaining, params) {
return Some(result);
}
debug_assert_eq!(
params.len(),
saved_len,
"params leaked during static child backtracking"
);
break; }
}
if let Some(param_idx) = node.param_child {
let end = remaining_bytes
.iter()
.position(|&b| b == b'/')
.unwrap_or(remaining_bytes.len());
let value = &remaining[..end];
let param_node = &self.arena[param_idx];
if let Some(name) = param_node.param_name {
params.push(name, value.to_string());
}
if let Some(result) = self.lookup_recursive(param_idx, &remaining[end..], params) {
return Some(result);
}
if let Some(name) = param_node.param_name {
params.remove(name);
}
}
None
}
fn alloc(&mut self, node: Node) -> usize {
let id = self.arena.len();
self.arena.push(node);
id
}
}
fn common_prefix_len(a: &[u8], b: &[u8]) -> usize {
a.iter().zip(b.iter()).take_while(|(x, y)| x == y).count()
}
pub(super) struct TrieRouter {
methods: [Option<RadixTrie>; NUM_STANDARD_METHODS],
}
impl TrieRouter {
pub(super) fn build(routes: &[(Method, super::Route)]) -> Self {
const INIT: Option<RadixTrie> = None;
let mut methods = [INIT; NUM_STANDARD_METHODS];
for (idx, (method, route)) in routes.iter().enumerate() {
if super::is_dynamic(&route.pattern) {
let mi = method_index(method).unwrap_or_else(|| {
panic!(
"unsupported HTTP method `{}` for route `{}` — \
only standard methods (GET, POST, PUT, DELETE, HEAD, \
OPTIONS, PATCH, CONNECT, TRACE) are supported",
method, route.pattern,
)
});
methods[mi]
.get_or_insert_with(RadixTrie::new)
.insert(&route.pattern, idx);
}
}
for slot in &mut methods {
if let Some(trie) = slot.as_mut() {
trie.compute_priorities();
trie.reorder_children();
}
}
Self { methods }
}
pub(super) fn lookup(
&self,
method: &Method,
path: &str,
params: &mut PathParams,
) -> Option<usize> {
let idx = method_index(method)?;
self.methods[idx].as_ref()?.lookup(path, params)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn lookup_params(trie: &RadixTrie, path: &str) -> (Option<usize>, PathParams) {
let mut params = PathParams::new();
let result = trie.lookup(path, &mut params);
(result, params)
}
#[test]
fn test_single_param_route() {
let mut trie = RadixTrie::new();
trie.insert("/users/:id", 0);
let (result, params) = lookup_params(&trie, "/users/42");
assert_eq!(result, Some(0));
assert_eq!(params.get("id").unwrap(), "42");
}
#[test]
fn test_multiple_param_routes() {
let mut trie = RadixTrie::new();
trie.insert("/users/:id", 0);
trie.insert("/posts/:id", 1);
let (result, params) = lookup_params(&trie, "/users/42");
assert_eq!(result, Some(0));
assert_eq!(params.get("id").unwrap(), "42");
let (result, params) = lookup_params(&trie, "/posts/99");
assert_eq!(result, Some(1));
assert_eq!(params.get("id").unwrap(), "99");
}
#[test]
fn test_nested_params() {
let mut trie = RadixTrie::new();
trie.insert("/users/:uid/posts/:pid", 0);
let (result, params) = lookup_params(&trie, "/users/5/posts/10");
assert_eq!(result, Some(0));
assert_eq!(params.get("uid").unwrap(), "5");
assert_eq!(params.get("pid").unwrap(), "10");
}
#[test]
fn test_param_with_deeper_static() {
let mut trie = RadixTrie::new();
trie.insert("/users/:id", 0);
trie.insert("/users/:id/posts", 1);
let (result, _) = lookup_params(&trie, "/users/42");
assert_eq!(result, Some(0));
let (result, params) = lookup_params(&trie, "/users/42/posts");
assert_eq!(result, Some(1));
assert_eq!(params.get("id").unwrap(), "42");
}
#[test]
fn test_shared_prefix_divergence() {
let mut trie = RadixTrie::new();
trie.insert("/users/:id/posts", 0);
trie.insert("/users/:id/comments", 1);
let (result, _) = lookup_params(&trie, "/users/1/posts");
assert_eq!(result, Some(0));
let (result, _) = lookup_params(&trie, "/users/1/comments");
assert_eq!(result, Some(1));
}
#[test]
fn test_no_match() {
let mut trie = RadixTrie::new();
trie.insert("/users/:id", 0);
let (result, _) = lookup_params(&trie, "/posts/42");
assert_eq!(result, None);
let (result, _) = lookup_params(&trie, "/users");
assert_eq!(result, None);
let (result, _) = lookup_params(&trie, "/users/42/extra");
assert_eq!(result, None);
}
#[test]
fn test_different_param_names_same_structure() {
let mut trie = RadixTrie::new();
trie.insert("/api/:version/users/:id", 0);
let (result, params) = lookup_params(&trie, "/api/v2/users/99");
assert_eq!(result, Some(0));
assert_eq!(params.get("version").unwrap(), "v2");
assert_eq!(params.get("id").unwrap(), "99");
}
#[test]
fn test_param_at_root() {
let mut trie = RadixTrie::new();
trie.insert("/:slug", 0);
let (result, params) = lookup_params(&trie, "/hello");
assert_eq!(result, Some(0));
assert_eq!(params.get("slug").unwrap(), "hello");
}
#[test]
fn test_priority_ordering() {
let mut trie = RadixTrie::new();
trie.insert("/api/v1/users/:id", 0);
trie.insert("/api/v1/posts/:id", 1);
trie.insert("/api/v1/comments/:id", 2);
trie.insert("/api/v2/users/:id", 3);
trie.compute_priorities();
trie.reorder_children();
let v_node = trie.arena.iter().find(|n| n.children.len() == 2).unwrap();
assert_eq!(
v_node.indices[0], b'1',
"v1 subtree should be first (higher priority)"
);
assert_eq!(
v_node.indices[1], b'2',
"v2 subtree should be second (lower priority)"
);
let (result, _) = lookup_params(&trie, "/api/v1/users/1");
assert_eq!(result, Some(0));
let (result, _) = lookup_params(&trie, "/api/v1/posts/1");
assert_eq!(result, Some(1));
let (result, _) = lookup_params(&trie, "/api/v2/users/1");
assert_eq!(result, Some(3));
}
#[test]
fn test_trie_router_method_isolation() {
let router = crate::router::Router::new()
.route(Method::GET, "/users/:id", |_, _, _| async {
http::StatusCode::OK
})
.route(Method::DELETE, "/users/:id", |_, _, _| async {
http::StatusCode::NO_CONTENT
});
let trie_router = TrieRouter::build(&router.routes);
let mut params = PathParams::new();
assert!(
trie_router
.lookup(&Method::GET, "/users/1", &mut params)
.is_some()
);
params.clear();
assert!(
trie_router
.lookup(&Method::DELETE, "/users/1", &mut params)
.is_some()
);
params.clear();
assert!(
trie_router
.lookup(&Method::POST, "/users/1", &mut params)
.is_none()
);
}
#[test]
fn test_empty_trie() {
let trie = RadixTrie::new();
let (result, _) = lookup_params(&trie, "/anything");
assert_eq!(result, None);
}
#[test]
#[should_panic(expected = "conflicting param names")]
fn test_conflicting_param_names_panics() {
let mut trie = RadixTrie::new();
trie.insert("/users/:id/posts", 0);
trie.insert("/users/:name/comments", 1); }
#[test]
fn test_same_param_name_no_conflict() {
let mut trie = RadixTrie::new();
trie.insert("/users/:id/posts", 0);
trie.insert("/users/:id/comments", 1); let (result, _) = lookup_params(&trie, "/users/1/posts");
assert_eq!(result, Some(0));
}
#[test]
fn test_split_pattern_invariant() {
split_pattern("/users/:id");
split_pattern("/users/:id/posts/:pid");
split_pattern("/:slug");
split_pattern("/api/:version/users");
}
#[test]
fn test_backtracking_static_to_param() {
let mut trie = RadixTrie::new();
trie.insert("/a/bb/:x", 0);
trie.insert("/a/:y", 1);
trie.compute_priorities();
trie.reorder_children();
let (result, params) = lookup_params(&trie, "/a/bc");
assert_eq!(result, Some(1));
assert_eq!(params.get("y").unwrap(), "bc");
let (result, params) = lookup_params(&trie, "/a/bb/42");
assert_eq!(result, Some(0));
assert_eq!(params.get("x").unwrap(), "42");
let (result, params) = lookup_params(&trie, "/a/z");
assert_eq!(result, Some(1));
assert_eq!(params.get("y").unwrap(), "z");
}
#[test]
fn test_split_pattern_mid_segment_colon_treated_as_param_name() {
let segments = split_pattern("/files/:name.txt");
assert_eq!(segments.len(), 2);
match &segments[1] {
Segment::Param(name) => assert_eq!(*name, "name.txt"),
_ => panic!("expected a Param segment"),
}
}
#[test]
#[should_panic(expected = "unsupported HTTP method")]
fn test_trie_router_build_panics_on_extension_method_with_dynamic_route() {
let router = crate::router::Router::new().route(
Method::from_bytes(b"FOOBAR").unwrap(),
"/users/:id",
|_, _, _| async { http::StatusCode::OK },
);
TrieRouter::build(&router.routes);
}
#[test]
fn test_params_cleaned_up_after_failed_deep_match() {
let mut trie = RadixTrie::new();
trie.insert("/a/:x/y", 0);
let mut params = PathParams::new();
let result = trie.lookup("/a/val/z", &mut params);
assert_eq!(result, None);
assert!(
params.is_empty(),
"param :x must be removed on backtrack failure"
);
}
#[test]
fn test_param_with_multiple_static_children() {
let mut trie = RadixTrie::new();
trie.insert("/users/:id/posts", 0);
trie.insert("/users/:id/comments", 1);
trie.insert("/users/:id/likes", 2);
let (result, _) = lookup_params(&trie, "/users/1/posts");
assert_eq!(result, Some(0));
let (result, _) = lookup_params(&trie, "/users/1/comments");
assert_eq!(result, Some(1));
let (result, _) = lookup_params(&trie, "/users/1/likes");
assert_eq!(result, Some(2));
let (result, _) = lookup_params(&trie, "/users/1/other");
assert_eq!(result, None);
}
}