balanced_tree_index/
lib.rs

1// Copyright (c) 2018-2023 The MobileCoin Foundation
2
3#![no_std]
4#![deny(missing_docs)]
5#![deny(unsafe_code)]
6
7//! Defines an interface for a type that represents an index into a
8//! memory-mapped complete balanced binary tree.
9//!
10//! The operations that we define mostly help with finding parents or common
11//! ancestors in the tree structure.
12//!
13//! This type is usually u32 or u64, and these operations are usually performed
14//! using bit-twiddling tricks. Coding against this API means that people
15//! reading ORAM code don't necessarily have to understand all the bit-twiddling
16//! tricks.
17
18use aligned_cmov::{
19    subtle::{ConstantTimeEq, ConstantTimeLess},
20    CMov,
21};
22use rand_core::RngCore;
23
24/// Trait representing a type that can represent a tree index in a balanced
25/// binary tree, using the numbering where the root is 1, and nodes are labelled
26/// consecutively level by level, using lexicographic order within a level.
27///
28/// All operations here should be constant time, leaking nothing about the input
29/// and &self, unless otherwise stated.
30pub trait TreeIndex: Copy + Eq + PartialEq + CMov {
31    /// The Zero index that is unused and does not actually refer to a node in
32    /// the tree.
33    const NONE: Self;
34
35    /// The index of the root of the tree, logically 1.
36    /// The parent of ROOT is NONE.
37    const ROOT: Self;
38
39    /// Find the i'th parent of a node.
40    fn parent(&self, i: u32) -> Self;
41
42    /// Find the height of a node.
43    /// This returns u32 because rust's count_leading_zeros does.
44    /// It is illegal to call this when self is the NONE value.
45    fn height(&self) -> u32;
46
47    /// For two nodes promised to be "peers" i.e. at the same height,
48    /// compute the distance from (either) to their common ancestor in the tree.
49    /// This is the number of times you have to compute "parent" before they are
50    /// equal. It is illegal to call this if the height of the two arguments
51    /// is not the same. Should not reveal anything else about the
52    /// arguments.
53    fn common_ancestor_distance_of_peers(&self, other: &Self) -> u32;
54
55    /// Compute the height of the common ancestor of any two nodes.
56    /// It is illegal to call this when either of the inputs is the NONE value.
57    fn common_ancestor_height(&self, other: &Self) -> u32 {
58        let ht_self = self.height();
59        let ht_other = other.height();
60
61        // Take the min in constant time of the two heights
62        let ht_min = {
63            let mut ht_min = ht_self;
64            ht_min.cmov(ht_other.ct_lt(&ht_self), &ht_other);
65            ht_min
66        };
67
68        let adjusted_self = self.parent(ht_self.wrapping_sub(ht_min));
69        let adjusted_other = other.parent(ht_other.wrapping_sub(ht_min));
70
71        debug_assert!(adjusted_self.height() == ht_min);
72        debug_assert!(adjusted_other.height() == ht_min);
73
74        let dist = adjusted_self.common_ancestor_distance_of_peers(&adjusted_other);
75        debug_assert!(dist <= ht_min);
76
77        ht_min.wrapping_sub(dist)
78    }
79
80    /// Random child at a given height.
81    /// This height must be the same or less than the height of the given node,
82    /// otherwise the call is illegal.
83    /// It is legal to call this on the NONE value, it will be as if ROOT was
84    /// passed.
85    fn random_child_at_height<R: RngCore>(&self, height: u32, rng: &mut R) -> Self;
86
87    /// Iterate over the parents of this node, including self.
88    /// Access patterns when evaluating this iterator reveal the height of self,
89    /// but not more than that.
90    fn parents(&self) -> ParentsIterator<Self> {
91        ParentsIterator::from(*self)
92    }
93}
94
95/// Iterator type over the sequence of parents of a TreeIndex
96pub struct ParentsIterator<I: TreeIndex> {
97    internal: I,
98}
99
100impl<I: TreeIndex> Iterator for ParentsIterator<I> {
101    type Item = I;
102
103    fn next(&mut self) -> Option<Self::Item> {
104        if self.internal == I::NONE {
105            None
106        } else {
107            let temp = self.internal;
108            self.internal = self.internal.parent(1);
109            Some(temp)
110        }
111    }
112}
113
114impl<I: TreeIndex> From<I> for ParentsIterator<I> {
115    fn from(internal: I) -> Self {
116        Self { internal }
117    }
118}
119
120// Implements TreeIndex for a type like u32 or u64
121// Because we need things like count_leading_ones and ::MAX and there are no
122// traits in the language for this, it is painful to do without macros.
123macro_rules! implement_tree_index_for_primitive {
124    ($uint:ty) => {
125        impl TreeIndex for $uint {
126            const NONE: $uint = 0;
127            const ROOT: $uint = 1;
128            fn parent(&self, i: u32) -> Self {
129                self >> i
130            }
131            fn height(&self) -> u32 {
132                debug_assert!(*self != 0);
133                const DIGITS_MINUS_ONE: u32 = <$uint>::MAX.leading_ones() - 1;
134                // Wrapping sub is used to avoid panics
135                // Note: We assume that leading_zeroes is compiling down to ctlz
136                // and is constant time.
137                DIGITS_MINUS_ONE.wrapping_sub(self.leading_zeros())
138            }
139            fn common_ancestor_distance_of_peers(&self, other: &Self) -> u32 {
140                debug_assert!(self.height() == other.height());
141                const DIGITS: u32 = <$uint>::MAX.leading_ones();
142                // Wrapping sub is used to avoid panics
143                // Note: We assume that leading_zeroes is compiling down to ctlz
144                // and is constant time.
145                DIGITS.wrapping_sub((self ^ other).leading_zeros())
146            }
147            fn random_child_at_height<R: RngCore>(&self, height: u32, rng: &mut R) -> Self {
148                // Make a copy of self that we can conditionally overwrite in case of none
149                let mut myself = *self;
150                myself.cmov(myself.ct_eq(&Self::NONE), &Self::ROOT);
151
152                // Wrapping sub is used to avoid panic, branching, in production
153                debug_assert!(height >= myself.height());
154                let num_bits_needed = height.wrapping_sub(myself.height());
155
156                // Note: Would be nice to use mc_util_from_random here instead of (next_u64 as
157                // $uint) Here we are taking the u64, casting to self, then masking it
158                // with bit mask for low order bits equal to number of random bits
159                // needed.
160                let randomness =
161                    (rng.next_u64() as $uint) & (((1 as $uint) << num_bits_needed) - 1);
162
163                // We shift myself over and xor in the random bits.
164                (myself << num_bits_needed) ^ randomness
165            }
166        }
167    };
168}
169
170implement_tree_index_for_primitive!(u32);
171implement_tree_index_for_primitive!(u64);
172
173#[cfg(test)]
174mod testing {
175    use super::*;
176    extern crate alloc;
177    use alloc::vec;
178
179    use alloc::vec::Vec;
180
181    // Helper that takes a ParentsIterator and returns a Vec
182    fn collect_to_vec<I: TreeIndex>(it: ParentsIterator<I>) -> Vec<I> {
183        it.collect()
184    }
185
186    // Test height calculations
187    #[test]
188    fn test_height_u64() {
189        assert_eq!(1u64.height(), 0);
190        assert_eq!(2u64.height(), 1);
191        assert_eq!(3u64.height(), 1);
192        assert_eq!(4u64.height(), 2);
193        assert_eq!(5u64.height(), 2);
194        assert_eq!(6u64.height(), 2);
195        assert_eq!(7u64.height(), 2);
196        assert_eq!(8u64.height(), 3);
197        assert_eq!(9u64.height(), 3);
198        assert_eq!(10u64.height(), 3);
199        assert_eq!(11u64.height(), 3);
200        assert_eq!(12u64.height(), 3);
201        assert_eq!(13u64.height(), 3);
202        assert_eq!(14u64.height(), 3);
203        assert_eq!(15u64.height(), 3);
204        assert_eq!(16u64.height(), 4);
205    }
206
207    // Test height calculations
208    #[test]
209    fn test_height_u32() {
210        assert_eq!(1u32.height(), 0);
211        assert_eq!(2u32.height(), 1);
212        assert_eq!(3u32.height(), 1);
213        assert_eq!(4u32.height(), 2);
214        assert_eq!(5u32.height(), 2);
215        assert_eq!(6u32.height(), 2);
216        assert_eq!(7u32.height(), 2);
217        assert_eq!(8u32.height(), 3);
218        assert_eq!(9u32.height(), 3);
219        assert_eq!(10u32.height(), 3);
220        assert_eq!(11u32.height(), 3);
221        assert_eq!(12u32.height(), 3);
222        assert_eq!(13u32.height(), 3);
223        assert_eq!(14u32.height(), 3);
224        assert_eq!(15u32.height(), 3);
225        assert_eq!(16u32.height(), 4);
226    }
227
228    // Test random_child_at_height
229    #[test]
230    fn test_random_child_at_height_u64() {
231        test_helper::run_with_several_seeds(|mut rng| {
232            for ht in 0..40 {
233                for _ in 0..10 {
234                    let node = 1u64.random_child_at_height(ht, &mut rng);
235                    assert_eq!(node.height(), ht);
236                }
237            }
238
239            for ht in 20..40 {
240                for _ in 0..10 {
241                    let node = 10u64.random_child_at_height(ht, &mut rng);
242                    assert_eq!(node.height(), ht);
243                    assert!(node.parents().any(|x| x == 10u64))
244                }
245            }
246        })
247    }
248
249    // Test random_child_at_height
250    #[test]
251    fn test_random_child_at_height_u32() {
252        test_helper::run_with_several_seeds(|mut rng| {
253            for ht in 0..30 {
254                for _ in 0..10 {
255                    let node = 1u32.random_child_at_height(ht, &mut rng);
256                    assert_eq!(node.height(), ht);
257                }
258            }
259
260            for ht in 20..30 {
261                for _ in 0..10 {
262                    let node = 10u64.random_child_at_height(ht, &mut rng);
263                    assert_eq!(node.height(), ht);
264                    assert!(node.parents().any(|x| x == 10u64))
265                }
266            }
267        })
268    }
269
270    // Test that parents iterator is giving expected outputs
271    #[test]
272    fn test_parents_iterator_u64() {
273        assert_eq!(collect_to_vec(1u64.parents()), vec![0b1]);
274        assert_eq!(collect_to_vec(2u64.parents()), vec![0b10, 0b1]);
275        assert_eq!(collect_to_vec(3u64.parents()), vec![0b11, 0b1]);
276        assert_eq!(collect_to_vec(4u64.parents()), vec![0b100, 0b10, 0b1]);
277        assert_eq!(collect_to_vec(5u64.parents()), vec![0b101, 0b10, 0b1]);
278        assert_eq!(collect_to_vec(6u64.parents()), vec![0b110, 0b11, 0b1]);
279        assert_eq!(collect_to_vec(7u64.parents()), vec![0b111, 0b11, 0b1]);
280        assert_eq!(
281            collect_to_vec(8u64.parents()),
282            vec![0b1000, 0b100, 0b10, 0b1]
283        );
284        assert_eq!(
285            collect_to_vec(9u64.parents()),
286            vec![0b1001, 0b100, 0b10, 0b1]
287        );
288        assert_eq!(
289            collect_to_vec(10u64.parents()),
290            vec![0b1010, 0b101, 0b10, 0b1]
291        );
292        assert_eq!(
293            collect_to_vec(11u64.parents()),
294            vec![0b1011, 0b101, 0b10, 0b1]
295        );
296        assert_eq!(
297            collect_to_vec(12u64.parents()),
298            vec![0b1100, 0b110, 0b11, 0b1]
299        );
300        assert_eq!(
301            collect_to_vec(13u64.parents()),
302            vec![0b1101, 0b110, 0b11, 0b1]
303        );
304        assert_eq!(
305            collect_to_vec(14u64.parents()),
306            vec![0b1110, 0b111, 0b11, 0b1]
307        );
308        assert_eq!(
309            collect_to_vec(15u64.parents()),
310            vec![0b1111, 0b111, 0b11, 0b1]
311        );
312        assert_eq!(
313            collect_to_vec(16u64.parents()),
314            vec![0b10000, 0b1000, 0b100, 0b10, 0b1]
315        );
316        assert_eq!(
317            collect_to_vec(17u64.parents()),
318            vec![0b10001, 0b1000, 0b100, 0b10, 0b1]
319        );
320        assert_eq!(
321            collect_to_vec(18u64.parents()),
322            vec![0b10010, 0b1001, 0b100, 0b10, 0b1]
323        );
324        assert_eq!(
325            collect_to_vec(19u64.parents()),
326            vec![0b10011, 0b1001, 0b100, 0b10, 0b1]
327        );
328    }
329
330    // Test that parents iterator is giving expected outputs
331    #[test]
332    fn test_parents_iterator_u32() {
333        assert_eq!(collect_to_vec(1u32.parents()), vec![0b1]);
334        assert_eq!(collect_to_vec(2u32.parents()), vec![0b10, 0b1]);
335        assert_eq!(collect_to_vec(3u32.parents()), vec![0b11, 0b1]);
336        assert_eq!(collect_to_vec(4u32.parents()), vec![0b100, 0b10, 0b1]);
337        assert_eq!(collect_to_vec(5u32.parents()), vec![0b101, 0b10, 0b1]);
338        assert_eq!(collect_to_vec(6u32.parents()), vec![0b110, 0b11, 0b1]);
339        assert_eq!(collect_to_vec(7u32.parents()), vec![0b111, 0b11, 0b1]);
340        assert_eq!(
341            collect_to_vec(8u32.parents()),
342            vec![0b1000, 0b100, 0b10, 0b1]
343        );
344        assert_eq!(
345            collect_to_vec(9u32.parents()),
346            vec![0b1001, 0b100, 0b10, 0b1]
347        );
348        assert_eq!(
349            collect_to_vec(10u32.parents()),
350            vec![0b1010, 0b101, 0b10, 0b1]
351        );
352        assert_eq!(
353            collect_to_vec(11u32.parents()),
354            vec![0b1011, 0b101, 0b10, 0b1]
355        );
356        assert_eq!(
357            collect_to_vec(12u32.parents()),
358            vec![0b1100, 0b110, 0b11, 0b1]
359        );
360        assert_eq!(
361            collect_to_vec(13u32.parents()),
362            vec![0b1101, 0b110, 0b11, 0b1]
363        );
364        assert_eq!(
365            collect_to_vec(14u32.parents()),
366            vec![0b1110, 0b111, 0b11, 0b1]
367        );
368        assert_eq!(
369            collect_to_vec(15u32.parents()),
370            vec![0b1111, 0b111, 0b11, 0b1]
371        );
372        assert_eq!(
373            collect_to_vec(16u32.parents()),
374            vec![0b10000, 0b1000, 0b100, 0b10, 0b1]
375        );
376        assert_eq!(
377            collect_to_vec(17u32.parents()),
378            vec![0b10001, 0b1000, 0b100, 0b10, 0b1]
379        );
380        assert_eq!(
381            collect_to_vec(18u32.parents()),
382            vec![0b10010, 0b1001, 0b100, 0b10, 0b1]
383        );
384        assert_eq!(
385            collect_to_vec(19u32.parents()),
386            vec![0b10011, 0b1001, 0b100, 0b10, 0b1]
387        );
388    }
389
390    // Test that common_ancestor_distance_of_peers is giving expected outputs
391    #[test]
392    fn test_common_ancestor_u64() {
393        assert_eq!(1u64.common_ancestor_distance_of_peers(&1u64), 0);
394        assert_eq!(2u64.common_ancestor_distance_of_peers(&2u64), 0);
395        assert_eq!(2u64.common_ancestor_distance_of_peers(&3u64), 1);
396        assert_eq!(3u64.common_ancestor_distance_of_peers(&3u64), 0);
397        assert_eq!(4u64.common_ancestor_distance_of_peers(&7u64), 2);
398        assert_eq!(4u64.common_ancestor_distance_of_peers(&5u64), 1);
399        assert_eq!(4u64.common_ancestor_distance_of_peers(&6u64), 2);
400        assert_eq!(7u64.common_ancestor_distance_of_peers(&7u64), 0);
401        assert_eq!(7u64.common_ancestor_distance_of_peers(&6u64), 1);
402        assert_eq!(7u64.common_ancestor_distance_of_peers(&5u64), 2);
403        assert_eq!(17u64.common_ancestor_distance_of_peers(&31u64), 4);
404        assert_eq!(17u64.common_ancestor_distance_of_peers(&23u64), 3);
405        assert_eq!(17u64.common_ancestor_distance_of_peers(&19u64), 2);
406    }
407
408    // Test that common_ancestor_distance_of_peers is giving expected outputs
409    #[test]
410    fn test_common_ancestor_u32() {
411        assert_eq!(1u32.common_ancestor_distance_of_peers(&1u32), 0);
412        assert_eq!(2u32.common_ancestor_distance_of_peers(&2u32), 0);
413        assert_eq!(2u32.common_ancestor_distance_of_peers(&3u32), 1);
414        assert_eq!(3u32.common_ancestor_distance_of_peers(&3u32), 0);
415        assert_eq!(4u32.common_ancestor_distance_of_peers(&7u32), 2);
416        assert_eq!(4u32.common_ancestor_distance_of_peers(&5u32), 1);
417        assert_eq!(4u32.common_ancestor_distance_of_peers(&6u32), 2);
418        assert_eq!(7u32.common_ancestor_distance_of_peers(&7u32), 0);
419        assert_eq!(7u32.common_ancestor_distance_of_peers(&6u32), 1);
420        assert_eq!(7u32.common_ancestor_distance_of_peers(&5u32), 2);
421        assert_eq!(17u32.common_ancestor_distance_of_peers(&31u32), 4);
422        assert_eq!(17u32.common_ancestor_distance_of_peers(&23u32), 3);
423        assert_eq!(17u32.common_ancestor_distance_of_peers(&19u32), 2);
424    }
425
426    // Naive implementation of common_ancestor_distance_of_peers
427    fn naive_common_ancestor_distance_of_peers<I: TreeIndex>(lhs: &I, rhs: &I) -> u32 {
428        let mut counter = 0u32;
429        let mut it1 = lhs.parents();
430        let mut it2 = rhs.parents();
431        while it1.next().unwrap() != it2.next().unwrap() {
432            counter += 1;
433        }
434        counter
435    }
436
437    // Test that common_ancestor_distance_of_peers agrees with the naive
438    // implementation
439    #[test]
440    fn common_ancestor_distance_conformance_u64() {
441        test_helper::run_with_several_seeds(|mut rng| {
442            for ht in 0..30 {
443                for _ in 0..10 {
444                    let node = 1u64.random_child_at_height(ht, &mut rng);
445                    let node2 = 1u64.random_child_at_height(ht, &mut rng);
446                    assert_eq!(
447                        node.common_ancestor_distance_of_peers(&node2),
448                        naive_common_ancestor_distance_of_peers(&node, &node2)
449                    );
450                }
451            }
452
453            for ht in 20..30 {
454                for _ in 0..10 {
455                    let node = 16u64.random_child_at_height(ht, &mut rng);
456                    let node2 = 16u64.random_child_at_height(ht, &mut rng);
457                    assert_eq!(
458                        node.common_ancestor_distance_of_peers(&node2),
459                        naive_common_ancestor_distance_of_peers(&node, &node2)
460                    );
461                }
462            }
463        })
464    }
465
466    // Test that common_ancestor_distance_of_peers agrees with the naive
467    // implementation
468    #[test]
469    fn common_ancestor_distance_conformance_u32() {
470        test_helper::run_with_several_seeds(|mut rng| {
471            for ht in 0..30 {
472                for _ in 0..10 {
473                    let node = 1u32.random_child_at_height(ht, &mut rng);
474                    let node2 = 1u32.random_child_at_height(ht, &mut rng);
475                    assert_eq!(
476                        node.common_ancestor_distance_of_peers(&node2),
477                        naive_common_ancestor_distance_of_peers(&node, &node2)
478                    );
479                }
480            }
481
482            for ht in 20..30 {
483                for _ in 0..10 {
484                    let node = 16u32.random_child_at_height(ht, &mut rng);
485                    let node2 = 16u32.random_child_at_height(ht, &mut rng);
486                    assert_eq!(
487                        node.common_ancestor_distance_of_peers(&node2),
488                        naive_common_ancestor_distance_of_peers(&node, &node2)
489                    );
490                }
491            }
492        })
493    }
494
495    // Test that common_ancestor_height is giving expected results for nodes
496    // at different heights.
497    #[test]
498    fn common_ancestor_height_u64() {
499        assert_eq!(1u64.common_ancestor_height(&1u64), 0);
500        assert_eq!(2u64.common_ancestor_height(&2u64), 1);
501        assert_eq!(4u64.common_ancestor_height(&4u64), 2);
502        assert_eq!(8u64.common_ancestor_height(&8u64), 3);
503        assert_eq!(8u64.common_ancestor_height(&4u64), 2);
504        assert_eq!(8u64.common_ancestor_height(&7u64), 0);
505        assert_eq!(8u64.common_ancestor_height(&3u64), 0);
506        assert_eq!(8u64.common_ancestor_height(&9u64), 2);
507        assert_eq!(8u64.common_ancestor_height(&11u64), 1);
508        assert_eq!(8u64.common_ancestor_height(&13u64), 0);
509        assert_eq!(16u64.common_ancestor_height(&8u64), 3);
510        assert_eq!(16u64.common_ancestor_height(&4u64), 2);
511        assert_eq!(16u64.common_ancestor_height(&7u64), 0);
512        assert_eq!(16u64.common_ancestor_height(&3u64), 0);
513        assert_eq!(16u64.common_ancestor_height(&9u64), 2);
514        assert_eq!(16u64.common_ancestor_height(&11u64), 1);
515        assert_eq!(16u64.common_ancestor_height(&13u64), 0);
516        assert_eq!(17u64.common_ancestor_height(&15u64), 0);
517        assert_eq!(17u64.common_ancestor_height(&19u64), 2);
518        assert_eq!(17u64.common_ancestor_height(&21u64), 1);
519        assert_eq!(17u64.common_ancestor_height(&31u64), 0);
520        assert_eq!(17u64.common_ancestor_height(&63u64), 0);
521        assert_eq!(17u64.common_ancestor_height(&127u64), 0);
522    }
523
524    // Test that common_ancestor_height is giving expected results for nodes
525    // at different heights.
526    #[test]
527    fn common_ancestor_height_u32() {
528        assert_eq!(1u32.common_ancestor_height(&1u32), 0);
529        assert_eq!(2u32.common_ancestor_height(&2u32), 1);
530        assert_eq!(4u32.common_ancestor_height(&4u32), 2);
531        assert_eq!(8u32.common_ancestor_height(&8u32), 3);
532        assert_eq!(8u32.common_ancestor_height(&4u32), 2);
533        assert_eq!(8u32.common_ancestor_height(&7u32), 0);
534        assert_eq!(8u32.common_ancestor_height(&3u32), 0);
535        assert_eq!(8u32.common_ancestor_height(&9u32), 2);
536        assert_eq!(8u32.common_ancestor_height(&11u32), 1);
537        assert_eq!(8u32.common_ancestor_height(&13u32), 0);
538        assert_eq!(16u32.common_ancestor_height(&8u32), 3);
539        assert_eq!(16u32.common_ancestor_height(&4u32), 2);
540        assert_eq!(16u32.common_ancestor_height(&7u32), 0);
541        assert_eq!(16u32.common_ancestor_height(&3u32), 0);
542        assert_eq!(16u32.common_ancestor_height(&9u32), 2);
543        assert_eq!(16u32.common_ancestor_height(&11u32), 1);
544        assert_eq!(16u32.common_ancestor_height(&13u32), 0);
545        assert_eq!(17u32.common_ancestor_height(&15u32), 0);
546        assert_eq!(17u32.common_ancestor_height(&19u32), 2);
547        assert_eq!(17u32.common_ancestor_height(&21u32), 1);
548        assert_eq!(17u32.common_ancestor_height(&31u32), 0);
549        assert_eq!(17u32.common_ancestor_height(&63u32), 0);
550        assert_eq!(17u32.common_ancestor_height(&127u32), 0);
551    }
552}