Skip to main content

dbsp/trace/cursor/
cursor_list.rs

1//! A generic cursor implementation merging multiple cursors.
2
3use crate::dynamic::{DataTrait, Factory, WeightTrait};
4use crate::trace::cursor::Position;
5use crate::utils::binary_heap::BinaryHeap;
6use std::marker::PhantomData;
7use std::mem::{take, transmute};
8use std::{any::TypeId, cmp::Ordering};
9
10use super::{Cursor, Direction};
11
12/// Provides a cursor interface over a list of cursors.
13///
14/// The `CursorList` tracks the indices of cursors with the minimum key, and
15/// the indices of cursors with the minimum key and minimum value.
16///
17/// # Design
18///
19/// The implementation uses a binary heap to keep the cursors partially
20/// sorted by key and value. This allows us to find the new min/max
21/// key/value in logarithmic time compared to linear time for a naive linear scan.
22///
23/// Why binary heap and not a tournament tree? Tournament trees are generally better
24/// for merge sorting, as they require exactly log(n) comparisons to find the new
25/// winner vs the worst-case 2 x log(n) comparisons for a binary heap. The problem
26/// with tournament trees is that they don't offer an efficient way to peek _all_
27/// max/min values, and not just one (you need to remove or update the current max first
28/// and pay another log(n) comparison to find the next max). In contrast, our customized
29/// binary heap implementation offers an efficient way to peek all max/min values at once.
30pub struct CursorList<K, V, T, R: WeightTrait, C>
31where
32    K: DataTrait + ?Sized,
33    V: DataTrait + ?Sized,
34    T: 'static,
35    R: WeightTrait + ?Sized,
36    C: Cursor<K, V, T, R>,
37{
38    cursors: Vec<C>,
39
40    /// Indexes of cursors that hold the current minimum key.
41    current_key: Vec<usize>,
42
43    /// Indexes of cursors in `current_key` in `key_priority_heap`.
44    current_key_indexes: Vec<usize>,
45
46    /// A priority heap that keeps cursors partially sorted by key.
47    /// The first element is the index of the cursor in `cursors`,
48    /// the second element is a pointer to the key. It should be equal
49    /// to `cursors[index].key()`.
50    ///
51    /// # Safety
52    ///
53    /// We assume that the key reference returned by the cursor remains valid
54    /// until the cursor moves to the next key. This is not officially part of the
55    /// cursor API and I hope we can find a more strongly typed way to achieve this;
56    /// however both in-memory and file-backed batches we have today are well-behaved
57    /// in this regard.
58    key_priority_heap: Vec<(usize, &'static K)>,
59
60    /// Indexes of cursors that hold the current minimum value.
61    current_val: Vec<usize>,
62
63    /// Indexes of cursors in `current_val` in `val_priority_heap`.
64    current_val_indexes: Vec<usize>,
65
66    /// A priority heap that keeps cursors partially sorted by value.
67    val_priority_heap: Vec<(usize, &'static V)>,
68
69    /// The direction is None after calling `seek_key_exact`. In this case we don't sort the cursors by key.
70    /// step_key, seek_key, etc., cannot be called on such a cursor.
71    #[cfg(debug_assertions)]
72    key_direction: Option<Direction>,
73    #[cfg(debug_assertions)]
74    val_direction: Direction,
75    weight: Box<R>,
76    /// Scratch space for use by `peek_all`.
77    scratch: Vec<usize>,
78    weight_factory: &'static dyn Factory<R>,
79    __type: PhantomData<fn(&K, &V, &T, &R)>,
80}
81
82impl<K, V, T, R, C> CursorList<K, V, T, R, C>
83where
84    K: DataTrait + ?Sized,
85    V: DataTrait + ?Sized,
86    R: WeightTrait + ?Sized,
87    C: Cursor<K, V, T, R>,
88    T: 'static,
89{
90    /// Creates a new cursor list from pre-existing cursors.
91    pub fn new(weight_factory: &'static dyn Factory<R>, cursors: Vec<C>) -> Self {
92        let num_cursors = cursors.len();
93        let mut result = Self {
94            cursors,
95            current_key: Vec::new(),
96            current_key_indexes: Vec::new(),
97            key_priority_heap: Vec::with_capacity(num_cursors),
98            current_val: Vec::new(),
99            current_val_indexes: Vec::new(),
100            val_priority_heap: Vec::with_capacity(num_cursors),
101            #[cfg(debug_assertions)]
102            key_direction: Some(Direction::Forward),
103            #[cfg(debug_assertions)]
104            val_direction: Direction::Forward,
105            weight: weight_factory.default_box(),
106            weight_factory,
107            scratch: Vec::new(),
108            __type: PhantomData,
109        };
110
111        result.minimize_keys();
112
113        result.skip_zero_weight_keys_forward();
114
115        result
116    }
117
118    #[cfg(debug_assertions)]
119    fn set_key_direction(&mut self, direction: Option<Direction>) {
120        self.key_direction = direction;
121    }
122
123    #[cfg(not(debug_assertions))]
124    fn set_key_direction(&mut self, _direction: Option<Direction>) {}
125
126    #[cfg(debug_assertions)]
127    fn set_val_direction(&mut self, direction: Direction) {
128        self.val_direction = direction;
129    }
130
131    #[cfg(not(debug_assertions))]
132    fn set_val_direction(&mut self, _direction: Direction) {}
133
134    #[cfg(debug_assertions)]
135    fn assert_key_direction(&self, direction: Direction) {
136        debug_assert_eq!(self.key_direction, Some(direction));
137    }
138
139    #[cfg(not(debug_assertions))]
140    fn assert_key_direction(&self, _direction: Direction) {}
141
142    #[cfg(debug_assertions)]
143    fn assert_val_direction(&self, direction: Direction) {
144        debug_assert_eq!(self.val_direction, direction);
145    }
146
147    #[cfg(not(debug_assertions))]
148    fn assert_val_direction(&self, _direction: Direction) {}
149
150    fn is_zero_weight(&mut self) -> bool {
151        if TypeId::of::<T>() == TypeId::of::<()>() {
152            debug_assert!(self.key_valid());
153            debug_assert!(self.val_valid());
154            debug_assert!(self.cursors[self.current_val[0]].val_valid());
155            self.weight.as_mut().set_zero();
156            for &index in self.current_val.iter() {
157                // TODO: use weight_checked
158                self.cursors[index].map_times(&mut |_, w| self.weight.add_assign(w));
159            }
160            self.weight.is_zero()
161        } else {
162            false
163        }
164    }
165
166    fn skip_zero_weight_vals_forward(&mut self) {
167        self.assert_val_direction(Direction::Forward);
168
169        while self.val_valid() && self.is_zero_weight() {
170            for &index in self.current_val.iter() {
171                self.cursors[index].step_val();
172            }
173            self.update_min_vals();
174        }
175    }
176
177    fn skip_zero_weight_vals_reverse(&mut self) {
178        self.assert_val_direction(Direction::Backward);
179
180        while self.val_valid() && self.is_zero_weight() {
181            for &index in self.current_val.iter() {
182                self.cursors[index].step_val_reverse();
183            }
184            self.update_max_vals();
185        }
186    }
187
188    fn skip_zero_weight_keys_forward(&mut self) {
189        self.assert_key_direction(Direction::Forward);
190
191        while self.key_valid() {
192            self.skip_zero_weight_vals_forward();
193            if self.val_valid() {
194                break;
195            }
196            for &index in self.current_key.iter() {
197                self.cursors[index].step_key();
198            }
199            self.set_val_direction(Direction::Forward);
200            self.update_min_keys();
201        }
202    }
203
204    fn skip_zero_weight_keys_reverse(&mut self) {
205        self.assert_key_direction(Direction::Backward);
206
207        while self.key_valid() {
208            self.skip_zero_weight_vals_forward();
209            if self.val_valid() {
210                break;
211            }
212            for &index in self.current_key.iter() {
213                self.cursors[index].step_key_reverse();
214            }
215            self.set_val_direction(Direction::Forward);
216            self.update_max_keys();
217        }
218    }
219
220    /// Find all cursors that point to the maximum (according to `comparator`) key and store the indexes in `current_key`.
221    ///
222    /// Uses naive linear scan to find the maximum key by performing num_cursors comparisons.
223    /// This is a fallback used for small number of cursors where a linear scan is faster than
224    /// maintaining the binary heap.
225    fn current_keys_linear(&mut self, comparator: impl Fn(&K, &K) -> Ordering) {
226        self.current_key.clear();
227
228        // Determine the index of the cursor with minimum key.
229        let mut max_key_opt: Option<&K> = None;
230        for (index, cursor) in self.cursors.iter().enumerate() {
231            if let Some(key) = cursor.get_key() {
232                if let Some(max_key_opt) = &mut max_key_opt {
233                    match (comparator)(key, max_key_opt) {
234                        Ordering::Greater => {
235                            *max_key_opt = key;
236                            self.current_key.clear();
237                            self.current_key.push(index);
238                        }
239                        Ordering::Equal => {
240                            self.current_key.push(index);
241                        }
242                        _ => (),
243                    }
244                } else {
245                    max_key_opt = Some(key);
246                    self.current_key.push(index);
247                }
248            }
249        }
250    }
251
252    /// Update current_key to contain the indexes of cursors with the maximum (according to `comparator`) key.
253    ///
254    /// If `rebuild` is true, assumes that all cursors have moved from their previous positions.
255    /// Otherwise, it assumes that only the cursors in `current_key` have moved.
256    ///
257    /// The implementation uses maximize_keys_linear for a small number of cursors. For larger numbers of cursors,
258    /// it uses a binary heap to keep the cursors partially sorted by key.
259    fn sort_keys(&mut self, comparator: impl Fn(&K, &K) -> Ordering, rebuild: bool) {
260        self.assert_val_direction(Direction::Forward);
261
262        // Use linear scan for small number of cursors.
263        //
264        // Binary heap cost:
265        //
266        // if rebuild is true:
267        //   cost <= 2*num_cursors + 2.
268        // else:
269        //   costs <= 2*log2(num_cursors) + 2.
270        //
271        // The +2 component above is the cost of peek_all assuming a single max element;
272        // in general it's 2x the number of max elements.
273        //
274        // Linear scan cost: num_cursors-1 in both cases.
275        //
276        // We pick 5 as a break-even point beyond which the binary heap is faster.
277        match self.cursors.len() {
278            0 => {
279                self.current_key.clear();
280                return;
281            }
282            1 => {
283                self.current_key.clear();
284                // SAFETY: `cursors.len() = 1`.
285                if self.cursors[0].key_valid() {
286                    self.current_key.push(0);
287                }
288                return;
289            }
290            n if n <= 5 => {
291                self.current_keys_linear(comparator);
292                return;
293            }
294            _ => {}
295        }
296
297        let cmp = |a: &(usize, &'static K), b: &(usize, &'static K)| comparator(a.1, b.1);
298
299        let heap = if rebuild {
300            // Build heap from scratch.
301            self.key_priority_heap.clear();
302            for (i, cursor) in self.cursors.iter().enumerate() {
303                if let Some(key) = cursor.get_key() {
304                    // SAFETY: We will not access the key after the cursor is moved.
305                    self.key_priority_heap
306                        .push((i, unsafe { transmute::<&K, &'static K>(key) }))
307                }
308            }
309
310            BinaryHeap::<(usize, &'static K), _>::from_vec(take(&mut self.key_priority_heap), cmp)
311        } else {
312            // Start with a previously built heap.
313            let mut heap = unsafe {
314                BinaryHeap::<(usize, &'static K), _>::from_vec_unchecked(
315                    take(&mut self.key_priority_heap),
316                    cmp,
317                )
318            };
319
320            // We may have moved cursors in `current_key` since the last time we built the heap
321            // and need to update the heap accordingly.
322            // IMPORTANT: we iterate over the indexes returned by peek_all in reverse order, which
323            // guarantees the modifying or removing the elements does not affect the indexes of the
324            // remaining elements.
325            for (pos, i) in self
326                .current_key_indexes
327                .iter()
328                .rev()
329                .zip(self.current_key.iter().rev())
330            {
331                // If the cursor still points to a valid key, update the key in the heap; otherwise,
332                // remove the cursor from the heap.
333                if let Some(key) = unsafe { self.cursors.get_unchecked(*i).get_key() } {
334                    // SAFETY: We know that `index` is in bounds since it's in `current_key`.
335                    // SAFETY: We will not access the key after the cursor is moved.
336                    unsafe {
337                        heap.update_pos_sift_down(*pos, ((*i), transmute::<&K, &'static K>(key)))
338                    };
339                } else {
340                    heap.remove(*pos);
341                }
342            }
343
344            heap
345        };
346
347        self.current_key.clear();
348        self.current_key_indexes.clear();
349
350        // Find the new set of cursors with the maximum key and record them in `current_key` and `current_key_indexes`,
351        // without removing them from the heap. This is 2x more efficient that removing them and re-inserting during the
352        // next call to sort_keys.
353        heap.peek_all(
354            |pos, &(i, _)| {
355                self.current_key.push(i);
356                self.current_key_indexes.push(pos);
357            },
358            &mut self.scratch,
359        );
360
361        self.key_priority_heap = heap.into_vec();
362    }
363
364    /// Find all cursors that point to the maximum (according to `comparator`) value and store the indexes in `current_val`.
365    ///
366    /// Uses naive linear scan to find the maximum value by performing current_key.len() comparisons.
367    /// This is a fallback used for small number of cursors where a linear scan is faster than
368    /// maintaining the binary heap.
369    fn current_vals_linear(&mut self, comparator: impl Fn(&V, &V) -> Ordering) {
370        self.current_val.clear();
371
372        let mut max_val: Option<&V> = None;
373        for &index in self.current_key.iter() {
374            // SAFETY: `index` is in bounds since it's in `current_key`.
375            if let Some(val) = unsafe { self.cursors.get_unchecked(index).get_val() } {
376                if let Some(max_val) = &mut max_val {
377                    match (comparator)(val, max_val) {
378                        Ordering::Greater => {
379                            *max_val = val;
380                            self.current_val.clear();
381                            self.current_val.push(index);
382                        }
383                        Ordering::Equal => self.current_val.push(index),
384                        _ => (),
385                    }
386                } else {
387                    max_val = Some(val);
388                    self.current_val.push(index);
389                }
390            }
391        }
392    }
393
394    /// Update `current_val` to contain the indexes of cursors with the maximum (according to `comparator`) values.
395    ///
396    /// If `rebuild` is true, assumes that all cursors have moved from their previous positions.
397    /// Otherwise, it assumes that only the cursors in `current_val` have moved.
398    ///
399    /// The implementation uses maximize_vals_linear for a small number of cursors. For larger numbers of cursors,
400    /// it uses a binary heap to keep the cursors partially sorted by value.
401    fn sort_vals(&mut self, comparator: impl Fn(&V, &V) -> Ordering, rebuild: bool) {
402        debug_assert!(
403            self.current_key
404                .iter()
405                .all(|i| self.cursors[*i].key_valid())
406        );
407
408        match self.current_key.len() {
409            0 => {
410                self.current_val.clear();
411                return;
412            }
413            1 => {
414                self.current_val.clear();
415                let i = unsafe { *self.current_key.get_unchecked(0) };
416                if unsafe { self.cursors.get_unchecked(i) }.val_valid() {
417                    self.current_val.push(i);
418                }
419                return;
420            }
421            n if n <= 5 => {
422                self.current_vals_linear(comparator);
423                return;
424            }
425            _ => {}
426        }
427
428        let cmp = |a: &(usize, &'static V), b: &(usize, &'static V)| comparator(a.1, b.1);
429
430        let heap = if rebuild {
431            self.val_priority_heap.clear();
432            for i in self.current_key.iter() {
433                if let Some(val) = unsafe { self.cursors.get_unchecked(*i).get_val() } {
434                    // SAFETY: We will not access the value after the cursor is moved.
435                    self.val_priority_heap
436                        .push(((*i), unsafe { transmute::<&V, &'static V>(val) }));
437                }
438            }
439            BinaryHeap::<(usize, &'static V), _>::from_vec(take(&mut self.val_priority_heap), cmp)
440        } else {
441            let mut heap = unsafe {
442                BinaryHeap::<(usize, &'static V), _>::from_vec_unchecked(
443                    take(&mut self.val_priority_heap),
444                    cmp,
445                )
446            };
447
448            for (pos, i) in self
449                .current_val_indexes
450                .iter()
451                .rev()
452                .zip(self.current_val.iter().rev())
453            {
454                if let Some(val) = unsafe { self.cursors.get_unchecked(*i).get_val() } {
455                    // SAFETY: We will not access the value after the cursor is moved.
456                    unsafe {
457                        heap.update_pos_sift_down(*pos, ((*i), transmute::<&V, &'static V>(val)))
458                    };
459                } else {
460                    heap.remove(*pos);
461                }
462            }
463
464            heap
465        };
466
467        self.current_val.clear();
468        self.current_val_indexes.clear();
469        heap.peek_all(
470            |pos, &(i, _)| {
471                self.current_val.push(i);
472                self.current_val_indexes.push(pos);
473            },
474            &mut self.scratch,
475        );
476
477        self.val_priority_heap = heap.into_vec();
478    }
479
480    /// Sort all cursors by key (ascending); initialize current_key with the indices of cursors with the minimum key.
481    ///
482    /// Invoked after calling seek_key, seek_key_with, or rewind_keys on all cursors.
483    ///
484    /// Once finished, it invokes `minimize_vals()` to ensure the value cursor is
485    /// in a consistent state as well.
486    fn minimize_keys(&mut self) {
487        self.assert_key_direction(Direction::Forward);
488
489        self.sort_keys(|c1, c2| c2.cmp(c1), true);
490        self.minimize_vals();
491    }
492
493    /// Sort all cursors by key (descending); initialize current_key with the indices of cursors with the maximum key.
494    ///
495    /// Invoked after calling seek_key_reverse, seek_key_with_reverse, or fast_forward_keys on all cursors.
496    ///
497    /// Once finished, it invokes `minimize_vals()` to ensure the value cursor is
498    /// in a consistent state as well.
499    fn maximize_keys(&mut self) {
500        self.assert_key_direction(Direction::Backward);
501
502        self.sort_keys(|c1, c2| c1.cmp(c2), true);
503        self.minimize_vals();
504    }
505
506    /// Sort all cursors is self.current_key by values (ascending); initialize current_val with the
507    /// indices of cursors with the minimum value.
508    fn minimize_vals(&mut self) {
509        self.assert_val_direction(Direction::Forward);
510
511        self.sort_vals(|c1, c2| c2.cmp(c1), true);
512    }
513
514    /// Sort all cursors is self.current_key by values (descending); initialize current_val with the
515    /// indices of cursors with the maximum value.
516    fn maximize_vals(&mut self) {
517        self.assert_val_direction(Direction::Backward);
518
519        self.sort_vals(|c1, c2| c1.cmp(c2), true);
520    }
521
522    /// Update the current_key array after stepping all cursors in self.current_key.
523    ///
524    /// Assumes that cursors in key_priority_heap are valid and sorted by key (ascending).
525    ///
526    /// Inserts valid cursors in current_key to key_priority_heap and pops the new min key cursors
527    /// from key_priority_heap into current_key.
528    fn update_min_keys(&mut self) {
529        self.assert_key_direction(Direction::Forward);
530
531        self.sort_keys(|c1, c2| c2.cmp(c1), false);
532        self.minimize_vals();
533    }
534
535    /// Update the current_key array after stepping all cursors in self.current_key in reverse.
536    ///
537    /// Assumes that cursors in key_priority_heap are valid and sorted by key (descending).
538    ///
539    /// Inserts valid cursors in current_key to key_priority_heap and pops the new max key cursors
540    /// from key_priority_heap into current_key.
541    fn update_max_keys(&mut self) {
542        self.assert_key_direction(Direction::Backward);
543
544        self.sort_keys(|c1, c2| c1.cmp(c2), false);
545        self.minimize_vals();
546    }
547
548    /// Update the current_val array after stepping all cursors in self.current_val with step_val.
549    ///
550    /// Assumes that cursors in val_priority_heap are valid and sorted by value (ascending).
551    ///
552    /// Inserts valid cursors in current_val to val_priority_heap and pops the new min value cursors
553    /// from val_priority_heap into current_val.
554    fn update_min_vals(&mut self) {
555        self.assert_val_direction(Direction::Forward);
556        self.sort_vals(|c1, c2| c2.cmp(c1), false);
557    }
558
559    /// Update the current_val array after stepping all cursors in self.current_val with step_val_reverse.
560    ///
561    /// Assumes that cursors in val_priority_heap are valid and sorted by value (descending).
562    ///
563    /// Inserts valid cursors in current_val to val_priority_heap and pops the new max value cursors
564    /// from val_priority_heap into current_val.
565    fn update_max_vals(&mut self) {
566        self.assert_val_direction(Direction::Backward);
567        self.sort_vals(|c1, c2| c1.cmp(c2), false);
568    }
569}
570
571impl<K, V, T, R, C: Cursor<K, V, T, R>> Cursor<K, V, T, R> for CursorList<K, V, T, R, C>
572where
573    K: DataTrait + ?Sized,
574    V: DataTrait + ?Sized,
575    R: WeightTrait + ?Sized,
576    T: 'static,
577{
578    fn weight_factory(&self) -> &'static dyn Factory<R> {
579        self.weight_factory
580    }
581
582    fn key_valid(&self) -> bool {
583        !self.current_key.is_empty()
584    }
585
586    fn val_valid(&self) -> bool {
587        !self.current_val.is_empty()
588    }
589
590    fn key(&self) -> &K {
591        debug_assert!(self.key_valid());
592        debug_assert!(self.cursors[self.current_key[0]].key_valid());
593        self.cursors[self.current_key[0]].key()
594    }
595
596    fn val(&self) -> &V {
597        debug_assert!(self.key_valid());
598        debug_assert!(self.val_valid());
599        debug_assert!(self.cursors[self.current_val[0]].val_valid());
600        self.cursors[self.current_val[0]].val()
601    }
602
603    fn map_times(&mut self, logic: &mut dyn FnMut(&T, &R)) {
604        debug_assert!(self.key_valid());
605        debug_assert!(self.val_valid());
606        for &index in self.current_val.iter() {
607            self.cursors[index].map_times(logic);
608        }
609    }
610
611    fn map_times_through(&mut self, upper: &T, logic: &mut dyn FnMut(&T, &R)) {
612        debug_assert!(self.key_valid());
613        debug_assert!(self.val_valid());
614        for &index in self.current_val.iter() {
615            self.cursors[index].map_times_through(upper, logic);
616        }
617    }
618
619    fn weight(&mut self) -> &R
620    where
621        T: PartialEq<()>,
622    {
623        self.weight_checked()
624    }
625
626    fn weight_checked(&mut self) -> &R {
627        if TypeId::of::<T>() == TypeId::of::<()>() {
628            debug_assert!(self.key_valid());
629            debug_assert!(self.val_valid());
630            debug_assert!(self.cursors[self.current_val[0]].val_valid());
631
632            // Weight should already be computed by `is_zero_weight`, which is always
633            // called as part of every operation that moves the cursor.
634            debug_assert!(!self.weight.is_zero());
635
636            &self.weight
637        } else {
638            panic!("CursorList::weight_checked called on non-unit timestamp type");
639        }
640    }
641
642    fn map_values(&mut self, logic: &mut dyn FnMut(&V, &R))
643    where
644        T: PartialEq<()>,
645    {
646        debug_assert!(self.key_valid());
647        while self.val_valid() {
648            let val = self.val();
649            logic(val, &self.weight);
650            self.step_val();
651        }
652    }
653
654    fn step_key(&mut self) {
655        self.assert_key_direction(Direction::Forward);
656
657        for &index in self.current_key.iter() {
658            debug_assert!(self.cursors[index].key_valid());
659            self.cursors[index].step_key();
660        }
661
662        self.set_val_direction(Direction::Forward);
663        self.update_min_keys();
664        self.skip_zero_weight_keys_forward();
665    }
666
667    fn step_key_reverse(&mut self) {
668        self.assert_key_direction(Direction::Backward);
669
670        for &index in self.current_key.iter() {
671            debug_assert!(self.cursors[index].key_valid());
672            self.cursors[index].step_key_reverse();
673        }
674
675        self.set_val_direction(Direction::Forward);
676        self.update_max_keys();
677        self.skip_zero_weight_keys_reverse();
678    }
679
680    fn seek_key(&mut self, key: &K) {
681        self.assert_key_direction(Direction::Forward);
682
683        for cursor in self.cursors.iter_mut() {
684            cursor.seek_key(key);
685        }
686
687        self.set_val_direction(Direction::Forward);
688        self.minimize_keys();
689        self.skip_zero_weight_keys_forward();
690    }
691
692    fn seek_key_exact(&mut self, key: &K, hash: Option<u64>) -> bool {
693        self.set_key_direction(None);
694
695        let hash = hash.unwrap_or_else(|| key.default_hash());
696        self.current_key.clear();
697
698        let mut result = false;
699
700        for (index, cursor) in self.cursors.iter_mut().enumerate() {
701            if cursor.seek_key_exact(key, Some(hash)) {
702                self.current_key.push(index);
703                result = true;
704            }
705        }
706
707        self.set_val_direction(Direction::Forward);
708        self.minimize_vals();
709
710        if result {
711            self.skip_zero_weight_vals_forward();
712            self.val_valid()
713        } else {
714            false
715        }
716    }
717
718    fn seek_key_with(&mut self, predicate: &dyn Fn(&K) -> bool) {
719        self.assert_key_direction(Direction::Forward);
720
721        for cursor in self.cursors.iter_mut() {
722            cursor.seek_key_with(&predicate);
723        }
724
725        self.set_val_direction(Direction::Forward);
726        self.minimize_keys();
727        self.skip_zero_weight_keys_forward();
728    }
729
730    fn seek_key_with_reverse(&mut self, predicate: &dyn Fn(&K) -> bool) {
731        self.assert_key_direction(Direction::Backward);
732
733        for cursor in self.cursors.iter_mut() {
734            cursor.seek_key_with_reverse(&predicate);
735        }
736
737        self.set_val_direction(Direction::Forward);
738        self.maximize_keys();
739        self.skip_zero_weight_keys_reverse();
740    }
741
742    fn seek_key_reverse(&mut self, key: &K) {
743        self.assert_key_direction(Direction::Backward);
744
745        for cursor in self.cursors.iter_mut() {
746            cursor.seek_key_reverse(key);
747        }
748
749        self.set_val_direction(Direction::Forward);
750        self.maximize_keys();
751        self.skip_zero_weight_keys_reverse();
752    }
753
754    fn step_val(&mut self) {
755        self.assert_val_direction(Direction::Forward);
756
757        for &index in self.current_val.iter() {
758            debug_assert!(self.cursors[index].key_valid());
759            debug_assert!(self.cursors[index].val_valid());
760            self.cursors[index].step_val();
761        }
762        self.update_min_vals();
763
764        self.skip_zero_weight_vals_forward();
765    }
766
767    fn seek_val(&mut self, val: &V) {
768        self.assert_val_direction(Direction::Forward);
769
770        for &index in self.current_key.iter() {
771            debug_assert!(self.cursors[index].key_valid());
772            self.cursors[index].seek_val(val);
773        }
774        self.minimize_vals();
775        self.skip_zero_weight_vals_forward();
776    }
777
778    fn seek_val_with(&mut self, predicate: &dyn Fn(&V) -> bool) {
779        self.assert_val_direction(Direction::Forward);
780
781        for &index in self.current_key.iter() {
782            debug_assert!(self.cursors[index].key_valid());
783            self.cursors[index].seek_val_with(predicate);
784        }
785        self.minimize_vals();
786        self.skip_zero_weight_vals_forward();
787    }
788
789    fn rewind_keys(&mut self) {
790        self.set_key_direction(Some(Direction::Forward));
791
792        for cursor in self.cursors.iter_mut() {
793            cursor.rewind_keys();
794        }
795
796        self.set_val_direction(Direction::Forward);
797        self.minimize_keys();
798        self.skip_zero_weight_keys_forward();
799    }
800
801    fn fast_forward_keys(&mut self) {
802        self.set_key_direction(Some(Direction::Backward));
803
804        for cursor in self.cursors.iter_mut() {
805            cursor.fast_forward_keys();
806        }
807
808        self.set_val_direction(Direction::Forward);
809        self.maximize_keys();
810        self.skip_zero_weight_keys_reverse();
811    }
812
813    fn rewind_vals(&mut self) {
814        for &index in self.current_key.iter() {
815            self.cursors[index].rewind_vals();
816        }
817
818        self.set_val_direction(Direction::Forward);
819        self.minimize_vals();
820        self.skip_zero_weight_vals_forward();
821    }
822
823    fn step_val_reverse(&mut self) {
824        self.assert_val_direction(Direction::Backward);
825
826        for &index in self.current_val.iter() {
827            debug_assert!(self.cursors[index].key_valid());
828            debug_assert!(self.cursors[index].val_valid());
829            self.cursors[index].step_val_reverse();
830        }
831        self.update_max_vals();
832        self.skip_zero_weight_vals_reverse();
833    }
834
835    fn seek_val_reverse(&mut self, val: &V) {
836        self.assert_val_direction(Direction::Backward);
837
838        for &index in self.current_key.iter() {
839            debug_assert!(self.cursors[index].key_valid());
840            self.cursors[index].seek_val_reverse(val);
841        }
842        self.maximize_vals();
843        self.skip_zero_weight_vals_reverse();
844    }
845
846    fn seek_val_with_reverse(&mut self, predicate: &dyn Fn(&V) -> bool) {
847        self.assert_val_direction(Direction::Backward);
848
849        for &index in self.current_key.iter() {
850            debug_assert!(self.cursors[index].key_valid());
851            self.cursors[index].seek_val_with_reverse(predicate);
852        }
853        self.maximize_vals();
854        self.skip_zero_weight_vals_reverse();
855    }
856
857    fn fast_forward_vals(&mut self) {
858        for &index in self.current_key.iter() {
859            debug_assert!(self.cursors[index].key_valid());
860            self.cursors[index].fast_forward_vals();
861        }
862
863        self.set_val_direction(Direction::Backward);
864        self.maximize_vals();
865        self.skip_zero_weight_vals_reverse();
866    }
867
868    fn position(&self) -> Option<Position> {
869        let mut num_keys = 0;
870        let mut current_key = 0;
871
872        for cursor in self.cursors.iter() {
873            let position = cursor.position().unwrap();
874            num_keys += position.total;
875            current_key += position.offset;
876        }
877        Some(Position {
878            total: num_keys,
879            offset: current_key,
880        })
881    }
882}
883
884#[cfg(test)]
885mod test {
886    use super::*;
887    use crate::IndexedZSetReader;
888    use crate::utils::Tup2;
889    use crate::{
890        dynamic::DowncastTrait,
891        indexed_zset,
892        trace::{BatchReader, BatchReaderFactories},
893    };
894    use proptest::{collection::vec, prelude::*};
895
896    pub type TestBatch = crate::OrdIndexedZSet<u64, u64>;
897
898    /// Collect (key, value, weight) tuples from a cursor by iterating forward.
899    pub fn cursor_to_tuples<C>(cursor: &mut C) -> Vec<(u64, u64, i64)>
900    where
901        C: Cursor<crate::dynamic::DynData, crate::dynamic::DynData, (), crate::DynZWeight>,
902    {
903        let mut result = Vec::new();
904        while cursor.key_valid() {
905            while cursor.val_valid() {
906                let k = *cursor.key().downcast_checked::<u64>();
907                let v = *cursor.val().downcast_checked::<u64>();
908                let w = *cursor.weight().downcast_checked::<i64>();
909                assert_ne!(w, 0);
910                result.push((k, v, w));
911
912                cursor.step_val();
913            }
914            cursor.step_key();
915        }
916        result
917    }
918
919    /// Collect tuples from batches via merge_batches_by_reference (uses cursor_to_tuples).
920    pub fn merged_batch_to_tuples(batches: &[TestBatch]) -> Vec<(u64, u64, i64)> {
921        if batches.is_empty() {
922            return Vec::new();
923        }
924
925        let mut tuples = batches.iter().flat_map(|b| b.iter()).collect::<Vec<_>>();
926        tuples.sort();
927
928        tuples.dedup_by(|(key1, v1, w1), (key2, v2, w2)| {
929            if key1 == key2 && v1 == v2 {
930                *w2 += std::mem::replace(w1, 0);
931                true
932            } else {
933                false
934            }
935        });
936        tuples.retain(|(_, _, w)| *w != 0);
937        tuples
938    }
939
940    #[test]
941    fn cursor_list_matches_merge_batches() {
942        let batch1: TestBatch = indexed_zset! { 1 => { 1 => 1, 2 => 2 }, 2 => { 1 => 1 } };
943        let batch2: TestBatch = indexed_zset! { 1 => { 2 => -1, 3 => 2 } };
944        let batch3: TestBatch = indexed_zset! { 2 => { 2 => 1 }, 3 => { 1 => 1 }, 4 => { 1 => 1 } };
945
946        let batches = vec![batch1, batch2, batch3];
947        let cursors: Vec<_> = batches.iter().map(|b| b.cursor()).collect();
948        let weight_factory = batches[0].factories().weight_factory();
949        let mut cursor_list = CursorList::new(weight_factory, cursors);
950
951        let cursor_output = cursor_to_tuples(&mut cursor_list);
952
953        let expected = merged_batch_to_tuples(&batches);
954
955        assert_eq!(cursor_output, expected);
956    }
957
958    #[test]
959    fn cursor_list_empty_batches() {
960        let batch1: TestBatch = indexed_zset! {};
961        let batch2: TestBatch = indexed_zset! { 1 => { 1 => 1 } };
962        let batch3: TestBatch = indexed_zset! {};
963
964        let batches = vec![batch1, batch2, batch3];
965        let cursors: Vec<_> = batches.iter().map(|b| b.cursor()).collect();
966        let weight_factory = batches[0].factories().weight_factory();
967        let mut cursor_list = CursorList::new(weight_factory, cursors);
968
969        let cursor_output = cursor_to_tuples(&mut cursor_list);
970        let expected = merged_batch_to_tuples(&batches);
971
972        assert_eq!(cursor_output, expected);
973        assert_eq!(cursor_output, vec![(1, 1, 1)]);
974    }
975
976    #[test]
977    fn cursor_list_single_batch() {
978        let batch: TestBatch = indexed_zset! { 1 => { 1 => 1, 2 => 2 }, 2 => { 1 => -1 } };
979
980        let batches = vec![batch];
981        let cursors: Vec<_> = batches.iter().map(|b| b.cursor()).collect();
982        let weight_factory = batches[0].factories().weight_factory();
983        let mut cursor_list = CursorList::new(weight_factory, cursors);
984
985        let cursor_output = cursor_to_tuples(&mut cursor_list);
986        let expected = merged_batch_to_tuples(&batches);
987
988        assert_eq!(cursor_output, expected);
989    }
990
991    #[test]
992    fn cursor_list_weights_consolidate() {
993        // Multiple batches with same (k,v) - weights should sum
994        let batch1: TestBatch = indexed_zset! { 1 => { 1 => 2, 2 => 1 } };
995        let batch2: TestBatch = indexed_zset! { 1 => { 1 => 3, 2 => -1 } };
996        let batch3: TestBatch = indexed_zset! { 1 => { 1 => -1 } };
997
998        let batches = vec![batch1, batch2, batch3];
999        let cursors: Vec<_> = batches.iter().map(|b| b.cursor()).collect();
1000        let weight_factory = batches[0].factories().weight_factory();
1001        let mut cursor_list = CursorList::new(weight_factory, cursors);
1002
1003        let cursor_output = cursor_to_tuples(&mut cursor_list);
1004        let expected = merged_batch_to_tuples(&batches);
1005
1006        assert_eq!(cursor_output, expected);
1007        // (1,1): 2+3-1=4, (1,2): 1-1=0 (filtered out)
1008        assert_eq!(cursor_output, vec![(1, 1, 4)]);
1009    }
1010
1011    fn batches(
1012        (max_key, max_value, weight_min, weight_max, max_batch_size, max_num_batches): (
1013            u64,
1014            u64,
1015            i64,
1016            i64,
1017            usize,
1018            usize,
1019        ),
1020    ) -> impl Strategy<Value = Vec<TestBatch>> {
1021        let tuple_strategy = (0..max_key, 0..max_value, weight_min..=weight_max)
1022            .prop_map(|(k, v, w)| Tup2(Tup2(k, v), w));
1023        vec(
1024            vec(tuple_strategy, 0..max_batch_size)
1025                .prop_map(|tuples| TestBatch::from_tuples((), tuples)),
1026            1..=max_num_batches,
1027        )
1028    }
1029
1030    fn test_cursor_list_matches_merge_batches(batches: &[TestBatch]) {
1031        let cursors: Vec<_> = batches.iter().map(|b| b.cursor()).collect();
1032        let weight_factory = batches[0].factories().weight_factory();
1033        let mut cursor_list = CursorList::new(weight_factory, cursors);
1034
1035        let cursor_output = cursor_to_tuples(&mut cursor_list);
1036        let expected = merged_batch_to_tuples(&batches);
1037
1038        assert_eq!(cursor_output, expected);
1039    }
1040
1041    proptest! {
1042        #![proptest_config(ProptestConfig::with_cases(100))]
1043
1044        #[test]
1045        fn proptest_cursor_list_matches_merge_batches1(batches in batches((100, 100, 0, 2, 100, 100))) {
1046            test_cursor_list_matches_merge_batches(&batches);
1047        }
1048
1049        #[test]
1050        fn proptest_cursor_list_matches_merge_batches2(batches in batches((100, 100, -2, 2, 100, 100))) {
1051            test_cursor_list_matches_merge_batches(&batches);
1052        }
1053
1054        #[test]
1055        fn proptest_cursor_list_matches_merge_batches3(batches in batches((1, 1, -2, 2, 100, 100))) {
1056            test_cursor_list_matches_merge_batches(&batches);
1057        }
1058
1059        #[test]
1060        fn proptest_cursor_list_matches_merge_batches4(batches in batches((1000, 1, -2, 2, 100, 100))) {
1061            test_cursor_list_matches_merge_batches(&batches);
1062        }
1063
1064        #[test]
1065        fn proptest_cursor_list_matches_merge_batches5(batches in batches((1, 1000, -2, 2, 100, 100))) {
1066            test_cursor_list_matches_merge_batches(&batches);
1067        }
1068
1069    }
1070}