use crate::*;
use core::ops::DerefMut;
use core::pin::Pin;
use core::task::{Context, Poll};
fn poll_multiple<I, P, S>(
streams: I,
cx: &mut Context<'_>,
before: Option<&S::Ordering>,
) -> Poll<PollResult<S::Ordering, S::Data>>
where
I: IntoIterator<Item = Pin<P>>,
P: DerefMut<Target = Peekable<S>>,
S: OrderedStream,
{
let mut best: Option<Pin<P>> = None;
let mut has_data = false;
let mut has_pending = false;
for mut stream in streams {
let best_before = best.as_ref().and_then(|p| p.item().map(|i| &i.0));
let before = match (before, best_before) {
(Some(a), Some(b)) if a < b => Some(a),
(_, Some(b)) => Some(b),
(a, None) => a,
};
match stream.as_mut().poll_peek_before(cx, before) {
Poll::Pending => {
has_pending = true;
}
Poll::Ready(PollResult::Terminated) => continue,
Poll::Ready(PollResult::NoneBefore) => {
has_data = true;
}
Poll::Ready(PollResult::Item { ordering, .. }) => match before {
Some(max) if max < ordering => continue,
_ => {
best = Some(stream);
}
},
}
}
match best {
None if has_data => Poll::Ready(PollResult::NoneBefore),
None if has_pending => Poll::Pending,
None => Poll::Ready(PollResult::Terminated),
Some(mut stream) => stream.as_mut().poll_next_before(cx, before),
}
}
#[derive(Debug, Default, Clone)]
pub struct JoinMultiple<C>(pub C);
impl<C> Unpin for JoinMultiple<C> {}
impl<C, S> OrderedStream for JoinMultiple<C>
where
for<'a> &'a mut C: IntoIterator<Item = &'a mut Peekable<S>>,
S: OrderedStream + Unpin,
{
type Ordering = S::Ordering;
type Data = S::Data;
fn poll_next_before(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
before: Option<&S::Ordering>,
) -> Poll<PollResult<S::Ordering, S::Data>> {
poll_multiple(self.get_mut().0.into_iter().map(Pin::new), cx, before)
}
}
impl<C, S> FusedOrderedStream for JoinMultiple<C>
where
for<'a> &'a mut C: IntoIterator<Item = &'a mut Peekable<S>>,
for<'a> &'a C: IntoIterator<Item = &'a Peekable<S>>,
S: OrderedStream + Unpin,
{
fn is_terminated(&self) -> bool {
self.0.into_iter().all(|peekable| peekable.is_terminated())
}
}
pin_project_lite::pin_project! {
#[derive(Debug,Default,Clone)]
pub struct JoinMultiplePin<C> {
#[pin]
pub streams: C,
}
}
impl<C> JoinMultiplePin<C> {
pub fn as_pin_mut(self: Pin<&mut Self>) -> Pin<&mut C> {
self.project().streams
}
}
impl<C, S> OrderedStream for JoinMultiplePin<C>
where
for<'a> Pin<&'a mut C>: IntoIterator<Item = Pin<&'a mut Peekable<S>>>,
S: OrderedStream,
{
type Ordering = S::Ordering;
type Data = S::Data;
fn poll_next_before(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
before: Option<&S::Ordering>,
) -> Poll<PollResult<S::Ordering, S::Data>> {
poll_multiple(self.as_pin_mut(), cx, before)
}
}
#[cfg(test)]
mod test {
extern crate alloc;
use crate::FromStream;
use crate::JoinMultiple;
use crate::OrderedStreamExt;
use alloc::boxed::Box;
use alloc::vec::Vec;
use core::pin::Pin;
use futures_core::Stream;
use futures_util::stream::iter;
#[test]
fn join_mutiple() {
futures_executor::block_on(async {
pub struct Message {
serial: u32,
}
pub struct RemoteLogSource {
stream: Pin<Box<dyn Stream<Item = Message>>>,
}
let mut logs = [
RemoteLogSource {
stream: Box::pin(iter([
Message { serial: 1 },
Message { serial: 4 },
Message { serial: 5 },
])),
},
RemoteLogSource {
stream: Box::pin(iter([
Message { serial: 2 },
Message { serial: 3 },
Message { serial: 6 },
])),
},
];
let streams: Vec<_> = logs
.iter_mut()
.map(|s| FromStream::with_ordering(&mut s.stream, |m| m.serial).peekable())
.collect();
let mut joined = JoinMultiple(streams);
for i in 0..6 {
let msg = joined.next().await.unwrap();
assert_eq!(msg.serial, i as u32 + 1);
}
});
}
}