1use crate::{LookupResult, Node};
5use ddk_dlc::Error;
6
7#[derive(Clone)]
9pub struct DigitTrie<T> {
10 store: Vec<Node<DigitLeaf<T>, DigitNode<T>>>,
13 root: Option<usize>,
14 pub(crate) base: usize,
15}
16
17pub struct DigitTrieDump<T>
19where
20 T: Clone,
21{
22 pub node_data: Vec<DigitNodeData<T>>,
24 pub root: Option<usize>,
26 pub base: usize,
28}
29
30impl<T> DigitTrie<T>
31where
32 T: Clone,
33{
34 pub fn dump(&self) -> DigitTrieDump<T> {
36 let node_data = self.store.iter().map(|x| x.get_data()).collect();
37 DigitTrieDump {
38 root: self.root,
39 base: self.base,
40 node_data,
41 }
42 }
43
44 pub fn from_dump(dump: DigitTrieDump<T>) -> DigitTrie<T> {
46 let DigitTrieDump {
47 root,
48 base,
49 node_data,
50 } = dump;
51 let store = node_data.into_iter().map(|x| Node::from_data(x)).collect();
52 DigitTrie { store, root, base }
53 }
54}
55
56pub struct DigitNodeData<T> {
58 pub data: Option<T>,
60 pub prefix: Vec<usize>,
62 pub children: Option<Vec<Option<usize>>>,
64}
65
66impl<T> Node<DigitLeaf<T>, DigitNode<T>>
67where
68 T: Clone,
69{
70 fn get_data(&self) -> DigitNodeData<T> {
71 match self {
72 Node::Leaf(l) => DigitNodeData {
73 data: Some(l.data.clone()),
74 prefix: l.prefix.clone(),
75 children: None,
76 },
77 Node::Node(n) => DigitNodeData {
78 data: n.data.clone(),
79 prefix: n.prefix.clone(),
80 children: Some(n.children.clone()),
81 },
82 Node::None => unreachable!(),
83 }
84 }
85
86 fn from_data(data: DigitNodeData<T>) -> Node<DigitLeaf<T>, DigitNode<T>> {
87 match data.children {
88 Some(c) => Node::Node(DigitNode {
89 children: c,
90 prefix: data.prefix,
91 data: data.data,
92 }),
93 None => Node::Leaf(DigitLeaf {
94 prefix: data.prefix,
95 data: data.data.unwrap(),
96 }),
97 }
98 }
99}
100
101pub struct DigitTrieIter<'a, T> {
104 trie: &'a DigitTrie<T>,
105 index_stack: Vec<(Option<usize>, isize)>,
110 cur_prefix: Vec<Vec<usize>>,
111}
112
113impl<'a, T> DigitTrieIter<'a, T> {
114 pub fn new(trie: &'a DigitTrie<T>) -> DigitTrieIter<'a, T> {
116 DigitTrieIter {
117 index_stack: vec![(trie.root, -1)],
118 trie,
119 cur_prefix: Vec::new(),
120 }
121 }
122
123 fn cur_prefix_append(&mut self, to_append: &[usize]) {
124 self.cur_prefix.push(to_append.to_vec());
125 }
126 fn cur_prefix_drop(&mut self) {
127 self.cur_prefix.pop();
128 }
129}
130
131#[derive(Clone)]
132struct DigitLeaf<T> {
133 data: T,
134 prefix: Vec<usize>,
135}
136
137#[derive(Clone)]
138struct DigitNode<T> {
139 children: Vec<Option<usize>>,
140 prefix: Vec<usize>,
141 data: Option<T>,
142}
143
144trait NodePrefix {
145 fn get_node_prefix(&self) -> Vec<usize>;
146 fn set_node_prefix(&mut self, prefix: Vec<usize>);
147}
148
149impl<T> NodePrefix for Node<DigitLeaf<T>, DigitNode<T>> {
150 fn get_node_prefix(&self) -> Vec<usize> {
151 match self {
152 Node::None => unreachable!(),
153 Node::Leaf(digit_leaf) => digit_leaf.prefix.clone(),
154 Node::Node(digit_node) => digit_node.prefix.clone(),
155 }
156 }
157
158 fn set_node_prefix(&mut self, prefix: Vec<usize>) {
159 let pref = match self {
160 Node::None => unreachable!(),
161 Node::Leaf(digit_leaf) => &mut digit_leaf.prefix,
162 Node::Node(digit_node) => &mut digit_node.prefix,
163 };
164
165 *pref = prefix;
166 }
167}
168
169fn get_common_prefix(a: &[usize], b: &[usize]) -> Vec<usize> {
170 a.iter()
171 .zip(b.iter())
172 .take_while(|(x, y)| x == y)
173 .map(|(x, _)| x)
174 .cloned()
175 .collect()
176}
177
178fn insert_new_leaf<T>(trie: &mut DigitTrie<T>, path: &[usize], data: T) -> usize {
179 trie.store.push(Node::Leaf(DigitLeaf {
180 prefix: path.to_vec(),
181 data,
182 }));
183 trie.store.len() - 1
184}
185
186fn is_prefix_of(prefix: &[usize], value: &[usize]) -> bool {
187 if prefix.len() > value.len() {
188 return false;
189 }
190 for i in 0..prefix.len() {
191 if prefix[i] != value[i] {
192 return false;
193 }
194 }
195
196 true
197}
198
199impl<'a, T> Iterator for DigitTrieIter<'a, T> {
201 type Item = LookupResult<'a, T, usize>;
202 fn next(&mut self) -> Option<Self::Item> {
203 let popped = self.index_stack.pop();
204 let (cur_index, mut cur_child) = match popped {
205 None => return None,
206 Some((cur_index, cur_child)) => match cur_index {
207 None => {
208 return self.next();
209 }
210 Some(cur_index) => (cur_index, cur_child),
211 },
212 };
213
214 match &self.trie.store[cur_index] {
215 Node::None => unreachable!(),
216 Node::Leaf(digit_leaf) => Some(LookupResult {
217 value: &digit_leaf.data,
218 path: self
219 .cur_prefix
220 .iter()
221 .filter(|x| !x.is_empty())
222 .flatten()
223 .chain(digit_leaf.prefix.iter())
224 .cloned()
225 .collect(),
226 }),
227 Node::Node(digit_node) => {
228 let node_prefix = digit_node.prefix.clone();
229
230 if cur_child >= (self.trie.base as isize) {
231 self.cur_prefix_drop();
232 self.next()
233 } else {
234 let cur_children = digit_node.children.clone();
235 if cur_child == -1 {
236 match &digit_node.data {
237 Some(data) => {
238 self.index_stack.push((Some(cur_index), cur_child + 1));
239 return Some(LookupResult {
240 value: data,
241 path: self
242 .cur_prefix
243 .iter()
244 .filter(|x| !x.is_empty())
245 .flatten()
246 .chain(digit_node.prefix.iter())
247 .cloned()
248 .collect(),
249 });
250 }
251 _ => {
252 cur_child += 1;
253 }
254 }
255 }
256 if cur_child == 0 {
257 self.cur_prefix_append(&node_prefix);
258 }
259 while cur_child < (self.trie.base as isize) {
260 self.index_stack.push((Some(cur_index), cur_child + 1));
261 self.index_stack
262 .push((cur_children[cur_child as usize], -1));
263 match self.next() {
264 None => {
265 self.index_stack.pop();
266 cur_child += 1;
267 }
268 Some(res) => {
269 return Some(res);
270 }
271 };
272 }
273 self.cur_prefix_drop();
274 self.index_stack.pop();
275 None
276 }
277 }
278 }
279 }
280}
281
282impl<T> DigitTrie<T> {
283 pub fn new(base: usize) -> DigitTrie<T> {
285 DigitTrie {
286 store: Vec::new(),
287 root: None,
288 base,
289 }
290 }
291
292 pub fn insert<F>(&mut self, path: &[usize], get_data: &mut F) -> Result<(), Error>
294 where
295 F: FnMut(Option<T>) -> Result<T, Error>,
296 {
297 if path.is_empty() || path.iter().any(|x| x > &self.base) {
298 panic!("Invalid path");
299 }
300
301 self.root = Some(self.insert_internal(self.root, path, get_data)?);
302 Ok(())
303 }
304
305 fn insert_internal<F>(
306 &mut self,
307 cur_index: Option<usize>,
308 path: &[usize],
309 get_data: &mut F,
310 ) -> Result<usize, Error>
311 where
312 F: FnMut(Option<T>) -> Result<T, Error>,
313 {
314 match cur_index {
315 None => Ok(insert_new_leaf(self, path, get_data(None)?)),
316 Some(cur_index) => {
317 self.store.push(Node::None);
318 let mut cur_node = self.store.swap_remove(cur_index);
319 let prefix = cur_node.get_node_prefix();
320 if prefix == path {
321 match cur_node {
322 Node::Leaf(digit_leaf) => {
323 self.store[cur_index] = Node::Leaf(DigitLeaf {
324 data: get_data(Some(digit_leaf.data))?,
325 prefix: digit_leaf.prefix.to_vec(),
326 });
327 Ok(cur_index)
328 }
329 Node::Node(mut node) => {
330 node.data = Some(get_data(node.data)?);
331 self.store[cur_index] = Node::Node(node);
332 Ok(cur_index)
333 }
334 Node::None => unreachable!(),
335 }
336 } else {
337 let common_prefix = get_common_prefix(&prefix, path);
338 let suffix: Vec<_> = path.iter().skip(common_prefix.len()).cloned().collect();
339 if prefix == common_prefix {
340 match cur_node {
341 Node::Node(mut digit_node) => {
342 digit_node.children[suffix[0]] = Some(self.insert_internal(
343 digit_node.children[suffix[0]],
344 &suffix,
345 get_data,
346 )?);
347 self.store[cur_index] = Node::Node(DigitNode {
348 children: digit_node.children,
349 prefix: digit_node.prefix,
350 data: digit_node.data,
351 });
352 return Ok(cur_index);
353 }
354 Node::None => unreachable!(),
355 Node::Leaf(digit_leaf) => {
356 let mut new_children = Vec::new();
357 new_children.resize_with(self.base, || None);
358 new_children[suffix[0]] =
359 Some(insert_new_leaf(self, &suffix, get_data(None)?));
360 self.store[cur_index] = Node::Node(DigitNode {
361 prefix: digit_leaf.prefix,
362 children: new_children,
363 data: Some(digit_leaf.data),
364 });
365 return Ok(cur_index);
366 }
367 }
368 }
369
370 let mut new_children = Vec::new();
371 new_children.resize_with(self.base, || None);
372
373 let data = if path == common_prefix {
374 Some(get_data(None)?)
375 } else {
376 new_children[path[common_prefix.len()]] =
377 Some(insert_new_leaf(self, &suffix, get_data(None)?));
378 None
379 };
380
381 new_children[prefix[common_prefix.len()]] = Some(cur_index);
382 cur_node.set_node_prefix(
383 prefix.iter().skip(common_prefix.len()).cloned().collect(),
384 );
385 self.store.push(Node::Node(DigitNode {
386 children: new_children,
387 prefix: common_prefix,
388 data,
389 }));
390 self.store[cur_index] = cur_node;
391 Ok(self.store.len() - 1)
392 }
393 }
394 }
395 }
396
397 pub fn look_up(&'_ self, path: &[usize]) -> Option<Vec<LookupResult<'_, T, usize>>> {
399 self.look_up_internal(self.root, path)
400 }
401
402 fn look_up_internal(
403 &'_ self,
404 cur_index: Option<usize>,
405 path: &[usize],
406 ) -> Option<Vec<LookupResult<'_, T, usize>>> {
407 match cur_index {
408 None => None,
409 Some(cur_index) => match &self.store[cur_index] {
410 Node::None => unreachable!(),
411 Node::Leaf(digit_leaf) => {
412 let common_prefix = get_common_prefix(&digit_leaf.prefix, path);
413 if digit_leaf.prefix == common_prefix {
414 Some(vec![LookupResult {
415 path: digit_leaf.prefix.to_vec(),
416 value: &digit_leaf.data,
417 }])
418 } else {
419 None
420 }
421 }
422 Node::Node(digit_node) => {
423 if digit_node.prefix.len() > path.len()
424 || !is_prefix_of(&digit_node.prefix, path)
425 {
426 return None;
427 }
428
429 if digit_node.prefix.len() == path.len() {
430 return digit_node.data.as_ref().map(|data| {
431 vec![LookupResult {
432 value: data,
433 path: digit_node.prefix.clone(),
434 }]
435 });
436 }
437
438 let prefix = path[digit_node.prefix.len()];
439 let suffix: Vec<_> =
440 path.iter().skip(digit_node.prefix.len()).cloned().collect();
441 let res = self.look_up_internal(digit_node.children[prefix], &suffix);
442 match res {
443 None => digit_node.data.as_ref().map(|data| {
444 vec![LookupResult {
445 value: data,
446 path: digit_node.prefix.clone(),
447 }]
448 }),
449 Some(l_res) => match &digit_node.data {
450 None => Some(extend_lookup_res_paths(l_res, &digit_node.prefix)),
451 Some(data) => {
452 let mut up_res = extend_lookup_res_paths(l_res, &digit_node.prefix);
453 let mut final_res = vec![LookupResult {
454 value: data,
455 path: digit_node.prefix.clone(),
456 }];
457 final_res.append(&mut up_res);
458 Some(final_res)
459 }
460 },
461 }
462 }
463 },
464 }
465 }
466}
467
468fn extend_lookup_res_paths<'a, T>(
469 l_res: Vec<LookupResult<'a, T, usize>>,
470 path: &[usize],
471) -> Vec<LookupResult<'a, T, usize>> {
472 l_res
473 .into_iter()
474 .map(|x| LookupResult {
475 value: x.value,
476 path: path.iter().chain(x.path.iter()).cloned().collect(),
477 })
478 .collect()
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484
485 fn digit_trie_test_cases() -> Vec<Vec<Vec<usize>>> {
486 vec![
487 vec![
488 vec![10, 11],
489 vec![10, 12],
490 vec![10, 13],
491 vec![10, 14],
492 vec![10, 15],
493 vec![11],
494 vec![12],
495 vec![13, 0],
496 vec![13, 1],
497 vec![13, 2],
498 ],
499 vec![
500 vec![0, 1, 2, 0, 10, 11],
501 vec![0, 1, 2, 0, 10, 12],
502 vec![0, 1, 2, 0, 10, 13],
503 vec![0, 1, 2, 0, 10, 14],
504 vec![0, 1, 2, 0, 10, 15],
505 vec![0, 1, 2, 0, 11],
506 vec![0, 1, 2, 0, 12],
507 vec![0, 1, 2, 0, 13, 0],
508 vec![0, 1, 2, 0, 13, 1],
509 vec![0, 1, 2, 0, 13, 2],
510 ],
511 ]
512 }
513
514 #[test]
515 fn digit_trie_returns_inserted_elements() {
516 for test_case in digit_trie_test_cases() {
517 let mut digit_trie = DigitTrie::<usize>::new(16);
518 for (i, path) in test_case.iter().enumerate() {
519 digit_trie.insert(path, &mut |_| Ok(i)).unwrap();
520 }
521
522 for (i, path) in test_case.iter().enumerate() {
523 let actual = digit_trie.look_up(path);
524 match actual {
525 None => panic!(),
526 Some(l_res) => {
527 assert_eq!(1, l_res.len());
528 assert_eq!(path, &l_res[0].path);
529 assert_eq!(i, *l_res[0].value);
530 }
531 }
532 }
533 }
534 }
535
536 #[test]
537 fn digit_trie_return_value_with_longer_path_query() {
538 let mut digit_trie = DigitTrie::new(5);
539 let expected_path = &[0, 1];
540 let expected_value = 1;
541 digit_trie
542 .insert(expected_path, &mut |_| Ok(expected_value))
543 .unwrap();
544 let actual = digit_trie.look_up(&[0, 1, 2]);
545 match actual {
546 None => panic!(),
547 Some(l_res) => {
548 assert_eq!(1, l_res.len());
549 assert_eq!(l_res[0].path, &[0, 1]);
550 assert_eq!(*l_res[0].value, expected_value);
551 }
552 }
553 }
554
555 #[test]
556 fn digit_trie_insert_on_common_prefix_query_longest_returns_both() {
557 let mut digit_trie = DigitTrie::new(5);
558 digit_trie.insert(&[0, 1, 2, 3], &mut |_| Ok(1)).unwrap();
559 digit_trie.insert(&[0, 1, 2], &mut |_| Ok(2)).unwrap();
560 let res = digit_trie.look_up(&[0, 1, 2, 3]).unwrap();
561
562 assert_eq!(res.len(), 2);
563 assert_eq!(vec![0, 1, 2], res[0].path);
564 assert_eq!(vec![0, 1, 2, 3], res[1].path);
565 }
566
567 #[test]
568 fn digit_trie_insert_on_common_prefix_query_shortest_returns_single() {
569 let mut digit_trie = DigitTrie::new(5);
570 digit_trie.insert(&[0, 1, 2, 3], &mut |_| Ok(1)).unwrap();
571 digit_trie.insert(&[0, 1, 2], &mut |_| Ok(2)).unwrap();
572 let res = digit_trie.look_up(&[0, 1, 2]).unwrap();
573
574 assert_eq!(res.len(), 1);
575 assert_eq!(vec![0, 1, 2], res[0].path);
576 }
577
578 #[test]
579 fn digit_trie_insert_on_common_prefix_query_longer_non_existing_returns_single() {
580 let mut digit_trie = DigitTrie::new(5);
581 digit_trie.insert(&[0, 1, 2, 3], &mut |_| Ok(1)).unwrap();
582 digit_trie.insert(&[0, 1, 2], &mut |_| Ok(2)).unwrap();
583 let res = digit_trie.look_up(&[0, 1, 2, 4]).unwrap();
584
585 assert_eq!(res.len(), 1);
586 assert_eq!(vec![0, 1, 2], res[0].path);
587 }
588
589 #[test]
590 fn digit_trie_insert_on_leaf_returns_both() {
591 let mut digit_trie = DigitTrie::new(5);
592 digit_trie.insert(&[0, 1, 2], &mut |_| Ok(1)).unwrap();
593 digit_trie.insert(&[0, 1, 2, 3], &mut |_| Ok(2)).unwrap();
594 let res = digit_trie.look_up(&[0, 1, 2, 3]).unwrap();
595
596 assert_eq!(res.len(), 2);
597 assert_eq!(vec![0, 1, 2], res[0].path);
598 assert_eq!(vec![0, 1, 2, 3], res[1].path);
599 }
600
601 #[test]
602 fn digit_trie_query_non_inserted_returns_not_found() {
603 let mut digit_trie = DigitTrie::new(5);
604 digit_trie.insert(&[0, 1, 2], &mut |_| Ok(1)).unwrap();
605 digit_trie.insert(&[1, 2, 3], &mut |_| Ok(2)).unwrap();
606 assert!(digit_trie.look_up(&[0, 1, 3]).is_none());
607 assert!(digit_trie.look_up(&[1, 2, 5]).is_none());
608 assert!(digit_trie.look_up(&[0, 0, 0]).is_none());
609 }
610
611 #[test]
612 fn digit_trie_replace_data_when_insert_on_existing_path() {
613 let mut digit_trie = DigitTrie::new(5);
614 let path = &[0, 1, 2, 3];
615 digit_trie.insert(path, &mut |_| Ok(1)).unwrap();
616 digit_trie.insert(path, &mut |_| Ok(2)).unwrap();
617 let res = digit_trie.look_up(path);
618 match res {
619 None => panic!(),
620 Some(l_res) => assert_eq!(*l_res[0].value, 2),
621 }
622 }
623
624 #[test]
625 fn digit_trie_insert_on_mid_node_returns_all() {
626 let mut digit_trie = DigitTrie::new(5);
627 digit_trie.insert(&[0, 1, 2, 3], &mut |_| Ok(1)).unwrap();
628 digit_trie.insert(&[0, 1, 2, 4], &mut |_| Ok(2)).unwrap();
629 digit_trie.insert(&[0, 1, 2], &mut |_| Ok(3)).unwrap();
630
631 let res = digit_trie.look_up(&[0, 1, 2, 3]).unwrap();
632
633 assert_eq!(2, res.len());
634 assert_eq!(*res[0].value, 3);
635 assert_eq!(*res[1].value, 1);
636
637 let res = digit_trie.look_up(&[0, 1, 2, 4]).unwrap();
638
639 assert_eq!(2, res.len());
640 assert_eq!(*res[0].value, 3);
641 assert_eq!(*res[1].value, 2);
642 }
643
644 fn assert_not_found<T>(res: Option<Vec<LookupResult<T, usize>>>)
645 where
646 T: Copy,
647 {
648 if res.is_some() {
649 panic!();
650 }
651 }
652
653 #[test]
654 fn digit_trie_return_not_found_if_not_inserted() {
655 let mut digit_trie = DigitTrie::new(5);
656 digit_trie.insert(&[0, 1, 2], &mut |_| Ok(1)).unwrap();
657 digit_trie.insert(&[0, 1, 3], &mut |_| Ok(2)).unwrap();
658 digit_trie.insert(&[4, 1, 2], &mut |_| Ok(3)).unwrap();
659
660 assert_not_found(digit_trie.look_up(&[1, 2, 5]));
661 assert_not_found(digit_trie.look_up(&[2]));
662 assert_not_found(digit_trie.look_up(&[1]));
663 assert_not_found(digit_trie.look_up(&[1, 3]));
664 assert_not_found(digit_trie.look_up(&[2, 1, 3]));
665 }
666
667 #[test]
668 fn digit_trie_returns_inserted_values_when_iterating() {
669 for test_case in digit_trie_test_cases() {
670 let mut digit_trie = DigitTrie::<usize>::new(16);
671 for (i, path) in test_case.iter().enumerate() {
672 digit_trie.insert(path, &mut |_| Ok(i)).unwrap();
673 }
674
675 let digit_trie_iter = DigitTrieIter::new(&digit_trie);
676
677 let mut count = 0;
678 for (i, res) in digit_trie_iter.enumerate() {
679 assert_eq!(test_case[i], res.path);
680 assert_eq!(i, *res.value);
681 count += 1;
682 }
683
684 assert_eq!(test_case.len(), count);
685 }
686 }
687
688 #[test]
689 fn digit_trie_returns_node_values_when_iterating() {
690 let mut digit_trie = DigitTrie::new(5);
691 let test_cases = vec![
692 vec![vec![0, 1, 2, 3], vec![0, 1, 2, 4], vec![0, 1, 2]],
693 vec![vec![0, 1, 2], vec![0, 1, 2, 3], vec![0, 1, 2, 4]],
694 vec![
695 vec![0, 1],
696 vec![0, 1, 2],
697 vec![0, 1, 2, 3],
698 vec![0, 1, 2, 4],
699 ],
700 vec![
701 vec![0, 1, 2],
702 vec![0, 1, 2, 3],
703 vec![0, 1],
704 vec![0, 1, 2, 4],
705 ],
706 vec![
707 vec![0, 1, 2, 3],
708 vec![0, 1, 2, 4],
709 vec![0, 1],
710 vec![0, 1, 2],
711 ],
712 ];
713 for test_case in test_cases {
714 for (i, test_path) in test_case.iter().enumerate() {
715 digit_trie.insert(test_path, &mut |_| Ok(i)).unwrap();
716 }
717
718 let digit_trie_iter = DigitTrieIter::new(&digit_trie);
719
720 let mut count = 0;
721 for res in digit_trie_iter {
722 assert_eq!(
723 *res.value,
724 test_case.iter().position(|x| x == &res.path).unwrap()
725 );
726 count += 1;
727 }
728
729 assert_eq!(test_case.len(), count);
730 }
731 }
732
733 #[test]
734 fn digit_trie_iterate_gets_all_inserted_values() {
735 let mut digit_trie = DigitTrie::new(2);
736 let paths = vec![vec![0, 0], vec![0, 1], vec![1, 0, 0], vec![0, 1, 0]];
737 let mut counter = 0;
738 let mut get_value = |_: Option<usize>| -> Result<usize, Error> {
739 let res = counter;
740 counter += 1;
741 Ok(res)
742 };
743
744 for path in &paths {
745 digit_trie.insert(path, &mut get_value).unwrap();
746 }
747
748 let iter = DigitTrieIter::new(&digit_trie);
749
750 let mut unordered = iter.map(|x| *x.value).collect::<Vec<_>>();
751
752 assert_eq!(paths.len(), unordered.len());
753
754 unordered.sort_unstable();
755
756 for (prev_index, i) in unordered.iter().skip(1).enumerate() {
757 assert_eq!(*i, prev_index + 1);
758 }
759 }
760}