#[cfg(test)]
mod tests;
use crate::{
db::{
access::AccessPathKind,
cursor::{ContinuationRuntime, LoopAction},
data::DataRow,
direction::Direction,
executor::{
AccessScanContinuationInput, AccessStreamBindings, ExecutableAccess, ExecutablePlan,
ExecutionKernel, KeyStreamLoopControl, PreparedAggregatePlan, TraversalRuntime,
aggregate::field::{
AggregateFieldValueError, FieldSlot,
extract_numeric_field_decimal_from_decoded_slot,
extract_numeric_field_decimal_with_slot_reader,
resolve_numeric_aggregate_target_slot_from_planner_slot_with_model,
},
aggregate::{
PreparedAggregateStreamingInputs, PreparedAggregateStreamingInputsCore,
PreparedScalarNumericAggregateStrategy, PreparedScalarNumericBoundary,
PreparedScalarNumericExecutionState, PreparedScalarNumericOp,
PreparedScalarNumericPayload,
},
pipeline::contracts::LoadExecutor,
plan_metrics::record_rows_scanned_for_path,
terminal::{RowDecoder, RowLayout},
},
numeric::{add_decimal_terms, average_decimal_terms},
query::plan::{ExecutionOrderContract, FieldSlot as PlannedFieldSlot},
},
error::InternalError,
traits::{EntityKind, EntityValue},
types::Decimal,
value::Value,
};
use std::cell::RefCell;
#[derive(Clone, Copy)]
pub(in crate::db) enum ScalarNumericFieldBoundaryRequest {
Sum,
SumDistinct,
Avg,
AvgDistinct,
}
impl ScalarNumericFieldBoundaryRequest {
const fn prepared_op(self) -> PreparedScalarNumericOp {
match self {
Self::Sum | Self::SumDistinct => PreparedScalarNumericOp::Sum,
Self::Avg | Self::AvgDistinct => PreparedScalarNumericOp::Avg,
}
}
const fn requires_global_distinct(self) -> bool {
matches!(self, Self::SumDistinct | Self::AvgDistinct)
}
}
impl<E> LoadExecutor<E>
where
E: EntityKind + EntityValue,
{
pub(in crate::db) fn execute_numeric_field_boundary(
&self,
plan: ExecutablePlan<E>,
target_field: PlannedFieldSlot,
request: ScalarNumericFieldBoundaryRequest,
) -> Result<Option<Decimal>, InternalError> {
let prepared = self.prepare_scalar_numeric_boundary(
plan.into_prepared_aggregate_plan(),
target_field,
request,
)?;
self.execute_prepared_scalar_numeric_boundary(prepared)
}
fn prepare_scalar_numeric_boundary(
&self,
plan: PreparedAggregatePlan,
target_field: PlannedFieldSlot,
request: ScalarNumericFieldBoundaryRequest,
) -> Result<PreparedScalarNumericExecutionState<'_>, InternalError> {
let boundary =
Self::resolve_prepared_scalar_numeric_boundary(&plan, &target_field, request)?;
let payload = self.prepare_scalar_numeric_payload(plan, &boundary, request)?;
Ok(PreparedScalarNumericExecutionState { boundary, payload })
}
fn execute_prepared_scalar_numeric_boundary(
&self,
prepared_state: PreparedScalarNumericExecutionState<'_>,
) -> Result<Option<Decimal>, InternalError> {
let PreparedScalarNumericExecutionState { boundary, payload } = prepared_state;
match payload {
PreparedScalarNumericPayload::Aggregate { strategy, prepared } => {
let prepared = *prepared;
if prepared.window_is_provably_empty() {
return Ok(None);
}
match strategy {
PreparedScalarNumericAggregateStrategy::Streaming => {
Self::aggregate_numeric_field_from_streaming(
prepared.into_core(),
&boundary.target_field_name,
boundary.field_slot,
boundary.op,
)
}
PreparedScalarNumericAggregateStrategy::Materialized => {
let row_layout = RowLayout::from_model(prepared.authority.model());
let page = self.execute_scalar_materialized_page_stage(prepared)?;
let (rows, _) = page.into_parts();
Self::aggregate_numeric_field_from_materialized(
rows,
&row_layout,
&boundary.target_field_name,
boundary.field_slot,
boundary.op,
)
}
}
}
PreparedScalarNumericPayload::GlobalDistinct { route } => {
let value = self.execute_prepared_global_distinct_grouped_aggregate(*route)?;
decode_global_distinct_numeric_output(value, boundary.op)
}
}
}
fn aggregate_numeric_field_from_materialized(
rows: Vec<DataRow>,
row_layout: &RowLayout,
target_field: &str,
field_slot: FieldSlot,
kind: PreparedScalarNumericOp,
) -> Result<Option<Decimal>, InternalError> {
let mut accumulator = NumericAggregateAccumulator::new();
for (data_key, raw_row) in rows {
let value = RowDecoder::decode_required_slot_value(
row_layout,
data_key.storage_key(),
&raw_row,
field_slot.index,
)?;
let value =
extract_numeric_field_decimal_from_decoded_slot(target_field, field_slot, value)
.map_err(AggregateFieldValueError::into_internal_error)?;
accumulator.add(value);
}
finalize_numeric_field_output(accumulator, kind)
}
fn streaming_numeric_field_aggregate_eligible(
prepared: &PreparedAggregateStreamingInputs<'_>,
) -> bool {
if !Self::aggregate_predicate_safe(prepared) {
return false;
}
let access_strategy = prepared.logical_plan.access.resolve_strategy();
let Some(path) = access_strategy.as_path() else {
return false;
};
let path_kind = path.capabilities().kind();
if !Self::aggregate_access_path_safe(path_kind) {
return false;
}
Self::aggregate_page_window_safe(prepared, path_kind)
}
const fn aggregate_predicate_safe(prepared: &PreparedAggregateStreamingInputs<'_>) -> bool {
prepared.has_no_predicate_or_distinct()
}
const fn aggregate_access_path_safe(path_kind: AccessPathKind) -> bool {
path_kind.supports_streaming_numeric_fold()
}
fn aggregate_page_window_safe(
prepared: &PreparedAggregateStreamingInputs<'_>,
path_kind: AccessPathKind,
) -> bool {
if prepared.page_spec().is_none() {
return true;
}
let Some(_order) = prepared.order_spec() else {
return false;
};
if prepared
.explicit_primary_key_order_direction(prepared.authority.model().primary_key.name)
.is_none()
{
return false;
}
path_kind.supports_streaming_numeric_fold_for_paged_primary_key_window()
}
fn aggregate_numeric_field_from_streaming(
prepared: PreparedAggregateStreamingInputsCore,
target_field: &str,
field_slot: FieldSlot,
kind: PreparedScalarNumericOp,
) -> Result<Option<Decimal>, InternalError> {
let consistency = prepared.consistency();
let direction = Self::aggregate_numeric_stream_direction(&prepared);
let row_layout = RowLayout::from_model(prepared.authority.model());
let PreparedAggregateStreamingInputsCore {
authority,
store,
logical_plan,
execution_preparation,
index_prefix_specs,
index_range_specs,
..
} = prepared;
let continuation = RefCell::new(ContinuationRuntime::from_window(
ExecutionKernel::window_cursor_contract(&logical_plan, None),
));
let index_predicate_execution = execution_preparation.strict_mode().map(|program| {
crate::db::index::predicate::IndexPredicateExecution {
program,
rejected_keys_counter: None,
}
});
let access = ExecutableAccess::new(
&logical_plan.access,
AccessStreamBindings::new(
index_prefix_specs.as_slice(),
index_range_specs.as_slice(),
AccessScanContinuationInput::new(None, direction),
),
None,
index_predicate_execution,
);
let runtime = TraversalRuntime::new(store, authority.entity_tag());
let mut key_stream = runtime.ordered_key_stream_from_runtime_access(access)?;
let mut rows_scanned = 0usize;
let mut accumulator = NumericAggregateAccumulator::new();
let mut pre_key =
|| Self::loop_control_from_continuation_action(continuation.borrow_mut().pre_fetch());
let mut on_key = |_data_key,
row: Option<crate::db::executor::terminal::page::KernelRow>|
-> Result<KeyStreamLoopControl, InternalError> {
let Some(row) = row else {
return Ok(KeyStreamLoopControl::Emit);
};
rows_scanned = rows_scanned.saturating_add(1);
match continuation.borrow_mut().accept_row() {
LoopAction::Skip => return Ok(KeyStreamLoopControl::Skip),
LoopAction::Emit => {}
LoopAction::Stop => return Ok(KeyStreamLoopControl::Stop),
}
let value = extract_numeric_field_decimal_with_slot_reader(
target_field,
field_slot,
&mut |index| row.slot(index),
)
.map_err(AggregateFieldValueError::into_internal_error)?;
accumulator.add(value);
Ok(KeyStreamLoopControl::Emit)
};
Self::drive_field_row_stream(
store,
&row_layout,
consistency,
key_stream.as_mut(),
&mut pre_key,
&mut on_key,
)?;
record_rows_scanned_for_path(authority.entity_path(), rows_scanned);
finalize_numeric_field_output(accumulator, kind)
}
fn aggregate_numeric_stream_direction(
prepared: &PreparedAggregateStreamingInputsCore,
) -> Direction {
ExecutionOrderContract::from_plan(false, prepared.order_spec()).primary_scan_direction()
}
const fn loop_control_from_continuation_action(action: LoopAction) -> KeyStreamLoopControl {
match action {
LoopAction::Skip => KeyStreamLoopControl::Skip,
LoopAction::Emit => KeyStreamLoopControl::Emit,
LoopAction::Stop => KeyStreamLoopControl::Stop,
}
}
fn resolve_prepared_scalar_numeric_boundary(
plan: &PreparedAggregatePlan,
target_field: &PlannedFieldSlot,
request: ScalarNumericFieldBoundaryRequest,
) -> Result<PreparedScalarNumericBoundary, InternalError> {
let authority = plan.authority();
let field_slot = resolve_numeric_aggregate_target_slot_from_planner_slot_with_model(
authority.model(),
target_field,
)
.map_err(AggregateFieldValueError::into_internal_error)?;
Ok(PreparedScalarNumericBoundary {
target_field_name: target_field.field().to_string(),
field_slot,
op: request.prepared_op(),
})
}
fn prepare_scalar_numeric_payload(
&self,
plan: PreparedAggregatePlan,
boundary: &PreparedScalarNumericBoundary,
request: ScalarNumericFieldBoundaryRequest,
) -> Result<PreparedScalarNumericPayload<'_>, InternalError> {
if request.requires_global_distinct() {
let route = self.prepare_global_distinct_grouped_route(
plan,
boundary.op.aggregate_kind(),
&boundary.target_field_name,
)?;
return Ok(PreparedScalarNumericPayload::GlobalDistinct {
route: Box::new(route),
});
}
let prepared = self.prepare_scalar_aggregate_boundary(plan)?;
let strategy = if Self::streaming_numeric_field_aggregate_eligible(&prepared) {
PreparedScalarNumericAggregateStrategy::Streaming
} else {
PreparedScalarNumericAggregateStrategy::Materialized
};
Ok(PreparedScalarNumericPayload::Aggregate {
strategy,
prepared: Box::new(prepared),
})
}
}
#[derive(Clone, Copy)]
struct NumericAggregateAccumulator {
sum: Decimal,
row_count: u64,
}
impl NumericAggregateAccumulator {
const fn new() -> Self {
Self {
sum: Decimal::ZERO,
row_count: 0,
}
}
fn add(&mut self, value: Decimal) {
self.sum = add_numeric_decimal(self.sum, value);
self.row_count = self.row_count.saturating_add(1);
}
}
fn finalize_numeric_field_output(
accumulator: NumericAggregateAccumulator,
kind: PreparedScalarNumericOp,
) -> Result<Option<Decimal>, InternalError> {
if accumulator.row_count == 0 {
return Ok(None);
}
let output = match kind {
PreparedScalarNumericOp::Sum => accumulator.sum,
PreparedScalarNumericOp::Avg => {
let Some(avg) = average_decimal_terms(accumulator.sum, accumulator.row_count) else {
return Err(kind.avg_divisor_conversion_invariant());
};
avg
}
};
Ok(Some(output))
}
fn decode_global_distinct_numeric_output(
value: Option<Value>,
op: PreparedScalarNumericOp,
) -> Result<Option<Decimal>, InternalError> {
match value {
Some(Value::Decimal(value)) => Ok(Some(value)),
Some(Value::Null) | None => Ok(None),
Some(value) => Err(op.grouped_distinct_output_type_mismatch(&value)),
}
}
fn add_numeric_decimal(sum: Decimal, value: Decimal) -> Decimal {
add_decimal_terms(sum, value)
}