use futures_util::future::{FutureExt, Shared};
use futures_util::stream::{FusedStream, Stream};
use std::future::Future;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
type SizeHint = (usize, Option<usize>);
#[cfg_attr(test, derive(Debug))]
struct InnerFuture<S>
where
S: Stream + Unpin,
{
inner: Option<S>,
}
impl<S> InnerFuture<S>
where
S: Stream + Unpin,
{
pub(crate) fn new(stream: S) -> Self {
InnerFuture {
inner: Some(stream),
}
}
}
impl<S> Future for InnerFuture<S>
where
S: Stream + Unpin,
S::Item: Clone,
{
type Output = Option<(S::Item, Shared<Self>, SizeHint)>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let stream = match self.inner.as_mut() {
Some(f) => Pin::new(f),
None => return Poll::Ready(None),
};
let result = ready!(stream.poll_next(cx));
let stream = self.inner.take().unwrap();
match result {
Some(item) => {
let size_hint = stream.size_hint();
let next_shared_future = InnerFuture::new(stream).shared();
Poll::Ready(Some((item, next_shared_future, size_hint)))
}
None => Poll::Ready(None),
}
}
}
#[derive(Debug)]
pub struct SharedStream<S>
where
S: Stream + Unpin,
S::Item: Clone,
{
future: Option<Shared<InnerFuture<S>>>,
size_hint: SizeHint,
#[cfg(feature = "stats")]
stats: crate::stats::Stats,
}
impl<S> Clone for SharedStream<S>
where
S: Stream + Unpin,
S::Item: Clone,
{
fn clone(&self) -> Self {
let s = Self {
future: self.future.clone(),
size_hint: self.size_hint,
#[cfg(feature = "stats")]
stats: self.stats.clone(),
};
#[cfg(feature = "stats")]
if self.future.is_some() {
s.stats.increment();
}
s
}
}
impl<S> SharedStream<S>
where
S: Stream + Unpin,
S::Item: Clone,
{
pub fn new(stream: S) -> Self {
let size_hint = stream.size_hint();
Self {
future: InnerFuture::new(stream).shared().into(),
size_hint,
#[cfg(feature = "stats")]
stats: crate::stats::Stats::new(),
}
}
#[cfg(feature = "stats")]
#[cfg_attr(docsrs, doc(cfg(feature = "stats")))]
pub fn stats(&self) -> crate::stats::Stats {
self.stats.clone()
}
}
impl<S> Stream for SharedStream<S>
where
S: Stream + Unpin,
S::Item: Clone,
{
type Item = S::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let poll_result = match &mut self.future {
Some(f) => Pin::new(f).poll(cx),
None => return Poll::Ready(None),
};
match poll_result {
Poll::Pending => Poll::Pending,
Poll::Ready(Some((item, next_shared_future, size_hint))) => {
self.future = next_shared_future.into();
self.size_hint = size_hint;
Poll::Ready(Some(item))
}
Poll::Ready(None) => {
self.future.take();
#[cfg(feature = "stats")]
{
self.stats.decrement();
}
Poll::Ready(None)
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.size_hint
}
}
impl<S> FusedStream for SharedStream<S>
where
S: Stream + Unpin,
S::Item: Clone,
{
fn is_terminated(&self) -> bool {
self.future.is_none()
}
}
#[cfg(feature = "stats")]
impl<S> Drop for SharedStream<S>
where
S: Stream + Unpin,
S::Item: Clone,
{
fn drop(&mut self) {
if self.future.is_some() {
self.stats.decrement();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::StreamExt;
use futures_util::stream;
#[tokio::test]
async fn test_basic_shared_stream_works() {
let original_data = vec![1, 2, 3, 4, 5];
let stream = stream::iter(original_data.clone());
let shared_stream = SharedStream::new(stream);
assert_eq!(shared_stream.size_hint(), (5, Some(5)));
let collected: Vec<i32> = shared_stream.collect().await;
assert_eq!(collected, original_data);
}
#[tokio::test]
async fn test_multiple_clones_get_same_data() {
let data = vec![10, 20, 30];
let stream = stream::iter(data.clone());
let shared_stream = SharedStream::new(stream);
let clone1 = shared_stream.clone();
let clone2 = shared_stream.clone();
let clone3 = shared_stream.clone();
assert_eq!(clone1.size_hint(), (3, Some(3)));
assert_eq!(clone2.size_hint(), (3, Some(3)));
assert_eq!(clone3.size_hint(), (3, Some(3)));
let (result1, result2, result3) = tokio::join!(
clone1.collect::<Vec<i32>>(),
clone2.collect::<Vec<i32>>(),
clone3.collect::<Vec<i32>>()
);
assert_eq!(result1, data);
assert_eq!(result2, data);
assert_eq!(result3, data);
}
#[tokio::test]
async fn test_clone_after_partial_consumption() {
use futures_util::StreamExt;
let numbers = vec![100, 200, 300, 400];
let stream = stream::iter(numbers.clone());
let mut shared_stream = SharedStream::new(stream);
assert_eq!(shared_stream.size_hint(), (4, Some(4)));
let first_item = shared_stream.next().await;
assert_eq!(first_item, Some(100));
assert_eq!(shared_stream.size_hint(), (3, Some(3)));
let cloned_stream = shared_stream.clone();
assert_eq!(cloned_stream.size_hint(), (3, Some(3)));
let clone_result: Vec<i32> = cloned_stream.collect().await;
assert_eq!(clone_result, vec![200, 300, 400]);
}
#[tokio::test]
async fn test_with_string_data() {
let messages = vec!["hello".to_string(), "world".to_string()];
let stream = stream::iter(messages.clone());
let shared_stream = SharedStream::new(stream);
let clone1 = shared_stream.clone();
let clone2 = shared_stream.clone();
let (result1, result2) = tokio::join!(
clone1.collect::<Vec<String>>(),
clone2.collect::<Vec<String>>()
);
assert_eq!(result1, messages);
assert_eq!(result2, messages);
}
#[tokio::test]
async fn test_empty_stream_behavior() {
let empty_stream = stream::iter(Vec::<i32>::new());
let shared_stream = SharedStream::new(empty_stream);
assert_eq!(shared_stream.size_hint(), (0, Some(0)));
let clone1 = shared_stream.clone();
let clone2 = shared_stream.clone();
assert_eq!(clone1.size_hint(), (0, Some(0)));
assert_eq!(clone2.size_hint(), (0, Some(0)));
let (result1, result2) =
tokio::join!(clone1.collect::<Vec<i32>>(), clone2.collect::<Vec<i32>>());
assert!(result1.is_empty());
assert!(result2.is_empty());
}
#[tokio::test]
async fn test_single_item_stream() {
use futures_util::StreamExt;
let single_item = vec![42];
let stream = stream::iter(single_item.clone());
let mut shared_stream = SharedStream::new(stream);
assert_eq!(shared_stream.size_hint(), (1, Some(1)));
let clone1 = shared_stream.clone();
let clone2 = shared_stream.clone();
assert_eq!(clone1.size_hint(), (1, Some(1)));
assert_eq!(clone2.size_hint(), (1, Some(1)));
let (result1, result2) =
tokio::join!(clone1.collect::<Vec<i32>>(), clone2.collect::<Vec<i32>>());
assert_eq!(result1, single_item);
assert_eq!(result2, single_item);
let item = shared_stream.next().await;
assert_eq!(item, Some(42));
assert_eq!(shared_stream.size_hint(), (0, Some(0)));
let remaining: Vec<i32> = shared_stream.collect().await;
assert!(remaining.is_empty());
}
#[tokio::test]
async fn test_many_clones_stress_test() {
let data = vec![1, 2, 3];
let stream = stream::iter(data.clone());
let shared_stream = SharedStream::new(stream);
let mut clone_futures = Vec::new();
for _ in 0..20 {
let clone = shared_stream.clone();
clone_futures.push(clone.collect::<Vec<i32>>());
}
let all_results = futures_util::future::join_all(clone_futures).await;
for result in all_results {
assert_eq!(result, data);
}
}
#[tokio::test]
async fn test_not_unpin_stream_with_box_pin() {
use std::marker::PhantomPinned;
use std::task::Context;
#[derive(Clone)]
struct NotUnpinStream {
data: Vec<i32>,
index: usize,
_pin: PhantomPinned,
}
impl Stream for NotUnpinStream {
type Item = i32;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = unsafe { self.get_unchecked_mut() };
if this.index < this.data.len() {
let item = this.data[this.index];
this.index += 1;
Poll::Ready(Some(item))
} else {
Poll::Ready(None)
}
}
}
let not_unpin_stream = NotUnpinStream {
data: vec![10, 20, 30],
index: 0,
_pin: PhantomPinned,
};
static_assertions::assert_not_impl_any!(NotUnpinStream: Unpin);
let pinned_stream = Box::pin(not_unpin_stream);
let shared_stream = SharedStream::new(pinned_stream);
let clone1 = shared_stream.clone();
let clone2 = shared_stream.clone();
let (result1, result2) =
tokio::join!(clone1.collect::<Vec<i32>>(), clone2.collect::<Vec<i32>>());
assert_eq!(result1, vec![10, 20, 30]);
assert_eq!(result2, vec![10, 20, 30]);
}
#[test]
fn test_send_sync_bounds() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
type TestStream = SharedStream<futures_util::stream::Iter<std::vec::IntoIter<i32>>>;
assert_send::<TestStream>();
assert_sync::<TestStream>();
static_assertions::assert_impl_all!(TestStream: Send, Sync);
}
#[tokio::test]
async fn test_cross_thread_sharing() {
use tokio::task;
let data = vec![1, 2, 3, 4, 5];
let stream = stream::iter(data.clone());
let shared_stream = SharedStream::new(stream);
let stream1 = shared_stream.clone();
let stream2 = shared_stream.clone();
let handle1 = task::spawn(async move { stream1.collect::<Vec<i32>>().await });
let handle2 = task::spawn(async move { stream2.collect::<Vec<i32>>().await });
let (result1, result2) = tokio::join!(handle1, handle2);
assert_eq!(result1.unwrap(), data);
assert_eq!(result2.unwrap(), data);
}
#[tokio::test]
async fn test_next_after_stream_exhausted() {
use futures_util::stream::FusedStream;
use futures_util::StreamExt;
let data = vec![1, 2, 3];
let stream = stream::iter(data.clone());
let mut shared_stream = SharedStream::new(stream);
assert!(!shared_stream.is_terminated());
let mut collected = Vec::new();
while let Some(item) = shared_stream.next().await {
collected.push(item);
if collected.len() < data.len() {
assert!(!shared_stream.is_terminated());
}
}
assert_eq!(collected, data);
assert!(shared_stream.is_terminated());
let result = shared_stream.next().await;
assert_eq!(result, None);
assert!(shared_stream.is_terminated());
let result2 = shared_stream.next().await;
assert_eq!(result2, None);
assert!(shared_stream.is_terminated());
let mut cloned_stream = shared_stream.clone();
assert!(cloned_stream.is_terminated());
let result3 = cloned_stream.next().await;
assert_eq!(result3, None);
assert!(cloned_stream.is_terminated());
}
#[tokio::test]
async fn test_pending_future_behavior() {
use futures_util::StreamExt;
use std::pin::Pin;
use std::task::{Context, Poll};
#[derive(Clone)]
struct PendingOnceStream {
data: Vec<i32>,
index: usize,
has_returned_pending: bool,
}
impl Stream for PendingOnceStream {
type Item = i32;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if !this.has_returned_pending {
this.has_returned_pending = true;
cx.waker().wake_by_ref();
return Poll::Pending;
}
if this.index < this.data.len() {
let item = this.data[this.index];
this.index += 1;
Poll::Ready(Some(item))
} else {
Poll::Ready(None)
}
}
}
let pending_stream = PendingOnceStream {
data: vec![100, 200],
index: 0,
has_returned_pending: false,
};
let shared_stream = SharedStream::new(pending_stream);
let result = shared_stream.collect::<Vec<i32>>().await;
assert_eq!(result, vec![100, 200]);
}
#[tokio::test]
async fn test_inner_future_direct_poll() {
use futures_util::stream;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
let stream = stream::iter(vec![42]);
let mut inner_future = InnerFuture::new(stream);
let waker = futures_util::task::noop_waker();
let mut cx = Context::from_waker(&waker);
let result = Pin::new(&mut inner_future).poll(&mut cx);
assert!(matches!(result, Poll::Ready(Some((42, _, _)))));
let result = Pin::new(&mut inner_future).poll(&mut cx);
assert!(matches!(result, Poll::Ready(None)));
}
#[tokio::test]
#[cfg(feature = "stats")]
async fn test_stats() {
use futures_util::stream;
let stream = stream::iter(vec![1, 2, 3]);
let shared = SharedStream::new(stream);
let stats = shared.stats();
assert_eq!(stats.active_clones(), 1);
let clone1 = shared.clone();
assert_eq!(stats.active_clones(), 2);
let clone2 = clone1.clone();
assert_eq!(stats.active_clones(), 3);
drop(clone2);
assert_eq!(stats.active_clones(), 2);
let _orig_collected: Vec<i32> = shared.collect().await;
assert_eq!(stats.active_clones(), 1);
drop(clone1);
assert_eq!(stats.active_clones(), 0);
let empty_stream = stream::iter(Vec::<i32>::new());
let shared_empty = SharedStream::new(empty_stream);
let stats_empty = shared_empty.stats();
assert_eq!(stats_empty.active_clones(), 1);
let clone_empty = shared_empty.clone();
assert_eq!(stats_empty.active_clones(), 2);
drop(clone_empty);
assert_eq!(stats_empty.active_clones(), 1);
let _ = shared_empty.collect::<Vec<i32>>().await;
assert_eq!(stats_empty.active_clones(), 0);
let stream2 = stream::iter(vec![10, 20, 30]);
let mut shared2 = SharedStream::new(stream2);
let stats2 = shared2.stats();
assert_eq!(stats2.active_clones(), 1);
let first = shared2.next().await;
assert_eq!(first, Some(10));
let clone_after = shared2.clone();
assert_eq!(stats2.active_clones(), 2);
drop(clone_after);
assert_eq!(stats2.active_clones(), 1);
let _ = shared2.collect::<Vec<i32>>().await;
assert_eq!(stats2.active_clones(), 0);
}
}