use alloc::sync::Arc;
use core::{fmt::Debug, future::Future, task::Poll};
use futures::task::AtomicWaker;
use crate::queues::{DequeueError, EnqueueError};
use super::{queue, UnboundedReceiver, UnboundedSender};
pub struct AsyncUnboundedSender<T> {
rx_waker: Arc<AtomicWaker>,
queue: UnboundedSender<T>,
}
pub struct AsyncUnboundedReceiver<T> {
rx_waker: Arc<AtomicWaker>,
queue: UnboundedReceiver<T>,
}
pub struct DequeueFuture<'queue, T> {
rx_waker: &'queue AtomicWaker,
queue: &'queue mut UnboundedReceiver<T>,
}
impl<T> AsyncUnboundedSender<T> {
pub fn is_closed(&self) -> bool {
self.queue.is_closed()
}
pub fn enqueue(&mut self, data: T) -> Result<(), (T, EnqueueError)> {
self.queue.enqueue(data)?;
self.rx_waker.wake();
Ok(())
}
}
impl<T> Debug for AsyncUnboundedSender<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Async-Unbounded-Sender ()")
}
}
impl<T> AsyncUnboundedReceiver<T> {
pub fn is_closed(&self) -> bool {
self.queue.is_closed()
}
pub fn dequeue(&mut self) -> DequeueFuture<'_, T> {
DequeueFuture {
rx_waker: &self.rx_waker,
queue: &mut self.queue,
}
}
pub fn try_dequeue(&mut self) -> Result<T, DequeueError> {
self.queue.try_dequeue()
}
}
impl<T> Debug for AsyncUnboundedReceiver<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Async-Unbounded-Receiver ()")
}
}
impl<'queue, T> Future for DequeueFuture<'queue, T> {
type Output = Result<T, DequeueError>;
fn poll(
mut self: core::pin::Pin<&mut Self>,
cx: &mut core::task::Context<'_>,
) -> core::task::Poll<Self::Output> {
match self.queue.try_dequeue() {
Ok(d) => Poll::Ready(Ok(d)),
Err(e) => match e {
DequeueError::Empty => {
self.rx_waker.register(cx.waker());
Poll::Pending
}
DequeueError::Closed => Poll::Ready(Err(DequeueError::Closed)),
},
}
}
}
impl<'queue, T> Debug for DequeueFuture<'queue, T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Dequeue-Future ()")
}
}
pub fn async_queue<T>() -> (AsyncUnboundedReceiver<T>, AsyncUnboundedSender<T>) {
let (u_rx, u_tx) = queue();
let rx_waker = Arc::new(AtomicWaker::new());
(
AsyncUnboundedReceiver {
rx_waker: rx_waker.clone(),
queue: u_rx,
},
AsyncUnboundedSender {
rx_waker,
queue: u_tx,
},
)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[cfg_attr(miri, ignore)]
async fn enqueue_dequeue() {
let (mut rx, mut tx) = async_queue();
tx.enqueue(13).unwrap();
assert_eq!(Ok(13), rx.dequeue().await);
}
}