use core::{
fmt::{Binary, Display},
ops::{BitAnd, BitOr, BitXor, BitXorAssign},
};
use super::{InOrderIndex, MmrError};
use crate::{
Felt,
utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable},
};
#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
pub struct Forest(usize);
impl Forest {
pub const MAX_LEAVES: usize = if (u32::MAX as usize) < (usize::MAX / 2) {
u32::MAX as usize
} else {
usize::MAX / 2
};
pub const fn empty() -> Self {
Self(0)
}
pub fn new(num_leaves: usize) -> Result<Self, DeserializationError> {
if !Self::is_valid_size(num_leaves) {
return Err(DeserializationError::InvalidValue(format!(
"forest size {} exceeds maximum {}",
num_leaves,
Self::MAX_LEAVES
)));
}
Ok(Self(num_leaves))
}
pub fn with_height(height: usize) -> Self {
assert!(height < usize::BITS as usize);
Self::new(1 << height).expect("forest height exceeds maximum")
}
pub const fn is_valid_size(num_leaves: usize) -> bool {
num_leaves <= Self::MAX_LEAVES
}
pub fn is_empty(self) -> bool {
self.0 == 0
}
pub fn append_leaf(&mut self) -> Result<(), MmrError> {
if self.0 >= Self::MAX_LEAVES {
return Err(MmrError::ForestSizeExceeded {
requested: self.0.saturating_add(1),
max: Self::MAX_LEAVES,
});
}
self.0 += 1;
Ok(())
}
pub fn num_leaves(self) -> usize {
self.0
}
pub const fn num_nodes(self) -> usize {
assert!(self.0 <= Self::MAX_LEAVES);
if self.0 <= usize::MAX / 2 {
self.0 * 2 - self.num_trees()
} else {
let (inner, num_trees) = (self.0 as u128, self.num_trees() as u128);
(inner * 2 - num_trees) as usize
}
}
pub const fn num_trees(self) -> usize {
self.0.count_ones() as usize
}
pub fn largest_tree_height_unchecked(self) -> usize {
self.0.ilog2() as usize
}
pub fn largest_tree_height(self) -> Option<usize> {
if self.is_empty() {
return None;
}
Some(self.largest_tree_height_unchecked())
}
pub fn largest_tree_unchecked(self) -> Self {
Self::with_height(self.largest_tree_height_unchecked())
}
pub fn largest_tree(self) -> Self {
if self.is_empty() {
return Self::empty();
}
self.largest_tree_unchecked()
}
pub fn smallest_tree_height_unchecked(self) -> usize {
self.0.trailing_zeros() as usize
}
pub fn smallest_tree_height(self) -> Option<usize> {
if self.is_empty() {
return None;
}
Some(self.smallest_tree_height_unchecked())
}
pub fn smallest_tree_unchecked(self) -> Self {
Self::with_height(self.smallest_tree_height_unchecked())
}
pub fn smallest_tree(self) -> Self {
if self.is_empty() {
return Self::empty();
}
self.smallest_tree_unchecked()
}
pub fn trees_larger_than(self, tree_idx: u32) -> Self {
let mask = high_bitmask(tree_idx + 1);
Self::new(self.0 & mask).expect("forest size exceeds maximum")
}
pub fn all_smaller_trees_unchecked(self) -> Self {
debug_assert_eq!(self.num_trees(), 1);
Self::new(self.0 - 1).expect("forest size exceeds maximum")
}
pub fn all_smaller_trees(self) -> Option<Forest> {
if self.num_trees() != 1 {
return None;
}
Some(self.all_smaller_trees_unchecked())
}
pub(crate) fn next_larger_tree(self) -> Result<Self, MmrError> {
debug_assert_eq!(self.num_trees(), 1);
let value = self.0.saturating_mul(2);
if value > Self::MAX_LEAVES {
return Err(MmrError::ForestSizeExceeded { requested: value, max: Self::MAX_LEAVES });
}
Ok(Forest(value))
}
pub fn has_single_leaf_tree(self) -> bool {
self.0 & 1 != 0
}
pub fn with_single_leaf(self) -> Self {
Self(self.0 | 1)
}
pub fn without_single_leaf(self) -> Self {
Self(self.0 & (usize::MAX - 1))
}
pub fn without_trees(self, other: Forest) -> Self {
Self(self.0 & !other.0)
}
pub fn tree_index(&self, leaf_idx: usize) -> usize {
let root = self
.leaf_to_corresponding_tree(leaf_idx)
.expect("position must be part of the forest");
let smaller_tree_mask =
Self::new(2_usize.pow(root) - 1).expect("forest size exceeds maximum");
let num_smaller_trees = (*self & smaller_tree_mask).num_trees();
self.num_trees() - num_smaller_trees - 1
}
pub fn root_in_order_index(&self) -> InOrderIndex {
let nodes = self.num_nodes();
let open_trees = self.num_trees() - 1;
let right_subtree_count = self.smallest_tree_unchecked().num_leaves() - 1;
let idx = nodes + open_trees - right_subtree_count;
InOrderIndex::new(idx.try_into().unwrap())
}
pub fn rightmost_in_order_index(&self) -> InOrderIndex {
let nodes = self.num_nodes();
let open_trees = self.num_trees() - 1;
let idx = nodes + open_trees;
InOrderIndex::new(idx.try_into().unwrap())
}
pub fn is_valid_in_order_index(&self, idx: &InOrderIndex) -> bool {
if idx.inner() == 0 {
return false;
}
if self.is_empty() {
return false;
}
let idx_val = idx.inner();
let mut offset = 0usize;
for tree in TreeSizeIterator::new(*self).rev() {
let tree_nodes = tree.num_nodes();
let tree_start = offset + 1;
let tree_end = offset + tree_nodes;
if idx_val >= tree_start && idx_val <= tree_end {
return true;
}
offset = tree_end + 1;
}
false
}
pub fn leaf_to_corresponding_tree(self, leaf_idx: usize) -> Option<u32> {
let forest = self.0;
if leaf_idx >= forest {
None
} else {
let before = forest & leaf_idx;
let after = forest ^ before;
let tree_idx = after.ilog2();
Some(tree_idx)
}
}
pub(super) fn leaf_relative_position(self, leaf_idx: usize) -> Option<usize> {
let tree_idx = self.leaf_to_corresponding_tree(leaf_idx)?;
let mask = high_bitmask(tree_idx + 1);
Some(leaf_idx - (self.0 & mask))
}
}
impl Display for Forest {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self.0)
}
}
impl Binary for Forest {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{:b}", self.0)
}
}
impl BitAnd<Forest> for Forest {
type Output = Self;
fn bitand(self, rhs: Self) -> Self::Output {
Self::new(self.0 & rhs.0).expect("forest size exceeds maximum")
}
}
const _: () =
assert!(Forest::MAX_LEAVES != 0 && (Forest::MAX_LEAVES & (Forest::MAX_LEAVES + 1)) == 0);
impl BitOr<Forest> for Forest {
type Output = Self;
fn bitor(self, rhs: Self) -> Self::Output {
Self(self.0 | rhs.0)
}
}
impl BitXor<Forest> for Forest {
type Output = Self;
fn bitxor(self, rhs: Self) -> Self::Output {
Self(self.0 ^ rhs.0)
}
}
impl BitXorAssign<Forest> for Forest {
fn bitxor_assign(&mut self, rhs: Self) {
self.0 ^= rhs.0;
}
}
impl TryFrom<Felt> for Forest {
type Error = MmrError;
fn try_from(value: Felt) -> Result<Self, Self::Error> {
let value = usize::try_from(value.as_canonical_u64()).map_err(|_| {
MmrError::ForestSizeExceeded {
requested: usize::MAX,
max: Self::MAX_LEAVES,
}
})?;
if value > Self::MAX_LEAVES {
return Err(MmrError::ForestSizeExceeded { requested: value, max: Self::MAX_LEAVES });
}
Ok(Self(value))
}
}
pub(crate) fn largest_tree_from_mask(mask: usize) -> Forest {
if mask == 0 {
Forest::empty()
} else {
let bit = mask.ilog2();
Forest::new(1usize << bit).expect("forest size exceeds maximum")
}
}
impl From<Forest> for Felt {
fn from(value: Forest) -> Self {
Felt::new_unchecked(value.0 as u64)
}
}
pub(crate) fn high_bitmask(bit: u32) -> usize {
if bit > usize::BITS - 1 { 0 } else { usize::MAX << bit }
}
impl Serializable for Forest {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.0.write_into(target);
}
}
impl Deserializable for Forest {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let value = source.read_usize()?;
Self::new(value)
}
}
#[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for Forest {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = usize::deserialize(deserializer)?;
Self::new(value).map_err(serde::de::Error::custom)
}
}
pub struct TreeSizeIterator {
inner: Forest,
}
impl TreeSizeIterator {
pub fn new(value: Forest) -> TreeSizeIterator {
TreeSizeIterator { inner: value }
}
}
impl Iterator for TreeSizeIterator {
type Item = Forest;
fn next(&mut self) -> Option<<Self as Iterator>::Item> {
let tree = self.inner.smallest_tree();
if tree.is_empty() {
None
} else {
self.inner = self.inner.without_trees(tree);
Some(tree)
}
}
}
impl DoubleEndedIterator for TreeSizeIterator {
fn next_back(&mut self) -> Option<<Self as Iterator>::Item> {
let tree = self.inner.largest_tree();
if tree.is_empty() {
None
} else {
self.inner = self.inner.without_trees(tree);
Some(tree)
}
}
}