use super::{fold_columns, fold_vals, DynProofPlan};
use crate::{
base::{
database::{
group_by_util::{aggregate_columns, AggregatedColumns},
Column, ColumnField, ColumnRef, ColumnType, LiteralValue, Table, TableEvaluation,
TableRef,
},
map::{IndexMap, IndexSet},
proof::{PlaceholderResult, ProofError},
scalar::Scalar,
slice_ops,
},
sql::{
proof::{
FinalRoundBuilder, FirstRoundBuilder, ProofPlan, ProverEvaluate,
SumcheckSubpolynomialType, VerificationBuilder,
},
proof_exprs::{AliasedDynProofExpr, DynProofExpr, ProofExpr},
proof_gadgets::{
final_round_evaluate_monotonic, first_round_evaluate_monotonic,
fold_log_expr::FoldLogExpr, verify_monotonic,
},
},
utils::log,
};
use alloc::{boxed::Box, vec, vec::Vec};
use bumpalo::Bump;
use core::iter;
use num_traits::One;
use serde::{Deserialize, Serialize};
use sqlparser::ast::Ident;
use tracing::{span, Level};
#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
pub struct AggregateExec {
group_by_exprs: Vec<AliasedDynProofExpr>,
sum_expr: Vec<AliasedDynProofExpr>,
count_alias: Ident,
input: Box<DynProofPlan>,
where_clause: DynProofExpr,
}
impl AggregateExec {
pub fn try_new(
group_by_exprs: Vec<AliasedDynProofExpr>,
sum_expr: Vec<AliasedDynProofExpr>,
count_alias: Ident,
input: Box<DynProofPlan>,
where_clause: DynProofExpr,
) -> Option<Self> {
let group_by = Self {
group_by_exprs,
sum_expr,
count_alias,
input,
where_clause,
};
group_by.try_get_is_uniqueness_provable().map(|_| group_by)
}
pub fn input(&self) -> &DynProofPlan {
&self.input
}
pub fn where_clause(&self) -> &DynProofExpr {
&self.where_clause
}
pub fn group_by_exprs(&self) -> &[AliasedDynProofExpr] {
&self.group_by_exprs
}
pub fn sum_expr(&self) -> &[AliasedDynProofExpr] {
&self.sum_expr
}
pub fn count_alias(&self) -> &Ident {
&self.count_alias
}
pub fn try_get_is_uniqueness_provable(&self) -> Option<bool> {
match (
self.group_by_exprs.len(),
self.group_by_exprs
.first()
.map(|aliased_expr| aliased_expr.expr.data_type()),
) {
(0, _) => Some(false),
(1, Some(data_type))
if !matches!(data_type, ColumnType::VarChar | ColumnType::VarBinary) =>
{
Some(true)
}
_ => None,
}
}
}
impl ProofPlan for AggregateExec {
fn verifier_evaluate<S: Scalar>(
&self,
builder: &mut impl VerificationBuilder<S>,
accessor: &IndexMap<TableRef, IndexMap<Ident, S>>,
chi_eval_map: &IndexMap<TableRef, (S, usize)>,
params: &[LiteralValue],
) -> Result<TableEvaluation<S>, ProofError> {
let alpha = builder.try_consume_post_result_challenge()?;
let beta = builder.try_consume_post_result_challenge()?;
let input_eval = self
.input
.verifier_evaluate(builder, accessor, chi_eval_map, params)?;
let input_chi_eval = input_eval.chi_eval();
let input_schema = self.input.get_column_result_fields();
let accessor = input_schema
.iter()
.zip(input_eval.column_evals())
.map(|(field, eval)| (field.name().clone(), *eval))
.collect::<IndexMap<_, _>>();
let fold_gadget = FoldLogExpr::new(alpha, beta);
let group_by_evals = self
.group_by_exprs
.iter()
.map(|aliased_expr| {
aliased_expr
.expr
.verifier_evaluate(builder, &accessor, input_chi_eval, params)
})
.collect::<Result<Vec<_>, _>>()?;
let g_in_star_eval = fold_gadget
.verify_evaluate(builder, &group_by_evals, input_chi_eval)?
.0;
let where_eval =
self.where_clause
.verifier_evaluate(builder, &accessor, input_chi_eval, params)?;
let aggregate_evals = self
.sum_expr
.iter()
.map(|aliased_expr| {
aliased_expr
.expr
.verifier_evaluate(builder, &accessor, input_chi_eval, params)
})
.collect::<Result<Vec<_>, _>>()?;
let sum_in_fold_eval = input_chi_eval + beta * fold_vals(beta, &aggregate_evals);
let output_chi_eval = builder.try_consume_chi_evaluation()?;
let group_by_result_columns_evals =
builder.try_consume_first_round_mle_evaluations(self.group_by_exprs.len())?;
let g_out_star_eval = fold_gadget
.verify_evaluate(builder, &group_by_result_columns_evals, output_chi_eval.0)?
.0;
match self.try_get_is_uniqueness_provable() {
Some(true) => {
verify_monotonic::<S, true, true>(
builder,
alpha,
beta,
group_by_result_columns_evals[0],
output_chi_eval.0,
)?;
}
Some(false) => (),
None => {
Err(ProofError::UnsupportedQueryPlan {
error: "AggregateExec with nonzero grouping columns and without provable uniqueness check not supported.",
})?;
}
}
let sum_result_columns_evals =
builder.try_consume_first_round_mle_evaluations(self.sum_expr.len() + 1)?;
let sum_out_fold_eval = fold_vals(beta, &sum_result_columns_evals);
builder.try_produce_sumcheck_subpolynomial_evaluation(
SumcheckSubpolynomialType::ZeroSum,
g_in_star_eval * where_eval * sum_in_fold_eval - g_out_star_eval * sum_out_fold_eval,
3,
)?;
let column_evals = group_by_result_columns_evals
.into_iter()
.chain(sum_result_columns_evals)
.collect::<Vec<_>>();
Ok(TableEvaluation::new(column_evals, output_chi_eval))
}
fn get_column_result_fields(&self) -> Vec<ColumnField> {
self.group_by_exprs
.iter()
.map(|aliased_expr| {
ColumnField::new(aliased_expr.alias.clone(), aliased_expr.expr.data_type())
})
.chain(self.sum_expr.iter().map(|aliased_expr| {
ColumnField::new(aliased_expr.alias.clone(), aliased_expr.expr.data_type())
}))
.chain(iter::once(ColumnField::new(
self.count_alias.clone(),
ColumnType::BigInt,
)))
.collect()
}
fn get_column_references(&self) -> IndexSet<ColumnRef> {
self.input.get_column_references()
}
fn get_table_references(&self) -> IndexSet<TableRef> {
self.input.get_table_references()
}
}
impl ProverEvaluate for AggregateExec {
#[tracing::instrument(
name = "AggregateExec::first_round_evaluate",
level = "debug",
skip_all
)]
fn first_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FirstRoundBuilder<'a, S>,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
params: &[LiteralValue],
) -> PlaceholderResult<Table<'a, S>> {
log::log_memory_usage("Start");
builder.request_post_result_challenges(2);
let input = self
.input
.first_round_evaluate(builder, alloc, table_map, params)?;
let group_by_columns = self
.group_by_exprs
.iter()
.map(|aliased_expr| -> PlaceholderResult<Column<'a, S>> {
aliased_expr
.expr
.first_round_evaluate(alloc, &input, params)
})
.collect::<PlaceholderResult<Vec<_>>>()?;
let selection_column: Column<'a, S> = self
.where_clause
.first_round_evaluate(alloc, &input, params)?;
let selection = selection_column
.as_boolean()
.expect("selection is not boolean");
let sum_columns = self
.sum_expr
.iter()
.map(|aliased_expr| -> PlaceholderResult<Column<'a, S>> {
aliased_expr
.expr
.first_round_evaluate(alloc, &input, params)
})
.collect::<PlaceholderResult<Vec<_>>>()?;
let AggregatedColumns {
group_by_columns: group_by_result_columns,
sum_columns: sum_result_columns,
count_column,
..
} = aggregate_columns(alloc, &group_by_columns, &sum_columns, &[], &[], selection)
.expect("columns should be aggregatable");
for column in &group_by_result_columns {
builder.produce_intermediate_mle(*column);
}
builder.produce_chi_evaluation_length(count_column.len());
let sum_result_columns_iter = sum_result_columns
.iter()
.map(|col| Column::Scalar(col))
.chain(iter::once(Column::BigInt(count_column)));
let res = Table::<'a, S>::try_from_iter(
self.get_column_result_fields()
.into_iter()
.map(|field| field.name())
.zip(
group_by_result_columns
.iter()
.copied()
.chain(sum_result_columns_iter.clone()),
),
)
.expect("Failed to create table from column references");
if self
.try_get_is_uniqueness_provable()
.expect("Group by must be provable")
{
first_round_evaluate_monotonic(
builder,
alloc,
alloc.alloc_slice_copy(&group_by_result_columns[0].to_scalar()),
);
}
for column in sum_result_columns_iter {
builder.produce_intermediate_mle(column);
}
log::log_memory_usage("End");
Ok(res)
}
#[expect(clippy::too_many_lines)]
#[tracing::instrument(
name = "AggregateExec::final_round_evaluate",
level = "debug",
skip_all
)]
fn final_round_evaluate<'a, S: Scalar>(
&self,
builder: &mut FinalRoundBuilder<'a, S>,
alloc: &'a Bump,
table_map: &IndexMap<TableRef, Table<'a, S>>,
params: &[LiteralValue],
) -> PlaceholderResult<Table<'a, S>> {
log::log_memory_usage("Start");
let alpha = builder.consume_post_result_challenge();
let beta = builder.consume_post_result_challenge();
let input = self
.input
.final_round_evaluate(builder, alloc, table_map, params)?;
let n = input.num_rows();
let group_by_columns = self
.group_by_exprs
.iter()
.map(|aliased_expr| -> PlaceholderResult<Column<'a, S>> {
aliased_expr
.expr
.final_round_evaluate(builder, alloc, &input, params)
})
.collect::<PlaceholderResult<Vec<_>>>()?;
let fold_gadget = FoldLogExpr::new(alpha, beta);
let g_in_star = fold_gadget
.final_round_evaluate(builder, alloc, &group_by_columns, n)
.0;
let selection_column: Column<'a, S> = self
.where_clause
.final_round_evaluate(builder, alloc, &input, params)?;
let selection = selection_column
.as_boolean()
.expect("selection is not boolean");
let span = span!(
Level::DEBUG,
"AggregateExec::final_round_evaluate sum_columns"
)
.entered();
let sum_columns = self
.sum_expr
.iter()
.map(|aliased_expr| -> PlaceholderResult<Column<'a, S>> {
aliased_expr
.expr
.final_round_evaluate(builder, alloc, &input, params)
})
.collect::<PlaceholderResult<Vec<_>>>()?;
span.exit();
let span = span!(
Level::DEBUG,
"AggregateExec::final_round_evaluate allocate sum_in_fold"
)
.entered();
let sum_in_fold = alloc.alloc_slice_fill_copy(n, One::one());
span.exit();
fold_columns(sum_in_fold, beta, beta, &sum_columns);
let AggregatedColumns {
group_by_columns: group_by_result_columns,
sum_columns: sum_result_columns,
count_column,
..
} = aggregate_columns(alloc, &group_by_columns, &sum_columns, &[], &[], selection)
.expect("columns should be aggregatable");
let m = count_column.len();
let g_out_star = fold_gadget
.final_round_evaluate(builder, alloc, &group_by_result_columns, m)
.0;
if self
.try_get_is_uniqueness_provable()
.expect("Group by must be provable")
{
let g_out_scalars = group_by_result_columns[0].to_scalar();
let alloc_g_out_scalars = alloc.alloc_slice_copy(&g_out_scalars);
final_round_evaluate_monotonic::<S, true, true>(
builder,
alloc,
alpha,
beta,
alloc_g_out_scalars,
);
}
let sum_result_columns_iter = sum_result_columns.iter().map(|col| Column::Scalar(col));
let columns = group_by_result_columns
.clone()
.into_iter()
.chain(sum_result_columns_iter.clone())
.chain(iter::once(Column::BigInt(count_column)));
let res = Table::<'a, S>::try_from_iter(
self.get_column_result_fields()
.into_iter()
.map(|field| field.name())
.zip(columns.clone()),
)
.expect("Failed to create table from column references");
let sum_out_fold = alloc.alloc_slice_fill_default(m);
slice_ops::slice_cast_mut(count_column, sum_out_fold);
fold_columns(sum_out_fold, beta, beta, &sum_result_columns);
builder.produce_sumcheck_subpolynomial(
SumcheckSubpolynomialType::ZeroSum,
vec![
(
S::one(),
vec![
Box::new(g_in_star as &[_]),
Box::new(selection),
Box::new(sum_in_fold as &[_]),
],
),
(
-S::one(),
vec![Box::new(g_out_star as &[_]), Box::new(sum_out_fold as &[_])],
),
],
);
log::log_memory_usage("End");
Ok(res)
}
}