merkletree_mintlayer/merkle/tree/
tree_size.rs

1// Copyright (c) 2021-2024 RBB S.r.l
2// opensource@mintlayer.org
3// SPDX-License-Identifier: MIT
4// Licensed under the MIT License;
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// https://github.com/mintlayer/merkletree-mintlayer/blob/master/LICENSE
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::{
17    fmt::{Display, Formatter},
18    num::NonZeroU32,
19};
20
21use itertools::Itertools;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
24pub struct TreeSize(u32);
25
26const MAX_TREE_SIZE: u32 = 1 << 31;
27
28#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
29pub enum TreeSizeError {
30    #[error("Zero is invalid size for tree")]
31    ZeroSize,
32    #[error("Tree size must be power of two minus one; this value was found: {0}")]
33    InvalidSize(u32),
34    #[error("Tree with this huge size is not supported: {0}")]
35    HugeTreeUnsupported(u64),
36}
37
38impl TreeSize {
39    pub fn get(&self) -> u32 {
40        self.0
41    }
42
43    pub fn leaf_count(&self) -> NonZeroU32 {
44        ((self.0 + 1) / 2)
45            .try_into()
46            .expect("Guaranteed by construction")
47    }
48
49    pub fn level_count(&self) -> NonZeroU32 {
50        self.0
51            .count_ones()
52            .try_into()
53            .expect("Guaranteed by construction")
54    }
55
56    pub fn from_u32(value: u32) -> Result<Self, TreeSizeError> {
57        Self::try_from(value)
58    }
59
60    pub fn from_usize(value: usize) -> Result<Self, TreeSizeError> {
61        Self::try_from(value)
62    }
63
64    pub fn from_leaf_count(leaf_count: u32) -> Result<Self, TreeSizeError> {
65        if leaf_count == 0 {
66            return Err(TreeSizeError::ZeroSize);
67        }
68        Self::try_from(leaf_count * 2 - 1)
69    }
70
71    /// The absolute index, at which the first node at level `level_from_bottom` starts.
72    pub fn level_start(&self, level_from_bottom: u32) -> Option<u32> {
73        let level_count = self.level_count().get();
74        if level_from_bottom >= level_count {
75            return None;
76        }
77
78        // To help in seeing how these formulas were derived, see this table that represents values in the case tree.len() == 31 == 0b11111:
79        //  level     level_start  level_start in binary    index_in_level_size
80        //  0         0            00000                    16
81        //  1         16           10000                    8
82        //  2         24           11000                    4
83        //  3         28           11100                    2
84        //  4         30           11110                    1
85
86        let level_from_top = level_count - level_from_bottom;
87        // to get leading ones, we shift the tree size, right then left, by the level we need (see the table above)
88        let level_start = (self.0 >> level_from_top) << level_from_top;
89        Some(level_start)
90    }
91
92    /// Creates an iterator that returns the indices of the nodes of the tree, from left to right, as pairs.
93    /// Root isn't included in this iterator
94    pub fn iter_pairs_indices(&self) -> impl Iterator<Item = (u32, u32)> {
95        (0..self.get() - 1).tuple_windows::<(u32, u32)>().step_by(2)
96    }
97}
98
99impl TryFrom<u32> for TreeSize {
100    type Error = TreeSizeError;
101
102    fn try_from(value: u32) -> Result<Self, Self::Error> {
103        if value == 0 {
104            Err(TreeSizeError::ZeroSize)
105        } else if !(value + 1).is_power_of_two() {
106            Err(TreeSizeError::InvalidSize(value))
107        } else if value > MAX_TREE_SIZE {
108            Err(TreeSizeError::HugeTreeUnsupported(value as u64))
109        } else {
110            Ok(Self(value))
111        }
112    }
113}
114
115impl TryFrom<usize> for TreeSize {
116    type Error = TreeSizeError;
117
118    fn try_from(value: usize) -> Result<Self, Self::Error> {
119        if value > MAX_TREE_SIZE as usize {
120            return Err(TreeSizeError::HugeTreeUnsupported(value as u64));
121        }
122        let size: u32 = value
123            .try_into()
124            .expect("Must fit because of last MAX_TREE_SIZE check");
125        Self::try_from(size)
126    }
127}
128
129impl From<TreeSize> for u32 {
130    fn from(tree_size: TreeSize) -> Self {
131        tree_size.0
132    }
133}
134
135impl AsRef<u32> for TreeSize {
136    fn as_ref(&self) -> &u32 {
137        &self.0
138    }
139}
140
141impl Display for TreeSize {
142    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
143        write!(f, "{}", self.0)
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use crate::rand_tools::{make_seedable_rng, Seed};
150    use rand::Rng;
151    use rstest::rstest;
152
153    use super::*;
154
155    #[rstest]
156    #[trace]
157    #[case(Seed::from_entropy())]
158    fn construction_from_tree_size(#[case] seed: Seed) {
159        let mut rng = make_seedable_rng(seed);
160
161        // select simple values
162        assert_eq!(TreeSize::from_u32(0), Err(TreeSizeError::ZeroSize));
163        assert_eq!(TreeSize::from_u32(1), Ok(TreeSize(1)));
164        assert_eq!(TreeSize::from_u32(2), Err(TreeSizeError::InvalidSize(2)));
165        assert_eq!(TreeSize::from_u32(3), Ok(TreeSize(3)));
166        assert_eq!(TreeSize::from_u32(4), Err(TreeSizeError::InvalidSize(4)));
167        assert_eq!(TreeSize::from_u32(5), Err(TreeSizeError::InvalidSize(5)));
168        assert_eq!(TreeSize::from_u32(6), Err(TreeSizeError::InvalidSize(6)));
169        assert_eq!(TreeSize::from_u32(7), Ok(TreeSize(7)));
170        assert_eq!(TreeSize::from_u32(8), Err(TreeSizeError::InvalidSize(8)));
171        assert_eq!(TreeSize::from_u32(9), Err(TreeSizeError::InvalidSize(9)));
172        assert_eq!(TreeSize::from_u32(10), Err(TreeSizeError::InvalidSize(10)));
173        assert_eq!(TreeSize::from_u32(11), Err(TreeSizeError::InvalidSize(11)));
174        assert_eq!(TreeSize::from_u32(12), Err(TreeSizeError::InvalidSize(12)));
175        assert_eq!(TreeSize::from_u32(13), Err(TreeSizeError::InvalidSize(13)));
176        assert_eq!(TreeSize::from_u32(14), Err(TreeSizeError::InvalidSize(14)));
177        assert_eq!(TreeSize::from_u32(15), Ok(TreeSize(15)));
178        assert_eq!(TreeSize::from_u32(16), Err(TreeSizeError::InvalidSize(16)));
179
180        // exhaustive valid
181        for i in 1..MAX_TREE_SIZE.ilog2() {
182            assert_eq!(TreeSize::from_u32((1 << i) - 1), Ok(TreeSize((1 << i) - 1)));
183        }
184
185        // random invalid
186        let attempts_count: u32 = 1000;
187        for _ in 0..attempts_count {
188            let sz = rng.gen_range(1..MAX_TREE_SIZE);
189            if (sz + 1).is_power_of_two() {
190                assert_eq!(TreeSize::try_from(sz), Ok(TreeSize(sz)));
191                assert_eq!(TreeSize::from_u32(sz), Ok(TreeSize(sz)));
192            } else {
193                assert_eq!(TreeSize::try_from(sz), Err(TreeSizeError::InvalidSize(sz)));
194                assert_eq!(TreeSize::from_u32(sz), Err(TreeSizeError::InvalidSize(sz)));
195            }
196        }
197    }
198
199    #[test]
200    fn construction_from_leaf_count() {
201        assert_eq!(TreeSize::from_leaf_count(0), Err(TreeSizeError::ZeroSize));
202        assert_eq!(TreeSize::from_leaf_count(1), Ok(TreeSize(1)));
203        assert_eq!(TreeSize::from_leaf_count(2), Ok(TreeSize(3)));
204        assert_eq!(
205            TreeSize::from_leaf_count(3),
206            Err(TreeSizeError::InvalidSize(5))
207        );
208        assert_eq!(TreeSize::from_leaf_count(4), Ok(TreeSize(7)));
209        assert_eq!(
210            TreeSize::from_leaf_count(5),
211            Err(TreeSizeError::InvalidSize(9))
212        );
213        assert_eq!(
214            TreeSize::from_leaf_count(6),
215            Err(TreeSizeError::InvalidSize(11))
216        );
217        assert_eq!(
218            TreeSize::from_leaf_count(7),
219            Err(TreeSizeError::InvalidSize(13))
220        );
221        assert_eq!(TreeSize::from_leaf_count(8), Ok(TreeSize(15)));
222        assert_eq!(
223            TreeSize::from_leaf_count(9),
224            Err(TreeSizeError::InvalidSize(17))
225        );
226        assert_eq!(
227            TreeSize::from_leaf_count(10),
228            Err(TreeSizeError::InvalidSize(19))
229        );
230        assert_eq!(
231            TreeSize::from_leaf_count(11),
232            Err(TreeSizeError::InvalidSize(21))
233        );
234        assert_eq!(
235            TreeSize::from_leaf_count(12),
236            Err(TreeSizeError::InvalidSize(23))
237        );
238        assert_eq!(
239            TreeSize::from_leaf_count(13),
240            Err(TreeSizeError::InvalidSize(25))
241        );
242        assert_eq!(
243            TreeSize::from_leaf_count(14),
244            Err(TreeSizeError::InvalidSize(27))
245        );
246        assert_eq!(
247            TreeSize::from_leaf_count(15),
248            Err(TreeSizeError::InvalidSize(29))
249        );
250        assert_eq!(TreeSize::from_leaf_count(16), Ok(TreeSize(31)));
251        assert_eq!(
252            TreeSize::from_leaf_count(17),
253            Err(TreeSizeError::InvalidSize(33))
254        );
255    }
256
257    #[test]
258    fn calculations() {
259        let t1 = TreeSize::from_u32(1).unwrap();
260        assert_eq!(t1.get(), 1);
261        assert_eq!(t1.leaf_count().get(), 1);
262        assert_eq!(t1.level_count().get(), 1);
263        assert_eq!(t1.level_start(0).unwrap(), 0);
264        for i in 1..1000u32 {
265            assert_eq!(t1.level_start(i), None);
266        }
267
268        let t3 = TreeSize::from_u32(3).unwrap();
269        assert_eq!(t3.get(), 3);
270        assert_eq!(t3.leaf_count().get(), 2);
271        assert_eq!(t3.level_count().get(), 2);
272        for i in 2..1000u32 {
273            assert_eq!(t3.level_start(i), None);
274        }
275
276        let t7 = TreeSize::from_u32(7).unwrap();
277        assert_eq!(t7.get(), 7);
278        assert_eq!(t7.leaf_count().get(), 4);
279        assert_eq!(t7.level_count().get(), 3);
280        assert_eq!(t7.level_start(0).unwrap(), 0);
281        assert_eq!(t7.level_start(1).unwrap(), 4);
282        assert_eq!(t7.level_start(2).unwrap(), 6);
283        for i in 3..1000u32 {
284            assert_eq!(t7.level_start(i), None);
285        }
286
287        let t15 = TreeSize::from_u32(15).unwrap();
288        assert_eq!(t15.get(), 15);
289        assert_eq!(t15.leaf_count().get(), 8);
290        assert_eq!(t15.level_count().get(), 4);
291        assert_eq!(t15.level_start(0).unwrap(), 0);
292        assert_eq!(t15.level_start(1).unwrap(), 8);
293        assert_eq!(t15.level_start(2).unwrap(), 12);
294        assert_eq!(t15.level_start(3).unwrap(), 14);
295        for i in 4..1000u32 {
296            assert_eq!(t15.level_start(i), None);
297        }
298
299        let t31 = TreeSize::from_u32(31).unwrap();
300        assert_eq!(t31.get(), 31);
301        assert_eq!(t31.leaf_count().get(), 16);
302        assert_eq!(t31.level_count().get(), 5);
303        assert_eq!(t31.level_start(0).unwrap(), 0);
304        assert_eq!(t31.level_start(1).unwrap(), 16);
305        assert_eq!(t31.level_start(2).unwrap(), 24);
306        assert_eq!(t31.level_start(3).unwrap(), 28);
307        assert_eq!(t31.level_start(4).unwrap(), 30);
308        for i in 5..1000u32 {
309            assert_eq!(t31.level_start(i), None);
310        }
311
312        let t63 = TreeSize::from_u32(63).unwrap();
313        assert_eq!(t63.get(), 63);
314        assert_eq!(t63.leaf_count().get(), 32);
315        assert_eq!(t63.level_count().get(), 6);
316        assert_eq!(t63.level_start(0).unwrap(), 0);
317        assert_eq!(t63.level_start(1).unwrap(), 32);
318        assert_eq!(t63.level_start(2).unwrap(), 48);
319        assert_eq!(t63.level_start(3).unwrap(), 56);
320        assert_eq!(t63.level_start(4).unwrap(), 60);
321        assert_eq!(t63.level_start(5).unwrap(), 62);
322        for i in 6..1000u32 {
323            assert_eq!(t63.level_start(i), None);
324        }
325
326        let t127 = TreeSize::from_u32(127).unwrap();
327        assert_eq!(t127.get(), 127);
328        assert_eq!(t127.leaf_count().get(), 64);
329        assert_eq!(t127.level_count().get(), 7);
330        assert_eq!(t127.level_start(0).unwrap(), 0);
331        assert_eq!(t127.level_start(1).unwrap(), 64);
332        assert_eq!(t127.level_start(2).unwrap(), 96);
333        assert_eq!(t127.level_start(3).unwrap(), 112);
334        assert_eq!(t127.level_start(4).unwrap(), 120);
335        assert_eq!(t127.level_start(5).unwrap(), 124);
336        assert_eq!(t127.level_start(6).unwrap(), 126);
337        for i in 7..1000u32 {
338            assert_eq!(t127.level_start(i), None);
339        }
340
341        let t255 = TreeSize::from_u32(255).unwrap();
342        assert_eq!(t255.get(), 255);
343        assert_eq!(t255.leaf_count().get(), 128);
344        assert_eq!(t255.level_count().get(), 8);
345        assert_eq!(t255.level_start(0).unwrap(), 0);
346        assert_eq!(t255.level_start(1).unwrap(), 128);
347        assert_eq!(t255.level_start(2).unwrap(), 192);
348        assert_eq!(t255.level_start(3).unwrap(), 224);
349        assert_eq!(t255.level_start(4).unwrap(), 240);
350        assert_eq!(t255.level_start(5).unwrap(), 248);
351        assert_eq!(t255.level_start(6).unwrap(), 252);
352        assert_eq!(t255.level_start(7).unwrap(), 254);
353        for i in 8..1000u32 {
354            assert_eq!(t255.level_start(i), None);
355        }
356    }
357
358    #[test]
359    fn huge_tree_sizes() {
360        assert!(TreeSize::try_from(MAX_TREE_SIZE - 1).is_ok());
361        // ensure it'll fit in a 32-bit integer, since MAX_TREE_SIZE*2 can overflow in 32-bit systems, depending on its value
362        let huge_tree_size = (((MAX_TREE_SIZE as u64) << 1) - 1) as usize;
363        assert_eq!(
364            TreeSize::try_from(huge_tree_size).unwrap_err(),
365            TreeSizeError::HugeTreeUnsupported(huge_tree_size as u64)
366        );
367    }
368
369    #[test]
370    fn iter_non_root_indices() {
371        let t1 = TreeSize::from_u32(1).unwrap();
372        assert_eq!(t1.iter_pairs_indices().count(), 0);
373
374        let t3 = TreeSize::from_u32(3).unwrap();
375        assert_eq!(t3.iter_pairs_indices().collect::<Vec<_>>(), vec![(0, 1)]);
376
377        let t7 = TreeSize::from_u32(7).unwrap();
378        assert_eq!(
379            t7.iter_pairs_indices().collect::<Vec<_>>(),
380            vec![(0, 1), (2, 3), (4, 5)]
381        );
382
383        let t15 = TreeSize::from_u32(15).unwrap();
384        assert_eq!(
385            t15.iter_pairs_indices().collect::<Vec<_>>(),
386            vec![(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (12, 13)]
387        );
388
389        let t31 = TreeSize::from_u32(31).unwrap();
390        assert_eq!(
391            t31.iter_pairs_indices().collect::<Vec<_>>(),
392            vec![
393                (0, 1),
394                (2, 3),
395                (4, 5),
396                (6, 7),
397                (8, 9),
398                (10, 11),
399                (12, 13),
400                (14, 15),
401                (16, 17),
402                (18, 19),
403                (20, 21),
404                (22, 23),
405                (24, 25),
406                (26, 27),
407                (28, 29),
408            ]
409        );
410
411        // Exhaustive... without this taking way too long
412        for i in 1..10_u32 {
413            let tree_size = TreeSize::from_u32((1 << i) - 1).unwrap();
414            assert_eq!(TreeSize::from_u32((1 << i) - 1), Ok(TreeSize((1 << i) - 1)));
415            assert_eq!(
416                tree_size.iter_pairs_indices().count() as u32,
417                tree_size.get() / 2
418            );
419            assert_eq!(
420                tree_size.iter_pairs_indices().collect::<Vec<_>>(),
421                (0..tree_size.get() / 2)
422                    .map(|i| (i * 2, i * 2 + 1))
423                    .collect::<Vec<_>>()
424            );
425        }
426    }
427}