use std::{
sync::Arc,
cell::UnsafeCell,
};
use crate::queue::{
waker::Checker,
unbounded::MpmcUnbounded,
bounded::{Bounded, MpmcBounded},
error::{SendError, RecvError, TryRecvError, TrySendError},
};
#[inline]
pub fn bounded<T: Send>(size: u32) -> (BSender<T>, BReceiver<T>) {
let queue = Arc::new(UnsafeCell::new(MpmcBounded::new(size)));
(BSender::new(queue.clone()), BReceiver::new(queue))
}
pub struct BSender<T> {
inner: Arc<UnsafeCell<MpmcBounded<T>>>,
}
unsafe impl<T: Send> Send for BSender<T> {}
unsafe impl<T: Send> Sync for BSender<T> {}
impl<T: Send> Clone for BSender<T> {
#[inline]
fn clone(&self) -> Self {
Self { inner: self.inner.clone() }
}
}
impl<T: Send> BSender<T> {
#[inline]
fn new(inner: Arc<UnsafeCell<MpmcBounded<T>>>) -> Self <> {
Self { inner }
}
#[inline]
pub fn try_send(&self, value: T) -> Result<(), TrySendError<T>> {
unsafe { (*self.inner.get()).try_send(value) }
}
#[inline]
pub fn send(&self, value: T) -> Result<(), SendError<T>> {
unsafe { (*self.inner.get()).send(value, (*self.inner.get()).cast()) }
}
#[inline]
pub fn length(&self) -> u32 {
unsafe { (*self.inner.get()).length() }
}
#[inline]
pub fn close(&self) {
unsafe { (*self.inner.get()).close() }
}
#[inline]
pub fn is_close(&self) -> bool {
unsafe { (*self.inner.get()).is_close() }
}
}
pub struct BReceiver<T> {
inner: Arc<UnsafeCell<MpmcBounded<T>>>,
}
unsafe impl<T: Send> Send for BReceiver<T> {}
unsafe impl<T: Send> Sync for BReceiver<T> {}
impl<T: Send> Clone for BReceiver<T> {
#[inline]
fn clone(&self) -> Self {
Self { inner: self.inner.clone() }
}
}
impl<T: Send> BReceiver<T> {
#[inline]
fn new(inner: Arc<UnsafeCell<MpmcBounded<T>>>) -> Self <> {
Self { inner }
}
#[inline]
pub fn try_recv(&self) -> Result<T, TryRecvError> {
unsafe { (*self.inner.get()).try_recv() }
}
#[inline]
pub fn recv(&self) -> Result<T, RecvError> {
unsafe { (*self.inner.get()).recv((*self.inner.get()).cast()) }
}
#[inline]
pub fn length(&self) -> u32 {
unsafe { (*self.inner.get()).length() }
}
#[inline]
pub fn close(&self) {
unsafe { (*self.inner.get()).close() }
}
#[inline]
pub fn is_close(&self) -> bool {
unsafe { (*self.inner.get()).is_close() }
}
}
#[inline]
pub fn unbounded<T: Send>() -> (USender<T>, UReceiver<T>) {
let queue = Arc::new(UnsafeCell::new(MpmcUnbounded::default()));
(USender::new(queue.clone()), UReceiver::new(queue))
}
pub struct USender<T> {
inner: Arc<UnsafeCell<MpmcUnbounded<T>>>,
}
unsafe impl<T: Send> Send for USender<T> {}
unsafe impl<T: Send> Sync for USender<T> {}
impl<T: Send> Clone for USender<T> {
#[inline]
fn clone(&self) -> Self {
Self { inner: self.inner.clone() }
}
}
impl<T: Send> USender<T> {
#[inline]
fn new(inner: Arc<UnsafeCell<MpmcUnbounded<T>>>) -> Self <> {
Self { inner }
}
#[inline]
pub fn send(&self, value: T) -> Result<(), SendError<T>> {
unsafe { (*self.inner.get()).send(value) }
}
#[inline]
pub fn close(&self) {
unsafe { (*self.inner.get()).close() }
}
}
pub struct UReceiver<T> {
inner: Arc<UnsafeCell<MpmcUnbounded<T>>>,
}
unsafe impl<T: Send> Send for UReceiver<T> {}
unsafe impl<T: Send> Sync for UReceiver<T> {}
impl<T: Send> Clone for UReceiver<T> {
#[inline]
fn clone(&self) -> Self {
Self { inner: self.inner.clone() }
}
}
impl<T: Send> UReceiver<T> {
#[inline]
fn new(inner: Arc<UnsafeCell<MpmcUnbounded<T>>>) -> Self <> {
Self { inner }
}
#[inline]
pub fn recv(&self) -> Result<T, RecvError> {
unsafe { (*self.inner.get()).recv() }
}
#[inline]
pub fn close(&self) {
unsafe { (*self.inner.get()).close() }
}
}
mod test {
use crate::queue::mpmc::{bounded, BReceiver, BSender, unbounded, UReceiver, USender};
fn is_send<T: Send>() {}
#[test]
fn bounds() {
is_send::<BSender<i32>>();
is_send::<BReceiver<i32>>();
}
#[test]
fn unbound() {
is_send::<USender<i32>>();
is_send::<UReceiver<i32>>();
}
#[test]
fn send_recv() {
let (tx_b, rx_b) = bounded(3);
tx_b.try_send(1).unwrap();
assert_eq!(rx_b.try_recv().unwrap(), 1);
let (tx_u, rx_u) = unbounded();
tx_u.send(1).unwrap();
assert_eq!(rx_u.recv().unwrap(), 1);
}
#[test]
fn send_shared_recv() {
let (tx_b1, rx_b) = bounded(4);
let tx_b2 = tx_b1.clone();
tx_b1.send(1).unwrap();
assert_eq!(rx_b.recv().unwrap(), 1);
tx_b2.send(2).unwrap();
assert_eq!(rx_b.recv().unwrap(), 2);
let (tx_u1, rx_u) = unbounded();
let tx_u2 = tx_u1.clone();
tx_u1.send(1).unwrap();
assert_eq!(rx_u.recv().unwrap(), 1);
tx_u2.send(2).unwrap();
assert_eq!(rx_u.recv().unwrap(), 2);
}
#[test]
fn send_recv_threads() {
let (tx_b, rx_b) = bounded(4);
let thread = std::thread::spawn(move || {
tx_b.send(1).unwrap();
});
assert_eq!(rx_b.recv().unwrap(), 1);
thread.join().unwrap();
let (tx_u, rx_u) = unbounded();
let thread = std::thread::spawn(move || {
tx_u.send(1).unwrap();
});
assert_eq!(rx_u.recv().unwrap(), 1);
thread.join().unwrap();
}
#[test]
fn send_recv_threads_no_capacity() {
let (tx, rx) = bounded(0);
let thread = std::thread::spawn(move || {
tx.send(1).unwrap();
tx.send(2).unwrap();
});
std::thread::sleep(std::time::Duration::from_millis(100));
assert_eq!(rx.recv().unwrap(), 1);
std::thread::sleep(std::time::Duration::from_millis(100));
assert_eq!(rx.recv().unwrap(), 2);
thread.join().unwrap();
}
#[test]
fn send_close_gets_none() {
let (tx_b, rx_b) = bounded::<i32>(1);
let thread = std::thread::spawn(move || {
assert!(rx_b.recv().is_err());
});
tx_b.close();
thread.join().unwrap();
let (tx_u, rx_u) = unbounded::<i32>();
let thread = std::thread::spawn(move || {
assert!(rx_u.recv().is_err());
});
tx_u.close();
thread.join().unwrap();
}
#[test]
fn mpsc_no_capacity() {
let amt = 30000;
let nthreads = 7;
let (tx, rx) = bounded(0);
for _ in 0..nthreads {
let txc = tx.clone();
std::thread::spawn(move || {
for _ in 0..amt {
assert_eq!(txc.send(1), Ok(()));
}
});
}
for _ in 0..amt * nthreads {
assert_eq!(rx.recv(), Ok(1));
}
}
#[test]
fn mpmc_no_capacity() {
let amt = 50000;
let nthreads_send = 7;
let nthreads_recv = 7;
let (tx, rx) = bounded(0);
let mut receiving_threads = Vec::new();
let mut sending_threads = Vec::new();
for _ in 0..nthreads_send {
let txc = tx.clone();
let child = std::thread::spawn(move || {
for _ in 0..amt {
assert_eq!(txc.send(1), Ok(()));
}
});
sending_threads.push(child);
}
for _ in 0..nthreads_recv {
let rxc = rx.clone();
let thread = std::thread::spawn(move || {
for _ in 0..amt {
assert_eq!(rxc.recv(), Ok(1));
}
});
receiving_threads.push(thread);
}
for thread in sending_threads {
thread.join().expect("oops! the child thread panicked");
}
for thread in receiving_threads {
thread.join().expect("oops! the child thread panicked");
}
}
}