certified_vars/
rbtree.rs

1//! This file contains a low-level RbTree implementation. The code is borrowed from the
2//! `ic-certified-map` crate by Dfinity.
3//!
4//! It is not recommend to use the [`RbTree`] directly since it is a low level data structure
5//! and does only provide basic functionalities. Instead we advise you to look at the
6//! [crate::collections] module.
7
8use std::borrow::{Borrow, Cow};
9use std::cmp::Ordering;
10use std::cmp::Ordering::{Equal, Greater, Less};
11use std::fmt;
12
13use crate::hashtree::{
14    fork, fork_hash, labeled_hash, Hash,
15    HashTree::{self, Empty, Pruned},
16};
17use crate::label::{Label, Prefix};
18use crate::AsHashTree;
19
20#[cfg(test)]
21pub(crate) mod debug_alloc;
22
23pub mod entry;
24pub mod iterator;
25
26#[derive(Clone, Copy, PartialEq, Eq)]
27enum Color {
28    Red,
29    Black,
30}
31
32impl Color {
33    fn flip(self) -> Self {
34        match self {
35            Self::Red => Self::Black,
36            Self::Black => Self::Red,
37        }
38    }
39}
40
41impl<K: 'static + Label, V: AsHashTree + 'static> AsHashTree for RbTree<K, V> {
42    #[inline]
43    fn root_hash(&self) -> Hash {
44        if self.root.is_null() {
45            Empty.reconstruct()
46        } else {
47            unsafe { (*self.root).subtree_hash }
48        }
49    }
50
51    #[inline]
52    fn as_hash_tree(&self) -> HashTree<'_> {
53        unsafe { Node::full_witness_tree(self.root, Node::data_tree) }
54    }
55}
56
57#[derive(PartialEq, Debug)]
58enum KeyBound<'a, T: Label> {
59    Exact(&'a T),
60    Neighbor(&'a T),
61}
62
63impl<'a, T: Label> Clone for KeyBound<'a, T> {
64    fn clone(&self) -> Self {
65        match self {
66            KeyBound::Exact(k) => KeyBound::Exact(*k),
67            KeyBound::Neighbor(k) => KeyBound::Neighbor(*k),
68        }
69    }
70}
71
72impl<'a, T: Label> Copy for KeyBound<'a, T> {}
73
74impl<'a, T: Label> Eq for KeyBound<'a, T> {}
75
76impl<'a, T: Label> PartialOrd<Self> for KeyBound<'a, T> {
77    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
78        self.as_ref().partial_cmp(other.as_ref())
79    }
80}
81
82impl<'a, T: Label> Ord for KeyBound<'a, T> {
83    fn cmp(&self, other: &Self) -> Ordering {
84        self.as_ref().cmp(other.as_ref())
85    }
86}
87
88impl<'a, T: Label> Label for KeyBound<'a, T> {
89    fn as_label(&self) -> Cow<[u8]> {
90        match self {
91            KeyBound::Exact(key) => key.as_label(),
92            KeyBound::Neighbor(key) => key.as_label(),
93        }
94    }
95}
96
97impl<'a, T: Label> AsRef<T> for KeyBound<'a, T> {
98    fn as_ref(&self) -> &T {
99        match self {
100            KeyBound::Exact(key) => key,
101            KeyBound::Neighbor(key) => key,
102        }
103    }
104}
105
106impl<'a, T: Label + AsRef<[u8]>> AsRef<[u8]> for KeyBound<'a, T> {
107    fn as_ref(&self) -> &[u8] {
108        match self {
109            KeyBound::Exact(key) => key.as_ref(),
110            KeyBound::Neighbor(key) => key.as_ref(),
111        }
112    }
113}
114
115// 1. All leaves are black.
116// 2. Children of a red node are black.
117// 3. Every path from a node goes through the same number of black
118//    nodes.
119struct Node<K, V> {
120    key: K,
121    value: V,
122    left: *mut Node<K, V>,
123    right: *mut Node<K, V>,
124    color: Color,
125
126    /// Hash of the full hash tree built from this node and its
127    /// children. It needs to be recomputed after every rotation.
128    subtree_hash: Hash,
129}
130
131impl<K: 'static + Label, V: AsHashTree + 'static> Node<K, V> {
132    #[allow(clippy::let_and_return)]
133    fn new(key: K, value: V) -> *mut Self {
134        let value_hash = value.root_hash();
135        let data_hash = labeled_hash(&key.as_label(), &value_hash);
136        let node = Box::into_raw(Box::new(Self {
137            key,
138            value,
139            left: Node::null(),
140            right: Node::null(),
141            color: Color::Red,
142            subtree_hash: data_hash,
143        }));
144
145        #[cfg(test)]
146        debug_alloc::mark_pointer_allocated(node);
147
148        node
149    }
150
151    unsafe fn data_hash(n: *mut Self) -> Hash {
152        debug_assert!(!n.is_null());
153        labeled_hash(&(*n).key.as_label(), &(*n).value.root_hash())
154    }
155
156    unsafe fn left_hash_tree<'a>(n: *mut Self) -> HashTree<'a> {
157        debug_assert!(!n.is_null());
158        if (*n).left.is_null() {
159            Empty
160        } else {
161            Pruned((*(*n).left).subtree_hash)
162        }
163    }
164
165    unsafe fn right_hash_tree<'a>(n: *mut Self) -> HashTree<'a> {
166        debug_assert!(!n.is_null());
167        if (*n).right.is_null() {
168            Empty
169        } else {
170            Pruned((*(*n).right).subtree_hash)
171        }
172    }
173
174    fn null() -> *mut Self {
175        std::ptr::null::<Self>() as *mut Node<K, V>
176    }
177
178    unsafe fn data_tree<'a>(n: *mut Self) -> HashTree<'a> {
179        debug_assert!(!n.is_null());
180        HashTree::Labeled((*n).key.as_label(), Box::new((*n).value.as_hash_tree()))
181    }
182
183    unsafe fn subtree_with<'a>(
184        n: *mut Self,
185        f: impl FnOnce(&'a V) -> HashTree<'a>,
186    ) -> HashTree<'a> {
187        debug_assert!(!n.is_null());
188
189        HashTree::Labeled((*n).key.as_label(), Box::new(f(&(*n).value)))
190    }
191
192    unsafe fn witness_tree<'a>(n: *mut Self) -> HashTree<'a> {
193        debug_assert!(!n.is_null());
194        let value_hash = (*n).value.root_hash();
195        HashTree::Labeled((*n).key.as_label(), Box::new(Pruned(value_hash)))
196    }
197
198    unsafe fn full_witness_tree<'a>(
199        n: *mut Self,
200        f: unsafe fn(*mut Self) -> HashTree<'a>,
201    ) -> HashTree<'a> {
202        if n.is_null() {
203            return Empty;
204        }
205        three_way_fork(
206            Self::full_witness_tree((*n).left, f),
207            f(n),
208            Self::full_witness_tree((*n).right, f),
209        )
210    }
211
212    unsafe fn delete(n: *mut Self) -> Option<(K, V)> {
213        if n.is_null() {
214            return None;
215        }
216        Self::delete((*n).left);
217        Self::delete((*n).right);
218        let node = Box::from_raw(n);
219
220        #[cfg(test)]
221        debug_alloc::mark_pointer_deleted(n);
222
223        Some((node.key, node.value))
224    }
225
226    unsafe fn subtree_hash(n: *mut Self) -> Hash {
227        if n.is_null() {
228            return Empty.reconstruct();
229        }
230
231        let h = Node::data_hash(n);
232
233        match ((*n).left.is_null(), (*n).right.is_null()) {
234            (true, true) => h,
235            (false, true) => fork_hash(&(*(*n).left).subtree_hash, &h),
236            (true, false) => fork_hash(&h, &(*(*n).right).subtree_hash),
237            (false, false) => fork_hash(
238                &(*(*n).left).subtree_hash,
239                &fork_hash(&h, &(*(*n).right).subtree_hash),
240            ),
241        }
242    }
243}
244
245/// Implements mutable Leaf-leaning red-black trees as defined in
246/// https://www.cs.princeton.edu/~rs/talks/LLRB/LLRB.pdf
247pub struct RbTree<K: 'static + Label, V: AsHashTree + 'static> {
248    len: usize,
249    root: *mut Node<K, V>,
250}
251
252impl<K: 'static + Label, V: AsHashTree + 'static> Drop for RbTree<K, V> {
253    fn drop(&mut self) {
254        unsafe {
255            Node::delete(self.root);
256        }
257    }
258}
259
260impl<K: 'static + Label, V: AsHashTree + 'static> Default for RbTree<K, V> {
261    fn default() -> Self {
262        Self::new()
263    }
264}
265
266impl<K: 'static + Label, V: AsHashTree + 'static> RbTree<K, V> {
267    #[inline]
268    pub fn new() -> Self {
269        Self {
270            len: 0,
271            root: Node::null(),
272        }
273    }
274
275    #[inline]
276    pub fn len(&self) -> usize {
277        self.len
278    }
279
280    #[inline]
281    pub fn is_empty(&self) -> bool {
282        self.root.is_null()
283    }
284
285    pub fn entry(&mut self, key: K) -> entry::Entry<K, V> {
286        let node = unsafe { self.get_node(&key) };
287
288        if node.is_null() {
289            entry::Entry::Vacant(entry::VacantEntry { map: self, key })
290        } else {
291            entry::Entry::Occupied(entry::OccupiedEntry {
292                map: self,
293                key,
294                node,
295            })
296        }
297    }
298
299    #[inline]
300    pub fn get<Q: ?Sized>(&self, key: &Q) -> Option<&V>
301    where
302        K: Borrow<Q>,
303        Q: Ord,
304    {
305        unsafe {
306            let mut root = self.root;
307            while !root.is_null() {
308                match key.cmp((*root).key.borrow()) {
309                    Equal => return Some(&(*root).value),
310                    Less => root = (*root).left,
311                    Greater => root = (*root).right,
312                }
313            }
314            None
315        }
316    }
317
318    #[inline]
319    pub fn get_with(&self, cmp: impl Fn(&K) -> Ordering) -> Option<&V> {
320        unsafe {
321            let mut root = self.root;
322            while !root.is_null() {
323                match cmp(&(*root).key) {
324                    Equal => return Some(&(*root).value),
325                    Less => root = (*root).left,
326                    Greater => root = (*root).right,
327                }
328            }
329            None
330        }
331    }
332
333    #[inline]
334    unsafe fn get_node(&self, key: &K) -> *mut Node<K, V> {
335        let mut root = self.root;
336        while !root.is_null() {
337            match key.cmp(&(*root).key) {
338                Equal => return root,
339                Less => root = (*root).left,
340                Greater => root = (*root).right,
341            }
342        }
343        Node::null()
344    }
345
346    /// Updates the value corresponding to the specified key.
347    #[inline]
348    pub fn modify<'a, Q: ?Sized, T>(&mut self, key: &Q, f: impl FnOnce(&'a mut V) -> T) -> Option<T>
349    where
350        K: Borrow<Q>,
351        Q: Ord,
352    {
353        unsafe fn go<'a, K: 'static + Label, V: AsHashTree + 'static, T, Q: ?Sized>(
354            mut h: *mut Node<K, V>,
355            k: &Q,
356            f: impl FnOnce(&'a mut V) -> T,
357        ) -> Option<T>
358        where
359            K: Borrow<Q>,
360            Q: Ord,
361        {
362            if h.is_null() {
363                return None;
364            }
365
366            match k.cmp((*h).key.borrow()) {
367                Equal => {
368                    let res = f(&mut (*h).value);
369                    (*h).subtree_hash = Node::subtree_hash(h);
370                    Some(res)
371                }
372                Less => {
373                    let res = go((*h).left, k, f);
374                    (*h).subtree_hash = Node::subtree_hash(h);
375                    res
376                }
377                Greater => {
378                    let res = go((*h).right, k, f);
379                    (*h).subtree_hash = Node::subtree_hash(h);
380                    res
381                }
382            }
383        }
384        unsafe { go(self.root, key, f) }
385    }
386
387    /// Modify the maximum node with the given prefix.
388    pub fn modify_max_with_prefix<'a, P: ?Sized, T>(
389        &mut self,
390        prefix: &P,
391        f: impl FnOnce(&'a K, &'a mut V) -> T,
392    ) -> Option<T>
393    where
394        K: Prefix<P>,
395        P: Ord,
396    {
397        unsafe fn go<
398            'a,
399            K: Label + 'static,
400            V: AsHashTree + 'static,
401            P: ?Sized,
402            T,
403            F: FnOnce(&'a K, &'a mut V) -> T,
404        >(
405            mut h: *mut Node<K, V>,
406            prefix: &P,
407            f: F,
408        ) -> (Option<T>, Option<F>)
409        where
410            K: Prefix<P>,
411            P: Ord,
412        {
413            if h.is_null() {
414                return (None, Some(f));
415            }
416
417            let node_key = &(*h).key;
418            let key_prefix = node_key.borrow();
419
420            let res = match key_prefix.cmp(prefix) {
421                Greater | Equal if node_key.is_prefix(prefix) => match go((*h).right, prefix, f) {
422                    (None, Some(f)) => {
423                        let ret = f(node_key, &mut (*h).value);
424                        (Some(ret), None)
425                    }
426                    ret => ret,
427                },
428                Greater => go((*h).left, prefix, f),
429                Less | Equal => go((*h).right, prefix, f),
430            };
431
432            if res.0.is_some() {
433                (*h).subtree_hash = Node::subtree_hash(h);
434            }
435
436            res
437        }
438
439        unsafe { go(self.root, prefix, f).0 }
440    }
441
442    pub fn max_entry_with_prefix<P: ?Sized>(&self, prefix: &P) -> Option<(&K, &V)>
443    where
444        K: Prefix<P>,
445        P: Ord,
446    {
447        unsafe fn go<'a, K: 'static + Label, V, P: ?Sized>(
448            n: *mut Node<K, V>,
449            prefix: &P,
450        ) -> Option<(&'a K, &'a V)>
451        where
452            K: Prefix<P>,
453            P: Ord,
454        {
455            if n.is_null() {
456                return None;
457            }
458
459            let node_key = &(*n).key;
460            let key_prefix = node_key.borrow();
461            match key_prefix.cmp(prefix) {
462                Greater | Equal if node_key.is_prefix(prefix) => {
463                    go((*n).right, prefix).or(Some((node_key, &(*n).value)))
464                }
465                Greater => go((*n).left, prefix),
466                Less | Equal => go((*n).right, prefix),
467            }
468        }
469        unsafe { go(self.root, prefix) }
470    }
471
472    fn range_witness<'a>(
473        &'a self,
474        left: Option<KeyBound<'a, K>>,
475        right: Option<KeyBound<'a, K>>,
476        f: unsafe fn(*mut Node<K, V>) -> HashTree<'a>,
477    ) -> HashTree<'a> {
478        match (left, right) {
479            (None, None) => unsafe { Node::full_witness_tree(self.root, f) },
480            (Some(l), None) => self.witness_range_above(l, f),
481            (None, Some(r)) => self.witness_range_below(r, f),
482            (Some(l), Some(r)) => self.witness_range_between(l, r, f),
483        }
484    }
485
486    /// Constructs a hash tree that acts as a proof that there is a
487    /// entry with the specified key in this map.  The proof also
488    /// contains the value in question.
489    ///
490    /// If the key is not in the map, returns a proof of absence.
491    #[inline]
492    pub fn witness<Q: ?Sized>(&self, key: &Q) -> HashTree<'_>
493    where
494        K: Borrow<Q>,
495        Q: Ord,
496    {
497        self.nested_witness(key, |v| v.as_hash_tree())
498    }
499
500    /// Like `witness`, but gives the caller more control over the
501    /// construction of the value witness.  This method is useful for
502    /// constructing witnesses for nested certified maps.
503    #[inline]
504    pub fn nested_witness<'a, Q: ?Sized>(
505        &'a self,
506        key: &Q,
507        f: impl FnOnce(&'a V) -> HashTree<'a>,
508    ) -> HashTree<'a>
509    where
510        K: Borrow<Q>,
511        Q: Ord,
512    {
513        if let Some(t) = self.lookup_and_build_witness(key, f) {
514            return t;
515        }
516        self.range_witness(
517            self.lower_bound(key),
518            self.upper_bound(key),
519            Node::witness_tree,
520        )
521    }
522
523    /// Returns a witness enumerating all the keys in this map.  The
524    /// resulting tree doesn't include values, they are replaced with
525    /// "Pruned" nodes.
526    #[inline]
527    pub fn keys(&self) -> HashTree<'_> {
528        unsafe { Node::full_witness_tree(self.root, Node::witness_tree) }
529    }
530
531    /// Returns a witness for the keys in the specified range.  The
532    /// resulting tree doesn't include values, they are replaced with
533    /// "Pruned" nodes.
534    #[inline]
535    pub fn key_range<Q1: ?Sized, Q2: ?Sized>(&self, first: &Q1, last: &Q2) -> HashTree<'_>
536    where
537        K: Borrow<Q1> + Borrow<Q2>,
538        Q1: Ord,
539        Q2: Ord,
540    {
541        self.range_witness(
542            self.lower_bound(first),
543            self.upper_bound(last),
544            Node::witness_tree,
545        )
546    }
547
548    /// Returns a witness for the key-value pairs in the specified range.
549    /// The resulting tree contains both keys and values.
550    #[inline]
551    pub fn value_range<Q1: ?Sized, Q2: ?Sized>(&self, first: &Q1, last: &Q2) -> HashTree<'_>
552    where
553        K: Borrow<Q1> + Borrow<Q2>,
554        Q1: Ord,
555        Q2: Ord,
556    {
557        self.range_witness(
558            self.lower_bound(first),
559            self.upper_bound(last),
560            Node::data_tree,
561        )
562    }
563
564    /// Returns a witness that enumerates all the keys starting with
565    /// the specified prefix.
566    #[inline]
567    pub fn keys_with_prefix<P: ?Sized>(&self, prefix: &P) -> HashTree<'_>
568    where
569        K: Prefix<P>,
570        P: Ord,
571    {
572        self.range_witness(
573            self.lower_bound(prefix),
574            self.right_prefix_neighbor(prefix),
575            Node::witness_tree,
576        )
577    }
578
579    /// Enumerates all the key-value pairs in the tree.
580    #[inline]
581    pub fn for_each<'a, F>(&'a self, mut f: F)
582    where
583        F: 'a + FnMut(&'a K, &'a V),
584    {
585        unsafe fn visit<'a, K, V, F>(n: *mut Node<K, V>, f: &mut F)
586        where
587            F: 'a + FnMut(&'a K, &'a V),
588            K: 'static + Label,
589            V: 'a + AsHashTree,
590        {
591            debug_assert!(!n.is_null());
592            if !(*n).left.is_null() {
593                visit((*n).left, f)
594            }
595            (*f)(&(*n).key, &(*n).value);
596            if !(*n).right.is_null() {
597                visit((*n).right, f)
598            }
599        }
600        if self.root.is_null() {
601            return;
602        }
603        unsafe { visit(self.root, &mut f) }
604    }
605
606    fn witness_range_above<'a>(
607        &'a self,
608        lo: KeyBound<'a, K>,
609        f: unsafe fn(*mut Node<K, V>) -> HashTree<'a>,
610    ) -> HashTree<'a> {
611        unsafe fn go<'a, K: 'static + Label, V: AsHashTree + 'static>(
612            n: *mut Node<K, V>,
613            lo: KeyBound<'a, K>,
614            f: unsafe fn(*mut Node<K, V>) -> HashTree<'a>,
615        ) -> HashTree<'a> {
616            if n.is_null() {
617                return Empty;
618            }
619            match (*n).key.cmp(lo.as_ref()) {
620                Equal => three_way_fork(
621                    Node::left_hash_tree(n),
622                    match lo {
623                        KeyBound::Exact(_) => f(n),
624                        KeyBound::Neighbor(_) => Node::witness_tree(n),
625                    },
626                    Node::full_witness_tree((*n).right, f),
627                ),
628                Less => three_way_fork(
629                    Node::left_hash_tree(n),
630                    Pruned(Node::data_hash(n)),
631                    go((*n).right, lo, f),
632                ),
633                Greater => three_way_fork(
634                    go((*n).left, lo, f),
635                    f(n),
636                    Node::full_witness_tree((*n).right, f),
637                ),
638            }
639        }
640        unsafe { go(self.root, lo, f) }
641    }
642
643    fn witness_range_below<'a>(
644        &'a self,
645        hi: KeyBound<'a, K>,
646        f: unsafe fn(*mut Node<K, V>) -> HashTree<'a>,
647    ) -> HashTree<'a> {
648        unsafe fn go<'a, K: 'static + Label, V: AsHashTree + 'static>(
649            n: *mut Node<K, V>,
650            hi: KeyBound<'a, K>,
651            f: unsafe fn(*mut Node<K, V>) -> HashTree<'a>,
652        ) -> HashTree<'a> {
653            if n.is_null() {
654                return Empty;
655            }
656            match (*n).key.cmp(hi.as_ref()) {
657                Equal => three_way_fork(
658                    Node::full_witness_tree((*n).left, f),
659                    match hi {
660                        KeyBound::Exact(_) => f(n),
661                        KeyBound::Neighbor(_) => Node::witness_tree(n),
662                    },
663                    Node::right_hash_tree(n),
664                ),
665                Greater => three_way_fork(
666                    go((*n).left, hi, f),
667                    Pruned(Node::data_hash(n)),
668                    Node::right_hash_tree(n),
669                ),
670                Less => three_way_fork(
671                    Node::full_witness_tree((*n).left, f),
672                    f(n),
673                    go((*n).right, hi, f),
674                ),
675            }
676        }
677        unsafe { go(self.root, hi, f) }
678    }
679
680    fn witness_range_between<'a>(
681        &'a self,
682        lo: KeyBound<'a, K>,
683        hi: KeyBound<'a, K>,
684        f: unsafe fn(*mut Node<K, V>) -> HashTree<'a>,
685    ) -> HashTree<'a> {
686        debug_assert!(
687            lo.as_ref() <= hi.as_ref(),
688            "lo = {:?} > hi = {:?}",
689            lo.as_ref().as_label(),
690            hi.as_ref().as_label()
691        );
692        unsafe fn go<'a, K: 'static + Label, V: AsHashTree + 'static>(
693            n: *mut Node<K, V>,
694            lo: KeyBound<'a, K>,
695            hi: KeyBound<'a, K>,
696            f: unsafe fn(*mut Node<K, V>) -> HashTree<'a>,
697        ) -> HashTree<'a> {
698            if n.is_null() {
699                return Empty;
700            }
701            let k = &(*n).key;
702            match (lo.as_ref().cmp(k), k.cmp(hi.as_ref())) {
703                (Less, Less) => {
704                    let left = go((*n).left, lo, hi, f);
705                    let right = go((*n).right, lo, hi, f);
706                    three_way_fork(left, f(n), right)
707                }
708                (Equal, Equal) => three_way_fork(
709                    Node::left_hash_tree(n),
710                    match (lo, hi) {
711                        (KeyBound::Exact(_), _) => f(n),
712                        (_, KeyBound::Exact(_)) => f(n),
713                        _ => Node::witness_tree(n),
714                    },
715                    Node::right_hash_tree(n),
716                ),
717                (_, Equal) => three_way_fork(
718                    go((*n).left, lo, hi, f),
719                    match hi {
720                        KeyBound::Exact(_) => f(n),
721                        KeyBound::Neighbor(_) => Node::witness_tree(n),
722                    },
723                    Node::right_hash_tree(n),
724                ),
725                (Equal, _) => three_way_fork(
726                    Node::left_hash_tree(n),
727                    match lo {
728                        KeyBound::Exact(_) => f(n),
729                        KeyBound::Neighbor(_) => Node::witness_tree(n),
730                    },
731                    go((*n).right, lo, hi, f),
732                ),
733                (Less, Greater) => three_way_fork(
734                    go((*n).left, lo, hi, f),
735                    Pruned(Node::data_hash(n)),
736                    Node::right_hash_tree(n),
737                ),
738                (Greater, Less) => three_way_fork(
739                    Node::left_hash_tree(n),
740                    Pruned(Node::data_hash(n)),
741                    go((*n).right, lo, hi, f),
742                ),
743                _ => Pruned((*n).subtree_hash),
744            }
745        }
746        unsafe { go(self.root, lo, hi, f) }
747    }
748
749    fn lower_bound<Q: ?Sized>(&self, key: &Q) -> Option<KeyBound<'_, K>>
750    where
751        K: Borrow<Q>,
752        Q: Ord,
753    {
754        unsafe fn go<'a, K: 'static + Label, V, Q: ?Sized>(
755            n: *mut Node<K, V>,
756            key: &Q,
757        ) -> Option<KeyBound<'a, K>>
758        where
759            K: Borrow<Q>,
760            Q: Ord,
761        {
762            if n.is_null() {
763                return None;
764            }
765            let node_key = &(*n).key;
766            match node_key.borrow().cmp(key) {
767                Less => go((*n).right, key).or(Some(KeyBound::Neighbor(node_key))),
768                Equal => Some(KeyBound::Exact(node_key)),
769                Greater => go((*n).left, key),
770            }
771        }
772        unsafe { go(self.root, key) }
773    }
774
775    fn upper_bound<Q: ?Sized>(&self, key: &Q) -> Option<KeyBound<'_, K>>
776    where
777        K: Borrow<Q>,
778        Q: Ord,
779    {
780        unsafe fn go<'a, K: 'static + Label, V, Q: ?Sized>(
781            n: *mut Node<K, V>,
782            key: &Q,
783        ) -> Option<KeyBound<'a, K>>
784        where
785            K: Borrow<Q>,
786            Q: Ord,
787        {
788            if n.is_null() {
789                return None;
790            }
791            let node_key = &(*n).key;
792            match node_key.borrow().cmp(key) {
793                Less => go((*n).right, key),
794                Equal => Some(KeyBound::Exact(node_key)),
795                Greater => go((*n).left, key).or(Some(KeyBound::Neighbor(node_key))),
796            }
797        }
798        unsafe { go(self.root, key) }
799    }
800
801    fn right_prefix_neighbor<P: ?Sized>(&self, prefix: &P) -> Option<KeyBound<'_, K>>
802    where
803        K: Prefix<P>,
804        P: Ord,
805    {
806        unsafe fn go<'a, K: 'static + Label, V, P: ?Sized>(
807            n: *mut Node<K, V>,
808            prefix: &P,
809        ) -> Option<KeyBound<'a, K>>
810        where
811            K: Prefix<P>,
812            P: Ord,
813        {
814            if n.is_null() {
815                return None;
816            }
817            let node_key = &(*n).key;
818            let key_prefix = node_key.borrow();
819            match key_prefix.cmp(prefix) {
820                Greater if node_key.is_prefix(prefix) => go((*n).right, prefix),
821                Greater => go((*n).left, prefix).or(Some(KeyBound::Neighbor(node_key))),
822                Less | Equal => go((*n).right, prefix),
823            }
824        }
825        unsafe { go(self.root, prefix) }
826    }
827
828    fn lookup_and_build_witness<'a, Q: ?Sized>(
829        &'a self,
830        key: &Q,
831        f: impl FnOnce(&'a V) -> HashTree<'a>,
832    ) -> Option<HashTree<'a>>
833    where
834        K: Borrow<Q>,
835        Q: Ord,
836    {
837        unsafe fn go<'a, K: 'static + Label, V: AsHashTree + 'static, Q: ?Sized>(
838            n: *mut Node<K, V>,
839            key: &Q,
840            f: impl FnOnce(&'a V) -> HashTree<'a>,
841        ) -> Option<HashTree<'a>>
842        where
843            K: Borrow<Q>,
844            Q: Ord,
845        {
846            if n.is_null() {
847                return None;
848            }
849            match key.cmp((*n).key.borrow()) {
850                Equal => Some(three_way_fork(
851                    Node::left_hash_tree(n),
852                    Node::subtree_with(n, f),
853                    Node::right_hash_tree(n),
854                )),
855                Less => {
856                    let subtree = go((*n).left, key, f)?;
857                    Some(three_way_fork(
858                        subtree,
859                        Pruned(Node::data_hash(n)),
860                        Node::right_hash_tree(n),
861                    ))
862                }
863                Greater => {
864                    let subtree = go((*n).right, key, f)?;
865                    Some(three_way_fork(
866                        Node::left_hash_tree(n),
867                        Pruned(Node::data_hash(n)),
868                        subtree,
869                    ))
870                }
871            }
872        }
873        unsafe { go(self.root, key, f) }
874    }
875
876    /// Inserts a key-value entry into the map.
877    #[inline]
878    pub fn insert(&mut self, key: K, value: V) -> (Option<V>, &mut V) {
879        struct GoResult<'a, K, V> {
880            node: *mut Node<K, V>,
881            old_value: Option<V>,
882            new_value_ref: &'a mut V,
883        }
884
885        unsafe fn go<K: 'static + Label, V: AsHashTree + 'static>(
886            mut h: *mut Node<K, V>,
887            k: K,
888            mut v: V,
889        ) -> GoResult<'static, K, V> {
890            if h.is_null() {
891                let node = Node::new(k, v);
892                return GoResult {
893                    node,
894                    old_value: None,
895                    new_value_ref: &mut (*node).value,
896                };
897            }
898
899            let (old_value, new_value_ref) = match k.cmp(&(*h).key) {
900                Equal => {
901                    std::mem::swap(&mut (*h).value, &mut v);
902                    (*h).subtree_hash = Node::subtree_hash(h);
903                    (Some(v), &mut (*h).value)
904                }
905                Less => {
906                    let res = go((*h).left, k, v);
907                    (*h).left = res.node;
908                    (*h).subtree_hash = Node::subtree_hash(h);
909                    (res.old_value, res.new_value_ref)
910                }
911                Greater => {
912                    let res = go((*h).right, k, v);
913                    (*h).right = res.node;
914                    (*h).subtree_hash = Node::subtree_hash(h);
915                    (res.old_value, res.new_value_ref)
916                }
917            };
918
919            GoResult {
920                node: balance(h),
921                old_value,
922                new_value_ref,
923            }
924        }
925
926        unsafe {
927            let mut result = go(self.root, key, value);
928            (*result.node).color = Color::Black;
929
930            #[cfg(test)]
931            debug_assert!(
932                is_balanced(result.node),
933                "the tree is not balanced:\n{:?}",
934                DebugView(result.node)
935            );
936            #[cfg(test)]
937            debug_assert!(!has_dangling_pointers(result.node));
938
939            if result.old_value.is_none() {
940                self.len += 1;
941            }
942
943            self.root = result.node;
944            (result.old_value, result.new_value_ref)
945        }
946    }
947
948    /// Removes the specified key from the map.
949    #[inline]
950    pub fn delete<Q: ?Sized>(&mut self, key: &Q) -> Option<(K, V)>
951    where
952        K: Borrow<Q>,
953        Q: Ord,
954    {
955        unsafe fn move_red_left<K: 'static + Label, V: AsHashTree + 'static>(
956            mut h: *mut Node<K, V>,
957        ) -> *mut Node<K, V> {
958            flip_colors(h);
959            if is_red((*(*h).right).left) {
960                (*h).right = rotate_right((*h).right);
961                h = rotate_left(h);
962                flip_colors(h);
963            }
964            h
965        }
966
967        unsafe fn move_red_right<K: 'static + Label, V: AsHashTree + 'static>(
968            mut h: *mut Node<K, V>,
969        ) -> *mut Node<K, V> {
970            flip_colors(h);
971            if is_red((*(*h).left).left) {
972                h = rotate_right(h);
973                flip_colors(h);
974            }
975            h
976        }
977
978        #[inline]
979        unsafe fn min<K: 'static + Label, V: AsHashTree + 'static>(
980            mut h: *mut Node<K, V>,
981        ) -> *mut Node<K, V> {
982            while !(*h).left.is_null() {
983                h = (*h).left;
984            }
985            h
986        }
987
988        unsafe fn delete_min<K: 'static + Label, V: AsHashTree + 'static>(
989            mut h: *mut Node<K, V>,
990            result: &mut Option<(K, V)>,
991        ) -> *mut Node<K, V> {
992            if (*h).left.is_null() {
993                debug_assert!((*h).right.is_null());
994                *result = Some(Node::delete(h).unwrap());
995                return Node::null();
996            }
997            if !is_red((*h).left) && !is_red((*(*h).left).left) {
998                h = move_red_left(h);
999            }
1000            (*h).left = delete_min((*h).left, result);
1001            (*h).subtree_hash = Node::subtree_hash(h);
1002            balance(h)
1003        }
1004
1005        unsafe fn go<K: 'static + Label, V: AsHashTree + 'static, Q: ?Sized>(
1006            mut h: *mut Node<K, V>,
1007            result: &mut Option<(K, V)>,
1008            key: &Q,
1009        ) -> *mut Node<K, V>
1010        where
1011            K: Borrow<Q>,
1012            Q: Ord,
1013        {
1014            if key < (*h).key.borrow() {
1015                if !is_red((*h).left) && !is_red((*(*h).left).left) {
1016                    h = move_red_left(h);
1017                }
1018                (*h).left = go((*h).left, result, key);
1019            } else {
1020                if is_red((*h).left) {
1021                    h = rotate_right(h);
1022                }
1023                if key == (*h).key.borrow() && (*h).right.is_null() {
1024                    debug_assert!((*h).left.is_null());
1025                    *result = Some(Node::delete(h).unwrap());
1026                    return Node::null();
1027                }
1028
1029                if !is_red((*h).right) && !is_red((*(*h).right).left) {
1030                    h = move_red_right(h);
1031                }
1032
1033                if key == (*h).key.borrow() {
1034                    let m = min((*h).right);
1035                    std::mem::swap(&mut (*h).key, &mut (*m).key);
1036                    std::mem::swap(&mut (*h).value, &mut (*m).value);
1037                    (*h).right = delete_min((*h).right, result);
1038                } else {
1039                    (*h).right = go((*h).right, result, key);
1040                }
1041            }
1042            (*h).subtree_hash = Node::subtree_hash(h);
1043            balance(h)
1044        }
1045
1046        unsafe {
1047            self.get(key)?;
1048            if !is_red((*self.root).left) && !is_red((*self.root).right) {
1049                (*self.root).color = Color::Red;
1050            }
1051
1052            let mut result = None;
1053            self.root = go(self.root, &mut result, key);
1054            if !self.root.is_null() {
1055                (*self.root).color = Color::Black;
1056            }
1057
1058            #[cfg(test)]
1059            debug_assert!(
1060                is_balanced(self.root),
1061                "unbalanced map: {:?}",
1062                DebugView(self.root)
1063            );
1064
1065            #[cfg(test)]
1066            debug_assert!(result.is_some());
1067            self.len -= 1;
1068
1069            debug_assert!(self.get(key).is_none());
1070            result
1071        }
1072    }
1073}
1074
1075fn three_way_fork<'a>(l: HashTree<'a>, m: HashTree<'a>, r: HashTree<'a>) -> HashTree<'a> {
1076    match (l, m, r) {
1077        (Empty, m, Empty) => m,
1078        (l, m, Empty) => fork(l, m),
1079        (Empty, m, r) => fork(m, r),
1080        (Pruned(lhash), Pruned(mhash), Pruned(rhash)) => {
1081            Pruned(fork_hash(&lhash, &fork_hash(&mhash, &rhash)))
1082        }
1083        (l, Pruned(mhash), Pruned(rhash)) => fork(l, Pruned(fork_hash(&mhash, &rhash))),
1084        (l, m, r) => fork(l, fork(m, r)),
1085    }
1086}
1087
1088// helper functions
1089unsafe fn is_red<K, V>(x: *const Node<K, V>) -> bool {
1090    if x.is_null() {
1091        false
1092    } else {
1093        (*x).color == Color::Red
1094    }
1095}
1096
1097unsafe fn balance<K: Label + 'static, V: AsHashTree + 'static>(
1098    mut h: *mut Node<K, V>,
1099) -> *mut Node<K, V> {
1100    assert!(!h.is_null());
1101
1102    if is_red((*h).right) && !is_red((*h).left) {
1103        h = rotate_left(h);
1104    }
1105    if is_red((*h).left) && is_red((*(*h).left).left) {
1106        h = rotate_right(h);
1107    }
1108    if is_red((*h).left) && is_red((*h).right) {
1109        flip_colors(h)
1110    }
1111    h
1112}
1113
1114/// Make a left-leaning link lean to the right.
1115unsafe fn rotate_right<K: 'static + Label, V: AsHashTree + 'static>(
1116    h: *mut Node<K, V>,
1117) -> *mut Node<K, V> {
1118    debug_assert!(!h.is_null());
1119    debug_assert!(is_red((*h).left));
1120
1121    let mut x = (*h).left;
1122    (*h).left = (*x).right;
1123    (*x).right = h;
1124    (*x).color = (*(*x).right).color;
1125    (*(*x).right).color = Color::Red;
1126
1127    (*h).subtree_hash = Node::subtree_hash(h);
1128    (*x).subtree_hash = Node::subtree_hash(x);
1129
1130    x
1131}
1132
1133unsafe fn rotate_left<K: 'static + Label, V: AsHashTree + 'static>(
1134    h: *mut Node<K, V>,
1135) -> *mut Node<K, V> {
1136    debug_assert!(!h.is_null());
1137    debug_assert!(is_red((*h).right));
1138
1139    let mut x = (*h).right;
1140    (*h).right = (*x).left;
1141    (*x).left = h;
1142    (*x).color = (*(*x).left).color;
1143    (*(*x).left).color = Color::Red;
1144
1145    (*h).subtree_hash = Node::subtree_hash(h);
1146    (*x).subtree_hash = Node::subtree_hash(x);
1147
1148    x
1149}
1150
1151unsafe fn flip_colors<K, V>(h: *mut Node<K, V>) {
1152    (*h).color = (*h).color.flip();
1153    (*(*h).left).color = (*(*h).left).color.flip();
1154    (*(*h).right).color = (*(*h).right).color.flip();
1155}
1156
1157#[cfg(test)]
1158unsafe fn is_balanced<K, V>(root: *mut Node<K, V>) -> bool {
1159    unsafe fn go<K, V>(node: *mut Node<K, V>, mut num_black: usize) -> bool {
1160        if node.is_null() {
1161            return num_black == 0;
1162        }
1163        if !is_red(node) {
1164            debug_assert!(num_black > 0);
1165            num_black -= 1;
1166        } else {
1167            assert!(!is_red((*node).left));
1168            assert!(!is_red((*node).right));
1169        }
1170        go((*node).left, num_black) && go((*node).right, num_black)
1171    }
1172
1173    let mut num_black = 0;
1174    let mut x = root;
1175    while !x.is_null() {
1176        if !is_red(x) {
1177            num_black += 1;
1178        }
1179        x = (*x).left;
1180    }
1181    go(root, num_black)
1182}
1183
1184#[cfg(test)]
1185unsafe fn has_dangling_pointers<K, V>(root: *mut Node<K, V>) -> bool {
1186    if root.is_null() {
1187        return false;
1188    }
1189
1190    !debug_alloc::is_live(root)
1191        || has_dangling_pointers((*root).left)
1192        || has_dangling_pointers((*root).right)
1193}
1194
1195struct DebugView<K, V>(*const Node<K, V>);
1196
1197impl<K: Label, V> fmt::Debug for DebugView<K, V> {
1198    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1199        unsafe fn go<K: Label, V>(
1200            f: &mut fmt::Formatter<'_>,
1201            h: *const Node<K, V>,
1202            offset: usize,
1203        ) -> fmt::Result {
1204            if h.is_null() {
1205                writeln!(f, "{:width$}[B] <null>", "", width = offset)
1206            } else {
1207                writeln!(
1208                    f,
1209                    "{:width$}[{}] {:?}",
1210                    "",
1211                    if is_red(h) { "R" } else { "B" },
1212                    (*h).key.as_label(),
1213                    width = offset
1214                )?;
1215                go(f, (*h).left, offset + 2)?;
1216                go(f, (*h).right, offset + 2)
1217            }
1218        }
1219        unsafe { go(f, self.0, 0) }
1220    }
1221}
1222
1223#[cfg(test)]
1224mod test;