use crate::{Error, Result};
#[cfg(any(test, feature = "proptest"))]
use proptest_derive::Arbitrary;
use std::ops::RangeInclusive;
use zkp_error_utils::require;
#[cfg(feature = "std")]
use std::fmt;
const USIZE_BITS: usize = 0_usize.count_zeros() as usize;
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[cfg_attr(
any(test, feature = "proptest"),
proptest(no_params),
derive(Arbitrary)
)]
pub struct Index(usize);
impl Index {
pub const fn max_size() -> usize {
1_usize << (USIZE_BITS - 1)
}
pub const fn size_at_depth(depth: usize) -> usize {
1_usize << depth
}
pub fn depth_for_size(size: usize) -> usize {
size.next_power_of_two().trailing_zeros() as usize
}
pub const fn root() -> Self {
Self(1)
}
pub const fn from_index(index: usize) -> Self {
Self(index + 1)
}
pub fn iter_layer(depth: usize) -> LayerIter {
LayerIter {
size: 1_usize << depth,
offset: 0,
}
}
pub fn layer_range(depth: usize) -> RangeInclusive<usize> {
let start = Self::from_depth_offset(depth, 0).unwrap();
let end = Self::from_depth_offset(depth, Self::size_at_depth(depth) - 1).unwrap();
start.as_index()..=end.as_index()
}
pub fn from_size_offset(size: usize, offset: usize) -> Result<Self> {
require!(size.is_power_of_two(), Error::NumLeavesNotPowerOfTwo);
require!(size <= Self::max_size(), Error::TreeToLarge);
require!(offset < size, Error::IndexOutOfRange);
Ok(Self(size | offset))
}
pub fn from_depth_offset(depth: usize, offset: usize) -> Result<Self> {
Self::from_size_offset(Self::size_at_depth(depth), offset)
}
pub fn as_index(self) -> usize {
self.0 - 1
}
pub fn depth(self) -> usize {
(0_usize.count_zeros() - self.0.leading_zeros() - 1) as usize
}
pub fn offset(self) -> usize {
self.0 - (1_usize << self.depth())
}
pub fn is_root(self) -> bool {
self.0 == 1
}
pub fn is_left(self) -> bool {
self.0 % 2 == 0
}
pub fn is_right(self) -> bool {
self.0 != 1 && self.0 % 2 == 1
}
pub fn is_left_most(self) -> bool {
self.0.is_power_of_two()
}
pub fn is_right_most(self) -> bool {
(self.0 + 1).is_power_of_two()
}
pub fn parent(self) -> Option<Self> {
if self.is_root() {
None
} else {
Some(Self(self.0 >> 1))
}
}
pub fn sibling(self) -> Option<Self> {
if self.is_root() {
None
} else {
Some(Self(self.0 ^ 1))
}
}
pub fn left_neighbor(self) -> Option<Self> {
if self.is_left_most() {
None
} else {
Some(Self(self.0 - 1))
}
}
pub fn right_neighbor(self) -> Option<Self> {
if self.is_right_most() {
None
} else {
Some(Self(self.0 + 1))
}
}
pub fn left_child(self) -> Self {
Self(self.0 << 1)
}
pub fn right_child(self) -> Self {
Self((self.0 << 1) | 1)
}
pub fn last_common_ancestor(self, other: Self) -> Self {
let a = self.0 << self.0.leading_zeros();
let b = other.0 << other.0.leading_zeros();
let prefix_length = (a ^ b).leading_zeros();
let prefix = a >> (0_usize.count_zeros() - prefix_length);
Self(prefix)
}
pub fn concat(self, other: Self) -> Self {
let other_depth = other.depth();
let other_path = other.0 ^ 1_usize << other_depth;
Self(self.0 << other_depth | other_path)
}
}
#[cfg(feature = "std")]
impl fmt::Debug for Index {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Index({:}, {:})", self.depth(), self.offset())
}
}
#[cfg_attr(feature = "std", derive(Debug))]
pub struct LayerIter {
size: usize,
offset: usize,
}
impl Iterator for LayerIter {
type Item = Index;
fn next(&mut self) -> Option<Self::Item> {
if self.offset < self.size {
let item = Index::from_size_offset(self.size, self.offset).unwrap();
self.offset += 1;
Some(item)
} else {
None
}
}
}
#[cfg(test)]
mod test {
use super::*;
use proptest::prelude::*;
proptest!(
#[test]
fn test_depth_offset_roundtrip(depth: usize, offset: usize) {
let depth = depth % (Index::max_size().trailing_zeros() as usize);
let offset = offset % Index::size_at_depth(depth);
let index = Index::from_size_offset(1_usize << depth, offset).unwrap();
prop_assert_eq!(index.depth(), depth);
prop_assert_eq!(index.offset(), offset);
}
#[test]
fn test_children(parent: Index) {
prop_assume!(parent.depth() < (0_usize.count_zeros() - 1) as usize);
let left = parent.left_child();
let right = parent.right_child();
prop_assert!(left.is_left());
prop_assert!(right.is_right());
prop_assert_eq!(left.depth(), right.depth());
prop_assert_eq!(left.depth(), parent.depth() + 1);
prop_assert_eq!(left.offset() + 1, right.offset());
prop_assert_eq!(left.parent().unwrap(), parent);
prop_assert_eq!(right.parent().unwrap(), parent);
prop_assert_eq!(left.right_neighbor().unwrap(), right);
prop_assert_eq!(right.left_neighbor().unwrap(), left);
prop_assert_eq!(left.sibling().unwrap(), right);
prop_assert_eq!(right.sibling().unwrap(), left);
}
);
}