#![warn(rust_2018_idioms)]
use std::cmp::Ordering;
use std::task::Poll;
use binary_heap_plus::BinaryHeap;
use compare::Compare;
use futures::future::{join_all, JoinAll};
use futures::{ready, stream::StreamFuture, FutureExt, Stream, StreamExt};
use pin_project_lite::pin_project;
#[derive(Debug)]
pub struct HeadTail<S>
where
S: Stream,
{
head: S::Item,
tail: S,
}
pin_project! {
#[must_use = "stream adaptors are lazy and do nothing unless consumed"]
pub struct KWayMergeBy<S, C>
where
S: Stream,
S: Unpin,
C: Compare<HeadTail<S>>
{
initial: Option<JoinAll<StreamFuture<S>>>,
next: Option<S>,
heap: BinaryHeap<HeadTail<S>, C>,
}
}
#[must_use = "stream adaptors are lazy and do nothing unless consumed"]
pub type KWayMerge<I> = KWayMergeBy<I, OrdComparator>;
pub struct OrdComparator;
impl<S> Compare<HeadTail<S>> for OrdComparator
where
S: Stream,
S::Item: Ord,
{
fn compare(&self, l: &HeadTail<S>, r: &HeadTail<S>) -> std::cmp::Ordering {
l.head.cmp(&r.head)
}
}
pub struct FnComparator<F> {
f: F,
}
impl<S, F> Compare<HeadTail<S>> for FnComparator<F>
where
S: Stream,
F: Fn(&S::Item, &S::Item) -> Ordering,
{
fn compare(&self, l: &HeadTail<S>, r: &HeadTail<S>) -> std::cmp::Ordering {
(self.f)(&l.head, &r.head)
}
}
pub struct KeyComparator<F> {
f: F,
}
impl<S, F, O> Compare<HeadTail<S>> for KeyComparator<F>
where
S: Stream,
F: Fn(&S::Item) -> O,
O: Ord,
{
fn compare(&self, l: &HeadTail<S>, r: &HeadTail<S>) -> std::cmp::Ordering {
(self.f)(&l.head).cmp(&(self.f)(&r.head))
}
}
pub fn kmerge<S>(xs: impl IntoIterator<Item = S>) -> KWayMerge<S>
where
S: Stream + Unpin,
S::Item: Ord,
{
assert_stream::<S::Item, _>(kmerge_generic(xs, OrdComparator))
}
pub fn kmerge_by<S, F>(xs: impl IntoIterator<Item = S>, f: F) -> KWayMergeBy<S, FnComparator<F>>
where
S: Stream + Unpin,
F: Fn(&S::Item, &S::Item) -> Ordering,
{
kmerge_generic(xs, FnComparator { f })
}
pub fn kmerge_by_key<S, F, O>(
xs: impl IntoIterator<Item = S>,
f: F,
) -> KWayMergeBy<S, KeyComparator<F>>
where
S: Stream + Unpin,
F: Fn(&S::Item) -> O,
O: Ord,
{
kmerge_generic(xs, KeyComparator { f })
}
fn kmerge_generic<S, C>(xs: impl IntoIterator<Item = S>, cmp: C) -> KWayMergeBy<S, C>
where
S: Stream + Unpin,
C: Compare<HeadTail<S>>,
{
let iter = xs.into_iter();
let (min_size, _) = iter.size_hint();
assert_stream::<S::Item, _>(KWayMergeBy {
initial: Some(join_all(iter.map(|x| x.into_future()))),
next: None,
heap: BinaryHeap::from_vec_cmp(Vec::with_capacity(min_size), cmp),
})
}
impl<S, C> Stream for KWayMergeBy<S, C>
where
S: Stream + Unpin,
C: Compare<HeadTail<S>>,
{
type Item = S::Item;
fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let this = self.project();
if let Some(init_fut) = this.initial.as_mut() {
let xs = ready!(init_fut.poll_unpin(cx));
*this.initial = None;
this.heap.extend(
xs.into_iter().filter_map(|(head_option, tail)| {
head_option.map(|head| HeadTail { head, tail })
}),
);
}
if let Some(ref mut next_stream) = this.next {
if let Some(item) = ready!(next_stream.next().poll_unpin(cx)) {
this.heap.push(HeadTail {
head: item,
tail: this.next.take().unwrap(),
});
}
}
match this.heap.pop() {
None => Poll::Ready(None),
Some(HeadTail { head, tail }) => {
this.next.replace(tail);
Poll::Ready(Some(head))
}
}
}
}
fn assert_stream<T, S>(stream: S) -> S
where
S: Stream<Item = T>,
{
stream
}
#[cfg(test)]
mod test {
use std::pin::Pin;
use std::time::Duration;
use futures::stream;
use futures::FutureExt;
use futures::Stream;
use futures::StreamExt;
use tokio::sync::oneshot;
use tokio::time;
use tokio_stream::wrappers::IntervalStream;
use super::*;
#[tokio::test]
async fn sync() {
let streams = vec![stream::iter(vec![5, 3, 1]), stream::iter(vec![4, 3, 2])];
assert_eq!(
kmerge(streams).collect::<Vec<usize>>().await,
vec![5, 4, 3, 3, 2, 1],
);
}
#[tokio::test]
async fn by() {
let streams = vec![stream::iter(vec![5, 3, 1]), stream::iter(vec![4, 3, 2])];
let stream = kmerge_by(streams, |x: &usize, y: &usize| x.cmp(&y));
assert_eq!(stream.collect::<Vec<usize>>().await, vec![5, 4, 3, 3, 2, 1],);
}
#[tokio::test]
async fn by_key() {
let streams = vec![
stream::iter(vec![("a", 5), ("a", 3)]),
stream::iter(vec![("b", 4), ("b", 4)]),
];
let stream = kmerge_by_key(streams, |x: &(&'static str, usize)| x.1);
assert_eq!(
stream.collect::<Vec<_>>().await,
vec![("a", 5), ("b", 4), ("b", 4), ("a", 3)]
);
}
#[tokio::test]
async fn kmerge_async() {
let streams = vec![
IntervalStream::new(time::interval(Duration::from_nanos(1))),
IntervalStream::new(time::interval(Duration::from_nanos(2))),
];
let result = kmerge(streams).take(10).collect::<Vec<_>>().await;
assert_eq!(result.len(), 10);
}
#[tokio::test]
async fn concurrent_initialization() {
let (tx1, rx1) = oneshot::channel();
let (tx2, rx2) = oneshot::channel();
let s1 = async move {
tx1.send(1).unwrap();
rx2.await.unwrap()
}
.into_stream();
let s2 = async move {
tx2.send(2).unwrap();
rx1.await.unwrap()
}
.into_stream();
let streams: Vec<Pin<Box<dyn Stream<Item = i32>>>> = vec![Box::pin(s1), Box::pin(s2)];
let result = kmerge(streams).collect::<Vec<_>>().await;
assert_eq!(result, vec![2, 1]);
}
}
#[cfg(doctest)]
doc_comment::doctest!("../README.md");