extsort_iter/merge/
mod.rs

1mod array_node;
2mod treebuilder;
3
4use std::{cmp::Ordering, marker::PhantomData};
5
6use crate::{orderer::Orderer, run::Run};
7
8use self::{
9    array_node::{TreeNode, Winner},
10    treebuilder::LoserTreeBuilder,
11};
12
13/// invariants:
14/// the winner must always point to the tape whose
15/// head is the smallest element.
16pub struct LoserTree<T, R, O> {
17    loser_indices: Vec<u32>,
18    tapes: Vec<R>,
19    orderer: O,
20    phantom: PhantomData<T>,
21    winner: Winner,
22    remaining_tapes: usize,
23}
24
25/// returns the largest power of 2 less or equal to the provided number
26fn previous_power_of_two(number: usize) -> usize {
27    let leading_zeros = number.leading_zeros();
28    const SHIFT_TO_HIGHEST_BIT: usize = core::mem::size_of::<usize>() * 8 - 1;
29    let shift = SHIFT_TO_HIGHEST_BIT - leading_zeros as usize;
30    1 << shift
31}
32
33fn get_candidate<T>(runs: &[impl Run<T>], candidate: Winner) -> Option<&T> {
34    runs[candidate.idx as usize].peek()
35}
36
37fn compare_winners<T>(
38    runs: &[impl Run<T>],
39    orderer: &impl Orderer<T>,
40    left: Winner,
41    right: Winner,
42) -> Ordering {
43    match (get_candidate(runs, left), get_candidate(runs, right)) {
44        (Some(l), Some(r)) => orderer.compare(l, r),
45        (Some(_), None) => Ordering::Less,
46        (None, Some(_)) => Ordering::Greater,
47        (None, None) => Ordering::Equal,
48    }
49}
50
51impl<T, R, O> LoserTree<T, R, O>
52where
53    R: Run<T>,
54    O: Orderer<T>,
55{
56    /// Constructs a new loser tree from the given tapes
57    /// and a provided ordering instruction.
58    pub fn new(tapes: Vec<R>, orderer: O) -> Self {
59        let remaining_tapes = tapes.len();
60        let mut result = Self {
61            loser_indices: Vec::new(),
62            remaining_tapes,
63            tapes,
64            orderer,
65            winner: Winner { idx: u32::MAX },
66            phantom: PhantomData,
67        };
68
69        result.winner = result.rebuild_tree();
70
71        result
72    }
73
74    /// returns the remaining items of all merged runs.
75    pub fn remaining_items(&self) -> usize {
76        self.tapes.iter().map(|t| t.remaining_items()).sum()
77    }
78
79    /// advances the internal state
80    /// Once this method returns None, it will never yield any elements again.
81    pub fn next(&mut self) -> Option<T> {
82        if self.tapes.len() <= 1 {
83            return self.tapes.first_mut()?.next();
84        }
85
86        let winning_tape = &mut self.tapes[self.winner.idx as usize];
87        let winning_value = winning_tape.next()?;
88        let tape_exhausted = winning_tape.peek().is_none();
89
90        self.winner = if tape_exhausted {
91            // while we surely know that the next result must be a None
92            // because the peek call did not return anything,
93            // reading the tape past the end will allow it to release
94            // backing resources already.
95            let none = winning_tape.next();
96            debug_assert!(none.is_none());
97
98            self.remove_winner(self.winner)
99        } else {
100            self.replay_matches(self.winner)
101        };
102
103        Some(winning_value)
104    }
105
106    /// rebuilds the loser tree, returning the new winner leaf
107    /// after reconstruction.
108    fn rebuild_tree(&mut self) -> Winner {
109        self.tapes.retain(|t| t.peek().is_some());
110        self.remaining_tapes = self.tapes.len();
111
112        if self.tapes.len() > 1 {
113            LoserTreeBuilder::new(
114                |left, right| compare_winners(&self.tapes, &self.orderer, left, right),
115                &mut self.loser_indices,
116            )
117            .build(self.tapes.len())
118        } else {
119            Winner { idx: 0 }
120        }
121    }
122
123    fn compare_winners(&self, left: Winner, right: Winner) -> Ordering {
124        compare_winners(&self.tapes, &self.orderer, left, right)
125    }
126
127    /// replay the matches from the previous winner back up to the root.
128    /// this must be applied to the tree after the winner was modified.
129    fn replay_matches(&mut self, previous_winner: Winner) -> Winner {
130        let mut winner = previous_winner;
131        let mut current_node = self.get_leaf_node(previous_winner).parent();
132        loop {
133            let challenger = Winner {
134                idx: self.loser_indices[current_node.idx],
135            };
136
137            if self.compare_winners(challenger, winner).is_lt() {
138                // the challenger won, note the previous winner in the tree and continue with the challenger
139                self.loser_indices[current_node.idx] = winner.idx;
140                winner = challenger;
141            }
142
143            if current_node.is_root() {
144                return winner;
145            }
146            current_node = current_node.parent();
147        }
148    }
149
150    /// removes the previous winner node from the tree, shrinking it in the process.
151    /// this involves a recomputation of the tree as a new node must be the winner after.
152    fn remove_winner(&mut self, previous_winner: Winner) -> Winner {
153        // we must have at least two runs remaining for it to make sense to remove one.
154        debug_assert!(!self.loser_indices.is_empty());
155
156        self.remaining_tapes -= 1;
157
158        let number_of_tapes = self.tapes.len();
159        let rebuild_threshold = previous_power_of_two(number_of_tapes - 1);
160
161        if self.remaining_tapes <= rebuild_threshold {
162            // we have exhausted enough tapes that the tree will be one level less deep.
163            // we can take this opportunity to drop runs as well as rebuild the tree
164            self.rebuild_tree()
165        } else {
166            // we have not exhausted enough tapes for it to actually make
167            // sense to rebuild the tree. Instead we rely on the fact that
168            // an exhausted tape is always compared to be greater than anything
169            // else and just replay to the root as usual.
170            self.replay_matches(previous_winner)
171        }
172    }
173
174    /// computes the tree node the leaf _would_ occuppy if we did store the leaves
175    fn get_leaf_node(&self, leaf: Winner) -> TreeNode {
176        let tree_size = self.tapes.len();
177        TreeNode::leaf_for_winner(leaf, tree_size)
178    }
179}
180
181impl<T, R, O> Iterator for LoserTree<T, R, O>
182where
183    R: Run<T>,
184    O: Orderer<T>,
185{
186    type Item = T;
187
188    fn next(&mut self) -> Option<Self::Item> {
189        self.next()
190    }
191    fn size_hint(&self) -> (usize, Option<usize>) {
192        let remaining = self.remaining_items();
193        (remaining, Some(remaining))
194    }
195}
196impl<T, R, O> ExactSizeIterator for LoserTree<T, R, O>
197where
198    R: Run<T>,
199    O: Orderer<T>,
200{
201    fn len(&self) -> usize {
202        self.remaining_items()
203    }
204}
205
206#[cfg(test)]
207mod test {
208
209    use crate::{orderer::OrdOrderer, run::buf_run::BufRun};
210
211    use super::LoserTree;
212
213    fn run_merge_test(runs: Vec<Vec<u32>>) {
214        let buf_runs = runs.iter().cloned().map(BufRun::new).collect();
215        let mut merger = LoserTree::new(buf_runs, OrdOrderer::new());
216
217        let mut result = Vec::new();
218        while let Some(next) = merger.next() {
219            result.push(next);
220        }
221
222        let mut expected: Vec<_> = runs.iter().flatten().cloned().collect();
223        expected.sort();
224
225        if expected != result {
226            for run in &runs {
227                println!("run: {run:?}");
228            }
229        }
230        assert_eq!(expected, result);
231    }
232
233    #[test]
234    fn test_merge_runs() {
235        let run_1 = vec![1, 3, 5, 7];
236        let run_4 = vec![0, 2, 4, 6];
237        let run_3 = vec![8, 10, 12, 14];
238        let run_2 = vec![9, 11, 13, 15];
239
240        run_merge_test(vec![run_1, run_2, run_3, run_4]);
241    }
242
243    #[test]
244    fn test_merge_unbalanced() {
245        let run_1 = vec![1, 4];
246        let run_2 = vec![5, 6, 7];
247        let run_3 = vec![2, 3];
248
249        run_merge_test(vec![run_1, run_3, run_2]);
250    }
251
252    #[test]
253    fn test_merge_five() {
254        let runs = vec![
255            vec![20, 73],
256            vec![29, 73],
257            vec![3, 84],
258            vec![33, 70],
259            vec![63, 95],
260        ];
261        run_merge_test(runs);
262    }
263
264    #[cfg(not(miri))]
265    // the only reason this is disabled on miri is that it would run too slowly
266    mod random {
267        use std::sync::{Arc, Mutex};
268
269        use rand::{rngs::ThreadRng, RngCore};
270
271        use super::run_merge_test;
272
273        fn generate_run(rng: &mut ThreadRng, len: usize) -> Vec<u32> {
274            let mut run = Vec::with_capacity(len);
275            for _ in 0..len {
276                run.push(rng.next_u32());
277            }
278            run.sort();
279            run
280        }
281
282        #[test]
283        fn test_merge_runs_random() {
284            let params = (1..100).flat_map(move |runs| {
285                (1..20).flat_map(move |items| (1..5).map(move |_| (runs, items)))
286            });
287
288            let params = Arc::new(Mutex::new(params));
289
290            let threads: Vec<_> = (0..num_cpus::get())
291                .map(|_| {
292                    let params = params.clone();
293                    std::thread::spawn(move || {
294                        let mut rng = rand::thread_rng();
295                        loop {
296                            let next = params.lock().unwrap().next();
297                            if let Some((num_runs, num_items)) = next {
298                                let runs: Vec<_> =
299                                    core::iter::repeat_with(|| generate_run(&mut rng, num_items))
300                                        .take(num_runs)
301                                        .collect();
302                                run_merge_test(runs);
303                            } else {
304                                break;
305                            }
306                        }
307                    })
308                })
309                .collect();
310
311            threads.into_iter().for_each(|t| t.join().unwrap());
312        }
313    }
314}