use std::cell::UnsafeCell;
use std::fmt;
use std::mem;
use std::ptr;
use std::sync::atomic::Ordering::{Acquire, Relaxed, Release, SeqCst};
use std::sync::atomic::{AtomicIsize, fence};
use std::sync::Arc;
use std::marker::PhantomData;
use std::cell::Cell;
use epoch::{self, Atomic, Shared, Owned};
const K: isize = 4;
const MIN_BITS: u32 = 7;
#[derive(Debug)]
struct Deque<T> {
bottom: AtomicIsize,
top: AtomicIsize,
array: Atomic<Buffer<T>>,
}
unsafe impl<T: Send> Send for Deque<T> {}
unsafe impl<T: Send> Sync for Deque<T> {}
#[derive(Debug)]
pub struct Worker<T> {
deque: Arc<Deque<T>>,
marker: PhantomData<Cell<()>>,
}
#[derive(Debug)]
pub struct Stealer<T> {
deque: Arc<Deque<T>>,
}
#[derive(PartialEq, Eq, Debug)]
pub enum Steal<T> {
Empty,
Abort,
Data(T),
}
struct Buffer<T> {
storage: UnsafeCell<Vec<T>>,
log_size: u32,
}
impl<T> fmt::Debug for Buffer<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Buffer {{ ... }}")
}
}
impl<T> Worker<T> {
pub fn push(&self, t: T) {
unsafe { self.deque.push(t) }
}
pub fn try_pop(&self) -> Option<T> {
unsafe { self.deque.try_pop() }
}
}
impl<T> Stealer<T> {
pub fn steal(&self) -> Steal<T> {
self.deque.steal()
}
}
impl<T> Clone for Stealer<T> {
fn clone(&self) -> Stealer<T> {
Stealer { deque: self.deque.clone() }
}
}
pub fn deque<T>() -> (Worker<T>, Stealer<T>) {
let a = Arc::new(Deque::new());
let b = a.clone();
(Worker { deque: a, marker: PhantomData }, Stealer { deque: b })
}
impl<T> Deque<T> {
fn new() -> Deque<T> {
let array = Atomic::null();
array.store(Some(Owned::new(Buffer::new(MIN_BITS))), SeqCst);
Deque {
bottom: AtomicIsize::new(0),
top: AtomicIsize::new(0),
array: array,
}
}
unsafe fn push(&self, data: T) {
let guard = epoch::pin();
let mut b = self.bottom.load(Relaxed);
let t = self.top.load(Acquire);
let mut a = self.array.load(Relaxed, &guard).unwrap();
let size = b - t;
if size >= (a.size() as isize) - 1 {
a = self.swap_buffer(a, a.resize(b, t, 1), &guard);
b = self.bottom.load(Relaxed);
}
a.put(b, data);
fence(Release);
self.bottom.store(b + 1, Relaxed);
}
unsafe fn try_pop(&self) -> Option<T> {
let guard = epoch::pin();
let b = self.bottom.load(Relaxed);
let a = self.array.load(Relaxed, &guard).unwrap();
self.bottom.store(b - 1, Relaxed);
fence(SeqCst); let t = self.top.load(Relaxed);
let size = b - t;
if size <= 0 {
self.bottom.store(b, Relaxed);
None
} else if size >= 2 {
let data = a.get(b - 1);
self.maybe_shrink(b - 1, t, &guard);
Some(data)
} else {
let success = self.top.compare_and_swap(t, t + 1, SeqCst) == t;
self.bottom.store(b, Relaxed);
if success {
Some(a.get(t))
} else {
None
}
}
}
fn steal(&self) -> Steal<T> {
let guard = epoch::pin();
let t = self.top.load(Relaxed);
fence(SeqCst); let b = self.bottom.load(Acquire);
let size = b - t;
if size <= 0 {
return Steal::Empty
}
unsafe {
let a = self.array.load(Acquire, &guard).unwrap();
let data = a.get(t);
if self.top.compare_and_swap(t, t + 1, SeqCst) == t {
Steal::Data(data)
} else {
mem::forget(data); Steal::Abort
}
}
}
unsafe fn maybe_shrink(&self, b: isize, t: isize, guard: &epoch::Guard) {
let a = self.array.load(SeqCst, guard).unwrap();
let size = b - t;
if size < (a.size() as isize) / K && size > (1 << MIN_BITS) {
self.swap_buffer(a, a.resize(b, t, -1), guard);
}
}
unsafe fn swap_buffer<'a>(&self,
old: Shared<'a, Buffer<T>>,
buf: Buffer<T>,
guard: &'a epoch::Guard)
-> Shared<'a, Buffer<T>> {
let newbuf = Owned::new(buf);
let newbuf = self.array.store_and_ref(newbuf, Release, &guard);
guard.unlinked(old);
newbuf
}
}
impl<T> Drop for Deque<T> {
fn drop(&mut self) {
let guard = epoch::pin();
let t = self.top.load(Relaxed);
let b = self.bottom.load(Relaxed);
let a = self.array.swap(None, Relaxed, &guard).unwrap();
unsafe {
for i in t..b {
drop(a.get(i));
}
guard.unlinked(a);
}
}
}
impl<T> Buffer<T> {
fn new(log_size: u32) -> Buffer<T> {
Buffer {
storage: UnsafeCell::new(Vec::with_capacity(1 << log_size)),
log_size: log_size,
}
}
fn size(&self) -> usize {
unsafe { (*self.storage.get()).capacity() }
}
fn mask(&self) -> isize {
unsafe {
((*self.storage.get()).capacity() - 1) as isize
}
}
unsafe fn elem(&self, i: isize) -> *mut T {
(*self.storage.get()).as_mut_ptr().offset(i & self.mask())
}
unsafe fn get(&self, i: isize) -> T {
ptr::read(self.elem(i))
}
unsafe fn put(&self, i: isize, t: T) {
ptr::write(self.elem(i), t);
}
unsafe fn resize(&self, b: isize, t: isize, delta: i32) -> Buffer<T> {
let buf = Buffer::new(((self.log_size as i32) + delta) as u32);
for i in t..b {
buf.put(i, self.get(i));
}
return buf;
}
}
#[cfg(test)]
mod tests {
extern crate rand;
use super::{deque, Worker, Stealer, Steal};
use std::thread;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, ATOMIC_BOOL_INIT,
AtomicUsize, ATOMIC_USIZE_INIT};
use std::sync::atomic::Ordering::SeqCst;
use self::rand::Rng;
#[test]
fn smoke() {
let (w, s) = deque();
assert_eq!(w.try_pop(), None);
assert_eq!(s.steal(), Steal::Empty);
w.push(1);
assert_eq!(w.try_pop(), Some(1));
w.push(1);
assert_eq!(s.steal(), Steal::Data(1));
w.push(1);
assert_eq!(s.clone().steal(), Steal::Data(1));
}
#[test]
fn stealpush() {
static AMT: isize = 100000;
let (w, s) = deque();
let t = thread::spawn(move || {
let mut left = AMT;
while left > 0 {
match s.steal() {
Steal::Data(i) => {
assert_eq!(i, 1);
left -= 1;
}
Steal::Abort | Steal::Empty => {}
}
}
});
for _ in 0..AMT {
w.push(1);
}
t.join().unwrap();
}
#[test]
fn stealpush_large() {
static AMT: isize = 100000;
let (w, s) = deque();
let t = thread::spawn(move || {
let mut left = AMT;
while left > 0 {
match s.steal() {
Steal::Data((1, 10)) => { left -= 1; }
Steal::Data(..) => panic!(),
Steal::Abort | Steal::Empty => {}
}
}
});
for _ in 0..AMT {
w.push((1, 10));
}
t.join().unwrap();
}
fn stampede(w: Worker<Box<isize>>,
s: Stealer<Box<isize>>,
nthreads: isize,
amt: usize) {
for _ in 0..amt {
w.push(Box::new(20));
}
let remaining = Arc::new(AtomicUsize::new(amt));
let threads = (0..nthreads).map(|_| {
let remaining = remaining.clone();
let s = s.clone();
thread::spawn(move || {
while remaining.load(SeqCst) > 0 {
match s.steal() {
Steal::Data(val) => {
if *val == 20 {
remaining.fetch_sub(1, SeqCst);
} else {
panic!()
}
}
Steal::Abort | Steal::Empty => {}
}
}
})
}).collect::<Vec<_>>();
while remaining.load(SeqCst) > 0 {
if let Some(val) = w.try_pop() {
if *val == 20 {
remaining.fetch_sub(1, SeqCst);
} else {
panic!()
}
}
}
for thread in threads.into_iter() {
thread.join().unwrap();
}
}
#[test]
fn run_stampede() {
let (w, s) = deque();
stampede(w, s, 8, 10000);
}
#[test]
fn many_stampede() {
static AMT: usize = 4;
let threads = (0..AMT).map(|_| {
let (w, s) = deque();
thread::spawn(|| {
stampede(w, s, 4, 10000);
})
}).collect::<Vec<_>>();
for thread in threads.into_iter() {
thread.join().unwrap();
}
}
#[test]
fn stress() {
static AMT: isize = 100000;
static NTHREADS: isize = 8;
static DONE: AtomicBool = ATOMIC_BOOL_INIT;
static HITS: AtomicUsize = ATOMIC_USIZE_INIT;
let (w, s) = deque();
let threads = (0..NTHREADS).map(|_| {
let s = s.clone();
thread::spawn(move || {
loop {
match s.steal() {
Steal::Data(2) => { HITS.fetch_add(1, SeqCst); }
Steal::Data(..) => panic!(),
_ if DONE.load(SeqCst) => break,
_ => {}
}
}
})
}).collect::<Vec<_>>();
let mut rng = rand::thread_rng();
let mut expected = 0;
while expected < AMT {
if rng.gen_range(0, 3) == 2 {
match w.try_pop() {
None => {}
Some(2) => { HITS.fetch_add(1, SeqCst); },
Some(_) => panic!(),
}
} else {
expected += 1;
w.push(2);
}
}
while HITS.load(SeqCst) < AMT as usize {
match w.try_pop() {
None => {}
Some(2) => { HITS.fetch_add(1, SeqCst); },
Some(_) => panic!(),
}
}
DONE.store(true, SeqCst);
for thread in threads.into_iter() {
thread.join().unwrap();
}
assert_eq!(HITS.load(SeqCst), expected as usize);
}
#[test]
fn no_starvation() {
static AMT: isize = 10000;
static NTHREADS: isize = 4;
static DONE: AtomicBool = ATOMIC_BOOL_INIT;
let (w, s) = deque();
let (threads, hits): (Vec<_>, Vec<_>) = (0..NTHREADS).map(|_| {
let s = s.clone();
let ctr = Arc::new(AtomicUsize::new(0));
let ctr2 = ctr.clone();
(thread::spawn(move || {
loop {
match s.steal() {
Steal::Data((1, 2)) => { ctr.fetch_add(1, SeqCst); }
Steal::Data(..) => panic!(),
_ if DONE.load(SeqCst) => break,
_ => {}
}
}
}), ctr2)
}).unzip();
let mut rng = rand::thread_rng();
let mut myhit = false;
'outer: loop {
for _ in 0..rng.gen_range(0, AMT) {
if !myhit && rng.gen_range(0, 3) == 2 {
match w.try_pop() {
None => {}
Some((1, 2)) => myhit = true,
Some(_) => panic!(),
}
} else {
w.push((1, 2));
}
}
for slot in hits.iter() {
let amt = slot.load(SeqCst);
if amt == 0 { continue 'outer; }
}
if myhit {
break
}
}
DONE.store(true, SeqCst);
for thread in threads.into_iter() {
thread.join().unwrap();
}
}
}