use core::{fmt, iter};
pub struct ForkTree<T> {
nodes: slab::Slab<Node<T>>,
first_root: Option<usize>,
}
struct Node<T> {
parent: Option<usize>,
first_child: Option<usize>,
next_sibling: Option<usize>,
previous_sibling: Option<usize>,
is_prune_target_ancestor: bool,
data: T,
}
impl<T> ForkTree<T> {
pub fn new() -> Self {
ForkTree {
nodes: slab::Slab::new(),
first_root: None,
}
}
pub fn with_capacity(cap: usize) -> Self {
ForkTree {
nodes: slab::Slab::with_capacity(cap),
first_root: None,
}
}
pub fn reserve(&mut self, additional: usize) {
self.nodes.reserve(additional);
}
pub fn clear(&mut self) {
self.nodes.clear();
self.first_root = None;
}
pub fn shrink_to_fit(&mut self) {
self.nodes.shrink_to_fit();
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn iter_unordered(&self) -> impl Iterator<Item = (NodeIndex, &T)> {
self.nodes.iter().map(|n| (NodeIndex(n.0), &n.1.data))
}
pub fn iter_ancestry_order(&self) -> impl Iterator<Item = (NodeIndex, &T)> {
iter::successors(self.first_root.map(NodeIndex), move |n| {
self.ancestry_order_next(*n)
})
.map(move |idx| (idx, &self.nodes[idx.0].data))
}
fn ancestry_order_next(&self, node_index: NodeIndex) -> Option<NodeIndex> {
debug_assert!(!self.nodes[node_index.0].is_prune_target_ancestor);
if let Some(idx) = self.nodes[node_index.0].first_child {
debug_assert_eq!(self.nodes[idx].parent, Some(node_index.0));
return Some(NodeIndex(idx));
}
if let Some(idx) = self.nodes[node_index.0].next_sibling {
debug_assert_eq!(self.nodes[idx].previous_sibling, Some(node_index.0));
debug_assert_eq!(self.nodes[idx].parent, self.nodes[node_index.0].parent);
return Some(NodeIndex(idx));
}
let mut return_value = self.nodes[node_index.0].parent;
while let Some(idx) = return_value {
if let Some(next_sibling) = self.nodes[idx].next_sibling {
debug_assert_eq!(self.nodes[next_sibling].previous_sibling, Some(idx));
debug_assert_eq!(self.nodes[next_sibling].parent, self.nodes[idx].parent);
return Some(NodeIndex(next_sibling));
}
return_value = self.nodes[idx].parent;
}
return_value.map(NodeIndex)
}
pub fn contains(&self, index: NodeIndex) -> bool {
self.nodes.contains(index.0)
}
pub fn get(&self, index: NodeIndex) -> Option<&T> {
self.nodes.get(index.0).map(|n| &n.data)
}
pub fn get_mut(&mut self, index: NodeIndex) -> Option<&mut T> {
self.nodes.get_mut(index.0).map(|n| &mut n.data)
}
pub fn map<U>(self, mut map: impl FnMut(T) -> U) -> ForkTree<U> {
ForkTree {
nodes: self
.nodes
.into_iter()
.map(|(index, node)| {
let node = Node {
parent: node.parent,
first_child: node.first_child,
next_sibling: node.next_sibling,
previous_sibling: node.previous_sibling,
is_prune_target_ancestor: node.is_prune_target_ancestor,
data: map(node.data),
};
(index, node)
})
.collect(),
first_root: self.first_root,
}
}
pub fn ancestors(&self, node: NodeIndex) -> impl Iterator<Item = NodeIndex> {
iter::successors(Some(node), move |n| self.nodes[n.0].parent.map(NodeIndex)).skip(1)
}
pub fn parent(&self, node: NodeIndex) -> Option<NodeIndex> {
self.nodes[node.0].parent.map(NodeIndex)
}
pub fn children(&self, node: Option<NodeIndex>) -> impl Iterator<Item = NodeIndex> {
let first = match node {
Some(n) => self.nodes[n.0].first_child,
None => self.first_root,
};
iter::successors(first, move |n| self.nodes[*n].next_sibling).map(NodeIndex)
}
pub fn prune_ancestors(&'_ mut self, node_index: NodeIndex) -> PruneAncestorsIter<'_, T> {
self.prune_ancestors_inner(node_index, false)
}
pub fn prune_uncles(&'_ mut self, node_index: NodeIndex) -> PruneAncestorsIter<'_, T> {
self.prune_ancestors_inner(node_index, true)
}
fn prune_ancestors_inner(
&'_ mut self,
node_index: NodeIndex,
uncles_only: bool,
) -> PruneAncestorsIter<'_, T> {
let iter = self.first_root.unwrap();
{
let mut node = node_index.0;
loop {
debug_assert!(!self.nodes[node].is_prune_target_ancestor);
self.nodes[node].is_prune_target_ancestor = true;
if uncles_only {
self.first_root = Some(node);
}
node = match self.nodes[node].parent {
Some(n) => n,
None => break,
}
}
}
if !uncles_only {
self.first_root = self.nodes[node_index.0].first_child;
}
PruneAncestorsIter {
finished: false,
tree: self,
uncles_only,
new_final: node_index,
iter,
traversing_up: false,
}
}
pub fn common_ancestor(&self, node1: NodeIndex, node2: NodeIndex) -> Option<NodeIndex> {
let dist_to_root1 = self.node_to_root_path(node1).count();
let dist_to_root2 = self.node_to_root_path(node2).count();
let mut iter1 = self
.node_to_root_path(node1)
.skip(dist_to_root1.saturating_sub(dist_to_root2));
let mut iter2 = self
.node_to_root_path(node2)
.skip(dist_to_root2.saturating_sub(dist_to_root1));
loop {
match (iter1.next(), iter2.next()) {
(Some(a), Some(b)) if a == b => return Some(a),
(Some(_), Some(_)) => continue,
(None, None) => return None,
_ => unreachable!(),
}
}
}
pub fn is_ancestor(&self, maybe_ancestor: NodeIndex, maybe_descendant: NodeIndex) -> bool {
assert!(self.nodes.contains(maybe_descendant.0));
let mut iter = maybe_descendant.0;
loop {
if iter == maybe_ancestor.0 {
return true;
}
iter = match self.nodes[iter].parent {
Some(p) => p,
None => return false,
};
}
}
pub fn ascend_and_descend(
&self,
node1: NodeIndex,
node2: NodeIndex,
) -> (
impl Iterator<Item = NodeIndex> + Clone,
impl Iterator<Item = NodeIndex> + Clone,
) {
let common_ancestor = self.common_ancestor(node1, node2);
let iter1 = self
.node_to_root_path(node1)
.take_while(move |v| Some(*v) != common_ancestor);
let iter2 = if let Some(common_ancestor) = common_ancestor {
either::Left(
self.root_to_node_path(node2)
.skip_while(move |v| *v != common_ancestor)
.skip(1),
)
} else {
either::Right(self.root_to_node_path(node2))
};
(iter1, iter2)
}
pub fn node_to_root_path(
&self,
node_index: NodeIndex,
) -> impl Iterator<Item = NodeIndex> + Clone {
iter::successors(Some(node_index), move |n| {
self.nodes[n.0].parent.map(NodeIndex)
})
}
pub fn root_to_node_path(
&self,
node_index: NodeIndex,
) -> impl Iterator<Item = NodeIndex> + Clone {
debug_assert!(self.nodes.get(usize::MAX).is_none());
iter::successors(Some(NodeIndex(usize::MAX)), move |¤t| {
self.node_to_root_path(node_index)
.take_while(move |n| *n != current)
.last()
})
.skip(1)
}
pub fn find(&self, mut cond: impl FnMut(&T) -> bool) -> Option<NodeIndex> {
self.nodes
.iter()
.filter(|(_, n)| cond(&n.data))
.map(|(i, _)| i)
.next()
.map(NodeIndex)
}
pub fn insert(&mut self, parent: Option<NodeIndex>, child: T) -> NodeIndex {
if let Some(parent) = parent {
let next_sibling = self.nodes.get_mut(parent.0).unwrap().first_child;
let new_node_index = self.nodes.insert(Node {
parent: Some(parent.0),
first_child: None,
next_sibling,
previous_sibling: None,
is_prune_target_ancestor: false,
data: child,
});
self.nodes.get_mut(parent.0).unwrap().first_child = Some(new_node_index);
if let Some(next_sibling) = next_sibling {
self.nodes.get_mut(next_sibling).unwrap().previous_sibling = Some(new_node_index);
}
NodeIndex(new_node_index)
} else {
let new_node_index = self.nodes.insert(Node {
parent: None,
first_child: None,
next_sibling: self.first_root,
previous_sibling: None,
is_prune_target_ancestor: false,
data: child,
});
if let Some(first_root) = self.first_root {
self.nodes.get_mut(first_root).unwrap().previous_sibling = Some(new_node_index);
}
self.first_root = Some(new_node_index);
NodeIndex(new_node_index)
}
}
}
impl<T> Default for ForkTree<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> fmt::Debug for ForkTree<T>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_list()
.entries(self.nodes.iter().map(|(_, v)| &v.data))
.finish()
}
}
pub struct PruneAncestorsIter<'a, T> {
finished: bool,
tree: &'a mut ForkTree<T>,
iter: usize,
traversing_up: bool,
new_final: NodeIndex,
uncles_only: bool,
}
impl<'a, T> Iterator for PruneAncestorsIter<'a, T> {
type Item = PrunedNode<T>;
fn next(&mut self) -> Option<Self::Item> {
loop {
if self.finished {
break None;
}
let iter_node = &mut self.tree.nodes[self.iter];
if iter_node.parent == Some(self.new_final.0) {
debug_assert!(!self.traversing_up);
if !self.uncles_only {
iter_node.parent = None;
}
self.iter = if let Some(next_sibling) = iter_node.next_sibling {
next_sibling
} else {
self.traversing_up = true;
self.new_final.0
};
continue;
}
if !self.traversing_up {
if let Some(first_child) = iter_node.first_child {
self.iter = first_child;
continue;
}
}
let maybe_removed_node_index = NodeIndex(self.iter);
if let Some(next_sibling) = iter_node.next_sibling {
self.traversing_up = false;
self.iter = next_sibling;
} else if let Some(parent) = iter_node.parent {
self.traversing_up = true;
self.iter = parent;
} else {
self.finished = true;
};
if self.uncles_only && iter_node.is_prune_target_ancestor {
iter_node.is_prune_target_ancestor = false;
iter_node.next_sibling = None;
if iter_node.previous_sibling.take().is_some() {
if let Some(parent) = iter_node.parent {
debug_assert!(self.tree.nodes[parent].first_child.is_some());
self.tree.nodes[parent].first_child = Some(maybe_removed_node_index.0);
}
}
continue;
}
debug_assert!(
self.tree
.first_root
.map_or(true, |n| n != maybe_removed_node_index.0)
);
let iter_node = self.tree.nodes.remove(maybe_removed_node_index.0);
break Some(PrunedNode {
index: maybe_removed_node_index,
is_prune_target_ancestor: iter_node.is_prune_target_ancestor,
user_data: iter_node.data,
});
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(0, Some(self.tree.nodes.len()))
}
}
impl<'a, T> Drop for PruneAncestorsIter<'a, T> {
fn drop(&mut self) {
loop {
if self.next().is_none() {
break;
}
}
if self.uncles_only {
debug_assert!(self.tree.first_root.is_some());
}
debug_assert!(
self.tree
.first_root
.map_or(true, |fr| self.tree.nodes.contains(fr))
);
debug_assert_eq!(self.uncles_only, self.tree.get(self.new_final).is_some());
#[cfg(debug_assertions)]
for _ in self.tree.iter_ancestry_order() {}
}
}
pub struct PrunedNode<T> {
pub index: NodeIndex,
pub is_prune_target_ancestor: bool,
pub user_data: T,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub struct NodeIndex(usize);
impl NodeIndex {
pub const MIN: Self = NodeIndex(usize::MIN);
pub const MAX: Self = NodeIndex(usize::MAX);
pub fn inc(self) -> Option<Self> {
self.0.checked_add(1).map(NodeIndex)
}
}
#[cfg(test)]
mod tests {
use super::ForkTree;
#[test]
fn basic() {
let mut tree = ForkTree::new();
let node0 = tree.insert(None, 0);
let node1 = tree.insert(Some(node0), 1);
let node2 = tree.insert(Some(node1), 2);
let node3 = tree.insert(Some(node2), 3);
let node4 = tree.insert(Some(node2), 4);
let node5 = tree.insert(Some(node0), 5);
assert_eq!(tree.find(|v| *v == 0), Some(node0));
assert_eq!(tree.find(|v| *v == 1), Some(node1));
assert_eq!(tree.find(|v| *v == 2), Some(node2));
assert_eq!(tree.find(|v| *v == 3), Some(node3));
assert_eq!(tree.find(|v| *v == 4), Some(node4));
assert_eq!(tree.find(|v| *v == 5), Some(node5));
assert_eq!(
tree.node_to_root_path(node3).collect::<Vec<_>>(),
&[node3, node2, node1, node0]
);
assert_eq!(
tree.node_to_root_path(node4).collect::<Vec<_>>(),
&[node4, node2, node1, node0]
);
assert_eq!(
tree.node_to_root_path(node1).collect::<Vec<_>>(),
&[node1, node0]
);
assert_eq!(
tree.node_to_root_path(node5).collect::<Vec<_>>(),
&[node5, node0]
);
let iter = tree.prune_ancestors(node1);
assert_eq!(
iter.filter(|n| n.is_prune_target_ancestor)
.map(|n| n.index)
.collect::<Vec<_>>(),
vec![node1, node0]
);
assert!(tree.get(node0).is_none());
assert!(tree.get(node1).is_none());
assert_eq!(tree.get(node2), Some(&2));
assert_eq!(tree.get(node3), Some(&3));
assert_eq!(tree.get(node4), Some(&4));
assert!(tree.get(node5).is_none());
assert_eq!(
tree.node_to_root_path(node3).collect::<Vec<_>>(),
&[node3, node2]
);
assert_eq!(
tree.node_to_root_path(node4).collect::<Vec<_>>(),
&[node4, node2]
);
}
#[test]
fn ascend_descend_when_common_ancestor_is_not_root() {
let mut tree = ForkTree::new();
let node0 = tree.insert(None, ());
let node1 = tree.insert(Some(node0), ());
let node2 = tree.insert(Some(node0), ());
let (ascend, descend) = tree.ascend_and_descend(node1, node2);
assert_eq!(ascend.collect::<Vec<_>>(), vec![node1]);
assert_eq!(descend.collect::<Vec<_>>(), vec![node2]);
assert_eq!(tree.common_ancestor(node1, node2), Some(node0));
}
#[test]
fn ascend_descend_when_common_ancestor_is_root() {
let mut tree = ForkTree::new();
let node0 = tree.insert(None, ());
let node1 = tree.insert(None, ());
let (ascend, descend) = tree.ascend_and_descend(node0, node1);
assert_eq!(ascend.collect::<Vec<_>>(), vec![node0]);
assert_eq!(descend.collect::<Vec<_>>(), vec![node1]);
assert_eq!(tree.common_ancestor(node0, node1), None);
}
}