use alloc::sync::Arc;
use core::{fmt::Debug, future::Future, task::Poll};
use futures::task::AtomicWaker;
use crate::queues::{DequeueError, EnqueueError};
use super::{BoundedReceiver, BoundedSender};
pub struct AsyncBoundedSender<T> {
rx_waker: Arc<AtomicWaker>,
tx_waker: Arc<AtomicWaker>,
queue: BoundedSender<T>,
}
pub struct AsyncBoundedReceiver<T> {
rx_waker: Arc<AtomicWaker>,
tx_waker: Arc<AtomicWaker>,
queue: BoundedReceiver<T>,
}
pub struct EnqueueFuture<'queue, T> {
rx_waker: &'queue AtomicWaker,
tx_waker: &'queue AtomicWaker,
queue: &'queue mut BoundedSender<T>,
data: Option<T>,
}
pub struct DequeueFuture<'queue, T> {
rx_waker: &'queue AtomicWaker,
tx_waker: &'queue AtomicWaker,
queue: &'queue mut BoundedReceiver<T>,
}
impl<T> AsyncBoundedSender<T> {
pub fn is_closed(&self) -> bool {
self.queue.is_closed()
}
pub fn enqueue(&mut self, data: T) -> EnqueueFuture<'_, T> {
EnqueueFuture {
rx_waker: &self.rx_waker,
tx_waker: &self.tx_waker,
queue: &mut self.queue,
data: Some(data),
}
}
pub fn try_enqueue(&mut self, data: T) -> Result<(), (T, EnqueueError)> {
match self.queue.try_enqueue(data) {
Ok(_) => {
self.rx_waker.wake();
Ok(())
}
Err(e) => Err(e),
}
}
pub fn is_full(&self) -> bool {
self.queue.is_full()
}
}
impl<T> Debug for AsyncBoundedSender<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Async-Bounded-Sender ()")
}
}
impl<T> AsyncBoundedReceiver<T> {
pub fn is_closed(&self) -> bool {
self.queue.is_closed()
}
pub fn dequeue(&mut self) -> DequeueFuture<'_, T> {
DequeueFuture {
rx_waker: &self.rx_waker,
tx_waker: &self.tx_waker,
queue: &mut self.queue,
}
}
pub fn try_dequeue(&mut self) -> Result<T, DequeueError> {
match self.queue.try_dequeue() {
Ok(d) => {
self.tx_waker.wake();
Ok(d)
}
Err(e) => Err(e),
}
}
pub fn is_empty(&self) -> bool {
self.queue.is_empty()
}
}
impl<T> Debug for AsyncBoundedReceiver<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Async-Bounded-Receiver ()")
}
}
impl<'queue, T> Unpin for EnqueueFuture<'queue, T> {}
impl<'queue, T> Future for EnqueueFuture<'queue, T> {
type Output = Result<(), (T, EnqueueError)>;
fn poll(
mut self: core::pin::Pin<&mut Self>,
cx: &mut core::task::Context<'_>,
) -> core::task::Poll<Self::Output> {
let data = match self.data.take() {
Some(d) => d,
None => return Poll::Ready(Ok(())),
};
match self.queue.try_enqueue(data) {
Ok(_) => {
self.rx_waker.wake();
Poll::Ready(Ok(()))
}
Err((d, e)) => match e {
EnqueueError::Full => {
self.data.replace(d);
self.tx_waker.register(cx.waker());
Poll::Pending
}
EnqueueError::Closed => Poll::Ready(Err((d, e))),
},
}
}
}
impl<'queue, T> Debug for EnqueueFuture<'queue, T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Enqueue-Future ()")
}
}
impl<'queue, T> Unpin for DequeueFuture<'queue, T> {}
impl<'queue, T> Future for DequeueFuture<'queue, T> {
type Output = Result<T, DequeueError>;
fn poll(
mut self: core::pin::Pin<&mut Self>,
cx: &mut core::task::Context<'_>,
) -> Poll<Self::Output> {
match self.queue.try_dequeue() {
Ok(d) => {
self.tx_waker.wake();
Poll::Ready(Ok(d))
}
Err(e) => match e {
DequeueError::Empty => {
self.rx_waker.register(cx.waker());
Poll::Pending
}
DequeueError::Closed => Poll::Ready(Err(DequeueError::Closed)),
},
}
}
}
impl<'queue, T> Debug for DequeueFuture<'queue, T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Dequeue-Future ()")
}
}
pub fn async_queue<T>(size: usize) -> (AsyncBoundedReceiver<T>, AsyncBoundedSender<T>) {
let (u_rx, u_tx) = super::queue(size);
let rx_waker = Arc::new(AtomicWaker::new());
let tx_waker = Arc::new(AtomicWaker::new());
(
AsyncBoundedReceiver {
rx_waker: rx_waker.clone(),
tx_waker: tx_waker.clone(),
queue: u_rx,
},
AsyncBoundedSender {
rx_waker,
tx_waker,
queue: u_tx,
},
)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[cfg_attr(miri, ignore)]
async fn enqueue_dequeue() {
let (mut rx, mut tx) = async_queue::<usize>(10);
tx.enqueue(13).await.unwrap();
assert_eq!(Ok(13), rx.dequeue().await);
}
}