use std::{
collections::VecDeque,
future::Future,
pin::Pin,
sync::{Arc, Mutex, MutexGuard},
task::{Context, Poll, Waker},
};
#[derive(Debug)]
struct AsyncDeque<T> {
queue: Option<VecDeque<T>>,
waker: Option<Waker>,
}
impl<T> AsyncDeque<T> {
fn push_back(&mut self, value: T) {
if let Some(queue) = &mut self.queue {
queue.push_back(value);
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
}
fn push_front(&mut self, value: T) {
if let Some(queue) = &mut self.queue {
queue.push_front(value);
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
}
fn poll_pop(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
match &mut self.queue {
Some(queue) => {
if let Some(frame) = queue.pop_front() {
Poll::Ready(Some(frame))
} else if let Some(ref waker) = self.waker {
if !waker.will_wake(cx.waker()) {
panic!(
"Multiple tasks are attempting to wait on the same AsyncDeque. This is a bug, place report it."
);
}
self.waker = Some(cx.waker().clone());
Poll::Pending
} else {
self.waker = Some(cx.waker().clone());
Poll::Pending
}
}
None => Poll::Ready(None),
}
}
fn len(&self) -> usize {
self.queue.as_ref().map(|v| v.len()).unwrap_or(0)
}
fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn close(&mut self) {
self.queue = None;
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
}
impl<T> Extend<T> for AsyncDeque<T> {
fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
if let Some(queue) = &mut self.queue {
queue.extend(iter);
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
}
}
#[derive(Debug)]
pub struct ArcAsyncDeque<T>(Arc<Mutex<AsyncDeque<T>>>);
impl<T> ArcAsyncDeque<T> {
pub fn new() -> Self {
Self(Arc::new(Mutex::new(AsyncDeque {
queue: Some(VecDeque::with_capacity(8)),
waker: None,
})))
}
pub fn with_capacity(capacity: usize) -> Self {
Self(Arc::new(Mutex::new(AsyncDeque {
queue: Some(VecDeque::with_capacity(capacity)),
waker: None,
})))
}
fn lock_guard(&self) -> MutexGuard<'_, AsyncDeque<T>> {
self.0.lock().unwrap()
}
pub fn push_front(&self, value: T) {
self.lock_guard().push_front(value);
}
pub fn push_back(&self, value: T) {
self.lock_guard().push_back(value);
}
pub fn pop(&self) -> Self {
self.clone()
}
pub fn poll_pop(&self, cx: &mut Context<'_>) -> Poll<Option<T>> {
self.lock_guard().poll_pop(cx)
}
pub fn len(&self) -> usize {
self.lock_guard().len()
}
pub fn is_empty(&self) -> bool {
self.lock_guard().is_empty()
}
pub fn close(&self) {
self.lock_guard().close();
}
}
impl<T> Default for ArcAsyncDeque<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> Clone for ArcAsyncDeque<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<T> Future for ArcAsyncDeque<T> {
type Output = Option<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.poll_pop(cx)
}
}
impl<T: Unpin> futures::Stream for ArcAsyncDeque<T> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.poll_pop(cx)
}
}
impl<T> Extend<T> for &ArcAsyncDeque<T> {
fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
self.0.lock().unwrap().extend(iter);
}
}
#[cfg(test)]
mod tests {
use futures::FutureExt;
use super::*;
#[tokio::test]
async fn push_pop() {
let deque = ArcAsyncDeque::new();
assert!(deque.is_empty());
deque.push_back(1);
deque.push_back(2);
assert_eq!(deque.len(), 2);
assert_eq!(deque.pop().await, Some(1));
assert_eq!(deque.pop().await, Some(2));
let deque = ArcAsyncDeque::with_capacity(2);
deque.push_back(1);
deque.push_front(2);
assert_eq!(deque.len(), 2);
assert_eq!(deque.pop().await, Some(2));
assert_eq!(deque.pop().await, Some(1));
}
#[tokio::test]
async fn close() {
let deque = ArcAsyncDeque::new();
assert!(deque.is_empty());
deque.push_back(1);
deque.push_back(2);
assert_eq!(deque.len(), 2);
deque.close();
assert!(deque.is_empty());
assert_eq!(deque.pop().await, None);
}
#[tokio::test]
async fn wake() {
let deque = ArcAsyncDeque::new();
tokio::select! {
item = deque.pop() => {
assert_eq!(item, Some(1));
}
_ = async {
deque.push_back(1);
std::future::pending::<()>().await;
} => unreachable!()
}
let deque = ArcAsyncDeque::new();
tokio::select! {
item = deque.pop() => {
assert_eq!(item, Some(1));
}
_ = async {
deque.push_back(1);
std::future::pending::<()>().await;
} => unreachable!()
}
}
#[tokio::test]
async fn cancel() {
let deque = ArcAsyncDeque::new();
let poll = core::future::poll_fn(|cx| Poll::Ready(deque.pop().poll_unpin(cx))).await;
assert_eq!(poll, Poll::Pending);
(&deque).extend([654]);
let poll = core::future::poll_fn(|cx| Poll::Ready(deque.pop().poll_unpin(cx))).await;
assert_eq!(poll, Poll::Ready(Some(654)));
let poll = core::future::poll_fn(|cx| Poll::Ready(deque.pop().poll_unpin(cx))).await;
assert_eq!(poll, Poll::Pending);
let poll = core::future::poll_fn(|cx| Poll::Ready(deque.pop().poll_unpin(cx))).await;
assert_eq!(poll, Poll::Pending);
}
#[tokio::test]
async fn racing() {
let deque: ArcAsyncDeque<()> = ArcAsyncDeque::new();
let consumer = tokio::spawn(deque.pop());
tokio::task::yield_now().await;
let abuse = tokio::spawn(deque.pop());
tokio::task::yield_now().await;
_ = consumer;
assert!(abuse.await.is_err());
}
}