Skip to main content

quill_sql/execution/physical_plan/
aggregate.rs

1//! Hash-based GROUP BY operator aggregating entire input in memory.
2
3use crate::catalog::SchemaRef;
4use crate::error::QuillSQLError;
5use crate::execution::physical_plan::PhysicalPlan;
6use crate::execution::{ExecutionContext, VolcanoExecutor};
7use crate::expression::Expr;
8use crate::function::Accumulator;
9use crate::utils::scalar::ScalarValue;
10use crate::{error::QuillSQLResult, storage::tuple::Tuple};
11use std::cell::RefCell;
12use std::collections::HashMap;
13use std::rc::Rc;
14use std::sync::atomic::{AtomicUsize, Ordering};
15
16#[derive(Debug)]
17pub struct PhysicalAggregate {
18    /// The incoming physical plan
19    pub input: Rc<PhysicalPlan>,
20    /// Grouping expressions
21    pub group_exprs: Vec<Expr>,
22    /// Aggregate expressions
23    pub aggr_exprs: Vec<Expr>,
24    /// The schema description of the aggregate output
25    pub schema: SchemaRef,
26
27    pub output_rows: RefCell<Vec<Tuple>>,
28    pub cursor: AtomicUsize,
29}
30
31impl PhysicalAggregate {
32    pub fn new(
33        input: Rc<PhysicalPlan>,
34        group_exprs: Vec<Expr>,
35        aggr_exprs: Vec<Expr>,
36        schema: SchemaRef,
37    ) -> Self {
38        Self {
39            input,
40            group_exprs,
41            aggr_exprs,
42            schema,
43            output_rows: RefCell::new(vec![]),
44            cursor: AtomicUsize::new(0),
45        }
46    }
47}
48
49impl PhysicalAggregate {
50    fn build_accumulators(&self) -> QuillSQLResult<Vec<Box<dyn Accumulator>>> {
51        self.aggr_exprs
52            .iter()
53            .map(|expr| {
54                if let Expr::AggregateFunction(aggr) = expr {
55                    Ok(aggr.func_kind.create_accumulator())
56                } else {
57                    Err(QuillSQLError::Execution(format!(
58                        "aggr expr is not AggregateFunction instead of {}",
59                        expr
60                    )))
61                }
62            })
63            .collect::<QuillSQLResult<Vec<Box<dyn Accumulator>>>>()
64    }
65}
66
67impl VolcanoExecutor for PhysicalAggregate {
68    fn init(&self, context: &mut ExecutionContext) -> QuillSQLResult<()> {
69        self.input.init(context)?;
70        self.cursor.store(0, Ordering::SeqCst);
71        Ok(())
72    }
73
74    fn next(&self, context: &mut ExecutionContext) -> QuillSQLResult<Option<Tuple>> {
75        let output_rows_len = self.output_rows.borrow().len();
76        // build output rows
77        if output_rows_len == 0 {
78            let mut groups: HashMap<Vec<ScalarValue>, Vec<Box<dyn Accumulator>>> = HashMap::new();
79            while let Some(tuple) = self.input.next(context)? {
80                let group_key = self
81                    .group_exprs
82                    .iter()
83                    .map(|e| context.eval_expr(e, &tuple))
84                    .collect::<QuillSQLResult<Vec<ScalarValue>>>()?;
85                let group_accumulators = if let Some(acc) = groups.get_mut(&group_key) {
86                    acc
87                } else {
88                    let accumulators = self.build_accumulators()?;
89                    groups.insert(group_key.clone(), accumulators);
90                    groups.get_mut(&group_key).unwrap()
91                };
92                for (idx, acc) in group_accumulators.iter_mut().enumerate() {
93                    acc.update_value(&context.eval_expr(&self.aggr_exprs[idx], &tuple)?)?;
94                }
95            }
96
97            for (group_key, accumulators) in groups.into_iter() {
98                let mut values = accumulators
99                    .iter()
100                    .map(|acc| acc.evaluate())
101                    .collect::<QuillSQLResult<Vec<ScalarValue>>>()?;
102                values.extend(group_key);
103                self.output_rows
104                    .borrow_mut()
105                    .push(Tuple::new(self.schema.clone(), values));
106            }
107        }
108
109        let cursor = self.cursor.fetch_add(1, Ordering::SeqCst);
110        Ok(self.output_rows.borrow().get(cursor).cloned())
111    }
112
113    fn output_schema(&self) -> SchemaRef {
114        self.schema.clone()
115    }
116}
117
118impl std::fmt::Display for PhysicalAggregate {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        write!(f, "Aggregate")
121    }
122}