1use hashbrown::HashMap;
4use kyu_common::KyuResult;
5use kyu_expression::{evaluate, BoundExpression};
6use kyu_planner::{AggFunc, AggregateSpec};
7use kyu_types::TypedValue;
8
9use crate::context::ExecutionContext;
10use crate::data_chunk::DataChunk;
11use crate::physical_plan::PhysicalOperator;
12
13pub struct AggregateOp {
14 pub child: Box<PhysicalOperator>,
15 pub group_by: Vec<BoundExpression>,
16 pub aggregates: Vec<AggregateSpec>,
17 result: Option<DataChunk>,
18}
19
20impl AggregateOp {
21 pub fn new(
22 child: PhysicalOperator,
23 group_by: Vec<BoundExpression>,
24 aggregates: Vec<AggregateSpec>,
25 ) -> Self {
26 Self {
27 child: Box::new(child),
28 group_by,
29 aggregates,
30 result: None,
31 }
32 }
33
34 pub fn next(&mut self, ctx: &ExecutionContext<'_>) -> KyuResult<Option<DataChunk>> {
35 if self.result.is_some() {
36 return Ok(None);
38 }
39
40 let num_aggs = self.aggregates.len();
42 let num_groups = self.group_by.len();
43
44 let mut groups: HashMap<Vec<TypedValue>, Vec<AccState>> = HashMap::new();
46 let mut insertion_order: Vec<Vec<TypedValue>> = Vec::new();
47
48 while let Some(chunk) = self.child.next(ctx)? {
49 for row_idx in 0..chunk.num_rows() {
50 let row_ref = chunk.row_ref(row_idx);
51
52 let key: Vec<TypedValue> = self
53 .group_by
54 .iter()
55 .map(|expr| evaluate(expr, &row_ref))
56 .collect::<KyuResult<_>>()?;
57
58 let accs = groups.entry(key).or_insert_with_key(|k| {
59 insertion_order.push(k.clone());
60 (0..num_aggs).map(|_| AccState::new()).collect()
61 });
62
63 for (i, agg) in self.aggregates.iter().enumerate() {
64 let val = if let Some(ref arg) = agg.arg {
65 evaluate(arg, &row_ref)?
66 } else {
67 TypedValue::Null
68 };
69 accs[i].accumulate(agg.resolved_func, &val);
70 }
71 }
72 }
73
74 if groups.is_empty() && num_groups == 0 {
77 let key = Vec::new();
78 let accs: Vec<AccState> = (0..num_aggs).map(|_| AccState::new()).collect();
79 groups.insert(key.clone(), accs);
80 insertion_order.push(key);
81 }
82
83 let total_cols = num_groups + num_aggs;
85 let mut result_chunk = DataChunk::with_capacity(total_cols, insertion_order.len());
86
87 for key in &insertion_order {
88 let accs = groups.get(key).unwrap();
89 let mut row = key.clone();
90 for (i, agg) in self.aggregates.iter().enumerate() {
91 row.push(accs[i].finalize(agg.resolved_func));
92 }
93 result_chunk.append_row(&row);
94 }
95
96 self.result = Some(DataChunk::empty(0)); if result_chunk.is_empty() {
99 Ok(None)
100 } else {
101 Ok(Some(result_chunk))
102 }
103 }
104}
105
106struct AccState {
108 count: i64,
109 sum_i64: i64,
110 sum_f64: f64,
111 min: Option<TypedValue>,
112 max: Option<TypedValue>,
113 collected: Vec<TypedValue>,
114 is_float: bool,
115}
116
117impl AccState {
118 fn new() -> Self {
119 Self {
120 count: 0,
121 sum_i64: 0,
122 sum_f64: 0.0,
123 min: None,
124 max: None,
125 collected: Vec::new(),
126 is_float: false,
127 }
128 }
129
130 fn accumulate(&mut self, func: AggFunc, val: &TypedValue) {
131 match func {
132 AggFunc::Count => {
133 self.count += 1;
134 }
135 AggFunc::Sum => {
136 match val {
137 TypedValue::Int64(v) => self.sum_i64 += v,
138 TypedValue::Int32(v) => self.sum_i64 += *v as i64,
139 TypedValue::Double(v) => {
140 self.sum_f64 += v;
141 self.is_float = true;
142 }
143 TypedValue::Float(v) => {
144 self.sum_f64 += *v as f64;
145 self.is_float = true;
146 }
147 _ => {}
148 }
149 self.count += 1;
150 }
151 AggFunc::Avg => {
152 match val {
153 TypedValue::Int64(v) => self.sum_f64 += *v as f64,
154 TypedValue::Int32(v) => self.sum_f64 += *v as f64,
155 TypedValue::Double(v) => self.sum_f64 += v,
156 TypedValue::Float(v) => self.sum_f64 += *v as f64,
157 _ => {}
158 }
159 if *val != TypedValue::Null {
160 self.count += 1;
161 }
162 }
163 AggFunc::Min => {
164 if *val != TypedValue::Null {
165 self.min = Some(match &self.min {
166 None => val.clone(),
167 Some(current) => {
168 if typed_value_lt(val, current) {
169 val.clone()
170 } else {
171 current.clone()
172 }
173 }
174 });
175 }
176 }
177 AggFunc::Max => {
178 if *val != TypedValue::Null {
179 self.max = Some(match &self.max {
180 None => val.clone(),
181 Some(current) => {
182 if typed_value_lt(current, val) {
183 val.clone()
184 } else {
185 current.clone()
186 }
187 }
188 });
189 }
190 }
191 AggFunc::Collect => {
192 self.collected.push(val.clone());
193 }
194 }
195 }
196
197 fn finalize(&self, func: AggFunc) -> TypedValue {
198 match func {
199 AggFunc::Count => TypedValue::Int64(self.count),
200 AggFunc::Sum => {
201 if self.is_float {
202 TypedValue::Double(self.sum_f64 + self.sum_i64 as f64)
203 } else {
204 TypedValue::Int64(self.sum_i64)
205 }
206 }
207 AggFunc::Avg => {
208 if self.count == 0 {
209 TypedValue::Null
210 } else {
211 TypedValue::Double(self.sum_f64 / self.count as f64)
212 }
213 }
214 AggFunc::Min => self.min.clone().unwrap_or(TypedValue::Null),
215 AggFunc::Max => self.max.clone().unwrap_or(TypedValue::Null),
216 AggFunc::Collect => TypedValue::List(self.collected.clone()),
217 }
218 }
219}
220
221fn typed_value_lt(a: &TypedValue, b: &TypedValue) -> bool {
222 match (a, b) {
223 (TypedValue::Int64(a), TypedValue::Int64(b)) => a < b,
224 (TypedValue::Int32(a), TypedValue::Int32(b)) => a < b,
225 (TypedValue::Double(a), TypedValue::Double(b)) => a < b,
226 (TypedValue::Float(a), TypedValue::Float(b)) => a < b,
227 (TypedValue::String(a), TypedValue::String(b)) => a < b,
228 _ => false,
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use crate::context::MockStorage;
236 use kyu_types::LogicalType;
237 use smol_str::SmolStr;
238
239 fn make_storage() -> MockStorage {
240 let mut storage = MockStorage::new();
241 storage.insert_table(
242 kyu_common::id::TableId(0),
243 vec![
244 vec![TypedValue::String(SmolStr::new("A")), TypedValue::Int64(10)],
245 vec![TypedValue::String(SmolStr::new("B")), TypedValue::Int64(20)],
246 vec![TypedValue::String(SmolStr::new("A")), TypedValue::Int64(30)],
247 ],
248 );
249 storage
250 }
251
252 #[test]
253 fn count_star_no_group_by() {
254 let storage = make_storage();
255 let ctx = ExecutionContext::new(kyu_catalog::CatalogContent::new(), &storage);
256 let scan = PhysicalOperator::ScanNode(crate::operators::scan::ScanNodeOp::new(
257 kyu_common::id::TableId(0),
258 ));
259 let mut agg = AggregateOp::new(
260 scan,
261 vec![],
262 vec![AggregateSpec {
263 function_name: SmolStr::new("count"),
264 resolved_func: AggFunc::Count,
265 arg: None,
266 distinct: false,
267 result_type: LogicalType::Int64,
268 alias: SmolStr::new("cnt"),
269 }],
270 );
271 let chunk = agg.next(&ctx).unwrap().unwrap();
272 assert_eq!(chunk.num_rows(), 1);
273 assert_eq!(chunk.get_row(0), vec![TypedValue::Int64(3)]);
274 }
275
276 #[test]
277 fn sum_with_group_by() {
278 let storage = make_storage();
279 let ctx = ExecutionContext::new(kyu_catalog::CatalogContent::new(), &storage);
280 let scan = PhysicalOperator::ScanNode(crate::operators::scan::ScanNodeOp::new(
281 kyu_common::id::TableId(0),
282 ));
283 let mut agg = AggregateOp::new(
284 scan,
285 vec![BoundExpression::Variable {
286 index: 0,
287 result_type: LogicalType::String,
288 }],
289 vec![AggregateSpec {
290 function_name: SmolStr::new("sum"),
291 resolved_func: AggFunc::Sum,
292 arg: Some(BoundExpression::Variable {
293 index: 1,
294 result_type: LogicalType::Int64,
295 }),
296 distinct: false,
297 result_type: LogicalType::Int64,
298 alias: SmolStr::new("total"),
299 }],
300 );
301 let chunk = agg.next(&ctx).unwrap().unwrap();
302 assert_eq!(chunk.num_rows(), 2); let row0 = chunk.get_row(0);
304 let row1 = chunk.get_row(1);
305 assert_eq!(row0[0], TypedValue::String(SmolStr::new("A")));
306 assert_eq!(row0[1], TypedValue::Int64(40));
307 assert_eq!(row1[0], TypedValue::String(SmolStr::new("B")));
308 assert_eq!(row1[1], TypedValue::Int64(20));
309 }
310
311 #[test]
312 fn count_star_empty_input() {
313 let storage = MockStorage::new();
314 let ctx = ExecutionContext::new(
315 kyu_catalog::CatalogContent::new(),
316 &storage,
317 );
318 let scan = PhysicalOperator::ScanNode(crate::operators::scan::ScanNodeOp::new(
319 kyu_common::id::TableId(99),
320 ));
321 let mut agg = AggregateOp::new(
322 scan,
323 vec![],
324 vec![AggregateSpec {
325 function_name: SmolStr::new("count"),
326 resolved_func: AggFunc::Count,
327 arg: None,
328 distinct: false,
329 result_type: LogicalType::Int64,
330 alias: SmolStr::new("cnt"),
331 }],
332 );
333 let chunk = agg.next(&ctx).unwrap().unwrap();
334 assert_eq!(chunk.num_rows(), 1);
335 assert_eq!(chunk.get_row(0), vec![TypedValue::Int64(0)]);
336 }
337}