use std::fmt;
use std::marker::PhantomData;
use std::mem;
use std::ptr;
use std::sync::atomic::AtomicPtr;
use std::sync::atomic::Ordering;
pub struct AtomicBox<T> {
ptr: AtomicPtr<T>,
phantom: PhantomData<Box<T>>,
}
unsafe impl<T> Sync for AtomicBox<T> where T: Send {}
impl<T> AtomicBox<T> {
pub fn new(value: Box<T>) -> AtomicBox<T> {
AtomicBox {
ptr: AtomicPtr::new(Box::into_raw(value)),
phantom: PhantomData,
}
}
pub fn swap(&self, other: Box<T>) -> Box<T> {
let mut result = other;
self.swap_mut(&mut result);
result
}
pub fn store(&self, other: Box<T>) {
self.swap(other);
}
pub fn swap_mut(&self, other: &mut Box<T>) {
let other_ptr = Box::into_raw(unsafe { ptr::read(other) });
let ptr = self.ptr.swap(other_ptr, Ordering::AcqRel);
unsafe { ptr::write(other, Box::from_raw(ptr)) };
}
pub fn into_inner(mut self) -> Box<T> {
let result = unsafe { Box::from_raw(*self.ptr.get_mut()) };
mem::forget(self);
result
}
pub fn get_mut(&mut self) -> &mut T {
unsafe { &mut **self.ptr.get_mut() }
}
}
impl<T> Drop for AtomicBox<T> {
fn drop(&mut self) {
let ptr = *self.ptr.get_mut();
unsafe { drop(Box::from_raw(ptr)) }
}
}
impl<T> Default for AtomicBox<T>
where
Box<T>: Default,
{
fn default() -> AtomicBox<T> {
AtomicBox::new(Default::default())
}
}
impl<T> fmt::Debug for AtomicBox<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
let p = self.ptr.load(Ordering::Relaxed);
f.write_str("AtomicBox(")?;
fmt::Pointer::fmt(&p, f)?;
f.write_str(")")?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::Barrier;
use super::*;
#[test]
fn atomic_box_swap_works() {
let b = AtomicBox::new(Box::new("hello world"));
let bis = Box::new("bis");
assert_eq!(b.swap(bis), Box::new("hello world"));
assert_eq!(b.swap(Box::new("")), Box::new("bis"));
}
#[test]
fn atomic_box_store_works() {
let b = AtomicBox::new(Box::new("hello world"));
let bis = Box::new("bis");
b.store(bis);
assert_eq!(b.into_inner(), Box::new("bis"));
}
#[test]
fn atomic_box_swap_mut_works() {
let b = AtomicBox::new(Box::new("hello world"));
let mut bis = Box::new("bis");
b.swap_mut(&mut bis);
assert_eq!(bis, Box::new("hello world"));
b.swap_mut(&mut bis);
assert_eq!(bis, Box::new("bis"));
}
#[test]
fn atomic_box_pointer_identity() {
let box1 = Box::new(1);
let p1 = format!("{box1:p}");
let atom = AtomicBox::new(box1);
let box2 = Box::new(2);
let p2 = format!("{box2:p}");
assert_ne!(p2, p1);
let box3 = atom.swap(box2); let p3 = format!("{box3:p}");
assert_eq!(p3, p1);
let box4 = atom.swap(Box::new(5)); let p4 = format!("{box4:p}");
assert_eq!(p4, p2); }
#[test]
fn atomic_box_drops() {
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
struct K(Arc<AtomicUsize>, usize);
impl Drop for K {
fn drop(&mut self) {
self.0.fetch_add(self.1, Ordering::Relaxed);
}
}
let n = Arc::new(AtomicUsize::new(0));
{
let ab = AtomicBox::new(Box::new(K(n.clone(), 5)));
assert_eq!(n.load(Ordering::Relaxed), 0);
let first = ab.swap(Box::new(K(n.clone(), 13)));
assert_eq!(n.load(Ordering::Relaxed), 0);
drop(first);
assert_eq!(n.load(Ordering::Relaxed), 5);
}
assert_eq!(n.load(Ordering::Relaxed), 5 + 13);
}
#[test]
fn atomic_threads() {
const NTHREADS: usize = 9;
let gate = Arc::new(Barrier::new(NTHREADS));
let abox: Arc<AtomicBox<Vec<u8>>> = Arc::new(Default::default());
let handles: Vec<_> = (0..NTHREADS as u8)
.map(|t| {
let my_gate = gate.clone();
let my_box = abox.clone();
std::thread::spawn(move || {
my_gate.wait();
let mut my_vec = Box::new(vec![]);
for _ in 0..100 {
my_vec = my_box.swap(my_vec);
my_vec.push(t);
}
my_vec
})
})
.collect();
let mut counts = [0usize; NTHREADS];
for h in handles {
for val in *h.join().unwrap() {
counts[val as usize] += 1;
}
}
for val in *abox.swap(Box::new(vec![])) {
counts[val as usize] += 1;
}
println!("{counts:?}");
for count in counts {
assert_eq!(count, 100);
}
}
#[test]
fn debug_fmt() {
let my_box = Box::new(32);
let expected = format!("AtomicBox({my_box:p})");
assert_eq!(format!("{:?}", AtomicBox::new(my_box)), expected);
}
}