Skip to main content

commonware_storage/mmr/
location.rs

1use super::position::Position;
2use bytes::{Buf, BufMut};
3use commonware_codec::{varint::UInt, Read, ReadExt};
4use core::{
5    convert::TryFrom,
6    fmt,
7    ops::{Add, AddAssign, Deref, Range, Sub, SubAssign},
8};
9use thiserror::Error;
10
11/// Maximum valid [Location] value: the largest leaf count an MMR can hold.
12///
13/// An MMR with N leaves has `2*N - popcount(N)` nodes. We require `size < 2^63` (top bit clear).
14/// The worst case is `N = 2^62` (a power of two, `popcount = 1`):
15///
16/// ```text
17/// 2*N - 1 < 2^63  =>  N <= 2^62
18/// ```
19///
20/// Therefore the maximum leaf count is `2^62` and `MAX_LOCATION = 2^62`.
21///
22/// Leaf indices are 0-based, so valid indices satisfy `loc < MAX_LOCATION` (i.e., `0..=2^62 - 1`).
23/// Leaf counts and exclusive range-ends satisfy `loc <= MAX_LOCATION`.
24pub const MAX_LOCATION: Location = Location(0x4000_0000_0000_0000); // 2^62
25
26/// A [Location] is a leaf index or leaf count in an MMR.
27/// This is in contrast to a [Position], which is a node index or node count.
28///
29/// # Limits
30///
31/// Values up to [MAX_LOCATION] are valid (see [Location::is_valid]). As a 0-based leaf index,
32/// valid indices are `0..MAX_LOCATION - 1`. As a leaf count or exclusive range-end, the maximum
33/// is `MAX_LOCATION` itself.
34#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Default, Debug)]
35pub struct Location(u64);
36
37#[cfg(feature = "arbitrary")]
38impl arbitrary::Arbitrary<'_> for Location {
39    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
40        let value = u.int_in_range(0..=*MAX_LOCATION)?;
41        Ok(Self(value))
42    }
43}
44
45impl Location {
46    /// Return a new [Location] from a raw `u64`.
47    #[inline]
48    pub const fn new(loc: u64) -> Self {
49        Self(loc)
50    }
51
52    /// Return the underlying `u64` value.
53    #[inline]
54    pub const fn as_u64(self) -> u64 {
55        self.0
56    }
57
58    /// Returns `true` iff this value is within the valid range (`<= MAX_LOCATION`).
59    /// This covers both leaf indices (`< MAX_LOCATION`) and leaf counts (`<= MAX_LOCATION`).
60    #[inline]
61    pub const fn is_valid(self) -> bool {
62        self.0 <= MAX_LOCATION.0
63    }
64
65    /// Return `self + rhs` returning `None` on overflow or if result exceeds [MAX_LOCATION].
66    #[inline]
67    pub const fn checked_add(self, rhs: u64) -> Option<Self> {
68        match self.0.checked_add(rhs) {
69            Some(value) => {
70                if value <= MAX_LOCATION.0 {
71                    Some(Self(value))
72                } else {
73                    None
74                }
75            }
76            None => None,
77        }
78    }
79
80    /// Return `self - rhs` returning `None` on underflow.
81    #[inline]
82    pub const fn checked_sub(self, rhs: u64) -> Option<Self> {
83        match self.0.checked_sub(rhs) {
84            Some(value) => Some(Self(value)),
85            None => None,
86        }
87    }
88
89    /// Return `self + rhs` saturating at [MAX_LOCATION].
90    #[inline]
91    pub const fn saturating_add(self, rhs: u64) -> Self {
92        let result = self.0.saturating_add(rhs);
93        if result > MAX_LOCATION.0 {
94            MAX_LOCATION
95        } else {
96            Self(result)
97        }
98    }
99
100    /// Return `self - rhs` saturating at zero.
101    #[inline]
102    pub const fn saturating_sub(self, rhs: u64) -> Self {
103        Self(self.0.saturating_sub(rhs))
104    }
105}
106
107impl fmt::Display for Location {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109        write!(f, "Location({})", self.0)
110    }
111}
112
113impl From<u64> for Location {
114    #[inline]
115    fn from(value: u64) -> Self {
116        Self::new(value)
117    }
118}
119
120impl From<usize> for Location {
121    #[inline]
122    fn from(value: usize) -> Self {
123        Self::new(value as u64)
124    }
125}
126
127impl Deref for Location {
128    type Target = u64;
129    fn deref(&self) -> &Self::Target {
130        &self.0
131    }
132}
133
134impl From<Location> for u64 {
135    #[inline]
136    fn from(loc: Location) -> Self {
137        *loc
138    }
139}
140
141// Codec implementations using varint encoding for efficient storage
142impl commonware_codec::Write for Location {
143    #[inline]
144    fn write(&self, buf: &mut impl BufMut) {
145        UInt(self.0).write(buf);
146    }
147}
148
149impl commonware_codec::EncodeSize for Location {
150    #[inline]
151    fn encode_size(&self) -> usize {
152        UInt(self.0).encode_size()
153    }
154}
155
156impl Read for Location {
157    type Cfg = ();
158
159    #[inline]
160    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, commonware_codec::Error> {
161        let value: u64 = UInt::read(buf)?.into();
162        let loc = Self::new(value);
163        if loc.is_valid() {
164            Ok(loc)
165        } else {
166            Err(commonware_codec::Error::Invalid(
167                "Location",
168                "value exceeds MAX_LOCATION",
169            ))
170        }
171    }
172}
173
174/// Add two locations together.
175///
176/// # Panics
177///
178/// Panics if the result overflows.
179impl Add for Location {
180    type Output = Self;
181
182    #[inline]
183    fn add(self, rhs: Self) -> Self::Output {
184        Self(self.0 + rhs.0)
185    }
186}
187
188/// Add a location and a `u64`.
189///
190/// # Panics
191///
192/// Panics if the result overflows.
193impl Add<u64> for Location {
194    type Output = Self;
195
196    #[inline]
197    fn add(self, rhs: u64) -> Self::Output {
198        Self(self.0 + rhs)
199    }
200}
201
202/// Subtract two locations.
203///
204/// # Panics
205///
206/// Panics if the result underflows.
207impl Sub for Location {
208    type Output = Self;
209
210    #[inline]
211    fn sub(self, rhs: Self) -> Self::Output {
212        Self(self.0 - rhs.0)
213    }
214}
215
216/// Subtract a `u64` from a location.
217///
218/// # Panics
219///
220/// Panics if the result underflows.
221impl Sub<u64> for Location {
222    type Output = Self;
223
224    #[inline]
225    fn sub(self, rhs: u64) -> Self::Output {
226        Self(self.0 - rhs)
227    }
228}
229
230impl PartialEq<u64> for Location {
231    #[inline]
232    fn eq(&self, other: &u64) -> bool {
233        self.0 == *other
234    }
235}
236
237impl PartialOrd<u64> for Location {
238    #[inline]
239    fn partial_cmp(&self, other: &u64) -> Option<core::cmp::Ordering> {
240        self.0.partial_cmp(other)
241    }
242}
243
244// Allow u64 to be compared with Location too
245impl PartialEq<Location> for u64 {
246    #[inline]
247    fn eq(&self, other: &Location) -> bool {
248        *self == other.0
249    }
250}
251
252impl PartialOrd<Location> for u64 {
253    #[inline]
254    fn partial_cmp(&self, other: &Location) -> Option<core::cmp::Ordering> {
255        self.partial_cmp(&other.0)
256    }
257}
258
259/// Add a `u64` to a location.
260///
261/// # Panics
262///
263/// Panics if the result overflows.
264impl AddAssign<u64> for Location {
265    #[inline]
266    fn add_assign(&mut self, rhs: u64) {
267        self.0 += rhs;
268    }
269}
270
271/// Subtract a `u64` from a location.
272///
273/// # Panics
274///
275/// Panics if the result underflows.
276impl SubAssign<u64> for Location {
277    #[inline]
278    fn sub_assign(&mut self, rhs: u64) {
279        self.0 -= rhs;
280    }
281}
282
283impl TryFrom<Position> for Location {
284    type Error = LocationError;
285
286    /// Attempt to derive the [Location] of a given node [Position].
287    ///
288    /// Returns an error if the position does not correspond to an MMR leaf or if position
289    /// overflow occurs.
290    ///
291    /// This computation is O(log2(n)) in the given position.
292    #[inline]
293    fn try_from(pos: Position) -> Result<Self, Self::Error> {
294        // Reject positions beyond the valid range.
295        if !pos.is_valid() {
296            return Err(LocationError::Overflow(pos));
297        }
298        // Position 0 is always the first leaf at location 0.
299        if *pos == 0 {
300            return Ok(Self(0));
301        }
302
303        // Find the height of the perfect binary tree containing this position.
304        // Safe: pos + 1 cannot overflow since pos <= MAX_POSITION (checked above).
305        let start = u64::MAX >> (pos + 1).leading_zeros();
306        let height = start.trailing_ones();
307        // Height 0 means this position is a peak (not a leaf in a tree).
308        if height == 0 {
309            return Err(LocationError::NonLeaf(pos));
310        }
311        let mut two_h = 1 << (height - 1);
312        let mut cur_node = start - 1;
313        let mut leaf_loc_floor = 0u64;
314
315        while two_h > 1 {
316            if cur_node == *pos {
317                return Err(LocationError::NonLeaf(pos));
318            }
319            let left_pos = cur_node - two_h;
320            two_h >>= 1;
321            if *pos > left_pos {
322                // The leaf is in the right subtree, so we must account for the leaves in the left
323                // subtree all of which precede it.
324                leaf_loc_floor += two_h;
325                cur_node -= 1; // move to the right child
326            } else {
327                // The node is in the left subtree
328                cur_node = left_pos;
329            }
330        }
331
332        Ok(Self(leaf_loc_floor))
333    }
334}
335
336/// Error returned when attempting to convert a [Position] to a [Location].
337#[derive(Debug, Clone, Copy, Eq, PartialEq, Error)]
338pub enum LocationError {
339    #[error("{0} is not a leaf")]
340    NonLeaf(Position),
341
342    #[error("{0} > MAX_LOCATION")]
343    Overflow(Position),
344}
345
346/// Extension trait for converting `Range<Location>` into other range types.
347pub trait LocationRangeExt {
348    /// Convert a `Range<Location>` to a `Range<usize>` suitable for slice indexing.
349    fn to_usize_range(&self) -> Range<usize>;
350}
351
352impl LocationRangeExt for Range<Location> {
353    #[inline]
354    fn to_usize_range(&self) -> Range<usize> {
355        *self.start as usize..*self.end as usize
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::{Location, MAX_LOCATION};
362    use crate::mmr::{position::Position, LocationError, MAX_POSITION};
363
364    // Test that the [Location::try_from] function returns the correct location for leaf positions.
365    #[test]
366    fn test_try_from_position() {
367        const CASES: &[(Position, Location)] = &[
368            (Position::new(0), Location::new(0)),
369            (Position::new(1), Location::new(1)),
370            (Position::new(3), Location::new(2)),
371            (Position::new(4), Location::new(3)),
372            (Position::new(7), Location::new(4)),
373            (Position::new(8), Location::new(5)),
374            (Position::new(10), Location::new(6)),
375            (Position::new(11), Location::new(7)),
376            (Position::new(15), Location::new(8)),
377            (Position::new(16), Location::new(9)),
378            (Position::new(18), Location::new(10)),
379            (Position::new(19), Location::new(11)),
380            (Position::new(22), Location::new(12)),
381            (Position::new(23), Location::new(13)),
382            (Position::new(25), Location::new(14)),
383            (Position::new(26), Location::new(15)),
384        ];
385        for (pos, expected_loc) in CASES {
386            let loc = Location::try_from(*pos).expect("should map to a leaf location");
387            assert_eq!(loc, *expected_loc);
388        }
389    }
390
391    // Test that the [Location::try_from] function returns an error for non-leaf positions.
392    #[test]
393    fn test_try_from_position_error() {
394        const CASES: &[Position] = &[
395            Position::new(2),
396            Position::new(5),
397            Position::new(6),
398            Position::new(9),
399            Position::new(12),
400            Position::new(13),
401            Position::new(14),
402            Position::new(17),
403            Position::new(20),
404            Position::new(21),
405            Position::new(24),
406            Position::new(27),
407            Position::new(28),
408            Position::new(29),
409            Position::new(30),
410        ];
411        for &pos in CASES {
412            let err = Location::try_from(pos).expect_err("position is not a leaf");
413            assert_eq!(err, LocationError::NonLeaf(pos));
414        }
415    }
416
417    #[test]
418    fn test_try_from_position_error_overflow() {
419        let overflow_pos = Position::new(u64::MAX);
420        let err = Location::try_from(overflow_pos).expect_err("should overflow");
421        assert_eq!(err, LocationError::Overflow(overflow_pos));
422
423        // MAX_POSITION is the leaf at MAX_LOCATION
424        let result = Location::try_from(MAX_POSITION);
425        assert_eq!(result, Ok(MAX_LOCATION));
426
427        let overflow_pos = MAX_POSITION + 1;
428        let err = Location::try_from(overflow_pos).expect_err("should overflow");
429        assert_eq!(err, LocationError::Overflow(overflow_pos));
430    }
431
432    #[test]
433    fn test_checked_add() {
434        let loc = Location::new(10);
435        assert_eq!(loc.checked_add(5).unwrap(), 15);
436
437        // Overflow returns None
438        assert!(Location::new(u64::MAX).checked_add(1).is_none());
439
440        // Exceeding MAX_LOCATION returns None
441        assert!(MAX_LOCATION.checked_add(1).is_none());
442
443        // At MAX_LOCATION is OK
444        let loc = Location::new(*MAX_LOCATION - 10);
445        assert_eq!(loc.checked_add(10).unwrap(), *MAX_LOCATION);
446    }
447
448    #[test]
449    fn test_checked_sub() {
450        let loc = Location::new(10);
451        assert_eq!(loc.checked_sub(5).unwrap(), 5);
452        assert!(loc.checked_sub(11).is_none());
453    }
454
455    #[test]
456    fn test_saturating_add() {
457        let loc = Location::new(10);
458        assert_eq!(loc.saturating_add(5), 15);
459
460        // Saturates at MAX_LOCATION, not u64::MAX
461        assert_eq!(Location::new(u64::MAX).saturating_add(1), MAX_LOCATION);
462        assert_eq!(MAX_LOCATION.saturating_add(1), MAX_LOCATION);
463        assert_eq!(MAX_LOCATION.saturating_add(1000), MAX_LOCATION);
464    }
465
466    #[test]
467    fn test_saturating_sub() {
468        let loc = Location::new(10);
469        assert_eq!(loc.saturating_sub(5), 5);
470        assert_eq!(Location::new(0).saturating_sub(1), 0);
471    }
472
473    #[test]
474    fn test_display() {
475        let location = Location::new(42);
476        assert_eq!(location.to_string(), "Location(42)");
477    }
478
479    #[test]
480    fn test_add() {
481        let loc1 = Location::new(10);
482        let loc2 = Location::new(5);
483        assert_eq!((loc1 + loc2), 15);
484    }
485
486    #[test]
487    fn test_sub() {
488        let loc1 = Location::new(10);
489        let loc2 = Location::new(3);
490        assert_eq!((loc1 - loc2), 7);
491    }
492
493    #[test]
494    fn test_comparison_with_u64() {
495        let loc = Location::new(42);
496
497        // Test equality
498        assert_eq!(loc, 42u64);
499        assert_eq!(42u64, loc);
500        assert_ne!(loc, 43u64);
501        assert_ne!(43u64, loc);
502
503        // Test ordering
504        assert!(loc < 43u64);
505        assert!(43u64 > loc);
506        assert!(loc > 41u64);
507        assert!(41u64 < loc);
508        assert!(loc <= 42u64);
509        assert!(42u64 >= loc);
510    }
511
512    #[test]
513    fn test_assignment_with_u64() {
514        let mut loc = Location::new(10);
515
516        // Test add assignment
517        loc += 5;
518        assert_eq!(loc, 15u64);
519
520        // Test sub assignment
521        loc -= 3;
522        assert_eq!(loc, 12u64);
523    }
524
525    #[test]
526    fn test_is_valid() {
527        assert!(Location::new(0).is_valid());
528        assert!(Location::new(1000).is_valid());
529        assert!(MAX_LOCATION.is_valid());
530        assert!(!Location::new(u64::MAX).is_valid());
531    }
532
533    #[test]
534    fn test_max_location_boundary() {
535        // MAX_LOCATION (2^62) is the max leaf count. It should be valid and convert to
536        // MAX_POSITION (2^63 - 1).
537        assert!(MAX_LOCATION.is_valid());
538        let pos = Position::try_from(MAX_LOCATION).unwrap();
539        assert_eq!(pos, crate::mmr::MAX_POSITION);
540        assert!(pos.is_valid());
541
542        // MAX_POSITION converts back to MAX_LOCATION (they are the same leaf).
543        let loc = Location::try_from(pos).unwrap();
544        assert_eq!(loc, MAX_LOCATION);
545    }
546
547    #[test]
548    fn test_overflow_location_returns_error() {
549        // MAX_LOCATION + 1 exceeds the valid range
550        let over_loc = Location::new(*MAX_LOCATION + 1);
551        assert!(!over_loc.is_valid());
552        assert!(Position::try_from(over_loc).is_err());
553
554        match Position::try_from(over_loc) {
555            Err(crate::mmr::Error::LocationOverflow(loc)) => {
556                assert_eq!(loc, over_loc);
557            }
558            _ => panic!("expected LocationOverflow error"),
559        }
560    }
561
562    #[test]
563    fn test_read_cfg_valid_values() {
564        use commonware_codec::{Encode, ReadExt};
565
566        // Test zero
567        let loc = Location::new(0);
568        let encoded = loc.encode();
569        let decoded = Location::read(&mut encoded.as_ref()).unwrap();
570        assert_eq!(decoded, loc);
571
572        // Test middle value
573        let loc = Location::new(12345);
574        let encoded = loc.encode();
575        let decoded = Location::read(&mut encoded.as_ref()).unwrap();
576        assert_eq!(decoded, loc);
577
578        // Test MAX_LOCATION (boundary)
579        let encoded = MAX_LOCATION.encode();
580        let decoded = Location::read(&mut encoded.as_ref()).unwrap();
581        assert_eq!(decoded, MAX_LOCATION);
582    }
583
584    #[test]
585    fn test_read_cfg_invalid_values() {
586        use commonware_codec::{varint::UInt, Encode, ReadExt};
587
588        // Encode MAX_LOCATION + 1 as a raw varint, then try to decode as Location
589        let invalid_value = *MAX_LOCATION + 1;
590        let encoded = UInt(invalid_value).encode();
591        let result = Location::read(&mut encoded.as_ref());
592        assert!(result.is_err());
593        assert!(matches!(
594            result,
595            Err(commonware_codec::Error::Invalid("Location", _))
596        ));
597
598        // Encode u64::MAX as a raw varint
599        let encoded = UInt(u64::MAX).encode();
600        let result = Location::read(&mut encoded.as_ref());
601        assert!(result.is_err());
602        assert!(matches!(
603            result,
604            Err(commonware_codec::Error::Invalid("Location", _))
605        ));
606    }
607}