use alloc::boxed::Box;
use alloc::collections::BTreeMap;
use alloc::format;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use spg_sql::ast::{Expr, SelectItem, SelectStatement};
use spg_storage::{ColumnSchema, DataType, Row, Value};
use crate::eval::{self, EvalContext, EvalError};
pub fn uses_aggregate(stmt: &SelectStatement) -> bool {
if stmt.group_by.is_some() || stmt.having.is_some() {
return true;
}
for item in &stmt.items {
if let SelectItem::Expr { expr, .. } = item
&& contains_aggregate(expr)
{
return true;
}
}
for o in &stmt.order_by {
if contains_aggregate(&o.expr) {
return true;
}
}
if let Some(h) = &stmt.having
&& contains_aggregate(h)
{
return true;
}
false
}
pub fn contains_aggregate(e: &Expr) -> bool {
match e {
Expr::FunctionCall { name, args } => {
is_aggregate_name(name) || args.iter().any(contains_aggregate)
}
Expr::Binary { lhs, rhs, .. } => contains_aggregate(lhs) || contains_aggregate(rhs),
Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
contains_aggregate(expr)
}
Expr::Like { expr, pattern, .. } => contains_aggregate(expr) || contains_aggregate(pattern),
Expr::Extract { source, .. } => contains_aggregate(source),
Expr::ScalarSubquery(_)
| Expr::Exists { .. }
| Expr::InSubquery { .. }
| Expr::WindowFunction { .. }
| Expr::Literal(_)
| Expr::Placeholder(_)
| Expr::Column(_) => false,
Expr::Array(items) => items.iter().any(contains_aggregate),
Expr::ArraySubscript { target, index } => {
contains_aggregate(target) || contains_aggregate(index)
}
Expr::AnyAll { expr, array, .. } => contains_aggregate(expr) || contains_aggregate(array),
}
}
pub fn is_aggregate_name(name: &str) -> bool {
matches!(
name.to_ascii_lowercase().as_str(),
"count" | "count_star" | "sum" | "min" | "max" | "avg"
)
}
#[derive(Debug, Default, Clone)]
struct AggState {
count: i64,
sum_int: i64,
sum_float: f64,
extreme: Option<Value>,
use_float: bool,
}
#[derive(Debug, Clone)]
struct AggSpec {
name: String, arg: Option<Expr>,
}
#[derive(Debug)]
pub struct AggResult {
pub columns: Vec<ColumnSchema>,
pub rows: Vec<Row>,
}
#[allow(clippy::too_many_lines)]
pub fn run(
stmt: &SelectStatement,
rows: &[&Row],
schema_cols: &[ColumnSchema],
table_alias: Option<&str>,
) -> Result<AggResult, EvalError> {
let ctx = EvalContext::new(schema_cols, table_alias);
let group_exprs: Vec<Expr> = stmt.group_by.clone().unwrap_or_default();
let mut agg_specs: Vec<AggSpec> = Vec::new();
for item in &stmt.items {
if let SelectItem::Expr { expr, .. } = item {
collect_aggregates(expr, &mut agg_specs);
}
}
for o in &stmt.order_by {
collect_aggregates(&o.expr, &mut agg_specs);
}
if let Some(h) = &stmt.having {
collect_aggregates(h, &mut agg_specs);
}
let mut groups: BTreeMap<String, (Vec<Value>, Vec<AggState>)> = BTreeMap::new();
let mut key_order: Vec<String> = Vec::new();
if rows.is_empty() && group_exprs.is_empty() {
let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
groups.insert(String::new(), (Vec::new(), init));
key_order.push(String::new());
}
for row in rows {
let group_vals: Vec<Value> = group_exprs
.iter()
.map(|g| eval::eval_expr(g, row, &ctx))
.collect::<Result<_, _>>()?;
let key = encode_key(&group_vals);
let entry = groups.entry(key.clone()).or_insert_with(|| {
key_order.push(key.clone());
let init: Vec<AggState> = (0..agg_specs.len()).map(|_| AggState::default()).collect();
(group_vals.clone(), init)
});
for (i, spec) in agg_specs.iter().enumerate() {
let arg_val = match &spec.arg {
None => Value::Bool(true), Some(e) => eval::eval_expr(e, row, &ctx)?,
};
update_state(&mut entry.1[i], &spec.name, &arg_val)?;
}
}
let group_types: Vec<DataType> = if rows.is_empty() {
group_exprs.iter().map(|_| DataType::Text).collect()
} else {
let probe = rows[0];
group_exprs
.iter()
.map(|g| {
eval::eval_expr(g, probe, &ctx).map(|v| v.data_type().unwrap_or(DataType::Text))
})
.collect::<Result<_, _>>()?
};
let agg_types: Vec<DataType> = agg_specs.iter().map(infer_agg_type).collect();
let mut synth_schema: Vec<ColumnSchema> = Vec::new();
for (i, ty) in group_types.iter().enumerate() {
synth_schema.push(ColumnSchema::new(format!("__grp_{i}"), *ty, true));
}
for (i, ty) in agg_types.iter().enumerate() {
synth_schema.push(ColumnSchema::new(format!("__agg_{i}"), *ty, true));
}
let mut synth_rows: Vec<Row> = Vec::new();
for k in &key_order {
let (gvals, states) = &groups[k];
let mut values: Vec<Value> = Vec::with_capacity(synth_schema.len());
values.extend(gvals.iter().cloned());
for (i, st) in states.iter().enumerate() {
values.push(finalize(&agg_specs[i].name, st));
}
synth_rows.push(Row::new(values));
}
let columns: Vec<ColumnSchema> = stmt
.items
.iter()
.map(|item| match item {
SelectItem::Wildcard => Err(EvalError::TypeMismatch {
detail: "SELECT * with aggregates is not supported".into(),
}),
SelectItem::Expr { expr, alias } => {
let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
let name = alias.clone().unwrap_or_else(|| expr.to_string());
Ok(ColumnSchema::new(
name,
agg_or_group_type(&rewritten, &synth_schema),
true,
))
}
})
.collect::<Result<_, _>>()?;
let synth_ctx = EvalContext::new(&synth_schema, None);
let having_rewritten = stmt
.having
.as_ref()
.map(|h| rewrite_expr(h, &group_exprs, &agg_specs));
let mut kept_synth: Vec<Row> = Vec::new();
let mut out_rows: Vec<Row> = Vec::new();
for srow in synth_rows {
if let Some(h) = &having_rewritten {
let cond = eval::eval_expr(h, &srow, &synth_ctx)?;
if !matches!(cond, Value::Bool(true)) {
continue;
}
}
let mut values: Vec<Value> = Vec::with_capacity(columns.len());
for item in &stmt.items {
if let SelectItem::Expr { expr, .. } = item {
let rewritten = rewrite_expr(expr, &group_exprs, &agg_specs);
values.push(eval::eval_expr(&rewritten, &srow, &synth_ctx)?);
}
}
kept_synth.push(srow);
out_rows.push(Row::new(values));
}
if !stmt.order_by.is_empty() {
let rewritten: Vec<Expr> = stmt
.order_by
.iter()
.map(|o| rewrite_expr(&o.expr, &group_exprs, &agg_specs))
.collect();
let descs: Vec<bool> = stmt.order_by.iter().map(|o| o.desc).collect();
let mut tagged: Vec<(Vec<Value>, Row)> = kept_synth
.into_iter()
.zip(out_rows)
.map(|(s, o)| {
let mut keys = Vec::with_capacity(rewritten.len());
for e in &rewritten {
keys.push(eval::eval_expr(e, &s, &synth_ctx)?);
}
Ok::<_, EvalError>((keys, o))
})
.collect::<Result<_, _>>()?;
tagged.sort_by(|a, b| {
use core::cmp::Ordering;
for (i, (ka, kb)) in a.0.iter().zip(b.0.iter()).enumerate() {
let cmp = value_cmp(ka, kb);
let cmp = if descs[i] { cmp.reverse() } else { cmp };
if cmp != Ordering::Equal {
return cmp;
}
}
Ordering::Equal
});
out_rows = tagged.into_iter().map(|(_, o)| o).collect();
}
Ok(AggResult {
columns,
rows: out_rows,
})
}
fn collect_aggregates(e: &Expr, out: &mut Vec<AggSpec>) {
match e {
Expr::FunctionCall { name, args } => {
let lower = name.to_ascii_lowercase();
if is_aggregate_name(&lower) {
let arg = if lower == "count_star" {
None
} else {
args.first().cloned()
};
let spec = AggSpec {
name: lower,
arg: arg.clone(),
};
if !out.iter().any(|s| s.name == spec.name && s.arg == spec.arg) {
out.push(spec);
}
} else {
for a in args {
collect_aggregates(a, out);
}
}
}
Expr::Binary { lhs, rhs, .. } => {
collect_aggregates(lhs, out);
collect_aggregates(rhs, out);
}
Expr::Unary { expr, .. } | Expr::Cast { expr, .. } | Expr::IsNull { expr, .. } => {
collect_aggregates(expr, out);
}
Expr::Like { expr, pattern, .. } => {
collect_aggregates(expr, out);
collect_aggregates(pattern, out);
}
Expr::Extract { source, .. } => collect_aggregates(source, out),
Expr::ScalarSubquery(_)
| Expr::Exists { .. }
| Expr::InSubquery { .. }
| Expr::WindowFunction { .. }
| Expr::Literal(_)
| Expr::Placeholder(_)
| Expr::Column(_) => {}
Expr::Array(items) => {
for elem in items {
collect_aggregates(elem, out);
}
}
Expr::ArraySubscript { target, index } => {
collect_aggregates(target, out);
collect_aggregates(index, out);
}
Expr::AnyAll { expr, array, .. } => {
collect_aggregates(expr, out);
collect_aggregates(array, out);
}
}
}
fn update_state(st: &mut AggState, name: &str, v: &Value) -> Result<(), EvalError> {
let is_null = matches!(v, Value::Null);
match name {
"count_star" => st.count += 1,
"count" => {
if !is_null {
st.count += 1;
}
}
"sum" | "avg" => {
if is_null {
return Ok(());
}
st.count += 1;
match v {
Value::Int(n) => st.sum_int += i64::from(*n),
Value::BigInt(n) => st.sum_int += *n,
Value::Float(x) => {
st.use_float = true;
st.sum_float += *x;
}
other => {
return Err(EvalError::TypeMismatch {
detail: format!("sum/avg need numeric, got {:?}", other.data_type()),
});
}
}
}
"min" => {
if is_null {
return Ok(());
}
match &st.extreme {
None => st.extreme = Some(v.clone()),
Some(cur) => {
if value_cmp(v, cur) == core::cmp::Ordering::Less {
st.extreme = Some(v.clone());
}
}
}
}
"max" => {
if is_null {
return Ok(());
}
match &st.extreme {
None => st.extreme = Some(v.clone()),
Some(cur) => {
if value_cmp(v, cur) == core::cmp::Ordering::Greater {
st.extreme = Some(v.clone());
}
}
}
}
_ => unreachable!("non-aggregate {name} in update_state"),
}
Ok(())
}
#[allow(clippy::cast_precision_loss)]
fn finalize(name: &str, st: &AggState) -> Value {
match name {
"count" | "count_star" => Value::BigInt(st.count),
"sum" => {
if st.count == 0 {
Value::Null
} else if st.use_float {
Value::Float(st.sum_float + (st.sum_int as f64))
} else {
Value::BigInt(st.sum_int)
}
}
"avg" => {
if st.count == 0 {
Value::Null
} else {
let total = if st.use_float {
st.sum_float + (st.sum_int as f64)
} else {
st.sum_int as f64
};
Value::Float(total / (st.count as f64))
}
}
"min" | "max" => st.extreme.clone().unwrap_or(Value::Null),
_ => unreachable!(),
}
}
fn infer_agg_type(spec: &AggSpec) -> DataType {
match spec.name.as_str() {
"count" | "count_star" | "sum" => DataType::BigInt,
"avg" => DataType::Float,
_ => DataType::Text,
}
}
fn agg_or_group_type(e: &Expr, synth: &[ColumnSchema]) -> DataType {
if let Expr::Column(c) = e
&& let Some(s) = synth.iter().find(|s| s.name == c.name)
{
return s.ty;
}
DataType::Text
}
fn rewrite_expr(e: &Expr, group_exprs: &[Expr], aggs: &[AggSpec]) -> Expr {
if let Expr::FunctionCall { name, args } = e {
let lower = name.to_ascii_lowercase();
if is_aggregate_name(&lower) {
let arg = if lower == "count_star" {
None
} else {
args.first().cloned()
};
for (i, spec) in aggs.iter().enumerate() {
if spec.name == lower && spec.arg == arg {
return Expr::Column(spg_sql::ast::ColumnName {
qualifier: None,
name: format!("__agg_{i}"),
});
}
}
}
}
for (i, g) in group_exprs.iter().enumerate() {
if g == e {
return Expr::Column(spg_sql::ast::ColumnName {
qualifier: None,
name: format!("__grp_{i}"),
});
}
}
match e {
Expr::Binary { lhs, op, rhs } => Expr::Binary {
lhs: Box::new(rewrite_expr(lhs, group_exprs, aggs)),
op: *op,
rhs: Box::new(rewrite_expr(rhs, group_exprs, aggs)),
},
Expr::Unary { op, expr } => Expr::Unary {
op: *op,
expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
},
Expr::Cast { expr, target } => Expr::Cast {
expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
target: *target,
},
Expr::IsNull { expr, negated } => Expr::IsNull {
expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
negated: *negated,
},
Expr::FunctionCall { name, args } => Expr::FunctionCall {
name: name.clone(),
args: args
.iter()
.map(|a| rewrite_expr(a, group_exprs, aggs))
.collect(),
},
Expr::Like {
expr,
pattern,
negated,
} => Expr::Like {
expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
pattern: Box::new(rewrite_expr(pattern, group_exprs, aggs)),
negated: *negated,
},
Expr::Extract { field, source } => Expr::Extract {
field: *field,
source: Box::new(rewrite_expr(source, group_exprs, aggs)),
},
Expr::ScalarSubquery(_)
| Expr::Exists { .. }
| Expr::InSubquery { .. }
| Expr::WindowFunction { .. }
| Expr::Literal(_)
| Expr::Placeholder(_)
| Expr::Column(_) => e.clone(),
Expr::Array(items) => Expr::Array(
items
.iter()
.map(|elem| rewrite_expr(elem, group_exprs, aggs))
.collect(),
),
Expr::ArraySubscript { target, index } => Expr::ArraySubscript {
target: Box::new(rewrite_expr(target, group_exprs, aggs)),
index: Box::new(rewrite_expr(index, group_exprs, aggs)),
},
Expr::AnyAll {
expr,
op,
array,
is_any,
} => Expr::AnyAll {
expr: Box::new(rewrite_expr(expr, group_exprs, aggs)),
op: *op,
array: Box::new(rewrite_expr(array, group_exprs, aggs)),
is_any: *is_any,
},
}
}
fn encode_key(vals: &[Value]) -> String {
let mut out = String::new();
for v in vals {
match v {
Value::Null => out.push_str("N|"),
Value::SmallInt(n) => {
out.push('s');
out.push_str(&n.to_string());
out.push('|');
}
Value::Int(n) => {
out.push('I');
out.push_str(&n.to_string());
out.push('|');
}
Value::BigInt(n) => {
out.push('B');
out.push_str(&n.to_string());
out.push('|');
}
Value::Float(x) => {
out.push('F');
out.push_str(&x.to_string());
out.push('|');
}
Value::Bool(b) => {
out.push(if *b { 'T' } else { 'f' });
out.push('|');
}
Value::Text(s) => {
out.push('S');
out.push_str(s);
out.push('|');
}
Value::Vector(v) => {
out.push('V');
for x in v {
out.push_str(&x.to_string());
out.push(',');
}
out.push('|');
}
Value::Sq8Vector(q) => {
out.push('Q');
out.push_str(&q.min.to_string());
out.push('@');
out.push_str(&q.max.to_string());
out.push(':');
for b in &q.bytes {
out.push_str(&b.to_string());
out.push(',');
}
out.push('|');
}
Value::HalfVector(h) => {
out.push('H');
for b in &h.bytes {
out.push_str(&b.to_string());
out.push(',');
}
out.push('|');
}
Value::Numeric { scaled, scale } => {
out.push('D');
out.push_str(&scaled.to_string());
out.push('@');
out.push_str(&scale.to_string());
out.push('|');
}
Value::Date(d) => {
out.push('d');
out.push_str(&d.to_string());
out.push('|');
}
Value::Timestamp(t) => {
out.push('t');
out.push_str(&t.to_string());
out.push('|');
}
Value::Interval { months, micros } => {
out.push('i');
out.push_str(&months.to_string());
out.push('m');
out.push_str(µs.to_string());
out.push('|');
}
Value::Json(s) => {
out.push('j');
out.push_str(s);
out.push('|');
}
_ => {
out.push('?');
out.push_str(&format!("{v:?}"));
out.push('|');
}
}
}
out
}
#[allow(clippy::cast_precision_loss)]
fn value_cmp(a: &Value, b: &Value) -> core::cmp::Ordering {
use core::cmp::Ordering::Equal;
match (a, b) {
(Value::Null, Value::Null) => Equal,
(Value::Null, _) => core::cmp::Ordering::Greater, (_, Value::Null) => core::cmp::Ordering::Less,
(Value::Int(x), Value::Int(y)) => x.cmp(y),
(Value::BigInt(x), Value::BigInt(y)) => x.cmp(y),
(Value::Int(x), Value::BigInt(y)) => i64::from(*x).cmp(y),
(Value::BigInt(x), Value::Int(y)) => x.cmp(&i64::from(*y)),
(Value::Float(x), Value::Float(y)) => x.partial_cmp(y).unwrap_or(Equal),
(Value::Int(x), Value::Float(y)) => f64::from(*x).partial_cmp(y).unwrap_or(Equal),
(Value::Float(x), Value::Int(y)) => x.partial_cmp(&f64::from(*y)).unwrap_or(Equal),
(Value::BigInt(x), Value::Float(y)) => (*x as f64).partial_cmp(y).unwrap_or(Equal),
(Value::Float(x), Value::BigInt(y)) => x.partial_cmp(&(*y as f64)).unwrap_or(Equal),
(Value::Text(x), Value::Text(y)) => x.cmp(y),
(Value::Bool(x), Value::Bool(y)) => x.cmp(y),
_ => Equal,
}
}