use std::{sync::Arc, vec::IntoIter};
use futures::{Stream, StreamExt};
use tokio::{
runtime::Runtime,
sync::mpsc::{self, Receiver},
task::JoinHandle,
};
use super::{
binary_heap::{BinaryHeap, PeekMut},
compare::Compare,
};
pub struct EagerStream<T> {
rx: Receiver<T>,
task: JoinHandle<()>,
runtime: Arc<Runtime>,
}
impl<T> EagerStream<T> {
pub fn from_stream_with_runtime<S>(stream: S, runtime: Arc<Runtime>) -> Self
where
S: Stream<Item = T> + Send + 'static,
T: Send + 'static,
{
let (tx, rx) = mpsc::channel(1);
let task = runtime.spawn(async move {
futures::pin_mut!(stream);
while let Some(item) = stream.next().await {
if tx.send(item).await.is_err() {
break;
}
}
});
Self { rx, task, runtime }
}
}
impl<T> Iterator for EagerStream<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.runtime.block_on(self.rx.recv())
}
}
impl<T> Drop for EagerStream<T> {
fn drop(&mut self) {
self.rx.close();
self.task.abort();
}
}
pub struct ElementBatchIter<I, T>
where
I: Iterator<Item = IntoIter<T>>,
{
pub item: T,
batch: I::Item,
iter: I,
}
impl<I, T> ElementBatchIter<I, T>
where
I: Iterator<Item = IntoIter<T>>,
{
fn new_from_iter(mut iter: I) -> Option<Self> {
loop {
let Some(mut batch) = iter.next() else {
break None;
};
if let Some(item) = batch.next() {
break Some(Self { item, batch, iter });
}
}
}
}
pub struct KMerge<I, T, C>
where
I: Iterator<Item = IntoIter<T>>,
{
heap: BinaryHeap<ElementBatchIter<I, T>, C>,
}
impl<I, T, C> KMerge<I, T, C>
where
I: Iterator<Item = IntoIter<T>>,
C: Compare<ElementBatchIter<I, T>>,
{
pub fn new(cmp: C) -> Self {
Self {
heap: BinaryHeap::from_vec_cmp(Vec::new(), cmp),
}
}
pub fn push_iter(&mut self, s: I) {
if let Some(heap_elem) = ElementBatchIter::new_from_iter(s) {
self.heap.push(heap_elem);
}
}
pub fn clear(&mut self) {
self.heap.clear();
}
}
impl<I, T, C> Iterator for KMerge<I, T, C>
where
I: Iterator<Item = IntoIter<T>>,
C: Compare<ElementBatchIter<I, T>>,
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
match self.heap.peek_mut() {
Some(mut heap_elem) => {
match heap_elem.batch.next() {
Some(mut item) => {
std::mem::swap(&mut item, &mut heap_elem.item);
Some(item)
}
None => loop {
let Some(mut batch) = heap_elem.iter.next() else {
let ElementBatchIter {
item,
batch: _,
iter: _,
} = PeekMut::pop(heap_elem);
break Some(item);
};
if let Some(mut item) = batch.next() {
heap_elem.batch = batch;
std::mem::swap(&mut item, &mut heap_elem.item);
break Some(item);
}
},
}
}
None => None,
}
}
}
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use rstest::rstest;
use super::*;
struct OrdComparator;
impl<S> Compare<ElementBatchIter<S, i32>> for OrdComparator
where
S: Iterator<Item = IntoIter<i32>>,
{
fn compare(
&self,
l: &ElementBatchIter<S, i32>,
r: &ElementBatchIter<S, i32>,
) -> std::cmp::Ordering {
l.item.cmp(&r.item).reverse()
}
}
impl<S> Compare<ElementBatchIter<S, u64>> for OrdComparator
where
S: Iterator<Item = IntoIter<u64>>,
{
fn compare(
&self,
l: &ElementBatchIter<S, u64>,
r: &ElementBatchIter<S, u64>,
) -> std::cmp::Ordering {
l.item.cmp(&r.item).reverse()
}
}
#[rstest]
fn test1() {
let iter_a = vec![vec![1, 2, 3].into_iter(), vec![7, 8, 9].into_iter()].into_iter();
let iter_b = vec![vec![4, 5, 6].into_iter()].into_iter();
let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
kmerge.push_iter(iter_a);
kmerge.push_iter(iter_b);
let values: Vec<i32> = kmerge.collect();
assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[rstest]
fn test2() {
let iter_a = vec![vec![1, 2, 6].into_iter(), vec![7, 8, 9].into_iter()].into_iter();
let iter_b = vec![vec![3, 4, 5, 6].into_iter()].into_iter();
let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
kmerge.push_iter(iter_a);
kmerge.push_iter(iter_b);
let values: Vec<i32> = kmerge.collect();
assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 6, 7, 8, 9]);
}
#[rstest]
fn test3() {
let iter_a = vec![vec![1, 4, 7].into_iter(), vec![24, 35, 56].into_iter()].into_iter();
let iter_b = vec![vec![2, 4, 8].into_iter()].into_iter();
let iter_c = vec![vec![3, 5, 9].into_iter(), vec![12, 12, 90].into_iter()].into_iter();
let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
kmerge.push_iter(iter_a);
kmerge.push_iter(iter_b);
kmerge.push_iter(iter_c);
let values: Vec<i32> = kmerge.collect();
assert_eq!(
values,
vec![1, 2, 3, 4, 4, 5, 7, 8, 9, 12, 12, 24, 35, 56, 90]
);
}
#[rstest]
fn test5() {
let iter_a = vec![
vec![1, 3, 5].into_iter(),
vec![].into_iter(),
vec![7, 9, 11].into_iter(),
]
.into_iter();
let iter_b = vec![vec![2, 4, 6].into_iter()].into_iter();
let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
kmerge.push_iter(iter_a);
kmerge.push_iter(iter_b);
let values: Vec<i32> = kmerge.collect();
assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 9, 11]);
}
#[derive(Debug, Clone)]
struct SortedNestedVec(Vec<Vec<u64>>);
fn sorted_nested_vec_strategy() -> impl Strategy<Value = SortedNestedVec> {
prop::collection::vec(any::<u64>(), 0..=100).prop_flat_map(|mut flat_vec| {
flat_vec.sort_unstable();
let total_len = flat_vec.len();
if total_len == 0 {
return Just(SortedNestedVec(vec![vec![]])).boxed();
}
prop::collection::vec(0..=total_len, 0..=10)
.prop_map(move |mut boundaries| {
boundaries.push(0);
boundaries.push(total_len);
boundaries.sort_unstable();
boundaries.dedup();
let mut nested_vec = Vec::new();
for [start, end] in boundaries.array_windows() {
nested_vec.push(flat_vec[*start..*end].to_vec());
}
SortedNestedVec(nested_vec)
})
.boxed()
})
}
proptest! {
#[rstest]
fn prop_kmerge_equivalent_to_sort(
all_data in prop::collection::vec(sorted_nested_vec_strategy(), 0..=10)
) {
let mut kmerge: KMerge<_, u64, _> = KMerge::new(OrdComparator);
let copy_data = all_data.clone();
for stream in copy_data {
let input = stream.0.into_iter().map(std::iter::IntoIterator::into_iter);
kmerge.push_iter(input);
}
let merged_data: Vec<u64> = kmerge.collect();
let mut sorted_data: Vec<u64> = all_data
.into_iter()
.flat_map(|stream| stream.0.into_iter().flatten())
.collect();
sorted_data.sort_unstable();
prop_assert_eq!(merged_data.len(), sorted_data.len(), "Lengths should be equal");
prop_assert_eq!(merged_data, sorted_data, "Merged data should equal sorted data");
}
#[rstest]
fn prop_kmerge_preserves_sort_order(
all_data in prop::collection::vec(sorted_nested_vec_strategy(), 1..=5)
) {
let mut kmerge: KMerge<_, u64, _> = KMerge::new(OrdComparator);
for stream in all_data {
let input = stream.0.into_iter().map(std::iter::IntoIterator::into_iter);
kmerge.push_iter(input);
}
let merged_data: Vec<u64> = kmerge.collect();
for [a, b] in merged_data.array_windows() {
prop_assert!(a <= b, "Merged data should be sorted");
}
}
#[rstest]
fn prop_kmerge_handles_empty_iterators(
data in sorted_nested_vec_strategy(),
empty_count in 0usize..=5
) {
let mut kmerge_with_empty: KMerge<_, u64, _> = KMerge::new(OrdComparator);
let mut kmerge_without_empty: KMerge<_, u64, _> = KMerge::new(OrdComparator);
let input_with_empty = data.0.clone().into_iter().map(std::iter::IntoIterator::into_iter);
let input_without_empty = data.0.into_iter().map(std::iter::IntoIterator::into_iter);
kmerge_with_empty.push_iter(input_with_empty);
kmerge_without_empty.push_iter(input_without_empty);
for _ in 0..empty_count {
let empty_vec: Vec<Vec<u64>> = vec![];
let empty_input = empty_vec.into_iter().map(std::iter::IntoIterator::into_iter);
kmerge_with_empty.push_iter(empty_input);
}
let result_with_empty: Vec<u64> = kmerge_with_empty.collect();
let result_without_empty: Vec<u64> = kmerge_without_empty.collect();
prop_assert_eq!(result_with_empty, result_without_empty, "Empty iterators should not affect result");
}
}
}