waitfree-sync 0.3.3

A collection of wait-free data structures
Documentation
//! This is a partial, wait-free, single-producer multiple-consumer (SPMC) queue
//! for sending data from a real-time thread to multiple non-real-time threads.
//!
//!
//!
use crate::{
    error::NoSpaceLeftError,
    import::{Arc, AtomicUsize, Ordering, UnsafeCell},
    spsc::is_power_of_two,
};
use crossbeam_utils::CachePadded;
use std::{fmt::Debug, mem::zeroed};

pub fn spmc_once<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
    if !is_power_of_two(capacity) {
        panic!("The SIZE must be a power of 2")
    }

    let chan = Arc::new(SpmcOnce::new(capacity));

    // SAFETY: This is always true!
    let r = Receiver::new(chan.clone()).unwrap();
    let w = Sender::new(chan);

    (w, r)
}

#[derive(Debug)]
struct Slot<T: Sized> {
    value: UnsafeCell<Option<T>>,
}
impl<T> Slot<T> {
    fn new() -> Self {
        Self {
            value: UnsafeCell::new(None),
        }
    }
}

#[derive(Debug)]
struct SpmcOnce<T> {
    mem: Box<[Slot<T>]>,
    // The mask is written when this structure is created and is then only read.
    // Therefore, we do not need Atomic here.
    mask: usize,
    write: CachePadded<AtomicUsize>,
    read_allocate: CachePadded<AtomicUsize>,
    read: CachePadded<AtomicUsize>,
}

impl<T> SpmcOnce<T> {
    fn new(capacity: usize) -> Self {
        let mut buffer = Vec::with_capacity(capacity);
        for _ in 0..capacity {
            buffer.push(Slot::new());
        }
        let buffer: Box<[Slot<T>]> = buffer.into_boxed_slice();
        SpmcOnce {
            mem: buffer,
            mask: capacity - 1,
            write: CachePadded::new(0.into()),
            read_allocate: CachePadded::new(0.into()),
            read: CachePadded::new(0.into()),
        }
    }
    #[inline]
    fn capacity(&self) -> usize {
        self.mask + 1
    }
}

#[derive(Debug)]
pub struct Receiver<T> {
    shared: Arc<SpmcOnce<T>>,
}
unsafe impl<T: Send> Send for Receiver<T> {}
unsafe impl<T: Send> Sync for Receiver<T> {}

impl<T> Receiver<T> {
    fn new(shared: Arc<SpmcOnce<T>>) -> Option<Self> {
        Some(Receiver { shared })
    }
}

impl<T> Clone for Receiver<T> {
    fn clone(&self) -> Self {
        let shared = self.shared.clone();
        Self { shared }
    }
}

impl<T> Receiver<T> {
    pub fn try_recv(&mut self) -> Option<T>
    where
        T: Clone,
    {
        if self.shared.read.load(Ordering::Acquire) == self.shared.write.load(Ordering::Acquire) {
            return None;
        }
        self.shared.read_allocate.fetch_add(1, Ordering::Acquire);
        let rpos = self.read & self.shared.mask;

        let slot = unsafe { self.shared.mem.get_unchecked(rpos) };
        #[cfg(not(loom))]
        let data = unsafe { slot.value.get().replace(None) };
        #[cfg(loom)]
        let data = unsafe {
            spmc.mem[self.read & spsc.mask]
                .value
                .get_mut()
                .with(|ptr| ptr.replace(std::mem::zeroed()))
        };
        slot.read_tracking
            .fetch_add(self.receiver_id, Ordering::Release);
        self.read += 1;
        data
    }
    #[inline]
    pub fn capacity(&self) -> usize {
        self.shared.capacity()
    }
}

#[derive(Debug)]
pub struct Sender<T> {
    shared: Arc<SpmcOnce<T>>,
    write: usize,
    local_read_index: usize,
}
unsafe impl<T: Send> Send for Sender<T> {}
unsafe impl<T: Send> Sync for Sender<T> {}
impl<T> Sender<T> {
    fn new(shared: Arc<SpmcOnce<T>>) -> Self {
        Sender {
            shared,
            write: 0,
            local_read_index: 0,
        }
    }
}

impl<T> Sender<T> {
    pub fn try_send(&mut self, data: T) -> Result<(), NoSpaceLeftError<T>>
    where
        T: Clone,
    {
        if self.write - self.local_read_index == self.shared.capacity() {
            if self
                .shared
                .slot_completed((self.local_read_index + 1) & self.shared.mask)
            {
                self.local_read_index += 1;
                debug_assert!(self.write - self.local_read_index != self.shared.capacity());
            } else {
                return Err(NoSpaceLeftError(data));
            }
        }

        let wpos = self.write & self.shared.mask;
        let slot = unsafe { self.shared.mem.get_unchecked(wpos) };
        #[cfg(not(loom))]
        unsafe {
            slot.value.get().write(Some(data))
        };
        #[cfg(loom)]
        unsafe {
            slot.value.get_mut().with(|ptr| ptr.write(Some(data)))
        };

        self.write += 1;
        self.shared.write.store(self.write, Ordering::Release);

        Ok(())
    }
    pub fn capacity(&self) -> usize {
        self.shared.capacity()
    }
}

#[cfg(not(loom))]
#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn smoke() {
        let (mut tx, mut rx) = spmc(8);

        tx.try_send(56).unwrap();
        assert_eq!(rx.try_recv(), Some(56));
        let mut rx2 = rx.clone();
        assert_eq!(rx2.try_recv(), None);

        tx.try_send(6798).unwrap();
        assert_eq!(rx.try_recv(), Some(6798));
        assert_eq!(rx2.try_recv(), Some(6798));
    }

    #[test]
    fn test_multiple_reader() {
        let (mut tx, mut rx) = spmc(8);
        assert_eq!(rx.shared.reader_ids.num_readers(), 1);
        tx.try_send(56).unwrap();
        tx.try_send(57).unwrap();
        let mut rx2 = rx.clone();
        tx.try_send(58).unwrap();
        assert_eq!(rx.try_recv(), Some(56));
        assert_eq!(rx.try_recv(), Some(57));
        assert_eq!(rx2.try_recv(), Some(58));
        drop(rx2);
        assert_eq!(rx.try_recv(), Some(58));
        assert_eq!(rx.try_recv(), None);
    }

    #[test]
    fn test_reader_tracking() {
        let rids = ReaderIds {
            reserved_ids: 0.into(),
        };
        assert_eq!(rids.get_free_id(), Some(1));
        assert_eq!(rids.reserved_ids.load(Ordering::Relaxed), 1);

        let rids = ReaderIds {
            reserved_ids: 0b001011.into(),
        };
        assert_eq!(rids.get_free_id(), Some(0b100));
        assert_eq!(rids.reserved_ids.load(Ordering::Relaxed), 0b001111);

        let rids = ReaderIds {
            reserved_ids: usize::MAX.into(),
        };
        assert_eq!(rids.get_free_id(), None);
        assert_eq!(rids.reserved_ids.load(Ordering::Relaxed), usize::MAX);
    }
}