use crate::{
db::{
DbSession, MissingRowPolicy, PersistedRow, Query, QueryError,
executor::{ScalarNumericFieldBoundaryRequest, ScalarProjectionBoundaryRequest},
query::plan::{AggregateKind, FieldSlot},
session::sql::explain::resolve_sql_aggregate_target_slot,
session::sql::{SqlParsedStatement, SqlStatementRoute},
sql::lowering::{
SqlGlobalAggregateTerminal, compile_sql_global_aggregate_command,
is_sql_global_aggregate_statement,
},
sql::parser::SqlStatement,
},
traits::{CanisterKind, EntityValue},
types::Id,
value::Value,
};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(in crate::db::session::sql) enum SqlAggregateSurface {
QueryFrom,
ExecuteSql,
ExecuteSqlGrouped,
ExecuteSqlDispatch,
GeneratedQuerySurface,
}
pub(in crate::db::session::sql) fn parsed_requires_dedicated_sql_aggregate_lane(
parsed: &SqlParsedStatement,
) -> bool {
is_sql_global_aggregate_statement(&parsed.statement)
}
pub(in crate::db::session::sql) const fn unsupported_sql_aggregate_lane_message(
surface: SqlAggregateSurface,
) -> &'static str {
match surface {
SqlAggregateSurface::QueryFrom => {
"query_from_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
}
SqlAggregateSurface::ExecuteSql => {
"execute_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
}
SqlAggregateSurface::ExecuteSqlGrouped => {
"execute_sql_grouped rejects global aggregate SELECT; use execute_sql_aggregate(...)"
}
SqlAggregateSurface::ExecuteSqlDispatch => {
"execute_sql_dispatch rejects global aggregate SELECT; use execute_sql_aggregate(...)"
}
SqlAggregateSurface::GeneratedQuerySurface => {
"generated SQL query surface rejects global aggregate SELECT; use execute_sql_aggregate(...)"
}
}
}
const fn unsupported_sql_aggregate_surface_lane_message(route: &SqlStatementRoute) -> &'static str {
match route {
SqlStatementRoute::Query { .. } => {
"execute_sql_aggregate requires constrained global aggregate SELECT"
}
SqlStatementRoute::Explain { .. } => {
"execute_sql_aggregate rejects EXPLAIN; use execute_sql_dispatch"
}
SqlStatementRoute::Describe { .. } => {
"execute_sql_aggregate rejects DESCRIBE; use execute_sql_dispatch"
}
SqlStatementRoute::ShowIndexes { .. } => {
"execute_sql_aggregate rejects SHOW INDEXES; use execute_sql_dispatch"
}
SqlStatementRoute::ShowColumns { .. } => {
"execute_sql_aggregate rejects SHOW COLUMNS; use execute_sql_dispatch"
}
SqlStatementRoute::ShowEntities => {
"execute_sql_aggregate rejects SHOW ENTITIES; use execute_sql_dispatch"
}
}
}
const fn unsupported_sql_aggregate_grouped_message() -> &'static str {
"execute_sql_aggregate rejects grouped SELECT; use execute_sql_grouped(...)"
}
impl<C: CanisterKind> DbSession<C> {
pub fn execute_sql_aggregate<E>(&self, sql: &str) -> Result<Value, QueryError>
where
E: PersistedRow<Canister = C> + EntityValue,
{
let parsed = self.parse_sql_statement(sql)?;
match &parsed.statement {
SqlStatement::Select(_) if is_sql_global_aggregate_statement(&parsed.statement) => {}
SqlStatement::Select(statement) if !statement.group_by.is_empty() => {
return Err(QueryError::unsupported_query(
unsupported_sql_aggregate_grouped_message(),
));
}
SqlStatement::Delete(_) => {
return Err(QueryError::unsupported_query(
"execute_sql_aggregate rejects DELETE; use execute_sql_dispatch",
));
}
_ => {
return Err(QueryError::unsupported_query(
unsupported_sql_aggregate_surface_lane_message(parsed.route()),
));
}
}
let command = compile_sql_global_aggregate_command::<E>(sql, MissingRowPolicy::Ignore)
.map_err(QueryError::from_sql_lowering_error)?;
match command.terminal() {
SqlGlobalAggregateTerminal::CountRows => self
.execute_load_query_with(command.query(), |load, plan| {
load.execute_scalar_terminal_request(
plan,
crate::db::executor::ScalarTerminalBoundaryRequest::Count,
)?
.into_count()
})
.map(|count| Value::Uint(u64::from(count))),
SqlGlobalAggregateTerminal::CountField(field) => {
let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
self.execute_load_query_with(command.query(), |load, plan| {
load.execute_scalar_projection_boundary(
plan,
target_slot,
ScalarProjectionBoundaryRequest::Values,
)?
.into_values()
})
.map(|values| {
let count = values
.into_iter()
.filter(|value| !matches!(value, Value::Null))
.count();
Value::Uint(u64::try_from(count).unwrap_or(u64::MAX))
})
}
SqlGlobalAggregateTerminal::SumField(field) => {
let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
self.execute_load_query_with(command.query(), |load, plan| {
load.execute_numeric_field_boundary(
plan,
target_slot,
ScalarNumericFieldBoundaryRequest::Sum,
)
})
.map(|value| value.map_or(Value::Null, Value::Decimal))
}
SqlGlobalAggregateTerminal::AvgField(field) => {
let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
self.execute_load_query_with(command.query(), |load, plan| {
load.execute_numeric_field_boundary(
plan,
target_slot,
ScalarNumericFieldBoundaryRequest::Avg,
)
})
.map(|value| value.map_or(Value::Null, Value::Decimal))
}
SqlGlobalAggregateTerminal::MinField(field) => self
.execute_ranked_sql_aggregate_field::<E>(
command.query(),
field,
AggregateKind::Min,
),
SqlGlobalAggregateTerminal::MaxField(field) => self
.execute_ranked_sql_aggregate_field::<E>(
command.query(),
field,
AggregateKind::Max,
),
}
}
fn execute_ranked_sql_aggregate_field<E>(
&self,
query: &Query<E>,
field: &str,
kind: AggregateKind,
) -> Result<Value, QueryError>
where
E: PersistedRow<Canister = C> + EntityValue,
{
let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
let matched_id = self.execute_ranked_sql_aggregate_id(query, target_slot, kind)?;
match matched_id {
Some(id) => self
.load::<E>()
.by_id(id)
.first_value_by(field)
.map(|value| value.unwrap_or(Value::Null)),
None => Ok(Value::Null),
}
}
fn execute_ranked_sql_aggregate_id<E>(
&self,
query: &Query<E>,
target_slot: FieldSlot,
kind: AggregateKind,
) -> Result<Option<Id<E>>, QueryError>
where
E: PersistedRow<Canister = C> + EntityValue,
{
if !kind.is_extrema() {
return Err(QueryError::invariant(
"ranked SQL aggregate id helper only supports MIN/MAX",
));
}
self.execute_load_query_with(query, |load, plan| {
load.execute_scalar_terminal_request(
plan,
crate::db::executor::ScalarTerminalBoundaryRequest::IdBySlot {
kind,
target_field: target_slot,
},
)?
.into_id()
})
}
}