use bytemuck::{Pod, Zeroable};
use std::cmp::max;
const SENTINEL: u32 = 0;
#[derive(Copy, Clone)]
enum Register {
Left,
Right,
Height,
}
enum Field {
Root,
Size,
Capacity,
FreeListHead,
Sequence,
}
type Ancestor = (Option<u32>, Option<Register>, u32);
macro_rules! node {
( $array:expr, $index:expr ) => {
$array[($index - 1) as usize]
};
}
macro_rules! readonly_impl {
( $name:tt ) => {
impl<
'a,
K: PartialOrd + Default + Copy + Clone + Pod + Zeroable,
V: Default + Copy + Clone + Pod + Zeroable,
> $name<'a, K, V>
{
pub const fn data_len(capacity: usize) -> usize {
std::mem::size_of::<Allocator>() + (capacity * std::mem::size_of::<Node<K, V>>())
}
pub fn capacity(&self) -> usize {
self.allocator.get_field(Field::Capacity) as usize
}
pub fn len(&self) -> usize {
self.allocator.get_field(Field::Size) as usize
}
pub fn is_full(&self) -> bool {
self.allocator.get_field(Field::Size) >= self.allocator.get_field(Field::Capacity)
}
pub fn is_empty(&self) -> bool {
self.allocator.get_field(Field::Size) == 0
}
pub fn get(&self, key: &K) -> Option<V> {
self.find(key)
.map(|node_index| node!(self.nodes, node_index).value)
}
pub fn lowest(&self) -> Option<K> {
let mut node = self.allocator.get_field(Field::Root);
if node == SENTINEL {
return None;
}
while node!(self.nodes, node).get_register(Register::Left) != SENTINEL {
node = node!(self.nodes, node).get_register(Register::Left);
}
Some(node!(self.nodes, node).key)
}
pub fn contains(&self, key: &K) -> bool {
self.find(key).is_some()
}
fn find(&self, key: &K) -> Option<u32> {
let mut reference_node = self.allocator.get_field(Field::Root);
while reference_node != SENTINEL {
let current = node!(self.nodes, reference_node).key;
let target = if *key < current {
node!(self.nodes, reference_node).get_register(Register::Left)
} else if *key > current {
node!(self.nodes, reference_node).get_register(Register::Right)
} else {
return Some(reference_node);
};
reference_node = target;
}
None
}
}
};
}
pub struct AVLTree<
'a,
K: PartialOrd + Default + Copy + Clone + Pod + Zeroable,
V: Default + Copy + Clone + Pod + Zeroable,
> {
allocator: &'a Allocator,
nodes: &'a [Node<K, V>],
}
readonly_impl!(AVLTree);
impl<
'a,
K: PartialOrd + Default + Copy + Clone + Pod + Zeroable,
V: Default + Copy + Clone + Pod + Zeroable,
> AVLTree<'a, K, V>
{
pub fn from_bytes(bytes: &'a [u8]) -> Self {
let (allocator, nodes) = bytes.split_at(std::mem::size_of::<Allocator>());
let allocator = bytemuck::from_bytes::<Allocator>(allocator);
let nodes = bytemuck::cast_slice(nodes);
Self { allocator, nodes }
}
}
pub struct AVLTreeMut<
'a,
K: PartialOrd + Default + Copy + Clone + Pod + Zeroable,
V: Default + Copy + Clone + Pod + Zeroable,
> {
allocator: &'a mut Allocator,
nodes: &'a mut [Node<K, V>],
}
readonly_impl!(AVLTreeMut);
impl<
'a,
K: PartialOrd + Default + Copy + Clone + Pod + Zeroable,
V: Default + Copy + Clone + Pod + Zeroable,
> AVLTreeMut<'a, K, V>
{
pub fn from_bytes_mut(bytes: &'a mut [u8]) -> Self {
let (allocator, nodes) = bytes.split_at_mut(std::mem::size_of::<Allocator>());
let allocator = bytemuck::from_bytes_mut::<Allocator>(allocator);
let nodes = bytemuck::cast_slice_mut(nodes);
Self { allocator, nodes }
}
pub fn initialize(&mut self, capacity: u32) {
self.allocator.initialize(capacity)
}
pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
self.find(key)
.map(|node_index| &mut node!(self.nodes, node_index).value)
}
pub fn insert(&mut self, key: K, value: V) -> Option<u32> {
let mut reference_node = self.allocator.get_field(Field::Root);
if reference_node == SENTINEL {
let root = self.add(key, value);
self.allocator.set_field(Field::Root, root);
return Some(root);
}
let mut path: Vec<Ancestor> = Vec::with_capacity((self.len() as f64).log2() as usize);
path.push((None, None, reference_node));
loop {
let current_key = node!(self.nodes, reference_node).key;
let parent = reference_node;
let branch = if key < current_key {
reference_node = node!(self.nodes, parent).get_register(Register::Left);
Register::Left
} else if key > current_key {
reference_node = node!(self.nodes, parent).get_register(Register::Right);
Register::Right
} else {
return None;
};
if reference_node == SENTINEL {
if self.is_full() {
return None;
}
reference_node = self.add(key, value);
self.update_child(parent, branch, reference_node);
break;
} else {
path.push((Some(parent), Some(branch), reference_node));
}
}
self.rebalance(path);
Some(reference_node)
}
pub fn remove(&mut self, key: &K) -> Option<V> {
let mut node_index = self.allocator.get_field(Field::Root);
if node_index == SENTINEL {
return None;
}
let mut path: Vec<Ancestor> = Vec::with_capacity((self.len() as f64).log2() as usize);
path.push((None, None, node_index));
while node_index != SENTINEL {
let current_key = node!(self.nodes, node_index).key;
let parent = node_index;
let branch = if *key < current_key {
node_index = node!(self.nodes, parent).get_register(Register::Left);
Register::Left
} else if *key > current_key {
node_index = node!(self.nodes, parent).get_register(Register::Right);
Register::Right
} else {
break;
};
path.push((Some(parent), Some(branch), node_index));
}
if node_index == SENTINEL {
return None;
}
let left = node!(self.nodes, node_index).get_register(Register::Left);
let right = node!(self.nodes, node_index).get_register(Register::Right);
let replacement = if left != SENTINEL && right != SENTINEL {
let mut leftmost = right;
let mut leftmost_parent = SENTINEL;
let mut inner_path = Vec::with_capacity((self.len() as f64).log2() as usize);
while node!(self.nodes, leftmost).get_register(Register::Left) != SENTINEL {
leftmost_parent = leftmost;
leftmost = node!(self.nodes, leftmost).get_register(Register::Left);
inner_path.push((Some(leftmost_parent), Some(Register::Left), leftmost));
}
if leftmost_parent != SENTINEL {
self.update_child(
leftmost_parent,
Register::Left,
node!(self.nodes, leftmost).get_register(Register::Right),
);
}
self.update_child(leftmost, Register::Left, left);
if right != leftmost {
self.update_child(leftmost, Register::Right, right);
}
let (parent, branch, _) = path.pop().unwrap();
if let Some(parent) = parent {
self.update_child(parent, branch.expect("invalid tree structure"), leftmost);
}
path.push((parent, branch, leftmost));
if right != leftmost {
path.push((Some(leftmost), Some(Register::Right), right));
}
if !inner_path.is_empty() {
inner_path.pop();
}
path.extend(inner_path);
leftmost
} else {
let child = if left == SENTINEL && right == SENTINEL {
SENTINEL
} else if left != SENTINEL {
left
} else {
right
};
let (parent, branch, _) = path.pop().unwrap();
if let Some(parent) = parent {
self.update_child(parent, branch.expect("invalid tree structure"), child);
if child != SENTINEL {
path.push((Some(parent), branch, child));
}
}
child
};
if node_index == self.allocator.get_field(Field::Root) {
self.allocator.set_field(Field::Root, replacement);
}
self.rebalance(path);
self.remove_node(node_index)
}
fn add(&mut self, key: K, value: V) -> u32 {
let free_node = self.allocator.get_field(Field::FreeListHead);
let sequence = self.allocator.get_field(Field::Sequence);
if free_node == sequence {
if (sequence - 1) == self.allocator.get_field(Field::Capacity) {
panic!(
"tree is full ({} nodes)",
self.allocator.get_field(Field::Size)
);
}
self.allocator.set_field(Field::Sequence, sequence + 1);
self.allocator.set_field(Field::FreeListHead, sequence + 1);
} else {
self.allocator.set_field(
Field::FreeListHead,
node!(self.nodes, free_node).get_register(Register::Height),
);
}
let entry = &mut node!(self.nodes, free_node);
entry.key = key;
entry.value = value;
entry.set_register(Register::Height, 0);
self.allocator
.set_field(Field::Size, self.allocator.get_field(Field::Size) + 1);
free_node
}
fn rebalance(&mut self, path: Vec<Ancestor>) {
for (parent, branch, child) in path.iter().rev() {
let left = node!(self.nodes, *child).get_register(Register::Left);
let right = node!(self.nodes, *child).get_register(Register::Right);
let balance_factor = self.balance_factor(left, right);
let index = if balance_factor > 1 {
let left_left = node!(self.nodes, left).get_register(Register::Left);
let left_right = node!(self.nodes, left).get_register(Register::Right);
let left_balance_factor = self.balance_factor(left_left, left_right);
if left_balance_factor < 0 {
let index = self.left_rotate(left);
self.update_child(*child, Register::Left, index);
}
Some(self.right_rotate(*child))
} else if balance_factor < -1 {
let right_left = node!(self.nodes, right).get_register(Register::Left);
let right_right = node!(self.nodes, right).get_register(Register::Right);
let right_balance_factor = self.balance_factor(right_left, right_right);
if right_balance_factor > 0 {
let index = self.right_rotate(right);
self.update_child(*child, Register::Right, index);
}
Some(self.left_rotate(*child))
} else {
self.update_height(*child);
None
};
if let Some(index) = index {
if let Some(parent) = parent {
self.update_child(*parent, branch.expect("invalid tree structure"), index);
} else {
self.allocator.set_field(Field::Root, index);
self.update_height(index);
}
}
}
}
fn balance_factor(&self, left: u32, right: u32) -> i32 {
let left_height = if left != SENTINEL {
node!(self.nodes, left).get_register(Register::Height) as i32 + 1
} else {
0
};
let right_height = if right != SENTINEL {
node!(self.nodes, right).get_register(Register::Height) as i32 + 1
} else {
0
};
left_height - right_height
}
fn left_rotate(&mut self, index: u32) -> u32 {
let right = node!(self.nodes, index).get_register(Register::Right);
let right_left = node!(self.nodes, right).get_register(Register::Left);
self.update_child(index, Register::Right, right_left);
self.update_child(right, Register::Left, index);
right
}
fn right_rotate(&mut self, index: u32) -> u32 {
let left = node!(self.nodes, index).get_register(Register::Left);
let left_right = node!(self.nodes, left).get_register(Register::Right);
self.update_child(index, Register::Left, left_right);
self.update_child(left, Register::Right, index);
left
}
#[inline]
fn update_child(&mut self, parent: u32, branch: Register, child: u32) {
match branch {
Register::Left => node!(self.nodes, parent).set_register(Register::Left, child),
Register::Right => node!(self.nodes, parent).set_register(Register::Right, child),
_ => panic!("invalid branch"),
}
self.update_height(parent);
}
fn update_height(&mut self, index: u32) {
let left = node!(self.nodes, index).get_register(Register::Left);
let right = node!(self.nodes, index).get_register(Register::Right);
let height = if left == SENTINEL && right == SENTINEL {
0
} else {
let left_height = if left != SENTINEL {
node!(self.nodes, left).get_register(Register::Height)
} else {
0
};
let right_height = if right != SENTINEL {
node!(self.nodes, right).get_register(Register::Height)
} else {
0
};
max(left_height, right_height) + 1
};
node!(self.nodes, index).set_register(Register::Height, height);
}
fn remove_node(&mut self, index: u32) -> Option<V> {
if index == SENTINEL {
return None;
}
let node = &mut node!(self.nodes, index);
let value = node.value;
node.initialize(K::default(), V::default());
let free_list_head = self.allocator.get_field(Field::FreeListHead);
node.set_register(Register::Height, free_list_head);
self.allocator.set_field(Field::FreeListHead, index);
self.allocator
.set_field(Field::Size, self.allocator.get_field(Field::Size) - 1);
Some(value)
}
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable)]
pub struct Allocator {
fields: [u32; 6],
}
impl Allocator {
pub fn initialize(&mut self, capacity: u32) {
self.fields = [SENTINEL, 0, capacity, 1, 1, 0];
}
#[inline(always)]
fn get_field(&self, field: Field) -> u32 {
self.fields[field as usize]
}
#[inline(always)]
fn set_field(&mut self, field: Field, value: u32) {
self.fields[field as usize] = value;
}
}
#[repr(C)]
#[derive(Clone, Copy, Default)]
pub struct Node<
K: PartialOrd + Copy + Clone + Default + Pod + Zeroable,
V: Default + Copy + Clone + Pod + Zeroable,
> {
registers: [u32; 4],
key: K,
value: V,
}
impl<
K: PartialOrd + Copy + Clone + Default + Pod + Zeroable,
V: Default + Copy + Clone + Pod + Zeroable,
> Node<K, V>
{
fn initialize(&mut self, key: K, value: V) {
self.registers = [SENTINEL, SENTINEL, 0, 0];
self.key = key;
self.value = value;
}
#[inline(always)]
fn get_register(&self, register: Register) -> u32 {
self.registers[register as usize]
}
#[inline(always)]
fn set_register(&mut self, register: Register, value: u32) {
self.registers[register as usize] = value;
}
}
unsafe impl<
K: PartialOrd + Copy + Clone + Default + Pod + Zeroable,
V: Default + Copy + Clone + Pod + Zeroable,
> Zeroable for Node<K, V>
{
}
unsafe impl<
K: PartialOrd + Copy + Clone + Default + Pod + Zeroable,
V: Default + Copy + Clone + Pod + Zeroable,
> Pod for Node<K, V>
{
}
#[cfg(test)]
mod tests {
use crate::collections::AVLTreeMut;
#[test]
fn test_insert() {
const CAPACITY: usize = 10;
let mut data = [0u8; AVLTreeMut::<u64, u64>::data_len(CAPACITY)];
let mut tree = AVLTreeMut::from_bytes_mut(&mut data);
tree.allocator.initialize(CAPACITY as u32);
for i in 0..CAPACITY {
let key = i as u64;
let value = i as u64;
let _ = tree.insert(key, value);
}
assert_eq!(tree.len(), CAPACITY);
for i in 0..CAPACITY {
let key = i as u64;
tree.get(&key).unwrap();
}
}
#[test]
fn test_large_insert() {
const CAPACITY: usize = 10_000;
let mut data = [0u8; AVLTreeMut::<u64, u64>::data_len(CAPACITY)];
let mut tree = AVLTreeMut::from_bytes_mut(&mut data);
tree.allocator.initialize(CAPACITY as u32);
for i in 0..CAPACITY {
let key = (i + 1) as u64;
let value = (i + 1) as u64;
let _ = tree.insert(key, value);
}
assert_eq!(tree.len(), CAPACITY);
for i in 0..CAPACITY {
let key = (i + 1) as u64;
tree.get(&key).unwrap();
}
}
#[test]
fn test_large_remove() {
const CAPACITY: usize = 10_000;
let mut data = [0u8; AVLTreeMut::<u64, u64>::data_len(CAPACITY)];
let mut tree = AVLTreeMut::from_bytes_mut(&mut data);
tree.allocator.initialize(CAPACITY as u32);
for i in 0..CAPACITY {
let key = (i + 1) as u64;
let value = (i + 1) as u64;
let _ = tree.insert(key, value);
}
assert_eq!(tree.len(), CAPACITY);
for i in 0..CAPACITY {
let key = (i + 1) as u64;
tree.remove(&key).unwrap();
}
assert_eq!(tree.len(), 0);
}
#[test]
fn test_large_remove_add() {
const CAPACITY: usize = 10_000;
let mut data = [0u8; AVLTreeMut::<u64, u64>::data_len(CAPACITY)];
let mut tree = AVLTreeMut::from_bytes_mut(&mut data);
tree.allocator.initialize(CAPACITY as u32);
for i in 0..CAPACITY {
let key = (i + 1) as u64;
let value = (i + 1) as u64;
let _ = tree.insert(key, value);
}
assert_eq!(tree.len(), CAPACITY);
for i in 0..CAPACITY {
let key = (i + 1) as u64;
tree.remove(&key).unwrap();
}
assert_eq!(tree.len(), 0);
for i in 0..CAPACITY {
let key = (i + 1) as u64;
let value = (i + 1) as u64;
let _ = tree.insert(key, value);
}
assert_eq!(tree.len(), CAPACITY);
for i in 0..CAPACITY {
let key = (i + 1) as u64;
tree.get(&key).unwrap();
}
}
#[test]
fn test_insert_when_full() {
const CAPACITY: usize = 10;
let mut data = [0u8; AVLTreeMut::<u64, u64>::data_len(CAPACITY)];
let mut tree = AVLTreeMut::from_bytes_mut(&mut data);
tree.allocator.initialize(CAPACITY as u32);
for i in 0..CAPACITY {
let key = i as u64;
let value = i as u64;
let _ = tree.insert(key, value);
}
assert_eq!(tree.len(), CAPACITY);
assert!(tree.is_full());
assert!(tree.insert(10, 0).is_none());
tree.remove(&0).unwrap();
tree.insert(10, 0).unwrap();
assert!(tree.is_full());
assert!(tree.insert(20, 0).is_none());
}
}