1use std::{cmp, fmt};
4
5#[derive(Clone, PartialEq, Eq)]
12pub struct Trie<K, V>
13where
14 K: Clone + PartialEq,
15 V: Clone,
16{
17 parent_key: Option<Vec<K>>,
18 roots: Vec<Node<K, V>>,
19 default: Option<DefaultMapping<K, V>>,
20}
21
22impl<K, V> fmt::Debug for Trie<K, V>
23where
24 K: Clone + PartialEq + fmt::Debug,
25 V: Clone + fmt::Debug,
26{
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 f.debug_struct("Trie")
29 .field("parent_key", &self.parent_key)
30 .field("roots", &self.roots)
31 .field("default", &stringify!(self.default))
32 .finish()
33 }
34}
35
36impl<K, V> Trie<K, V>
37where
38 K: Clone + PartialEq,
39 V: Clone,
40{
41 pub fn from_pairs(pairs: Vec<(Vec<K>, V)>) -> Result<Self, &'static str> {
44 let mut roots = Vec::new();
45
46 for (k, v) in pairs.into_iter() {
47 insert(k, v, &mut roots)?;
48 }
49
50 Ok(Self {
51 parent_key: None,
52 roots,
53 default: None,
54 })
55 }
56
57 pub fn set_default(&mut self, default: DefaultMapping<K, V>) {
71 self.default = Some(default);
72 }
73
74 pub fn get(&self, key: &[K]) -> QueryResult<V> {
80 match get_node(key, &self.roots) {
81 Some(Node {
82 d: Data::Val(v), ..
83 }) => QueryResult::Val(v.clone()),
84
85 Some(_) => QueryResult::Partial,
86
87 None => match self.default {
88 Some(f) if key.len() == 1 => f(&key[0]).into(),
89 _ => QueryResult::Missing,
90 },
91 }
92 }
93
94 pub fn get_exact(&self, key: &[K]) -> Option<V> {
98 match get_node(key, &self.roots) {
99 Some(Node {
100 d: Data::Val(v), ..
101 }) => Some(v.clone()),
102
103 Some(_) => None,
104
105 None => match self.default {
106 Some(f) if key.len() == 1 => f(&key[0]),
107 _ => None,
108 },
109 }
110 }
111
112 pub fn candidates(&self, key: &[K]) -> Vec<Vec<K>> {
115 match get_node(key, &self.roots) {
116 None => vec![],
117 Some(n) => n.resolved_keys(key),
118 }
119 }
120
121 pub fn len(&self) -> usize {
123 self.roots.iter().map(|r| r.len()).sum()
124 }
125
126 pub fn is_empty(&self) -> bool {
128 self.roots.is_empty()
129 }
130
131 pub fn contains_key_or_prefix(&self, key: &[K]) -> bool {
134 !matches!(self.get(key), QueryResult::Missing)
135 }
136}
137
138impl<V> Trie<char, V>
139where
140 V: Clone,
141{
142 pub fn from_str_keys(pairs: Vec<(&str, V)>) -> Result<Self, &'static str> {
144 let mut roots = Vec::new();
145
146 for (k, v) in pairs.into_iter() {
147 insert(k.chars().collect(), v, &mut roots)?;
148 }
149
150 Ok(Self {
151 parent_key: None,
152 roots,
153 default: None,
154 })
155 }
156
157 pub fn get_str(&self, key: &str) -> QueryResult<V> {
161 self.get(&key.chars().collect::<Vec<_>>())
162 }
163
164 pub fn get_str_exact(&self, key: &str) -> Option<V> {
168 self.get_exact(&key.chars().collect::<Vec<_>>())
169 }
170
171 pub fn candidate_strings(&self, key: &str) -> Vec<String> {
173 let raw = self.candidates(&key.chars().collect::<Vec<_>>());
174 let mut strings: Vec<String> = raw.into_iter().map(|v| v.into_iter().collect()).collect();
175 strings.sort();
176
177 strings
178 }
179}
180
181pub type DefaultMapping<K, V> = fn(&K) -> Option<V>;
186
187#[derive(Debug, Clone, PartialEq, Eq)]
189pub enum QueryResult<V> {
190 Val(V),
192 Partial,
194 Missing,
196}
197
198impl<V> From<Option<V>> for QueryResult<V> {
199 fn from(opt: Option<V>) -> Self {
200 match opt {
201 Some(v) => QueryResult::Val(v),
202 None => QueryResult::Missing,
203 }
204 }
205}
206
207impl<V> From<QueryResult<V>> for Option<V> {
208 fn from(q: QueryResult<V>) -> Self {
209 match q {
210 QueryResult::Val(v) => Some(v),
211 _ => None,
212 }
213 }
214}
215
216impl<V> QueryResult<V> {
217 pub fn map<F, U>(self, f: F) -> QueryResult<U>
218 where
219 F: Fn(V) -> U,
220 {
221 match self {
222 Self::Val(v) => QueryResult::Val(f(v)),
223 Self::Partial => QueryResult::Partial,
224 Self::Missing => QueryResult::Missing,
225 }
226 }
227}
228
229#[derive(Debug, Clone, PartialEq, Eq)]
230enum Data<K, V>
231where
232 K: Clone + PartialEq,
233{
234 Val(V),
235 Children(Vec<Node<K, V>>),
236}
237
238#[derive(Debug, Clone, PartialEq, Eq)]
239struct Node<K, V>
240where
241 K: Clone + PartialEq,
242{
243 k: K,
244 d: Data<K, V>,
245}
246
247impl<K, V> Node<K, V>
248where
249 K: Clone + PartialEq,
250{
251 fn len(&self) -> usize {
252 match &self.d {
253 Data::Children(nodes) => nodes.iter().map(|n| n.len()).sum(),
254 Data::Val(_) => 1,
255 }
256 }
257
258 fn insert(&mut self, k: Vec<K>, v: V) -> Result<(), &'static str> {
260 match &mut self.d {
261 Data::Children(nodes) => insert(k, v, nodes),
262 Data::Val(_) => Err("attempt to insert into value node"),
263 }
264 }
265
266 fn get_child<'s>(&'s self, k: &[K]) -> Option<&'s Node<K, V>> {
267 match &self.d {
268 Data::Children(nodes) => get_node(k, nodes),
269 Data::Val(_) => None,
270 }
271 }
272
273 fn resolved_keys(&self, prefix: &[K]) -> Vec<Vec<K>> {
274 match &self.d {
275 Data::Val(_) => vec![prefix.to_vec()],
276 Data::Children(nodes) => nodes
277 .iter()
278 .flat_map(|n| {
279 let mut so_far = prefix.to_vec();
280 so_far.push(n.k.clone());
281 n.resolved_keys(&so_far)
282 })
283 .collect(),
284 }
285 }
286}
287
288fn insert<K, V>(mut key: Vec<K>, v: V, current: &mut Vec<Node<K, V>>) -> Result<(), &'static str>
289where
290 K: Clone + PartialEq,
291{
292 for n in current.iter_mut() {
293 if key[0] == n.k {
294 if key.len() > 1 {
295 key.remove(0);
296 n.insert(key, v)?;
297 return Ok(());
298 }
299 return Err("duplicate entry for key");
300 }
301 }
302
303 let k = key.remove(0);
304
305 if key.is_empty() {
307 current.push(Node { k, d: Data::Val(v) });
308 } else {
309 let mut children = vec![];
310 insert(key, v, &mut children)?;
311
312 let d = Data::Children(children);
313 current.push(Node { k, d });
314 }
315
316 Ok(())
317}
318
319fn get_node<'n, K, V>(key: &[K], nodes: &'n [Node<K, V>]) -> Option<&'n Node<K, V>>
320where
321 K: Clone + PartialEq,
322{
323 if key.is_empty() {
324 return None;
325 }
326
327 for n in nodes.iter() {
328 if key[0] == n.k {
329 return if key.len() == 1 {
330 Some(n)
331 } else {
332 n.get_child(&key[1..])
333 };
334 }
335 }
336
337 None
338}
339
340pub type WildcardFn<K> = fn(&K) -> bool;
342
343#[derive(Debug)]
345pub enum WildCard<K> {
346 Lit(K),
348 Wild(WildcardFn<K>),
350}
351
352impl<K> cmp::PartialEq<K> for WildCard<K>
353where
354 K: PartialEq,
355{
356 fn eq(&self, other: &K) -> bool {
357 match self {
358 WildCard::Lit(k) => k == other,
359 WildCard::Wild(f) => f(other),
360 }
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367 use simple_test_case::test_case;
368
369 #[test]
370 fn duplicate_keys_errors() {
371 assert!(Trie::from_pairs(vec![(vec![42], 1), (vec![42], 2)]).is_err());
372 }
373
374 #[test]
375 fn children_under_a_value_node_errors() {
376 assert!(Trie::from_pairs(vec![(vec![42], 1), (vec![42, 69], 2)]).is_err());
377 }
378
379 #[test_case(&[42], None; "partial should be None")]
380 #[test_case(&[144], None; "missing should be None")]
381 #[test_case(&[42, 69, 144], None; "overshoot should be None")]
382 #[test_case(&[42, 69], Some(1); "exact should be Some")]
383 #[test]
384 fn get_exact_works(k: &[usize], expected: Option<usize>) {
385 let t = Trie::from_pairs(vec![(vec![42, 69], 1)]).unwrap();
386
387 assert_eq!(t.get_exact(k), expected);
388 }
389
390 #[test_case("fo", None; "partial should be None")]
391 #[test_case("bar", None; "missing should be None")]
392 #[test_case("fooo", None; "overshoot should be None")]
393 #[test_case("foo", Some(1); "exact should be Some")]
394 #[test]
395 fn get_str_exact_works(k: &str, expected: Option<usize>) {
396 let t = Trie::from_str_keys(vec![("foo", 1)]).unwrap();
397
398 assert_eq!(t.get_str_exact(k), expected);
399 }
400
401 #[test_case("ba", QueryResult::Partial; "partial match")]
402 #[test_case("bar", QueryResult::Val(2); "exact match")]
403 #[test_case("baz", QueryResult::Val(3); "exact match with shared prefix")]
404 #[test_case("barf", QueryResult::Missing; "overshot known key")]
405 #[test_case("have you any wool?", QueryResult::Missing; "completely missing")]
406 #[test]
407 fn get_works(k: &str, expected: QueryResult<usize>) {
408 let t = Trie::from_str_keys(vec![("foo", 1), ("bar", 2), ("baz", 3)]).unwrap();
409
410 assert_eq!(t.get_str(k), expected);
411 }
412
413 #[test_case("f", &["fold", "food", "fool"]; "first char")]
414 #[test_case("fo", &["fold", "food", "fool"]; "shared prefix")]
415 #[test_case("foo", &["food", "fool"]; "shared prefix not all match")]
416 #[test_case("food", &["food"]; "exact match")]
417 #[test_case("foods", &[]; "overshot")]
418 #[test_case("q", &[]; "unknown first char")]
419 #[test_case("quux", &[]; "unknown full key")]
420 #[test_case("", &[]; "empty string")]
421 #[test]
422 fn candidate_strings_works(k: &str, expected: &[&str]) {
423 let expected: Vec<String> = expected.iter().map(|s| s.to_string()).collect();
424 let t = Trie::from_str_keys(
425 ["fool", "fold", "food"]
426 .into_iter()
427 .enumerate()
428 .map(|(i, s)| (s, i))
429 .collect(),
430 )
431 .unwrap();
432
433 assert_eq!(t.candidate_strings(k), expected);
434 }
435
436 fn usize_default_handler(n: &usize) -> Option<usize> {
437 Some(n + 1)
438 }
439
440 #[test_case(&[42], QueryResult::Val(1); "exact single should match from the Try")]
441 #[test_case(&[12, 13], QueryResult::Val(2); "exact multi should match from the Try")]
442 #[test_case(&[69], QueryResult::Val(70); "missing single should be defaulted")]
443 #[test_case(&[69, 420], QueryResult::Missing; "missing multi should always be missing")]
444 #[test_case(&[12], QueryResult::Partial; "partial should remain partial")]
445 #[test]
446 fn default_handlers_work(k: &[usize], expected: QueryResult<usize>) {
447 let mut t = Trie::from_pairs(vec![(vec![42], 1), (vec![12, 13], 2)]).unwrap();
448 t.set_default(usize_default_handler);
449
450 assert_eq!(t.get(k), expected);
451
452 let expected_opt: Option<usize> = expected.into();
453 assert_eq!(t.get_exact(k), expected_opt);
454 }
455
456 #[test]
457 fn wildcards_match_correctly() {
458 fn is_valid(c: &char) -> bool {
459 *c == 'i' || *c == 'a'
460 }
461
462 let w = WildCard::Wild(is_valid);
463
464 assert!(w == 'a');
465 assert!(w != 'b');
466 }
467}