1use std::borrow::{Borrow, Cow};
9use std::cmp::Ordering;
10use std::cmp::Ordering::{Equal, Greater, Less};
11use std::fmt;
12
13use crate::hashtree::{
14 fork, fork_hash, labeled_hash, Hash,
15 HashTree::{self, Empty, Pruned},
16};
17use crate::label::{Label, Prefix};
18use crate::AsHashTree;
19
20#[cfg(test)]
21pub(crate) mod debug_alloc;
22
23pub mod entry;
24pub mod iterator;
25
26#[derive(Clone, Copy, PartialEq, Eq)]
27enum Color {
28 Red,
29 Black,
30}
31
32impl Color {
33 fn flip(self) -> Self {
34 match self {
35 Self::Red => Self::Black,
36 Self::Black => Self::Red,
37 }
38 }
39}
40
41impl<K: 'static + Label, V: AsHashTree + 'static> AsHashTree for RbTree<K, V> {
42 #[inline]
43 fn root_hash(&self) -> Hash {
44 if self.root.is_null() {
45 Empty.reconstruct()
46 } else {
47 unsafe { (*self.root).subtree_hash }
48 }
49 }
50
51 #[inline]
52 fn as_hash_tree(&self) -> HashTree<'_> {
53 unsafe { Node::full_witness_tree(self.root, Node::data_tree) }
54 }
55}
56
57#[derive(PartialEq, Debug)]
58enum KeyBound<'a, T: Label> {
59 Exact(&'a T),
60 Neighbor(&'a T),
61}
62
63impl<'a, T: Label> Clone for KeyBound<'a, T> {
64 fn clone(&self) -> Self {
65 match self {
66 KeyBound::Exact(k) => KeyBound::Exact(*k),
67 KeyBound::Neighbor(k) => KeyBound::Neighbor(*k),
68 }
69 }
70}
71
72impl<'a, T: Label> Copy for KeyBound<'a, T> {}
73
74impl<'a, T: Label> Eq for KeyBound<'a, T> {}
75
76impl<'a, T: Label> PartialOrd<Self> for KeyBound<'a, T> {
77 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
78 self.as_ref().partial_cmp(other.as_ref())
79 }
80}
81
82impl<'a, T: Label> Ord for KeyBound<'a, T> {
83 fn cmp(&self, other: &Self) -> Ordering {
84 self.as_ref().cmp(other.as_ref())
85 }
86}
87
88impl<'a, T: Label> Label for KeyBound<'a, T> {
89 fn as_label(&self) -> Cow<[u8]> {
90 match self {
91 KeyBound::Exact(key) => key.as_label(),
92 KeyBound::Neighbor(key) => key.as_label(),
93 }
94 }
95}
96
97impl<'a, T: Label> AsRef<T> for KeyBound<'a, T> {
98 fn as_ref(&self) -> &T {
99 match self {
100 KeyBound::Exact(key) => key,
101 KeyBound::Neighbor(key) => key,
102 }
103 }
104}
105
106impl<'a, T: Label + AsRef<[u8]>> AsRef<[u8]> for KeyBound<'a, T> {
107 fn as_ref(&self) -> &[u8] {
108 match self {
109 KeyBound::Exact(key) => key.as_ref(),
110 KeyBound::Neighbor(key) => key.as_ref(),
111 }
112 }
113}
114
115struct Node<K, V> {
120 key: K,
121 value: V,
122 left: *mut Node<K, V>,
123 right: *mut Node<K, V>,
124 color: Color,
125
126 subtree_hash: Hash,
129}
130
131impl<K: 'static + Label, V: AsHashTree + 'static> Node<K, V> {
132 #[allow(clippy::let_and_return)]
133 fn new(key: K, value: V) -> *mut Self {
134 let value_hash = value.root_hash();
135 let data_hash = labeled_hash(&key.as_label(), &value_hash);
136 let node = Box::into_raw(Box::new(Self {
137 key,
138 value,
139 left: Node::null(),
140 right: Node::null(),
141 color: Color::Red,
142 subtree_hash: data_hash,
143 }));
144
145 #[cfg(test)]
146 debug_alloc::mark_pointer_allocated(node);
147
148 node
149 }
150
151 unsafe fn data_hash(n: *mut Self) -> Hash {
152 debug_assert!(!n.is_null());
153 labeled_hash(&(*n).key.as_label(), &(*n).value.root_hash())
154 }
155
156 unsafe fn left_hash_tree<'a>(n: *mut Self) -> HashTree<'a> {
157 debug_assert!(!n.is_null());
158 if (*n).left.is_null() {
159 Empty
160 } else {
161 Pruned((*(*n).left).subtree_hash)
162 }
163 }
164
165 unsafe fn right_hash_tree<'a>(n: *mut Self) -> HashTree<'a> {
166 debug_assert!(!n.is_null());
167 if (*n).right.is_null() {
168 Empty
169 } else {
170 Pruned((*(*n).right).subtree_hash)
171 }
172 }
173
174 fn null() -> *mut Self {
175 std::ptr::null::<Self>() as *mut Node<K, V>
176 }
177
178 unsafe fn data_tree<'a>(n: *mut Self) -> HashTree<'a> {
179 debug_assert!(!n.is_null());
180 HashTree::Labeled((*n).key.as_label(), Box::new((*n).value.as_hash_tree()))
181 }
182
183 unsafe fn subtree_with<'a>(
184 n: *mut Self,
185 f: impl FnOnce(&'a V) -> HashTree<'a>,
186 ) -> HashTree<'a> {
187 debug_assert!(!n.is_null());
188
189 HashTree::Labeled((*n).key.as_label(), Box::new(f(&(*n).value)))
190 }
191
192 unsafe fn witness_tree<'a>(n: *mut Self) -> HashTree<'a> {
193 debug_assert!(!n.is_null());
194 let value_hash = (*n).value.root_hash();
195 HashTree::Labeled((*n).key.as_label(), Box::new(Pruned(value_hash)))
196 }
197
198 unsafe fn full_witness_tree<'a>(
199 n: *mut Self,
200 f: unsafe fn(*mut Self) -> HashTree<'a>,
201 ) -> HashTree<'a> {
202 if n.is_null() {
203 return Empty;
204 }
205 three_way_fork(
206 Self::full_witness_tree((*n).left, f),
207 f(n),
208 Self::full_witness_tree((*n).right, f),
209 )
210 }
211
212 unsafe fn delete(n: *mut Self) -> Option<(K, V)> {
213 if n.is_null() {
214 return None;
215 }
216 Self::delete((*n).left);
217 Self::delete((*n).right);
218 let node = Box::from_raw(n);
219
220 #[cfg(test)]
221 debug_alloc::mark_pointer_deleted(n);
222
223 Some((node.key, node.value))
224 }
225
226 unsafe fn subtree_hash(n: *mut Self) -> Hash {
227 if n.is_null() {
228 return Empty.reconstruct();
229 }
230
231 let h = Node::data_hash(n);
232
233 match ((*n).left.is_null(), (*n).right.is_null()) {
234 (true, true) => h,
235 (false, true) => fork_hash(&(*(*n).left).subtree_hash, &h),
236 (true, false) => fork_hash(&h, &(*(*n).right).subtree_hash),
237 (false, false) => fork_hash(
238 &(*(*n).left).subtree_hash,
239 &fork_hash(&h, &(*(*n).right).subtree_hash),
240 ),
241 }
242 }
243}
244
245pub struct RbTree<K: 'static + Label, V: AsHashTree + 'static> {
248 len: usize,
249 root: *mut Node<K, V>,
250}
251
252impl<K: 'static + Label, V: AsHashTree + 'static> Drop for RbTree<K, V> {
253 fn drop(&mut self) {
254 unsafe {
255 Node::delete(self.root);
256 }
257 }
258}
259
260impl<K: 'static + Label, V: AsHashTree + 'static> Default for RbTree<K, V> {
261 fn default() -> Self {
262 Self::new()
263 }
264}
265
266impl<K: 'static + Label, V: AsHashTree + 'static> RbTree<K, V> {
267 #[inline]
268 pub fn new() -> Self {
269 Self {
270 len: 0,
271 root: Node::null(),
272 }
273 }
274
275 #[inline]
276 pub fn len(&self) -> usize {
277 self.len
278 }
279
280 #[inline]
281 pub fn is_empty(&self) -> bool {
282 self.root.is_null()
283 }
284
285 pub fn entry(&mut self, key: K) -> entry::Entry<K, V> {
286 let node = unsafe { self.get_node(&key) };
287
288 if node.is_null() {
289 entry::Entry::Vacant(entry::VacantEntry { map: self, key })
290 } else {
291 entry::Entry::Occupied(entry::OccupiedEntry {
292 map: self,
293 key,
294 node,
295 })
296 }
297 }
298
299 #[inline]
300 pub fn get<Q: ?Sized>(&self, key: &Q) -> Option<&V>
301 where
302 K: Borrow<Q>,
303 Q: Ord,
304 {
305 unsafe {
306 let mut root = self.root;
307 while !root.is_null() {
308 match key.cmp((*root).key.borrow()) {
309 Equal => return Some(&(*root).value),
310 Less => root = (*root).left,
311 Greater => root = (*root).right,
312 }
313 }
314 None
315 }
316 }
317
318 #[inline]
319 pub fn get_with(&self, cmp: impl Fn(&K) -> Ordering) -> Option<&V> {
320 unsafe {
321 let mut root = self.root;
322 while !root.is_null() {
323 match cmp(&(*root).key) {
324 Equal => return Some(&(*root).value),
325 Less => root = (*root).left,
326 Greater => root = (*root).right,
327 }
328 }
329 None
330 }
331 }
332
333 #[inline]
334 unsafe fn get_node(&self, key: &K) -> *mut Node<K, V> {
335 let mut root = self.root;
336 while !root.is_null() {
337 match key.cmp(&(*root).key) {
338 Equal => return root,
339 Less => root = (*root).left,
340 Greater => root = (*root).right,
341 }
342 }
343 Node::null()
344 }
345
346 #[inline]
348 pub fn modify<'a, Q: ?Sized, T>(&mut self, key: &Q, f: impl FnOnce(&'a mut V) -> T) -> Option<T>
349 where
350 K: Borrow<Q>,
351 Q: Ord,
352 {
353 unsafe fn go<'a, K: 'static + Label, V: AsHashTree + 'static, T, Q: ?Sized>(
354 mut h: *mut Node<K, V>,
355 k: &Q,
356 f: impl FnOnce(&'a mut V) -> T,
357 ) -> Option<T>
358 where
359 K: Borrow<Q>,
360 Q: Ord,
361 {
362 if h.is_null() {
363 return None;
364 }
365
366 match k.cmp((*h).key.borrow()) {
367 Equal => {
368 let res = f(&mut (*h).value);
369 (*h).subtree_hash = Node::subtree_hash(h);
370 Some(res)
371 }
372 Less => {
373 let res = go((*h).left, k, f);
374 (*h).subtree_hash = Node::subtree_hash(h);
375 res
376 }
377 Greater => {
378 let res = go((*h).right, k, f);
379 (*h).subtree_hash = Node::subtree_hash(h);
380 res
381 }
382 }
383 }
384 unsafe { go(self.root, key, f) }
385 }
386
387 pub fn modify_max_with_prefix<'a, P: ?Sized, T>(
389 &mut self,
390 prefix: &P,
391 f: impl FnOnce(&'a K, &'a mut V) -> T,
392 ) -> Option<T>
393 where
394 K: Prefix<P>,
395 P: Ord,
396 {
397 unsafe fn go<
398 'a,
399 K: Label + 'static,
400 V: AsHashTree + 'static,
401 P: ?Sized,
402 T,
403 F: FnOnce(&'a K, &'a mut V) -> T,
404 >(
405 mut h: *mut Node<K, V>,
406 prefix: &P,
407 f: F,
408 ) -> (Option<T>, Option<F>)
409 where
410 K: Prefix<P>,
411 P: Ord,
412 {
413 if h.is_null() {
414 return (None, Some(f));
415 }
416
417 let node_key = &(*h).key;
418 let key_prefix = node_key.borrow();
419
420 let res = match key_prefix.cmp(prefix) {
421 Greater | Equal if node_key.is_prefix(prefix) => match go((*h).right, prefix, f) {
422 (None, Some(f)) => {
423 let ret = f(node_key, &mut (*h).value);
424 (Some(ret), None)
425 }
426 ret => ret,
427 },
428 Greater => go((*h).left, prefix, f),
429 Less | Equal => go((*h).right, prefix, f),
430 };
431
432 if res.0.is_some() {
433 (*h).subtree_hash = Node::subtree_hash(h);
434 }
435
436 res
437 }
438
439 unsafe { go(self.root, prefix, f).0 }
440 }
441
442 pub fn max_entry_with_prefix<P: ?Sized>(&self, prefix: &P) -> Option<(&K, &V)>
443 where
444 K: Prefix<P>,
445 P: Ord,
446 {
447 unsafe fn go<'a, K: 'static + Label, V, P: ?Sized>(
448 n: *mut Node<K, V>,
449 prefix: &P,
450 ) -> Option<(&'a K, &'a V)>
451 where
452 K: Prefix<P>,
453 P: Ord,
454 {
455 if n.is_null() {
456 return None;
457 }
458
459 let node_key = &(*n).key;
460 let key_prefix = node_key.borrow();
461 match key_prefix.cmp(prefix) {
462 Greater | Equal if node_key.is_prefix(prefix) => {
463 go((*n).right, prefix).or(Some((node_key, &(*n).value)))
464 }
465 Greater => go((*n).left, prefix),
466 Less | Equal => go((*n).right, prefix),
467 }
468 }
469 unsafe { go(self.root, prefix) }
470 }
471
472 fn range_witness<'a>(
473 &'a self,
474 left: Option<KeyBound<'a, K>>,
475 right: Option<KeyBound<'a, K>>,
476 f: unsafe fn(*mut Node<K, V>) -> HashTree<'a>,
477 ) -> HashTree<'a> {
478 match (left, right) {
479 (None, None) => unsafe { Node::full_witness_tree(self.root, f) },
480 (Some(l), None) => self.witness_range_above(l, f),
481 (None, Some(r)) => self.witness_range_below(r, f),
482 (Some(l), Some(r)) => self.witness_range_between(l, r, f),
483 }
484 }
485
486 #[inline]
492 pub fn witness<Q: ?Sized>(&self, key: &Q) -> HashTree<'_>
493 where
494 K: Borrow<Q>,
495 Q: Ord,
496 {
497 self.nested_witness(key, |v| v.as_hash_tree())
498 }
499
500 #[inline]
504 pub fn nested_witness<'a, Q: ?Sized>(
505 &'a self,
506 key: &Q,
507 f: impl FnOnce(&'a V) -> HashTree<'a>,
508 ) -> HashTree<'a>
509 where
510 K: Borrow<Q>,
511 Q: Ord,
512 {
513 if let Some(t) = self.lookup_and_build_witness(key, f) {
514 return t;
515 }
516 self.range_witness(
517 self.lower_bound(key),
518 self.upper_bound(key),
519 Node::witness_tree,
520 )
521 }
522
523 #[inline]
527 pub fn keys(&self) -> HashTree<'_> {
528 unsafe { Node::full_witness_tree(self.root, Node::witness_tree) }
529 }
530
531 #[inline]
535 pub fn key_range<Q1: ?Sized, Q2: ?Sized>(&self, first: &Q1, last: &Q2) -> HashTree<'_>
536 where
537 K: Borrow<Q1> + Borrow<Q2>,
538 Q1: Ord,
539 Q2: Ord,
540 {
541 self.range_witness(
542 self.lower_bound(first),
543 self.upper_bound(last),
544 Node::witness_tree,
545 )
546 }
547
548 #[inline]
551 pub fn value_range<Q1: ?Sized, Q2: ?Sized>(&self, first: &Q1, last: &Q2) -> HashTree<'_>
552 where
553 K: Borrow<Q1> + Borrow<Q2>,
554 Q1: Ord,
555 Q2: Ord,
556 {
557 self.range_witness(
558 self.lower_bound(first),
559 self.upper_bound(last),
560 Node::data_tree,
561 )
562 }
563
564 #[inline]
567 pub fn keys_with_prefix<P: ?Sized>(&self, prefix: &P) -> HashTree<'_>
568 where
569 K: Prefix<P>,
570 P: Ord,
571 {
572 self.range_witness(
573 self.lower_bound(prefix),
574 self.right_prefix_neighbor(prefix),
575 Node::witness_tree,
576 )
577 }
578
579 #[inline]
581 pub fn for_each<'a, F>(&'a self, mut f: F)
582 where
583 F: 'a + FnMut(&'a K, &'a V),
584 {
585 unsafe fn visit<'a, K, V, F>(n: *mut Node<K, V>, f: &mut F)
586 where
587 F: 'a + FnMut(&'a K, &'a V),
588 K: 'static + Label,
589 V: 'a + AsHashTree,
590 {
591 debug_assert!(!n.is_null());
592 if !(*n).left.is_null() {
593 visit((*n).left, f)
594 }
595 (*f)(&(*n).key, &(*n).value);
596 if !(*n).right.is_null() {
597 visit((*n).right, f)
598 }
599 }
600 if self.root.is_null() {
601 return;
602 }
603 unsafe { visit(self.root, &mut f) }
604 }
605
606 fn witness_range_above<'a>(
607 &'a self,
608 lo: KeyBound<'a, K>,
609 f: unsafe fn(*mut Node<K, V>) -> HashTree<'a>,
610 ) -> HashTree<'a> {
611 unsafe fn go<'a, K: 'static + Label, V: AsHashTree + 'static>(
612 n: *mut Node<K, V>,
613 lo: KeyBound<'a, K>,
614 f: unsafe fn(*mut Node<K, V>) -> HashTree<'a>,
615 ) -> HashTree<'a> {
616 if n.is_null() {
617 return Empty;
618 }
619 match (*n).key.cmp(lo.as_ref()) {
620 Equal => three_way_fork(
621 Node::left_hash_tree(n),
622 match lo {
623 KeyBound::Exact(_) => f(n),
624 KeyBound::Neighbor(_) => Node::witness_tree(n),
625 },
626 Node::full_witness_tree((*n).right, f),
627 ),
628 Less => three_way_fork(
629 Node::left_hash_tree(n),
630 Pruned(Node::data_hash(n)),
631 go((*n).right, lo, f),
632 ),
633 Greater => three_way_fork(
634 go((*n).left, lo, f),
635 f(n),
636 Node::full_witness_tree((*n).right, f),
637 ),
638 }
639 }
640 unsafe { go(self.root, lo, f) }
641 }
642
643 fn witness_range_below<'a>(
644 &'a self,
645 hi: KeyBound<'a, K>,
646 f: unsafe fn(*mut Node<K, V>) -> HashTree<'a>,
647 ) -> HashTree<'a> {
648 unsafe fn go<'a, K: 'static + Label, V: AsHashTree + 'static>(
649 n: *mut Node<K, V>,
650 hi: KeyBound<'a, K>,
651 f: unsafe fn(*mut Node<K, V>) -> HashTree<'a>,
652 ) -> HashTree<'a> {
653 if n.is_null() {
654 return Empty;
655 }
656 match (*n).key.cmp(hi.as_ref()) {
657 Equal => three_way_fork(
658 Node::full_witness_tree((*n).left, f),
659 match hi {
660 KeyBound::Exact(_) => f(n),
661 KeyBound::Neighbor(_) => Node::witness_tree(n),
662 },
663 Node::right_hash_tree(n),
664 ),
665 Greater => three_way_fork(
666 go((*n).left, hi, f),
667 Pruned(Node::data_hash(n)),
668 Node::right_hash_tree(n),
669 ),
670 Less => three_way_fork(
671 Node::full_witness_tree((*n).left, f),
672 f(n),
673 go((*n).right, hi, f),
674 ),
675 }
676 }
677 unsafe { go(self.root, hi, f) }
678 }
679
680 fn witness_range_between<'a>(
681 &'a self,
682 lo: KeyBound<'a, K>,
683 hi: KeyBound<'a, K>,
684 f: unsafe fn(*mut Node<K, V>) -> HashTree<'a>,
685 ) -> HashTree<'a> {
686 debug_assert!(
687 lo.as_ref() <= hi.as_ref(),
688 "lo = {:?} > hi = {:?}",
689 lo.as_ref().as_label(),
690 hi.as_ref().as_label()
691 );
692 unsafe fn go<'a, K: 'static + Label, V: AsHashTree + 'static>(
693 n: *mut Node<K, V>,
694 lo: KeyBound<'a, K>,
695 hi: KeyBound<'a, K>,
696 f: unsafe fn(*mut Node<K, V>) -> HashTree<'a>,
697 ) -> HashTree<'a> {
698 if n.is_null() {
699 return Empty;
700 }
701 let k = &(*n).key;
702 match (lo.as_ref().cmp(k), k.cmp(hi.as_ref())) {
703 (Less, Less) => {
704 let left = go((*n).left, lo, hi, f);
705 let right = go((*n).right, lo, hi, f);
706 three_way_fork(left, f(n), right)
707 }
708 (Equal, Equal) => three_way_fork(
709 Node::left_hash_tree(n),
710 match (lo, hi) {
711 (KeyBound::Exact(_), _) => f(n),
712 (_, KeyBound::Exact(_)) => f(n),
713 _ => Node::witness_tree(n),
714 },
715 Node::right_hash_tree(n),
716 ),
717 (_, Equal) => three_way_fork(
718 go((*n).left, lo, hi, f),
719 match hi {
720 KeyBound::Exact(_) => f(n),
721 KeyBound::Neighbor(_) => Node::witness_tree(n),
722 },
723 Node::right_hash_tree(n),
724 ),
725 (Equal, _) => three_way_fork(
726 Node::left_hash_tree(n),
727 match lo {
728 KeyBound::Exact(_) => f(n),
729 KeyBound::Neighbor(_) => Node::witness_tree(n),
730 },
731 go((*n).right, lo, hi, f),
732 ),
733 (Less, Greater) => three_way_fork(
734 go((*n).left, lo, hi, f),
735 Pruned(Node::data_hash(n)),
736 Node::right_hash_tree(n),
737 ),
738 (Greater, Less) => three_way_fork(
739 Node::left_hash_tree(n),
740 Pruned(Node::data_hash(n)),
741 go((*n).right, lo, hi, f),
742 ),
743 _ => Pruned((*n).subtree_hash),
744 }
745 }
746 unsafe { go(self.root, lo, hi, f) }
747 }
748
749 fn lower_bound<Q: ?Sized>(&self, key: &Q) -> Option<KeyBound<'_, K>>
750 where
751 K: Borrow<Q>,
752 Q: Ord,
753 {
754 unsafe fn go<'a, K: 'static + Label, V, Q: ?Sized>(
755 n: *mut Node<K, V>,
756 key: &Q,
757 ) -> Option<KeyBound<'a, K>>
758 where
759 K: Borrow<Q>,
760 Q: Ord,
761 {
762 if n.is_null() {
763 return None;
764 }
765 let node_key = &(*n).key;
766 match node_key.borrow().cmp(key) {
767 Less => go((*n).right, key).or(Some(KeyBound::Neighbor(node_key))),
768 Equal => Some(KeyBound::Exact(node_key)),
769 Greater => go((*n).left, key),
770 }
771 }
772 unsafe { go(self.root, key) }
773 }
774
775 fn upper_bound<Q: ?Sized>(&self, key: &Q) -> Option<KeyBound<'_, K>>
776 where
777 K: Borrow<Q>,
778 Q: Ord,
779 {
780 unsafe fn go<'a, K: 'static + Label, V, Q: ?Sized>(
781 n: *mut Node<K, V>,
782 key: &Q,
783 ) -> Option<KeyBound<'a, K>>
784 where
785 K: Borrow<Q>,
786 Q: Ord,
787 {
788 if n.is_null() {
789 return None;
790 }
791 let node_key = &(*n).key;
792 match node_key.borrow().cmp(key) {
793 Less => go((*n).right, key),
794 Equal => Some(KeyBound::Exact(node_key)),
795 Greater => go((*n).left, key).or(Some(KeyBound::Neighbor(node_key))),
796 }
797 }
798 unsafe { go(self.root, key) }
799 }
800
801 fn right_prefix_neighbor<P: ?Sized>(&self, prefix: &P) -> Option<KeyBound<'_, K>>
802 where
803 K: Prefix<P>,
804 P: Ord,
805 {
806 unsafe fn go<'a, K: 'static + Label, V, P: ?Sized>(
807 n: *mut Node<K, V>,
808 prefix: &P,
809 ) -> Option<KeyBound<'a, K>>
810 where
811 K: Prefix<P>,
812 P: Ord,
813 {
814 if n.is_null() {
815 return None;
816 }
817 let node_key = &(*n).key;
818 let key_prefix = node_key.borrow();
819 match key_prefix.cmp(prefix) {
820 Greater if node_key.is_prefix(prefix) => go((*n).right, prefix),
821 Greater => go((*n).left, prefix).or(Some(KeyBound::Neighbor(node_key))),
822 Less | Equal => go((*n).right, prefix),
823 }
824 }
825 unsafe { go(self.root, prefix) }
826 }
827
828 fn lookup_and_build_witness<'a, Q: ?Sized>(
829 &'a self,
830 key: &Q,
831 f: impl FnOnce(&'a V) -> HashTree<'a>,
832 ) -> Option<HashTree<'a>>
833 where
834 K: Borrow<Q>,
835 Q: Ord,
836 {
837 unsafe fn go<'a, K: 'static + Label, V: AsHashTree + 'static, Q: ?Sized>(
838 n: *mut Node<K, V>,
839 key: &Q,
840 f: impl FnOnce(&'a V) -> HashTree<'a>,
841 ) -> Option<HashTree<'a>>
842 where
843 K: Borrow<Q>,
844 Q: Ord,
845 {
846 if n.is_null() {
847 return None;
848 }
849 match key.cmp((*n).key.borrow()) {
850 Equal => Some(three_way_fork(
851 Node::left_hash_tree(n),
852 Node::subtree_with(n, f),
853 Node::right_hash_tree(n),
854 )),
855 Less => {
856 let subtree = go((*n).left, key, f)?;
857 Some(three_way_fork(
858 subtree,
859 Pruned(Node::data_hash(n)),
860 Node::right_hash_tree(n),
861 ))
862 }
863 Greater => {
864 let subtree = go((*n).right, key, f)?;
865 Some(three_way_fork(
866 Node::left_hash_tree(n),
867 Pruned(Node::data_hash(n)),
868 subtree,
869 ))
870 }
871 }
872 }
873 unsafe { go(self.root, key, f) }
874 }
875
876 #[inline]
878 pub fn insert(&mut self, key: K, value: V) -> (Option<V>, &mut V) {
879 struct GoResult<'a, K, V> {
880 node: *mut Node<K, V>,
881 old_value: Option<V>,
882 new_value_ref: &'a mut V,
883 }
884
885 unsafe fn go<K: 'static + Label, V: AsHashTree + 'static>(
886 mut h: *mut Node<K, V>,
887 k: K,
888 mut v: V,
889 ) -> GoResult<'static, K, V> {
890 if h.is_null() {
891 let node = Node::new(k, v);
892 return GoResult {
893 node,
894 old_value: None,
895 new_value_ref: &mut (*node).value,
896 };
897 }
898
899 let (old_value, new_value_ref) = match k.cmp(&(*h).key) {
900 Equal => {
901 std::mem::swap(&mut (*h).value, &mut v);
902 (*h).subtree_hash = Node::subtree_hash(h);
903 (Some(v), &mut (*h).value)
904 }
905 Less => {
906 let res = go((*h).left, k, v);
907 (*h).left = res.node;
908 (*h).subtree_hash = Node::subtree_hash(h);
909 (res.old_value, res.new_value_ref)
910 }
911 Greater => {
912 let res = go((*h).right, k, v);
913 (*h).right = res.node;
914 (*h).subtree_hash = Node::subtree_hash(h);
915 (res.old_value, res.new_value_ref)
916 }
917 };
918
919 GoResult {
920 node: balance(h),
921 old_value,
922 new_value_ref,
923 }
924 }
925
926 unsafe {
927 let mut result = go(self.root, key, value);
928 (*result.node).color = Color::Black;
929
930 #[cfg(test)]
931 debug_assert!(
932 is_balanced(result.node),
933 "the tree is not balanced:\n{:?}",
934 DebugView(result.node)
935 );
936 #[cfg(test)]
937 debug_assert!(!has_dangling_pointers(result.node));
938
939 if result.old_value.is_none() {
940 self.len += 1;
941 }
942
943 self.root = result.node;
944 (result.old_value, result.new_value_ref)
945 }
946 }
947
948 #[inline]
950 pub fn delete<Q: ?Sized>(&mut self, key: &Q) -> Option<(K, V)>
951 where
952 K: Borrow<Q>,
953 Q: Ord,
954 {
955 unsafe fn move_red_left<K: 'static + Label, V: AsHashTree + 'static>(
956 mut h: *mut Node<K, V>,
957 ) -> *mut Node<K, V> {
958 flip_colors(h);
959 if is_red((*(*h).right).left) {
960 (*h).right = rotate_right((*h).right);
961 h = rotate_left(h);
962 flip_colors(h);
963 }
964 h
965 }
966
967 unsafe fn move_red_right<K: 'static + Label, V: AsHashTree + 'static>(
968 mut h: *mut Node<K, V>,
969 ) -> *mut Node<K, V> {
970 flip_colors(h);
971 if is_red((*(*h).left).left) {
972 h = rotate_right(h);
973 flip_colors(h);
974 }
975 h
976 }
977
978 #[inline]
979 unsafe fn min<K: 'static + Label, V: AsHashTree + 'static>(
980 mut h: *mut Node<K, V>,
981 ) -> *mut Node<K, V> {
982 while !(*h).left.is_null() {
983 h = (*h).left;
984 }
985 h
986 }
987
988 unsafe fn delete_min<K: 'static + Label, V: AsHashTree + 'static>(
989 mut h: *mut Node<K, V>,
990 result: &mut Option<(K, V)>,
991 ) -> *mut Node<K, V> {
992 if (*h).left.is_null() {
993 debug_assert!((*h).right.is_null());
994 *result = Some(Node::delete(h).unwrap());
995 return Node::null();
996 }
997 if !is_red((*h).left) && !is_red((*(*h).left).left) {
998 h = move_red_left(h);
999 }
1000 (*h).left = delete_min((*h).left, result);
1001 (*h).subtree_hash = Node::subtree_hash(h);
1002 balance(h)
1003 }
1004
1005 unsafe fn go<K: 'static + Label, V: AsHashTree + 'static, Q: ?Sized>(
1006 mut h: *mut Node<K, V>,
1007 result: &mut Option<(K, V)>,
1008 key: &Q,
1009 ) -> *mut Node<K, V>
1010 where
1011 K: Borrow<Q>,
1012 Q: Ord,
1013 {
1014 if key < (*h).key.borrow() {
1015 if !is_red((*h).left) && !is_red((*(*h).left).left) {
1016 h = move_red_left(h);
1017 }
1018 (*h).left = go((*h).left, result, key);
1019 } else {
1020 if is_red((*h).left) {
1021 h = rotate_right(h);
1022 }
1023 if key == (*h).key.borrow() && (*h).right.is_null() {
1024 debug_assert!((*h).left.is_null());
1025 *result = Some(Node::delete(h).unwrap());
1026 return Node::null();
1027 }
1028
1029 if !is_red((*h).right) && !is_red((*(*h).right).left) {
1030 h = move_red_right(h);
1031 }
1032
1033 if key == (*h).key.borrow() {
1034 let m = min((*h).right);
1035 std::mem::swap(&mut (*h).key, &mut (*m).key);
1036 std::mem::swap(&mut (*h).value, &mut (*m).value);
1037 (*h).right = delete_min((*h).right, result);
1038 } else {
1039 (*h).right = go((*h).right, result, key);
1040 }
1041 }
1042 (*h).subtree_hash = Node::subtree_hash(h);
1043 balance(h)
1044 }
1045
1046 unsafe {
1047 self.get(key)?;
1048 if !is_red((*self.root).left) && !is_red((*self.root).right) {
1049 (*self.root).color = Color::Red;
1050 }
1051
1052 let mut result = None;
1053 self.root = go(self.root, &mut result, key);
1054 if !self.root.is_null() {
1055 (*self.root).color = Color::Black;
1056 }
1057
1058 #[cfg(test)]
1059 debug_assert!(
1060 is_balanced(self.root),
1061 "unbalanced map: {:?}",
1062 DebugView(self.root)
1063 );
1064
1065 #[cfg(test)]
1066 debug_assert!(result.is_some());
1067 self.len -= 1;
1068
1069 debug_assert!(self.get(key).is_none());
1070 result
1071 }
1072 }
1073}
1074
1075fn three_way_fork<'a>(l: HashTree<'a>, m: HashTree<'a>, r: HashTree<'a>) -> HashTree<'a> {
1076 match (l, m, r) {
1077 (Empty, m, Empty) => m,
1078 (l, m, Empty) => fork(l, m),
1079 (Empty, m, r) => fork(m, r),
1080 (Pruned(lhash), Pruned(mhash), Pruned(rhash)) => {
1081 Pruned(fork_hash(&lhash, &fork_hash(&mhash, &rhash)))
1082 }
1083 (l, Pruned(mhash), Pruned(rhash)) => fork(l, Pruned(fork_hash(&mhash, &rhash))),
1084 (l, m, r) => fork(l, fork(m, r)),
1085 }
1086}
1087
1088unsafe fn is_red<K, V>(x: *const Node<K, V>) -> bool {
1090 if x.is_null() {
1091 false
1092 } else {
1093 (*x).color == Color::Red
1094 }
1095}
1096
1097unsafe fn balance<K: Label + 'static, V: AsHashTree + 'static>(
1098 mut h: *mut Node<K, V>,
1099) -> *mut Node<K, V> {
1100 assert!(!h.is_null());
1101
1102 if is_red((*h).right) && !is_red((*h).left) {
1103 h = rotate_left(h);
1104 }
1105 if is_red((*h).left) && is_red((*(*h).left).left) {
1106 h = rotate_right(h);
1107 }
1108 if is_red((*h).left) && is_red((*h).right) {
1109 flip_colors(h)
1110 }
1111 h
1112}
1113
1114unsafe fn rotate_right<K: 'static + Label, V: AsHashTree + 'static>(
1116 h: *mut Node<K, V>,
1117) -> *mut Node<K, V> {
1118 debug_assert!(!h.is_null());
1119 debug_assert!(is_red((*h).left));
1120
1121 let mut x = (*h).left;
1122 (*h).left = (*x).right;
1123 (*x).right = h;
1124 (*x).color = (*(*x).right).color;
1125 (*(*x).right).color = Color::Red;
1126
1127 (*h).subtree_hash = Node::subtree_hash(h);
1128 (*x).subtree_hash = Node::subtree_hash(x);
1129
1130 x
1131}
1132
1133unsafe fn rotate_left<K: 'static + Label, V: AsHashTree + 'static>(
1134 h: *mut Node<K, V>,
1135) -> *mut Node<K, V> {
1136 debug_assert!(!h.is_null());
1137 debug_assert!(is_red((*h).right));
1138
1139 let mut x = (*h).right;
1140 (*h).right = (*x).left;
1141 (*x).left = h;
1142 (*x).color = (*(*x).left).color;
1143 (*(*x).left).color = Color::Red;
1144
1145 (*h).subtree_hash = Node::subtree_hash(h);
1146 (*x).subtree_hash = Node::subtree_hash(x);
1147
1148 x
1149}
1150
1151unsafe fn flip_colors<K, V>(h: *mut Node<K, V>) {
1152 (*h).color = (*h).color.flip();
1153 (*(*h).left).color = (*(*h).left).color.flip();
1154 (*(*h).right).color = (*(*h).right).color.flip();
1155}
1156
1157#[cfg(test)]
1158unsafe fn is_balanced<K, V>(root: *mut Node<K, V>) -> bool {
1159 unsafe fn go<K, V>(node: *mut Node<K, V>, mut num_black: usize) -> bool {
1160 if node.is_null() {
1161 return num_black == 0;
1162 }
1163 if !is_red(node) {
1164 debug_assert!(num_black > 0);
1165 num_black -= 1;
1166 } else {
1167 assert!(!is_red((*node).left));
1168 assert!(!is_red((*node).right));
1169 }
1170 go((*node).left, num_black) && go((*node).right, num_black)
1171 }
1172
1173 let mut num_black = 0;
1174 let mut x = root;
1175 while !x.is_null() {
1176 if !is_red(x) {
1177 num_black += 1;
1178 }
1179 x = (*x).left;
1180 }
1181 go(root, num_black)
1182}
1183
1184#[cfg(test)]
1185unsafe fn has_dangling_pointers<K, V>(root: *mut Node<K, V>) -> bool {
1186 if root.is_null() {
1187 return false;
1188 }
1189
1190 !debug_alloc::is_live(root)
1191 || has_dangling_pointers((*root).left)
1192 || has_dangling_pointers((*root).right)
1193}
1194
1195struct DebugView<K, V>(*const Node<K, V>);
1196
1197impl<K: Label, V> fmt::Debug for DebugView<K, V> {
1198 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1199 unsafe fn go<K: Label, V>(
1200 f: &mut fmt::Formatter<'_>,
1201 h: *const Node<K, V>,
1202 offset: usize,
1203 ) -> fmt::Result {
1204 if h.is_null() {
1205 writeln!(f, "{:width$}[B] <null>", "", width = offset)
1206 } else {
1207 writeln!(
1208 f,
1209 "{:width$}[{}] {:?}",
1210 "",
1211 if is_red(h) { "R" } else { "B" },
1212 (*h).key.as_label(),
1213 width = offset
1214 )?;
1215 go(f, (*h).left, offset + 2)?;
1216 go(f, (*h).right, offset + 2)
1217 }
1218 }
1219 unsafe { go(f, self.0, 0) }
1220 }
1221}
1222
1223#[cfg(test)]
1224mod test;