Skip to main content

commonware_storage/mmr/
position.rs

1use super::location::Location;
2use bytes::{Buf, BufMut};
3use commonware_codec::{varint::UInt, ReadExt};
4use core::{
5    fmt,
6    ops::{Add, AddAssign, Deref, Sub, SubAssign},
7};
8
9/// Maximum valid [Position] value: the largest node count (size) an MMR can hold.
10///
11/// An MMR with `2^62` leaves has `2^63 - 1` nodes, so `MAX_POSITION = 2^63 - 1`.
12///
13/// Node indices are 0-based, so valid indices satisfy `pos < MAX_POSITION`. Node counts
14/// and MMR sizes satisfy `pos <= MAX_POSITION`.
15pub const MAX_POSITION: Position = Position::new(0x7FFFFFFFFFFFFFFF); // (1 << 63) - 1
16
17/// A [Position] is an index into an MMR's nodes.
18/// This is in contrast to a [Location], which is an index into an MMR's _leaves_.
19#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Default, Debug)]
20pub struct Position(u64);
21
22#[cfg(feature = "arbitrary")]
23impl arbitrary::Arbitrary<'_> for Position {
24    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
25        let value = u.int_in_range(0..=MAX_POSITION.0)?;
26        Ok(Self(value))
27    }
28}
29
30impl Position {
31    /// Return a new [Position] from a raw `u64`.
32    #[inline]
33    pub const fn new(pos: u64) -> Self {
34        Self(pos)
35    }
36
37    /// Return the underlying `u64` value.
38    #[inline]
39    pub const fn as_u64(self) -> u64 {
40        self.0
41    }
42
43    /// Returns `true` iff this value is within the valid range (`<= MAX_POSITION`).
44    /// This covers both node indices (`< MAX_POSITION`) and node counts (`<= MAX_POSITION`).
45    #[inline]
46    pub const fn is_valid(self) -> bool {
47        self.0 <= MAX_POSITION.0
48    }
49
50    /// Return `self + rhs` returning `None` on overflow or if result exceeds [MAX_POSITION].
51    #[inline]
52    pub const fn checked_add(self, rhs: u64) -> Option<Self> {
53        match self.0.checked_add(rhs) {
54            Some(value) => {
55                if value <= MAX_POSITION.0 {
56                    Some(Self(value))
57                } else {
58                    None
59                }
60            }
61            None => None,
62        }
63    }
64
65    /// Return `self - rhs` returning `None` on underflow.
66    #[inline]
67    pub const fn checked_sub(self, rhs: u64) -> Option<Self> {
68        match self.0.checked_sub(rhs) {
69            Some(value) => Some(Self(value)),
70            None => None,
71        }
72    }
73
74    /// Return `self + rhs` saturating at [MAX_POSITION].
75    #[inline]
76    pub const fn saturating_add(self, rhs: u64) -> Self {
77        let result = self.0.saturating_add(rhs);
78        if result > MAX_POSITION.0 {
79            MAX_POSITION
80        } else {
81            Self(result)
82        }
83    }
84
85    /// Return `self - rhs` saturating at zero.
86    #[inline]
87    pub const fn saturating_sub(self, rhs: u64) -> Self {
88        Self(self.0.saturating_sub(rhs))
89    }
90
91    /// Returns whether this is a valid MMR size.
92    ///
93    /// The implementation verifies that (1) the size won't result in overflow and (2) peaks in the
94    /// MMR of the given size have strictly decreasing height, which is a necessary condition for
95    /// MMR validity.
96    #[inline]
97    pub const fn is_mmr_size(self) -> bool {
98        if self.0 == 0 {
99            return true;
100        }
101        let leading_zeros = self.0.leading_zeros();
102        if leading_zeros == 0 {
103            // size overflow
104            return false;
105        }
106        let start = u64::MAX >> leading_zeros;
107        let mut two_h = 1 << start.trailing_ones();
108        let mut node_pos = start.checked_sub(1).expect("start > 0 because size != 0");
109        while two_h > 1 {
110            if node_pos < self.0 {
111                if two_h == 2 {
112                    // If this peak is a leaf yet there are more nodes remaining, then this MMR is
113                    // invalid.
114                    return node_pos == self.0 - 1;
115                }
116                // move to the right sibling
117                node_pos += two_h - 1;
118                if node_pos < self.0 {
119                    // If the right sibling is in the MMR, then it is invalid.
120                    return false;
121                }
122                continue;
123            }
124            // descend to the left child
125            two_h >>= 1;
126            node_pos -= two_h;
127        }
128        true
129    }
130}
131
132impl fmt::Display for Position {
133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134        write!(f, "Position({})", self.0)
135    }
136}
137
138impl Deref for Position {
139    type Target = u64;
140    fn deref(&self) -> &Self::Target {
141        &self.0
142    }
143}
144
145impl AsRef<u64> for Position {
146    fn as_ref(&self) -> &u64 {
147        &self.0
148    }
149}
150
151impl From<u64> for Position {
152    #[inline]
153    fn from(value: u64) -> Self {
154        Self::new(value)
155    }
156}
157
158impl From<usize> for Position {
159    #[inline]
160    fn from(value: usize) -> Self {
161        Self::new(value as u64)
162    }
163}
164
165impl From<Position> for u64 {
166    #[inline]
167    fn from(position: Position) -> Self {
168        *position
169    }
170}
171
172/// Try to convert a leaf [Location] to its node [Position].
173///
174/// Returns [super::Error::LocationOverflow] if `!loc.is_valid()`.
175///
176/// # Examples
177///
178/// ```
179/// use commonware_storage::mmr::{Location, Position, MAX_LOCATION};
180/// use core::convert::TryFrom;
181///
182/// let loc = Location::new(5);
183/// let pos = Position::try_from(loc).unwrap();
184/// assert_eq!(pos, Position::new(8));
185///
186/// // MAX_LOCATION converts successfully (it is the leaf count for 2^62 leaves)
187/// let pos = Position::try_from(MAX_LOCATION).unwrap();
188/// assert!(pos.is_valid());
189/// ```
190impl TryFrom<Location> for Position {
191    type Error = super::Error;
192
193    #[inline]
194    fn try_from(loc: Location) -> Result<Self, Self::Error> {
195        if !loc.is_valid() {
196            return Err(super::Error::LocationOverflow(loc));
197        }
198        // This will never underflow since 2*n >= count_ones(n).
199        let loc_val = *loc;
200        Ok(Self(
201            loc_val
202                .checked_mul(2)
203                .expect("should not overflow for valid leaf index")
204                - loc_val.count_ones() as u64,
205        ))
206    }
207}
208
209/// Add two positions together.
210///
211/// # Panics
212///
213/// Panics if the result overflows.
214impl Add for Position {
215    type Output = Self;
216
217    #[inline]
218    fn add(self, rhs: Self) -> Self::Output {
219        Self(self.0 + rhs.0)
220    }
221}
222
223/// Add a position and a `u64`.
224///
225/// # Panics
226///
227/// Panics if the result overflows.
228impl Add<u64> for Position {
229    type Output = Self;
230
231    #[inline]
232    fn add(self, rhs: u64) -> Self::Output {
233        Self(self.0 + rhs)
234    }
235}
236
237/// Subtract two positions.
238///
239/// # Panics
240///
241/// Panics if the result underflows.
242impl Sub for Position {
243    type Output = Self;
244
245    #[inline]
246    fn sub(self, rhs: Self) -> Self::Output {
247        Self(self.0 - rhs.0)
248    }
249}
250
251/// Subtract a `u64` from a position.
252///
253/// # Panics
254///
255/// Panics if the result underflows.
256impl Sub<u64> for Position {
257    type Output = Self;
258
259    #[inline]
260    fn sub(self, rhs: u64) -> Self::Output {
261        Self(self.0 - rhs)
262    }
263}
264
265impl PartialEq<u64> for Position {
266    #[inline]
267    fn eq(&self, other: &u64) -> bool {
268        self.0 == *other
269    }
270}
271
272impl PartialOrd<u64> for Position {
273    #[inline]
274    fn partial_cmp(&self, other: &u64) -> Option<core::cmp::Ordering> {
275        self.0.partial_cmp(other)
276    }
277}
278
279// Allow u64 to be compared with Position too
280impl PartialEq<Position> for u64 {
281    #[inline]
282    fn eq(&self, other: &Position) -> bool {
283        *self == other.0
284    }
285}
286
287impl PartialOrd<Position> for u64 {
288    #[inline]
289    fn partial_cmp(&self, other: &Position) -> Option<core::cmp::Ordering> {
290        self.partial_cmp(&other.0)
291    }
292}
293
294/// Add a `u64` to a position.
295///
296/// # Panics
297///
298/// Panics if the result overflows.
299impl AddAssign<u64> for Position {
300    #[inline]
301    fn add_assign(&mut self, rhs: u64) {
302        self.0 += rhs;
303    }
304}
305
306/// Subtract a `u64` from a position.
307///
308/// # Panics
309///
310/// Panics if the result underflows.
311impl SubAssign<u64> for Position {
312    #[inline]
313    fn sub_assign(&mut self, rhs: u64) {
314        self.0 -= rhs;
315    }
316}
317
318// Codec implementations using varint encoding for efficient storage
319impl commonware_codec::Write for Position {
320    #[inline]
321    fn write(&self, buf: &mut impl BufMut) {
322        UInt(self.0).write(buf);
323    }
324}
325
326impl commonware_codec::EncodeSize for Position {
327    #[inline]
328    fn encode_size(&self) -> usize {
329        UInt(self.0).encode_size()
330    }
331}
332
333impl commonware_codec::Read for Position {
334    type Cfg = ();
335
336    #[inline]
337    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, commonware_codec::Error> {
338        let pos = Self(UInt::read(buf)?.into());
339        if pos.is_valid() {
340            Ok(pos)
341        } else {
342            Err(commonware_codec::Error::Invalid(
343                "Position",
344                "value exceeds MAX_POSITION",
345            ))
346        }
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::{Location, Position};
353    use crate::mmr::{mem::Mmr, StandardHasher as Standard, MAX_LOCATION, MAX_POSITION};
354    use commonware_cryptography::Sha256;
355
356    // Test that the [Position::from] function returns the correct position for leaf locations.
357    #[test]
358    fn test_from_location() {
359        const CASES: &[(Location, Position)] = &[
360            (Location::new(0), Position::new(0)),
361            (Location::new(1), Position::new(1)),
362            (Location::new(2), Position::new(3)),
363            (Location::new(3), Position::new(4)),
364            (Location::new(4), Position::new(7)),
365            (Location::new(5), Position::new(8)),
366            (Location::new(6), Position::new(10)),
367            (Location::new(7), Position::new(11)),
368            (Location::new(8), Position::new(15)),
369            (Location::new(9), Position::new(16)),
370            (Location::new(10), Position::new(18)),
371            (Location::new(11), Position::new(19)),
372            (Location::new(12), Position::new(22)),
373            (Location::new(13), Position::new(23)),
374            (Location::new(14), Position::new(25)),
375            (Location::new(15), Position::new(26)),
376        ];
377        for (loc, expected_pos) in CASES {
378            let pos = Position::try_from(*loc).unwrap();
379            assert_eq!(pos, *expected_pos);
380        }
381    }
382
383    #[test]
384    fn test_checked_add() {
385        let pos = Position::new(10);
386        assert_eq!(pos.checked_add(5).unwrap(), 15);
387
388        // Overflow returns None
389        assert!(Position::new(u64::MAX).checked_add(1).is_none());
390
391        // Exceeding MAX_POSITION returns None
392        assert!(MAX_POSITION.checked_add(1).is_none());
393        assert!(Position::new(*MAX_POSITION - 5).checked_add(10).is_none());
394
395        // At MAX_POSITION is OK
396        assert_eq!(
397            Position::new(*MAX_POSITION - 10).checked_add(10).unwrap(),
398            MAX_POSITION
399        );
400    }
401
402    #[test]
403    fn test_checked_sub() {
404        let pos = Position::new(10);
405        assert_eq!(pos.checked_sub(5).unwrap(), 5);
406        assert!(pos.checked_sub(11).is_none());
407    }
408
409    #[test]
410    fn test_saturating_add() {
411        let pos = Position::new(10);
412        assert_eq!(pos.saturating_add(5), 15);
413
414        // Saturates at MAX_POSITION, not u64::MAX
415        assert_eq!(Position::new(u64::MAX).saturating_add(1), MAX_POSITION);
416        assert_eq!(MAX_POSITION.saturating_add(1), MAX_POSITION);
417        assert_eq!(MAX_POSITION.saturating_add(1000), MAX_POSITION);
418        assert_eq!(
419            Position::new(*MAX_POSITION - 5).saturating_add(10),
420            MAX_POSITION
421        );
422    }
423
424    #[test]
425    fn test_saturating_sub() {
426        let pos = Position::new(10);
427        assert_eq!(pos.saturating_sub(5), 5);
428        assert_eq!(Position::new(0).saturating_sub(1), 0);
429    }
430
431    #[test]
432    fn test_display() {
433        let position = Position::new(42);
434        assert_eq!(position.to_string(), "Position(42)");
435    }
436
437    #[test]
438    fn test_add() {
439        let pos1 = Position::new(10);
440        let pos2 = Position::new(5);
441        assert_eq!((pos1 + pos2), 15);
442    }
443
444    #[test]
445    fn test_sub() {
446        let pos1 = Position::new(10);
447        let pos2 = Position::new(3);
448        assert_eq!((pos1 - pos2), 7);
449    }
450
451    #[test]
452    fn test_comparison_with_u64() {
453        let pos = Position::new(42);
454
455        // Test equality
456        assert_eq!(pos, 42u64);
457        assert_eq!(42u64, pos);
458        assert_ne!(pos, 43u64);
459        assert_ne!(43u64, pos);
460
461        // Test ordering
462        assert!(pos < 43u64);
463        assert!(43u64 > pos);
464        assert!(pos > 41u64);
465        assert!(41u64 < pos);
466        assert!(pos <= 42u64);
467        assert!(42u64 >= pos);
468    }
469
470    #[test]
471    fn test_assignment_with_u64() {
472        let mut pos = Position::new(10);
473
474        // Test add assignment
475        pos += 5;
476        assert_eq!(pos, 15u64);
477
478        // Test sub assignment
479        pos -= 3;
480        assert_eq!(pos, 12u64);
481    }
482
483    #[test]
484    fn test_max_position() {
485        // MAX_POSITION = max MMR size = 2^63 - 1 (for 2^62 leaves).
486        let max_leaves = 1u64 << 62;
487        let max_size = 2 * max_leaves - 1; // 2^63 - 1
488        assert_eq!(*MAX_POSITION, max_size);
489        assert_eq!(*MAX_POSITION, (1u64 << 63) - 1);
490        assert_eq!(max_size.leading_zeros(), 1); // top bit clear
491
492        // One more leaf would overflow: size = 2^63, top bit set.
493        let overflow_size = 2 * (max_leaves + 1) - 1;
494        assert_eq!(overflow_size.leading_zeros(), 0);
495
496        // MAX_LOCATION converts to MAX_POSITION.
497        let pos = Position::try_from(MAX_LOCATION).unwrap();
498        assert_eq!(pos, MAX_POSITION);
499        assert!(pos.is_valid());
500    }
501
502    #[test]
503    fn test_is_mmr_size() {
504        // Build an MMR one node at a time and check that the validity check is correct for all
505        // sizes up to the current size.
506        let mut size_to_check = Position::new(0);
507        let mut hasher = Standard::<Sha256>::new();
508        let mut mmr = Mmr::new(&mut hasher);
509        let digest = [1u8; 32];
510        for _i in 0..10000 {
511            while size_to_check != mmr.size() {
512                assert!(
513                    !size_to_check.is_mmr_size(),
514                    "size_to_check: {} {}",
515                    size_to_check,
516                    mmr.size()
517                );
518                size_to_check += 1;
519            }
520            assert!(size_to_check.is_mmr_size());
521            let changeset = {
522                let mut batch = mmr.new_batch();
523                batch.add(&mut hasher, &digest);
524                batch.merkleize(&mut hasher).finalize()
525            };
526            mmr.apply(changeset).unwrap();
527            size_to_check += 1;
528        }
529
530        // Test overflow boundaries.
531        assert!(!Position::new(u64::MAX).is_mmr_size());
532        assert!(Position::new(u64::MAX >> 1).is_mmr_size()); // 2^63 - 1 = MAX_POSITION
533        assert!(!Position::new((u64::MAX >> 1) + 1).is_mmr_size());
534        assert!(MAX_POSITION.is_mmr_size()); // MAX_POSITION is the largest valid MMR size
535    }
536
537    #[test]
538    fn test_read_cfg_valid_values() {
539        use commonware_codec::{Encode, ReadExt};
540
541        // Test zero
542        let pos = Position::new(0);
543        let encoded = pos.encode();
544        let decoded = Position::read(&mut encoded.as_ref()).unwrap();
545        assert_eq!(decoded, pos);
546
547        // Test middle value
548        let pos = Position::new(12345);
549        let encoded = pos.encode();
550        let decoded = Position::read(&mut encoded.as_ref()).unwrap();
551        assert_eq!(decoded, pos);
552
553        // Test MAX_POSITION (boundary)
554        let pos = MAX_POSITION;
555        let encoded = pos.encode();
556        let decoded = Position::read(&mut encoded.as_ref()).unwrap();
557        assert_eq!(decoded, pos);
558    }
559
560    #[test]
561    fn test_read_cfg_invalid_values() {
562        use commonware_codec::{varint::UInt, Encode, ReadExt};
563
564        // Encode MAX_POSITION + 1 as a raw varint, then try to decode as Position
565        let invalid_value = *MAX_POSITION + 1;
566        let encoded = UInt(invalid_value).encode();
567        let result = Position::read(&mut encoded.as_ref());
568        assert!(result.is_err());
569        assert!(matches!(
570            result,
571            Err(commonware_codec::Error::Invalid("Position", _))
572        ));
573
574        // Encode u64::MAX as a raw varint
575        let encoded = UInt(u64::MAX).encode();
576        let result = Position::read(&mut encoded.as_ref());
577        assert!(result.is_err());
578        assert!(matches!(
579            result,
580            Err(commonware_codec::Error::Invalid("Position", _))
581        ));
582    }
583}