#![deny(warnings, missing_docs)]
pub use std::sync::mpsc::{SendError, TrySendError};
pub use std::sync::mpsc::{TryRecvError, RecvError, RecvTimeoutError};
use std::{mem, ops, ptr, usize};
use std::sync::{Arc, Mutex, MutexGuard, Condvar};
use std::sync::atomic::{AtomicUsize, AtomicBool, Ordering};
use std::time::{Duration, Instant};
pub struct Sender<T> {
inner: Arc<Inner<T>>,
}
pub struct Receiver<T> {
inner: Arc<Inner<T>>,
}
struct Inner<T> {
capacity: usize,
len: AtomicUsize,
is_open: AtomicBool,
head: Mutex<NodePtr<T>>,
not_empty: Condvar,
num_tx: AtomicUsize,
tail: Mutex<NodePtr<T>>,
not_full: Condvar,
num_rx: AtomicUsize,
}
pub enum SendTimeoutError<T> {
Disconnected(T),
Timeout(T),
}
pub fn channel<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
let inner = Inner::with_capacity(capacity);
let inner = Arc::new(inner);
let tx = Sender { inner: inner.clone() };
let rx = Receiver { inner: inner };
(tx, rx)
}
pub fn unbounded<T>() -> (Sender<T>, Receiver<T>) {
channel(usize::MAX)
}
impl<T> Sender<T> {
pub fn send(&self, t: T) -> Result<(), SendError<T>> {
self.inner.push(t)
}
pub fn send_timeout(&self, t: T, timeout: Duration) -> Result<(), SendTimeoutError<T>> {
self.inner.push_timeout(t, timeout)
}
pub fn try_send(&self, t: T) -> Result<(), TrySendError<T>> {
self.inner.try_push(t)
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn close(&self) {
self.inner.close();
}
pub fn is_open(&self) -> bool {
self.inner.is_open.load(Ordering::SeqCst)
}
pub fn capacity(&self) -> usize {
self.inner.capacity
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Sender<T> {
let inner = self.inner.clone();
self.inner.num_tx.fetch_add(1, Ordering::SeqCst);
Sender { inner: inner }
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
if 1 == self.inner.num_tx.fetch_sub(1, Ordering::SeqCst) {
self.inner.close();
}
}
}
impl<T> Receiver<T> {
pub fn recv(&self) -> Result<T, RecvError> {
self.inner.pop()
}
pub fn recv_timeout(&self, timeout: Duration) -> Result<T, RecvTimeoutError> {
self.inner.pop_timeout(timeout)
}
pub fn try_recv(&self) -> Result<T, TryRecvError> {
self.inner.try_pop()
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn close(&self) {
self.inner.close();
}
pub fn is_open(&self) -> bool {
self.inner.is_open.load(Ordering::SeqCst)
}
pub fn capacity(&self) -> usize {
self.inner.capacity
}
}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Receiver<T> {
let inner = self.inner.clone();
self.inner.num_rx.fetch_add(1, Ordering::SeqCst);
Receiver { inner: inner }
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
if 1 == self.inner.num_rx.fetch_sub(1, Ordering::SeqCst) {
self.inner.close();
}
}
}
impl<T> Inner<T> {
fn with_capacity(capacity: usize) -> Inner<T> {
let head = NodePtr::new(Box::new(Node::empty()));
Inner {
capacity: capacity,
len: AtomicUsize::new(0),
is_open: AtomicBool::new(true),
head: Mutex::new(head),
not_empty: Condvar::new(),
num_tx: AtomicUsize::new(1),
tail: Mutex::new(head),
not_full: Condvar::new(),
num_rx: AtomicUsize::new(1),
}
}
fn len(&self) -> usize {
self.len.load(Ordering::SeqCst)
}
fn push(&self, el: T) -> Result<(), SendError<T>> {
let node = Box::new(Node::new(el));
let mut tail = self.tail.lock().ok().expect("something went wrong");
while self.len.load(Ordering::Acquire) == self.capacity {
if !self.is_open.load(Ordering::Relaxed) {
return Err(SendError(node.into_inner()));
}
tail = self.not_full.wait(tail).ok().expect("something went wrong");
}
if !self.is_open.load(Ordering::Relaxed) {
return Err(SendError(node.into_inner()));
}
self.enqueue(node, tail);
Ok(())
}
fn try_push(&self, el: T) -> Result<(), TrySendError<T>> {
let node = Box::new(Node::new(el));
let tail = self.tail.lock().ok().expect("something went wrong");
if !self.is_open.load(Ordering::Relaxed) {
return Err(TrySendError::Disconnected(node.into_inner()));
}
if self.len.load(Ordering::Acquire) == self.capacity {
return Err(TrySendError::Full(node.into_inner()));
}
self.enqueue(node, tail);
Ok(())
}
fn push_timeout(&self, el: T, dur: Duration) -> Result<(), SendTimeoutError<T>> {
let node = Box::new(Node::new(el));
let mut tail = self.tail.lock().ok().expect("something went wrong");
if self.len.load(Ordering::Acquire) == self.capacity {
let mut now = Instant::now();
let deadline = now + dur;
loop {
if now >= deadline {
return Err(SendTimeoutError::Timeout(node.into_inner()));
}
if !self.is_open.load(Ordering::Relaxed) {
return Err(SendTimeoutError::Disconnected(node.into_inner()));
}
tail = self.not_full
.wait_timeout(tail, deadline.duration_since(now))
.ok()
.expect("something went wrong")
.0;
if self.len.load(Ordering::Acquire) != self.capacity {
break;
}
now = Instant::now();
}
}
if !self.is_open.load(Ordering::Relaxed) {
return Err(SendTimeoutError::Disconnected(node.into_inner()));
}
self.enqueue(node, tail);
Ok(())
}
fn enqueue(&self, el: Box<Node<T>>, mut tail: MutexGuard<NodePtr<T>>) {
let ptr = NodePtr::new(el);
tail.next = ptr;
*tail = ptr;
let len = self.len.fetch_add(1, Ordering::Release);
if len + 1 < self.capacity {
self.not_full.notify_one();
}
drop(tail);
if len == 0 {
let _l = self.head
.lock()
.ok()
.expect("something went wrong");
self.not_empty.notify_one();
}
}
fn pop(&self) -> Result<T, RecvError> {
let mut head = self.head.lock().ok().expect("something went wrong");
while self.len.load(Ordering::Acquire) == 0 {
if !self.is_open.load(Ordering::Relaxed) {
return Err(RecvError);
}
head = self.not_empty.wait(head).ok().expect("something went wrong");
}
Ok(self.dequeue(head))
}
fn try_pop(&self) -> Result<T, TryRecvError> {
let head = self.head.lock().ok().expect("something went wrong");
if self.len.load(Ordering::Acquire) == 0 {
if !self.is_open.load(Ordering::Relaxed) {
return Err(TryRecvError::Disconnected);
} else {
return Err(TryRecvError::Empty);
}
}
Ok(self.dequeue(head))
}
fn pop_timeout(&self, dur: Duration) -> Result<T, RecvTimeoutError> {
let mut head = self.head.lock().ok().expect("something went wrong");
if self.len.load(Ordering::Acquire) == 0 {
let mut now = Instant::now();
let deadline = now + dur;
loop {
if now >= deadline {
return Err(RecvTimeoutError::Timeout);
}
if !self.is_open.load(Ordering::Relaxed) {
return Err(RecvTimeoutError::Disconnected);
}
head = self.not_empty
.wait_timeout(head, deadline.duration_since(now))
.ok()
.expect("something went wrong")
.0;
if self.len.load(Ordering::Acquire) != 0 {
break;
}
now = Instant::now();
}
}
Ok(self.dequeue(head))
}
fn dequeue(&self, mut head: MutexGuard<NodePtr<T>>) -> T {
let h = *head;
let mut first = h.next;
*head = first;
let val = first.item.take().expect("item already consumed");
let len = self.len.fetch_sub(1, Ordering::Release);
if len > 1 {
self.not_empty.notify_one();
}
drop(head);
h.free();
if len == self.capacity {
let _l = self.tail
.lock()
.ok()
.expect("something went wrong");
self.not_full.notify_one();
}
val
}
fn close(&self) {
if self.is_open.swap(false, Ordering::SeqCst) {
self.notify_tx();
self.notify_rx();
}
}
fn notify_tx(&self) {
let _lock = self.head.lock().expect("something went wrong");
self.not_empty.notify_all();
}
fn notify_rx(&self) {
let _lock = self.tail.lock().expect("something went wrong");
self.not_full.notify_all();
}
}
impl<T> Drop for Inner<T> {
fn drop(&mut self) {
while let Ok(_) = self.try_pop() {
}
}
}
struct Node<T> {
next: NodePtr<T>,
item: Option<T>,
}
impl<T> Node<T> {
fn new(val: T) -> Node<T> {
Node {
next: NodePtr::null(),
item: Some(val),
}
}
fn empty() -> Node<T> {
Node {
next: NodePtr::null(),
item: None,
}
}
fn into_inner(self) -> T {
self.item.unwrap()
}
}
struct NodePtr<T> {
ptr: *mut Node<T>,
}
impl<T> NodePtr<T> {
fn new(node: Box<Node<T>>) -> NodePtr<T> {
NodePtr { ptr: unsafe { mem::transmute(node) } }
}
fn null() -> NodePtr<T> {
NodePtr { ptr: ptr::null_mut() }
}
fn free(self) {
let NodePtr { ptr } = self;
let _: Box<Node<T>> = unsafe { mem::transmute(ptr) };
}
}
impl<T> ops::Deref for NodePtr<T> {
type Target = Node<T>;
fn deref(&self) -> &Node<T> {
unsafe { mem::transmute(self.ptr) }
}
}
impl<T> ops::DerefMut for NodePtr<T> {
fn deref_mut(&mut self) -> &mut Node<T> {
unsafe { mem::transmute(self.ptr) }
}
}
impl<T> Clone for NodePtr<T> {
fn clone(&self) -> NodePtr<T> {
NodePtr { ptr: self.ptr }
}
}
impl<T> Copy for NodePtr<T> {}
unsafe impl<T> Send for NodePtr<T> where T: Send {}
#[cfg(test)]
mod test {
use super::*;
use std::thread;
use std::time::{Duration, Instant};
#[test]
fn single_thread_send_recv() {
let (tx, rx) = channel(1024);
assert_eq!(0, tx.len());
assert_eq!(0, rx.len());
tx.send("hello").unwrap();
assert_eq!(1, tx.len());
assert_eq!(1, rx.len());
assert_eq!("hello", rx.recv().unwrap());
assert_eq!(0, tx.len());
assert_eq!(0, rx.len());
assert_eq!(TryRecvError::Empty, rx.try_recv().unwrap_err());
}
#[test]
fn single_thread_send_timeout() {
let (tx, _rx) = channel(1);
tx.try_send("hello").unwrap();
let now = Instant::now();
let dur = Duration::from_millis(200);
assert!(tx.send_timeout("world", dur).is_err());
let act = now.elapsed();
assert!(act >= dur);
assert!(act < dur * 2);
}
#[test]
fn single_thread_recv_timeout() {
let (_tx, rx) = channel::<u32>(1024);
let now = Instant::now();
let dur = Duration::from_millis(200);
assert!(rx.recv_timeout(dur).is_err());
let act = now.elapsed();
assert!(act >= dur);
assert!(act < dur * 2);
}
#[test]
fn single_consumer_single_producer() {
let (tx, rx) = channel(1024);
thread::spawn(move || {
thread::sleep(Duration::from_millis(10));
for i in 0..10_000 {
tx.send(i).unwrap();
}
});
for i in 0..10_000 {
assert_eq!(i, rx.recv().unwrap());
}
assert!(rx.recv().is_err());
}
#[test]
fn single_consumer_multi_producer() {
let (tx, rx) = channel(1024);
for t in 0..10 {
let tx = tx.clone();
thread::spawn(move || {
for i in 0..10_000 {
tx.send((t, i)).unwrap();
}
});
}
drop(tx);
let mut vals = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
for _ in 0..10 * 10_000 {
let (t, v) = rx.recv().unwrap();
assert_eq!(vals[t], v);
vals[t] += 1;
}
assert!(rx.recv().is_err());
for i in 0..10 {
assert_eq!(vals[i], 10_000);
}
}
#[test]
fn multi_consumer_multi_producer() {
let (tx, rx) = channel(1024);
let (res_tx, res_rx) = channel(1024);
const PER_PRODUCER: usize = 10_000;
for t in 0..5 {
let tx = tx.clone();
thread::spawn(move || {
for i in 1..PER_PRODUCER {
tx.send((t, i)).unwrap();
if i % 10 == 0 {
thread::yield_now();
}
}
});
}
drop(tx);
for _ in 0..5 {
let rx = rx.clone();
let res_tx = res_tx.clone();
thread::spawn(move || {
let mut vals = vec![];
let mut per_producer = [0, 0, 0, 0, 0];
loop {
let (t, v) = match rx.recv() {
Ok(v) => v,
_ => break,
};
assert!(per_producer[t] < v);
per_producer[t] = v;
vals.push((t, v));
if v % 10 == 0 {
thread::yield_now();
}
}
res_tx.send(vals).unwrap();
});
}
drop(rx);
drop(res_tx);
let mut all_vals = vec![];
for _ in 0..5 {
let vals = res_rx.recv().unwrap();
for &v in vals.iter() {
all_vals.push(v);
}
}
all_vals.sort();
let mut per_producer = [1, 1, 1, 1, 1];
for &(t, v) in all_vals.iter() {
assert_eq!(per_producer[t], v);
per_producer[t] += 1;
}
for &v in per_producer.iter() {
assert_eq!(PER_PRODUCER, v);
}
}
#[test]
fn queue_with_capacity() {
let (tx, rx) = channel(8);
for i in 0..8 {
assert!(tx.try_send(i).is_ok());
}
assert_eq!(TrySendError::Full(8), tx.try_send(8).unwrap_err());
assert_eq!(0, rx.try_recv().unwrap());
assert!(tx.try_send(8).is_ok());
for i in 1..9 {
assert_eq!(i, rx.try_recv().unwrap());
}
}
#[test]
fn multi_producer_at_capacity() {
let (tx, rx) = channel(8);
for _ in 0..8 {
let tx = tx.clone();
thread::spawn(move || {
for i in 0..1_000 {
tx.send(i).unwrap();
}
});
}
drop(tx);
for _ in 0..8 * 1_000 {
rx.recv().unwrap();
}
rx.recv().unwrap_err();
}
#[test]
fn test_tx_shutdown() {
let (tx, rx) = channel(1024);
{
let tx = tx.clone();
thread::spawn(move || {
tx.send("hello").unwrap();
tx.close();
});
}
assert_eq!("hello", rx.recv().unwrap());
assert!(rx.recv().is_err());
assert!(tx.send("goodbye").is_err());
}
#[test]
fn test_rx_shutdown() {
let (tx, rx) = channel(1024);
{
let tx = tx.clone();
let rx = rx.clone();
thread::spawn(move || {
tx.send("hello").unwrap();
rx.close();
});
}
assert_eq!("hello", rx.recv().unwrap());
assert!(rx.recv().is_err());
assert!(tx.send("goodbye").is_err());
}
}