Skip to main content

commonware_storage/merkle/
location.rs

1use super::{position::Position, Family};
2use bytes::{Buf, BufMut};
3use commonware_codec::{varint::UInt, ReadExt};
4use core::{
5    convert::TryFrom,
6    fmt,
7    marker::PhantomData,
8    ops::{Add, AddAssign, Deref, Range, Sub, SubAssign},
9};
10
11/// A [Location] is a leaf index or leaf count in a Merkle structure.
12/// This is in contrast to a [Position], which is a node index or node count.
13///
14/// # Limits
15///
16/// Values up to the family's maximum are valid (see [Location::is_valid]). As a 0-based leaf
17/// index, valid indices are `0..MAX - 1`. As a leaf count or exclusive range-end, the maximum
18/// is `MAX` itself.
19pub struct Location<F: Family>(u64, PhantomData<F>);
20
21#[cfg(feature = "arbitrary")]
22impl<F: Family> arbitrary::Arbitrary<'_> for Location<F> {
23    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
24        let value = u.int_in_range(0..=F::MAX_LEAVES.as_u64())?;
25        Ok(Self(value, PhantomData))
26    }
27}
28
29impl<F: Family> Location<F> {
30    /// Return a new [Location] from a raw `u64`.
31    #[inline]
32    pub const fn new(loc: u64) -> Self {
33        Self(loc, PhantomData)
34    }
35
36    /// Return the underlying `u64` value.
37    #[inline]
38    pub const fn as_u64(self) -> u64 {
39        self.0
40    }
41
42    /// Returns `true` iff this value is a valid leaf count (`<= MAX_LEAVES`).
43    #[inline]
44    pub const fn is_valid(self) -> bool {
45        self.0 <= F::MAX_LEAVES.as_u64()
46    }
47
48    /// Returns `true` iff this value is a valid 0-based leaf index (`< MAX_LEAVES`).
49    #[inline]
50    pub const fn is_valid_index(self) -> bool {
51        self.0 < F::MAX_LEAVES.as_u64()
52    }
53
54    /// Return `self + rhs` returning `None` on overflow or if result exceeds the maximum.
55    #[inline]
56    pub const fn checked_add(self, rhs: u64) -> Option<Self> {
57        match self.0.checked_add(rhs) {
58            Some(value) => {
59                if value <= F::MAX_LEAVES.as_u64() {
60                    Some(Self::new(value))
61                } else {
62                    None
63                }
64            }
65            None => None,
66        }
67    }
68
69    /// Return `self - rhs` returning `None` on underflow.
70    #[inline]
71    pub const fn checked_sub(self, rhs: u64) -> Option<Self> {
72        match self.0.checked_sub(rhs) {
73            Some(value) => Some(Self::new(value)),
74            None => None,
75        }
76    }
77
78    /// Return `self + rhs` saturating at the maximum.
79    #[inline]
80    pub const fn saturating_add(self, rhs: u64) -> Self {
81        let result = self.0.saturating_add(rhs);
82        if result > F::MAX_LEAVES.as_u64() {
83            F::MAX_LEAVES
84        } else {
85            Self::new(result)
86        }
87    }
88
89    /// Return `self - rhs` saturating at zero.
90    #[inline]
91    pub const fn saturating_sub(self, rhs: u64) -> Self {
92        Self::new(self.0.saturating_sub(rhs))
93    }
94}
95
96// --- Manual trait implementations (to avoid unnecessary bounds on F) ---
97
98impl<F: Family> Copy for Location<F> {}
99
100impl<F: Family> Clone for Location<F> {
101    #[inline]
102    fn clone(&self) -> Self {
103        *self
104    }
105}
106
107impl<F: Family> PartialEq for Location<F> {
108    #[inline]
109    fn eq(&self, other: &Self) -> bool {
110        self.0 == other.0
111    }
112}
113
114impl<F: Family> Eq for Location<F> {}
115
116impl<F: Family> PartialOrd for Location<F> {
117    #[inline]
118    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
119        Some(self.cmp(other))
120    }
121}
122
123impl<F: Family> Ord for Location<F> {
124    #[inline]
125    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
126        self.0.cmp(&other.0)
127    }
128}
129
130impl<F: Family> core::hash::Hash for Location<F> {
131    #[inline]
132    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
133        self.0.hash(state);
134    }
135}
136
137impl<F: Family> Default for Location<F> {
138    #[inline]
139    fn default() -> Self {
140        Self::new(0)
141    }
142}
143
144impl<F: Family> fmt::Debug for Location<F> {
145    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
146        f.debug_tuple("Location").field(&self.0).finish()
147    }
148}
149
150impl<F: Family> fmt::Display for Location<F> {
151    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152        write!(f, "Location({})", self.0)
153    }
154}
155
156impl<F: Family> Deref for Location<F> {
157    type Target = u64;
158    fn deref(&self) -> &Self::Target {
159        &self.0
160    }
161}
162
163impl<F: Family> From<u64> for Location<F> {
164    #[inline]
165    fn from(value: u64) -> Self {
166        Self::new(value)
167    }
168}
169
170impl<F: Family> From<usize> for Location<F> {
171    #[inline]
172    fn from(value: usize) -> Self {
173        Self::new(value as u64)
174    }
175}
176
177impl<F: Family> From<Location<F>> for u64 {
178    #[inline]
179    fn from(loc: Location<F>) -> Self {
180        *loc
181    }
182}
183
184// --- Codec implementations using varint encoding ---
185
186impl<F: Family> commonware_codec::Write for Location<F> {
187    #[inline]
188    fn write(&self, buf: &mut impl BufMut) {
189        UInt(self.0).write(buf);
190    }
191}
192
193impl<F: Family> commonware_codec::EncodeSize for Location<F> {
194    #[inline]
195    fn encode_size(&self) -> usize {
196        UInt(self.0).encode_size()
197    }
198}
199
200impl<F: Family> commonware_codec::Read for Location<F> {
201    type Cfg = ();
202
203    #[inline]
204    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, commonware_codec::Error> {
205        let loc = Self::new(UInt::read(buf)?.into());
206        if loc.is_valid() {
207            Ok(loc)
208        } else {
209            Err(commonware_codec::Error::Invalid(
210                "Location",
211                "value exceeds MAX_LEAVES",
212            ))
213        }
214    }
215}
216
217/// Attempt to derive the [Location] of a given node [Position].
218///
219/// Equivalently, convert a total node count (size) to the corresponding leaf count.
220///
221/// Returns an error if `pos` exceeds the valid range or if it is neither a leaf position nor a
222/// valid size.
223impl<F: Family> TryFrom<Position<F>> for Location<F> {
224    type Error = super::Error<F>;
225
226    #[inline]
227    fn try_from(pos: Position<F>) -> Result<Self, Self::Error> {
228        if !pos.is_valid() {
229            return Err(super::Error::PositionOverflow(pos));
230        }
231        F::position_to_location(pos).ok_or(super::Error::NonLeaf(pos))
232    }
233}
234
235// --- Arithmetic operators ---
236
237/// Add two locations together.
238///
239/// # Panics
240///
241/// Panics if the result overflows.
242impl<F: Family> Add for Location<F> {
243    type Output = Self;
244
245    #[inline]
246    fn add(self, rhs: Self) -> Self::Output {
247        Self::new(self.0 + rhs.0)
248    }
249}
250
251/// Add a location and a `u64`.
252///
253/// # Panics
254///
255/// Panics if the result overflows.
256impl<F: Family> Add<u64> for Location<F> {
257    type Output = Self;
258
259    #[inline]
260    fn add(self, rhs: u64) -> Self::Output {
261        Self::new(self.0 + rhs)
262    }
263}
264
265/// Subtract two locations.
266///
267/// # Panics
268///
269/// Panics if the result underflows.
270impl<F: Family> Sub for Location<F> {
271    type Output = Self;
272
273    #[inline]
274    fn sub(self, rhs: Self) -> Self::Output {
275        Self::new(self.0 - rhs.0)
276    }
277}
278
279/// Subtract a `u64` from a location.
280///
281/// # Panics
282///
283/// Panics if the result underflows.
284impl<F: Family> Sub<u64> for Location<F> {
285    type Output = Self;
286
287    #[inline]
288    fn sub(self, rhs: u64) -> Self::Output {
289        Self::new(self.0 - rhs)
290    }
291}
292
293impl<F: Family> PartialEq<u64> for Location<F> {
294    #[inline]
295    fn eq(&self, other: &u64) -> bool {
296        self.0 == *other
297    }
298}
299
300impl<F: Family> PartialOrd<u64> for Location<F> {
301    #[inline]
302    fn partial_cmp(&self, other: &u64) -> Option<core::cmp::Ordering> {
303        self.0.partial_cmp(other)
304    }
305}
306
307impl<F: Family> PartialEq<Location<F>> for u64 {
308    #[inline]
309    fn eq(&self, other: &Location<F>) -> bool {
310        *self == other.0
311    }
312}
313
314impl<F: Family> PartialOrd<Location<F>> for u64 {
315    #[inline]
316    fn partial_cmp(&self, other: &Location<F>) -> Option<core::cmp::Ordering> {
317        self.partial_cmp(&other.0)
318    }
319}
320
321/// Add a `u64` to a location.
322///
323/// # Panics
324///
325/// Panics if the result overflows.
326impl<F: Family> AddAssign<u64> for Location<F> {
327    #[inline]
328    fn add_assign(&mut self, rhs: u64) {
329        self.0 += rhs;
330    }
331}
332
333/// Subtract a `u64` from a location.
334///
335/// # Panics
336///
337/// Panics if the result underflows.
338impl<F: Family> SubAssign<u64> for Location<F> {
339    #[inline]
340    fn sub_assign(&mut self, rhs: u64) {
341        self.0 -= rhs;
342    }
343}
344
345/// Extension trait for converting `Range<Location>` into other range types.
346pub trait LocationRangeExt {
347    /// Convert a `Range<Location>` to a `Range<usize>` suitable for slice indexing.
348    fn to_usize_range(&self) -> Range<usize>;
349}
350
351impl<F: Family> LocationRangeExt for Range<Location<F>> {
352    #[inline]
353    fn to_usize_range(&self) -> Range<usize> {
354        *self.start as usize..*self.end as usize
355    }
356}