1use super::position::Position;
2use crate::mmr::MAX_POSITION;
3use bytes::{Buf, BufMut};
4use commonware_codec::{Read, ReadExt};
5use core::{
6 convert::TryFrom,
7 fmt,
8 ops::{Add, AddAssign, Deref, Range, Sub, SubAssign},
9};
10use thiserror::Error;
11
12pub const MAX_LOCATION: u64 = 0x3FFF_FFFF_FFFF_FFFF; #[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Default, Debug)]
53pub struct Location(u64);
54
55#[cfg(feature = "arbitrary")]
56impl arbitrary::Arbitrary<'_> for Location {
57 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
58 let value = u.int_in_range(0..=MAX_LOCATION)?;
59 Ok(Self(value))
60 }
61}
62
63impl Location {
64 #[inline]
69 pub(crate) const fn new_unchecked(loc: u64) -> Self {
70 Self(loc)
71 }
72
73 #[inline]
93 pub const fn new(loc: u64) -> Option<Self> {
94 if loc > MAX_LOCATION {
95 None
96 } else {
97 Some(Self(loc))
98 }
99 }
100
101 #[inline]
103 pub const fn as_u64(self) -> u64 {
104 self.0
105 }
106
107 #[inline]
109 pub const fn is_valid(self) -> bool {
110 self.0 <= MAX_LOCATION
111 }
112
113 #[inline]
115 pub const fn checked_add(self, rhs: u64) -> Option<Self> {
116 match self.0.checked_add(rhs) {
117 Some(value) => {
118 if value <= MAX_LOCATION {
119 Some(Self(value))
120 } else {
121 None
122 }
123 }
124 None => None,
125 }
126 }
127
128 #[inline]
130 pub const fn checked_sub(self, rhs: u64) -> Option<Self> {
131 match self.0.checked_sub(rhs) {
132 Some(value) => Some(Self(value)),
133 None => None,
134 }
135 }
136
137 #[inline]
139 pub const fn saturating_add(self, rhs: u64) -> Self {
140 let result = self.0.saturating_add(rhs);
141 if result > MAX_LOCATION {
142 Self(MAX_LOCATION)
143 } else {
144 Self(result)
145 }
146 }
147
148 #[inline]
150 pub const fn saturating_sub(self, rhs: u64) -> Self {
151 Self(self.0.saturating_sub(rhs))
152 }
153}
154
155impl fmt::Display for Location {
156 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157 write!(f, "Location({})", self.0)
158 }
159}
160
161impl From<u64> for Location {
162 #[inline]
163 fn from(value: u64) -> Self {
164 Self::new_unchecked(value)
165 }
166}
167
168impl From<usize> for Location {
169 #[inline]
170 fn from(value: usize) -> Self {
171 Self::new_unchecked(value as u64)
172 }
173}
174
175impl Deref for Location {
176 type Target = u64;
177 fn deref(&self) -> &Self::Target {
178 &self.0
179 }
180}
181
182impl From<Location> for u64 {
183 #[inline]
184 fn from(loc: Location) -> Self {
185 *loc
186 }
187}
188
189impl commonware_codec::Write for Location {
191 #[inline]
192 fn write(&self, buf: &mut impl BufMut) {
193 commonware_codec::varint::UInt(self.0).write(buf);
194 }
195}
196
197impl commonware_codec::EncodeSize for Location {
198 #[inline]
199 fn encode_size(&self) -> usize {
200 commonware_codec::varint::UInt(self.0).encode_size()
201 }
202}
203
204impl Read for Location {
205 type Cfg = ();
206
207 #[inline]
208 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, commonware_codec::Error> {
209 let value: u64 = commonware_codec::varint::UInt::read(buf)?.into();
210 Self::new(value).ok_or(commonware_codec::Error::Invalid(
211 "Location",
212 "value exceeds MAX_LOCATION",
213 ))
214 }
215}
216
217impl Add for Location {
223 type Output = Self;
224
225 #[inline]
226 fn add(self, rhs: Self) -> Self::Output {
227 Self(self.0 + rhs.0)
228 }
229}
230
231impl Add<u64> for Location {
237 type Output = Self;
238
239 #[inline]
240 fn add(self, rhs: u64) -> Self::Output {
241 Self(self.0 + rhs)
242 }
243}
244
245impl Sub for Location {
251 type Output = Self;
252
253 #[inline]
254 fn sub(self, rhs: Self) -> Self::Output {
255 Self(self.0 - rhs.0)
256 }
257}
258
259impl Sub<u64> for Location {
265 type Output = Self;
266
267 #[inline]
268 fn sub(self, rhs: u64) -> Self::Output {
269 Self(self.0 - rhs)
270 }
271}
272
273impl PartialEq<u64> for Location {
274 #[inline]
275 fn eq(&self, other: &u64) -> bool {
276 self.0 == *other
277 }
278}
279
280impl PartialOrd<u64> for Location {
281 #[inline]
282 fn partial_cmp(&self, other: &u64) -> Option<core::cmp::Ordering> {
283 self.0.partial_cmp(other)
284 }
285}
286
287impl PartialEq<Location> for u64 {
289 #[inline]
290 fn eq(&self, other: &Location) -> bool {
291 *self == other.0
292 }
293}
294
295impl PartialOrd<Location> for u64 {
296 #[inline]
297 fn partial_cmp(&self, other: &Location) -> Option<core::cmp::Ordering> {
298 self.partial_cmp(&other.0)
299 }
300}
301
302impl AddAssign<u64> for Location {
308 #[inline]
309 fn add_assign(&mut self, rhs: u64) {
310 self.0 += rhs;
311 }
312}
313
314impl SubAssign<u64> for Location {
320 #[inline]
321 fn sub_assign(&mut self, rhs: u64) {
322 self.0 -= rhs;
323 }
324}
325
326impl TryFrom<Position> for Location {
327 type Error = LocationError;
328
329 #[inline]
336 fn try_from(pos: Position) -> Result<Self, Self::Error> {
337 if *pos > MAX_POSITION {
339 return Err(LocationError::Overflow(pos));
340 }
341 if *pos == 0 {
343 return Ok(Self(0));
344 }
345
346 let start = u64::MAX >> (pos + 1).leading_zeros();
349 let height = start.trailing_ones();
350 if height == 0 {
352 return Err(LocationError::NonLeaf(pos));
353 }
354 let mut two_h = 1 << (height - 1);
355 let mut cur_node = start - 1;
356 let mut leaf_loc_floor = 0u64;
357
358 while two_h > 1 {
359 if cur_node == *pos {
360 return Err(LocationError::NonLeaf(pos));
361 }
362 let left_pos = cur_node - two_h;
363 two_h >>= 1;
364 if *pos > left_pos {
365 leaf_loc_floor += two_h;
368 cur_node -= 1; } else {
370 cur_node = left_pos;
372 }
373 }
374
375 Ok(Self(leaf_loc_floor))
376 }
377}
378
379#[derive(Debug, Clone, Copy, Eq, PartialEq, Error)]
381pub enum LocationError {
382 #[error("{0} is not a leaf")]
383 NonLeaf(Position),
384
385 #[error("{0} > MAX_LOCATION")]
386 Overflow(Position),
387}
388
389pub trait LocationRangeExt {
391 fn to_usize_range(&self) -> Range<usize>;
393}
394
395impl LocationRangeExt for Range<Location> {
396 #[inline]
397 fn to_usize_range(&self) -> Range<usize> {
398 *self.start as usize..*self.end as usize
399 }
400}
401
402#[cfg(test)]
403mod tests {
404 use super::{Location, MAX_LOCATION};
405 use crate::mmr::{position::Position, LocationError, MAX_POSITION};
406
407 #[test]
409 fn test_try_from_position() {
410 const CASES: &[(Position, Location)] = &[
411 (Position::new(0), Location::new_unchecked(0)),
412 (Position::new(1), Location::new_unchecked(1)),
413 (Position::new(3), Location::new_unchecked(2)),
414 (Position::new(4), Location::new_unchecked(3)),
415 (Position::new(7), Location::new_unchecked(4)),
416 (Position::new(8), Location::new_unchecked(5)),
417 (Position::new(10), Location::new_unchecked(6)),
418 (Position::new(11), Location::new_unchecked(7)),
419 (Position::new(15), Location::new_unchecked(8)),
420 (Position::new(16), Location::new_unchecked(9)),
421 (Position::new(18), Location::new_unchecked(10)),
422 (Position::new(19), Location::new_unchecked(11)),
423 (Position::new(22), Location::new_unchecked(12)),
424 (Position::new(23), Location::new_unchecked(13)),
425 (Position::new(25), Location::new_unchecked(14)),
426 (Position::new(26), Location::new_unchecked(15)),
427 ];
428 for (pos, expected_loc) in CASES {
429 let loc = Location::try_from(*pos).expect("should map to a leaf location");
430 assert_eq!(loc, *expected_loc);
431 }
432 }
433
434 #[test]
436 fn test_try_from_position_error() {
437 const CASES: &[Position] = &[
438 Position::new(2),
439 Position::new(5),
440 Position::new(6),
441 Position::new(9),
442 Position::new(12),
443 Position::new(13),
444 Position::new(14),
445 Position::new(17),
446 Position::new(20),
447 Position::new(21),
448 Position::new(24),
449 Position::new(27),
450 Position::new(28),
451 Position::new(29),
452 Position::new(30),
453 ];
454 for &pos in CASES {
455 let err = Location::try_from(pos).expect_err("position is not a leaf");
456 assert_eq!(err, LocationError::NonLeaf(pos));
457 }
458 }
459
460 #[test]
461 fn test_try_from_position_error_overflow() {
462 let overflow_pos = Position::new(u64::MAX);
463 let err = Location::try_from(overflow_pos).expect_err("should overflow");
464 assert_eq!(err, LocationError::Overflow(overflow_pos));
465
466 let result = Location::try_from(MAX_POSITION);
468 assert_eq!(result, Err(LocationError::NonLeaf(MAX_POSITION)));
469
470 let overflow_pos = MAX_POSITION + 1;
471 let err = Location::try_from(overflow_pos).expect_err("should overflow");
472 assert_eq!(err, LocationError::Overflow(overflow_pos));
473 }
474
475 #[test]
476 fn test_checked_add() {
477 let loc = Location::new_unchecked(10);
478 assert_eq!(loc.checked_add(5).unwrap(), 15);
479
480 assert!(Location::new_unchecked(u64::MAX).checked_add(1).is_none());
482
483 assert!(Location::new_unchecked(MAX_LOCATION)
485 .checked_add(1)
486 .is_none());
487
488 let loc = Location::new_unchecked(MAX_LOCATION - 10);
490 assert_eq!(loc.checked_add(10).unwrap(), MAX_LOCATION);
491 }
492
493 #[test]
494 fn test_checked_sub() {
495 let loc = Location::new_unchecked(10);
496 assert_eq!(loc.checked_sub(5).unwrap(), 5);
497 assert!(loc.checked_sub(11).is_none());
498 }
499
500 #[test]
501 fn test_saturating_add() {
502 let loc = Location::new_unchecked(10);
503 assert_eq!(loc.saturating_add(5), 15);
504
505 assert_eq!(
507 Location::new_unchecked(u64::MAX).saturating_add(1),
508 MAX_LOCATION
509 );
510 assert_eq!(
511 Location::new_unchecked(MAX_LOCATION).saturating_add(1),
512 MAX_LOCATION
513 );
514 assert_eq!(
515 Location::new_unchecked(MAX_LOCATION).saturating_add(1000),
516 MAX_LOCATION
517 );
518 }
519
520 #[test]
521 fn test_saturating_sub() {
522 let loc = Location::new_unchecked(10);
523 assert_eq!(loc.saturating_sub(5), 5);
524 assert_eq!(Location::new_unchecked(0).saturating_sub(1), 0);
525 }
526
527 #[test]
528 fn test_display() {
529 let location = Location::new_unchecked(42);
530 assert_eq!(location.to_string(), "Location(42)");
531 }
532
533 #[test]
534 fn test_add() {
535 let loc1 = Location::new_unchecked(10);
536 let loc2 = Location::new_unchecked(5);
537 assert_eq!((loc1 + loc2), 15);
538 }
539
540 #[test]
541 fn test_sub() {
542 let loc1 = Location::new_unchecked(10);
543 let loc2 = Location::new_unchecked(3);
544 assert_eq!((loc1 - loc2), 7);
545 }
546
547 #[test]
548 fn test_comparison_with_u64() {
549 let loc = Location::new_unchecked(42);
550
551 assert_eq!(loc, 42u64);
553 assert_eq!(42u64, loc);
554 assert_ne!(loc, 43u64);
555 assert_ne!(43u64, loc);
556
557 assert!(loc < 43u64);
559 assert!(43u64 > loc);
560 assert!(loc > 41u64);
561 assert!(41u64 < loc);
562 assert!(loc <= 42u64);
563 assert!(42u64 >= loc);
564 }
565
566 #[test]
567 fn test_assignment_with_u64() {
568 let mut loc = Location::new_unchecked(10);
569
570 loc += 5;
572 assert_eq!(loc, 15u64);
573
574 loc -= 3;
576 assert_eq!(loc, 12u64);
577 }
578
579 #[test]
580 fn test_new() {
581 assert!(Location::new(0).is_some());
583 assert!(Location::new(1000).is_some());
584 assert!(Location::new(MAX_LOCATION).is_some());
585
586 assert!(Location::new(MAX_LOCATION + 1).is_none());
588 assert!(Location::new(u64::MAX).is_none());
589 }
590
591 #[test]
592 fn test_is_valid() {
593 assert!(Location::new_unchecked(0).is_valid());
594 assert!(Location::new_unchecked(1000).is_valid());
595 assert!(Location::new_unchecked(MAX_LOCATION).is_valid());
596 assert!(Location::new_unchecked(MAX_LOCATION).is_valid());
597 assert!(!Location::new_unchecked(u64::MAX).is_valid());
598 }
599
600 #[test]
601 fn test_max_location_boundary() {
602 let max_loc = Location::new_unchecked(MAX_LOCATION);
604 assert!(max_loc.is_valid());
605 let pos = Position::try_from(max_loc).unwrap();
606 let expected = (1u64 << 63) - 64;
610 assert_eq!(*pos, expected);
611 }
612
613 #[test]
614 fn test_overflow_location_returns_error() {
615 let over_loc = Location::new_unchecked(MAX_LOCATION + 1);
617 assert!(Position::try_from(over_loc).is_err());
618
619 match Position::try_from(over_loc) {
621 Err(crate::mmr::Error::LocationOverflow(loc)) => {
622 assert_eq!(loc, over_loc);
623 }
624 _ => panic!("expected LocationOverflow error"),
625 }
626 }
627
628 #[test]
629 fn test_read_cfg_valid_values() {
630 use commonware_codec::{Encode, ReadExt};
631
632 let loc = Location::new(0).unwrap();
634 let encoded = loc.encode();
635 let decoded = Location::read(&mut encoded.as_ref()).unwrap();
636 assert_eq!(decoded, loc);
637
638 let loc = Location::new(12345).unwrap();
640 let encoded = loc.encode();
641 let decoded = Location::read(&mut encoded.as_ref()).unwrap();
642 assert_eq!(decoded, loc);
643
644 let loc = Location::new(MAX_LOCATION).unwrap();
646 let encoded = loc.encode();
647 let decoded = Location::read(&mut encoded.as_ref()).unwrap();
648 assert_eq!(decoded, loc);
649 }
650
651 #[test]
652 fn test_read_cfg_invalid_values() {
653 use commonware_codec::{Encode, ReadExt};
654
655 let invalid_value = MAX_LOCATION + 1;
657 let encoded = commonware_codec::varint::UInt(invalid_value).encode();
658 let result = Location::read(&mut encoded.as_ref());
659 assert!(result.is_err());
660 assert!(matches!(
661 result,
662 Err(commonware_codec::Error::Invalid("Location", _))
663 ));
664
665 let encoded = commonware_codec::varint::UInt(u64::MAX).encode();
667 let result = Location::read(&mut encoded.as_ref());
668 assert!(result.is_err());
669 assert!(matches!(
670 result,
671 Err(commonware_codec::Error::Invalid("Location", _))
672 ));
673 }
674}