commonware_storage/mmr/
position.rs

1use super::location::Location;
2use core::{
3    fmt,
4    ops::{Add, AddAssign, Deref, Sub, SubAssign},
5};
6
7/// Maximum valid [Position] value that can exist in a valid MMR.
8///
9/// This value corresponds to the last node in an MMR with the maximum number of leaves.
10pub const MAX_POSITION: Position = Position::new(0x7FFFFFFFFFFFFFFE); // (1 << 63) - 2
11
12/// A [Position] is an index into an MMR's nodes.
13/// This is in contrast to a [Location], which is an index into an MMR's _leaves_.
14#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Default, Debug)]
15pub struct Position(u64);
16
17impl Position {
18    /// Return a new [Position] from a raw `u64`.
19    #[inline]
20    pub const fn new(pos: u64) -> Self {
21        Self(pos)
22    }
23
24    /// Return the underlying `u64` value.
25    #[inline]
26    pub const fn as_u64(self) -> u64 {
27        self.0
28    }
29
30    /// Return `self + rhs` returning `None` on overflow or if result exceeds [MAX_POSITION].
31    #[inline]
32    pub const fn checked_add(self, rhs: u64) -> Option<Self> {
33        match self.0.checked_add(rhs) {
34            Some(value) => {
35                if value <= MAX_POSITION.0 {
36                    Some(Self(value))
37                } else {
38                    None
39                }
40            }
41            None => None,
42        }
43    }
44
45    /// Return `self - rhs` returning `None` on underflow.
46    #[inline]
47    pub const fn checked_sub(self, rhs: u64) -> Option<Self> {
48        match self.0.checked_sub(rhs) {
49            Some(value) => Some(Self(value)),
50            None => None,
51        }
52    }
53
54    /// Return `self + rhs` saturating at [MAX_POSITION].
55    #[inline]
56    pub const fn saturating_add(self, rhs: u64) -> Self {
57        let result = self.0.saturating_add(rhs);
58        if result > MAX_POSITION.0 {
59            MAX_POSITION
60        } else {
61            Self(result)
62        }
63    }
64
65    /// Return `self - rhs` saturating at zero.
66    #[inline]
67    pub const fn saturating_sub(self, rhs: u64) -> Self {
68        Self(self.0.saturating_sub(rhs))
69    }
70
71    /// Returns whether this is a valid MMR size.
72    ///
73    /// The implementation verifies that (1) the size won't result in overflow and (2) peaks in the
74    /// MMR of the given size have strictly decreasing height, which is a necessary condition for
75    /// MMR validity.
76    #[inline]
77    pub const fn is_mmr_size(self) -> bool {
78        if self.0 == 0 {
79            return true;
80        }
81        let leading_zeros = self.0.leading_zeros();
82        if leading_zeros == 0 {
83            // size overflow
84            return false;
85        }
86        let start = u64::MAX >> leading_zeros;
87        let mut two_h = 1 << start.trailing_ones();
88        let mut node_pos = start.checked_sub(1).expect("start > 0 because size != 0");
89        while two_h > 1 {
90            if node_pos < self.0 {
91                if two_h == 2 {
92                    // If this peak is a leaf yet there are more nodes remaining, then this MMR is
93                    // invalid.
94                    return node_pos == self.0 - 1;
95                }
96                // move to the right sibling
97                node_pos += two_h - 1;
98                if node_pos < self.0 {
99                    // If the right sibling is in the MMR, then it is invalid.
100                    return false;
101                }
102                continue;
103            }
104            // descend to the left child
105            two_h >>= 1;
106            node_pos -= two_h;
107        }
108        true
109    }
110}
111
112impl fmt::Display for Position {
113    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114        write!(f, "Position({})", self.0)
115    }
116}
117
118impl Deref for Position {
119    type Target = u64;
120    fn deref(&self) -> &Self::Target {
121        &self.0
122    }
123}
124
125impl AsRef<u64> for Position {
126    fn as_ref(&self) -> &u64 {
127        &self.0
128    }
129}
130
131impl From<u64> for Position {
132    #[inline]
133    fn from(value: u64) -> Self {
134        Self::new(value)
135    }
136}
137
138impl From<usize> for Position {
139    #[inline]
140    fn from(value: usize) -> Self {
141        Self::new(value as u64)
142    }
143}
144
145impl From<Position> for u64 {
146    #[inline]
147    fn from(position: Position) -> Self {
148        *position
149    }
150}
151
152/// Try to convert a [Location] to a [Position].
153///
154/// Returns an error if `loc` > [super::MAX_LOCATION].
155///
156/// # Examples
157///
158/// ```
159/// use commonware_storage::mmr::{Location, Position, MAX_LOCATION};
160/// use core::convert::TryFrom;
161///
162/// let loc = Location::new(5).unwrap();
163/// let pos = Position::try_from(loc).unwrap();
164/// assert_eq!(pos, Position::new(8));
165///
166/// // Invalid locations return error  
167/// assert!(Location::new(MAX_LOCATION + 1).is_none());
168/// ```
169impl TryFrom<Location> for Position {
170    type Error = super::Error;
171
172    #[inline]
173    fn try_from(loc: Location) -> Result<Self, Self::Error> {
174        if !loc.is_valid() {
175            return Err(super::Error::LocationOverflow(loc));
176        }
177        // This will never underflow since 2*n >= count_ones(n).
178        let loc_val = *loc;
179        Ok(Self(
180            loc_val
181                .checked_mul(2)
182                .expect("should not overflow for valid location")
183                - loc_val.count_ones() as u64,
184        ))
185    }
186}
187
188/// Add two positions together.
189///
190/// # Panics
191///
192/// Panics if the result overflows.
193impl Add for Position {
194    type Output = Self;
195
196    #[inline]
197    fn add(self, rhs: Self) -> Self::Output {
198        Self(self.0 + rhs.0)
199    }
200}
201
202/// Add a position and a `u64`.
203///
204/// # Panics
205///
206/// Panics if the result overflows.
207impl Add<u64> for Position {
208    type Output = Self;
209
210    #[inline]
211    fn add(self, rhs: u64) -> Self::Output {
212        Self(self.0 + rhs)
213    }
214}
215
216/// Subtract two positions.
217///
218/// # Panics
219///
220/// Panics if the result underflows.
221impl Sub for Position {
222    type Output = Self;
223
224    #[inline]
225    fn sub(self, rhs: Self) -> Self::Output {
226        Self(self.0 - rhs.0)
227    }
228}
229
230/// Subtract a `u64` from a position.
231///
232/// # Panics
233///
234/// Panics if the result underflows.
235impl Sub<u64> for Position {
236    type Output = Self;
237
238    #[inline]
239    fn sub(self, rhs: u64) -> Self::Output {
240        Self(self.0 - rhs)
241    }
242}
243
244impl PartialEq<u64> for Position {
245    #[inline]
246    fn eq(&self, other: &u64) -> bool {
247        self.0 == *other
248    }
249}
250
251impl PartialOrd<u64> for Position {
252    #[inline]
253    fn partial_cmp(&self, other: &u64) -> Option<core::cmp::Ordering> {
254        self.0.partial_cmp(other)
255    }
256}
257
258// Allow u64 to be compared with Position too
259impl PartialEq<Position> for u64 {
260    #[inline]
261    fn eq(&self, other: &Position) -> bool {
262        *self == other.0
263    }
264}
265
266impl PartialOrd<Position> for u64 {
267    #[inline]
268    fn partial_cmp(&self, other: &Position) -> Option<core::cmp::Ordering> {
269        self.partial_cmp(&other.0)
270    }
271}
272
273/// Add a `u64` to a position.
274///
275/// # Panics
276///
277/// Panics if the result overflows.
278impl AddAssign<u64> for Position {
279    #[inline]
280    fn add_assign(&mut self, rhs: u64) {
281        self.0 += rhs;
282    }
283}
284
285/// Subtract a `u64` from a position.
286///
287/// # Panics
288///
289/// Panics if the result underflows.
290impl SubAssign<u64> for Position {
291    #[inline]
292    fn sub_assign(&mut self, rhs: u64) {
293        self.0 -= rhs;
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::{Location, Position};
300    use crate::mmr::{mem::Mmr, StandardHasher as Standard, MAX_LOCATION, MAX_POSITION};
301    use commonware_cryptography::Sha256;
302
303    // Test that the [Position::from] function returns the correct position for leaf locations.
304    #[test]
305    fn test_from_location() {
306        const CASES: &[(Location, Position)] = &[
307            (Location::new_unchecked(0), Position::new(0)),
308            (Location::new_unchecked(1), Position::new(1)),
309            (Location::new_unchecked(2), Position::new(3)),
310            (Location::new_unchecked(3), Position::new(4)),
311            (Location::new_unchecked(4), Position::new(7)),
312            (Location::new_unchecked(5), Position::new(8)),
313            (Location::new_unchecked(6), Position::new(10)),
314            (Location::new_unchecked(7), Position::new(11)),
315            (Location::new_unchecked(8), Position::new(15)),
316            (Location::new_unchecked(9), Position::new(16)),
317            (Location::new_unchecked(10), Position::new(18)),
318            (Location::new_unchecked(11), Position::new(19)),
319            (Location::new_unchecked(12), Position::new(22)),
320            (Location::new_unchecked(13), Position::new(23)),
321            (Location::new_unchecked(14), Position::new(25)),
322            (Location::new_unchecked(15), Position::new(26)),
323        ];
324        for (loc, expected_pos) in CASES {
325            let pos = Position::try_from(*loc).unwrap();
326            assert_eq!(pos, *expected_pos);
327        }
328    }
329
330    #[test]
331    fn test_checked_add() {
332        let pos = Position::new(10);
333        assert_eq!(pos.checked_add(5).unwrap(), 15);
334
335        // Overflow returns None
336        assert!(Position::new(u64::MAX).checked_add(1).is_none());
337
338        // Exceeding MAX_POSITION returns None
339        assert!(MAX_POSITION.checked_add(1).is_none());
340        assert!(Position::new(*MAX_POSITION - 5).checked_add(10).is_none());
341
342        // At MAX_POSITION is OK
343        assert_eq!(
344            Position::new(*MAX_POSITION - 10).checked_add(10).unwrap(),
345            MAX_POSITION
346        );
347    }
348
349    #[test]
350    fn test_checked_sub() {
351        let pos = Position::new(10);
352        assert_eq!(pos.checked_sub(5).unwrap(), 5);
353        assert!(pos.checked_sub(11).is_none());
354    }
355
356    #[test]
357    fn test_saturating_add() {
358        let pos = Position::new(10);
359        assert_eq!(pos.saturating_add(5), 15);
360
361        // Saturates at MAX_POSITION, not u64::MAX
362        assert_eq!(Position::new(u64::MAX).saturating_add(1), MAX_POSITION);
363        assert_eq!(MAX_POSITION.saturating_add(1), MAX_POSITION);
364        assert_eq!(MAX_POSITION.saturating_add(1000), MAX_POSITION);
365        assert_eq!(
366            Position::new(*MAX_POSITION - 5).saturating_add(10),
367            MAX_POSITION
368        );
369    }
370
371    #[test]
372    fn test_saturating_sub() {
373        let pos = Position::new(10);
374        assert_eq!(pos.saturating_sub(5), 5);
375        assert_eq!(Position::new(0).saturating_sub(1), 0);
376    }
377
378    #[test]
379    fn test_display() {
380        let position = Position::new(42);
381        assert_eq!(position.to_string(), "Position(42)");
382    }
383
384    #[test]
385    fn test_add() {
386        let pos1 = Position::new(10);
387        let pos2 = Position::new(5);
388        assert_eq!((pos1 + pos2), 15);
389    }
390
391    #[test]
392    fn test_sub() {
393        let pos1 = Position::new(10);
394        let pos2 = Position::new(3);
395        assert_eq!((pos1 - pos2), 7);
396    }
397
398    #[test]
399    fn test_comparison_with_u64() {
400        let pos = Position::new(42);
401
402        // Test equality
403        assert_eq!(pos, 42u64);
404        assert_eq!(42u64, pos);
405        assert_ne!(pos, 43u64);
406        assert_ne!(43u64, pos);
407
408        // Test ordering
409        assert!(pos < 43u64);
410        assert!(43u64 > pos);
411        assert!(pos > 41u64);
412        assert!(41u64 < pos);
413        assert!(pos <= 42u64);
414        assert!(42u64 >= pos);
415    }
416
417    #[test]
418    fn test_assignment_with_u64() {
419        let mut pos = Position::new(10);
420
421        // Test add assignment
422        pos += 5;
423        assert_eq!(pos, 15u64);
424
425        // Test sub assignment
426        pos -= 3;
427        assert_eq!(pos, 12u64);
428    }
429
430    #[test]
431    fn test_max_position() {
432        // The constraint is: MMR_size must have top bit clear (< 2^63)
433        // For N leaves: MMR_size = 2*N - popcount(N)
434        // Worst case (maximum size) is when N is a power of 2: MMR_size = 2*N - 1
435
436        // Maximum N where 2*N - 1 < 2^63:
437        //   2*N - 1 < 2^63
438        //   2*N < 2^63 + 1
439        //   N <= 2^62
440        let max_leaves = 1u64 << 62;
441
442        // For N = 2^62 leaves:
443        // MMR_size = 2 * 2^62 - 1 = 2^63 - 1
444        let mmr_size_at_max = 2 * max_leaves - 1;
445        assert_eq!(mmr_size_at_max, (1u64 << 63) - 1);
446        assert_eq!(mmr_size_at_max.leading_zeros(), 1); // Top bit clear ✓
447
448        // Last position (0-indexed) = MMR_size - 1 = 2^63 - 2
449        let expected_max_pos = mmr_size_at_max - 1;
450        assert_eq!(MAX_POSITION, expected_max_pos);
451        assert_eq!(MAX_POSITION, (1u64 << 63) - 2);
452
453        // Verify the constraint: a position at MAX_POSITION + 1 would require
454        // an MMR_size >= 2^63, which violates the "top bit clear" requirement
455        let hypothetical_mmr_size = MAX_POSITION + 2; // Would need this many nodes
456        assert_eq!(hypothetical_mmr_size, 1u64 << 63);
457        assert_eq!(hypothetical_mmr_size.leading_zeros(), 0); // Top bit NOT clear ✗
458
459        // Verify relationship with MAX_LOCATION
460        // Converting MAX_LOCATION to position should give a value < MAX_POSITION
461        let max_loc = Location::new_unchecked(MAX_LOCATION);
462        let last_leaf_pos = Position::try_from(max_loc).unwrap();
463        assert!(*last_leaf_pos < MAX_POSITION);
464    }
465
466    #[test]
467    fn test_is_mmr_size() {
468        // Build an MMR one node at a time and check that the validity check is correct for all
469        // sizes up to the current size.
470        let mut mmr = Mmr::new();
471        let mut size_to_check = Position::new(0);
472        let mut hasher = Standard::<Sha256>::new();
473        let digest = [1u8; 32];
474        for _i in 0..10000 {
475            while size_to_check != mmr.size() {
476                assert!(
477                    !size_to_check.is_mmr_size(),
478                    "size_to_check: {} {}",
479                    size_to_check,
480                    mmr.size()
481                );
482                size_to_check += 1;
483            }
484            assert!(size_to_check.is_mmr_size());
485            mmr.add(&mut hasher, &digest);
486            size_to_check += 1;
487        }
488
489        // Test overflow boundaries.
490        assert!(!Position::new(u64::MAX).is_mmr_size());
491        assert!(Position::new(u64::MAX >> 1).is_mmr_size());
492        assert!(!Position::new((u64::MAX >> 1) + 1).is_mmr_size());
493        assert!(!MAX_POSITION.is_mmr_size());
494    }
495}