1use alloc::boxed::Box;
24use alloc::collections::BTreeMap;
25use alloc::format;
26use alloc::string::{String, ToString};
27use alloc::vec::Vec;
28
29use spg_sql::ast::{Expr, SelectItem, SelectStatement};
30use spg_storage::{ColumnSchema, DataType, Row, Value};
31
32use crate::eval::{self, EvalContext, EvalError};
33
34pub fn uses_aggregate(stmt: &SelectStatement) -> bool {
36 if stmt.group_by.is_some() || stmt.having.is_some() {
37 return true;
38 }
39 for item in &stmt.items {
40 if let SelectItem::Expr { expr, .. } = item
41 && contains_aggregate(expr)
42 {
43 return true;
44 }
45 }
46 for o in &stmt.order_by {
47 if contains_aggregate(&o.expr) {
48 return true;
49 }
50 }
51 if let Some(h) = &stmt.having
52 && contains_aggregate(h)
53 {
54 return true;
55 }
56 false
57}
58
59pub fn contains_aggregate(e: &Expr) -> bool {
60 match e {
61 Expr::FunctionCall { name, args } => {
62 is_aggregate_name(name) || args.iter().any(contains_aggregate)
63 }
64 Expr::Binary { lhs, rhs, .. } => contains_aggregate(lhs) || contains_aggregate(rhs),
65 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
66 contains_aggregate(expr)
67 }
68 Expr::Like { expr, pattern, .. } => contains_aggregate(expr) || contains_aggregate(pattern),
69 Expr::Extract { source, .. } => contains_aggregate(source),
70 Expr::ScalarSubquery(_)
75 | Expr::Exists { .. }
76 | Expr::InSubquery { .. }
77 | Expr::WindowFunction { .. }
78 | Expr::Literal(_)
79 | Expr::Placeholder(_)
80 | Expr::Column(_) => false,
81 Expr::Array(items) => items.iter().any(contains_aggregate),
85 Expr::ArraySubscript { target, index } => {
86 contains_aggregate(target) || contains_aggregate(index)
87 }
88 Expr::AnyAll { expr, array, .. } => {
89 contains_aggregate(expr) || contains_aggregate(array)
90 }
91 }
92}
93
94pub fn is_aggregate_name(name: &str) -> bool {
95 matches!(
96 name.to_ascii_lowercase().as_str(),
97 "count" | "count_star" | "sum" | "min" | "max" | "avg"
98 )
99}
100
101#[derive(Debug, Default, Clone)]
103struct AggState {
104 count: i64,
105 sum_int: i64,
106 sum_float: f64,
107 extreme: Option<Value>,
108 use_float: bool,
109}
110
111#[derive(Debug, Clone)]
112struct AggSpec {
113 name: String, arg: Option<Expr>,
116}
117
118#[derive(Debug)]
121pub struct AggResult {
122 pub columns: Vec<ColumnSchema>,
123 pub rows: Vec<Row>,
124}
125
126#[allow(clippy::too_many_lines)]
129pub fn run(
130 stmt: &SelectStatement,
131 rows: &[&Row],
132 schema_cols: &[ColumnSchema],
133 table_alias: Option<&str>,
134) -> Result<AggResult, EvalError> {
135 let ctx = EvalContext::new(schema_cols, table_alias);
136 let group_exprs: Vec<Expr> = stmt.group_by.clone().unwrap_or_default();
137
138 let mut agg_specs: Vec<AggSpec> = Vec::new();
140 for item in &stmt.items {
141 if let SelectItem::Expr { expr, .. } = item {
142 collect_aggregates(expr, &mut agg_specs);
143 }
144 }
145 for o in &stmt.order_by {
146 collect_aggregates(&o.expr, &mut agg_specs);
147 }
148 if let Some(h) = &stmt.having {
149 collect_aggregates(h, &mut agg_specs);
150 }
151
152 let mut groups: BTreeMap<String, (Vec<Value>, Vec<AggState>)> = BTreeMap::new();
155 let mut key_order: Vec<String> = Vec::new();
156 if rows.is_empty() && group_exprs.is_empty() {
159 let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
161 groups.insert(String::new(), (Vec::new(), init));
162 key_order.push(String::new());
163 }
164
165 for row in rows {
166 let group_vals: Vec<Value> = group_exprs
167 .iter()
168 .map(|g| eval::eval_expr(g, row, &ctx))
169 .collect::<Result<_, _>>()?;
170 let key = encode_key(&group_vals);
171 let entry = groups.entry(key.clone()).or_insert_with(|| {
172 key_order.push(key.clone());
173 let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
174 (group_vals.clone(), init)
175 });
176 for (i, spec) in agg_specs.iter().enumerate() {
177 let arg_val = match &spec.arg {
178 None => Value::Bool(true), Some(e) => eval::eval_expr(e, row, &ctx)?,
180 };
181 update_state(&mut entry.1[i], &spec.name, &arg_val)?;
182 }
183 }
184
185 let group_types: Vec<DataType> = if rows.is_empty() {
187 group_exprs.iter().map(|_| DataType::Text).collect()
190 } else {
191 let probe = rows[0];
192 group_exprs
193 .iter()
194 .map(|g| {
195 eval::eval_expr(g, probe, &ctx).map(|v| v.data_type().unwrap_or(DataType::Text))
196 })
197 .collect::<Result<_, _>>()?
198 };
199 let agg_types: Vec<DataType> = agg_specs.iter().map(infer_agg_type).collect();
200 let mut synth_schema: Vec<ColumnSchema> = Vec::new();
201 for (i, ty) in group_types.iter().enumerate() {
202 synth_schema.push(ColumnSchema::new(format!("__grp_{i}"), *ty, true));
203 }
204 for (i, ty) in agg_types.iter().enumerate() {
205 synth_schema.push(ColumnSchema::new(format!("__agg_{i}"), *ty, true));
206 }
207
208 let mut synth_rows: Vec<Row> = Vec::new();
210 for k in &key_order {
211 let (gvals, states) = &groups[k];
212 let mut values: Vec<Value> = Vec::with_capacity(synth_schema.len());
213 values.extend(gvals.iter().cloned());
214 for (i, st) in states.iter().enumerate() {
215 values.push(finalize(&agg_specs[i].name, st));
216 }
217 synth_rows.push(Row::new(values));
218 }
219
220 let columns: Vec<ColumnSchema> = stmt
225 .items
226 .iter()
227 .map(|item| match item {
228 SelectItem::Wildcard => Err(EvalError::TypeMismatch {
229 detail: "SELECT * with aggregates is not supported".into(),
230 }),
231 SelectItem::Expr { expr, alias } => {
232 let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
233 let name = alias.clone().unwrap_or_else(|| expr.to_string());
234 Ok(ColumnSchema::new(
235 name,
236 agg_or_group_type(&rewritten, &synth_schema),
237 true,
238 ))
239 }
240 })
241 .collect::<Result<_, _>>()?;
242
243 let synth_ctx = EvalContext::new(&synth_schema, None);
248 let having_rewritten = stmt
249 .having
250 .as_ref()
251 .map(|h| rewrite_expr(h, &group_exprs, &agg_specs));
252 let mut kept_synth: Vec<Row> = Vec::new();
253 let mut out_rows: Vec<Row> = Vec::new();
254 for srow in synth_rows {
255 if let Some(h) = &having_rewritten {
256 let cond = eval::eval_expr(h, &srow, &synth_ctx)?;
257 if !matches!(cond, Value::Bool(true)) {
258 continue;
259 }
260 }
261 let mut values: Vec<Value> = Vec::with_capacity(columns.len());
262 for item in &stmt.items {
263 if let SelectItem::Expr { expr, .. } = item {
264 let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
265 values.push(eval::eval_expr(&rewritten, &srow, &synth_ctx)?);
266 }
267 }
268 kept_synth.push(srow);
269 out_rows.push(Row::new(values));
270 }
271
272 if !stmt.order_by.is_empty() {
275 let rewritten: Vec<Expr> = stmt
278 .order_by
279 .iter()
280 .map(|o| rewrite_expr(&o.expr, &group_exprs, &agg_specs))
281 .collect();
282 let descs: Vec<bool> = stmt.order_by.iter().map(|o| o.desc).collect();
283 let mut tagged: Vec<(Vec<Value>, Row)> = kept_synth
284 .into_iter()
285 .zip(out_rows)
286 .map(|(s, o)| {
287 let mut keys = Vec::with_capacity(rewritten.len());
288 for e in &rewritten {
289 keys.push(eval::eval_expr(e, &s, &synth_ctx)?);
290 }
291 Ok::<_, EvalError>((keys, o))
292 })
293 .collect::<Result<_, _>>()?;
294 tagged.sort_by(|a, b| {
295 use core::cmp::Ordering;
296 for (i, (ka, kb)) in a.0.iter().zip(b.0.iter()).enumerate() {
297 let cmp = value_cmp(ka, kb);
298 let cmp = if descs[i] { cmp.reverse() } else { cmp };
299 if cmp != Ordering::Equal {
300 return cmp;
301 }
302 }
303 Ordering::Equal
304 });
305 out_rows = tagged.into_iter().map(|(_, o)| o).collect();
306 }
307
308 Ok(AggResult {
309 columns,
310 rows: out_rows,
311 })
312}
313
314fn collect_aggregates(e: &Expr, out: &mut Vec<AggSpec>) {
315 match e {
316 Expr::FunctionCall { name, args } => {
317 let lower = name.to_ascii_lowercase();
318 if is_aggregate_name(&lower) {
319 let arg = if lower == "count_star" {
320 None
321 } else {
322 args.first().cloned()
323 };
324 let spec = AggSpec {
325 name: lower,
326 arg: arg.clone(),
327 };
328 if !out.iter().any(|s| s.name == spec.name && s.arg == spec.arg) {
329 out.push(spec);
330 }
331 } else {
334 for a in args {
335 collect_aggregates(a, out);
336 }
337 }
338 }
339 Expr::Binary { lhs, rhs, .. } => {
340 collect_aggregates(lhs, out);
341 collect_aggregates(rhs, out);
342 }
343 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
344 collect_aggregates(expr, out);
345 }
346 Expr::Like { expr, pattern, .. } => {
347 collect_aggregates(expr, out);
348 collect_aggregates(pattern, out);
349 }
350 Expr::Extract { source, .. } => collect_aggregates(source, out),
351 Expr::ScalarSubquery(_)
354 | Expr::Exists { .. }
355 | Expr::InSubquery { .. }
356 | Expr::WindowFunction { .. }
357 | Expr::Literal(_)
358 | Expr::Placeholder(_)
359 | Expr::Column(_) => {}
360 Expr::Array(items) => {
363 for elem in items {
364 collect_aggregates(elem, out);
365 }
366 }
367 Expr::ArraySubscript { target, index } => {
368 collect_aggregates(target, out);
369 collect_aggregates(index, out);
370 }
371 Expr::AnyAll { expr, array, .. } => {
372 collect_aggregates(expr, out);
373 collect_aggregates(array, out);
374 }
375 }
376}
377
378fn update_state(st: &mut AggState, name: &str, v: &Value) -> Result<(), EvalError> {
379 let is_null = matches!(v, Value::Null);
380 match name {
381 "count_star" => st.count += 1,
382 "count" => {
383 if !is_null {
384 st.count += 1;
385 }
386 }
387 "sum" | "avg" => {
388 if is_null {
389 return Ok(());
390 }
391 st.count += 1;
392 match v {
393 Value::Int(n) => st.sum_int += i64::from(*n),
394 Value::BigInt(n) => st.sum_int += *n,
395 Value::Float(x) => {
396 st.use_float = true;
397 st.sum_float += *x;
398 }
399 other => {
400 return Err(EvalError::TypeMismatch {
401 detail: format!("sum/avg need numeric, got {:?}", other.data_type()),
402 });
403 }
404 }
405 }
406 "min" => {
407 if is_null {
408 return Ok(());
409 }
410 match &st.extreme {
411 None => st.extreme = Some(v.clone()),
412 Some(cur) => {
413 if value_cmp(v, cur) == core::cmp::Ordering::Less {
414 st.extreme = Some(v.clone());
415 }
416 }
417 }
418 }
419 "max" => {
420 if is_null {
421 return Ok(());
422 }
423 match &st.extreme {
424 None => st.extreme = Some(v.clone()),
425 Some(cur) => {
426 if value_cmp(v, cur) == core::cmp::Ordering::Greater {
427 st.extreme = Some(v.clone());
428 }
429 }
430 }
431 }
432 _ => unreachable!("non-aggregate {name} in update_state"),
433 }
434 Ok(())
435}
436
437#[allow(clippy::cast_precision_loss)]
438fn finalize(name: &str, st: &AggState) -> Value {
439 match name {
440 "count" | "count_star" => Value::BigInt(st.count),
441 "sum" => {
442 if st.count == 0 {
443 Value::Null
444 } else if st.use_float {
445 Value::Float(st.sum_float + (st.sum_int as f64))
446 } else {
447 Value::BigInt(st.sum_int)
448 }
449 }
450 "avg" => {
451 if st.count == 0 {
452 Value::Null
453 } else {
454 let total = if st.use_float {
455 st.sum_float + (st.sum_int as f64)
456 } else {
457 st.sum_int as f64
458 };
459 Value::Float(total / (st.count as f64))
460 }
461 }
462 "min" | "max" => st.extreme.clone().unwrap_or(Value::Null),
463 _ => unreachable!(),
464 }
465}
466
467fn infer_agg_type(spec: &AggSpec) -> DataType {
468 match spec.name.as_str() {
469 "count" | "count_star" | "sum" => DataType::BigInt,
473 "avg" => DataType::Float,
474 _ => DataType::Text,
477 }
478}
479
480fn agg_or_group_type(e: &Expr, synth: &[ColumnSchema]) -> DataType {
481 if let Expr::Column(c) = e
482 && let Some(s) = synth.iter().find(|s| s.name == c.name)
483 {
484 return s.ty;
485 }
486 DataType::Text
489}
490
491fn rewrite_expr(e: &Expr, group_exprs: &[Expr], aggs: &[AggSpec]) -> Expr {
492 if let Expr::FunctionCall { name, args } = e {
494 let lower = name.to_ascii_lowercase();
495 if is_aggregate_name(&lower) {
496 let arg = if lower == "count_star" {
497 None
498 } else {
499 args.first().cloned()
500 };
501 for (i, spec) in aggs.iter().enumerate() {
502 if spec.name == lower && spec.arg == arg {
503 return Expr::Column(spg_sql::ast::ColumnName {
504 qualifier: None,
505 name: format!("__agg_{i}"),
506 });
507 }
508 }
509 }
510 }
511 for (i, g) in group_exprs.iter().enumerate() {
513 if g == e {
514 return Expr::Column(spg_sql::ast::ColumnName {
515 qualifier: None,
516 name: format!("__grp_{i}"),
517 });
518 }
519 }
520 match e {
522 Expr::Binary { lhs, op, rhs } => Expr::Binary {
523 lhs: Box::new(rewrite_expr(lhs, group_exprs, aggs)),
524 op: *op,
525 rhs: Box::new(rewrite_expr(rhs, group_exprs, aggs)),
526 },
527 Expr::Unary { op, expr } => Expr::Unary {
528 op: *op,
529 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
530 },
531 Expr::Cast { expr, target } => Expr::Cast {
532 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
533 target: *target,
534 },
535 Expr::IsNull { expr, negated } => Expr::IsNull {
536 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
537 negated: *negated,
538 },
539 Expr::FunctionCall { name, args } => Expr::FunctionCall {
540 name: name.clone(),
541 args: args
542 .iter()
543 .map(|a| rewrite_expr(a, group_exprs, aggs))
544 .collect(),
545 },
546 Expr::Like {
547 expr,
548 pattern,
549 negated,
550 } => Expr::Like {
551 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
552 pattern: Box::new(rewrite_expr(pattern, group_exprs, aggs)),
553 negated: *negated,
554 },
555 Expr::Extract { field, source } => Expr::Extract {
556 field: *field,
557 source: Box::new(rewrite_expr(source, group_exprs, aggs)),
558 },
559 Expr::ScalarSubquery(_)
562 | Expr::Exists { .. }
563 | Expr::InSubquery { .. }
564 | Expr::WindowFunction { .. }
565 | Expr::Literal(_)
566 | Expr::Placeholder(_)
567 | Expr::Column(_) => e.clone(),
568 Expr::Array(items) => Expr::Array(
570 items
571 .iter()
572 .map(|elem| rewrite_expr(elem, group_exprs, aggs))
573 .collect(),
574 ),
575 Expr::ArraySubscript { target, index } => Expr::ArraySubscript {
576 target: Box::new(rewrite_expr(target, group_exprs, aggs)),
577 index: Box::new(rewrite_expr(index, group_exprs, aggs)),
578 },
579 Expr::AnyAll {
580 expr,
581 op,
582 array,
583 is_any,
584 } => Expr::AnyAll {
585 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
586 op: *op,
587 array: Box::new(rewrite_expr(array, group_exprs, aggs)),
588 is_any: *is_any,
589 },
590 }
591}
592
593fn encode_key(vals: &[Value]) -> String {
595 let mut out = String::new();
596 for v in vals {
597 match v {
598 Value::Null => out.push_str("N|"),
599 Value::SmallInt(n) => {
600 out.push('s');
601 out.push_str(&n.to_string());
602 out.push('|');
603 }
604 Value::Int(n) => {
605 out.push('I');
606 out.push_str(&n.to_string());
607 out.push('|');
608 }
609 Value::BigInt(n) => {
610 out.push('B');
611 out.push_str(&n.to_string());
612 out.push('|');
613 }
614 Value::Float(x) => {
615 out.push('F');
616 out.push_str(&x.to_string());
617 out.push('|');
618 }
619 Value::Bool(b) => {
620 out.push(if *b { 'T' } else { 'f' });
621 out.push('|');
622 }
623 Value::Text(s) => {
624 out.push('S');
625 out.push_str(s);
626 out.push('|');
627 }
628 Value::Vector(v) => {
629 out.push('V');
630 for x in v {
631 out.push_str(&x.to_string());
632 out.push(',');
633 }
634 out.push('|');
635 }
636 Value::Sq8Vector(q) => {
642 out.push('Q');
643 out.push_str(&q.min.to_string());
644 out.push('@');
645 out.push_str(&q.max.to_string());
646 out.push(':');
647 for b in &q.bytes {
648 out.push_str(&b.to_string());
649 out.push(',');
650 }
651 out.push('|');
652 }
653 Value::HalfVector(h) => {
657 out.push('H');
658 for b in &h.bytes {
659 out.push_str(&b.to_string());
660 out.push(',');
661 }
662 out.push('|');
663 }
664 Value::Numeric { scaled, scale } => {
665 out.push('D');
666 out.push_str(&scaled.to_string());
667 out.push('@');
668 out.push_str(&scale.to_string());
669 out.push('|');
670 }
671 Value::Date(d) => {
672 out.push('d');
673 out.push_str(&d.to_string());
674 out.push('|');
675 }
676 Value::Timestamp(t) => {
677 out.push('t');
678 out.push_str(&t.to_string());
679 out.push('|');
680 }
681 Value::Interval { months, micros } => {
682 out.push('i');
683 out.push_str(&months.to_string());
684 out.push('m');
685 out.push_str(µs.to_string());
686 out.push('|');
687 }
688 Value::Json(s) => {
689 out.push('j');
690 out.push_str(s);
691 out.push('|');
692 }
693 _ => {
698 out.push('?');
699 out.push_str(&format!("{v:?}"));
700 out.push('|');
701 }
702 }
703 }
704 out
705}
706
707#[allow(clippy::cast_precision_loss)]
708fn value_cmp(a: &Value, b: &Value) -> core::cmp::Ordering {
709 use core::cmp::Ordering::Equal;
710 match (a, b) {
711 (Value::Null, Value::Null) => Equal,
712 (Value::Null, _) => core::cmp::Ordering::Greater, (_, Value::Null) => core::cmp::Ordering::Less,
714 (Value::Int(x), Value::Int(y)) => x.cmp(y),
715 (Value::BigInt(x), Value::BigInt(y)) => x.cmp(y),
716 (Value::Int(x), Value::BigInt(y)) => i64::from(*x).cmp(y),
717 (Value::BigInt(x), Value::Int(y)) => x.cmp(&i64::from(*y)),
718 (Value::Float(x), Value::Float(y)) => x.partial_cmp(y).unwrap_or(Equal),
719 (Value::Int(x), Value::Float(y)) => f64::from(*x).partial_cmp(y).unwrap_or(Equal),
720 (Value::Float(x), Value::Int(y)) => x.partial_cmp(&f64::from(*y)).unwrap_or(Equal),
721 (Value::BigInt(x), Value::Float(y)) => (*x as f64).partial_cmp(y).unwrap_or(Equal),
722 (Value::Float(x), Value::BigInt(y)) => x.partial_cmp(&(*y as f64)).unwrap_or(Equal),
723 (Value::Text(x), Value::Text(y)) => x.cmp(y),
724 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
725 _ => Equal,
726 }
727}