1mod array_node;
2mod treebuilder;
3
4use std::{cmp::Ordering, marker::PhantomData};
5
6use crate::{orderer::Orderer, run::Run};
7
8use self::{
9 array_node::{TreeNode, Winner},
10 treebuilder::LoserTreeBuilder,
11};
12
13pub struct LoserTree<T, R, O> {
17 loser_indices: Vec<u32>,
18 tapes: Vec<R>,
19 orderer: O,
20 phantom: PhantomData<T>,
21 winner: Winner,
22 remaining_tapes: usize,
23}
24
25fn previous_power_of_two(number: usize) -> usize {
27 let leading_zeros = number.leading_zeros();
28 const SHIFT_TO_HIGHEST_BIT: usize = core::mem::size_of::<usize>() * 8 - 1;
29 let shift = SHIFT_TO_HIGHEST_BIT - leading_zeros as usize;
30 1 << shift
31}
32
33fn get_candidate<T>(runs: &[impl Run<T>], candidate: Winner) -> Option<&T> {
34 runs[candidate.idx as usize].peek()
35}
36
37fn compare_winners<T>(
38 runs: &[impl Run<T>],
39 orderer: &impl Orderer<T>,
40 left: Winner,
41 right: Winner,
42) -> Ordering {
43 match (get_candidate(runs, left), get_candidate(runs, right)) {
44 (Some(l), Some(r)) => orderer.compare(l, r),
45 (Some(_), None) => Ordering::Less,
46 (None, Some(_)) => Ordering::Greater,
47 (None, None) => Ordering::Equal,
48 }
49}
50
51impl<T, R, O> LoserTree<T, R, O>
52where
53 R: Run<T>,
54 O: Orderer<T>,
55{
56 pub fn new(tapes: Vec<R>, orderer: O) -> Self {
59 let remaining_tapes = tapes.len();
60 let mut result = Self {
61 loser_indices: Vec::new(),
62 remaining_tapes,
63 tapes,
64 orderer,
65 winner: Winner { idx: u32::MAX },
66 phantom: PhantomData,
67 };
68
69 result.winner = result.rebuild_tree();
70
71 result
72 }
73
74 pub fn remaining_items(&self) -> usize {
76 self.tapes.iter().map(|t| t.remaining_items()).sum()
77 }
78
79 pub fn next(&mut self) -> Option<T> {
82 if self.tapes.len() <= 1 {
83 return self.tapes.first_mut()?.next();
84 }
85
86 let winning_tape = &mut self.tapes[self.winner.idx as usize];
87 let winning_value = winning_tape.next()?;
88 let tape_exhausted = winning_tape.peek().is_none();
89
90 self.winner = if tape_exhausted {
91 let none = winning_tape.next();
96 debug_assert!(none.is_none());
97
98 self.remove_winner(self.winner)
99 } else {
100 self.replay_matches(self.winner)
101 };
102
103 Some(winning_value)
104 }
105
106 fn rebuild_tree(&mut self) -> Winner {
109 self.tapes.retain(|t| t.peek().is_some());
110 self.remaining_tapes = self.tapes.len();
111
112 if self.tapes.len() > 1 {
113 LoserTreeBuilder::new(
114 |left, right| compare_winners(&self.tapes, &self.orderer, left, right),
115 &mut self.loser_indices,
116 )
117 .build(self.tapes.len())
118 } else {
119 Winner { idx: 0 }
120 }
121 }
122
123 fn compare_winners(&self, left: Winner, right: Winner) -> Ordering {
124 compare_winners(&self.tapes, &self.orderer, left, right)
125 }
126
127 fn replay_matches(&mut self, previous_winner: Winner) -> Winner {
130 let mut winner = previous_winner;
131 let mut current_node = self.get_leaf_node(previous_winner).parent();
132 loop {
133 let challenger = Winner {
134 idx: self.loser_indices[current_node.idx],
135 };
136
137 if self.compare_winners(challenger, winner).is_lt() {
138 self.loser_indices[current_node.idx] = winner.idx;
140 winner = challenger;
141 }
142
143 if current_node.is_root() {
144 return winner;
145 }
146 current_node = current_node.parent();
147 }
148 }
149
150 fn remove_winner(&mut self, previous_winner: Winner) -> Winner {
153 debug_assert!(!self.loser_indices.is_empty());
155
156 self.remaining_tapes -= 1;
157
158 let number_of_tapes = self.tapes.len();
159 let rebuild_threshold = previous_power_of_two(number_of_tapes - 1);
160
161 if self.remaining_tapes <= rebuild_threshold {
162 self.rebuild_tree()
165 } else {
166 self.replay_matches(previous_winner)
171 }
172 }
173
174 fn get_leaf_node(&self, leaf: Winner) -> TreeNode {
176 let tree_size = self.tapes.len();
177 TreeNode::leaf_for_winner(leaf, tree_size)
178 }
179}
180
181impl<T, R, O> Iterator for LoserTree<T, R, O>
182where
183 R: Run<T>,
184 O: Orderer<T>,
185{
186 type Item = T;
187
188 fn next(&mut self) -> Option<Self::Item> {
189 self.next()
190 }
191 fn size_hint(&self) -> (usize, Option<usize>) {
192 let remaining = self.remaining_items();
193 (remaining, Some(remaining))
194 }
195}
196impl<T, R, O> ExactSizeIterator for LoserTree<T, R, O>
197where
198 R: Run<T>,
199 O: Orderer<T>,
200{
201 fn len(&self) -> usize {
202 self.remaining_items()
203 }
204}
205
206#[cfg(test)]
207mod test {
208
209 use crate::{orderer::OrdOrderer, run::buf_run::BufRun};
210
211 use super::LoserTree;
212
213 fn run_merge_test(runs: Vec<Vec<u32>>) {
214 let buf_runs = runs.iter().cloned().map(BufRun::new).collect();
215 let mut merger = LoserTree::new(buf_runs, OrdOrderer::new());
216
217 let mut result = Vec::new();
218 while let Some(next) = merger.next() {
219 result.push(next);
220 }
221
222 let mut expected: Vec<_> = runs.iter().flatten().cloned().collect();
223 expected.sort();
224
225 if expected != result {
226 for run in &runs {
227 println!("run: {run:?}");
228 }
229 }
230 assert_eq!(expected, result);
231 }
232
233 #[test]
234 fn test_merge_runs() {
235 let run_1 = vec![1, 3, 5, 7];
236 let run_4 = vec![0, 2, 4, 6];
237 let run_3 = vec![8, 10, 12, 14];
238 let run_2 = vec![9, 11, 13, 15];
239
240 run_merge_test(vec![run_1, run_2, run_3, run_4]);
241 }
242
243 #[test]
244 fn test_merge_unbalanced() {
245 let run_1 = vec![1, 4];
246 let run_2 = vec![5, 6, 7];
247 let run_3 = vec![2, 3];
248
249 run_merge_test(vec![run_1, run_3, run_2]);
250 }
251
252 #[test]
253 fn test_merge_five() {
254 let runs = vec![
255 vec![20, 73],
256 vec![29, 73],
257 vec![3, 84],
258 vec![33, 70],
259 vec![63, 95],
260 ];
261 run_merge_test(runs);
262 }
263
264 #[cfg(not(miri))]
265 mod random {
267 use std::sync::{Arc, Mutex};
268
269 use rand::{rngs::ThreadRng, RngCore};
270
271 use super::run_merge_test;
272
273 fn generate_run(rng: &mut ThreadRng, len: usize) -> Vec<u32> {
274 let mut run = Vec::with_capacity(len);
275 for _ in 0..len {
276 run.push(rng.next_u32());
277 }
278 run.sort();
279 run
280 }
281
282 #[test]
283 fn test_merge_runs_random() {
284 let params = (1..100).flat_map(move |runs| {
285 (1..20).flat_map(move |items| (1..5).map(move |_| (runs, items)))
286 });
287
288 let params = Arc::new(Mutex::new(params));
289
290 let threads: Vec<_> = (0..num_cpus::get())
291 .map(|_| {
292 let params = params.clone();
293 std::thread::spawn(move || {
294 let mut rng = rand::thread_rng();
295 loop {
296 let next = params.lock().unwrap().next();
297 if let Some((num_runs, num_items)) = next {
298 let runs: Vec<_> =
299 core::iter::repeat_with(|| generate_run(&mut rng, num_items))
300 .take(num_runs)
301 .collect();
302 run_merge_test(runs);
303 } else {
304 break;
305 }
306 }
307 })
308 })
309 .collect();
310
311 threads.into_iter().for_each(|t| t.join().unwrap());
312 }
313 }
314}