use serde::{Deserialize, Serialize};
use std::{collections::BTreeMap, fmt::Debug};
use thiserror::Error;
use crate::error::LibraryError;
use super::{
sorted_iter::sorted_iter,
tree::{ABinaryTree, ABinaryTreeError},
treemath::{
copath, direct_path, left, lowest_common_ancestor, right, root, LeafNodeIndex,
ParentNodeIndex, TreeNodeIndex, TreeSize, MAX_TREE_SIZE, MIN_TREE_SIZE,
},
};
#[derive(Debug, Serialize, Deserialize)]
#[cfg_attr(any(test, feature = "test-utils"), derive(Clone, PartialEq))]
pub(crate) struct StagedAbDiff<L: Clone + Debug + Default, P: Clone + Debug + Default> {
leaf_diff: BTreeMap<LeafNodeIndex, L>,
parent_diff: BTreeMap<ParentNodeIndex, P>,
size: TreeSize,
}
impl<'a, L: Clone + Debug + Default, P: Clone + Debug + Default> From<AbDiff<'a, L, P>>
for StagedAbDiff<L, P>
{
fn from(diff: AbDiff<'a, L, P>) -> Self {
StagedAbDiff {
leaf_diff: diff.leaf_diff,
parent_diff: diff.parent_diff,
size: diff.size,
}
}
}
impl<L: Clone + Debug + Default, P: Clone + Debug + Default> StagedAbDiff<L, P> {
pub(super) fn into_diffs(self) -> (BTreeMap<LeafNodeIndex, L>, BTreeMap<ParentNodeIndex, P>) {
(self.leaf_diff, self.parent_diff)
}
pub(super) fn tree_size(&self) -> TreeSize {
self.size
}
pub(crate) fn leaves<'a>(
&'a self,
original_tree: &'a ABinaryTree<L, P>,
) -> impl Iterator<Item = (LeafNodeIndex, &'a L)> {
AbDiff::raw_leaves(
original_tree,
&self.leaf_diff,
self.size.leaf_count() as usize,
)
}
pub(crate) fn parents<'a>(
&'a self,
original_tree: &'a ABinaryTree<L, P>,
) -> impl Iterator<Item = (ParentNodeIndex, &'a P)> {
AbDiff::raw_parents(
original_tree,
&self.parent_diff,
self.size.parent_count() as usize,
)
}
}
pub(crate) struct AbDiff<'a, L: Clone + Debug + Default, P: Clone + Debug + Default> {
original_tree: &'a ABinaryTree<L, P>,
leaf_diff: BTreeMap<LeafNodeIndex, L>,
parent_diff: BTreeMap<ParentNodeIndex, P>,
size: TreeSize,
default_leaf: L,
default_parent: P,
}
impl<'a, L: Clone + Debug + Default, P: Clone + Debug + Default> From<&'a ABinaryTree<L, P>>
for AbDiff<'a, L, P>
{
fn from(tree: &'a ABinaryTree<L, P>) -> AbDiff<'a, L, P> {
AbDiff {
original_tree: tree,
leaf_diff: BTreeMap::new(),
parent_diff: BTreeMap::new(),
size: tree.tree_size(),
default_leaf: L::default(),
default_parent: P::default(),
}
}
}
impl<L: Clone + Debug + Default, P: Clone + Debug + Default> AbDiff<'_, L, P> {
pub(crate) fn grow_tree(&mut self) -> Result<(), ABinaryTreeDiffError> {
if self.size().u32() > MAX_TREE_SIZE / 2 {
return Err(ABinaryTreeDiffError::TreeTooLarge);
}
self.size.inc();
Ok(())
}
pub(crate) fn shrink_tree(&mut self) -> Result<(), ABinaryTreeDiffError> {
if self.size().u32() <= MIN_TREE_SIZE {
return Err(ABinaryTreeDiffError::TreeTooSmall);
}
self.size.dec();
self.leaf_diff
.retain(|&index, _| index.u32() < self.size.leaf_count());
self.parent_diff
.retain(|&index, _| index.u32() < self.size.parent_count());
Ok(())
}
pub(crate) fn replace_leaf(&mut self, leaf_index: LeafNodeIndex, new_leaf: L) {
debug_assert!(leaf_index.u32() < self.leaf_count());
self.leaf_diff.insert(leaf_index, new_leaf);
}
pub(crate) fn replace_parent(&mut self, parent_index: ParentNodeIndex, node: P) {
debug_assert!(parent_index.u32() < self.parent_count());
self.parent_diff.insert(parent_index, node);
}
fn raw_parents<'a>(
original_tree: &'a ABinaryTree<L, P>,
parent_diff: &'a BTreeMap<ParentNodeIndex, P>,
parent_count: usize,
) -> impl Iterator<Item = (ParentNodeIndex, &'a P)> {
let original_parents = original_tree.parents().peekable();
let diff_parents = parent_diff
.iter()
.map(|(index, parent)| (*index, parent))
.peekable();
let a_iter = Box::new(diff_parents) as Box<dyn Iterator<Item = (ParentNodeIndex, &P)>>;
let b_iter = Box::new(original_parents) as Box<dyn Iterator<Item = (ParentNodeIndex, &P)>>;
let cmp = |&(x, _): &(ParentNodeIndex, &P)| x;
sorted_iter(a_iter, b_iter, cmp, parent_count)
}
fn raw_leaves<'a>(
original_tree: &'a ABinaryTree<L, P>,
leaf_diff: &'a BTreeMap<LeafNodeIndex, L>,
leaf_count: usize,
) -> impl Iterator<Item = (LeafNodeIndex, &'a L)> {
let original_leaves = original_tree.leaves().peekable();
let diff_leaves = leaf_diff
.iter()
.map(|(index, leaf)| (*index, leaf))
.peekable();
let a_iter = Box::new(diff_leaves) as Box<dyn Iterator<Item = (LeafNodeIndex, &L)>>;
let b_iter = Box::new(original_leaves) as Box<dyn Iterator<Item = (LeafNodeIndex, &L)>>;
let cmp = |&(x, _): &(LeafNodeIndex, &L)| x;
sorted_iter(a_iter, b_iter, cmp, leaf_count)
}
pub(crate) fn leaves(&self) -> impl Iterator<Item = (LeafNodeIndex, &L)> {
Self::raw_leaves(
self.original_tree,
&self.leaf_diff,
self.leaf_count() as usize,
)
}
pub(crate) fn parents(&self) -> impl Iterator<Item = (ParentNodeIndex, &P)> {
Self::raw_parents(
self.original_tree,
&self.parent_diff,
self.parent_count() as usize,
)
}
pub(crate) fn direct_path(&self, leaf_index: LeafNodeIndex) -> Vec<ParentNodeIndex> {
direct_path(leaf_index, self.size())
}
pub(crate) fn set_direct_path_to_node(&mut self, leaf_index: LeafNodeIndex, node: &P) {
let direct_path = self.direct_path(leaf_index);
for node_index in &direct_path {
self.replace_parent(*node_index, node.clone());
}
}
pub(crate) fn copath(&self, leaf_index: LeafNodeIndex) -> Vec<TreeNodeIndex> {
copath(leaf_index, self.size())
}
pub(crate) fn lowest_common_ancestor(
&self,
leaf_index_1: LeafNodeIndex,
leaf_index_2: LeafNodeIndex,
) -> ParentNodeIndex {
debug_assert!(leaf_index_1 != leaf_index_2);
debug_assert!(leaf_index_1.u32() < self.leaf_count());
debug_assert!(leaf_index_2.u32() < self.leaf_count());
lowest_common_ancestor(leaf_index_1, leaf_index_2)
}
pub(crate) fn subtree_root_copath_node(
&self,
leaf_index_1: LeafNodeIndex,
leaf_index_2: LeafNodeIndex,
) -> TreeNodeIndex {
debug_assert!(leaf_index_1 != leaf_index_2);
debug_assert!(leaf_index_1.u32() < self.leaf_count());
debug_assert!(leaf_index_2.u32() < self.leaf_count());
let subtree_root_node_index = lowest_common_ancestor(leaf_index_1, leaf_index_2);
if leaf_index_2 < leaf_index_1 {
left(subtree_root_node_index)
} else {
right(subtree_root_node_index)
}
}
pub(crate) fn leaf_count(&self) -> u32 {
self.size.leaf_count()
}
pub(crate) fn parent_count(&self) -> u32 {
self.size.parent_count()
}
pub(crate) fn size(&self) -> TreeSize {
self.size
}
pub(crate) fn root(&self) -> TreeNodeIndex {
root(self.size())
}
pub(crate) fn left_child(&self, node_index: ParentNodeIndex) -> TreeNodeIndex {
left(node_index)
}
pub(crate) fn right_child(&self, node_index: ParentNodeIndex) -> TreeNodeIndex {
right(node_index)
}
pub(crate) fn leaf(&self, leaf_index: LeafNodeIndex) -> &L {
if let Some(node) = self.leaf_diff.get(&leaf_index) {
node
} else if leaf_index.u32() >= self.leaf_count() {
&self.default_leaf
} else {
self.original_tree.leaf_by_index(leaf_index)
}
}
pub(crate) fn parent(&self, parent_index: ParentNodeIndex) -> &P {
if let Some(node) = self.parent_diff.get(&parent_index) {
return node;
}
self.original_tree.parent_by_index(parent_index)
}
pub(crate) fn parent_mut(&mut self, parent_index: ParentNodeIndex) -> &mut P {
debug_assert!(parent_index.u32() < self.parent_count());
if self.parent_diff.contains_key(&parent_index) {
return self
.parent_diff
.get_mut(&parent_index)
.unwrap_or(&mut self.default_parent);
}
let tree_node = self.original_tree.parent_by_index(parent_index);
self.replace_parent(parent_index, tree_node.clone());
self.parent_diff
.get_mut(&parent_index)
.unwrap_or(&mut self.default_parent)
}
#[cfg(test)]
pub(crate) fn deref_vec(
&self,
parent_index_vec: Vec<ParentNodeIndex>,
) -> Result<Vec<&P>, ABinaryTreeDiffError> {
let mut parent_vec = Vec::new();
for parent_index in parent_index_vec {
let node = self.parent(parent_index);
parent_vec.push(node);
}
Ok(parent_vec)
}
}
#[derive(Error, Debug, PartialEq, Clone)]
pub(crate) enum ABinaryTreeDiffError {
#[error(transparent)]
LibraryError(#[from] LibraryError),
#[error("Maximum tree size reached.")]
TreeTooLarge,
#[error("Minimum tree size reached.")]
TreeTooSmall,
#[error(transparent)]
ABinaryTreeError(#[from] ABinaryTreeError),
}