ext_sort/
merger.rs

1//! Binary heap merger.
2
3use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5use std::error::Error;
6
7/// Value wrapper binding custom compare function to a value.
8struct OrderedWrapper<T, F>
9where
10    F: Fn(&T, &T) -> Ordering,
11{
12    value: T,
13    compare: F,
14}
15
16impl<T, F> OrderedWrapper<T, F>
17where
18    F: Fn(&T, &T) -> Ordering,
19{
20    fn wrap(value: T, compare: F) -> Self {
21        OrderedWrapper { value, compare }
22    }
23
24    fn unwrap(self) -> T {
25        self.value
26    }
27}
28
29impl<T, F> PartialEq for OrderedWrapper<T, F>
30where
31    F: Fn(&T, &T) -> Ordering,
32{
33    fn eq(&self, other: &Self) -> bool {
34        self.cmp(other) == Ordering::Equal
35    }
36}
37
38impl<T, F> Eq for OrderedWrapper<T, F> where F: Fn(&T, &T) -> Ordering {}
39
40impl<T, F> PartialOrd for OrderedWrapper<T, F>
41where
42    F: Fn(&T, &T) -> Ordering,
43{
44    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
45        Some(self.cmp(other))
46    }
47}
48impl<T, F> Ord for OrderedWrapper<T, F>
49where
50    F: Fn(&T, &T) -> Ordering,
51{
52    fn cmp(&self, other: &Self) -> Ordering {
53        (self.compare)(&self.value, &other.value)
54    }
55}
56
57/// Binary heap merger implementation.
58/// Merges multiple sorted inputs into a single sorted output.
59/// Time complexity is *m* \* log(*n*) in worst case where *m* is the number of items,
60/// *n* is the number of chunks (inputs).
61pub struct BinaryHeapMerger<T, E, F, C>
62where
63    E: Error,
64    F: Fn(&T, &T) -> Ordering,
65    C: IntoIterator<Item = Result<T, E>>,
66{
67    // binary heap is max-heap by default so we reverse it to convert it to min-heap
68    items: BinaryHeap<(std::cmp::Reverse<OrderedWrapper<T, F>>, usize)>,
69    chunks: Vec<C::IntoIter>,
70    initiated: bool,
71    compare: F,
72}
73
74impl<T, E, F, C> BinaryHeapMerger<T, E, F, C>
75where
76    E: Error,
77    F: Fn(&T, &T) -> Ordering,
78    C: IntoIterator<Item = Result<T, E>>,
79{
80    /// Creates an instance of a binary heap merger using chunks as inputs.
81    /// Chunk items should be sorted in ascending order otherwise the result is undefined.
82    ///
83    /// # Arguments
84    /// * `chunks` - Chunks to be merged in a single sorted one
85    pub fn new<I>(chunks: I, compare: F) -> Self
86    where
87        I: IntoIterator<Item = C>,
88    {
89        let chunks = Vec::from_iter(chunks.into_iter().map(|c| c.into_iter()));
90        let items = BinaryHeap::with_capacity(chunks.len());
91
92        return BinaryHeapMerger {
93            chunks,
94            items,
95            compare,
96            initiated: false,
97        };
98    }
99}
100
101impl<T, E, F, C> Iterator for BinaryHeapMerger<T, E, F, C>
102where
103    E: Error,
104    F: Fn(&T, &T) -> Ordering + Copy,
105    C: IntoIterator<Item = Result<T, E>>,
106{
107    type Item = Result<T, E>;
108
109    /// Returns the next item from the inputs in ascending order.
110    fn next(&mut self) -> Option<Self::Item> {
111        if !self.initiated {
112            for (idx, chunk) in self.chunks.iter_mut().enumerate() {
113                if let Some(item) = chunk.next() {
114                    match item {
115                        Ok(item) => self
116                            .items
117                            .push((std::cmp::Reverse(OrderedWrapper::wrap(item, self.compare)), idx)),
118                        Err(err) => return Some(Err(err)),
119                    }
120                }
121            }
122            self.initiated = true;
123        }
124
125        let (result, idx) = self.items.pop()?;
126        if let Some(item) = self.chunks[idx].next() {
127            match item {
128                Ok(item) => self
129                    .items
130                    .push((std::cmp::Reverse(OrderedWrapper::wrap(item, self.compare)), idx)),
131                Err(err) => return Some(Err(err)),
132            }
133        }
134
135        return Some(Ok(result.0.unwrap()));
136    }
137}
138
139#[cfg(test)]
140mod test {
141    use rstest::*;
142    use std::error::Error;
143    use std::io::{self, ErrorKind};
144
145    use super::BinaryHeapMerger;
146
147    #[rstest]
148    #[case(
149        vec![],
150        vec![],
151    )]
152    #[case(
153        vec![
154            vec![],
155            vec![]
156        ],
157        vec![],
158    )]
159    #[case(
160        vec![
161            vec![Ok(4), Ok(5), Ok(7)],
162            vec![Ok(1), Ok(6)],
163            vec![Ok(3)],
164            vec![],
165        ],
166        vec![Ok(1), Ok(3), Ok(4), Ok(5), Ok(6), Ok(7)],
167    )]
168    #[case(
169        vec![
170            vec![Result::Err(io::Error::new(ErrorKind::Other, "test error"))]
171        ],
172        vec![
173            Result::Err(io::Error::new(ErrorKind::Other, "test error"))
174        ],
175    )]
176    #[case(
177        vec![
178            vec![Ok(3), Result::Err(io::Error::new(ErrorKind::Other, "test error"))],
179            vec![Ok(1), Ok(2)],
180        ],
181        vec![
182            Ok(1),
183            Ok(2),
184            Result::Err(io::Error::new(ErrorKind::Other, "test error")),
185        ],
186    )]
187    fn test_merger(
188        #[case] chunks: Vec<Vec<Result<i32, io::Error>>>,
189        #[case] expected_result: Vec<Result<i32, io::Error>>,
190    ) {
191        let merger = BinaryHeapMerger::new(chunks, i32::cmp);
192        let actual_result = merger.collect();
193        assert!(
194            compare_vectors_of_result::<_, io::Error>(&actual_result, &expected_result),
195            "actual={:?}, expected={:?}",
196            actual_result,
197            expected_result
198        );
199    }
200
201    fn compare_vectors_of_result<T: PartialEq, E: Error + 'static>(
202        actual: &Vec<Result<T, E>>,
203        expected: &Vec<Result<T, E>>,
204    ) -> bool {
205        actual
206            .into_iter()
207            .zip(expected)
208            .all(
209                |(actual_result, expected_result)| match (actual_result, expected_result) {
210                    (Ok(actual_result), Ok(expected_result)) if actual_result == expected_result => true,
211                    (Err(actual_err), Err(expected_err)) => actual_err.to_string() == expected_err.to_string(),
212                    _ => false,
213                },
214            )
215    }
216}