use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use binary_heap_plus::BinaryHeap;
use binary_heap_plus::PeekMut;
use compare::Compare;
use futures::ready;
use futures::Stream;
use crate::comparators::Ascending;
use crate::comparators::Descending;
use crate::stream_more::comparators::FnCmp;
use crate::stream_more::kmerge::heap_entry::HeapEntry;
use crate::stream_more::kmerge::heap_entry::HeapEntryCmp;
use crate::stream_more::peeked::Peeked;
pub struct KMerge<'a, C, D>
where C: Compare<D>
{
curr_id: u64,
heap: BinaryHeap<HeapEntry<'a, D>, HeapEntryCmp<D, C>>,
}
impl<'a, F, D> KMerge<'a, FnCmp<F>, D>
where
F: Fn(&D, &D) -> bool,
FnCmp<F>: Compare<D>,
{
pub fn by(first: F) -> Self {
Self::by_cmp(FnCmp(first))
}
}
impl<'a, D, C> KMerge<'a, C, D>
where C: Compare<D>
{
pub fn by_cmp(cmp: C) -> Self {
KMerge {
curr_id: 0,
heap: BinaryHeap::<HeapEntry<D>, _>::from_vec_cmp(vec![], HeapEntryCmp::new(cmp)),
}
}
pub fn merge(mut self, stream: impl Stream<Item = D> + Send + 'a) -> Self {
self.curr_id += 1;
self.heap.push(HeapEntry::new(Box::pin(stream)).with_id(self.curr_id));
self
}
}
impl<'a, D> KMerge<'a, Descending, D>
where Descending: Compare<D>
{
pub fn max() -> Self {
KMerge::by_cmp(Descending)
}
}
impl<'a, D> KMerge<'a, Ascending, D>
where Ascending: Compare<D>
{
pub fn min() -> Self {
KMerge::by_cmp(Ascending)
}
}
impl<'a, D, C> Stream for KMerge<'a, C, D>
where
D: Unpin,
C: Compare<D> + Unpin,
{
type Item = D;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
let Some(mut peek_mut) = self.heap.peek_mut() else {
return Poll::Ready(None);
};
if peek_mut.peeked.has_peeked() {
return Poll::Ready(peek_mut.peeked.take());
}
let next = ready!(peek_mut.stream.as_mut().poll_next(cx));
if let Some(t) = next {
peek_mut.peeked = Peeked::Yes(t);
} else {
PeekMut::pop(peek_mut);
}
}
}
}