#![allow(unsafe_code)]
use std::cell::UnsafeCell;
use std::mem::MaybeUninit;
use std::sync::atomic::{AtomicU8, AtomicUsize, Ordering};
use std::sync::OnceLock;
use crossbeam_epoch::{self as epoch, Atomic, Guard, Owned};
use crate::key::Key;
use crate::model::LinearModel;
type ChildArray<K, V> = OnceLock<Box<[Atomic<Node<K, V>>]>>;
pub const SLOT_EMPTY: u8 = 0;
pub const SLOT_WRITING: u8 = 1;
pub const SLOT_DATA: u8 = 2;
pub const SLOT_CHILD: u8 = 3;
pub const SLOT_CHILD_STALE: u8 = 4;
pub const SLOT_TOMBSTONE: u8 = 5;
#[inline]
pub fn is_child(state: u8) -> bool {
state == SLOT_CHILD || state == SLOT_CHILD_STALE
}
pub struct Node<K, V> {
model: LinearModel,
states: Box<[AtomicU8]>,
keys: Box<[UnsafeCell<MaybeUninit<K>>]>,
values: Box<[UnsafeCell<MaybeUninit<V>>]>,
children: ChildArray<K, V>,
num_keys: AtomicUsize,
num_tombstones: AtomicUsize,
split_key: Option<K>,
}
impl<K, V> std::fmt::Debug for Node<K, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Node")
.field("model", &self.model)
.field("capacity", &self.states.len())
.field("num_keys", &self.num_keys.load(Ordering::Relaxed))
.field(
"num_tombstones",
&self.num_tombstones.load(Ordering::Relaxed),
)
.field("has_split_key", &self.split_key.is_some())
.finish_non_exhaustive()
}
}
impl<K: Key, V> Node<K, V> {
pub fn with_capacity(model: LinearModel, array_size: usize) -> Self {
let children = OnceLock::new();
let _ = children.set(
(0..array_size)
.map(|_| Atomic::null())
.collect::<Vec<_>>()
.into_boxed_slice(),
);
Self {
model,
states: (0..array_size)
.map(|_| AtomicU8::new(SLOT_EMPTY))
.collect::<Vec<_>>()
.into_boxed_slice(),
keys: (0..array_size)
.map(|_| UnsafeCell::new(MaybeUninit::uninit()))
.collect::<Vec<_>>()
.into_boxed_slice(),
values: (0..array_size)
.map(|_| UnsafeCell::new(MaybeUninit::uninit()))
.collect::<Vec<_>>()
.into_boxed_slice(),
children,
num_keys: AtomicUsize::new(0),
num_tombstones: AtomicUsize::new(0),
split_key: None,
}
}
pub fn with_capacity_leaf(model: LinearModel, array_size: usize) -> Self {
Self {
model,
states: (0..array_size)
.map(|_| AtomicU8::new(SLOT_EMPTY))
.collect::<Vec<_>>()
.into_boxed_slice(),
keys: (0..array_size)
.map(|_| UnsafeCell::new(MaybeUninit::uninit()))
.collect::<Vec<_>>()
.into_boxed_slice(),
values: (0..array_size)
.map(|_| UnsafeCell::new(MaybeUninit::uninit()))
.collect::<Vec<_>>()
.into_boxed_slice(),
children: OnceLock::new(),
num_keys: AtomicUsize::new(0),
num_tombstones: AtomicUsize::new(0),
split_key: None,
}
}
pub fn with_split_key(boundary: K, array_size: usize) -> Self {
let mut node = Self::with_capacity(LinearModel::constant(), array_size);
node.split_key = Some(boundary);
node
}
#[inline]
pub fn predict_slot(&self, key: &K) -> usize {
if let Some(ref sk) = self.split_key {
return if key <= sk {
0
} else {
self.states.len().saturating_sub(1)
};
}
self.model.predict(key, self.states.len())
}
#[inline]
pub fn capacity(&self) -> usize {
self.states.len()
}
#[inline]
fn ensure_children(&self) -> &[Atomic<Self>] {
self.children.get_or_init(|| {
let cap = self.states.len();
(0..cap)
.map(|_| Atomic::null())
.collect::<Vec<_>>()
.into_boxed_slice()
})
}
pub fn store_data(&self, idx: usize, key: K, value: V) {
debug_assert_eq!(
self.states[idx].load(Ordering::Relaxed),
SLOT_EMPTY,
"store_data called on non-empty slot {idx}"
);
unsafe {
(*self.keys[idx].get()) = MaybeUninit::new(key);
(*self.values[idx].get()) = MaybeUninit::new(value);
}
self.states[idx].store(SLOT_DATA, Ordering::Relaxed);
}
pub fn store_child(&self, idx: usize, child: Self) {
debug_assert_eq!(
self.states[idx].load(Ordering::Relaxed),
SLOT_EMPTY,
"store_child called on non-empty slot {idx}"
);
let children = self.ensure_children();
unsafe {
let guard = epoch::unprotected();
children[idx].store(Owned::new(child).into_shared(guard), Ordering::Relaxed);
}
self.states[idx].store(SLOT_CHILD, Ordering::Relaxed);
}
#[inline]
pub fn slot_state(&self, idx: usize) -> u8 {
self.states[idx].load(Ordering::Acquire)
}
#[inline]
pub unsafe fn read_key(&self, idx: usize) -> &K {
(*self.keys[idx].get()).assume_init_ref()
}
#[inline]
pub unsafe fn read_value(&self, idx: usize) -> &V {
(*self.values[idx].get()).assume_init_ref()
}
#[inline]
pub fn load_child<'g>(
&self,
idx: usize,
guard: &'g Guard,
) -> crossbeam_epoch::Shared<'g, Self> {
self.children
.get()
.map_or_else(crossbeam_epoch::Shared::null, |children| {
children[idx].load(Ordering::Acquire, guard)
})
}
#[inline]
pub fn child_atomic(&self, idx: usize) -> &Atomic<Self> {
&self.ensure_children()[idx]
}
#[inline]
pub fn state_atomic(&self, idx: usize) -> &AtomicU8 {
&self.states[idx]
}
pub fn cas_empty_to_data(&self, idx: usize, key: K, value: V) -> bool {
if self.states[idx]
.compare_exchange(
SLOT_EMPTY,
SLOT_WRITING,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_err()
{
return false;
}
unsafe {
(*self.keys[idx].get()) = MaybeUninit::new(key);
(*self.values[idx].get()) = MaybeUninit::new(value);
}
self.states[idx].store(SLOT_DATA, Ordering::Release);
true
}
pub fn cas_tombstone_to_child_stale(&self, idx: usize, child: Self, guard: &Guard) -> bool {
if self.states[idx]
.compare_exchange(
SLOT_TOMBSTONE,
SLOT_WRITING,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_err()
{
return false;
}
self.ensure_children()[idx].store(Owned::new(child).into_shared(guard), Ordering::Release);
self.states[idx].store(SLOT_CHILD_STALE, Ordering::Release);
true
}
pub fn cas_data_to_child_stale(&self, idx: usize, child: Self, guard: &Guard) -> bool {
if self.states[idx]
.compare_exchange(SLOT_DATA, SLOT_WRITING, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
return false;
}
self.ensure_children()[idx].store(Owned::new(child).into_shared(guard), Ordering::Release);
self.states[idx].store(SLOT_CHILD_STALE, Ordering::Release);
true
}
pub fn cas_data_to_tombstone(&self, idx: usize) -> bool {
self.states[idx]
.compare_exchange(
SLOT_DATA,
SLOT_TOMBSTONE,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
}
pub fn inc_keys(&self) {
self.num_keys.fetch_add(1, Ordering::Relaxed);
}
pub fn dec_keys(&self) {
self.num_keys.fetch_sub(1, Ordering::Relaxed);
}
pub fn inc_tombstones(&self) {
self.num_tombstones.fetch_add(1, Ordering::Relaxed);
}
pub fn dec_tombstones(&self) {
self.num_tombstones.fetch_sub(1, Ordering::Relaxed);
}
pub fn tombstone_ratio(&self) -> f64 {
let cap = self.states.len();
if cap == 0 {
return 0.0;
}
self.num_tombstones.load(Ordering::Relaxed) as f64 / cap as f64
}
pub fn total_keys(&self, guard: &Guard) -> usize {
let children = self.children.get();
let mut count = 0;
for i in 0..self.states.len() {
let state = self.states[i].load(Ordering::Acquire);
match state {
SLOT_DATA => count += 1,
s if is_child(s) => {
if let Some(c) = children {
let child_shared = c[i].load(Ordering::Acquire, guard);
if !child_shared.is_null() {
count += unsafe { child_shared.deref() }.total_keys(guard);
}
}
}
_ => {}
}
}
count
}
pub fn allocated_bytes(&self, guard: &Guard) -> usize {
let node_size = std::mem::size_of::<Self>();
let cap = self.states.len();
let children = self.children.get();
let children_cap = children.map_or(0, |c| c.len());
let arrays_size = cap
* (std::mem::size_of::<AtomicU8>()
+ std::mem::size_of::<UnsafeCell<MaybeUninit<K>>>()
+ std::mem::size_of::<UnsafeCell<MaybeUninit<V>>>())
+ children_cap * std::mem::size_of::<Atomic<Self>>();
let mut total = node_size + arrays_size;
if let Some(c) = children {
for i in 0..cap {
let state = self.states[i].load(Ordering::Acquire);
if is_child(state) {
let child_shared = c[i].load(Ordering::Acquire, guard);
if !child_shared.is_null() {
total += unsafe { child_shared.deref() }.allocated_bytes(guard);
}
}
}
}
total
}
pub fn max_depth(&self, guard: &Guard) -> usize {
let mut max_child_depth = 0;
if let Some(children) = self.children.get() {
for i in 0..self.states.len() {
let state = self.states[i].load(Ordering::Acquire);
if is_child(state) {
let child_shared = children[i].load(Ordering::Acquire, guard);
if !child_shared.is_null() {
let depth = unsafe { child_shared.deref() }.max_depth(guard);
max_child_depth = max_child_depth.max(depth);
}
}
}
}
1 + max_child_depth
}
}
impl<K, V> Drop for Node<K, V> {
fn drop(&mut self) {
unsafe {
let guard = epoch::unprotected();
let children = self.children.get();
for i in 0..self.states.len() {
let state = *self.states[i].get_mut();
let has_inline =
state == SLOT_DATA || state == SLOT_CHILD_STALE || state == SLOT_TOMBSTONE;
if has_inline && std::mem::needs_drop::<K>() {
std::ptr::drop_in_place((*self.keys[i].get()).as_mut_ptr());
}
if has_inline && std::mem::needs_drop::<V>() {
std::ptr::drop_in_place((*self.values[i].get()).as_mut_ptr());
}
if let Some(c) = children {
if is_child(state) {
let shared = c[i].load(Ordering::Relaxed, guard);
if !shared.is_null() {
drop(shared.into_owned());
}
}
}
}
}
}
}
unsafe impl<K: Send + Sync, V: Send + Sync> Send for Node<K, V> {}
unsafe impl<K: Send + Sync, V: Send + Sync> Sync for Node<K, V> {}
#[cfg(test)]
mod tests {
use super::*;
fn guard() -> epoch::Guard {
epoch::pin()
}
#[test]
fn new_node_all_empty() {
let g = guard();
let node = Node::<u64, String>::with_capacity(LinearModel::constant(), 10);
assert_eq!(node.capacity(), 10);
assert_eq!(node.total_keys(&g), 0);
}
#[test]
fn total_keys_empty() {
let g = guard();
let node = Node::<u64, ()>::with_capacity(LinearModel::constant(), 5);
assert_eq!(node.total_keys(&g), 0);
}
#[test]
fn total_keys_with_data() {
let g = guard();
let node = Node::<u64, &str>::with_capacity(LinearModel::constant(), 5);
node.store_data(0, 1, "a");
node.inc_keys();
node.store_data(2, 2, "b");
node.inc_keys();
assert_eq!(node.total_keys(&g), 2);
}
#[test]
fn total_keys_with_children() {
let g = guard();
let child = Node::<u64, &str>::with_capacity(LinearModel::constant(), 3);
child.store_data(0, 10, "x");
child.inc_keys();
child.store_data(1, 20, "y");
child.inc_keys();
let parent = Node::<u64, &str>::with_capacity(LinearModel::constant(), 5);
parent.store_data(0, 1, "a");
parent.inc_keys();
parent.store_child(1, child);
assert_eq!(parent.total_keys(&g), 3);
}
#[test]
fn max_depth_leaf() {
let g = guard();
let node = Node::<u64, ()>::with_capacity(LinearModel::constant(), 5);
assert_eq!(node.max_depth(&g), 1);
}
#[test]
fn max_depth_nested() {
let g = guard();
let leaf = Node::<u64, ()>::with_capacity(LinearModel::constant(), 2);
let mid = Node::<u64, ()>::with_capacity(LinearModel::constant(), 2);
mid.store_child(0, leaf);
let root = Node::<u64, ()>::with_capacity(LinearModel::constant(), 2);
root.store_child(0, mid);
assert_eq!(root.max_depth(&g), 3);
}
#[test]
fn store_and_read_data() {
let node = Node::<u64, i32>::with_capacity(LinearModel::constant(), 4);
node.store_data(1, 42, 100);
node.inc_keys();
assert_eq!(node.slot_state(1), SLOT_DATA);
unsafe {
assert_eq!(*node.read_key(1), 42);
assert_eq!(*node.read_value(1), 100);
}
}
#[test]
fn cas_empty_to_data_success() {
let g = guard();
let node = Node::<u64, &str>::with_capacity(LinearModel::constant(), 4);
assert!(node.cas_empty_to_data(0, 10, "hello"));
assert_eq!(node.slot_state(0), SLOT_DATA);
unsafe {
assert_eq!(*node.read_key(0), 10);
assert_eq!(*node.read_value(0), "hello");
}
assert_eq!(node.total_keys(&g), 1);
}
#[test]
fn cas_empty_to_data_fails_on_occupied() {
let node = Node::<u64, u64>::with_capacity(LinearModel::constant(), 4);
assert!(node.cas_empty_to_data(0, 1, 10));
assert!(!node.cas_empty_to_data(0, 2, 20));
unsafe {
assert_eq!(*node.read_key(0), 1);
assert_eq!(*node.read_value(0), 10);
}
}
#[test]
fn cas_data_to_child_stale() {
let g = guard();
let node = Node::<u64, u64>::with_capacity(LinearModel::constant(), 4);
node.store_data(0, 10, 100);
node.inc_keys();
let child = Node::<u64, u64>::with_capacity(LinearModel::constant(), 2);
child.store_data(0, 10, 200);
child.inc_keys();
assert!(node.cas_data_to_child_stale(0, child, &g));
assert_eq!(node.slot_state(0), SLOT_CHILD_STALE);
let child_shared = node.load_child(0, &g);
assert!(!child_shared.is_null());
}
#[test]
fn cas_data_to_tombstone() {
let node = Node::<u64, u64>::with_capacity(LinearModel::constant(), 4);
node.store_data(0, 10, 100);
assert!(node.cas_data_to_tombstone(0));
assert_eq!(node.slot_state(0), SLOT_TOMBSTONE);
}
#[test]
fn drop_with_inline_data() {
let node = Node::<u64, String>::with_capacity(LinearModel::constant(), 4);
node.store_data(0, 1, "hello".to_string());
node.store_data(1, 2, "world".to_string());
drop(node);
}
}