Skip to main content

commonware_storage/merkle/
position.rs

1use super::{location::Location, Family};
2use bytes::{Buf, BufMut};
3use commonware_codec::{varint::UInt, ReadExt};
4use core::{
5    fmt,
6    marker::PhantomData,
7    ops::{Add, AddAssign, Deref, Sub, SubAssign},
8};
9
10/// A [Position] is a node index or node count in a Merkle structure.
11/// This is in contrast to a [Location], which is a leaf index or leaf count.
12///
13/// # Limits
14///
15/// Values up to the family's maximum are valid (see [Position::is_valid]). As a 0-based node
16/// index, valid indices are `0..MAX - 1`. As a node count or total size, the maximum is `MAX`
17/// itself. Use [Position::is_valid_size] to ask whether a count is a structurally valid size for
18/// the specific Merkle family.
19pub struct Position<F: Family>(u64, PhantomData<F>);
20
21#[cfg(feature = "arbitrary")]
22impl<F: Family> arbitrary::Arbitrary<'_> for Position<F> {
23    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
24        let value = u.int_in_range(0..=F::MAX_NODES.as_u64())?;
25        Ok(Self::new(value))
26    }
27}
28
29impl<F: Family> Position<F> {
30    /// Return a new [Position] from a raw `u64`.
31    #[inline]
32    pub const fn new(pos: u64) -> Self {
33        Self(pos, PhantomData)
34    }
35
36    /// Return the underlying `u64` value.
37    #[inline]
38    pub const fn as_u64(self) -> u64 {
39        self.0
40    }
41
42    /// Returns `true` iff this value is a valid node count or size (`<= MAX_NODES`).
43    #[inline]
44    pub const fn is_valid(self) -> bool {
45        self.0 <= F::MAX_NODES.as_u64()
46    }
47
48    /// Returns `true` iff this value is a valid 0-based node index (`< MAX_NODES`).
49    #[inline]
50    pub const fn is_valid_index(self) -> bool {
51        self.0 < F::MAX_NODES.as_u64()
52    }
53
54    /// Return `self + rhs` returning `None` on overflow or if result exceeds the maximum.
55    #[inline]
56    pub const fn checked_add(self, rhs: u64) -> Option<Self> {
57        match self.0.checked_add(rhs) {
58            Some(value) => {
59                if value <= F::MAX_NODES.as_u64() {
60                    Some(Self::new(value))
61                } else {
62                    None
63                }
64            }
65            None => None,
66        }
67    }
68
69    /// Return `self - rhs` returning `None` on underflow.
70    #[inline]
71    pub const fn checked_sub(self, rhs: u64) -> Option<Self> {
72        match self.0.checked_sub(rhs) {
73            Some(value) => Some(Self::new(value)),
74            None => None,
75        }
76    }
77
78    /// Return `self + rhs` saturating at the maximum.
79    #[inline]
80    pub const fn saturating_add(self, rhs: u64) -> Self {
81        let result = self.0.saturating_add(rhs);
82        if result > F::MAX_NODES.as_u64() {
83            F::MAX_NODES
84        } else {
85            Self::new(result)
86        }
87    }
88
89    /// Return `self - rhs` saturating at zero.
90    #[inline]
91    pub const fn saturating_sub(self, rhs: u64) -> Self {
92        Self::new(self.0.saturating_sub(rhs))
93    }
94
95    /// Returns whether this is a valid size for this Merkle structure.
96    #[inline]
97    pub fn is_valid_size(self) -> bool {
98        F::is_valid_size(self)
99    }
100}
101
102// --- Manual trait implementations (to avoid unnecessary bounds on F) ---
103
104impl<F: Family> Copy for Position<F> {}
105
106impl<F: Family> Clone for Position<F> {
107    #[inline]
108    fn clone(&self) -> Self {
109        *self
110    }
111}
112
113impl<F: Family> PartialEq for Position<F> {
114    #[inline]
115    fn eq(&self, other: &Self) -> bool {
116        self.0 == other.0
117    }
118}
119
120impl<F: Family> Eq for Position<F> {}
121
122impl<F: Family> PartialOrd for Position<F> {
123    #[inline]
124    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
125        Some(self.cmp(other))
126    }
127}
128
129impl<F: Family> Ord for Position<F> {
130    #[inline]
131    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
132        self.0.cmp(&other.0)
133    }
134}
135
136impl<F: Family> core::hash::Hash for Position<F> {
137    #[inline]
138    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
139        self.0.hash(state);
140    }
141}
142
143impl<F: Family> Default for Position<F> {
144    #[inline]
145    fn default() -> Self {
146        Self::new(0)
147    }
148}
149
150impl<F: Family> fmt::Debug for Position<F> {
151    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152        f.debug_tuple("Position").field(&self.0).finish()
153    }
154}
155
156impl<F: Family> fmt::Display for Position<F> {
157    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158        write!(f, "Position({})", self.0)
159    }
160}
161
162impl<F: Family> Deref for Position<F> {
163    type Target = u64;
164    fn deref(&self) -> &Self::Target {
165        &self.0
166    }
167}
168
169impl<F: Family> AsRef<u64> for Position<F> {
170    fn as_ref(&self) -> &u64 {
171        &self.0
172    }
173}
174
175impl<F: Family> From<u64> for Position<F> {
176    #[inline]
177    fn from(value: u64) -> Self {
178        Self::new(value)
179    }
180}
181
182impl<F: Family> From<usize> for Position<F> {
183    #[inline]
184    fn from(value: usize) -> Self {
185        Self::new(value as u64)
186    }
187}
188
189impl<F: Family> From<Position<F>> for u64 {
190    #[inline]
191    fn from(position: Position<F>) -> Self {
192        *position
193    }
194}
195
196/// Convert a leaf [Location] to its corresponding node [Position].
197///
198/// Equivalently, convert a leaf count to the corresponding total node count (size).
199///
200/// Returns [`super::Error::LocationOverflow`] if `!loc.is_valid()`.
201impl<F: Family> TryFrom<Location<F>> for Position<F> {
202    type Error = super::Error<F>;
203
204    #[inline]
205    fn try_from(loc: Location<F>) -> Result<Self, Self::Error> {
206        if !loc.is_valid() {
207            return Err(super::Error::LocationOverflow(loc));
208        }
209        Ok(F::location_to_position(loc))
210    }
211}
212
213// --- Arithmetic operators ---
214
215/// Add two positions together.
216///
217/// # Panics
218///
219/// Panics if the result overflows.
220impl<F: Family> Add for Position<F> {
221    type Output = Self;
222
223    #[inline]
224    fn add(self, rhs: Self) -> Self::Output {
225        Self::new(self.0 + rhs.0)
226    }
227}
228
229/// Add a position and a `u64`.
230///
231/// # Panics
232///
233/// Panics if the result overflows.
234impl<F: Family> Add<u64> for Position<F> {
235    type Output = Self;
236
237    #[inline]
238    fn add(self, rhs: u64) -> Self::Output {
239        Self::new(self.0 + rhs)
240    }
241}
242
243/// Subtract two positions.
244///
245/// # Panics
246///
247/// Panics if the result underflows.
248impl<F: Family> Sub for Position<F> {
249    type Output = Self;
250
251    #[inline]
252    fn sub(self, rhs: Self) -> Self::Output {
253        Self::new(self.0 - rhs.0)
254    }
255}
256
257/// Subtract a `u64` from a position.
258///
259/// # Panics
260///
261/// Panics if the result underflows.
262impl<F: Family> Sub<u64> for Position<F> {
263    type Output = Self;
264
265    #[inline]
266    fn sub(self, rhs: u64) -> Self::Output {
267        Self::new(*self - rhs)
268    }
269}
270
271impl<F: Family> PartialEq<u64> for Position<F> {
272    #[inline]
273    fn eq(&self, other: &u64) -> bool {
274        self.0 == *other
275    }
276}
277
278impl<F: Family> PartialOrd<u64> for Position<F> {
279    #[inline]
280    fn partial_cmp(&self, other: &u64) -> Option<core::cmp::Ordering> {
281        self.0.partial_cmp(other)
282    }
283}
284
285impl<F: Family> PartialEq<Position<F>> for u64 {
286    #[inline]
287    fn eq(&self, other: &Position<F>) -> bool {
288        *self == other.0
289    }
290}
291
292impl<F: Family> PartialOrd<Position<F>> for u64 {
293    #[inline]
294    fn partial_cmp(&self, other: &Position<F>) -> Option<core::cmp::Ordering> {
295        self.partial_cmp(&other.0)
296    }
297}
298
299/// Add a `u64` to a position.
300///
301/// # Panics
302///
303/// Panics if the result overflows.
304impl<F: Family> AddAssign<u64> for Position<F> {
305    #[inline]
306    fn add_assign(&mut self, rhs: u64) {
307        self.0 += rhs;
308    }
309}
310
311/// Subtract a `u64` from a position.
312///
313/// # Panics
314///
315/// Panics if the result underflows.
316impl<F: Family> SubAssign<u64> for Position<F> {
317    #[inline]
318    fn sub_assign(&mut self, rhs: u64) {
319        self.0 -= rhs;
320    }
321}
322
323// --- Codec implementations using varint encoding ---
324
325impl<F: Family> commonware_codec::Write for Position<F> {
326    #[inline]
327    fn write(&self, buf: &mut impl BufMut) {
328        UInt(self.0).write(buf);
329    }
330}
331
332impl<F: Family> commonware_codec::EncodeSize for Position<F> {
333    #[inline]
334    fn encode_size(&self) -> usize {
335        UInt(self.0).encode_size()
336    }
337}
338
339impl<F: Family> commonware_codec::Read for Position<F> {
340    type Cfg = ();
341
342    #[inline]
343    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, commonware_codec::Error> {
344        let pos = Self::new(UInt::read(buf)?.into());
345        if pos.is_valid() {
346            Ok(pos)
347        } else {
348            Err(commonware_codec::Error::Invalid(
349                "Position",
350                "value exceeds MAX_NODES",
351            ))
352        }
353    }
354}
355#[cfg(test)]
356mod tests {
357    use super::{Location as GenericLocation, Position as GenericPosition};
358    use crate::{
359        merkle::Family as _,
360        mmr::{self, mem::Mmr, StandardHasher as Standard},
361    };
362    use commonware_cryptography::Sha256;
363
364    type Location = GenericLocation<mmr::Family>;
365    type Position = GenericPosition<mmr::Family>;
366
367    // Test that the [Position::from] function returns the correct position for leaf locations.
368    #[test]
369    fn test_from_location() {
370        const CASES: &[(Location, Position)] = &[
371            (Location::new(0), Position::new(0)),
372            (Location::new(1), Position::new(1)),
373            (Location::new(2), Position::new(3)),
374            (Location::new(3), Position::new(4)),
375            (Location::new(4), Position::new(7)),
376            (Location::new(5), Position::new(8)),
377            (Location::new(6), Position::new(10)),
378            (Location::new(7), Position::new(11)),
379            (Location::new(8), Position::new(15)),
380            (Location::new(9), Position::new(16)),
381            (Location::new(10), Position::new(18)),
382            (Location::new(11), Position::new(19)),
383            (Location::new(12), Position::new(22)),
384            (Location::new(13), Position::new(23)),
385            (Location::new(14), Position::new(25)),
386            (Location::new(15), Position::new(26)),
387        ];
388        for (loc, expected_pos) in CASES {
389            let pos = Position::try_from(*loc).unwrap();
390            assert_eq!(pos, *expected_pos);
391        }
392    }
393
394    #[test]
395    fn test_checked_add() {
396        let pos = Position::new(10);
397        assert_eq!(pos.checked_add(5).unwrap(), 15);
398
399        // Overflow returns None
400        assert!(Position::new(u64::MAX).checked_add(1).is_none());
401
402        // Exceeding MAX_NODES returns None, but MAX_NODES itself IS valid (inclusive bound)
403        assert!(mmr::Family::MAX_NODES.checked_add(1).is_none());
404        assert!(Position::new(*mmr::Family::MAX_NODES - 5)
405            .checked_add(10)
406            .is_none());
407        // MAX_NODES - 10 + 10 = MAX_NODES, which IS valid (inclusive bound)
408        assert_eq!(
409            Position::new(*mmr::Family::MAX_NODES - 10)
410                .checked_add(10)
411                .unwrap(),
412            *mmr::Family::MAX_NODES
413        );
414
415        // MAX_NODES - 11 + 10 = MAX_NODES - 1, also valid
416        assert_eq!(
417            Position::new(*mmr::Family::MAX_NODES - 11)
418                .checked_add(10)
419                .unwrap(),
420            *mmr::Family::MAX_NODES - 1
421        );
422    }
423
424    #[test]
425    fn test_checked_sub() {
426        let pos = Position::new(10);
427        assert_eq!(pos.checked_sub(5).unwrap(), 5);
428        assert!(pos.checked_sub(11).is_none());
429    }
430
431    #[test]
432    fn test_saturating_add() {
433        let pos = Position::new(10);
434        assert_eq!(pos.saturating_add(5), 15);
435
436        // Saturates AT MAX_NODES (inclusive bound)
437        assert_eq!(
438            Position::new(u64::MAX).saturating_add(1),
439            *mmr::Family::MAX_NODES
440        );
441        assert_eq!(
442            mmr::Family::MAX_NODES.saturating_add(1),
443            *mmr::Family::MAX_NODES
444        );
445        assert_eq!(
446            mmr::Family::MAX_NODES.saturating_add(1000),
447            *mmr::Family::MAX_NODES
448        );
449        assert_eq!(
450            Position::new(*mmr::Family::MAX_NODES - 5).saturating_add(10),
451            *mmr::Family::MAX_NODES
452        );
453    }
454
455    #[test]
456    fn test_saturating_sub() {
457        let pos = Position::new(10);
458        assert_eq!(pos.saturating_sub(5), 5);
459        assert_eq!(Position::new(0).saturating_sub(1), 0);
460    }
461
462    #[test]
463    fn test_display() {
464        let position = Position::new(42);
465        assert_eq!(position.to_string(), "Position(42)");
466    }
467
468    #[test]
469    fn test_add() {
470        let pos1 = Position::new(10);
471        let pos2 = Position::new(5);
472        assert_eq!((pos1 + pos2), 15);
473    }
474
475    #[test]
476    fn test_sub() {
477        let pos1 = Position::new(10);
478        let pos2 = Position::new(3);
479        assert_eq!((pos1 - pos2), 7);
480    }
481
482    #[test]
483    fn test_comparison_with_u64() {
484        let pos = Position::new(42);
485
486        // Test equality
487        assert_eq!(pos, 42u64);
488        assert_eq!(42u64, pos);
489        assert_ne!(pos, 43u64);
490        assert_ne!(43u64, pos);
491
492        // Test ordering
493        assert!(pos < 43u64);
494        assert!(43u64 > pos);
495        assert!(pos > 41u64);
496        assert!(41u64 < pos);
497        assert!(pos <= 42u64);
498        assert!(42u64 >= pos);
499    }
500
501    #[test]
502    fn test_assignment_with_u64() {
503        let mut pos = Position::new(10);
504
505        // Test add assignment
506        pos += 5;
507        assert_eq!(pos, 15u64);
508
509        // Test sub assignment
510        pos -= 3;
511        assert_eq!(pos, 12u64);
512    }
513
514    #[test]
515    fn test_max_position() {
516        // MAX_NODES = max MMR size = 2^63 - 1 (for 2^62 leaves).
517        let max_leaves = 1u64 << 62;
518        let max_size = 2 * max_leaves - 1; // 2^63 - 1
519        assert_eq!(*mmr::Family::MAX_NODES, max_size);
520        assert_eq!(*mmr::Family::MAX_NODES, (1u64 << 63) - 1);
521        assert_eq!(max_size.leading_zeros(), 1); // top bit clear
522
523        // One more leaf would overflow: size = 2^63, top bit set.
524        let overflow_size = 2 * (max_leaves + 1) - 1;
525        assert_eq!(overflow_size.leading_zeros(), 0);
526
527        // MAX_LEAVES is a valid location (inclusive bound) and converts to MAX_NODES.
528        let pos = Position::try_from(mmr::Family::MAX_LEAVES).unwrap();
529        assert_eq!(pos, mmr::Family::MAX_NODES);
530    }
531
532    #[test]
533    fn test_is_valid_size() {
534        // Build an MMR one node at a time and check that the validity check is correct for all
535        // sizes up to the current size.
536        let mut size_to_check = Position::new(0);
537        let hasher = Standard::<Sha256>::new();
538        let mut mmr = Mmr::new(&hasher);
539        let digest = [1u8; 32];
540        for _i in 0..10000 {
541            while size_to_check != mmr.size() {
542                assert!(
543                    !size_to_check.is_valid_size(),
544                    "size_to_check: {} {}",
545                    size_to_check,
546                    mmr.size()
547                );
548                size_to_check += 1;
549            }
550            assert!(size_to_check.is_valid_size());
551            let batch = mmr
552                .new_batch()
553                .add(&hasher, &digest)
554                .merkleize(&mmr, &hasher);
555            mmr.apply_batch(&batch).unwrap();
556            size_to_check += 1;
557        }
558
559        // Test overflow boundaries.
560        assert!(!Position::new(u64::MAX).is_valid_size());
561        assert!(Position::new(u64::MAX >> 1).is_valid_size()); // 2^63 - 1 = MAX_NODES
562        assert!(!Position::new((u64::MAX >> 1) + 1).is_valid_size());
563        assert!(mmr::Family::MAX_NODES.is_valid_size()); // MAX_NODES is the largest valid MMR size
564    }
565
566    #[test]
567    fn test_read_cfg_valid_values() {
568        use commonware_codec::{Encode, ReadExt};
569
570        // Test zero
571        let pos = Position::new(0);
572        let encoded = pos.encode();
573        let decoded = Position::read(&mut encoded.as_ref()).unwrap();
574        assert_eq!(decoded, pos);
575
576        // Test middle value
577        let pos = Position::new(12345);
578        let encoded = pos.encode();
579        let decoded = Position::read(&mut encoded.as_ref()).unwrap();
580        assert_eq!(decoded, pos);
581
582        // MAX_NODES is a valid value (inclusive bound), so it should decode successfully
583        let pos = mmr::Family::MAX_NODES;
584        let encoded = pos.encode();
585        let decoded = Position::read(&mut encoded.as_ref()).unwrap();
586        assert_eq!(decoded, pos);
587
588        // MAX_NODES - 1 is also valid
589        let pos = mmr::Family::MAX_NODES - 1;
590        let encoded = pos.encode();
591        let decoded = Position::read(&mut encoded.as_ref()).unwrap();
592        assert_eq!(decoded, pos);
593    }
594
595    #[test]
596    fn test_read_cfg_invalid_values() {
597        use commonware_codec::{varint::UInt, Encode, ReadExt};
598
599        // Encode MAX_NODES + 1 as a raw varint, then try to decode as Position
600        let invalid_value = *mmr::Family::MAX_NODES + 1;
601        let encoded = UInt(invalid_value).encode();
602        let result = Position::read(&mut encoded.as_ref());
603        assert!(result.is_err());
604        assert!(matches!(
605            result,
606            Err(commonware_codec::Error::Invalid("Position", _))
607        ));
608
609        // Encode u64::MAX as a raw varint
610        let encoded = UInt(u64::MAX).encode();
611        let result = Position::read(&mut encoded.as_ref());
612        assert!(result.is_err());
613        assert!(matches!(
614            result,
615            Err(commonware_codec::Error::Invalid("Position", _))
616        ));
617    }
618}