1use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5use std::error::Error;
6
7struct OrderedWrapper<T, F>
9where
10 F: Fn(&T, &T) -> Ordering,
11{
12 value: T,
13 compare: F,
14}
15
16impl<T, F> OrderedWrapper<T, F>
17where
18 F: Fn(&T, &T) -> Ordering,
19{
20 fn wrap(value: T, compare: F) -> Self {
21 OrderedWrapper { value, compare }
22 }
23
24 fn unwrap(self) -> T {
25 self.value
26 }
27}
28
29impl<T, F> PartialEq for OrderedWrapper<T, F>
30where
31 F: Fn(&T, &T) -> Ordering,
32{
33 fn eq(&self, other: &Self) -> bool {
34 self.cmp(other) == Ordering::Equal
35 }
36}
37
38impl<T, F> Eq for OrderedWrapper<T, F> where F: Fn(&T, &T) -> Ordering {}
39
40impl<T, F> PartialOrd for OrderedWrapper<T, F>
41where
42 F: Fn(&T, &T) -> Ordering,
43{
44 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
45 Some(self.cmp(other))
46 }
47}
48impl<T, F> Ord for OrderedWrapper<T, F>
49where
50 F: Fn(&T, &T) -> Ordering,
51{
52 fn cmp(&self, other: &Self) -> Ordering {
53 (self.compare)(&self.value, &other.value)
54 }
55}
56
57pub struct BinaryHeapMerger<T, E, F, C>
62where
63 E: Error,
64 F: Fn(&T, &T) -> Ordering,
65 C: IntoIterator<Item = Result<T, E>>,
66{
67 items: BinaryHeap<(std::cmp::Reverse<OrderedWrapper<T, F>>, usize)>,
69 chunks: Vec<C::IntoIter>,
70 initiated: bool,
71 compare: F,
72}
73
74impl<T, E, F, C> BinaryHeapMerger<T, E, F, C>
75where
76 E: Error,
77 F: Fn(&T, &T) -> Ordering,
78 C: IntoIterator<Item = Result<T, E>>,
79{
80 pub fn new<I>(chunks: I, compare: F) -> Self
86 where
87 I: IntoIterator<Item = C>,
88 {
89 let chunks = Vec::from_iter(chunks.into_iter().map(|c| c.into_iter()));
90 let items = BinaryHeap::with_capacity(chunks.len());
91
92 return BinaryHeapMerger {
93 chunks,
94 items,
95 compare,
96 initiated: false,
97 };
98 }
99}
100
101impl<T, E, F, C> Iterator for BinaryHeapMerger<T, E, F, C>
102where
103 E: Error,
104 F: Fn(&T, &T) -> Ordering + Copy,
105 C: IntoIterator<Item = Result<T, E>>,
106{
107 type Item = Result<T, E>;
108
109 fn next(&mut self) -> Option<Self::Item> {
111 if !self.initiated {
112 for (idx, chunk) in self.chunks.iter_mut().enumerate() {
113 if let Some(item) = chunk.next() {
114 match item {
115 Ok(item) => self
116 .items
117 .push((std::cmp::Reverse(OrderedWrapper::wrap(item, self.compare)), idx)),
118 Err(err) => return Some(Err(err)),
119 }
120 }
121 }
122 self.initiated = true;
123 }
124
125 let (result, idx) = self.items.pop()?;
126 if let Some(item) = self.chunks[idx].next() {
127 match item {
128 Ok(item) => self
129 .items
130 .push((std::cmp::Reverse(OrderedWrapper::wrap(item, self.compare)), idx)),
131 Err(err) => return Some(Err(err)),
132 }
133 }
134
135 return Some(Ok(result.0.unwrap()));
136 }
137}
138
139#[cfg(test)]
140mod test {
141 use rstest::*;
142 use std::error::Error;
143 use std::io::{self, ErrorKind};
144
145 use super::BinaryHeapMerger;
146
147 #[rstest]
148 #[case(
149 vec![],
150 vec![],
151 )]
152 #[case(
153 vec![
154 vec![],
155 vec![]
156 ],
157 vec![],
158 )]
159 #[case(
160 vec![
161 vec![Ok(4), Ok(5), Ok(7)],
162 vec![Ok(1), Ok(6)],
163 vec![Ok(3)],
164 vec![],
165 ],
166 vec![Ok(1), Ok(3), Ok(4), Ok(5), Ok(6), Ok(7)],
167 )]
168 #[case(
169 vec![
170 vec![Result::Err(io::Error::new(ErrorKind::Other, "test error"))]
171 ],
172 vec![
173 Result::Err(io::Error::new(ErrorKind::Other, "test error"))
174 ],
175 )]
176 #[case(
177 vec![
178 vec![Ok(3), Result::Err(io::Error::new(ErrorKind::Other, "test error"))],
179 vec![Ok(1), Ok(2)],
180 ],
181 vec![
182 Ok(1),
183 Ok(2),
184 Result::Err(io::Error::new(ErrorKind::Other, "test error")),
185 ],
186 )]
187 fn test_merger(
188 #[case] chunks: Vec<Vec<Result<i32, io::Error>>>,
189 #[case] expected_result: Vec<Result<i32, io::Error>>,
190 ) {
191 let merger = BinaryHeapMerger::new(chunks, i32::cmp);
192 let actual_result = merger.collect();
193 assert!(
194 compare_vectors_of_result::<_, io::Error>(&actual_result, &expected_result),
195 "actual={:?}, expected={:?}",
196 actual_result,
197 expected_result
198 );
199 }
200
201 fn compare_vectors_of_result<T: PartialEq, E: Error + 'static>(
202 actual: &Vec<Result<T, E>>,
203 expected: &Vec<Result<T, E>>,
204 ) -> bool {
205 actual
206 .into_iter()
207 .zip(expected)
208 .all(
209 |(actual_result, expected_result)| match (actual_result, expected_result) {
210 (Ok(actual_result), Ok(expected_result)) if actual_result == expected_result => true,
211 (Err(actual_err), Err(expected_err)) => actual_err.to_string() == expected_err.to_string(),
212 _ => false,
213 },
214 )
215 }
216}