commonware_storage/mmr/
position.rs

1use super::location::Location;
2use core::{
3    fmt,
4    ops::{Add, AddAssign, Deref, Sub, SubAssign},
5};
6
7/// Maximum valid [Position] value that can exist in a valid MMR.
8///
9/// This value corresponds to the last node in an MMR with the maximum number of leaves.
10pub const MAX_POSITION: Position = Position::new(0x7FFFFFFFFFFFFFFE); // (1 << 63) - 2
11
12/// A [Position] is an index into an MMR's nodes.
13/// This is in contrast to a [Location], which is an index into an MMR's _leaves_.
14#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Default, Debug)]
15pub struct Position(u64);
16
17#[cfg(feature = "arbitrary")]
18impl arbitrary::Arbitrary<'_> for Position {
19    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
20        let value = u.int_in_range(0..=MAX_POSITION.0)?;
21        Ok(Self(value))
22    }
23}
24
25impl Position {
26    /// Return a new [Position] from a raw `u64`.
27    #[inline]
28    pub const fn new(pos: u64) -> Self {
29        Self(pos)
30    }
31
32    /// Return the underlying `u64` value.
33    #[inline]
34    pub const fn as_u64(self) -> u64 {
35        self.0
36    }
37
38    /// Return `self + rhs` returning `None` on overflow or if result exceeds [MAX_POSITION].
39    #[inline]
40    pub const fn checked_add(self, rhs: u64) -> Option<Self> {
41        match self.0.checked_add(rhs) {
42            Some(value) => {
43                if value <= MAX_POSITION.0 {
44                    Some(Self(value))
45                } else {
46                    None
47                }
48            }
49            None => None,
50        }
51    }
52
53    /// Return `self - rhs` returning `None` on underflow.
54    #[inline]
55    pub const fn checked_sub(self, rhs: u64) -> Option<Self> {
56        match self.0.checked_sub(rhs) {
57            Some(value) => Some(Self(value)),
58            None => None,
59        }
60    }
61
62    /// Return `self + rhs` saturating at [MAX_POSITION].
63    #[inline]
64    pub const fn saturating_add(self, rhs: u64) -> Self {
65        let result = self.0.saturating_add(rhs);
66        if result > MAX_POSITION.0 {
67            MAX_POSITION
68        } else {
69            Self(result)
70        }
71    }
72
73    /// Return `self - rhs` saturating at zero.
74    #[inline]
75    pub const fn saturating_sub(self, rhs: u64) -> Self {
76        Self(self.0.saturating_sub(rhs))
77    }
78
79    /// Returns whether this is a valid MMR size.
80    ///
81    /// The implementation verifies that (1) the size won't result in overflow and (2) peaks in the
82    /// MMR of the given size have strictly decreasing height, which is a necessary condition for
83    /// MMR validity.
84    #[inline]
85    pub const fn is_mmr_size(self) -> bool {
86        if self.0 == 0 {
87            return true;
88        }
89        let leading_zeros = self.0.leading_zeros();
90        if leading_zeros == 0 {
91            // size overflow
92            return false;
93        }
94        let start = u64::MAX >> leading_zeros;
95        let mut two_h = 1 << start.trailing_ones();
96        let mut node_pos = start.checked_sub(1).expect("start > 0 because size != 0");
97        while two_h > 1 {
98            if node_pos < self.0 {
99                if two_h == 2 {
100                    // If this peak is a leaf yet there are more nodes remaining, then this MMR is
101                    // invalid.
102                    return node_pos == self.0 - 1;
103                }
104                // move to the right sibling
105                node_pos += two_h - 1;
106                if node_pos < self.0 {
107                    // If the right sibling is in the MMR, then it is invalid.
108                    return false;
109                }
110                continue;
111            }
112            // descend to the left child
113            two_h >>= 1;
114            node_pos -= two_h;
115        }
116        true
117    }
118}
119
120impl fmt::Display for Position {
121    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122        write!(f, "Position({})", self.0)
123    }
124}
125
126impl Deref for Position {
127    type Target = u64;
128    fn deref(&self) -> &Self::Target {
129        &self.0
130    }
131}
132
133impl AsRef<u64> for Position {
134    fn as_ref(&self) -> &u64 {
135        &self.0
136    }
137}
138
139impl From<u64> for Position {
140    #[inline]
141    fn from(value: u64) -> Self {
142        Self::new(value)
143    }
144}
145
146impl From<usize> for Position {
147    #[inline]
148    fn from(value: usize) -> Self {
149        Self::new(value as u64)
150    }
151}
152
153impl From<Position> for u64 {
154    #[inline]
155    fn from(position: Position) -> Self {
156        *position
157    }
158}
159
160/// Try to convert a [Location] to a [Position].
161///
162/// Returns an error if `loc` > [super::MAX_LOCATION].
163///
164/// # Examples
165///
166/// ```
167/// use commonware_storage::mmr::{Location, Position, MAX_LOCATION};
168/// use core::convert::TryFrom;
169///
170/// let loc = Location::new(5).unwrap();
171/// let pos = Position::try_from(loc).unwrap();
172/// assert_eq!(pos, Position::new(8));
173///
174/// // Invalid locations return error
175/// assert!(Location::new(MAX_LOCATION + 1).is_none());
176/// ```
177impl TryFrom<Location> for Position {
178    type Error = super::Error;
179
180    #[inline]
181    fn try_from(loc: Location) -> Result<Self, Self::Error> {
182        if !loc.is_valid() {
183            return Err(super::Error::LocationOverflow(loc));
184        }
185        // This will never underflow since 2*n >= count_ones(n).
186        let loc_val = *loc;
187        Ok(Self(
188            loc_val
189                .checked_mul(2)
190                .expect("should not overflow for valid location")
191                - loc_val.count_ones() as u64,
192        ))
193    }
194}
195
196/// Add two positions together.
197///
198/// # Panics
199///
200/// Panics if the result overflows.
201impl Add for Position {
202    type Output = Self;
203
204    #[inline]
205    fn add(self, rhs: Self) -> Self::Output {
206        Self(self.0 + rhs.0)
207    }
208}
209
210/// Add a position and a `u64`.
211///
212/// # Panics
213///
214/// Panics if the result overflows.
215impl Add<u64> for Position {
216    type Output = Self;
217
218    #[inline]
219    fn add(self, rhs: u64) -> Self::Output {
220        Self(self.0 + rhs)
221    }
222}
223
224/// Subtract two positions.
225///
226/// # Panics
227///
228/// Panics if the result underflows.
229impl Sub for Position {
230    type Output = Self;
231
232    #[inline]
233    fn sub(self, rhs: Self) -> Self::Output {
234        Self(self.0 - rhs.0)
235    }
236}
237
238/// Subtract a `u64` from a position.
239///
240/// # Panics
241///
242/// Panics if the result underflows.
243impl Sub<u64> for Position {
244    type Output = Self;
245
246    #[inline]
247    fn sub(self, rhs: u64) -> Self::Output {
248        Self(self.0 - rhs)
249    }
250}
251
252impl PartialEq<u64> for Position {
253    #[inline]
254    fn eq(&self, other: &u64) -> bool {
255        self.0 == *other
256    }
257}
258
259impl PartialOrd<u64> for Position {
260    #[inline]
261    fn partial_cmp(&self, other: &u64) -> Option<core::cmp::Ordering> {
262        self.0.partial_cmp(other)
263    }
264}
265
266// Allow u64 to be compared with Position too
267impl PartialEq<Position> for u64 {
268    #[inline]
269    fn eq(&self, other: &Position) -> bool {
270        *self == other.0
271    }
272}
273
274impl PartialOrd<Position> for u64 {
275    #[inline]
276    fn partial_cmp(&self, other: &Position) -> Option<core::cmp::Ordering> {
277        self.partial_cmp(&other.0)
278    }
279}
280
281/// Add a `u64` to a position.
282///
283/// # Panics
284///
285/// Panics if the result overflows.
286impl AddAssign<u64> for Position {
287    #[inline]
288    fn add_assign(&mut self, rhs: u64) {
289        self.0 += rhs;
290    }
291}
292
293/// Subtract a `u64` from a position.
294///
295/// # Panics
296///
297/// Panics if the result underflows.
298impl SubAssign<u64> for Position {
299    #[inline]
300    fn sub_assign(&mut self, rhs: u64) {
301        self.0 -= rhs;
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::{Location, Position};
308    use crate::mmr::{mem::CleanMmr, StandardHasher as Standard, MAX_LOCATION, MAX_POSITION};
309    use commonware_cryptography::Sha256;
310
311    // Test that the [Position::from] function returns the correct position for leaf locations.
312    #[test]
313    fn test_from_location() {
314        const CASES: &[(Location, Position)] = &[
315            (Location::new_unchecked(0), Position::new(0)),
316            (Location::new_unchecked(1), Position::new(1)),
317            (Location::new_unchecked(2), Position::new(3)),
318            (Location::new_unchecked(3), Position::new(4)),
319            (Location::new_unchecked(4), Position::new(7)),
320            (Location::new_unchecked(5), Position::new(8)),
321            (Location::new_unchecked(6), Position::new(10)),
322            (Location::new_unchecked(7), Position::new(11)),
323            (Location::new_unchecked(8), Position::new(15)),
324            (Location::new_unchecked(9), Position::new(16)),
325            (Location::new_unchecked(10), Position::new(18)),
326            (Location::new_unchecked(11), Position::new(19)),
327            (Location::new_unchecked(12), Position::new(22)),
328            (Location::new_unchecked(13), Position::new(23)),
329            (Location::new_unchecked(14), Position::new(25)),
330            (Location::new_unchecked(15), Position::new(26)),
331        ];
332        for (loc, expected_pos) in CASES {
333            let pos = Position::try_from(*loc).unwrap();
334            assert_eq!(pos, *expected_pos);
335        }
336    }
337
338    #[test]
339    fn test_checked_add() {
340        let pos = Position::new(10);
341        assert_eq!(pos.checked_add(5).unwrap(), 15);
342
343        // Overflow returns None
344        assert!(Position::new(u64::MAX).checked_add(1).is_none());
345
346        // Exceeding MAX_POSITION returns None
347        assert!(MAX_POSITION.checked_add(1).is_none());
348        assert!(Position::new(*MAX_POSITION - 5).checked_add(10).is_none());
349
350        // At MAX_POSITION is OK
351        assert_eq!(
352            Position::new(*MAX_POSITION - 10).checked_add(10).unwrap(),
353            MAX_POSITION
354        );
355    }
356
357    #[test]
358    fn test_checked_sub() {
359        let pos = Position::new(10);
360        assert_eq!(pos.checked_sub(5).unwrap(), 5);
361        assert!(pos.checked_sub(11).is_none());
362    }
363
364    #[test]
365    fn test_saturating_add() {
366        let pos = Position::new(10);
367        assert_eq!(pos.saturating_add(5), 15);
368
369        // Saturates at MAX_POSITION, not u64::MAX
370        assert_eq!(Position::new(u64::MAX).saturating_add(1), MAX_POSITION);
371        assert_eq!(MAX_POSITION.saturating_add(1), MAX_POSITION);
372        assert_eq!(MAX_POSITION.saturating_add(1000), MAX_POSITION);
373        assert_eq!(
374            Position::new(*MAX_POSITION - 5).saturating_add(10),
375            MAX_POSITION
376        );
377    }
378
379    #[test]
380    fn test_saturating_sub() {
381        let pos = Position::new(10);
382        assert_eq!(pos.saturating_sub(5), 5);
383        assert_eq!(Position::new(0).saturating_sub(1), 0);
384    }
385
386    #[test]
387    fn test_display() {
388        let position = Position::new(42);
389        assert_eq!(position.to_string(), "Position(42)");
390    }
391
392    #[test]
393    fn test_add() {
394        let pos1 = Position::new(10);
395        let pos2 = Position::new(5);
396        assert_eq!((pos1 + pos2), 15);
397    }
398
399    #[test]
400    fn test_sub() {
401        let pos1 = Position::new(10);
402        let pos2 = Position::new(3);
403        assert_eq!((pos1 - pos2), 7);
404    }
405
406    #[test]
407    fn test_comparison_with_u64() {
408        let pos = Position::new(42);
409
410        // Test equality
411        assert_eq!(pos, 42u64);
412        assert_eq!(42u64, pos);
413        assert_ne!(pos, 43u64);
414        assert_ne!(43u64, pos);
415
416        // Test ordering
417        assert!(pos < 43u64);
418        assert!(43u64 > pos);
419        assert!(pos > 41u64);
420        assert!(41u64 < pos);
421        assert!(pos <= 42u64);
422        assert!(42u64 >= pos);
423    }
424
425    #[test]
426    fn test_assignment_with_u64() {
427        let mut pos = Position::new(10);
428
429        // Test add assignment
430        pos += 5;
431        assert_eq!(pos, 15u64);
432
433        // Test sub assignment
434        pos -= 3;
435        assert_eq!(pos, 12u64);
436    }
437
438    #[test]
439    fn test_max_position() {
440        // The constraint is: MMR_size must have top bit clear (< 2^63)
441        // For N leaves: MMR_size = 2*N - popcount(N)
442        // Worst case (maximum size) is when N is a power of 2: MMR_size = 2*N - 1
443
444        // Maximum N where 2*N - 1 < 2^63:
445        //   2*N - 1 < 2^63
446        //   2*N < 2^63 + 1
447        //   N <= 2^62
448        let max_leaves = 1u64 << 62;
449
450        // For N = 2^62 leaves:
451        // MMR_size = 2 * 2^62 - 1 = 2^63 - 1
452        let mmr_size_at_max = 2 * max_leaves - 1;
453        assert_eq!(mmr_size_at_max, (1u64 << 63) - 1);
454        assert_eq!(mmr_size_at_max.leading_zeros(), 1); // Top bit clear ✓
455
456        // Last position (0-indexed) = MMR_size - 1 = 2^63 - 2
457        let expected_max_pos = mmr_size_at_max - 1;
458        assert_eq!(MAX_POSITION, expected_max_pos);
459        assert_eq!(MAX_POSITION, (1u64 << 63) - 2);
460
461        // Verify the constraint: a position at MAX_POSITION + 1 would require
462        // an MMR_size >= 2^63, which violates the "top bit clear" requirement
463        let hypothetical_mmr_size = MAX_POSITION + 2; // Would need this many nodes
464        assert_eq!(hypothetical_mmr_size, 1u64 << 63);
465        assert_eq!(hypothetical_mmr_size.leading_zeros(), 0); // Top bit NOT clear ✗
466
467        // Verify relationship with MAX_LOCATION
468        // Converting MAX_LOCATION to position should give a value < MAX_POSITION
469        let max_loc = Location::new_unchecked(MAX_LOCATION);
470        let last_leaf_pos = Position::try_from(max_loc).unwrap();
471        assert!(*last_leaf_pos < MAX_POSITION);
472    }
473
474    #[test]
475    fn test_is_mmr_size() {
476        // Build an MMR one node at a time and check that the validity check is correct for all
477        // sizes up to the current size.
478        let mut size_to_check = Position::new(0);
479        let mut hasher = Standard::<Sha256>::new();
480        let mut mmr = CleanMmr::new(&mut hasher);
481        let digest = [1u8; 32];
482        for _i in 0..10000 {
483            while size_to_check != mmr.size() {
484                assert!(
485                    !size_to_check.is_mmr_size(),
486                    "size_to_check: {} {}",
487                    size_to_check,
488                    mmr.size()
489                );
490                size_to_check += 1;
491            }
492            assert!(size_to_check.is_mmr_size());
493            mmr.add(&mut hasher, &digest);
494            size_to_check += 1;
495        }
496
497        // Test overflow boundaries.
498        assert!(!Position::new(u64::MAX).is_mmr_size());
499        assert!(Position::new(u64::MAX >> 1).is_mmr_size());
500        assert!(!Position::new((u64::MAX >> 1) + 1).is_mmr_size());
501        assert!(!MAX_POSITION.is_mmr_size());
502    }
503}