use super::super::bounded;
use crate::queues::DequeueError;
use alloc::{boxed::Box, sync::Arc};
use core::{fmt::Debug, sync::atomic};
struct Node<T> {
data: Option<T>,
previous: *mut Self,
next: atomic::AtomicPtr<Node<T>>,
}
pub struct UnboundedSender<T> {
closed: Arc<atomic::AtomicBool>,
tail: *mut Node<T>,
node_receiver: bounded::BoundedReceiver<Box<Node<T>>>,
}
impl<T> UnboundedSender<T> {
fn create_new_node(&mut self, data: T, previous: *mut Node<T>) -> Box<Node<T>> {
match self.node_receiver.try_dequeue() {
Ok(mut n) => {
n.data = Some(data);
n.previous = previous;
n.next
.store(core::ptr::null_mut(), atomic::Ordering::Release);
n
}
Err(_) => Box::new(Node {
data: Some(data),
previous,
next: atomic::AtomicPtr::new(core::ptr::null_mut()),
}),
}
}
pub fn enqueue(&mut self, data: T) -> Result<(), T> {
if self.closed.load(atomic::Ordering::Acquire) {
return Err(data);
}
let node = self.create_new_node(data, self.tail);
let node_ptr = Box::into_raw(node);
let cur_tail = unsafe { &*self.tail };
cur_tail.next.store(node_ptr, atomic::Ordering::Release);
self.tail = node_ptr;
Ok(())
}
}
impl<T> Debug for UnboundedSender<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "UnboundedSender ()")
}
}
impl<T> Drop for UnboundedSender<T> {
fn drop(&mut self) {
match self.closed.compare_exchange(
false,
true,
atomic::Ordering::SeqCst,
atomic::Ordering::SeqCst,
) {
Ok(_) => {}
Err(_) => {
let mut current_ptr = self.tail;
while !current_ptr.is_null() {
let current = unsafe { Box::from_raw(current_ptr) };
current_ptr = current.previous;
drop(current);
}
}
};
}
}
pub struct UnboundedReceiver<T> {
closed: Arc<atomic::AtomicBool>,
head: *mut Node<T>,
node_return: bounded::BoundedSender<Box<Node<T>>>,
}
impl<T> UnboundedReceiver<T> {
pub fn try_dequeue(&mut self) -> Result<T, DequeueError> {
let prev_head = unsafe { &*self.head };
let next_ptr = prev_head.next.load(atomic::Ordering::Acquire);
if next_ptr.is_null() {
return Err(DequeueError::Empty);
}
let next = unsafe { &mut *next_ptr };
let data = next.data.take().unwrap();
let prev_head_ptr = self.head;
self.head = next_ptr;
if let Err((node, _)) = self
.node_return
.try_enqueue(unsafe { Box::from_raw(prev_head_ptr) })
{
drop(node);
}
Ok(data)
}
pub fn has_next(&self) -> bool {
let prev_head = unsafe { &*self.head };
let next_ptr = prev_head.next.load(atomic::Ordering::Acquire);
!next_ptr.is_null()
}
}
impl<T> Debug for UnboundedReceiver<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "UnboundedReceiver ()")
}
}
impl<T> Drop for UnboundedReceiver<T> {
fn drop(&mut self) {
match self.closed.compare_exchange(
false,
true,
atomic::Ordering::SeqCst,
atomic::Ordering::SeqCst,
) {
Ok(_) => {}
Err(_) => {
let mut current_ptr = self.head;
while !current_ptr.is_null() {
let current = unsafe { Box::from_raw(current_ptr) };
current_ptr = current.next.load(atomic::Ordering::Acquire);
drop(current);
}
}
};
}
}
pub fn unbounded_basic_queue<T>() -> (UnboundedReceiver<T>, UnboundedSender<T>) {
let (node_rx, node_tx) = bounded::queue(64);
let dummy_node = Box::new(Node {
data: None,
previous: core::ptr::null_mut(),
next: atomic::AtomicPtr::new(core::ptr::null_mut()),
});
let dummy_ptr = Box::into_raw(dummy_node);
let closed = Arc::new(atomic::AtomicBool::new(false));
(
UnboundedReceiver {
closed: closed.clone(),
head: dummy_ptr,
node_return: node_tx,
},
UnboundedSender {
closed,
tail: dummy_ptr,
node_receiver: node_rx,
},
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn enqueue_dequeue() {
let (mut rx, mut tx) = unbounded_basic_queue();
tx.enqueue(13).unwrap();
assert_eq!(Ok(13), rx.try_dequeue());
}
#[test]
fn dequeue_empty() {
let (mut rx, mut tx) = unbounded_basic_queue();
assert_eq!(Err(DequeueError::Empty), rx.try_dequeue());
tx.enqueue(13).unwrap();
assert_eq!(Ok(13), rx.try_dequeue());
assert_eq!(Err(DequeueError::Empty), rx.try_dequeue());
}
#[test]
fn multiple_enqueue_dequeue() {
let (mut rx, mut tx) = unbounded_basic_queue();
tx.enqueue(13).unwrap();
assert_eq!(Ok(13), rx.try_dequeue());
tx.enqueue(14).unwrap();
assert_eq!(Ok(14), rx.try_dequeue());
tx.enqueue(15).unwrap();
assert_eq!(Ok(15), rx.try_dequeue());
}
#[test]
fn multiple_enqueue_dequeue_2() {
let (mut rx, mut tx) = unbounded_basic_queue();
tx.enqueue(13).unwrap();
tx.enqueue(14).unwrap();
tx.enqueue(15).unwrap();
assert_eq!(Ok(13), rx.try_dequeue());
assert_eq!(Ok(14), rx.try_dequeue());
assert_eq!(Ok(15), rx.try_dequeue());
}
}