use alloc::vec::Vec;
use core::{
pin::Pin,
task::{Context, Poll},
};
use futures_core::Stream;
use crate::{futures_unordered::MIN_CAPACITY, FuturesUnorderedBounded, MergeBounded};
pub struct MergeUnbounded<S> {
pub(crate) groups: Vec<MergeBounded<S>>,
poll_next: usize,
}
impl<S> Default for MergeUnbounded<S> {
fn default() -> Self {
Self::new()
}
}
impl<S> MergeUnbounded<S> {
pub const fn new() -> Self {
Self {
groups: Vec::new(),
poll_next: 0,
}
}
#[track_caller]
pub fn push(&mut self, stream: S) {
let last = match self.groups.last_mut() {
Some(last) => last,
None => {
self.groups.push(MergeBounded {
streams: FuturesUnorderedBounded::new(MIN_CAPACITY),
});
self.groups.last_mut().unwrap()
}
};
match last.try_push(stream) {
Ok(()) => {}
Err(stream) => {
let mut next = MergeBounded {
streams: FuturesUnorderedBounded::new(last.streams.capacity() * 2),
};
next.push(stream);
self.groups.push(next);
}
}
}
pub fn is_empty(&self) -> bool {
self.groups.iter().all(|g| g.streams.is_empty())
}
pub fn len(&self) -> usize {
self.groups.iter().map(|g| g.streams.len()).sum()
}
}
impl<S: Stream + Unpin> Stream for MergeUnbounded<S> {
type Item = S::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let Self { groups, poll_next } = &mut *self;
if groups.is_empty() {
return Poll::Ready(None);
}
for _ in 0..groups.len() {
if *poll_next >= groups.len() {
*poll_next = 0;
}
let poll = Pin::new(&mut groups[*poll_next]).poll_next(cx);
match poll {
Poll::Ready(Some(x)) => {
return Poll::Ready(Some(x));
}
Poll::Ready(None) => {
let group = groups.remove(*poll_next);
debug_assert!(group.streams.is_empty());
if groups.is_empty() {
groups.push(group);
return Poll::Ready(None);
}
if *poll_next == groups.len() {
groups.push(group);
*poll_next = 0;
}
}
Poll::Pending => {
*poll_next += 1;
}
}
}
Poll::Pending
}
}
impl<S: Stream + Unpin> FromIterator<S> for MergeUnbounded<S> {
fn from_iter<T>(iter: T) -> Self
where
T: IntoIterator<Item = S>,
{
let iter = iter.into_iter();
let mut this = Self::new();
for stream in iter {
this.push(stream);
}
this
}
}
#[cfg(test)]
mod tests {
use core::cell::RefCell;
use core::task::Waker;
use super::*;
use alloc::collections::VecDeque;
use alloc::rc::Rc;
use futures::executor::block_on;
use futures::executor::LocalPool;
use futures::stream;
use futures::task::LocalSpawnExt;
use futures::StreamExt;
#[test]
fn merge_tuple_4() {
block_on(async {
let a = stream::repeat(2).take(2);
let b = stream::repeat(3).take(3);
let c = stream::repeat(5).take(5);
let d = stream::repeat(7).take(7);
let mut s: MergeUnbounded<_> = [a, b, c, d].into_iter().collect();
let mut counter = 0;
while let Some(n) = s.next().await {
counter += n;
}
assert_eq!(counter, 4 + 9 + 25 + 49);
});
}
#[test]
fn add_streams() {
block_on(async {
let a = stream::repeat(2).take(2);
let b = stream::repeat(3).take(3);
let mut s = MergeUnbounded::default();
assert_eq!(s.next().await, None);
assert!(s.is_empty());
assert_eq!(s.len(), 0);
s.push(a);
s.push(b);
assert!(!s.is_empty());
assert_eq!(s.len(), 2);
let mut counter = 0;
while let Some(n) = s.next().await {
counter += n;
assert!(!s.is_empty());
}
assert!(s.is_empty());
assert_eq!(s.len(), 0);
let b = stream::repeat(4).take(4);
s.push(b);
assert!(!s.is_empty());
assert_eq!(s.len(), 1);
while let Some(n) = s.next().await {
counter += n;
}
assert_eq!(counter, 4 + 9 + 16);
assert!(s.is_empty());
assert_eq!(s.len(), 0);
});
}
#[test]
fn merge_channels() {
struct LocalChannel<T> {
queue: VecDeque<T>,
waker: Option<Waker>,
closed: bool,
}
struct LocalReceiver<T> {
channel: Rc<RefCell<LocalChannel<T>>>,
}
impl<T> Stream for LocalReceiver<T> {
type Item = T;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut channel = self.channel.borrow_mut();
match channel.queue.pop_front() {
Some(item) => Poll::Ready(Some(item)),
None => {
if channel.closed {
Poll::Ready(None)
} else {
channel.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}
}
}
struct LocalSender<T> {
channel: Rc<RefCell<LocalChannel<T>>>,
}
impl<T> LocalSender<T> {
fn send(&self, item: T) {
let mut channel = self.channel.borrow_mut();
channel.queue.push_back(item);
let _ = channel.waker.take().map(Waker::wake);
}
}
impl<T> Drop for LocalSender<T> {
fn drop(&mut self) {
let mut channel = self.channel.borrow_mut();
channel.closed = true;
let _ = channel.waker.take().map(Waker::wake);
}
}
fn local_channel<T>() -> (LocalSender<T>, LocalReceiver<T>) {
let channel = Rc::new(RefCell::new(LocalChannel {
queue: VecDeque::new(),
waker: None,
closed: false,
}));
(
LocalSender {
channel: channel.clone(),
},
LocalReceiver { channel },
)
}
let mut pool = LocalPool::new();
let done = Rc::new(RefCell::new(false));
let done2 = done.clone();
pool.spawner()
.spawn_local(async move {
let (send1, receive1) = local_channel();
let (send2, receive2) = local_channel();
let (send3, receive3) = local_channel();
let (count, ()) = futures::future::join(
async {
let s: MergeUnbounded<_> =
[receive1, receive2, receive3].into_iter().collect();
s.fold(0, |a, b| async move { a + b }).await
},
async {
for i in 1..=4 {
send1.send(i);
send2.send(i);
send3.send(i);
}
drop(send1);
drop(send2);
drop(send3);
},
)
.await;
assert_eq!(count, 30);
*done2.borrow_mut() = true;
})
.unwrap();
while !*done.borrow() {
pool.run_until_stalled();
}
}
}