1use super::location::Location;
2use bytes::{Buf, BufMut};
3use commonware_codec::{varint::UInt, ReadExt};
4use core::{
5 fmt,
6 ops::{Add, AddAssign, Deref, Sub, SubAssign},
7};
8
9pub const MAX_POSITION: Position = Position::new(0x7FFFFFFFFFFFFFFF); #[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Default, Debug)]
20pub struct Position(u64);
21
22#[cfg(feature = "arbitrary")]
23impl arbitrary::Arbitrary<'_> for Position {
24 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
25 let value = u.int_in_range(0..=MAX_POSITION.0)?;
26 Ok(Self(value))
27 }
28}
29
30impl Position {
31 #[inline]
33 pub const fn new(pos: u64) -> Self {
34 Self(pos)
35 }
36
37 #[inline]
39 pub const fn as_u64(self) -> u64 {
40 self.0
41 }
42
43 #[inline]
46 pub const fn is_valid(self) -> bool {
47 self.0 <= MAX_POSITION.0
48 }
49
50 #[inline]
52 pub const fn checked_add(self, rhs: u64) -> Option<Self> {
53 match self.0.checked_add(rhs) {
54 Some(value) => {
55 if value <= MAX_POSITION.0 {
56 Some(Self(value))
57 } else {
58 None
59 }
60 }
61 None => None,
62 }
63 }
64
65 #[inline]
67 pub const fn checked_sub(self, rhs: u64) -> Option<Self> {
68 match self.0.checked_sub(rhs) {
69 Some(value) => Some(Self(value)),
70 None => None,
71 }
72 }
73
74 #[inline]
76 pub const fn saturating_add(self, rhs: u64) -> Self {
77 let result = self.0.saturating_add(rhs);
78 if result > MAX_POSITION.0 {
79 MAX_POSITION
80 } else {
81 Self(result)
82 }
83 }
84
85 #[inline]
87 pub const fn saturating_sub(self, rhs: u64) -> Self {
88 Self(self.0.saturating_sub(rhs))
89 }
90
91 #[inline]
97 pub const fn is_mmr_size(self) -> bool {
98 if self.0 == 0 {
99 return true;
100 }
101 let leading_zeros = self.0.leading_zeros();
102 if leading_zeros == 0 {
103 return false;
105 }
106 let start = u64::MAX >> leading_zeros;
107 let mut two_h = 1 << start.trailing_ones();
108 let mut node_pos = start.checked_sub(1).expect("start > 0 because size != 0");
109 while two_h > 1 {
110 if node_pos < self.0 {
111 if two_h == 2 {
112 return node_pos == self.0 - 1;
115 }
116 node_pos += two_h - 1;
118 if node_pos < self.0 {
119 return false;
121 }
122 continue;
123 }
124 two_h >>= 1;
126 node_pos -= two_h;
127 }
128 true
129 }
130}
131
132impl fmt::Display for Position {
133 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134 write!(f, "Position({})", self.0)
135 }
136}
137
138impl Deref for Position {
139 type Target = u64;
140 fn deref(&self) -> &Self::Target {
141 &self.0
142 }
143}
144
145impl AsRef<u64> for Position {
146 fn as_ref(&self) -> &u64 {
147 &self.0
148 }
149}
150
151impl From<u64> for Position {
152 #[inline]
153 fn from(value: u64) -> Self {
154 Self::new(value)
155 }
156}
157
158impl From<usize> for Position {
159 #[inline]
160 fn from(value: usize) -> Self {
161 Self::new(value as u64)
162 }
163}
164
165impl From<Position> for u64 {
166 #[inline]
167 fn from(position: Position) -> Self {
168 *position
169 }
170}
171
172impl TryFrom<Location> for Position {
191 type Error = super::Error;
192
193 #[inline]
194 fn try_from(loc: Location) -> Result<Self, Self::Error> {
195 if !loc.is_valid() {
196 return Err(super::Error::LocationOverflow(loc));
197 }
198 let loc_val = *loc;
200 Ok(Self(
201 loc_val
202 .checked_mul(2)
203 .expect("should not overflow for valid leaf index")
204 - loc_val.count_ones() as u64,
205 ))
206 }
207}
208
209impl Add for Position {
215 type Output = Self;
216
217 #[inline]
218 fn add(self, rhs: Self) -> Self::Output {
219 Self(self.0 + rhs.0)
220 }
221}
222
223impl Add<u64> for Position {
229 type Output = Self;
230
231 #[inline]
232 fn add(self, rhs: u64) -> Self::Output {
233 Self(self.0 + rhs)
234 }
235}
236
237impl Sub for Position {
243 type Output = Self;
244
245 #[inline]
246 fn sub(self, rhs: Self) -> Self::Output {
247 Self(self.0 - rhs.0)
248 }
249}
250
251impl Sub<u64> for Position {
257 type Output = Self;
258
259 #[inline]
260 fn sub(self, rhs: u64) -> Self::Output {
261 Self(self.0 - rhs)
262 }
263}
264
265impl PartialEq<u64> for Position {
266 #[inline]
267 fn eq(&self, other: &u64) -> bool {
268 self.0 == *other
269 }
270}
271
272impl PartialOrd<u64> for Position {
273 #[inline]
274 fn partial_cmp(&self, other: &u64) -> Option<core::cmp::Ordering> {
275 self.0.partial_cmp(other)
276 }
277}
278
279impl PartialEq<Position> for u64 {
281 #[inline]
282 fn eq(&self, other: &Position) -> bool {
283 *self == other.0
284 }
285}
286
287impl PartialOrd<Position> for u64 {
288 #[inline]
289 fn partial_cmp(&self, other: &Position) -> Option<core::cmp::Ordering> {
290 self.partial_cmp(&other.0)
291 }
292}
293
294impl AddAssign<u64> for Position {
300 #[inline]
301 fn add_assign(&mut self, rhs: u64) {
302 self.0 += rhs;
303 }
304}
305
306impl SubAssign<u64> for Position {
312 #[inline]
313 fn sub_assign(&mut self, rhs: u64) {
314 self.0 -= rhs;
315 }
316}
317
318impl commonware_codec::Write for Position {
320 #[inline]
321 fn write(&self, buf: &mut impl BufMut) {
322 UInt(self.0).write(buf);
323 }
324}
325
326impl commonware_codec::EncodeSize for Position {
327 #[inline]
328 fn encode_size(&self) -> usize {
329 UInt(self.0).encode_size()
330 }
331}
332
333impl commonware_codec::Read for Position {
334 type Cfg = ();
335
336 #[inline]
337 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, commonware_codec::Error> {
338 let pos = Self(UInt::read(buf)?.into());
339 if pos.is_valid() {
340 Ok(pos)
341 } else {
342 Err(commonware_codec::Error::Invalid(
343 "Position",
344 "value exceeds MAX_POSITION",
345 ))
346 }
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::{Location, Position};
353 use crate::mmr::{mem::Mmr, StandardHasher as Standard, MAX_LOCATION, MAX_POSITION};
354 use commonware_cryptography::Sha256;
355
356 #[test]
358 fn test_from_location() {
359 const CASES: &[(Location, Position)] = &[
360 (Location::new(0), Position::new(0)),
361 (Location::new(1), Position::new(1)),
362 (Location::new(2), Position::new(3)),
363 (Location::new(3), Position::new(4)),
364 (Location::new(4), Position::new(7)),
365 (Location::new(5), Position::new(8)),
366 (Location::new(6), Position::new(10)),
367 (Location::new(7), Position::new(11)),
368 (Location::new(8), Position::new(15)),
369 (Location::new(9), Position::new(16)),
370 (Location::new(10), Position::new(18)),
371 (Location::new(11), Position::new(19)),
372 (Location::new(12), Position::new(22)),
373 (Location::new(13), Position::new(23)),
374 (Location::new(14), Position::new(25)),
375 (Location::new(15), Position::new(26)),
376 ];
377 for (loc, expected_pos) in CASES {
378 let pos = Position::try_from(*loc).unwrap();
379 assert_eq!(pos, *expected_pos);
380 }
381 }
382
383 #[test]
384 fn test_checked_add() {
385 let pos = Position::new(10);
386 assert_eq!(pos.checked_add(5).unwrap(), 15);
387
388 assert!(Position::new(u64::MAX).checked_add(1).is_none());
390
391 assert!(MAX_POSITION.checked_add(1).is_none());
393 assert!(Position::new(*MAX_POSITION - 5).checked_add(10).is_none());
394
395 assert_eq!(
397 Position::new(*MAX_POSITION - 10).checked_add(10).unwrap(),
398 MAX_POSITION
399 );
400 }
401
402 #[test]
403 fn test_checked_sub() {
404 let pos = Position::new(10);
405 assert_eq!(pos.checked_sub(5).unwrap(), 5);
406 assert!(pos.checked_sub(11).is_none());
407 }
408
409 #[test]
410 fn test_saturating_add() {
411 let pos = Position::new(10);
412 assert_eq!(pos.saturating_add(5), 15);
413
414 assert_eq!(Position::new(u64::MAX).saturating_add(1), MAX_POSITION);
416 assert_eq!(MAX_POSITION.saturating_add(1), MAX_POSITION);
417 assert_eq!(MAX_POSITION.saturating_add(1000), MAX_POSITION);
418 assert_eq!(
419 Position::new(*MAX_POSITION - 5).saturating_add(10),
420 MAX_POSITION
421 );
422 }
423
424 #[test]
425 fn test_saturating_sub() {
426 let pos = Position::new(10);
427 assert_eq!(pos.saturating_sub(5), 5);
428 assert_eq!(Position::new(0).saturating_sub(1), 0);
429 }
430
431 #[test]
432 fn test_display() {
433 let position = Position::new(42);
434 assert_eq!(position.to_string(), "Position(42)");
435 }
436
437 #[test]
438 fn test_add() {
439 let pos1 = Position::new(10);
440 let pos2 = Position::new(5);
441 assert_eq!((pos1 + pos2), 15);
442 }
443
444 #[test]
445 fn test_sub() {
446 let pos1 = Position::new(10);
447 let pos2 = Position::new(3);
448 assert_eq!((pos1 - pos2), 7);
449 }
450
451 #[test]
452 fn test_comparison_with_u64() {
453 let pos = Position::new(42);
454
455 assert_eq!(pos, 42u64);
457 assert_eq!(42u64, pos);
458 assert_ne!(pos, 43u64);
459 assert_ne!(43u64, pos);
460
461 assert!(pos < 43u64);
463 assert!(43u64 > pos);
464 assert!(pos > 41u64);
465 assert!(41u64 < pos);
466 assert!(pos <= 42u64);
467 assert!(42u64 >= pos);
468 }
469
470 #[test]
471 fn test_assignment_with_u64() {
472 let mut pos = Position::new(10);
473
474 pos += 5;
476 assert_eq!(pos, 15u64);
477
478 pos -= 3;
480 assert_eq!(pos, 12u64);
481 }
482
483 #[test]
484 fn test_max_position() {
485 let max_leaves = 1u64 << 62;
487 let max_size = 2 * max_leaves - 1; assert_eq!(*MAX_POSITION, max_size);
489 assert_eq!(*MAX_POSITION, (1u64 << 63) - 1);
490 assert_eq!(max_size.leading_zeros(), 1); let overflow_size = 2 * (max_leaves + 1) - 1;
494 assert_eq!(overflow_size.leading_zeros(), 0);
495
496 let pos = Position::try_from(MAX_LOCATION).unwrap();
498 assert_eq!(pos, MAX_POSITION);
499 assert!(pos.is_valid());
500 }
501
502 #[test]
503 fn test_is_mmr_size() {
504 let mut size_to_check = Position::new(0);
507 let mut hasher = Standard::<Sha256>::new();
508 let mut mmr = Mmr::new(&mut hasher);
509 let digest = [1u8; 32];
510 for _i in 0..10000 {
511 while size_to_check != mmr.size() {
512 assert!(
513 !size_to_check.is_mmr_size(),
514 "size_to_check: {} {}",
515 size_to_check,
516 mmr.size()
517 );
518 size_to_check += 1;
519 }
520 assert!(size_to_check.is_mmr_size());
521 let changeset = {
522 let mut batch = mmr.new_batch();
523 batch.add(&mut hasher, &digest);
524 batch.merkleize(&mut hasher).finalize()
525 };
526 mmr.apply(changeset).unwrap();
527 size_to_check += 1;
528 }
529
530 assert!(!Position::new(u64::MAX).is_mmr_size());
532 assert!(Position::new(u64::MAX >> 1).is_mmr_size()); assert!(!Position::new((u64::MAX >> 1) + 1).is_mmr_size());
534 assert!(MAX_POSITION.is_mmr_size()); }
536
537 #[test]
538 fn test_read_cfg_valid_values() {
539 use commonware_codec::{Encode, ReadExt};
540
541 let pos = Position::new(0);
543 let encoded = pos.encode();
544 let decoded = Position::read(&mut encoded.as_ref()).unwrap();
545 assert_eq!(decoded, pos);
546
547 let pos = Position::new(12345);
549 let encoded = pos.encode();
550 let decoded = Position::read(&mut encoded.as_ref()).unwrap();
551 assert_eq!(decoded, pos);
552
553 let pos = MAX_POSITION;
555 let encoded = pos.encode();
556 let decoded = Position::read(&mut encoded.as_ref()).unwrap();
557 assert_eq!(decoded, pos);
558 }
559
560 #[test]
561 fn test_read_cfg_invalid_values() {
562 use commonware_codec::{varint::UInt, Encode, ReadExt};
563
564 let invalid_value = *MAX_POSITION + 1;
566 let encoded = UInt(invalid_value).encode();
567 let result = Position::read(&mut encoded.as_ref());
568 assert!(result.is_err());
569 assert!(matches!(
570 result,
571 Err(commonware_codec::Error::Invalid("Position", _))
572 ));
573
574 let encoded = UInt(u64::MAX).encode();
576 let result = Position::read(&mut encoded.as_ref());
577 assert!(result.is_err());
578 assert!(matches!(
579 result,
580 Err(commonware_codec::Error::Invalid("Position", _))
581 ));
582 }
583}