use alloc::sync::Arc;
use core::{fmt::Debug, future::Future, task::Poll};
use futures::task::AtomicWaker;
use crate::queues::{DequeueError, EnqueueError};
use super::{queue, Receiver, Sender};
pub struct AsyncReceiver<T> {
waker: Arc<AtomicWaker>,
queue: Receiver<T>,
}
pub struct AsyncSender<T> {
waker: Arc<AtomicWaker>,
queue: Sender<T>,
}
impl<T> AsyncReceiver<T> {
pub fn is_closed(&self) -> bool {
self.queue.is_closed()
}
pub fn try_dequeue(&mut self) -> Result<T, DequeueError> {
self.queue.try_dequeue()
}
pub fn dequeue(&mut self) -> DequeueFuture<'_, T> {
DequeueFuture {
waker: &self.waker,
queue: &mut self.queue,
}
}
}
impl<T> Debug for AsyncReceiver<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Async-Receiver ()")
}
}
pub struct DequeueFuture<'queue, T> {
waker: &'queue AtomicWaker,
queue: &'queue mut Receiver<T>,
}
impl<'queue, T> Future for DequeueFuture<'queue, T> {
type Output = Result<T, DequeueError>;
fn poll(
mut self: core::pin::Pin<&mut Self>,
cx: &mut core::task::Context<'_>,
) -> core::task::Poll<Self::Output> {
match self.queue.try_dequeue() {
Ok(d) => Poll::Ready(Ok(d)),
Err(e) => match e {
DequeueError::Empty => {
self.waker.register(cx.waker());
Poll::Pending
}
DequeueError::Closed => Poll::Ready(Err(DequeueError::Closed)),
},
}
}
}
impl<'queue, T> Debug for DequeueFuture<'queue, T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Async-Dequeue-Operation ()")
}
}
impl<T> AsyncSender<T> {
pub fn is_closed(&self) -> bool {
self.queue.is_closed()
}
pub fn enqueue(&self, data: T) -> Result<(), (T, EnqueueError)> {
self.queue.enqueue(data)?;
self.waker.wake();
Ok(())
}
}
impl<T> Debug for AsyncSender<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Async-Sender ()")
}
}
pub fn async_queue<T>() -> (AsyncReceiver<T>, AsyncSender<T>) {
let (u_rx, u_tx) = queue();
let waker = Arc::new(AtomicWaker::new());
(
AsyncReceiver {
waker: waker.clone(),
queue: u_rx,
},
AsyncSender { waker, queue: u_tx },
)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[cfg_attr(miri, ignore)]
async fn enqueue_dequeue() {
let (mut rx, tx) = async_queue();
tx.enqueue(13).unwrap();
assert_eq!(Ok(13), rx.dequeue().await);
}
}