1#![cfg_attr(not(doctest), doc = include_str!("../README.md"))]
16#![deny(
17 missing_docs,
18 missing_debug_implementations,
19 unreachable_pub,
20 rustdoc::broken_intra_doc_links,
21 unsafe_code
22)]
23#![warn(rust_2018_idioms)]
24
25mod u256;
26
27use std::{
28 collections::{BTreeMap, BTreeSet, VecDeque},
29 ops::RangeFrom,
30};
31
32use u256::U256;
33
34#[derive(Debug, Clone)]
35#[repr(transparent)]
36struct MasksByByteSized<I>([I; 256]);
37
38impl<I> Default for MasksByByteSized<I>
39where
40 I: Default + Copy,
41{
42 fn default() -> Self {
43 Self([I::default(); 256])
44 }
45}
46
47#[allow(clippy::large_enum_variant)]
48enum MasksByByte {
49 U8(MasksByByteSized<u8>),
50 U16(MasksByByteSized<u16>),
51 U32(MasksByByteSized<u32>),
52 U64(MasksByByteSized<u64>),
53 U128(MasksByByteSized<u128>),
54 U256(MasksByByteSized<U256>),
55}
56
57impl MasksByByte {
58 fn new(used_bytes: BTreeSet<u8>) -> Self {
59 match used_bytes.len() {
60 ..=8 => MasksByByte::U8(MasksByByteSized::<u8>::new(used_bytes)),
61 9..=16 => {
62 MasksByByte::U16(MasksByByteSized::<u16>::new(used_bytes))
63 }
64 17..=32 => {
65 MasksByByte::U32(MasksByByteSized::<u32>::new(used_bytes))
66 }
67 33..=64 => {
68 MasksByByte::U64(MasksByByteSized::<u64>::new(used_bytes))
69 }
70 65..=128 => {
71 MasksByByte::U128(MasksByByteSized::<u128>::new(used_bytes))
72 }
73 129..=256 => {
74 MasksByByte::U256(MasksByByteSized::<U256>::new(used_bytes))
75 }
76 _ => unreachable!("There are only 256 possible u8s"),
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
84pub struct TrieHardSized<'a, T, I> {
85 masks: MasksByByteSized<I>,
86 nodes: Vec<TrieState<'a, T, I>>,
87}
88
89impl<'a, T, I> Default for TrieHardSized<'a, T, I>
90where
91 I: Default + Copy,
92{
93 fn default() -> Self {
94 Self {
95 masks: MasksByByteSized::default(),
96 nodes: Default::default(),
97 }
98 }
99}
100
101#[derive(PartialEq, Eq, PartialOrd, Ord)]
102struct StateSpec<'a> {
103 prefix: &'a [u8],
104 index: usize,
105}
106
107#[derive(Debug, Clone)]
108struct SearchNode<I> {
109 mask: I,
110 edge_start: usize,
111}
112
113#[derive(Debug, Clone)]
114enum TrieState<'a, T, I> {
115 Leaf(&'a [u8], T),
116 Search(SearchNode<I>),
117 SearchOrLeaf(&'a [u8], T, SearchNode<I>),
118}
119
120#[allow(clippy::large_enum_variant)]
142#[derive(Debug, Clone)]
143pub enum TrieHard<'a, T> {
144 U8(TrieHardSized<'a, T, u8>),
146 U16(TrieHardSized<'a, T, u16>),
148 U32(TrieHardSized<'a, T, u32>),
150 U64(TrieHardSized<'a, T, u64>),
152 U128(TrieHardSized<'a, T, u128>),
154 U256(TrieHardSized<'a, T, U256>),
156}
157
158impl<'a, T> Default for TrieHard<'a, T> {
159 fn default() -> Self {
160 TrieHard::U8(TrieHardSized::default())
161 }
162}
163
164impl<'a, T> TrieHard<'a, T>
165where
166 T: 'a + Copy,
167{
168 pub fn new(values: Vec<(&'a [u8], T)>) -> Self {
190 if values.is_empty() {
191 return Self::default();
192 }
193
194 let used_bytes = values
195 .iter()
196 .flat_map(|(k, _)| k.iter())
197 .cloned()
198 .collect::<BTreeSet<_>>();
199
200 let masks = MasksByByte::new(used_bytes);
201
202 match masks {
203 MasksByByte::U8(masks) => {
204 TrieHard::U8(TrieHardSized::<'_, _, u8>::new(masks, values))
205 }
206 MasksByByte::U16(masks) => {
207 TrieHard::U16(TrieHardSized::<'_, _, u16>::new(masks, values))
208 }
209 MasksByByte::U32(masks) => {
210 TrieHard::U32(TrieHardSized::<'_, _, u32>::new(masks, values))
211 }
212 MasksByByte::U64(masks) => {
213 TrieHard::U64(TrieHardSized::<'_, _, u64>::new(masks, values))
214 }
215 MasksByByte::U128(masks) => {
216 TrieHard::U128(TrieHardSized::<'_, _, u128>::new(masks, values))
217 }
218 MasksByByte::U256(masks) => {
219 TrieHard::U256(TrieHardSized::<'_, _, U256>::new(masks, values))
220 }
221 }
222 }
223
224 pub fn get<K: AsRef<[u8]>>(&self, raw_key: K) -> Option<T> {
238 match self {
239 TrieHard::U8(trie) => trie.get(raw_key),
240 TrieHard::U16(trie) => trie.get(raw_key),
241 TrieHard::U32(trie) => trie.get(raw_key),
242 TrieHard::U64(trie) => trie.get(raw_key),
243 TrieHard::U128(trie) => trie.get(raw_key),
244 TrieHard::U256(trie) => trie.get(raw_key),
245 }
246 }
247
248 pub fn get_from_bytes(&self, key: &[u8]) -> Option<T> {
260 match self {
261 TrieHard::U8(trie) => trie.get_from_bytes(key),
262 TrieHard::U16(trie) => trie.get_from_bytes(key),
263 TrieHard::U32(trie) => trie.get_from_bytes(key),
264 TrieHard::U64(trie) => trie.get_from_bytes(key),
265 TrieHard::U128(trie) => trie.get_from_bytes(key),
266 TrieHard::U256(trie) => trie.get_from_bytes(key),
267 }
268 }
269
270 pub fn iter(&self) -> TrieIter<'_, 'a, T> {
285 match self {
286 TrieHard::U8(trie) => TrieIter::U8(trie.iter()),
287 TrieHard::U16(trie) => TrieIter::U16(trie.iter()),
288 TrieHard::U32(trie) => TrieIter::U32(trie.iter()),
289 TrieHard::U64(trie) => TrieIter::U64(trie.iter()),
290 TrieHard::U128(trie) => TrieIter::U128(trie.iter()),
291 TrieHard::U256(trie) => TrieIter::U256(trie.iter()),
292 }
293 }
294
295 pub fn prefix_search<K: AsRef<[u8]>>(
310 &self,
311 prefix: K,
312 ) -> TrieIter<'_, 'a, T> {
313 match self {
314 TrieHard::U8(trie) => TrieIter::U8(trie.prefix_search(prefix)),
315 TrieHard::U16(trie) => TrieIter::U16(trie.prefix_search(prefix)),
316 TrieHard::U32(trie) => TrieIter::U32(trie.prefix_search(prefix)),
317 TrieHard::U64(trie) => TrieIter::U64(trie.prefix_search(prefix)),
318 TrieHard::U128(trie) => TrieIter::U128(trie.prefix_search(prefix)),
319 TrieHard::U256(trie) => TrieIter::U256(trie.prefix_search(prefix)),
320 }
321 }
322}
323
324#[derive(Debug)]
326pub enum TrieIter<'b, 'a, T> {
327 U8(TrieIterSized<'b, 'a, T, u8>),
329 U16(TrieIterSized<'b, 'a, T, u16>),
331 U32(TrieIterSized<'b, 'a, T, u32>),
333 U64(TrieIterSized<'b, 'a, T, u64>),
335 U128(TrieIterSized<'b, 'a, T, u128>),
337 U256(TrieIterSized<'b, 'a, T, U256>),
339}
340
341#[derive(Debug, Default)]
342struct TrieNodeIter {
343 node_index: usize,
344 stage: TrieNodeIterStage,
345}
346
347#[derive(Debug, Default)]
348enum TrieNodeIterStage {
349 #[default]
350 Inner,
351 Child(usize, usize),
352}
353
354#[derive(Debug)]
357pub struct TrieIterSized<'b, 'a, T, I> {
358 stack: Vec<TrieNodeIter>,
359 trie: &'b TrieHardSized<'a, T, I>,
360}
361
362impl<'b, 'a, T, I> TrieIterSized<'b, 'a, T, I> {
363 fn empty(trie: &'b TrieHardSized<'a, T, I>) -> Self {
364 Self {
365 stack: Default::default(),
366 trie,
367 }
368 }
369
370 fn new(trie: &'b TrieHardSized<'a, T, I>, node_index: usize) -> Self {
371 Self {
372 stack: vec![TrieNodeIter {
373 node_index,
374 stage: Default::default(),
375 }],
376 trie,
377 }
378 }
379}
380
381impl<'b, 'a, T> Iterator for TrieIter<'b, 'a, T>
382where
383 T: Copy,
384{
385 type Item = (&'a [u8], T);
386
387 fn next(&mut self) -> Option<Self::Item> {
388 match self {
389 TrieIter::U8(iter) => iter.next(),
390 TrieIter::U16(iter) => iter.next(),
391 TrieIter::U32(iter) => iter.next(),
392 TrieIter::U64(iter) => iter.next(),
393 TrieIter::U128(iter) => iter.next(),
394 TrieIter::U256(iter) => iter.next(),
395 }
396 }
397}
398
399impl<'a, T> FromIterator<&'a T> for TrieHard<'a, &'a T>
400where
401 T: 'a + AsRef<[u8]> + ?Sized,
402{
403 fn from_iter<I: IntoIterator<Item = &'a T>>(values: I) -> Self {
404 let values = values
405 .into_iter()
406 .map(|v| (v.as_ref(), v))
407 .collect::<Vec<_>>();
408
409 Self::new(values)
410 }
411}
412
413macro_rules! trie_impls {
414 ($($int_type:ty),+) => {
415 $(
416 trie_impls!(_impl $int_type);
417 )+
418 };
419
420 (_impl $int_type:ty) => {
421
422 impl SearchNode<$int_type> {
423 fn evaluate<T>(&self, c: u8, trie: &TrieHardSized<'_, T, $int_type>) -> Option<usize> {
424 let c_mask = trie.masks.0[c as usize];
425 let mask_res = self.mask & c_mask;
426 (mask_res > 0).then(|| {
427 let smaller_bits = mask_res - 1;
428 let smaller_bits_mask = smaller_bits & self.mask;
429 let index_offset = smaller_bits_mask.count_ones() as usize;
430 self.edge_start + index_offset
431 })
432 }
433 }
434
435 impl<'a, T> TrieHardSized<'a, T, $int_type>
436 where
437 T: Copy
438 {
439
440 pub fn get<K: AsRef<[u8]>>(&self, key: K) -> Option<T> {
458 self.get_from_bytes(key.as_ref())
459 }
460
461 pub fn get_from_bytes(&self, key: &[u8]) -> Option<T> {
477 let mut state = self.nodes.get(0)?;
478
479 for (i, c) in key.iter().enumerate() {
480
481 let next_state_opt = match state {
482 TrieState::Leaf(k, value) => {
483 return (
484 k.len() == key.len()
485 && k[i..] == key[i..]
486 ).then_some(*value)
487 }
488 TrieState::Search(search)
489 | TrieState::SearchOrLeaf(_, _, search) => {
490 search.evaluate(*c, self)
491 }
492 };
493
494 if let Some(next_state_index) = next_state_opt {
495 state = &self.nodes[next_state_index];
496 } else {
497 return None;
498 }
499 }
500
501 if let TrieState::Leaf(k, value)
502 | TrieState::SearchOrLeaf(k, value, _) = state
503 {
504 (k.len() == key.len()).then_some(*value)
505 } else {
506 None
507 }
508 }
509
510 pub fn iter(&self) -> TrieIterSized<'_, 'a, T, $int_type> {
529 TrieIterSized {
530 stack: vec![TrieNodeIter::default()],
531 trie: self
532 }
533 }
534
535
536 pub fn prefix_search<K: AsRef<[u8]>>(&self, prefix: K) -> TrieIterSized<'_, 'a, T, $int_type> {
555 let key = prefix.as_ref();
556 let mut node_index = 0;
557 let Some(mut state) = self.nodes.get(node_index) else {
558 return TrieIterSized::empty(self);
559 };
560
561 for (i, c) in key.iter().enumerate() {
562 let next_state_opt = match state {
563 TrieState::Leaf(k, _) => {
564 if k.len() == key.len() && k[i..] == key[i..] {
565 return TrieIterSized::new(self, node_index);
566 } else {
567 return TrieIterSized::empty(self);
568 }
569 }
570 TrieState::Search(search)
571 | TrieState::SearchOrLeaf(_, _, search) => {
572 search.evaluate(*c, self)
573 }
574 };
575
576 if let Some(next_state_index) = next_state_opt {
577 node_index = next_state_index;
578 state = &self.nodes[next_state_index];
579 } else {
580 return TrieIterSized::empty(self);
581 }
582 }
583
584 TrieIterSized::new(self, node_index)
585 }
586 }
587
588 impl<'a, T> TrieHardSized<'a, T, $int_type> where T: 'a + Copy {
589 fn new(masks: MasksByByteSized<$int_type>, values: Vec<(&'a [u8], T)>) -> Self {
590 let values = values.into_iter().collect::<Vec<_>>();
591 let sorted = values
592 .iter()
593 .map(|(k, v)| (*k, *v))
594 .collect::<BTreeMap<_, _>>();
595
596 let mut nodes = Vec::new();
597 let mut next_index = 1;
598
599 let root_state_spec = StateSpec {
600 prefix: &[],
601 index: 0,
602 };
603
604 let mut spec_queue = VecDeque::new();
605 spec_queue.push_back(root_state_spec);
606
607 while let Some(spec) = spec_queue.pop_front() {
608 debug_assert_eq!(spec.index, nodes.len());
609 let (state, next_specs) = TrieState::<'_, _, $int_type>::new(
610 spec,
611 next_index,
612 &masks.0,
613 &sorted,
614 );
615
616 next_index += next_specs.len();
617 spec_queue.extend(next_specs);
618 nodes.push(state);
619 }
620
621 TrieHardSized {
622 nodes,
623 masks,
624 }
625 }
626 }
627
628
629 impl <'a, T> TrieState<'a, T, $int_type> where T: 'a + Copy {
630 fn new(
631 spec: StateSpec<'a>,
632 edge_start: usize,
633 byte_masks: &[$int_type; 256],
634 sorted: &BTreeMap<&'a [u8], T>,
635 ) -> (Self, Vec<StateSpec<'a>>) {
636 let StateSpec { prefix, .. } = spec;
637
638 let prefix_len = prefix.len();
639 let next_prefix_len = prefix_len + 1;
640
641 let mut prefix_match = None;
642 let mut children_seen = 0;
643 let mut last_seen = None;
644
645 let next_states_paired = sorted
646 .range(RangeFrom { start: prefix })
647 .take_while(|(key, _)| key.starts_with(prefix))
648 .filter_map(|(key, val)| {
649 children_seen += 1;
650 last_seen = Some((key, *val));
651
652 if *key == prefix {
653 prefix_match = Some((key, *val));
654 None
655 } else {
656 let next_c = key.get(prefix_len).unwrap();
659 let next_prefix = &key[..next_prefix_len];
660
661 Some((
662 *next_c,
663 StateSpec {
664 prefix: next_prefix,
665 index: 0,
666 },
667 ))
668 }
669 })
670 .collect::<BTreeMap<_, _>>()
671 .into_iter()
672 .collect::<Vec<_>>();
673
674 let (last_k, last_v) = last_seen.unwrap();
677
678 if children_seen == 1 {
679 return (TrieState::Leaf(last_k, last_v), vec![]);
680 }
681
682 if next_states_paired.is_empty() {
684 return (TrieState::Leaf(last_k, last_v), vec![], );
685 }
686
687 let mut mask = Default::default();
688
689 let next_state_specs = next_states_paired
691 .into_iter()
692 .enumerate()
693 .map(|(i, (c, mut next_state))| {
694 let next_node = edge_start + i;
695 next_state.index = next_node;
696 mask |= byte_masks[c as usize];
697 next_state
698 })
699 .collect();
700
701 let search_node = SearchNode { mask, edge_start };
702 let state = match prefix_match {
703 Some((key, value)) => {
704 TrieState::SearchOrLeaf(key, value, search_node)
705 }
706 _ => TrieState::Search(search_node),
707 };
708
709 (state, next_state_specs)
710 }
711 }
712
713 impl MasksByByteSized<$int_type> {
714 fn new(used_bytes: BTreeSet<u8>) -> Self {
715 let mut mask = Default::default();
716 mask += 1;
717
718 let mut byte_masks = [Default::default(); 256];
719
720 for c in used_bytes.into_iter() {
721 byte_masks[c as usize] = mask;
722 mask <<= 1;
723
724 }
725
726 Self(byte_masks)
727 }
728 }
729
730 impl <'b, 'a, T> Iterator for TrieIterSized<'b, 'a, T, $int_type>
731 where
732 T: Copy
733 {
734 type Item = (&'a [u8], T);
735
736 fn next(&mut self) -> Option<Self::Item> {
737
738 use TrieState as T;
739 use TrieNodeIterStage as S;
740
741 while let Some((node, node_index, stage)) = self.stack.pop()
742 .and_then(|TrieNodeIter { node_index, stage }| {
743 self.trie.nodes.get(node_index).map(|node| (node, node_index, stage))
744 })
745 {
746 match (node, stage) {
747 (T::Leaf(key, value), S::Inner) => return Some((*key, *value)),
748 (T::SearchOrLeaf(key, value, search), S::Inner) => {
749 self.stack.push(TrieNodeIter {
750 node_index,
751 stage: TrieNodeIterStage::Child(0, search.mask.count_ones() as usize)
752 });
753 self.stack.push(TrieNodeIter {
754 node_index: search.edge_start,
755 stage: Default::default()
756 });
757 return Some((*key, *value));
758 }
759 (T::Search(search), S::Inner) => {
760 self.stack.push(TrieNodeIter {
761 node_index,
762 stage: TrieNodeIterStage::Child(0, search.mask.count_ones() as usize)
763 });
764 self.stack.push(TrieNodeIter {
765 node_index: search.edge_start,
766 stage: Default::default()
767 });
768 }
769 (
770 T::SearchOrLeaf(_, _, search) | T::Search(search),
771 S::Child(mut child, child_count)
772 ) => {
773 child += 1;
774 if child < child_count {
775 self.stack.push(TrieNodeIter {
776 node_index,
777 stage: TrieNodeIterStage::Child(child, child_count)
778 });
779 self.stack.push(TrieNodeIter {
780 node_index: search.edge_start + child,
781 stage: Default::default()
782 });
783 }
784 }
785 _ => unreachable!()
786 }
787 }
788
789 None
790 }
791 }
792 }
793}
794
795trie_impls! {u8, u16, u32, u64, u128, U256}
796
797#[cfg(test)]
798mod tests {
799 use rstest::rstest;
800
801 use super::*;
802
803 #[test]
804 fn test_trivial() {
805 let empty: Vec<&str> = vec![];
806 let empty_trie = empty.iter().collect::<TrieHard<'_, _>>();
807
808 assert_eq!(None, empty_trie.get("anything"));
809 }
810
811 #[rstest]
812 #[case("", Some(""))]
813 #[case("a", Some("a"))]
814 #[case("ab", Some("ab"))]
815 #[case("abc", None)]
816 #[case("aac", Some("aac"))]
817 #[case("aa", None)]
818 #[case("aab", None)]
819 #[case("adddd", Some("adddd"))]
820 fn test_small_get(#[case] key: &str, #[case] expected: Option<&str>) {
821 let trie = ["", "a", "ab", "aac", "adddd", "addde"]
822 .into_iter()
823 .collect::<TrieHard<'_, _>>();
824 assert_eq!(expected, trie.get(key));
825 }
826
827 #[test]
828 fn test_skip_to_leaf() {
829 let trie = ["a", "aa", "aaa"].into_iter().collect::<TrieHard<'_, _>>();
830
831 assert_eq!(trie.get("aa"), Some("aa"))
832 }
833
834 #[rstest]
835 #[case(8)]
836 #[case(16)]
837 #[case(32)]
838 #[case(64)]
839 #[case(128)]
840 #[case(256)]
841 fn test_sizes(#[case] bits: usize) {
842 let range = 0..bits;
843 let bytes = range.map(|b| [b as u8]).collect::<Vec<_>>();
844 let trie = bytes.iter().collect::<TrieHard<'_, _>>();
845
846 use TrieHard as T;
847
848 match (bits, trie) {
849 (8, T::U8(_)) => (),
850 (16, T::U16(_)) => (),
851 (32, T::U32(_)) => (),
852 (64, T::U64(_)) => (),
853 (128, T::U128(_)) => (),
854 (256, T::U256(_)) => (),
855 _ => panic!("Mismatched trie sizes"),
856 }
857 }
858
859 #[rstest]
860 #[case(include_str!("../data/1984.txt"))]
861 #[case(include_str!("../data/sun-rising.txt"))]
862 fn test_full_text(#[case] text: &str) {
863 let words: Vec<&str> =
864 text.split(|c: char| c.is_whitespace()).collect();
865 let trie: TrieHard<'_, _> = words.iter().copied().collect();
866
867 let unique_words = words
868 .into_iter()
869 .collect::<BTreeSet<_>>()
870 .into_iter()
871 .collect::<Vec<_>>();
872
873 for word in &unique_words {
874 assert!(trie.get(word).is_some())
875 }
876
877 assert_eq!(
878 unique_words,
879 trie.iter().map(|(_, v)| v).collect::<Vec<_>>()
880 );
881 }
882
883 #[test]
884 fn test_unicode() {
885 let trie: TrieHard<'_, _> = ["bär", "bären"].into_iter().collect();
886
887 assert_eq!(trie.get("bär"), Some("bär"));
888 assert_eq!(trie.get("bä"), None);
889 assert_eq!(trie.get("bären"), Some("bären"));
890 assert_eq!(trie.get("bärën"), None);
891 }
892
893 #[rstest]
894 #[case(&[], &[])]
895 #[case(&[""], &[""])]
896 #[case(&["aaa", "a", ""], &["", "a", "aaa"])]
897 #[case(&["aaa", "a", ""], &["", "a", "aaa"])]
898 #[case(&["", "a", "ab", "aac", "adddd", "addde"], &["", "a", "aac", "ab", "adddd", "addde"])]
899 fn test_iter(#[case] input: &[&str], #[case] output: &[&str]) {
900 let trie = input.iter().copied().collect::<TrieHard<'_, _>>();
901 let emitted = trie.iter().map(|(_, v)| v).collect::<Vec<_>>();
902 assert_eq!(emitted, output);
903 }
904
905 #[rstest]
906 #[case(&[], "", &[])]
907 #[case(&[""], "", &[""])]
908 #[case(&["aaa", "a", ""], "", &["", "a", "aaa"])]
909 #[case(&["aaa", "a", ""], "a", &["a", "aaa"])]
910 #[case(&["aaa", "a", ""], "aa", &["aaa"])]
911 #[case(&["aaa", "a", ""], "aab", &[])]
912 #[case(&["aaa", "a", ""], "aaa", &["aaa"])]
913 #[case(&["aaa", "a", ""], "b", &[])]
914 #[case(&["dad", "ant", "and", "dot", "do"], "d", &["dad", "do", "dot"])]
915 fn test_prefix_search(
916 #[case] input: &[&str],
917 #[case] prefix: &str,
918 #[case] output: &[&str],
919 ) {
920 let trie = input.iter().copied().collect::<TrieHard<'_, _>>();
921 let emitted = trie
922 .prefix_search(prefix)
923 .map(|(_, v)| v)
924 .collect::<Vec<_>>();
925 assert_eq!(emitted, output);
926 }
927}