commonware_storage/merkle/
position.rs1use super::{location::Location, Family};
2use bytes::{Buf, BufMut};
3use commonware_codec::{varint::UInt, ReadExt};
4use core::{
5 fmt,
6 marker::PhantomData,
7 ops::{Add, AddAssign, Deref, Sub, SubAssign},
8};
9
10pub struct Position<F: Family>(u64, PhantomData<F>);
20
21#[cfg(feature = "arbitrary")]
22impl<F: Family> arbitrary::Arbitrary<'_> for Position<F> {
23 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
24 let value = u.int_in_range(0..=F::MAX_NODES.as_u64())?;
25 Ok(Self::new(value))
26 }
27}
28
29impl<F: Family> Position<F> {
30 #[inline]
32 pub const fn new(pos: u64) -> Self {
33 Self(pos, PhantomData)
34 }
35
36 #[inline]
38 pub const fn as_u64(self) -> u64 {
39 self.0
40 }
41
42 #[inline]
44 pub const fn is_valid(self) -> bool {
45 self.0 <= F::MAX_NODES.as_u64()
46 }
47
48 #[inline]
50 pub const fn is_valid_index(self) -> bool {
51 self.0 < F::MAX_NODES.as_u64()
52 }
53
54 #[inline]
56 pub const fn checked_add(self, rhs: u64) -> Option<Self> {
57 match self.0.checked_add(rhs) {
58 Some(value) => {
59 if value <= F::MAX_NODES.as_u64() {
60 Some(Self::new(value))
61 } else {
62 None
63 }
64 }
65 None => None,
66 }
67 }
68
69 #[inline]
71 pub const fn checked_sub(self, rhs: u64) -> Option<Self> {
72 match self.0.checked_sub(rhs) {
73 Some(value) => Some(Self::new(value)),
74 None => None,
75 }
76 }
77
78 #[inline]
80 pub const fn saturating_add(self, rhs: u64) -> Self {
81 let result = self.0.saturating_add(rhs);
82 if result > F::MAX_NODES.as_u64() {
83 F::MAX_NODES
84 } else {
85 Self::new(result)
86 }
87 }
88
89 #[inline]
91 pub const fn saturating_sub(self, rhs: u64) -> Self {
92 Self::new(self.0.saturating_sub(rhs))
93 }
94
95 #[inline]
97 pub fn is_valid_size(self) -> bool {
98 F::is_valid_size(self)
99 }
100}
101
102impl<F: Family> Copy for Position<F> {}
105
106impl<F: Family> Clone for Position<F> {
107 #[inline]
108 fn clone(&self) -> Self {
109 *self
110 }
111}
112
113impl<F: Family> PartialEq for Position<F> {
114 #[inline]
115 fn eq(&self, other: &Self) -> bool {
116 self.0 == other.0
117 }
118}
119
120impl<F: Family> Eq for Position<F> {}
121
122impl<F: Family> PartialOrd for Position<F> {
123 #[inline]
124 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
125 Some(self.cmp(other))
126 }
127}
128
129impl<F: Family> Ord for Position<F> {
130 #[inline]
131 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
132 self.0.cmp(&other.0)
133 }
134}
135
136impl<F: Family> core::hash::Hash for Position<F> {
137 #[inline]
138 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
139 self.0.hash(state);
140 }
141}
142
143impl<F: Family> Default for Position<F> {
144 #[inline]
145 fn default() -> Self {
146 Self::new(0)
147 }
148}
149
150impl<F: Family> fmt::Debug for Position<F> {
151 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152 f.debug_tuple("Position").field(&self.0).finish()
153 }
154}
155
156impl<F: Family> fmt::Display for Position<F> {
157 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158 write!(f, "Position({})", self.0)
159 }
160}
161
162impl<F: Family> Deref for Position<F> {
163 type Target = u64;
164 fn deref(&self) -> &Self::Target {
165 &self.0
166 }
167}
168
169impl<F: Family> AsRef<u64> for Position<F> {
170 fn as_ref(&self) -> &u64 {
171 &self.0
172 }
173}
174
175impl<F: Family> From<u64> for Position<F> {
176 #[inline]
177 fn from(value: u64) -> Self {
178 Self::new(value)
179 }
180}
181
182impl<F: Family> From<usize> for Position<F> {
183 #[inline]
184 fn from(value: usize) -> Self {
185 Self::new(value as u64)
186 }
187}
188
189impl<F: Family> From<Position<F>> for u64 {
190 #[inline]
191 fn from(position: Position<F>) -> Self {
192 *position
193 }
194}
195
196impl<F: Family> TryFrom<Location<F>> for Position<F> {
202 type Error = super::Error<F>;
203
204 #[inline]
205 fn try_from(loc: Location<F>) -> Result<Self, Self::Error> {
206 if !loc.is_valid() {
207 return Err(super::Error::LocationOverflow(loc));
208 }
209 Ok(F::location_to_position(loc))
210 }
211}
212
213impl<F: Family> Add for Position<F> {
221 type Output = Self;
222
223 #[inline]
224 fn add(self, rhs: Self) -> Self::Output {
225 Self::new(self.0 + rhs.0)
226 }
227}
228
229impl<F: Family> Add<u64> for Position<F> {
235 type Output = Self;
236
237 #[inline]
238 fn add(self, rhs: u64) -> Self::Output {
239 Self::new(self.0 + rhs)
240 }
241}
242
243impl<F: Family> Sub for Position<F> {
249 type Output = Self;
250
251 #[inline]
252 fn sub(self, rhs: Self) -> Self::Output {
253 Self::new(self.0 - rhs.0)
254 }
255}
256
257impl<F: Family> Sub<u64> for Position<F> {
263 type Output = Self;
264
265 #[inline]
266 fn sub(self, rhs: u64) -> Self::Output {
267 Self::new(*self - rhs)
268 }
269}
270
271impl<F: Family> PartialEq<u64> for Position<F> {
272 #[inline]
273 fn eq(&self, other: &u64) -> bool {
274 self.0 == *other
275 }
276}
277
278impl<F: Family> PartialOrd<u64> for Position<F> {
279 #[inline]
280 fn partial_cmp(&self, other: &u64) -> Option<core::cmp::Ordering> {
281 self.0.partial_cmp(other)
282 }
283}
284
285impl<F: Family> PartialEq<Position<F>> for u64 {
286 #[inline]
287 fn eq(&self, other: &Position<F>) -> bool {
288 *self == other.0
289 }
290}
291
292impl<F: Family> PartialOrd<Position<F>> for u64 {
293 #[inline]
294 fn partial_cmp(&self, other: &Position<F>) -> Option<core::cmp::Ordering> {
295 self.partial_cmp(&other.0)
296 }
297}
298
299impl<F: Family> AddAssign<u64> for Position<F> {
305 #[inline]
306 fn add_assign(&mut self, rhs: u64) {
307 self.0 += rhs;
308 }
309}
310
311impl<F: Family> SubAssign<u64> for Position<F> {
317 #[inline]
318 fn sub_assign(&mut self, rhs: u64) {
319 self.0 -= rhs;
320 }
321}
322
323impl<F: Family> commonware_codec::Write for Position<F> {
326 #[inline]
327 fn write(&self, buf: &mut impl BufMut) {
328 UInt(self.0).write(buf);
329 }
330}
331
332impl<F: Family> commonware_codec::EncodeSize for Position<F> {
333 #[inline]
334 fn encode_size(&self) -> usize {
335 UInt(self.0).encode_size()
336 }
337}
338
339impl<F: Family> commonware_codec::Read for Position<F> {
340 type Cfg = ();
341
342 #[inline]
343 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, commonware_codec::Error> {
344 let pos = Self::new(UInt::read(buf)?.into());
345 if pos.is_valid() {
346 Ok(pos)
347 } else {
348 Err(commonware_codec::Error::Invalid(
349 "Position",
350 "value exceeds MAX_NODES",
351 ))
352 }
353 }
354}
355#[cfg(test)]
356mod tests {
357 use super::{Location as GenericLocation, Position as GenericPosition};
358 use crate::{
359 merkle::Family as _,
360 mmr::{self, mem::Mmr, StandardHasher as Standard},
361 };
362 use commonware_cryptography::Sha256;
363
364 type Location = GenericLocation<mmr::Family>;
365 type Position = GenericPosition<mmr::Family>;
366
367 #[test]
369 fn test_from_location() {
370 const CASES: &[(Location, Position)] = &[
371 (Location::new(0), Position::new(0)),
372 (Location::new(1), Position::new(1)),
373 (Location::new(2), Position::new(3)),
374 (Location::new(3), Position::new(4)),
375 (Location::new(4), Position::new(7)),
376 (Location::new(5), Position::new(8)),
377 (Location::new(6), Position::new(10)),
378 (Location::new(7), Position::new(11)),
379 (Location::new(8), Position::new(15)),
380 (Location::new(9), Position::new(16)),
381 (Location::new(10), Position::new(18)),
382 (Location::new(11), Position::new(19)),
383 (Location::new(12), Position::new(22)),
384 (Location::new(13), Position::new(23)),
385 (Location::new(14), Position::new(25)),
386 (Location::new(15), Position::new(26)),
387 ];
388 for (loc, expected_pos) in CASES {
389 let pos = Position::try_from(*loc).unwrap();
390 assert_eq!(pos, *expected_pos);
391 }
392 }
393
394 #[test]
395 fn test_checked_add() {
396 let pos = Position::new(10);
397 assert_eq!(pos.checked_add(5).unwrap(), 15);
398
399 assert!(Position::new(u64::MAX).checked_add(1).is_none());
401
402 assert!(mmr::Family::MAX_NODES.checked_add(1).is_none());
404 assert!(Position::new(*mmr::Family::MAX_NODES - 5)
405 .checked_add(10)
406 .is_none());
407 assert_eq!(
409 Position::new(*mmr::Family::MAX_NODES - 10)
410 .checked_add(10)
411 .unwrap(),
412 *mmr::Family::MAX_NODES
413 );
414
415 assert_eq!(
417 Position::new(*mmr::Family::MAX_NODES - 11)
418 .checked_add(10)
419 .unwrap(),
420 *mmr::Family::MAX_NODES - 1
421 );
422 }
423
424 #[test]
425 fn test_checked_sub() {
426 let pos = Position::new(10);
427 assert_eq!(pos.checked_sub(5).unwrap(), 5);
428 assert!(pos.checked_sub(11).is_none());
429 }
430
431 #[test]
432 fn test_saturating_add() {
433 let pos = Position::new(10);
434 assert_eq!(pos.saturating_add(5), 15);
435
436 assert_eq!(
438 Position::new(u64::MAX).saturating_add(1),
439 *mmr::Family::MAX_NODES
440 );
441 assert_eq!(
442 mmr::Family::MAX_NODES.saturating_add(1),
443 *mmr::Family::MAX_NODES
444 );
445 assert_eq!(
446 mmr::Family::MAX_NODES.saturating_add(1000),
447 *mmr::Family::MAX_NODES
448 );
449 assert_eq!(
450 Position::new(*mmr::Family::MAX_NODES - 5).saturating_add(10),
451 *mmr::Family::MAX_NODES
452 );
453 }
454
455 #[test]
456 fn test_saturating_sub() {
457 let pos = Position::new(10);
458 assert_eq!(pos.saturating_sub(5), 5);
459 assert_eq!(Position::new(0).saturating_sub(1), 0);
460 }
461
462 #[test]
463 fn test_display() {
464 let position = Position::new(42);
465 assert_eq!(position.to_string(), "Position(42)");
466 }
467
468 #[test]
469 fn test_add() {
470 let pos1 = Position::new(10);
471 let pos2 = Position::new(5);
472 assert_eq!((pos1 + pos2), 15);
473 }
474
475 #[test]
476 fn test_sub() {
477 let pos1 = Position::new(10);
478 let pos2 = Position::new(3);
479 assert_eq!((pos1 - pos2), 7);
480 }
481
482 #[test]
483 fn test_comparison_with_u64() {
484 let pos = Position::new(42);
485
486 assert_eq!(pos, 42u64);
488 assert_eq!(42u64, pos);
489 assert_ne!(pos, 43u64);
490 assert_ne!(43u64, pos);
491
492 assert!(pos < 43u64);
494 assert!(43u64 > pos);
495 assert!(pos > 41u64);
496 assert!(41u64 < pos);
497 assert!(pos <= 42u64);
498 assert!(42u64 >= pos);
499 }
500
501 #[test]
502 fn test_assignment_with_u64() {
503 let mut pos = Position::new(10);
504
505 pos += 5;
507 assert_eq!(pos, 15u64);
508
509 pos -= 3;
511 assert_eq!(pos, 12u64);
512 }
513
514 #[test]
515 fn test_max_position() {
516 let max_leaves = 1u64 << 62;
518 let max_size = 2 * max_leaves - 1; assert_eq!(*mmr::Family::MAX_NODES, max_size);
520 assert_eq!(*mmr::Family::MAX_NODES, (1u64 << 63) - 1);
521 assert_eq!(max_size.leading_zeros(), 1); let overflow_size = 2 * (max_leaves + 1) - 1;
525 assert_eq!(overflow_size.leading_zeros(), 0);
526
527 let pos = Position::try_from(mmr::Family::MAX_LEAVES).unwrap();
529 assert_eq!(pos, mmr::Family::MAX_NODES);
530 }
531
532 #[test]
533 fn test_is_valid_size() {
534 let mut size_to_check = Position::new(0);
537 let hasher = Standard::<Sha256>::new();
538 let mut mmr = Mmr::new(&hasher);
539 let digest = [1u8; 32];
540 for _i in 0..10000 {
541 while size_to_check != mmr.size() {
542 assert!(
543 !size_to_check.is_valid_size(),
544 "size_to_check: {} {}",
545 size_to_check,
546 mmr.size()
547 );
548 size_to_check += 1;
549 }
550 assert!(size_to_check.is_valid_size());
551 let batch = mmr
552 .new_batch()
553 .add(&hasher, &digest)
554 .merkleize(&mmr, &hasher);
555 mmr.apply_batch(&batch).unwrap();
556 size_to_check += 1;
557 }
558
559 assert!(!Position::new(u64::MAX).is_valid_size());
561 assert!(Position::new(u64::MAX >> 1).is_valid_size()); assert!(!Position::new((u64::MAX >> 1) + 1).is_valid_size());
563 assert!(mmr::Family::MAX_NODES.is_valid_size()); }
565
566 #[test]
567 fn test_read_cfg_valid_values() {
568 use commonware_codec::{Encode, ReadExt};
569
570 let pos = Position::new(0);
572 let encoded = pos.encode();
573 let decoded = Position::read(&mut encoded.as_ref()).unwrap();
574 assert_eq!(decoded, pos);
575
576 let pos = Position::new(12345);
578 let encoded = pos.encode();
579 let decoded = Position::read(&mut encoded.as_ref()).unwrap();
580 assert_eq!(decoded, pos);
581
582 let pos = mmr::Family::MAX_NODES;
584 let encoded = pos.encode();
585 let decoded = Position::read(&mut encoded.as_ref()).unwrap();
586 assert_eq!(decoded, pos);
587
588 let pos = mmr::Family::MAX_NODES - 1;
590 let encoded = pos.encode();
591 let decoded = Position::read(&mut encoded.as_ref()).unwrap();
592 assert_eq!(decoded, pos);
593 }
594
595 #[test]
596 fn test_read_cfg_invalid_values() {
597 use commonware_codec::{varint::UInt, Encode, ReadExt};
598
599 let invalid_value = *mmr::Family::MAX_NODES + 1;
601 let encoded = UInt(invalid_value).encode();
602 let result = Position::read(&mut encoded.as_ref());
603 assert!(result.is_err());
604 assert!(matches!(
605 result,
606 Err(commonware_codec::Error::Invalid("Position", _))
607 ));
608
609 let encoded = UInt(u64::MAX).encode();
611 let result = Position::read(&mut encoded.as_ref());
612 assert!(result.is_err());
613 assert!(matches!(
614 result,
615 Err(commonware_codec::Error::Invalid("Position", _))
616 ));
617 }
618}