Skip to main content

commonware_storage/mmr/
location.rs

1use super::position::Position;
2use crate::mmr::MAX_POSITION;
3use bytes::{Buf, BufMut};
4use commonware_codec::{Read, ReadExt};
5use core::{
6    convert::TryFrom,
7    fmt,
8    ops::{Add, AddAssign, Deref, Range, Sub, SubAssign},
9};
10use thiserror::Error;
11
12/// Maximum valid [Location] value that can exist in a valid MMR.
13///
14/// This limit exists because the total MMR size (number of nodes) must be representable in a u64
15/// with at least one leading zero bit for validity checking. The MMR size for N leaves is:
16///
17/// ```text
18/// MMR_size = 2*N - popcount(N)
19/// ```
20///
21/// where `popcount(N)` is the number of set bits in N (the number of binary trees in the MMR forest).
22///
23/// The worst case occurs when N is a power of 2 (popcount = 1), giving `MMR_size = 2*N - 1`.
24///
25/// For validity, we require `MMR_size < 2^63` (top bit clear), which gives us:
26///
27/// ```text
28/// 2*N - 1 < 2^63
29/// 2*N < 2^63 + 1
30/// N ≤ 2^62
31/// ```
32///
33/// Therefore, the maximum number of leaves is `2^62`, and the maximum location (0-indexed) is `2^62 - 1`.
34///
35/// ## Verification
36///
37/// For `N = 2^62` leaves (worst case):
38/// - `MMR_size = 2 * 2^62 - 1 = 2^63 - 1 = 0x7FFF_FFFF_FFFF_FFFF` ✓
39/// - Leading zeros: 1 ✓
40///
41/// For `N = 2^62 + 1` leaves:
42/// - `2 * N = 2^63 + 2` ✗ (exceeds maximum valid MMR size)
43pub const MAX_LOCATION: u64 = 0x3FFF_FFFF_FFFF_FFFF; // 2^62 - 1
44
45/// A [Location] is an index into an MMR's _leaves_.
46/// This is in contrast to a [Position], which is an index into an MMR's _nodes_.
47///
48/// # Limits
49///
50/// While [Location] can technically hold any `u64` value, only values up to [MAX_LOCATION]
51/// can be safely converted to [Position]. Values beyond this are considered invalid.
52#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Default, Debug)]
53pub struct Location(u64);
54
55#[cfg(feature = "arbitrary")]
56impl arbitrary::Arbitrary<'_> for Location {
57    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
58        let value = u.int_in_range(0..=MAX_LOCATION)?;
59        Ok(Self(value))
60    }
61}
62
63impl Location {
64    /// Create a new [Location] from a raw `u64` without validation.
65    ///
66    /// This is an internal constructor that assumes the value is valid. For creating
67    /// locations from external or untrusted sources, use [Location::new].
68    #[inline]
69    pub(crate) const fn new_unchecked(loc: u64) -> Self {
70        Self(loc)
71    }
72
73    /// Create a new [Location] from a raw `u64`, validating it does not exceed [MAX_LOCATION].
74    ///
75    /// Returns `None` if `loc > MAX_LOCATION`.
76    ///
77    /// # Examples
78    ///
79    /// ```
80    /// use commonware_storage::mmr::{Location, MAX_LOCATION};
81    ///
82    /// let loc = Location::new(100).unwrap();
83    /// assert_eq!(*loc, 100);
84    ///
85    /// // Values at MAX_LOCATION are valid
86    /// assert!(Location::new(MAX_LOCATION).is_some());
87    ///
88    /// // Values exceeding MAX_LOCATION return None
89    /// assert!(Location::new(MAX_LOCATION + 1).is_none());
90    /// assert!(Location::new(u64::MAX).is_none());
91    /// ```
92    #[inline]
93    pub const fn new(loc: u64) -> Option<Self> {
94        if loc > MAX_LOCATION {
95            None
96        } else {
97            Some(Self(loc))
98        }
99    }
100
101    /// Return the underlying `u64` value.
102    #[inline]
103    pub const fn as_u64(self) -> u64 {
104        self.0
105    }
106
107    /// Returns `true` iff this location can be safely converted to a [Position].
108    #[inline]
109    pub const fn is_valid(self) -> bool {
110        self.0 <= MAX_LOCATION
111    }
112
113    /// Return `self + rhs` returning `None` on overflow or if result exceeds [MAX_LOCATION].
114    #[inline]
115    pub const fn checked_add(self, rhs: u64) -> Option<Self> {
116        match self.0.checked_add(rhs) {
117            Some(value) => {
118                if value <= MAX_LOCATION {
119                    Some(Self(value))
120                } else {
121                    None
122                }
123            }
124            None => None,
125        }
126    }
127
128    /// Return `self - rhs` returning `None` on underflow.
129    #[inline]
130    pub const fn checked_sub(self, rhs: u64) -> Option<Self> {
131        match self.0.checked_sub(rhs) {
132            Some(value) => Some(Self(value)),
133            None => None,
134        }
135    }
136
137    /// Return `self + rhs` saturating at [MAX_LOCATION].
138    #[inline]
139    pub const fn saturating_add(self, rhs: u64) -> Self {
140        let result = self.0.saturating_add(rhs);
141        if result > MAX_LOCATION {
142            Self(MAX_LOCATION)
143        } else {
144            Self(result)
145        }
146    }
147
148    /// Return `self - rhs` saturating at zero.
149    #[inline]
150    pub const fn saturating_sub(self, rhs: u64) -> Self {
151        Self(self.0.saturating_sub(rhs))
152    }
153}
154
155impl fmt::Display for Location {
156    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157        write!(f, "Location({})", self.0)
158    }
159}
160
161impl From<u64> for Location {
162    #[inline]
163    fn from(value: u64) -> Self {
164        Self::new_unchecked(value)
165    }
166}
167
168impl From<usize> for Location {
169    #[inline]
170    fn from(value: usize) -> Self {
171        Self::new_unchecked(value as u64)
172    }
173}
174
175impl Deref for Location {
176    type Target = u64;
177    fn deref(&self) -> &Self::Target {
178        &self.0
179    }
180}
181
182impl From<Location> for u64 {
183    #[inline]
184    fn from(loc: Location) -> Self {
185        *loc
186    }
187}
188
189// Codec implementations using varint encoding for efficient storage
190impl commonware_codec::Write for Location {
191    #[inline]
192    fn write(&self, buf: &mut impl BufMut) {
193        commonware_codec::varint::UInt(self.0).write(buf);
194    }
195}
196
197impl commonware_codec::EncodeSize for Location {
198    #[inline]
199    fn encode_size(&self) -> usize {
200        commonware_codec::varint::UInt(self.0).encode_size()
201    }
202}
203
204impl Read for Location {
205    type Cfg = ();
206
207    #[inline]
208    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, commonware_codec::Error> {
209        let value: u64 = commonware_codec::varint::UInt::read(buf)?.into();
210        Self::new(value).ok_or(commonware_codec::Error::Invalid(
211            "Location",
212            "value exceeds MAX_LOCATION",
213        ))
214    }
215}
216
217/// Add two locations together.
218///
219/// # Panics
220///
221/// Panics if the result overflows.
222impl Add for Location {
223    type Output = Self;
224
225    #[inline]
226    fn add(self, rhs: Self) -> Self::Output {
227        Self(self.0 + rhs.0)
228    }
229}
230
231/// Add a location and a `u64`.
232///
233/// # Panics
234///
235/// Panics if the result overflows.
236impl Add<u64> for Location {
237    type Output = Self;
238
239    #[inline]
240    fn add(self, rhs: u64) -> Self::Output {
241        Self(self.0 + rhs)
242    }
243}
244
245/// Subtract two locations.
246///
247/// # Panics
248///
249/// Panics if the result underflows.
250impl Sub for Location {
251    type Output = Self;
252
253    #[inline]
254    fn sub(self, rhs: Self) -> Self::Output {
255        Self(self.0 - rhs.0)
256    }
257}
258
259/// Subtract a `u64` from a location.
260///
261/// # Panics
262///
263/// Panics if the result underflows.
264impl Sub<u64> for Location {
265    type Output = Self;
266
267    #[inline]
268    fn sub(self, rhs: u64) -> Self::Output {
269        Self(self.0 - rhs)
270    }
271}
272
273impl PartialEq<u64> for Location {
274    #[inline]
275    fn eq(&self, other: &u64) -> bool {
276        self.0 == *other
277    }
278}
279
280impl PartialOrd<u64> for Location {
281    #[inline]
282    fn partial_cmp(&self, other: &u64) -> Option<core::cmp::Ordering> {
283        self.0.partial_cmp(other)
284    }
285}
286
287// Allow u64 to be compared with Location too
288impl PartialEq<Location> for u64 {
289    #[inline]
290    fn eq(&self, other: &Location) -> bool {
291        *self == other.0
292    }
293}
294
295impl PartialOrd<Location> for u64 {
296    #[inline]
297    fn partial_cmp(&self, other: &Location) -> Option<core::cmp::Ordering> {
298        self.partial_cmp(&other.0)
299    }
300}
301
302/// Add a `u64` to a location.
303///
304/// # Panics
305///
306/// Panics if the result overflows.
307impl AddAssign<u64> for Location {
308    #[inline]
309    fn add_assign(&mut self, rhs: u64) {
310        self.0 += rhs;
311    }
312}
313
314/// Subtract a `u64` from a location.
315///
316/// # Panics
317///
318/// Panics if the result underflows.
319impl SubAssign<u64> for Location {
320    #[inline]
321    fn sub_assign(&mut self, rhs: u64) {
322        self.0 -= rhs;
323    }
324}
325
326impl TryFrom<Position> for Location {
327    type Error = LocationError;
328
329    /// Attempt to derive the [Location] of a given node [Position].
330    ///
331    /// Returns an error if the position does not correspond to an MMR leaf or if position
332    /// overflow occurs.
333    ///
334    /// This computation is O(log2(n)) in the given position.
335    #[inline]
336    fn try_from(pos: Position) -> Result<Self, Self::Error> {
337        // Reject positions beyond the valid MMR range. This ensures `pos + 1` won't overflow below.
338        if *pos > MAX_POSITION {
339            return Err(LocationError::Overflow(pos));
340        }
341        // Position 0 is always the first leaf at location 0.
342        if *pos == 0 {
343            return Ok(Self(0));
344        }
345
346        // Find the height of the perfect binary tree containing this position.
347        // Safe: pos + 1 cannot overflow since pos <= MAX_POSITION (checked above).
348        let start = u64::MAX >> (pos + 1).leading_zeros();
349        let height = start.trailing_ones();
350        // Height 0 means this position is a peak (not a leaf in a tree).
351        if height == 0 {
352            return Err(LocationError::NonLeaf(pos));
353        }
354        let mut two_h = 1 << (height - 1);
355        let mut cur_node = start - 1;
356        let mut leaf_loc_floor = 0u64;
357
358        while two_h > 1 {
359            if cur_node == *pos {
360                return Err(LocationError::NonLeaf(pos));
361            }
362            let left_pos = cur_node - two_h;
363            two_h >>= 1;
364            if *pos > left_pos {
365                // The leaf is in the right subtree, so we must account for the leaves in the left
366                // subtree all of which precede it.
367                leaf_loc_floor += two_h;
368                cur_node -= 1; // move to the right child
369            } else {
370                // The node is in the left subtree
371                cur_node = left_pos;
372            }
373        }
374
375        Ok(Self(leaf_loc_floor))
376    }
377}
378
379/// Error returned when attempting to convert a [Position] to a [Location].
380#[derive(Debug, Clone, Copy, Eq, PartialEq, Error)]
381pub enum LocationError {
382    #[error("{0} is not a leaf")]
383    NonLeaf(Position),
384
385    #[error("{0} > MAX_LOCATION")]
386    Overflow(Position),
387}
388
389/// Extension trait for converting `Range<Location>` into other range types.
390pub trait LocationRangeExt {
391    /// Convert a `Range<Location>` to a `Range<usize>` suitable for slice indexing.
392    fn to_usize_range(&self) -> Range<usize>;
393}
394
395impl LocationRangeExt for Range<Location> {
396    #[inline]
397    fn to_usize_range(&self) -> Range<usize> {
398        *self.start as usize..*self.end as usize
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::{Location, MAX_LOCATION};
405    use crate::mmr::{position::Position, LocationError, MAX_POSITION};
406
407    // Test that the [Location::try_from] function returns the correct location for leaf positions.
408    #[test]
409    fn test_try_from_position() {
410        const CASES: &[(Position, Location)] = &[
411            (Position::new(0), Location::new_unchecked(0)),
412            (Position::new(1), Location::new_unchecked(1)),
413            (Position::new(3), Location::new_unchecked(2)),
414            (Position::new(4), Location::new_unchecked(3)),
415            (Position::new(7), Location::new_unchecked(4)),
416            (Position::new(8), Location::new_unchecked(5)),
417            (Position::new(10), Location::new_unchecked(6)),
418            (Position::new(11), Location::new_unchecked(7)),
419            (Position::new(15), Location::new_unchecked(8)),
420            (Position::new(16), Location::new_unchecked(9)),
421            (Position::new(18), Location::new_unchecked(10)),
422            (Position::new(19), Location::new_unchecked(11)),
423            (Position::new(22), Location::new_unchecked(12)),
424            (Position::new(23), Location::new_unchecked(13)),
425            (Position::new(25), Location::new_unchecked(14)),
426            (Position::new(26), Location::new_unchecked(15)),
427        ];
428        for (pos, expected_loc) in CASES {
429            let loc = Location::try_from(*pos).expect("should map to a leaf location");
430            assert_eq!(loc, *expected_loc);
431        }
432    }
433
434    // Test that the [Location::try_from] function returns an error for non-leaf positions.
435    #[test]
436    fn test_try_from_position_error() {
437        const CASES: &[Position] = &[
438            Position::new(2),
439            Position::new(5),
440            Position::new(6),
441            Position::new(9),
442            Position::new(12),
443            Position::new(13),
444            Position::new(14),
445            Position::new(17),
446            Position::new(20),
447            Position::new(21),
448            Position::new(24),
449            Position::new(27),
450            Position::new(28),
451            Position::new(29),
452            Position::new(30),
453        ];
454        for &pos in CASES {
455            let err = Location::try_from(pos).expect_err("position is not a leaf");
456            assert_eq!(err, LocationError::NonLeaf(pos));
457        }
458    }
459
460    #[test]
461    fn test_try_from_position_error_overflow() {
462        let overflow_pos = Position::new(u64::MAX);
463        let err = Location::try_from(overflow_pos).expect_err("should overflow");
464        assert_eq!(err, LocationError::Overflow(overflow_pos));
465
466        // MAX_POSITION doesn't overflow and isn't a leaf
467        let result = Location::try_from(MAX_POSITION);
468        assert_eq!(result, Err(LocationError::NonLeaf(MAX_POSITION)));
469
470        let overflow_pos = MAX_POSITION + 1;
471        let err = Location::try_from(overflow_pos).expect_err("should overflow");
472        assert_eq!(err, LocationError::Overflow(overflow_pos));
473    }
474
475    #[test]
476    fn test_checked_add() {
477        let loc = Location::new_unchecked(10);
478        assert_eq!(loc.checked_add(5).unwrap(), 15);
479
480        // Overflow returns None
481        assert!(Location::new_unchecked(u64::MAX).checked_add(1).is_none());
482
483        // Exceeding MAX_LOCATION returns None
484        assert!(Location::new_unchecked(MAX_LOCATION)
485            .checked_add(1)
486            .is_none());
487
488        // At MAX_LOCATION is OK
489        let loc = Location::new_unchecked(MAX_LOCATION - 10);
490        assert_eq!(loc.checked_add(10).unwrap(), MAX_LOCATION);
491    }
492
493    #[test]
494    fn test_checked_sub() {
495        let loc = Location::new_unchecked(10);
496        assert_eq!(loc.checked_sub(5).unwrap(), 5);
497        assert!(loc.checked_sub(11).is_none());
498    }
499
500    #[test]
501    fn test_saturating_add() {
502        let loc = Location::new_unchecked(10);
503        assert_eq!(loc.saturating_add(5), 15);
504
505        // Saturates at MAX_LOCATION, not u64::MAX
506        assert_eq!(
507            Location::new_unchecked(u64::MAX).saturating_add(1),
508            MAX_LOCATION
509        );
510        assert_eq!(
511            Location::new_unchecked(MAX_LOCATION).saturating_add(1),
512            MAX_LOCATION
513        );
514        assert_eq!(
515            Location::new_unchecked(MAX_LOCATION).saturating_add(1000),
516            MAX_LOCATION
517        );
518    }
519
520    #[test]
521    fn test_saturating_sub() {
522        let loc = Location::new_unchecked(10);
523        assert_eq!(loc.saturating_sub(5), 5);
524        assert_eq!(Location::new_unchecked(0).saturating_sub(1), 0);
525    }
526
527    #[test]
528    fn test_display() {
529        let location = Location::new_unchecked(42);
530        assert_eq!(location.to_string(), "Location(42)");
531    }
532
533    #[test]
534    fn test_add() {
535        let loc1 = Location::new_unchecked(10);
536        let loc2 = Location::new_unchecked(5);
537        assert_eq!((loc1 + loc2), 15);
538    }
539
540    #[test]
541    fn test_sub() {
542        let loc1 = Location::new_unchecked(10);
543        let loc2 = Location::new_unchecked(3);
544        assert_eq!((loc1 - loc2), 7);
545    }
546
547    #[test]
548    fn test_comparison_with_u64() {
549        let loc = Location::new_unchecked(42);
550
551        // Test equality
552        assert_eq!(loc, 42u64);
553        assert_eq!(42u64, loc);
554        assert_ne!(loc, 43u64);
555        assert_ne!(43u64, loc);
556
557        // Test ordering
558        assert!(loc < 43u64);
559        assert!(43u64 > loc);
560        assert!(loc > 41u64);
561        assert!(41u64 < loc);
562        assert!(loc <= 42u64);
563        assert!(42u64 >= loc);
564    }
565
566    #[test]
567    fn test_assignment_with_u64() {
568        let mut loc = Location::new_unchecked(10);
569
570        // Test add assignment
571        loc += 5;
572        assert_eq!(loc, 15u64);
573
574        // Test sub assignment
575        loc -= 3;
576        assert_eq!(loc, 12u64);
577    }
578
579    #[test]
580    fn test_new() {
581        // Valid locations
582        assert!(Location::new(0).is_some());
583        assert!(Location::new(1000).is_some());
584        assert!(Location::new(MAX_LOCATION).is_some());
585
586        // Invalid locations (too large)
587        assert!(Location::new(MAX_LOCATION + 1).is_none());
588        assert!(Location::new(u64::MAX).is_none());
589    }
590
591    #[test]
592    fn test_is_valid() {
593        assert!(Location::new_unchecked(0).is_valid());
594        assert!(Location::new_unchecked(1000).is_valid());
595        assert!(Location::new_unchecked(MAX_LOCATION).is_valid());
596        assert!(Location::new_unchecked(MAX_LOCATION).is_valid());
597        assert!(!Location::new_unchecked(u64::MAX).is_valid());
598    }
599
600    #[test]
601    fn test_max_location_boundary() {
602        // MAX_LOCATION should convert successfully
603        let max_loc = Location::new_unchecked(MAX_LOCATION);
604        assert!(max_loc.is_valid());
605        let pos = Position::try_from(max_loc).unwrap();
606        // Verify the position value
607        // For MAX_LOCATION = 2^62 - 1 = 0x3FFFFFFFFFFFFFFF, popcount = 62
608        // Position = 2 * (2^62 - 1) - 62 = 2^63 - 2 - 62 = 2^63 - 64
609        let expected = (1u64 << 63) - 64;
610        assert_eq!(*pos, expected);
611    }
612
613    #[test]
614    fn test_overflow_location_returns_error() {
615        // MAX_LOCATION + 1 should return error
616        let over_loc = Location::new_unchecked(MAX_LOCATION + 1);
617        assert!(Position::try_from(over_loc).is_err());
618
619        // Verify the error message
620        match Position::try_from(over_loc) {
621            Err(crate::mmr::Error::LocationOverflow(loc)) => {
622                assert_eq!(loc, over_loc);
623            }
624            _ => panic!("expected LocationOverflow error"),
625        }
626    }
627
628    #[test]
629    fn test_read_cfg_valid_values() {
630        use commonware_codec::{Encode, ReadExt};
631
632        // Test zero
633        let loc = Location::new(0).unwrap();
634        let encoded = loc.encode();
635        let decoded = Location::read(&mut encoded.as_ref()).unwrap();
636        assert_eq!(decoded, loc);
637
638        // Test middle value
639        let loc = Location::new(12345).unwrap();
640        let encoded = loc.encode();
641        let decoded = Location::read(&mut encoded.as_ref()).unwrap();
642        assert_eq!(decoded, loc);
643
644        // Test MAX_LOCATION (boundary)
645        let loc = Location::new(MAX_LOCATION).unwrap();
646        let encoded = loc.encode();
647        let decoded = Location::read(&mut encoded.as_ref()).unwrap();
648        assert_eq!(decoded, loc);
649    }
650
651    #[test]
652    fn test_read_cfg_invalid_values() {
653        use commonware_codec::{Encode, ReadExt};
654
655        // Encode MAX_LOCATION + 1 as a raw varint, then try to decode as Location
656        let invalid_value = MAX_LOCATION + 1;
657        let encoded = commonware_codec::varint::UInt(invalid_value).encode();
658        let result = Location::read(&mut encoded.as_ref());
659        assert!(result.is_err());
660        assert!(matches!(
661            result,
662            Err(commonware_codec::Error::Invalid("Location", _))
663        ));
664
665        // Encode u64::MAX as a raw varint
666        let encoded = commonware_codec::varint::UInt(u64::MAX).encode();
667        let result = Location::read(&mut encoded.as_ref());
668        assert!(result.is_err());
669        assert!(matches!(
670            result,
671            Err(commonware_codec::Error::Invalid("Location", _))
672        ));
673    }
674}