use either_or_both::EitherOrBoth;
use futures::Stream;
use futures::StreamExt;
use futures::stream::Fuse;
use pin_project_lite::pin_project;
use std::{
cmp::{self, Ordering},
pin::Pin,
task::{Context, Poll},
};
pin_project! {
#[must_use = "streams do nothing unless polled"]
pub (crate) struct MergeJoinBy<L, R, F>
where
L: Stream,
R: Stream,
F: Fn(&L::Item, &R::Item) -> Ordering
{
#[pin]
left: Fuse<L>,
#[pin]
right: Fuse<R>,
left_queued: Option<L::Item>,
right_queued: Option<R::Item>,
comparison: F,
}
}
impl<L, R, F> MergeJoinBy<L, R, F>
where
L: Stream,
R: Stream,
F: Fn(&L::Item, &R::Item) -> Ordering,
{
pub(crate) fn new(left_stream: L, right_stream: R, comparison: F) -> Self {
Self {
left: left_stream.fuse(),
right: right_stream.fuse(),
left_queued: None,
right_queued: None,
comparison,
}
}
}
impl<L, R, F> Stream for MergeJoinBy<L, R, F>
where
L: Stream,
R: Stream,
F: Fn(&L::Item, &R::Item) -> Ordering,
{
type Item = EitherOrBoth<L::Item, R::Item>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
if this.left_queued.is_none() {
match this.left.as_mut().poll_next(cx) {
Poll::Ready(Some(item)) => *this.left_queued = Some(item),
Poll::Ready(None) | Poll::Pending => {}
}
}
if this.right_queued.is_none() {
match this.right.as_mut().poll_next(cx) {
Poll::Ready(Some(item)) => *this.right_queued = Some(item),
Poll::Ready(None) | Poll::Pending => {}
}
}
if this.left_queued.is_some() && this.right_queued.is_some() {
match (this.comparison)(
this.left_queued.as_ref().unwrap(),
this.right_queued.as_ref().unwrap(),
) {
Ordering::Less => {
let just_left = EitherOrBoth::Left(this.left_queued.take().unwrap());
Poll::Ready(Some(just_left))
}
Ordering::Equal => {
let both = EitherOrBoth::Both(
this.left_queued.take().unwrap(),
this.right_queued.take().unwrap(),
);
Poll::Ready(Some(both))
}
Ordering::Greater => {
let just_right = EitherOrBoth::Right(this.right_queued.take().unwrap());
Poll::Ready(Some(just_right))
}
}
} else if this.left_queued.is_some() {
if this.right.is_done() {
let just_left = EitherOrBoth::Left(this.left_queued.take().unwrap());
Poll::Ready(Some(just_left))
} else {
Poll::Pending }
} else if this.right_queued.is_some() {
if this.left.is_done() {
let just_right = EitherOrBoth::Right(this.right_queued.take().unwrap());
Poll::Ready(Some(just_right))
} else {
Poll::Pending }
}
else if this.left.is_done() && this.right.is_done() {
Poll::Ready(None)
} else {
Poll::Pending
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let left_queued_len = usize::from(self.left_queued.is_some());
let right_queued_len = usize::from(self.right_queued.is_some());
let (left_lower, left_upper) = self.left.size_hint();
let (right_lower, right_upper) = self.right.size_hint();
let left_total_lower = left_lower.saturating_add(left_queued_len);
let right_total_lower = right_lower.saturating_add(right_queued_len);
let lower = cmp::max(left_total_lower, right_total_lower);
let upper = match (left_upper, right_upper) {
(Some(l_upper), Some(r_upper)) => {
l_upper
.checked_add(left_queued_len)
.and_then(|left_total_upper| left_total_upper.checked_add(r_upper))
.and_then(|all_but_right_queue| {
all_but_right_queue.checked_add(right_queued_len)
})
}
_ => None,
};
(lower, upper)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::StreamTools;
use futures::{executor::block_on_stream, stream};
use quickcheck::TestResult;
use quickcheck_macros::quickcheck;
use std::collections::BTreeSet;
#[quickcheck]
fn merge_of_sorteds_is_sorted(left: BTreeSet<isize>, right: BTreeSet<isize>) -> TestResult {
let left_stream = stream::iter(left);
let right_stream = stream::iter(right);
let stream = left_stream.merge_join_by(right_stream, Ord::cmp);
let sorted = block_on_stream(stream)
.flat_map(|either_or_both| {
either_or_both.into_iter() })
.is_sorted();
TestResult::from_bool(sorted)
}
#[quickcheck]
fn size_hints_dont_lie(left: Vec<isize>, right: Vec<isize>) -> bool {
let expected_lower = cmp::max(left.len(), right.len());
let expected_upper = left.len().checked_add(right.len());
let left_stream = stream::iter(left);
let right_stream = stream::iter(right);
let stream = left_stream.merge_join_by(right_stream, Ord::cmp);
let (lower, upper) = stream.size_hint();
assert_eq!(expected_lower, lower);
assert_eq!(expected_upper, upper);
let actual_size = block_on_stream(stream).count();
lower <= actual_size && upper.is_none_or(|limit| actual_size <= limit)
}
}