1use super::location::Location;
2use bytes::{Buf, BufMut};
3use commonware_codec::ReadExt;
4use core::{
5 fmt,
6 ops::{Add, AddAssign, Deref, Sub, SubAssign},
7};
8
9pub const MAX_POSITION: Position = Position::new(0x7FFFFFFFFFFFFFFE); #[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Default, Debug)]
17pub struct Position(u64);
18
19#[cfg(feature = "arbitrary")]
20impl arbitrary::Arbitrary<'_> for Position {
21 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
22 let value = u.int_in_range(0..=MAX_POSITION.0)?;
23 Ok(Self(value))
24 }
25}
26
27impl Position {
28 #[inline]
30 pub const fn new(pos: u64) -> Self {
31 Self(pos)
32 }
33
34 #[inline]
36 pub const fn as_u64(self) -> u64 {
37 self.0
38 }
39
40 #[inline]
42 pub const fn checked_add(self, rhs: u64) -> Option<Self> {
43 match self.0.checked_add(rhs) {
44 Some(value) => {
45 if value <= MAX_POSITION.0 {
46 Some(Self(value))
47 } else {
48 None
49 }
50 }
51 None => None,
52 }
53 }
54
55 #[inline]
57 pub const fn checked_sub(self, rhs: u64) -> Option<Self> {
58 match self.0.checked_sub(rhs) {
59 Some(value) => Some(Self(value)),
60 None => None,
61 }
62 }
63
64 #[inline]
66 pub const fn saturating_add(self, rhs: u64) -> Self {
67 let result = self.0.saturating_add(rhs);
68 if result > MAX_POSITION.0 {
69 MAX_POSITION
70 } else {
71 Self(result)
72 }
73 }
74
75 #[inline]
77 pub const fn saturating_sub(self, rhs: u64) -> Self {
78 Self(self.0.saturating_sub(rhs))
79 }
80
81 #[inline]
87 pub const fn is_mmr_size(self) -> bool {
88 if self.0 == 0 {
89 return true;
90 }
91 let leading_zeros = self.0.leading_zeros();
92 if leading_zeros == 0 {
93 return false;
95 }
96 let start = u64::MAX >> leading_zeros;
97 let mut two_h = 1 << start.trailing_ones();
98 let mut node_pos = start.checked_sub(1).expect("start > 0 because size != 0");
99 while two_h > 1 {
100 if node_pos < self.0 {
101 if two_h == 2 {
102 return node_pos == self.0 - 1;
105 }
106 node_pos += two_h - 1;
108 if node_pos < self.0 {
109 return false;
111 }
112 continue;
113 }
114 two_h >>= 1;
116 node_pos -= two_h;
117 }
118 true
119 }
120}
121
122impl fmt::Display for Position {
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 write!(f, "Position({})", self.0)
125 }
126}
127
128impl Deref for Position {
129 type Target = u64;
130 fn deref(&self) -> &Self::Target {
131 &self.0
132 }
133}
134
135impl AsRef<u64> for Position {
136 fn as_ref(&self) -> &u64 {
137 &self.0
138 }
139}
140
141impl From<u64> for Position {
142 #[inline]
143 fn from(value: u64) -> Self {
144 Self::new(value)
145 }
146}
147
148impl From<usize> for Position {
149 #[inline]
150 fn from(value: usize) -> Self {
151 Self::new(value as u64)
152 }
153}
154
155impl From<Position> for u64 {
156 #[inline]
157 fn from(position: Position) -> Self {
158 *position
159 }
160}
161
162impl TryFrom<Location> for Position {
180 type Error = super::Error;
181
182 #[inline]
183 fn try_from(loc: Location) -> Result<Self, Self::Error> {
184 if !loc.is_valid() {
185 return Err(super::Error::LocationOverflow(loc));
186 }
187 let loc_val = *loc;
189 Ok(Self(
190 loc_val
191 .checked_mul(2)
192 .expect("should not overflow for valid location")
193 - loc_val.count_ones() as u64,
194 ))
195 }
196}
197
198impl Add for Position {
204 type Output = Self;
205
206 #[inline]
207 fn add(self, rhs: Self) -> Self::Output {
208 Self(self.0 + rhs.0)
209 }
210}
211
212impl Add<u64> for Position {
218 type Output = Self;
219
220 #[inline]
221 fn add(self, rhs: u64) -> Self::Output {
222 Self(self.0 + rhs)
223 }
224}
225
226impl Sub for Position {
232 type Output = Self;
233
234 #[inline]
235 fn sub(self, rhs: Self) -> Self::Output {
236 Self(self.0 - rhs.0)
237 }
238}
239
240impl Sub<u64> for Position {
246 type Output = Self;
247
248 #[inline]
249 fn sub(self, rhs: u64) -> Self::Output {
250 Self(self.0 - rhs)
251 }
252}
253
254impl PartialEq<u64> for Position {
255 #[inline]
256 fn eq(&self, other: &u64) -> bool {
257 self.0 == *other
258 }
259}
260
261impl PartialOrd<u64> for Position {
262 #[inline]
263 fn partial_cmp(&self, other: &u64) -> Option<core::cmp::Ordering> {
264 self.0.partial_cmp(other)
265 }
266}
267
268impl PartialEq<Position> for u64 {
270 #[inline]
271 fn eq(&self, other: &Position) -> bool {
272 *self == other.0
273 }
274}
275
276impl PartialOrd<Position> for u64 {
277 #[inline]
278 fn partial_cmp(&self, other: &Position) -> Option<core::cmp::Ordering> {
279 self.partial_cmp(&other.0)
280 }
281}
282
283impl AddAssign<u64> for Position {
289 #[inline]
290 fn add_assign(&mut self, rhs: u64) {
291 self.0 += rhs;
292 }
293}
294
295impl SubAssign<u64> for Position {
301 #[inline]
302 fn sub_assign(&mut self, rhs: u64) {
303 self.0 -= rhs;
304 }
305}
306
307impl commonware_codec::Write for Position {
309 #[inline]
310 fn write(&self, buf: &mut impl BufMut) {
311 commonware_codec::varint::UInt(self.0).write(buf);
312 }
313}
314
315impl commonware_codec::EncodeSize for Position {
316 #[inline]
317 fn encode_size(&self) -> usize {
318 commonware_codec::varint::UInt(self.0).encode_size()
319 }
320}
321
322impl commonware_codec::Read for Position {
323 type Cfg = ();
324
325 #[inline]
326 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, commonware_codec::Error> {
327 let value: u64 = commonware_codec::varint::UInt::read(buf)?.into();
328 if value <= MAX_POSITION.0 {
329 Ok(Self(value))
330 } else {
331 Err(commonware_codec::Error::Invalid(
332 "Position",
333 "value exceeds MAX_POSITION",
334 ))
335 }
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::{Location, Position};
342 use crate::mmr::{mem::DirtyMmr, StandardHasher as Standard, MAX_LOCATION, MAX_POSITION};
343 use commonware_cryptography::Sha256;
344
345 #[test]
347 fn test_from_location() {
348 const CASES: &[(Location, Position)] = &[
349 (Location::new_unchecked(0), Position::new(0)),
350 (Location::new_unchecked(1), Position::new(1)),
351 (Location::new_unchecked(2), Position::new(3)),
352 (Location::new_unchecked(3), Position::new(4)),
353 (Location::new_unchecked(4), Position::new(7)),
354 (Location::new_unchecked(5), Position::new(8)),
355 (Location::new_unchecked(6), Position::new(10)),
356 (Location::new_unchecked(7), Position::new(11)),
357 (Location::new_unchecked(8), Position::new(15)),
358 (Location::new_unchecked(9), Position::new(16)),
359 (Location::new_unchecked(10), Position::new(18)),
360 (Location::new_unchecked(11), Position::new(19)),
361 (Location::new_unchecked(12), Position::new(22)),
362 (Location::new_unchecked(13), Position::new(23)),
363 (Location::new_unchecked(14), Position::new(25)),
364 (Location::new_unchecked(15), Position::new(26)),
365 ];
366 for (loc, expected_pos) in CASES {
367 let pos = Position::try_from(*loc).unwrap();
368 assert_eq!(pos, *expected_pos);
369 }
370 }
371
372 #[test]
373 fn test_checked_add() {
374 let pos = Position::new(10);
375 assert_eq!(pos.checked_add(5).unwrap(), 15);
376
377 assert!(Position::new(u64::MAX).checked_add(1).is_none());
379
380 assert!(MAX_POSITION.checked_add(1).is_none());
382 assert!(Position::new(*MAX_POSITION - 5).checked_add(10).is_none());
383
384 assert_eq!(
386 Position::new(*MAX_POSITION - 10).checked_add(10).unwrap(),
387 MAX_POSITION
388 );
389 }
390
391 #[test]
392 fn test_checked_sub() {
393 let pos = Position::new(10);
394 assert_eq!(pos.checked_sub(5).unwrap(), 5);
395 assert!(pos.checked_sub(11).is_none());
396 }
397
398 #[test]
399 fn test_saturating_add() {
400 let pos = Position::new(10);
401 assert_eq!(pos.saturating_add(5), 15);
402
403 assert_eq!(Position::new(u64::MAX).saturating_add(1), MAX_POSITION);
405 assert_eq!(MAX_POSITION.saturating_add(1), MAX_POSITION);
406 assert_eq!(MAX_POSITION.saturating_add(1000), MAX_POSITION);
407 assert_eq!(
408 Position::new(*MAX_POSITION - 5).saturating_add(10),
409 MAX_POSITION
410 );
411 }
412
413 #[test]
414 fn test_saturating_sub() {
415 let pos = Position::new(10);
416 assert_eq!(pos.saturating_sub(5), 5);
417 assert_eq!(Position::new(0).saturating_sub(1), 0);
418 }
419
420 #[test]
421 fn test_display() {
422 let position = Position::new(42);
423 assert_eq!(position.to_string(), "Position(42)");
424 }
425
426 #[test]
427 fn test_add() {
428 let pos1 = Position::new(10);
429 let pos2 = Position::new(5);
430 assert_eq!((pos1 + pos2), 15);
431 }
432
433 #[test]
434 fn test_sub() {
435 let pos1 = Position::new(10);
436 let pos2 = Position::new(3);
437 assert_eq!((pos1 - pos2), 7);
438 }
439
440 #[test]
441 fn test_comparison_with_u64() {
442 let pos = Position::new(42);
443
444 assert_eq!(pos, 42u64);
446 assert_eq!(42u64, pos);
447 assert_ne!(pos, 43u64);
448 assert_ne!(43u64, pos);
449
450 assert!(pos < 43u64);
452 assert!(43u64 > pos);
453 assert!(pos > 41u64);
454 assert!(41u64 < pos);
455 assert!(pos <= 42u64);
456 assert!(42u64 >= pos);
457 }
458
459 #[test]
460 fn test_assignment_with_u64() {
461 let mut pos = Position::new(10);
462
463 pos += 5;
465 assert_eq!(pos, 15u64);
466
467 pos -= 3;
469 assert_eq!(pos, 12u64);
470 }
471
472 #[test]
473 fn test_max_position() {
474 let max_leaves = 1u64 << 62;
483
484 let mmr_size_at_max = 2 * max_leaves - 1;
487 assert_eq!(mmr_size_at_max, (1u64 << 63) - 1);
488 assert_eq!(mmr_size_at_max.leading_zeros(), 1); let expected_max_pos = mmr_size_at_max - 1;
492 assert_eq!(MAX_POSITION, expected_max_pos);
493 assert_eq!(MAX_POSITION, (1u64 << 63) - 2);
494
495 let hypothetical_mmr_size = MAX_POSITION + 2; assert_eq!(hypothetical_mmr_size, 1u64 << 63);
499 assert_eq!(hypothetical_mmr_size.leading_zeros(), 0); let max_loc = Location::new_unchecked(MAX_LOCATION);
504 let last_leaf_pos = Position::try_from(max_loc).unwrap();
505 assert!(*last_leaf_pos < MAX_POSITION);
506 }
507
508 #[test]
509 fn test_is_mmr_size() {
510 let mut size_to_check = Position::new(0);
513 let mut hasher = Standard::<Sha256>::new();
514 let mut mmr = DirtyMmr::new();
515 let digest = [1u8; 32];
516 for _i in 0..10000 {
517 while size_to_check != mmr.size() {
518 assert!(
519 !size_to_check.is_mmr_size(),
520 "size_to_check: {} {}",
521 size_to_check,
522 mmr.size()
523 );
524 size_to_check += 1;
525 }
526 assert!(size_to_check.is_mmr_size());
527 mmr.add(&mut hasher, &digest);
528 size_to_check += 1;
529 }
530
531 assert!(!Position::new(u64::MAX).is_mmr_size());
533 assert!(Position::new(u64::MAX >> 1).is_mmr_size());
534 assert!(!Position::new((u64::MAX >> 1) + 1).is_mmr_size());
535 assert!(!MAX_POSITION.is_mmr_size());
536 }
537
538 #[test]
539 fn test_read_cfg_valid_values() {
540 use commonware_codec::{Encode, ReadExt};
541
542 let pos = Position::new(0);
544 let encoded = pos.encode();
545 let decoded = Position::read(&mut encoded.as_ref()).unwrap();
546 assert_eq!(decoded, pos);
547
548 let pos = Position::new(12345);
550 let encoded = pos.encode();
551 let decoded = Position::read(&mut encoded.as_ref()).unwrap();
552 assert_eq!(decoded, pos);
553
554 let pos = MAX_POSITION;
556 let encoded = pos.encode();
557 let decoded = Position::read(&mut encoded.as_ref()).unwrap();
558 assert_eq!(decoded, pos);
559 }
560
561 #[test]
562 fn test_read_cfg_invalid_values() {
563 use commonware_codec::{Encode, ReadExt};
564
565 let invalid_value = *MAX_POSITION + 1;
567 let encoded = commonware_codec::varint::UInt(invalid_value).encode();
568 let result = Position::read(&mut encoded.as_ref());
569 assert!(result.is_err());
570 assert!(matches!(
571 result,
572 Err(commonware_codec::Error::Invalid("Position", _))
573 ));
574
575 let encoded = commonware_codec::varint::UInt(u64::MAX).encode();
577 let result = Position::read(&mut encoded.as_ref());
578 assert!(result.is_err());
579 assert!(matches!(
580 result,
581 Err(commonware_codec::Error::Invalid("Position", _))
582 ));
583 }
584}