1use crate::hashtree::{
2 Hash,
3 HashTree::{self, Empty, Leaf, Pruned},
4 fork, fork_hash, labeled, labeled_hash, leaf_hash,
5};
6use std::borrow::Cow;
7use std::cmp::Ordering::{self, Equal, Greater, Less};
8use std::fmt;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11enum Color {
12 Red,
13 Black,
14}
15
16impl Color {
17 fn flip_assign(&mut self) {
18 *self = self.flip();
19 }
20
21 fn flip(self) -> Self {
22 match self {
23 Self::Red => Self::Black,
24 Self::Black => Self::Red,
25 }
26 }
27}
28
29pub trait AsHashTree {
31 fn root_hash(&self) -> Hash;
34
35 fn as_hash_tree(&self) -> HashTree<'_>;
37}
38
39impl AsHashTree for Vec<u8> {
40 fn root_hash(&self) -> Hash {
41 leaf_hash(&self[..])
42 }
43
44 fn as_hash_tree(&self) -> HashTree<'_> {
45 Leaf(Cow::from(&self[..]))
46 }
47}
48
49impl AsHashTree for Hash {
50 fn root_hash(&self) -> Hash {
51 leaf_hash(&self[..])
52 }
53
54 fn as_hash_tree(&self) -> HashTree<'_> {
55 Leaf(Cow::from(&self[..]))
56 }
57}
58
59impl<K: AsRef<[u8]>, V: AsHashTree> AsHashTree for RbTree<K, V> {
60 fn root_hash(&self) -> Hash {
61 match self.root.as_ref() {
62 None => Empty.reconstruct(),
63 Some(n) => n.subtree_hash,
64 }
65 }
66
67 fn as_hash_tree(&self) -> HashTree<'_> {
68 Node::full_witness_tree(&self.root, Node::data_tree)
69 }
70}
71
72#[derive(PartialEq, Debug, Clone, Copy)]
73enum KeyBound<'a> {
74 Exact(&'a [u8]),
75 Neighbor(&'a [u8]),
76}
77
78impl<'a> AsRef<[u8]> for KeyBound<'a> {
79 fn as_ref(&self) -> &'a [u8] {
80 match self {
81 KeyBound::Exact(key) => key,
82 KeyBound::Neighbor(key) => key,
83 }
84 }
85}
86
87type NodeRef<K, V> = Option<Box<Node<K, V>>>;
88
89#[derive(Clone, Debug)]
94struct Node<K, V> {
95 key: K,
96 value: V,
97 left: NodeRef<K, V>,
98 right: NodeRef<K, V>,
99 color: Color,
100
101 subtree_hash: Hash,
104}
105
106impl<K: AsRef<[u8]>, V: AsHashTree> Node<K, V> {
107 fn new(key: K, value: V) -> Box<Node<K, V>> {
108 let value_hash = value.root_hash();
109 let data_hash = labeled_hash(key.as_ref(), &value_hash);
110 Box::new(Self {
111 key,
112 value,
113 left: None,
114 right: None,
115 color: Color::Red,
116 subtree_hash: data_hash,
117 })
118 }
119
120 fn data_hash(&self) -> Hash {
121 labeled_hash(self.key.as_ref(), &self.value.root_hash())
122 }
123
124 fn left_hash_tree(&self) -> HashTree<'_> {
125 match self.left.as_ref() {
126 None => Empty,
127 Some(l) => Pruned(l.subtree_hash),
128 }
129 }
130
131 fn right_hash_tree(&self) -> HashTree<'_> {
132 match self.right.as_ref() {
133 None => Empty,
134 Some(r) => Pruned(r.subtree_hash),
135 }
136 }
137
138 fn visit<'a, F>(n: &'a NodeRef<K, V>, f: &mut F)
139 where
140 F: 'a + FnMut(&'a [u8], &'a V),
141 {
142 if let Some(n) = n {
143 Self::visit(&n.left, f);
144 (*f)(n.key.as_ref(), &n.value);
145 Self::visit(&n.right, f);
146 }
147 }
148
149 fn data_tree(&self) -> HashTree<'_> {
150 labeled(self.key.as_ref(), self.value.as_hash_tree())
151 }
152
153 fn subtree_with<'a>(&'a self, f: impl FnOnce(&'a V) -> HashTree<'a>) -> HashTree<'a> {
154 labeled(self.key.as_ref(), f(&self.value))
155 }
156
157 fn witness_tree(&self) -> HashTree<'_> {
158 labeled(self.key.as_ref(), Pruned(self.value.root_hash()))
159 }
160
161 fn full_witness_tree<'a>(
162 n: &'a NodeRef<K, V>,
163 f: fn(&'a Node<K, V>) -> HashTree<'a>,
164 ) -> HashTree<'a> {
165 match n {
166 None => Empty,
167 Some(n) => three_way_fork(
168 Self::full_witness_tree(&n.left, f),
169 f(n),
170 Self::full_witness_tree(&n.right, f),
171 ),
172 }
173 }
174
175 fn update_subtree_hash(&mut self) {
176 self.subtree_hash = self.compute_subtree_hash();
177 }
178
179 fn compute_subtree_hash(&self) -> Hash {
180 let h = self.data_hash();
181
182 match (self.left.as_ref(), self.right.as_ref()) {
183 (None, None) => h,
184 (Some(l), None) => fork_hash(&l.subtree_hash, &h),
185 (None, Some(r)) => fork_hash(&h, &r.subtree_hash),
186 (Some(l), Some(r)) => fork_hash(&l.subtree_hash, &fork_hash(&h, &r.subtree_hash)),
187 }
188 }
189}
190
191#[derive(PartialEq, Debug)]
192enum Visit {
193 Pre,
194 In,
195 Post,
196}
197
198#[derive(Debug)]
200pub struct Iter<'a, K, V> {
201 visit: Visit,
206 parents: Vec<&'a Node<K, V>>,
207}
208
209impl<K, V> Iter<'_, K, V> {
210 fn step(&mut self) -> bool {
219 match self.parents.last() {
220 Some(tip) => {
221 match self.visit {
222 Visit::Pre => {
223 if let Some(l) = &tip.left {
224 self.parents.push(l);
225 } else {
226 self.visit = Visit::In;
227 }
228 }
229 Visit::In => {
230 if let Some(r) = &tip.right {
231 self.parents.push(r);
232 self.visit = Visit::Pre;
233 } else {
234 self.visit = Visit::Post;
235 }
236 }
237 Visit::Post => {
238 let tip = self.parents.pop().unwrap();
239 if let Some(parent) = self.parents.last() {
240 if parent
241 .left
242 .as_ref()
243 .is_some_and(|l| std::ptr::eq(l.as_ref(), tip))
244 {
245 self.visit = Visit::In;
246 }
247 }
248 }
249 }
250 true
251 }
252 None => false,
253 }
254 }
255}
256
257impl<'a, K, V> std::iter::Iterator for Iter<'a, K, V> {
258 type Item = (&'a K, &'a V);
259
260 fn next(&mut self) -> Option<Self::Item> {
261 while self.step() {
262 if self.visit == Visit::In {
263 return self.parents.last().map(|n| (&n.key, &n.value));
264 }
265 }
266 None
267 }
268}
269
270#[derive(Default, Clone)]
273pub struct RbTree<K, V> {
274 root: NodeRef<K, V>,
275}
276
277impl<K, V> PartialEq for RbTree<K, V>
278where
279 K: AsRef<[u8]> + PartialEq,
280 V: AsHashTree + PartialEq,
281{
282 fn eq(&self, other: &Self) -> bool {
283 self.iter().eq(other.iter())
284 }
285}
286
287impl<K, V> Eq for RbTree<K, V>
288where
289 K: AsRef<[u8]> + Eq,
290 V: AsHashTree + Eq,
291{
292}
293
294impl<K, V> PartialOrd for RbTree<K, V>
295where
296 K: AsRef<[u8]> + PartialOrd,
297 V: AsHashTree + PartialOrd,
298{
299 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
300 self.iter().partial_cmp(other.iter())
301 }
302}
303
304impl<K, V> Ord for RbTree<K, V>
305where
306 K: AsRef<[u8]> + Ord,
307 V: AsHashTree + Ord,
308{
309 fn cmp(&self, other: &Self) -> Ordering {
310 self.iter().cmp(other.iter())
311 }
312}
313
314impl<K, V> std::iter::FromIterator<(K, V)> for RbTree<K, V>
315where
316 K: AsRef<[u8]>,
317 V: AsHashTree,
318{
319 fn from_iter<T>(iter: T) -> Self
320 where
321 T: IntoIterator<Item = (K, V)>,
322 {
323 let mut t = RbTree::<K, V>::new();
324 for (k, v) in iter {
325 t.insert(k, v);
326 }
327 t
328 }
329}
330
331impl<K, V> std::fmt::Debug for RbTree<K, V>
332where
333 K: AsRef<[u8]> + std::fmt::Debug,
334 V: AsHashTree + std::fmt::Debug,
335{
336 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337 write!(f, "[")?;
338 let mut first = true;
339 for (k, v) in self {
340 if !first {
341 write!(f, ", ")?;
342 }
343 first = false;
344 write!(f, "({k:?}, {v:?})")?;
345 }
346 write!(f, "]")
347 }
348}
349
350impl<K, V> RbTree<K, V> {
351 pub const fn new() -> Self {
353 Self { root: None }
354 }
355
356 pub const fn is_empty(&self) -> bool {
358 self.root.is_none()
359 }
360}
361
362impl<K: AsRef<[u8]>, V: AsHashTree> RbTree<K, V> {
363 pub fn get(&self, key: &[u8]) -> Option<&V> {
365 let mut root = self.root.as_ref();
366 while let Some(n) = root {
367 match key.cmp(n.key.as_ref()) {
368 Equal => return Some(&n.value),
369 Less => root = n.left.as_ref(),
370 Greater => root = n.right.as_ref(),
371 }
372 }
373 None
374 }
375
376 pub fn modify(&mut self, key: &[u8], f: impl FnOnce(&mut V)) {
378 fn go<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
379 h: &mut NodeRef<K, V>,
380 k: &[u8],
381 f: impl FnOnce(&mut V),
382 ) {
383 if let Some(h) = h {
384 match k.as_ref().cmp(h.key.as_ref()) {
385 Equal => {
386 f(&mut h.value);
387 h.update_subtree_hash();
388 }
389 Less => {
390 go(&mut h.left, k, f);
391 h.update_subtree_hash();
392 }
393 Greater => {
394 go(&mut h.right, k, f);
395 h.update_subtree_hash();
396 }
397 }
398 }
399 }
400 go(&mut self.root, key, f);
401 }
402
403 fn range_witness<'a>(
404 &'a self,
405 left: Option<KeyBound<'a>>,
406 right: Option<KeyBound<'a>>,
407 f: fn(&'a Node<K, V>) -> HashTree<'a>,
408 ) -> HashTree<'a> {
409 match (left, right) {
410 (None, None) => Node::full_witness_tree(&self.root, f),
411 (Some(l), None) => self.witness_range_above(l, f),
412 (None, Some(r)) => self.witness_range_below(r, f),
413 (Some(l), Some(r)) => self.witness_range_between(l, r, f),
414 }
415 }
416
417 pub fn witness<'a>(&'a self, key: &[u8]) -> HashTree<'a> {
423 self.nested_witness(key, |v| v.as_hash_tree())
424 }
425
426 pub fn nested_witness<'a>(
430 &'a self,
431 key: &[u8],
432 f: impl FnOnce(&'a V) -> HashTree<'a>,
433 ) -> HashTree<'a> {
434 if let Some(t) = self.lookup_and_build_witness(key, f) {
435 return t;
436 }
437 self.range_witness(
438 self.lower_bound(key),
439 self.upper_bound(key),
440 Node::witness_tree,
441 )
442 }
443
444 pub fn keys(&self) -> HashTree<'_> {
448 Node::full_witness_tree(&self.root, Node::witness_tree)
449 }
450
451 pub fn key_range(&self, first: &[u8], last: &[u8]) -> HashTree<'_> {
455 self.range_witness(
456 self.lower_bound(first),
457 self.upper_bound(last),
458 Node::witness_tree,
459 )
460 }
461
462 pub fn value_range(&self, first: &[u8], last: &[u8]) -> HashTree<'_> {
465 self.range_witness(
466 self.lower_bound(first),
467 self.upper_bound(last),
468 Node::data_tree,
469 )
470 }
471
472 pub fn keys_with_prefix(&self, prefix: &[u8]) -> HashTree<'_> {
475 self.range_witness(
476 self.lower_bound(prefix),
477 self.right_prefix_neighbor(prefix),
478 Node::witness_tree,
479 )
480 }
481
482 pub fn iter(&self) -> Iter<'_, K, V> {
484 match &self.root {
485 None => Iter {
486 visit: Visit::Pre,
487 parents: vec![],
488 },
489 Some(n) => Iter {
490 visit: Visit::Pre,
491 parents: vec![n],
492 },
493 }
494 }
495
496 pub fn for_each<'a, F>(&'a self, mut f: F)
498 where
499 F: 'a + FnMut(&'a [u8], &'a V),
500 {
501 Node::visit(&self.root, &mut f);
502 }
503
504 fn witness_range_above<'a>(
505 &'a self,
506 lo: KeyBound<'a>,
507 f: fn(&'a Node<K, V>) -> HashTree<'a>,
508 ) -> HashTree<'a> {
509 fn go<'a, 't, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
510 n: &'a NodeRef<K, V>,
511 lo: KeyBound<'a>,
512 f: fn(&'a Node<K, V>) -> HashTree<'a>,
513 ) -> HashTree<'a> {
514 match n {
515 None => Empty,
516 Some(n) => match n.key.as_ref().cmp(lo.as_ref()) {
517 Equal => three_way_fork(
518 n.left_hash_tree(),
519 match lo {
520 KeyBound::Exact(_) => f(n),
521 KeyBound::Neighbor(_) => n.witness_tree(),
522 },
523 Node::full_witness_tree(&n.right, f),
524 ),
525 Less => three_way_fork(
526 n.left_hash_tree(),
527 Pruned(n.data_hash()),
528 go(&n.right, lo, f),
529 ),
530 Greater => three_way_fork(
531 go(&n.left, lo, f),
532 f(n),
533 Node::full_witness_tree(&n.right, f),
534 ),
535 },
536 }
537 }
538 go(&self.root, lo, f)
539 }
540
541 fn witness_range_below<'a>(
542 &'a self,
543 hi: KeyBound<'a>,
544 f: fn(&'a Node<K, V>) -> HashTree<'a>,
545 ) -> HashTree<'a> {
546 fn go<'a, 't, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
547 n: &'a NodeRef<K, V>,
548 hi: KeyBound<'a>,
549 f: fn(&'a Node<K, V>) -> HashTree<'a>,
550 ) -> HashTree<'a> {
551 match n {
552 None => Empty,
553 Some(n) => match n.key.as_ref().cmp(hi.as_ref()) {
554 Equal => three_way_fork(
555 Node::full_witness_tree(&n.left, f),
556 match hi {
557 KeyBound::Exact(_) => f(n),
558 KeyBound::Neighbor(_) => n.witness_tree(),
559 },
560 n.right_hash_tree(),
561 ),
562 Greater => three_way_fork(
563 go(&n.left, hi, f),
564 Pruned(n.data_hash()),
565 n.right_hash_tree(),
566 ),
567 Less => three_way_fork(
568 Node::full_witness_tree(&n.left, f),
569 f(n),
570 go(&n.right, hi, f),
571 ),
572 },
573 }
574 }
575 go(&self.root, hi, f)
576 }
577
578 fn witness_range_between<'a>(
579 &'a self,
580 lo: KeyBound<'a>,
581 hi: KeyBound<'a>,
582 f: fn(&'a Node<K, V>) -> HashTree<'a>,
583 ) -> HashTree<'a> {
584 debug_assert!(
585 lo.as_ref() <= hi.as_ref(),
586 "lo = {:?} > hi = {:?}",
587 lo.as_ref(),
588 hi.as_ref()
589 );
590 fn go<'a, 't, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
591 n: &'a NodeRef<K, V>,
592 lo: KeyBound<'a>,
593 hi: KeyBound<'a>,
594 f: fn(&'a Node<K, V>) -> HashTree<'a>,
595 ) -> HashTree<'a> {
596 match n {
597 None => Empty,
598 Some(n) => {
599 let k = n.key.as_ref();
600 match (lo.as_ref().cmp(k), k.cmp(hi.as_ref())) {
601 (Less, Less) => {
602 three_way_fork(go(&n.left, lo, hi, f), f(n), go(&n.right, lo, hi, f))
603 }
604 (Equal, Equal) => three_way_fork(
605 n.left_hash_tree(),
606 match (lo, hi) {
607 (KeyBound::Exact(_), _) => f(n),
608 (_, KeyBound::Exact(_)) => f(n),
609 _ => n.witness_tree(),
610 },
611 n.right_hash_tree(),
612 ),
613 (_, Equal) => three_way_fork(
614 go(&n.left, lo, hi, f),
615 match hi {
616 KeyBound::Exact(_) => f(n),
617 KeyBound::Neighbor(_) => n.witness_tree(),
618 },
619 n.right_hash_tree(),
620 ),
621 (Equal, _) => three_way_fork(
622 n.left_hash_tree(),
623 match lo {
624 KeyBound::Exact(_) => f(n),
625 KeyBound::Neighbor(_) => n.witness_tree(),
626 },
627 go(&n.right, lo, hi, f),
628 ),
629 (Less, Greater) => three_way_fork(
630 go(&n.left, lo, hi, f),
631 Pruned(n.data_hash()),
632 n.right_hash_tree(),
633 ),
634 (Greater, Less) => three_way_fork(
635 n.left_hash_tree(),
636 Pruned(n.data_hash()),
637 go(&n.right, lo, hi, f),
638 ),
639 _ => Pruned(n.subtree_hash),
640 }
641 }
642 }
643 }
644 go(&self.root, lo, hi, f)
645 }
646
647 fn lower_bound(&self, key: &[u8]) -> Option<KeyBound<'_>> {
648 fn go<'a, 't, K: 't + AsRef<[u8]>, V>(
649 n: &'a NodeRef<K, V>,
650 key: &[u8],
651 ) -> Option<KeyBound<'a>> {
652 n.as_ref().and_then(|n| {
653 let node_key = n.key.as_ref();
654 match node_key.cmp(key) {
655 Less => go(&n.right, key).or(Some(KeyBound::Neighbor(node_key))),
656 Equal => Some(KeyBound::Exact(node_key)),
657 Greater => go(&n.left, key),
658 }
659 })
660 }
661 go(&self.root, key)
662 }
663
664 fn upper_bound(&self, key: &[u8]) -> Option<KeyBound<'_>> {
665 fn go<'a, 't, K: 't + AsRef<[u8]>, V>(
666 n: &'a NodeRef<K, V>,
667 key: &[u8],
668 ) -> Option<KeyBound<'a>> {
669 n.as_ref().and_then(|n| {
670 let node_key = n.key.as_ref();
671 match node_key.cmp(key) {
672 Less => go(&n.right, key),
673 Equal => Some(KeyBound::Exact(node_key)),
674 Greater => go(&n.left, key).or(Some(KeyBound::Neighbor(node_key))),
675 }
676 })
677 }
678 go(&self.root, key)
679 }
680
681 fn right_prefix_neighbor(&self, prefix: &[u8]) -> Option<KeyBound<'_>> {
682 fn is_prefix_of(p: &[u8], x: &[u8]) -> bool {
683 if p.len() > x.len() {
684 return false;
685 }
686 &x[0..p.len()] == p
687 }
688 fn go<'a, 't, K: 't + AsRef<[u8]>, V>(
689 n: &'a NodeRef<K, V>,
690 prefix: &[u8],
691 ) -> Option<KeyBound<'a>> {
692 n.as_ref().and_then(|n| {
693 let node_key = n.key.as_ref();
694 match node_key.cmp(prefix) {
695 Greater if is_prefix_of(prefix, node_key) => go(&n.right, prefix),
696 Greater => go(&n.left, prefix).or(Some(KeyBound::Neighbor(node_key))),
697 Less | Equal => go(&n.right, prefix),
698 }
699 })
700 }
701 go(&self.root, prefix)
702 }
703
704 fn lookup_and_build_witness<'a>(
705 &'a self,
706 key: &[u8],
707 f: impl FnOnce(&'a V) -> HashTree<'a>,
708 ) -> Option<HashTree<'a>> {
709 fn go<'a, 't, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
710 n: &'a NodeRef<K, V>,
711 key: &[u8],
712 f: impl FnOnce(&'a V) -> HashTree<'a>,
713 ) -> Option<HashTree<'a>> {
714 n.as_ref().and_then(|n| match key.cmp(n.key.as_ref()) {
715 Equal => Some(three_way_fork(
716 n.left_hash_tree(),
717 n.subtree_with(f),
718 n.right_hash_tree(),
719 )),
720 Less => {
721 let subtree = go(&n.left, key, f)?;
722 Some(three_way_fork(
723 subtree,
724 Pruned(n.data_hash()),
725 n.right_hash_tree(),
726 ))
727 }
728 Greater => {
729 let subtree = go(&n.right, key, f)?;
730 Some(three_way_fork(
731 n.left_hash_tree(),
732 Pruned(n.data_hash()),
733 subtree,
734 ))
735 }
736 })
737 }
738 go(&self.root, key, f)
739 }
740
741 pub fn insert(&mut self, key: K, value: V) {
743 fn go<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
744 h: NodeRef<K, V>,
745 k: K,
746 v: V,
747 ) -> Box<Node<K, V>> {
748 match h {
749 None => Node::new(k, v),
750 Some(mut h) => {
751 match k.as_ref().cmp(h.key.as_ref()) {
752 Equal => {
753 h.value = v;
754 }
755 Less => {
756 h.left = Some(go(h.left, k, v));
757 }
758 Greater => {
759 h.right = Some(go(h.right, k, v));
760 }
761 }
762 h.update_subtree_hash();
763 balance(h)
764 }
765 }
766 }
767 let mut root = go(self.root.take(), key, value);
768 root.color = Color::Black;
769 self.root = Some(root);
770
771 #[cfg(test)]
772 debug_assert!(
773 is_balanced(&self.root),
774 "the tree is not balanced:\n{:?}",
775 DebugView(&self.root)
776 );
777 }
778
779 pub fn delete(&mut self, key: &[u8]) {
781 fn move_red_left<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
782 mut h: Box<Node<K, V>>,
783 ) -> Box<Node<K, V>> {
784 flip_colors(&mut h);
785 if is_red(&h.right.as_ref().unwrap().left) {
786 h.right = Some(rotate_right(h.right.take().unwrap()));
787 h = rotate_left(h);
788 flip_colors(&mut h);
789 }
790 h
791 }
792
793 fn move_red_right<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
794 mut h: Box<Node<K, V>>,
795 ) -> Box<Node<K, V>> {
796 flip_colors(&mut h);
797 if is_red(&h.left.as_ref().unwrap().left) {
798 h = rotate_right(h);
799 flip_colors(&mut h);
800 }
801 h
802 }
803
804 #[inline]
805 fn min<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
806 mut h: &mut Box<Node<K, V>>,
807 ) -> &mut Box<Node<K, V>> {
808 while h.left.is_some() {
809 h = h.left.as_mut().unwrap();
810 }
811 h
812 }
813
814 fn delete_min<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
815 mut h: Box<Node<K, V>>,
816 ) -> NodeRef<K, V> {
817 if h.left.is_none() {
818 debug_assert!(h.right.is_none());
819 drop(h);
820 return None;
821 }
822 if !is_red(&h.left) && !is_red(&h.left.as_ref().unwrap().left) {
823 h = move_red_left(h);
824 }
825 h.left = delete_min(h.left.unwrap());
826 h.update_subtree_hash();
827 Some(balance(h))
828 }
829
830 fn go<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
831 mut h: Box<Node<K, V>>,
832 key: &[u8],
833 ) -> NodeRef<K, V> {
834 if key < h.key.as_ref() {
835 debug_assert!(h.left.is_some(), "the key must be present in the tree");
836 if !is_red(&h.left) && !is_red(&h.left.as_ref().unwrap().left) {
837 h = move_red_left(h);
838 }
839 h.left = go(h.left.take().unwrap(), key);
840 } else {
841 if is_red(&h.left) {
842 h = rotate_right(h);
843 }
844 if key == h.key.as_ref() && h.right.is_none() {
845 debug_assert!(h.left.is_none());
846 drop(h);
847 return None;
848 }
849
850 if !is_red(&h.right) && !is_red(&h.right.as_ref().unwrap().left) {
851 h = move_red_right(h);
852 }
853
854 if key == h.key.as_ref() {
855 let m = min(h.right.as_mut().unwrap());
856 std::mem::swap(&mut h.key, &mut m.key);
857 std::mem::swap(&mut h.value, &mut m.value);
858 h.right = delete_min(h.right.take().unwrap());
859 } else {
860 h.right = go(h.right.take().unwrap(), key);
861 }
862 }
863 h.update_subtree_hash();
864 Some(balance(h))
865 }
866
867 if self.get(key).is_none() {
868 return;
869 }
870
871 if !is_red(&self.root.as_ref().unwrap().left) && !is_red(&self.root.as_ref().unwrap().right)
872 {
873 self.root.as_mut().unwrap().color = Color::Red;
874 }
875 self.root = go(self.root.take().unwrap(), key);
876 if let Some(n) = self.root.as_mut() {
877 n.color = Color::Black;
878 }
879
880 #[cfg(test)]
881 debug_assert!(
882 is_balanced(&self.root),
883 "unbalanced map: {:?}",
884 DebugView(&self.root)
885 );
886
887 debug_assert!(self.get(key).is_none());
888 }
889}
890
891impl<'a, K: AsRef<[u8]>, V: AsHashTree> IntoIterator for &'a RbTree<K, V> {
892 type Item = (&'a K, &'a V);
893 type IntoIter = Iter<'a, K, V>;
894
895 fn into_iter(self) -> Self::IntoIter {
896 self.iter()
897 }
898}
899
900use candid::CandidType;
901
902impl<K, V> CandidType for RbTree<K, V>
903where
904 K: CandidType + AsRef<[u8]>,
905 V: CandidType + AsHashTree,
906{
907 fn _ty() -> candid::types::internal::Type {
908 <Vec<(&K, &V)> as CandidType>::_ty()
909 }
910 fn idl_serialize<S: candid::types::Serializer>(&self, serializer: S) -> Result<(), S::Error> {
911 let collect_as_vec = self.iter().collect::<Vec<(&K, &V)>>();
912 <Vec<(&K, &V)> as CandidType>::idl_serialize(&collect_as_vec, serializer)
913 }
914}
915
916use serde::{
917 de::{Deserialize, Deserializer, MapAccess, Visitor},
918 ser::{Serialize, SerializeMap, Serializer},
919};
920use std::marker::PhantomData;
921
922impl<K, V> Serialize for RbTree<K, V>
923where
924 K: Serialize + AsRef<[u8]>,
925 V: Serialize + AsHashTree,
926{
927 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
928 where
929 S: Serializer,
930 {
931 let mut map = serializer.serialize_map(Some(self.iter().count()))?;
932 for (k, v) in self {
933 map.serialize_entry(k, v)?;
934 }
935 map.end()
936 }
937}
938
939struct RbTreeSerdeVisitor<K, V> {
941 marker: PhantomData<fn() -> RbTree<K, V>>,
942}
943
944impl<K, V> RbTreeSerdeVisitor<K, V> {
945 fn new() -> Self {
946 RbTreeSerdeVisitor {
947 marker: PhantomData,
948 }
949 }
950}
951
952impl<'de, K, V> Visitor<'de> for RbTreeSerdeVisitor<K, V>
953where
954 K: Deserialize<'de> + AsRef<[u8]>,
955 V: Deserialize<'de> + AsHashTree,
956{
957 type Value = RbTree<K, V>;
958
959 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
960 formatter.write_str("a map")
961 }
962
963 fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
964 where
965 M: MapAccess<'de>,
966 {
967 let mut t = RbTree::<K, V>::new();
968 while let Some((key, value)) = access.next_entry()? {
969 t.insert(key, value);
970 }
971 Ok(t)
972 }
973}
974
975impl<'de, K, V> Deserialize<'de> for RbTree<K, V>
976where
977 K: Deserialize<'de> + AsRef<[u8]>,
978 V: Deserialize<'de> + AsHashTree,
979{
980 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
981 where
982 D: Deserializer<'de>,
983 {
984 deserializer.deserialize_map(RbTreeSerdeVisitor::new())
985 }
986}
987
988fn three_way_fork<'a>(l: HashTree<'a>, m: HashTree<'a>, r: HashTree<'a>) -> HashTree<'a> {
989 match (l, m, r) {
990 (Empty, m, Empty) => m,
991 (l, m, Empty) => fork(l, m),
992 (Empty, m, r) => fork(m, r),
993 (Pruned(lhash), Pruned(mhash), Pruned(rhash)) => {
994 Pruned(fork_hash(&lhash, &fork_hash(&mhash, &rhash)))
995 }
996 (l, Pruned(mhash), Pruned(rhash)) => fork(l, Pruned(fork_hash(&mhash, &rhash))),
997 (l, m, r) => fork(l, fork(m, r)),
998 }
999}
1000
1001fn is_red<K, V>(x: &NodeRef<K, V>) -> bool {
1003 x.as_ref().is_some_and(|h| h.color == Color::Red)
1004}
1005
1006fn balance<'t, K: AsRef<[u8]> + 't, V: AsHashTree + 't>(mut h: Box<Node<K, V>>) -> Box<Node<K, V>> {
1007 if is_red(&h.right) && !is_red(&h.left) {
1008 h = rotate_left(h);
1009 }
1010 if is_red(&h.left) && is_red(&h.left.as_ref().unwrap().left) {
1011 h = rotate_right(h);
1012 }
1013 if is_red(&h.left) && is_red(&h.right) {
1014 flip_colors(&mut h);
1015 }
1016 h
1017}
1018
1019fn rotate_right<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
1021 mut h: Box<Node<K, V>>,
1022) -> Box<Node<K, V>> {
1023 debug_assert!(is_red(&h.left));
1024
1025 let mut x = h.left.take().unwrap();
1026 h.left = x.right.take();
1027 h.update_subtree_hash();
1028
1029 x.right = Some(h);
1030 x.color = x.right.as_ref().unwrap().color;
1031 x.right.as_mut().unwrap().color = Color::Red;
1032 x.update_subtree_hash();
1033
1034 x
1035}
1036
1037fn rotate_left<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
1038 mut h: Box<Node<K, V>>,
1039) -> Box<Node<K, V>> {
1040 debug_assert!(is_red(&h.right));
1041
1042 let mut x = h.right.take().unwrap();
1043 h.right = x.left.take();
1044 h.update_subtree_hash();
1045
1046 x.left = Some(h);
1047 x.color = x.left.as_ref().unwrap().color;
1048 x.left.as_mut().unwrap().color = Color::Red;
1049 x.update_subtree_hash();
1050
1051 x
1052}
1053
1054fn flip_colors<K, V>(h: &mut Box<Node<K, V>>) {
1055 h.color.flip_assign();
1056 h.left.as_mut().unwrap().color.flip_assign();
1057 h.right.as_mut().unwrap().color.flip_assign();
1058}
1059
1060#[cfg(test)]
1061fn is_balanced<K, V>(root: &NodeRef<K, V>) -> bool {
1062 fn go<K, V>(node: &NodeRef<K, V>, mut num_black: usize) -> bool {
1063 match node {
1064 None => num_black == 0,
1065 Some(n) => {
1066 if !is_red(node) {
1067 debug_assert!(num_black > 0);
1068 num_black -= 1;
1069 } else {
1070 assert!(!is_red(&n.left));
1071 assert!(!is_red(&n.right));
1072 }
1073 go(&n.left, num_black) && go(&n.right, num_black)
1074 }
1075 }
1076 }
1077
1078 let mut num_black = 0;
1079 let mut x = root;
1080 while let Some(n) = x {
1081 if !is_red(x) {
1082 num_black += 1;
1083 }
1084 x = &n.left;
1085 }
1086 go(root, num_black)
1087}
1088
1089#[allow(dead_code)]
1090struct DebugView<'a, K, V>(&'a NodeRef<K, V>);
1091
1092impl<K: AsRef<[u8]>, V> fmt::Debug for DebugView<'_, K, V> {
1093 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1094 fn go<K: AsRef<[u8]>, V>(
1095 f: &mut fmt::Formatter<'_>,
1096 node: &NodeRef<K, V>,
1097 offset: usize,
1098 ) -> fmt::Result {
1099 match node {
1100 None => writeln!(f, "{:width$}[B] <null>", "", width = offset),
1101 Some(h) => {
1102 writeln!(
1103 f,
1104 "{:width$}[{}] {:?}",
1105 "",
1106 if is_red(node) { "R" } else { "B" },
1107 h.key.as_ref(),
1108 width = offset
1109 )?;
1110 go(f, &h.left, offset + 2)?;
1111 go(f, &h.right, offset + 2)
1112 }
1113 }
1114 }
1115 go(f, self.0, 0)
1116 }
1117}
1118
1119#[cfg(test)]
1120mod test;