use crate::{
error::NoSpaceLeftError,
import::{Arc, AtomicUsize, Ordering, UnsafeCell},
spsc::is_power_of_two,
};
use crossbeam_utils::CachePadded;
use std::{fmt::Debug, mem::zeroed};
pub fn spmc_once<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
if !is_power_of_two(capacity) {
panic!("The SIZE must be a power of 2")
}
let chan = Arc::new(SpmcOnce::new(capacity));
let r = Receiver::new(chan.clone()).unwrap();
let w = Sender::new(chan);
(w, r)
}
#[derive(Debug)]
struct Slot<T: Sized> {
value: UnsafeCell<Option<T>>,
}
impl<T> Slot<T> {
fn new() -> Self {
Self {
value: UnsafeCell::new(None),
}
}
}
#[derive(Debug)]
struct SpmcOnce<T> {
mem: Box<[Slot<T>]>,
mask: usize,
write: CachePadded<AtomicUsize>,
read_allocate: CachePadded<AtomicUsize>,
read: CachePadded<AtomicUsize>,
}
impl<T> SpmcOnce<T> {
fn new(capacity: usize) -> Self {
let mut buffer = Vec::with_capacity(capacity);
for _ in 0..capacity {
buffer.push(Slot::new());
}
let buffer: Box<[Slot<T>]> = buffer.into_boxed_slice();
SpmcOnce {
mem: buffer,
mask: capacity - 1,
write: CachePadded::new(0.into()),
read_allocate: CachePadded::new(0.into()),
read: CachePadded::new(0.into()),
}
}
#[inline]
fn capacity(&self) -> usize {
self.mask + 1
}
}
#[derive(Debug)]
pub struct Receiver<T> {
shared: Arc<SpmcOnce<T>>,
}
unsafe impl<T: Send> Send for Receiver<T> {}
unsafe impl<T: Send> Sync for Receiver<T> {}
impl<T> Receiver<T> {
fn new(shared: Arc<SpmcOnce<T>>) -> Option<Self> {
Some(Receiver { shared })
}
}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
let shared = self.shared.clone();
Self { shared }
}
}
impl<T> Receiver<T> {
pub fn try_recv(&mut self) -> Option<T>
where
T: Clone,
{
if self.shared.read.load(Ordering::Acquire) == self.shared.write.load(Ordering::Acquire) {
return None;
}
self.shared.read_allocate.fetch_add(1, Ordering::Acquire);
let rpos = self.read & self.shared.mask;
let slot = unsafe { self.shared.mem.get_unchecked(rpos) };
#[cfg(not(loom))]
let data = unsafe { slot.value.get().replace(None) };
#[cfg(loom)]
let data = unsafe {
spmc.mem[self.read & spsc.mask]
.value
.get_mut()
.with(|ptr| ptr.replace(std::mem::zeroed()))
};
slot.read_tracking
.fetch_add(self.receiver_id, Ordering::Release);
self.read += 1;
data
}
#[inline]
pub fn capacity(&self) -> usize {
self.shared.capacity()
}
}
#[derive(Debug)]
pub struct Sender<T> {
shared: Arc<SpmcOnce<T>>,
write: usize,
local_read_index: usize,
}
unsafe impl<T: Send> Send for Sender<T> {}
unsafe impl<T: Send> Sync for Sender<T> {}
impl<T> Sender<T> {
fn new(shared: Arc<SpmcOnce<T>>) -> Self {
Sender {
shared,
write: 0,
local_read_index: 0,
}
}
}
impl<T> Sender<T> {
pub fn try_send(&mut self, data: T) -> Result<(), NoSpaceLeftError<T>>
where
T: Clone,
{
if self.write - self.local_read_index == self.shared.capacity() {
if self
.shared
.slot_completed((self.local_read_index + 1) & self.shared.mask)
{
self.local_read_index += 1;
debug_assert!(self.write - self.local_read_index != self.shared.capacity());
} else {
return Err(NoSpaceLeftError(data));
}
}
let wpos = self.write & self.shared.mask;
let slot = unsafe { self.shared.mem.get_unchecked(wpos) };
#[cfg(not(loom))]
unsafe {
slot.value.get().write(Some(data))
};
#[cfg(loom)]
unsafe {
slot.value.get_mut().with(|ptr| ptr.write(Some(data)))
};
self.write += 1;
self.shared.write.store(self.write, Ordering::Release);
Ok(())
}
pub fn capacity(&self) -> usize {
self.shared.capacity()
}
}
#[cfg(not(loom))]
#[cfg(test)]
mod test {
use super::*;
#[test]
fn smoke() {
let (mut tx, mut rx) = spmc(8);
tx.try_send(56).unwrap();
assert_eq!(rx.try_recv(), Some(56));
let mut rx2 = rx.clone();
assert_eq!(rx2.try_recv(), None);
tx.try_send(6798).unwrap();
assert_eq!(rx.try_recv(), Some(6798));
assert_eq!(rx2.try_recv(), Some(6798));
}
#[test]
fn test_multiple_reader() {
let (mut tx, mut rx) = spmc(8);
assert_eq!(rx.shared.reader_ids.num_readers(), 1);
tx.try_send(56).unwrap();
tx.try_send(57).unwrap();
let mut rx2 = rx.clone();
tx.try_send(58).unwrap();
assert_eq!(rx.try_recv(), Some(56));
assert_eq!(rx.try_recv(), Some(57));
assert_eq!(rx2.try_recv(), Some(58));
drop(rx2);
assert_eq!(rx.try_recv(), Some(58));
assert_eq!(rx.try_recv(), None);
}
#[test]
fn test_reader_tracking() {
let rids = ReaderIds {
reserved_ids: 0.into(),
};
assert_eq!(rids.get_free_id(), Some(1));
assert_eq!(rids.reserved_ids.load(Ordering::Relaxed), 1);
let rids = ReaderIds {
reserved_ids: 0b001011.into(),
};
assert_eq!(rids.get_free_id(), Some(0b100));
assert_eq!(rids.reserved_ids.load(Ordering::Relaxed), 0b001111);
let rids = ReaderIds {
reserved_ids: usize::MAX.into(),
};
assert_eq!(rids.get_free_id(), None);
assert_eq!(rids.reserved_ids.load(Ordering::Relaxed), usize::MAX);
}
}