Skip to main content

kyu_executor/operators/
aggregate.rs

1//! Aggregate operator — hash-based GROUP BY + aggregation.
2
3use hashbrown::HashMap;
4use kyu_common::KyuResult;
5use kyu_expression::{evaluate, BoundExpression};
6use kyu_planner::{AggFunc, AggregateSpec};
7use kyu_types::TypedValue;
8
9use crate::context::ExecutionContext;
10use crate::data_chunk::DataChunk;
11use crate::physical_plan::PhysicalOperator;
12
13pub struct AggregateOp {
14    pub child: Box<PhysicalOperator>,
15    pub group_by: Vec<BoundExpression>,
16    pub aggregates: Vec<AggregateSpec>,
17    result: Option<DataChunk>,
18}
19
20impl AggregateOp {
21    pub fn new(
22        child: PhysicalOperator,
23        group_by: Vec<BoundExpression>,
24        aggregates: Vec<AggregateSpec>,
25    ) -> Self {
26        Self {
27            child: Box::new(child),
28            group_by,
29            aggregates,
30            result: None,
31        }
32    }
33
34    pub fn next(&mut self, ctx: &ExecutionContext<'_>) -> KyuResult<Option<DataChunk>> {
35        if self.result.is_some() {
36            // Already consumed.
37            return Ok(None);
38        }
39
40        // Drain child and accumulate.
41        let num_aggs = self.aggregates.len();
42        let num_groups = self.group_by.len();
43
44        // Map from group-by key → accumulator states.
45        let mut groups: HashMap<Vec<TypedValue>, Vec<AccState>> = HashMap::new();
46        let mut insertion_order: Vec<Vec<TypedValue>> = Vec::new();
47
48        while let Some(chunk) = self.child.next(ctx)? {
49            for row_idx in 0..chunk.num_rows() {
50                let row_ref = chunk.row_ref(row_idx);
51
52                let key: Vec<TypedValue> = self
53                    .group_by
54                    .iter()
55                    .map(|expr| evaluate(expr, &row_ref))
56                    .collect::<KyuResult<_>>()?;
57
58                let accs = groups.entry(key).or_insert_with_key(|k| {
59                    insertion_order.push(k.clone());
60                    (0..num_aggs).map(|_| AccState::new()).collect()
61                });
62
63                for (i, agg) in self.aggregates.iter().enumerate() {
64                    let val = if let Some(ref arg) = agg.arg {
65                        evaluate(arg, &row_ref)?
66                    } else {
67                        TypedValue::Null
68                    };
69                    accs[i].accumulate(agg.resolved_func, &val);
70                }
71            }
72        }
73
74        // If no groups and no group-by keys (e.g., COUNT(*) with no rows),
75        // produce one row with identity values.
76        if groups.is_empty() && num_groups == 0 {
77            let key = Vec::new();
78            let accs: Vec<AccState> = (0..num_aggs).map(|_| AccState::new()).collect();
79            groups.insert(key.clone(), accs);
80            insertion_order.push(key);
81        }
82
83        // Build result DataChunk.
84        let total_cols = num_groups + num_aggs;
85        let mut result_chunk = DataChunk::with_capacity(total_cols, insertion_order.len());
86
87        for key in &insertion_order {
88            let accs = groups.get(key).unwrap();
89            let mut row = key.clone();
90            for (i, agg) in self.aggregates.iter().enumerate() {
91                row.push(accs[i].finalize(agg.resolved_func));
92            }
93            result_chunk.append_row(&row);
94        }
95
96        self.result = Some(DataChunk::empty(0)); // Mark as consumed.
97
98        if result_chunk.is_empty() {
99            Ok(None)
100        } else {
101            Ok(Some(result_chunk))
102        }
103    }
104}
105
106/// Per-group accumulator state.
107struct AccState {
108    count: i64,
109    sum_i64: i64,
110    sum_f64: f64,
111    min: Option<TypedValue>,
112    max: Option<TypedValue>,
113    collected: Vec<TypedValue>,
114    is_float: bool,
115}
116
117impl AccState {
118    fn new() -> Self {
119        Self {
120            count: 0,
121            sum_i64: 0,
122            sum_f64: 0.0,
123            min: None,
124            max: None,
125            collected: Vec::new(),
126            is_float: false,
127        }
128    }
129
130    fn accumulate(&mut self, func: AggFunc, val: &TypedValue) {
131        match func {
132            AggFunc::Count => {
133                self.count += 1;
134            }
135            AggFunc::Sum => {
136                match val {
137                    TypedValue::Int64(v) => self.sum_i64 += v,
138                    TypedValue::Int32(v) => self.sum_i64 += *v as i64,
139                    TypedValue::Double(v) => {
140                        self.sum_f64 += v;
141                        self.is_float = true;
142                    }
143                    TypedValue::Float(v) => {
144                        self.sum_f64 += *v as f64;
145                        self.is_float = true;
146                    }
147                    _ => {}
148                }
149                self.count += 1;
150            }
151            AggFunc::Avg => {
152                match val {
153                    TypedValue::Int64(v) => self.sum_f64 += *v as f64,
154                    TypedValue::Int32(v) => self.sum_f64 += *v as f64,
155                    TypedValue::Double(v) => self.sum_f64 += v,
156                    TypedValue::Float(v) => self.sum_f64 += *v as f64,
157                    _ => {}
158                }
159                if *val != TypedValue::Null {
160                    self.count += 1;
161                }
162            }
163            AggFunc::Min => {
164                if *val != TypedValue::Null {
165                    self.min = Some(match &self.min {
166                        None => val.clone(),
167                        Some(current) => {
168                            if typed_value_lt(val, current) {
169                                val.clone()
170                            } else {
171                                current.clone()
172                            }
173                        }
174                    });
175                }
176            }
177            AggFunc::Max => {
178                if *val != TypedValue::Null {
179                    self.max = Some(match &self.max {
180                        None => val.clone(),
181                        Some(current) => {
182                            if typed_value_lt(current, val) {
183                                val.clone()
184                            } else {
185                                current.clone()
186                            }
187                        }
188                    });
189                }
190            }
191            AggFunc::Collect => {
192                self.collected.push(val.clone());
193            }
194        }
195    }
196
197    fn finalize(&self, func: AggFunc) -> TypedValue {
198        match func {
199            AggFunc::Count => TypedValue::Int64(self.count),
200            AggFunc::Sum => {
201                if self.is_float {
202                    TypedValue::Double(self.sum_f64 + self.sum_i64 as f64)
203                } else {
204                    TypedValue::Int64(self.sum_i64)
205                }
206            }
207            AggFunc::Avg => {
208                if self.count == 0 {
209                    TypedValue::Null
210                } else {
211                    TypedValue::Double(self.sum_f64 / self.count as f64)
212                }
213            }
214            AggFunc::Min => self.min.clone().unwrap_or(TypedValue::Null),
215            AggFunc::Max => self.max.clone().unwrap_or(TypedValue::Null),
216            AggFunc::Collect => TypedValue::List(self.collected.clone()),
217        }
218    }
219}
220
221fn typed_value_lt(a: &TypedValue, b: &TypedValue) -> bool {
222    match (a, b) {
223        (TypedValue::Int64(a), TypedValue::Int64(b)) => a < b,
224        (TypedValue::Int32(a), TypedValue::Int32(b)) => a < b,
225        (TypedValue::Double(a), TypedValue::Double(b)) => a < b,
226        (TypedValue::Float(a), TypedValue::Float(b)) => a < b,
227        (TypedValue::String(a), TypedValue::String(b)) => a < b,
228        _ => false,
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use crate::context::MockStorage;
236    use kyu_types::LogicalType;
237    use smol_str::SmolStr;
238
239    fn make_storage() -> MockStorage {
240        let mut storage = MockStorage::new();
241        storage.insert_table(
242            kyu_common::id::TableId(0),
243            vec![
244                vec![TypedValue::String(SmolStr::new("A")), TypedValue::Int64(10)],
245                vec![TypedValue::String(SmolStr::new("B")), TypedValue::Int64(20)],
246                vec![TypedValue::String(SmolStr::new("A")), TypedValue::Int64(30)],
247            ],
248        );
249        storage
250    }
251
252    #[test]
253    fn count_star_no_group_by() {
254        let storage = make_storage();
255        let ctx = ExecutionContext::new(kyu_catalog::CatalogContent::new(), &storage);
256        let scan = PhysicalOperator::ScanNode(crate::operators::scan::ScanNodeOp::new(
257            kyu_common::id::TableId(0),
258        ));
259        let mut agg = AggregateOp::new(
260            scan,
261            vec![],
262            vec![AggregateSpec {
263                function_name: SmolStr::new("count"),
264                resolved_func: AggFunc::Count,
265                arg: None,
266                distinct: false,
267                result_type: LogicalType::Int64,
268                alias: SmolStr::new("cnt"),
269            }],
270        );
271        let chunk = agg.next(&ctx).unwrap().unwrap();
272        assert_eq!(chunk.num_rows(), 1);
273        assert_eq!(chunk.get_row(0), vec![TypedValue::Int64(3)]);
274    }
275
276    #[test]
277    fn sum_with_group_by() {
278        let storage = make_storage();
279        let ctx = ExecutionContext::new(kyu_catalog::CatalogContent::new(), &storage);
280        let scan = PhysicalOperator::ScanNode(crate::operators::scan::ScanNodeOp::new(
281            kyu_common::id::TableId(0),
282        ));
283        let mut agg = AggregateOp::new(
284            scan,
285            vec![BoundExpression::Variable {
286                index: 0,
287                result_type: LogicalType::String,
288            }],
289            vec![AggregateSpec {
290                function_name: SmolStr::new("sum"),
291                resolved_func: AggFunc::Sum,
292                arg: Some(BoundExpression::Variable {
293                    index: 1,
294                    result_type: LogicalType::Int64,
295                }),
296                distinct: false,
297                result_type: LogicalType::Int64,
298                alias: SmolStr::new("total"),
299            }],
300        );
301        let chunk = agg.next(&ctx).unwrap().unwrap();
302        assert_eq!(chunk.num_rows(), 2); // Group A and B
303        let row0 = chunk.get_row(0);
304        let row1 = chunk.get_row(1);
305        assert_eq!(row0[0], TypedValue::String(SmolStr::new("A")));
306        assert_eq!(row0[1], TypedValue::Int64(40));
307        assert_eq!(row1[0], TypedValue::String(SmolStr::new("B")));
308        assert_eq!(row1[1], TypedValue::Int64(20));
309    }
310
311    #[test]
312    fn count_star_empty_input() {
313        let storage = MockStorage::new();
314        let ctx = ExecutionContext::new(
315            kyu_catalog::CatalogContent::new(),
316            &storage,
317        );
318        let scan = PhysicalOperator::ScanNode(crate::operators::scan::ScanNodeOp::new(
319            kyu_common::id::TableId(99),
320        ));
321        let mut agg = AggregateOp::new(
322            scan,
323            vec![],
324            vec![AggregateSpec {
325                function_name: SmolStr::new("count"),
326                resolved_func: AggFunc::Count,
327                arg: None,
328                distinct: false,
329                result_type: LogicalType::Int64,
330                alias: SmolStr::new("cnt"),
331            }],
332        );
333        let chunk = agg.next(&ctx).unwrap().unwrap();
334        assert_eq!(chunk.num_rows(), 1);
335        assert_eq!(chunk.get_row(0), vec![TypedValue::Int64(0)]);
336    }
337}