#![warn(missing_copy_implementations, missing_debug_implementations, missing_docs)]
#![cfg_attr(feature="valgrind", feature(alloc_system))]
#[cfg(feature="valgrind")]
extern crate alloc_system;
extern crate hazard;
use std::error;
use std::fmt;
use std::ptr;
use std::usize;
use std::sync::{Arc, Mutex};
use std::sync::atomic::{AtomicPtr, AtomicUsize};
use std::sync::atomic::Ordering::*;
use hazard::{AlignVec, BoxMemory, Memory, Pointers};
#[cfg(target_pointer_width="32")]
const POINTERS: usize = 32;
#[cfg(target_pointer_width="64")]
const POINTERS: usize = 16;
const INVALID: usize = usize::MAX;
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ConsumeError {
Disconnected,
Empty,
}
impl error::Error for ConsumeError {
fn description(&self) -> &str {
match *self {
ConsumeError::Disconnected => "the queue was empty and had no remaining producers",
ConsumeError::Empty => "the queue was empty",
}
}
}
impl fmt::Display for ConsumeError {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "{}", error::Error::description(self))
}
}
#[derive(Debug)]
pub struct Consumer<T>(usize, Arc<Queue<T>>);
impl<T> Consumer<T> {
pub fn consume(&self) -> Result<T, ConsumeError> {
self.1.consume(self.0)
}
pub fn try_clone(&self) -> Option<Self> {
if let Some(thread) = self.1.deqthreads.lock().unwrap().pop() {
self.1.consumers.fetch_add(1, Release);
Some(Consumer(thread, self.1.clone()))
} else {
None
}
}
}
impl<T> Clone for Consumer<T> {
fn clone(&self) -> Self {
self.try_clone().expect("too many consumer clones")
}
}
impl<T> Drop for Consumer<T> {
fn drop(&mut self) {
self.1.deqthreads.lock().unwrap().push(self.0);
self.1.consumers.fetch_sub(1, Release);
}
}
unsafe impl<T> Send for Consumer<T> where T: Send { }
#[derive(Debug)]
struct Node<T> {
item: Option<T>,
enqueuer: usize,
dequeuer: AtomicUsize,
next: AtomicPtr<Node<T>>,
}
impl<T> Node<T> {
fn new(item: Option<T>, enqueuer: usize) -> Self {
let dequeuer = AtomicUsize::new(INVALID);
Node { item, enqueuer, dequeuer, next: AtomicPtr::new(ptr::null_mut()) }
}
}
unsafe impl<T> Send for Queue<T> where T: Send { }
#[derive(Copy, Clone, PartialEq, Eq)]
pub struct ProduceError<T>(pub T);
impl<T> error::Error for ProduceError<T> {
fn description(&self) -> &str {
"the queue was full"
}
}
impl<T> fmt::Debug for ProduceError<T> {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "ProduceError(..)")
}
}
impl<T> fmt::Display for ProduceError<T> {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "{}", error::Error::description(self))
}
}
#[derive(Debug)]
pub struct Producer<T>(usize, Arc<Queue<T>>);
impl<T> Producer<T> {
pub fn produce(&self, item: T) -> Result<(), ProduceError<T>> {
self.1.produce(self.0, item)
}
pub fn try_clone(&self) -> Option<Self> {
if let Some(thread) = self.1.enqthreads.lock().unwrap().pop() {
self.1.producers.fetch_add(1, Release);
Some(Producer(thread, self.1.clone()))
} else {
None
}
}
}
impl<T> Clone for Producer<T> {
fn clone(&self) -> Self {
self.try_clone().expect("too many producer clones")
}
}
impl<T> Drop for Producer<T> {
fn drop(&mut self) {
self.1.enqthreads.lock().unwrap().push(self.0);
self.1.producers.fetch_sub(1, Release);
}
}
unsafe impl<T> Send for Producer<T> where T: Send { }
const WRITE: usize = 0;
const READ: usize = 0;
const NEXT: usize = 1;
const DEQUEUE: usize = 2;
#[derive(Debug)]
#[repr(C)]
struct Queue<T> {
write: AtomicPtr<Node<T>>,
producers: AtomicUsize,
_wpadding: [usize; POINTERS - 2],
enqreq: AlignVec<AtomicPtr<Node<T>>>,
read: AtomicPtr<Node<T>>,
consumers: AtomicUsize,
_rpadding: [usize; POINTERS - 2],
deqreq: AlignVec<AtomicPtr<Node<T>>>,
deqold: AlignVec<AtomicPtr<Node<T>>>,
pointers: Pointers<Node<T>, BoxMemory>,
sentinel: *mut Node<T>,
enqthreads: Mutex<Vec<usize>>,
deqthreads: Mutex<Vec<usize>>,
}
impl<T> Queue<T> {
fn new(producers: usize, consumers: usize) -> Arc<Self> {
let sentinel = BoxMemory.allocate(Node::new(None, 0));
let enqreq = (0..producers).map(|_| AtomicPtr::new(ptr::null_mut())).collect();
let deqold = (0..consumers).map(|_| {
AtomicPtr::new(BoxMemory.allocate(Node::new(None, 0)))
}).collect();
let deqreq = (0..consumers).map(|_| {
AtomicPtr::new(BoxMemory.allocate(Node::new(None, 0)))
}).collect();
Arc::new(Queue {
write: AtomicPtr::new(sentinel),
producers: AtomicUsize::new(1),
_wpadding: [0; POINTERS - 2],
enqreq: AlignVec::new(enqreq),
read: AtomicPtr::new(sentinel),
consumers: AtomicUsize::new(1),
_rpadding: [0; POINTERS - 2],
deqreq: AlignVec::new(deqreq),
deqold: AlignVec::new(deqold),
pointers: Pointers::new(BoxMemory, producers + consumers, 3, 512),
sentinel,
enqthreads: Mutex::new((1..producers).collect()),
deqthreads: Mutex::new((1..consumers).collect()),
})
}
fn produce(&self, thread: usize, item: T) -> Result<(), ProduceError<T>> {
if self.consumers.load(Acquire) == 0 {
return Err(ProduceError(item));
}
let node = BoxMemory.allocate(Node::new(Some(item), thread));
self.enqreq[thread].store(node, SeqCst);
for _ in 0..self.enqreq.len() {
if self.enqreq[thread].load(SeqCst).is_null() {
self.pointers.clear(thread, WRITE);
}
let write = self.pointers.mark_ptr(thread, WRITE, self.write.load(SeqCst));
if write != self.write.load(SeqCst) {
continue;
}
let enqueuer = unsafe { (*write).enqueuer };
if self.enqreq[enqueuer].load(SeqCst) == write {
exchange_ptr(&self.enqreq[enqueuer], write, ptr::null_mut());
}
for turn in 1..(self.enqreq.len() + 1) {
let node = self.enqreq[(enqueuer + turn) % self.enqreq.len()].load(SeqCst);
if !node.is_null() {
exchange_ptr(unsafe { &(*write).next }, ptr::null_mut(), node);
break;
}
}
let next = unsafe { (*write).next.load(SeqCst) };
if !next.is_null() {
exchange_ptr(&self.write, write, next);
}
}
self.enqreq[thread].store(ptr::null_mut(), Release);
self.pointers.clear(thread, WRITE);
Ok(())
}
fn assign(&self, read: *mut Node<T>, next: *mut Node<T>) -> bool {
let dequeuer = unsafe { (*read).dequeuer.load(SeqCst) };
for turn in 1..(self.deqreq.len() + 1) {
let thread = dequeuer.wrapping_add(turn) % self.deqreq.len();
if self.deqreq[thread].load(SeqCst) != self.deqold[thread].load(SeqCst) {
continue;
}
let dequeuer = unsafe { (*next).dequeuer.load(SeqCst) };
if dequeuer == INVALID {
exchange_usize(unsafe { &(*next).dequeuer }, INVALID, thread);
}
break;
}
let dequeuer = unsafe { (*next).dequeuer.load(SeqCst) };
dequeuer != INVALID
}
fn close(&self, thread: usize, read: *mut Node<T>, next: *mut Node<T>) {
let dequeuer = unsafe { (*next).dequeuer.load(SeqCst) };
if dequeuer == thread {
self.deqreq[dequeuer].store(next, Release);
} else {
let node = self.pointers.mark_ptr(thread, DEQUEUE, self.deqreq[dequeuer].load(SeqCst));
if node != next && read == self.read.load(SeqCst) {
exchange_ptr(&self.deqreq[dequeuer], node, next);
}
}
exchange_ptr(&self.read, read, next);
}
fn rollback(&self, thread: usize, old: *mut Node<T>, req: *mut Node<T>) {
self.deqold[thread].store(old, SeqCst);
let read = self.read.load(SeqCst);
if self.deqreq[thread].load(SeqCst) != req || read == self.write.load(SeqCst) {
return;
}
self.pointers.mark_ptr(thread, READ, read);
if read != self.read.load(SeqCst) {
return;
}
let next = self.pointers.mark_ptr(thread, NEXT, unsafe { (*read).next.load(SeqCst) });
if read != self.read.load(SeqCst) {
return;
}
if !self.assign(read, next) {
exchange_usize(unsafe { &(*next).dequeuer }, usize::max_value(), thread);
}
self.close(thread, read, next);
}
fn consume(&self, thread: usize) -> Result<T, ConsumeError> {
let old = self.deqold[thread].load(SeqCst);
let req = self.deqreq[thread].load(SeqCst);
self.deqold[thread].store(req, SeqCst);
for _ in 0..self.deqreq.len() {
if self.deqreq[thread].load(SeqCst) != req {
break;
}
let read = self.pointers.mark_ptr(thread, READ, self.read.load(SeqCst));
if read != self.read.load(SeqCst) {
continue;
}
if read == self.write.load(SeqCst) {
self.rollback(thread, old, req);
if self.deqreq[thread].load(SeqCst) != req {
self.deqold[thread].store(req, Relaxed);
break;
}
self.pointers.clear(thread, READ);
self.pointers.clear(thread, NEXT);
self.pointers.clear(thread, DEQUEUE);
if self.producers.load(Acquire) == 0 {
return Err(ConsumeError::Disconnected);
} else {
return Err(ConsumeError::Empty);
}
}
let next = self.pointers.mark_ptr(thread, NEXT, unsafe { (*read).next.load(SeqCst) });
if read != self.read.load(SeqCst) {
continue;
}
if self.assign(read, next) {
self.close(thread, read, next);
}
}
let node = self.deqreq[thread].load(SeqCst);
let read = self.pointers.mark_ptr(thread, READ, self.read.load(SeqCst));
let next = unsafe { (*read).next.load(SeqCst) };
if read == self.read.load(SeqCst) && node == next {
exchange_ptr(&self.read, read, next);
}
self.pointers.clear(thread, READ);
self.pointers.clear(thread, NEXT);
self.pointers.clear(thread, DEQUEUE);
self.pointers.retire(thread, old);
Ok(unsafe { (*node).item.take().unwrap() })
}
}
impl<T> Drop for Queue<T> {
fn drop(&mut self) {
unsafe { BoxMemory.deallocate(self.sentinel); }
while self.consume(0).is_ok() { }
for (req, old) in self.deqreq.iter().zip(self.deqold.iter()) {
unsafe { BoxMemory.deallocate(req.load(Relaxed)); }
unsafe { BoxMemory.deallocate(old.load(Relaxed)); }
}
}
}
unsafe impl<T> Sync for Queue<T> where T: Send { }
fn exchange_ptr<T>(atomic: &AtomicPtr<T>, current: *mut T, new: *mut T) {
let _ = atomic.compare_exchange(current, new, SeqCst, SeqCst);
}
fn exchange_usize(atomic: &AtomicUsize, current: usize, new: usize) {
let _ = atomic.compare_exchange(current, new, SeqCst, SeqCst);
}
pub fn channel<T>(producers: usize, consumers: usize) -> (Producer<T>, Consumer<T>) {
let queue = Queue::new(producers + 1, consumers + 1);
(Producer(0, queue.clone()), Consumer(0, queue))
}