1use crate::error::DbxResult;
6use crate::sql::planner::PhysicalExpr;
7use arrow::array::{Array, ArrayRef, BooleanArray, Float64Array, Int64Array, RecordBatch};
8use arrow::compute;
9use arrow::datatypes::Schema;
10use rayon::prelude::*;
11use std::sync::Arc;
12
13pub struct ParallelQueryExecutor {
18 parallel_threshold: usize,
20 min_rows_for_parallel: usize,
22 thread_pool: Option<Arc<rayon::ThreadPool>>,
24}
25
26impl ParallelQueryExecutor {
27 pub fn new() -> Self {
29 Self {
30 parallel_threshold: 2,
31 min_rows_for_parallel: 1000,
32 thread_pool: None,
33 }
34 }
35
36 pub fn with_thread_pool(mut self, pool: Arc<rayon::ThreadPool>) -> Self {
38 self.thread_pool = Some(pool);
39 self
40 }
41
42 pub fn with_threshold(mut self, threshold: usize) -> Self {
44 self.parallel_threshold = threshold;
45 self
46 }
47
48 pub fn with_min_rows(mut self, min_rows: usize) -> Self {
50 self.min_rows_for_parallel = min_rows;
51 self
52 }
53
54 fn should_parallelize(&self, batches: &[RecordBatch]) -> bool {
56 if batches.len() < self.parallel_threshold {
57 return false;
58 }
59 let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
60 total_rows >= self.min_rows_for_parallel
61 }
62
63 pub fn par_filter(
67 &self,
68 batches: &[RecordBatch],
69 predicate: &PhysicalExpr,
70 ) -> DbxResult<Vec<RecordBatch>> {
71 if !self.should_parallelize(batches) {
72 return batches
74 .iter()
75 .filter_map(
76 |batch| match Self::apply_filter_to_batch(batch, predicate) {
77 Ok(Some(b)) if b.num_rows() > 0 => Some(Ok(b)),
78 Ok(_) => None,
79 Err(e) => Some(Err(e)),
80 },
81 )
82 .collect();
83 }
84
85 let results: Vec<DbxResult<Option<RecordBatch>>> = self.run_parallel(batches, |batch| {
87 Self::apply_filter_to_batch(batch, predicate)
88 });
89
90 results
91 .into_iter()
92 .filter_map(|r| match r {
93 Ok(Some(b)) if b.num_rows() > 0 => Some(Ok(b)),
94 Ok(_) => None,
95 Err(e) => Some(Err(e)),
96 })
97 .collect()
98 }
99
100 pub fn par_aggregate(
104 &self,
105 batches: &[RecordBatch],
106 column_idx: usize,
107 agg_type: AggregateType,
108 ) -> DbxResult<AggregateResult> {
109 if batches.is_empty() {
110 return Ok(AggregateResult::empty(agg_type));
111 }
112
113 let partials: Vec<DbxResult<PartialAggregate>> = if self.should_parallelize(batches) {
115 self.run_parallel(batches, |batch| {
116 Self::partial_aggregate(batch, column_idx, agg_type)
117 })
118 } else {
119 batches
120 .iter()
121 .map(|batch| Self::partial_aggregate(batch, column_idx, agg_type))
122 .collect()
123 };
124
125 let mut merged = PartialAggregate::empty(agg_type);
127 for partial in partials {
128 merged.merge(&partial?);
129 }
130
131 Ok(merged.finalize())
132 }
133
134 pub fn par_project(
136 &self,
137 batches: &[RecordBatch],
138 indices: &[usize],
139 ) -> DbxResult<Vec<RecordBatch>> {
140 if !self.should_parallelize(batches) {
141 return batches
142 .iter()
143 .map(|batch| Self::project_batch(batch, indices))
144 .collect();
145 }
146
147 self.run_parallel(batches, |batch| Self::project_batch(batch, indices))
148 .into_iter()
149 .collect()
150 }
151
152 fn apply_filter_to_batch(
156 batch: &RecordBatch,
157 predicate: &PhysicalExpr,
158 ) -> DbxResult<Option<RecordBatch>> {
159 if batch.num_rows() == 0 {
160 return Ok(None);
161 }
162
163 let result = crate::sql::executor::evaluate_expr(predicate, batch)?;
164 let mask = result
165 .as_any()
166 .downcast_ref::<BooleanArray>()
167 .ok_or_else(|| crate::error::DbxError::TypeMismatch {
168 expected: "BooleanArray".to_string(),
169 actual: format!("{:?}", result.data_type()),
170 })?;
171
172 let filtered = compute::filter_record_batch(batch, mask)?;
173 if filtered.num_rows() > 0 {
174 Ok(Some(filtered))
175 } else {
176 Ok(None)
177 }
178 }
179
180 fn project_batch(batch: &RecordBatch, indices: &[usize]) -> DbxResult<RecordBatch> {
182 let columns: Vec<ArrayRef> = indices
183 .iter()
184 .map(|&idx| Arc::clone(batch.column(idx)))
185 .collect();
186 let fields: Vec<_> = indices
187 .iter()
188 .map(|&idx| batch.schema().field(idx).clone())
189 .collect();
190 let schema = Arc::new(Schema::new(fields));
191 Ok(RecordBatch::try_new(schema, columns)?)
192 }
193
194 fn partial_aggregate(
196 batch: &RecordBatch,
197 column_idx: usize,
198 agg_type: AggregateType,
199 ) -> DbxResult<PartialAggregate> {
200 let column = batch.column(column_idx);
201 let mut partial = PartialAggregate::empty(agg_type);
202
203 if let Some(arr) = column.as_any().downcast_ref::<Int64Array>() {
205 for i in 0..arr.len() {
206 if !arr.is_null(i) {
207 let val = arr.value(i) as f64;
208 partial.accumulate(val);
209 }
210 }
211 } else if let Some(arr) = column.as_any().downcast_ref::<Float64Array>() {
212 for i in 0..arr.len() {
213 if !arr.is_null(i) {
214 partial.accumulate(arr.value(i));
215 }
216 }
217 }
218
219 Ok(partial)
220 }
221
222 fn run_parallel<T, F>(&self, batches: &[RecordBatch], op: F) -> Vec<T>
224 where
225 T: Send,
226 F: Fn(&RecordBatch) -> T + Sync,
227 {
228 if let Some(pool) = &self.thread_pool {
229 pool.install(|| batches.par_iter().map(&op).collect())
230 } else {
231 batches.par_iter().map(&op).collect()
232 }
233 }
234}
235
236impl Default for ParallelQueryExecutor {
237 fn default() -> Self {
238 Self::new()
239 }
240}
241
242#[derive(Debug, Clone, Copy, PartialEq)]
244pub enum AggregateType {
245 Sum,
246 Count,
247 Avg,
248 Min,
249 Max,
250}
251
252#[derive(Debug, Clone)]
254pub struct PartialAggregate {
255 pub agg_type: AggregateType,
256 pub sum: f64,
257 pub count: u64,
258 pub min: f64,
259 pub max: f64,
260}
261
262impl PartialAggregate {
263 fn empty(agg_type: AggregateType) -> Self {
264 Self {
265 agg_type,
266 sum: 0.0,
267 count: 0,
268 min: f64::MAX,
269 max: f64::MIN,
270 }
271 }
272
273 fn accumulate(&mut self, val: f64) {
274 self.sum += val;
275 self.count += 1;
276 if val < self.min {
277 self.min = val;
278 }
279 if val > self.max {
280 self.max = val;
281 }
282 }
283
284 fn merge(&mut self, other: &PartialAggregate) {
285 self.sum += other.sum;
286 self.count += other.count;
287 if other.min < self.min {
288 self.min = other.min;
289 }
290 if other.max > self.max {
291 self.max = other.max;
292 }
293 }
294
295 fn finalize(&self) -> AggregateResult {
296 match self.agg_type {
297 AggregateType::Sum => AggregateResult {
298 value: self.sum,
299 count: self.count,
300 },
301 AggregateType::Count => AggregateResult {
302 value: self.count as f64,
303 count: self.count,
304 },
305 AggregateType::Avg => {
306 let avg = if self.count > 0 {
307 self.sum / self.count as f64
308 } else {
309 0.0
310 };
311 AggregateResult {
312 value: avg,
313 count: self.count,
314 }
315 }
316 AggregateType::Min => AggregateResult {
317 value: self.min,
318 count: self.count,
319 },
320 AggregateType::Max => AggregateResult {
321 value: self.max,
322 count: self.count,
323 },
324 }
325 }
326}
327
328#[derive(Debug, Clone)]
330pub struct AggregateResult {
331 pub value: f64,
332 pub count: u64,
333}
334
335impl AggregateResult {
336 fn empty(_agg_type: AggregateType) -> Self {
337 Self {
338 value: 0.0,
339 count: 0,
340 }
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use arrow::array::{Int64Array, StringArray};
348 use arrow::datatypes::{DataType, Field, Schema};
349
350 fn make_test_batch(ids: &[i64], names: &[&str]) -> RecordBatch {
351 let schema = Arc::new(Schema::new(vec![
352 Field::new("id", DataType::Int64, false),
353 Field::new("name", DataType::Utf8, false),
354 ]));
355 RecordBatch::try_new(
356 schema,
357 vec![
358 Arc::new(Int64Array::from(ids.to_vec())),
359 Arc::new(StringArray::from(names.to_vec())),
360 ],
361 )
362 .unwrap()
363 }
364
365 #[test]
366 fn test_par_aggregate_sum() {
367 let executor = ParallelQueryExecutor::new();
368 let batches = vec![
369 make_test_batch(&[1, 2, 3], &["a", "b", "c"]),
370 make_test_batch(&[4, 5, 6], &["d", "e", "f"]),
371 make_test_batch(&[7, 8, 9], &["g", "h", "i"]),
372 ];
373
374 let result = executor
375 .par_aggregate(&batches, 0, AggregateType::Sum)
376 .unwrap();
377 assert_eq!(result.value, 45.0); assert_eq!(result.count, 9);
379 }
380
381 #[test]
382 fn test_par_aggregate_avg() {
383 let executor = ParallelQueryExecutor::new();
384 let batches = vec![
385 make_test_batch(&[10, 20], &["a", "b"]),
386 make_test_batch(&[30, 40], &["c", "d"]),
387 ];
388
389 let result = executor
390 .par_aggregate(&batches, 0, AggregateType::Avg)
391 .unwrap();
392 assert_eq!(result.value, 25.0);
393 }
394
395 #[test]
396 fn test_par_aggregate_min_max() {
397 let executor = ParallelQueryExecutor::new();
398 let batches = vec![
399 make_test_batch(&[5, 1, 8], &["a", "b", "c"]),
400 make_test_batch(&[3, 9, 2], &["d", "e", "f"]),
401 ];
402
403 let min_result = executor
404 .par_aggregate(&batches, 0, AggregateType::Min)
405 .unwrap();
406 assert_eq!(min_result.value, 1.0);
407
408 let max_result = executor
409 .par_aggregate(&batches, 0, AggregateType::Max)
410 .unwrap();
411 assert_eq!(max_result.value, 9.0);
412 }
413
414 #[test]
415 fn test_par_project() {
416 let executor = ParallelQueryExecutor::new();
417 let batches = vec![
418 make_test_batch(&[1, 2], &["a", "b"]),
419 make_test_batch(&[3, 4], &["c", "d"]),
420 make_test_batch(&[5, 6], &["e", "f"]),
421 ];
422
423 let projected = executor.par_project(&batches, &[0]).unwrap();
424 assert_eq!(projected.len(), 3);
425 assert_eq!(projected[0].num_columns(), 1);
426 assert_eq!(projected[0].schema().field(0).name(), "id");
427 }
428
429 #[test]
430 fn test_par_aggregate_empty() {
431 let executor = ParallelQueryExecutor::new();
432 let batches: Vec<RecordBatch> = vec![];
433
434 let result = executor
435 .par_aggregate(&batches, 0, AggregateType::Count)
436 .unwrap();
437 assert_eq!(result.count, 0);
438 }
439}