use crate::{
db::{
executor::{
PreparedExecutionPlan,
pipeline::{
contracts::{CursorEmissionMode, LoadExecutor, ProjectionMaterializationMode},
entrypoints::execute_prepared_scalar_aggregate_kernel_row_sink_for_canister,
runtime::compile_retained_slot_layout_for_mode_with_extra_slots,
},
projection::{
ProjectionEvalError, ScalarProjectionExpr,
eval_canonical_scalar_projection_expr_with_required_value_reader_cow,
},
terminal::{KernelRow, RetainedSlotLayout},
},
numeric::{
add_decimal_terms, average_decimal_terms, coerce_numeric_decimal,
compare_numeric_or_strict_order,
},
query::plan::AccessPlannedQuery,
},
error::InternalError,
model::entity::EntityModel,
traits::{EntityKind, EntityValue},
types::Decimal,
value::Value,
};
#[derive(Clone, Debug, Eq, PartialEq)]
pub(in crate::db) struct PreparedScalarAggregateTerminalSet {
terminals: Vec<InternedPreparedScalarAggregateTerminal>,
input_exprs: Vec<ScalarProjectionExpr>,
filter_exprs: Vec<ScalarProjectionExpr>,
}
impl PreparedScalarAggregateTerminalSet {
#[must_use]
pub(in crate::db) fn new(terminals: Vec<PreparedScalarAggregateTerminal>) -> Self {
let mut input_exprs = Vec::new();
let mut filter_exprs = Vec::new();
let terminals = terminals
.into_iter()
.map(|terminal| terminal.into_interned(&mut input_exprs, &mut filter_exprs))
.collect();
Self {
terminals,
input_exprs,
filter_exprs,
}
}
const fn is_empty(&self) -> bool {
self.terminals.is_empty()
}
fn retained_slot_layout(
&self,
model: &EntityModel,
plan: &AccessPlannedQuery,
) -> Result<RetainedSlotLayout, InternalError> {
let mut extra_slots = Vec::new();
for terminal in &self.terminals {
terminal.input.extend_referenced_slots(&mut extra_slots);
}
for expr in &self.input_exprs {
expr.extend_referenced_slots(&mut extra_slots);
}
for expr in &self.filter_exprs {
expr.extend_referenced_slots(&mut extra_slots);
}
compile_retained_slot_layout_for_mode_with_extra_slots(
model,
plan,
ProjectionMaterializationMode::RetainSlotRows,
CursorEmissionMode::Suppress,
extra_slots.as_slice(),
)
.ok_or_else(|| {
InternalError::query_executor_invariant(
"scalar aggregate terminal execution requires a retained-slot layout",
)
})
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(in crate::db) struct PreparedScalarAggregateTerminal {
kind: ScalarAggregateTerminalKind,
input: ScalarAggregateInput,
filter: Option<ScalarProjectionExpr>,
distinct: bool,
empty_behavior: AggregateEmptyBehavior,
}
impl PreparedScalarAggregateTerminal {
#[must_use]
pub(in crate::db) const fn from_validated_parts(
kind: ScalarAggregateTerminalKind,
input: ScalarAggregateInput,
filter: Option<ScalarProjectionExpr>,
distinct: bool,
) -> Self {
Self {
kind,
input,
filter,
distinct,
empty_behavior: kind.empty_behavior(),
}
}
fn into_interned(
self,
input_exprs: &mut Vec<ScalarProjectionExpr>,
filter_exprs: &mut Vec<ScalarProjectionExpr>,
) -> InternedPreparedScalarAggregateTerminal {
let input = match self.input {
ScalarAggregateInput::Rows => InternedScalarAggregateInput::Rows,
ScalarAggregateInput::Field { slot, field } => {
InternedScalarAggregateInput::Field { slot, field }
}
ScalarAggregateInput::Expr(expr) => {
InternedScalarAggregateInput::Expr(intern_scalar_terminal_expr(input_exprs, expr))
}
};
let filter = self
.filter
.map(|expr| intern_scalar_terminal_expr(filter_exprs, expr));
InternedPreparedScalarAggregateTerminal {
kind: self.kind,
input,
filter,
distinct: self.distinct,
empty_behavior: self.empty_behavior,
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(in crate::db) enum ScalarAggregateTerminalKind {
CountRows,
CountValues,
Sum,
Avg,
Min,
Max,
}
impl ScalarAggregateTerminalKind {
const fn empty_behavior(self) -> AggregateEmptyBehavior {
match self {
Self::CountRows | Self::CountValues => AggregateEmptyBehavior::Zero,
Self::Sum | Self::Avg | Self::Min | Self::Max => AggregateEmptyBehavior::Null,
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(in crate::db) enum ScalarAggregateInput {
Rows,
Field { slot: usize, field: String },
Expr(ScalarProjectionExpr),
}
#[derive(Clone, Debug, Eq, PartialEq)]
struct InternedPreparedScalarAggregateTerminal {
kind: ScalarAggregateTerminalKind,
input: InternedScalarAggregateInput,
filter: Option<usize>,
distinct: bool,
empty_behavior: AggregateEmptyBehavior,
}
#[derive(Clone, Debug, Eq, PartialEq)]
enum InternedScalarAggregateInput {
Rows,
Field { slot: usize, field: String },
Expr(usize),
}
impl InternedScalarAggregateInput {
fn extend_referenced_slots(&self, slots: &mut Vec<usize>) {
match self {
Self::Rows | Self::Expr(_) => {}
Self::Field { slot, .. } => {
if !slots.contains(slot) {
slots.push(*slot);
}
}
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum AggregateEmptyBehavior {
Zero,
Null,
}
struct ScalarAggregateReducerState {
terminal: InternedPreparedScalarAggregateTerminal,
distinct_values: Vec<Value>,
count: u64,
sum: Option<Decimal>,
selected: Option<Value>,
}
impl ScalarAggregateReducerState {
const fn new(terminal: InternedPreparedScalarAggregateTerminal) -> Self {
Self {
terminal,
distinct_values: Vec::new(),
count: 0,
sum: None,
selected: None,
}
}
fn admit_distinct_value(&mut self, value: &Value) -> bool {
if !self.terminal.distinct {
return true;
}
if self.distinct_values.iter().any(|current| current == value) {
return false;
}
self.distinct_values.push(value.clone());
true
}
fn ingest_row(&mut self) -> Result<(), InternalError> {
if self.terminal.distinct {
return Err(InternalError::query_executor_invariant(
"COUNT(*) scalar aggregate terminal cannot be DISTINCT",
));
}
self.count = self.count.saturating_add(1);
Ok(())
}
fn ingest_value(&mut self, value: Value) -> Result<(), InternalError> {
if !self.admit_distinct_value(&value) || matches!(value, Value::Null) {
return Ok(());
}
match self.terminal.kind {
ScalarAggregateTerminalKind::CountValues => {
self.count = self.count.saturating_add(1);
Ok(())
}
ScalarAggregateTerminalKind::Sum | ScalarAggregateTerminalKind::Avg => {
let decimal = coerce_numeric_decimal(&value).ok_or_else(|| {
InternalError::query_executor_invariant(format!(
"scalar aggregate numeric terminal encountered non-numeric value: {value:?}",
))
})?;
self.sum = Some(
self.sum
.map_or(decimal, |current| add_decimal_terms(current, decimal)),
);
self.count = self.count.saturating_add(1);
Ok(())
}
ScalarAggregateTerminalKind::Min | ScalarAggregateTerminalKind::Max => {
let replace = match self.selected.as_ref() {
None => true,
Some(current) => {
let ordering = compare_numeric_or_strict_order(&value, current)
.ok_or_else(|| {
InternalError::query_executor_invariant(format!(
"scalar aggregate extrema terminal encountered incomparable values: left={value:?} right={current:?}",
))
})?;
match self.terminal.kind {
ScalarAggregateTerminalKind::Min => ordering.is_lt(),
ScalarAggregateTerminalKind::Max => ordering.is_gt(),
ScalarAggregateTerminalKind::CountRows
| ScalarAggregateTerminalKind::CountValues
| ScalarAggregateTerminalKind::Sum
| ScalarAggregateTerminalKind::Avg => {
return Err(InternalError::query_executor_invariant(
"scalar aggregate extrema terminal kind mismatch",
));
}
}
}
};
if replace {
self.selected = Some(value);
}
Ok(())
}
ScalarAggregateTerminalKind::CountRows => Err(InternalError::query_executor_invariant(
"COUNT(*) scalar aggregate terminal cannot consume projected values",
)),
}
}
fn finalize(self) -> Value {
match self.terminal.kind {
ScalarAggregateTerminalKind::CountRows | ScalarAggregateTerminalKind::CountValues => {
Value::Uint(self.count)
}
ScalarAggregateTerminalKind::Sum => {
self.sum.map_or_else(|| self.empty_value(), Value::Decimal)
}
ScalarAggregateTerminalKind::Avg => self
.sum
.and_then(|sum| average_decimal_terms(sum, self.count))
.map_or_else(|| self.empty_value(), Value::Decimal),
ScalarAggregateTerminalKind::Min | ScalarAggregateTerminalKind::Max => {
let empty_value = self.empty_value();
self.selected.unwrap_or(empty_value)
}
}
}
const fn empty_value(&self) -> Value {
match self.terminal.empty_behavior {
AggregateEmptyBehavior::Zero => Value::Uint(0),
AggregateEmptyBehavior::Null => Value::Null,
}
}
}
struct ScalarAggregateReducerRuntime {
reducers: Vec<ScalarAggregateReducerState>,
input_exprs: Vec<ScalarProjectionExpr>,
filter_exprs: Vec<ScalarProjectionExpr>,
input_expr_values: Vec<Option<Value>>,
filter_expr_values: Vec<Option<Value>>,
}
impl ScalarAggregateReducerRuntime {
fn new(terminals: PreparedScalarAggregateTerminalSet) -> Self {
let reducers = terminals
.terminals
.into_iter()
.map(ScalarAggregateReducerState::new)
.collect();
let input_expr_values = Vec::with_capacity(terminals.input_exprs.len());
let filter_expr_values = Vec::with_capacity(terminals.filter_exprs.len());
Self {
reducers,
input_exprs: terminals.input_exprs,
filter_exprs: terminals.filter_exprs,
input_expr_values,
filter_expr_values,
}
}
fn ingest_row(&mut self, row: &KernelRow) -> Result<(), InternalError> {
reset_scalar_terminal_expr_values(&mut self.input_expr_values, self.input_exprs.len());
reset_scalar_terminal_expr_values(&mut self.filter_expr_values, self.filter_exprs.len());
for reducer in &mut self.reducers {
if !terminal_filter_matches(
&reducer.terminal,
self.filter_exprs.as_slice(),
row,
&mut self.filter_expr_values,
)? {
continue;
}
match &reducer.terminal.input {
InternedScalarAggregateInput::Rows => reducer.ingest_row()?,
InternedScalarAggregateInput::Field { slot, field } => {
let value = row.slot_ref(*slot).cloned().ok_or_else(|| {
ProjectionEvalError::MissingFieldValue {
field: field.clone(),
index: *slot,
}
.into_invalid_logical_plan_internal_error()
})?;
reducer.ingest_value(value)?;
}
InternedScalarAggregateInput::Expr(expr_index) => {
let value = cached_scalar_terminal_expr_value(
self.input_exprs.as_slice(),
row,
&mut self.input_expr_values,
*expr_index,
"input",
)?
.clone();
reducer.ingest_value(value)?;
}
}
}
Ok(())
}
fn finalize(self) -> Vec<Value> {
self.reducers
.into_iter()
.map(ScalarAggregateReducerState::finalize)
.collect()
}
}
impl<E> LoadExecutor<E>
where
E: EntityKind + EntityValue,
{
pub(in crate::db) fn execute_scalar_aggregate_terminals(
&self,
plan: PreparedExecutionPlan<E>,
terminals: PreparedScalarAggregateTerminalSet,
) -> Result<Vec<Value>, InternalError> {
if terminals.is_empty() {
return Ok(Vec::new());
}
let plan = plan.into_prepared_load_plan();
let retained_slot_layout =
terminals.retained_slot_layout(plan.authority().model(), plan.logical_plan())?;
let mut reducer_runtime = ScalarAggregateReducerRuntime::new(terminals);
execute_prepared_scalar_aggregate_kernel_row_sink_for_canister(
&self.db,
self.debug,
plan,
retained_slot_layout,
|row| reducer_runtime.ingest_row(row),
)?;
Ok(reducer_runtime.finalize())
}
}
fn intern_scalar_terminal_expr(
exprs: &mut Vec<ScalarProjectionExpr>,
expr: ScalarProjectionExpr,
) -> usize {
if let Some(index) = exprs.iter().position(|candidate| candidate == &expr) {
return index;
}
let index = exprs.len();
exprs.push(expr);
index
}
fn reset_scalar_terminal_expr_values(values: &mut Vec<Option<Value>>, len: usize) {
values.clear();
values.resize_with(len, || None);
}
fn cached_scalar_terminal_expr_value<'a>(
exprs: &[ScalarProjectionExpr],
row: &KernelRow,
values: &'a mut [Option<Value>],
index: usize,
label: &str,
) -> Result<&'a Value, InternalError> {
let expr = exprs.get(index).ok_or_else(|| {
InternalError::query_executor_invariant(format!(
"scalar aggregate terminal {label} expression index missing from expression table",
))
})?;
let value = values.get_mut(index).ok_or_else(|| {
InternalError::query_executor_invariant(format!(
"scalar aggregate terminal {label} expression index missing from row buffer",
))
})?;
if value.is_none() {
*value = Some(evaluate_scalar_terminal_expr(expr, row)?);
}
value.as_ref().ok_or_else(|| {
InternalError::query_executor_invariant(format!(
"scalar aggregate terminal {label} expression evaluation produced no row value",
))
})
}
fn terminal_filter_matches(
terminal: &InternedPreparedScalarAggregateTerminal,
filter_exprs: &[ScalarProjectionExpr],
row: &KernelRow,
filter_expr_values: &mut [Option<Value>],
) -> Result<bool, InternalError> {
let Some(filter_index) = terminal.filter else {
return Ok(true);
};
let value = cached_scalar_terminal_expr_value(
filter_exprs,
row,
filter_expr_values,
filter_index,
"filter",
)?;
match value {
Value::Bool(true) => Ok(true),
Value::Bool(false) | Value::Null => Ok(false),
found => Err(InternalError::query_executor_invariant(format!(
"scalar aggregate terminal filter expression produced non-boolean value: {found:?}",
))),
}
}
fn evaluate_scalar_terminal_expr(
expr: &ScalarProjectionExpr,
row: &KernelRow,
) -> Result<Value, InternalError> {
let mut read_slot = |slot: usize| {
row.slot_ref(slot)
.map(std::borrow::Cow::Borrowed)
.ok_or_else(|| {
ProjectionEvalError::MissingFieldValue {
field: format!("slot[{slot}]"),
index: slot,
}
.into_invalid_logical_plan_internal_error()
})
};
eval_canonical_scalar_projection_expr_with_required_value_reader_cow(expr, &mut read_slot)
.map(std::borrow::Cow::into_owned)
}
#[cfg(test)]
mod tests {
use crate::{db::query::plan::expr::BinaryOp, value::Value};
use super::*;
fn literal_uint(value: u64) -> ScalarProjectionExpr {
ScalarProjectionExpr::Literal(Value::Uint(value))
}
fn repeated_input_expr() -> ScalarProjectionExpr {
ScalarProjectionExpr::Binary {
op: BinaryOp::Add,
left: Box::new(literal_uint(41)),
right: Box::new(literal_uint(1)),
}
}
fn repeated_filter_expr() -> ScalarProjectionExpr {
ScalarProjectionExpr::Binary {
op: BinaryOp::Gte,
left: Box::new(literal_uint(42)),
right: Box::new(literal_uint(1)),
}
}
#[test]
fn scalar_aggregate_terminal_set_interns_duplicate_input_and_filter_exprs() {
let input = repeated_input_expr();
let filter = repeated_filter_expr();
let terminals = PreparedScalarAggregateTerminalSet::new(vec![
PreparedScalarAggregateTerminal::from_validated_parts(
ScalarAggregateTerminalKind::Sum,
ScalarAggregateInput::Expr(input.clone()),
Some(filter.clone()),
false,
),
PreparedScalarAggregateTerminal::from_validated_parts(
ScalarAggregateTerminalKind::Avg,
ScalarAggregateInput::Expr(input),
Some(filter),
false,
),
]);
assert_eq!(
terminals.input_exprs.len(),
1,
"duplicate SUM/AVG input expressions should share one interned input expression",
);
assert_eq!(
terminals.filter_exprs.len(),
1,
"duplicate aggregate FILTER expressions should share one interned filter expression",
);
assert!(
terminals
.terminals
.iter()
.all(|terminal| matches!(terminal.input, InternedScalarAggregateInput::Expr(0))),
"every expression-backed terminal should point at the shared input expression",
);
assert!(
terminals
.terminals
.iter()
.all(|terminal| terminal.filter == Some(0)),
"every filtered terminal should point at the shared filter expression",
);
}
#[test]
fn scalar_aggregate_terminal_set_keeps_field_inputs_out_of_expr_table() {
let terminals = PreparedScalarAggregateTerminalSet::new(vec![
PreparedScalarAggregateTerminal::from_validated_parts(
ScalarAggregateTerminalKind::CountValues,
ScalarAggregateInput::Field {
slot: 2,
field: "age".to_string(),
},
None,
false,
),
PreparedScalarAggregateTerminal::from_validated_parts(
ScalarAggregateTerminalKind::Sum,
ScalarAggregateInput::Expr(repeated_input_expr()),
None,
false,
),
]);
assert_eq!(
terminals.input_exprs.len(),
1,
"only expression-backed aggregate inputs should enter the input expression table",
);
assert!(
matches!(
terminals.terminals[0].input,
InternedScalarAggregateInput::Field { slot: 2, .. }
),
"field-backed aggregate inputs should remain direct retained-slot reads",
);
assert!(
matches!(
terminals.terminals[1].input,
InternedScalarAggregateInput::Expr(0)
),
"the expression-backed aggregate input should point at its interned expression",
);
}
}