1use std::{collections::BTreeMap, fmt, ops::Range, sync::Arc};
4
5#[derive(Clone, PartialEq, Eq)]
12pub struct Trie<K, V>
13where
14 K: Clone + PartialEq + Ord,
15 V: Clone,
16{
17 nodes: Arc<[Node<K>]>,
18 values: Arc<[V]>,
19 n_roots: usize,
20}
21
22impl<K, V> fmt::Debug for Trie<K, V>
23where
24 K: Clone + PartialEq + Ord + fmt::Debug,
25 V: Clone + fmt::Debug,
26{
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 f.debug_struct("Trie")
29 .field("nodes", &self.nodes)
30 .field("values", &self.values)
31 .field("n_roots", &self.n_roots)
32 .finish()
33 }
34}
35
36impl<K, V> Default for Trie<K, V>
37where
38 K: Clone + PartialEq + Ord,
39 V: Clone,
40{
41 fn default() -> Self {
42 Self {
43 nodes: Arc::from(Vec::new()),
44 values: Arc::from(Vec::new()),
45 n_roots: 0,
46 }
47 }
48}
49
50impl<K, V> Trie<K, V>
51where
52 K: Clone + PartialEq + Ord,
53 V: Clone,
54{
55 pub fn try_from_iter(it: impl IntoIterator<Item = (Vec<K>, V)>) -> Result<Self, &'static str> {
60 let mut roots = Vec::new();
61 for (key, value) in it.into_iter() {
62 insert(key, value, &mut roots)?;
63 }
64
65 if roots.is_empty() {
66 return Ok(Self::default());
67 }
68
69 let mut nodes = Vec::new();
70 let mut values = Vec::new();
71 let (_, n_roots) = flatten(roots, &mut nodes, &mut values);
72
73 Ok(Trie {
74 nodes: Arc::from(nodes),
75 values: Arc::from(values),
76 n_roots,
77 })
78 }
79
80 pub fn merge(self, other: Self) -> Result<Self, &'static str> {
86 if self.is_empty() {
87 return Ok(other);
88 } else if other.is_empty() {
89 return Ok(self);
90 }
91
92 let mut pairs = Vec::with_capacity(self.len() + other.len());
93 self.extract_pairs(&mut pairs, Vec::new(), 0..self.n_roots);
94 other.extract_pairs(&mut pairs, Vec::new(), 0..other.n_roots);
95
96 Self::try_from_iter(pairs)
97 }
98
99 pub fn merge_overriding(self, other: Self) -> Result<Self, &'static str> {
105 if self.is_empty() {
106 return Ok(other);
107 } else if other.is_empty() {
108 return Ok(self);
109 }
110
111 let mut pairs = Vec::with_capacity(self.len() + other.len());
112 self.extract_pairs(&mut pairs, Vec::new(), 0..self.n_roots);
113
114 let mut m = BTreeMap::from_iter(pairs.drain(..));
115 other.extract_pairs(&mut pairs, Vec::new(), 0..other.n_roots);
116 m.extend(pairs.drain(..));
117
118 Self::try_from_iter(m)
119 }
120
121 fn extract_pairs(&self, pairs: &mut Vec<(Vec<K>, V)>, key: Vec<K>, indices: Range<usize>) {
122 for i in indices {
123 let node = &self.nodes[i];
124 let mut child_key = key.clone();
125 child_key.push(node.key.clone());
126
127 match node.data {
128 Data::Leaf { i } => pairs.push((child_key, self.values[i].clone())),
129
130 Data::Internal {
131 child_start,
132 n_children,
133 } => self.extract_pairs(pairs, child_key, child_start..child_start + n_children),
134 }
135 }
136 }
137
138 pub fn get<'a>(&'a self, key: &[K]) -> QueryResult<'a, V> {
144 if key.is_empty() {
145 return QueryResult::Missing;
146 }
147
148 let mut indices = 0..self.n_roots;
149 let mut key_index = 0;
150
151 'outer: while key_index < key.len() {
152 let target = &key[key_index];
153
154 for i in indices {
156 let node = &self.nodes[i];
157 if &node.key == target {
158 key_index += 1;
159
160 match node.data {
161 Data::Leaf { i } => {
162 return if key_index == key.len() {
163 QueryResult::Val(&self.values[i])
164 } else {
165 QueryResult::Missing
166 };
167 }
168
169 Data::Internal {
170 child_start,
171 n_children,
172 } => {
173 indices = child_start..child_start + n_children;
174 continue 'outer;
175 }
176 }
177 } else if node.key > *target {
178 return QueryResult::Missing;
180 }
181 }
182
183 return QueryResult::Missing;
184 }
185
186 QueryResult::Partial
187 }
188
189 pub fn get_exact<'a>(&'a self, key: &[K]) -> Option<&'a V> {
193 self.get(key).into()
194 }
195
196 pub fn len(&self) -> usize {
198 self.nodes.iter().filter(|n| n.is_leaf()).count()
199 }
200
201 pub fn is_empty(&self) -> bool {
203 self.nodes.is_empty()
204 }
205}
206
207impl<V> Trie<char, V>
209where
210 V: Clone,
211{
212 pub fn from_str_keys(pairs: Vec<(&str, V)>) -> Result<Self, &'static str> {
214 let char_pairs: Vec<(Vec<char>, V)> = pairs
215 .into_iter()
216 .map(|(k, v)| (k.chars().collect(), v))
217 .collect();
218
219 Self::try_from_iter(char_pairs)
220 }
221
222 pub fn get_str<'a>(&'a self, key: &str) -> QueryResult<'a, V> {
226 self.get(&key.chars().collect::<Vec<_>>())
227 }
228
229 pub fn get_str_exact<'a>(&'a self, key: &str) -> Option<&'a V> {
233 self.get_exact(&key.chars().collect::<Vec<_>>())
234 }
235}
236
237#[derive(Debug, Clone, PartialEq, Eq)]
242struct Node<K>
243where
244 K: Clone + PartialEq + Ord,
245{
246 key: K,
247 data: Data,
248}
249
250impl<K> Node<K>
251where
252 K: Clone + PartialEq + PartialOrd + Ord,
253{
254 fn new_internal(key: K, child_start: usize, n_children: usize) -> Self {
255 Self {
256 key,
257 data: Data::Internal {
258 child_start,
259 n_children,
260 },
261 }
262 }
263
264 fn new_leaf(key: K, i: usize) -> Self {
265 Self {
266 key,
267 data: Data::Leaf { i },
268 }
269 }
270
271 fn is_leaf(&self) -> bool {
272 matches!(self.data, Data::Leaf { .. })
273 }
274}
275
276#[derive(Debug, Copy, Clone, PartialEq, Eq)]
281enum Data {
282 Internal {
283 child_start: usize,
285 n_children: usize,
287 },
288 Leaf {
289 i: usize,
291 },
292}
293
294#[derive(Debug)]
295struct BuildNode<K, V>
296where
297 K: PartialEq + Ord,
298{
299 k: K,
300 data: BuildNodeData<K, V>,
301}
302
303#[derive(Debug)]
304enum BuildNodeData<K, V>
305where
306 K: PartialEq + Ord,
307{
308 Internal(Vec<BuildNode<K, V>>),
309 Leaf(V),
310}
311
312fn insert<K, V>(
313 mut key: Vec<K>,
314 v: V,
315 current: &mut Vec<BuildNode<K, V>>,
316) -> Result<(), &'static str>
317where
318 K: PartialEq + Ord,
319{
320 for n in current.iter_mut() {
321 if key[0] == n.k {
322 if key.len() <= 1 {
323 return Err("duplicate entry for key");
324 }
325
326 key.remove(0);
327 return match &mut n.data {
328 BuildNodeData::Internal(nodes) => insert(key, v, nodes),
329 BuildNodeData::Leaf(_) => Err("attempt to insert into value node"),
330 };
331 }
332 }
333
334 let k = key.remove(0);
335
336 if key.is_empty() {
337 current.push(BuildNode {
338 k,
339 data: BuildNodeData::Leaf(v),
340 });
341 } else {
342 let mut children = vec![];
343 insert(key, v, &mut children)?;
344 current.push(BuildNode {
345 k,
346 data: BuildNodeData::Internal(children),
347 });
348 }
349
350 Ok(())
351}
352
353fn flatten<K, V>(
354 mut roots: Vec<BuildNode<K, V>>,
355 nodes: &mut Vec<Node<K>>,
356 values: &mut Vec<V>,
357) -> (usize, usize)
358where
359 K: Clone + PartialEq + Ord,
360 V: Clone,
361{
362 roots.sort_by(|l, r| l.k.cmp(&r.k));
363
364 let child_start = nodes.len();
365 let n_children = roots.len();
366
367 let mut child_stack = Vec::new();
368
369 for BuildNode { k, data } in roots.into_iter() {
371 match data {
372 BuildNodeData::Internal(children) => {
373 let i = nodes.len();
374 nodes.push(Node::new_internal(k, 0, children.len()));
375 child_stack.push((i, children));
376 }
377
378 BuildNodeData::Leaf(v) => {
379 let i = values.len();
380 values.push(v);
381 nodes.push(Node::new_leaf(k, i))
382 }
383 }
384 }
385
386 for (i, children) in child_stack.into_iter() {
389 let (start, _) = flatten(children, nodes, values);
390 match &mut nodes[i] {
391 Node {
392 data: Data::Internal { child_start, .. },
393 ..
394 } => {
395 *child_start = start;
396 }
397
398 _ => unreachable!(),
399 }
400 }
401
402 (child_start, n_children)
403}
404
405pub type DefaultMapping<K, V> = fn(&K) -> Option<V>;
410
411#[derive(Debug, Clone, PartialEq, Eq)]
413pub enum QueryResult<'a, V> {
414 Val(&'a V),
416 Partial,
418 Missing,
420}
421
422impl<'a, V> From<Option<&'a V>> for QueryResult<'a, V> {
423 fn from(opt: Option<&'a V>) -> Self {
424 match opt {
425 Some(v) => QueryResult::Val(v),
426 None => QueryResult::Missing,
427 }
428 }
429}
430
431impl<'a, V> From<QueryResult<'a, V>> for Option<&'a V> {
432 fn from(q: QueryResult<'a, V>) -> Self {
433 match q {
434 QueryResult::Val(v) => Some(v),
435 _ => None,
436 }
437 }
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443 use simple_test_case::test_case;
444
445 #[test]
446 fn duplicate_keys_errors() {
447 assert!(Trie::try_from_iter(vec![(vec![42], 1), (vec![42], 2)]).is_err());
448 }
449
450 #[test]
451 fn children_under_a_value_node_errors() {
452 assert!(Trie::try_from_iter(vec![(vec![42], 1), (vec![42, 69], 2)]).is_err());
453 }
454
455 #[test_case("foo", QueryResult::Val(&1); "val 1")]
456 #[test_case("bar", QueryResult::Val(&2); "val 2")]
457 #[test_case("baz", QueryResult::Val(&3); "val 3")]
458 #[test_case("ba", QueryResult::Partial; "partial 1")] #[test_case("fo", QueryResult::Partial; "partial 2")] #[test_case("barf", QueryResult::Missing; "overshoot")]
461 #[test_case("have you any wool?", QueryResult::Missing; "fully missing")]
462 #[test]
463 fn get_works(key: &str, expected: QueryResult<'_, usize>) {
464 let t = Trie::from_str_keys(vec![("foo", 1), ("bar", 2), ("baz", 3)]).unwrap();
465 assert_eq!(t.get_str(key), expected);
466 }
467
468 #[test_case(&[42], None; "partial should be None")]
469 #[test_case(&[144], None; "missing should be None")]
470 #[test_case(&[42, 69, 144], None; "overshoot should be None")]
471 #[test_case(&[42, 69], Some(1); "exact should be Some")]
472 #[test]
473 fn get_exact_works(key: &[usize], expected: Option<usize>) {
474 let t = Trie::try_from_iter(vec![(vec![42, 69], 1)]).unwrap();
475 assert_eq!(t.get_exact(key), expected.as_ref());
476 }
477
478 #[test_case("fo", None; "partial")] #[test_case("bar", None; "missing")]
480 #[test_case("fool", None; "overshoot")]
481 #[test_case("foo", Some(1); "found")]
482 #[test]
483 fn get_str_exact_works(key: &str, expected: Option<usize>) {
484 let t = Trie::from_str_keys(vec![("foo", 1)]).unwrap();
485 assert_eq!(t.get_str_exact(key), expected.as_ref());
486 }
487
488 #[test]
489 fn merge_works() {
490 let t1 = Trie::from_str_keys(vec![("foo", 1), ("bar", 2)]).unwrap();
491 let t2 = Trie::from_str_keys(vec![("baz", 3), ("qux", 4)]).unwrap();
492
493 let merged = t1.merge(t2).unwrap();
494
495 assert_eq!(merged.get_str_exact("foo"), Some(&1));
496 assert_eq!(merged.get_str_exact("bar"), Some(&2));
497 assert_eq!(merged.get_str_exact("baz"), Some(&3));
498 assert_eq!(merged.get_str_exact("qux"), Some(&4));
499 assert_eq!(merged.len(), 4);
500 }
501
502 #[test]
503 fn merge_conflicts_error() {
504 let t1 = Trie::from_str_keys(vec![("foo", 1)]).unwrap();
505 let t2 = Trie::from_str_keys(vec![("foo", 2)]).unwrap();
506
507 assert!(t1.merge(t2).is_err());
508 }
509
510 #[test]
511 fn merge_overriding_works() {
512 let t1 = Trie::from_str_keys(vec![("foo", 1), ("bar", 2)]).unwrap();
513 let t2 = Trie::from_str_keys(vec![("baz", 3), ("foo", 4)]).unwrap();
514
515 let merged = t1.merge_overriding(t2).unwrap();
516
517 assert_eq!(merged.get_str_exact("foo"), Some(&4));
518 assert_eq!(merged.get_str_exact("bar"), Some(&2));
519 assert_eq!(merged.get_str_exact("baz"), Some(&3));
520 assert_eq!(merged.len(), 3);
521 }
522
523 #[test]
524 fn merge_overriding_conflicts_are_ok() {
525 let t1 = Trie::from_str_keys(vec![("foo", 1)]).unwrap();
526 let t2 = Trie::from_str_keys(vec![("foo", 2)]).unwrap();
527
528 assert!(t1.merge_overriding(t2).is_ok());
529 }
530}