use crate::arena::{Arena, EntryId};
pub use children::*;
pub use error::*;
pub use location::*;
pub use node_mut::*;
pub use node_ref::*;
mod children;
mod error;
mod location;
mod node_mut;
mod node_ref;
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Tree<T: Clone> {
root: NodeId,
pub(crate) arena: Arena<Node<T>>,
}
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub(crate) struct Node<T: Clone> {
data: T,
parent: Option<NodeId>,
first_child: Option<NodeId>,
previous_sibling: Option<NodeId>,
next_sibling: Option<NodeId>,
}
impl<T: Clone> Node<T> {
fn new(data: T) -> Self {
Self {
data,
parent: None,
first_child: None,
previous_sibling: None,
next_sibling: None,
}
}
}
#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct NodeId(pub(crate) EntryId);
impl<T: Clone> Tree<T> {
#[must_use]
pub fn new_with_root(value: T) -> Self {
let mut arena = Arena::new();
let root = NodeId(arena.insert(Node::new(value)));
Self { root, arena }
}
fn create_node(&mut self, value: T) -> NodeId {
let id = self.arena.insert(Node::new(value));
NodeId(id)
}
fn validate_location(&self, location: Location) -> Result<(), InvalidLocation> {
match location {
Location::AfterSibling(id) => {
if id == self.root {
return Err(InvalidLocation::RootCannotHaveSiblings);
}
self.arena
.get(id.0)
.ok_or(InvalidLocation::NoSuchNode(id))?;
}
Location::FirstChildOf(id) => {
self.arena
.get(id.0)
.ok_or(InvalidLocation::NoSuchNode(id))?;
}
}
Ok(())
}
pub fn insert(&mut self, value: T, location: Location) -> Result<NodeId, InvalidLocation> {
self.validate_location(location)?;
let node = self.create_node(value);
self.attach(node, location)
.expect("unreachable; location validated");
Ok(node)
}
fn attach(&mut self, node_id: NodeId, location: Location) -> Result<(), InvalidLocation> {
match location {
Location::AfterSibling(previous_sibling_id) => {
let previous_sibling = self
.arena
.get(previous_sibling_id.0)
.ok_or(InvalidLocation::NoSuchNode(previous_sibling_id))?;
let next_sibling_id = previous_sibling.next_sibling;
let parent_id = match previous_sibling.parent {
None => return Err(InvalidLocation::RootCannotHaveSiblings),
Some(id) => id,
};
let node = &mut self.arena[node_id.0];
node.next_sibling = next_sibling_id;
node.previous_sibling = Some(previous_sibling_id);
node.parent = Some(parent_id);
self.arena[previous_sibling_id.0].next_sibling = Some(node_id);
if let Some(next_sibling_id) = next_sibling_id {
self.arena[next_sibling_id.0].previous_sibling = Some(node_id);
}
Ok(())
}
Location::FirstChildOf(parent_id) => {
let parent = self
.arena
.get_mut(parent_id.0)
.ok_or(InvalidLocation::NoSuchNode(parent_id))?;
let original_first_child_id = parent.first_child;
parent.first_child = Some(node_id);
if let Some(original_first_child_id) = original_first_child_id {
self.arena[original_first_child_id.0].previous_sibling = Some(node_id);
}
let node = &mut self.arena[node_id.0];
node.parent = Some(parent_id);
node.next_sibling = original_first_child_id;
node.previous_sibling = None;
Ok(())
}
}
}
fn detach(&mut self, id: NodeId) -> Result<(), RemoveByIdError> {
let node = self
.arena
.get(id.0)
.ok_or(RemoveByIdError::NoSuchNode(id))?;
let parent_id = node.parent.ok_or(RemoveByIdError::CannotRemoveRoot)?;
let next_sibling = node.next_sibling;
let previous_sibling = node.previous_sibling;
match node.previous_sibling {
None => {
self.arena[parent_id.0].first_child = next_sibling;
}
Some(previous_id) => {
self.arena[previous_id.0].next_sibling = next_sibling;
}
}
if let Some(next_sibling) = next_sibling {
self.arena[next_sibling.0].previous_sibling = previous_sibling;
}
Ok(())
}
#[cfg(test)]
pub(crate) fn validate(&self) {
use std::collections::HashSet;
self.arena.validate();
fn collect_ids<T: Clone>(
arena: &Arena<Node<T>>,
found_ids: &mut HashSet<NodeId>,
id: NodeId,
) {
let had_id = !found_ids.insert(id);
assert!(!had_id, "Circular graph: second occurrence of {:?}", id);
let node = &arena[id.0];
if let Some(next_sibling) = node.next_sibling {
assert_eq!(
arena[next_sibling.0].previous_sibling,
Some(id),
"Inconsistent sibling references"
);
collect_ids(arena, found_ids, next_sibling);
}
if let Some(first_child) = node.first_child {
assert_eq!(
arena[first_child.0].parent,
Some(id),
"Inconsistent parent-child references"
);
collect_ids(arena, found_ids, first_child);
}
}
let expected_ids: HashSet<_> = self.arena.iter_items().map(|(id, _)| NodeId(id)).collect();
let mut found_ids = HashSet::new();
collect_ids(&self.arena, &mut found_ids, self.root);
assert!(
self.arena[self.root.0].next_sibling.is_none(),
"Root has sibling"
);
assert_eq!(expected_ids, found_ids)
}
#[must_use]
pub fn root(&self) -> NodeRef<T> {
NodeRef::new(self, self.root)
}
#[must_use]
pub fn root_mut(&mut self) -> NodeMut<T> {
NodeMut::new(self, self.root)
}
#[must_use]
pub fn get(&self, id: NodeId) -> Option<NodeRef<T>> {
self.arena.get(id.0).map(|_| NodeRef::new(self, id))
}
#[must_use]
pub fn get_mut(&mut self, id: NodeId) -> Option<NodeMut<T>> {
if self.arena.get(id.0).is_some() {
Some(NodeMut::new(self, id))
} else {
None
}
}
pub fn iter_nodes(&self) -> impl Iterator<Item = NodeRef<T>> {
self.arena
.iter_items()
.map(move |(index, _)| NodeRef::new(self, NodeId(index)))
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use super::*;
#[test]
fn new_with_root() {
let tree = Tree::new_with_root(42);
assert_eq!(tree.root().data(), &42);
tree.validate();
}
#[test]
fn insert() {
let mut tree = Tree::new_with_root(42);
let location = Location::FirstChildOf(tree.root().id());
assert!(tree.insert(1, location).is_ok());
let invalid_location = Location::AfterSibling(tree.root().id());
assert!(matches!(
tree.insert(2, invalid_location),
Err(InvalidLocation::RootCannotHaveSiblings)
));
let invalid_location = Location::FirstChildOf(NodeId(EntryId(404)));
assert!(matches!(
tree.insert(3, invalid_location),
Err(InvalidLocation::NoSuchNode(_))
));
let children: Vec<_> = tree.root().children().map(|node| *node.data()).collect();
assert_eq!(children, &[1]);
tree.validate();
}
#[test]
fn data() {
let mut tree = Tree::new_with_root(42);
assert_eq!(tree.root().data(), &42);
assert_eq!(tree.root_mut().data(), &42);
tree.validate();
}
#[test]
fn data_mut() {
let mut tree = Tree::new_with_root(42);
assert_eq!(tree.root().data(), &42);
*tree.root_mut().data_mut() = 100;
assert_eq!(tree.root().data(), &100);
tree.validate();
}
#[test]
fn push_front_child() {
let mut tree = Tree::new_with_root(42);
let child_id = tree.root_mut().push_front_child(43).id();
assert_eq!(tree.get(child_id).unwrap().data(), &43);
tree.validate();
}
#[test]
fn push_next_sibling() {
let mut tree = Tree::new_with_root(42);
let sibling_id = tree
.root_mut()
.push_front_child(43)
.push_next_sibling(44)
.unwrap()
.id();
assert_eq!(tree.get(sibling_id).unwrap().data(), &44);
tree.validate();
}
#[test]
fn get_nonexistent() {
let mut tree = Tree::new_with_root(42);
assert!(tree.get(NodeId(EntryId(404))).is_none());
assert!(tree.get_mut(NodeId(EntryId(404))).is_none());
tree.validate();
}
#[test]
fn iter_items() {
let mut tree = Tree::new_with_root(42);
let mut root = tree.root_mut();
root.push_front_child(43);
root.push_front_child(44);
let set: HashSet<_> = tree.iter_nodes().map(|node| *node.data()).collect();
assert_eq!(set.len(), 3);
assert!(set.contains(&42));
assert!(set.contains(&43));
assert!(set.contains(&44));
tree.validate();
}
#[test]
fn remove_root_fails() {
let mut tree = Tree::new_with_root(42);
assert!(tree.root_mut().remove().is_err());
tree.validate();
}
#[test]
fn remove_node() {
let mut tree = Tree::new_with_root(42);
let mut root = tree.root_mut();
let node_id = root.push_front_child(43).id();
root.push_front_child(44);
let set: HashSet<_> = tree.iter_nodes().map(|node| *node.data()).collect();
assert_eq!(set.len(), 3);
tree.get_mut(node_id).unwrap().remove().unwrap();
let set: HashSet<_> = tree.iter_nodes().map(|node| *node.data()).collect();
assert_eq!(set.len(), 2);
assert!(set.contains(&42));
assert!(set.contains(&44));
tree.validate();
}
#[test]
fn remove_node_with_consumer() {
let mut tree = Tree::new_with_root(0);
let mut root = tree.root_mut();
let mut node = root.push_front_child(1);
let node_id = node.id();
node.push_front_child(2).push_next_sibling(3).unwrap();
let mut removed = HashSet::new();
tree.get_mut(node_id)
.unwrap()
.remove_with_consumer(|value| {
removed.insert(value);
})
.expect("cannot remove");
assert_eq!(removed.len(), 3);
assert!(removed.contains(&1));
assert!(removed.contains(&2));
assert!(removed.contains(&3));
tree.validate();
}
#[test]
fn move_node() {
let mut tree = Tree::new_with_root(0);
let b_id = tree.root_mut().push_front_child(2).id();
tree.root_mut().push_front_child(1);
let children: Vec<_> = tree.root().children().map(|node| *node.data()).collect();
assert_eq!(children, vec![1, 2]);
let root_id = tree.root().id();
tree.get_mut(b_id)
.unwrap()
.move_to(Location::FirstChildOf(root_id))
.expect("Could not move");
let children: Vec<_> = tree.root().children().map(|node| *node.data()).collect();
assert_eq!(children, vec![2, 1]);
tree.validate();
}
#[test]
fn move_identity() {
let mut tree = Tree::new_with_root(0);
let root_id = tree.root().id();
tree.root_mut()
.move_to(Location::AfterSibling(root_id))
.expect("Could not move");
tree.validate();
}
#[test]
fn move_under_child_fails() {
let mut tree = Tree::new_with_root(0);
let mut root = tree.root_mut();
let mut a = root.push_front_child(1);
let b_id = a.push_front_child(2).id();
let result = a.move_to(Location::FirstChildOf(b_id));
assert!(result.is_err());
tree.validate();
}
#[test]
fn persistence() {
let mut tree = Tree::new_with_root(42);
let old = tree.clone();
let index = tree.create_node(5);
assert!(tree.get(index).is_some());
assert!(old.get(index).is_none());
}
#[test]
#[should_panic]
fn validate() {
let mut tree = Tree::new_with_root(0);
tree.arena[tree.root.0].next_sibling = Some(NodeId(EntryId(0)));
tree.validate();
}
}