use std::cmp;
use std::fmt;
use std::marker::PhantomData;
use std::mem;
use std::ptr;
use std::sync::Arc;
use std::sync::atomic::{AtomicIsize, fence};
use std::sync::atomic::Ordering::{Acquire, Release, Relaxed, SeqCst};
use epoch::{self, Atomic};
const MIN_CAP: usize = 16;
struct Buffer<T> {
ptr: *mut T,
cap: usize,
}
impl<T> Buffer<T> {
fn new(cap: usize) -> Self {
let mut v = Vec::with_capacity(cap);
let ptr = v.as_mut_ptr();
mem::forget(v);
Buffer {
ptr: ptr,
cap: cap,
}
}
unsafe fn at(&self, index: isize) -> *mut T {
self.ptr.offset(index & (self.cap - 1) as isize)
}
unsafe fn write(&self, index: isize, value: T) {
ptr::write(self.at(index), value)
}
unsafe fn read(&self, index: isize) -> T {
ptr::read(self.at(index))
}
}
struct Deque<T> {
bottom: AtomicIsize,
top: AtomicIsize,
buffer: Atomic<Buffer<T>>,
}
impl<T> Deque<T> {
fn new() -> Self {
Deque {
bottom: AtomicIsize::new(0),
top: AtomicIsize::new(0),
buffer: Atomic::new(Buffer::new(MIN_CAP), 0),
}
}
fn len(&self) -> usize {
let b = self.bottom.load(Relaxed);
let t = self.top.load(Relaxed);
cmp::max(b.wrapping_sub(t), 0) as usize
}
#[cold]
unsafe fn resize(&self, new_cap: usize) {
let b = self.bottom.load(Relaxed);
let t = self.top.load(Relaxed);
let buffer = self.buffer.load_raw(Relaxed).0;
let new = Buffer::new(new_cap);
let mut i = t;
while i != b {
ptr::copy_nonoverlapping((*buffer).at(i), new.at(i), 1);
i = i.wrapping_add(1);
}
epoch::pin(|pin| {
self.buffer.store_box(Box::new(new), 0, pin).as_raw();
let ptr = (*buffer).ptr;
let cap = (*buffer).cap;
epoch::defer_free(ptr, cap, pin);
epoch::defer_free(buffer, 1, pin);
if mem::size_of::<T>() * cap >= 1 << 10 {
epoch::flush(pin);
}
})
}
fn push(&self, value: T) {
unsafe {
let b = self.bottom.load(Relaxed);
let t = self.top.load(Acquire);
let mut buffer = self.buffer.load_raw(Relaxed).0;
let len = b.wrapping_sub(t);
let cap = (*buffer).cap;
if len >= cap as isize {
self.resize(2 * cap);
buffer = self.buffer.load_raw(Relaxed).0;
}
(*buffer).write(b, value);
fence(Release);
self.bottom.store(b.wrapping_add(1), Relaxed);
}
}
fn pop(&self) -> Option<T> {
let b = self.bottom.load(Relaxed);
let t = self.top.load(Relaxed);
if b.wrapping_sub(t) <= 0 {
return None;
}
let b = b.wrapping_sub(1);
self.bottom.store(b, Relaxed);
let buffer = self.buffer.load_raw(Relaxed).0;
fence(SeqCst);
let t = self.top.load(Relaxed);
let len = b.wrapping_sub(t);
if len < 0 {
self.bottom.store(b.wrapping_add(1), Relaxed);
None
} else {
let mut value = unsafe { Some((*buffer).read(b)) };
if len == 0 {
if self.top.compare_exchange(t, t.wrapping_add(1), SeqCst, Relaxed).is_err() {
mem::forget(value.take());
}
self.bottom.store(b.wrapping_add(1), Relaxed);
} else {
unsafe {
let cap = (*buffer).cap;
if cap > MIN_CAP && len < cap as isize / 4 {
self.resize(cap / 2);
}
}
}
value
}
}
fn steal(&self) -> Option<T> {
let mut t = self.top.load(Acquire);
if epoch::is_pinned() {
fence(SeqCst);
}
epoch::pin(|pin| {
loop {
let b = self.bottom.load(Acquire);
if b.wrapping_sub(t) <= 0 {
return None;
}
let a = self.buffer.load(pin).unwrap();
let value = unsafe { a.read(t) };
if self.top.compare_exchange(t, t.wrapping_add(1), SeqCst, Relaxed).is_ok() {
return Some(value);
}
mem::forget(value);
t = self.top.load(Acquire);
fence(SeqCst);
}
})
}
}
impl<T> Drop for Deque<T> {
fn drop(&mut self) {
let b = self.bottom.load(Relaxed);
let t = self.top.load(Relaxed);
let buffer = self.buffer.load_raw(Relaxed).0;
unsafe {
let mut i = t;
while i != b {
ptr::drop_in_place((*buffer).at(i));
i = i.wrapping_add(1);
}
drop(Vec::from_raw_parts((*buffer).ptr, 0, (*buffer).cap));
drop(Vec::from_raw_parts(buffer, 0, 1));
}
}
}
pub struct Worker<T> {
deque: Arc<Deque<T>>,
_marker: PhantomData<*mut ()>, }
unsafe impl<T: Send> Send for Worker<T> {}
impl<T> Worker<T> {
pub fn len(&self) -> usize {
self.deque.len()
}
pub fn push(&self, value: T) {
self.deque.push(value);
}
pub fn pop(&self) -> Option<T> {
self.deque.pop()
}
}
impl<T> fmt::Debug for Worker<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Worker {{ ... }}")
}
}
pub struct Stealer<T> {
deque: Arc<Deque<T>>,
_marker: PhantomData<*mut ()>, }
unsafe impl<T: Send> Send for Stealer<T> {}
unsafe impl<T: Send> Sync for Stealer<T> {}
impl<T> Stealer<T> {
pub fn len(&self) -> usize {
self.deque.len()
}
pub fn steal(&self) -> Option<T> {
self.deque.steal()
}
}
impl<T> Clone for Stealer<T> {
fn clone(&self) -> Self {
Stealer {
deque: self.deque.clone(),
_marker: PhantomData,
}
}
}
impl<T> fmt::Debug for Stealer<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Stealer {{ ... }}")
}
}
pub fn new<T>() -> (Worker<T>, Stealer<T>) {
let d = Arc::new(Deque::new());
let worker = Worker {
deque: d.clone(),
_marker: PhantomData,
};
let stealer = Stealer {
deque: d,
_marker: PhantomData,
};
(worker, stealer)
}
#[cfg(test)]
mod tests {
extern crate rand;
use std::sync::{Arc, Mutex};
use std::sync::atomic::{AtomicBool, AtomicUsize};
use std::sync::atomic::Ordering::SeqCst;
use std::thread;
use epoch;
use self::rand::Rng;
#[test]
fn smoke() {
let (w, s) = super::new();
assert_eq!(w.pop(), None);
assert_eq!(s.steal(), None);
assert_eq!(w.len(), 0);
assert_eq!(s.len(), 0);
w.push(1);
assert_eq!(w.len(), 1);
assert_eq!(s.len(), 1);
assert_eq!(w.pop(), Some(1));
assert_eq!(w.pop(), None);
assert_eq!(s.steal(), None);
assert_eq!(w.len(), 0);
assert_eq!(s.len(), 0);
w.push(2);
assert_eq!(s.steal(), Some(2));
assert_eq!(s.steal(), None);
assert_eq!(w.pop(), None);
}
#[test]
fn steal_push() {
const STEPS: usize = 50_000;
let (w, s) = super::new();
let t = thread::spawn(move || {
for i in 0..STEPS {
loop {
if let Some(v) = s.steal() {
assert_eq!(i, v);
break;
}
}
}
});
for i in 0..STEPS {
w.push(i);
}
t.join().unwrap();
}
#[test]
fn stampede() {
const COUNT: usize = 50_000;
let (w, s) = super::new();
for i in 0..COUNT {
w.push(Box::new(i + 1));
}
let remaining = Arc::new(AtomicUsize::new(COUNT));
let threads = (0..8).map(|_| {
let s = s.clone();
let remaining = remaining.clone();
thread::spawn(move || {
let mut last = 0;
while remaining.load(SeqCst) > 0 {
if let Some(x) = s.steal() {
assert!(last < *x);
last = *x;
remaining.fetch_sub(1, SeqCst);
}
}
})
}).collect::<Vec<_>>();
let mut last = COUNT + 1;
while remaining.load(SeqCst) > 0 {
if let Some(x) = w.pop() {
assert!(last > *x);
last = *x;
remaining.fetch_sub(1, SeqCst);
}
}
for t in threads {
t.join().unwrap();
}
}
fn run_stress() {
const COUNT: usize = 50_000;
let (w, s) = super::new();
let done = Arc::new(AtomicBool::new(false));
let hits = Arc::new(AtomicUsize::new(0));
let threads = (0..8).map(|_| {
let s = s.clone();
let done = done.clone();
let hits = hits.clone();
thread::spawn(move || {
while !done.load(SeqCst) {
if let Some(_) = s.steal() {
hits.fetch_add(1, SeqCst);
}
}
})
}).collect::<Vec<_>>();
let mut rng = rand::thread_rng();
let mut expected = 0;
while expected < COUNT {
if rng.gen_range(0, 3) == 0 {
if w.pop().is_some() {
hits.fetch_add(1, SeqCst);
}
} else {
w.push(expected);
expected += 1;
}
}
while hits.load(SeqCst) < COUNT {
if w.pop().is_some() {
hits.fetch_add(1, SeqCst);
}
}
done.store(true, SeqCst);
for t in threads {
t.join().unwrap();
}
}
#[test]
fn stress() {
run_stress();
}
#[test]
fn stress_pinned() {
epoch::pin(|_| run_stress());
}
#[test]
fn no_starvation() {
const COUNT: usize = 50_000;
let (w, s) = super::new();
let done = Arc::new(AtomicBool::new(false));
let (threads, hits): (Vec<_>, Vec<_>) = (0..8).map(|_| {
let s = s.clone();
let done = done.clone();
let hits = Arc::new(AtomicUsize::new(0));
let t = {
let hits = hits.clone();
thread::spawn(move || {
while !done.load(SeqCst) {
if let Some(_) = s.steal() {
hits.fetch_add(1, SeqCst);
}
}
})
};
(t, hits)
}).unzip();
let mut rng = rand::thread_rng();
let mut my_hits = 0;
loop {
for i in 0..rng.gen_range(0, COUNT) {
if rng.gen_range(0, 3) == 0 && my_hits == 0 {
if w.pop().is_some() {
my_hits += 1;
}
} else {
w.push(i);
}
}
if my_hits > 0 && hits.iter().all(|h| h.load(SeqCst) > 0) {
break;
}
}
done.store(true, SeqCst);
for t in threads {
t.join().unwrap();
}
}
#[test]
fn destructors() {
const COUNT: usize = 50_000;
struct Elem(usize, Arc<Mutex<Vec<usize>>>);
impl Drop for Elem {
fn drop(&mut self) {
self.1.lock().unwrap().push(self.0);
}
}
let (w, s) = super::new();
let dropped = Arc::new(Mutex::new(Vec::new()));
let remaining = Arc::new(AtomicUsize::new(COUNT));
for i in 0..COUNT {
w.push(Elem(i, dropped.clone()));
}
let threads = (0..8).map(|_| {
let s = s.clone();
let remaining = remaining.clone();
thread::spawn(move || {
for _ in 0..1000 {
if s.steal().is_some() {
remaining.fetch_sub(1, SeqCst);
}
}
})
}).collect::<Vec<_>>();
for _ in 0..1000 {
if w.pop().is_some() {
remaining.fetch_sub(1, SeqCst);
}
}
for t in threads {
t.join().unwrap();
}
let rem = remaining.load(SeqCst);
assert!(rem > 0);
assert_eq!(w.len(), rem);
assert_eq!(s.len(), rem);
{
let mut v = dropped.lock().unwrap();
assert_eq!(v.len(), COUNT - rem);
v.clear();
}
drop(w);
drop(s);
{
let mut v = dropped.lock().unwrap();
assert_eq!(v.len(), rem);
v.sort();
for w in v.windows(2) {
assert_eq!(w[0] + 1, w[1]);
}
}
}
}