use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::task::{Context, Poll};
use atomic_waker::AtomicWaker;
use futures_core::Stream;
use crate::{FlushResult, Padded, Ring};
struct AsyncRing<T> {
ring: Ring<T>,
waker: Padded<AtomicWaker>,
}
unsafe impl<T: Send> Send for AsyncRing<T> {}
unsafe impl<T: Send> Sync for AsyncRing<T> {}
impl<T> Drop for AsyncRing<T> {
fn drop(&mut self) {
let head = *self.ring.head.0.get_mut();
let flush = *self.ring.flush.0.get_mut();
for i in head..flush {
unsafe {
self.ring.buf[i & self.ring.mask]
.get_mut()
.assume_init_drop();
}
}
*self.ring.head.0.get_mut() = 0;
*self.ring.flush.0.get_mut() = 0;
}
}
pub struct AsyncProducer<T> {
ring: Arc<AsyncRing<T>>,
tail: usize,
cached_head: usize,
}
unsafe impl<T: Send> Send for AsyncProducer<T> {}
pub struct AsyncConsumer<T> {
ring: Arc<AsyncRing<T>>,
head: usize,
cached_flush: usize,
}
unsafe impl<T: Send> Send for AsyncConsumer<T> {}
pub fn async_spsc<T>(capacity: usize) -> (AsyncProducer<T>, AsyncConsumer<T>) {
let ring = Arc::new(AsyncRing {
ring: Ring::new(capacity),
waker: Padded(AtomicWaker::new()),
});
(
AsyncProducer {
ring: ring.clone(),
tail: 0,
cached_head: 0,
},
AsyncConsumer {
ring,
head: 0,
cached_flush: 0,
},
)
}
impl<T> AsyncProducer<T> {
#[inline]
pub fn push(&mut self, val: T) -> Result<(), T> {
if self.tail - self.cached_head >= self.ring.ring.capacity() {
self.cached_head = self.ring.ring.head.0.load(Ordering::Acquire);
if self.tail - self.cached_head >= self.ring.ring.capacity() {
return Err(val);
}
}
unsafe {
(*self.ring.ring.buf[self.tail & self.ring.ring.mask].get()).write(val);
}
self.tail += 1;
Ok(())
}
#[inline]
pub fn flush(&mut self) -> FlushResult {
let prev_flush = self.ring.ring.flush.0.load(Ordering::Relaxed);
if self.tail == prev_flush {
return FlushResult::NothingToFlush;
}
let count = self.tail - prev_flush;
self.cached_head = self.ring.ring.head.0.load(Ordering::Acquire);
let was_empty = prev_flush == self.cached_head;
self.ring.ring.flush.0.store(self.tail, Ordering::Release);
if was_empty {
self.ring.waker.0.wake();
}
FlushResult::Flushed { count, was_empty }
}
#[inline]
pub fn push_and_flush(&mut self, val: T) -> Result<FlushResult, T> {
self.push(val)?;
Ok(self.flush())
}
#[inline]
pub fn is_full(&mut self) -> bool {
if self.tail - self.cached_head >= self.ring.ring.capacity() {
self.cached_head = self.ring.ring.head.0.load(Ordering::Acquire);
self.tail - self.cached_head >= self.ring.ring.capacity()
} else {
false
}
}
#[inline]
pub fn capacity(&self) -> usize {
self.ring.ring.capacity()
}
#[inline]
pub fn len(&self) -> usize {
self.tail
.wrapping_sub(self.ring.ring.head.0.load(Ordering::Acquire))
}
#[inline]
pub fn is_empty(&self) -> bool {
self.tail == self.ring.ring.head.0.load(Ordering::Acquire)
}
}
impl<T> AsyncConsumer<T> {
#[inline]
pub fn pop(&mut self) -> Option<T> {
if self.head == self.cached_flush {
return None;
}
let val = unsafe {
(*self.ring.ring.buf[self.head & self.ring.ring.mask].get()).assume_init_read()
};
self.head += 1;
self.ring.ring.head.0.store(self.head, Ordering::Release);
Some(val)
}
#[inline]
pub fn prefetch(&mut self) -> usize {
let new_flush = self.ring.ring.flush.0.load(Ordering::Acquire);
let count = new_flush.wrapping_sub(self.cached_flush);
self.cached_flush = new_flush;
count
}
#[inline]
pub fn prefetch_and_pop(&mut self) -> Option<T> {
if self.head == self.cached_flush {
self.prefetch();
}
self.pop()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.head == self.cached_flush
&& self.ring.ring.flush.0.load(Ordering::Acquire) == self.head
}
#[inline]
pub fn capacity(&self) -> usize {
self.ring.ring.capacity()
}
#[inline]
pub fn len(&self) -> usize {
self.ring
.ring
.flush
.0
.load(Ordering::Acquire)
.wrapping_sub(self.head)
}
}
impl<T> Stream for AsyncConsumer<T> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if let Some(val) = this.prefetch_and_pop() {
return Poll::Ready(Some(val));
}
this.ring.waker.0.register(cx.waker());
if let Some(val) = this.prefetch_and_pop() {
Poll::Ready(Some(val))
} else {
Poll::Pending
}
}
}
impl<T> std::fmt::Debug for AsyncProducer<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncProducer")
.field("capacity", &self.capacity())
.finish_non_exhaustive()
}
}
impl<T> std::fmt::Debug for AsyncConsumer<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AsyncConsumer")
.field("capacity", &self.capacity())
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use futures_lite::StreamExt;
use super::*;
#[test]
fn async_push_pop() {
let (mut p, mut c) = async_spsc::<u32>(4);
p.push(1).unwrap();
p.push(2).unwrap();
assert!(c.prefetch_and_pop().is_none());
p.flush();
assert_eq!(c.prefetch_and_pop(), Some(1));
assert_eq!(c.prefetch_and_pop(), Some(2));
}
#[test]
fn stream_impl() {
futures_lite::future::block_on(async {
let (mut p, mut c) = async_spsc::<u32>(8);
p.push(10).unwrap();
p.push(20).unwrap();
p.push(30).unwrap();
p.flush();
assert_eq!(c.next().await, Some(10));
assert_eq!(c.next().await, Some(20));
assert_eq!(c.next().await, Some(30));
});
}
#[test]
fn stream_wakes_on_flush() {
use std::sync::atomic::{AtomicBool, Ordering};
let (mut p, mut c) = async_spsc::<u32>(8);
let done = Arc::new(AtomicBool::new(false));
let done2 = done.clone();
let handle = std::thread::spawn(move || {
futures_lite::future::block_on(async {
let val = c.next().await;
done2.store(true, Ordering::Release);
val
})
});
std::thread::sleep(std::time::Duration::from_millis(10));
assert!(!done.load(Ordering::Acquire));
p.push(42).unwrap();
p.flush();
let val = handle.join().unwrap();
assert_eq!(val, Some(42));
}
#[test]
fn cross_thread_stream() {
let (mut p, c) = async_spsc::<u64>(1024);
let n = 50_000u64;
let receiver = std::thread::spawn(move || {
futures_lite::future::block_on(async {
futures_lite::pin!(c);
let mut received = 0u64;
while let Some(v) = c.next().await {
assert_eq!(v, received);
received += 1;
if received == n {
break;
}
}
received
})
});
for i in 0..n {
while p.push(i).is_err() {
p.flush();
std::thread::yield_now();
}
if i % 64 == 63 {
p.flush();
}
}
p.flush();
let count = receiver.join().unwrap();
assert_eq!(count, n);
}
}