1use crate::error::Result;
4use crate::executor::scan::{ColumnData, RecordBatch};
5use crate::parser::ast::OrderByExpr;
6use rayon::prelude::*;
7use std::cmp::Ordering;
8use std::collections::BinaryHeap;
9use std::sync::Arc;
10
11#[derive(Debug, Clone)]
13pub struct ParallelConfig {
14 pub num_threads: usize,
16 pub min_batch_size: usize,
18}
19
20#[derive(Debug, Clone, Copy)]
22struct BatchCursor {
23 batch_idx: usize,
25 row_idx: usize,
27}
28
29struct MergeEntry<'a> {
31 cursor: BatchCursor,
33 batches: &'a [RecordBatch],
35 order_by: &'a [OrderByExpr],
37}
38
39impl<'a> Eq for MergeEntry<'a> {}
40
41impl<'a> PartialEq for MergeEntry<'a> {
42 fn eq(&self, other: &Self) -> bool {
43 self.cmp(other) == Ordering::Equal
44 }
45}
46
47impl<'a> Ord for MergeEntry<'a> {
48 fn cmp(&self, other: &Self) -> Ordering {
49 compare_rows(
51 self.batches,
52 self.cursor.batch_idx,
53 self.cursor.row_idx,
54 other.cursor.batch_idx,
55 other.cursor.row_idx,
56 self.order_by,
57 )
58 .reverse()
59 }
60}
61
62impl<'a> PartialOrd for MergeEntry<'a> {
63 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
64 Some(self.cmp(other))
65 }
66}
67
68#[derive(Debug, Clone)]
70enum ColumnValue {
71 Boolean(Option<bool>),
72 Int32(Option<i32>),
73 Int64(Option<i64>),
74 Float32(Option<f32>),
75 Float64(Option<f64>),
76 String(Option<String>),
77 Binary(Option<bytes::Bytes>),
78}
79
80impl Default for ParallelConfig {
81 fn default() -> Self {
82 Self {
83 num_threads: num_cpus(),
84 min_batch_size: 1000,
85 }
86 }
87}
88
89pub struct ParallelExecutor {
91 config: ParallelConfig,
93}
94
95impl ParallelExecutor {
96 pub fn new(config: ParallelConfig) -> Self {
98 Self { config }
99 }
100
101 pub fn execute_parallel<F>(
103 &self,
104 batches: Vec<RecordBatch>,
105 func: F,
106 ) -> Result<Vec<RecordBatch>>
107 where
108 F: Fn(&RecordBatch) -> Result<RecordBatch> + Send + Sync,
109 {
110 if batches.len() < 2
111 || batches
112 .iter()
113 .all(|b| b.num_rows < self.config.min_batch_size)
114 {
115 return batches.into_iter().map(|batch| func(&batch)).collect();
117 }
118
119 let func = Arc::new(func);
121 let results: Result<Vec<_>> = batches
122 .par_iter()
123 .map(|batch| {
124 let f = func.clone();
125 f(batch)
126 })
127 .collect();
128
129 results
130 }
131
132 pub fn partition_batches(
134 &self,
135 batches: Vec<RecordBatch>,
136 num_partitions: usize,
137 ) -> Vec<Vec<RecordBatch>> {
138 if num_partitions == 0 || batches.is_empty() {
139 return vec![batches];
140 }
141
142 let total_rows: usize = batches.iter().map(|b| b.num_rows).sum();
143 let rows_per_partition = total_rows.div_ceil(num_partitions);
144
145 let mut partitions: Vec<Vec<RecordBatch>> = vec![Vec::new(); num_partitions];
146 let mut current_partition = 0;
147 let mut current_partition_rows = 0;
148
149 for batch in batches {
150 if current_partition_rows >= rows_per_partition
151 && current_partition < num_partitions - 1
152 {
153 current_partition += 1;
154 current_partition_rows = 0;
155 }
156
157 current_partition_rows += batch.num_rows;
158 partitions[current_partition].push(batch);
159 }
160
161 partitions
162 }
163
164 pub fn merge_batches(
169 &self,
170 batches: Vec<RecordBatch>,
171 order_by: Option<&[OrderByExpr]>,
172 ) -> Result<Vec<RecordBatch>> {
173 if batches.is_empty() {
174 return Ok(vec![]);
175 }
176
177 let Some(order_by) = order_by else {
179 return Ok(batches);
180 };
181
182 if batches.len() == 1 {
184 return Ok(batches);
185 }
186
187 self.k_way_merge(batches, order_by)
189 }
190
191 fn k_way_merge(
193 &self,
194 batches: Vec<RecordBatch>,
195 order_by: &[OrderByExpr],
196 ) -> Result<Vec<RecordBatch>> {
197 if batches.is_empty() {
198 return Ok(vec![]);
199 }
200
201 let schema = batches[0].schema.clone();
202
203 let cursors: Vec<BatchCursor> = batches
205 .iter()
206 .enumerate()
207 .filter(|(_, batch)| batch.num_rows > 0)
208 .map(|(idx, _)| BatchCursor {
209 batch_idx: idx,
210 row_idx: 0,
211 })
212 .collect();
213
214 if cursors.is_empty() {
215 return Ok(vec![]);
216 }
217
218 let mut heap = BinaryHeap::new();
221 for cursor in &cursors {
222 heap.push(MergeEntry {
223 cursor: *cursor,
224 batches: &batches,
225 order_by,
226 });
227 }
228
229 let total_rows: usize = batches.iter().map(|b| b.num_rows).sum();
231
232 let num_columns = schema.fields.len();
234 let mut output_columns: Vec<Vec<Option<ColumnValue>>> =
235 vec![Vec::with_capacity(total_rows); num_columns];
236
237 while let Some(entry) = heap.pop() {
239 let batch = &batches[entry.cursor.batch_idx];
240 let row_idx = entry.cursor.row_idx;
241
242 for (col_idx, column) in batch.columns.iter().enumerate() {
244 let value = extract_value(column, row_idx);
245 output_columns[col_idx].push(value);
246 }
247
248 let next_row = row_idx + 1;
250 if next_row < batch.num_rows {
251 heap.push(MergeEntry {
252 cursor: BatchCursor {
253 batch_idx: entry.cursor.batch_idx,
254 row_idx: next_row,
255 },
256 batches: &batches,
257 order_by,
258 });
259 }
260 }
261
262 let merged_columns: Vec<ColumnData> = output_columns
264 .into_iter()
265 .zip(schema.fields.iter())
266 .map(|(values, field)| values_to_column_data(values, &field.data_type))
267 .collect();
268
269 let merged_batch = RecordBatch::new(schema, merged_columns, total_rows)?;
271
272 Ok(vec![merged_batch])
273 }
274}
275
276pub struct Pipeline {
278 stages: Vec<Box<dyn PipelineStage>>,
280}
281
282impl Pipeline {
283 pub fn new() -> Self {
285 Self { stages: Vec::new() }
286 }
287
288 pub fn add_stage<S: PipelineStage + 'static>(mut self, stage: S) -> Self {
290 self.stages.push(Box::new(stage));
291 self
292 }
293
294 pub async fn execute(&self, input: Vec<RecordBatch>) -> Result<Vec<RecordBatch>> {
296 let mut current = input;
297
298 for stage in &self.stages {
299 current = stage.execute(current).await?;
300 }
301
302 Ok(current)
303 }
304}
305
306impl Default for Pipeline {
307 fn default() -> Self {
308 Self::new()
309 }
310}
311
312#[async_trait::async_trait]
314pub trait PipelineStage: Send + Sync {
315 async fn execute(&self, input: Vec<RecordBatch>) -> Result<Vec<RecordBatch>>;
317}
318
319pub struct TaskScheduler {
321 num_workers: usize,
323}
324
325impl TaskScheduler {
326 pub fn new(num_workers: usize) -> Self {
328 Self { num_workers }
329 }
330
331 pub fn schedule<F, T>(&self, tasks: Vec<F>) -> Vec<T>
333 where
334 F: Fn() -> T + Send,
335 T: Send,
336 {
337 tasks.into_par_iter().map(|task| task()).collect()
338 }
339
340 pub fn num_workers(&self) -> usize {
342 self.num_workers
343 }
344}
345
346fn num_cpus() -> usize {
348 std::thread::available_parallelism()
349 .map(|n| n.get())
350 .unwrap_or(4)
351}
352
353fn compare_rows(
355 batches: &[RecordBatch],
356 batch_a: usize,
357 row_a: usize,
358 batch_b: usize,
359 row_b: usize,
360 order_by: &[OrderByExpr],
361) -> Ordering {
362 use crate::parser::ast::Expr;
363
364 let batch_a = &batches[batch_a];
365 let batch_b = &batches[batch_b];
366
367 for order in order_by {
368 let column_name = match &order.expr {
370 Expr::Column { name, .. } => name,
371 _ => continue, };
373
374 let col_idx_a = batch_a.schema.index_of(column_name);
376 let col_idx_b = batch_b.schema.index_of(column_name);
377
378 if let (Some(idx_a), Some(idx_b)) = (col_idx_a, col_idx_b) {
379 let ordering = compare_column_values(
380 &batch_a.columns[idx_a],
381 row_a,
382 &batch_b.columns[idx_b],
383 row_b,
384 order.nulls_first,
385 );
386
387 let ordering = if order.asc {
388 ordering
389 } else {
390 ordering.reverse()
391 };
392
393 if ordering != Ordering::Equal {
394 return ordering;
395 }
396 }
397 }
398
399 Ordering::Equal
400}
401
402fn compare_column_values(
404 col_a: &ColumnData,
405 row_a: usize,
406 col_b: &ColumnData,
407 row_b: usize,
408 nulls_first: bool,
409) -> Ordering {
410 match (col_a, col_b) {
411 (ColumnData::Boolean(data_a), ColumnData::Boolean(data_b)) => {
412 compare_optional(&data_a[row_a], &data_b[row_b], nulls_first)
413 }
414 (ColumnData::Int32(data_a), ColumnData::Int32(data_b)) => {
415 compare_optional(&data_a[row_a], &data_b[row_b], nulls_first)
416 }
417 (ColumnData::Int64(data_a), ColumnData::Int64(data_b)) => {
418 compare_optional(&data_a[row_a], &data_b[row_b], nulls_first)
419 }
420 (ColumnData::Float32(data_a), ColumnData::Float32(data_b)) => {
421 let val_a = &data_a[row_a];
422 let val_b = &data_b[row_b];
423 match (val_a, val_b) {
424 (Some(a), Some(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
425 (Some(_), None) => {
426 if nulls_first {
427 Ordering::Greater
428 } else {
429 Ordering::Less
430 }
431 }
432 (None, Some(_)) => {
433 if nulls_first {
434 Ordering::Less
435 } else {
436 Ordering::Greater
437 }
438 }
439 (None, None) => Ordering::Equal,
440 }
441 }
442 (ColumnData::Float64(data_a), ColumnData::Float64(data_b)) => {
443 let val_a = &data_a[row_a];
444 let val_b = &data_b[row_b];
445 match (val_a, val_b) {
446 (Some(a), Some(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
447 (Some(_), None) => {
448 if nulls_first {
449 Ordering::Greater
450 } else {
451 Ordering::Less
452 }
453 }
454 (None, Some(_)) => {
455 if nulls_first {
456 Ordering::Less
457 } else {
458 Ordering::Greater
459 }
460 }
461 (None, None) => Ordering::Equal,
462 }
463 }
464 (ColumnData::String(data_a), ColumnData::String(data_b)) => {
465 compare_optional(&data_a[row_a], &data_b[row_b], nulls_first)
466 }
467 (ColumnData::Binary(data_a), ColumnData::Binary(data_b)) => {
468 compare_optional(&data_a[row_a], &data_b[row_b], nulls_first)
469 }
470 _ => Ordering::Equal, }
472}
473
474fn compare_optional<T: Ord>(a: &Option<T>, b: &Option<T>, nulls_first: bool) -> Ordering {
476 match (a, b) {
477 (Some(a), Some(b)) => a.cmp(b),
478 (Some(_), None) => {
479 if nulls_first {
480 Ordering::Greater
481 } else {
482 Ordering::Less
483 }
484 }
485 (None, Some(_)) => {
486 if nulls_first {
487 Ordering::Less
488 } else {
489 Ordering::Greater
490 }
491 }
492 (None, None) => Ordering::Equal,
493 }
494}
495
496fn extract_value(column: &ColumnData, row_idx: usize) -> Option<ColumnValue> {
498 match column {
499 ColumnData::Boolean(data) => Some(ColumnValue::Boolean(data[row_idx])),
500 ColumnData::Int32(data) => Some(ColumnValue::Int32(data[row_idx])),
501 ColumnData::Int64(data) => Some(ColumnValue::Int64(data[row_idx])),
502 ColumnData::Float32(data) => Some(ColumnValue::Float32(data[row_idx])),
503 ColumnData::Float64(data) => Some(ColumnValue::Float64(data[row_idx])),
504 ColumnData::String(data) => Some(ColumnValue::String(data[row_idx].clone())),
505 ColumnData::Binary(data) => Some(ColumnValue::Binary(data[row_idx].clone())),
506 }
507}
508
509fn values_to_column_data(
511 values: Vec<Option<ColumnValue>>,
512 data_type: &crate::executor::scan::DataType,
513) -> ColumnData {
514 use crate::executor::scan::DataType;
515
516 match data_type {
517 DataType::Boolean => {
518 let data: Vec<Option<bool>> = values
519 .into_iter()
520 .map(|v| {
521 v.and_then(|val| {
522 if let ColumnValue::Boolean(b) = val {
523 b
524 } else {
525 None
526 }
527 })
528 })
529 .collect();
530 ColumnData::Boolean(data)
531 }
532 DataType::Int32 => {
533 let data: Vec<Option<i32>> = values
534 .into_iter()
535 .map(|v| {
536 v.and_then(|val| {
537 if let ColumnValue::Int32(i) = val {
538 i
539 } else {
540 None
541 }
542 })
543 })
544 .collect();
545 ColumnData::Int32(data)
546 }
547 DataType::Int64 => {
548 let data: Vec<Option<i64>> = values
549 .into_iter()
550 .map(|v| {
551 v.and_then(|val| {
552 if let ColumnValue::Int64(i) = val {
553 i
554 } else {
555 None
556 }
557 })
558 })
559 .collect();
560 ColumnData::Int64(data)
561 }
562 DataType::Float32 => {
563 let data: Vec<Option<f32>> = values
564 .into_iter()
565 .map(|v| {
566 v.and_then(|val| {
567 if let ColumnValue::Float32(f) = val {
568 f
569 } else {
570 None
571 }
572 })
573 })
574 .collect();
575 ColumnData::Float32(data)
576 }
577 DataType::Float64 => {
578 let data: Vec<Option<f64>> = values
579 .into_iter()
580 .map(|v| {
581 v.and_then(|val| {
582 if let ColumnValue::Float64(f) = val {
583 f
584 } else {
585 None
586 }
587 })
588 })
589 .collect();
590 ColumnData::Float64(data)
591 }
592 DataType::String => {
593 let data: Vec<Option<String>> = values
594 .into_iter()
595 .map(|v| {
596 v.and_then(|val| {
597 if let ColumnValue::String(s) = val {
598 s
599 } else {
600 None
601 }
602 })
603 })
604 .collect();
605 ColumnData::String(data)
606 }
607 DataType::Binary => {
608 let data: Vec<Option<bytes::Bytes>> = values
609 .into_iter()
610 .map(|v| {
611 v.and_then(|val| {
612 if let ColumnValue::Binary(b) = val {
613 b
614 } else {
615 None
616 }
617 })
618 })
619 .collect();
620 ColumnData::Binary(data)
621 }
622 DataType::Geometry => {
623 ColumnData::Binary(vec![None; values.len()])
625 }
626 }
627}
628
629#[cfg(test)]
630#[allow(clippy::needless_range_loop)]
631#[allow(clippy::panic)]
632mod tests {
633 use super::*;
634 use crate::executor::scan::{ColumnData, DataType, Field, Schema};
635 use std::sync::Arc;
636
637 #[test]
638 fn test_parallel_executor() -> Result<()> {
639 let config = ParallelConfig::default();
640 let executor = ParallelExecutor::new(config);
641
642 let schema = Arc::new(Schema::new(vec![Field::new(
643 "value".to_string(),
644 DataType::Int64,
645 false,
646 )]));
647
648 let mut batches = Vec::new();
649 for i in 0..5 {
650 let columns = vec![ColumnData::Int64(vec![Some(i), Some(i + 1)])];
651 batches.push(RecordBatch::new(schema.clone(), columns, 2)?);
652 }
653
654 let results = executor.execute_parallel(batches, |batch| Ok(batch.clone()))?;
655
656 assert_eq!(results.len(), 5);
657
658 Ok(())
659 }
660
661 #[test]
662 fn test_partition_batches() {
663 let config = ParallelConfig::default();
664 let executor = ParallelExecutor::new(config);
665
666 let schema = Arc::new(Schema::new(vec![Field::new(
667 "value".to_string(),
668 DataType::Int64,
669 false,
670 )]));
671
672 let mut batches = Vec::new();
673 for i in 0..10 {
674 let columns = vec![ColumnData::Int64(vec![Some(i)])];
675 if let Ok(batch) = RecordBatch::new(schema.clone(), columns, 1) {
676 batches.push(batch);
677 }
678 }
679
680 let partitions = executor.partition_batches(batches, 3);
681 assert_eq!(partitions.len(), 3);
682 }
683
684 #[test]
685 fn test_merge_batches_no_order() -> Result<()> {
686 let config = ParallelConfig::default();
687 let executor = ParallelExecutor::new(config);
688
689 let schema = Arc::new(Schema::new(vec![Field::new(
690 "value".to_string(),
691 DataType::Int64,
692 false,
693 )]));
694
695 let mut batches = Vec::new();
696 for i in 0..3 {
697 let columns = vec![ColumnData::Int64(vec![Some(i), Some(i + 1)])];
698 batches.push(RecordBatch::new(schema.clone(), columns, 2)?);
699 }
700
701 let merged = executor.merge_batches(batches, None)?;
702 assert_eq!(merged.len(), 3); Ok(())
705 }
706
707 #[test]
708 fn test_merge_batches_with_order() -> Result<()> {
709 use crate::parser::ast::{Expr, OrderByExpr};
710
711 let config = ParallelConfig::default();
712 let executor = ParallelExecutor::new(config);
713
714 let schema = Arc::new(Schema::new(vec![
715 Field::new("id".to_string(), DataType::Int64, false),
716 Field::new("value".to_string(), DataType::Int64, false),
717 ]));
718
719 let batch1 = RecordBatch::new(
721 schema.clone(),
722 vec![
723 ColumnData::Int64(vec![Some(1), Some(4), Some(7)]),
724 ColumnData::Int64(vec![Some(10), Some(40), Some(70)]),
725 ],
726 3,
727 )?;
728
729 let batch2 = RecordBatch::new(
730 schema.clone(),
731 vec![
732 ColumnData::Int64(vec![Some(2), Some(5), Some(8)]),
733 ColumnData::Int64(vec![Some(20), Some(50), Some(80)]),
734 ],
735 3,
736 )?;
737
738 let batch3 = RecordBatch::new(
739 schema.clone(),
740 vec![
741 ColumnData::Int64(vec![Some(3), Some(6), Some(9)]),
742 ColumnData::Int64(vec![Some(30), Some(60), Some(90)]),
743 ],
744 3,
745 )?;
746
747 let order_by = vec![OrderByExpr {
748 expr: Expr::Column {
749 table: None,
750 name: "id".to_string(),
751 },
752 asc: true,
753 nulls_first: false,
754 }];
755
756 let merged = executor.merge_batches(vec![batch1, batch2, batch3], Some(&order_by))?;
757
758 assert_eq!(merged.len(), 1);
759 assert_eq!(merged[0].num_rows, 9);
760
761 let ColumnData::Int64(data) = &merged[0].columns[0] else {
763 panic!("Expected Int64 column");
764 };
765 for i in 0..9 {
766 assert_eq!(data[i], Some((i + 1) as i64));
767 }
768
769 Ok(())
770 }
771
772 #[test]
773 fn test_merge_batches_descending() -> Result<()> {
774 use crate::parser::ast::{Expr, OrderByExpr};
775
776 let config = ParallelConfig::default();
777 let executor = ParallelExecutor::new(config);
778
779 let schema = Arc::new(Schema::new(vec![Field::new(
780 "score".to_string(),
781 DataType::Float64,
782 false,
783 )]));
784
785 let batch1 = RecordBatch::new(
787 schema.clone(),
788 vec![ColumnData::Float64(vec![Some(9.5), Some(7.5), Some(5.5)])],
789 3,
790 )?;
791
792 let batch2 = RecordBatch::new(
793 schema.clone(),
794 vec![ColumnData::Float64(vec![Some(8.5), Some(6.5), Some(4.5)])],
795 3,
796 )?;
797
798 let order_by = vec![OrderByExpr {
799 expr: Expr::Column {
800 table: None,
801 name: "score".to_string(),
802 },
803 asc: false, nulls_first: false,
805 }];
806
807 let merged = executor.merge_batches(vec![batch1, batch2], Some(&order_by))?;
808
809 assert_eq!(merged.len(), 1);
810 assert_eq!(merged[0].num_rows, 6);
811
812 let ColumnData::Float64(data) = &merged[0].columns[0] else {
814 panic!("Expected Float64 column");
815 };
816 let expected = [9.5, 8.5, 7.5, 6.5, 5.5, 4.5];
817 for (i, &exp) in expected.iter().enumerate() {
818 assert_eq!(data[i], Some(exp));
819 }
820
821 Ok(())
822 }
823
824 #[test]
825 fn test_merge_batches_with_nulls() -> Result<()> {
826 use crate::parser::ast::{Expr, OrderByExpr};
827
828 let config = ParallelConfig::default();
829 let executor = ParallelExecutor::new(config);
830
831 let schema = Arc::new(Schema::new(vec![Field::new(
832 "value".to_string(),
833 DataType::Int32,
834 true,
835 )]));
836
837 let batch1_nulls_last = RecordBatch::new(
839 schema.clone(),
840 vec![ColumnData::Int32(vec![Some(1), Some(5), None])],
841 3,
842 )?;
843
844 let batch2_nulls_last = RecordBatch::new(
845 schema.clone(),
846 vec![ColumnData::Int32(vec![Some(3), Some(7), None])], 3,
848 )?;
849
850 let order_by = vec![OrderByExpr {
851 expr: Expr::Column {
852 table: None,
853 name: "value".to_string(),
854 },
855 asc: true,
856 nulls_first: false,
857 }];
858
859 let merged =
860 executor.merge_batches(vec![batch1_nulls_last, batch2_nulls_last], Some(&order_by))?;
861
862 assert_eq!(merged.len(), 1);
863 assert_eq!(merged[0].num_rows, 6);
864
865 let ColumnData::Int32(data) = &merged[0].columns[0] else {
866 panic!("Expected Int32 column");
867 };
868 assert_eq!(data[0], Some(1));
870 assert_eq!(data[1], Some(3));
871 assert_eq!(data[2], Some(5));
872 assert_eq!(data[3], Some(7));
873 assert_eq!(data[4], None);
874 assert_eq!(data[5], None);
875
876 let batch1_nulls_first = RecordBatch::new(
878 schema.clone(),
879 vec![ColumnData::Int32(vec![None, Some(1), Some(5)])],
880 3,
881 )?;
882
883 let batch2_nulls_first = RecordBatch::new(
884 schema.clone(),
885 vec![ColumnData::Int32(vec![None, Some(3), Some(7)])],
886 3,
887 )?;
888
889 let order_by_nulls_first = vec![OrderByExpr {
890 expr: Expr::Column {
891 table: None,
892 name: "value".to_string(),
893 },
894 asc: true,
895 nulls_first: true,
896 }];
897
898 let merged2 = executor.merge_batches(
899 vec![batch1_nulls_first, batch2_nulls_first],
900 Some(&order_by_nulls_first),
901 )?;
902
903 let ColumnData::Int32(data) = &merged2[0].columns[0] else {
904 panic!("Expected Int32 column");
905 };
906 assert_eq!(data[0], None);
908 assert_eq!(data[1], None);
909 assert_eq!(data[2], Some(1));
910 assert_eq!(data[3], Some(3));
911 assert_eq!(data[4], Some(5));
912 assert_eq!(data[5], Some(7));
913
914 Ok(())
915 }
916}