1use alloc::boxed::Box;
24use alloc::collections::BTreeSet;
25use alloc::format;
26use alloc::string::{String, ToString};
27use alloc::vec::Vec;
28
29use spg_sql::ast::{Expr, SelectItem, SelectStatement};
30use spg_storage::{ColumnSchema, DataType, Row, Value};
31
32use crate::eval::{self, EvalContext, EvalError};
33
34pub fn uses_aggregate(stmt: &SelectStatement) -> bool {
36 if stmt.group_by.is_some() || stmt.having.is_some() {
37 return true;
38 }
39 for item in &stmt.items {
40 if let SelectItem::Expr { expr, .. } = item
41 && contains_aggregate(expr)
42 {
43 return true;
44 }
45 }
46 for o in &stmt.order_by {
47 if contains_aggregate(&o.expr) {
48 return true;
49 }
50 }
51 if let Some(h) = &stmt.having
52 && contains_aggregate(h)
53 {
54 return true;
55 }
56 false
57}
58
59pub fn contains_aggregate(e: &Expr) -> bool {
60 match e {
61 Expr::FunctionCall { name, args } => {
62 is_aggregate_name(name) || args.iter().any(contains_aggregate)
63 }
64 Expr::AggregateOrdered { .. } => true,
65 Expr::Binary { lhs, rhs, .. } => contains_aggregate(lhs) || contains_aggregate(rhs),
66 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
67 contains_aggregate(expr)
68 }
69 Expr::Like { expr, pattern, .. } => contains_aggregate(expr) || contains_aggregate(pattern),
70 Expr::Extract { source, .. } => contains_aggregate(source),
71 Expr::ScalarSubquery(_)
76 | Expr::Exists { .. }
77 | Expr::InSubquery { .. }
78 | Expr::WindowFunction { .. }
79 | Expr::Literal(_)
80 | Expr::Placeholder(_)
81 | Expr::Column(_) => false,
82 Expr::Array(items) => items.iter().any(contains_aggregate),
86 Expr::ArraySubscript { target, index } => {
87 contains_aggregate(target) || contains_aggregate(index)
88 }
89 Expr::AnyAll { expr, array, .. } => contains_aggregate(expr) || contains_aggregate(array),
90 Expr::Case {
93 operand,
94 branches,
95 else_branch,
96 } => {
97 operand.as_deref().is_some_and(contains_aggregate)
98 || branches
99 .iter()
100 .any(|(w, t)| contains_aggregate(w) || contains_aggregate(t))
101 || else_branch.as_deref().is_some_and(contains_aggregate)
102 }
103 }
104}
105
106pub fn is_aggregate_name(name: &str) -> bool {
107 matches!(
108 name.to_ascii_lowercase().as_str(),
109 "count"
110 | "count_star"
111 | "sum"
112 | "min"
113 | "max"
114 | "avg"
115 | "string_agg"
120 | "array_agg"
121 | "bool_and"
124 | "bool_or"
125 | "every"
126 )
127}
128
129#[derive(Debug, Default, Clone)]
131struct AggState {
132 count: i64,
133 sum_int: i64,
134 sum_float: f64,
135 extreme: Option<Value>,
136 use_float: bool,
137 items: Vec<Value>,
144 seen: BTreeSet<String>,
148 item_keys: Vec<Vec<Value>>,
152 separator: Option<String>,
158 bool_acc: Option<bool>,
162}
163
164#[derive(Debug, Clone)]
165struct AggSpec {
166 name: String, arg: Option<Expr>,
170 arg2: Option<Expr>,
176 distinct: bool,
179 order_by: Vec<spg_sql::ast::OrderBy>,
185}
186
187#[derive(Debug)]
190pub struct AggResult {
191 pub columns: Vec<ColumnSchema>,
192 pub rows: Vec<Row>,
193}
194
195#[allow(clippy::too_many_lines)]
198pub type CorrelatedEval<'a> = &'a dyn Fn(&Expr, &Row, &EvalContext<'_>) -> Result<Value, EvalError>;
205
206pub fn run(
207 stmt: &SelectStatement,
208 rows: &[&Row],
209 schema_cols: &[ColumnSchema],
210 table_alias: Option<&str>,
211 correlated_eval: Option<CorrelatedEval<'_>>,
212) -> Result<AggResult, EvalError> {
213 let ctx = EvalContext::new(schema_cols, table_alias);
214 let group_exprs: Vec<Expr> = stmt.group_by.clone().unwrap_or_default();
215
216 let mut agg_specs: Vec<AggSpec> = Vec::new();
218 for item in &stmt.items {
219 if let SelectItem::Expr { expr, .. } = item {
220 collect_aggregates(expr, &mut agg_specs);
221 }
222 }
223 for o in &stmt.order_by {
224 collect_aggregates(&o.expr, &mut agg_specs);
225 }
226 if let Some(h) = &stmt.having {
227 collect_aggregates(h, &mut agg_specs);
228 }
229 validate_agg_arities(stmt, &agg_specs)?;
235
236 let mut groups: hashbrown::HashMap<String, (Vec<Value>, Vec<AggState>)> =
240 hashbrown::HashMap::new();
241 let mut key_order: Vec<String> = Vec::new();
242 if rows.is_empty() && group_exprs.is_empty() {
245 let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
247 groups.insert(String::new(), (Vec::new(), init));
248 key_order.push(String::new());
249 }
250
251 let col_pos = |e: &Expr| -> Option<usize> {
262 if let Expr::Column(c) = e
265 && c.qualifier.is_some()
266 {
267 eval::find_column_pos(c, &ctx)
268 } else {
269 None
270 }
271 };
272 let group_pos: Vec<Option<usize>> = group_exprs.iter().map(col_pos).collect();
273 let all_groups_bound = group_pos.iter().all(Option::is_some);
274 let arg_pos: Vec<Option<usize>> = agg_specs
275 .iter()
276 .map(|spec| spec.arg.as_ref().and_then(|e| col_pos(e)))
277 .collect();
278 let ci_positions: Vec<usize> = group_exprs
279 .iter()
280 .enumerate()
281 .filter(|(_, g)| {
282 matches!(
283 eval::column_collation(g, &ctx),
284 Some(spg_storage::Collation::CaseInsensitive)
285 )
286 })
287 .map(|(i, _)| i)
288 .collect();
289 for row in rows {
290 if all_groups_bound && ci_positions.is_empty() && !group_exprs.is_empty() {
294 let refs: Vec<&Value> = group_pos
295 .iter()
296 .map(|p| row.values.get(p.unwrap()).unwrap_or(&Value::Null))
297 .collect();
298 let key = encode_key_refs(&refs);
299 let entry = match groups.entry_ref(key.as_str()) {
300 hashbrown::hash_map::EntryRef::Occupied(o) => o.into_mut(),
301 hashbrown::hash_map::EntryRef::Vacant(v) => {
302 key_order.push(key.clone());
303 let init: Vec<AggState> =
304 (0..agg_specs.len()).map(|_| AggState::default()).collect();
305 let owned: Vec<Value> = refs.iter().map(|v| (*v).clone()).collect();
306 v.insert((owned, init))
307 }
308 };
309 for (i, spec) in agg_specs.iter().enumerate() {
310 let arg_owned: Value;
311 let arg_ref: &Value = match (&arg_pos[i], &spec.arg) {
312 (Some(p), _) => row.values.get(*p).unwrap_or(&Value::Null),
313 (None, None) => {
314 arg_owned = Value::Bool(true);
315 &arg_owned
316 }
317 (None, Some(e)) => {
318 arg_owned = eval::eval_expr(e, row, &ctx)?;
319 &arg_owned
320 }
321 };
322 let arg2_val = match &spec.arg2 {
323 None => None,
324 Some(e) => Some(eval::eval_expr(e, row, &ctx)?),
325 };
326 let order_keys = if spec.order_by.is_empty() {
327 None
328 } else {
329 let mut keys = Vec::with_capacity(spec.order_by.len());
330 for o in &spec.order_by {
331 keys.push(eval::eval_expr(&o.expr, row, &ctx)?);
332 }
333 Some(keys)
334 };
335 if spec.distinct {
336 let dkey = encode_key_refs(core::slice::from_ref(&arg_ref));
337 if !entry.1[i].seen.insert(dkey) {
338 continue;
339 }
340 }
341 update_state(
342 &mut entry.1[i],
343 &spec.name,
344 arg_ref,
345 arg2_val.as_ref(),
346 order_keys,
347 )?;
348 }
349 continue;
350 }
351 let group_vals: Vec<Value> = group_exprs
352 .iter()
353 .map(|g| eval::eval_expr(g, row, &ctx))
354 .collect::<Result<_, _>>()?;
355 let key = if ci_positions.is_empty() {
359 encode_key(&group_vals)
360 } else {
361 let mut key_vals = group_vals.clone();
362 for &i in &ci_positions {
363 if let Value::Text(s) = &key_vals[i] {
364 key_vals[i] = Value::Text(s.to_ascii_lowercase());
365 }
366 }
367 encode_key(&key_vals)
368 };
369 let entry = match groups.entry_ref(key.as_str()) {
371 hashbrown::hash_map::EntryRef::Occupied(o) => o.into_mut(),
372 hashbrown::hash_map::EntryRef::Vacant(v) => {
373 key_order.push(key.clone());
374 let init: Vec<AggState> =
375 (0..agg_specs.len()).map(|_| AggState::default()).collect();
376 v.insert((group_vals.clone(), init))
377 }
378 };
379 for (i, spec) in agg_specs.iter().enumerate() {
380 let arg_val = match &spec.arg {
381 None => Value::Bool(true), Some(e) => eval::eval_expr(e, row, &ctx)?,
383 };
384 let arg2_val = match &spec.arg2 {
390 None => None,
391 Some(e) => Some(eval::eval_expr(e, row, &ctx)?),
392 };
393 let order_keys = if spec.order_by.is_empty() {
396 None
397 } else {
398 let mut keys = Vec::with_capacity(spec.order_by.len());
399 for o in &spec.order_by {
400 keys.push(eval::eval_expr(&o.expr, row, &ctx)?);
401 }
402 Some(keys)
403 };
404 if spec.distinct {
409 let key = encode_key(core::slice::from_ref(&arg_val));
410 if !entry.1[i].seen.insert(key) {
411 continue;
412 }
413 }
414 update_state(
415 &mut entry.1[i],
416 &spec.name,
417 &arg_val,
418 arg2_val.as_ref(),
419 order_keys,
420 )?;
421 }
422 }
423
424 let group_types: Vec<DataType> = if rows.is_empty() {
426 group_exprs.iter().map(|_| DataType::Text).collect()
429 } else {
430 let probe = rows[0];
431 group_exprs
432 .iter()
433 .map(|g| {
434 eval::eval_expr(g, probe, &ctx).map(|v| v.data_type().unwrap_or(DataType::Text))
435 })
436 .collect::<Result<_, _>>()?
437 };
438 let agg_types: Vec<DataType> = agg_specs
439 .iter()
440 .map(|spec| infer_agg_type(spec, schema_cols))
441 .collect();
442 let mut synth_schema: Vec<ColumnSchema> = Vec::new();
443 for (i, ty) in group_types.iter().enumerate() {
444 synth_schema.push(ColumnSchema::new(format!("__grp_{i}"), *ty, true));
445 }
446 for (i, ty) in agg_types.iter().enumerate() {
447 synth_schema.push(ColumnSchema::new(format!("__agg_{i}"), *ty, true));
448 }
449
450 let mut synth_rows: Vec<Row> = Vec::new();
452 for k in &key_order {
453 let (gvals, states) = &groups[k];
454 let mut values: Vec<Value> = Vec::with_capacity(synth_schema.len());
455 values.extend(gvals.iter().cloned());
456 for (i, st) in states.iter().enumerate() {
457 let st_sorted;
461 let st_final: &AggState =
462 if !agg_specs[i].order_by.is_empty() && st.item_keys.len() == st.items.len() {
463 let mut idx: Vec<usize> = (0..st.items.len()).collect();
464 let ob = &agg_specs[i].order_by;
465 idx.sort_by(|&x, &y| {
466 for (k, o) in ob.iter().enumerate() {
467 let cmp = crate::order_by_value_cmp(
468 o.desc,
469 o.nulls_first,
470 &st.item_keys[x][k],
471 &st.item_keys[y][k],
472 );
473 if cmp != core::cmp::Ordering::Equal {
474 return cmp;
475 }
476 }
477 core::cmp::Ordering::Equal
478 });
479 let mut sorted = st.clone();
480 sorted.items = idx.iter().map(|&j| st.items[j].clone()).collect();
481 st_sorted = sorted;
482 &st_sorted
483 } else {
484 st
485 };
486 values.push(finalize(&agg_specs[i].name, st_final));
487 }
488 synth_rows.push(Row::new(values));
489 }
490
491 let columns: Vec<ColumnSchema> = stmt
496 .items
497 .iter()
498 .map(|item| match item {
499 SelectItem::Wildcard => Err(EvalError::TypeMismatch {
500 detail: "SELECT * with aggregates is not supported".into(),
501 }),
502 SelectItem::Expr { expr, alias } => {
503 let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
504 let name = alias.clone().unwrap_or_else(|| expr.to_string());
505 Ok(ColumnSchema::new(
506 name,
507 agg_or_group_type(&rewritten, &synth_schema),
508 true,
509 ))
510 }
511 })
512 .collect::<Result<_, _>>()?;
513
514 let synth_ctx = EvalContext::new(&synth_schema, None);
519 let having_rewritten = stmt
520 .having
521 .as_ref()
522 .map(|h| rewrite_expr(h, &group_exprs, &agg_specs));
523 let items_rewritten: alloc::vec::Vec<Option<Expr>> = stmt
529 .items
530 .iter()
531 .map(|item| match item {
532 SelectItem::Expr { expr, .. } => Some(rewrite_expr(expr, &group_exprs, &agg_specs)),
533 SelectItem::Wildcard => None,
534 })
535 .collect();
536 let mut kept_synth: Vec<Row> = Vec::new();
537 let mut out_rows: Vec<Row> = Vec::new();
538 for srow in synth_rows {
539 if let Some(h) = &having_rewritten {
540 let cond = match correlated_eval {
541 Some(f) if crate::expr_has_subquery(h) => f(h, &srow, &synth_ctx)?,
542 _ => eval::eval_expr(h, &srow, &synth_ctx)?,
543 };
544 if !matches!(cond, Value::Bool(true)) {
545 continue;
546 }
547 }
548 let mut values: Vec<Value> = Vec::with_capacity(columns.len());
549 for rewritten in items_rewritten.iter().flatten() {
550 values.push(match correlated_eval {
551 Some(f) if crate::expr_has_subquery(rewritten) => f(rewritten, &srow, &synth_ctx)?,
552 _ => eval::eval_expr(rewritten, &srow, &synth_ctx)?,
553 });
554 }
555 kept_synth.push(srow);
556 out_rows.push(Row::new(values));
557 }
558
559 if !stmt.order_by.is_empty() {
562 let rewritten: Vec<Expr> = stmt
565 .order_by
566 .iter()
567 .map(|o| rewrite_expr(&o.expr, &group_exprs, &agg_specs))
568 .collect();
569 let keys_meta: Vec<(bool, Option<bool>)> = stmt
570 .order_by
571 .iter()
572 .map(|o| (o.desc, o.nulls_first))
573 .collect();
574 let mut tagged: Vec<(Vec<Value>, Row)> = kept_synth
575 .into_iter()
576 .zip(out_rows)
577 .map(|(s, o)| {
578 let mut keys = Vec::with_capacity(rewritten.len());
579 for e in &rewritten {
580 keys.push(match correlated_eval {
581 Some(f) if crate::expr_has_subquery(e) => f(e, &s, &synth_ctx)?,
582 _ => eval::eval_expr(e, &s, &synth_ctx)?,
583 });
584 }
585 Ok::<_, EvalError>((keys, o))
586 })
587 .collect::<Result<_, _>>()?;
588 tagged.sort_by(|a, b| {
589 use core::cmp::Ordering;
590 for (i, (ka, kb)) in a.0.iter().zip(b.0.iter()).enumerate() {
591 let (desc, nf) = keys_meta[i];
592 let cmp = crate::order_by_value_cmp(desc, nf, ka, kb);
593 if cmp != Ordering::Equal {
594 return cmp;
595 }
596 }
597 Ordering::Equal
598 });
599 out_rows = tagged.into_iter().map(|(_, o)| o).collect();
600 }
601
602 Ok(AggResult {
603 columns,
604 rows: out_rows,
605 })
606}
607
608fn validate_agg_arities(stmt: &SelectStatement, _specs: &[AggSpec]) -> Result<(), EvalError> {
614 fn walk(e: &Expr) -> Result<(), EvalError> {
615 if let Expr::FunctionCall { name, args } = e {
616 let lower = name.to_ascii_lowercase();
617 let expected: Option<usize> = match lower.as_str() {
618 "count_star" => Some(0),
619 "count" | "sum" | "avg" | "min" | "max" | "array_agg"
620 | "bool_and" | "bool_or" | "every" => Some(1),
624 "string_agg" => Some(2),
625 _ => None,
626 };
627 if let Some(want) = expected
628 && args.len() != want
629 {
630 return Err(EvalError::TypeMismatch {
631 detail: alloc::format!("{lower}() takes {want} arg(s), got {}", args.len()),
632 });
633 }
634 for a in args {
635 walk(a)?;
636 }
637 } else if let Expr::Binary { lhs, rhs, .. } = e {
638 walk(lhs)?;
639 walk(rhs)?;
640 } else if let Expr::Unary { expr, .. }
641 | Expr::Cast { expr, .. }
642 | Expr::IsNull { expr, .. } = e
643 {
644 walk(expr)?;
645 }
646 Ok(())
647 }
648 for item in &stmt.items {
649 if let SelectItem::Expr { expr, .. } = item {
650 walk(expr)?;
651 }
652 }
653 for o in &stmt.order_by {
654 walk(&o.expr)?;
655 }
656 if let Some(h) = &stmt.having {
657 walk(h)?;
658 }
659 Ok(())
660}
661
662fn collect_aggregates(e: &Expr, out: &mut Vec<AggSpec>) {
663 match e {
664 Expr::AggregateOrdered {
667 call,
668 order_by,
669 distinct,
670 } => {
671 if let Expr::FunctionCall { name, args } = call.as_ref() {
672 let lower = name.to_ascii_lowercase();
673 if is_aggregate_name(&lower) {
674 let canonical = if lower == "every" {
675 "bool_and".to_string()
676 } else {
677 lower
678 };
679 let spec = AggSpec {
680 name: canonical,
681 arg: args.first().cloned(),
682 arg2: if name.eq_ignore_ascii_case("string_agg") {
683 args.get(1).cloned()
684 } else {
685 None
686 },
687 distinct: *distinct,
688 order_by: order_by.clone(),
689 };
690 if !out.iter().any(|s| {
691 s.name == spec.name
692 && s.arg == spec.arg
693 && s.arg2 == spec.arg2
694 && s.distinct == spec.distinct
695 && s.order_by == spec.order_by
696 }) {
697 out.push(spec);
698 }
699 return;
700 }
701 }
702 collect_aggregates(call, out);
703 for o in order_by {
704 collect_aggregates(&o.expr, out);
705 }
706 }
707 Expr::FunctionCall { name, args } => {
708 let lower = name.to_ascii_lowercase();
709 if is_aggregate_name(&lower) {
710 let arg = if lower == "count_star" {
711 None
712 } else {
713 args.first().cloned()
714 };
715 let arg2 = if lower == "string_agg" {
719 args.get(1).cloned()
720 } else {
721 None
722 };
723 let canonical = if lower == "every" {
727 "bool_and".to_string()
728 } else {
729 lower
730 };
731 let spec = AggSpec {
732 name: canonical,
733 arg: arg.clone(),
734 arg2: arg2.clone(),
735 distinct: false,
736 order_by: Vec::new(),
737 };
738 if !out.iter().any(|s| {
739 s.name == spec.name
740 && s.arg == spec.arg
741 && s.arg2 == spec.arg2
742 && !s.distinct
743 && s.order_by == spec.order_by
744 }) {
745 out.push(spec);
746 }
747 } else {
750 for a in args {
751 collect_aggregates(a, out);
752 }
753 }
754 }
755 Expr::Binary { lhs, rhs, .. } => {
756 collect_aggregates(lhs, out);
757 collect_aggregates(rhs, out);
758 }
759 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
760 collect_aggregates(expr, out);
761 }
762 Expr::Like { expr, pattern, .. } => {
763 collect_aggregates(expr, out);
764 collect_aggregates(pattern, out);
765 }
766 Expr::Extract { source, .. } => collect_aggregates(source, out),
767 Expr::ScalarSubquery(_)
770 | Expr::Exists { .. }
771 | Expr::InSubquery { .. }
772 | Expr::WindowFunction { .. }
773 | Expr::Literal(_)
774 | Expr::Placeholder(_)
775 | Expr::Column(_) => {}
776 Expr::Array(items) => {
779 for elem in items {
780 collect_aggregates(elem, out);
781 }
782 }
783 Expr::ArraySubscript { target, index } => {
784 collect_aggregates(target, out);
785 collect_aggregates(index, out);
786 }
787 Expr::AnyAll { expr, array, .. } => {
788 collect_aggregates(expr, out);
789 collect_aggregates(array, out);
790 }
791 Expr::Case {
792 operand,
793 branches,
794 else_branch,
795 } => {
796 if let Some(o) = operand {
797 collect_aggregates(o, out);
798 }
799 for (w, t) in branches {
800 collect_aggregates(w, out);
801 collect_aggregates(t, out);
802 }
803 if let Some(e) = else_branch {
804 collect_aggregates(e, out);
805 }
806 }
807 }
808}
809
810fn update_state(
811 st: &mut AggState,
812 name: &str,
813 v: &Value,
814 arg2: Option<&Value>,
815 order_keys: Option<Vec<Value>>,
816) -> Result<(), EvalError> {
817 let is_null = matches!(v, Value::Null);
818 match name {
819 "count_star" => st.count += 1,
820 "count" => {
821 if !is_null {
822 st.count += 1;
823 }
824 }
825 "sum" | "avg" => {
826 if is_null {
827 return Ok(());
828 }
829 st.count += 1;
830 match v {
831 Value::Int(n) => st.sum_int += i64::from(*n),
832 Value::BigInt(n) => st.sum_int += *n,
833 Value::Float(x) => {
834 st.use_float = true;
835 st.sum_float += *x;
836 }
837 other => {
838 return Err(EvalError::TypeMismatch {
839 detail: format!("sum/avg need numeric, got {:?}", other.data_type()),
840 });
841 }
842 }
843 }
844 "min" => {
845 if is_null {
846 return Ok(());
847 }
848 match &st.extreme {
849 None => st.extreme = Some(v.clone()),
850 Some(cur) => {
851 if value_cmp(v, cur) == core::cmp::Ordering::Less {
852 st.extreme = Some(v.clone());
853 }
854 }
855 }
856 }
857 "max" => {
858 if is_null {
859 return Ok(());
860 }
861 match &st.extreme {
862 None => st.extreme = Some(v.clone()),
863 Some(cur) => {
864 if value_cmp(v, cur) == core::cmp::Ordering::Greater {
865 st.extreme = Some(v.clone());
866 }
867 }
868 }
869 }
870 "string_agg" => {
878 if let Some(sep) = arg2
879 && let Value::Text(s) = sep
880 {
881 st.separator = Some(s.clone());
882 }
883 if is_null {
884 return Ok(());
885 }
886 if let Value::Text(s) = v {
887 st.items.push(Value::Text(s.clone()));
888 if let Some(k) = order_keys {
889 st.item_keys.push(k);
890 }
891 st.count += 1;
892 } else {
893 return Err(EvalError::TypeMismatch {
894 detail: format!("string_agg requires text value, got {:?}", v.data_type()),
895 });
896 }
897 }
898 "array_agg" => {
904 st.items.push(v.clone());
905 if let Some(k) = order_keys {
906 st.item_keys.push(k);
907 }
908 st.count += 1;
909 }
910 "bool_and" => {
914 if is_null {
915 return Ok(());
916 }
917 let b = match v {
918 Value::Bool(b) => *b,
919 other => {
920 return Err(EvalError::TypeMismatch {
921 detail: format!("bool_and requires bool, got {:?}", other.data_type()),
922 });
923 }
924 };
925 st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc && b));
926 }
927 "bool_or" => {
930 if is_null {
931 return Ok(());
932 }
933 let b = match v {
934 Value::Bool(b) => *b,
935 other => {
936 return Err(EvalError::TypeMismatch {
937 detail: format!("bool_or requires bool, got {:?}", other.data_type()),
938 });
939 }
940 };
941 st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc || b));
942 }
943 _ => unreachable!("non-aggregate {name} in update_state"),
944 }
945 Ok(())
946}
947
948#[allow(clippy::cast_precision_loss)]
949fn finalize(name: &str, st: &AggState) -> Value {
950 match name {
951 "count" | "count_star" => Value::BigInt(st.count),
952 "sum" => {
953 if st.count == 0 {
954 Value::Null
955 } else if st.use_float {
956 Value::Float(st.sum_float + (st.sum_int as f64))
957 } else {
958 Value::BigInt(st.sum_int)
959 }
960 }
961 "avg" => {
962 if st.count == 0 {
963 Value::Null
964 } else {
965 let total = if st.use_float {
966 st.sum_float + (st.sum_int as f64)
967 } else {
968 st.sum_int as f64
969 };
970 Value::Float(total / (st.count as f64))
971 }
972 }
973 "min" | "max" => st.extreme.clone().unwrap_or(Value::Null),
974 "string_agg" => {
978 if st.items.is_empty() {
979 return Value::Null;
980 }
981 let sep = st.separator.clone().unwrap_or_default();
982 let mut out = String::new();
983 for (i, item) in st.items.iter().enumerate() {
984 if i > 0 {
985 out.push_str(&sep);
986 }
987 if let Value::Text(s) = item {
988 out.push_str(s);
989 }
990 }
991 Value::Text(out)
992 }
993 "array_agg" => {
1000 if st.items.is_empty() {
1001 return Value::Null;
1002 }
1003 let probe = st.items.iter().find(|v| !v.is_null());
1004 match probe.and_then(spg_storage::Value::data_type) {
1005 Some(DataType::Int) | Some(DataType::SmallInt) => {
1006 let items: Vec<Option<i32>> = st
1007 .items
1008 .iter()
1009 .map(|v| match v {
1010 Value::Int(n) => Some(*n),
1011 Value::SmallInt(n) => Some(i32::from(*n)),
1012 _ => None,
1013 })
1014 .collect();
1015 Value::IntArray(items)
1016 }
1017 Some(DataType::BigInt) => {
1018 let items: Vec<Option<i64>> = st
1019 .items
1020 .iter()
1021 .map(|v| match v {
1022 Value::BigInt(n) => Some(*n),
1023 _ => None,
1024 })
1025 .collect();
1026 Value::BigIntArray(items)
1027 }
1028 _ => {
1029 let items: Vec<Option<String>> = st
1030 .items
1031 .iter()
1032 .map(|v| match v {
1033 Value::Text(s) => Some(s.clone()),
1034 Value::Null => None,
1035 other => Some(format!("{other:?}")),
1036 })
1037 .collect();
1038 Value::TextArray(items)
1039 }
1040 }
1041 }
1042 "bool_and" | "bool_or" => st.bool_acc.map_or(Value::Null, Value::Bool),
1046 _ => unreachable!(),
1047 }
1048}
1049
1050fn infer_agg_type(spec: &AggSpec, schema_cols: &[ColumnSchema]) -> DataType {
1051 let arg_ty = spec
1055 .arg
1056 .as_ref()
1057 .and_then(|a| crate::describe::describe_expr(a, schema_cols))
1058 .map(|shape| shape.ty);
1059 match spec.name.as_str() {
1060 "count" | "count_star" => DataType::BigInt,
1061 "sum" => match arg_ty {
1062 Some(DataType::Float) => DataType::Float,
1063 _ => DataType::BigInt,
1064 },
1065 "avg" => DataType::Float,
1066 "string_agg" => DataType::Text,
1068 "array_agg" => match arg_ty {
1069 Some(DataType::Int | DataType::SmallInt) => DataType::IntArray,
1070 Some(DataType::BigInt) => DataType::BigIntArray,
1071 _ => DataType::TextArray,
1072 },
1073 "bool_and" | "bool_or" => DataType::Bool,
1076 _ => arg_ty.unwrap_or(DataType::Text),
1078 }
1079}
1080
1081fn agg_or_group_type(e: &Expr, synth: &[ColumnSchema]) -> DataType {
1082 if let Expr::Column(c) = e
1083 && let Some(s) = synth.iter().find(|s| s.name == c.name)
1084 {
1085 return s.ty;
1086 }
1087 crate::describe::describe_expr(e, synth)
1093 .map(|shape| shape.ty)
1094 .unwrap_or(DataType::Text)
1095}
1096
1097fn rewrite_expr(e: &Expr, group_exprs: &[Expr], aggs: &[AggSpec]) -> Expr {
1098 if let Expr::AggregateOrdered {
1101 call,
1102 order_by,
1103 distinct,
1104 } = e
1105 && let Expr::FunctionCall { name, args } = call.as_ref()
1106 {
1107 let lower = name.to_ascii_lowercase();
1108 if is_aggregate_name(&lower) {
1109 let canonical: &str = if lower == "every" { "bool_and" } else { &lower };
1110 let arg = args.first().cloned();
1111 let arg2 = if lower == "string_agg" {
1112 args.get(1).cloned()
1113 } else {
1114 None
1115 };
1116 for (i, spec) in aggs.iter().enumerate() {
1117 if spec.name == canonical
1118 && spec.arg == arg
1119 && spec.arg2 == arg2
1120 && spec.distinct == *distinct
1121 && spec.order_by == *order_by
1122 {
1123 return Expr::Column(spg_sql::ast::ColumnName {
1124 qualifier: None,
1125 name: format!("__agg_{i}"),
1126 });
1127 }
1128 }
1129 }
1130 }
1131 if let Expr::FunctionCall { name, args } = e {
1133 let lower = name.to_ascii_lowercase();
1134 if is_aggregate_name(&lower) {
1135 let arg = if lower == "count_star" {
1136 None
1137 } else {
1138 args.first().cloned()
1139 };
1140 let arg2 = if lower == "string_agg" {
1143 args.get(1).cloned()
1144 } else {
1145 None
1146 };
1147 let canonical: &str = if lower == "every" {
1151 "bool_and"
1152 } else {
1153 lower.as_str()
1154 };
1155 for (i, spec) in aggs.iter().enumerate() {
1156 if spec.name == canonical
1157 && spec.arg == arg
1158 && spec.arg2 == arg2
1159 && !spec.distinct
1160 && spec.order_by.is_empty()
1161 {
1162 return Expr::Column(spg_sql::ast::ColumnName {
1163 qualifier: None,
1164 name: format!("__agg_{i}"),
1165 });
1166 }
1167 }
1168 }
1169 }
1170 for (i, g) in group_exprs.iter().enumerate() {
1172 if g == e {
1173 return Expr::Column(spg_sql::ast::ColumnName {
1174 qualifier: None,
1175 name: format!("__grp_{i}"),
1176 });
1177 }
1178 }
1179 match e {
1181 Expr::AggregateOrdered {
1182 call,
1183 order_by,
1184 distinct,
1185 } => Expr::AggregateOrdered {
1186 call: Box::new(rewrite_expr(call, group_exprs, aggs)),
1187 distinct: *distinct,
1188 order_by: order_by
1189 .iter()
1190 .map(|o| spg_sql::ast::OrderBy {
1191 expr: rewrite_expr(&o.expr, group_exprs, aggs),
1192 desc: o.desc,
1193 nulls_first: o.nulls_first,
1194 })
1195 .collect(),
1196 },
1197 Expr::Binary { lhs, op, rhs } => Expr::Binary {
1198 lhs: Box::new(rewrite_expr(lhs, group_exprs, aggs)),
1199 op: *op,
1200 rhs: Box::new(rewrite_expr(rhs, group_exprs, aggs)),
1201 },
1202 Expr::Unary { op, expr } => Expr::Unary {
1203 op: *op,
1204 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1205 },
1206 Expr::Cast { expr, target } => Expr::Cast {
1207 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1208 target: *target,
1209 },
1210 Expr::IsNull { expr, negated } => Expr::IsNull {
1211 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1212 negated: *negated,
1213 },
1214 Expr::FunctionCall { name, args } => Expr::FunctionCall {
1215 name: name.clone(),
1216 args: args
1217 .iter()
1218 .map(|a| rewrite_expr(a, group_exprs, aggs))
1219 .collect(),
1220 },
1221 Expr::Like {
1222 expr,
1223 pattern,
1224 negated,
1225 case_insensitive,
1226 } => Expr::Like {
1227 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1228 pattern: Box::new(rewrite_expr(pattern, group_exprs, aggs)),
1229 negated: *negated,
1230 case_insensitive: *case_insensitive,
1231 },
1232 Expr::Extract { field, source } => Expr::Extract {
1233 field: *field,
1234 source: Box::new(rewrite_expr(source, group_exprs, aggs)),
1235 },
1236 Expr::ScalarSubquery(s) => {
1242 Expr::ScalarSubquery(Box::new(rewrite_group_keys_in_select(s, group_exprs)))
1243 }
1244 Expr::Exists { subquery, negated } => Expr::Exists {
1245 subquery: Box::new(rewrite_group_keys_in_select(subquery, group_exprs)),
1246 negated: *negated,
1247 },
1248 Expr::InSubquery {
1249 expr,
1250 subquery,
1251 negated,
1252 } => Expr::InSubquery {
1253 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1254 subquery: Box::new(rewrite_group_keys_in_select(subquery, group_exprs)),
1255 negated: *negated,
1256 },
1257 Expr::WindowFunction { .. } | Expr::Literal(_) | Expr::Placeholder(_) | Expr::Column(_) => {
1260 e.clone()
1261 }
1262 Expr::Array(items) => Expr::Array(
1264 items
1265 .iter()
1266 .map(|elem| rewrite_expr(elem, group_exprs, aggs))
1267 .collect(),
1268 ),
1269 Expr::ArraySubscript { target, index } => Expr::ArraySubscript {
1270 target: Box::new(rewrite_expr(target, group_exprs, aggs)),
1271 index: Box::new(rewrite_expr(index, group_exprs, aggs)),
1272 },
1273 Expr::AnyAll {
1274 expr,
1275 op,
1276 array,
1277 is_any,
1278 } => Expr::AnyAll {
1279 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
1280 op: *op,
1281 array: Box::new(rewrite_expr(array, group_exprs, aggs)),
1282 is_any: *is_any,
1283 },
1284 Expr::Case {
1285 operand,
1286 branches,
1287 else_branch,
1288 } => Expr::Case {
1289 operand: operand
1290 .as_deref()
1291 .map(|o| Box::new(rewrite_expr(o, group_exprs, aggs))),
1292 branches: branches
1293 .iter()
1294 .map(|(w, t)| {
1295 (
1296 rewrite_expr(w, group_exprs, aggs),
1297 rewrite_expr(t, group_exprs, aggs),
1298 )
1299 })
1300 .collect(),
1301 else_branch: else_branch
1302 .as_deref()
1303 .map(|e| Box::new(rewrite_expr(e, group_exprs, aggs))),
1304 },
1305 }
1306}
1307
1308fn rewrite_group_keys_in_select(
1313 s: &spg_sql::ast::SelectStatement,
1314 group_exprs: &[Expr],
1315) -> spg_sql::ast::SelectStatement {
1316 let mut out = s.clone();
1317 let _ = crate::walk_select_exprs_mut(&mut out, &mut |e| {
1318 *e = rewrite_expr(e, group_exprs, &[]);
1319 Ok(())
1320 });
1321 out
1322}
1323
1324fn encode_one(out: &mut String, v: &Value) {
1327 match v {
1328 Value::Null => out.push_str("N|"),
1329 Value::SmallInt(n) => {
1330 out.push('s');
1331 out.push_str(&n.to_string());
1332 out.push('|');
1333 }
1334 Value::Int(n) => {
1335 out.push('I');
1336 out.push_str(&n.to_string());
1337 out.push('|');
1338 }
1339 Value::BigInt(n) => {
1340 out.push('B');
1341 out.push_str(&n.to_string());
1342 out.push('|');
1343 }
1344 Value::Float(x) => {
1345 out.push('F');
1346 out.push_str(&x.to_string());
1347 out.push('|');
1348 }
1349 Value::Bool(b) => {
1350 out.push(if *b { 'T' } else { 'f' });
1351 out.push('|');
1352 }
1353 Value::Text(s) => {
1354 out.push('S');
1355 out.push_str(s);
1356 out.push('|');
1357 }
1358 Value::Vector(v) => {
1359 out.push('V');
1360 for x in v {
1361 out.push_str(&x.to_string());
1362 out.push(',');
1363 }
1364 out.push('|');
1365 }
1366 Value::Sq8Vector(q) => {
1372 out.push('Q');
1373 out.push_str(&q.min.to_string());
1374 out.push('@');
1375 out.push_str(&q.max.to_string());
1376 out.push(':');
1377 for b in &q.bytes {
1378 out.push_str(&b.to_string());
1379 out.push(',');
1380 }
1381 out.push('|');
1382 }
1383 Value::HalfVector(h) => {
1387 out.push('H');
1388 for b in &h.bytes {
1389 out.push_str(&b.to_string());
1390 out.push(',');
1391 }
1392 out.push('|');
1393 }
1394 Value::Numeric { scaled, scale } => {
1395 out.push('D');
1396 out.push_str(&scaled.to_string());
1397 out.push('@');
1398 out.push_str(&scale.to_string());
1399 out.push('|');
1400 }
1401 Value::Date(d) => {
1402 out.push('d');
1403 out.push_str(&d.to_string());
1404 out.push('|');
1405 }
1406 Value::Timestamp(t) => {
1407 out.push('t');
1408 out.push_str(&t.to_string());
1409 out.push('|');
1410 }
1411 Value::Interval { months, micros } => {
1412 out.push('i');
1413 out.push_str(&months.to_string());
1414 out.push('m');
1415 out.push_str(µs.to_string());
1416 out.push('|');
1417 }
1418 Value::Json(s) => {
1419 out.push('j');
1420 out.push_str(s);
1421 out.push('|');
1422 }
1423 _ => {
1428 out.push('?');
1429 out.push_str(&format!("{v:?}"));
1430 out.push('|');
1431 }
1432 }
1433}
1434
1435fn encode_key_refs(vals: &[&Value]) -> String {
1438 let mut out = String::new();
1439 for v in vals {
1440 encode_one(&mut out, v);
1441 }
1442 out
1443}
1444
1445pub(crate) fn encode_key(vals: &[Value]) -> String {
1446 let mut out = String::new();
1447 for v in vals {
1448 encode_one(&mut out, v);
1449 }
1450 out
1451}
1452
1453#[allow(clippy::cast_precision_loss)]
1454fn value_cmp(a: &Value, b: &Value) -> core::cmp::Ordering {
1455 use core::cmp::Ordering::Equal;
1456 match (a, b) {
1457 (Value::Null, Value::Null) => Equal,
1458 (Value::Null, _) => core::cmp::Ordering::Greater, (_, Value::Null) => core::cmp::Ordering::Less,
1460 (Value::Int(x), Value::Int(y)) => x.cmp(y),
1461 (Value::BigInt(x), Value::BigInt(y)) => x.cmp(y),
1462 (Value::Int(x), Value::BigInt(y)) => i64::from(*x).cmp(y),
1463 (Value::BigInt(x), Value::Int(y)) => x.cmp(&i64::from(*y)),
1464 (Value::Float(x), Value::Float(y)) => x.partial_cmp(y).unwrap_or(Equal),
1465 (Value::Int(x), Value::Float(y)) => f64::from(*x).partial_cmp(y).unwrap_or(Equal),
1466 (Value::Float(x), Value::Int(y)) => x.partial_cmp(&f64::from(*y)).unwrap_or(Equal),
1467 (Value::BigInt(x), Value::Float(y)) => (*x as f64).partial_cmp(y).unwrap_or(Equal),
1468 (Value::Float(x), Value::BigInt(y)) => x.partial_cmp(&(*y as f64)).unwrap_or(Equal),
1469 (Value::Text(x), Value::Text(y)) => x.cmp(y),
1470 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
1471 _ => Equal,
1472 }
1473}