1use std::marker::PhantomData;
2use std::ops::{Deref, DerefMut};
3
4use replace_with::replace_with_or_abort;
5
6pub trait AsSlice<K> {
7 fn as_slice(&self) -> &[K];
8}
9
10impl AsSlice<u8> for &str {
11 fn as_slice(&self) -> &[u8] {
12 self.as_bytes()
13 }
14}
15
16impl AsSlice<u8> for String {
17 fn as_slice(&self) -> &[u8] {
18 self.as_bytes()
19 }
20}
21
22impl<K> AsSlice<K> for &[K] {
23 fn as_slice(&self) -> &[K] {
24 self
25 }
26}
27
28impl<K> AsSlice<K> for Vec<K> {
29 fn as_slice(&self) -> &[K] {
30 self.as_slice()
31 }
32}
33
34#[derive(PartialEq, Eq, Debug)]
47pub struct RadixTree<K, V>
48where
49 K: PartialEq,
50{
51 count: usize,
52 root: RadixTreeNode<K, V>,
53}
54
55#[derive(PartialEq, Eq, Debug)]
56pub struct RadixTreeNode<K, V>
57where
58 K: PartialEq,
59{
60 value: Option<V>,
61 edges: Vec<(Vec<K>, RadixTreeNode<K, V>)>,
69}
70
71impl<K, V> RadixTree<K, V>
72where
73 K: PartialEq + Clone,
74{
75 pub fn new() -> RadixTree<K, V> {
77 RadixTree {
78 count: 0,
79 root: RadixTreeNode::new(),
80 }
81 }
82
83 pub fn singleton(value: V) -> RadixTree<K, V> {
85 RadixTree {
86 count: 1,
87 root: RadixTreeNode::singleton(value),
88 }
89 }
90
91 pub fn len(&self) -> usize {
93 self.count
94 }
95
96 pub fn is_empty(&self) -> bool {
98 self.count == 0
99 }
100
101 pub fn insert<T>(&mut self, key: T, value: V) -> Option<V>
104 where
105 T: AsSlice<K>,
106 {
107 let optv = self.root.insert(key, value);
108 if optv.is_none() {
109 self.count += 1;
110 }
111 optv
112 }
113
114 pub fn remove<T>(&mut self, key: T) -> Option<V>
117 where
118 T: AsSlice<K>,
119 {
120 let optv = self.root.remove(key);
121 if optv.is_some() {
122 self.count -= 1;
123 }
124 optv
125 }
126
127 pub fn clear(&mut self) {
129 self.root.clear();
130 self.count = 0;
131 }
132}
133
134impl<K, V> Deref for RadixTree<K, V>
135where
136 K: PartialEq,
137{
138 type Target = RadixTreeNode<K, V>;
139
140 fn deref(&self) -> &Self::Target {
141 &self.root
142 }
143}
144
145impl<K, V> DerefMut for RadixTree<K, V>
146where
147 K: PartialEq,
148{
149 fn deref_mut(&mut self) -> &mut Self::Target {
150 &mut self.root
151 }
152}
153
154impl<K, V> Default for RadixTree<K, V>
155where
156 K: PartialEq + Clone,
157{
158 fn default() -> Self {
159 Self::new()
160 }
161}
162
163impl<K, V> RadixTreeNode<K, V>
164where
165 K: PartialEq + Clone,
166{
167 fn new() -> RadixTreeNode<K, V> {
168 RadixTreeNode {
169 value: None,
170 edges: Vec::new(),
171 }
172 }
173
174 fn singleton(value: V) -> RadixTreeNode<K, V> {
175 RadixTreeNode {
176 value: Some(value),
177 edges: Vec::new(),
178 }
179 }
180
181 pub fn is_leaf(&self) -> bool {
183 self.edges.is_empty()
184 }
185
186 pub fn is_node(&self) -> bool {
188 !self.is_leaf()
189 }
190
191 pub fn value(&self) -> Option<&V> {
193 self.get(&[] as &[K])
194 }
195
196 pub fn len(&self) -> usize {
199 self.values().count()
200 }
201
202 pub fn is_empty(&self) -> bool {
204 self.len() == 0
205 }
206
207 pub fn contains_key<T>(&self, key: T) -> bool
209 where
210 T: AsSlice<K>,
211 {
212 self.get(key).is_some()
213 }
214
215 pub fn iter<'a>(&'a self) -> Iter<'a, K, V> {
218 Iter {
219 node: self,
220 prefix: Vec::new(),
221 parents: Vec::new(),
222 yielded: false,
223 index: 0,
224 }
225 }
226
227 pub fn keys<'a>(&'a self) -> Keys<'a, K, V> {
228 Keys {
229 node: self,
230 parents: Vec::new(),
231 prefix: Vec::new(),
232 yielded: false,
233 index: 0,
234 }
235 }
236
237 pub fn edges<'a>(&'a self) -> Edges<'a, K, V> {
238 Edges {
239 node: self,
240 parents: Vec::new(),
241 prefix: &[],
242 yielded: false,
243 index: 0,
244 }
245 }
246
247 pub fn values<'a>(&'a self) -> Values<'a, K, V> {
248 Values {
249 node: self,
250 parents: Vec::new(),
251 yielded: false,
252 index: 0,
253 }
254 }
255
256 pub fn iter_edges<'a>(&'a self) -> IterEdges<'a, K, V> {
257 IterEdges {
258 node: self,
259 prefix: &[],
260 parents: Vec::new(),
261 yielded: false,
262 index: 0,
263 }
264 }
265
266 pub fn iter_mut<'a>(&'a mut self) -> IterMut<'a, K, V> {
267 IterMut {
268 node: std::ptr::from_mut(self),
269 prefix: Vec::new(),
270 parents: Vec::new(),
271 yielded: false,
272 index: 0,
273 _marker: PhantomData,
274 }
275 }
276
277 pub fn iter_edges_mut<'a>(&'a mut self) -> IterEdgesMut<'a, K, V> {
278 IterEdgesMut {
279 node: std::ptr::from_mut(self),
280 prefix: &[],
281 parents: Vec::new(),
282 yielded: false,
283 index: 0,
284 _marker: PhantomData,
285 }
286 }
287
288 pub fn values_mut<'a>(&'a mut self) -> ValuesMut<'a, K, V> {
289 ValuesMut {
290 node: std::ptr::from_mut(self),
291 parents: Vec::new(),
292 yielded: false,
293 index: 0,
294 _marker: PhantomData,
295 }
296 }
297
298 pub fn at_prefix<T>(&self, key: T) -> Option<&RadixTreeNode<K, V>>
300 where
301 T: AsSlice<K>,
302 {
303 let key = key.as_slice();
304 if key.is_empty() {
305 return Some(self);
306 }
307 for (prefix, child) in &self.edges {
308 if let Some(rest) = key.strip_prefix(prefix.as_slice()) {
309 return child.at_prefix(rest);
310 }
311 }
312 None
313 }
314
315 pub fn at_prefix_mut<T>(&mut self, key: T) -> Option<&mut RadixTreeNode<K, V>>
319 where
320 T: AsSlice<K>,
321 {
322 let key = key.as_slice();
323 if key.is_empty() {
324 return Some(self);
325 }
326 for (prefix, child) in &mut self.edges {
327 if let Some(rest) = key.strip_prefix(prefix.as_slice()) {
328 return child.at_prefix_mut(rest);
329 }
330 }
331 None
332 }
333
334 pub fn get<T>(&self, key: T) -> Option<&V>
336 where
337 T: AsSlice<K>,
338 {
339 let key = key.as_slice();
340 self.at_prefix(key).and_then(|node| node.value.as_ref())
341 }
342
343 pub fn get_mut<T>(&mut self, key: T) -> Option<&mut V>
345 where
346 T: AsSlice<K>,
347 {
348 self.at_prefix_mut(key).and_then(|node| node.value.as_mut())
349 }
350
351 fn insert<T>(&mut self, key: T, value: V) -> Option<V>
352 where
353 T: AsSlice<K>,
354 {
355 let key = key.as_slice();
356 if key.is_empty() {
357 return self.value.replace(value);
358 }
359 for (prefix, child) in &mut self.edges {
360 let common_len = longest_common_prefix(prefix, key);
361 if common_len > 0 {
362 if common_len == prefix.len() {
363 return child.insert(&key[common_len..], value);
364 }
365 let prefix_rest = prefix.drain(common_len..).collect();
367 replace_with_or_abort(child, |node| RadixTreeNode {
368 value: None,
369 edges: vec![
370 (prefix_rest, node),
371 (key[common_len..].to_vec(), RadixTreeNode::singleton(value)),
372 ],
373 });
374 return None;
375 }
376 }
377 self.edges
378 .push((key.to_vec(), RadixTreeNode::singleton(value)));
379 None
380 }
381
382 fn remove<T>(&mut self, key: T) -> Option<V>
383 where
384 T: AsSlice<K>,
385 {
386 let key = key.as_slice();
387 if key.is_empty() {
388 return self.value.take();
389 }
390 let mut cleanup_node = None;
391 for (i, (prefix, child)) in self.edges.iter_mut().enumerate() {
392 let common_len = longest_common_prefix(prefix, key);
393 if common_len > 0 {
394 if common_len == prefix.len() {
395 let removed = child.remove(&key[common_len..]);
396 if removed.is_some() && child.value.is_none() {
398 if child.edges.is_empty() {
399 cleanup_node = Some((i, removed));
400 break;
401 }
402 if child.edges.len() == 1 {
403 let (child_prefix, grandchild) = child.edges.remove(0);
404 prefix.extend(child_prefix);
405 *child = grandchild;
406 }
407 }
408 return removed;
409 }
410 return None;
414 }
415 }
416 if let Some((i, removed)) = cleanup_node {
417 self.edges.remove(i);
418 return removed;
419 }
420 None
421 }
422
423 fn clear(&mut self) {
424 self.value.take();
425 self.edges.clear();
426 }
427}
428
429impl<K, V> Default for RadixTreeNode<K, V>
430where
431 K: PartialEq + Clone,
432{
433 fn default() -> Self {
434 Self::new()
435 }
436}
437
438pub struct Iter<'a, K, V>
439where
440 K: PartialEq,
441{
442 parents: Vec<(&'a RadixTreeNode<K, V>, usize, usize)>,
443 node: &'a RadixTreeNode<K, V>,
444 prefix: Vec<K>,
445 yielded: bool,
446 index: usize,
447}
448
449impl<'a, K, V> Iterator for Iter<'a, K, V>
450where
451 K: PartialEq + Clone,
452{
453 type Item = (Vec<K>, &'a V);
454
455 fn next(&mut self) -> Option<Self::Item> {
456 loop {
457 if !self.yielded
458 && let Some(val) = &self.node.value
459 {
460 self.yielded = true;
461 return Some((self.prefix.clone(), val));
462 }
463 if let Some((prefix, node)) = self.node.edges.get(self.index) {
464 self.parents
465 .push((self.node, self.index + 1, self.prefix.len()));
466 self.node = node;
467 self.prefix.extend(prefix.iter().cloned());
468 self.yielded = false;
469 self.index = 0;
470 } else if let Some((node, index, prefix_len)) = self.parents.pop() {
471 self.prefix.truncate(prefix_len);
472 self.node = node;
473 self.index = index;
474 self.yielded = true;
475 } else {
476 return None;
477 }
478 }
479 }
480}
481
482pub struct Keys<'a, K, V>
483where
484 K: PartialEq,
485{
486 parents: Vec<(&'a RadixTreeNode<K, V>, usize, usize)>,
487 node: &'a RadixTreeNode<K, V>,
488 prefix: Vec<K>,
489 yielded: bool,
490 index: usize,
491}
492
493impl<'a, K, V> Iterator for Keys<'a, K, V>
494where
495 K: PartialEq + Clone,
496{
497 type Item = Vec<K>;
498
499 fn next(&mut self) -> Option<Self::Item> {
500 loop {
501 if !self.yielded && self.node.value.is_some() {
502 self.yielded = true;
503 return Some(self.prefix.clone());
504 }
505 if let Some((prefix, node)) = self.node.edges.get(self.index) {
506 self.parents
507 .push((self.node, self.index + 1, self.prefix.len()));
508 self.node = node;
509 self.prefix.extend(prefix.iter().cloned());
510 self.yielded = false;
511 self.index = 0;
512 } else if let Some((node, index, prefix_len)) = self.parents.pop() {
513 self.prefix.truncate(prefix_len);
514 self.node = node;
515 self.index = index;
516 self.yielded = true;
517 } else {
518 return None;
519 }
520 }
521 }
522}
523
524pub struct Values<'a, K, V>
525where
526 K: PartialEq,
527{
528 parents: Vec<(&'a RadixTreeNode<K, V>, usize)>,
529 node: &'a RadixTreeNode<K, V>,
530 yielded: bool,
531 index: usize,
532}
533
534impl<'a, K, V> Iterator for Values<'a, K, V>
535where
536 K: PartialEq,
537{
538 type Item = &'a V;
539
540 fn next(&mut self) -> Option<Self::Item> {
541 loop {
542 if !self.yielded
543 && let Some(val) = &self.node.value
544 {
545 self.yielded = true;
546 return Some(val);
547 }
548 if let Some((_, node)) = self.node.edges.get(self.index) {
549 self.parents.push((self.node, self.index + 1));
550 self.node = node;
551 self.yielded = false;
552 self.index = 0;
553 } else if let Some((node, index)) = self.parents.pop() {
554 self.node = node;
555 self.index = index;
556 self.yielded = true;
557 } else {
558 return None;
559 }
560 }
561 }
562}
563
564pub struct IterEdges<'a, K, V>
565where
566 K: PartialEq,
567{
568 parents: Vec<(&'a RadixTreeNode<K, V>, usize)>,
569 node: &'a RadixTreeNode<K, V>,
570 prefix: &'a [K],
571 yielded: bool,
572 index: usize,
573}
574
575impl<'a, K, V> Iterator for IterEdges<'a, K, V>
576where
577 K: PartialEq,
578{
579 type Item = (&'a [K], &'a V);
580
581 fn next(&mut self) -> Option<Self::Item> {
582 loop {
583 if !self.yielded
584 && let Some(val) = &self.node.value
585 {
586 self.yielded = true;
587 return Some((self.prefix, val));
588 }
589 if let Some((prefix, node)) = self.node.edges.get(self.index) {
590 self.parents.push((self.node, self.index + 1));
591 self.node = node;
592 self.prefix = prefix.as_slice();
593 self.yielded = false;
594 self.index = 0;
595 } else if let Some((node, index)) = self.parents.pop() {
596 self.node = node;
597 self.index = index;
598 self.yielded = true;
599 } else {
600 return None;
601 }
602 }
603 }
604}
605
606pub struct Edges<'a, K, V>
607where
608 K: PartialEq,
609{
610 parents: Vec<(&'a RadixTreeNode<K, V>, usize)>,
611 node: &'a RadixTreeNode<K, V>,
612 prefix: &'a [K],
613 yielded: bool,
614 index: usize,
615}
616
617impl<'a, K, V> Iterator for Edges<'a, K, V>
618where
619 K: PartialEq,
620{
621 type Item = &'a [K];
622
623 fn next(&mut self) -> Option<Self::Item> {
624 loop {
625 if !self.yielded && self.node.value.is_some() {
626 self.yielded = true;
627 return Some(self.prefix);
628 }
629 if let Some((prefix, node)) = self.node.edges.get(self.index) {
630 self.parents.push((self.node, self.index + 1));
631 self.node = node;
632 self.prefix = prefix.as_slice();
633 self.yielded = false;
634 self.index = 0;
635 } else if let Some((node, index)) = self.parents.pop() {
636 self.node = node;
637 self.index = index;
638 self.yielded = true;
639 } else {
640 return None;
641 }
642 }
643 }
644}
645
646pub struct IterMut<'a, K, V>
647where
648 K: PartialEq,
649{
650 parents: Vec<(*mut RadixTreeNode<K, V>, usize, usize)>,
651 node: *mut RadixTreeNode<K, V>,
652 prefix: Vec<K>,
653 yielded: bool,
654 index: usize,
655 _marker: PhantomData<&'a mut V>,
656}
657
658impl<'a, K, V> Iterator for IterMut<'a, K, V>
659where
660 K: PartialEq + Clone,
661{
662 type Item = (Vec<K>, &'a mut V);
663
664 fn next(&mut self) -> Option<Self::Item> {
665 loop {
666 let node = unsafe { &mut *self.node };
667 if !self.yielded
668 && let Some(val) = &mut node.value
669 {
670 self.yielded = true;
671 return Some((self.prefix.clone(), val));
672 }
673 if let Some((prefix, node)) = node.edges.get_mut(self.index) {
674 self.parents
675 .push((self.node, self.index + 1, self.prefix.len()));
676 self.node = node;
677 self.prefix.extend(prefix.iter().cloned());
678 self.yielded = false;
679 self.index = 0;
680 } else if let Some((node, index, prefix_len)) = self.parents.pop() {
681 self.prefix.truncate(prefix_len);
682 self.node = node;
683 self.index = index;
684 self.yielded = true;
685 } else {
686 return None;
687 }
688 }
689 }
690}
691
692pub struct IterEdgesMut<'a, K, V>
693where
694 K: PartialEq,
695{
696 parents: Vec<(*mut RadixTreeNode<K, V>, usize)>,
697 node: *mut RadixTreeNode<K, V>,
698 prefix: &'a [K],
699 yielded: bool,
700 index: usize,
701 _marker: PhantomData<&'a mut V>,
702}
703
704impl<'a, K, V> Iterator for IterEdgesMut<'a, K, V>
705where
706 K: PartialEq,
707{
708 type Item = (&'a [K], &'a mut V);
709
710 fn next(&mut self) -> Option<Self::Item> {
711 loop {
712 let node = unsafe { &mut *self.node };
713 if !self.yielded
714 && let Some(val) = &mut node.value
715 {
716 self.yielded = true;
717 return Some((self.prefix, val));
718 }
719 if let Some((prefix, node)) = node.edges.get_mut(self.index) {
720 self.parents.push((self.node, self.index + 1));
721 self.node = node;
722 self.prefix = prefix.as_slice();
723 self.yielded = false;
724 self.index = 0;
725 } else if let Some((node, index)) = self.parents.pop() {
726 self.node = node;
727 self.index = index;
728 self.yielded = true;
729 } else {
730 return None;
731 }
732 }
733 }
734}
735
736pub struct ValuesMut<'a, K, V>
737where
738 K: PartialEq,
739{
740 parents: Vec<(*mut RadixTreeNode<K, V>, usize)>,
741 node: *mut RadixTreeNode<K, V>,
742 yielded: bool,
743 index: usize,
744 _marker: PhantomData<&'a mut V>,
745}
746
747impl<'a, K, V> Iterator for ValuesMut<'a, K, V>
748where
749 K: PartialEq,
750{
751 type Item = &'a mut V;
752
753 fn next(&mut self) -> Option<Self::Item> {
754 loop {
755 let node = unsafe { &mut *self.node };
756 if !self.yielded
757 && let Some(val) = &mut node.value
758 {
759 self.yielded = true;
760 return Some(val);
761 }
762 if let Some((_, node)) = node.edges.get_mut(self.index) {
763 self.parents.push((self.node, self.index + 1));
764 self.node = node;
765 self.yielded = false;
766 self.index = 0;
767 } else if let Some((node, index)) = self.parents.pop() {
768 self.node = node;
769 self.index = index;
770 self.yielded = true;
771 } else {
772 return None;
773 }
774 }
775 }
776}
777
778impl<'a, K, V> IntoIterator for &'a RadixTreeNode<K, V>
779where
780 K: PartialEq + Clone,
781{
782 type Item = (Vec<K>, &'a V);
783
784 type IntoIter = Iter<'a, K, V>;
785
786 fn into_iter(self) -> Self::IntoIter {
787 self.iter()
788 }
789}
790
791impl<'a, K, V> IntoIterator for &'a mut RadixTreeNode<K, V>
792where
793 K: PartialEq + Clone,
794{
795 type Item = (Vec<K>, &'a mut V);
796
797 type IntoIter = IterMut<'a, K, V>;
798
799 fn into_iter(self) -> Self::IntoIter {
800 self.iter_mut()
801 }
802}
803
804fn longest_common_prefix<T>(s1: &[T], s2: &[T]) -> usize
805where
806 T: PartialEq,
807{
808 s1.iter()
809 .zip(s2.iter())
810 .position(|(x, y)| x != y)
811 .unwrap_or_else(|| s1.len().min(s2.len()))
812}
813
814#[cfg(test)]
815mod tests {
816 use super::*;
817
818 #[test]
819 fn longest_common_prefix_works() {
820 assert_eq!(longest_common_prefix(b"bar", b"baz"), 2);
821 assert_eq!(longest_common_prefix(b"bar", b"barbie"), 3);
822 assert_eq!(longest_common_prefix(b"foo", b"bar"), 0);
823 assert_eq!(longest_common_prefix(b"foo", b"foo"), 3);
824 }
825
826 #[test]
827 fn radix_tree_works() {
828 let mut tree = RadixTree::new();
829 assert_eq!(tree.value, None);
830 assert_eq!(tree.insert("foo", 42), None);
831 assert_eq!(tree.value, None);
832 assert_eq!(tree.edges.len(), 1);
833 let _node = tree.at_prefix("foo");
834 assert_eq!(_node.and_then(|node| node.value), Some(42));
835 assert!(_node.is_some_and(|node| node.edges.is_empty()));
836
837 assert!(tree.insert("bar", 13).is_none());
838 assert_eq!(tree.get("bar"), Some(&13));
839 assert!(tree.insert("baz", 7).is_none());
840 assert_eq!(tree.get("baz"), Some(&7));
841 let _node = tree.at_prefix("ba");
844 assert!(_node.is_some_and(|node| node.value.is_none()));
845 assert!(_node.is_some_and(|node| node.edges.len() == 2));
846 assert!(_node.is_some_and(|node| {
847 let (prefix, child) = &node.edges[0];
848 prefix == b"r" && child.value == Some(13)
849 }));
850 assert!(_node.is_some_and(|node| {
851 let (prefix, child) = &node.edges[1];
852 prefix == b"z" && child.value == Some(7)
853 }));
854
855 assert_eq!(tree.insert("ba", 18), None);
856 assert_eq!(tree.get("ba"), Some(&18));
857 assert_eq!(tree.insert("barbie", 23), None);
858 assert_eq!(tree.get("barbie"), Some(&23));
859 assert_eq!(tree.get("bag"), None);
860 assert_eq!(tree.get("qux"), None);
861 assert_eq!(tree.insert("ba", 27), Some(18));
862 assert_eq!(tree.get("ba"), Some(&27));
863
864 println!("Keys matching prefix \"ba\" and their values");
865 let subtree = tree.at_prefix("ba").unwrap();
866 for (key, value) in subtree.iter() {
867 let key = unsafe { String::from_utf8_unchecked(key.to_vec()) };
868 println!("\"{key}\": {value}");
869 }
870
871 println!("All values");
872 for v in tree.values() {
873 println!("{v}");
874 }
875
876 println!("All keys and values");
877 for (key, val) in tree.iter() {
878 let key = unsafe { String::from_utf8_unchecked(key.to_vec()) };
879 println!("\"{key}\": {val}");
880 }
881
882 println!("Fully reconstructed keys");
883 for key in tree.keys() {
884 let key = unsafe { String::from_utf8_unchecked(key.to_vec()) };
885 println!("\"{key}\"");
886 }
887
888 println!("Incrementing all values by 1");
889 for (key, val) in tree.iter_mut() {
890 let key = unsafe { String::from_utf8_unchecked(key.to_vec()) };
891 *val += 1;
892 println!("\"{key}\": {val}");
893 }
894
895 assert_eq!(tree.remove("bar"), Some(14));
896 println!("{tree:?}");
897 assert_eq!(tree.get("bar"), None);
899 assert_eq!(tree.remove("baz"), Some(8));
900 assert_eq!(tree.remove("baz"), None);
901 }
902}