1use std::cmp::Ordering;
22use std::collections::BinaryHeap;
23use std::sync::Arc;
24
25use grafeo_common::types::{LogicalType, Value};
26
27use super::sort::SortKey;
28use super::value_utils::compare_values_with_nulls;
29use super::{Operator, OperatorResult};
30use crate::execution::DataChunk;
31use crate::execution::chunk::DataChunkBuilder;
32
33pub struct TopKOperator {
35 child: Box<dyn Operator>,
36 sort_keys: Arc<Vec<SortKey>>,
41 limit: usize,
42 output_schema: Vec<LogicalType>,
43 state: TopKState,
44 #[cfg(test)]
45 materialized_rows: std::sync::atomic::AtomicUsize,
46}
47
48enum TopKState {
49 Building {
50 heap: BinaryHeap<HeapEntry>,
51 next_insertion_id: u64,
52 },
53 Draining {
54 rows: Vec<HeapEntry>,
55 position: usize,
56 },
57 Done,
58}
59
60struct HeapEntry {
61 sort_values: Vec<Option<Value>>,
62 row_values: Vec<Option<Value>>,
63 insertion_id: u64,
64 sort_keys: Arc<Vec<SortKey>>,
66}
67
68impl TopKOperator {
69 #[must_use]
118 pub fn new(
119 child: Box<dyn Operator>,
120 sort_keys: Vec<SortKey>,
121 limit: usize,
122 output_schema: Vec<LogicalType>,
123 ) -> Self {
124 Self {
125 child,
126 sort_keys: Arc::new(sort_keys),
127 limit,
128 output_schema,
129 state: TopKState::Building {
130 heap: BinaryHeap::new(),
131 next_insertion_id: 0,
132 },
133 #[cfg(test)]
134 materialized_rows: std::sync::atomic::AtomicUsize::new(0),
135 }
136 }
137
138 #[must_use]
147 pub fn into_parts(self) -> (Box<dyn Operator>, Vec<SortKey>, usize) {
148 let sort_keys = Arc::try_unwrap(self.sort_keys).unwrap_or_else(|arc| (*arc).clone());
149 (self.child, sort_keys, self.limit)
150 }
151}
152
153impl Operator for TopKOperator {
154 fn next(&mut self) -> OperatorResult {
155 if matches!(self.state, TopKState::Building { .. }) {
156 let TopKState::Building {
157 mut heap,
158 mut next_insertion_id,
159 } = std::mem::replace(&mut self.state, TopKState::Done)
160 else {
161 unreachable!("matches! guard above")
162 };
163
164 let mut schema_checked = false;
165 while let Some(chunk) = self.child.next()? {
166 if !schema_checked {
167 debug_assert_eq!(
168 chunk.column_count(),
169 self.output_schema.len(),
170 "TopKOperator output_schema width must match child schema width",
171 );
172 schema_checked = true;
173 }
174
175 for row_idx in chunk.selected_indices() {
176 let new_sort_values =
177 extract_sort_values(&chunk, row_idx, self.sort_keys.as_slice());
178
179 let should_push = if heap.len() < self.limit {
180 true
181 } else if let Some(top) = heap.peek() {
182 row_beats_heap_top(&new_sort_values, top, self.sort_keys.as_slice())
183 } else {
184 false
186 };
187
188 if !should_push {
189 continue;
190 }
191
192 let row_values = extract_row_values(&chunk, row_idx, self.output_schema.len());
193 #[cfg(test)]
194 self.materialized_rows
195 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
196 let entry = HeapEntry {
197 sort_values: new_sort_values,
198 row_values,
199 insertion_id: next_insertion_id,
200 sort_keys: Arc::clone(&self.sort_keys),
201 };
202 next_insertion_id += 1;
203 if heap.len() < self.limit {
204 heap.push(entry);
205 } else {
206 let mut top = heap.peek_mut().expect("heap.len() == limit > 0");
210 *top = entry;
211 }
212 }
213 }
214
215 let rows = heap.into_sorted_vec();
216 self.state = TopKState::Draining { rows, position: 0 };
217 }
218
219 if let TopKState::Draining { rows, position } = &mut self.state {
220 if *position < rows.len() {
221 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
222 while *position < rows.len() && !builder.is_full() {
223 let entry = &rows[*position];
224 for col_idx in 0..self.output_schema.len() {
225 if let Some(dst_col) = builder.column_mut(col_idx) {
226 let val = entry.row_values[col_idx].clone().unwrap_or(Value::Null);
227 dst_col.push_value(val);
228 }
229 }
230 builder.advance_row();
231 *position += 1;
232 }
233 if builder.row_count() > 0 {
234 return Ok(Some(builder.finish()));
235 }
236 }
237 self.state = TopKState::Done;
238 }
239
240 Ok(None)
241 }
242
243 fn reset(&mut self) {
244 self.child.reset();
245 self.state = TopKState::Building {
246 heap: BinaryHeap::new(),
247 next_insertion_id: 0,
248 };
249 #[cfg(test)]
250 self.materialized_rows
251 .store(0, std::sync::atomic::Ordering::Relaxed);
252 }
253
254 fn name(&self) -> &'static str {
255 "TopK"
256 }
257
258 fn into_any(self: Box<Self>) -> Box<dyn std::any::Any + Send> {
259 self
260 }
261}
262
263#[cfg(test)]
264impl TopKOperator {
265 pub(crate) fn materialized_rows(&self) -> usize {
266 self.materialized_rows
267 .load(std::sync::atomic::Ordering::Relaxed)
268 }
269}
270
271fn extract_sort_values(
272 chunk: &DataChunk,
273 row_idx: usize,
274 sort_keys: &[SortKey],
275) -> Vec<Option<Value>> {
276 sort_keys
277 .iter()
278 .map(|k| chunk.column(k.column).and_then(|c| c.get_value(row_idx)))
279 .collect()
280}
281
282fn extract_row_values(chunk: &DataChunk, row_idx: usize, n_cols: usize) -> Vec<Option<Value>> {
283 (0..n_cols)
284 .map(|i| chunk.column(i).and_then(|c| c.get_value(row_idx)))
285 .collect()
286}
287
288fn row_beats_heap_top(new: &[Option<Value>], top: &HeapEntry, keys: &[SortKey]) -> bool {
294 use super::sort::SortDirection;
295 for (i, key) in keys.iter().enumerate() {
296 let cmp = compare_values_with_nulls(&new[i], &top.sort_values[i], key.null_order);
297 let user_cmp = match key.direction {
298 SortDirection::Ascending => cmp,
299 SortDirection::Descending => cmp.reverse(),
300 };
301 match user_cmp {
302 Ordering::Less => return true,
303 Ordering::Greater => return false,
304 Ordering::Equal => continue,
305 }
306 }
307 false
308}
309
310impl PartialEq for HeapEntry {
311 fn eq(&self, other: &Self) -> bool {
312 self.insertion_id == other.insertion_id
313 }
314}
315
316impl Eq for HeapEntry {}
317
318impl PartialOrd for HeapEntry {
319 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
320 Some(self.cmp(other))
321 }
322}
323
324impl Ord for HeapEntry {
325 fn cmp(&self, other: &Self) -> Ordering {
326 use super::sort::SortDirection;
327 for (i, key) in self.sort_keys.iter().enumerate() {
337 let cmp = compare_values_with_nulls(
338 &self.sort_values[i],
339 &other.sort_values[i],
340 key.null_order,
341 );
342 let heap_cmp = match key.direction {
343 SortDirection::Ascending => cmp,
344 SortDirection::Descending => cmp.reverse(),
345 };
346 if heap_cmp != Ordering::Equal {
347 return heap_cmp;
348 }
349 }
350 self.insertion_id.cmp(&other.insertion_id)
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use crate::execution::DataChunk;
361 use crate::execution::chunk::DataChunkBuilder;
362
363 struct MockOperator {
364 chunks: Vec<DataChunk>,
365 position: usize,
366 }
367
368 impl MockOperator {
369 fn new(chunks: Vec<DataChunk>) -> Self {
370 Self {
371 chunks,
372 position: 0,
373 }
374 }
375 }
376
377 impl Operator for MockOperator {
378 fn next(&mut self) -> OperatorResult {
379 if self.position < self.chunks.len() {
380 let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
381 self.position += 1;
382 Ok(Some(chunk))
383 } else {
384 Ok(None)
385 }
386 }
387
388 fn reset(&mut self) {
389 self.position = 0;
390 }
391
392 fn name(&self) -> &'static str {
393 "Mock"
394 }
395
396 fn into_any(self: Box<Self>) -> Box<dyn std::any::Any + Send> {
397 self
398 }
399 }
400
401 fn chunk_int64(values: &[i64]) -> DataChunk {
402 let mut b = DataChunkBuilder::new(&[LogicalType::Int64]);
403 for &v in values {
404 b.column_mut(0).unwrap().push_int64(v);
405 b.advance_row();
406 }
407 b.finish()
408 }
409
410 fn collect_int64_col(op: &mut dyn Operator) -> Vec<i64> {
411 let mut out = Vec::new();
412 while let Some(chunk) = op.next().unwrap() {
413 for row in chunk.selected_indices() {
414 out.push(chunk.column(0).unwrap().get_int64(row).unwrap());
415 }
416 }
417 out
418 }
419
420 #[test]
421 fn top_k_returns_top_k_descending() {
422 let mock = MockOperator::new(vec![chunk_int64(&[19, 88, 33, 8, 319])]);
423 let mut top_k = TopKOperator::new(
424 Box::new(mock),
425 vec![SortKey::descending(0)],
426 3,
427 vec![LogicalType::Int64],
428 );
429 let out = collect_int64_col(&mut top_k);
430 assert_eq!(out, vec![319, 88, 33]);
431 }
432
433 fn chunk_int_str(rows: &[(i64, &str)]) -> DataChunk {
434 let mut b = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::String]);
435 for (n, s) in rows {
436 b.column_mut(0).unwrap().push_int64(*n);
437 b.column_mut(1).unwrap().push_string(*s);
438 b.advance_row();
439 }
440 b.finish()
441 }
442
443 fn collect_int_str(op: &mut dyn Operator) -> Vec<(i64, String)> {
444 let mut out = Vec::new();
445 while let Some(chunk) = op.next().unwrap() {
446 for row in chunk.selected_indices() {
447 let n = chunk.column(0).unwrap().get_int64(row).unwrap();
448 let s = chunk
449 .column(1)
450 .unwrap()
451 .get_string(row)
452 .unwrap()
453 .to_string();
454 out.push((n, s));
455 }
456 }
457 out
458 }
459
460 #[test]
461 fn top_k_is_stable_on_ties_descending() {
462 let mock = MockOperator::new(vec![chunk_int_str(&[
464 (3, "Vincent"),
465 (88, "Jules"),
466 (3, "Mia"),
467 (88, "Butch"),
468 ])]);
469 let mut top_k = TopKOperator::new(
470 Box::new(mock),
471 vec![SortKey::descending(0)],
472 2,
473 vec![LogicalType::Int64, LogicalType::String],
474 );
475 let out = collect_int_str(&mut top_k);
476 assert_eq!(out, vec![(88, "Jules".into()), (88, "Butch".into())]);
477 }
478
479 #[test]
480 fn top_k_is_stable_on_ties_ascending() {
481 let mock = MockOperator::new(vec![chunk_int_str(&[
482 (88, "Vincent"),
483 (3, "Jules"),
484 (88, "Mia"),
485 (3, "Butch"),
486 ])]);
487 let mut top_k = TopKOperator::new(
488 Box::new(mock),
489 vec![SortKey::ascending(0)],
490 2,
491 vec![LogicalType::Int64, LogicalType::String],
492 );
493 let out = collect_int_str(&mut top_k);
494 assert_eq!(out, vec![(3, "Jules".into()), (3, "Butch".into())]);
495 }
496
497 #[test]
498 fn top_k_skips_materialization_for_losers() {
499 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
505 let values: Vec<i64> = (0..1000_i64).map(|i| (i * 31 + 7) % 1000).collect();
506 let mock = MockOperator::new(vec![chunk_int64(&values)]);
507 let mut top_k = TopKOperator::new(
508 Box::new(mock),
509 vec![SortKey::ascending(0)],
510 5,
511 vec![LogicalType::Int64],
512 );
513
514 let out = collect_int64_col(&mut top_k);
515 assert_eq!(out.len(), 5);
516
517 let materialized = top_k.materialized_rows();
522 assert!(
523 materialized < 50,
524 "expected < 50 materializations for k=5 over 1000 inputs, got {materialized}"
525 );
526 }
527
528 #[test]
529 fn top_k_multi_key_mixed_directions() {
530 let mock = MockOperator::new(vec![chunk_int_str(&[
535 (88, "5"),
536 (88, "3"),
537 (19, "8"),
538 (88, "5b"),
539 ])]);
540 let mut top_k = TopKOperator::new(
541 Box::new(mock),
542 vec![SortKey::descending(0), SortKey::ascending(1)],
543 2,
544 vec![LogicalType::Int64, LogicalType::String],
545 );
546 let out = collect_int_str(&mut top_k);
547 assert_eq!(out, vec![(88, "3".into()), (88, "5".into())]);
548 }
549
550 #[test]
551 fn top_k_handles_nulls_first_ascending() {
552 use super::super::sort::NullOrder;
553 let mut b = DataChunkBuilder::new(&[LogicalType::Int64]);
554 for v in [Some(19_i64), None, Some(88), None, Some(3)] {
555 match v {
556 Some(n) => b.column_mut(0).unwrap().push_int64(n),
557 None => b.column_mut(0).unwrap().push_value(Value::Null),
558 }
559 b.advance_row();
560 }
561 let chunk = b.finish();
562 let mock = MockOperator::new(vec![chunk]);
563
564 let mut top_k = TopKOperator::new(
565 Box::new(mock),
566 vec![SortKey::ascending(0).with_null_order(NullOrder::NullsFirst)],
567 3,
568 vec![LogicalType::Int64],
569 );
570
571 let mut out = Vec::new();
573 while let Some(chunk) = top_k.next().unwrap() {
574 for row in chunk.selected_indices() {
575 out.push(chunk.column(0).unwrap().get_value(row));
576 }
577 }
578 assert_eq!(out.len(), 3);
579 assert!(matches!(out[0], Some(Value::Null)));
580 assert!(matches!(out[1], Some(Value::Null)));
581 assert_eq!(out[2], Some(Value::Int64(3)));
582 }
583
584 #[test]
585 fn top_k_handles_nulls_last_ascending() {
586 use super::super::sort::NullOrder;
587 let mut b = DataChunkBuilder::new(&[LogicalType::Int64]);
588 for v in [Some(19_i64), None, Some(88), None, Some(3)] {
589 match v {
590 Some(n) => b.column_mut(0).unwrap().push_int64(n),
591 None => b.column_mut(0).unwrap().push_value(Value::Null),
592 }
593 b.advance_row();
594 }
595 let chunk = b.finish();
596 let mock = MockOperator::new(vec![chunk]);
597
598 let mut top_k = TopKOperator::new(
599 Box::new(mock),
600 vec![SortKey::ascending(0).with_null_order(NullOrder::NullsLast)],
601 3,
602 vec![LogicalType::Int64],
603 );
604
605 let mut out = Vec::new();
607 while let Some(chunk) = top_k.next().unwrap() {
608 for row in chunk.selected_indices() {
609 out.push(chunk.column(0).unwrap().get_value(row));
610 }
611 }
612 assert_eq!(
613 out,
614 vec![
615 Some(Value::Int64(3)),
616 Some(Value::Int64(19)),
617 Some(Value::Int64(88))
618 ]
619 );
620 }
621
622 #[test]
623 fn top_k_empty_input() {
624 let mock = MockOperator::new(vec![]);
625 let mut top_k = TopKOperator::new(
626 Box::new(mock),
627 vec![SortKey::descending(0)],
628 5,
629 vec![LogicalType::Int64],
630 );
631 assert_eq!(collect_int64_col(&mut top_k), Vec::<i64>::new());
632 }
633
634 #[test]
635 fn top_k_k_zero_returns_no_rows() {
636 let mock = MockOperator::new(vec![chunk_int64(&[3, 19, 88])]);
637 let mut top_k = TopKOperator::new(
638 Box::new(mock),
639 vec![SortKey::descending(0)],
640 0,
641 vec![LogicalType::Int64],
642 );
643 assert_eq!(collect_int64_col(&mut top_k), Vec::<i64>::new());
644 }
645
646 #[test]
647 fn top_k_k_greater_than_n() {
648 let mock = MockOperator::new(vec![chunk_int64(&[19, 88, 3])]);
649 let mut top_k = TopKOperator::new(
650 Box::new(mock),
651 vec![SortKey::descending(0)],
652 10,
653 vec![LogicalType::Int64],
654 );
655 assert_eq!(collect_int64_col(&mut top_k), vec![88, 19, 3]);
656 }
657
658 #[test]
659 fn top_k_returns_top_k_ascending() {
660 let mock = MockOperator::new(vec![chunk_int64(&[19, 88, 33, 8, 319])]);
661 let mut top_k = TopKOperator::new(
662 Box::new(mock),
663 vec![SortKey::ascending(0)],
664 3,
665 vec![LogicalType::Int64],
666 );
667 assert_eq!(collect_int64_col(&mut top_k), vec![8, 19, 33]);
668 }
669
670 #[test]
671 fn top_k_spans_multiple_input_chunks() {
672 let mock = MockOperator::new(vec![
673 chunk_int64(&[19, 88]),
674 chunk_int64(&[33, 8]),
675 chunk_int64(&[40, 319]),
676 ]);
677 let mut top_k = TopKOperator::new(
678 Box::new(mock),
679 vec![SortKey::descending(0)],
680 3,
681 vec![LogicalType::Int64],
682 );
683 assert_eq!(collect_int64_col(&mut top_k), vec![319, 88, 40]);
684 }
685
686 #[test]
687 fn top_k_into_parts_round_trip() {
688 let mock = MockOperator::new(vec![chunk_int64(&[3, 19, 88])]);
689 let top_k = TopKOperator::new(
690 Box::new(mock),
691 vec![SortKey::descending(0)],
692 5,
693 vec![LogicalType::Int64],
694 );
695 let (mut child, sort_keys, limit) = top_k.into_parts();
696 assert_eq!(sort_keys.len(), 1);
697 assert_eq!(limit, 5);
698 let chunk = child.next().unwrap().expect("mock yields one chunk");
699 assert_eq!(chunk.row_count(), 3);
700 }
701
702 #[test]
703 fn top_k_name() {
704 let mock = MockOperator::new(vec![]);
705 let top_k = TopKOperator::new(
706 Box::new(mock),
707 vec![SortKey::descending(0)],
708 5,
709 vec![LogicalType::Int64],
710 );
711 assert_eq!(top_k.name(), "TopK");
712 }
713
714 #[test]
715 fn top_k_into_any_downcasts() {
716 let mock = MockOperator::new(vec![]);
717 let op: Box<dyn Operator> = Box::new(TopKOperator::new(
718 Box::new(mock),
719 vec![SortKey::descending(0)],
720 5,
721 vec![LogicalType::Int64],
722 ));
723 let any = op.into_any();
724 assert!(any.downcast::<TopKOperator>().is_ok());
725 }
726}