use std::sync::atomic::{self, AtomicPtr};
use std::marker::PhantomData;
use std::ptr;
use {Guard, add_garbage_box};
pub struct Treiber<T> {
head: AtomicPtr<Node<T>>,
_marker: PhantomData<T>,
}
impl<T> Treiber<T> {
pub fn new() -> Treiber<T> {
Treiber {
head: AtomicPtr::default(),
_marker: PhantomData,
}
}
pub fn pop(&self) -> Option<Guard<T>> {
let mut snapshot = Guard::maybe_new(|| unsafe {
self.head.load(atomic::Ordering::Acquire).as_ref()
});
while let Some(old) = snapshot {
snapshot = Guard::maybe_new(|| unsafe {
self.head.compare_and_swap(
old.as_ptr() as *mut _,
old.next as *mut Node<T>,
atomic::Ordering::Release,
).as_ref()
});
if let Some(ref new) = snapshot {
if new.as_ptr() == old.as_ptr() {
unsafe { add_garbage_box(old.as_ptr()); }
return Some(old.map(|x| &x.item));
}
} else {
break;
}
}
None
}
pub fn push(&self, item: T)
where T: 'static {
let mut snapshot = Guard::maybe_new(|| unsafe {
self.head.load(atomic::Ordering::Relaxed).as_ref()
});
let mut node = Box::into_raw(Box::new(Node {
item: item,
next: ptr::null_mut(),
}));
loop {
let next = snapshot.map_or(ptr::null_mut(), |x| x.as_ptr() as *mut _);
unsafe { (*node).next = next; }
match Guard::maybe_new(|| unsafe {
self.head.compare_and_swap(next, node, atomic::Ordering::Release).as_ref()
}) {
Some(ref new) if new.as_ptr() == next => break,
None if next.is_null() => break,
new => snapshot = new,
}
}
}
}
impl<T> Drop for Treiber<T> {
fn drop(&mut self) {
unsafe {
let ptr = *self.head.get_mut();
if !ptr.is_null() {
(*ptr).destroy();
drop(Box::from_raw(ptr));
}
}
}
}
struct Node<T> {
item: T,
next: *mut Node<T>,
}
impl<T> Node<T> {
unsafe fn destroy(&mut self) {
if !self.next.is_null() {
(*self.next).destroy();
drop(Box::from_raw(self.next as *mut Node<T>));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
#[derive(Clone)]
struct Dropper {
d: Arc<AtomicUsize>,
}
impl Drop for Dropper {
fn drop(&mut self) {
self.d.fetch_add(1, atomic::Ordering::Relaxed);
}
}
#[test]
fn empty() {
for _ in 0..1000 {
let b = Box::new(20);
Treiber::<u8>::new();
assert_eq!(*b, 20);
}
}
#[test]
fn just_push() {
let stack = Treiber::new();
stack.push(1);
stack.push(2);
stack.push(3);
drop(stack);
}
#[test]
fn simple1() {
let stack = Treiber::new();
stack.push(1);
stack.push(200);
stack.push(44);
assert_eq!(*stack.pop().unwrap(), 44);
assert_eq!(*stack.pop().unwrap(), 200);
assert_eq!(*stack.pop().unwrap(), 1);
assert!(stack.pop().is_none());
::gc();
}
#[test]
fn simple2() {
let stack = Treiber::new();
for _ in 0..16 {
stack.push(1);
stack.push(200);
stack.push(44);
assert_eq!(*stack.pop().unwrap(), 44);
assert_eq!(*stack.pop().unwrap(), 200);
stack.push(20000);
assert_eq!(*stack.pop().unwrap(), 20000);
assert_eq!(*stack.pop().unwrap(), 1);
assert!(stack.pop().is_none());
assert!(stack.pop().is_none());
assert!(stack.pop().is_none());
assert!(stack.pop().is_none());
}
::gc();
}
#[test]
fn simple3() {
let stack = Treiber::new();
for i in 0..10000 {
stack.push(i);
}
for i in (0..10000).rev() {
assert_eq!(*stack.pop().unwrap(), i);
}
for i in 0..10000 {
stack.push(i);
}
for i in (0..10000).rev() {
assert_eq!(*stack.pop().unwrap(), i);
}
assert!(stack.pop().is_none());
assert!(stack.pop().is_none());
assert!(stack.pop().is_none());
assert!(stack.pop().is_none());
}
#[test]
fn push_pop() {
let stack = Arc::new(Treiber::new());
let mut j = Vec::new();
for _ in 0..16 {
let s = stack.clone();
j.push(thread::spawn(move || {
for _ in 0..1_000_000 {
s.push(23);
assert_eq!(*s.pop().unwrap(), 23);
}
}));
}
for i in j {
i.join().unwrap();
}
}
#[test]
fn increment() {
let stack = Arc::new(Treiber::<u64>::new());
stack.push(0);
let mut j = Vec::new();
for _ in 0..16 {
let s = stack.clone();
j.push(thread::spawn(move || {
for n in 0..1001 {
loop {
if let Some(x) = s.pop() {
s.push(*x + n);
break;
}
}
}
}));
}
for i in j {
i.join().unwrap();
}
assert_eq!(*stack.pop().unwrap(), 16 * 1000 * 1001 / 2);
}
#[test]
fn sum() {
let stack = Arc::new(Treiber::<i64>::new());
let mut j = Vec::new();
for _ in 0..1000 {
stack.push(10);
}
for _ in 0..16 {
let s = stack.clone();
j.push(thread::spawn(move || {
for _ in 0..100000 {
loop {
if let Some(a) = s.pop() {
loop {
if let Some(b) = s.pop() {
s.push(*a + 1);
s.push(*b - 1);
break;
}
}
break;
}
}
}
}));
}
for i in j {
i.join().unwrap();
}
let mut sum = 0;
while let Some(x) = stack.pop() {
sum += *x;
}
assert_eq!(sum, 10000);
}
#[test]
fn drop1() {
let drops = Arc::new(AtomicUsize::default());
let stack = Arc::new(Treiber::new());
let d = Dropper {
d: drops.clone(),
};
let mut j = Vec::new();
for _ in 0..16 {
let d = d.clone();
let stack = stack.clone();
j.push(thread::spawn(move || {
for _ in 0..20 {
stack.push(d.clone());
}
stack.pop();
stack.pop();
}))
}
for i in j {
i.join().unwrap();
}
::gc();
assert_eq!(drops.load(atomic::Ordering::Relaxed), 32 + 16);
drop(stack);
::gc();
assert_eq!(drops.load(atomic::Ordering::Relaxed), 20 * 16 + 16);
}
#[test]
#[should_panic]
fn panic_in_dtor() {
struct A;
impl Drop for A {
fn drop(&mut self) {
panic!();
}
}
let stack = Treiber::new();
stack.push(Box::new(A));
stack.push(Box::new(A));
stack.push(Box::new(A));
}
}