use super::leaf::*;
use super::node::*;
use crate::{CACHE_LINE_SIZE, trace_log};
use alloc::alloc::{Layout, dealloc};
use core::borrow::Borrow;
use core::fmt;
use core::marker::PhantomData;
use core::mem::{MaybeUninit, align_of, needs_drop, size_of};
use core::ops::{Deref, DerefMut};
use core::ptr::{self, NonNull};
const INTER_KEY_HEAD_SIZE: usize = 8;
const INTER_PTR_HEAD_SIZE: usize = 0;
pub(super) struct InterNode<K, V> {
base: NodeBase,
_phan: PhantomData<fn(&K, &V)>,
}
impl<K, V> Clone for InterNode<K, V> {
#[inline(always)]
fn clone(&self) -> Self {
Self { base: self.base.clone(), _phan: Default::default() }
}
}
impl<K, V> Deref for InterNode<K, V> {
type Target = NodeBase;
fn deref(&self) -> &Self::Target {
&self.base
}
}
impl<K, V> DerefMut for InterNode<K, V> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.base
}
}
impl<K, V> InterNode<K, V> {
const LAYOUT: (u32, Layout) = Self::cal_layout();
pub(super) const UNDERFLOW_CAP: u32 = Self::LAYOUT.0 / 3;
const fn cal_layout() -> (u32, Layout) {
let mut align = align_of::<K>();
assert!(align <= 8);
let key_size = size_of::<K>();
if align < PTR_ALIGN {
align = PTR_ALIGN;
}
assert!(size_of::<NodeHeader>() == INTER_KEY_HEAD_SIZE);
assert!(key_size <= CACHE_LINE_SIZE - 16);
let mut inter_key_cap = (AREA_SIZE - INTER_KEY_HEAD_SIZE) / key_size;
let inter_value_cap = (AREA_SIZE - INTER_PTR_HEAD_SIZE) / PTR_SIZE;
if inter_key_cap > inter_value_cap - 1 {
inter_key_cap = inter_value_cap - 1;
}
match Layout::from_size_align(NODE_SIZE, align) {
Ok(l) => (inter_key_cap as u32, l),
Err(_) => panic!("invalid layout"),
}
}
#[inline(always)]
pub unsafe fn alloc(height: u32) -> Self {
let mut base = NodeBase::_alloc(Self::LAYOUT.1);
let header = base.get_header_mut();
header.height = height; header.count = 0;
Self { base, _phan: Default::default() }
}
#[inline(always)]
pub fn dealloc<const DROP_ITEM: bool>(self) {
unsafe {
if DROP_ITEM && needs_drop::<K>() {
let count = self.key_count();
for i in 0..count {
(*self.key_ptr(i)).assume_init_drop();
}
}
dealloc(self.base.header.as_ptr() as *mut u8, Self::LAYOUT.1);
}
}
#[cfg(test)]
pub(super) fn get_keys(&self) -> &[K] {
self.base.get_array::<K>(INTER_KEY_HEAD_SIZE, 0)
}
#[inline]
pub const fn cap() -> u32 {
Self::LAYOUT.0
}
#[inline(always)]
pub(crate) unsafe fn from_header(header: *mut NodeHeader) -> Self {
unsafe {
debug_assert!(!(*header).is_leaf());
Self {
base: NodeBase { header: NonNull::new_unchecked(header) },
_phan: Default::default(),
}
}
}
#[inline(always)]
pub(crate) fn set_left_ptr(&mut self, child_ptr: *mut NodeHeader) {
unsafe {
let p = self.child_ptr(0);
p.write(child_ptr)
}
}
#[inline(always)]
pub fn is_full(&self) -> bool {
let avail = Self::cap() - self.key_count();
avail == 0
}
#[inline(always)]
pub unsafe fn key_ptr(&self, idx: u32) -> *mut MaybeUninit<K> {
unsafe { self.base.item_ptr::<MaybeUninit<K>>(INTER_KEY_HEAD_SIZE, idx) }
}
#[inline(always)]
pub unsafe fn child_ptr(&self, idx: u32) -> *mut *mut NodeHeader {
unsafe { self.base.item_ptr::<*mut NodeHeader>(AREA_SIZE + INTER_PTR_HEAD_SIZE, idx) }
}
#[inline(always)]
pub fn get_child_ptr(&self, idx: u32) -> *mut NodeHeader {
unsafe { *self.child_ptr(idx) }
}
#[inline]
pub fn get_child(&self, idx: u32) -> Node<K, V> {
unsafe {
let child_ptr = *self.child_ptr(idx);
if child_ptr.is_null() {
panic!("{:?} child {idx} is null", self);
} else if (*child_ptr).is_leaf() {
Node::Leaf(LeafNode::<K, V>::from_header(child_ptr))
} else {
Node::Inter(InterNode::<K, V>::from_header(child_ptr))
}
}
}
#[inline]
pub fn get_child_as_inter(&self, idx: u32) -> Self {
unsafe {
let child_ptr = *self.child_ptr(idx);
if child_ptr.is_null() {
panic!("{:?} child {idx} is null", self);
} else {
debug_assert!(!(*child_ptr).is_leaf());
InterNode::<K, V>::from_header(child_ptr)
}
}
}
#[inline]
pub fn get_child_as_leaf(&self, idx: u32) -> LeafNode<K, V> {
unsafe {
let child_ptr = *self.child_ptr(idx);
if child_ptr.is_null() {
panic!("{:?} child {idx} is null", self);
} else {
debug_assert!((*child_ptr).is_leaf());
LeafNode::<K, V>::from_header(child_ptr)
}
}
}
}
impl<K: Ord, V> InterNode<K, V> {
#[inline(always)]
pub fn new_root(
height: u32, promote_key: K, left_ptr: *mut NodeHeader, right_ptr: *mut NodeHeader,
) -> Self {
let mut root = unsafe { Self::alloc(height) };
root.set_left_ptr(left_ptr);
root.insert_no_split_with_idx(0, promote_key, right_ptr);
root
}
#[inline(always)]
pub fn search_child<Q>(&self, key: &Q) -> u32
where
K: Borrow<Q>,
Q: Ord + ?Sized,
{
let (idx, is_equal) = self.base._search::<K, Q>(INTER_KEY_HEAD_SIZE, key);
if is_equal { idx + 1 } else { idx }
}
#[inline(always)]
pub fn search_key<Q>(&self, key: &Q) -> u32
where
K: Borrow<Q>,
Q: Ord + ?Sized,
{
let (idx, _is_equal) = self.base._search::<K, Q>(INTER_KEY_HEAD_SIZE, key);
idx
}
#[cfg(test)]
pub fn insert_no_split(&mut self, key: K, ptr: *mut NodeHeader) {
let idx = self.search_key(&key);
debug_assert!(!self.is_full());
self.insert_no_split_with_idx(idx, key, ptr);
}
#[inline(always)]
pub fn insert_no_split_with_idx(&mut self, idx: u32, key: K, ptr: *mut NodeHeader) {
debug_assert!(self.key_count() < Self::cap());
let _ = unsafe {
self.base._insert::<K, *mut NodeHeader>(
INTER_KEY_HEAD_SIZE,
AREA_SIZE + PTR_SIZE, idx,
key,
ptr,
)
};
}
#[inline(always)]
pub fn insert_at_front(&mut self, left_ptr: *mut NodeHeader, key: K) {
let count = self.key_count();
debug_assert!(count < Self::cap(), "Node is full, cannot insert at front");
unsafe {
if count > 0 {
let src_key = self.key_ptr(0);
let dst_key = self.key_ptr(1);
ptr::copy(src_key, dst_key, count as usize);
let src_child = self.child_ptr(0);
let dst_child = self.child_ptr(1);
ptr::copy(src_child, dst_child, (count + 1) as usize);
} else {
*self.child_ptr(1) = *self.child_ptr(0);
}
(*self.key_ptr(0)).write(key);
(*self.child_ptr(0)) = left_ptr;
self.get_header_mut().count += 1;
}
}
#[inline(always)]
fn copy_right(&mut self, right_node: &mut Self, start_idx: u32, copy_count: u32) {
let right_count = right_node.key_count();
debug_assert!(start_idx + copy_count <= self.key_count());
debug_assert!(right_count + copy_count <= Self::cap());
unsafe {
let src_key = self.key_ptr(start_idx) as *mut K;
let dst_key = right_node.key_ptr(right_count) as *mut K;
ptr::copy_nonoverlapping(src_key, dst_key, copy_count as usize);
let src_child = self.child_ptr(start_idx + 1);
let dst_child = right_node.child_ptr(right_count + 1);
ptr::copy_nonoverlapping(src_child, dst_child, copy_count as usize);
right_node.get_header_mut().count += copy_count;
}
}
pub fn merge(&mut self, right: Self, grand: &mut Self, right_idx: u32) {
let key = grand.remove_mid_child(right_idx);
let right_count = right.key_count();
let mut self_count = self.key_count();
debug_assert!(right_count + self_count + 1 <= Self::cap());
self.insert_no_split_with_idx(self_count, key, right.get_child_ptr(0));
if right_count > 0 {
unsafe {
self_count += 1;
let src_key = right.key_ptr(0) as *mut K;
let dst_key = self.key_ptr(self_count) as *mut K;
ptr::copy_nonoverlapping(src_key, dst_key, right_count as usize);
let src_child = right.child_ptr(1);
let dst_child = self.child_ptr(self_count + 1);
ptr::copy_nonoverlapping(src_child, dst_child, right_count as usize);
self.get_header_mut().count += right_count;
}
}
right.dealloc::<false>();
}
pub fn insert_split(&mut self, key: K, child_ptr: *mut NodeHeader) -> (Self, K) {
let cap = Self::cap();
debug_assert_eq!(self.key_count(), Self::cap());
let idx = self.search_key(&key);
let mut new_node = unsafe { InterNode::<K, V>::alloc(self.height()) };
if idx == cap {
trace_log!("{self:?} insert_split new_node {new_node:?} at cap {idx} {child_ptr:p}");
new_node.set_left_ptr(child_ptr);
return (new_node, key);
}
let split_idx = cap >> 1;
unsafe {
if idx == split_idx {
trace_log!(
"{self:?} insert_split new_node {new_node:?} at split_idx=idx {idx} {child_ptr:p}"
);
new_node.set_left_ptr(child_ptr);
self.copy_right(&mut new_node, split_idx, cap - split_idx);
self.get_header_mut().count = split_idx;
return (new_node, key);
}
let promote_key = (*self.key_ptr(split_idx)).assume_init_read();
new_node.set_left_ptr(*self.child_ptr(split_idx + 1));
if idx < split_idx {
trace_log!(
"{self:?} insert_split new_node {new_node:?} insert {idx} < split_idx {idx} {child_ptr:p}"
);
let right_count = cap - split_idx - 1;
if right_count > 0 {
self.copy_right(&mut new_node, split_idx + 1, right_count);
}
self.get_header_mut().count = split_idx;
self.insert_no_split_with_idx(idx, key, child_ptr);
} else {
trace_log!(
"{self:?} insert_split new_node {new_node:?} insert {idx} > split_idx {idx} {child_ptr:p}"
);
if idx > split_idx + 1 {
self.copy_right(&mut new_node, split_idx + 1, idx - split_idx - 1);
}
new_node.insert_no_split_with_idx(idx - split_idx - 1, key, child_ptr);
if idx < cap {
self.copy_right(&mut new_node, idx, cap - idx);
}
self.get_header_mut().count = split_idx;
}
(new_node, promote_key)
}
}
#[inline]
pub fn find_child_branch(
&self, height: u32, mut idx: u32, left: bool, mut cache: Option<&mut PathCache<K, V>>,
) -> (Self, u32) {
debug_assert!(height > 0);
let mut child = self.get_child_as_inter(idx);
if let Some(_cache) = cache.as_mut() {
_cache.push(self.clone(), idx);
}
idx = if left { 0 } else { child.key_count() };
while child.height() > height {
if let Some(_cache) = cache.as_mut() {
_cache.push(child.clone(), idx);
}
child = child.get_child_as_inter(idx);
}
(child, idx)
}
#[inline(always)]
pub fn change_key(&self, idx: u32, key: K) -> K {
debug_assert!(self.key_count() > idx);
unsafe {
let k_ptr = self.key_ptr(idx);
let old_key = (*k_ptr).assume_init_read();
(*k_ptr).write(key);
old_key
}
}
#[inline]
pub fn remove_last_child(&mut self) -> (K, *mut NodeHeader) {
let idx = self.key_count();
debug_assert!(idx > 0);
self.get_header_mut().count = idx - 1;
unsafe {
let key = (*self.key_ptr(idx - 1)).assume_init_read();
let child = *self.child_ptr(idx);
(key, child)
}
}
#[inline(always)]
pub fn remove_first_child(&mut self) -> K {
let key_count = self.key_count();
debug_assert!(key_count > 0);
unsafe {
let first_key_ptr = self.key_ptr(0);
let first_child_ptr = self.child_ptr(0);
let first_key = (*first_key_ptr).assume_init_read();
ptr::copy(first_key_ptr.add(1), first_key_ptr, (key_count - 1) as usize);
ptr::copy(first_child_ptr.add(1), first_child_ptr, key_count as usize);
self.get_header_mut().count = key_count - 1;
first_key
}
}
#[inline(always)]
pub fn remove_mid_child(&mut self, child_idx: u32) -> K {
let key_count = self.key_count();
debug_assert!(child_idx > 0);
debug_assert!(child_idx <= key_count);
unsafe {
let key = (*self.key_ptr(child_idx - 1)).assume_init_read();
if child_idx < key_count {
ptr::copy(
self.key_ptr(child_idx),
self.key_ptr(child_idx - 1),
(key_count - child_idx) as usize,
);
}
ptr::copy(
self.child_ptr(child_idx + 1),
self.child_ptr(child_idx),
(key_count - child_idx) as usize,
);
self.get_header_mut().count = key_count - 1;
key
}
}
#[inline]
pub fn insert_rotate_left(
&mut self, parent: &mut Self, my_idx: u32, left: &mut Self, insert_child_idx: u32, key: K,
child_ptr: *mut NodeHeader,
) {
debug_assert!(insert_child_idx <= self.key_count());
debug_assert!(insert_child_idx > 0);
unsafe {
let first_key_ptr = self.key_ptr(0);
let first_child_ptr = self.child_ptr(0);
let first_key = (*first_key_ptr).assume_init_read();
let first_child = *first_child_ptr;
if insert_child_idx > 1 {
ptr::copy(first_key_ptr.add(1), first_key_ptr, (insert_child_idx - 1) as usize);
}
ptr::copy(first_child_ptr.add(1), first_child_ptr, insert_child_idx as usize);
(*self.key_ptr(insert_child_idx - 1)).write(key);
(*self.child_ptr(insert_child_idx)) = child_ptr;
let demote_key = parent.change_key(my_idx - 1, first_key);
left.append(demote_key, first_child);
}
}
#[inline]
pub fn rotate_right(&mut self, parent: &mut Self, my_idx: u32, right: &mut Self) {
let (promote_key, child) = self.remove_last_child();
let demote_key = parent.change_key(my_idx, promote_key);
right.insert_at_front(child, demote_key);
}
#[inline(always)]
pub fn append(&mut self, key: K, child: *mut NodeHeader) {
self.insert_no_split_with_idx(self.key_count(), key, child);
}
}
impl<K, V> fmt::Debug for InterNode<K, V> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"InterNode({:p} height:{}, count:{})",
self.base.header,
self.height(),
self.key_count() + 1
)
}
}
impl<K: fmt::Debug, V> fmt::Display for InterNode<K, V> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let count = self.key_count();
write!(
f,
"InterNode({:p} height:{}, count:{}, keys: [",
self.base.header,
self.height(),
count + 1
)?;
unsafe { write!(f, "{:p}", (*self.child_ptr(0))) }?;
for i in 0..count {
unsafe {
let key = (*self.key_ptr(i)).assume_init_ref();
write!(f, ", ")?;
write!(f, "{:?}|{:p}", key, (*self.child_ptr(i + 1)))?;
}
}
write!(f, "])")
}
}
#[cfg(test)]
impl<K, V> PartialEq for InterNode<K, V> {
#[inline(always)]
fn eq(&self, other: &Self) -> bool {
self.get_ptr() == other.get_ptr()
}
}
impl<K: Ord + fmt::Debug, V: fmt::Debug> InterNode<K, V> {
pub fn validate(&self) {
let count = self.key_count() as usize;
if count == 0 {
return;
}
assert!(
count as u32 <= Self::cap(),
"Internal node has too many keys: {} > {}",
count,
Self::cap()
);
unsafe {
for i in 1..count {
let prev_key = (*self.key_ptr((i - 1) as u32)).assume_init_ref();
let curr_key = (*self.key_ptr(i as u32)).assume_init_ref();
assert!(
prev_key < curr_key,
"Internal node keys not sorted: {:?} >= {:?}",
prev_key,
curr_key
);
}
}
}
}