use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::task::{Context, Poll};
use atomic_waker::AtomicWaker;
use futures_core::Stream;
use crate::{Padded, Ring};
struct AsyncRing<T> {
ring: Ring<T>,
consumer_waker: Padded<AtomicWaker>,
producer_waker: Padded<AtomicWaker>,
}
unsafe impl<T: Send> Send for AsyncRing<T> {}
unsafe impl<T: Send> Sync for AsyncRing<T> {}
impl<T> Drop for AsyncRing<T> {
fn drop(&mut self) {
self.ring.drop_remaining();
*self.ring.head.0.get_mut() = 0;
*self.ring.flush.0.get_mut() = 0;
}
}
pub struct AsyncProducer<T> {
ring: Arc<AsyncRing<T>>,
tail: usize,
cached_head: usize,
}
unsafe impl<T: Send> Send for AsyncProducer<T> {}
pub struct AsyncConsumer<T> {
ring: Arc<AsyncRing<T>>,
head: usize,
cached_flush: usize,
}
unsafe impl<T: Send> Send for AsyncConsumer<T> {}
pub fn async_spsc<T>(capacity: usize) -> (AsyncProducer<T>, AsyncConsumer<T>) {
let ring = Arc::new(AsyncRing {
ring: Ring::new(capacity),
consumer_waker: Padded(AtomicWaker::new()),
producer_waker: Padded(AtomicWaker::new()),
});
(
AsyncProducer {
ring: ring.clone(),
tail: 0,
cached_head: 0,
},
AsyncConsumer {
ring,
head: 0,
cached_flush: 0,
},
)
}
impl<T> AsyncProducer<T> {
#[inline]
pub fn push(&mut self, val: T) -> Result<(), T> {
self.ring
.ring
.push(&mut self.tail, &mut self.cached_head, val)
}
#[inline]
pub fn flush(&mut self) {
self.ring.ring.flush.0.store(self.tail, Ordering::Release);
self.ring.consumer_waker.0.wake();
}
#[inline]
pub fn push_and_flush(&mut self, val: T) -> Result<(), T> {
self.push(val)?;
self.flush();
Ok(())
}
#[inline]
pub fn push_async(&mut self, val: T) -> PushFuture<'_, T> {
PushFuture {
producer: self,
val: Some(val),
}
}
#[inline]
pub fn is_full(&mut self) -> bool {
self.ring.ring.is_full(self.tail, &mut self.cached_head)
}
#[inline]
pub fn capacity(&self) -> usize {
self.ring.ring.capacity()
}
#[inline]
pub fn len(&self) -> usize {
self.ring.ring.producer_len(self.tail)
}
#[inline]
pub fn is_empty(&self) -> bool {
self.ring.ring.producer_is_empty(self.tail)
}
}
pub struct PushFuture<'a, T> {
producer: &'a mut AsyncProducer<T>,
val: Option<T>,
}
impl<T> std::fmt::Debug for PushFuture<'_, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PushFuture")
.field("has_value", &self.val.is_some())
.finish_non_exhaustive()
}
}
impl<T> Unpin for PushFuture<'_, T> {}
impl<T> Future for PushFuture<'_, T> {
type Output = Result<(), T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
let val = this.val.take().expect("PushFuture polled after completion");
match this.producer.push(val) {
Ok(()) => Poll::Ready(Ok(())),
Err(returned) => {
if this
.producer
.ring
.ring
.consumer_dropped
.load(Ordering::Acquire)
{
return Poll::Ready(Err(returned));
}
this.producer.ring.producer_waker.0.register(cx.waker());
match this.producer.push(returned) {
Ok(()) => Poll::Ready(Ok(())),
Err(returned) => {
this.val = Some(returned);
Poll::Pending
}
}
}
}
}
}
impl<T> AsyncConsumer<T> {
#[inline]
pub fn pop(&mut self) -> Option<T> {
self.ring.ring.pop(&mut self.head, self.cached_flush)
}
#[inline]
pub fn release(&mut self) {
self.ring.ring.release(self.head);
self.ring.producer_waker.0.wake();
}
#[inline]
pub fn prefetch(&mut self) -> usize {
self.ring.ring.prefetch(&mut self.cached_flush)
}
#[inline]
pub fn prefetch_and_pop(&mut self) -> Option<T> {
if self.head == self.cached_flush {
self.prefetch();
}
let val = self.pop();
if val.is_some() {
self.release();
}
val
}
#[inline]
pub fn is_empty(&self) -> bool {
self.ring
.ring
.consumer_is_empty(self.head, self.cached_flush)
}
#[inline]
pub fn capacity(&self) -> usize {
self.ring.ring.capacity()
}
#[inline]
pub fn len(&self) -> usize {
self.ring.ring.consumer_len(self.head)
}
}
impl<T> Stream for AsyncConsumer<T> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.head == this.cached_flush {
this.prefetch();
}
if let Some(val) = this.pop() {
return Poll::Ready(Some(val));
}
this.release();
this.ring.consumer_waker.0.register(cx.waker());
if this.head == this.cached_flush {
this.prefetch();
}
if let Some(val) = this.pop() {
Poll::Ready(Some(val))
} else if this.ring.ring.producer_dropped.load(Ordering::Acquire) {
Poll::Ready(None)
} else {
Poll::Pending
}
}
}
impl<T> Drop for AsyncConsumer<T> {
fn drop(&mut self) {
self.release();
self.ring
.ring
.consumer_dropped
.store(true, Ordering::Release);
self.ring.producer_waker.0.wake();
}
}
impl<T> Drop for AsyncProducer<T> {
fn drop(&mut self) {
self.flush();
self.ring
.ring
.producer_dropped
.store(true, Ordering::Release);
self.ring.consumer_waker.0.wake();
}
}
impl<T> std::fmt::Debug for AsyncProducer<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncProducer")
.field("capacity", &self.capacity())
.finish_non_exhaustive()
}
}
impl<T> std::fmt::Debug for AsyncConsumer<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncConsumer")
.field("capacity", &self.capacity())
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use futures_lite::StreamExt;
use super::*;
#[test]
fn async_push_pop() {
let (mut p, mut c) = async_spsc::<u32>(4);
p.push(1).unwrap();
p.push(2).unwrap();
assert!(c.prefetch_and_pop().is_none());
p.flush();
assert_eq!(c.prefetch_and_pop(), Some(1));
assert_eq!(c.prefetch_and_pop(), Some(2));
}
#[test]
fn stream_impl() {
futures_lite::future::block_on(async {
let (mut p, mut c) = async_spsc::<u32>(8);
p.push(10).unwrap();
p.push(20).unwrap();
p.push(30).unwrap();
p.flush();
assert_eq!(c.next().await, Some(10));
assert_eq!(c.next().await, Some(20));
assert_eq!(c.next().await, Some(30));
});
}
#[test]
fn stream_wakes_on_flush() {
use std::sync::atomic::{AtomicBool, Ordering};
let (mut p, mut c) = async_spsc::<u32>(8);
let done = Arc::new(AtomicBool::new(false));
let done2 = done.clone();
let handle = std::thread::spawn(move || {
futures_lite::future::block_on(async {
let val = c.next().await;
done2.store(true, Ordering::Release);
val
})
});
std::thread::sleep(std::time::Duration::from_millis(10));
assert!(!done.load(Ordering::Acquire));
p.push(42).unwrap();
p.flush();
let val = handle.join().unwrap();
assert_eq!(val, Some(42));
}
#[test]
fn cross_thread_stream() {
let (mut p, c) = async_spsc::<u64>(1024);
let n = 50_000u64;
let receiver = std::thread::spawn(move || {
futures_lite::future::block_on(async {
futures_lite::pin!(c);
let mut received = 0u64;
while let Some(v) = c.next().await {
assert_eq!(v, received);
received += 1;
if received == n {
break;
}
}
received
})
});
for i in 0..n {
while p.push(i).is_err() {
p.flush();
std::thread::yield_now();
}
if i % 64 == 63 {
p.flush();
}
}
p.flush();
let count = receiver.join().unwrap();
assert_eq!(count, n);
}
#[test]
fn alternating_push_pop_wakes() {
use std::sync::mpsc;
let (mut p, c) = async_spsc::<u32>(8);
let (tx, rx) = mpsc::sync_channel::<u32>(0);
let handle = std::thread::spawn(move || {
futures_lite::future::block_on(async {
futures_lite::pin!(c);
for _ in 0..5 {
let val = c.next().await.unwrap();
tx.send(val).unwrap();
}
});
});
for i in 0..5 {
p.push_and_flush(i).unwrap();
let val = rx.recv_timeout(std::time::Duration::from_secs(3)).unwrap();
assert_eq!(val, i);
}
handle.join().unwrap();
}
#[test]
fn push_async_blocks_when_full() {
use std::sync::atomic::{AtomicBool, Ordering};
let (mut p, mut c) = async_spsc::<u32>(4);
for i in 0..4 {
p.push(i).unwrap();
}
p.flush();
assert!(p.push(99).is_err());
let pushed = Arc::new(AtomicBool::new(false));
let pushed2 = pushed.clone();
let handle = std::thread::spawn(move || {
futures_lite::future::block_on(async {
p.push_async(99).await.unwrap();
p.flush();
pushed2.store(true, Ordering::Release);
});
});
std::thread::sleep(std::time::Duration::from_millis(20));
assert!(!pushed.load(Ordering::Acquire), "should be blocked");
c.prefetch();
c.pop();
c.release();
handle.join().unwrap();
assert!(pushed.load(Ordering::Acquire));
c.prefetch();
c.pop(); c.pop(); c.pop(); let val = c.pop(); assert_eq!(val, Some(99));
}
#[test]
fn push_async_cross_thread() {
let (mut p, c) = async_spsc::<u64>(64);
let n = 100_000u64;
let receiver = std::thread::spawn(move || {
futures_lite::future::block_on(async {
futures_lite::pin!(c);
let mut received = 0u64;
while let Some(v) = c.next().await {
assert_eq!(v, received);
received += 1;
if received == n {
break;
}
}
received
})
});
futures_lite::future::block_on(async {
for i in 0..n {
p.push_async(i).await.unwrap();
if i % 64 == 63 {
p.flush();
}
}
p.flush();
});
let count = receiver.join().unwrap();
assert_eq!(count, n);
}
#[test]
fn push_async_returns_err_on_consumer_drop() {
let (mut p, c) = async_spsc::<u32>(4);
for i in 0..4 {
p.push(i).unwrap();
}
p.flush();
drop(c);
let result = futures_lite::future::block_on(async { p.push_async(99).await });
assert!(result.is_err());
assert_eq!(result.unwrap_err(), 99);
}
}