use std::{
alloc::{alloc_zeroed, Layout},
mem::{align_of, size_of},
sync::atomic::Ordering::Relaxed,
};
use super::*;
const FANFACTOR: u64 = 18;
const FANOUT: u64 = 1 << FANFACTOR;
const FAN_MASK: u64 = FANOUT - 1;
#[doc(hidden)]
pub const PAGETABLE_NODE_SZ: usize = size_of::<Node1<()>>();
pub type PageId = u64;
#[inline(always)]
fn split_fanout(i: u64) -> (u64, u64) {
#[cfg(target_pointer_width = "64")]
assert!(
i <= 1 << (FANFACTOR * 2),
"trying to access key of {}, which is \
higher than 2 ^ {}",
i,
(FANFACTOR * 2)
);
let left = i >> FANFACTOR;
let right = i & FAN_MASK;
(left, right)
}
struct Node1<T: Send + 'static> {
children: [Atomic<Node2<T>>; FANOUT as usize],
}
struct Node2<T: Send + 'static> {
children: [Atomic<T>; FANOUT as usize],
}
impl<T: Send + 'static> Node1<T> {
fn new() -> Box<Self> {
let size = size_of::<Self>();
let align = align_of::<Self>();
unsafe {
let layout = Layout::from_size_align_unchecked(size, align);
#[allow(clippy::cast_ptr_alignment)]
let ptr = alloc_zeroed(layout) as *mut Self;
Box::from_raw(ptr)
}
}
}
impl<T: Send + 'static> Node2<T> {
fn new() -> Owned<Self> {
let size = size_of::<Self>();
let align = align_of::<Self>();
unsafe {
let layout = Layout::from_size_align_unchecked(size, align);
#[allow(clippy::cast_ptr_alignment)]
let ptr = alloc_zeroed(layout) as *mut Self;
Owned::from_raw(ptr)
}
}
}
impl<T: Send + 'static> Drop for Node1<T> {
fn drop(&mut self) {
for child in self.children.iter() {
unsafe {
let shared_child = child.load(Relaxed, &unprotected());
if shared_child.as_raw().is_null() {
break;
}
drop(shared_child.into_owned());
}
}
}
}
impl<T: Send + 'static> Drop for Node2<T> {
fn drop(&mut self) {
for child in self.children.iter() {
unsafe {
let shared_child = child.load(Relaxed, &unprotected());
if shared_child.as_raw().is_null() {
break;
}
drop(shared_child.into_owned());
}
}
}
}
pub struct PageTable<T>
where
T: 'static + Send + Sync,
{
head: Atomic<Node1<T>>,
}
impl<T> Default for PageTable<T>
where
T: 'static + Send + Sync,
{
fn default() -> Self {
let head = Node1::new();
Self {
head: Atomic::from(head),
}
}
}
impl<T> PageTable<T>
where
T: 'static + Send + Sync,
{
pub fn swap<'g>(
&self,
pid: PageId,
new: Shared<'g, T>,
guard: &'g Guard,
) -> Shared<'g, T> {
let tip = traverse(self.head.load(Acquire, guard), pid, guard);
debug_delay();
tip.swap(new, SeqCst, guard)
}
pub fn cas<'g>(
&self,
pid: PageId,
old: Shared<'g, T>,
new: Shared<'g, T>,
guard: &'g Guard,
) -> std::result::Result<Shared<'g, T>, Shared<'g, T>> {
debug_delay();
let tip = traverse(self.head.load(Acquire, guard), pid, guard);
debug_delay();
let _ = tip
.compare_and_set(old, new, Release, guard)
.map_err(|e| e.current)?;
if !old.is_null() {
unsafe {
guard.defer_destroy(old);
}
}
Ok(new)
}
pub fn get<'g>(
&self,
pid: PageId,
guard: &'g Guard,
) -> Option<Shared<'g, T>> {
debug_delay();
let tip = traverse(self.head.load(Acquire, guard), pid, guard);
let res = tip.load(Acquire, guard);
if res.is_null() {
None
} else {
Some(res)
}
}
pub fn del<'g>(
&self,
pid: PageId,
guard: &'g Guard,
) -> Option<Shared<'g, T>> {
debug_delay();
let old = self.swap(pid, Shared::null(), guard);
if old.is_null() {
None
} else {
unsafe {
guard.defer_destroy(old);
}
Some(old)
}
}
}
fn traverse<'g, T: 'static + Send>(
head: Shared<'g, Node1<T>>,
k: PageId,
guard: &'g Guard,
) -> &'g Atomic<T> {
let (l1k, l2k) = split_fanout(k);
debug_delay();
let l1 = unsafe { &head.deref().children };
debug_delay();
let mut l2_ptr = l1[usize::try_from(l1k).unwrap()].load(Acquire, guard);
if l2_ptr.is_null() {
let next_child = Node2::new().into_shared(guard);
debug_delay();
let ret = l1[usize::try_from(l1k).unwrap()]
.compare_and_set(l2_ptr, next_child, Release, guard);
match ret {
Ok(_) => {
l2_ptr = next_child;
}
Err(e) => {
l2_ptr = e.current;
}
}
}
debug_delay();
let l2 = unsafe { &l2_ptr.deref().children };
&l2[usize::try_from(l2k).unwrap()]
}
impl<T> Drop for PageTable<T>
where
T: 'static + Send + Sync,
{
fn drop(&mut self) {
unsafe {
let head = self.head.load(Relaxed, &unprotected()).into_owned();
drop(head);
}
}
}
#[test]
fn test_split_fanout() {
assert_eq!(
split_fanout(0b11_1111_1111_1111_1111),
(0, 0b11_1111_1111_1111_1111)
);
assert_eq!(
split_fanout(0b111_1111_1111_1111_1111),
(0b1, 0b11_1111_1111_1111_1111)
);
}
#[test]
fn basic_functionality() {
unsafe {
let guard = pin();
let rt = PageTable::default();
let v1 = Owned::new(5).into_shared(&guard);
rt.cas(0, Shared::null(), v1, &guard).unwrap();
let ptr = rt.get(0, &guard).unwrap();
assert_eq!(ptr.deref(), &5);
rt.cas(0, ptr, Owned::new(6).into_shared(&guard), &guard)
.unwrap();
assert_eq!(rt.get(0, &guard).unwrap().deref(), &6);
rt.del(0, &guard);
assert!(rt.get(0, &guard).is_none());
let k2 = 321 << FANFACTOR;
let k3 = k2 + 1;
let v2 = Owned::new(2).into_shared(&guard);
rt.cas(k2, Shared::null(), v2, &guard).unwrap();
assert_eq!(rt.get(k2, &guard).unwrap().deref(), &2);
assert!(rt.get(k3, &guard).is_none());
let v3 = Owned::new(3).into_shared(&guard);
rt.cas(k3, Shared::null(), v3, &guard).unwrap();
assert_eq!(rt.get(k3, &guard).unwrap().deref(), &3);
assert_eq!(rt.get(k2, &guard).unwrap().deref(), &2);
}
}