healm/
lib.rs

1//! **He**ap **al**located **me**rkle tree.
2#![deny(missing_docs)]
3#![deny(clippy::pedantic)]
4#![deny(rustdoc::broken_intra_doc_links)]
5#![feature(allocator_api)]
6#![cfg_attr(docsrs, feature(doc_cfg))]
7#![cfg_attr(feature = "bench", feature(test))]
8#![no_std]
9
10extern crate alloc;
11extern crate core;
12
13use alloc::alloc::{Allocator, Global, Layout};
14
15use core::{
16    mem,
17    ptr::{self, NonNull},
18    slice,
19};
20
21#[cfg(feature = "blake3")]
22#[cfg_attr(docsrs, doc(cfg(feature = "blake3")))]
23mod blake3;
24
25/// Types that can be nodes in a `HamTree`. `Aggregate`s can be aggregated to
26/// produce a new node.
27///
28/// `Aggregate` types must be `PartialEq`, since comparison is used to determine
29/// whether a `HamTree` node is zeroed or not.
30///
31/// # Safety
32/// The implementer must ensure that the type is safely zeroable. If
33/// [`mem::zeroed`] is safe to call on the type, then it is also safe to
34/// implement `Aggregate`.
35pub unsafe trait Aggregate: Eq + Sized {
36    /// Aggregate the given nodes into a new node.
37    fn aggregate(nodes: &[Self]) -> Self;
38}
39
40fn empty_node<T: Aggregate>() -> T {
41    unsafe { mem::zeroed() }
42}
43
44unsafe impl<T, const H: usize, const A: usize, Alloc: Allocator> Send
45    for HamTree<T, H, A, Alloc>
46{
47}
48unsafe impl<T, const H: usize, const A: usize, Alloc: Allocator> Sync
49    for HamTree<T, H, A, Alloc>
50{
51}
52
53/// A heap allocated Merkle tree.
54pub struct HamTree<T, const H: usize, const A: usize, Alloc: Allocator = Global>
55{
56    base: *mut T,
57    alloc: Alloc,
58}
59
60impl<T, const H: usize, const A: usize, Alloc: Allocator>
61    HamTree<T, H, A, Alloc>
62{
63    /// The maximum number of leaves a tree can hold.
64    pub const N_LEAVES: usize = n_tree_leaves(H, A);
65
66    /// Layout of the tree in memory.
67    const LAYOUT: Layout = tree_layout::<T>(H, A);
68}
69
70impl<T, const H: usize, const A: usize> HamTree<T, H, A>
71where
72    T: Aggregate,
73{
74    /// Construct a new, empty `HamTree`.
75    ///
76    /// The tree will not allocate until leaves are inserted.
77    #[must_use]
78    pub const fn new() -> Self {
79        Self {
80            base: ptr::null_mut(),
81            alloc: Global,
82        }
83    }
84}
85
86impl<T, const H: usize, const A: usize, Alloc> HamTree<T, H, A, Alloc>
87where
88    T: Aggregate,
89    Alloc: Allocator,
90{
91    /// Construct a new, empty `HamTree`, that will allocate using the given
92    /// `alloc`ator.
93    ///
94    /// The tree will not allocate until leaves are inserted.
95    pub const fn new_in(alloc: Alloc) -> Self {
96        Self {
97            base: ptr::null_mut(),
98            alloc,
99        }
100    }
101
102    /// Inserts a leaf at position `index` in the tree, ejecting the last
103    /// element occupying the position, if any.
104    ///
105    /// # Panics
106    /// Panics if `index >= capacity`, or the underlying allocator fails if it
107    /// is the first insertion.
108    pub fn insert(&mut self, index: usize, leaf: T) -> Option<T> {
109        assert!(index < Self::N_LEAVES, "Index out of bounds");
110
111        self.ensure_allocated();
112
113        // safety: the memory was just allocated, and we ensure in the layout
114        // that our calculations never leave the bounds of the allocated object
115        //
116        // # See docs/layout.svg
117        // # https://doc.rust-lang.org/core/ptr/index.html#allocated-object
118        unsafe {
119            let mut level_ptr = self.base;
120            let mut index = index;
121
122            // Modify the leaf node
123            let mut leaf = leaf;
124            let leaf_ptr = level_ptr.add(index);
125            ptr::swap(leaf_ptr, &mut leaf);
126
127            let empty_children: [T; A] = mem::zeroed();
128
129            // Propagate changes towards the root
130            let mut n_nodes = Self::N_LEAVES;
131            for _ in 0..H {
132                let next_level_ptr = level_ptr.add(n_nodes);
133
134                let next_n_nodes = n_nodes / A;
135                let next_index = index / A;
136
137                let children_index = index - (index % A);
138                let children_ptr = level_ptr.add(children_index);
139                let children: *const [T; A] = children_ptr.cast();
140
141                let parent_ptr = next_level_ptr.add(next_index);
142                // The new parent will be empty if all children are empty nodes,
143                // otherwise it will be the aggregate of the children.
144                let parent = if *children == empty_children {
145                    empty_node()
146                } else {
147                    T::aggregate(&*children)
148                };
149                *parent_ptr = parent;
150
151                index = next_index;
152                n_nodes = next_n_nodes;
153
154                level_ptr = next_level_ptr;
155            }
156
157            if leaf == empty_node() {
158                None
159            } else {
160                Some(leaf)
161            }
162        }
163    }
164
165    /// Removes the leaf at the given index, returning it if present.
166    pub fn remove(&mut self, index: usize) -> Option<T> {
167        if self.is_unallocated() {
168            return None;
169        }
170
171        self.insert(index, empty_node())
172    }
173
174    fn empty(node: &T) -> Option<&T> {
175        if *node == empty_node::<T>() {
176            None
177        } else {
178            Some(node)
179        }
180    }
181
182    /// Returns the leaf at the given index, if any.
183    pub fn leaf(&self, index: usize) -> Option<&T> {
184        if self.is_unallocated() {
185            return None;
186        }
187
188        // safety: we check that the tree is allocated above, so de-referencing
189        // is safe.
190        unsafe {
191            let leaf_ptr = self.base.add(index);
192            let leaf = &*leaf_ptr.cast::<T>();
193
194            Self::empty(leaf)
195        }
196    }
197
198    /// Returns an iterator over the leaves of the tree.
199    pub fn leaves(&self) -> impl Iterator<Item = &T> {
200        if self.is_unallocated() {
201            return [].iter().filter_map(Self::empty);
202        }
203
204        // safety: we check that the tree is allocated above, so de-referencing
205        // is safe.
206        unsafe {
207            slice::from_raw_parts(self.base, Self::N_LEAVES)
208                .iter()
209                .filter_map(Self::empty)
210        }
211    }
212
213    /// Returns the branch at the given index, if any.
214    pub fn branch(&self, index: usize) -> Option<HamBranch<T, H, A>> {
215        if self.is_unallocated() {
216            return None;
217        }
218
219        // safety: we check that the tree is allocated above, so de-referencing
220        // is safe.
221        unsafe {
222            let mut level_ptr = self.base;
223            let mut index = index;
224
225            // If the leaf is empty, the branch doesn't exist
226            let leaf_ptr = level_ptr.add(index);
227            if *leaf_ptr == empty_node() {
228                return None;
229            }
230
231            let mut offsets = [0; H];
232            let mut levels: [[T; A]; H] = mem::zeroed();
233
234            // Propagate changes towards the root
235            let mut n_nodes = Self::N_LEAVES;
236            for h in 0..H {
237                let next_level_ptr = level_ptr.add(n_nodes);
238
239                let next_n_nodes = n_nodes / A;
240                let next_index = index / A;
241
242                let children_index = index - (index % A);
243                let children_ptr = level_ptr.add(children_index);
244                let children: *const [T; A] = children_ptr.cast();
245
246                offsets[h] = index - children_index;
247                levels[h] = ptr::read(children);
248
249                index = next_index;
250                n_nodes = next_n_nodes;
251
252                level_ptr = next_level_ptr;
253            }
254
255            Some(HamBranch {
256                root: ptr::read(level_ptr),
257                levels,
258                offsets,
259            })
260        }
261    }
262
263    /// Returns the root of the tree.
264    ///
265    /// If no leaves have been inserted, it returns `None`.
266    pub fn root(&self) -> Option<&T> {
267        if self.is_unallocated() {
268            return None;
269        }
270
271        // safety: we check that the tree is allocated above, so de-referencing
272        // the root is safe.
273        unsafe {
274            let root_ptr = self.base.add(n_tree_nodes(H, A) - 1);
275            let root = &*root_ptr.cast::<T>();
276
277            if *root == empty_node() {
278                None
279            } else {
280                Some(root)
281            }
282        }
283    }
284
285    /// The maximum number of leaves the tree can hold.
286    ///
287    /// This number is the same as [`N_LEAVES`].
288    ///
289    /// [`N_LEAVES`]: Self::N_LEAVES
290    pub const fn capacity(&self) -> usize {
291        Self::N_LEAVES
292    }
293
294    /// Ensures that the tree is allocated.
295    ///
296    /// # Panics
297    /// Panics if the underlying allocator fails.
298    fn ensure_allocated(&mut self) {
299        if self.is_unallocated() {
300            match self.alloc.allocate_zeroed(Self::LAYOUT) {
301                Ok(ptr) => self.base = ptr.as_ptr().cast(),
302                Err(err) => panic!("HamTree failed allocation: {err}"),
303            }
304        }
305    }
306
307    fn is_unallocated(&self) -> bool {
308        self.base.is_null()
309    }
310}
311
312impl<T, const H: usize, const A: usize, Alloc> Drop for HamTree<T, H, A, Alloc>
313where
314    Alloc: Allocator,
315{
316    fn drop(&mut self) {
317        // safety: we check if the tree is allocated using `NonNull::new` so
318        // de-allocating is safe.
319        unsafe {
320            if let Some(ptr) = NonNull::new(self.base) {
321                self.alloc.deallocate(ptr.cast(), Self::LAYOUT);
322            }
323        }
324    }
325}
326
327const fn tree_layout<T>(height: usize, arity: usize) -> Layout {
328    let node_size = mem::size_of::<T>();
329    let node_align = mem::align_of::<T>();
330
331    let size = n_tree_nodes(height, arity) * node_size;
332    let align = node_align;
333
334    unsafe { Layout::from_size_align_unchecked(size, align) }
335}
336
337/// Number of leaves in a tree with the given height and arity.
338const fn n_tree_leaves(height: usize, arity: usize) -> usize {
339    // SAFETY: it is safe to cast to `u32` a height larger than `u32::MAX` is
340    // pretty inconceivable anyway.
341    #[allow(clippy::cast_possible_truncation)]
342    arity.pow(height as u32)
343}
344
345/// Total number of nodes in a tree with the given height and arity.
346const fn n_tree_nodes(height: usize, arity: usize) -> usize {
347    let mut n_nodes = 0;
348
349    let mut h = 0;
350    while h <= height {
351        n_nodes += n_tree_leaves(h, arity);
352        h += 1;
353    }
354
355    n_nodes
356}
357
358/// A branch of a [`HamTree`].
359pub struct HamBranch<T, const H: usize, const A: usize> {
360    root: T,
361    levels: [[T; A]; H],
362    offsets: [usize; H],
363}
364
365impl<T, const H: usize, const A: usize> HamBranch<T, H, A> {
366    /// Root of the branch.
367    pub fn root(&self) -> &T {
368        &self.root
369    }
370
371    /// Returns the nodes of the branch, from the bottom up.
372    pub fn levels(&self) -> &[[T; A]; H] {
373        &self.levels
374    }
375
376    /// Returns the offsets of the branch, from the bottom up.
377    pub fn offsets(&self) -> &[usize; H] {
378        &self.offsets
379    }
380}
381
382impl<T, const H: usize, const A: usize> HamBranch<T, H, A>
383where
384    T: Aggregate,
385{
386    /// Returns whether the given item is the leaf of the branch, and that the
387    /// branch is correct.
388    pub fn verify(&self, node: T) -> bool {
389        let mut node = node;
390
391        for h in 0..H {
392            let level = &self.levels[h];
393            let offset = self.offsets[h];
394
395            if node != level[offset] {
396                return false;
397            }
398
399            node = T::aggregate(level);
400        }
401
402        node == self.root
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    use alloc::collections::{BTreeMap, BTreeSet};
411
412    use paste::paste;
413    use rand::{rngs::StdRng, RngCore, SeedableRng};
414
415    #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
416    struct Count(usize);
417
418    unsafe impl Aggregate for Count {
419        fn aggregate(nodes: &[Self]) -> Self {
420            let mut sum = 0;
421            for node in nodes {
422                sum += node.0;
423            }
424            Self(sum)
425        }
426    }
427
428    // A macro that generates test cases for the given arity and height.
429    macro_rules! tree_tests {
430        (H=$height:literal; A = $($arity:literal),+) => {
431            $(
432            paste! {
433                mod [<tree _ h $height _ a $arity>] {
434                    use super::*;
435
436                    type Tree = HamTree<Count, $height, $arity>;
437
438                    const N_INSERTIONS: usize = 100;
439
440                    #[test]
441                    fn insert() {
442                        let mut rng = StdRng::seed_from_u64(0xBAAD_F00D);
443
444                        let mut tree = Tree::new();
445                        let mut index_set = BTreeSet::new();
446
447                        for _ in 0..N_INSERTIONS {
448                            let i = (rng.next_u64() % Tree::N_LEAVES as u64) as usize;
449                            index_set.insert(i);
450                            tree.insert(i, Count(1));
451                        }
452
453                        let n_insertions = index_set.len();
454                        assert!(matches!(tree.root(), Some(x) if *x == Count(n_insertions)));
455                    }
456
457                    #[test]
458                    fn remove() {
459                        let mut rng = StdRng::seed_from_u64(0xBAAD_F00D);
460
461                        let mut tree = Tree::new();
462                        let mut index_set = BTreeSet::new();
463
464                        for _ in 0..N_INSERTIONS {
465                            let i = (rng.next_u64() % Tree::N_LEAVES as u64) as usize;
466                            index_set.insert(i);
467                            tree.insert(i, Count(1));
468                        }
469
470                        for i in index_set {
471                            tree.remove(i);
472                            assert!(tree.leaf(i).is_none());
473                        }
474
475                        assert!(tree.root().is_none());
476                    }
477
478                    #[test]
479                    fn leaf() {
480                        let mut rng = StdRng::seed_from_u64(0xBAAD_F00D);
481
482                        let mut tree = Tree::new();
483                        let mut index_set = BTreeSet::new();
484
485                        for _ in 0..N_INSERTIONS {
486                            let i = (rng.next_u64() % Tree::N_LEAVES as u64) as usize;
487                            index_set.insert(i);
488                            tree.insert(i, Count(1));
489                        }
490
491                        for i in index_set {
492                            assert_eq!(tree.leaf(i), Some(&Count(1)));
493                        }
494                    }
495
496                    #[test]
497                    fn empty_leaf() {
498                        let tree = Tree::new();
499                        assert_eq!(tree.leaf(0), None);
500                    }
501
502                    #[test]
503                    fn leaves() {
504                        let mut rng = StdRng::seed_from_u64(0xBAAD_F00D);
505
506                        let mut tree = Tree::new();
507                        let mut index_set = BTreeSet::new();
508
509                        for _ in 0..N_INSERTIONS {
510                            let i = (rng.next_u64() % Tree::N_LEAVES as u64) as usize;
511                            index_set.insert(i);
512                            tree.insert(i, Count(1));
513                        }
514
515                        let mut leaf_count = 0;
516                        for leaf in tree.leaves() {
517                            assert_eq!(*leaf, Count(1));
518                            leaf_count += 1;
519                        }
520
521                        let n_insertions = index_set.len();
522                        assert_eq!(leaf_count, n_insertions);
523                    }
524
525                    #[test]
526                    fn empty_leaves() {
527                        let tree = Tree::new();
528                        assert_eq!(tree.leaves().count(), 0);
529                    }
530
531                    #[test]
532                    fn branch() {
533                        let mut rng = StdRng::seed_from_u64(0xBAAD_F00D);
534
535                        let mut tree = Tree::new();
536                        let mut index_map = BTreeMap::new();
537
538                        for _ in 0..N_INSERTIONS {
539                            let i = (rng.next_u64() % Tree::N_LEAVES as u64) as usize;
540                            let c = (rng.next_u64() % Tree::N_LEAVES as u64) as usize;
541                            index_map.insert(i, c);
542                            tree.insert(i, Count(c));
543                        }
544
545                        for (i, c) in index_map {
546                            let branch = tree.branch(i);
547                            // The branch should not exist if the `mem::zeroed` leaf was inserted.
548                            if c == unsafe { mem::zeroed() } {
549                                assert!(branch.is_none());
550                            } else {
551                                assert!(matches!(branch, Some(b) if b.verify(Count(c))));
552                            }
553                        }
554                    }
555
556                    #[test]
557                    fn empty_branch() {
558                        let tree = Tree::new();
559                        let branch = tree.branch(0);
560                        assert!(branch.is_none());
561                    }
562                }
563            }
564            )+
565        };
566    }
567
568    tree_tests!(H = 0; A = 2, 3, 4, 5);
569    tree_tests!(H = 1; A = 2, 3, 4, 5);
570    tree_tests!(H = 2; A = 2, 3, 4, 5);
571    tree_tests!(H = 3; A = 2, 3, 4, 5);
572    tree_tests!(H = 4; A = 2, 3, 4, 5);
573    tree_tests!(H = 5; A = 2, 3, 4, 5);
574    tree_tests!(H = 6; A = 2, 3, 4, 5);
575    tree_tests!(H = 7; A = 2, 3, 4, 5);
576    tree_tests!(H = 8; A = 2, 3, 4, 5);
577}