stream_kmerge/
lib.rs

1#![warn(rust_2018_idioms)]
2use std::cmp::Ordering;
3use std::task::Poll;
4
5use binary_heap_plus::BinaryHeap;
6use compare::Compare;
7use futures::future::{join_all, JoinAll};
8use futures::{ready, stream::StreamFuture, FutureExt, Stream, StreamExt};
9use pin_project_lite::pin_project;
10
11/// Head element and tail stream pair
12#[derive(Debug)]
13pub struct HeadTail<S>
14where
15    S: Stream,
16{
17    head: S::Item,
18    tail: S,
19}
20
21pin_project! {
22    /// A stream adaptor that merges an abitrary number of base streams
23    /// according to an ordering function.
24    ///
25    /// Iterator element type is `S::Item`.
26    ///
27    /// See [`.kmerge_by()`](crate::kmerge_by) for more
28    /// information.
29    #[must_use = "stream adaptors are lazy and do nothing unless consumed"]
30    pub struct KWayMergeBy<S, C>
31    where
32        S: Stream,
33        S: Unpin,
34        C: Compare<HeadTail<S>>
35    {
36        initial: Option<JoinAll<StreamFuture<S>>>,
37        next: Option<S>,
38        heap: BinaryHeap<HeadTail<S>, C>,
39    }
40}
41
42/// A stream adaptor that merges an abitrary number of base streams in descending order.
43/// If all base streams are sorted (descending), the result is sorted.
44///
45/// Stream element type is `S::Item`.
46///
47/// See [`.kmerge()`](crate::kmerge) for more information.
48#[must_use = "stream adaptors are lazy and do nothing unless consumed"]
49pub type KWayMerge<I> = KWayMergeBy<I, OrdComparator>;
50
51pub struct OrdComparator;
52
53impl<S> Compare<HeadTail<S>> for OrdComparator
54where
55    S: Stream,
56    S::Item: Ord,
57{
58    fn compare(&self, l: &HeadTail<S>, r: &HeadTail<S>) -> std::cmp::Ordering {
59        l.head.cmp(&r.head)
60    }
61}
62
63pub struct FnComparator<F> {
64    f: F,
65}
66
67impl<S, F> Compare<HeadTail<S>> for FnComparator<F>
68where
69    S: Stream,
70    F: Fn(&S::Item, &S::Item) -> Ordering,
71{
72    fn compare(&self, l: &HeadTail<S>, r: &HeadTail<S>) -> std::cmp::Ordering {
73        (self.f)(&l.head, &r.head)
74    }
75}
76
77pub struct KeyComparator<F> {
78    f: F,
79}
80
81impl<S, F, O> Compare<HeadTail<S>> for KeyComparator<F>
82where
83    S: Stream,
84    F: Fn(&S::Item) -> O,
85    O: Ord,
86{
87    fn compare(&self, l: &HeadTail<S>, r: &HeadTail<S>) -> std::cmp::Ordering {
88        (self.f)(&l.head).cmp(&(self.f)(&r.head))
89    }
90}
91
92/// Create a stream that merges elements of the contained streams using
93/// the ordering function.
94///
95/// ```
96/// # tokio_test::block_on(async {
97/// use futures::{stream, StreamExt};
98/// use stream_kmerge::kmerge;
99///
100/// let streams = vec![stream::iter(vec![5, 3, 1]), stream::iter(vec![4, 3, 2])];
101///
102/// assert_eq!(
103///     kmerge(streams).collect::<Vec<usize>>().await,
104///     vec![5, 4, 3, 3, 2, 1],
105/// );
106/// # })
107/// ```
108pub fn kmerge<S>(xs: impl IntoIterator<Item = S>) -> KWayMerge<S>
109where
110    S: Stream + Unpin,
111    S::Item: Ord,
112{
113    assert_stream::<S::Item, _>(kmerge_generic(xs, OrdComparator))
114}
115
116/// Create a stream that merges elements of the contained streams.
117///
118/// ```
119/// # tokio_test::block_on(async {
120/// use futures::{stream, StreamExt};
121/// use stream_kmerge::kmerge_by;
122///
123/// let streams = vec![stream::iter(vec![1, 3, 5]), stream::iter(vec![2, 3, 4])];
124///
125/// assert_eq!(
126///     kmerge_by(streams, |x: &usize, y: &usize| y.cmp(&x)).collect::<Vec<usize>>().await,
127///     vec![1, 2, 3, 3, 4, 5],
128/// );
129/// # })
130/// ```
131pub fn kmerge_by<S, F>(xs: impl IntoIterator<Item = S>, f: F) -> KWayMergeBy<S, FnComparator<F>>
132where
133    S: Stream + Unpin,
134    F: Fn(&S::Item, &S::Item) -> Ordering,
135{
136    kmerge_generic(xs, FnComparator { f })
137}
138
139/// Create a stream that merges elements of the contained streams.
140///
141/// ```
142/// # tokio_test::block_on(async {
143/// use futures::{stream, StreamExt};
144/// use stream_kmerge::kmerge_by_key;
145///
146/// let streams = vec![stream::iter(vec![("a", 5), ("a", 3)]), stream::iter(vec![("b", 4), ("b", 4)])];
147///
148/// assert_eq!(
149///     kmerge_by_key(streams, |x: &(&'static str, usize)| x.1).collect::<Vec<_>>().await,
150///     vec![("a", 5), ("b", 4), ("b", 4), ("a", 3)],
151/// );
152/// # })
153/// ```
154pub fn kmerge_by_key<S, F, O>(
155    xs: impl IntoIterator<Item = S>,
156    f: F,
157) -> KWayMergeBy<S, KeyComparator<F>>
158where
159    S: Stream + Unpin,
160    F: Fn(&S::Item) -> O,
161    O: Ord,
162{
163    kmerge_generic(xs, KeyComparator { f })
164}
165
166/// This was originally meant to be [`kmerge_by`], but triggers a compiler bug, if you directly pass
167/// a closure without type hint for `less_than` (https://github.com/rust-lang/rust/issues/81511).
168/// More specifically, the following code fails to compile:
169///
170/// ```norust
171/// use stream_kmerge::kmerge_generic;
172/// use futures::stream;
173///
174/// kmerge_generic(vec![stream::empty::<usize>()], |x, y| x < y);
175/// ```
176///
177/// If you add a type hint to the closure parameter's, it works:
178///
179/// ```norust
180/// use stream_kmerge::kmerge_generic;
181/// use futures::stream;
182///
183/// kmerge_generic(vec![stream::empty::<usize>()], |x: &_, y: &_| x < y);
184/// ```
185///
186/// The error message is
187/// ```norust
188/// error[E0308]: mismatched types
189///    --> src/lib.rs:165:1
190///     |
191/// 7   | kmerge_generic(vec![stream::empty::<usize>()], |x, y| x < y);
192///     | ^^^^^^^^^^^^^^ lifetime mismatch
193///     |
194///     = note: expected type `for<'r, 's> FnMut<(&'r usize, &'s usize)>`
195///                found type `FnMut<(&usize, &usize)>`
196/// note: this closure does not fulfill the lifetime requirements
197///    --> src/lib.rs:165:48
198///     |
199/// 7   | kmerge_generic(vec![stream::empty::<usize>()], |x, y| x < y);
200///     |                                                ^^^^^^^^^^^^
201/// note: the lifetime requirement is introduced here
202///    --> /home/thomas/src/stream-kmerge/src/lib.rs:171:8
203///     |
204/// 171 |     F: KMergePredicate<S::Item>,
205///     |        ^^^^^^^^^^^^^^^^^^^^^^^^
206///
207/// error: aborting due to previous error
208///
209/// For more information about this error, try `rustc --explain E0308`.
210/// Couldn't compile the test.
211/// ```
212///
213/// Therefore, `kmerge_generic` is private and the public [`kmerge_by`] explicitly takes a closure
214/// instead of [`KMergePredicate`].
215fn kmerge_generic<S, C>(xs: impl IntoIterator<Item = S>, cmp: C) -> KWayMergeBy<S, C>
216where
217    S: Stream + Unpin,
218    C: Compare<HeadTail<S>>,
219{
220    let iter = xs.into_iter();
221    let (min_size, _) = iter.size_hint();
222    assert_stream::<S::Item, _>(KWayMergeBy {
223        initial: Some(join_all(iter.map(|x| x.into_future()))),
224        next: None,
225        heap: BinaryHeap::from_vec_cmp(Vec::with_capacity(min_size), cmp),
226    })
227}
228
229impl<S, C> Stream for KWayMergeBy<S, C>
230where
231    S: Stream + Unpin,
232    C: Compare<HeadTail<S>>,
233{
234    type Item = S::Item;
235
236    fn poll_next(
237        self: std::pin::Pin<&mut Self>,
238        cx: &mut std::task::Context<'_>,
239    ) -> std::task::Poll<Option<Self::Item>> {
240        let this = self.project();
241        if let Some(init_fut) = this.initial.as_mut() {
242            let xs = ready!(init_fut.poll_unpin(cx));
243            *this.initial = None;
244            this.heap.extend(
245                xs.into_iter().filter_map(|(head_option, tail)| {
246                    head_option.map(|head| HeadTail { head, tail })
247                }),
248            );
249        }
250
251        if let Some(ref mut next_stream) = this.next {
252            if let Some(item) = ready!(next_stream.next().poll_unpin(cx)) {
253                this.heap.push(HeadTail {
254                    head: item,
255                    tail: this.next.take().unwrap(),
256                });
257            }
258        }
259
260        match this.heap.pop() {
261            None => Poll::Ready(None),
262            Some(HeadTail { head, tail }) => {
263                this.next.replace(tail);
264
265                Poll::Ready(Some(head))
266            }
267        }
268    }
269}
270
271// Just a helper function to ensure the streams we're returning all have the
272// right implementations.
273fn assert_stream<T, S>(stream: S) -> S
274where
275    S: Stream<Item = T>,
276{
277    stream
278}
279
280#[cfg(test)]
281mod test {
282    use std::pin::Pin;
283    use std::time::Duration;
284
285    use futures::stream;
286    use futures::FutureExt;
287    use futures::Stream;
288    use futures::StreamExt;
289    use tokio::sync::oneshot;
290    use tokio::time;
291    use tokio_stream::wrappers::IntervalStream;
292
293    use super::*;
294
295    #[tokio::test]
296    async fn sync() {
297        let streams = vec![stream::iter(vec![5, 3, 1]), stream::iter(vec![4, 3, 2])];
298
299        assert_eq!(
300            kmerge(streams).collect::<Vec<usize>>().await,
301            vec![5, 4, 3, 3, 2, 1],
302        );
303    }
304
305    #[tokio::test]
306    async fn by() {
307        let streams = vec![stream::iter(vec![5, 3, 1]), stream::iter(vec![4, 3, 2])];
308        let stream = kmerge_by(streams, |x: &usize, y: &usize| x.cmp(&y));
309
310        assert_eq!(stream.collect::<Vec<usize>>().await, vec![5, 4, 3, 3, 2, 1],);
311    }
312
313    #[tokio::test]
314    async fn by_key() {
315        let streams = vec![
316            stream::iter(vec![("a", 5), ("a", 3)]),
317            stream::iter(vec![("b", 4), ("b", 4)]),
318        ];
319        let stream = kmerge_by_key(streams, |x: &(&'static str, usize)| x.1);
320
321        assert_eq!(
322            stream.collect::<Vec<_>>().await,
323            vec![("a", 5), ("b", 4), ("b", 4), ("a", 3)]
324        );
325    }
326
327    #[tokio::test]
328    async fn kmerge_async() {
329        let streams = vec![
330            IntervalStream::new(time::interval(Duration::from_nanos(1))),
331            IntervalStream::new(time::interval(Duration::from_nanos(2))),
332        ];
333
334        let result = kmerge(streams).take(10).collect::<Vec<_>>().await;
335
336        assert_eq!(result.len(), 10);
337    }
338
339    #[tokio::test]
340    async fn concurrent_initialization() {
341        let (tx1, rx1) = oneshot::channel();
342        let (tx2, rx2) = oneshot::channel();
343
344        let s1 = async move {
345            tx1.send(1).unwrap();
346            rx2.await.unwrap()
347        }
348        .into_stream();
349        let s2 = async move {
350            tx2.send(2).unwrap();
351            rx1.await.unwrap()
352        }
353        .into_stream();
354
355        let streams: Vec<Pin<Box<dyn Stream<Item = i32>>>> = vec![Box::pin(s1), Box::pin(s2)];
356
357        let result = kmerge(streams).collect::<Vec<_>>().await;
358        assert_eq!(result, vec![2, 1]);
359    }
360}
361
362#[cfg(doctest)]
363doc_comment::doctest!("../README.md");