use std::collections::HashMap;
use std::ptr::NonNull;
pub struct FastPtrBTree<K: Ord, V> {
root: Link<K, V>,
len: usize,
value_to_node: HashMap<usize, NonNull<BTreeNode<K, V>>>,
}
type Link<K, V> = Option<NonNull<BTreeNode<K, V>>>;
pub struct BTreeNode<K, V> {
key: K,
value: V,
left: Link<K, V>,
right: Link<K, V>,
}
impl<K, V> BTreeNode<K, V> {
pub fn key(&self) -> &K {
&self.key
}
pub fn value(&self) -> &V {
&self.value
}
}
impl<K: Ord, V> FastPtrBTree<K, V> {
pub fn new() -> Self {
Self {
root: None,
len: 0,
value_to_node: HashMap::new(),
}
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn get(&self, key: &K) -> Option<&V> {
self.find_node(key).map(|node| unsafe { &node.as_ref().value })
}
pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
self.find_node(key).map(|mut node| unsafe { &mut node.as_mut().value })
}
pub fn value_ptr_of_key(&self, key: &K) -> Option<*const V> {
self.find_node(key)
.map(|node| unsafe { std::ptr::addr_of!(node.as_ref().value) })
}
pub fn node_from_value_ptr(&self, value_ptr: *const V) -> Option<&BTreeNode<K, V>> {
let addr = value_ptr as usize;
let node_ptr = self.value_to_node.get(&addr)?;
Some(unsafe { node_ptr.as_ref() })
}
pub fn node_from_value_ptr_mut(&mut self, value_ptr: *const V) -> Option<&mut BTreeNode<K, V>> {
let addr = value_ptr as usize;
let mut node_ptr = *self.value_to_node.get(&addr)?;
Some(unsafe { node_ptr.as_mut() })
}
pub fn insert(&mut self, key: K, value: V) -> Option<V> {
let mut cur = &mut self.root;
while let Some(mut nn) = *cur {
unsafe {
let node = nn.as_mut();
if key < node.key {
cur = &mut node.left;
} else if key > node.key {
cur = &mut node.right;
} else {
let old_addr = std::ptr::addr_of!(node.value) as usize;
let old = std::mem::replace(&mut node.value, value);
self.value_to_node.remove(&old_addr);
let new_addr = std::ptr::addr_of!(node.value) as usize;
self.value_to_node.insert(new_addr, nn);
return Some(old);
}
}
}
let boxed = Box::new(BTreeNode {
key,
value,
left: None,
right: None,
});
let leaked = Box::leak(boxed);
let nn = NonNull::from(leaked);
let value_addr = unsafe { std::ptr::addr_of!(nn.as_ref().value) as usize };
self.value_to_node.insert(value_addr, nn);
*cur = Some(nn);
self.len += 1;
None
}
fn find_node(&self, key: &K) -> Option<NonNull<BTreeNode<K, V>>> {
let mut cur = self.root;
while let Some(nn) = cur {
unsafe {
let node = nn.as_ref();
if key < &node.key {
cur = node.left;
} else if key > &node.key {
cur = node.right;
} else {
return Some(nn);
}
}
}
None
}
}
impl<K: Ord, V> Default for FastPtrBTree<K, V> {
fn default() -> Self {
Self::new()
}
}
impl<K: Ord, V> Drop for FastPtrBTree<K, V> {
fn drop(&mut self) {
let mut stack = Vec::new();
if let Some(root) = self.root {
stack.push(root);
}
while let Some(node_ptr) = stack.pop() {
unsafe {
let node = node_ptr.as_ref();
if let Some(left) = node.left {
stack.push(left);
}
if let Some(right) = node.right {
stack.push(right);
}
drop(Box::from_raw(node_ptr.as_ptr()));
}
}
}
}
#[cfg(test)]
mod tests {
use super::FastPtrBTree;
#[test]
fn fetch_node_from_its_value_ptr() {
let mut tree = FastPtrBTree::new();
tree.insert(7, String::from("root"));
tree.insert(3, String::from("left"));
tree.insert(9, String::from("right"));
let vptr = tree.value_ptr_of_key(&3).unwrap();
let node = tree.node_from_value_ptr(vptr).unwrap();
assert_eq!(node.key(), &3);
assert_eq!(node.value(), "left");
}
#[test]
fn pointer_index_updates_when_replacing_value() {
let mut tree = FastPtrBTree::new();
tree.insert(1, String::from("a"));
let old_ptr = tree.value_ptr_of_key(&1).unwrap();
let old = tree.insert(1, String::from("b"));
assert_eq!(old.as_deref(), Some("a"));
let node = tree.node_from_value_ptr(old_ptr).unwrap();
assert_eq!(node.value(), "b");
}
}