use std::sync::Arc;
use arc_swap::ArcSwapOption;
use super::node::OverlayNode;
use crate::persistent_artrie_core::key_encoding::KeyEncoding;
const NULL_PTR: u64 = 0;
pub struct AtomicNodePtr<K: KeyEncoding, V = ()> {
ptr: ArcSwapOption<OverlayNode<K, V>>,
}
impl<K: KeyEncoding, V> std::fmt::Debug for AtomicNodePtr<K, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AtomicNodePtr")
.field("is_null", &self.ptr.load().is_none())
.finish()
}
}
impl<K: KeyEncoding, V: Clone> AtomicNodePtr<K, V> {
pub fn new(node: Arc<OverlayNode<K, V>>) -> Self {
Self {
ptr: ArcSwapOption::new(Some(node)),
}
}
pub fn null() -> Self {
Self {
ptr: ArcSwapOption::empty(),
}
}
#[inline]
pub fn is_null(&self) -> bool {
self.ptr.load().is_none()
}
pub fn load(&self) -> Option<Arc<OverlayNode<K, V>>> {
self.ptr.load_full()
}
#[inline]
pub fn load_unchecked(&self) -> Arc<OverlayNode<K, V>> {
self.load()
.expect("AtomicNodePtr::load_unchecked called on null pointer")
}
pub fn store(&self, node: Arc<OverlayNode<K, V>>) {
self.ptr.store(Some(node));
}
pub fn take(&self) -> Option<Arc<OverlayNode<K, V>>> {
self.ptr.swap(None)
}
pub fn compare_exchange(
&self,
expected: &Arc<OverlayNode<K, V>>,
new: Arc<OverlayNode<K, V>>,
) -> Result<Arc<OverlayNode<K, V>>, Arc<OverlayNode<K, V>>> {
let prev = self.ptr.compare_and_swap(expected, Some(new));
match &*prev {
Some(p) if Arc::ptr_eq(p, expected) => Ok(Arc::clone(p)),
Some(p) => Err(Arc::clone(p)),
None => Err(Arc::new(OverlayNode::new())),
}
}
pub fn compare_exchange_weak(
&self,
expected: &Arc<OverlayNode<K, V>>,
new: Arc<OverlayNode<K, V>>,
) -> Result<Arc<OverlayNode<K, V>>, Arc<OverlayNode<K, V>>> {
self.compare_exchange(expected, new)
}
pub fn try_init(&self, new: Arc<OverlayNode<K, V>>) -> Result<(), Arc<OverlayNode<K, V>>> {
let prev = self
.ptr
.compare_and_swap(&None::<Arc<OverlayNode<K, V>>>, Some(new));
match &*prev {
None => Ok(()),
Some(p) => Err(Arc::clone(p)),
}
}
#[inline]
pub fn as_raw(&self) -> u64 {
self.ptr
.load()
.as_ref()
.map(|node| Arc::as_ptr(node) as u64)
.unwrap_or(NULL_PTR)
}
}
impl<K: KeyEncoding, V: Clone> Clone for AtomicNodePtr<K, V> {
fn clone(&self) -> Self {
match self.load() {
Some(arc) => Self::new(arc),
None => Self::null(),
}
}
}
impl<K: KeyEncoding, V: Clone> Default for AtomicNodePtr<K, V> {
fn default() -> Self {
Self::null()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::persistent_artrie_core::key_encoding::{ByteKey, CharKey};
use crate::persistent_artrie_core::overlay::node::Child;
use crate::persistent_artrie_core::swizzled_ptr::{NodeType, SwizzledPtr};
type ByteNode = OverlayNode<ByteKey, ()>;
type ByteAtomicNodePtr = AtomicNodePtr<ByteKey, ()>;
type CharNode = OverlayNode<CharKey, ()>;
type CharAtomicNodePtr = AtomicNodePtr<CharKey, ()>;
#[test]
fn test_new_and_load_byte() {
let node = Arc::new(ByteNode::new());
let ptr = ByteAtomicNodePtr::new(node);
let loaded = ptr.load().expect("should load");
assert_eq!(loaded.num_children(), 0);
}
#[test]
fn test_null_pointer_char() {
let ptr = CharAtomicNodePtr::null();
assert!(ptr.is_null());
assert!(ptr.load().is_none());
}
#[test]
fn test_store_byte() {
let node1 = Arc::new(ByteNode::new());
let child = Child::OnDisk(SwizzledPtr::on_disk(1, 100, NodeType::Node4));
let node2 = Arc::new(node1.with_child(b'a', child));
let ptr = ByteAtomicNodePtr::new(node1);
assert_eq!(ptr.load().expect("should load").num_children(), 0);
ptr.store(node2);
assert_eq!(ptr.load().expect("should load").num_children(), 1);
}
#[test]
fn test_take_char() {
let node = Arc::new(CharNode::new());
let ptr = CharAtomicNodePtr::new(node);
assert!(!ptr.is_null());
let taken = ptr.take();
assert!(taken.is_some());
assert!(ptr.is_null());
assert!(ptr.take().is_none());
}
#[test]
fn test_compare_exchange_success_byte() {
let node1 = Arc::new(ByteNode::new());
let child = Child::OnDisk(SwizzledPtr::on_disk(1, 100, NodeType::Node4));
let node2 = Arc::new(node1.with_child(b'a', child));
let ptr = ByteAtomicNodePtr::new(node1.clone());
assert!(ptr.compare_exchange(&node1, node2).is_ok());
assert_eq!(ptr.load().expect("should load").num_children(), 1);
}
#[test]
fn test_compare_exchange_failure_char() {
let node1 = Arc::new(CharNode::new());
let child = Child::OnDisk(SwizzledPtr::on_disk(1, 100, NodeType::CharNode4));
let node2 = Arc::new(node1.with_child('a' as u32, child));
let node3 = Arc::new(CharNode::new());
let ptr = CharAtomicNodePtr::new(node1.clone());
assert!(ptr.compare_exchange(&node1, node2).is_ok());
let result = ptr.compare_exchange(&node1, node3);
assert!(result.is_err());
assert_eq!(ptr.load().expect("should load").num_children(), 1);
}
#[test]
fn test_try_init_byte() {
let ptr = ByteAtomicNodePtr::null();
let node = Arc::new(ByteNode::new());
assert!(ptr.try_init(node).is_ok());
assert!(!ptr.is_null());
let other = Arc::new(ByteNode::new());
assert!(ptr.try_init(other).is_err());
}
#[test]
fn test_clone_char() {
let child = Child::OnDisk(SwizzledPtr::on_disk(1, 100, NodeType::CharNode4));
let node = Arc::new(CharNode::new().with_child('a' as u32, child));
let ptr1 = CharAtomicNodePtr::new(node);
let ptr2 = ptr1.clone();
assert_eq!(ptr1.load().expect("load").num_children(), 1);
assert_eq!(ptr2.load().expect("load").num_children(), 1);
}
#[test]
fn test_load_unchecked_byte() {
let node = Arc::new(ByteNode::new());
let ptr = ByteAtomicNodePtr::new(node);
assert_eq!(ptr.load_unchecked().num_children(), 0);
}
#[test]
#[should_panic(expected = "null pointer")]
fn test_load_unchecked_panics_on_null_char() {
let ptr = CharAtomicNodePtr::null();
let _loaded = ptr.load_unchecked();
}
use crate::persistent_artrie_core::key_encoding::KeyEncoding;
use std::thread;
fn check_cas_contract<K: KeyEncoding>() {
let n1 = Arc::new(OverlayNode::<K, ()>::new());
let n2 = Arc::new(n1.as_final());
let n3 = Arc::new(OverlayNode::<K, ()>::new());
let ptr = AtomicNodePtr::<K, ()>::new(Arc::clone(&n1));
assert!(ptr.compare_exchange(&n3, Arc::clone(&n2)).is_err());
assert!(ptr.compare_exchange(&n1, Arc::clone(&n2)).is_ok());
let actual = ptr
.compare_exchange(&n1, Arc::clone(&n3))
.expect_err("stale expected after a winning CAS must fail");
assert!(Arc::ptr_eq(&actual, &n2));
}
fn check_concurrent_cas<K: KeyEncoding>()
where
K::Unit: TryFrom<u32>,
<K::Unit as TryFrom<u32>>::Error: std::fmt::Debug,
{
let ptr = Arc::new(AtomicNodePtr::<K, ()>::new(Arc::new(
OverlayNode::<K, ()>::new(),
)));
let total: usize = (0..8u32)
.map(|t| {
let ptr = Arc::clone(&ptr);
thread::spawn(move || {
let mut wins = 0;
for i in 0..64u32 {
let cur = ptr
.load()
.unwrap_or_else(|| Arc::new(OverlayNode::<K, ()>::new()));
let key = K::Unit::try_from((t * 64 + i) % 250).expect("unit fits");
let child = Child::OnDisk(SwizzledPtr::on_disk(
(t * 64 + i) as u32,
0,
NodeType::Node4,
));
let next = Arc::new(cur.with_child(key, child));
if ptr.compare_exchange(&cur, next).is_ok() {
wins += 1;
}
}
wins
})
})
.collect::<Vec<_>>()
.into_iter()
.map(|h| h.join().expect("thread join"))
.sum();
assert!(total > 0, "at least one CAS must win");
assert!(ptr.load().expect("final load").num_children() > 0);
}
fn check_no_leak_churn<K: KeyEncoding>() {
for _ in 0..500 {
let ptr = AtomicNodePtr::<K, ()>::new(Arc::new(OverlayNode::<K, ()>::new()));
drop(ptr);
}
}
#[test]
fn generic_cas_contract_byte() {
check_cas_contract::<ByteKey>();
}
#[test]
fn generic_cas_contract_char() {
check_cas_contract::<CharKey>();
}
#[test]
fn generic_concurrent_cas_byte() {
check_concurrent_cas::<ByteKey>();
}
#[test]
fn generic_concurrent_cas_char() {
check_concurrent_cas::<CharKey>();
}
#[test]
fn generic_no_leak_churn_byte() {
check_no_leak_churn::<ByteKey>();
}
#[test]
fn generic_no_leak_churn_char() {
check_no_leak_churn::<CharKey>();
}
}