use std::{
collections::VecDeque,
sync::{Arc, Mutex},
task::Waker,
};
use futures::{stream::BoxStream, Stream, StreamExt};
use tokio::sync::Semaphore;
use tokio_util::sync::PollSemaphore;
#[derive(Clone, Copy, Debug, PartialEq)]
enum Side {
Left,
Right,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Capacity {
Bounded(u32),
Unbounded,
}
struct InnerState<'a, T> {
inner: Option<BoxStream<'a, T>>,
buffer: VecDeque<T>,
polling: Option<Side>,
waker: Option<Waker>,
exhausted: bool,
left_buffered: u32,
right_buffered: u32,
available_buffer: Option<PollSemaphore>,
}
pub struct SharedStream<'a, T: Clone> {
state: Arc<Mutex<InnerState<'a, T>>>,
side: Side,
}
impl<'a, T: Clone> SharedStream<'a, T> {
pub fn new(inner: BoxStream<'a, T>, capacity: Capacity) -> (Self, Self) {
let available_buffer = match capacity {
Capacity::Unbounded => None,
Capacity::Bounded(capacity) => Some(PollSemaphore::new(Arc::new(Semaphore::new(
capacity as usize,
)))),
};
let state = InnerState {
inner: Some(inner),
buffer: VecDeque::new(),
polling: None,
waker: None,
exhausted: false,
left_buffered: 0,
right_buffered: 0,
available_buffer,
};
let state = Arc::new(Mutex::new(state));
let left = Self {
state: state.clone(),
side: Side::Left,
};
let right = Self {
state,
side: Side::Right,
};
(left, right)
}
}
impl<'a, T: Clone> Stream for SharedStream<'a, T> {
type Item = T;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let mut inner_state = self.state.lock().unwrap();
let can_take_buffered = match self.side {
Side::Left => inner_state.left_buffered > 0,
Side::Right => inner_state.right_buffered > 0,
};
if can_take_buffered {
let item = inner_state.buffer.pop_front();
match self.side {
Side::Left => {
inner_state.left_buffered -= 1;
}
Side::Right => {
inner_state.right_buffered -= 1;
}
}
if let Some(available_buffer) = inner_state.available_buffer.as_mut() {
available_buffer.add_permits(1);
}
std::task::Poll::Ready(item)
} else {
if inner_state.exhausted {
return std::task::Poll::Ready(None);
}
let permit = if let Some(available_buffer) = inner_state.available_buffer.as_mut() {
match available_buffer.poll_acquire(cx) {
std::task::Poll::Ready(permit) => Some(permit.unwrap()),
std::task::Poll::Pending => {
return std::task::Poll::Pending;
}
}
} else {
None
};
if let Some(polling_side) = inner_state.polling.as_ref() {
if *polling_side != self.side {
debug_assert!(inner_state.waker.is_none());
inner_state.waker = Some(cx.waker().clone());
return std::task::Poll::Pending;
}
}
inner_state.polling = Some(self.side);
let mut to_poll = inner_state
.inner
.take()
.expect("Other half of shared stream panic'd while polling inner stream");
drop(inner_state);
let res = to_poll.poll_next_unpin(cx);
let mut inner_state = self.state.lock().unwrap();
let mut should_wake = true;
match &res {
std::task::Poll::Ready(None) => {
inner_state.exhausted = true;
inner_state.polling = None;
}
std::task::Poll::Ready(Some(item)) => {
if let Some(permit) = permit {
permit.forget();
}
inner_state.polling = None;
match self.side {
Side::Left => {
inner_state.right_buffered += 1;
}
Side::Right => {
inner_state.left_buffered += 1;
}
};
inner_state.buffer.push_back(item.clone());
}
std::task::Poll::Pending => {
should_wake = false;
}
};
inner_state.inner = Some(to_poll);
let to_wake = if should_wake {
inner_state.waker.take()
} else {
None
};
drop(inner_state);
if let Some(waker) = to_wake {
waker.wake();
}
res
}
}
}
pub trait SharedStreamExt<'a>: Stream + Send
where
Self::Item: Clone,
{
fn share(
self,
capacity: Capacity,
) -> (SharedStream<'a, Self::Item>, SharedStream<'a, Self::Item>);
}
impl<'a, T: Clone> SharedStreamExt<'a> for BoxStream<'a, T> {
fn share(self, capacity: Capacity) -> (SharedStream<'a, T>, SharedStream<'a, T>) {
SharedStream::new(self, capacity)
}
}
#[cfg(test)]
mod tests {
use futures::{FutureExt, StreamExt};
use tokio_stream::wrappers::ReceiverStream;
use crate::utils::futures::{Capacity, SharedStreamExt};
fn is_pending(fut: &mut (impl std::future::Future + Unpin)) -> bool {
let noop_waker = futures::task::noop_waker();
let mut context = std::task::Context::from_waker(&noop_waker);
fut.poll_unpin(&mut context).is_pending()
}
#[tokio::test]
async fn test_shared_stream() {
let (tx, rx) = tokio::sync::mpsc::channel::<u32>(10);
let inner_stream = ReceiverStream::new(rx);
for i in 0..3 {
tx.send(i).await.unwrap();
}
let (mut left, mut right) = inner_stream.boxed().share(Capacity::Bounded(2));
assert_eq!(left.next().await.unwrap(), 0);
assert_eq!(left.next().await.unwrap(), 1);
let mut left_fut = left.next();
assert!(is_pending(&mut left_fut));
assert_eq!(right.next().await.unwrap(), 0);
assert_eq!(left_fut.await.unwrap(), 2);
assert_eq!(right.next().await.unwrap(), 1);
assert_eq!(right.next().await.unwrap(), 2);
let mut right_fut = right.next();
let mut left_fut = left.next();
assert!(is_pending(&mut right_fut));
assert!(is_pending(&mut left_fut));
tx.send(3).await.unwrap();
assert_eq!(right_fut.await.unwrap(), 3);
assert_eq!(left_fut.await.unwrap(), 3);
drop(tx);
assert_eq!(left.next().await, None);
assert_eq!(right.next().await, None);
assert_eq!(left.next().await, None);
assert_eq!(right.next().await, None);
}
#[tokio::test]
async fn test_unbounded_shared_stream() {
let (tx, rx) = tokio::sync::mpsc::channel::<u32>(10);
let inner_stream = ReceiverStream::new(rx);
for i in 0..10 {
tx.send(i).await.unwrap();
}
drop(tx);
let (mut left, mut right) = inner_stream.boxed().share(Capacity::Unbounded);
for i in 0..10 {
assert_eq!(left.next().await.unwrap(), i);
}
assert_eq!(left.next().await, None);
for i in 0..10 {
assert_eq!(right.next().await.unwrap(), i);
}
assert_eq!(right.next().await, None);
}
#[tokio::test(flavor = "multi_thread")]
async fn stress_shared_stream() {
for _ in 0..100 {
let (tx, rx) = tokio::sync::mpsc::channel::<u32>(10);
let inner_stream = ReceiverStream::new(rx);
let (mut left, mut right) = inner_stream.boxed().share(Capacity::Bounded(2));
let left_handle = tokio::spawn(async move {
let mut counter = 0;
while let Some(item) = left.next().await {
assert_eq!(item, counter);
counter += 1;
}
});
let right_handle = tokio::spawn(async move {
let mut counter = 0;
while let Some(item) = right.next().await {
assert_eq!(item, counter);
counter += 1;
}
});
for i in 0..1000 {
tx.send(i).await.unwrap();
}
drop(tx);
left_handle.await.unwrap();
right_handle.await.unwrap();
}
}
}