Skip to main content

dbx_core/sql/executor/
parallel_query.rs

1//! Parallel Query Executor — Phase 2: Section 4.2
2//!
3//! RecordBatch 단위 병렬 처리: 스캔, 필터, 집계를 Rayon 기반으로 병렬화
4
5use crate::error::DbxResult;
6use crate::sql::planner::PhysicalExpr;
7use arrow::array::{Array, ArrayRef, BooleanArray, Float64Array, Int64Array, RecordBatch};
8use arrow::compute;
9use arrow::datatypes::Schema;
10use rayon::prelude::*;
11use std::sync::Arc;
12
13/// 병렬 쿼리 실행기
14///
15/// 여러 RecordBatch를 Rayon work-stealing 스레드 풀로 동시에 처리합니다.
16/// 총 행 수가 `min_rows_for_parallel` 미만이면 순차 실행으로 fallback합니다.
17pub struct ParallelQueryExecutor {
18    /// 병렬화 임계값 (이 이상의 batch 수에서 병렬 처리)
19    parallel_threshold: usize,
20    /// 병렬화 최소 행 수 (이 이하면 순차 실행)
21    min_rows_for_parallel: usize,
22    /// 사용할 스레드 풀 (None이면 글로벌)
23    thread_pool: Option<Arc<rayon::ThreadPool>>,
24}
25
26impl ParallelQueryExecutor {
27    /// 새 병렬 쿼리 실행기 생성
28    pub fn new() -> Self {
29        Self {
30            parallel_threshold: 2,
31            min_rows_for_parallel: 1000,
32            thread_pool: None,
33        }
34    }
35
36    /// 커스텀 스레드 풀 설정
37    pub fn with_thread_pool(mut self, pool: Arc<rayon::ThreadPool>) -> Self {
38        self.thread_pool = Some(pool);
39        self
40    }
41
42    /// 병렬화 batch 수 임계값 설정
43    pub fn with_threshold(mut self, threshold: usize) -> Self {
44        self.parallel_threshold = threshold;
45        self
46    }
47
48    /// 병렬화 최소 행 수 설정
49    pub fn with_min_rows(mut self, min_rows: usize) -> Self {
50        self.min_rows_for_parallel = min_rows;
51        self
52    }
53
54    /// 총 행 수가 임계값 이상인지 판단
55    fn should_parallelize(&self, batches: &[RecordBatch]) -> bool {
56        if batches.len() < self.parallel_threshold {
57            return false;
58        }
59        let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
60        total_rows >= self.min_rows_for_parallel
61    }
62
63    /// 병렬 테이블 스캔 + 필터
64    ///
65    /// 여러 RecordBatch를 병렬로 필터링합니다.
66    pub fn par_filter(
67        &self,
68        batches: &[RecordBatch],
69        predicate: &PhysicalExpr,
70    ) -> DbxResult<Vec<RecordBatch>> {
71        if !self.should_parallelize(batches) {
72            // Sequential fallback (소규모 데이터)
73            return batches
74                .iter()
75                .filter_map(
76                    |batch| match Self::apply_filter_to_batch(batch, predicate) {
77                        Ok(Some(b)) if b.num_rows() > 0 => Some(Ok(b)),
78                        Ok(_) => None,
79                        Err(e) => Some(Err(e)),
80                    },
81                )
82                .collect();
83        }
84
85        // Parallel
86        let results: Vec<DbxResult<Option<RecordBatch>>> = self.run_parallel(batches, |batch| {
87            Self::apply_filter_to_batch(batch, predicate)
88        });
89
90        results
91            .into_iter()
92            .filter_map(|r| match r {
93                Ok(Some(b)) if b.num_rows() > 0 => Some(Ok(b)),
94                Ok(_) => None,
95                Err(e) => Some(Err(e)),
96            })
97            .collect()
98    }
99
100    /// 병렬 집계 (SUM, COUNT, AVG, MIN, MAX)
101    ///
102    /// 각 batch를 병렬로 부분 집계 후, 최종 집계합니다.
103    pub fn par_aggregate(
104        &self,
105        batches: &[RecordBatch],
106        column_idx: usize,
107        agg_type: AggregateType,
108    ) -> DbxResult<AggregateResult> {
109        if batches.is_empty() {
110            return Ok(AggregateResult::empty(agg_type));
111        }
112
113        // 행 수 기반 순차/병렬 분기
114        let partials: Vec<DbxResult<PartialAggregate>> = if self.should_parallelize(batches) {
115            self.run_parallel(batches, |batch| {
116                Self::partial_aggregate(batch, column_idx, agg_type)
117            })
118        } else {
119            batches
120                .iter()
121                .map(|batch| Self::partial_aggregate(batch, column_idx, agg_type))
122                .collect()
123        };
124
125        // Phase 2: merge partial results
126        let mut merged = PartialAggregate::empty(agg_type);
127        for partial in partials {
128            merged.merge(&partial?);
129        }
130
131        Ok(merged.finalize())
132    }
133
134    /// 병렬 프로젝션 (컬럼 선택)
135    pub fn par_project(
136        &self,
137        batches: &[RecordBatch],
138        indices: &[usize],
139    ) -> DbxResult<Vec<RecordBatch>> {
140        if !self.should_parallelize(batches) {
141            return batches
142                .iter()
143                .map(|batch| Self::project_batch(batch, indices))
144                .collect();
145        }
146
147        self.run_parallel(batches, |batch| Self::project_batch(batch, indices))
148            .into_iter()
149            .collect()
150    }
151
152    // ─── Internal helpers ───────────────────────────────
153
154    /// 단일 batch에 필터 적용
155    fn apply_filter_to_batch(
156        batch: &RecordBatch,
157        predicate: &PhysicalExpr,
158    ) -> DbxResult<Option<RecordBatch>> {
159        if batch.num_rows() == 0 {
160            return Ok(None);
161        }
162
163        let result = crate::sql::executor::evaluate_expr(predicate, batch)?;
164        let mask = result
165            .as_any()
166            .downcast_ref::<BooleanArray>()
167            .ok_or_else(|| crate::error::DbxError::TypeMismatch {
168                expected: "BooleanArray".to_string(),
169                actual: format!("{:?}", result.data_type()),
170            })?;
171
172        let filtered = compute::filter_record_batch(batch, mask)?;
173        if filtered.num_rows() > 0 {
174            Ok(Some(filtered))
175        } else {
176            Ok(None)
177        }
178    }
179
180    /// 단일 batch에 프로젝션 적용
181    fn project_batch(batch: &RecordBatch, indices: &[usize]) -> DbxResult<RecordBatch> {
182        let columns: Vec<ArrayRef> = indices
183            .iter()
184            .map(|&idx| Arc::clone(batch.column(idx)))
185            .collect();
186        let fields: Vec<_> = indices
187            .iter()
188            .map(|&idx| batch.schema().field(idx).clone())
189            .collect();
190        let schema = Arc::new(Schema::new(fields));
191        Ok(RecordBatch::try_new(schema, columns)?)
192    }
193
194    /// 단일 batch에 대한 부분 집계
195    fn partial_aggregate(
196        batch: &RecordBatch,
197        column_idx: usize,
198        agg_type: AggregateType,
199    ) -> DbxResult<PartialAggregate> {
200        let column = batch.column(column_idx);
201        let mut partial = PartialAggregate::empty(agg_type);
202
203        // Try as Int64 first, then Float64
204        if let Some(arr) = column.as_any().downcast_ref::<Int64Array>() {
205            for i in 0..arr.len() {
206                if !arr.is_null(i) {
207                    let val = arr.value(i) as f64;
208                    partial.accumulate(val);
209                }
210            }
211        } else if let Some(arr) = column.as_any().downcast_ref::<Float64Array>() {
212            for i in 0..arr.len() {
213                if !arr.is_null(i) {
214                    partial.accumulate(arr.value(i));
215                }
216            }
217        }
218
219        Ok(partial)
220    }
221
222    /// Rayon 기반 병렬 실행 (스레드 풀 사용)
223    fn run_parallel<T, F>(&self, batches: &[RecordBatch], op: F) -> Vec<T>
224    where
225        T: Send,
226        F: Fn(&RecordBatch) -> T + Sync,
227    {
228        if let Some(pool) = &self.thread_pool {
229            pool.install(|| batches.par_iter().map(&op).collect())
230        } else {
231            batches.par_iter().map(&op).collect()
232        }
233    }
234}
235
236impl Default for ParallelQueryExecutor {
237    fn default() -> Self {
238        Self::new()
239    }
240}
241
242/// 집계 연산 종류
243#[derive(Debug, Clone, Copy, PartialEq)]
244pub enum AggregateType {
245    Sum,
246    Count,
247    Avg,
248    Min,
249    Max,
250}
251
252/// 부분 집계 결과 (병렬 merge 가능)
253#[derive(Debug, Clone)]
254pub struct PartialAggregate {
255    pub agg_type: AggregateType,
256    pub sum: f64,
257    pub count: u64,
258    pub min: f64,
259    pub max: f64,
260}
261
262impl PartialAggregate {
263    fn empty(agg_type: AggregateType) -> Self {
264        Self {
265            agg_type,
266            sum: 0.0,
267            count: 0,
268            min: f64::MAX,
269            max: f64::MIN,
270        }
271    }
272
273    fn accumulate(&mut self, val: f64) {
274        self.sum += val;
275        self.count += 1;
276        if val < self.min {
277            self.min = val;
278        }
279        if val > self.max {
280            self.max = val;
281        }
282    }
283
284    fn merge(&mut self, other: &PartialAggregate) {
285        self.sum += other.sum;
286        self.count += other.count;
287        if other.min < self.min {
288            self.min = other.min;
289        }
290        if other.max > self.max {
291            self.max = other.max;
292        }
293    }
294
295    fn finalize(&self) -> AggregateResult {
296        match self.agg_type {
297            AggregateType::Sum => AggregateResult {
298                value: self.sum,
299                count: self.count,
300            },
301            AggregateType::Count => AggregateResult {
302                value: self.count as f64,
303                count: self.count,
304            },
305            AggregateType::Avg => {
306                let avg = if self.count > 0 {
307                    self.sum / self.count as f64
308                } else {
309                    0.0
310                };
311                AggregateResult {
312                    value: avg,
313                    count: self.count,
314                }
315            }
316            AggregateType::Min => AggregateResult {
317                value: self.min,
318                count: self.count,
319            },
320            AggregateType::Max => AggregateResult {
321                value: self.max,
322                count: self.count,
323            },
324        }
325    }
326}
327
328/// 최종 집계 결과
329#[derive(Debug, Clone)]
330pub struct AggregateResult {
331    pub value: f64,
332    pub count: u64,
333}
334
335impl AggregateResult {
336    fn empty(_agg_type: AggregateType) -> Self {
337        Self {
338            value: 0.0,
339            count: 0,
340        }
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use arrow::array::{Int64Array, StringArray};
348    use arrow::datatypes::{DataType, Field, Schema};
349
350    fn make_test_batch(ids: &[i64], names: &[&str]) -> RecordBatch {
351        let schema = Arc::new(Schema::new(vec![
352            Field::new("id", DataType::Int64, false),
353            Field::new("name", DataType::Utf8, false),
354        ]));
355        RecordBatch::try_new(
356            schema,
357            vec![
358                Arc::new(Int64Array::from(ids.to_vec())),
359                Arc::new(StringArray::from(names.to_vec())),
360            ],
361        )
362        .unwrap()
363    }
364
365    #[test]
366    fn test_par_aggregate_sum() {
367        let executor = ParallelQueryExecutor::new();
368        let batches = vec![
369            make_test_batch(&[1, 2, 3], &["a", "b", "c"]),
370            make_test_batch(&[4, 5, 6], &["d", "e", "f"]),
371            make_test_batch(&[7, 8, 9], &["g", "h", "i"]),
372        ];
373
374        let result = executor
375            .par_aggregate(&batches, 0, AggregateType::Sum)
376            .unwrap();
377        assert_eq!(result.value, 45.0); // 1+2+...+9
378        assert_eq!(result.count, 9);
379    }
380
381    #[test]
382    fn test_par_aggregate_avg() {
383        let executor = ParallelQueryExecutor::new();
384        let batches = vec![
385            make_test_batch(&[10, 20], &["a", "b"]),
386            make_test_batch(&[30, 40], &["c", "d"]),
387        ];
388
389        let result = executor
390            .par_aggregate(&batches, 0, AggregateType::Avg)
391            .unwrap();
392        assert_eq!(result.value, 25.0);
393    }
394
395    #[test]
396    fn test_par_aggregate_min_max() {
397        let executor = ParallelQueryExecutor::new();
398        let batches = vec![
399            make_test_batch(&[5, 1, 8], &["a", "b", "c"]),
400            make_test_batch(&[3, 9, 2], &["d", "e", "f"]),
401        ];
402
403        let min_result = executor
404            .par_aggregate(&batches, 0, AggregateType::Min)
405            .unwrap();
406        assert_eq!(min_result.value, 1.0);
407
408        let max_result = executor
409            .par_aggregate(&batches, 0, AggregateType::Max)
410            .unwrap();
411        assert_eq!(max_result.value, 9.0);
412    }
413
414    #[test]
415    fn test_par_project() {
416        let executor = ParallelQueryExecutor::new();
417        let batches = vec![
418            make_test_batch(&[1, 2], &["a", "b"]),
419            make_test_batch(&[3, 4], &["c", "d"]),
420            make_test_batch(&[5, 6], &["e", "f"]),
421        ];
422
423        let projected = executor.par_project(&batches, &[0]).unwrap();
424        assert_eq!(projected.len(), 3);
425        assert_eq!(projected[0].num_columns(), 1);
426        assert_eq!(projected[0].schema().field(0).name(), "id");
427    }
428
429    #[test]
430    fn test_par_aggregate_empty() {
431        let executor = ParallelQueryExecutor::new();
432        let batches: Vec<RecordBatch> = vec![];
433
434        let result = executor
435            .par_aggregate(&batches, 0, AggregateType::Count)
436            .unwrap();
437        assert_eq!(result.count, 0);
438    }
439}