1use alloc::boxed::Box;
24use alloc::collections::BTreeSet;
25use alloc::format;
26use alloc::string::{String, ToString};
27use alloc::vec::Vec;
28
29use spg_sql::ast::{Expr, SelectItem, SelectStatement};
30use spg_storage::{ColumnSchema, DataType, Row, Value};
31
32use crate::eval::{self, EvalContext, EvalError};
33use crate::join::RowRef;
34
35pub fn uses_aggregate(stmt: &SelectStatement) -> bool {
37 if stmt.group_by.is_some() || stmt.having.is_some() {
38 return true;
39 }
40 for item in &stmt.items {
41 if let SelectItem::Expr { expr, .. } = item
42 && contains_aggregate(expr)
43 {
44 return true;
45 }
46 }
47 for o in &stmt.order_by {
48 if contains_aggregate(&o.expr) {
49 return true;
50 }
51 }
52 if let Some(h) = &stmt.having
53 && contains_aggregate(h)
54 {
55 return true;
56 }
57 false
58}
59
60pub fn contains_aggregate(e: &Expr) -> bool {
61 match e {
62 Expr::FunctionCall { name, args } => {
63 is_aggregate_name(name) || args.iter().any(contains_aggregate)
64 }
65 Expr::AggregateOrdered { .. } => true,
66 Expr::Binary { lhs, rhs, .. } => contains_aggregate(lhs) || contains_aggregate(rhs),
67 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
68 contains_aggregate(expr)
69 }
70 Expr::Like { expr, pattern, .. } => contains_aggregate(expr) || contains_aggregate(pattern),
71 Expr::Extract { source, .. } => contains_aggregate(source),
72 Expr::ScalarSubquery(_)
77 | Expr::Exists { .. }
78 | Expr::InSubquery { .. }
79 | Expr::WindowFunction { .. }
80 | Expr::Literal(_)
81 | Expr::Placeholder(_)
82 | Expr::Column(_) => false,
83 Expr::Array(items) => items.iter().any(contains_aggregate),
87 Expr::ArraySubscript { target, index } => {
88 contains_aggregate(target) || contains_aggregate(index)
89 }
90 Expr::AnyAll { expr, array, .. } => contains_aggregate(expr) || contains_aggregate(array),
91 Expr::InList { expr, list, .. } => {
92 contains_aggregate(expr) || list.iter().any(contains_aggregate)
93 }
94 Expr::Case {
97 operand,
98 branches,
99 else_branch,
100 } => {
101 operand.as_deref().is_some_and(contains_aggregate)
102 || branches
103 .iter()
104 .any(|(w, t)| contains_aggregate(w) || contains_aggregate(t))
105 || else_branch.as_deref().is_some_and(contains_aggregate)
106 }
107 }
108}
109
110pub fn is_aggregate_name(name: &str) -> bool {
111 matches!(
112 name.to_ascii_lowercase().as_str(),
113 "count"
114 | "count_star"
115 | "sum"
116 | "min"
117 | "max"
118 | "avg"
119 | "string_agg"
124 | "array_agg"
125 | "bool_and"
128 | "bool_or"
129 | "every"
130 | "stddev" | "stddev_samp" | "stddev_pop"
133 | "variance" | "var_samp" | "var_pop"
134 | "bit_and" | "bit_or" | "bit_xor"
136 | "percentile_cont" | "percentile_disc" | "mode"
139 | "rank" | "dense_rank" | "percent_rank" | "cume_dist"
142 | "covar_pop" | "covar_samp" | "corr"
144 | "regr_count" | "regr_avgx" | "regr_avgy" | "regr_slope"
145 | "regr_intercept" | "regr_r2" | "regr_sxx" | "regr_syy" | "regr_sxy"
146 | "json_agg" | "jsonb_agg" | "json_object_agg" | "jsonb_object_agg"
148 )
149}
150
151fn is_regression_name(name: &str) -> bool {
153 matches!(
154 name,
155 "covar_pop"
156 | "covar_samp"
157 | "corr"
158 | "regr_count"
159 | "regr_avgx"
160 | "regr_avgy"
161 | "regr_slope"
162 | "regr_intercept"
163 | "regr_r2"
164 | "regr_sxx"
165 | "regr_syy"
166 | "regr_sxy"
167 )
168}
169
170fn agg_uses_second_arg(name: &str) -> bool {
174 name == "string_agg"
175 || name == "json_object_agg"
176 || name == "jsonb_object_agg"
177 || is_regression_name(name)
178}
179
180pub fn is_ordered_set_name(name: &str) -> bool {
185 ["percentile_cont", "percentile_disc", "mode"]
190 .iter()
191 .any(|k| name.eq_ignore_ascii_case(k))
192}
193
194pub fn is_hypothetical_set_name(name: &str) -> bool {
199 ["rank", "dense_rank", "percent_rank", "cume_dist"]
200 .iter()
201 .any(|k| name.eq_ignore_ascii_case(k))
202}
203
204pub fn is_within_group_name(name: &str) -> bool {
207 is_ordered_set_name(name) || is_hypothetical_set_name(name)
208}
209
210#[derive(Debug, Default, Clone)]
212struct AggState {
213 count: i64,
214 sum_int: i64,
215 sum_float: f64,
216 extreme: Option<Value>,
217 use_float: bool,
218 items: Vec<Value>,
225 seen: BTreeSet<String>,
229 item_keys: Vec<Vec<Value>>,
233 separator: Option<String>,
239 bool_acc: Option<bool>,
243 sum_sq: f64,
246 bit_acc: Option<i64>,
249 reg_n: i64,
254 reg_sx: f64,
255 reg_sy: f64,
256 reg_sxx: f64,
257 reg_syy: f64,
258 reg_sxy: f64,
259 aux_items: Vec<Value>,
262}
263
264#[derive(Debug, Clone)]
265struct AggSpec {
266 name: String, arg: Option<Expr>,
270 arg2: Option<Expr>,
276 distinct: bool,
279 order_by: Vec<spg_sql::ast::OrderBy>,
285 filter: Option<Expr>,
290 direct_arg: Option<Expr>,
295}
296
297#[derive(Debug)]
300pub struct AggResult {
301 pub columns: Vec<ColumnSchema>,
302 pub rows: Vec<Row>,
303 pub deferred: Vec<(usize, Expr)>,
311 pub synth_rows: Vec<Row>,
314 pub synth_schema: Vec<ColumnSchema>,
316}
317
318#[allow(clippy::too_many_lines)]
321pub type CorrelatedEval<'a> = &'a dyn Fn(&Expr, &Row, &EvalContext<'_>) -> Result<Value, EvalError>;
328
329struct Projection {
334 columns: Vec<ColumnSchema>,
335 out_rows: Vec<Row>,
336 kept_synth: Vec<Row>,
337 deferred: Vec<(usize, Expr)>,
338 order_rewritten: Vec<Expr>,
339}
340
341pub(crate) fn run(
342 stmt: &SelectStatement,
343 rows: &[RowRef<'_>],
344 schema_cols: &[ColumnSchema],
345 table_alias: Option<&str>,
346 correlated_eval: Option<CorrelatedEval<'_>>,
347) -> Result<AggResult, EvalError> {
348 let group_exprs: Vec<Expr> = stmt.group_by.clone().unwrap_or_default();
349
350 let mut agg_specs: Vec<AggSpec> = Vec::new();
352 for item in &stmt.items {
353 if let SelectItem::Expr { expr, .. } = item {
354 collect_aggregates(expr, &mut agg_specs);
355 }
356 }
357 for o in &stmt.order_by {
358 collect_aggregates(&o.expr, &mut agg_specs);
359 }
360 if let Some(h) = &stmt.having {
361 collect_aggregates(h, &mut agg_specs);
362 }
363 validate_agg_arities(stmt, &agg_specs)?;
369 validate_within_group(&agg_specs)?;
370
371 let order = accumulate_groups(
373 rows,
374 &group_exprs,
375 &agg_specs,
376 schema_cols,
377 table_alias,
378 correlated_eval,
379 )?;
380
381 let synth_schema =
383 build_synth_schema(rows, &group_exprs, &agg_specs, schema_cols, table_alias)?;
384 let synth_rows = finalize_synth_rows(
385 &order,
386 &agg_specs,
387 &synth_schema,
388 rows,
389 schema_cols,
390 table_alias,
391 )?;
392
393 let Projection {
395 columns,
396 mut out_rows,
397 mut kept_synth,
398 deferred,
399 order_rewritten,
400 } = project_groups(
401 synth_rows,
402 stmt,
403 &group_exprs,
404 &agg_specs,
405 &synth_schema,
406 correlated_eval,
407 )?;
408
409 if !stmt.order_by.is_empty() {
411 let (sorted_synth, sorted_out) = sort_synth_by_order_by(
412 &synth_schema,
413 &stmt.order_by,
414 &order_rewritten,
415 kept_synth,
416 out_rows,
417 correlated_eval,
418 )?;
419 kept_synth = sorted_synth;
420 out_rows = sorted_out;
421 }
422
423 let (synth_rows_out, synth_schema_out) = if deferred.is_empty() {
424 (Vec::new(), Vec::new())
425 } else {
426 (kept_synth, synth_schema.clone())
427 };
428 Ok(AggResult {
429 columns,
430 rows: out_rows,
431 deferred,
432 synth_rows: synth_rows_out,
433 synth_schema: synth_schema_out,
434 })
435}
436
437fn validate_within_group(agg_specs: &[AggSpec]) -> Result<(), EvalError> {
442 for spec in agg_specs {
446 if is_within_group_name(&spec.name) {
447 if spec.order_by.is_empty() {
448 return Err(EvalError::TypeMismatch {
449 detail: format!("{}() requires WITHIN GROUP (ORDER BY …)", spec.name),
450 });
451 }
452 if spec.name != "mode" && spec.direct_arg.is_none() {
456 return Err(EvalError::TypeMismatch {
457 detail: format!("{}() requires a direct argument", spec.name),
458 });
459 }
460 if spec.order_by.len() > 1 {
464 return Err(EvalError::TypeMismatch {
465 detail: format!(
466 "{}() with multiple WITHIN GROUP sort keys is not supported yet",
467 spec.name
468 ),
469 });
470 }
471 }
472 }
473 Ok(())
474}
475
476#[allow(clippy::too_many_lines, clippy::type_complexity)]
480fn accumulate_groups(
481 rows: &[RowRef<'_>],
482 group_exprs: &[Expr],
483 agg_specs: &[AggSpec],
484 schema_cols: &[ColumnSchema],
485 table_alias: Option<&str>,
486 correlated_eval: Option<CorrelatedEval<'_>>,
487) -> Result<Vec<(Vec<Value>, Vec<AggState>)>, EvalError> {
488 let ctx = EvalContext::new(schema_cols, table_alias);
489 let mut order: Vec<(Vec<Value>, Vec<AggState>)> = Vec::new();
496 let mut groups: hashbrown::HashMap<String, usize> = hashbrown::HashMap::new();
497 if rows.is_empty() && group_exprs.is_empty() {
500 let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
503 order.push((Vec::new(), init));
504 }
505
506 let col_pos = |e: &Expr| -> Option<usize> {
517 if let Expr::Column(c) = e
520 && c.qualifier.is_some()
521 {
522 eval::find_column_pos(c, &ctx)
523 } else {
524 None
525 }
526 };
527 let group_pos: Vec<Option<usize>> = group_exprs.iter().map(col_pos).collect();
528 let all_groups_bound = group_pos.iter().all(Option::is_some);
529 let arg_pos: Vec<Option<usize>> = agg_specs
530 .iter()
531 .map(|spec| spec.arg.as_ref().and_then(|e| col_pos(e)))
532 .collect();
533 let ci_positions: Vec<usize> = group_exprs
534 .iter()
535 .enumerate()
536 .filter(|(_, g)| {
537 matches!(
538 eval::column_collation(g, &ctx),
539 Some(spg_storage::Collation::CaseInsensitive)
540 )
541 })
542 .map(|(i, _)| i)
543 .collect();
544 let mut keybuf_s = String::new();
549 let mut dkeybuf = String::new();
550 let mut refs: Vec<&Value> = Vec::with_capacity(group_pos.len());
551 let any_agg_subquery = correlated_eval.is_some()
565 && agg_specs.iter().any(|s| {
566 s.filter
567 .as_ref()
568 .is_some_and(|e| crate::expr_has_subquery(e))
569 || s.arg.as_ref().is_some_and(|e| crate::expr_has_subquery(e))
570 || s.arg2.as_ref().is_some_and(|e| crate::expr_has_subquery(e))
571 || s.order_by.iter().any(|o| crate::expr_has_subquery(&o.expr))
572 });
573 let eval_arg = |e: &Expr, r: &Row, c: &EvalContext<'_>| -> Result<Value, EvalError> {
574 match correlated_eval {
575 Some(f) if any_agg_subquery && crate::expr_has_subquery(e) => f(e, r, c),
576 _ => eval::eval_expr(e, r, c),
577 }
578 };
579 for row in rows {
580 if all_groups_bound && ci_positions.is_empty() && !group_exprs.is_empty() {
584 refs.clear();
585 refs.extend(
586 group_pos
587 .iter()
588 .map(|p| row.get(p.unwrap()).unwrap_or(&Value::Null)),
589 );
590 encode_key_refs_into(&refs, &mut keybuf_s);
591 let idx = match groups.get(keybuf_s.as_str()) {
592 Some(&i) => i,
593 None => {
594 let i = order.len();
595 let init: Vec<AggState> =
596 (0..agg_specs.len()).map(|_| AggState::default()).collect();
597 let owned: Vec<Value> = refs.iter().map(|v| (*v).clone()).collect();
598 order.push((owned, init));
599 groups.insert(keybuf_s.clone(), i);
600 i
601 }
602 };
603 let entry = &mut order[idx];
604 for (i, spec) in agg_specs.iter().enumerate() {
605 if let Some(f) = &spec.filter
609 && !matches!(eval_arg(f, &row.as_row(), &ctx)?, Value::Bool(true))
610 {
611 continue;
612 }
613 let arg_owned: Value;
614 let arg_ref: &Value = match (&arg_pos[i], &spec.arg) {
615 (Some(p), _) => row.get(*p).unwrap_or(&Value::Null),
616 (None, None) => {
617 arg_owned = Value::Bool(true);
618 &arg_owned
619 }
620 (None, Some(e)) => {
621 arg_owned = eval_arg(e, &row.as_row(), &ctx)?;
622 &arg_owned
623 }
624 };
625 let arg2_val = match &spec.arg2 {
626 None => None,
627 Some(e) => Some(eval_arg(e, &row.as_row(), &ctx)?),
628 };
629 let order_keys = if spec.order_by.is_empty() {
630 None
631 } else {
632 let mut keys = Vec::with_capacity(spec.order_by.len());
633 for o in &spec.order_by {
634 keys.push(eval_arg(&o.expr, &row.as_row(), &ctx)?);
635 }
636 Some(keys)
637 };
638 if spec.distinct {
639 encode_key_refs_into(core::slice::from_ref(&arg_ref), &mut dkeybuf);
640 if entry.1[i].seen.contains(dkeybuf.as_str()) {
641 continue;
642 }
643 entry.1[i].seen.insert(dkeybuf.clone());
644 }
645 update_state(
646 &mut entry.1[i],
647 &spec.name,
648 arg_ref,
649 arg2_val.as_ref(),
650 order_keys,
651 )?;
652 }
653 continue;
654 }
655 let row_materialised = row.as_row();
660 let row: &Row = &row_materialised;
661 let group_vals: Vec<Value> = group_exprs
662 .iter()
663 .map(|g| eval::eval_expr(g, row, &ctx))
664 .collect::<Result<_, _>>()?;
665 let key = if ci_positions.is_empty() {
669 encode_key(&group_vals)
670 } else {
671 let mut key_vals = group_vals.clone();
672 for &i in &ci_positions {
673 if let Value::Text(s) = &key_vals[i] {
674 key_vals[i] = Value::Text(s.to_ascii_lowercase());
675 }
676 }
677 encode_key(&key_vals)
678 };
679 let idx = match groups.get(key.as_str()) {
681 Some(&i) => i,
682 None => {
683 let i = order.len();
684 let init: Vec<AggState> =
685 (0..agg_specs.len()).map(|_| AggState::default()).collect();
686 order.push((group_vals.clone(), init));
687 groups.insert(key, i);
688 i
689 }
690 };
691 let entry = &mut order[idx];
692 for (i, spec) in agg_specs.iter().enumerate() {
693 if let Some(f) = &spec.filter
696 && !matches!(eval_arg(f, row, &ctx)?, Value::Bool(true))
697 {
698 continue;
699 }
700 let arg_val = match &spec.arg {
701 None => Value::Bool(true), Some(e) => eval_arg(e, row, &ctx)?,
703 };
704 let arg2_val = match &spec.arg2 {
710 None => None,
711 Some(e) => Some(eval_arg(e, row, &ctx)?),
712 };
713 let order_keys = if spec.order_by.is_empty() {
716 None
717 } else {
718 let mut keys = Vec::with_capacity(spec.order_by.len());
719 for o in &spec.order_by {
720 keys.push(eval_arg(&o.expr, row, &ctx)?);
721 }
722 Some(keys)
723 };
724 if spec.distinct {
729 let key = encode_key(core::slice::from_ref(&arg_val));
730 if !entry.1[i].seen.insert(key) {
731 continue;
732 }
733 }
734 update_state(
735 &mut entry.1[i],
736 &spec.name,
737 &arg_val,
738 arg2_val.as_ref(),
739 order_keys,
740 )?;
741 }
742 }
743 Ok(order)
744}
745
746fn build_synth_schema(
750 rows: &[RowRef<'_>],
751 group_exprs: &[Expr],
752 agg_specs: &[AggSpec],
753 schema_cols: &[ColumnSchema],
754 table_alias: Option<&str>,
755) -> Result<Vec<ColumnSchema>, EvalError> {
756 let ctx = EvalContext::new(schema_cols, table_alias);
757 let group_types: Vec<DataType> = if rows.is_empty() {
759 group_exprs.iter().map(|_| DataType::Text).collect()
762 } else {
763 let probe_row = rows[0].as_row();
764 let probe: &Row = &probe_row;
765 group_exprs
766 .iter()
767 .map(|g| {
768 eval::eval_expr(g, probe, &ctx).map(|v| v.data_type().unwrap_or(DataType::Text))
769 })
770 .collect::<Result<_, _>>()?
771 };
772 let agg_types: Vec<DataType> = agg_specs
773 .iter()
774 .map(|spec| infer_agg_type(spec, schema_cols))
775 .collect();
776 let mut synth_schema: Vec<ColumnSchema> = Vec::new();
777 for (i, ty) in group_types.iter().enumerate() {
778 synth_schema.push(ColumnSchema::new(format!("__grp_{i}"), *ty, true));
779 }
780 for (i, ty) in agg_types.iter().enumerate() {
781 synth_schema.push(ColumnSchema::new(format!("__agg_{i}"), *ty, true));
782 }
783 Ok(synth_schema)
784}
785
786fn finalize_synth_rows(
790 order: &[(Vec<Value>, Vec<AggState>)],
791 agg_specs: &[AggSpec],
792 synth_schema: &[ColumnSchema],
793 rows: &[RowRef<'_>],
794 schema_cols: &[ColumnSchema],
795 table_alias: Option<&str>,
796) -> Result<Vec<Row>, EvalError> {
797 let ctx = EvalContext::new(schema_cols, table_alias);
798 let direct_arg_vals: Vec<Option<Value>> = agg_specs
801 .iter()
802 .map(|spec| match (&spec.direct_arg, rows.first()) {
803 (Some(e), Some(r)) => eval::eval_expr(e, &r.as_row(), &ctx).map(Some),
804 _ => Ok(None),
805 })
806 .collect::<Result<_, _>>()?;
807
808 let mut synth_rows: Vec<Row> = Vec::new();
810 for (gvals, states) in order {
811 let mut values: Vec<Value> = Vec::with_capacity(synth_schema.len());
812 values.extend(gvals.iter().cloned());
813 for (i, st) in states.iter().enumerate() {
814 let st_sorted;
818 let st_final: &AggState =
819 if !agg_specs[i].order_by.is_empty() && st.item_keys.len() == st.items.len() {
820 let mut idx: Vec<usize> = (0..st.items.len()).collect();
821 let ob = &agg_specs[i].order_by;
822 idx.sort_by(|&x, &y| {
823 for (k, o) in ob.iter().enumerate() {
824 let cmp = crate::order_by_value_cmp(
825 o.desc,
826 o.nulls_first,
827 &st.item_keys[x][k],
828 &st.item_keys[y][k],
829 );
830 if cmp != core::cmp::Ordering::Equal {
831 return cmp;
832 }
833 }
834 core::cmp::Ordering::Equal
835 });
836 let mut sorted = st.clone();
837 sorted.items = idx.iter().map(|&j| st.items[j].clone()).collect();
838 st_sorted = sorted;
839 &st_sorted
840 } else {
841 st
842 };
843 let v = if is_within_group_name(&agg_specs[i].name) {
846 finalize_ordered_set(
847 &agg_specs[i].name,
848 st_final,
849 direct_arg_vals[i].as_ref(),
850 agg_specs[i].order_by.first(),
851 )
852 } else {
853 finalize(&agg_specs[i].name, st_final)
854 };
855 values.push(v);
856 }
857 synth_rows.push(Row::new(values));
858 }
859 Ok(synth_rows)
860}
861
862#[allow(clippy::too_many_lines)]
867fn project_groups(
868 synth_rows: Vec<Row>,
869 stmt: &SelectStatement,
870 group_exprs: &[Expr],
871 agg_specs: &[AggSpec],
872 synth_schema: &[ColumnSchema],
873 correlated_eval: Option<CorrelatedEval<'_>>,
874) -> Result<Projection, EvalError> {
875 let columns: Vec<ColumnSchema> = stmt
880 .items
881 .iter()
882 .map(|item| match item {
883 SelectItem::Wildcard => Err(EvalError::TypeMismatch {
884 detail: "SELECT * with aggregates is not supported".into(),
885 }),
886 SelectItem::Expr { expr, alias } => {
887 let rewritten = rewrite_expr(expr, group_exprs, agg_specs);
888 let name = alias.clone().unwrap_or_else(|| expr.to_string());
889 Ok(ColumnSchema::new(
890 name,
891 agg_or_group_type(&rewritten, synth_schema),
892 true,
893 ))
894 }
895 })
896 .collect::<Result<_, _>>()?;
897
898 let synth_ctx = EvalContext::new(synth_schema, None);
903 let having_rewritten = stmt
904 .having
905 .as_ref()
906 .map(|h| rewrite_expr(h, group_exprs, agg_specs));
907 let items_rewritten: alloc::vec::Vec<Option<Expr>> = stmt
913 .items
914 .iter()
915 .map(|item| match item {
916 SelectItem::Expr { expr, .. } => Some(rewrite_expr(expr, group_exprs, agg_specs)),
917 SelectItem::Wildcard => None,
918 })
919 .collect();
920 let order_rewritten: Vec<Expr> = stmt
925 .order_by
926 .iter()
927 .map(|o| rewrite_expr(&o.expr, group_exprs, agg_specs))
928 .collect();
929 let defer_enabled = correlated_eval.is_some()
930 && !stmt.distinct
931 && !having_rewritten
932 .as_ref()
933 .is_some_and(crate::expr_has_subquery)
934 && !order_rewritten.iter().any(crate::expr_has_subquery);
935 let deferred: Vec<(usize, Expr)> = if defer_enabled {
936 items_rewritten
937 .iter()
938 .enumerate()
939 .filter_map(|(i, r)| {
940 r.as_ref()
941 .filter(|e| crate::expr_has_subquery(e))
942 .map(|e| (i, e.clone()))
943 })
944 .collect()
945 } else {
946 Vec::new()
947 };
948 let having_compiled = having_rewritten
954 .as_ref()
955 .filter(|h| eval::fully_compilable(h))
956 .map(|h| eval::compile_expr(h, &synth_ctx));
957 let items_compiled: Vec<Option<eval::CompiledExpr>> = items_rewritten
958 .iter()
959 .enumerate()
960 .map(|(i, r)| {
961 r.as_ref()
962 .filter(|e| !deferred.iter().any(|(c, _)| *c == i) && eval::fully_compilable(e))
963 .map(|e| eval::compile_expr(e, &synth_ctx))
964 })
965 .collect();
966 let mut kept_synth: Vec<Row> = Vec::new();
967 let mut out_rows: Vec<Row> = Vec::new();
968 let mut stack: Vec<Value> = Vec::new();
969 for srow in synth_rows {
970 if let Some(hc) = &having_compiled {
971 let cond = eval::eval_compiled(hc, &srow, &synth_ctx, &mut stack)?;
972 if !matches!(cond, Value::Bool(true)) {
973 continue;
974 }
975 } else if let Some(h) = &having_rewritten {
976 let cond = match correlated_eval {
977 Some(f) if crate::expr_has_subquery(h) => f(h, &srow, &synth_ctx)?,
978 _ => eval::eval_expr(h, &srow, &synth_ctx)?,
979 };
980 if !matches!(cond, Value::Bool(true)) {
981 continue;
982 }
983 }
984 let mut values: Vec<Value> = Vec::with_capacity(columns.len());
985 for (i, rewritten) in items_rewritten.iter().enumerate() {
986 let Some(rewritten) = rewritten else { continue };
987 if deferred.iter().any(|(c, _)| *c == i) {
988 values.push(Value::Null);
989 continue;
990 }
991 values.push(if let Some(cc) = &items_compiled[i] {
992 eval::eval_compiled(cc, &srow, &synth_ctx, &mut stack)?
993 } else {
994 match correlated_eval {
995 Some(f) if crate::expr_has_subquery(rewritten) => {
996 f(rewritten, &srow, &synth_ctx)?
997 }
998 _ => eval::eval_expr(rewritten, &srow, &synth_ctx)?,
999 }
1000 });
1001 }
1002 kept_synth.push(srow);
1003 out_rows.push(Row::new(values));
1004 }
1005 Ok(Projection {
1006 columns,
1007 out_rows,
1008 kept_synth,
1009 deferred,
1010 order_rewritten,
1011 })
1012}
1013
1014fn sort_synth_by_order_by(
1018 synth_schema: &[ColumnSchema],
1019 order_by: &[spg_sql::ast::OrderBy],
1020 order_rewritten: &[Expr],
1021 mut kept_synth: Vec<Row>,
1022 mut out_rows: Vec<Row>,
1023 correlated_eval: Option<CorrelatedEval<'_>>,
1024) -> Result<(Vec<Row>, Vec<Row>), EvalError> {
1025 let synth_ctx = EvalContext::new(synth_schema, None);
1026 let keys_meta: Vec<(bool, Option<bool>)> =
1031 order_by.iter().map(|o| (o.desc, o.nulls_first)).collect();
1032 let order_compiled: Vec<Option<eval::CompiledExpr>> = order_rewritten
1035 .iter()
1036 .map(|e| {
1037 Some(e)
1038 .filter(|e| eval::fully_compilable(e))
1039 .map(|e| eval::compile_expr(e, &synth_ctx))
1040 })
1041 .collect();
1042 let mut keystack: Vec<Value> = Vec::new();
1046 let mut tagged: Vec<(Vec<Value>, Row, Row)> = Vec::with_capacity(kept_synth.len());
1047 for (s, o) in kept_synth.into_iter().zip(out_rows) {
1048 let mut keys = Vec::with_capacity(order_rewritten.len());
1049 for (e, oc) in order_rewritten.iter().zip(&order_compiled) {
1050 keys.push(if let Some(oc) = oc {
1051 eval::eval_compiled(oc, &s, &synth_ctx, &mut keystack)?
1052 } else {
1053 match correlated_eval {
1054 Some(f) if crate::expr_has_subquery(e) => f(e, &s, &synth_ctx)?,
1055 _ => eval::eval_expr(e, &s, &synth_ctx)?,
1056 }
1057 });
1058 }
1059 tagged.push((keys, s, o));
1060 }
1061 tagged.sort_by(|a, b| {
1062 use core::cmp::Ordering;
1063 for (i, (ka, kb)) in a.0.iter().zip(b.0.iter()).enumerate() {
1064 let (desc, nf) = keys_meta[i];
1065 let cmp = crate::order_by_value_cmp(desc, nf, ka, kb);
1066 if cmp != Ordering::Equal {
1067 return cmp;
1068 }
1069 }
1070 Ordering::Equal
1071 });
1072 kept_synth = Vec::with_capacity(tagged.len());
1073 out_rows = Vec::with_capacity(tagged.len());
1074 for (_, s, o) in tagged {
1075 kept_synth.push(s);
1076 out_rows.push(o);
1077 }
1078 Ok((kept_synth, out_rows))
1079}
1080
1081fn validate_agg_arities(stmt: &SelectStatement, _specs: &[AggSpec]) -> Result<(), EvalError> {
1087 fn walk(e: &Expr) -> Result<(), EvalError> {
1088 if let Expr::FunctionCall { name, args } = e {
1089 let lower = name.to_ascii_lowercase();
1090 let expected: Option<usize> = match lower.as_str() {
1091 "count_star" => Some(0),
1092 "count" | "sum" | "avg" | "min" | "max" | "array_agg"
1093 | "bool_and" | "bool_or" | "every"
1097 | "stddev" | "stddev_samp" | "stddev_pop"
1100 | "variance" | "var_samp" | "var_pop"
1101 | "bit_and" | "bit_or" | "bit_xor"
1102 | "json_agg" | "jsonb_agg" => Some(1),
1103 "string_agg"
1106 | "covar_pop" | "covar_samp" | "corr"
1107 | "regr_count" | "regr_avgx" | "regr_avgy" | "regr_slope"
1108 | "regr_intercept" | "regr_r2" | "regr_sxx" | "regr_syy" | "regr_sxy"
1109 | "json_object_agg" | "jsonb_object_agg" => Some(2),
1110 _ => None,
1111 };
1112 if let Some(want) = expected
1113 && args.len() != want
1114 {
1115 return Err(EvalError::TypeMismatch {
1116 detail: alloc::format!("{lower}() takes {want} arg(s), got {}", args.len()),
1117 });
1118 }
1119 for a in args {
1120 walk(a)?;
1121 }
1122 } else if let Expr::Binary { lhs, rhs, .. } = e {
1123 walk(lhs)?;
1124 walk(rhs)?;
1125 } else if let Expr::Unary { expr, .. }
1126 | Expr::Cast { expr, .. }
1127 | Expr::IsNull { expr, .. } = e
1128 {
1129 walk(expr)?;
1130 }
1131 Ok(())
1132 }
1133 for item in &stmt.items {
1134 if let SelectItem::Expr { expr, .. } = item {
1135 walk(expr)?;
1136 }
1137 }
1138 for o in &stmt.order_by {
1139 walk(&o.expr)?;
1140 }
1141 if let Some(h) = &stmt.having {
1142 walk(h)?;
1143 }
1144 Ok(())
1145}
1146
1147fn collect_aggregates(e: &Expr, out: &mut Vec<AggSpec>) {
1148 match e {
1149 Expr::AggregateOrdered {
1152 call,
1153 order_by,
1154 distinct,
1155 filter,
1156 } => {
1157 if let Expr::FunctionCall { name, args } = call.as_ref() {
1158 let lower = name.to_ascii_lowercase();
1159 if is_aggregate_name(&lower) {
1160 let canonical = if lower == "every" {
1161 "bool_and".to_string()
1162 } else {
1163 lower
1164 };
1165 let ordered_set = is_within_group_name(&canonical);
1170 let (arg, direct_arg) = if ordered_set {
1171 (
1172 order_by.first().map(|o| o.expr.clone()),
1173 args.first().cloned(),
1174 )
1175 } else {
1176 (args.first().cloned(), None)
1177 };
1178 let spec = AggSpec {
1179 name: canonical.clone(),
1180 arg,
1181 arg2: if agg_uses_second_arg(&canonical) {
1182 args.get(1).cloned()
1183 } else {
1184 None
1185 },
1186 distinct: *distinct,
1187 order_by: order_by.clone(),
1188 filter: filter.as_deref().cloned(),
1189 direct_arg,
1190 };
1191 if !out.iter().any(|s| {
1192 s.name == spec.name
1193 && s.arg == spec.arg
1194 && s.arg2 == spec.arg2
1195 && s.distinct == spec.distinct
1196 && s.order_by == spec.order_by
1197 && s.filter == spec.filter
1198 && s.direct_arg == spec.direct_arg
1199 }) {
1200 out.push(spec);
1201 }
1202 return;
1203 }
1204 }
1205 collect_aggregates(call, out);
1206 for o in order_by {
1207 collect_aggregates(&o.expr, out);
1208 }
1209 }
1210 Expr::FunctionCall { name, args } => {
1211 let lower = name.to_ascii_lowercase();
1212 if is_aggregate_name(&lower) {
1213 let arg = if lower == "count_star" {
1214 None
1215 } else {
1216 args.first().cloned()
1217 };
1218 let arg2 = if agg_uses_second_arg(&lower) {
1222 args.get(1).cloned()
1223 } else {
1224 None
1225 };
1226 let canonical = if lower == "every" {
1230 "bool_and".to_string()
1231 } else {
1232 lower
1233 };
1234 let spec = AggSpec {
1235 name: canonical,
1236 arg: arg.clone(),
1237 arg2: arg2.clone(),
1238 distinct: false,
1239 order_by: Vec::new(),
1240 filter: None,
1241 direct_arg: None,
1242 };
1243 if !out.iter().any(|s| {
1244 s.name == spec.name
1245 && s.arg == spec.arg
1246 && s.arg2 == spec.arg2
1247 && !s.distinct
1248 && s.order_by == spec.order_by
1249 && s.filter.is_none()
1250 }) {
1251 out.push(spec);
1252 }
1253 } else {
1256 for a in args {
1257 collect_aggregates(a, out);
1258 }
1259 }
1260 }
1261 Expr::Binary { lhs, rhs, .. } => {
1262 collect_aggregates(lhs, out);
1263 collect_aggregates(rhs, out);
1264 }
1265 Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
1266 collect_aggregates(expr, out);
1267 }
1268 Expr::Like { expr, pattern, .. } => {
1269 collect_aggregates(expr, out);
1270 collect_aggregates(pattern, out);
1271 }
1272 Expr::InList { expr, list, .. } => {
1273 collect_aggregates(expr, out);
1274 for item in list {
1275 collect_aggregates(item, out);
1276 }
1277 }
1278 Expr::Extract { source, .. } => collect_aggregates(source, out),
1279 Expr::ScalarSubquery(_)
1282 | Expr::Exists { .. }
1283 | Expr::InSubquery { .. }
1284 | Expr::WindowFunction { .. }
1285 | Expr::Literal(_)
1286 | Expr::Placeholder(_)
1287 | Expr::Column(_) => {}
1288 Expr::Array(items) => {
1291 for elem in items {
1292 collect_aggregates(elem, out);
1293 }
1294 }
1295 Expr::ArraySubscript { target, index } => {
1296 collect_aggregates(target, out);
1297 collect_aggregates(index, out);
1298 }
1299 Expr::AnyAll { expr, array, .. } => {
1300 collect_aggregates(expr, out);
1301 collect_aggregates(array, out);
1302 }
1303 Expr::Case {
1304 operand,
1305 branches,
1306 else_branch,
1307 } => {
1308 if let Some(o) = operand {
1309 collect_aggregates(o, out);
1310 }
1311 for (w, t) in branches {
1312 collect_aggregates(w, out);
1313 collect_aggregates(t, out);
1314 }
1315 if let Some(e) = else_branch {
1316 collect_aggregates(e, out);
1317 }
1318 }
1319 }
1320}
1321
1322fn update_state(
1323 st: &mut AggState,
1324 name: &str,
1325 v: &Value,
1326 arg2: Option<&Value>,
1327 order_keys: Option<Vec<Value>>,
1328) -> Result<(), EvalError> {
1329 let is_null = matches!(v, Value::Null);
1330 match name {
1331 "count_star" => st.count += 1,
1332 "count" => {
1333 if !is_null {
1334 st.count += 1;
1335 }
1336 }
1337 "sum" | "avg" => {
1338 if is_null {
1339 return Ok(());
1340 }
1341 st.count += 1;
1342 match v {
1343 Value::Int(n) => st.sum_int += i64::from(*n),
1344 Value::BigInt(n) => st.sum_int += *n,
1345 Value::Float(x) => {
1346 st.use_float = true;
1347 st.sum_float += *x;
1348 }
1349 other => {
1350 return Err(EvalError::TypeMismatch {
1351 detail: format!("sum/avg need numeric, got {:?}", other.data_type()),
1352 });
1353 }
1354 }
1355 }
1356 "min" => {
1357 if is_null {
1358 return Ok(());
1359 }
1360 match &st.extreme {
1361 None => st.extreme = Some(v.clone()),
1362 Some(cur) => {
1363 if value_cmp(v, cur) == core::cmp::Ordering::Less {
1364 st.extreme = Some(v.clone());
1365 }
1366 }
1367 }
1368 }
1369 "max" => {
1370 if is_null {
1371 return Ok(());
1372 }
1373 match &st.extreme {
1374 None => st.extreme = Some(v.clone()),
1375 Some(cur) => {
1376 if value_cmp(v, cur) == core::cmp::Ordering::Greater {
1377 st.extreme = Some(v.clone());
1378 }
1379 }
1380 }
1381 }
1382 "string_agg" => {
1390 if let Some(sep) = arg2
1391 && let Value::Text(s) = sep
1392 {
1393 st.separator = Some(s.clone());
1394 }
1395 if is_null {
1396 return Ok(());
1397 }
1398 if let Value::Text(s) = v {
1399 st.items.push(Value::Text(s.clone()));
1400 if let Some(k) = order_keys {
1401 st.item_keys.push(k);
1402 }
1403 st.count += 1;
1404 } else {
1405 return Err(EvalError::TypeMismatch {
1406 detail: format!("string_agg requires text value, got {:?}", v.data_type()),
1407 });
1408 }
1409 }
1410 "array_agg" => {
1416 st.items.push(v.clone());
1417 if let Some(k) = order_keys {
1418 st.item_keys.push(k);
1419 }
1420 st.count += 1;
1421 }
1422 "bool_and" => {
1426 if is_null {
1427 return Ok(());
1428 }
1429 let b = match v {
1430 Value::Bool(b) => *b,
1431 other => {
1432 return Err(EvalError::TypeMismatch {
1433 detail: format!("bool_and requires bool, got {:?}", other.data_type()),
1434 });
1435 }
1436 };
1437 st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc && b));
1438 }
1439 "bool_or" => {
1442 if is_null {
1443 return Ok(());
1444 }
1445 let b = match v {
1446 Value::Bool(b) => *b,
1447 other => {
1448 return Err(EvalError::TypeMismatch {
1449 detail: format!("bool_or requires bool, got {:?}", other.data_type()),
1450 });
1451 }
1452 };
1453 st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc || b));
1454 }
1455 "stddev" | "stddev_samp" | "stddev_pop" | "variance" | "var_samp" | "var_pop" => {
1459 if is_null {
1460 return Ok(());
1461 }
1462 let x = match v {
1463 Value::Int(n) => f64::from(*n),
1464 Value::SmallInt(n) => f64::from(*n),
1465 Value::BigInt(n) => *n as f64,
1466 Value::Float(x) => *x,
1467 other => {
1468 return Err(EvalError::TypeMismatch {
1469 detail: format!("{name} needs numeric, got {:?}", other.data_type()),
1470 });
1471 }
1472 };
1473 st.count += 1;
1474 st.sum_float += x;
1475 st.sum_sq += x * x;
1476 }
1477 "bit_and" | "bit_or" | "bit_xor" => {
1479 if is_null {
1480 return Ok(());
1481 }
1482 let n = match v {
1483 Value::Int(n) => i64::from(*n),
1484 Value::SmallInt(n) => i64::from(*n),
1485 Value::BigInt(n) => *n,
1486 other => {
1487 return Err(EvalError::TypeMismatch {
1488 detail: format!("{name} needs integer, got {:?}", other.data_type()),
1489 });
1490 }
1491 };
1492 st.bit_acc = Some(match (st.bit_acc, name) {
1493 (None, _) => n,
1494 (Some(acc), "bit_and") => acc & n,
1495 (Some(acc), "bit_or") => acc | n,
1496 (Some(acc), _) => acc ^ n, });
1498 }
1499 n if is_within_group_name(n) => {
1504 if is_null {
1505 return Ok(());
1506 }
1507 st.items.push(v.clone());
1508 if let Some(k) = order_keys {
1509 st.item_keys.push(k);
1510 }
1511 st.count += 1;
1512 }
1513 n if is_regression_name(n) => {
1517 let (Some(y), Some(x)) = (agg_value_to_f64(v), arg2.and_then(agg_value_to_f64)) else {
1518 return Ok(()); };
1520 st.reg_n += 1;
1521 st.reg_sx += x;
1522 st.reg_sy += y;
1523 st.reg_sxx += x * x;
1524 st.reg_syy += y * y;
1525 st.reg_sxy += x * y;
1526 }
1527 "json_agg" | "jsonb_agg" => {
1530 st.items.push(v.clone());
1531 st.count += 1;
1532 }
1533 "json_object_agg" | "jsonb_object_agg" => {
1537 if is_null {
1538 return Ok(());
1539 }
1540 st.items.push(v.clone());
1541 st.aux_items.push(arg2.cloned().unwrap_or(Value::Null));
1542 st.count += 1;
1543 }
1544 _ => unreachable!("non-aggregate {name} in update_state"),
1545 }
1546 Ok(())
1547}
1548
1549#[allow(clippy::cast_precision_loss)]
1550fn finalize(name: &str, st: &AggState) -> Value {
1551 match name {
1552 "count" | "count_star" => Value::BigInt(st.count),
1553 "sum" => {
1554 if st.count == 0 {
1555 Value::Null
1556 } else if st.use_float {
1557 Value::Float(st.sum_float + (st.sum_int as f64))
1558 } else {
1559 Value::BigInt(st.sum_int)
1560 }
1561 }
1562 "avg" => {
1563 if st.count == 0 {
1564 Value::Null
1565 } else {
1566 let total = if st.use_float {
1567 st.sum_float + (st.sum_int as f64)
1568 } else {
1569 st.sum_int as f64
1570 };
1571 Value::Float(total / (st.count as f64))
1572 }
1573 }
1574 "min" | "max" => st.extreme.clone().unwrap_or(Value::Null),
1575 "string_agg" => {
1579 if st.items.is_empty() {
1580 return Value::Null;
1581 }
1582 let sep = st.separator.clone().unwrap_or_default();
1583 let mut out = String::new();
1584 for (i, item) in st.items.iter().enumerate() {
1585 if i > 0 {
1586 out.push_str(&sep);
1587 }
1588 if let Value::Text(s) = item {
1589 out.push_str(s);
1590 }
1591 }
1592 Value::Text(out)
1593 }
1594 "array_agg" => {
1601 if st.items.is_empty() {
1602 return Value::Null;
1603 }
1604 let probe = st.items.iter().find(|v| !v.is_null());
1605 match probe.and_then(spg_storage::Value::data_type) {
1606 Some(DataType::Int) | Some(DataType::SmallInt) => {
1607 let items: Vec<Option<i32>> = st
1608 .items
1609 .iter()
1610 .map(|v| match v {
1611 Value::Int(n) => Some(*n),
1612 Value::SmallInt(n) => Some(i32::from(*n)),
1613 _ => None,
1614 })
1615 .collect();
1616 Value::IntArray(items)
1617 }
1618 Some(DataType::BigInt) => {
1619 let items: Vec<Option<i64>> = st
1620 .items
1621 .iter()
1622 .map(|v| match v {
1623 Value::BigInt(n) => Some(*n),
1624 _ => None,
1625 })
1626 .collect();
1627 Value::BigIntArray(items)
1628 }
1629 _ => {
1630 let items: Vec<Option<String>> = st
1631 .items
1632 .iter()
1633 .map(|v| match v {
1634 Value::Text(s) => Some(s.clone()),
1635 Value::Null => None,
1636 other => Some(format!("{other:?}")),
1637 })
1638 .collect();
1639 Value::TextArray(items)
1640 }
1641 }
1642 }
1643 "bool_and" | "bool_or" => st.bool_acc.map_or(Value::Null, Value::Bool),
1647 "variance" | "var_samp" | "var_pop" | "stddev" | "stddev_samp" | "stddev_pop" => {
1651 let n = st.count;
1652 if n == 0 {
1653 return Value::Null;
1654 }
1655 let nf = n as f64;
1656 let ss = st.sum_sq - (st.sum_float * st.sum_float) / nf;
1658 let pop = name.ends_with("_pop");
1659 let denom = if pop { nf } else { nf - 1.0 };
1660 if denom <= 0.0 {
1661 return Value::Null;
1663 }
1664 let var = (ss / denom).max(0.0); if name.starts_with("stddev") {
1666 Value::Float(crate::eval::f64_sqrt(var))
1667 } else {
1668 Value::Float(var)
1669 }
1670 }
1671 "bit_and" | "bit_or" | "bit_xor" => st.bit_acc.map_or(Value::Null, Value::BigInt),
1674 "regr_count" => Value::BigInt(st.reg_n),
1678 "covar_pop" | "covar_samp" | "corr" | "regr_avgx" | "regr_avgy" | "regr_slope"
1679 | "regr_intercept" | "regr_r2" | "regr_sxx" | "regr_syy" | "regr_sxy" => {
1680 let n = st.reg_n;
1681 if n == 0 {
1682 return Value::Null;
1683 }
1684 let nf = n as f64;
1685 let sxx = st.reg_sxx - st.reg_sx * st.reg_sx / nf;
1686 let syy = st.reg_syy - st.reg_sy * st.reg_sy / nf;
1687 let sxy = st.reg_sxy - st.reg_sx * st.reg_sy / nf;
1688 let avgx = st.reg_sx / nf;
1689 let avgy = st.reg_sy / nf;
1690 let out = match name {
1691 "regr_avgx" => Some(avgx),
1692 "regr_avgy" => Some(avgy),
1693 "regr_sxx" => Some(sxx),
1694 "regr_syy" => Some(syy),
1695 "regr_sxy" => Some(sxy),
1696 "covar_pop" => Some(sxy / nf),
1697 "covar_samp" => (n >= 2).then(|| sxy / (nf - 1.0)),
1698 "regr_slope" => (sxx != 0.0).then(|| sxy / sxx),
1699 "regr_intercept" => (sxx != 0.0).then(|| avgy - (sxy / sxx) * avgx),
1700 "corr" => {
1701 let d = sxx * syy;
1702 (d > 0.0).then(|| sxy / crate::eval::f64_sqrt(d))
1703 }
1704 "regr_r2" => {
1706 if sxx == 0.0 {
1707 None
1708 } else if syy == 0.0 {
1709 Some(1.0)
1710 } else {
1711 Some((sxy * sxy) / (sxx * syy))
1712 }
1713 }
1714 _ => None,
1715 };
1716 out.map_or(Value::Null, Value::Float)
1717 }
1718 "json_agg" | "jsonb_agg" => {
1721 if st.items.is_empty() {
1722 return Value::Null;
1723 }
1724 let mut out = String::from("[");
1725 for (i, item) in st.items.iter().enumerate() {
1726 if i > 0 {
1727 out.push_str(", ");
1728 }
1729 out.push_str(&crate::json::value_to_json_text(item));
1730 }
1731 out.push(']');
1732 Value::Json(out)
1733 }
1734 "json_object_agg" | "jsonb_object_agg" => {
1737 if st.items.is_empty() {
1738 return Value::Null;
1739 }
1740 let mut out = String::from("{");
1741 for (i, key) in st.items.iter().enumerate() {
1742 if i > 0 {
1743 out.push_str(", ");
1744 }
1745 let key_text = match key {
1747 Value::Text(s) | Value::Json(s) => s.clone(),
1748 other => crate::json::value_to_json_text(other),
1749 };
1750 out.push_str(&crate::json::value_to_json_text(&Value::Text(key_text)));
1751 out.push_str(": ");
1752 let val = st.aux_items.get(i).unwrap_or(&Value::Null);
1753 out.push_str(&crate::json::value_to_json_text(val));
1754 }
1755 out.push('}');
1756 Value::Json(out)
1757 }
1758 _ => unreachable!(),
1761 }
1762}
1763
1764fn agg_value_to_f64(v: &Value) -> Option<f64> {
1766 match v {
1767 Value::Int(n) => Some(f64::from(*n)),
1768 Value::SmallInt(n) => Some(f64::from(*n)),
1769 Value::BigInt(n) => Some(*n as f64),
1770 Value::Float(x) => Some(*x),
1771 _ => None,
1772 }
1773}
1774
1775#[allow(
1782 clippy::cast_precision_loss,
1783 clippy::cast_possible_truncation,
1784 clippy::cast_sign_loss
1785)]
1786fn finalize_ordered_set(
1787 name: &str,
1788 st: &AggState,
1789 direct: Option<&Value>,
1790 order: Option<&spg_sql::ast::OrderBy>,
1791) -> Value {
1792 let fraction = direct;
1793 let items = &st.items;
1794 if items.is_empty() {
1795 return match name {
1798 "rank" | "dense_rank" => Value::BigInt(1),
1799 "percent_rank" => Value::Float(0.0),
1800 "cume_dist" => Value::Float(1.0),
1801 _ => Value::Null,
1802 };
1803 }
1804 let n = items.len();
1805 match name {
1806 "rank" | "dense_rank" | "percent_rank" | "cume_dist" => {
1809 let Some(h) = fraction else {
1810 return Value::Null;
1811 };
1812 let (desc, nulls_first) = order.map_or((false, None), |o| (o.desc, o.nulls_first));
1813 let mut before = 0usize; let mut before_or_eq = 0usize; let mut distinct_before = 0usize;
1816 let mut last_before: Option<&Value> = None;
1817 for it in items {
1818 match crate::order_by_value_cmp(desc, nulls_first, it, h) {
1819 core::cmp::Ordering::Less => {
1820 before += 1;
1821 before_or_eq += 1;
1822 if last_before
1823 .is_none_or(|p| value_cmp(p, it) != core::cmp::Ordering::Equal)
1824 {
1825 distinct_before += 1;
1826 last_before = Some(it);
1827 }
1828 }
1829 core::cmp::Ordering::Equal => before_or_eq += 1,
1830 core::cmp::Ordering::Greater => {}
1831 }
1832 }
1833 let nn = n as f64;
1834 match name {
1835 "rank" => Value::BigInt((before + 1) as i64),
1836 "dense_rank" => Value::BigInt((distinct_before + 1) as i64),
1837 "percent_rank" => Value::Float(before as f64 / nn),
1838 "cume_dist" => Value::Float((before_or_eq as f64 + 1.0) / (nn + 1.0)),
1839 _ => unreachable!(),
1840 }
1841 }
1842 "mode" => {
1846 let (mut best_i, mut best_cnt) = (0usize, 1usize);
1847 let (mut run_i, mut run_cnt) = (0usize, 1usize);
1848 for i in 1..n {
1849 if value_cmp(&items[i], &items[run_i]) == core::cmp::Ordering::Equal {
1850 run_cnt += 1;
1851 } else {
1852 run_i = i;
1853 run_cnt = 1;
1854 }
1855 if run_cnt > best_cnt {
1856 best_cnt = run_cnt;
1857 best_i = run_i;
1858 }
1859 }
1860 items[best_i].clone()
1861 }
1862 "percentile_disc" => {
1864 let f = fraction
1865 .and_then(agg_value_to_f64)
1866 .unwrap_or(0.0)
1867 .clamp(0.0, 1.0);
1868 let idx = if f <= 0.0 {
1869 0
1870 } else {
1871 (crate::eval::f64_ceil(f * n as f64) as usize)
1872 .saturating_sub(1)
1873 .min(n - 1)
1874 };
1875 items[idx].clone()
1876 }
1877 "percentile_cont" => {
1879 let f = fraction
1880 .and_then(agg_value_to_f64)
1881 .unwrap_or(0.0)
1882 .clamp(0.0, 1.0);
1883 let Some(nums) = items
1884 .iter()
1885 .map(agg_value_to_f64)
1886 .collect::<Option<Vec<f64>>>()
1887 else {
1888 return Value::Null; };
1890 if n == 1 {
1891 return Value::Float(nums[0]);
1892 }
1893 let rank = f * (n as f64 - 1.0);
1894 let lo = crate::eval::f64_floor(rank) as usize;
1895 let hi = crate::eval::f64_ceil(rank) as usize;
1896 let frac = rank - lo as f64;
1897 Value::Float(nums[lo] + (nums[hi] - nums[lo]) * frac)
1898 }
1899 _ => unreachable!(),
1900 }
1901}
1902
1903fn infer_agg_type(spec: &AggSpec, schema_cols: &[ColumnSchema]) -> DataType {
1904 let arg_ty = spec
1908 .arg
1909 .as_ref()
1910 .and_then(|a| crate::describe::describe_expr(a, schema_cols))
1911 .map(|shape| shape.ty);
1912 match spec.name.as_str() {
1913 "count" | "count_star" => DataType::BigInt,
1914 "sum" => match arg_ty {
1915 Some(DataType::Float) => DataType::Float,
1916 _ => DataType::BigInt,
1917 },
1918 "avg" => DataType::Float,
1919 "string_agg" => DataType::Text,
1921 "array_agg" => match arg_ty {
1922 Some(DataType::Int | DataType::SmallInt) => DataType::IntArray,
1923 Some(DataType::BigInt) => DataType::BigIntArray,
1924 _ => DataType::TextArray,
1925 },
1926 "bool_and" | "bool_or" => DataType::Bool,
1929 "stddev" | "stddev_samp" | "stddev_pop" | "variance" | "var_samp" | "var_pop"
1933 | "percentile_cont" | "covar_pop" | "covar_samp" | "corr" | "regr_avgx" | "regr_avgy"
1934 | "regr_slope" | "regr_intercept" | "regr_r2" | "regr_sxx" | "regr_syy" | "regr_sxy" => {
1935 DataType::Float
1936 }
1937 "bit_and" | "bit_or" | "bit_xor" | "regr_count" | "rank" | "dense_rank" => DataType::BigInt,
1940 "percent_rank" | "cume_dist" => DataType::Float,
1942 "json_agg" | "jsonb_agg" | "json_object_agg" | "jsonb_object_agg" => DataType::Json,
1944 _ => arg_ty.unwrap_or(DataType::Text),
1948 }
1949}
1950
1951fn agg_or_group_type(e: &Expr, synth: &[ColumnSchema]) -> DataType {
1952 if let Expr::Column(c) = e
1953 && let Some(s) = synth.iter().find(|s| s.name == c.name)
1954 {
1955 return s.ty;
1956 }
1957 crate::describe::describe_expr(e, synth)
1963 .map(|shape| shape.ty)
1964 .unwrap_or(DataType::Text)
1965}
1966
1967fn rewrite_expr(e: &Expr, group_exprs: &[Expr], aggs: &[AggSpec]) -> Expr {
1968 if let Expr::AggregateOrdered {
1971 call,
1972 order_by,
1973 distinct,
1974 filter,
1975 } = e
1976 && let Expr::FunctionCall { name, args } = call.as_ref()
1977 {
1978 let lower = name.to_ascii_lowercase();
1979 if is_aggregate_name(&lower) {
1980 let canonical: &str = if lower == "every" { "bool_and" } else { &lower };
1981 let (arg, direct_arg) = if is_within_group_name(canonical) {
1984 (
1985 order_by.first().map(|o| o.expr.clone()),
1986 args.first().cloned(),
1987 )
1988 } else {
1989 (args.first().cloned(), None)
1990 };
1991 let arg2 = if agg_uses_second_arg(canonical) {
1992 args.get(1).cloned()
1993 } else {
1994 None
1995 };
1996 let filter_owned = filter.as_deref().cloned();
1997 for (i, spec) in aggs.iter().enumerate() {
1998 if spec.name == canonical
1999 && spec.arg == arg
2000 && spec.arg2 == arg2
2001 && spec.distinct == *distinct
2002 && spec.order_by == *order_by
2003 && spec.filter == filter_owned
2004 && spec.direct_arg == direct_arg
2005 {
2006 return Expr::Column(spg_sql::ast::ColumnName {
2007 qualifier: None,
2008 name: format!("__agg_{i}"),
2009 });
2010 }
2011 }
2012 }
2013 }
2014 if let Expr::FunctionCall { name, args } = e {
2016 let lower = name.to_ascii_lowercase();
2017 if is_aggregate_name(&lower) {
2018 let arg = if lower == "count_star" {
2019 None
2020 } else {
2021 args.first().cloned()
2022 };
2023 let arg2 = if agg_uses_second_arg(&lower) {
2027 args.get(1).cloned()
2028 } else {
2029 None
2030 };
2031 let canonical: &str = if lower == "every" {
2035 "bool_and"
2036 } else {
2037 lower.as_str()
2038 };
2039 for (i, spec) in aggs.iter().enumerate() {
2040 if spec.name == canonical
2041 && spec.arg == arg
2042 && spec.arg2 == arg2
2043 && !spec.distinct
2044 && spec.order_by.is_empty()
2045 {
2046 return Expr::Column(spg_sql::ast::ColumnName {
2047 qualifier: None,
2048 name: format!("__agg_{i}"),
2049 });
2050 }
2051 }
2052 }
2053 }
2054 for (i, g) in group_exprs.iter().enumerate() {
2056 if g == e {
2057 return Expr::Column(spg_sql::ast::ColumnName {
2058 qualifier: None,
2059 name: format!("__grp_{i}"),
2060 });
2061 }
2062 }
2063 match e {
2065 Expr::AggregateOrdered {
2066 call,
2067 order_by,
2068 distinct,
2069 filter,
2070 } => Expr::AggregateOrdered {
2071 call: Box::new(rewrite_expr(call, group_exprs, aggs)),
2072 distinct: *distinct,
2073 order_by: order_by
2074 .iter()
2075 .map(|o| spg_sql::ast::OrderBy {
2076 expr: rewrite_expr(&o.expr, group_exprs, aggs),
2077 desc: o.desc,
2078 nulls_first: o.nulls_first,
2079 })
2080 .collect(),
2081 filter: filter.clone(),
2084 },
2085 Expr::Binary { lhs, op, rhs } => Expr::Binary {
2086 lhs: Box::new(rewrite_expr(lhs, group_exprs, aggs)),
2087 op: *op,
2088 rhs: Box::new(rewrite_expr(rhs, group_exprs, aggs)),
2089 },
2090 Expr::Unary { op, expr } => Expr::Unary {
2091 op: *op,
2092 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
2093 },
2094 Expr::Cast { expr, target } => Expr::Cast {
2095 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
2096 target: *target,
2097 },
2098 Expr::IsNull { expr, negated } => Expr::IsNull {
2099 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
2100 negated: *negated,
2101 },
2102 Expr::FunctionCall { name, args } => Expr::FunctionCall {
2103 name: name.clone(),
2104 args: args
2105 .iter()
2106 .map(|a| rewrite_expr(a, group_exprs, aggs))
2107 .collect(),
2108 },
2109 Expr::Like {
2110 expr,
2111 pattern,
2112 negated,
2113 case_insensitive,
2114 } => Expr::Like {
2115 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
2116 pattern: Box::new(rewrite_expr(pattern, group_exprs, aggs)),
2117 negated: *negated,
2118 case_insensitive: *case_insensitive,
2119 },
2120 Expr::Extract { field, source } => Expr::Extract {
2121 field: *field,
2122 source: Box::new(rewrite_expr(source, group_exprs, aggs)),
2123 },
2124 Expr::ScalarSubquery(s) => {
2130 Expr::ScalarSubquery(Box::new(rewrite_group_keys_in_select(s, group_exprs)))
2131 }
2132 Expr::Exists { subquery, negated } => Expr::Exists {
2133 subquery: Box::new(rewrite_group_keys_in_select(subquery, group_exprs)),
2134 negated: *negated,
2135 },
2136 Expr::InSubquery {
2137 expr,
2138 subquery,
2139 negated,
2140 } => Expr::InSubquery {
2141 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
2142 subquery: Box::new(rewrite_group_keys_in_select(subquery, group_exprs)),
2143 negated: *negated,
2144 },
2145 Expr::WindowFunction { .. } | Expr::Literal(_) | Expr::Placeholder(_) | Expr::Column(_) => {
2148 e.clone()
2149 }
2150 Expr::Array(items) => Expr::Array(
2152 items
2153 .iter()
2154 .map(|elem| rewrite_expr(elem, group_exprs, aggs))
2155 .collect(),
2156 ),
2157 Expr::ArraySubscript { target, index } => Expr::ArraySubscript {
2158 target: Box::new(rewrite_expr(target, group_exprs, aggs)),
2159 index: Box::new(rewrite_expr(index, group_exprs, aggs)),
2160 },
2161 Expr::AnyAll {
2162 expr,
2163 op,
2164 array,
2165 is_any,
2166 } => Expr::AnyAll {
2167 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
2168 op: *op,
2169 array: Box::new(rewrite_expr(array, group_exprs, aggs)),
2170 is_any: *is_any,
2171 },
2172 Expr::InList {
2173 expr,
2174 list,
2175 negated,
2176 } => Expr::InList {
2177 expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
2178 list: list
2179 .iter()
2180 .map(|item| rewrite_expr(item, group_exprs, aggs))
2181 .collect(),
2182 negated: *negated,
2183 },
2184 Expr::Case {
2185 operand,
2186 branches,
2187 else_branch,
2188 } => Expr::Case {
2189 operand: operand
2190 .as_deref()
2191 .map(|o| Box::new(rewrite_expr(o, group_exprs, aggs))),
2192 branches: branches
2193 .iter()
2194 .map(|(w, t)| {
2195 (
2196 rewrite_expr(w, group_exprs, aggs),
2197 rewrite_expr(t, group_exprs, aggs),
2198 )
2199 })
2200 .collect(),
2201 else_branch: else_branch
2202 .as_deref()
2203 .map(|e| Box::new(rewrite_expr(e, group_exprs, aggs))),
2204 },
2205 }
2206}
2207
2208fn rewrite_group_keys_in_select(
2213 s: &spg_sql::ast::SelectStatement,
2214 group_exprs: &[Expr],
2215) -> spg_sql::ast::SelectStatement {
2216 let mut out = s.clone();
2217 let _ = crate::walk_select_exprs_mut(&mut out, &mut |e| {
2218 *e = rewrite_expr(e, group_exprs, &[]);
2219 Ok(())
2220 });
2221 out
2222}
2223
2224fn encode_one(out: &mut String, v: &Value) {
2227 match v {
2228 Value::Null => out.push_str("N|"),
2229 Value::SmallInt(n) => {
2230 out.push('s');
2231 out.push_str(&n.to_string());
2232 out.push('|');
2233 }
2234 Value::Int(n) => {
2235 out.push('I');
2236 out.push_str(&n.to_string());
2237 out.push('|');
2238 }
2239 Value::BigInt(n) => {
2240 out.push('B');
2241 out.push_str(&n.to_string());
2242 out.push('|');
2243 }
2244 Value::Float(x) => {
2245 out.push('F');
2246 out.push_str(&x.to_string());
2247 out.push('|');
2248 }
2249 Value::Bool(b) => {
2250 out.push(if *b { 'T' } else { 'f' });
2251 out.push('|');
2252 }
2253 Value::Text(s) => {
2254 out.push('S');
2255 out.push_str(s);
2256 out.push('|');
2257 }
2258 Value::Vector(v) => {
2259 out.push('V');
2260 for x in v {
2261 out.push_str(&x.to_string());
2262 out.push(',');
2263 }
2264 out.push('|');
2265 }
2266 Value::Sq8Vector(q) => {
2272 out.push('Q');
2273 out.push_str(&q.min.to_string());
2274 out.push('@');
2275 out.push_str(&q.max.to_string());
2276 out.push(':');
2277 for b in &q.bytes {
2278 out.push_str(&b.to_string());
2279 out.push(',');
2280 }
2281 out.push('|');
2282 }
2283 Value::HalfVector(h) => {
2287 out.push('H');
2288 for b in &h.bytes {
2289 out.push_str(&b.to_string());
2290 out.push(',');
2291 }
2292 out.push('|');
2293 }
2294 Value::Numeric { scaled, scale } => {
2295 out.push('D');
2296 out.push_str(&scaled.to_string());
2297 out.push('@');
2298 out.push_str(&scale.to_string());
2299 out.push('|');
2300 }
2301 Value::Date(d) => {
2302 out.push('d');
2303 out.push_str(&d.to_string());
2304 out.push('|');
2305 }
2306 Value::Timestamp(t) => {
2307 out.push('t');
2308 out.push_str(&t.to_string());
2309 out.push('|');
2310 }
2311 Value::Interval { months, micros } => {
2312 out.push('i');
2313 out.push_str(&months.to_string());
2314 out.push('m');
2315 out.push_str(µs.to_string());
2316 out.push('|');
2317 }
2318 Value::Json(s) => {
2319 out.push('j');
2320 out.push_str(s);
2321 out.push('|');
2322 }
2323 _ => {
2328 out.push('?');
2329 out.push_str(&format!("{v:?}"));
2330 out.push('|');
2331 }
2332 }
2333}
2334
2335pub(crate) fn encode_key_refs(vals: &[&Value]) -> String {
2338 let mut out = String::new();
2339 for v in vals {
2340 encode_one(&mut out, v);
2341 }
2342 out
2343}
2344
2345pub(crate) fn encode_key_refs_into(vals: &[&Value], out: &mut String) {
2351 out.clear();
2352 for v in vals {
2353 encode_one(out, v);
2354 }
2355}
2356
2357pub(crate) fn encode_key(vals: &[Value]) -> String {
2358 let mut out = String::new();
2359 for v in vals {
2360 encode_one(&mut out, v);
2361 }
2362 out
2363}
2364
2365#[allow(clippy::cast_precision_loss)]
2366fn value_cmp(a: &Value, b: &Value) -> core::cmp::Ordering {
2367 use core::cmp::Ordering::Equal;
2368 match (a, b) {
2369 (Value::Null, Value::Null) => Equal,
2370 (Value::Null, _) => core::cmp::Ordering::Greater, (_, Value::Null) => core::cmp::Ordering::Less,
2372 (Value::Int(x), Value::Int(y)) => x.cmp(y),
2373 (Value::BigInt(x), Value::BigInt(y)) => x.cmp(y),
2374 (Value::Int(x), Value::BigInt(y)) => i64::from(*x).cmp(y),
2375 (Value::BigInt(x), Value::Int(y)) => x.cmp(&i64::from(*y)),
2376 (Value::Float(x), Value::Float(y)) => x.partial_cmp(y).unwrap_or(Equal),
2377 (Value::Int(x), Value::Float(y)) => f64::from(*x).partial_cmp(y).unwrap_or(Equal),
2378 (Value::Float(x), Value::Int(y)) => x.partial_cmp(&f64::from(*y)).unwrap_or(Equal),
2379 (Value::BigInt(x), Value::Float(y)) => (*x as f64).partial_cmp(y).unwrap_or(Equal),
2380 (Value::Float(x), Value::BigInt(y)) => x.partial_cmp(&(*y as f64)).unwrap_or(Equal),
2381 (Value::Text(x), Value::Text(y)) => x.cmp(y),
2382 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
2383 _ => Equal,
2384 }
2385}