1use std::{collections::HashMap, hash::Hash};
2
3#[derive(Debug, Default)]
4struct TrieNode<K, V>
5where
6 K: Hash + Eq + Default,
7 V: Hash + Eq + Clone + Default,
8{
9 children: HashMap<K, TrieNode<K, V>>,
10 value: Option<V>,
11}
12
13pub struct Trie<K, V>
14where
15 K: Hash + Eq + Default,
16 V: Hash + Eq + Clone + Default,
17{
18 root: TrieNode<K, V>,
19}
20
21impl<K, V> Trie<K, V>
22where
23 K: Hash + Eq + Default,
24 V: Hash + Eq + Clone + Default,
25{
26 pub fn new() -> Self {
28 Trie {
29 root: TrieNode::default(),
30 }
31 }
32
33 pub fn get<I>(&self, key: I) -> Option<V>
35 where
36 I: IntoIterator<Item = K>,
37 {
38 self.traverse(key).and_then(|node| node.value.clone())
39 }
40
41 pub fn contains<I>(&self, key: I) -> bool
43 where
44 I: IntoIterator<Item = K>,
45 {
46 self.traverse(key)
47 .map_or(false, |node| node.value.is_some())
48 }
49
50 pub fn best_match<I>(&self, key: I) -> Option<(Vec<K>, V)>
60 where
61 I: IntoIterator<Item = K>,
62 {
63 let mut cur: &TrieNode<K, V> = &self.root;
64 let mut cur_key = Vec::new();
65 let mut cur_value = None;
66 for part in key {
67 if let Some(v) = cur.children.get(&part) {
68 cur_key.push(part);
69 cur = v;
70 if let Some(new_match) = cur.value.as_ref() {
71 cur_value.replace(new_match.clone());
72 }
73 } else {
74 break;
75 }
76 }
77 match cur_value {
78 Some(value) => Some((cur_key, value)),
79 None => None,
80 }
81 }
82
83 pub fn insert<I>(&mut self, key: I, value: V)
85 where
86 I: IntoIterator<Item = K>,
87 {
88 let mut cur = &mut self.root;
89 for part in key {
90 cur = cur.children.entry(part).or_insert(TrieNode::default());
91 }
92 cur.value = Some(value);
93 }
94
95 fn traverse<I>(&self, key: I) -> Option<&TrieNode<K, V>>
97 where
98 I: IntoIterator<Item = K>,
99 {
100 let mut cur: &TrieNode<K, V> = &self.root;
101 for part in key {
102 match cur.children.get(&part) {
103 Some(v) => cur = v,
104 None => return None,
105 }
106 }
107 Some(cur)
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use std::path::{Path, PathBuf};
114
115 use super::*;
116
117 fn path_to_vec(path: &Path) -> Vec<String> {
118 path.components()
119 .map(|c| c.as_os_str().to_string_lossy().into_owned())
120 .collect()
121 }
122
123 #[test]
124 fn test_all() {
125 let mut trie: Trie<String, PathBuf> = Trie::new();
126 let src_path1 = Path::new("/etc/bin/echos");
127 let src_path2 = Path::new("/etc/bin/echo");
128 let src_path3 = Path::new("/etc/bin/echo/hello.txt");
129
130 let dest_path1 = PathBuf::from("usr/cat");
131 let dest_path2 = PathBuf::from("usr/tar");
132
133 let longer_path1 = Path::new("/etc/bin/echo/hello.txt/jello");
134 trie.insert(path_to_vec(src_path1), dest_path1.clone());
135 trie.insert(path_to_vec(src_path3), dest_path2.clone());
136
137 assert!(trie.contains(path_to_vec(src_path1)));
138 assert!(trie.contains(path_to_vec(src_path3)));
139 assert_eq!(trie.get(path_to_vec(src_path1)), Some(dest_path1.clone()));
140 assert_eq!(trie.get(path_to_vec(src_path3)), Some(dest_path2.clone()));
141 assert_eq!(
142 trie.best_match(path_to_vec(src_path1)),
143 Some((path_to_vec(src_path1), dest_path1.clone()))
144 );
145 assert_eq!(
146 trie.best_match(path_to_vec(longer_path1)),
147 Some((path_to_vec(src_path3), dest_path2.clone()))
148 );
149 assert!(!trie.contains(path_to_vec(src_path2)));
150 }
151}