#[cfg(feature = "rayon")]
pub use rayon;
#[cfg(feature = "rayon")]
use rayon::prelude::*;
#[derive(Clone)]
pub struct Node<T> {
data: T,
children: Vec<usize>,
parent: Option<usize>,
}
impl<T> Node<T> {
pub fn new(data: T) -> Self {
Self {
data,
children: Vec::new(),
parent: None,
}
}
pub(crate) fn add_child(&mut self, child: usize) {
self.children.push(child);
}
pub(crate) fn set_parent(&mut self, parent: usize) {
self.parent = Some(parent);
}
}
#[derive(Clone)]
pub struct Tree<T> {
nodes: Vec<Option<Node<T>>>,
free_list: Vec<usize>,
node_count: usize,
stack: Vec<(usize, bool)>,
}
impl<T> Default for Tree<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> Tree<T> {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
free_list: Vec::new(),
node_count: 0,
stack: Vec::new(),
}
}
pub fn add_node(&mut self, data: T) -> usize {
let node = Node::new(data);
self.node_count += 1;
if let Some(index) = self.free_list.pop() {
self.nodes[index] = Some(node);
index
} else {
let index = self.nodes.len();
self.nodes.push(Some(node));
index
}
}
pub fn add_child(&mut self, parent: usize, data: T) -> usize {
let index = self.add_node(data);
self.nodes[parent].as_mut().unwrap().add_child(index);
self.nodes[index].as_mut().unwrap().set_parent(parent);
index
}
pub fn add_child_to_root(&mut self, data: T) -> usize {
self.add_child(0, data)
}
pub fn get(&self, index: usize) -> Option<&T> {
self.nodes
.get(index)
.and_then(|slot| slot.as_ref().map(|node| &node.data))
}
#[inline(always)]
pub fn get_unchecked(&self, index: usize) -> &T {
&self.nodes[index].as_ref().unwrap().data
}
pub fn get_mut(&mut self, index: usize) -> Option<&mut T> {
self.nodes
.get_mut(index)
.and_then(|slot| slot.as_mut().map(|node| &mut node.data))
}
#[inline(always)]
pub fn get_unchecked_mut(&mut self, index: usize) -> &mut T {
&mut self.nodes[index].as_mut().unwrap().data
}
pub fn parent_index_unchecked(&self, index: usize) -> Option<usize> {
self.nodes[index].as_ref().unwrap().parent
}
pub fn children(&self, index: usize) -> &[usize] {
&self.nodes[index].as_ref().unwrap().children
}
pub fn traverse<'a, S>(
&'a self,
mut before_processing_children: impl FnMut(usize, &'a T, &mut S),
mut after_processing_the_subtree: impl FnMut(usize, &'a T, &mut S),
s: &mut S,
) {
if !matches!(self.nodes.first(), Some(Some(_))) {
return;
}
let mut stack = vec![(0, false)];
while let Some((index, children_visited)) = stack.pop() {
let node = self.nodes[index].as_ref().unwrap();
if children_visited {
after_processing_the_subtree(index, &node.data, s);
} else {
before_processing_children(index, &node.data, s);
stack.push((index, true));
for &child in node.children.iter().rev() {
stack.push((child, false));
}
}
}
}
pub fn traverse_mut<S>(
&mut self,
mut before_processing_children: impl FnMut(usize, &mut T, &mut S),
mut after_processing_the_subtree: impl FnMut(usize, &mut T, &mut S),
s: &mut S,
) {
if matches!(self.nodes.first(), Some(Some(_))) {
self.traverse_subtree_mut(
0,
&mut before_processing_children,
&mut after_processing_the_subtree,
s,
);
}
}
pub fn traverse_subtree_mut<S>(
&mut self,
start: usize,
mut before_processing_children: impl FnMut(usize, &mut T, &mut S),
mut after_processing_the_subtree: impl FnMut(usize, &mut T, &mut S),
s: &mut S,
) {
if self.is_empty() || self.nodes.get(start).and_then(|n| n.as_ref()).is_none() {
return;
}
self.stack.clear();
self.stack.push((start, false));
while let Some((index, children_visited)) = self.stack.pop() {
if children_visited {
let node = self.nodes[index].as_mut().unwrap();
after_processing_the_subtree(index, &mut node.data, s);
} else {
let node = self.nodes[index].as_mut().unwrap();
before_processing_children(index, &mut node.data, s);
self.stack.push((index, true));
for &child in node.children.iter().rev() {
self.stack.push((child, false));
}
}
}
}
pub fn iter(&self) -> impl Iterator<Item = (usize, &T)> {
self.nodes
.iter()
.enumerate()
.filter_map(|(index, slot)| slot.as_ref().map(|node| (index, &node.data)))
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = (usize, &mut T)> {
self.nodes
.iter_mut()
.enumerate()
.filter_map(|(index, slot)| slot.as_mut().map(|node| (index, &mut node.data)))
}
pub fn is_empty(&self) -> bool {
self.node_count == 0
}
pub fn len(&self) -> usize {
self.node_count
}
pub fn clear(&mut self) {
self.nodes.clear();
self.free_list.clear();
self.node_count = 0;
}
pub fn remove_subtree(&mut self, index: usize) {
if !matches!(self.nodes.get(index), Some(Some(_))) {
return;
}
if let Some(parent_idx) = self.nodes[index].as_ref().unwrap().parent {
if let Some(parent) = self.nodes[parent_idx].as_mut() {
parent.children.retain(|&child| child != index);
}
}
let mut removal_stack = vec![index];
while let Some(current) = removal_stack.pop() {
if let Some(node) = self.nodes[current].take() {
removal_stack.extend(node.children);
self.free_list.push(current);
self.node_count -= 1;
}
}
if self.node_count == 0 {
self.nodes.clear();
self.free_list.clear();
}
}
}
#[cfg(feature = "rayon")]
impl<T: Send + Sync> Tree<T> {
#[cfg(feature = "rayon")]
pub fn par_iter(&self) -> impl ParallelIterator<Item = (usize, &T)> {
self.nodes
.par_iter()
.enumerate()
.filter_map(|(index, slot)| slot.as_ref().map(|node| (index, &node.data)))
}
#[cfg(feature = "rayon")]
pub fn par_iter_mut(&mut self) -> impl ParallelIterator<Item = (usize, &mut T)> {
self.nodes
.par_iter_mut()
.enumerate()
.filter_map(|(index, slot)| slot.as_mut().map(|node| (index, &mut node.data)))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tree() {
let mut tree = Tree::new();
let root = tree.add_node(0);
let child1 = tree.add_child(root, 1);
let child2 = tree.add_child(root, 2);
let child3 = tree.add_child(child1, 3);
assert_eq!(tree.get(root), Some(&0));
assert_eq!(tree.get(child1), Some(&1));
assert_eq!(tree.get(child2), Some(&2));
assert_eq!(tree.get(child3), Some(&3));
assert_eq!(tree.parent_index_unchecked(child1), Some(root));
assert_eq!(tree.parent_index_unchecked(child2), Some(root));
assert_eq!(tree.parent_index_unchecked(child3), Some(child1));
assert_eq!(tree.children(root), &[child1, child2]);
assert_eq!(tree.children(child1), &[child3]);
assert_eq!(tree.children(child2), &[]);
assert_eq!(tree.children(child3), &[]);
}
#[test]
fn test_tree_iter() {
let mut tree = Tree::new();
let root = tree.add_node(0);
let child1 = tree.add_child(root, 1);
let child2 = tree.add_child(root, 2);
let child3 = tree.add_child(child1, 3);
let mut iter = tree.iter();
assert_eq!(iter.next(), Some((root, &0)));
assert_eq!(iter.next(), Some((child1, &1)));
assert_eq!(iter.next(), Some((child2, &2)));
assert_eq!(iter.next(), Some((child3, &3)));
assert_eq!(iter.next(), None);
}
#[test]
fn test_tree_iter_mut() {
let mut tree = Tree::new();
let root = tree.add_node(0);
let child1 = tree.add_child(root, 1);
let child2 = tree.add_child(root, 2);
let child3 = tree.add_child(child1, 3);
let mut iter = tree.iter_mut();
assert_eq!(iter.next(), Some((root, &mut 0)));
assert_eq!(iter.next(), Some((child1, &mut 1)));
assert_eq!(iter.next(), Some((child2, &mut 2)));
assert_eq!(iter.next(), Some((child3, &mut 3)));
assert_eq!(iter.next(), None);
}
#[test]
fn test_tree_traverse() {
let mut tree = Tree::new();
let root = tree.add_node(0); let child1 = tree.add_child(root, 1); let _child2 = tree.add_child(root, 2); let _child3 = tree.add_child(child1, 3);
let mut result = vec![];
tree.traverse(
|index, node, result| result.push(format!("Calling handler for node {index}: {node}")),
|index, _node, result| {
result.push(format!(
"Finished handling node {index} and all its children"
))
},
&mut result,
);
assert_eq!(
result,
vec![
"Calling handler for node 0: 0",
"Calling handler for node 1: 1",
"Calling handler for node 3: 3",
"Finished handling node 3 and all its children",
"Finished handling node 1 and all its children",
"Calling handler for node 2: 2",
"Finished handling node 2 and all its children",
"Finished handling node 0 and all its children",
]
);
}
#[test]
fn test_remove_subtree() {
let mut tree = Tree::new();
let root = tree.add_node("root");
let child1 = tree.add_child(root, "child1");
let child2 = tree.add_child(root, "child2");
let grandchild1 = tree.add_child(child1, "grandchild1");
let _grandchild2 = tree.add_child(child1, "grandchild2");
assert_eq!(tree.len(), 5);
tree.remove_subtree(child1);
assert_eq!(tree.len(), 2);
assert_eq!(tree.get(root), Some(&"root"));
assert_eq!(tree.get(child1), None);
assert_eq!(tree.get(child2), Some(&"child2"));
assert_eq!(tree.get(grandchild1), None);
assert_eq!(tree.children(root), &[child2]);
}
#[test]
fn test_remove_leaf_node() {
let mut tree = Tree::new();
let root = tree.add_node("root");
let child1 = tree.add_child(root, "child1");
let child2 = tree.add_child(root, "child2");
tree.remove_subtree(child1);
assert_eq!(tree.len(), 2);
assert_eq!(tree.get(child1), None);
assert_eq!(tree.children(root), &[child2]);
}
#[test]
fn test_remove_root() {
let mut tree = Tree::new();
let root = tree.add_node("root");
tree.add_child(root, "child1");
tree.add_child(root, "child2");
tree.remove_subtree(root);
assert!(tree.is_empty());
assert_eq!(tree.len(), 0);
}
#[test]
fn test_remove_and_reuse() {
let mut tree = Tree::new();
let root = tree.add_node(0);
let child1 = tree.add_child(root, 1);
tree.add_child(root, 2);
tree.add_child(child1, 3);
tree.remove_subtree(child1);
let new_child = tree.add_child(root, 10);
assert!(new_child == 3 || new_child == 1);
assert_eq!(tree.len(), 3);
assert_eq!(tree.get(new_child), Some(&10));
}
#[test]
fn test_iter_after_remove() {
let mut tree = Tree::new();
let root = tree.add_node(0);
let child1 = tree.add_child(root, 1);
let child2 = tree.add_child(root, 2);
tree.add_child(child1, 3);
tree.remove_subtree(child1);
let items: Vec<_> = tree.iter().collect();
assert_eq!(items.len(), 2);
assert_eq!(items[0], (root, &0));
assert_eq!(items[1], (child2, &2));
}
#[test]
fn test_traverse_after_remove() {
let mut tree = Tree::new();
let root = tree.add_node(0);
let child1 = tree.add_child(root, 1);
tree.add_child(root, 2);
tree.add_child(child1, 3);
tree.remove_subtree(child1);
let mut result = vec![];
tree.traverse(
|idx, data, result: &mut Vec<String>| result.push(format!("enter {idx}:{data}")),
|idx, data, result: &mut Vec<String>| result.push(format!("leave {idx}:{data}")),
&mut result,
);
assert_eq!(
result,
vec!["enter 0:0", "enter 2:2", "leave 2:2", "leave 0:0",]
);
}
#[test]
fn test_traverse_after_removing_first_root() {
let mut tree = Tree::new();
let root0 = tree.add_node("root0");
let root1 = tree.add_node("root1");
tree.add_child(root1, "child");
tree.remove_subtree(root0);
let mut result = vec![];
tree.traverse(
|idx, data, result: &mut Vec<String>| result.push(format!("enter {idx}:{data}")),
|idx, data, result: &mut Vec<String>| result.push(format!("leave {idx}:{data}")),
&mut result,
);
assert!(result.is_empty());
}
#[test]
fn test_traverse_mut_after_removing_first_root() {
let mut tree = Tree::new();
let root0 = tree.add_node(0);
let root1 = tree.add_node(10);
let child = tree.add_child(root1, 20);
tree.remove_subtree(root0);
let mut visited = vec![];
tree.traverse_mut(
|idx, data, visited: &mut Vec<(usize, i32)>| {
*data += 1;
visited.push((idx, *data));
},
|_, _, _| {},
&mut visited,
);
assert!(visited.is_empty());
assert_eq!(tree.get(root1), Some(&10));
assert_eq!(tree.get(child), Some(&20));
}
#[test]
fn test_remove_idempotent() {
let mut tree = Tree::new();
let root = tree.add_node("root");
let child = tree.add_child(root, "child");
tree.remove_subtree(child);
tree.remove_subtree(child);
assert_eq!(tree.len(), 1);
}
#[test]
fn test_remove_out_of_bounds() {
let mut tree = Tree::new();
let root = tree.add_node("root");
tree.remove_subtree(999);
assert_eq!(tree.len(), 1);
assert_eq!(tree.get(root), Some(&"root"));
}
#[test]
fn test_add_after_remove_root() {
let mut tree = Tree::new();
let root = tree.add_node("root");
tree.add_child(root, "child");
tree.remove_subtree(root);
assert!(tree.is_empty());
let new_root = tree.add_node("new_root");
assert_eq!(new_root, 0);
assert_eq!(tree.get(new_root), Some(&"new_root"));
assert_eq!(tree.len(), 1);
}
}