1use alloc::boxed::Box;
24use alloc::collections::BTreeMap;
25use alloc::format;
26use alloc::string::{String, ToString};
27use alloc::vec::Vec;
28
29use spg_sql::ast::{Expr, SelectItem, SelectStatement};
30use spg_storage::{ColumnSchema, DataType, Row, Value};
31
32use crate::eval::{self, EvalContext, EvalError};
33
34pub fn uses_aggregate(stmt: &SelectStatement) -> bool {
36 if stmt.group_by.is_some() || stmt.having.is_some() {
37 return true;
38 }
39 for item in &stmt.items {
40 if let SelectItem::Expr { expr, .. } = item
41 && contains_aggregate(expr)
42 {
43 return true;
44 }
45 }
46 for o in &stmt.order_by {
47 if contains_aggregate(&o.expr) {
48 return true;
49 }
50 }
51 if let Some(h) = &stmt.having
52 && contains_aggregate(h)
53 {
54 return true;
55 }
56 false
57}
58
59pub fn contains_aggregate(e: &Expr) -> bool {
60 match e {
61 Expr::FunctionCall { name, args } => {
62 is_aggregate_name(name) || args.iter().any(contains_aggregate)
63 }
64 Expr::Binary { lhs, rhs, .. } => contains_aggregate(lhs) || contains_aggregate(rhs),
65 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
66 contains_aggregate(expr)
67 }
68 Expr::Like { expr, pattern, .. } => contains_aggregate(expr) || contains_aggregate(pattern),
69 Expr::Extract { source, .. } => contains_aggregate(source),
70 Expr::ScalarSubquery(_)
75 | Expr::Exists { .. }
76 | Expr::InSubquery { .. }
77 | Expr::WindowFunction { .. }
78 | Expr::Literal(_)
79 | Expr::Placeholder(_)
80 | Expr::Column(_) => false,
81 }
82}
83
84pub fn is_aggregate_name(name: &str) -> bool {
85 matches!(
86 name.to_ascii_lowercase().as_str(),
87 "count" | "count_star" | "sum" | "min" | "max" | "avg"
88 )
89}
90
91#[derive(Debug, Default, Clone)]
93struct AggState {
94 count: i64,
95 sum_int: i64,
96 sum_float: f64,
97 extreme: Option<Value>,
98 use_float: bool,
99}
100
101#[derive(Debug, Clone)]
102struct AggSpec {
103 name: String, arg: Option<Expr>,
106}
107
108#[derive(Debug)]
111pub struct AggResult {
112 pub columns: Vec<ColumnSchema>,
113 pub rows: Vec<Row>,
114}
115
116#[allow(clippy::too_many_lines)]
119pub fn run(
120 stmt: &SelectStatement,
121 rows: &[&Row],
122 schema_cols: &[ColumnSchema],
123 table_alias: Option<&str>,
124) -> Result<AggResult, EvalError> {
125 let ctx = EvalContext::new(schema_cols, table_alias);
126 let group_exprs: Vec<Expr> = stmt.group_by.clone().unwrap_or_default();
127
128 let mut agg_specs: Vec<AggSpec> = Vec::new();
130 for item in &stmt.items {
131 if let SelectItem::Expr { expr, .. } = item {
132 collect_aggregates(expr, &mut agg_specs);
133 }
134 }
135 for o in &stmt.order_by {
136 collect_aggregates(&o.expr, &mut agg_specs);
137 }
138 if let Some(h) = &stmt.having {
139 collect_aggregates(h, &mut agg_specs);
140 }
141
142 let mut groups: BTreeMap<String, (Vec<Value>, Vec<AggState>)> = BTreeMap::new();
145 let mut key_order: Vec<String> = Vec::new();
146 if rows.is_empty() && group_exprs.is_empty() {
149 let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
151 groups.insert(String::new(), (Vec::new(), init));
152 key_order.push(String::new());
153 }
154
155 for row in rows {
156 let group_vals: Vec<Value> = group_exprs
157 .iter()
158 .map(|g| eval::eval_expr(g, row, &ctx))
159 .collect::<Result<_, _>>()?;
160 let key = encode_key(&group_vals);
161 let entry = groups.entry(key.clone()).or_insert_with(|| {
162 key_order.push(key.clone());
163 let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
164 (group_vals.clone(), init)
165 });
166 for (i, spec) in agg_specs.iter().enumerate() {
167 let arg_val = match &spec.arg {
168 None => Value::Bool(true), Some(e) => eval::eval_expr(e, row, &ctx)?,
170 };
171 update_state(&mut entry.1[i], &spec.name, &arg_val)?;
172 }
173 }
174
175 let group_types: Vec<DataType> = if rows.is_empty() {
177 group_exprs.iter().map(|_| DataType::Text).collect()
180 } else {
181 let probe = rows[0];
182 group_exprs
183 .iter()
184 .map(|g| {
185 eval::eval_expr(g, probe, &ctx).map(|v| v.data_type().unwrap_or(DataType::Text))
186 })
187 .collect::<Result<_, _>>()?
188 };
189 let agg_types: Vec<DataType> = agg_specs.iter().map(infer_agg_type).collect();
190 let mut synth_schema: Vec<ColumnSchema> = Vec::new();
191 for (i, ty) in group_types.iter().enumerate() {
192 synth_schema.push(ColumnSchema::new(format!("__grp_{i}"), *ty, true));
193 }
194 for (i, ty) in agg_types.iter().enumerate() {
195 synth_schema.push(ColumnSchema::new(format!("__agg_{i}"), *ty, true));
196 }
197
198 let mut synth_rows: Vec<Row> = Vec::new();
200 for k in &key_order {
201 let (gvals, states) = &groups[k];
202 let mut values: Vec<Value> = Vec::with_capacity(synth_schema.len());
203 values.extend(gvals.iter().cloned());
204 for (i, st) in states.iter().enumerate() {
205 values.push(finalize(&agg_specs[i].name, st));
206 }
207 synth_rows.push(Row::new(values));
208 }
209
210 let columns: Vec<ColumnSchema> = stmt
215 .items
216 .iter()
217 .map(|item| match item {
218 SelectItem::Wildcard => Err(EvalError::TypeMismatch {
219 detail: "SELECT * with aggregates is not supported".into(),
220 }),
221 SelectItem::Expr { expr, alias } => {
222 let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
223 let name = alias.clone().unwrap_or_else(|| expr.to_string());
224 Ok(ColumnSchema::new(
225 name,
226 agg_or_group_type(&rewritten, &synth_schema),
227 true,
228 ))
229 }
230 })
231 .collect::<Result<_, _>>()?;
232
233 let synth_ctx = EvalContext::new(&synth_schema, None);
238 let having_rewritten = stmt
239 .having
240 .as_ref()
241 .map(|h| rewrite_expr(h, &group_exprs, &agg_specs));
242 let mut kept_synth: Vec<Row> = Vec::new();
243 let mut out_rows: Vec<Row> = Vec::new();
244 for srow in synth_rows {
245 if let Some(h) = &having_rewritten {
246 let cond = eval::eval_expr(h, &srow, &synth_ctx)?;
247 if !matches!(cond, Value::Bool(true)) {
248 continue;
249 }
250 }
251 let mut values: Vec<Value> = Vec::with_capacity(columns.len());
252 for item in &stmt.items {
253 if let SelectItem::Expr { expr, .. } = item {
254 let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
255 values.push(eval::eval_expr(&rewritten, &srow, &synth_ctx)?);
256 }
257 }
258 kept_synth.push(srow);
259 out_rows.push(Row::new(values));
260 }
261
262 if !stmt.order_by.is_empty() {
265 let rewritten: Vec<Expr> = stmt
268 .order_by
269 .iter()
270 .map(|o| rewrite_expr(&o.expr, &group_exprs, &agg_specs))
271 .collect();
272 let descs: Vec<bool> = stmt.order_by.iter().map(|o| o.desc).collect();
273 let mut tagged: Vec<(Vec<Value>, Row)> = kept_synth
274 .into_iter()
275 .zip(out_rows)
276 .map(|(s, o)| {
277 let mut keys = Vec::with_capacity(rewritten.len());
278 for e in &rewritten {
279 keys.push(eval::eval_expr(e, &s, &synth_ctx)?);
280 }
281 Ok::<_, EvalError>((keys, o))
282 })
283 .collect::<Result<_, _>>()?;
284 tagged.sort_by(|a, b| {
285 use core::cmp::Ordering;
286 for (i, (ka, kb)) in a.0.iter().zip(b.0.iter()).enumerate() {
287 let cmp = value_cmp(ka, kb);
288 let cmp = if descs[i] { cmp.reverse() } else { cmp };
289 if cmp != Ordering::Equal {
290 return cmp;
291 }
292 }
293 Ordering::Equal
294 });
295 out_rows = tagged.into_iter().map(|(_, o)| o).collect();
296 }
297
298 Ok(AggResult {
299 columns,
300 rows: out_rows,
301 })
302}
303
304fn collect_aggregates(e: &Expr, out: &mut Vec<AggSpec>) {
305 match e {
306 Expr::FunctionCall { name, args } => {
307 let lower = name.to_ascii_lowercase();
308 if is_aggregate_name(&lower) {
309 let arg = if lower == "count_star" {
310 None
311 } else {
312 args.first().cloned()
313 };
314 let spec = AggSpec {
315 name: lower,
316 arg: arg.clone(),
317 };
318 if !out.iter().any(|s| s.name == spec.name && s.arg == spec.arg) {
319 out.push(spec);
320 }
321 } else {
324 for a in args {
325 collect_aggregates(a, out);
326 }
327 }
328 }
329 Expr::Binary { lhs, rhs, .. } => {
330 collect_aggregates(lhs, out);
331 collect_aggregates(rhs, out);
332 }
333 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
334 collect_aggregates(expr, out);
335 }
336 Expr::Like { expr, pattern, .. } => {
337 collect_aggregates(expr, out);
338 collect_aggregates(pattern, out);
339 }
340 Expr::Extract { source, .. } => collect_aggregates(source, out),
341 Expr::ScalarSubquery(_)
344 | Expr::Exists { .. }
345 | Expr::InSubquery { .. }
346 | Expr::WindowFunction { .. }
347 | Expr::Literal(_)
348 | Expr::Placeholder(_)
349 | Expr::Column(_) => {}
350 }
351}
352
353fn update_state(st: &mut AggState, name: &str, v: &Value) -> Result<(), EvalError> {
354 let is_null = matches!(v, Value::Null);
355 match name {
356 "count_star" => st.count += 1,
357 "count" => {
358 if !is_null {
359 st.count += 1;
360 }
361 }
362 "sum" | "avg" => {
363 if is_null {
364 return Ok(());
365 }
366 st.count += 1;
367 match v {
368 Value::Int(n) => st.sum_int += i64::from(*n),
369 Value::BigInt(n) => st.sum_int += *n,
370 Value::Float(x) => {
371 st.use_float = true;
372 st.sum_float += *x;
373 }
374 other => {
375 return Err(EvalError::TypeMismatch {
376 detail: format!("sum/avg need numeric, got {:?}", other.data_type()),
377 });
378 }
379 }
380 }
381 "min" => {
382 if is_null {
383 return Ok(());
384 }
385 match &st.extreme {
386 None => st.extreme = Some(v.clone()),
387 Some(cur) => {
388 if value_cmp(v, cur) == core::cmp::Ordering::Less {
389 st.extreme = Some(v.clone());
390 }
391 }
392 }
393 }
394 "max" => {
395 if is_null {
396 return Ok(());
397 }
398 match &st.extreme {
399 None => st.extreme = Some(v.clone()),
400 Some(cur) => {
401 if value_cmp(v, cur) == core::cmp::Ordering::Greater {
402 st.extreme = Some(v.clone());
403 }
404 }
405 }
406 }
407 _ => unreachable!("non-aggregate {name} in update_state"),
408 }
409 Ok(())
410}
411
412#[allow(clippy::cast_precision_loss)]
413fn finalize(name: &str, st: &AggState) -> Value {
414 match name {
415 "count" | "count_star" => Value::BigInt(st.count),
416 "sum" => {
417 if st.count == 0 {
418 Value::Null
419 } else if st.use_float {
420 Value::Float(st.sum_float + (st.sum_int as f64))
421 } else {
422 Value::BigInt(st.sum_int)
423 }
424 }
425 "avg" => {
426 if st.count == 0 {
427 Value::Null
428 } else {
429 let total = if st.use_float {
430 st.sum_float + (st.sum_int as f64)
431 } else {
432 st.sum_int as f64
433 };
434 Value::Float(total / (st.count as f64))
435 }
436 }
437 "min" | "max" => st.extreme.clone().unwrap_or(Value::Null),
438 _ => unreachable!(),
439 }
440}
441
442fn infer_agg_type(spec: &AggSpec) -> DataType {
443 match spec.name.as_str() {
444 "count" | "count_star" | "sum" => DataType::BigInt,
448 "avg" => DataType::Float,
449 _ => DataType::Text,
452 }
453}
454
455fn agg_or_group_type(e: &Expr, synth: &[ColumnSchema]) -> DataType {
456 if let Expr::Column(c) = e
457 && let Some(s) = synth.iter().find(|s| s.name == c.name)
458 {
459 return s.ty;
460 }
461 DataType::Text
464}
465
466fn rewrite_expr(e: &Expr, group_exprs: &[Expr], aggs: &[AggSpec]) -> Expr {
467 if let Expr::FunctionCall { name, args } = e {
469 let lower = name.to_ascii_lowercase();
470 if is_aggregate_name(&lower) {
471 let arg = if lower == "count_star" {
472 None
473 } else {
474 args.first().cloned()
475 };
476 for (i, spec) in aggs.iter().enumerate() {
477 if spec.name == lower && spec.arg == arg {
478 return Expr::Column(spg_sql::ast::ColumnName {
479 qualifier: None,
480 name: format!("__agg_{i}"),
481 });
482 }
483 }
484 }
485 }
486 for (i, g) in group_exprs.iter().enumerate() {
488 if g == e {
489 return Expr::Column(spg_sql::ast::ColumnName {
490 qualifier: None,
491 name: format!("__grp_{i}"),
492 });
493 }
494 }
495 match e {
497 Expr::Binary { lhs, op, rhs } => Expr::Binary {
498 lhs: Box::new(rewrite_expr(lhs, group_exprs, aggs)),
499 op: *op,
500 rhs: Box::new(rewrite_expr(rhs, group_exprs, aggs)),
501 },
502 Expr::Unary { op, expr } => Expr::Unary {
503 op: *op,
504 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
505 },
506 Expr::Cast { expr, target } => Expr::Cast {
507 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
508 target: *target,
509 },
510 Expr::IsNull { expr, negated } => Expr::IsNull {
511 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
512 negated: *negated,
513 },
514 Expr::FunctionCall { name, args } => Expr::FunctionCall {
515 name: name.clone(),
516 args: args
517 .iter()
518 .map(|a| rewrite_expr(a, group_exprs, aggs))
519 .collect(),
520 },
521 Expr::Like {
522 expr,
523 pattern,
524 negated,
525 } => Expr::Like {
526 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
527 pattern: Box::new(rewrite_expr(pattern, group_exprs, aggs)),
528 negated: *negated,
529 },
530 Expr::Extract { field, source } => Expr::Extract {
531 field: *field,
532 source: Box::new(rewrite_expr(source, group_exprs, aggs)),
533 },
534 Expr::ScalarSubquery(_)
537 | Expr::Exists { .. }
538 | Expr::InSubquery { .. }
539 | Expr::WindowFunction { .. }
540 | Expr::Literal(_)
541 | Expr::Placeholder(_)
542 | Expr::Column(_) => e.clone(),
543 }
544}
545
546fn encode_key(vals: &[Value]) -> String {
548 let mut out = String::new();
549 for v in vals {
550 match v {
551 Value::Null => out.push_str("N|"),
552 Value::SmallInt(n) => {
553 out.push('s');
554 out.push_str(&n.to_string());
555 out.push('|');
556 }
557 Value::Int(n) => {
558 out.push('I');
559 out.push_str(&n.to_string());
560 out.push('|');
561 }
562 Value::BigInt(n) => {
563 out.push('B');
564 out.push_str(&n.to_string());
565 out.push('|');
566 }
567 Value::Float(x) => {
568 out.push('F');
569 out.push_str(&x.to_string());
570 out.push('|');
571 }
572 Value::Bool(b) => {
573 out.push(if *b { 'T' } else { 'f' });
574 out.push('|');
575 }
576 Value::Text(s) => {
577 out.push('S');
578 out.push_str(s);
579 out.push('|');
580 }
581 Value::Vector(v) => {
582 out.push('V');
583 for x in v {
584 out.push_str(&x.to_string());
585 out.push(',');
586 }
587 out.push('|');
588 }
589 Value::Sq8Vector(q) => {
595 out.push('Q');
596 out.push_str(&q.min.to_string());
597 out.push('@');
598 out.push_str(&q.max.to_string());
599 out.push(':');
600 for b in &q.bytes {
601 out.push_str(&b.to_string());
602 out.push(',');
603 }
604 out.push('|');
605 }
606 Value::HalfVector(h) => {
610 out.push('H');
611 for b in &h.bytes {
612 out.push_str(&b.to_string());
613 out.push(',');
614 }
615 out.push('|');
616 }
617 Value::Numeric { scaled, scale } => {
618 out.push('D');
619 out.push_str(&scaled.to_string());
620 out.push('@');
621 out.push_str(&scale.to_string());
622 out.push('|');
623 }
624 Value::Date(d) => {
625 out.push('d');
626 out.push_str(&d.to_string());
627 out.push('|');
628 }
629 Value::Timestamp(t) => {
630 out.push('t');
631 out.push_str(&t.to_string());
632 out.push('|');
633 }
634 Value::Interval { months, micros } => {
635 out.push('i');
636 out.push_str(&months.to_string());
637 out.push('m');
638 out.push_str(µs.to_string());
639 out.push('|');
640 }
641 Value::Json(s) => {
642 out.push('j');
643 out.push_str(s);
644 out.push('|');
645 }
646 _ => {
651 out.push('?');
652 out.push_str(&format!("{v:?}"));
653 out.push('|');
654 }
655 }
656 }
657 out
658}
659
660#[allow(clippy::cast_precision_loss)]
661fn value_cmp(a: &Value, b: &Value) -> core::cmp::Ordering {
662 use core::cmp::Ordering::Equal;
663 match (a, b) {
664 (Value::Null, Value::Null) => Equal,
665 (Value::Null, _) => core::cmp::Ordering::Greater, (_, Value::Null) => core::cmp::Ordering::Less,
667 (Value::Int(x), Value::Int(y)) => x.cmp(y),
668 (Value::BigInt(x), Value::BigInt(y)) => x.cmp(y),
669 (Value::Int(x), Value::BigInt(y)) => i64::from(*x).cmp(y),
670 (Value::BigInt(x), Value::Int(y)) => x.cmp(&i64::from(*y)),
671 (Value::Float(x), Value::Float(y)) => x.partial_cmp(y).unwrap_or(Equal),
672 (Value::Int(x), Value::Float(y)) => f64::from(*x).partial_cmp(y).unwrap_or(Equal),
673 (Value::Float(x), Value::Int(y)) => x.partial_cmp(&f64::from(*y)).unwrap_or(Equal),
674 (Value::BigInt(x), Value::Float(y)) => (*x as f64).partial_cmp(y).unwrap_or(Equal),
675 (Value::Float(x), Value::BigInt(y)) => x.partial_cmp(&(*y as f64)).unwrap_or(Equal),
676 (Value::Text(x), Value::Text(y)) => x.cmp(y),
677 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
678 _ => Equal,
679 }
680}