mini_lsm/iterators/
merge_iterator.rs1use std::cmp::{self};
2use std::collections::binary_heap::PeekMut;
3use std::collections::BinaryHeap;
4
5use anyhow::Result;
6
7use crate::key::KeySlice;
8
9use super::StorageIterator;
10
11struct HeapWrapper<I: StorageIterator>(pub usize, pub Box<I>);
12
13impl<I: StorageIterator> PartialEq for HeapWrapper<I> {
14 fn eq(&self, other: &Self) -> bool {
15 self.partial_cmp(other).unwrap() == cmp::Ordering::Equal
16 }
17}
18
19impl<I: StorageIterator> Eq for HeapWrapper<I> {}
20
21impl<I: StorageIterator> PartialOrd for HeapWrapper<I> {
22 #[allow(clippy::non_canonical_partial_ord_impl)]
23 fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
24 match self.1.key().cmp(&other.1.key()) {
25 cmp::Ordering::Greater => Some(cmp::Ordering::Greater),
26 cmp::Ordering::Less => Some(cmp::Ordering::Less),
27 cmp::Ordering::Equal => self.0.partial_cmp(&other.0),
28 }
29 .map(|x| x.reverse())
30 }
31}
32
33impl<I: StorageIterator> Ord for HeapWrapper<I> {
34 fn cmp(&self, other: &Self) -> cmp::Ordering {
35 self.partial_cmp(other).unwrap()
36 }
37}
38
39pub struct MergeIterator<I: StorageIterator> {
42 iters: BinaryHeap<HeapWrapper<I>>,
43 current: Option<HeapWrapper<I>>,
44}
45
46impl<I: StorageIterator> MergeIterator<I> {
47 pub fn create(iters: Vec<Box<I>>) -> Self {
48 if iters.is_empty() {
49 return Self {
50 iters: BinaryHeap::new(),
51 current: None,
52 };
53 }
54
55 let mut heap = BinaryHeap::new();
56
57 if iters.iter().all(|x| !x.is_valid()) {
58 let mut iters = iters;
60 return Self {
61 iters: heap,
62 current: Some(HeapWrapper(0, iters.pop().unwrap())),
63 };
64 }
65
66 for (idx, iter) in iters.into_iter().enumerate() {
67 if iter.is_valid() {
68 heap.push(HeapWrapper(idx, iter));
69 }
70 }
71
72 let current = heap.pop().unwrap();
73 Self {
74 iters: heap,
75 current: Some(current),
76 }
77 }
78}
79
80impl<I: 'static + for<'a> StorageIterator<KeyType<'a> = KeySlice<'a>>> StorageIterator
81 for MergeIterator<I>
82{
83 type KeyType<'a> = KeySlice<'a>;
84
85 fn key(&self) -> KeySlice {
86 self.current.as_ref().unwrap().1.key()
87 }
88
89 fn value(&self) -> &[u8] {
90 self.current.as_ref().unwrap().1.value()
91 }
92
93 fn is_valid(&self) -> bool {
94 self.current
95 .as_ref()
96 .map(|x| x.1.is_valid())
97 .unwrap_or(false)
98 }
99
100 fn next(&mut self) -> Result<()> {
101 let current = self.current.as_mut().unwrap();
102 while let Some(mut inner_iter) = self.iters.peek_mut() {
104 debug_assert!(
105 inner_iter.1.key() >= current.1.key(),
106 "heap invariant violated"
107 );
108 if inner_iter.1.key() == current.1.key() {
109 if let e @ Err(_) = inner_iter.1.next() {
111 PeekMut::pop(inner_iter);
112 return e;
113 }
114
115 if !inner_iter.1.is_valid() {
117 PeekMut::pop(inner_iter);
118 }
119 } else {
120 break;
121 }
122 }
123
124 current.1.next()?;
125
126 if !current.1.is_valid() {
128 if let Some(iter) = self.iters.pop() {
129 *current = iter;
130 }
131 return Ok(());
132 }
133
134 if let Some(mut inner_iter) = self.iters.peek_mut() {
136 if *current < *inner_iter {
137 std::mem::swap(&mut *inner_iter, current);
138 }
139 }
140
141 Ok(())
142 }
143
144 fn num_active_iterators(&self) -> usize {
145 self.iters
146 .iter()
147 .map(|x| x.1.num_active_iterators())
148 .sum::<usize>()
149 + self
150 .current
151 .as_ref()
152 .map(|x| x.1.num_active_iterators())
153 .unwrap_or(0)
154 }
155}