quill_sql/execution/physical_plan/
aggregate.rs1use crate::catalog::SchemaRef;
4use crate::error::QuillSQLError;
5use crate::execution::physical_plan::PhysicalPlan;
6use crate::execution::{ExecutionContext, VolcanoExecutor};
7use crate::expression::Expr;
8use crate::function::Accumulator;
9use crate::utils::scalar::ScalarValue;
10use crate::{error::QuillSQLResult, storage::tuple::Tuple};
11use std::cell::RefCell;
12use std::collections::HashMap;
13use std::rc::Rc;
14use std::sync::atomic::{AtomicUsize, Ordering};
15
16#[derive(Debug)]
17pub struct PhysicalAggregate {
18 pub input: Rc<PhysicalPlan>,
20 pub group_exprs: Vec<Expr>,
22 pub aggr_exprs: Vec<Expr>,
24 pub schema: SchemaRef,
26
27 pub output_rows: RefCell<Vec<Tuple>>,
28 pub cursor: AtomicUsize,
29}
30
31impl PhysicalAggregate {
32 pub fn new(
33 input: Rc<PhysicalPlan>,
34 group_exprs: Vec<Expr>,
35 aggr_exprs: Vec<Expr>,
36 schema: SchemaRef,
37 ) -> Self {
38 Self {
39 input,
40 group_exprs,
41 aggr_exprs,
42 schema,
43 output_rows: RefCell::new(vec![]),
44 cursor: AtomicUsize::new(0),
45 }
46 }
47}
48
49impl PhysicalAggregate {
50 fn build_accumulators(&self) -> QuillSQLResult<Vec<Box<dyn Accumulator>>> {
51 self.aggr_exprs
52 .iter()
53 .map(|expr| {
54 if let Expr::AggregateFunction(aggr) = expr {
55 Ok(aggr.func_kind.create_accumulator())
56 } else {
57 Err(QuillSQLError::Execution(format!(
58 "aggr expr is not AggregateFunction instead of {}",
59 expr
60 )))
61 }
62 })
63 .collect::<QuillSQLResult<Vec<Box<dyn Accumulator>>>>()
64 }
65}
66
67impl VolcanoExecutor for PhysicalAggregate {
68 fn init(&self, context: &mut ExecutionContext) -> QuillSQLResult<()> {
69 self.input.init(context)?;
70 self.cursor.store(0, Ordering::SeqCst);
71 Ok(())
72 }
73
74 fn next(&self, context: &mut ExecutionContext) -> QuillSQLResult<Option<Tuple>> {
75 let output_rows_len = self.output_rows.borrow().len();
76 if output_rows_len == 0 {
78 let mut groups: HashMap<Vec<ScalarValue>, Vec<Box<dyn Accumulator>>> = HashMap::new();
79 while let Some(tuple) = self.input.next(context)? {
80 let group_key = self
81 .group_exprs
82 .iter()
83 .map(|e| context.eval_expr(e, &tuple))
84 .collect::<QuillSQLResult<Vec<ScalarValue>>>()?;
85 let group_accumulators = if let Some(acc) = groups.get_mut(&group_key) {
86 acc
87 } else {
88 let accumulators = self.build_accumulators()?;
89 groups.insert(group_key.clone(), accumulators);
90 groups.get_mut(&group_key).unwrap()
91 };
92 for (idx, acc) in group_accumulators.iter_mut().enumerate() {
93 acc.update_value(&context.eval_expr(&self.aggr_exprs[idx], &tuple)?)?;
94 }
95 }
96
97 for (group_key, accumulators) in groups.into_iter() {
98 let mut values = accumulators
99 .iter()
100 .map(|acc| acc.evaluate())
101 .collect::<QuillSQLResult<Vec<ScalarValue>>>()?;
102 values.extend(group_key);
103 self.output_rows
104 .borrow_mut()
105 .push(Tuple::new(self.schema.clone(), values));
106 }
107 }
108
109 let cursor = self.cursor.fetch_add(1, Ordering::SeqCst);
110 Ok(self.output_rows.borrow().get(cursor).cloned())
111 }
112
113 fn output_schema(&self) -> SchemaRef {
114 self.schema.clone()
115 }
116}
117
118impl std::fmt::Display for PhysicalAggregate {
119 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120 write!(f, "Aggregate")
121 }
122}