use std::collections::VecDeque;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::task::{Context, Poll, Waker};
struct State<T> {
buffer: VecDeque<T>,
waker: Option<Waker>,
capacity: usize,
closed: bool,
}
struct BackPressure {
cvar: Condvar,
consumer_gone: Mutex<bool>,
sender_count: AtomicUsize,
}
pub struct BoundedAsyncStream<T> {
state: Arc<Mutex<State<T>>>,
back_pressure: Arc<BackPressure>,
}
pub struct AsyncStreamSender<T> {
state: Arc<Mutex<State<T>>>,
back_pressure: Arc<BackPressure>,
}
impl<T> Clone for AsyncStreamSender<T> {
fn clone(&self) -> Self {
self.back_pressure
.sender_count
.fetch_add(1, Ordering::Relaxed);
Self {
state: Arc::clone(&self.state),
back_pressure: Arc::clone(&self.back_pressure),
}
}
}
impl<T> fmt::Debug for BoundedAsyncStream<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BoundedAsyncStream")
.field("buffered", &self.buffered_count())
.field("capacity", &self.capacity())
.field("is_closed", &self.is_closed())
.finish_non_exhaustive()
}
}
impl<T> fmt::Debug for AsyncStreamSender<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AsyncStreamSender").finish_non_exhaustive()
}
}
impl<T> BoundedAsyncStream<T> {
#[must_use]
pub fn new(capacity: usize) -> (Self, AsyncStreamSender<T>) {
assert!(capacity > 0, "BoundedAsyncStream capacity must be > 0");
let state = Arc::new(Mutex::new(State {
buffer: VecDeque::with_capacity(capacity),
waker: None,
capacity,
closed: false,
}));
let back_pressure = Arc::new(BackPressure {
cvar: Condvar::new(),
consumer_gone: Mutex::new(false),
sender_count: AtomicUsize::new(1),
});
let stream = Self {
state: Arc::clone(&state),
back_pressure: Arc::clone(&back_pressure),
};
let sender = AsyncStreamSender {
state,
back_pressure,
};
(stream, sender)
}
#[must_use]
pub const fn next(&self) -> NextItem<'_, T> {
NextItem { stream: self }
}
#[must_use]
pub fn try_next(&self) -> Option<T> {
self.state.lock().ok()?.buffer.pop_front()
}
#[must_use]
pub fn is_closed(&self) -> bool {
self.state.lock().map_or(true, |s| s.closed)
}
#[must_use]
pub fn buffered_count(&self) -> usize {
self.state.lock().map_or(0, |s| s.buffer.len())
}
#[must_use]
pub fn capacity(&self) -> usize {
self.state.lock().map_or(0, |s| s.capacity)
}
pub fn clear_buffer(&self) {
if let Ok(mut state) = self.state.lock() {
state.buffer.clear();
}
}
}
impl<T> Drop for BoundedAsyncStream<T> {
fn drop(&mut self) {
if let Ok(mut consumer_gone) = self.back_pressure.consumer_gone.lock() {
*consumer_gone = true;
}
self.back_pressure.cvar.notify_all();
}
}
impl<T> AsyncStreamSender<T> {
pub fn push(&self, item: T) {
let Ok(mut state) = self.state.lock() else {
return;
};
if state.buffer.len() >= state.capacity {
state.buffer.pop_front();
}
state.buffer.push_back(item);
if let Some(w) = state.waker.take() {
w.wake();
}
}
pub fn push_or_block(&self, item: T) -> Result<(), T> {
let mut item_slot = Some(item);
let Ok(mut state_guard) = self.state.lock() else {
return Err(item_slot.take().expect("item present"));
};
loop {
if let Ok(consumer_gone) = self.back_pressure.consumer_gone.lock() {
if *consumer_gone {
return Err(item_slot.take().expect("item present"));
}
}
if state_guard.buffer.len() < state_guard.capacity {
let item = item_slot.take().expect("item present");
state_guard.buffer.push_back(item);
if let Some(w) = state_guard.waker.take() {
w.wake();
}
return Ok(());
}
drop(state_guard);
let Ok(consumer_gone) = self.back_pressure.consumer_gone.lock() else {
return Err(item_slot.take().expect("item present"));
};
let wait_outcome = self.back_pressure.cvar.wait(consumer_gone);
drop(wait_outcome);
state_guard = match self.state.lock() {
Ok(g) => g,
Err(_) => return Err(item_slot.take().expect("item present")),
};
}
}
#[must_use]
pub fn buffered_count(&self) -> usize {
self.state.lock().map_or(0, |s| s.buffer.len())
}
#[must_use]
pub fn is_consumer_gone(&self) -> bool {
self.back_pressure.consumer_gone.lock().map_or(true, |g| *g)
}
}
impl<T> Drop for AsyncStreamSender<T> {
fn drop(&mut self) {
let prev = self
.back_pressure
.sender_count
.fetch_sub(1, Ordering::AcqRel);
if prev == 1 {
if let Ok(mut state) = self.state.lock() {
state.closed = true;
if let Some(w) = state.waker.take() {
w.wake();
}
}
}
self.back_pressure.cvar.notify_all();
}
}
pub struct NextItem<'a, T> {
stream: &'a BoundedAsyncStream<T>,
}
impl<T> fmt::Debug for NextItem<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NextItem").finish_non_exhaustive()
}
}
impl<T> Future for NextItem<'_, T> {
type Output = Option<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Ok(mut state) = self.stream.state.lock() else {
return Poll::Ready(None);
};
if let Some(item) = state.buffer.pop_front() {
self.stream.back_pressure.cvar.notify_all();
return Poll::Ready(Some(item));
}
if state.closed {
return Poll::Ready(None);
}
let waker = cx.waker();
match state.waker {
Some(ref existing) if existing.will_wake(waker) => {}
_ => state.waker = Some(waker.clone()),
}
Poll::Pending
}
}
#[cfg(feature = "futures-stream")]
impl<T: 'static> futures_core::Stream for BoundedAsyncStream<T> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
let Ok(mut state) = self.state.lock() else {
return Poll::Ready(None);
};
if let Some(item) = state.buffer.pop_front() {
self.back_pressure.cvar.notify_all();
return Poll::Ready(Some(item));
}
if state.closed {
return Poll::Ready(None);
}
let waker = cx.waker();
match state.waker {
Some(ref existing) if existing.will_wake(waker) => {}
_ => state.waker = Some(waker.clone()),
}
Poll::Pending
}
}