use std::{
cmp::Ordering,
collections::BinaryHeap,
task::{Context, Poll},
};
use bytes::Bytes;
use crate::{
base::{Comparer, InternalKey, StorageResult},
iterator::StorageIterator,
};
pub(crate) struct MergingIteratorHeapEntry<'i, C: Comparer> {
pub iter: Box<dyn StorageIterator<'i, C> + 'i>,
pub source_id: u64,
}
impl<'i, C: Comparer> MergingIteratorHeapEntry<'i, C> {
pub fn new<I: StorageIterator<'i, C> + 'i>(iter: I, source_id: u64) -> Self {
Self {
iter: Box::new(iter),
source_id,
}
}
}
impl<'i, C: Comparer> PartialEq for MergingIteratorHeapEntry<'i, C> {
fn eq(&self, other: &Self) -> bool {
self.cmp(other).is_eq()
}
}
impl<'i, C: Comparer> Eq for MergingIteratorHeapEntry<'i, C> {}
impl<'i, C: Comparer> PartialOrd for MergingIteratorHeapEntry<'i, C> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<'i, C: Comparer> Ord for MergingIteratorHeapEntry<'i, C> {
fn cmp(&self, other: &Self) -> Ordering {
match (self.iter.key(), other.iter.key()) {
(Some(a), Some(b)) => a
.cmp(b)
.reverse()
.then_with(|| self.source_id.cmp(&other.source_id)),
(Some(_), None) => Ordering::Greater,
(None, Some(_)) => Ordering::Less,
(None, None) => Ordering::Equal,
}
}
}
enum MergingIteratorState<'i, C: Comparer> {
Initializing {
sources: Vec<MergingIteratorHeapEntry<'i, C>>,
},
Active,
}
pub(crate) struct MergingIterator<'i, C: Comparer> {
state: MergingIteratorState<'i, C>,
heap: BinaryHeap<MergingIteratorHeapEntry<'i, C>>,
current: Option<(InternalKey<C>, Bytes)>,
}
impl<'i, C: Comparer> MergingIterator<'i, C> {
pub(crate) fn new(sources: Vec<MergingIteratorHeapEntry<'i, C>>) -> Self {
Self {
state: MergingIteratorState::Initializing { sources },
heap: Default::default(),
current: None,
}
}
}
impl<'i, C: Comparer> StorageIterator<'i, C> for MergingIterator<'i, C> {
fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<StorageResult<Option<()>>> {
if let MergingIteratorState::Initializing { ref mut sources } = self.state {
trace!(sources = sources.len(), "initializing merging iterator");
let mut i = 0;
while i < sources.len() {
match sources[i].iter.poll_next(cx) {
Poll::Ready(Ok(Some(()))) => {
trace!("source ready");
let entry = sources.swap_remove(i);
self.heap.push(entry);
}
Poll::Ready(Ok(None)) => {
trace!("source empty");
sources.swap_remove(i);
}
Poll::Pending => {
trace!("source still pending");
i += 1;
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}
if sources.is_empty() {
trace!("finished initializing merging iterator");
self.state = MergingIteratorState::Active;
if self.heap.is_empty() {
return Poll::Ready(Ok(None));
}
} else {
trace!("initializing finished, but still incomplete");
return Poll::Pending;
}
}
if self.current.is_some() {
trace!("polling sources");
let mut top = self
.heap
.pop()
.expect("heap cannot be empty if current is not");
match top.iter.poll_next(cx) {
Poll::Ready(Ok(Some(()))) => self.heap.push(top),
Poll::Ready(Ok(None)) => {} Poll::Pending => {
self.heap.push(top);
return Poll::Pending;
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}
if let Some(top) = self.heap.peek() {
self.current = Some((
top.iter
.key()
.expect("iterators on heap must not be exhausted")
.clone(),
top.iter
.value()
.expect("iterators on heap must not be exhausted")
.clone(),
));
trace!(current = ?self.current, "got current value");
Poll::Ready(Ok(Some(())))
} else {
self.current = None;
Poll::Ready(Ok(None))
}
}
fn poll_seek(&mut self, key: &[u8], cx: &mut Context<'_>) -> Poll<StorageResult<()>> {
let mut entries: Vec<_> = self.heap.drain().collect();
if let MergingIteratorState::Initializing { ref mut sources } = self.state {
entries.append(sources);
self.state = MergingIteratorState::Active;
}
let mut i = 0;
while i < entries.len() {
match entries[i].iter.poll_seek(key, cx) {
Poll::Ready(Ok(())) => {
if entries[i].iter.key().is_some() {
i += 1;
} else {
entries.swap_remove(i);
}
}
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
}
}
self.current = None;
self.heap = entries.into_iter().collect();
Poll::Ready(Ok(()))
}
fn key(&self) -> Option<&InternalKey<C>> {
self.current.as_ref().map(|(k, _v)| k)
}
fn value(&self) -> Option<&Bytes> {
self.current.as_ref().map(|(_k, v)| v)
}
}
#[cfg(test)]
mod tests {
use tempest_core::test_utils::setup_tracing;
use crate::{base::DefaultComparer, iterator::mock::MockIterator};
use super::*;
#[tokio::test]
async fn test_merging_interleave() {
setup_tracing();
let mut sources = Vec::new();
sources.push(MergingIteratorHeapEntry::new(
MockIterator::new().add(1, "a").add(3, "c"),
1,
));
sources.push(MergingIteratorHeapEntry::new(
MockIterator::new().add(2, "b").add(4, "d"),
2,
));
let mut merger = MergingIterator::<DefaultComparer>::new(sources);
let mut results = Vec::new();
while let Ok(Some(())) = merger.next().await {
results.push(merger.key().unwrap().test_key_as_u64());
}
assert_eq!(results, vec![1, 2, 3, 4]);
}
#[tokio::test]
async fn test_merging_source_priority() {
setup_tracing();
let mut sources = Vec::new();
sources.push(MergingIteratorHeapEntry::new(
MockIterator::new().add(1, "old"),
10,
));
sources.push(MergingIteratorHeapEntry::new(
MockIterator::new().add(1, "new"),
100,
));
let mut merger = MergingIterator::<DefaultComparer>::new(sources);
assert!(matches!(merger.next().await, Ok(Some(()))));
assert_eq!(merger.value().unwrap(), &Bytes::from("new"));
assert!(matches!(merger.next().await, Ok(Some(()))));
assert_eq!(merger.value().unwrap(), &Bytes::from("old"));
}
#[tokio::test]
async fn test_merging_pending_propagation() {
setup_tracing();
let mut sources = Vec::new();
sources.push(MergingIteratorHeapEntry::new(
MockIterator::<DefaultComparer>::new()
.add(1, "val")
.pending_once(true),
1,
));
let mut merger = MergingIterator::<DefaultComparer>::new(sources);
let res = merger.next().await;
assert!(matches!(res, Ok(Some(()))));
assert_eq!(merger.key().unwrap().test_key_as_u64(), 1);
}
}