ic_certified_map/
rbtree.rs

1use crate::hashtree::{
2    Hash,
3    HashTree::{self, Empty, Leaf, Pruned},
4    fork, fork_hash, labeled, labeled_hash, leaf_hash,
5};
6use std::borrow::Cow;
7use std::cmp::Ordering::{self, Equal, Greater, Less};
8use std::fmt;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11enum Color {
12    Red,
13    Black,
14}
15
16impl Color {
17    fn flip_assign(&mut self) {
18        *self = self.flip();
19    }
20
21    fn flip(self) -> Self {
22        match self {
23            Self::Red => Self::Black,
24            Self::Black => Self::Red,
25        }
26    }
27}
28
29/// Types that can be converted into a [`HashTree`].
30pub trait AsHashTree {
31    /// Returns the root hash of the tree without constructing it.
32    /// Must be equivalent to `as_hash_tree().reconstruct()`.
33    fn root_hash(&self) -> Hash;
34
35    /// Constructs a hash tree corresponding to the data.
36    fn as_hash_tree(&self) -> HashTree<'_>;
37}
38
39impl AsHashTree for Vec<u8> {
40    fn root_hash(&self) -> Hash {
41        leaf_hash(&self[..])
42    }
43
44    fn as_hash_tree(&self) -> HashTree<'_> {
45        Leaf(Cow::from(&self[..]))
46    }
47}
48
49impl AsHashTree for Hash {
50    fn root_hash(&self) -> Hash {
51        leaf_hash(&self[..])
52    }
53
54    fn as_hash_tree(&self) -> HashTree<'_> {
55        Leaf(Cow::from(&self[..]))
56    }
57}
58
59impl<K: AsRef<[u8]>, V: AsHashTree> AsHashTree for RbTree<K, V> {
60    fn root_hash(&self) -> Hash {
61        match self.root.as_ref() {
62            None => Empty.reconstruct(),
63            Some(n) => n.subtree_hash,
64        }
65    }
66
67    fn as_hash_tree(&self) -> HashTree<'_> {
68        Node::full_witness_tree(&self.root, Node::data_tree)
69    }
70}
71
72#[derive(PartialEq, Debug, Clone, Copy)]
73enum KeyBound<'a> {
74    Exact(&'a [u8]),
75    Neighbor(&'a [u8]),
76}
77
78impl<'a> AsRef<[u8]> for KeyBound<'a> {
79    fn as_ref(&self) -> &'a [u8] {
80        match self {
81            KeyBound::Exact(key) => key,
82            KeyBound::Neighbor(key) => key,
83        }
84    }
85}
86
87type NodeRef<K, V> = Option<Box<Node<K, V>>>;
88
89// 1. All leaves are black.
90// 2. Children of a red node are black.
91// 3. Every path from a node goes through the same number of black
92//    nodes.
93#[derive(Clone, Debug)]
94struct Node<K, V> {
95    key: K,
96    value: V,
97    left: NodeRef<K, V>,
98    right: NodeRef<K, V>,
99    color: Color,
100
101    /// Hash of the full hash tree built from this node and its
102    /// children. It needs to be recomputed after every rotation.
103    subtree_hash: Hash,
104}
105
106impl<K: AsRef<[u8]>, V: AsHashTree> Node<K, V> {
107    fn new(key: K, value: V) -> Box<Node<K, V>> {
108        let value_hash = value.root_hash();
109        let data_hash = labeled_hash(key.as_ref(), &value_hash);
110        Box::new(Self {
111            key,
112            value,
113            left: None,
114            right: None,
115            color: Color::Red,
116            subtree_hash: data_hash,
117        })
118    }
119
120    fn data_hash(&self) -> Hash {
121        labeled_hash(self.key.as_ref(), &self.value.root_hash())
122    }
123
124    fn left_hash_tree(&self) -> HashTree<'_> {
125        match self.left.as_ref() {
126            None => Empty,
127            Some(l) => Pruned(l.subtree_hash),
128        }
129    }
130
131    fn right_hash_tree(&self) -> HashTree<'_> {
132        match self.right.as_ref() {
133            None => Empty,
134            Some(r) => Pruned(r.subtree_hash),
135        }
136    }
137
138    fn visit<'a, F>(n: &'a NodeRef<K, V>, f: &mut F)
139    where
140        F: 'a + FnMut(&'a [u8], &'a V),
141    {
142        if let Some(n) = n {
143            Self::visit(&n.left, f);
144            (*f)(n.key.as_ref(), &n.value);
145            Self::visit(&n.right, f);
146        }
147    }
148
149    fn data_tree(&self) -> HashTree<'_> {
150        labeled(self.key.as_ref(), self.value.as_hash_tree())
151    }
152
153    fn subtree_with<'a>(&'a self, f: impl FnOnce(&'a V) -> HashTree<'a>) -> HashTree<'a> {
154        labeled(self.key.as_ref(), f(&self.value))
155    }
156
157    fn witness_tree(&self) -> HashTree<'_> {
158        labeled(self.key.as_ref(), Pruned(self.value.root_hash()))
159    }
160
161    fn full_witness_tree<'a>(
162        n: &'a NodeRef<K, V>,
163        f: fn(&'a Node<K, V>) -> HashTree<'a>,
164    ) -> HashTree<'a> {
165        match n {
166            None => Empty,
167            Some(n) => three_way_fork(
168                Self::full_witness_tree(&n.left, f),
169                f(n),
170                Self::full_witness_tree(&n.right, f),
171            ),
172        }
173    }
174
175    fn update_subtree_hash(&mut self) {
176        self.subtree_hash = self.compute_subtree_hash();
177    }
178
179    fn compute_subtree_hash(&self) -> Hash {
180        let h = self.data_hash();
181
182        match (self.left.as_ref(), self.right.as_ref()) {
183            (None, None) => h,
184            (Some(l), None) => fork_hash(&l.subtree_hash, &h),
185            (None, Some(r)) => fork_hash(&h, &r.subtree_hash),
186            (Some(l), Some(r)) => fork_hash(&l.subtree_hash, &fork_hash(&h, &r.subtree_hash)),
187        }
188    }
189}
190
191#[derive(PartialEq, Debug)]
192enum Visit {
193    Pre,
194    In,
195    Post,
196}
197
198/// Iterator over a `RbTree`.
199#[derive(Debug)]
200pub struct Iter<'a, K, V> {
201    /// Invariants:
202    /// 1. visit == Pre: none of the nodes in parents were visited yet.
203    /// 2. visit == In:  the last node in parents and all its left children are visited.
204    /// 3. visit == Post: all the nodes reachable from the last node in parents are visited.
205    visit: Visit,
206    parents: Vec<&'a Node<K, V>>,
207}
208
209impl<K, V> Iter<'_, K, V> {
210    /// This function is an adaptation of the `traverse_step` procedure described in
211    /// section 7.2 "Bidirectional Bifurcate Coordinates" of
212    /// "Elements of Programming" by A. Stepanov and P. McJones, p. 118.
213    /// <http://elementsofprogramming.com/eop.pdf>
214    ///
215    /// The main difference is that our nodes don't have parent links for two reasons:
216    /// 1. They don't play well with safe Rust ownership model.
217    /// 2. Iterating a tree shouldn't be an operation common enough to complicate the code.
218    fn step(&mut self) -> bool {
219        match self.parents.last() {
220            Some(tip) => {
221                match self.visit {
222                    Visit::Pre => {
223                        if let Some(l) = &tip.left {
224                            self.parents.push(l);
225                        } else {
226                            self.visit = Visit::In;
227                        }
228                    }
229                    Visit::In => {
230                        if let Some(r) = &tip.right {
231                            self.parents.push(r);
232                            self.visit = Visit::Pre;
233                        } else {
234                            self.visit = Visit::Post;
235                        }
236                    }
237                    Visit::Post => {
238                        let tip = self.parents.pop().unwrap();
239                        if let Some(parent) = self.parents.last() {
240                            if parent
241                                .left
242                                .as_ref()
243                                .is_some_and(|l| std::ptr::eq(l.as_ref(), tip))
244                            {
245                                self.visit = Visit::In;
246                            }
247                        }
248                    }
249                }
250                true
251            }
252            None => false,
253        }
254    }
255}
256
257impl<'a, K, V> std::iter::Iterator for Iter<'a, K, V> {
258    type Item = (&'a K, &'a V);
259
260    fn next(&mut self) -> Option<Self::Item> {
261        while self.step() {
262            if self.visit == Visit::In {
263                return self.parents.last().map(|n| (&n.key, &n.value));
264            }
265        }
266        None
267    }
268}
269
270/// Implements mutable left-leaning red-black trees as defined in
271/// <https://www.cs.princeton.edu/~rs/talks/LLRB/LLRB.pdf>
272#[derive(Default, Clone)]
273pub struct RbTree<K, V> {
274    root: NodeRef<K, V>,
275}
276
277impl<K, V> PartialEq for RbTree<K, V>
278where
279    K: AsRef<[u8]> + PartialEq,
280    V: AsHashTree + PartialEq,
281{
282    fn eq(&self, other: &Self) -> bool {
283        self.iter().eq(other.iter())
284    }
285}
286
287impl<K, V> Eq for RbTree<K, V>
288where
289    K: AsRef<[u8]> + Eq,
290    V: AsHashTree + Eq,
291{
292}
293
294impl<K, V> PartialOrd for RbTree<K, V>
295where
296    K: AsRef<[u8]> + PartialOrd,
297    V: AsHashTree + PartialOrd,
298{
299    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
300        self.iter().partial_cmp(other.iter())
301    }
302}
303
304impl<K, V> Ord for RbTree<K, V>
305where
306    K: AsRef<[u8]> + Ord,
307    V: AsHashTree + Ord,
308{
309    fn cmp(&self, other: &Self) -> Ordering {
310        self.iter().cmp(other.iter())
311    }
312}
313
314impl<K, V> std::iter::FromIterator<(K, V)> for RbTree<K, V>
315where
316    K: AsRef<[u8]>,
317    V: AsHashTree,
318{
319    fn from_iter<T>(iter: T) -> Self
320    where
321        T: IntoIterator<Item = (K, V)>,
322    {
323        let mut t = RbTree::<K, V>::new();
324        for (k, v) in iter {
325            t.insert(k, v);
326        }
327        t
328    }
329}
330
331impl<K, V> std::fmt::Debug for RbTree<K, V>
332where
333    K: AsRef<[u8]> + std::fmt::Debug,
334    V: AsHashTree + std::fmt::Debug,
335{
336    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337        write!(f, "[")?;
338        let mut first = true;
339        for (k, v) in self {
340            if !first {
341                write!(f, ", ")?;
342            }
343            first = false;
344            write!(f, "({k:?}, {v:?})")?;
345        }
346        write!(f, "]")
347    }
348}
349
350impl<K, V> RbTree<K, V> {
351    /// Constructs a new empty tree.
352    pub const fn new() -> Self {
353        Self { root: None }
354    }
355
356    /// Returns true if the map is empty.
357    pub const fn is_empty(&self) -> bool {
358        self.root.is_none()
359    }
360}
361
362impl<K: AsRef<[u8]>, V: AsHashTree> RbTree<K, V> {
363    /// Looks up the key in the map and returns the associated value, if there is one.
364    pub fn get(&self, key: &[u8]) -> Option<&V> {
365        let mut root = self.root.as_ref();
366        while let Some(n) = root {
367            match key.cmp(n.key.as_ref()) {
368                Equal => return Some(&n.value),
369                Less => root = n.left.as_ref(),
370                Greater => root = n.right.as_ref(),
371            }
372        }
373        None
374    }
375
376    /// Updates the value corresponding to the specified key.
377    pub fn modify(&mut self, key: &[u8], f: impl FnOnce(&mut V)) {
378        fn go<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
379            h: &mut NodeRef<K, V>,
380            k: &[u8],
381            f: impl FnOnce(&mut V),
382        ) {
383            if let Some(h) = h {
384                match k.as_ref().cmp(h.key.as_ref()) {
385                    Equal => {
386                        f(&mut h.value);
387                        h.update_subtree_hash();
388                    }
389                    Less => {
390                        go(&mut h.left, k, f);
391                        h.update_subtree_hash();
392                    }
393                    Greater => {
394                        go(&mut h.right, k, f);
395                        h.update_subtree_hash();
396                    }
397                }
398            }
399        }
400        go(&mut self.root, key, f);
401    }
402
403    fn range_witness<'a>(
404        &'a self,
405        left: Option<KeyBound<'a>>,
406        right: Option<KeyBound<'a>>,
407        f: fn(&'a Node<K, V>) -> HashTree<'a>,
408    ) -> HashTree<'a> {
409        match (left, right) {
410            (None, None) => Node::full_witness_tree(&self.root, f),
411            (Some(l), None) => self.witness_range_above(l, f),
412            (None, Some(r)) => self.witness_range_below(r, f),
413            (Some(l), Some(r)) => self.witness_range_between(l, r, f),
414        }
415    }
416
417    /// Constructs a hash tree that acts as a proof that there is a
418    /// entry with the specified key in this map.  The proof also
419    /// contains the value in question.
420    ///
421    /// If the key is not in the map, returns a proof of absence.
422    pub fn witness<'a>(&'a self, key: &[u8]) -> HashTree<'a> {
423        self.nested_witness(key, |v| v.as_hash_tree())
424    }
425
426    /// Like `witness`, but gives the caller more control over the
427    /// construction of the value witness.  This method is useful for
428    /// constructing witnesses for nested certified maps.
429    pub fn nested_witness<'a>(
430        &'a self,
431        key: &[u8],
432        f: impl FnOnce(&'a V) -> HashTree<'a>,
433    ) -> HashTree<'a> {
434        if let Some(t) = self.lookup_and_build_witness(key, f) {
435            return t;
436        }
437        self.range_witness(
438            self.lower_bound(key),
439            self.upper_bound(key),
440            Node::witness_tree,
441        )
442    }
443
444    /// Returns a witness enumerating all the keys in this map.  The
445    /// resulting tree doesn't include values, they are replaced with
446    /// "Pruned" nodes.
447    pub fn keys(&self) -> HashTree<'_> {
448        Node::full_witness_tree(&self.root, Node::witness_tree)
449    }
450
451    /// Returns a witness for the keys in the specified range.  The
452    /// resulting tree doesn't include values, they are replaced with
453    /// "Pruned" nodes.
454    pub fn key_range(&self, first: &[u8], last: &[u8]) -> HashTree<'_> {
455        self.range_witness(
456            self.lower_bound(first),
457            self.upper_bound(last),
458            Node::witness_tree,
459        )
460    }
461
462    /// Returns a witness for the key-value pairs in the specified range.
463    /// The resulting tree contains both keys and values.
464    pub fn value_range(&self, first: &[u8], last: &[u8]) -> HashTree<'_> {
465        self.range_witness(
466            self.lower_bound(first),
467            self.upper_bound(last),
468            Node::data_tree,
469        )
470    }
471
472    /// Returns a witness that enumerates all the keys starting with
473    /// the specified prefix.
474    pub fn keys_with_prefix(&self, prefix: &[u8]) -> HashTree<'_> {
475        self.range_witness(
476            self.lower_bound(prefix),
477            self.right_prefix_neighbor(prefix),
478            Node::witness_tree,
479        )
480    }
481
482    /// Creates an iterator over the map's keys and values.
483    pub fn iter(&self) -> Iter<'_, K, V> {
484        match &self.root {
485            None => Iter {
486                visit: Visit::Pre,
487                parents: vec![],
488            },
489            Some(n) => Iter {
490                visit: Visit::Pre,
491                parents: vec![n],
492            },
493        }
494    }
495
496    /// Enumerates all the key-value pairs in the tree.
497    pub fn for_each<'a, F>(&'a self, mut f: F)
498    where
499        F: 'a + FnMut(&'a [u8], &'a V),
500    {
501        Node::visit(&self.root, &mut f);
502    }
503
504    fn witness_range_above<'a>(
505        &'a self,
506        lo: KeyBound<'a>,
507        f: fn(&'a Node<K, V>) -> HashTree<'a>,
508    ) -> HashTree<'a> {
509        fn go<'a, 't, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
510            n: &'a NodeRef<K, V>,
511            lo: KeyBound<'a>,
512            f: fn(&'a Node<K, V>) -> HashTree<'a>,
513        ) -> HashTree<'a> {
514            match n {
515                None => Empty,
516                Some(n) => match n.key.as_ref().cmp(lo.as_ref()) {
517                    Equal => three_way_fork(
518                        n.left_hash_tree(),
519                        match lo {
520                            KeyBound::Exact(_) => f(n),
521                            KeyBound::Neighbor(_) => n.witness_tree(),
522                        },
523                        Node::full_witness_tree(&n.right, f),
524                    ),
525                    Less => three_way_fork(
526                        n.left_hash_tree(),
527                        Pruned(n.data_hash()),
528                        go(&n.right, lo, f),
529                    ),
530                    Greater => three_way_fork(
531                        go(&n.left, lo, f),
532                        f(n),
533                        Node::full_witness_tree(&n.right, f),
534                    ),
535                },
536            }
537        }
538        go(&self.root, lo, f)
539    }
540
541    fn witness_range_below<'a>(
542        &'a self,
543        hi: KeyBound<'a>,
544        f: fn(&'a Node<K, V>) -> HashTree<'a>,
545    ) -> HashTree<'a> {
546        fn go<'a, 't, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
547            n: &'a NodeRef<K, V>,
548            hi: KeyBound<'a>,
549            f: fn(&'a Node<K, V>) -> HashTree<'a>,
550        ) -> HashTree<'a> {
551            match n {
552                None => Empty,
553                Some(n) => match n.key.as_ref().cmp(hi.as_ref()) {
554                    Equal => three_way_fork(
555                        Node::full_witness_tree(&n.left, f),
556                        match hi {
557                            KeyBound::Exact(_) => f(n),
558                            KeyBound::Neighbor(_) => n.witness_tree(),
559                        },
560                        n.right_hash_tree(),
561                    ),
562                    Greater => three_way_fork(
563                        go(&n.left, hi, f),
564                        Pruned(n.data_hash()),
565                        n.right_hash_tree(),
566                    ),
567                    Less => three_way_fork(
568                        Node::full_witness_tree(&n.left, f),
569                        f(n),
570                        go(&n.right, hi, f),
571                    ),
572                },
573            }
574        }
575        go(&self.root, hi, f)
576    }
577
578    fn witness_range_between<'a>(
579        &'a self,
580        lo: KeyBound<'a>,
581        hi: KeyBound<'a>,
582        f: fn(&'a Node<K, V>) -> HashTree<'a>,
583    ) -> HashTree<'a> {
584        debug_assert!(
585            lo.as_ref() <= hi.as_ref(),
586            "lo = {:?} > hi = {:?}",
587            lo.as_ref(),
588            hi.as_ref()
589        );
590        fn go<'a, 't, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
591            n: &'a NodeRef<K, V>,
592            lo: KeyBound<'a>,
593            hi: KeyBound<'a>,
594            f: fn(&'a Node<K, V>) -> HashTree<'a>,
595        ) -> HashTree<'a> {
596            match n {
597                None => Empty,
598                Some(n) => {
599                    let k = n.key.as_ref();
600                    match (lo.as_ref().cmp(k), k.cmp(hi.as_ref())) {
601                        (Less, Less) => {
602                            three_way_fork(go(&n.left, lo, hi, f), f(n), go(&n.right, lo, hi, f))
603                        }
604                        (Equal, Equal) => three_way_fork(
605                            n.left_hash_tree(),
606                            match (lo, hi) {
607                                (KeyBound::Exact(_), _) => f(n),
608                                (_, KeyBound::Exact(_)) => f(n),
609                                _ => n.witness_tree(),
610                            },
611                            n.right_hash_tree(),
612                        ),
613                        (_, Equal) => three_way_fork(
614                            go(&n.left, lo, hi, f),
615                            match hi {
616                                KeyBound::Exact(_) => f(n),
617                                KeyBound::Neighbor(_) => n.witness_tree(),
618                            },
619                            n.right_hash_tree(),
620                        ),
621                        (Equal, _) => three_way_fork(
622                            n.left_hash_tree(),
623                            match lo {
624                                KeyBound::Exact(_) => f(n),
625                                KeyBound::Neighbor(_) => n.witness_tree(),
626                            },
627                            go(&n.right, lo, hi, f),
628                        ),
629                        (Less, Greater) => three_way_fork(
630                            go(&n.left, lo, hi, f),
631                            Pruned(n.data_hash()),
632                            n.right_hash_tree(),
633                        ),
634                        (Greater, Less) => three_way_fork(
635                            n.left_hash_tree(),
636                            Pruned(n.data_hash()),
637                            go(&n.right, lo, hi, f),
638                        ),
639                        _ => Pruned(n.subtree_hash),
640                    }
641                }
642            }
643        }
644        go(&self.root, lo, hi, f)
645    }
646
647    fn lower_bound(&self, key: &[u8]) -> Option<KeyBound<'_>> {
648        fn go<'a, 't, K: 't + AsRef<[u8]>, V>(
649            n: &'a NodeRef<K, V>,
650            key: &[u8],
651        ) -> Option<KeyBound<'a>> {
652            n.as_ref().and_then(|n| {
653                let node_key = n.key.as_ref();
654                match node_key.cmp(key) {
655                    Less => go(&n.right, key).or(Some(KeyBound::Neighbor(node_key))),
656                    Equal => Some(KeyBound::Exact(node_key)),
657                    Greater => go(&n.left, key),
658                }
659            })
660        }
661        go(&self.root, key)
662    }
663
664    fn upper_bound(&self, key: &[u8]) -> Option<KeyBound<'_>> {
665        fn go<'a, 't, K: 't + AsRef<[u8]>, V>(
666            n: &'a NodeRef<K, V>,
667            key: &[u8],
668        ) -> Option<KeyBound<'a>> {
669            n.as_ref().and_then(|n| {
670                let node_key = n.key.as_ref();
671                match node_key.cmp(key) {
672                    Less => go(&n.right, key),
673                    Equal => Some(KeyBound::Exact(node_key)),
674                    Greater => go(&n.left, key).or(Some(KeyBound::Neighbor(node_key))),
675                }
676            })
677        }
678        go(&self.root, key)
679    }
680
681    fn right_prefix_neighbor(&self, prefix: &[u8]) -> Option<KeyBound<'_>> {
682        fn is_prefix_of(p: &[u8], x: &[u8]) -> bool {
683            if p.len() > x.len() {
684                return false;
685            }
686            &x[0..p.len()] == p
687        }
688        fn go<'a, 't, K: 't + AsRef<[u8]>, V>(
689            n: &'a NodeRef<K, V>,
690            prefix: &[u8],
691        ) -> Option<KeyBound<'a>> {
692            n.as_ref().and_then(|n| {
693                let node_key = n.key.as_ref();
694                match node_key.cmp(prefix) {
695                    Greater if is_prefix_of(prefix, node_key) => go(&n.right, prefix),
696                    Greater => go(&n.left, prefix).or(Some(KeyBound::Neighbor(node_key))),
697                    Less | Equal => go(&n.right, prefix),
698                }
699            })
700        }
701        go(&self.root, prefix)
702    }
703
704    fn lookup_and_build_witness<'a>(
705        &'a self,
706        key: &[u8],
707        f: impl FnOnce(&'a V) -> HashTree<'a>,
708    ) -> Option<HashTree<'a>> {
709        fn go<'a, 't, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
710            n: &'a NodeRef<K, V>,
711            key: &[u8],
712            f: impl FnOnce(&'a V) -> HashTree<'a>,
713        ) -> Option<HashTree<'a>> {
714            n.as_ref().and_then(|n| match key.cmp(n.key.as_ref()) {
715                Equal => Some(three_way_fork(
716                    n.left_hash_tree(),
717                    n.subtree_with(f),
718                    n.right_hash_tree(),
719                )),
720                Less => {
721                    let subtree = go(&n.left, key, f)?;
722                    Some(three_way_fork(
723                        subtree,
724                        Pruned(n.data_hash()),
725                        n.right_hash_tree(),
726                    ))
727                }
728                Greater => {
729                    let subtree = go(&n.right, key, f)?;
730                    Some(three_way_fork(
731                        n.left_hash_tree(),
732                        Pruned(n.data_hash()),
733                        subtree,
734                    ))
735                }
736            })
737        }
738        go(&self.root, key, f)
739    }
740
741    /// Inserts a key-value entry into the map.
742    pub fn insert(&mut self, key: K, value: V) {
743        fn go<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
744            h: NodeRef<K, V>,
745            k: K,
746            v: V,
747        ) -> Box<Node<K, V>> {
748            match h {
749                None => Node::new(k, v),
750                Some(mut h) => {
751                    match k.as_ref().cmp(h.key.as_ref()) {
752                        Equal => {
753                            h.value = v;
754                        }
755                        Less => {
756                            h.left = Some(go(h.left, k, v));
757                        }
758                        Greater => {
759                            h.right = Some(go(h.right, k, v));
760                        }
761                    }
762                    h.update_subtree_hash();
763                    balance(h)
764                }
765            }
766        }
767        let mut root = go(self.root.take(), key, value);
768        root.color = Color::Black;
769        self.root = Some(root);
770
771        #[cfg(test)]
772        debug_assert!(
773            is_balanced(&self.root),
774            "the tree is not balanced:\n{:?}",
775            DebugView(&self.root)
776        );
777    }
778
779    /// Removes the specified key from the map.
780    pub fn delete(&mut self, key: &[u8]) {
781        fn move_red_left<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
782            mut h: Box<Node<K, V>>,
783        ) -> Box<Node<K, V>> {
784            flip_colors(&mut h);
785            if is_red(&h.right.as_ref().unwrap().left) {
786                h.right = Some(rotate_right(h.right.take().unwrap()));
787                h = rotate_left(h);
788                flip_colors(&mut h);
789            }
790            h
791        }
792
793        fn move_red_right<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
794            mut h: Box<Node<K, V>>,
795        ) -> Box<Node<K, V>> {
796            flip_colors(&mut h);
797            if is_red(&h.left.as_ref().unwrap().left) {
798                h = rotate_right(h);
799                flip_colors(&mut h);
800            }
801            h
802        }
803
804        #[inline]
805        fn min<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
806            mut h: &mut Box<Node<K, V>>,
807        ) -> &mut Box<Node<K, V>> {
808            while h.left.is_some() {
809                h = h.left.as_mut().unwrap();
810            }
811            h
812        }
813
814        fn delete_min<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
815            mut h: Box<Node<K, V>>,
816        ) -> NodeRef<K, V> {
817            if h.left.is_none() {
818                debug_assert!(h.right.is_none());
819                drop(h);
820                return None;
821            }
822            if !is_red(&h.left) && !is_red(&h.left.as_ref().unwrap().left) {
823                h = move_red_left(h);
824            }
825            h.left = delete_min(h.left.unwrap());
826            h.update_subtree_hash();
827            Some(balance(h))
828        }
829
830        fn go<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
831            mut h: Box<Node<K, V>>,
832            key: &[u8],
833        ) -> NodeRef<K, V> {
834            if key < h.key.as_ref() {
835                debug_assert!(h.left.is_some(), "the key must be present in the tree");
836                if !is_red(&h.left) && !is_red(&h.left.as_ref().unwrap().left) {
837                    h = move_red_left(h);
838                }
839                h.left = go(h.left.take().unwrap(), key);
840            } else {
841                if is_red(&h.left) {
842                    h = rotate_right(h);
843                }
844                if key == h.key.as_ref() && h.right.is_none() {
845                    debug_assert!(h.left.is_none());
846                    drop(h);
847                    return None;
848                }
849
850                if !is_red(&h.right) && !is_red(&h.right.as_ref().unwrap().left) {
851                    h = move_red_right(h);
852                }
853
854                if key == h.key.as_ref() {
855                    let m = min(h.right.as_mut().unwrap());
856                    std::mem::swap(&mut h.key, &mut m.key);
857                    std::mem::swap(&mut h.value, &mut m.value);
858                    h.right = delete_min(h.right.take().unwrap());
859                } else {
860                    h.right = go(h.right.take().unwrap(), key);
861                }
862            }
863            h.update_subtree_hash();
864            Some(balance(h))
865        }
866
867        if self.get(key).is_none() {
868            return;
869        }
870
871        if !is_red(&self.root.as_ref().unwrap().left) && !is_red(&self.root.as_ref().unwrap().right)
872        {
873            self.root.as_mut().unwrap().color = Color::Red;
874        }
875        self.root = go(self.root.take().unwrap(), key);
876        if let Some(n) = self.root.as_mut() {
877            n.color = Color::Black;
878        }
879
880        #[cfg(test)]
881        debug_assert!(
882            is_balanced(&self.root),
883            "unbalanced map: {:?}",
884            DebugView(&self.root)
885        );
886
887        debug_assert!(self.get(key).is_none());
888    }
889}
890
891impl<'a, K: AsRef<[u8]>, V: AsHashTree> IntoIterator for &'a RbTree<K, V> {
892    type Item = (&'a K, &'a V);
893    type IntoIter = Iter<'a, K, V>;
894
895    fn into_iter(self) -> Self::IntoIter {
896        self.iter()
897    }
898}
899
900use candid::CandidType;
901
902impl<K, V> CandidType for RbTree<K, V>
903where
904    K: CandidType + AsRef<[u8]>,
905    V: CandidType + AsHashTree,
906{
907    fn _ty() -> candid::types::internal::Type {
908        <Vec<(&K, &V)> as CandidType>::_ty()
909    }
910    fn idl_serialize<S: candid::types::Serializer>(&self, serializer: S) -> Result<(), S::Error> {
911        let collect_as_vec = self.iter().collect::<Vec<(&K, &V)>>();
912        <Vec<(&K, &V)> as CandidType>::idl_serialize(&collect_as_vec, serializer)
913    }
914}
915
916use serde::{
917    de::{Deserialize, Deserializer, MapAccess, Visitor},
918    ser::{Serialize, SerializeMap, Serializer},
919};
920use std::marker::PhantomData;
921
922impl<K, V> Serialize for RbTree<K, V>
923where
924    K: Serialize + AsRef<[u8]>,
925    V: Serialize + AsHashTree,
926{
927    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
928    where
929        S: Serializer,
930    {
931        let mut map = serializer.serialize_map(Some(self.iter().count()))?;
932        for (k, v) in self {
933            map.serialize_entry(k, v)?;
934        }
935        map.end()
936    }
937}
938
939// The PhantomData keeps the compiler from complaining about unused generic type parameters.
940struct RbTreeSerdeVisitor<K, V> {
941    marker: PhantomData<fn() -> RbTree<K, V>>,
942}
943
944impl<K, V> RbTreeSerdeVisitor<K, V> {
945    fn new() -> Self {
946        RbTreeSerdeVisitor {
947            marker: PhantomData,
948        }
949    }
950}
951
952impl<'de, K, V> Visitor<'de> for RbTreeSerdeVisitor<K, V>
953where
954    K: Deserialize<'de> + AsRef<[u8]>,
955    V: Deserialize<'de> + AsHashTree,
956{
957    type Value = RbTree<K, V>;
958
959    fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
960        formatter.write_str("a map")
961    }
962
963    fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
964    where
965        M: MapAccess<'de>,
966    {
967        let mut t = RbTree::<K, V>::new();
968        while let Some((key, value)) = access.next_entry()? {
969            t.insert(key, value);
970        }
971        Ok(t)
972    }
973}
974
975impl<'de, K, V> Deserialize<'de> for RbTree<K, V>
976where
977    K: Deserialize<'de> + AsRef<[u8]>,
978    V: Deserialize<'de> + AsHashTree,
979{
980    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
981    where
982        D: Deserializer<'de>,
983    {
984        deserializer.deserialize_map(RbTreeSerdeVisitor::new())
985    }
986}
987
988fn three_way_fork<'a>(l: HashTree<'a>, m: HashTree<'a>, r: HashTree<'a>) -> HashTree<'a> {
989    match (l, m, r) {
990        (Empty, m, Empty) => m,
991        (l, m, Empty) => fork(l, m),
992        (Empty, m, r) => fork(m, r),
993        (Pruned(lhash), Pruned(mhash), Pruned(rhash)) => {
994            Pruned(fork_hash(&lhash, &fork_hash(&mhash, &rhash)))
995        }
996        (l, Pruned(mhash), Pruned(rhash)) => fork(l, Pruned(fork_hash(&mhash, &rhash))),
997        (l, m, r) => fork(l, fork(m, r)),
998    }
999}
1000
1001// helper functions
1002fn is_red<K, V>(x: &NodeRef<K, V>) -> bool {
1003    x.as_ref().is_some_and(|h| h.color == Color::Red)
1004}
1005
1006fn balance<'t, K: AsRef<[u8]> + 't, V: AsHashTree + 't>(mut h: Box<Node<K, V>>) -> Box<Node<K, V>> {
1007    if is_red(&h.right) && !is_red(&h.left) {
1008        h = rotate_left(h);
1009    }
1010    if is_red(&h.left) && is_red(&h.left.as_ref().unwrap().left) {
1011        h = rotate_right(h);
1012    }
1013    if is_red(&h.left) && is_red(&h.right) {
1014        flip_colors(&mut h);
1015    }
1016    h
1017}
1018
1019/// Make a left-leaning link lean to the right.
1020fn rotate_right<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
1021    mut h: Box<Node<K, V>>,
1022) -> Box<Node<K, V>> {
1023    debug_assert!(is_red(&h.left));
1024
1025    let mut x = h.left.take().unwrap();
1026    h.left = x.right.take();
1027    h.update_subtree_hash();
1028
1029    x.right = Some(h);
1030    x.color = x.right.as_ref().unwrap().color;
1031    x.right.as_mut().unwrap().color = Color::Red;
1032    x.update_subtree_hash();
1033
1034    x
1035}
1036
1037fn rotate_left<'t, K: 't + AsRef<[u8]>, V: AsHashTree + 't>(
1038    mut h: Box<Node<K, V>>,
1039) -> Box<Node<K, V>> {
1040    debug_assert!(is_red(&h.right));
1041
1042    let mut x = h.right.take().unwrap();
1043    h.right = x.left.take();
1044    h.update_subtree_hash();
1045
1046    x.left = Some(h);
1047    x.color = x.left.as_ref().unwrap().color;
1048    x.left.as_mut().unwrap().color = Color::Red;
1049    x.update_subtree_hash();
1050
1051    x
1052}
1053
1054fn flip_colors<K, V>(h: &mut Box<Node<K, V>>) {
1055    h.color.flip_assign();
1056    h.left.as_mut().unwrap().color.flip_assign();
1057    h.right.as_mut().unwrap().color.flip_assign();
1058}
1059
1060#[cfg(test)]
1061fn is_balanced<K, V>(root: &NodeRef<K, V>) -> bool {
1062    fn go<K, V>(node: &NodeRef<K, V>, mut num_black: usize) -> bool {
1063        match node {
1064            None => num_black == 0,
1065            Some(n) => {
1066                if !is_red(node) {
1067                    debug_assert!(num_black > 0);
1068                    num_black -= 1;
1069                } else {
1070                    assert!(!is_red(&n.left));
1071                    assert!(!is_red(&n.right));
1072                }
1073                go(&n.left, num_black) && go(&n.right, num_black)
1074            }
1075        }
1076    }
1077
1078    let mut num_black = 0;
1079    let mut x = root;
1080    while let Some(n) = x {
1081        if !is_red(x) {
1082            num_black += 1;
1083        }
1084        x = &n.left;
1085    }
1086    go(root, num_black)
1087}
1088
1089#[allow(dead_code)]
1090struct DebugView<'a, K, V>(&'a NodeRef<K, V>);
1091
1092impl<K: AsRef<[u8]>, V> fmt::Debug for DebugView<'_, K, V> {
1093    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1094        fn go<K: AsRef<[u8]>, V>(
1095            f: &mut fmt::Formatter<'_>,
1096            node: &NodeRef<K, V>,
1097            offset: usize,
1098        ) -> fmt::Result {
1099            match node {
1100                None => writeln!(f, "{:width$}[B] <null>", "", width = offset),
1101                Some(h) => {
1102                    writeln!(
1103                        f,
1104                        "{:width$}[{}] {:?}",
1105                        "",
1106                        if is_red(node) { "R" } else { "B" },
1107                        h.key.as_ref(),
1108                        width = offset
1109                    )?;
1110                    go(f, &h.left, offset + 2)?;
1111                    go(f, &h.right, offset + 2)
1112                }
1113            }
1114        }
1115        go(f, self.0, 0)
1116    }
1117}
1118
1119#[cfg(test)]
1120mod test;