use std::fmt::Debug;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use super::{
diff::{AbDiff, StagedAbDiff},
treemath::{common_direct_path, LeafNodeIndex, ParentNodeIndex, TreeSize, MAX_TREE_SIZE},
};
#[derive(Clone, Debug)]
pub(crate) enum TreeNode<L, P>
where
L: Clone + Debug + Default,
P: Clone + Debug + Default,
{
Leaf(Box<L>),
Parent(Box<P>),
}
#[cfg(test)]
impl<L, P> TreeNode<L, P>
where
L: Clone + Debug + Default,
P: Clone + Debug + Default,
{
pub(crate) fn leaf(l: L) -> Self {
Self::Leaf(Box::new(l))
}
pub(crate) fn parent(p: P) -> Self {
Self::Parent(Box::new(p))
}
}
#[cfg_attr(any(test, feature = "test-utils"), derive(PartialEq))]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub(crate) struct ABinaryTree<L: Clone + Debug + Default, P: Clone + Debug + Default> {
leaf_nodes: Vec<L>,
parent_nodes: Vec<P>,
default_leaf: L,
default_parent: P,
}
impl<L: Clone + Debug + Default, P: Clone + Debug + Default> ABinaryTree<L, P> {
pub(crate) fn new(nodes: Vec<TreeNode<L, P>>) -> Result<Self, ABinaryTreeError> {
if nodes.len() > MAX_TREE_SIZE as usize {
return Err(ABinaryTreeError::OutOfRange);
}
if nodes.len() % 2 != 1 {
return Err(ABinaryTreeError::InvalidNumberOfNodes);
}
let mut leaf_nodes = Vec::new();
let mut parent_nodes = Vec::new();
for (i, node) in nodes.into_iter().enumerate() {
match node {
TreeNode::Leaf(l) => {
if i % 2 == 0 {
leaf_nodes.push(*l)
} else {
return Err(ABinaryTreeError::WrongNodeType);
}
}
TreeNode::Parent(p) => {
if i % 2 == 1 {
parent_nodes.push(*p)
} else {
return Err(ABinaryTreeError::WrongNodeType);
}
}
}
}
Ok(ABinaryTree {
leaf_nodes,
parent_nodes,
default_leaf: L::default(),
default_parent: P::default(),
})
}
pub(crate) fn from_components(
leaf_nodes: Vec<L>,
parent_nodes: Vec<P>,
) -> Result<Self, ABinaryTreeError> {
let total_nodes = leaf_nodes.len() + parent_nodes.len();
if total_nodes > MAX_TREE_SIZE as usize {
return Err(ABinaryTreeError::OutOfRange);
}
if leaf_nodes.len() != parent_nodes.len() + 1 {
return Err(ABinaryTreeError::InvalidNumberOfNodes);
}
Ok(ABinaryTree {
leaf_nodes,
parent_nodes,
default_leaf: L::default(),
default_parent: P::default(),
})
}
pub(in crate::binary_tree) fn leaf_by_index(&self, leaf_index: LeafNodeIndex) -> &L {
self.leaf_nodes
.get(leaf_index.usize())
.unwrap_or(&self.default_leaf)
}
pub(crate) fn parent_by_index(&self, parent_index: ParentNodeIndex) -> &P {
self.parent_nodes
.get(parent_index.usize())
.unwrap_or(&self.default_parent)
}
pub(crate) fn tree_size(&self) -> TreeSize {
TreeSize::new((self.leaf_nodes.len() + self.parent_nodes.len()) as u32)
}
pub(crate) fn leaf_count(&self) -> u32 {
self.leaf_nodes.len() as u32
}
pub(crate) fn parent_count(&self) -> u32 {
self.parent_nodes.len() as u32
}
pub(crate) fn leaves(&self) -> impl Iterator<Item = (LeafNodeIndex, &L)> {
self.leaf_nodes
.iter()
.enumerate()
.map(|(index, leave)| (LeafNodeIndex::new(index as u32), leave))
}
pub(crate) fn parents(&self) -> impl Iterator<Item = (ParentNodeIndex, &P)> {
self.parent_nodes
.iter()
.enumerate()
.map(|(index, leave)| (ParentNodeIndex::new(index as u32), leave))
}
pub(crate) fn empty_diff(&self) -> AbDiff<'_, L, P> {
self.into()
}
pub(crate) fn merge_diff(&mut self, diff: StagedAbDiff<L, P>) {
let tree_size = diff.tree_size();
let (leaf_diff, parent_diff) = diff.into_diffs();
self.leaf_nodes
.resize_with(tree_size.leaf_count() as usize, Default::default);
self.parent_nodes
.resize_with(tree_size.parent_count() as usize, Default::default);
for (leaf_index, diff_leaf) in leaf_diff.into_iter() {
debug_assert!(leaf_index.u32() < self.leaf_count());
match self.leaf_nodes.get_mut(leaf_index.usize()) {
Some(n) => *n = diff_leaf,
None => {
debug_assert!(false);
}
}
}
for (parent_index, diff_parent) in parent_diff.into_iter() {
debug_assert!(parent_index.u32() < self.parent_count());
match self.parent_nodes.get_mut(parent_index.usize()) {
Some(n) => *n = diff_parent,
None => {
debug_assert!(false);
}
}
}
}
pub(crate) fn leaf(&self, leaf_index: LeafNodeIndex) -> &L {
self.leaf_nodes
.get(leaf_index.usize())
.unwrap_or(&self.default_leaf)
}
pub(crate) fn subtree_path(
&self,
leaf_index_1: LeafNodeIndex,
leaf_index_2: LeafNodeIndex,
) -> Vec<ParentNodeIndex> {
common_direct_path(leaf_index_1, leaf_index_2, self.tree_size())
}
pub(crate) fn parent(&self, parent_index: ParentNodeIndex) -> &P {
self.parent_nodes
.get(parent_index.usize())
.unwrap_or(&self.default_parent)
}
}
#[derive(Error, Debug, PartialEq, Clone)]
pub(crate) enum ABinaryTreeError {
#[error("Adding nodes exceeds the maximum possible size of the tree.")]
OutOfRange,
#[error("Not enough nodes to remove.")]
InvalidNumberOfNodes,
#[error("Wrong node type.")]
WrongNodeType,
}