use crossbeam_epoch::{self as epoch, Atomic, Guard, Owned, Shared};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
type CompareExchangeResult<'g, T> =
Result<Shared<'g, VersionedNode<T>>, (Shared<'g, VersionedNode<T>>, Owned<VersionedNode<T>>)>;
pub struct EpochManager {
global_epoch: Arc<AtomicUsize>,
_phantom: std::marker::PhantomData<()>,
}
impl EpochManager {
pub fn new() -> Self {
Self {
global_epoch: Arc::new(AtomicUsize::new(0)),
_phantom: std::marker::PhantomData,
}
}
pub fn pin(&self) -> Guard {
epoch::pin()
}
pub fn advance(&self) {
self.global_epoch.fetch_add(1, Ordering::Release);
}
pub fn current_epoch(&self) -> usize {
self.global_epoch.load(Ordering::Acquire)
}
pub fn defer<F>(&self, guard: &Guard, f: F)
where
F: FnOnce() + Send + 'static,
{
guard.defer(f);
}
pub fn flush(&self, guard: &Guard) {
guard.flush();
}
}
impl Default for EpochManager {
fn default() -> Self {
Self::new()
}
}
pub struct VersionedPointer<T> {
ptr: Atomic<VersionedNode<T>>,
}
pub struct VersionedNode<T> {
data: T,
version: usize,
}
impl<T> VersionedPointer<T> {
pub fn new(data: T) -> Self {
let node = VersionedNode { data, version: 0 };
Self {
ptr: Atomic::new(node),
}
}
pub fn load<'g>(&self, guard: &'g Guard) -> Option<&'g T> {
let shared = self.ptr.load(Ordering::Acquire, guard);
unsafe { shared.as_ref().map(|node| &node.data) }
}
pub fn compare_and_swap<'g>(
&self,
current: Shared<'g, VersionedNode<T>>,
new: Owned<VersionedNode<T>>,
guard: &'g Guard,
) -> CompareExchangeResult<'g, T> {
match self
.ptr
.compare_exchange(current, new, Ordering::Release, Ordering::Acquire, guard)
{
Ok(shared) => Ok(shared),
Err(e) => Err((e.current, e.new)),
}
}
pub fn update(&self, data: T, version: usize, guard: &Guard) -> bool {
let current = self.ptr.load(Ordering::Acquire, guard);
if let Some(current_node) = unsafe { current.as_ref() } {
if current_node.version >= version {
return false;
}
}
let new_node = VersionedNode { data, version };
let new = Owned::new(new_node);
match self.compare_and_swap(current, new, guard) {
Ok(_) => {
if !current.is_null() {
unsafe {
guard.defer_destroy(current);
}
}
true
}
Err((_, returned)) => {
drop(returned);
false
}
}
}
}
pub struct HazardPointer<T> {
inner: Atomic<T>,
}
impl<T> HazardPointer<T> {
pub fn new(data: T) -> Self {
Self {
inner: Atomic::new(data),
}
}
pub fn load<'g>(&self, guard: &'g Guard) -> Shared<'g, T> {
self.inner.load(Ordering::Acquire, guard)
}
pub fn store(&self, new: Owned<T>, guard: &Guard) {
let old = self.inner.swap(new, Ordering::Release, guard);
if !old.is_null() {
unsafe {
guard.defer_destroy(old);
}
}
}
pub fn compare_and_swap<'g>(
&self,
current: Shared<'g, T>,
new: Owned<T>,
guard: &'g Guard,
) -> Result<Shared<'g, T>, (Shared<'g, T>, Owned<T>)> {
match self
.inner
.compare_exchange(current, new, Ordering::Release, Ordering::Acquire, guard)
{
Ok(shared) => Ok(shared),
Err(e) => Err((e.current, e.new)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn test_epoch_manager() {
let manager = Arc::new(EpochManager::new());
let initial_epoch = manager.current_epoch();
manager.advance();
assert_eq!(manager.current_epoch(), initial_epoch + 1);
let guard = manager.pin();
drop(guard);
}
#[test]
fn test_versioned_pointer() {
let ptr = Arc::new(VersionedPointer::new(42));
let guard = epoch::pin();
assert_eq!(ptr.load(&guard), Some(&42));
assert!(ptr.update(100, 1, &guard));
assert_eq!(ptr.load(&guard), Some(&100));
let result = ptr.update(50, 0, &guard);
assert!(!result, "Update with outdated version should fail");
assert_eq!(ptr.load(&guard), Some(&100));
}
#[test]
fn test_concurrent_updates() {
let ptr = Arc::new(VersionedPointer::new(0));
let num_threads = 4;
let updates_per_thread = 1000;
let handles: Vec<_> = (0..num_threads)
.map(|i| {
let ptr = ptr.clone();
thread::spawn(move || {
let guard = epoch::pin();
for j in 0..updates_per_thread {
let version = i * updates_per_thread + j;
ptr.update(version as i32, version, &guard);
}
})
})
.collect();
for handle in handles {
handle.join().expect("thread should not panic");
}
let guard = epoch::pin();
let final_value = ptr.load(&guard).expect("load should succeed");
assert!(*final_value >= 0);
}
#[test]
fn test_hazard_pointer() {
let hp = Arc::new(HazardPointer::new("initial"));
let guard = epoch::pin();
hp.store(Owned::new("updated"), &guard);
let loaded = hp.load(&guard);
unsafe {
assert_eq!(
loaded.as_ref().expect("operation should succeed"),
&"updated"
);
}
}
}