use std::{
borrow::Borrow,
cmp::{self, Ordering},
collections::HashMap,
fmt::{self, Debug, Formatter},
hash::{BuildHasher, Hash, Hasher, RandomState},
iter::FusedIterator,
marker::PhantomData,
mem::MaybeUninit,
ptr::{self, NonNull},
};
use serde::{
de::{Deserialize, Deserializer, SeqAccess, Visitor},
ser::{Serialize, SerializeSeq, Serializer},
};
pub struct HashTree<T, S = RandomState> {
map: HashMap<DataRef<T>, NonNull<Entry<T>>, S>,
root: *mut Entry<T>,
}
struct Entry<T> {
data: MaybeUninit<T>,
left: *mut Entry<T>,
right: *mut Entry<T>,
height: usize,
}
struct DataRef<T> {
data: *const T,
}
#[repr(transparent)]
struct KeyWrapper<K>(K);
pub struct Iter<'a, T, S> {
tree: &'a HashTree<T, S>,
stack: Vec<IterVisitor<T>>,
}
struct IterVisitor<T> {
entry: *const Entry<T>,
visited_self: bool,
visited_left: bool,
visited_right: bool,
}
pub struct IntoIter<T, S> {
tree: HashTree<T, S>,
}
impl<T, S> HashTree<T, S>
where
T: Eq + Ord + Hash,
S: BuildHasher,
{
#[inline]
pub fn contains<K>(&self, key: &K) -> bool
where
T: Borrow<K>,
K: Eq + Hash,
{
self.map.contains_key(KeyWrapper::from_ref(key))
}
#[inline]
pub fn insert(&mut self, data: T) -> Option<T> {
let maybe_old_entry = self.map.remove(&DataRef::from_ref(&data));
if let Some(old_entry) = maybe_old_entry {
self.root = remove_entry(self.root, old_entry.as_ptr());
reset_entry(old_entry.as_ptr());
}
let entry = Entry::new(data);
let entry_ptr = entry.as_ptr();
self.root = insert_entry(self.root, entry_ptr);
let data_ref = DataRef::from_entry_ptr(entry_ptr);
self.map.insert(data_ref, entry);
maybe_old_entry.map(|old_entry| Entry::<T>::into_data(old_entry.as_ptr()))
}
#[inline]
pub fn pop_smallest(&mut self) -> Option<T> {
if self.root.is_null() {
return None;
}
let entry_ptr = find_min(self.root);
self.root = remove_entry(self.root, entry_ptr);
let data_ref = DataRef::from_entry_ptr(entry_ptr);
self.map.remove(&data_ref).unwrap();
Some(Entry::<T>::into_data(entry_ptr))
}
#[inline]
pub fn pop_largest(&mut self) -> Option<T> {
if self.root.is_null() {
return None;
}
let entry_ptr = find_max(self.root);
self.root = remove_entry(self.root, entry_ptr);
let data_ref = DataRef::from_entry_ptr(entry_ptr);
self.map.remove(&data_ref).unwrap();
Some(Entry::<T>::into_data(entry_ptr))
}
#[inline]
pub fn get<K>(&self, key: &K) -> Option<&T>
where
T: Borrow<K>,
K: Eq + Hash,
{
let entry = self.map.get(KeyWrapper::from_ref(key))?;
let entry_ptr = entry.as_ptr();
let data = unsafe { (*entry_ptr).data.assume_init_ref() };
Some(data)
}
#[inline]
pub fn left<K>(&self, key: &K) -> Option<&T>
where
T: Borrow<K>,
K: Eq + Hash,
{
let entry = self.map.get(KeyWrapper::from_ref(key))?;
let entry_ptr = entry.as_ptr();
let left_ptr = unsafe { (*entry_ptr).left };
if left_ptr.is_null() {
return None;
}
let data = unsafe { (*left_ptr).data.assume_init_ref() };
Some(data)
}
#[inline]
pub fn right<K>(&self, key: &K) -> Option<&T>
where
T: Borrow<K>,
K: Eq + Hash,
{
let entry = self.map.get(KeyWrapper::from_ref(key))?;
let entry_ptr = entry.as_ptr();
let right_ptr = unsafe { (*entry_ptr).right };
if right_ptr.is_null() {
return None;
}
let data = unsafe { (*right_ptr).data.assume_init_ref() };
Some(data)
}
#[inline]
pub fn update<K, F>(&mut self, key: &K, mut f: F)
where
T: Borrow<K>,
K: Eq + Hash,
F: FnMut(&mut T),
{
let Some(entry) = self.map.remove(KeyWrapper::from_ref(key)) else {
return;
};
let entry_ptr = entry.as_ptr();
let data = unsafe { &mut *(*entry_ptr).data.as_mut_ptr() };
self.root = remove_entry(self.root, entry_ptr);
reset_entry(entry_ptr);
f(data);
let data_ref = DataRef::from_ref(data);
self.root = insert_entry(self.root, entry_ptr);
self.map.insert(data_ref, entry);
}
#[inline]
pub fn remove<K>(&mut self, key: &K) -> Option<T>
where
T: Borrow<K>,
K: Eq + Hash,
{
let entry = self.map.remove(KeyWrapper::from_ref(key))?;
let entry_ptr = entry.as_ptr();
self.root = remove_entry(self.root, entry_ptr);
Some(Entry::<T>::into_data(entry_ptr))
}
#[inline]
pub fn clear(&mut self) {
while self.pop_smallest().is_some() {}
}
}
impl<T, S> HashTree<T, S> {
pub fn with_hasher(hasher: S) -> Self {
HashTree {
map: HashMap::with_hasher(hasher),
root: ptr::null_mut(),
}
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
pub fn len(&self) -> usize {
self.map.len()
}
#[inline]
pub fn iter(&self) -> Iter<'_, T, S> {
let mut stack = Vec::<IterVisitor<T>>::new();
if !self.root.is_null() {
stack.push(IterVisitor::new(self.root))
}
Iter {
tree: self,
stack,
}
}
}
impl<T> HashTree<T, RandomState> {
pub fn new() -> Self {
HashTree::with_hasher(RandomState::new())
}
}
impl<T, S> Default for HashTree<T, S>
where
S: Default,
{
fn default() -> Self {
HashTree::<T, S>::with_hasher(S::default())
}
}
impl<T, S> PartialEq for HashTree<T, S>
where
T: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.len() == other.len() && self.iter().eq(other.iter())
}
}
impl<T, S> Eq for HashTree<T, S> where T: Eq {}
impl<T> Entry<T> {
fn new(data: T) -> NonNull<Self> {
let entry = Entry {
data: MaybeUninit::new(data),
left: ptr::null_mut(),
right: ptr::null_mut(),
height: 1,
};
let boxed = Box::new(entry);
unsafe { NonNull::new_unchecked(Box::into_raw(boxed)) }
}
fn set_left(&mut self, left: *mut Entry<T>) {
self.left = left;
self.refresh_height();
}
fn set_right(&mut self, right: *mut Entry<T>) {
self.right = right;
self.refresh_height();
}
fn refresh_height(&mut self) {
let left_height = if !self.left.is_null() {
unsafe { (*self.left).height }
} else {
0
};
let right_height = if !self.right.is_null() {
unsafe { (*self.right).height }
} else {
0
};
self.height = cmp::max(left_height, right_height) + 1;
}
fn into_data(entry_ptr: *mut Entry<T>) -> T {
unsafe {
let entry = *Box::from_raw(entry_ptr);
entry.data.assume_init()
}
}
}
fn insert_entry<T>(root: *mut Entry<T>, entry: *mut Entry<T>) -> *mut Entry<T>
where
T: Ord,
{
if root.is_null() {
return entry;
}
let cmp = unsafe {
(*entry)
.data
.assume_init_ref()
.cmp((*root).data.assume_init_ref())
};
match cmp {
Ordering::Less => unsafe {
let new_left = insert_entry((*root).left, entry);
(*root).set_left(new_left);
let balanced = balance_entry(root);
NonNull::new(balanced).unwrap().as_ptr()
},
Ordering::Greater => unsafe {
let new_right = insert_entry((*root).right, entry);
(*root).set_right(new_right);
let balanced = balance_entry(root);
NonNull::new(balanced).unwrap().as_ptr()
},
Ordering::Equal => unsafe {
let new_left = (*root).left;
let new_right = (*root).right;
(*entry).set_left(new_left);
(*entry).set_right(new_right);
entry
},
}
}
fn remove_entry<T>(root: *mut Entry<T>, entry: *mut Entry<T>) -> *mut Entry<T>
where
T: Ord,
{
if root.is_null() {
return root;
}
let cmp = unsafe {
let entry_ref = entry.as_ref().unwrap();
entry_ref
.data
.assume_init_ref()
.cmp((*root).data.assume_init_ref())
};
match cmp {
Ordering::Less => unsafe {
let new_left = remove_entry((*root).left, entry);
(*root).set_left(new_left);
let balanced = balance_entry(root);
NonNull::new(balanced).unwrap().as_ptr()
},
Ordering::Greater => unsafe {
let new_right = remove_entry((*root).right, entry);
(*root).set_right(new_right);
let balanced = balance_entry(root);
NonNull::new(balanced).unwrap().as_ptr()
},
Ordering::Equal => unsafe {
let left = (*root).left;
let right = (*root).right;
if left.is_null() || right.is_null() {
if !left.is_null() {
return left;
}
if !right.is_null() {
return right;
}
ptr::null_mut()
} else {
let right_min = find_min(right);
(*right_min).right = remove_entry(right, right_min);
(*right_min).left = (*root).left;
right_min
}
},
}
}
fn reset_entry<T>(entry: *mut Entry<T>) {
unsafe {
(*entry).left = ptr::null_mut();
(*entry).right = ptr::null_mut();
(*entry).height = 1;
}
}
fn find_min<T>(root: *mut Entry<T>) -> *mut Entry<T>
where
T: Ord,
{
if root.is_null() {
return root;
}
let mut current = root;
loop {
let left = unsafe { (*current).left };
if left.is_null() {
return current;
}
current = left;
}
}
fn find_max<T>(root: *mut Entry<T>) -> *mut Entry<T>
where
T: Ord,
{
if root.is_null() {
return root;
}
let mut current = root;
loop {
let right = unsafe { (*current).right };
if right.is_null() {
return current;
}
current = right;
}
}
fn balance_entry<T>(entry: *mut Entry<T>) -> *mut Entry<T> {
let factor = balance_factor(entry);
if factor > 1 {
let left_factor = unsafe { balance_factor((*entry).left) };
if left_factor > 0 {
return ll_rotate(entry);
} else {
return lr_rotate(entry);
};
}
if factor < -1 {
let right_factor = unsafe { balance_factor((*entry).right) };
if right_factor > 0 {
return rl_rotate(entry);
} else {
return rr_rotate(entry);
}
}
entry
}
fn balance_factor<T>(entry: *mut Entry<T>) -> i64 {
if entry.is_null() {
return 0;
}
let left = unsafe { (*entry).left };
let right = unsafe { (*entry).right };
let left_height = if !left.is_null() {
unsafe { (*left).height }
} else {
0
};
let right_height = if !right.is_null() {
unsafe { (*right).height }
} else {
0
};
left_height as i64 - right_height as i64
}
fn rr_rotate<T>(old_root: *mut Entry<T>) -> *mut Entry<T> {
if old_root.is_null() {
return old_root;
}
unsafe {
let new_root = (*old_root).right;
if new_root.is_null() {
return old_root;
}
(*old_root).right = (*new_root).left;
(*new_root).left = old_root;
(*old_root).refresh_height();
(*new_root).refresh_height();
new_root
}
}
fn ll_rotate<T>(old_root: *mut Entry<T>) -> *mut Entry<T> {
if old_root.is_null() {
return old_root;
}
unsafe {
let new_root = (*old_root).left;
if new_root.is_null() {
return old_root;
}
(*old_root).left = (*new_root).right;
(*new_root).right = old_root;
(*old_root).refresh_height();
(*new_root).refresh_height();
new_root
}
}
fn lr_rotate<T>(old_root: *mut Entry<T>) -> *mut Entry<T> {
if old_root.is_null() {
return old_root;
}
unsafe {
(*old_root).left = rr_rotate((*old_root).left);
ll_rotate(old_root)
}
}
fn rl_rotate<T>(old_root: *mut Entry<T>) -> *mut Entry<T> {
if old_root.is_null() {
return old_root;
}
unsafe {
(*old_root).right = ll_rotate((*old_root).right);
rr_rotate(old_root)
}
}
impl<T> DataRef<T> {
fn from_ref(data: &T) -> Self {
DataRef {
data,
}
}
fn from_entry_ptr(entry_ptr: *mut Entry<T>) -> Self {
let data_ptr = unsafe { (*entry_ptr).data.as_ptr() };
DataRef {
data: data_ptr
}
}
}
impl<T> Hash for DataRef<T>
where
T: Hash,
{
fn hash<H>(&self, state: &mut H)
where
H: Hasher,
{
unsafe { (*self.data).hash(state) }
}
}
impl<T> PartialEq for DataRef<T>
where
T: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
unsafe { (*self.data).eq(&*other.data) }
}
}
impl<T> Eq for DataRef<T> where T: Eq {}
impl<K> KeyWrapper<K> {
fn from_ref(key: &K) -> &Self {
unsafe { &*(key as *const K as *const KeyWrapper<K>) }
}
}
impl<K> Hash for KeyWrapper<K>
where
K: Hash,
{
fn hash<H>(&self, state: &mut H)
where
H: Hasher,
{
self.0.hash(state)
}
}
impl<K> PartialEq for KeyWrapper<K>
where
K: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.0.eq(&other.0)
}
}
impl<K> Eq for KeyWrapper<K> where K: Eq {}
impl<K, T> Borrow<KeyWrapper<K>> for DataRef<T>
where
T: Borrow<K>,
{
fn borrow(&self) -> &KeyWrapper<K> {
let data_ref = unsafe { &*self.data }.borrow();
KeyWrapper::from_ref(data_ref)
}
}
impl<'a, T, S> Iterator for Iter<'a, T, S> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
loop {
let visitor = self.stack.last_mut()?;
if !visitor.visited_left {
visitor.visited_left = true;
let left_visitor = visitor.get_left_visitor();
self.stack.push(left_visitor);
} else if !visitor.visited_self {
visitor.visited_self = true;
return Some(visitor.get_ref());
} else if !visitor.visited_right {
visitor.visited_right = true;
let right_visitor = visitor.get_right_visitor();
self.stack.push(right_visitor);
} else {
self.stack.pop();
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.tree.len(), Some(self.tree.len()))
}
}
impl<T> IterVisitor<T> {
fn new(entry: *const Entry<T>) -> Self {
let visited_left = unsafe { (*entry).left.is_null() };
let visited_right = unsafe { (*entry).right.is_null() };
IterVisitor {
entry,
visited_self: false,
visited_left,
visited_right,
}
}
fn get_ref<'a>(&self) -> &'a T {
unsafe { (*self.entry).data.assume_init_ref() }
}
fn get_left_visitor(&self) -> Self {
let left = unsafe { (*self.entry).left };
IterVisitor::<T>::new(left)
}
fn get_right_visitor(&self) -> Self {
let right = unsafe { (*self.entry).right };
IterVisitor::<T>::new(right)
}
}
impl<'a, T, S> IntoIterator for &'a HashTree<T, S>
where
T: Eq + Hash,
S: BuildHasher,
{
type Item = &'a T;
type IntoIter = Iter<'a, T, S>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<T, S> Iterator for IntoIter<T, S>
where
T: Eq + Ord + Hash,
S: BuildHasher,
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.tree.pop_smallest()
}
}
impl<T, S> ExactSizeIterator for Iter<'_, T, S>
where
T: Eq + Hash,
S: BuildHasher,
{
}
impl<T, S> ExactSizeIterator for IntoIter<T, S>
where
T: Eq + Ord + Hash,
S: BuildHasher,
{
}
impl<T, S> FusedIterator for Iter<'_, T, S>
where
T: Eq + Hash,
S: BuildHasher,
{
}
impl<T, S> FusedIterator for IntoIter<T, S>
where
T: Eq + Ord + Hash,
S: BuildHasher,
{
}
impl<T, S> IntoIterator for HashTree<T, S>
where
T: Eq + Ord + Hash,
S: BuildHasher,
{
type Item = T;
type IntoIter = IntoIter<T, S>;
fn into_iter(self) -> Self::IntoIter {
IntoIter {
tree: self
}
}
}
impl<T, S> FromIterator<T> for HashTree<T, S>
where
T: Eq + Ord + Hash,
S: Default + BuildHasher,
{
fn from_iter<I>(iter: I) -> Self
where
I: IntoIterator<Item = T>,
{
let mut tree = HashTree::<T, S>::default();
for value in iter {
tree.insert(value);
}
tree
}
}
impl<T, S> Extend<T> for HashTree<T, S>
where
T: Eq + Ord + Hash,
S: BuildHasher,
{
fn extend<I>(&mut self, iter: I)
where
I: IntoIterator<Item = T>,
{
for value in iter {
self.insert(value);
}
}
}
impl<T, S> Hash for HashTree<T, S>
where
T: Eq + Hash,
S: BuildHasher,
{
fn hash<H>(&self, state: &mut H)
where
H: Hasher,
{
self.len().hash(state);
for value in self {
value.hash(state);
}
}
}
impl<T, S> Debug for HashTree<T, S>
where
T: Eq + Hash + Debug,
S: BuildHasher,
{
fn fmt(&self, fmt: &mut Formatter<'_>) -> fmt::Result {
fmt.write_str("HashTree(")?;
fmt.debug_list().entries(self).finish()?;
fmt.write_str(")")?;
Ok(())
}
}
impl<T, S> Drop for HashTree<T, S> {
fn drop(&mut self) {
self.map.drain().for_each(|(_, entry)| unsafe {
let mut entry = *Box::from_raw(entry.as_ptr());
ptr::drop_in_place(entry.data.as_mut_ptr());
});
}
}
impl<T, S> Serialize for HashTree<T, S>
where
T: Eq + Hash + Serialize,
S: BuildHasher,
{
fn serialize<Se>(&self, serializer: Se) -> Result<Se::Ok, Se::Error>
where
Se: Serializer,
{
let mut seq = serializer.serialize_seq(Some(self.len()))?;
for value in self {
seq.serialize_element(value)?;
}
seq.end()
}
}
struct HashTreeVisitor<T, S> {
marker: PhantomData<(T, S)>,
}
impl<'de, T, S> Visitor<'de> for HashTreeVisitor<T, S>
where
T: Eq + Ord + Hash + Deserialize<'de>,
S: Default + BuildHasher,
{
type Value = HashTree<T, S>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a hash tree")
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut tree = HashTree::<T, S>::default();
while let Some(value) = seq.next_element()? {
tree.insert(value);
}
Ok(tree)
}
}
impl<'de, T, S> Deserialize<'de> for HashTree<T, S>
where
T: Eq + Ord + Hash + Deserialize<'de>,
S: Default + BuildHasher,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let visitor = HashTreeVisitor {
marker: PhantomData,
};
deserializer.deserialize_seq(visitor)
}
}
unsafe impl<T, S> Send for HashTree<T, S> {}
unsafe impl<T, S> Sync for HashTree<T, S> {}
#[cfg(test)]
mod tests {
use std::{
borrow::Borrow,
cmp::Ordering,
hash::{Hash, Hasher},
};
use droptest::{DropGuard, DropRegistry, assert_drop, assert_no_drop};
use serde_test::{Token, assert_tokens};
use crate::collections::hash_tree::{Entry, HashTree};
struct DroppableObject<'a> {
id: u64,
guard: DropGuard<'a, ()>,
}
#[test]
fn it_inserts_correctly() {
let mut tree = HashTree::<u64>::default();
assert_eq!(tree.insert(1), None);
let (root_data, root_height) = get_entry_info(tree.root);
assert_eq!(root_data, 1);
assert_eq!(root_height, 1);
let (l, r) = get_entry_children(tree.root);
assert!(l.is_null());
assert!(r.is_null());
assert_eq!(tree.insert(2), None);
let (root_data, root_height) = get_entry_info(tree.root);
assert_eq!(root_data, 1);
assert_eq!(root_height, 2);
let (l, r) = get_entry_children(tree.root);
assert!(l.is_null());
assert!(!r.is_null());
let (r_data, r_height) = get_entry_info(r);
assert_eq!(r_data, 2);
assert_eq!(r_height, 1);
assert_eq!(tree.insert(3), None);
let (root_data, root_height) = get_entry_info(tree.root);
assert_eq!(root_data, 2);
assert_eq!(root_height, 2);
let (l, r) = get_entry_children(tree.root);
assert!(!l.is_null());
assert!(!r.is_null());
let (l_data, l_height) = get_entry_info(l);
assert_eq!(l_data, 1);
assert_eq!(l_height, 1);
let (r_data, r_height) = get_entry_info(r);
assert_eq!(r_data, 3);
assert_eq!(r_height, 1);
assert_eq!(tree.insert(4), None);
let (root_data, root_height) = get_entry_info(tree.root);
assert_eq!(root_data, 2);
assert_eq!(root_height, 3);
let (l, r) = get_entry_children(tree.root);
assert!(!l.is_null());
assert!(!r.is_null());
let (l_data, l_height) = get_entry_info(l);
assert_eq!(l_data, 1);
assert_eq!(l_height, 1);
let (r_data, r_height) = get_entry_info(r);
assert_eq!(r_data, 3);
assert_eq!(r_height, 2);
let (rl, rr) = get_entry_children(r);
assert!(rl.is_null());
assert!(!rr.is_null());
let (rr_data, rr_height) = get_entry_info(rr);
assert_eq!(rr_data, 4);
assert_eq!(rr_height, 1);
assert_eq!(tree.insert(5), None);
let (root_data, root_height) = get_entry_info(tree.root);
assert_eq!(root_data, 2);
assert_eq!(root_height, 3);
let (l, r) = get_entry_children(tree.root);
assert!(!l.is_null());
assert!(!r.is_null());
let (l_data, l_height) = get_entry_info(l);
assert_eq!(l_data, 1);
assert_eq!(l_height, 1);
let (r_data, r_height) = get_entry_info(r);
assert_eq!(r_data, 4);
assert_eq!(r_height, 2);
let (rl, rr) = get_entry_children(r);
assert!(!rl.is_null());
assert!(!rr.is_null());
let (rl_data, rl_height) = get_entry_info(rl);
assert_eq!(rl_data, 3);
assert_eq!(rl_height, 1);
let (rr_data, rr_height) = get_entry_info(rr);
assert_eq!(rr_data, 5);
assert_eq!(rr_height, 1);
assert_eq!(tree.insert(6), None);
let (root_data, root_height) = get_entry_info(tree.root);
assert_eq!(root_data, 4);
assert_eq!(root_height, 3);
let (l, r) = get_entry_children(tree.root);
assert!(!l.is_null());
assert!(!r.is_null());
let (l_data, l_height) = get_entry_info(l);
assert_eq!(l_data, 2);
assert_eq!(l_height, 2);
let (ll, lr) = get_entry_children(l);
assert!(!ll.is_null());
assert!(!lr.is_null());
let (ll_data, ll_height) = get_entry_info(ll);
assert_eq!(ll_data, 1);
assert_eq!(ll_height, 1);
let (lr_data, lr_height) = get_entry_info(lr);
assert_eq!(lr_data, 3);
assert_eq!(lr_height, 1);
let (r_data, r_height) = get_entry_info(r);
assert_eq!(r_data, 5);
assert_eq!(r_height, 2);
let (rl, rr) = get_entry_children(r);
assert!(rl.is_null());
assert!(!rr.is_null());
let (rr_data, rr_height) = get_entry_info(rr);
assert_eq!(rr_data, 6);
assert_eq!(rr_height, 1);
assert_eq!(tree.insert(1), Some(1));
let (root_data, root_height) = get_entry_info(tree.root);
assert_eq!(root_data, 4);
assert_eq!(root_height, 3);
let (l, r) = get_entry_children(tree.root);
assert!(!l.is_null());
assert!(!r.is_null());
let (l_data, l_height) = get_entry_info(l);
assert_eq!(l_data, 2);
assert_eq!(l_height, 2);
let (ll, lr) = get_entry_children(l);
assert!(!ll.is_null());
assert!(!lr.is_null());
let (ll_data, ll_height) = get_entry_info(ll);
assert_eq!(ll_data, 1);
assert_eq!(ll_height, 1);
let (lr_data, lr_height) = get_entry_info(lr);
assert_eq!(lr_data, 3);
assert_eq!(lr_height, 1);
let (r_data, r_height) = get_entry_info(r);
assert_eq!(r_data, 5);
assert_eq!(r_height, 2);
let (rl, rr) = get_entry_children(r);
assert!(rl.is_null());
assert!(!rr.is_null());
let (rr_data, rr_height) = get_entry_info(rr);
assert_eq!(rr_data, 6);
assert_eq!(rr_height, 1);
assert_eq!(tree.insert(4), Some(4));
let (root_data, root_height) = get_entry_info(tree.root);
assert_eq!(root_data, 3);
assert_eq!(root_height, 3);
let (l, r) = get_entry_children(tree.root);
assert!(!l.is_null());
assert!(!r.is_null());
let (l_data, l_height) = get_entry_info(l);
assert_eq!(l_data, 2);
assert_eq!(l_height, 2);
let (ll, lr) = get_entry_children(l);
assert!(!ll.is_null());
assert!(lr.is_null());
let (ll_data, ll_height) = get_entry_info(ll);
assert_eq!(ll_data, 1);
assert_eq!(ll_height, 1);
let (r_data, r_height) = get_entry_info(r);
assert_eq!(r_data, 5);
assert_eq!(r_height, 2);
let (rl, rr) = get_entry_children(r);
assert!(!rl.is_null());
assert!(!rr.is_null());
let (rl_data, rl_height) = get_entry_info(rl);
assert_eq!(rl_data, 4);
assert_eq!(rl_height, 1);
let (rr_data, rr_height) = get_entry_info(rr);
assert_eq!(rr_data, 6);
assert_eq!(rr_height, 1);
}
#[test]
fn it_balances_null_root() {
use std::ptr;
use crate::collections::hash_tree::balance_entry;
balance_entry::<u64>(ptr::null_mut());
}
#[test]
fn it_calculates_balance_factor_correctly() {
use crate::collections::hash_tree::balance_factor;
let entry_one = Entry::new(1).as_ptr();
let entry_two = Entry::new(2).as_ptr();
let entry_three = Entry::new(3).as_ptr();
unsafe {
(*entry_one).right = entry_two;
(*entry_two).right = entry_three;
(*entry_one).height = 3;
(*entry_two).height = 2;
(*entry_three).height = 1;
}
let factor = balance_factor(entry_one);
assert_eq!(factor, -2);
}
#[test]
fn it_rr_rotates_null_root() {
use std::ptr;
use crate::collections::hash_tree::rr_rotate;
rr_rotate::<u64>(ptr::null_mut());
}
#[test]
fn it_rr_rotates_null_children() {
use crate::collections::hash_tree::rr_rotate;
let entry = Entry::new(0).as_ptr();
rr_rotate(entry);
}
#[test]
fn it_ll_rotates_null_root() {
use std::ptr;
use crate::collections::hash_tree::ll_rotate;
ll_rotate::<u64>(ptr::null_mut());
}
#[test]
fn it_ll_rotates_null_children() {
use crate::collections::hash_tree::ll_rotate;
let entry = Entry::new(0).as_ptr();
ll_rotate(entry);
}
#[test]
fn it_lr_rotates_null_root() {
use std::ptr;
use crate::collections::hash_tree::lr_rotate;
lr_rotate::<u64>(ptr::null_mut());
}
#[test]
fn it_lr_rotates_null_children() {
use crate::collections::hash_tree::lr_rotate;
let entry = Entry::new(0).as_ptr();
lr_rotate(entry);
}
#[test]
fn it_rl_rotates_null_root() {
use std::ptr;
use crate::collections::hash_tree::rl_rotate;
rl_rotate::<u64>(ptr::null_mut());
}
#[test]
fn it_rl_rotates_null_children() {
use crate::collections::hash_tree::rl_rotate;
let entry = Entry::new(0).as_ptr();
rl_rotate(entry);
}
#[test]
fn it_rr_rotates_correctly() {
use crate::collections::hash_tree::rr_rotate;
let entry_one = Entry::new(1).as_ptr();
let entry_two = Entry::new(2).as_ptr();
let entry_three = Entry::new(3).as_ptr();
unsafe {
(*entry_one).right = entry_two;
(*entry_two).right = entry_three;
(*entry_one).height = 3;
(*entry_two).height = 2;
(*entry_three).height = 1;
let root = rr_rotate(entry_one);
assert_eq!((*root).data.assume_init(), 2);
let (left, right) = get_entry_children(root);
assert_eq!((*left).data.assume_init(), 1);
assert_eq!((*right).data.assume_init(), 3);
assert_eq!((*entry_one).height, 1);
assert_eq!((*entry_two).height, 2);
assert_eq!((*entry_three).height, 1);
}
}
#[test]
fn it_ll_rotates_correctly() {
use crate::collections::hash_tree::ll_rotate;
let entry_one = Entry::new(1).as_ptr();
let entry_two = Entry::new(2).as_ptr();
let entry_three = Entry::new(3).as_ptr();
unsafe {
(*entry_one).left = entry_two;
(*entry_two).left = entry_three;
(*entry_one).height = 3;
(*entry_two).height = 2;
(*entry_three).height = 1;
let root = ll_rotate(entry_one);
assert_eq!((*root).data.assume_init(), 2);
let (left, right) = get_entry_children(root);
assert_eq!((*left).data.assume_init(), 3);
assert_eq!((*right).data.assume_init(), 1);
assert_eq!((*entry_one).height, 1);
assert_eq!((*entry_two).height, 2);
assert_eq!((*entry_three).height, 1);
}
}
#[test]
fn it_lr_rotates_correctly() {
use crate::collections::hash_tree::lr_rotate;
let entry_one = Entry::new(1).as_ptr();
let entry_two = Entry::new(2).as_ptr();
let entry_three = Entry::new(3).as_ptr();
unsafe {
(*entry_one).left = entry_two;
(*entry_two).left = entry_three;
(*entry_one).height = 3;
(*entry_two).height = 2;
(*entry_three).height = 1;
let root = lr_rotate(entry_one);
assert_eq!((*root).data.assume_init(), 2);
let (left, right) = get_entry_children(root);
assert_eq!((*left).data.assume_init(), 3);
assert_eq!((*right).data.assume_init(), 1);
assert_eq!((*entry_one).height, 1);
assert_eq!((*entry_two).height, 2);
assert_eq!((*entry_three).height, 1);
}
}
#[test]
fn it_reorders_updates() {
let mut tree = HashTree::<u64>::default();
tree.insert(1);
tree.insert(2);
tree.insert(3);
let (data, height) = get_entry_info(tree.root);
assert_eq!(data, 2);
assert_eq!(height, 2);
let (l, r) = get_entry_children(tree.root);
assert!(!l.is_null());
assert!(!r.is_null());
let (l_data, l_height) = get_entry_info(l);
assert_eq!(l_data, 1);
assert_eq!(l_height, 1);
let (r_data, r_height) = get_entry_info(r);
assert_eq!(r_data, 3);
assert_eq!(r_height, 1);
tree.update(&2, |value| *value = 4);
let (data, height) = get_entry_info(tree.root);
assert_eq!(data, 3);
assert_eq!(height, 2);
let (l, r) = get_entry_children(tree.root);
assert!(!l.is_null());
assert!(!r.is_null());
let (l_data, l_height) = get_entry_info(l);
assert_eq!(l_data, 1);
assert_eq!(l_height, 1);
let (r_data, r_height) = get_entry_info(r);
assert_eq!(r_data, 4);
assert_eq!(r_height, 1);
tree.update(&4, |value| *value = 2);
let (data, height) = get_entry_info(tree.root);
assert_eq!(data, 2);
assert_eq!(height, 2);
let (l, r) = get_entry_children(tree.root);
assert!(!l.is_null());
assert!(!r.is_null());
let (l_data, l_height) = get_entry_info(l);
assert_eq!(l_data, 1);
assert_eq!(l_height, 1);
let (r_data, r_height) = get_entry_info(r);
assert_eq!(r_data, 3);
assert_eq!(r_height, 1);
}
#[test]
fn it_drops_removed_object() {
let registry = DropRegistry::default();
let mut tree = HashTree::<DroppableObject>::default();
let object1 = DroppableObject::new(®istry, 1);
let object2 = DroppableObject::new(®istry, 2);
let object1_guard_id = object1.guard.id();
let object2_guard_id = object2.guard.id();
tree.insert(object1);
tree.insert(object2);
tree.remove(&1);
assert_drop!(registry, object1_guard_id);
assert_no_drop!(registry, object2_guard_id);
}
#[test]
fn it_drops_cleared_objects() {
let registry = DropRegistry::default();
let mut tree = HashTree::<DroppableObject>::default();
let object1 = DroppableObject::new(®istry, 1);
let object2 = DroppableObject::new(®istry, 2);
let object1_guard_id = object1.guard.id();
let object2_guard_id = object2.guard.id();
tree.insert(object1);
tree.insert(object2);
tree.clear();
assert_drop!(registry, object1_guard_id);
assert_drop!(registry, object2_guard_id);
}
#[test]
fn it_iters_correctly() {
let mut tree = HashTree::<u32>::default();
tree.insert(5);
tree.insert(3);
tree.insert(7);
tree.insert(1);
tree.insert(2);
tree.insert(6);
let mut iter = tree.iter();
assert_eq!(iter.next(), Some(&1));
assert_eq!(iter.next(), Some(&2));
assert_eq!(iter.next(), Some(&3));
assert_eq!(iter.next(), Some(&5));
assert_eq!(iter.next(), Some(&6));
assert_eq!(iter.next(), Some(&7));
assert!(iter.next().is_none());
}
#[test]
fn it_ser_de_empty() {
let tree = HashTree::<u32>::default();
assert_tokens(&tree, &[
Token::Seq {
len: Some(0)
},
Token::SeqEnd,
]);
}
#[test]
fn it_ser_de() {
let tree: HashTree<u32> = [1, 2, 3, 4, 5, 6].into_iter().collect();
assert_tokens(&tree, &[
Token::Seq {
len: Some(6)
},
Token::U32(1),
Token::U32(2),
Token::U32(3),
Token::U32(4),
Token::U32(5),
Token::U32(6),
Token::SeqEnd,
]);
}
fn get_entry_children<T>(entry: *mut Entry<T>) -> (*mut Entry<T>, *mut Entry<T>) {
let left = unsafe { (*entry).left };
let right = unsafe { (*entry).right };
(left, right)
}
fn get_entry_info<T>(entry: *mut Entry<T>) -> (T, usize) {
let data = unsafe { (*entry).data.assume_init_read() };
let height = unsafe { (*entry).height };
(data, height)
}
impl<'a> DroppableObject<'a> {
fn new(registry: &'a DropRegistry, id: u64) -> Self {
DroppableObject {
id,
guard: registry.new_guard(),
}
}
}
impl PartialEq for DroppableObject<'_> {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
impl Eq for DroppableObject<'_> {}
impl PartialOrd for DroppableObject<'_> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for DroppableObject<'_> {
fn cmp(&self, other: &Self) -> Ordering {
self.id.cmp(&other.id)
}
}
impl Borrow<u64> for DroppableObject<'_> {
fn borrow(&self) -> &u64 {
&self.id
}
}
impl Hash for DroppableObject<'_> {
fn hash<H>(&self, state: &mut H)
where
H: Hasher,
{
self.id.hash(state)
}
}
}