Skip to main content

commonware_storage/mmr/
position.rs

1use super::location::Location;
2use bytes::{Buf, BufMut};
3use commonware_codec::ReadExt;
4use core::{
5    fmt,
6    ops::{Add, AddAssign, Deref, Sub, SubAssign},
7};
8
9/// Maximum valid [Position] value that can exist in a valid MMR.
10///
11/// This value corresponds to the last node in an MMR with the maximum number of leaves.
12pub const MAX_POSITION: Position = Position::new(0x7FFFFFFFFFFFFFFE); // (1 << 63) - 2
13
14/// A [Position] is an index into an MMR's nodes.
15/// This is in contrast to a [Location], which is an index into an MMR's _leaves_.
16#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Default, Debug)]
17pub struct Position(u64);
18
19#[cfg(feature = "arbitrary")]
20impl arbitrary::Arbitrary<'_> for Position {
21    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
22        let value = u.int_in_range(0..=MAX_POSITION.0)?;
23        Ok(Self(value))
24    }
25}
26
27impl Position {
28    /// Return a new [Position] from a raw `u64`.
29    #[inline]
30    pub const fn new(pos: u64) -> Self {
31        Self(pos)
32    }
33
34    /// Return the underlying `u64` value.
35    #[inline]
36    pub const fn as_u64(self) -> u64 {
37        self.0
38    }
39
40    /// Return `self + rhs` returning `None` on overflow or if result exceeds [MAX_POSITION].
41    #[inline]
42    pub const fn checked_add(self, rhs: u64) -> Option<Self> {
43        match self.0.checked_add(rhs) {
44            Some(value) => {
45                if value <= MAX_POSITION.0 {
46                    Some(Self(value))
47                } else {
48                    None
49                }
50            }
51            None => None,
52        }
53    }
54
55    /// Return `self - rhs` returning `None` on underflow.
56    #[inline]
57    pub const fn checked_sub(self, rhs: u64) -> Option<Self> {
58        match self.0.checked_sub(rhs) {
59            Some(value) => Some(Self(value)),
60            None => None,
61        }
62    }
63
64    /// Return `self + rhs` saturating at [MAX_POSITION].
65    #[inline]
66    pub const fn saturating_add(self, rhs: u64) -> Self {
67        let result = self.0.saturating_add(rhs);
68        if result > MAX_POSITION.0 {
69            MAX_POSITION
70        } else {
71            Self(result)
72        }
73    }
74
75    /// Return `self - rhs` saturating at zero.
76    #[inline]
77    pub const fn saturating_sub(self, rhs: u64) -> Self {
78        Self(self.0.saturating_sub(rhs))
79    }
80
81    /// Returns whether this is a valid MMR size.
82    ///
83    /// The implementation verifies that (1) the size won't result in overflow and (2) peaks in the
84    /// MMR of the given size have strictly decreasing height, which is a necessary condition for
85    /// MMR validity.
86    #[inline]
87    pub const fn is_mmr_size(self) -> bool {
88        if self.0 == 0 {
89            return true;
90        }
91        let leading_zeros = self.0.leading_zeros();
92        if leading_zeros == 0 {
93            // size overflow
94            return false;
95        }
96        let start = u64::MAX >> leading_zeros;
97        let mut two_h = 1 << start.trailing_ones();
98        let mut node_pos = start.checked_sub(1).expect("start > 0 because size != 0");
99        while two_h > 1 {
100            if node_pos < self.0 {
101                if two_h == 2 {
102                    // If this peak is a leaf yet there are more nodes remaining, then this MMR is
103                    // invalid.
104                    return node_pos == self.0 - 1;
105                }
106                // move to the right sibling
107                node_pos += two_h - 1;
108                if node_pos < self.0 {
109                    // If the right sibling is in the MMR, then it is invalid.
110                    return false;
111                }
112                continue;
113            }
114            // descend to the left child
115            two_h >>= 1;
116            node_pos -= two_h;
117        }
118        true
119    }
120}
121
122impl fmt::Display for Position {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        write!(f, "Position({})", self.0)
125    }
126}
127
128impl Deref for Position {
129    type Target = u64;
130    fn deref(&self) -> &Self::Target {
131        &self.0
132    }
133}
134
135impl AsRef<u64> for Position {
136    fn as_ref(&self) -> &u64 {
137        &self.0
138    }
139}
140
141impl From<u64> for Position {
142    #[inline]
143    fn from(value: u64) -> Self {
144        Self::new(value)
145    }
146}
147
148impl From<usize> for Position {
149    #[inline]
150    fn from(value: usize) -> Self {
151        Self::new(value as u64)
152    }
153}
154
155impl From<Position> for u64 {
156    #[inline]
157    fn from(position: Position) -> Self {
158        *position
159    }
160}
161
162/// Try to convert a [Location] to a [Position].
163///
164/// Returns an error if `loc` > [super::MAX_LOCATION].
165///
166/// # Examples
167///
168/// ```
169/// use commonware_storage::mmr::{Location, Position, MAX_LOCATION};
170/// use core::convert::TryFrom;
171///
172/// let loc = Location::new(5).unwrap();
173/// let pos = Position::try_from(loc).unwrap();
174/// assert_eq!(pos, Position::new(8));
175///
176/// // Invalid locations return error
177/// assert!(Location::new(MAX_LOCATION + 1).is_none());
178/// ```
179impl TryFrom<Location> for Position {
180    type Error = super::Error;
181
182    #[inline]
183    fn try_from(loc: Location) -> Result<Self, Self::Error> {
184        if !loc.is_valid() {
185            return Err(super::Error::LocationOverflow(loc));
186        }
187        // This will never underflow since 2*n >= count_ones(n).
188        let loc_val = *loc;
189        Ok(Self(
190            loc_val
191                .checked_mul(2)
192                .expect("should not overflow for valid location")
193                - loc_val.count_ones() as u64,
194        ))
195    }
196}
197
198/// Add two positions together.
199///
200/// # Panics
201///
202/// Panics if the result overflows.
203impl Add for Position {
204    type Output = Self;
205
206    #[inline]
207    fn add(self, rhs: Self) -> Self::Output {
208        Self(self.0 + rhs.0)
209    }
210}
211
212/// Add a position and a `u64`.
213///
214/// # Panics
215///
216/// Panics if the result overflows.
217impl Add<u64> for Position {
218    type Output = Self;
219
220    #[inline]
221    fn add(self, rhs: u64) -> Self::Output {
222        Self(self.0 + rhs)
223    }
224}
225
226/// Subtract two positions.
227///
228/// # Panics
229///
230/// Panics if the result underflows.
231impl Sub for Position {
232    type Output = Self;
233
234    #[inline]
235    fn sub(self, rhs: Self) -> Self::Output {
236        Self(self.0 - rhs.0)
237    }
238}
239
240/// Subtract a `u64` from a position.
241///
242/// # Panics
243///
244/// Panics if the result underflows.
245impl Sub<u64> for Position {
246    type Output = Self;
247
248    #[inline]
249    fn sub(self, rhs: u64) -> Self::Output {
250        Self(self.0 - rhs)
251    }
252}
253
254impl PartialEq<u64> for Position {
255    #[inline]
256    fn eq(&self, other: &u64) -> bool {
257        self.0 == *other
258    }
259}
260
261impl PartialOrd<u64> for Position {
262    #[inline]
263    fn partial_cmp(&self, other: &u64) -> Option<core::cmp::Ordering> {
264        self.0.partial_cmp(other)
265    }
266}
267
268// Allow u64 to be compared with Position too
269impl PartialEq<Position> for u64 {
270    #[inline]
271    fn eq(&self, other: &Position) -> bool {
272        *self == other.0
273    }
274}
275
276impl PartialOrd<Position> for u64 {
277    #[inline]
278    fn partial_cmp(&self, other: &Position) -> Option<core::cmp::Ordering> {
279        self.partial_cmp(&other.0)
280    }
281}
282
283/// Add a `u64` to a position.
284///
285/// # Panics
286///
287/// Panics if the result overflows.
288impl AddAssign<u64> for Position {
289    #[inline]
290    fn add_assign(&mut self, rhs: u64) {
291        self.0 += rhs;
292    }
293}
294
295/// Subtract a `u64` from a position.
296///
297/// # Panics
298///
299/// Panics if the result underflows.
300impl SubAssign<u64> for Position {
301    #[inline]
302    fn sub_assign(&mut self, rhs: u64) {
303        self.0 -= rhs;
304    }
305}
306
307// Codec implementations using varint encoding for efficient storage
308impl commonware_codec::Write for Position {
309    #[inline]
310    fn write(&self, buf: &mut impl BufMut) {
311        commonware_codec::varint::UInt(self.0).write(buf);
312    }
313}
314
315impl commonware_codec::EncodeSize for Position {
316    #[inline]
317    fn encode_size(&self) -> usize {
318        commonware_codec::varint::UInt(self.0).encode_size()
319    }
320}
321
322impl commonware_codec::Read for Position {
323    type Cfg = ();
324
325    #[inline]
326    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, commonware_codec::Error> {
327        let value: u64 = commonware_codec::varint::UInt::read(buf)?.into();
328        if value <= MAX_POSITION.0 {
329            Ok(Self(value))
330        } else {
331            Err(commonware_codec::Error::Invalid(
332                "Position",
333                "value exceeds MAX_POSITION",
334            ))
335        }
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use super::{Location, Position};
342    use crate::mmr::{mem::DirtyMmr, StandardHasher as Standard, MAX_LOCATION, MAX_POSITION};
343    use commonware_cryptography::Sha256;
344
345    // Test that the [Position::from] function returns the correct position for leaf locations.
346    #[test]
347    fn test_from_location() {
348        const CASES: &[(Location, Position)] = &[
349            (Location::new_unchecked(0), Position::new(0)),
350            (Location::new_unchecked(1), Position::new(1)),
351            (Location::new_unchecked(2), Position::new(3)),
352            (Location::new_unchecked(3), Position::new(4)),
353            (Location::new_unchecked(4), Position::new(7)),
354            (Location::new_unchecked(5), Position::new(8)),
355            (Location::new_unchecked(6), Position::new(10)),
356            (Location::new_unchecked(7), Position::new(11)),
357            (Location::new_unchecked(8), Position::new(15)),
358            (Location::new_unchecked(9), Position::new(16)),
359            (Location::new_unchecked(10), Position::new(18)),
360            (Location::new_unchecked(11), Position::new(19)),
361            (Location::new_unchecked(12), Position::new(22)),
362            (Location::new_unchecked(13), Position::new(23)),
363            (Location::new_unchecked(14), Position::new(25)),
364            (Location::new_unchecked(15), Position::new(26)),
365        ];
366        for (loc, expected_pos) in CASES {
367            let pos = Position::try_from(*loc).unwrap();
368            assert_eq!(pos, *expected_pos);
369        }
370    }
371
372    #[test]
373    fn test_checked_add() {
374        let pos = Position::new(10);
375        assert_eq!(pos.checked_add(5).unwrap(), 15);
376
377        // Overflow returns None
378        assert!(Position::new(u64::MAX).checked_add(1).is_none());
379
380        // Exceeding MAX_POSITION returns None
381        assert!(MAX_POSITION.checked_add(1).is_none());
382        assert!(Position::new(*MAX_POSITION - 5).checked_add(10).is_none());
383
384        // At MAX_POSITION is OK
385        assert_eq!(
386            Position::new(*MAX_POSITION - 10).checked_add(10).unwrap(),
387            MAX_POSITION
388        );
389    }
390
391    #[test]
392    fn test_checked_sub() {
393        let pos = Position::new(10);
394        assert_eq!(pos.checked_sub(5).unwrap(), 5);
395        assert!(pos.checked_sub(11).is_none());
396    }
397
398    #[test]
399    fn test_saturating_add() {
400        let pos = Position::new(10);
401        assert_eq!(pos.saturating_add(5), 15);
402
403        // Saturates at MAX_POSITION, not u64::MAX
404        assert_eq!(Position::new(u64::MAX).saturating_add(1), MAX_POSITION);
405        assert_eq!(MAX_POSITION.saturating_add(1), MAX_POSITION);
406        assert_eq!(MAX_POSITION.saturating_add(1000), MAX_POSITION);
407        assert_eq!(
408            Position::new(*MAX_POSITION - 5).saturating_add(10),
409            MAX_POSITION
410        );
411    }
412
413    #[test]
414    fn test_saturating_sub() {
415        let pos = Position::new(10);
416        assert_eq!(pos.saturating_sub(5), 5);
417        assert_eq!(Position::new(0).saturating_sub(1), 0);
418    }
419
420    #[test]
421    fn test_display() {
422        let position = Position::new(42);
423        assert_eq!(position.to_string(), "Position(42)");
424    }
425
426    #[test]
427    fn test_add() {
428        let pos1 = Position::new(10);
429        let pos2 = Position::new(5);
430        assert_eq!((pos1 + pos2), 15);
431    }
432
433    #[test]
434    fn test_sub() {
435        let pos1 = Position::new(10);
436        let pos2 = Position::new(3);
437        assert_eq!((pos1 - pos2), 7);
438    }
439
440    #[test]
441    fn test_comparison_with_u64() {
442        let pos = Position::new(42);
443
444        // Test equality
445        assert_eq!(pos, 42u64);
446        assert_eq!(42u64, pos);
447        assert_ne!(pos, 43u64);
448        assert_ne!(43u64, pos);
449
450        // Test ordering
451        assert!(pos < 43u64);
452        assert!(43u64 > pos);
453        assert!(pos > 41u64);
454        assert!(41u64 < pos);
455        assert!(pos <= 42u64);
456        assert!(42u64 >= pos);
457    }
458
459    #[test]
460    fn test_assignment_with_u64() {
461        let mut pos = Position::new(10);
462
463        // Test add assignment
464        pos += 5;
465        assert_eq!(pos, 15u64);
466
467        // Test sub assignment
468        pos -= 3;
469        assert_eq!(pos, 12u64);
470    }
471
472    #[test]
473    fn test_max_position() {
474        // The constraint is: MMR_size must have top bit clear (< 2^63)
475        // For N leaves: MMR_size = 2*N - popcount(N)
476        // Worst case (maximum size) is when N is a power of 2: MMR_size = 2*N - 1
477
478        // Maximum N where 2*N - 1 < 2^63:
479        //   2*N - 1 < 2^63
480        //   2*N < 2^63 + 1
481        //   N <= 2^62
482        let max_leaves = 1u64 << 62;
483
484        // For N = 2^62 leaves:
485        // MMR_size = 2 * 2^62 - 1 = 2^63 - 1
486        let mmr_size_at_max = 2 * max_leaves - 1;
487        assert_eq!(mmr_size_at_max, (1u64 << 63) - 1);
488        assert_eq!(mmr_size_at_max.leading_zeros(), 1); // Top bit clear ✓
489
490        // Last position (0-indexed) = MMR_size - 1 = 2^63 - 2
491        let expected_max_pos = mmr_size_at_max - 1;
492        assert_eq!(MAX_POSITION, expected_max_pos);
493        assert_eq!(MAX_POSITION, (1u64 << 63) - 2);
494
495        // Verify the constraint: a position at MAX_POSITION + 1 would require
496        // an MMR_size >= 2^63, which violates the "top bit clear" requirement
497        let hypothetical_mmr_size = MAX_POSITION + 2; // Would need this many nodes
498        assert_eq!(hypothetical_mmr_size, 1u64 << 63);
499        assert_eq!(hypothetical_mmr_size.leading_zeros(), 0); // Top bit NOT clear ✗
500
501        // Verify relationship with MAX_LOCATION
502        // Converting MAX_LOCATION to position should give a value < MAX_POSITION
503        let max_loc = Location::new_unchecked(MAX_LOCATION);
504        let last_leaf_pos = Position::try_from(max_loc).unwrap();
505        assert!(*last_leaf_pos < MAX_POSITION);
506    }
507
508    #[test]
509    fn test_is_mmr_size() {
510        // Build an MMR one node at a time and check that the validity check is correct for all
511        // sizes up to the current size.
512        let mut size_to_check = Position::new(0);
513        let mut hasher = Standard::<Sha256>::new();
514        let mut mmr = DirtyMmr::new();
515        let digest = [1u8; 32];
516        for _i in 0..10000 {
517            while size_to_check != mmr.size() {
518                assert!(
519                    !size_to_check.is_mmr_size(),
520                    "size_to_check: {} {}",
521                    size_to_check,
522                    mmr.size()
523                );
524                size_to_check += 1;
525            }
526            assert!(size_to_check.is_mmr_size());
527            mmr.add(&mut hasher, &digest);
528            size_to_check += 1;
529        }
530
531        // Test overflow boundaries.
532        assert!(!Position::new(u64::MAX).is_mmr_size());
533        assert!(Position::new(u64::MAX >> 1).is_mmr_size());
534        assert!(!Position::new((u64::MAX >> 1) + 1).is_mmr_size());
535        assert!(!MAX_POSITION.is_mmr_size());
536    }
537
538    #[test]
539    fn test_read_cfg_valid_values() {
540        use commonware_codec::{Encode, ReadExt};
541
542        // Test zero
543        let pos = Position::new(0);
544        let encoded = pos.encode();
545        let decoded = Position::read(&mut encoded.as_ref()).unwrap();
546        assert_eq!(decoded, pos);
547
548        // Test middle value
549        let pos = Position::new(12345);
550        let encoded = pos.encode();
551        let decoded = Position::read(&mut encoded.as_ref()).unwrap();
552        assert_eq!(decoded, pos);
553
554        // Test MAX_POSITION (boundary)
555        let pos = MAX_POSITION;
556        let encoded = pos.encode();
557        let decoded = Position::read(&mut encoded.as_ref()).unwrap();
558        assert_eq!(decoded, pos);
559    }
560
561    #[test]
562    fn test_read_cfg_invalid_values() {
563        use commonware_codec::{Encode, ReadExt};
564
565        // Encode MAX_POSITION + 1 as a raw varint, then try to decode as Position
566        let invalid_value = *MAX_POSITION + 1;
567        let encoded = commonware_codec::varint::UInt(invalid_value).encode();
568        let result = Position::read(&mut encoded.as_ref());
569        assert!(result.is_err());
570        assert!(matches!(
571            result,
572            Err(commonware_codec::Error::Invalid("Position", _))
573        ));
574
575        // Encode u64::MAX as a raw varint
576        let encoded = commonware_codec::varint::UInt(u64::MAX).encode();
577        let result = Position::read(&mut encoded.as_ref());
578        assert!(result.is_err());
579        assert!(matches!(
580            result,
581            Err(commonware_codec::Error::Invalid("Position", _))
582        ));
583    }
584}