use std::{mem, ptr};
use std::sync::atomic::{self, AtomicPtr};
use local;
use garbage::Garbage;
use guard::Guard;
#[derive(Default)]
pub struct AtomicOption<T> {
inner: AtomicPtr<T>,
}
impl<T> AtomicOption<T> {
pub fn new(init: Option<Box<T>>) -> AtomicOption<T> {
AtomicOption {
inner: AtomicPtr::new(init.map_or(ptr::null_mut(), Box::into_raw)),
}
}
pub fn load(&self, ordering: atomic::Ordering) -> Option<Guard<T>> {
Guard::maybe_new(|| unsafe {
self.inner.load(ordering).as_ref()
})
}
pub fn store(&self, new: Option<Box<T>>, ordering: atomic::Ordering) {
let new = new.map_or(ptr::null_mut(), |new| Box::into_raw(new));
let ptr = self.inner.swap(new, ordering);
if !ptr.is_null() {
local::add_garbage(unsafe { Garbage::new_box(ptr) });
}
}
pub fn swap(&self, new: Option<Box<T>>, ordering: atomic::Ordering) -> Option<Guard<T>> {
let new_ptr = new.map_or(ptr::null_mut(), Box::into_raw);
Guard::maybe_new(|| unsafe {
self.inner.swap(new_ptr, ordering).as_ref()
}).map(|guard| {
local::add_garbage(unsafe { Garbage::new_box(&*guard) });
guard
})
}
pub fn compare_and_store(&self, old: Option<*const T>, mut new: Option<Box<T>>, ordering: atomic::Ordering)
-> Result<(), Option<Box<T>>> {
let new_ptr = new.as_mut().map_or(ptr::null_mut(), |x| &mut **x);
let old_ptr = old.map_or(ptr::null_mut(), |x| x as *mut T);
let ptr = self.inner.compare_and_swap(old_ptr, new_ptr, ordering);
if ptr == old_ptr {
mem::forget(new);
if !old_ptr.is_null() {
local::add_garbage(unsafe { Garbage::new_box(old_ptr) });
}
Ok(())
} else {
Err(new)
}
}
pub fn compare_and_swap(&self, old: Option<*const T>, mut new: Option<Box<T>>, ordering: atomic::Ordering)
-> Result<Option<Guard<T>>, (Option<Guard<T>>, Option<Box<T>>)> {
let new_ptr = new.as_mut().map_or(ptr::null_mut(), |x| &mut **x);
let old_ptr = old.map_or(ptr::null_mut(), |x| x as *mut T);
let guard = Guard::maybe_new(|| {
unsafe { self.inner.compare_and_swap(old_ptr, new_ptr, ordering).as_ref() }
});
let guard_ptr = guard.as_ref().map_or(ptr::null_mut(), |x| &**x as *const T as *mut T);
if guard_ptr == old_ptr {
mem::forget(new);
if !old_ptr.is_null() {
local::add_garbage(unsafe { Garbage::new_box(old_ptr) });
}
Ok(guard)
} else {
Err((guard, new))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{atomic, Arc};
use std::sync::atomic::AtomicUsize;
use std::thread;
struct Basic;
impl Drop for Basic {
fn drop(&mut self) {
basic();
}
}
thread_local! {
static BASIC: Basic = Basic;
}
fn basic() {
let opt = AtomicOption::default();
assert!(opt.load(atomic::Ordering::Relaxed).is_none());
assert!(opt.swap(None, atomic::Ordering::Relaxed).is_none());
assert!(opt.load(atomic::Ordering::Relaxed).is_none());
assert!(opt.swap(Some(Box::new(42)), atomic::Ordering::Relaxed).is_none());
assert_eq!(*opt.load(atomic::Ordering::Relaxed).unwrap(), 42);
assert_eq!(*opt.swap(Some(Box::new(43)), atomic::Ordering::Relaxed).unwrap(), 42);
assert_eq!(*opt.load(atomic::Ordering::Relaxed).unwrap(), 43);
}
#[test]
fn basic_properties() {
basic()
}
#[test]
fn cas() {
let bx1 = Box::new(1);
let ptr1 = &*bx1 as *const usize;
let bx2 = Box::new(1);
let ptr2 = &*bx2 as *const usize;
let opt = AtomicOption::new(Some(bx1));
assert_eq!(ptr1, &*opt.compare_and_swap(Some(ptr2), None, atomic::Ordering::Relaxed).unwrap_err().0.unwrap());
assert_eq!(ptr1, &*opt.load(atomic::Ordering::Relaxed).unwrap());
assert_eq!(ptr1, &*opt.compare_and_swap(None, Some(Box::new(2)), atomic::Ordering::Relaxed).unwrap_err().0.unwrap());
assert_eq!(ptr1, &*opt.load(atomic::Ordering::Relaxed).unwrap());
opt.compare_and_swap(Some(ptr1), None, atomic::Ordering::Relaxed).unwrap();
assert!(opt.load(atomic::Ordering::Relaxed).is_none());
opt.compare_and_swap(None, Some(bx2), atomic::Ordering::Relaxed).unwrap();
assert_eq!(ptr2, &*opt.load(atomic::Ordering::Relaxed).unwrap());
opt.compare_and_store(Some(ptr2), None, atomic::Ordering::Relaxed).unwrap();
opt.compare_and_store(Some(Box::into_raw(Box::new(2))), None, atomic::Ordering::Relaxed).unwrap_err();
assert!(opt.load(atomic::Ordering::Relaxed).is_none());
::gc();
::gc();
::gc();
::gc();
}
#[test]
fn spam() {
let opt = Arc::new(AtomicOption::default());
let mut j = Vec::new();
for _ in 0..16 {
let opt = opt.clone();
j.push(thread::spawn(move || {
for i in 0..1_000_001 {
let _ = opt.load(atomic::Ordering::Relaxed);
opt.store(Some(Box::new(i)), atomic::Ordering::Relaxed);
}
opt
}))
}
::gc();
::gc();
for i in j {
i.join().unwrap();
}
assert_eq!(*opt.load(atomic::Ordering::Relaxed).unwrap(), 1_000_000);
}
#[test]
fn drop() {
#[derive(Clone)]
struct Dropper {
d: Arc<AtomicUsize>,
}
impl Drop for Dropper {
fn drop(&mut self) {
self.d.fetch_add(1, atomic::Ordering::Relaxed);
}
}
let drops = Arc::new(AtomicUsize::default());
let opt = Arc::new(AtomicOption::new(None));
let d = Dropper {
d: drops.clone(),
};
let mut j = Vec::new();
for _ in 0..16 {
let d = d.clone();
let opt = opt.clone();
j.push(thread::spawn(move || {
for _ in 0..1_000_000 {
opt.store(Some(Box::new(d.clone())), atomic::Ordering::Relaxed);
}
}))
}
for i in j {
i.join().unwrap();
}
opt.store(None, atomic::Ordering::Relaxed);
::gc();
assert_eq!(drops.load(atomic::Ordering::Relaxed), 16_000_000 + 16);
}
#[test]
fn tls() {
thread::spawn(|| BASIC.with(|_| {})).join().unwrap();
thread::spawn(|| BASIC.with(|_| {})).join().unwrap();
thread::spawn(|| BASIC.with(|_| {})).join().unwrap();
thread::spawn(|| BASIC.with(|_| {})).join().unwrap();
}
}