use super::*;
use std::{fmt::Debug, mem::MaybeUninit};
mod iterators;
pub mod sort_util;
pub mod swappable;
pub use self::iterators::StackSignal;
use self::{iterators::*, sort_util::*, swappable::*};
#[derive(Copy, Clone, Hash, Default, Debug, PartialEq, Eq)]
pub struct NodeID(pub usize);
impl NodeID {
pub fn as_usize(&self) -> usize {
self.0
}
}
#[derive(Copy, Clone)]
pub struct NodeInfo<'a, T> {
pub parent: Option<NodeID>,
pub id: NodeID,
pub val: &'a T,
}
pub struct NodeInfoMut<'a, T> {
pub parent: Option<NodeID>,
pub id: NodeID,
pub val: &'a mut T,
}
pub struct LinearTree<T> {
order: Vec<u32>,
level: Vec<u32>,
parent: Vec<Ptr>,
data: Vec<MaybeUninit<T>>,
parent_stack: Vec<usize>,
node_id: Vec<NodeID>,
node_id_counter: usize,
nodes_deleted: usize,
id_to_ptr_table: Vec<Ptr>,
}
impl<T> Default for LinearTree<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> LinearTree<T> {
pub fn new() -> Self {
Self {
order: vec![],
level: vec![],
parent: vec![],
data: vec![],
parent_stack: Vec::with_capacity(128),
node_id: vec![],
node_id_counter: 0,
nodes_deleted: 0,
id_to_ptr_table: Vec::new(),
}
}
pub fn as_slice(&self) -> &[T] {
let len = self.len();
unsafe { std::slice::from_raw_parts(self.data.as_ptr() as *const T, len) }
}
pub fn as_slice_mut(&mut self) -> &mut [T] {
let len = self.len();
unsafe { std::slice::from_raw_parts_mut(self.data.as_mut_ptr() as *mut T, len) }
}
pub fn set_parent<NID: Copy + Into<NodeID>>(&mut self, id: NID, new_parent_id: NID) {
let current_ptr = self.resolve_id_to_ptr(id.into());
let new_parent_ptr = self.resolve_id_to_ptr(new_parent_id.into());
self.parent[current_ptr.as_usize()] = new_parent_ptr;
}
pub fn iter(&self) -> impl Iterator<Item = NodeInfo<'_, T>> {
let len = self.len();
let data = &self.data;
let parent = &self.parent;
let node_id = &self.node_id;
data.iter()
.zip(parent.iter())
.enumerate()
.take(len)
.map(|(cur_ptr_usize, (val, &parent_ptr))| unsafe {
(cur_ptr_usize, parent_ptr, val.assume_init_ref())
})
.map(move |(cur_ptr_usize, parent_ptr, val)| NodeInfo {
val,
id: node_id[cur_ptr_usize],
parent: (parent_ptr != Ptr::null()).then(|| node_id[parent_ptr.as_usize()]),
})
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = NodeInfoMut<'_, T>> {
let len = self.len();
let data = &mut self.data;
let parent = &mut self.parent;
let node_id = &mut self.node_id;
data.iter_mut()
.zip(parent.iter_mut())
.enumerate()
.take(len)
.map(|(cur_ptr_usize, (val, &mut parent_ptr))| unsafe {
(cur_ptr_usize, parent_ptr, val.assume_init_mut())
})
.map(move |(cur_ptr_usize, parent_ptr, val)| NodeInfoMut {
val,
id: node_id[cur_ptr_usize],
parent: (parent_ptr != Ptr::null()).then(|| node_id[parent_ptr.as_usize()]),
})
}
pub fn get<NID>(&self, node_id: NID) -> Option<&T>
where
NID: Copy + Into<NodeID>,
{
let node_ptr = self.resolve_id_to_ptr(node_id.into());
if node_ptr == Ptr::null() {
return None;
}
self.data
.get(node_ptr.as_usize())
.map(|val| unsafe { val.assume_init_ref() })
}
pub fn get_mut<NID>(&mut self, node_id: NID) -> Option<&mut T>
where
NID: Copy + Into<NodeID>,
{
let node_ptr = self.resolve_id_to_ptr(node_id.into());
self.data
.get_mut(node_ptr.as_usize())
.map(|val| unsafe { val.assume_init_mut() })
}
pub fn get_mut_uninit<NID>(&mut self, node_id: NID) -> &mut MaybeUninit<T>
where
NID: Copy + Into<NodeID>,
{
let node_ptr = self.resolve_id_to_ptr(node_id.into());
self.data
.get_mut(node_ptr.as_usize())
.expect("node_id invalid")
}
pub fn get_parent_id<NID>(&self, id: NID) -> Option<NodeID>
where
NID: Copy + Into<NodeID>,
{
let node_ptr = self.resolve_id_to_ptr(id.into());
let parent_ptr = self.parent[node_ptr.as_usize()];
(parent_ptr != Ptr::null()).then(|| self.node_id[parent_ptr.as_usize()])
}
fn resolve_id_to_ptr(&self, id: NodeID) -> Ptr {
self.id_to_ptr_table
.get(id.as_usize())
.copied()
.unwrap_or(Ptr::null())
}
pub unsafe fn add_deferred_reconstruction(
&mut self,
data: MaybeUninit<T>,
parent_id: NodeID,
) -> NodeID {
let parent = self.resolve_id_to_ptr(parent_id);
let (nid, _) = self.allocate_node_uninit(data, parent);
#[cfg(debug_assertions)]
{
if self.parent[0] != Ptr::null() {
panic!("always add parent first");
}
}
nid
}
pub fn add(&mut self, data: T, parent_id: NodeID) -> NodeID {
let parent = self.resolve_id_to_ptr(parent_id);
let (nid, _) = self.allocate_node(data, parent);
#[cfg(debug_assertions)]
{
if self.parent[0] != Ptr::null() {
panic!("always add parent first");
}
}
self.reconstruct_preorder();
nid
}
fn recompute_prefix_ordering(&mut self) {
self.parent_stack.clear();
let root = Ptr::from(0);
self.compute_pre_order_traversal(root);
let order = &mut self.order;
let data = &mut self.data;
let node_id = &mut self.node_id;
let level = &mut self.level;
quick_co_sort(
order,
[
&mut Swappable::new(data),
&mut Swappable::new(level),
&mut Swappable::new(node_id),
],
);
}
pub fn reconstruct_preorder(&mut self) {
self.recompute_prefix_ordering();
self.reconstruct_parent_pointers();
self.reconstruct_id_to_ptr_table();
}
fn reconstruct_id_to_ptr_table(&mut self) {
for k in 0..self.data.len() {
let nid = self.node_id[k];
let ptr = Ptr::from(k);
self.id_to_ptr_table[nid.as_usize()] = ptr;
}
}
fn reconstruct_parent_pointers(&mut self) {
let root = Ptr::from(0);
let valid_nodes_len = self.len();
let level = &mut self.level;
let parent_stack = &mut self.parent_stack;
let parent = &mut self.parent;
parent_stack.clear();
parent_stack.push(root.as_usize());
for cur_node in 1..valid_nodes_len {
let cur_level = level[cur_node] as usize;
let diff = cur_level as isize - level[cur_node - 1] as isize;
if diff <= 0 {
while parent_stack.last().is_some()
&& level[*parent_stack.last().unwrap()] as usize != cur_level
{
parent_stack.pop();
}
parent_stack.pop();
}
parent[cur_node] = Ptr::from(*parent_stack.last().expect("root should exist"));
parent_stack.push(cur_node);
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn len(&self) -> usize {
self.data.len() - self.nodes_deleted
}
fn compute_pre_order_traversal(&mut self, root: Ptr) {
let mut order_idx = 0;
self.order.iter_mut().for_each(|e| *e = !0);
self.compute_pre_order_traversal_helper(root, 0, &mut order_idx)
}
fn compute_pre_order_traversal_helper(&mut self, root: Ptr, level: u32, order_idx: &mut u32) {
let self_ptr = self as *mut Self;
let root_idx = root.as_usize();
self.level[root_idx] = level;
let is_root_or_non_root_but_has_parent =
root_idx == 0 || root_idx > 0 && self.parent[root_idx] != Ptr::null();
if is_root_or_non_root_but_has_parent {
self.order[root_idx] = *order_idx;
}
*order_idx += 1;
for child in self.get_child_nodes(root) {
let local_self = unsafe { &mut *self_ptr };
local_self.compute_pre_order_traversal_helper(child, level + 1, order_idx);
}
}
fn get_child_nodes<PTR>(&self, root: PTR) -> impl Iterator<Item = Ptr> + '_
where
PTR: Into<Ptr>,
{
let num_active_nodes = self.len();
let root = root.into();
let parent = &self.parent;
(0..num_active_nodes)
.map(Ptr::from)
.filter(move |ptr| parent[ptr.as_usize()] == root)
}
fn allocate_node(&mut self, data: T, parent: Ptr) -> (NodeID, Ptr) {
unsafe { self.allocate_node_uninit(MaybeUninit::new(data), parent) }
}
unsafe fn allocate_node_uninit(&mut self, data: MaybeUninit<T>, parent: Ptr) -> (NodeID, Ptr) {
debug_assert!(
self.nodes_deleted <= self.data.len(),
"nodes_deleted cannot be greater than the length of the array"
);
if self.nodes_deleted > 0 {
let ptr = self.data.len() - self.nodes_deleted;
let node_id = self.node_id[ptr];
self.order[ptr] = !0;
self.data[ptr] = data;
self.parent[ptr] = parent;
self.level[ptr] = 0;
self.nodes_deleted -= 1;
(node_id, Ptr::from(ptr))
} else {
let node_id = NodeID(self.node_id_counter);
self.order.push(!0);
self.data.push(data);
self.level.push(0);
self.parent.push(parent);
self.node_id.push(node_id);
self.id_to_ptr_table.push(Ptr::from(self.data.len() - 1));
self.node_id_counter += 1;
(node_id, Ptr::from(self.data.len() - 1))
}
}
pub fn iter_children<NID: Into<NodeID>>(
&self,
root: NID,
) -> impl Iterator<Item = NodeInfo<'_, T>> {
let root = root.into();
self.iter_subtree(root)
.filter(move |node| node.parent.is_some() && node.parent.unwrap() == root)
}
pub fn iter_children_mut<NID: Into<NodeID>>(
&mut self,
root: NID,
) -> impl Iterator<Item = NodeInfoMut<'_, T>> {
let root = root.into();
self.iter_subtree_mut(root)
.filter(move |node| node.parent.is_some() && node.parent.unwrap() == root)
}
pub fn iter_subtree<NID: Into<NodeID>>(
&self,
root: NID,
) -> impl Iterator<Item = NodeInfo<'_, T>> {
let ptr = self.resolve_id_to_ptr(root.into());
let len = self.len();
let level = &self.level;
let node_id = &self.node_id;
let data = &self.data;
let parent = &self.parent;
let root_level = self.level[ptr];
(ptr.as_usize() + 1..)
.take_while(move |&ptr| ptr < len && level[ptr] > root_level)
.map(move |ptr| NodeInfo {
parent: (parent[ptr] != Ptr::null()).then(|| node_id[parent[ptr]]),
id: node_id[ptr],
val: unsafe { data[ptr].assume_init_ref() },
})
}
pub fn iter_subtree_mut<NID: Into<NodeID>>(
&mut self,
root: NID,
) -> impl Iterator<Item = NodeInfoMut<'_, T>> {
let ptr = self.resolve_id_to_ptr(root.into());
let len = self.len();
let level = &mut self.level;
let node_id_ptr = &mut self.node_id as *mut Vec<_>;
let data_ptr = &mut self.data as *mut Vec<MaybeUninit<_>>;
let parent_ptr = &mut self.parent as *mut Vec<_>;
let root_level = level[ptr];
(ptr.as_usize() + 1..)
.take_while(move |&ptr| ptr < len && level[ptr] > root_level)
.map(move |ptr| unsafe {
let node_id = &mut *node_id_ptr;
let data = &mut *data_ptr;
let parent = &mut *parent_ptr;
NodeInfoMut {
parent: (parent[ptr] != Ptr::null()).then(|| node_id[parent[ptr]]),
id: node_id[ptr],
val: data[ptr].assume_init_mut(),
}
})
}
pub fn remove(&mut self, id: NodeID, removed_vals: &mut Vec<T>) {
removed_vals.clear();
let ptr = self.resolve_id_to_ptr(id);
let deleted_level = self.level[ptr];
let len = self.len();
self.remove_single_node(ptr, removed_vals);
let mut subtree_node = ptr + 1;
while subtree_node.as_usize() < len && self.level[subtree_node] > deleted_level {
self.remove_single_node(subtree_node, removed_vals);
subtree_node += 1;
}
self.reconstruct_preorder();
}
fn remove_single_node(&mut self, ptr: Ptr, removed_vals: &mut Vec<T>) {
self.order[ptr] = !0;
self.parent[ptr] = Ptr::null();
self.level[ptr] = !0;
let removed_item = unsafe { self.data[ptr].assume_init_read() };
removed_vals.push(removed_item);
self.data[ptr] = MaybeUninit::zeroed();
self.nodes_deleted += 1;
}
pub fn iter_mut_stack_signals(&mut self) -> StackSignalIteratorMut<'_, T> {
StackSignalIteratorMut::new(self)
}
pub fn iter_stack_signals(&self) -> StackSignalIterator<'_, T> {
StackSignalIterator::new(self)
}
pub fn print_by_ids(&mut self) {
let mut indents = String::new();
let indent = "--";
for (signal, item, _) in StackSignalIteratorMut::new(self) {
match signal {
StackSignal::Push => indents.push_str(indent),
StackSignal::Pop { n_times } => (0..indent.len() * n_times).for_each(|_| {
indents.pop();
}),
StackSignal::Nop => {}
}
if !indents.is_empty() {
indents.pop();
indents.push('>');
println!("{}{}", indents, item.as_usize());
indents.pop();
indents.push('-');
} else {
println!("{}{}", indents, item.as_usize());
}
}
}
}
impl<T> LinearTree<T>
where
T: Debug + Display,
{
pub fn print(&mut self) {
let mut indents = String::new();
let indent = "--";
for (signal, _, item) in StackSignalIteratorMut::new(self) {
match signal {
StackSignal::Push => indents.push_str(indent),
StackSignal::Pop { n_times } => (0..indent.len() * n_times).for_each(|_| {
indents.pop();
}),
StackSignal::Nop => {}
}
if !indents.is_empty() {
indents.pop();
indents.push('>');
println!("{}{}", indents, item);
indents.pop();
indents.push('-');
} else {
println!("{}{}", indents, item);
}
}
}
}
#[test]
pub fn remove_sanity() {
let mut removed_nodes = vec![];
let mut tree = LinearTree::<i32>::new();
let root = tree.add(1, NodeID::default());
let lb = tree.add(2, root);
let rb = tree.add(3, root);
tree.add(5, lb);
tree.add(4, rb);
tree.add(7, rb);
tree.add(9, rb);
tree.print();
for _ in 0..10 {
tree.remove(rb, &mut removed_nodes);
let rb = tree.add(3, root);
tree.add(4, rb);
tree.add(7, rb);
tree.add(9, rb);
}
tree.print();
}
#[test]
pub fn drop_sanity() {
use std::{cell::*, rc::*};
let has_been_dropped = Rc::new(Cell::new(false));
struct HasHeapStuff {
dropped: Rc<Cell<bool>>,
_a: String,
_b: Vec<i32>,
}
impl Drop for HasHeapStuff {
fn drop(&mut self) {
self.dropped.set(true);
}
}
let droppable_item = HasHeapStuff {
dropped: has_been_dropped.clone(),
_a: String::from("hello world my name is adam poo poo head"),
_b: vec![0; 1_000],
};
let mut removed_nodes = vec![];
let mut tree = LinearTree::new();
let node = tree.add(droppable_item, NodeID::default());
tree.remove(node, &mut removed_nodes);
assert!(!has_been_dropped.get());
removed_nodes.clear();
assert!(has_been_dropped.get());
}