use hashbrown::HashMap;
use std::fmt::Debug;
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct TrieNode<V: Debug> {
pub children: HashMap<Box<str>, TrieNode<V>>,
pub value: Option<V>,
}
impl<V: Debug> TrieNode<V> {
pub fn new() -> Self {
TrieNode {
children: HashMap::new(),
value: None,
}
}
}
impl<V: Debug> Default for TrieNode<V> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Trie<V: Debug> {
pub root: TrieNode<V>,
pub match_all: bool,
}
impl<V: Debug> Default for Trie<V> {
fn default() -> Self {
Self::new()
}
}
impl<V: Debug> Trie<V> {
pub fn new() -> Self {
Self {
root: TrieNode::new(),
match_all: false,
}
}
#[inline]
fn path_start(path: &str) -> usize {
let after_scheme = if path.starts_with("https://") {
8
} else if path.starts_with("http://") {
7
} else if let Some(pos) = path.find("://") {
pos + 3
} else {
return 0;
};
if after_scheme < path.len() {
memchr::memchr(
b'/',
path.as_bytes().get(after_scheme..).unwrap_or_default(),
)
.map_or(path.len(), |p| after_scheme + p)
} else {
0
}
}
#[cfg_attr(feature = "inline-more", inline)]
pub fn insert(&mut self, path: &str, value: V) {
let mut node = &mut self.root;
let start = Self::path_start(path);
let bytes = path.as_bytes();
let len = bytes.len();
let mut i = start;
while i < len {
if bytes[i] == b'/' {
i += 1;
continue;
}
let seg_start = i;
let seg_end = memchr::memchr(b'/', &bytes[i..]).map_or(len, |p| i + p);
let segment = &path[seg_start..seg_end];
if memchr::memchr(b'.', segment.as_bytes()).is_none() {
node = node.children.entry_ref(segment).or_default();
}
i = seg_end;
}
if path == "/" {
self.match_all = true;
}
node.value = Some(value);
}
#[inline]
pub fn search(&self, input: &str) -> Option<&V> {
let mut node = &self.root;
if node.children.is_empty() && node.value.is_none() {
return None;
}
let start = Self::path_start(input);
let bytes = input.as_bytes();
let len = bytes.len();
let mut i = start;
while i < len {
if bytes[i] == b'/' {
i += 1;
continue;
}
let seg_start = i;
let seg_end = memchr::memchr(b'/', &bytes[i..]).map_or(len, |p| i + p);
let segment = &input[seg_start..seg_end];
if memchr::memchr(b'.', segment.as_bytes()).is_none() {
if let Some(child) = node.children.get(segment) {
node = child;
} else if !self.match_all {
return None;
}
}
i = seg_end;
}
node.value.as_ref()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trie_node_new() {
let node: TrieNode<usize> = TrieNode::new();
assert!(node.children.is_empty());
assert!(node.value.is_none());
}
#[test]
fn test_trie_new() {
let trie: Trie<usize> = Trie::new();
assert!(trie.root.children.is_empty());
assert!(trie.root.value.is_none());
}
#[test]
fn test_insert_and_search() {
let mut trie: Trie<usize> = Trie::new();
trie.insert("/path/to/node", 42);
trie.insert("https://mywebsite/path/to/node", 22);
assert_eq!(trie.search("https://mywebsite/path/to/node"), Some(&22));
assert_eq!(trie.search("/path/to/node"), Some(&22));
assert_eq!(trie.search("/path"), None);
assert_eq!(trie.search("/path/to"), None);
assert_eq!(trie.search("/path/to/node/extra"), None);
trie.insert("/", 11);
assert_eq!(trie.search("/random"), Some(&11));
}
#[test]
fn test_insert_multiple_nodes() {
let mut trie: Trie<usize> = Trie::new();
trie.insert("/path/to/node1", 1);
trie.insert("/path/to/node2", 2);
trie.insert("/path/to/node3", 3);
assert_eq!(trie.search("/path/to/node1"), Some(&1));
assert_eq!(trie.search("/path/to/node2"), Some(&2));
assert_eq!(trie.search("/path/to/node3"), Some(&3));
}
#[test]
fn test_insert_overwrite() {
let mut trie: Trie<usize> = Trie::new();
trie.insert("/path/to/node", 42);
trie.insert("/path/to/node", 84);
assert_eq!(trie.search("/path/to/node"), Some(&84));
}
#[test]
fn test_search_nonexistent_path() {
let mut trie: Trie<usize> = Trie::new();
trie.insert("/path/to/node", 42);
assert!(trie.search("/nonexistent").is_none());
assert!(trie.search("/path/to/wrongnode").is_none());
}
#[test]
fn test_trie_empty_path() {
let mut trie: Trie<usize> = Trie::new();
trie.insert("", 1);
assert!(trie.search("").is_some() || trie.search("/anything").is_some());
}
#[test]
fn test_trie_unicode_paths() {
let mut trie: Trie<&str> = Trie::new();
trie.insert("/café/menü", "unicode");
assert_eq!(trie.search("/café/menü"), Some(&"unicode"));
}
#[test]
fn test_trie_many_entries() {
let mut trie: Trie<usize> = Trie::new();
for i in 0..1000 {
trie.insert(&format!("/path/{}", i), i);
}
assert_eq!(trie.search("/path/0"), Some(&0));
assert_eq!(trie.search("/path/999"), Some(&999));
assert!(trie.search("/path/1000").is_none());
}
#[test]
fn test_trie_default() {
let trie: Trie<usize> = Trie::default();
assert!(trie.root.children.is_empty());
assert!(!trie.match_all);
}
#[test]
fn test_trie_shared_prefix_insert() {
let mut trie: Trie<usize> = Trie::new();
for i in 0..100 {
trie.insert(&format!("/api/v1/resource/{}", i), i);
}
for i in 0..100 {
assert_eq!(
trie.search(&format!("/api/v1/resource/{}", i)),
Some(&i),
"shared prefix path {} not found",
i
);
}
assert!(trie.search("/api").is_none());
assert!(trie.search("/api/v1").is_none());
assert!(trie.search("/api/v1/resource").is_none());
}
#[test]
fn test_trie_overwrite_preserves_others() {
let mut trie: Trie<usize> = Trie::new();
trie.insert("/a/b/c", 1);
trie.insert("/a/b/d", 2);
trie.insert("/a/b/e", 3);
trie.insert("/a/b/c", 99);
assert_eq!(trie.search("/a/b/c"), Some(&99));
assert_eq!(trie.search("/a/b/d"), Some(&2));
assert_eq!(trie.search("/a/b/e"), Some(&3));
}
#[test]
fn test_trie_insert_search_full_urls() {
let mut trie: Trie<&str> = Trie::new();
trie.insert("https://example.com/users/profile", "profile");
trie.insert("/users/settings", "settings");
trie.insert("http://other.com/api/data", "data");
assert_eq!(
trie.search("https://example.com/users/profile"),
Some(&"profile")
);
assert_eq!(trie.search("/users/profile"), Some(&"profile"));
assert_eq!(
trie.search("https://any.com/users/settings"),
Some(&"settings")
);
assert_eq!(trie.search("/api/data"), Some(&"data"));
assert_eq!(
trie.search("http://cdn.example.com/api/data"),
Some(&"data")
);
assert!(trie.search("/users/unknown").is_none());
}
}