1use super::position::Position;
2use bytes::{Buf, BufMut};
3use commonware_codec::{varint::UInt, Read, ReadExt};
4use core::{
5 convert::TryFrom,
6 fmt,
7 ops::{Add, AddAssign, Deref, Range, Sub, SubAssign},
8};
9use thiserror::Error;
10
11pub const MAX_LOCATION: Location = Location(0x4000_0000_0000_0000); #[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Default, Debug)]
35pub struct Location(u64);
36
37#[cfg(feature = "arbitrary")]
38impl arbitrary::Arbitrary<'_> for Location {
39 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
40 let value = u.int_in_range(0..=*MAX_LOCATION)?;
41 Ok(Self(value))
42 }
43}
44
45impl Location {
46 #[inline]
48 pub const fn new(loc: u64) -> Self {
49 Self(loc)
50 }
51
52 #[inline]
54 pub const fn as_u64(self) -> u64 {
55 self.0
56 }
57
58 #[inline]
61 pub const fn is_valid(self) -> bool {
62 self.0 <= MAX_LOCATION.0
63 }
64
65 #[inline]
67 pub const fn checked_add(self, rhs: u64) -> Option<Self> {
68 match self.0.checked_add(rhs) {
69 Some(value) => {
70 if value <= MAX_LOCATION.0 {
71 Some(Self(value))
72 } else {
73 None
74 }
75 }
76 None => None,
77 }
78 }
79
80 #[inline]
82 pub const fn checked_sub(self, rhs: u64) -> Option<Self> {
83 match self.0.checked_sub(rhs) {
84 Some(value) => Some(Self(value)),
85 None => None,
86 }
87 }
88
89 #[inline]
91 pub const fn saturating_add(self, rhs: u64) -> Self {
92 let result = self.0.saturating_add(rhs);
93 if result > MAX_LOCATION.0 {
94 MAX_LOCATION
95 } else {
96 Self(result)
97 }
98 }
99
100 #[inline]
102 pub const fn saturating_sub(self, rhs: u64) -> Self {
103 Self(self.0.saturating_sub(rhs))
104 }
105}
106
107impl fmt::Display for Location {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 write!(f, "Location({})", self.0)
110 }
111}
112
113impl From<u64> for Location {
114 #[inline]
115 fn from(value: u64) -> Self {
116 Self::new(value)
117 }
118}
119
120impl From<usize> for Location {
121 #[inline]
122 fn from(value: usize) -> Self {
123 Self::new(value as u64)
124 }
125}
126
127impl Deref for Location {
128 type Target = u64;
129 fn deref(&self) -> &Self::Target {
130 &self.0
131 }
132}
133
134impl From<Location> for u64 {
135 #[inline]
136 fn from(loc: Location) -> Self {
137 *loc
138 }
139}
140
141impl commonware_codec::Write for Location {
143 #[inline]
144 fn write(&self, buf: &mut impl BufMut) {
145 UInt(self.0).write(buf);
146 }
147}
148
149impl commonware_codec::EncodeSize for Location {
150 #[inline]
151 fn encode_size(&self) -> usize {
152 UInt(self.0).encode_size()
153 }
154}
155
156impl Read for Location {
157 type Cfg = ();
158
159 #[inline]
160 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, commonware_codec::Error> {
161 let value: u64 = UInt::read(buf)?.into();
162 let loc = Self::new(value);
163 if loc.is_valid() {
164 Ok(loc)
165 } else {
166 Err(commonware_codec::Error::Invalid(
167 "Location",
168 "value exceeds MAX_LOCATION",
169 ))
170 }
171 }
172}
173
174impl Add for Location {
180 type Output = Self;
181
182 #[inline]
183 fn add(self, rhs: Self) -> Self::Output {
184 Self(self.0 + rhs.0)
185 }
186}
187
188impl Add<u64> for Location {
194 type Output = Self;
195
196 #[inline]
197 fn add(self, rhs: u64) -> Self::Output {
198 Self(self.0 + rhs)
199 }
200}
201
202impl Sub for Location {
208 type Output = Self;
209
210 #[inline]
211 fn sub(self, rhs: Self) -> Self::Output {
212 Self(self.0 - rhs.0)
213 }
214}
215
216impl Sub<u64> for Location {
222 type Output = Self;
223
224 #[inline]
225 fn sub(self, rhs: u64) -> Self::Output {
226 Self(self.0 - rhs)
227 }
228}
229
230impl PartialEq<u64> for Location {
231 #[inline]
232 fn eq(&self, other: &u64) -> bool {
233 self.0 == *other
234 }
235}
236
237impl PartialOrd<u64> for Location {
238 #[inline]
239 fn partial_cmp(&self, other: &u64) -> Option<core::cmp::Ordering> {
240 self.0.partial_cmp(other)
241 }
242}
243
244impl PartialEq<Location> for u64 {
246 #[inline]
247 fn eq(&self, other: &Location) -> bool {
248 *self == other.0
249 }
250}
251
252impl PartialOrd<Location> for u64 {
253 #[inline]
254 fn partial_cmp(&self, other: &Location) -> Option<core::cmp::Ordering> {
255 self.partial_cmp(&other.0)
256 }
257}
258
259impl AddAssign<u64> for Location {
265 #[inline]
266 fn add_assign(&mut self, rhs: u64) {
267 self.0 += rhs;
268 }
269}
270
271impl SubAssign<u64> for Location {
277 #[inline]
278 fn sub_assign(&mut self, rhs: u64) {
279 self.0 -= rhs;
280 }
281}
282
283impl TryFrom<Position> for Location {
284 type Error = LocationError;
285
286 #[inline]
293 fn try_from(pos: Position) -> Result<Self, Self::Error> {
294 if !pos.is_valid() {
296 return Err(LocationError::Overflow(pos));
297 }
298 if *pos == 0 {
300 return Ok(Self(0));
301 }
302
303 let start = u64::MAX >> (pos + 1).leading_zeros();
306 let height = start.trailing_ones();
307 if height == 0 {
309 return Err(LocationError::NonLeaf(pos));
310 }
311 let mut two_h = 1 << (height - 1);
312 let mut cur_node = start - 1;
313 let mut leaf_loc_floor = 0u64;
314
315 while two_h > 1 {
316 if cur_node == *pos {
317 return Err(LocationError::NonLeaf(pos));
318 }
319 let left_pos = cur_node - two_h;
320 two_h >>= 1;
321 if *pos > left_pos {
322 leaf_loc_floor += two_h;
325 cur_node -= 1; } else {
327 cur_node = left_pos;
329 }
330 }
331
332 Ok(Self(leaf_loc_floor))
333 }
334}
335
336#[derive(Debug, Clone, Copy, Eq, PartialEq, Error)]
338pub enum LocationError {
339 #[error("{0} is not a leaf")]
340 NonLeaf(Position),
341
342 #[error("{0} > MAX_LOCATION")]
343 Overflow(Position),
344}
345
346pub trait LocationRangeExt {
348 fn to_usize_range(&self) -> Range<usize>;
350}
351
352impl LocationRangeExt for Range<Location> {
353 #[inline]
354 fn to_usize_range(&self) -> Range<usize> {
355 *self.start as usize..*self.end as usize
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::{Location, MAX_LOCATION};
362 use crate::mmr::{position::Position, LocationError, MAX_POSITION};
363
364 #[test]
366 fn test_try_from_position() {
367 const CASES: &[(Position, Location)] = &[
368 (Position::new(0), Location::new(0)),
369 (Position::new(1), Location::new(1)),
370 (Position::new(3), Location::new(2)),
371 (Position::new(4), Location::new(3)),
372 (Position::new(7), Location::new(4)),
373 (Position::new(8), Location::new(5)),
374 (Position::new(10), Location::new(6)),
375 (Position::new(11), Location::new(7)),
376 (Position::new(15), Location::new(8)),
377 (Position::new(16), Location::new(9)),
378 (Position::new(18), Location::new(10)),
379 (Position::new(19), Location::new(11)),
380 (Position::new(22), Location::new(12)),
381 (Position::new(23), Location::new(13)),
382 (Position::new(25), Location::new(14)),
383 (Position::new(26), Location::new(15)),
384 ];
385 for (pos, expected_loc) in CASES {
386 let loc = Location::try_from(*pos).expect("should map to a leaf location");
387 assert_eq!(loc, *expected_loc);
388 }
389 }
390
391 #[test]
393 fn test_try_from_position_error() {
394 const CASES: &[Position] = &[
395 Position::new(2),
396 Position::new(5),
397 Position::new(6),
398 Position::new(9),
399 Position::new(12),
400 Position::new(13),
401 Position::new(14),
402 Position::new(17),
403 Position::new(20),
404 Position::new(21),
405 Position::new(24),
406 Position::new(27),
407 Position::new(28),
408 Position::new(29),
409 Position::new(30),
410 ];
411 for &pos in CASES {
412 let err = Location::try_from(pos).expect_err("position is not a leaf");
413 assert_eq!(err, LocationError::NonLeaf(pos));
414 }
415 }
416
417 #[test]
418 fn test_try_from_position_error_overflow() {
419 let overflow_pos = Position::new(u64::MAX);
420 let err = Location::try_from(overflow_pos).expect_err("should overflow");
421 assert_eq!(err, LocationError::Overflow(overflow_pos));
422
423 let result = Location::try_from(MAX_POSITION);
425 assert_eq!(result, Ok(MAX_LOCATION));
426
427 let overflow_pos = MAX_POSITION + 1;
428 let err = Location::try_from(overflow_pos).expect_err("should overflow");
429 assert_eq!(err, LocationError::Overflow(overflow_pos));
430 }
431
432 #[test]
433 fn test_checked_add() {
434 let loc = Location::new(10);
435 assert_eq!(loc.checked_add(5).unwrap(), 15);
436
437 assert!(Location::new(u64::MAX).checked_add(1).is_none());
439
440 assert!(MAX_LOCATION.checked_add(1).is_none());
442
443 let loc = Location::new(*MAX_LOCATION - 10);
445 assert_eq!(loc.checked_add(10).unwrap(), *MAX_LOCATION);
446 }
447
448 #[test]
449 fn test_checked_sub() {
450 let loc = Location::new(10);
451 assert_eq!(loc.checked_sub(5).unwrap(), 5);
452 assert!(loc.checked_sub(11).is_none());
453 }
454
455 #[test]
456 fn test_saturating_add() {
457 let loc = Location::new(10);
458 assert_eq!(loc.saturating_add(5), 15);
459
460 assert_eq!(Location::new(u64::MAX).saturating_add(1), MAX_LOCATION);
462 assert_eq!(MAX_LOCATION.saturating_add(1), MAX_LOCATION);
463 assert_eq!(MAX_LOCATION.saturating_add(1000), MAX_LOCATION);
464 }
465
466 #[test]
467 fn test_saturating_sub() {
468 let loc = Location::new(10);
469 assert_eq!(loc.saturating_sub(5), 5);
470 assert_eq!(Location::new(0).saturating_sub(1), 0);
471 }
472
473 #[test]
474 fn test_display() {
475 let location = Location::new(42);
476 assert_eq!(location.to_string(), "Location(42)");
477 }
478
479 #[test]
480 fn test_add() {
481 let loc1 = Location::new(10);
482 let loc2 = Location::new(5);
483 assert_eq!((loc1 + loc2), 15);
484 }
485
486 #[test]
487 fn test_sub() {
488 let loc1 = Location::new(10);
489 let loc2 = Location::new(3);
490 assert_eq!((loc1 - loc2), 7);
491 }
492
493 #[test]
494 fn test_comparison_with_u64() {
495 let loc = Location::new(42);
496
497 assert_eq!(loc, 42u64);
499 assert_eq!(42u64, loc);
500 assert_ne!(loc, 43u64);
501 assert_ne!(43u64, loc);
502
503 assert!(loc < 43u64);
505 assert!(43u64 > loc);
506 assert!(loc > 41u64);
507 assert!(41u64 < loc);
508 assert!(loc <= 42u64);
509 assert!(42u64 >= loc);
510 }
511
512 #[test]
513 fn test_assignment_with_u64() {
514 let mut loc = Location::new(10);
515
516 loc += 5;
518 assert_eq!(loc, 15u64);
519
520 loc -= 3;
522 assert_eq!(loc, 12u64);
523 }
524
525 #[test]
526 fn test_is_valid() {
527 assert!(Location::new(0).is_valid());
528 assert!(Location::new(1000).is_valid());
529 assert!(MAX_LOCATION.is_valid());
530 assert!(!Location::new(u64::MAX).is_valid());
531 }
532
533 #[test]
534 fn test_max_location_boundary() {
535 assert!(MAX_LOCATION.is_valid());
538 let pos = Position::try_from(MAX_LOCATION).unwrap();
539 assert_eq!(pos, crate::mmr::MAX_POSITION);
540 assert!(pos.is_valid());
541
542 let loc = Location::try_from(pos).unwrap();
544 assert_eq!(loc, MAX_LOCATION);
545 }
546
547 #[test]
548 fn test_overflow_location_returns_error() {
549 let over_loc = Location::new(*MAX_LOCATION + 1);
551 assert!(!over_loc.is_valid());
552 assert!(Position::try_from(over_loc).is_err());
553
554 match Position::try_from(over_loc) {
555 Err(crate::mmr::Error::LocationOverflow(loc)) => {
556 assert_eq!(loc, over_loc);
557 }
558 _ => panic!("expected LocationOverflow error"),
559 }
560 }
561
562 #[test]
563 fn test_read_cfg_valid_values() {
564 use commonware_codec::{Encode, ReadExt};
565
566 let loc = Location::new(0);
568 let encoded = loc.encode();
569 let decoded = Location::read(&mut encoded.as_ref()).unwrap();
570 assert_eq!(decoded, loc);
571
572 let loc = Location::new(12345);
574 let encoded = loc.encode();
575 let decoded = Location::read(&mut encoded.as_ref()).unwrap();
576 assert_eq!(decoded, loc);
577
578 let encoded = MAX_LOCATION.encode();
580 let decoded = Location::read(&mut encoded.as_ref()).unwrap();
581 assert_eq!(decoded, MAX_LOCATION);
582 }
583
584 #[test]
585 fn test_read_cfg_invalid_values() {
586 use commonware_codec::{varint::UInt, Encode, ReadExt};
587
588 let invalid_value = *MAX_LOCATION + 1;
590 let encoded = UInt(invalid_value).encode();
591 let result = Location::read(&mut encoded.as_ref());
592 assert!(result.is_err());
593 assert!(matches!(
594 result,
595 Err(commonware_codec::Error::Invalid("Location", _))
596 ));
597
598 let encoded = UInt(u64::MAX).encode();
600 let result = Location::read(&mut encoded.as_ref());
601 assert!(result.is_err());
602 assert!(matches!(
603 result,
604 Err(commonware_codec::Error::Invalid("Location", _))
605 ));
606 }
607}