#![deny(clippy::large_stack_arrays, clippy::large_types_passed_by_value)]
use std::alloc::{alloc_zeroed, Layout};
use std::ptr::null_mut;
use std::sync::{
atomic::{
AtomicBool, AtomicI16, AtomicI32, AtomicI64, AtomicI8, AtomicIsize, AtomicPtr, AtomicU16,
AtomicU32, AtomicU64, AtomicU8, AtomicUsize, Ordering,
},
Arc,
};
const FANOUT: usize = 1 << 16;
fn _impl_send_sync() {
fn is_send<T: Send>() {}
fn is_sync<T: Sync>() {}
is_send::<PageTable<AtomicU8>>();
is_sync::<PageTable<AtomicU8>>();
}
pub trait Zeroable: Sync {}
impl<T: Zeroable, const LEN: usize> Zeroable for [T; LEN] {}
impl<A: Zeroable, B: Zeroable> Zeroable for (A, B) {}
impl<A: Zeroable, B: Zeroable, C: Zeroable> Zeroable for (A, B, C) {}
impl<A: Zeroable, B: Zeroable, C: Zeroable, D: Zeroable> Zeroable for (A, B, C, D) {}
impl<A: Zeroable, B: Zeroable, C: Zeroable, D: Zeroable, E: Zeroable> Zeroable for (A, B, C, D, E) {}
impl<A: Zeroable, B: Zeroable, C: Zeroable, D: Zeroable, E: Zeroable, F: Zeroable> Zeroable
for (A, B, C, D, E, F)
{
}
macro_rules! impl_zeroable {
($($t:ty),+) => {
$(
impl Zeroable for $t {}
)+
};
}
impl_zeroable!(
AtomicBool,
AtomicI8,
AtomicI16,
AtomicI32,
AtomicI64,
AtomicIsize,
AtomicU8,
AtomicU16,
AtomicU32,
AtomicU64,
AtomicUsize
);
impl<T> Zeroable for AtomicPtr<T> {}
pub struct PageTable<T: Zeroable> {
inner: Arc<PageTableInner<T>>,
}
impl<T: Zeroable> Default for PageTable<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Zeroable> Clone for PageTable<T> {
fn clone(&self) -> Self {
PageTable {
inner: self.inner.clone(),
}
}
}
pub struct PageTableInner<T: Zeroable> {
l1: AtomicPtr<L1<T>>,
l2: AtomicPtr<L2<T>>,
l3: AtomicPtr<L3<T>>,
l4: AtomicPtr<L4<T>>,
}
impl<T: Zeroable> Default for PageTableInner<T> {
fn default() -> Self {
PageTableInner {
l1: AtomicPtr::new(null_mut()),
l2: AtomicPtr::new(null_mut()),
l3: AtomicPtr::new(null_mut()),
l4: AtomicPtr::new(null_mut()),
}
}
}
impl<T: Zeroable> Drop for PageTableInner<T> {
fn drop(&mut self) {
let l1 = self.l1.load(Ordering::Acquire);
if !l1.is_null() {
unsafe { drop(Box::from_raw(l1)) }
}
let l2 = self.l2.load(Ordering::Acquire);
if !l2.is_null() {
unsafe { drop(Box::from_raw(l2)) }
}
let l3 = self.l3.load(Ordering::Acquire);
if !l3.is_null() {
unsafe { drop(Box::from_raw(l3)) }
}
let l4 = self.l4.load(Ordering::Acquire);
if !l4.is_null() {
unsafe { drop(Box::from_raw(l4)) }
}
}
}
struct L1<T> {
children: [AtomicPtr<L2<T>>; FANOUT],
}
struct L2<T> {
children: [AtomicPtr<L3<T>>; FANOUT],
}
struct L3<T> {
children: [AtomicPtr<L4<T>>; FANOUT],
}
struct L4<T> {
children: [T; FANOUT],
}
#[inline]
fn punch_through<Child: Zeroable>(atomic_ptr: &AtomicPtr<Child>) -> &Child {
let mut ptr = atomic_ptr.load(Ordering::Acquire);
if ptr.is_null() {
let layout = Layout::new::<Child>();
let new_child_ptr = unsafe { alloc_zeroed(layout) as *mut Child };
assert!(!new_child_ptr.is_null());
match atomic_ptr.compare_exchange(
null_mut(),
new_child_ptr,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
ptr = new_child_ptr;
}
Err(cur_ptr) => {
ptr = cur_ptr;
unsafe {
drop(Box::from_raw(new_child_ptr));
}
}
}
};
unsafe { &*ptr }
}
#[inline]
fn traverse_or_install<Child: Zeroable>(parent: &[AtomicPtr<Child>; FANOUT], key: u16) -> &Child {
let atomic_ptr: &AtomicPtr<Child> = &parent[key as usize];
punch_through(atomic_ptr)
}
impl<T: Zeroable> PageTable<T> {
pub fn new() -> Self {
PageTable {
inner: Arc::default(),
}
}
pub fn get(&self, key: u64) -> &T {
let bytes = key.to_be_bytes();
let k1 = u16::from_be_bytes([bytes[0], bytes[1]]);
let k2 = u16::from_be_bytes([bytes[2], bytes[3]]);
let k3 = u16::from_be_bytes([bytes[4], bytes[5]]);
let k4 = u16::from_be_bytes([bytes[6], bytes[7]]);
let direct_4 = k1 | k2 | k3 == 0;
let direct_3 = k1 | k2 == 0;
let direct_2 = k1 == 0;
let l4 = if direct_4 {
punch_through(&self.inner.l4)
} else if direct_3 {
let l3 = punch_through(&self.inner.l3);
traverse_or_install(&l3.children, k3)
} else if direct_2 {
let l2 = punch_through(&self.inner.l2);
let l3 = traverse_or_install(&l2.children, k2);
traverse_or_install(&l3.children, k3)
} else {
let l1 = punch_through(&self.inner.l1);
let l2 = traverse_or_install(&l1.children, k1);
let l3 = traverse_or_install(&l2.children, k2);
traverse_or_install(&l3.children, k3)
};
&l4.children[k4 as usize]
}
}
impl<T: Zeroable, I: Into<u64>> std::ops::Index<I> for PageTable<T> {
type Output = T;
fn index(&self, index: I) -> &Self::Output {
self.get(index.into())
}
}
impl<T: Zeroable> Zeroable for L1<T> {}
impl<T: Zeroable> Zeroable for L2<T> {}
impl<T: Zeroable> Zeroable for L3<T> {}
impl<T: Zeroable> Zeroable for L4<T> {}
macro_rules! impl_drop_children {
($parent:ty, $child:ty) => {
impl<T> Drop for $parent {
fn drop(&mut self) {
for child in &self.children {
let ptr: *mut $child = child.load(Ordering::Acquire);
if !ptr.is_null() {
unsafe { drop(Box::from_raw(ptr)) }
}
}
}
}
};
}
impl_drop_children!(L1<T>, L2<T>);
impl_drop_children!(L2<T>, L3<T>);
impl_drop_children!(L3<T>, L4<T>);
#[test]
fn smoke() {
#[cfg(miri)]
const N: u64 = 1;
#[cfg(not(miri))]
const N: u64 = 100_000_000;
let pt = PageTable::<AtomicU64>::default();
for i in 0..N {
pt.get(i).fetch_add(1, Ordering::Relaxed);
}
for i in 0..N {
let value = pt.get(i).load(Ordering::Relaxed);
assert_eq!(value, 1);
}
pt.get(u64::MAX).fetch_add(1, Ordering::Relaxed);
let value = pt.get(u64::MAX).load(Ordering::Relaxed);
assert_eq!(value, 1);
}