use crate::sync::arcunique::ArcUnique;
use anyhow::Result;
use std::convert::TryFrom;
use std::ops::Deref;
use std::sync::atomic::{AtomicPtr, Ordering};
use std::sync::Arc;
#[derive(Debug)]
pub struct AtomicBox<T: Sized> {
ptr: AtomicPtr<T>,
}
impl<T: Sized> AtomicBox<T> {
pub fn new(value: T) -> AtomicBox<T> {
AtomicBox {
ptr: AtomicPtr::new(AtomicBox::alloc_from(value)),
}
}
#[inline]
fn alloc_from(value: T) -> *mut T {
let total: Arc<T> = Arc::new(value);
Arc::into_raw(total) as *mut T
}
fn strongest_failure_ordering(order: Ordering) -> Ordering {
use Ordering::*;
match order {
Release => Relaxed,
Relaxed => Relaxed,
SeqCst => SeqCst,
Acquire => Acquire,
AcqRel => Acquire,
_ => unsafe { std::hint::unreachable_unchecked() },
}
}
fn compare_and_swap(&self, current: *mut T, new: *mut T, order: Ordering) -> *mut T {
match self.ptr.compare_exchange(
current,
new,
order,
Self::strongest_failure_ordering(order),
) {
Ok(x) => x,
Err(x) => x,
}
}
fn take(&self) -> Arc<T> {
loop {
let curr = self.ptr.load(Ordering::Acquire);
let null: *mut T = std::ptr::null_mut();
if curr == null {
continue;
}
if self.compare_and_swap(curr, null, Ordering::AcqRel) == curr {
return unsafe { Arc::from_raw(curr) };
}
}
}
fn release(&self, ptr: *mut T) {
self.ptr.store(ptr, Ordering::Release);
}
pub fn get(&self) -> Arc<T> {
let val = self.take();
let copy = Arc::clone(&val);
let ptr = Arc::into_raw(val) as *mut T;
self.release(ptr);
copy
}
pub fn extract_mut_ptr(&mut self) -> *mut T {
let x = self.get();
Arc::into_raw(x) as *mut T
}
pub fn extract(&self) -> Result<Arc<T>> {
let au: ArcUnique<Arc<T>> = ArcUnique::try_from(self.get())?;
Ok(au.deref().clone())
}
pub fn replace_with<F>(&self, f: F)
where
F: Fn(Arc<T>) -> T,
{
let val = self.take();
let new_val = f(val);
let ptr = Arc::into_raw(Arc::new(new_val)) as *mut T;
self.release(ptr);
}
}
impl<T: Sized + PartialEq> PartialEq for AtomicBox<T> {
fn eq(&self, other: &AtomicBox<T>) -> bool {
self == other
}
}
impl<T: Sized> Drop for AtomicBox<T> {
fn drop(&mut self) {
unsafe { Arc::from_raw(self.ptr.load(Ordering::Acquire)) };
}
}
unsafe impl<T: Sized + Sync> Sync for AtomicBox<T> {}
unsafe impl<T: Sized + Send> Send for AtomicBox<T> {}
#[cfg(test)]
mod tests {
use std::sync::{Arc, Barrier};
use std::thread;
use super::AtomicBox;
#[test]
#[cfg_attr(miri, ignore)]
fn atomic_arc_new() {
let b = AtomicBox::new(1024);
assert_eq!(*b.get(), 1024);
}
#[test]
#[cfg_attr(miri, ignore)]
fn atomic_arc_replace_with() {
let value: i64 = 1024;
let b = AtomicBox::new(value);
b.replace_with(|x| *x * 2);
assert_eq!(*b.get(), value * 2);
}
#[test]
#[cfg_attr(miri, ignore)]
fn atomic_arc_replace_with_ten_times() {
let value = 1024;
let b = AtomicBox::new(value);
for _i in 0..10 {
b.replace_with(|x| *x * 2);
}
assert_eq!(*b.get(), value * 2_i32.pow(10));
}
#[test]
#[cfg_attr(miri, ignore)]
fn atomic_arc_replace_instance() {
let b = Arc::new(AtomicBox::new(1024));
let b1 = b.clone();
b1.replace_with(|x| *x * 2);
assert_eq!(*b.get(), 2048);
}
#[test]
#[cfg_attr(miri, ignore)]
fn atomic_arc_threaded_leak_test() {
let val = Arc::new(AtomicBox::new(10));
let val_cpys: Vec<Arc<AtomicBox<i32>>> = (0..10).map(|_| val.clone()).collect();
let mut guards = Vec::new();
for i in 0..10 {
let val_cpy = val_cpys[i].clone();
let guard = thread::spawn(move || {
val_cpy.replace_with(|x| *x * 2);
});
guards.push(guard);
}
for g in guards {
g.join().unwrap();
}
assert_eq!(*val.get(), 10 * 2_i32.pow(10));
}
#[test]
#[cfg_attr(miri, ignore)]
fn atomic_arc_threaded_contention() {
let abox = Arc::new(AtomicBox::new(0));
let thread_num = 10;
let mut guards = Vec::new();
let barrier = Arc::new(Barrier::new(thread_num));
for _i in 0..thread_num {
let b = Arc::clone(&barrier);
let cpy = abox.clone();
guards.push(thread::spawn(move || {
b.wait();
for _j in 0..1000 {
cpy.replace_with(|x| *x + 100)
}
}));
}
for g in guards {
g.join().unwrap();
}
assert_eq!(*abox.get(), thread_num * 1000 * 100);
}
#[test]
#[cfg_attr(miri, ignore)]
fn atomic_arc_vector_container() {
let values: Vec<i32> = (0..10).map(|x: i32| x.pow(2)).collect();
let abox = Arc::new(AtomicBox::new(vec![]));
let mut guards = Vec::new();
for i in 0..10 {
let cpy = abox.clone();
let values: Vec<i32> = values.clone();
guards.push(thread::spawn(move || {
cpy.replace_with(|x| {
let mut nx = (*x).clone();
nx.push(values[i]);
nx
})
}));
}
for g in guards {
g.join().unwrap();
}
assert_eq!(abox.get().len(), values.len());
for i in values {
assert_eq!(abox.get().contains(&i), true);
}
}
}