quill_sql/execution/physical_plan/
aggregate.rs

1use crate::catalog::SchemaRef;
2use crate::error::QuillSQLError;
3use crate::execution::physical_plan::PhysicalPlan;
4use crate::execution::{ExecutionContext, VolcanoExecutor};
5use crate::expression::{Expr, ExprTrait};
6use crate::function::Accumulator;
7use crate::utils::scalar::ScalarValue;
8use crate::{error::QuillSQLResult, storage::tuple::Tuple};
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicUsize, Ordering};
11use std::sync::{Arc, Mutex};
12
13#[derive(Debug)]
14pub struct PhysicalAggregate {
15    /// The incoming physical plan
16    pub input: Arc<PhysicalPlan>,
17    /// Grouping expressions
18    pub group_exprs: Vec<Expr>,
19    /// Aggregate expressions
20    pub aggr_exprs: Vec<Expr>,
21    /// The schema description of the aggregate output
22    pub schema: SchemaRef,
23
24    pub output_rows: Mutex<Vec<Tuple>>,
25    pub cursor: AtomicUsize,
26}
27
28impl PhysicalAggregate {
29    pub fn new(
30        input: Arc<PhysicalPlan>,
31        group_exprs: Vec<Expr>,
32        aggr_exprs: Vec<Expr>,
33        schema: SchemaRef,
34    ) -> Self {
35        Self {
36            input,
37            group_exprs,
38            aggr_exprs,
39            schema,
40            output_rows: Mutex::new(vec![]),
41            cursor: AtomicUsize::new(0),
42        }
43    }
44}
45
46impl PhysicalAggregate {
47    fn build_accumulators(&self) -> QuillSQLResult<Vec<Box<dyn Accumulator>>> {
48        self.aggr_exprs
49            .iter()
50            .map(|expr| {
51                if let Expr::AggregateFunction(aggr) = expr {
52                    Ok(aggr.func_kind.create_accumulator())
53                } else {
54                    Err(QuillSQLError::Execution(format!(
55                        "aggr expr is not AggregateFunction instead of {}",
56                        expr
57                    )))
58                }
59            })
60            .collect::<QuillSQLResult<Vec<Box<dyn Accumulator>>>>()
61    }
62}
63
64impl VolcanoExecutor for PhysicalAggregate {
65    fn init(&self, context: &mut ExecutionContext) -> QuillSQLResult<()> {
66        self.input.init(context)?;
67        self.cursor.store(0, Ordering::SeqCst);
68        Ok(())
69    }
70
71    fn next(&self, context: &mut ExecutionContext) -> QuillSQLResult<Option<Tuple>> {
72        let output_rows_len = self.output_rows.lock().unwrap().len();
73        // build output rows
74        if output_rows_len == 0 {
75            let mut groups: HashMap<Vec<ScalarValue>, Vec<Box<dyn Accumulator>>> = HashMap::new();
76            while let Some(tuple) = self.input.next(context)? {
77                let group_key = self
78                    .group_exprs
79                    .iter()
80                    .map(|e| e.evaluate(&tuple))
81                    .collect::<QuillSQLResult<Vec<ScalarValue>>>()?;
82                let group_accumulators = if let Some(acc) = groups.get_mut(&group_key) {
83                    acc
84                } else {
85                    let accumulators = self.build_accumulators()?;
86                    groups.insert(group_key.clone(), accumulators);
87                    groups.get_mut(&group_key).unwrap()
88                };
89                for (idx, acc) in group_accumulators.iter_mut().enumerate() {
90                    acc.update_value(&self.aggr_exprs[idx].evaluate(&tuple)?)?;
91                }
92            }
93
94            for (group_key, accumulators) in groups.into_iter() {
95                let mut values = accumulators
96                    .iter()
97                    .map(|acc| acc.evaluate())
98                    .collect::<QuillSQLResult<Vec<ScalarValue>>>()?;
99                values.extend(group_key);
100                self.output_rows
101                    .lock()
102                    .unwrap()
103                    .push(Tuple::new(self.schema.clone(), values));
104            }
105        }
106
107        let cursor = self.cursor.fetch_add(1, Ordering::SeqCst);
108        Ok(self.output_rows.lock().unwrap().get(cursor).cloned())
109    }
110
111    fn output_schema(&self) -> SchemaRef {
112        self.schema.clone()
113    }
114}
115
116impl std::fmt::Display for PhysicalAggregate {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        write!(f, "Aggregate")
119    }
120}