quill_sql/execution/physical_plan/
aggregate.rs1use crate::catalog::SchemaRef;
2use crate::error::QuillSQLError;
3use crate::execution::physical_plan::PhysicalPlan;
4use crate::execution::{ExecutionContext, VolcanoExecutor};
5use crate::expression::{Expr, ExprTrait};
6use crate::function::Accumulator;
7use crate::utils::scalar::ScalarValue;
8use crate::{error::QuillSQLResult, storage::tuple::Tuple};
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::{Arc, Mutex};
12
13#[derive(Debug)]
14pub struct PhysicalAggregate {
15 pub input: Arc<PhysicalPlan>,
17 pub group_exprs: Vec<Expr>,
19 pub aggr_exprs: Vec<Expr>,
21 pub schema: SchemaRef,
23
24 pub output_rows: Mutex<Vec<Tuple>>,
25 pub cursor: AtomicUsize,
26}
27
28impl PhysicalAggregate {
29 pub fn new(
30 input: Arc<PhysicalPlan>,
31 group_exprs: Vec<Expr>,
32 aggr_exprs: Vec<Expr>,
33 schema: SchemaRef,
34 ) -> Self {
35 Self {
36 input,
37 group_exprs,
38 aggr_exprs,
39 schema,
40 output_rows: Mutex::new(vec![]),
41 cursor: AtomicUsize::new(0),
42 }
43 }
44}
45
46impl PhysicalAggregate {
47 fn build_accumulators(&self) -> QuillSQLResult<Vec<Box<dyn Accumulator>>> {
48 self.aggr_exprs
49 .iter()
50 .map(|expr| {
51 if let Expr::AggregateFunction(aggr) = expr {
52 Ok(aggr.func_kind.create_accumulator())
53 } else {
54 Err(QuillSQLError::Execution(format!(
55 "aggr expr is not AggregateFunction instead of {}",
56 expr
57 )))
58 }
59 })
60 .collect::<QuillSQLResult<Vec<Box<dyn Accumulator>>>>()
61 }
62}
63
64impl VolcanoExecutor for PhysicalAggregate {
65 fn init(&self, context: &mut ExecutionContext) -> QuillSQLResult<()> {
66 self.input.init(context)?;
67 self.cursor.store(0, Ordering::SeqCst);
68 Ok(())
69 }
70
71 fn next(&self, context: &mut ExecutionContext) -> QuillSQLResult<Option<Tuple>> {
72 let output_rows_len = self.output_rows.lock().unwrap().len();
73 if output_rows_len == 0 {
75 let mut groups: HashMap<Vec<ScalarValue>, Vec<Box<dyn Accumulator>>> = HashMap::new();
76 while let Some(tuple) = self.input.next(context)? {
77 let group_key = self
78 .group_exprs
79 .iter()
80 .map(|e| e.evaluate(&tuple))
81 .collect::<QuillSQLResult<Vec<ScalarValue>>>()?;
82 let group_accumulators = if let Some(acc) = groups.get_mut(&group_key) {
83 acc
84 } else {
85 let accumulators = self.build_accumulators()?;
86 groups.insert(group_key.clone(), accumulators);
87 groups.get_mut(&group_key).unwrap()
88 };
89 for (idx, acc) in group_accumulators.iter_mut().enumerate() {
90 acc.update_value(&self.aggr_exprs[idx].evaluate(&tuple)?)?;
91 }
92 }
93
94 for (group_key, accumulators) in groups.into_iter() {
95 let mut values = accumulators
96 .iter()
97 .map(|acc| acc.evaluate())
98 .collect::<QuillSQLResult<Vec<ScalarValue>>>()?;
99 values.extend(group_key);
100 self.output_rows
101 .lock()
102 .unwrap()
103 .push(Tuple::new(self.schema.clone(), values));
104 }
105 }
106
107 let cursor = self.cursor.fetch_add(1, Ordering::SeqCst);
108 Ok(self.output_rows.lock().unwrap().get(cursor).cloned())
109 }
110
111 fn output_schema(&self) -> SchemaRef {
112 self.schema.clone()
113 }
114}
115
116impl std::fmt::Display for PhysicalAggregate {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 write!(f, "Aggregate")
119 }
120}