1use std::{cmp, fmt};
4
5#[derive(Clone, PartialEq, Eq)]
12pub struct Trie<K, V>
13where
14 K: Clone + PartialEq,
15 V: Clone,
16{
17 parent_key: Option<Vec<K>>,
18 roots: Vec<Node<K, V>>,
19 default: Option<DefaultMapping<K, V>>,
20}
21
22impl<K, V> fmt::Debug for Trie<K, V>
23where
24 K: Clone + PartialEq + fmt::Debug,
25 V: Clone + fmt::Debug,
26{
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 f.debug_struct("Trie")
29 .field("parent_key", &self.parent_key)
30 .field("roots", &self.roots)
31 .field("default", &stringify!(self.default))
32 .finish()
33 }
34}
35
36impl<K, V> Trie<K, V>
37where
38 K: Clone + PartialEq,
39 V: Clone,
40{
41 pub fn from_pairs(pairs: Vec<(Vec<K>, V)>) -> Self {
44 let mut roots = Vec::new();
45
46 for (k, v) in pairs.into_iter() {
47 insert(k, v, &mut roots)
48 }
49
50 Self {
51 parent_key: None,
52 roots,
53 default: None,
54 }
55 }
56
57 pub fn set_default(&mut self, default: DefaultMapping<K, V>) {
59 self.default = Some(default);
60 }
61
62 pub fn get(&self, key: &[K]) -> QueryResult<V> {
68 match get_node(key, &self.roots) {
69 Some(Node {
70 d: Data::Val(v), ..
71 }) => QueryResult::Val(v.clone()),
72
73 Some(_) => QueryResult::Partial,
74
75 None => match self.default {
76 Some(f) if key.len() == 1 => f(&key[0]).into(),
77 _ => QueryResult::Missing,
78 },
79 }
80 }
81
82 pub fn get_exact(&self, key: &[K]) -> Option<V> {
86 match get_node(key, &self.roots) {
87 Some(Node {
88 d: Data::Val(v), ..
89 }) => Some(v.clone()),
90
91 Some(_) => None,
92
93 None => match self.default {
94 Some(f) if key.len() == 1 => f(&key[0]),
95 _ => None,
96 },
97 }
98 }
99
100 pub fn candidates(&self, key: &[K]) -> Vec<Vec<K>> {
103 match get_node(key, &self.roots) {
104 None => vec![],
105 Some(n) => n.resolved_keys(key),
106 }
107 }
108
109 pub fn len(&self) -> usize {
111 self.roots.iter().map(|r| r.len()).sum()
112 }
113
114 pub fn is_empty(&self) -> bool {
116 self.roots.is_empty()
117 }
118
119 pub fn contains_key_or_prefix(&self, key: &[K]) -> bool {
122 !matches!(self.get(key), QueryResult::Missing)
123 }
124}
125
126impl<V> Trie<char, V>
127where
128 V: Clone,
129{
130 pub fn from_str_keys(pairs: Vec<(&str, V)>) -> Self {
132 let mut roots = Vec::new();
133
134 for (k, v) in pairs.into_iter() {
135 insert(k.chars().collect(), v, &mut roots)
136 }
137
138 Self {
139 parent_key: None,
140 roots,
141 default: None,
142 }
143 }
144
145 pub fn get_str(&self, key: &str) -> QueryResult<V> {
149 self.get(&key.chars().collect::<Vec<_>>())
150 }
151
152 pub fn get_str_exact(&self, key: &str) -> Option<V> {
156 self.get_exact(&key.chars().collect::<Vec<_>>())
157 }
158
159 pub fn candidate_strings(&self, key: &str) -> Vec<String> {
161 let raw = self.candidates(&key.chars().collect::<Vec<_>>());
162 let mut strings: Vec<String> = raw.into_iter().map(|v| v.into_iter().collect()).collect();
163 strings.sort();
164
165 strings
166 }
167}
168
169pub type DefaultMapping<K, V> = fn(&K) -> Option<V>;
174
175#[derive(Debug, Clone, PartialEq, Eq)]
177pub enum QueryResult<V> {
178 Val(V),
180 Partial,
182 Missing,
184}
185
186impl<V> From<Option<V>> for QueryResult<V> {
187 fn from(opt: Option<V>) -> Self {
188 match opt {
189 Some(v) => QueryResult::Val(v),
190 None => QueryResult::Missing,
191 }
192 }
193}
194
195impl<V> From<QueryResult<V>> for Option<V> {
196 fn from(q: QueryResult<V>) -> Self {
197 match q {
198 QueryResult::Val(v) => Some(v),
199 _ => None,
200 }
201 }
202}
203
204impl<V> QueryResult<V> {
205 pub fn map<F, U>(self, f: F) -> QueryResult<U>
206 where
207 F: Fn(V) -> U,
208 {
209 match self {
210 Self::Val(v) => QueryResult::Val(f(v)),
211 Self::Partial => QueryResult::Partial,
212 Self::Missing => QueryResult::Missing,
213 }
214 }
215}
216
217#[derive(Debug, Clone, PartialEq, Eq)]
218enum Data<K, V>
219where
220 K: Clone + PartialEq,
221{
222 Val(V),
223 Children(Vec<Node<K, V>>),
224}
225
226#[derive(Debug, Clone, PartialEq, Eq)]
227struct Node<K, V>
228where
229 K: Clone + PartialEq,
230{
231 k: K,
232 d: Data<K, V>,
233}
234
235impl<K, V> Node<K, V>
236where
237 K: Clone + PartialEq,
238{
239 fn len(&self) -> usize {
240 match &self.d {
241 Data::Children(nodes) => nodes.iter().map(|n| n.len()).sum(),
242 Data::Val(_) => 1,
243 }
244 }
245
246 fn insert(&mut self, k: Vec<K>, v: V) {
248 match &mut self.d {
249 Data::Children(nodes) => insert(k, v, nodes),
250 Data::Val(_) => panic!("attempt to insert into value node"),
251 }
252 }
253
254 fn get_child<'s>(&'s self, k: &[K]) -> Option<&'s Node<K, V>> {
255 match &self.d {
256 Data::Children(nodes) => get_node(k, nodes),
257 Data::Val(_) => None,
258 }
259 }
260
261 fn resolved_keys(&self, prefix: &[K]) -> Vec<Vec<K>> {
262 match &self.d {
263 Data::Val(_) => vec![prefix.to_vec()],
264 Data::Children(nodes) => nodes
265 .iter()
266 .flat_map(|n| {
267 let mut so_far = prefix.to_vec();
268 so_far.push(n.k.clone());
269 n.resolved_keys(&so_far)
270 })
271 .collect(),
272 }
273 }
274}
275
276fn insert<K, V>(mut key: Vec<K>, v: V, current: &mut Vec<Node<K, V>>)
277where
278 K: Clone + PartialEq,
279{
280 for n in current.iter_mut() {
281 if key[0] == n.k {
282 if key.len() > 1 {
283 key.remove(0);
284 n.insert(key, v);
285 return;
286 }
287 panic!("duplicate entry for key")
288 }
289 }
290
291 let k = key.remove(0);
292
293 if key.is_empty() {
295 current.push(Node { k, d: Data::Val(v) });
296 } else {
297 let mut children = vec![];
298 insert(key, v, &mut children);
299
300 let d = Data::Children(children);
301 current.push(Node { k, d });
302 }
303}
304
305fn get_node<'n, K, V>(key: &[K], nodes: &'n [Node<K, V>]) -> Option<&'n Node<K, V>>
306where
307 K: Clone + PartialEq,
308{
309 if key.is_empty() {
310 return None;
311 }
312
313 for n in nodes.iter() {
314 if key[0] == n.k {
315 return if key.len() == 1 {
316 Some(n)
317 } else {
318 n.get_child(&key[1..])
319 };
320 }
321 }
322
323 None
324}
325
326pub type WildcardFn<K> = fn(&K) -> bool;
328
329#[derive(Debug)]
331pub enum WildCard<K> {
332 Lit(K),
334 Wild(WildcardFn<K>),
336}
337
338impl<K> cmp::PartialEq<K> for WildCard<K>
339where
340 K: PartialEq,
341{
342 fn eq(&self, other: &K) -> bool {
343 match self {
344 WildCard::Lit(k) => k == other,
345 WildCard::Wild(f) => f(other),
346 }
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353 use simple_test_case::test_case;
354
355 #[test]
356 #[should_panic(expected = "duplicate entry for key")]
357 fn duplicate_keys_panic() {
358 Trie::from_pairs(vec![(vec![42], 1), (vec![42], 2)]);
359 }
360
361 #[test]
362 #[should_panic(expected = "attempt to insert into value node")]
363 fn children_under_a_value_node_panics() {
364 Trie::from_pairs(vec![(vec![42], 1), (vec![42, 69], 2)]);
365 }
366
367 #[test_case(&[42], None; "partial should be None")]
368 #[test_case(&[144], None; "missing should be None")]
369 #[test_case(&[42, 69, 144], None; "overshoot should be None")]
370 #[test_case(&[42, 69], Some(1); "exact should be Some")]
371 #[test]
372 fn get_exact_works(k: &[usize], expected: Option<usize>) {
373 let t = Trie::from_pairs(vec![(vec![42, 69], 1)]);
374
375 assert_eq!(t.get_exact(k), expected);
376 }
377
378 #[test_case("fo", None; "partial should be None")]
379 #[test_case("bar", None; "missing should be None")]
380 #[test_case("fooo", None; "overshoot should be None")]
381 #[test_case("foo", Some(1); "exact should be Some")]
382 #[test]
383 fn get_str_exact_works(k: &str, expected: Option<usize>) {
384 let t = Trie::from_str_keys(vec![("foo", 1)]);
385
386 assert_eq!(t.get_str_exact(k), expected);
387 }
388
389 #[test_case("ba", QueryResult::Partial; "partial match")]
390 #[test_case("bar", QueryResult::Val(2); "exact match")]
391 #[test_case("baz", QueryResult::Val(3); "exact match with shared prefix")]
392 #[test_case("barf", QueryResult::Missing; "overshot known key")]
393 #[test_case("have you any wool?", QueryResult::Missing; "completely missing")]
394 #[test]
395 fn get_works(k: &str, expected: QueryResult<usize>) {
396 let t = Trie::from_str_keys(vec![("foo", 1), ("bar", 2), ("baz", 3)]);
397
398 assert_eq!(t.get_str(k), expected);
399 }
400
401 #[test_case("f", &["fold", "food", "fool"]; "first char")]
402 #[test_case("fo", &["fold", "food", "fool"]; "shared prefix")]
403 #[test_case("foo", &["food", "fool"]; "shared prefix not all match")]
404 #[test_case("food", &["food"]; "exact match")]
405 #[test_case("foods", &[]; "overshot")]
406 #[test_case("q", &[]; "unknown first char")]
407 #[test_case("quux", &[]; "unknown full key")]
408 #[test_case("", &[]; "empty string")]
409 #[test]
410 fn candidate_strings_works(k: &str, expected: &[&str]) {
411 let expected: Vec<String> = expected.iter().map(|s| s.to_string()).collect();
412 let t = Trie::from_str_keys(
413 ["fool", "fold", "food"]
414 .into_iter()
415 .enumerate()
416 .map(|(i, s)| (s, i))
417 .collect(),
418 );
419
420 assert_eq!(t.candidate_strings(k), expected);
421 }
422
423 fn usize_default_handler(n: &usize) -> Option<usize> {
424 Some(n + 1)
425 }
426
427 #[test_case(&[42], QueryResult::Val(1); "exact single should match from the Try")]
428 #[test_case(&[12, 13], QueryResult::Val(2); "exact multi should match from the Try")]
429 #[test_case(&[69], QueryResult::Val(70); "missing single should be defaulted")]
430 #[test_case(&[69, 420], QueryResult::Missing; "missing multi should always be missing")]
431 #[test_case(&[12], QueryResult::Partial; "partial should remain partial")]
432 #[test]
433 fn default_handlers_work(k: &[usize], expected: QueryResult<usize>) {
434 let mut t = Trie::from_pairs(vec![(vec![42], 1), (vec![12, 13], 2)]);
435 t.set_default(usize_default_handler);
436
437 assert_eq!(t.get(k), expected);
438
439 let expected_opt: Option<usize> = expected.into();
440 assert_eq!(t.get_exact(k), expected_opt);
441 }
442
443 #[test]
444 fn wildcards_match_correctly() {
445 fn is_valid(c: &char) -> bool {
446 *c == 'i' || *c == 'a'
447 }
448
449 let w = WildCard::Wild(is_valid);
450
451 assert!(w == 'a');
452 assert!(w != 'b');
453 }
454}