mod batch;
mod consumer;
mod producer;
mod shared;
pub use batch::MpscBatchSlots;
pub use consumer::Consumer;
pub use producer::Producer;
use crate::ringbuffer::RingBuffer;
use cpu::{CachePadded, Cursor, SpinLoopHintWait, WaitStrategy};
use shared::{Shared, BITS_PER_WORD};
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize};
use std::sync::Arc;
pub fn channel<T>(capacity: usize) -> (Producer<T>, Consumer<T, SpinLoopHintWait>) {
channel_with_wait(capacity, SpinLoopHintWait)
}
pub fn channel_with_wait<T, W: WaitStrategy>(
capacity: usize,
wait_strategy: W,
) -> (Producer<T>, Consumer<T, W>) {
let (shared, wait_strategy) = create_shared(capacity, wait_strategy, None::<fn() -> T>);
let producer = Producer {
shared: Arc::clone(&shared),
};
let consumer = Consumer {
shared,
next_sequence: 0,
_wait_strategy: wait_strategy,
};
(producer, consumer)
}
pub fn channel_with_factory<T, F, W>(
capacity: usize,
factory: F,
wait_strategy: W,
) -> (Producer<T>, Consumer<T, W>)
where
F: Fn() -> T,
W: WaitStrategy,
{
let (shared, wait_strategy) = create_shared(capacity, wait_strategy, Some(factory));
let producer = Producer {
shared: Arc::clone(&shared),
};
let consumer = Consumer {
shared,
next_sequence: 0,
_wait_strategy: wait_strategy,
};
(producer, consumer)
}
fn create_shared<T, F, W>(
capacity: usize,
wait_strategy: W,
factory: Option<F>,
) -> (Arc<Shared<T>>, W)
where
F: Fn() -> T,
W: WaitStrategy,
{
let availability_words = (capacity + BITS_PER_WORD - 1) / BITS_PER_WORD;
let availability: Box<[CachePadded<AtomicU64>]> = {
let mut v = Vec::with_capacity(availability_words);
for _ in 0..availability_words {
v.push(CachePadded::new(AtomicU64::new(!0u64)));
}
v.into_boxed_slice()
};
let buffer = match factory {
Some(f) => RingBuffer::with_factory(capacity, f),
None => RingBuffer::new(capacity),
};
let shared = Arc::new(Shared {
buffer,
claim_cursor: Cursor::new(),
consumer_cursor: Cursor::new(),
availability,
availability_words,
producer_count: CachePadded::new(AtomicUsize::new(1)),
closed: CachePadded::new(AtomicBool::new(false)),
capacity,
});
(shared, wait_strategy)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::{RecvError, SendError, TryRecvError};
use std::collections::HashSet;
use std::thread;
#[test]
fn test_basic_send_recv() {
let (tx, mut rx) = channel::<u64>(64);
tx.send(42).unwrap();
assert_eq!(rx.recv().unwrap(), 42);
}
#[test]
fn test_multiple_sends() {
let (tx, mut rx) = channel::<u64>(64);
for i in 0..10 {
tx.send(i).unwrap();
}
for i in 0..10 {
assert_eq!(rx.recv().unwrap(), i);
}
}
#[test]
fn test_try_recv_empty() {
let (_tx, mut rx) = channel::<u64>(64);
match rx.try_recv() {
Err(TryRecvError::Empty) => {}
_ => panic!("expected Empty error"),
}
}
#[test]
fn test_producer_dropped() {
let (tx, mut rx) = channel::<u64>(64);
drop(tx);
match rx.recv() {
Err(RecvError) => {}
_ => panic!("expected RecvError"),
}
}
#[test]
fn test_consumer_dropped() {
let (tx, rx) = channel::<u64>(64);
drop(rx);
match tx.send(42) {
Err(SendError(42)) => {}
_ => panic!("expected SendError"),
}
}
#[test]
fn test_multiple_producers() {
let (tx, mut rx) = channel::<u64>(1024);
let tx1 = tx.clone();
let tx2 = tx.clone();
let h1 = thread::spawn(move || {
for i in 0..100 {
tx1.send(i).unwrap();
}
});
let h2 = thread::spawn(move || {
for i in 100..200 {
tx2.send(i).unwrap();
}
});
h1.join().unwrap();
h2.join().unwrap();
drop(tx);
let mut received = HashSet::new();
while let Ok(v) = rx.try_recv() {
received.insert(v);
}
assert_eq!(received.len(), 200);
for i in 0..200 {
assert!(received.contains(&i), "missing {}", i);
}
}
#[test]
fn test_clone_producer() {
let (tx, _rx) = channel::<u64>(64);
let tx2 = tx.clone();
let tx3 = tx.clone();
tx.send(1).unwrap();
tx2.send(2).unwrap();
tx3.send(3).unwrap();
}
#[test]
fn test_peek() {
let (tx, rx) = channel::<u64>(64);
tx.send(42).unwrap();
assert_eq!(*rx.peek().unwrap(), 42);
assert_eq!(*rx.peek().unwrap(), 42); }
#[test]
fn test_pending() {
let (tx, rx) = channel::<u64>(64);
assert_eq!(rx.pending(), 0);
tx.send(1).unwrap();
tx.send(2).unwrap();
assert_eq!(rx.pending(), 2);
}
#[test]
fn test_close() {
let (tx, mut rx) = channel::<u64>(64);
tx.send(42).unwrap();
tx.close();
assert_eq!(rx.recv().unwrap(), 42);
match rx.recv() {
Err(RecvError) => {}
_ => panic!("expected RecvError"),
}
}
#[test]
fn test_debug() {
let (tx, rx) = channel::<u64>(64);
let _ = format!("{:?}", tx);
let _ = format!("{:?}", rx);
}
}