use alloc::{sync::Arc, vec::Vec};
use core::{cell::UnsafeCell, mem::MaybeUninit, sync::atomic};
use crate::queues::{DequeueError, EnqueueError};
pub mod ncq;
pub mod scq;
pub struct BoundedReceiver<T, UQ> {
data: Arc<Vec<UnsafeCell<MaybeUninit<T>>>>,
aq: Arc<UQ>,
fq: Arc<UQ>,
rx_count: Arc<atomic::AtomicU64>,
tx_count: Arc<atomic::AtomicU64>,
}
pub struct BoundedSender<T, UQ> {
data: Arc<Vec<UnsafeCell<MaybeUninit<T>>>>,
aq: Arc<UQ>,
fq: Arc<UQ>,
rx_count: Arc<atomic::AtomicU64>,
tx_count: Arc<atomic::AtomicU64>,
}
pub trait UnderlyingQueue {
fn enqueue(&self, index: usize);
fn dequeue(&self) -> Option<usize>;
}
fn new_queue<T, UQ>(
aq: UQ,
fq: UQ,
capacity: usize,
) -> (BoundedReceiver<T, UQ>, BoundedSender<T, UQ>) {
let data = {
let mut tmp = Vec::with_capacity(capacity);
for _ in 0..capacity {
tmp.push(UnsafeCell::new(MaybeUninit::uninit()));
}
Arc::new(tmp)
};
let aq_arc = Arc::new(aq);
let fq_arc = Arc::new(fq);
let rx_count = Arc::new(atomic::AtomicU64::new(1));
let tx_count = Arc::new(atomic::AtomicU64::new(1));
let rx = BoundedReceiver {
data: data.clone(),
aq: aq_arc.clone(),
fq: fq_arc.clone(),
rx_count: rx_count.clone(),
tx_count: tx_count.clone(),
};
let tx = BoundedSender {
data,
aq: aq_arc,
fq: fq_arc,
rx_count,
tx_count,
};
(rx, tx)
}
unsafe impl<T, UQ> Sync for BoundedReceiver<T, UQ> {}
unsafe impl<T, UQ> Sync for BoundedSender<T, UQ> {}
unsafe impl<T, UQ> Send for BoundedReceiver<T, UQ> where T: Send {}
unsafe impl<T, UQ> Send for BoundedSender<T, UQ> where T: Send {}
pub fn queue_ncq<T>(
capacity: usize,
) -> (BoundedReceiver<T, ncq::Queue>, BoundedSender<T, ncq::Queue>) {
let aq = ncq::Queue::new(capacity);
let fq = ncq::Queue::new(capacity);
for index in 0..capacity {
fq.enqueue(index);
}
new_queue(aq, fq, capacity)
}
pub fn queue_scq<T>(
capacity: usize,
) -> (BoundedReceiver<T, scq::Queue>, BoundedSender<T, scq::Queue>) {
let aq = scq::Queue::new(capacity);
let fq = scq::Queue::new(capacity);
for index in 0..capacity {
fq.enqueue(index);
}
new_queue(aq, fq, capacity)
}
impl<T, UQ> BoundedSender<T, UQ>
where
UQ: UnderlyingQueue,
{
pub fn try_enqueue(&self, data: T) -> Result<(), (EnqueueError, T)> {
if self.is_closed() {
return Err((EnqueueError::Closed, data));
}
let index = match self.fq.dequeue() {
Some(i) => i,
None => return Err((EnqueueError::Full, data)),
};
let bucket = self
.data
.get(index)
.expect("The received Index should always be in the Bounds of the Data Buffer");
let bucket_ptr = bucket.get();
unsafe { bucket_ptr.write(MaybeUninit::new(data)) };
self.aq.enqueue(index);
Ok(())
}
pub fn is_closed(&self) -> bool {
self.rx_count.load(atomic::Ordering::Acquire) == 0
}
}
impl<T, UQ> Drop for BoundedSender<T, UQ> {
fn drop(&mut self) {
self.tx_count.fetch_sub(1, atomic::Ordering::AcqRel);
}
}
impl<T, UQ> BoundedReceiver<T, UQ>
where
UQ: UnderlyingQueue,
{
pub fn dequeue(&self) -> Result<T, DequeueError> {
let index = match self.aq.dequeue() {
Some(i) => i,
None => {
if self.is_closed() {
return Err(DequeueError::Closed);
}
return Err(DequeueError::Empty);
}
};
let bucket = self
.data
.get(index)
.expect("The received Index should always be in the Bounds of the Data-Buffer");
let bucket_ptr = bucket.get();
let data = unsafe { bucket_ptr.replace(MaybeUninit::uninit()).assume_init() };
self.fq.enqueue(index);
Ok(data)
}
pub fn is_closed(&self) -> bool {
self.tx_count.load(atomic::Ordering::Acquire) == 0
}
}
impl<T, UQ> Drop for BoundedReceiver<T, UQ> {
fn drop(&mut self) {
self.rx_count.fetch_sub(1, atomic::Ordering::AcqRel);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ncq_new() {
queue_ncq::<u64>(10);
}
#[test]
fn scq_new() {
queue_scq::<u64>(10);
}
#[test]
fn enqueue() {
let (rx, tx) = queue_ncq::<u64>(10);
assert_eq!(Ok(()), tx.try_enqueue(15));
drop(rx);
}
#[test]
fn enqueue_full() {
let (rx, tx) = queue_ncq::<u64>(10);
for index in 0..10 {
assert_eq!(Ok(()), tx.try_enqueue(index));
}
assert_eq!(Err((EnqueueError::Full, 15)), tx.try_enqueue(15));
drop(rx);
}
#[test]
fn enqueue_closed() {
let (rx, tx) = queue_ncq::<u64>(10);
drop(rx);
assert_eq!(Err((EnqueueError::Closed, 15)), tx.try_enqueue(15));
}
#[test]
fn dequeue_empty() {
let (rx, tx) = queue_ncq::<u64>(10);
assert_eq!(Err(DequeueError::Empty), rx.dequeue());
drop(tx);
}
#[test]
fn dequeue_closed() {
let (rx, tx) = queue_ncq::<u64>(10);
drop(tx);
assert_eq!(Err(DequeueError::Closed), rx.dequeue());
}
#[test]
fn enqueue_dequeue() {
let (rx, tx) = queue_ncq::<u64>(10);
assert_eq!(Ok(()), tx.try_enqueue(15));
assert_eq!(Ok(15), rx.dequeue());
}
#[test]
fn enqueue_dequeue_fill_multiple() {
let (rx, tx) = queue_ncq::<u64>(10);
for index in 0..(5 * 10) {
assert_eq!(Ok(()), tx.try_enqueue(index));
assert_eq!(Ok(index), rx.dequeue());
}
}
#[test]
fn receiver_closed() {
let (rx, tx) = queue_ncq::<u64>(10);
assert!(!rx.is_closed());
drop(tx);
assert!(rx.is_closed());
}
#[test]
fn sending_closed() {
let (rx, tx) = queue_ncq::<u64>(10);
assert!(!tx.is_closed());
drop(rx);
assert!(tx.is_closed());
}
}