use alloc::borrow::Cow;
use alloc::boxed::Box;
use alloc::collections::BTreeSet;
use alloc::format;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use spg_sql::ast::{Expr, SelectItem, SelectStatement};
use spg_storage::{ColumnSchema, DataType, Row, Value};
use crate::eval::{self, EvalContext, EvalError};
use crate::join::RowRef;
pub fn uses_aggregate(stmt: &SelectStatement) -> bool {
if stmt.group_by.is_some() || stmt.having.is_some() {
return true;
}
for item in &stmt.items {
if let SelectItem::Expr { expr, .. } = item
&& contains_aggregate(expr)
{
return true;
}
}
for o in &stmt.order_by {
if contains_aggregate(&o.expr) {
return true;
}
}
if let Some(h) = &stmt.having
&& contains_aggregate(h)
{
return true;
}
false
}
pub fn contains_aggregate(e: &Expr) -> bool {
match e {
Expr::FunctionCall { name, args } => {
is_aggregate_name(name) || args.iter().any(contains_aggregate)
}
Expr::AggregateOrdered { .. } => true,
Expr::Binary { lhs, rhs, .. } => contains_aggregate(lhs) || contains_aggregate(rhs),
Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
contains_aggregate(expr)
}
Expr::Like { expr, pattern, .. } => contains_aggregate(expr) || contains_aggregate(pattern),
Expr::Extract { source, .. } => contains_aggregate(source),
Expr::ScalarSubquery(_)
| Expr::Exists { .. }
| Expr::InSubquery { .. }
| Expr::WindowFunction { .. }
| Expr::Literal(_)
| Expr::Placeholder(_)
| Expr::Column(_) => false,
Expr::Array(items) => items.iter().any(contains_aggregate),
Expr::ArraySubscript { target, index } => {
contains_aggregate(target) || contains_aggregate(index)
}
Expr::AnyAll { expr, array, .. } => contains_aggregate(expr) || contains_aggregate(array),
Expr::InList { expr, list, .. } => {
contains_aggregate(expr) || list.iter().any(contains_aggregate)
}
Expr::Case {
operand,
branches,
else_branch,
} => {
operand.as_deref().is_some_and(contains_aggregate)
|| branches
.iter()
.any(|(w, t)| contains_aggregate(w) || contains_aggregate(t))
|| else_branch.as_deref().is_some_and(contains_aggregate)
}
}
}
pub fn is_aggregate_name(name: &str) -> bool {
matches!(
name.to_ascii_lowercase().as_str(),
"count"
| "count_star"
| "sum"
| "min"
| "max"
| "avg"
| "string_agg"
| "array_agg"
| "bool_and"
| "bool_or"
| "every"
| "stddev" | "stddev_samp" | "stddev_pop"
| "variance" | "var_samp" | "var_pop"
| "bit_and" | "bit_or" | "bit_xor"
| "percentile_cont" | "percentile_disc" | "mode"
| "rank" | "dense_rank" | "percent_rank" | "cume_dist"
| "covar_pop" | "covar_samp" | "corr"
| "regr_count" | "regr_avgx" | "regr_avgy" | "regr_slope"
| "regr_intercept" | "regr_r2" | "regr_sxx" | "regr_syy" | "regr_sxy"
| "json_agg" | "jsonb_agg" | "json_object_agg" | "jsonb_object_agg"
)
}
fn is_regression_name(name: &str) -> bool {
matches!(
name,
"covar_pop"
| "covar_samp"
| "corr"
| "regr_count"
| "regr_avgx"
| "regr_avgy"
| "regr_slope"
| "regr_intercept"
| "regr_r2"
| "regr_sxx"
| "regr_syy"
| "regr_sxy"
)
}
fn agg_uses_second_arg(name: &str) -> bool {
name == "string_agg"
|| name == "json_object_agg"
|| name == "jsonb_object_agg"
|| is_regression_name(name)
}
pub fn is_ordered_set_name(name: &str) -> bool {
["percentile_cont", "percentile_disc", "mode"]
.iter()
.any(|k| name.eq_ignore_ascii_case(k))
}
pub fn is_hypothetical_set_name(name: &str) -> bool {
["rank", "dense_rank", "percent_rank", "cume_dist"]
.iter()
.any(|k| name.eq_ignore_ascii_case(k))
}
pub fn is_within_group_name(name: &str) -> bool {
is_ordered_set_name(name) || is_hypothetical_set_name(name)
}
#[derive(Debug, Default, Clone)]
struct AggState {
count: i64,
sum_int: i64,
sum_float: f64,
extreme: Option<Value>,
use_float: bool,
items: Vec<Value>,
seen: BTreeSet<String>,
item_keys: Vec<Vec<Value>>,
separator: Option<String>,
bool_acc: Option<bool>,
sum_sq: f64,
bit_acc: Option<i64>,
reg_n: i64,
reg_sx: f64,
reg_sy: f64,
reg_sxx: f64,
reg_syy: f64,
reg_sxy: f64,
aux_items: Vec<Value>,
first_best: Option<(Vec<Value>, Value)>,
}
#[derive(Debug, Clone)]
struct AggSpec {
name: String, arg: Option<Expr>,
arg2: Option<Expr>,
distinct: bool,
order_by: Vec<spg_sql::ast::OrderBy>,
filter: Option<Expr>,
direct_arg: Option<Expr>,
first_ordered: bool,
}
#[derive(Debug)]
pub struct AggResult {
pub columns: Vec<ColumnSchema>,
pub rows: Vec<Row>,
pub deferred: Vec<(usize, Expr)>,
pub synth_rows: Vec<Row>,
pub synth_schema: Vec<ColumnSchema>,
}
#[allow(clippy::too_many_lines)]
pub type CorrelatedEval<'a> = &'a dyn Fn(&Expr, &Row, &EvalContext<'_>) -> Result<Value, EvalError>;
struct Projection {
columns: Vec<ColumnSchema>,
out_rows: Vec<Row>,
kept_synth: Vec<Row>,
deferred: Vec<(usize, Expr)>,
order_rewritten: Vec<Expr>,
}
pub(crate) fn run(
stmt: &SelectStatement,
rows: &[RowRef<'_>],
schema_cols: &[ColumnSchema],
table_alias: Option<&str>,
correlated_eval: Option<CorrelatedEval<'_>>,
) -> Result<AggResult, EvalError> {
let group_exprs: Vec<Expr> = stmt.group_by.clone().unwrap_or_default();
let mut agg_specs: Vec<AggSpec> = Vec::new();
for item in &stmt.items {
if let SelectItem::Expr { expr, .. } = item {
collect_aggregates(expr, &mut agg_specs);
}
}
for o in &stmt.order_by {
collect_aggregates(&o.expr, &mut agg_specs);
}
if let Some(h) = &stmt.having {
collect_aggregates(h, &mut agg_specs);
}
validate_agg_arities(stmt, &agg_specs)?;
validate_within_group(&agg_specs)?;
let order = accumulate_groups(
rows,
&group_exprs,
&agg_specs,
schema_cols,
table_alias,
correlated_eval,
)?;
let synth_schema =
build_synth_schema(rows, &group_exprs, &agg_specs, schema_cols, table_alias)?;
let synth_rows = finalize_synth_rows(
&order,
&agg_specs,
&synth_schema,
rows,
schema_cols,
table_alias,
)?;
let Projection {
columns,
mut out_rows,
mut kept_synth,
deferred,
order_rewritten,
} = project_groups(
synth_rows,
stmt,
&group_exprs,
&agg_specs,
&synth_schema,
correlated_eval,
)?;
if !stmt.order_by.is_empty() {
let (sorted_synth, sorted_out) = sort_synth_by_order_by(
&synth_schema,
&stmt.order_by,
&order_rewritten,
kept_synth,
out_rows,
correlated_eval,
)?;
kept_synth = sorted_synth;
out_rows = sorted_out;
}
let (synth_rows_out, synth_schema_out) = if deferred.is_empty() {
(Vec::new(), Vec::new())
} else {
(kept_synth, synth_schema.clone())
};
Ok(AggResult {
columns,
rows: out_rows,
deferred,
synth_rows: synth_rows_out,
synth_schema: synth_schema_out,
})
}
fn validate_within_group(agg_specs: &[AggSpec]) -> Result<(), EvalError> {
for spec in agg_specs {
if is_within_group_name(&spec.name) {
if spec.order_by.is_empty() {
return Err(EvalError::TypeMismatch {
detail: format!("{}() requires WITHIN GROUP (ORDER BY …)", spec.name),
});
}
if spec.name != "mode" && spec.direct_arg.is_none() {
return Err(EvalError::TypeMismatch {
detail: format!("{}() requires a direct argument", spec.name),
});
}
if spec.order_by.len() > 1 {
return Err(EvalError::TypeMismatch {
detail: format!(
"{}() with multiple WITHIN GROUP sort keys is not supported yet",
spec.name
),
});
}
}
}
Ok(())
}
#[allow(clippy::too_many_lines, clippy::type_complexity)]
fn accumulate_groups(
rows: &[RowRef<'_>],
group_exprs: &[Expr],
agg_specs: &[AggSpec],
schema_cols: &[ColumnSchema],
table_alias: Option<&str>,
correlated_eval: Option<CorrelatedEval<'_>>,
) -> Result<Vec<(Vec<Value>, Vec<AggState>)>, EvalError> {
let ctx = EvalContext::new(schema_cols, table_alias);
let mut order: Vec<(Vec<Value>, Vec<AggState>)> = Vec::new();
let mut groups: hashbrown::HashMap<String, usize> = hashbrown::HashMap::new();
if rows.is_empty() && group_exprs.is_empty() {
let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
order.push((Vec::new(), init));
}
let col_pos = |e: &Expr| -> Option<usize> {
if let Expr::Column(c) = e
&& c.qualifier.is_some()
{
eval::find_column_pos(c, &ctx)
} else {
None
}
};
let group_pos: Vec<Option<usize>> = group_exprs.iter().map(col_pos).collect();
let all_groups_bound = group_pos.iter().all(Option::is_some);
let arg_pos: Vec<Option<usize>> = agg_specs
.iter()
.map(|spec| spec.arg.as_ref().and_then(|e| col_pos(e)))
.collect();
let order_pos: Vec<Vec<Option<usize>>> = agg_specs
.iter()
.map(|spec| spec.order_by.iter().map(|o| col_pos(&o.expr)).collect())
.collect();
let needs_mat = agg_specs.iter().enumerate().any(|(i, s)| {
s.filter.is_some()
|| (s.arg.is_some() && arg_pos[i].is_none())
|| s.arg2.is_some()
|| order_pos[i].iter().any(Option::is_none)
});
let ci_positions: Vec<usize> = group_exprs
.iter()
.enumerate()
.filter(|(_, g)| {
matches!(
eval::column_collation(g, &ctx),
Some(spg_storage::Collation::CaseInsensitive)
)
})
.map(|(i, _)| i)
.collect();
let mut keybuf_s = String::new();
let mut dkeybuf = String::new();
let mut refs: Vec<&Value> = Vec::with_capacity(group_pos.len());
let any_agg_subquery = correlated_eval.is_some()
&& agg_specs.iter().any(|s| {
s.filter
.as_ref()
.is_some_and(|e| crate::expr_has_subquery(e))
|| s.arg.as_ref().is_some_and(|e| crate::expr_has_subquery(e))
|| s.arg2.as_ref().is_some_and(|e| crate::expr_has_subquery(e))
|| s.order_by.iter().any(|o| crate::expr_has_subquery(&o.expr))
});
let eval_arg = |e: &Expr, r: &Row, c: &EvalContext<'_>| -> Result<Value, EvalError> {
match correlated_eval {
Some(f) if any_agg_subquery && crate::expr_has_subquery(e) => f(e, r, c),
_ => eval::eval_expr(e, r, c),
}
};
for row in rows {
if all_groups_bound && ci_positions.is_empty() && !group_exprs.is_empty() {
refs.clear();
refs.extend(
group_pos
.iter()
.map(|p| row.get(p.unwrap()).unwrap_or(&Value::Null)),
);
encode_key_refs_into(&refs, &mut keybuf_s);
let idx = match groups.get(keybuf_s.as_str()) {
Some(&i) => i,
None => {
let i = order.len();
let init: Vec<AggState> =
(0..agg_specs.len()).map(|_| AggState::default()).collect();
let owned: Vec<Value> = refs.iter().map(|v| (*v).clone()).collect();
order.push((owned, init));
groups.insert(keybuf_s.clone(), i);
i
}
};
let entry = &mut order[idx];
let mat: Option<Cow<'_, Row>> = if needs_mat { Some(row.as_row()) } else { None };
for (i, spec) in agg_specs.iter().enumerate() {
if let Some(f) = &spec.filter
&& !matches!(
eval_arg(f, mat.as_deref().expect("needs_mat for FILTER"), &ctx)?,
Value::Bool(true)
)
{
continue;
}
let arg_owned: Value;
let arg_ref: &Value = match (&arg_pos[i], &spec.arg) {
(Some(p), _) => row.get(*p).unwrap_or(&Value::Null),
(None, None) => {
arg_owned = Value::Bool(true);
&arg_owned
}
(None, Some(e)) => {
arg_owned = eval_arg(
e,
mat.as_deref().expect("needs_mat for non-bound arg"),
&ctx,
)?;
&arg_owned
}
};
let arg2_val = match &spec.arg2 {
None => None,
Some(e) => Some(eval_arg(
e,
mat.as_deref().expect("needs_mat for arg2"),
&ctx,
)?),
};
let order_keys = if spec.order_by.is_empty() {
None
} else {
let mut keys = Vec::with_capacity(spec.order_by.len());
for (k, o) in spec.order_by.iter().enumerate() {
keys.push(match order_pos[i][k] {
Some(p) => row.get(p).cloned().unwrap_or(Value::Null),
None => eval_arg(
&o.expr,
mat.as_deref().expect("needs_mat for non-bound ORDER key"),
&ctx,
)?,
});
}
Some(keys)
};
if spec.first_ordered {
if let Some(keys) = order_keys {
let st = &mut entry.1[i];
let better = match &st.first_best {
None => true,
Some((bk, _)) => {
cmp_order_keys(&spec.order_by, &keys, bk)
== core::cmp::Ordering::Less
}
};
if better {
st.first_best = Some((keys, arg_ref.clone()));
}
}
continue;
}
if spec.distinct {
encode_key_refs_into(core::slice::from_ref(&arg_ref), &mut dkeybuf);
if entry.1[i].seen.contains(dkeybuf.as_str()) {
continue;
}
entry.1[i].seen.insert(dkeybuf.clone());
}
update_state(
&mut entry.1[i],
&spec.name,
arg_ref,
arg2_val.as_ref(),
order_keys,
)?;
}
continue;
}
let row_materialised = row.as_row();
let row: &Row = &row_materialised;
let group_vals: Vec<Value> = group_exprs
.iter()
.map(|g| eval::eval_expr(g, row, &ctx))
.collect::<Result<_, _>>()?;
let key = if ci_positions.is_empty() {
encode_key(&group_vals)
} else {
let mut key_vals = group_vals.clone();
for &i in &ci_positions {
if let Value::Text(s) = &key_vals[i] {
key_vals[i] = Value::Text(s.to_ascii_lowercase());
}
}
encode_key(&key_vals)
};
let idx = match groups.get(key.as_str()) {
Some(&i) => i,
None => {
let i = order.len();
let init: Vec<AggState> =
(0..agg_specs.len()).map(|_| AggState::default()).collect();
order.push((group_vals.clone(), init));
groups.insert(key, i);
i
}
};
let entry = &mut order[idx];
for (i, spec) in agg_specs.iter().enumerate() {
if let Some(f) = &spec.filter
&& !matches!(eval_arg(f, row, &ctx)?, Value::Bool(true))
{
continue;
}
let arg_val = match &spec.arg {
None => Value::Bool(true), Some(e) => eval_arg(e, row, &ctx)?,
};
let arg2_val = match &spec.arg2 {
None => None,
Some(e) => Some(eval_arg(e, row, &ctx)?),
};
let order_keys = if spec.order_by.is_empty() {
None
} else {
let mut keys = Vec::with_capacity(spec.order_by.len());
for o in &spec.order_by {
keys.push(eval_arg(&o.expr, row, &ctx)?);
}
Some(keys)
};
if spec.first_ordered {
if let Some(keys) = order_keys {
let st = &mut entry.1[i];
let better = match &st.first_best {
None => true,
Some((bk, _)) => {
cmp_order_keys(&spec.order_by, &keys, bk) == core::cmp::Ordering::Less
}
};
if better {
st.first_best = Some((keys, arg_val.clone()));
}
}
continue;
}
if spec.distinct {
let key = encode_key(core::slice::from_ref(&arg_val));
if !entry.1[i].seen.insert(key) {
continue;
}
}
update_state(
&mut entry.1[i],
&spec.name,
&arg_val,
arg2_val.as_ref(),
order_keys,
)?;
}
}
Ok(order)
}
fn build_synth_schema(
rows: &[RowRef<'_>],
group_exprs: &[Expr],
agg_specs: &[AggSpec],
schema_cols: &[ColumnSchema],
table_alias: Option<&str>,
) -> Result<Vec<ColumnSchema>, EvalError> {
let ctx = EvalContext::new(schema_cols, table_alias);
let group_types: Vec<DataType> = if rows.is_empty() {
group_exprs.iter().map(|_| DataType::Text).collect()
} else {
let probe_row = rows[0].as_row();
let probe: &Row = &probe_row;
group_exprs
.iter()
.map(|g| {
eval::eval_expr(g, probe, &ctx).map(|v| v.data_type().unwrap_or(DataType::Text))
})
.collect::<Result<_, _>>()?
};
let agg_types: Vec<DataType> = agg_specs
.iter()
.map(|spec| infer_agg_type(spec, schema_cols))
.collect();
let mut synth_schema: Vec<ColumnSchema> = Vec::new();
for (i, ty) in group_types.iter().enumerate() {
synth_schema.push(ColumnSchema::new(format!("__grp_{i}"), *ty, true));
}
for (i, ty) in agg_types.iter().enumerate() {
synth_schema.push(ColumnSchema::new(format!("__agg_{i}"), *ty, true));
}
Ok(synth_schema)
}
fn cmp_order_keys(
order_by: &[spg_sql::ast::OrderBy],
a: &[Value],
b: &[Value],
) -> core::cmp::Ordering {
for (k, o) in order_by.iter().enumerate() {
let cmp = crate::order_by_value_cmp(o.desc, o.nulls_first, &a[k], &b[k]);
if cmp != core::cmp::Ordering::Equal {
return cmp;
}
}
core::cmp::Ordering::Equal
}
fn finalize_synth_rows(
order: &[(Vec<Value>, Vec<AggState>)],
agg_specs: &[AggSpec],
synth_schema: &[ColumnSchema],
rows: &[RowRef<'_>],
schema_cols: &[ColumnSchema],
table_alias: Option<&str>,
) -> Result<Vec<Row>, EvalError> {
let ctx = EvalContext::new(schema_cols, table_alias);
let direct_arg_vals: Vec<Option<Value>> = agg_specs
.iter()
.map(|spec| match (&spec.direct_arg, rows.first()) {
(Some(e), Some(r)) => eval::eval_expr(e, &r.as_row(), &ctx).map(Some),
_ => Ok(None),
})
.collect::<Result<_, _>>()?;
let mut synth_rows: Vec<Row> = Vec::new();
for (gvals, states) in order {
let mut values: Vec<Value> = Vec::with_capacity(synth_schema.len());
values.extend(gvals.iter().cloned());
for (i, st) in states.iter().enumerate() {
if agg_specs[i].first_ordered {
values.push(
st.first_best
.as_ref()
.map_or(Value::Null, |(_, v)| v.clone()),
);
continue;
}
let st_sorted;
let st_final: &AggState =
if !agg_specs[i].order_by.is_empty() && st.item_keys.len() == st.items.len() {
let mut idx: Vec<usize> = (0..st.items.len()).collect();
let ob = &agg_specs[i].order_by;
idx.sort_by(|&x, &y| cmp_order_keys(ob, &st.item_keys[x], &st.item_keys[y]));
let mut sorted = st.clone();
sorted.items = idx.iter().map(|&j| st.items[j].clone()).collect();
st_sorted = sorted;
&st_sorted
} else {
st
};
let v = if is_within_group_name(&agg_specs[i].name) {
finalize_ordered_set(
&agg_specs[i].name,
st_final,
direct_arg_vals[i].as_ref(),
agg_specs[i].order_by.first(),
)
} else {
finalize(&agg_specs[i].name, st_final)
};
values.push(v);
}
synth_rows.push(Row::new(values));
}
Ok(synth_rows)
}
#[allow(clippy::too_many_lines)]
fn project_groups(
synth_rows: Vec<Row>,
stmt: &SelectStatement,
group_exprs: &[Expr],
agg_specs: &[AggSpec],
synth_schema: &[ColumnSchema],
correlated_eval: Option<CorrelatedEval<'_>>,
) -> Result<Projection, EvalError> {
let columns: Vec<ColumnSchema> = stmt
.items
.iter()
.map(|item| match item {
SelectItem::Wildcard => Err(EvalError::TypeMismatch {
detail: "SELECT * with aggregates is not supported".into(),
}),
SelectItem::Expr { expr, alias } => {
let rewritten = rewrite_expr(expr, group_exprs, agg_specs);
let name = alias.clone().unwrap_or_else(|| expr.to_string());
Ok(ColumnSchema::new(
name,
agg_or_group_type(&rewritten, synth_schema),
true,
))
}
})
.collect::<Result<_, _>>()?;
let synth_ctx = EvalContext::new(synth_schema, None);
let having_rewritten = stmt
.having
.as_ref()
.map(|h| rewrite_expr(h, group_exprs, agg_specs));
let items_rewritten: alloc::vec::Vec<Option<Expr>> = stmt
.items
.iter()
.map(|item| match item {
SelectItem::Expr { expr, .. } => Some(rewrite_expr(expr, group_exprs, agg_specs)),
SelectItem::Wildcard => None,
})
.collect();
let order_rewritten: Vec<Expr> = stmt
.order_by
.iter()
.map(|o| rewrite_expr(&o.expr, group_exprs, agg_specs))
.collect();
let defer_enabled = correlated_eval.is_some()
&& !stmt.distinct
&& !having_rewritten
.as_ref()
.is_some_and(crate::expr_has_subquery)
&& !order_rewritten.iter().any(crate::expr_has_subquery);
let deferred: Vec<(usize, Expr)> = if defer_enabled {
items_rewritten
.iter()
.enumerate()
.filter_map(|(i, r)| {
r.as_ref()
.filter(|e| crate::expr_has_subquery(e))
.map(|e| (i, e.clone()))
})
.collect()
} else {
Vec::new()
};
let having_compiled = having_rewritten
.as_ref()
.filter(|h| eval::fully_compilable(h))
.map(|h| eval::compile_expr(h, &synth_ctx));
let items_compiled: Vec<Option<eval::CompiledExpr>> = items_rewritten
.iter()
.enumerate()
.map(|(i, r)| {
r.as_ref()
.filter(|e| !deferred.iter().any(|(c, _)| *c == i) && eval::fully_compilable(e))
.map(|e| eval::compile_expr(e, &synth_ctx))
})
.collect();
let mut kept_synth: Vec<Row> = Vec::new();
let mut out_rows: Vec<Row> = Vec::new();
let mut stack: Vec<Value> = Vec::new();
for srow in synth_rows {
if let Some(hc) = &having_compiled {
let cond = eval::eval_compiled(hc, &srow, &synth_ctx, &mut stack)?;
if !matches!(cond, Value::Bool(true)) {
continue;
}
} else if let Some(h) = &having_rewritten {
let cond = match correlated_eval {
Some(f) if crate::expr_has_subquery(h) => f(h, &srow, &synth_ctx)?,
_ => eval::eval_expr(h, &srow, &synth_ctx)?,
};
if !matches!(cond, Value::Bool(true)) {
continue;
}
}
let mut values: Vec<Value> = Vec::with_capacity(columns.len());
for (i, rewritten) in items_rewritten.iter().enumerate() {
let Some(rewritten) = rewritten else { continue };
if deferred.iter().any(|(c, _)| *c == i) {
values.push(Value::Null);
continue;
}
values.push(if let Some(cc) = &items_compiled[i] {
eval::eval_compiled(cc, &srow, &synth_ctx, &mut stack)?
} else {
match correlated_eval {
Some(f) if crate::expr_has_subquery(rewritten) => {
f(rewritten, &srow, &synth_ctx)?
}
_ => eval::eval_expr(rewritten, &srow, &synth_ctx)?,
}
});
}
kept_synth.push(srow);
out_rows.push(Row::new(values));
}
Ok(Projection {
columns,
out_rows,
kept_synth,
deferred,
order_rewritten,
})
}
fn sort_synth_by_order_by(
synth_schema: &[ColumnSchema],
order_by: &[spg_sql::ast::OrderBy],
order_rewritten: &[Expr],
mut kept_synth: Vec<Row>,
mut out_rows: Vec<Row>,
correlated_eval: Option<CorrelatedEval<'_>>,
) -> Result<(Vec<Row>, Vec<Row>), EvalError> {
let synth_ctx = EvalContext::new(synth_schema, None);
let keys_meta: Vec<(bool, Option<bool>)> =
order_by.iter().map(|o| (o.desc, o.nulls_first)).collect();
let order_compiled: Vec<Option<eval::CompiledExpr>> = order_rewritten
.iter()
.map(|e| {
Some(e)
.filter(|e| eval::fully_compilable(e))
.map(|e| eval::compile_expr(e, &synth_ctx))
})
.collect();
let mut keystack: Vec<Value> = Vec::new();
let mut tagged: Vec<(Vec<Value>, Row, Row)> = Vec::with_capacity(kept_synth.len());
for (s, o) in kept_synth.into_iter().zip(out_rows) {
let mut keys = Vec::with_capacity(order_rewritten.len());
for (e, oc) in order_rewritten.iter().zip(&order_compiled) {
keys.push(if let Some(oc) = oc {
eval::eval_compiled(oc, &s, &synth_ctx, &mut keystack)?
} else {
match correlated_eval {
Some(f) if crate::expr_has_subquery(e) => f(e, &s, &synth_ctx)?,
_ => eval::eval_expr(e, &s, &synth_ctx)?,
}
});
}
tagged.push((keys, s, o));
}
tagged.sort_by(|a, b| {
use core::cmp::Ordering;
for (i, (ka, kb)) in a.0.iter().zip(b.0.iter()).enumerate() {
let (desc, nf) = keys_meta[i];
let cmp = crate::order_by_value_cmp(desc, nf, ka, kb);
if cmp != Ordering::Equal {
return cmp;
}
}
Ordering::Equal
});
kept_synth = Vec::with_capacity(tagged.len());
out_rows = Vec::with_capacity(tagged.len());
for (_, s, o) in tagged {
kept_synth.push(s);
out_rows.push(o);
}
Ok((kept_synth, out_rows))
}
fn validate_agg_arities(stmt: &SelectStatement, _specs: &[AggSpec]) -> Result<(), EvalError> {
fn walk(e: &Expr) -> Result<(), EvalError> {
if let Expr::FunctionCall { name, args } = e {
let lower = name.to_ascii_lowercase();
let expected: Option<usize> = match lower.as_str() {
"count_star" => Some(0),
"count" | "sum" | "avg" | "min" | "max" | "array_agg"
| "bool_and" | "bool_or" | "every"
| "stddev" | "stddev_samp" | "stddev_pop"
| "variance" | "var_samp" | "var_pop"
| "bit_and" | "bit_or" | "bit_xor"
| "json_agg" | "jsonb_agg" => Some(1),
"string_agg"
| "covar_pop" | "covar_samp" | "corr"
| "regr_count" | "regr_avgx" | "regr_avgy" | "regr_slope"
| "regr_intercept" | "regr_r2" | "regr_sxx" | "regr_syy" | "regr_sxy"
| "json_object_agg" | "jsonb_object_agg" => Some(2),
_ => None,
};
if let Some(want) = expected
&& args.len() != want
{
return Err(EvalError::TypeMismatch {
detail: alloc::format!("{lower}() takes {want} arg(s), got {}", args.len()),
});
}
for a in args {
walk(a)?;
}
} else if let Expr::Binary { lhs, rhs, .. } = e {
walk(lhs)?;
walk(rhs)?;
} else if let Expr::Unary { expr, .. }
| Expr::Cast { expr, .. }
| Expr::IsNull { expr, .. } = e
{
walk(expr)?;
}
Ok(())
}
for item in &stmt.items {
if let SelectItem::Expr { expr, .. } = item {
walk(expr)?;
}
}
for o in &stmt.order_by {
walk(&o.expr)?;
}
if let Some(h) = &stmt.having {
walk(h)?;
}
Ok(())
}
fn first_ordered_array_agg(e: &Expr) -> Option<(&Expr, &[spg_sql::ast::OrderBy], Option<&Expr>)> {
let Expr::ArraySubscript { target, index } = e else {
return None;
};
if !matches!(
index.as_ref(),
Expr::Literal(spg_sql::ast::Literal::Integer(1))
) {
return None;
}
let Expr::AggregateOrdered {
call,
order_by,
distinct,
filter,
} = target.as_ref()
else {
return None;
};
if *distinct || order_by.is_empty() {
return None;
}
let Expr::FunctionCall { name, args } = call.as_ref() else {
return None;
};
if !name.eq_ignore_ascii_case("array_agg") || args.len() != 1 {
return None;
}
Some((&args[0], order_by, filter.as_deref()))
}
fn collect_aggregates(e: &Expr, out: &mut Vec<AggSpec>) {
match e {
Expr::AggregateOrdered {
call,
order_by,
distinct,
filter,
} => {
if let Expr::FunctionCall { name, args } = call.as_ref() {
let lower = name.to_ascii_lowercase();
if is_aggregate_name(&lower) {
let canonical = if lower == "every" {
"bool_and".to_string()
} else {
lower
};
let ordered_set = is_within_group_name(&canonical);
let (arg, direct_arg) = if ordered_set {
(
order_by.first().map(|o| o.expr.clone()),
args.first().cloned(),
)
} else {
(args.first().cloned(), None)
};
let spec = AggSpec {
name: canonical.clone(),
arg,
arg2: if agg_uses_second_arg(&canonical) {
args.get(1).cloned()
} else {
None
},
distinct: *distinct,
order_by: order_by.clone(),
filter: filter.as_deref().cloned(),
direct_arg,
first_ordered: false,
};
if !out.iter().any(|s| {
s.name == spec.name
&& s.arg == spec.arg
&& s.arg2 == spec.arg2
&& s.distinct == spec.distinct
&& s.order_by == spec.order_by
&& s.filter == spec.filter
&& s.direct_arg == spec.direct_arg
&& s.first_ordered == spec.first_ordered
}) {
out.push(spec);
}
return;
}
}
collect_aggregates(call, out);
for o in order_by {
collect_aggregates(&o.expr, out);
}
}
Expr::FunctionCall { name, args } => {
let lower = name.to_ascii_lowercase();
if is_aggregate_name(&lower) {
let arg = if lower == "count_star" {
None
} else {
args.first().cloned()
};
let arg2 = if agg_uses_second_arg(&lower) {
args.get(1).cloned()
} else {
None
};
let canonical = if lower == "every" {
"bool_and".to_string()
} else {
lower
};
let spec = AggSpec {
name: canonical,
arg: arg.clone(),
arg2: arg2.clone(),
distinct: false,
order_by: Vec::new(),
filter: None,
direct_arg: None,
first_ordered: false,
};
if !out.iter().any(|s| {
s.name == spec.name
&& s.arg == spec.arg
&& s.arg2 == spec.arg2
&& !s.distinct
&& s.order_by == spec.order_by
&& s.filter.is_none()
&& !s.first_ordered
}) {
out.push(spec);
}
} else {
for a in args {
collect_aggregates(a, out);
}
}
}
Expr::Binary { lhs, rhs, .. } => {
collect_aggregates(lhs, out);
collect_aggregates(rhs, out);
}
Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
collect_aggregates(expr, out);
}
Expr::Like { expr, pattern, .. } => {
collect_aggregates(expr, out);
collect_aggregates(pattern, out);
}
Expr::InList { expr, list, .. } => {
collect_aggregates(expr, out);
for item in list {
collect_aggregates(item, out);
}
}
Expr::Extract { source, .. } => collect_aggregates(source, out),
Expr::ScalarSubquery(_)
| Expr::Exists { .. }
| Expr::InSubquery { .. }
| Expr::WindowFunction { .. }
| Expr::Literal(_)
| Expr::Placeholder(_)
| Expr::Column(_) => {}
Expr::Array(items) => {
for elem in items {
collect_aggregates(elem, out);
}
}
Expr::ArraySubscript { target, index } => {
if let Some((arg, order_by, filter)) = first_ordered_array_agg(e) {
let spec = AggSpec {
name: "array_agg".to_string(),
arg: Some(arg.clone()),
arg2: None,
distinct: false,
order_by: order_by.to_vec(),
filter: filter.cloned(),
direct_arg: None,
first_ordered: true,
};
if !out.iter().any(|s| {
s.name == spec.name
&& s.arg == spec.arg
&& s.order_by == spec.order_by
&& s.filter == spec.filter
&& s.first_ordered
}) {
out.push(spec);
}
return;
}
collect_aggregates(target, out);
collect_aggregates(index, out);
}
Expr::AnyAll { expr, array, .. } => {
collect_aggregates(expr, out);
collect_aggregates(array, out);
}
Expr::Case {
operand,
branches,
else_branch,
} => {
if let Some(o) = operand {
collect_aggregates(o, out);
}
for (w, t) in branches {
collect_aggregates(w, out);
collect_aggregates(t, out);
}
if let Some(e) = else_branch {
collect_aggregates(e, out);
}
}
}
}
fn update_state(
st: &mut AggState,
name: &str,
v: &Value,
arg2: Option<&Value>,
order_keys: Option<Vec<Value>>,
) -> Result<(), EvalError> {
let is_null = matches!(v, Value::Null);
match name {
"count_star" => st.count += 1,
"count" => {
if !is_null {
st.count += 1;
}
}
"sum" | "avg" => {
if is_null {
return Ok(());
}
st.count += 1;
match v {
Value::Int(n) => st.sum_int += i64::from(*n),
Value::BigInt(n) => st.sum_int += *n,
Value::Float(x) => {
st.use_float = true;
st.sum_float += *x;
}
other => {
return Err(EvalError::TypeMismatch {
detail: format!("sum/avg need numeric, got {:?}", other.data_type()),
});
}
}
}
"min" => {
if is_null {
return Ok(());
}
match &st.extreme {
None => st.extreme = Some(v.clone()),
Some(cur) => {
if value_cmp(v, cur) == core::cmp::Ordering::Less {
st.extreme = Some(v.clone());
}
}
}
}
"max" => {
if is_null {
return Ok(());
}
match &st.extreme {
None => st.extreme = Some(v.clone()),
Some(cur) => {
if value_cmp(v, cur) == core::cmp::Ordering::Greater {
st.extreme = Some(v.clone());
}
}
}
}
"string_agg" => {
if let Some(sep) = arg2
&& let Value::Text(s) = sep
{
st.separator = Some(s.clone());
}
if is_null {
return Ok(());
}
if let Value::Text(s) = v {
st.items.push(Value::Text(s.clone()));
if let Some(k) = order_keys {
st.item_keys.push(k);
}
st.count += 1;
} else {
return Err(EvalError::TypeMismatch {
detail: format!("string_agg requires text value, got {:?}", v.data_type()),
});
}
}
"array_agg" => {
st.items.push(v.clone());
if let Some(k) = order_keys {
st.item_keys.push(k);
}
st.count += 1;
}
"bool_and" => {
if is_null {
return Ok(());
}
let b = match v {
Value::Bool(b) => *b,
other => {
return Err(EvalError::TypeMismatch {
detail: format!("bool_and requires bool, got {:?}", other.data_type()),
});
}
};
st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc && b));
}
"bool_or" => {
if is_null {
return Ok(());
}
let b = match v {
Value::Bool(b) => *b,
other => {
return Err(EvalError::TypeMismatch {
detail: format!("bool_or requires bool, got {:?}", other.data_type()),
});
}
};
st.bool_acc = Some(st.bool_acc.map_or(b, |acc| acc || b));
}
"stddev" | "stddev_samp" | "stddev_pop" | "variance" | "var_samp" | "var_pop" => {
if is_null {
return Ok(());
}
let x = match v {
Value::Int(n) => f64::from(*n),
Value::SmallInt(n) => f64::from(*n),
Value::BigInt(n) => *n as f64,
Value::Float(x) => *x,
other => {
return Err(EvalError::TypeMismatch {
detail: format!("{name} needs numeric, got {:?}", other.data_type()),
});
}
};
st.count += 1;
st.sum_float += x;
st.sum_sq += x * x;
}
"bit_and" | "bit_or" | "bit_xor" => {
if is_null {
return Ok(());
}
let n = match v {
Value::Int(n) => i64::from(*n),
Value::SmallInt(n) => i64::from(*n),
Value::BigInt(n) => *n,
other => {
return Err(EvalError::TypeMismatch {
detail: format!("{name} needs integer, got {:?}", other.data_type()),
});
}
};
st.bit_acc = Some(match (st.bit_acc, name) {
(None, _) => n,
(Some(acc), "bit_and") => acc & n,
(Some(acc), "bit_or") => acc | n,
(Some(acc), _) => acc ^ n, });
}
n if is_within_group_name(n) => {
if is_null {
return Ok(());
}
st.items.push(v.clone());
if let Some(k) = order_keys {
st.item_keys.push(k);
}
st.count += 1;
}
n if is_regression_name(n) => {
let (Some(y), Some(x)) = (agg_value_to_f64(v), arg2.and_then(agg_value_to_f64)) else {
return Ok(()); };
st.reg_n += 1;
st.reg_sx += x;
st.reg_sy += y;
st.reg_sxx += x * x;
st.reg_syy += y * y;
st.reg_sxy += x * y;
}
"json_agg" | "jsonb_agg" => {
st.items.push(v.clone());
st.count += 1;
}
"json_object_agg" | "jsonb_object_agg" => {
if is_null {
return Ok(());
}
st.items.push(v.clone());
st.aux_items.push(arg2.cloned().unwrap_or(Value::Null));
st.count += 1;
}
_ => unreachable!("non-aggregate {name} in update_state"),
}
Ok(())
}
#[allow(clippy::cast_precision_loss)]
fn finalize(name: &str, st: &AggState) -> Value {
match name {
"count" | "count_star" => Value::BigInt(st.count),
"sum" => {
if st.count == 0 {
Value::Null
} else if st.use_float {
Value::Float(st.sum_float + (st.sum_int as f64))
} else {
Value::BigInt(st.sum_int)
}
}
"avg" => {
if st.count == 0 {
Value::Null
} else {
let total = if st.use_float {
st.sum_float + (st.sum_int as f64)
} else {
st.sum_int as f64
};
Value::Float(total / (st.count as f64))
}
}
"min" | "max" => st.extreme.clone().unwrap_or(Value::Null),
"string_agg" => {
if st.items.is_empty() {
return Value::Null;
}
let sep = st.separator.clone().unwrap_or_default();
let mut out = String::new();
for (i, item) in st.items.iter().enumerate() {
if i > 0 {
out.push_str(&sep);
}
if let Value::Text(s) = item {
out.push_str(s);
}
}
Value::Text(out)
}
"array_agg" => {
if st.items.is_empty() {
return Value::Null;
}
let probe = st.items.iter().find(|v| !v.is_null());
match probe.and_then(spg_storage::Value::data_type) {
Some(DataType::Int) | Some(DataType::SmallInt) => {
let items: Vec<Option<i32>> = st
.items
.iter()
.map(|v| match v {
Value::Int(n) => Some(*n),
Value::SmallInt(n) => Some(i32::from(*n)),
_ => None,
})
.collect();
Value::IntArray(items)
}
Some(DataType::BigInt) => {
let items: Vec<Option<i64>> = st
.items
.iter()
.map(|v| match v {
Value::BigInt(n) => Some(*n),
_ => None,
})
.collect();
Value::BigIntArray(items)
}
_ => {
let items: Vec<Option<String>> = st
.items
.iter()
.map(|v| match v {
Value::Text(s) => Some(s.clone()),
Value::Null => None,
other => Some(format!("{other:?}")),
})
.collect();
Value::TextArray(items)
}
}
}
"bool_and" | "bool_or" => st.bool_acc.map_or(Value::Null, Value::Bool),
"variance" | "var_samp" | "var_pop" | "stddev" | "stddev_samp" | "stddev_pop" => {
let n = st.count;
if n == 0 {
return Value::Null;
}
let nf = n as f64;
let ss = st.sum_sq - (st.sum_float * st.sum_float) / nf;
let pop = name.ends_with("_pop");
let denom = if pop { nf } else { nf - 1.0 };
if denom <= 0.0 {
return Value::Null;
}
let var = (ss / denom).max(0.0); if name.starts_with("stddev") {
Value::Float(crate::eval::f64_sqrt(var))
} else {
Value::Float(var)
}
}
"bit_and" | "bit_or" | "bit_xor" => st.bit_acc.map_or(Value::Null, Value::BigInt),
"regr_count" => Value::BigInt(st.reg_n),
"covar_pop" | "covar_samp" | "corr" | "regr_avgx" | "regr_avgy" | "regr_slope"
| "regr_intercept" | "regr_r2" | "regr_sxx" | "regr_syy" | "regr_sxy" => {
let n = st.reg_n;
if n == 0 {
return Value::Null;
}
let nf = n as f64;
let sxx = st.reg_sxx - st.reg_sx * st.reg_sx / nf;
let syy = st.reg_syy - st.reg_sy * st.reg_sy / nf;
let sxy = st.reg_sxy - st.reg_sx * st.reg_sy / nf;
let avgx = st.reg_sx / nf;
let avgy = st.reg_sy / nf;
let out = match name {
"regr_avgx" => Some(avgx),
"regr_avgy" => Some(avgy),
"regr_sxx" => Some(sxx),
"regr_syy" => Some(syy),
"regr_sxy" => Some(sxy),
"covar_pop" => Some(sxy / nf),
"covar_samp" => (n >= 2).then(|| sxy / (nf - 1.0)),
"regr_slope" => (sxx != 0.0).then(|| sxy / sxx),
"regr_intercept" => (sxx != 0.0).then(|| avgy - (sxy / sxx) * avgx),
"corr" => {
let d = sxx * syy;
(d > 0.0).then(|| sxy / crate::eval::f64_sqrt(d))
}
"regr_r2" => {
if sxx == 0.0 {
None
} else if syy == 0.0 {
Some(1.0)
} else {
Some((sxy * sxy) / (sxx * syy))
}
}
_ => None,
};
out.map_or(Value::Null, Value::Float)
}
"json_agg" | "jsonb_agg" => {
if st.items.is_empty() {
return Value::Null;
}
let mut out = String::from("[");
for (i, item) in st.items.iter().enumerate() {
if i > 0 {
out.push_str(", ");
}
out.push_str(&crate::json::value_to_json_text(item));
}
out.push(']');
Value::Json(out)
}
"json_object_agg" | "jsonb_object_agg" => {
if st.items.is_empty() {
return Value::Null;
}
let mut out = String::from("{");
for (i, key) in st.items.iter().enumerate() {
if i > 0 {
out.push_str(", ");
}
let key_text = match key {
Value::Text(s) | Value::Json(s) => s.clone(),
other => crate::json::value_to_json_text(other),
};
out.push_str(&crate::json::value_to_json_text(&Value::Text(key_text)));
out.push_str(": ");
let val = st.aux_items.get(i).unwrap_or(&Value::Null);
out.push_str(&crate::json::value_to_json_text(val));
}
out.push('}');
Value::Json(out)
}
_ => unreachable!(),
}
}
fn agg_value_to_f64(v: &Value) -> Option<f64> {
match v {
Value::Int(n) => Some(f64::from(*n)),
Value::SmallInt(n) => Some(f64::from(*n)),
Value::BigInt(n) => Some(*n as f64),
Value::Float(x) => Some(*x),
_ => None,
}
}
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
fn finalize_ordered_set(
name: &str,
st: &AggState,
direct: Option<&Value>,
order: Option<&spg_sql::ast::OrderBy>,
) -> Value {
let fraction = direct;
let items = &st.items;
if items.is_empty() {
return match name {
"rank" | "dense_rank" => Value::BigInt(1),
"percent_rank" => Value::Float(0.0),
"cume_dist" => Value::Float(1.0),
_ => Value::Null,
};
}
let n = items.len();
match name {
"rank" | "dense_rank" | "percent_rank" | "cume_dist" => {
let Some(h) = fraction else {
return Value::Null;
};
let (desc, nulls_first) = order.map_or((false, None), |o| (o.desc, o.nulls_first));
let mut before = 0usize; let mut before_or_eq = 0usize; let mut distinct_before = 0usize;
let mut last_before: Option<&Value> = None;
for it in items {
match crate::order_by_value_cmp(desc, nulls_first, it, h) {
core::cmp::Ordering::Less => {
before += 1;
before_or_eq += 1;
if last_before
.is_none_or(|p| value_cmp(p, it) != core::cmp::Ordering::Equal)
{
distinct_before += 1;
last_before = Some(it);
}
}
core::cmp::Ordering::Equal => before_or_eq += 1,
core::cmp::Ordering::Greater => {}
}
}
let nn = n as f64;
match name {
"rank" => Value::BigInt((before + 1) as i64),
"dense_rank" => Value::BigInt((distinct_before + 1) as i64),
"percent_rank" => Value::Float(before as f64 / nn),
"cume_dist" => Value::Float((before_or_eq as f64 + 1.0) / (nn + 1.0)),
_ => unreachable!(),
}
}
"mode" => {
let (mut best_i, mut best_cnt) = (0usize, 1usize);
let (mut run_i, mut run_cnt) = (0usize, 1usize);
for i in 1..n {
if value_cmp(&items[i], &items[run_i]) == core::cmp::Ordering::Equal {
run_cnt += 1;
} else {
run_i = i;
run_cnt = 1;
}
if run_cnt > best_cnt {
best_cnt = run_cnt;
best_i = run_i;
}
}
items[best_i].clone()
}
"percentile_disc" => {
let f = fraction
.and_then(agg_value_to_f64)
.unwrap_or(0.0)
.clamp(0.0, 1.0);
let idx = if f <= 0.0 {
0
} else {
(crate::eval::f64_ceil(f * n as f64) as usize)
.saturating_sub(1)
.min(n - 1)
};
items[idx].clone()
}
"percentile_cont" => {
let f = fraction
.and_then(agg_value_to_f64)
.unwrap_or(0.0)
.clamp(0.0, 1.0);
let Some(nums) = items
.iter()
.map(agg_value_to_f64)
.collect::<Option<Vec<f64>>>()
else {
return Value::Null; };
if n == 1 {
return Value::Float(nums[0]);
}
let rank = f * (n as f64 - 1.0);
let lo = crate::eval::f64_floor(rank) as usize;
let hi = crate::eval::f64_ceil(rank) as usize;
let frac = rank - lo as f64;
Value::Float(nums[lo] + (nums[hi] - nums[lo]) * frac)
}
_ => unreachable!(),
}
}
fn infer_agg_type(spec: &AggSpec, schema_cols: &[ColumnSchema]) -> DataType {
let arg_ty = spec
.arg
.as_ref()
.and_then(|a| crate::describe::describe_expr(a, schema_cols))
.map(|shape| shape.ty);
if spec.first_ordered {
return arg_ty.unwrap_or(DataType::Text);
}
match spec.name.as_str() {
"count" | "count_star" => DataType::BigInt,
"sum" => match arg_ty {
Some(DataType::Float) => DataType::Float,
_ => DataType::BigInt,
},
"avg" => DataType::Float,
"string_agg" => DataType::Text,
"array_agg" => match arg_ty {
Some(DataType::Int | DataType::SmallInt) => DataType::IntArray,
Some(DataType::BigInt) => DataType::BigIntArray,
_ => DataType::TextArray,
},
"bool_and" | "bool_or" => DataType::Bool,
"stddev" | "stddev_samp" | "stddev_pop" | "variance" | "var_samp" | "var_pop"
| "percentile_cont" | "covar_pop" | "covar_samp" | "corr" | "regr_avgx" | "regr_avgy"
| "regr_slope" | "regr_intercept" | "regr_r2" | "regr_sxx" | "regr_syy" | "regr_sxy" => {
DataType::Float
}
"bit_and" | "bit_or" | "bit_xor" | "regr_count" | "rank" | "dense_rank" => DataType::BigInt,
"percent_rank" | "cume_dist" => DataType::Float,
"json_agg" | "jsonb_agg" | "json_object_agg" | "jsonb_object_agg" => DataType::Json,
_ => arg_ty.unwrap_or(DataType::Text),
}
}
fn agg_or_group_type(e: &Expr, synth: &[ColumnSchema]) -> DataType {
if let Expr::Column(c) = e
&& let Some(s) = synth.iter().find(|s| s.name == c.name)
{
return s.ty;
}
crate::describe::describe_expr(e, synth)
.map(|shape| shape.ty)
.unwrap_or(DataType::Text)
}
fn rewrite_expr(e: &Expr, group_exprs: &[Expr], aggs: &[AggSpec]) -> Expr {
if let Some((arg, order_by, filter)) = first_ordered_array_agg(e) {
let arg_owned = Some(arg.clone());
let filter_owned = filter.cloned();
for (i, spec) in aggs.iter().enumerate() {
if spec.first_ordered
&& spec.name == "array_agg"
&& spec.arg == arg_owned
&& spec.order_by == *order_by
&& spec.filter == filter_owned
{
return Expr::Column(spg_sql::ast::ColumnName {
qualifier: None,
name: format!("__agg_{i}"),
});
}
}
}
if let Expr::AggregateOrdered {
call,
order_by,
distinct,
filter,
} = e
&& let Expr::FunctionCall { name, args } = call.as_ref()
{
let lower = name.to_ascii_lowercase();
if is_aggregate_name(&lower) {
let canonical: &str = if lower == "every" { "bool_and" } else { &lower };
let (arg, direct_arg) = if is_within_group_name(canonical) {
(
order_by.first().map(|o| o.expr.clone()),
args.first().cloned(),
)
} else {
(args.first().cloned(), None)
};
let arg2 = if agg_uses_second_arg(canonical) {
args.get(1).cloned()
} else {
None
};
let filter_owned = filter.as_deref().cloned();
for (i, spec) in aggs.iter().enumerate() {
if spec.name == canonical
&& spec.arg == arg
&& spec.arg2 == arg2
&& spec.distinct == *distinct
&& spec.order_by == *order_by
&& spec.filter == filter_owned
&& spec.direct_arg == direct_arg
{
return Expr::Column(spg_sql::ast::ColumnName {
qualifier: None,
name: format!("__agg_{i}"),
});
}
}
}
}
if let Expr::FunctionCall { name, args } = e {
let lower = name.to_ascii_lowercase();
if is_aggregate_name(&lower) {
let arg = if lower == "count_star" {
None
} else {
args.first().cloned()
};
let arg2 = if agg_uses_second_arg(&lower) {
args.get(1).cloned()
} else {
None
};
let canonical: &str = if lower == "every" {
"bool_and"
} else {
lower.as_str()
};
for (i, spec) in aggs.iter().enumerate() {
if spec.name == canonical
&& spec.arg == arg
&& spec.arg2 == arg2
&& !spec.distinct
&& spec.order_by.is_empty()
{
return Expr::Column(spg_sql::ast::ColumnName {
qualifier: None,
name: format!("__agg_{i}"),
});
}
}
}
}
for (i, g) in group_exprs.iter().enumerate() {
if g == e {
return Expr::Column(spg_sql::ast::ColumnName {
qualifier: None,
name: format!("__grp_{i}"),
});
}
}
match e {
Expr::AggregateOrdered {
call,
order_by,
distinct,
filter,
} => Expr::AggregateOrdered {
call: Box::new(rewrite_expr(call, group_exprs, aggs)),
distinct: *distinct,
order_by: order_by
.iter()
.map(|o| spg_sql::ast::OrderBy {
expr: rewrite_expr(&o.expr, group_exprs, aggs),
desc: o.desc,
nulls_first: o.nulls_first,
})
.collect(),
filter: filter.clone(),
},
Expr::Binary { lhs, op, rhs } => Expr::Binary {
lhs: Box::new(rewrite_expr(lhs, group_exprs, aggs)),
op: *op,
rhs: Box::new(rewrite_expr(rhs, group_exprs, aggs)),
},
Expr::Unary { op, expr } => Expr::Unary {
op: *op,
expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
},
Expr::Cast { expr, target } => Expr::Cast {
expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
target: *target,
},
Expr::IsNull { expr, negated } => Expr::IsNull {
expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
negated: *negated,
},
Expr::FunctionCall { name, args } => Expr::FunctionCall {
name: name.clone(),
args: args
.iter()
.map(|a| rewrite_expr(a, group_exprs, aggs))
.collect(),
},
Expr::Like {
expr,
pattern,
negated,
case_insensitive,
} => Expr::Like {
expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
pattern: Box::new(rewrite_expr(pattern, group_exprs, aggs)),
negated: *negated,
case_insensitive: *case_insensitive,
},
Expr::Extract { field, source } => Expr::Extract {
field: *field,
source: Box::new(rewrite_expr(source, group_exprs, aggs)),
},
Expr::ScalarSubquery(s) => {
Expr::ScalarSubquery(Box::new(rewrite_group_keys_in_select(s, group_exprs)))
}
Expr::Exists { subquery, negated } => Expr::Exists {
subquery: Box::new(rewrite_group_keys_in_select(subquery, group_exprs)),
negated: *negated,
},
Expr::InSubquery {
expr,
subquery,
negated,
} => Expr::InSubquery {
expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
subquery: Box::new(rewrite_group_keys_in_select(subquery, group_exprs)),
negated: *negated,
},
Expr::WindowFunction { .. } | Expr::Literal(_) | Expr::Placeholder(_) | Expr::Column(_) => {
e.clone()
}
Expr::Array(items) => Expr::Array(
items
.iter()
.map(|elem| rewrite_expr(elem, group_exprs, aggs))
.collect(),
),
Expr::ArraySubscript { target, index } => Expr::ArraySubscript {
target: Box::new(rewrite_expr(target, group_exprs, aggs)),
index: Box::new(rewrite_expr(index, group_exprs, aggs)),
},
Expr::AnyAll {
expr,
op,
array,
is_any,
} => Expr::AnyAll {
expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
op: *op,
array: Box::new(rewrite_expr(array, group_exprs, aggs)),
is_any: *is_any,
},
Expr::InList {
expr,
list,
negated,
} => Expr::InList {
expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
list: list
.iter()
.map(|item| rewrite_expr(item, group_exprs, aggs))
.collect(),
negated: *negated,
},
Expr::Case {
operand,
branches,
else_branch,
} => Expr::Case {
operand: operand
.as_deref()
.map(|o| Box::new(rewrite_expr(o, group_exprs, aggs))),
branches: branches
.iter()
.map(|(w, t)| {
(
rewrite_expr(w, group_exprs, aggs),
rewrite_expr(t, group_exprs, aggs),
)
})
.collect(),
else_branch: else_branch
.as_deref()
.map(|e| Box::new(rewrite_expr(e, group_exprs, aggs))),
},
}
}
fn rewrite_group_keys_in_select(
s: &spg_sql::ast::SelectStatement,
group_exprs: &[Expr],
) -> spg_sql::ast::SelectStatement {
let mut out = s.clone();
let _ = crate::walk_select_exprs_mut(&mut out, &mut |e| {
*e = rewrite_expr(e, group_exprs, &[]);
Ok(())
});
out
}
fn encode_one(out: &mut String, v: &Value) {
match v {
Value::Null => out.push_str("N|"),
Value::SmallInt(n) => {
out.push('s');
out.push_str(&n.to_string());
out.push('|');
}
Value::Int(n) => {
out.push('I');
out.push_str(&n.to_string());
out.push('|');
}
Value::BigInt(n) => {
out.push('B');
out.push_str(&n.to_string());
out.push('|');
}
Value::Float(x) => {
out.push('F');
out.push_str(&x.to_string());
out.push('|');
}
Value::Bool(b) => {
out.push(if *b { 'T' } else { 'f' });
out.push('|');
}
Value::Text(s) => {
out.push('S');
out.push_str(s);
out.push('|');
}
Value::Vector(v) => {
out.push('V');
for x in v {
out.push_str(&x.to_string());
out.push(',');
}
out.push('|');
}
Value::Sq8Vector(q) => {
out.push('Q');
out.push_str(&q.min.to_string());
out.push('@');
out.push_str(&q.max.to_string());
out.push(':');
for b in &q.bytes {
out.push_str(&b.to_string());
out.push(',');
}
out.push('|');
}
Value::HalfVector(h) => {
out.push('H');
for b in &h.bytes {
out.push_str(&b.to_string());
out.push(',');
}
out.push('|');
}
Value::Numeric { scaled, scale } => {
out.push('D');
out.push_str(&scaled.to_string());
out.push('@');
out.push_str(&scale.to_string());
out.push('|');
}
Value::Date(d) => {
out.push('d');
out.push_str(&d.to_string());
out.push('|');
}
Value::Timestamp(t) => {
out.push('t');
out.push_str(&t.to_string());
out.push('|');
}
Value::Interval { months, micros } => {
out.push('i');
out.push_str(&months.to_string());
out.push('m');
out.push_str(µs.to_string());
out.push('|');
}
Value::Json(s) => {
out.push('j');
out.push_str(s);
out.push('|');
}
_ => {
out.push('?');
out.push_str(&format!("{v:?}"));
out.push('|');
}
}
}
pub(crate) fn encode_key_refs(vals: &[&Value]) -> String {
let mut out = String::new();
for v in vals {
encode_one(&mut out, v);
}
out
}
pub(crate) fn encode_key_refs_into(vals: &[&Value], out: &mut String) {
out.clear();
for v in vals {
encode_one(out, v);
}
}
pub(crate) fn encode_key(vals: &[Value]) -> String {
let mut out = String::new();
for v in vals {
encode_one(&mut out, v);
}
out
}
#[allow(clippy::cast_precision_loss)]
fn value_cmp(a: &Value, b: &Value) -> core::cmp::Ordering {
use core::cmp::Ordering::Equal;
match (a, b) {
(Value::Null, Value::Null) => Equal,
(Value::Null, _) => core::cmp::Ordering::Greater, (_, Value::Null) => core::cmp::Ordering::Less,
(Value::Int(x), Value::Int(y)) => x.cmp(y),
(Value::BigInt(x), Value::BigInt(y)) => x.cmp(y),
(Value::Int(x), Value::BigInt(y)) => i64::from(*x).cmp(y),
(Value::BigInt(x), Value::Int(y)) => x.cmp(&i64::from(*y)),
(Value::Float(x), Value::Float(y)) => x.partial_cmp(y).unwrap_or(Equal),
(Value::Int(x), Value::Float(y)) => f64::from(*x).partial_cmp(y).unwrap_or(Equal),
(Value::Float(x), Value::Int(y)) => x.partial_cmp(&f64::from(*y)).unwrap_or(Equal),
(Value::BigInt(x), Value::Float(y)) => (*x as f64).partial_cmp(y).unwrap_or(Equal),
(Value::Float(x), Value::BigInt(y)) => x.partial_cmp(&(*y as f64)).unwrap_or(Equal),
(Value::Text(x), Value::Text(y)) => x.cmp(y),
(Value::Bool(x), Value::Bool(y)) => x.cmp(y),
_ => Equal,
}
}