Skip to main content

icydb_core/db/session/sql/
aggregate.rs

1//! Module: db::session::sql::aggregate
2//! Responsibility: session-owned execution and shaping helpers for lowered SQL
3//! scalar aggregate commands.
4//! Does not own: aggregate lowering or aggregate executor route selection.
5//! Boundary: binds lowered SQL aggregate commands onto authority-aware planning and result shaping.
6
7use crate::{
8    db::{
9        DbSession, MissingRowPolicy, PersistedRow, QueryError,
10        executor::{ScalarNumericFieldBoundaryRequest, ScalarProjectionBoundaryRequest},
11        numeric::{
12            add_decimal_terms, average_decimal_terms, coerce_numeric_decimal,
13            compare_numeric_or_strict_order,
14        },
15        session::sql::surface::sql_statement_route_from_statement,
16        session::sql::{SqlDispatchResult, SqlParsedStatement, SqlStatementRoute},
17        sql::lowering::{
18            PreparedSqlScalarAggregateRuntimeDescriptor, PreparedSqlScalarAggregateStrategy,
19            SqlGlobalAggregateCommand, SqlGlobalAggregateCommandCore,
20            compile_sql_global_aggregate_command_core_from_prepared,
21            compile_sql_global_aggregate_command_from_prepared, is_sql_global_aggregate_statement,
22            prepare_sql_statement,
23        },
24        sql::parser::{SqlStatement, parse_sql},
25    },
26    traits::{CanisterKind, EntityValue},
27    value::Value,
28};
29
30#[derive(Clone, Copy, Debug, Eq, PartialEq)]
31pub(in crate::db::session::sql) enum SqlAggregateSurface {
32    QueryFrom,
33    ExecuteSql,
34    ExecuteSqlGrouped,
35}
36
37pub(in crate::db::session::sql) fn parsed_requires_dedicated_sql_aggregate_lane(
38    parsed: &SqlParsedStatement,
39) -> bool {
40    is_sql_global_aggregate_statement(&parsed.statement)
41}
42
43pub(in crate::db::session::sql) const fn unsupported_sql_aggregate_lane_message(
44    surface: SqlAggregateSurface,
45) -> &'static str {
46    match surface {
47        SqlAggregateSurface::QueryFrom => {
48            "query_from_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
49        }
50        SqlAggregateSurface::ExecuteSql => {
51            "execute_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
52        }
53        SqlAggregateSurface::ExecuteSqlGrouped => {
54            "execute_sql_grouped rejects global aggregate SELECT; use execute_sql_aggregate(...)"
55        }
56    }
57}
58
59const fn unsupported_sql_aggregate_surface_lane_message(route: &SqlStatementRoute) -> &'static str {
60    match route {
61        SqlStatementRoute::Query { .. } => {
62            "execute_sql_aggregate requires constrained global aggregate SELECT"
63        }
64        SqlStatementRoute::Explain { .. } => {
65            "execute_sql_aggregate rejects EXPLAIN; use execute_sql_dispatch"
66        }
67        SqlStatementRoute::Describe { .. } => {
68            "execute_sql_aggregate rejects DESCRIBE; use execute_sql_dispatch"
69        }
70        SqlStatementRoute::ShowIndexes { .. } => {
71            "execute_sql_aggregate rejects SHOW INDEXES; use execute_sql_dispatch"
72        }
73        SqlStatementRoute::ShowColumns { .. } => {
74            "execute_sql_aggregate rejects SHOW COLUMNS; use execute_sql_dispatch"
75        }
76        SqlStatementRoute::ShowEntities => {
77            "execute_sql_aggregate rejects SHOW ENTITIES; use execute_sql_dispatch"
78        }
79    }
80}
81
82const fn unsupported_sql_aggregate_grouped_message() -> &'static str {
83    "execute_sql_aggregate rejects grouped SELECT; use execute_sql_grouped(...)"
84}
85
86impl<C: CanisterKind> DbSession<C> {
87    // Build the canonical SQL aggregate label projected by the prepared
88    // aggregate strategy so unified dispatch rows stay parser-stable.
89    pub(in crate::db::session::sql) fn sql_scalar_aggregate_label(
90        strategy: &PreparedSqlScalarAggregateStrategy,
91    ) -> String {
92        let kind = match strategy.runtime_descriptor() {
93            PreparedSqlScalarAggregateRuntimeDescriptor::CountRows
94            | PreparedSqlScalarAggregateRuntimeDescriptor::CountField => "COUNT",
95            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
96                kind: crate::db::query::plan::AggregateKind::Sum,
97            } => "SUM",
98            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
99                kind: crate::db::query::plan::AggregateKind::Avg,
100            } => "AVG",
101            PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
102                kind: crate::db::query::plan::AggregateKind::Min,
103            } => "MIN",
104            PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
105                kind: crate::db::query::plan::AggregateKind::Max,
106            } => "MAX",
107            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
108            | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
109                unreachable!("prepared SQL scalar aggregate strategy drifted outside SQL support")
110            }
111        };
112
113        match strategy.projected_field() {
114            Some(field) if strategy.is_distinct() => format!("{kind}(DISTINCT {field})"),
115            Some(field) => format!("{kind}({field})"),
116            None => format!("{kind}(*)"),
117        }
118    }
119
120    // Deduplicate one projected aggregate input stream while preserving the
121    // first-observed value order used by SQL aggregate reduction.
122    fn dedup_structural_sql_aggregate_input_values(values: Vec<Value>) -> Vec<Value> {
123        let mut deduped = Vec::with_capacity(values.len());
124
125        for value in values {
126            if deduped.iter().any(|current| current == &value) {
127                continue;
128            }
129            deduped.push(value);
130        }
131
132        deduped
133    }
134
135    // Reduce one structural aggregate field projection into canonical aggregate
136    // value semantics for the unified SQL dispatch/query surface.
137    fn reduce_structural_sql_aggregate_field_values(
138        values: Vec<Value>,
139        strategy: &PreparedSqlScalarAggregateStrategy,
140    ) -> Result<Value, QueryError> {
141        let values = if strategy.is_distinct() {
142            Self::dedup_structural_sql_aggregate_input_values(values)
143        } else {
144            values
145        };
146
147        match strategy.runtime_descriptor() {
148            PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => Err(QueryError::invariant(
149                "COUNT(*) structural reduction does not consume projected field values",
150            )),
151            PreparedSqlScalarAggregateRuntimeDescriptor::CountField => {
152                let count = values
153                    .into_iter()
154                    .filter(|value| !matches!(value, Value::Null))
155                    .count();
156
157                Ok(Value::Uint(u64::try_from(count).unwrap_or(u64::MAX)))
158            }
159            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
160                kind:
161                    crate::db::query::plan::AggregateKind::Sum
162                    | crate::db::query::plan::AggregateKind::Avg,
163            } => {
164                let mut sum = None;
165                let mut row_count = 0_u64;
166
167                for value in values {
168                    if matches!(value, Value::Null) {
169                        continue;
170                    }
171
172                    let decimal = coerce_numeric_decimal(&value).ok_or_else(|| {
173                        QueryError::invariant(
174                            "numeric SQL aggregate dispatch encountered non-numeric projected value",
175                        )
176                    })?;
177                    sum = Some(sum.map_or(decimal, |current| add_decimal_terms(current, decimal)));
178                    row_count = row_count.saturating_add(1);
179                }
180
181                match strategy.runtime_descriptor() {
182                    PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
183                        kind: crate::db::query::plan::AggregateKind::Sum,
184                    } => Ok(sum.map_or(Value::Null, Value::Decimal)),
185                    PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
186                        kind: crate::db::query::plan::AggregateKind::Avg,
187                    } => Ok(sum
188                        .and_then(|sum| average_decimal_terms(sum, row_count))
189                        .map_or(Value::Null, Value::Decimal)),
190                    _ => unreachable!("numeric SQL aggregate strategy drifted during reduction"),
191                }
192            }
193            PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
194                kind:
195                    crate::db::query::plan::AggregateKind::Min
196                    | crate::db::query::plan::AggregateKind::Max,
197            } => {
198                let mut selected = None::<Value>;
199
200                for value in values {
201                    if matches!(value, Value::Null) {
202                        continue;
203                    }
204
205                    let replace = match selected.as_ref() {
206                        None => true,
207                        Some(current) => {
208                            let ordering =
209                                compare_numeric_or_strict_order(&value, current).ok_or_else(
210                                    || {
211                                        QueryError::invariant(
212                                            "extrema SQL aggregate dispatch encountered incomparable projected values",
213                                        )
214                                    },
215                                )?;
216
217                            match strategy.runtime_descriptor() {
218                                PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
219                                    kind: crate::db::query::plan::AggregateKind::Min,
220                                } => ordering.is_lt(),
221                                PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
222                                    kind: crate::db::query::plan::AggregateKind::Max,
223                                } => ordering.is_gt(),
224                                _ => unreachable!(
225                                    "extrema SQL aggregate strategy drifted during reduction"
226                                ),
227                            }
228                        }
229                    };
230
231                    if replace {
232                        selected = Some(value);
233                    }
234                }
235
236                Ok(selected.unwrap_or(Value::Null))
237            }
238            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
239            | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
240                Err(QueryError::invariant(
241                    "prepared SQL scalar aggregate strategy drifted outside SQL support",
242                ))
243            }
244        }
245    }
246
247    // Project one single-field structural query and return its canonical field
248    // values for aggregate reduction.
249    fn execute_structural_sql_aggregate_field_projection(
250        &self,
251        query: crate::db::query::intent::StructuralQuery,
252        authority: crate::db::executor::EntityAuthority,
253    ) -> Result<Vec<Value>, QueryError> {
254        let (_, rows, _) = self
255            .execute_structural_sql_projection(query, authority)?
256            .into_parts();
257        let mut projected = Vec::with_capacity(rows.len());
258
259        for row in rows {
260            let [value] = row.as_slice() else {
261                return Err(QueryError::invariant(
262                    "structural SQL aggregate projection must emit exactly one field",
263                ));
264            };
265
266            projected.push(value.clone());
267        }
268
269        Ok(projected)
270    }
271
272    // Execute one generic-free prepared SQL aggregate command through the
273    // structural SQL projection path and package the result as one row-shaped
274    // dispatch payload for unified SQL loops.
275    pub(in crate::db::session::sql) fn execute_sql_aggregate_dispatch_for_authority(
276        &self,
277        command: SqlGlobalAggregateCommandCore,
278        authority: crate::db::executor::EntityAuthority,
279        label_override: Option<String>,
280    ) -> Result<SqlDispatchResult, QueryError> {
281        let model = authority.model();
282        let strategy = command
283            .prepared_scalar_strategy_with_model(model)
284            .map_err(QueryError::from_sql_lowering_error)?;
285        let label = label_override.unwrap_or_else(|| Self::sql_scalar_aggregate_label(&strategy));
286        let value = match strategy.runtime_descriptor() {
287            PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => {
288                let (_, _, row_count) = self
289                    .execute_structural_sql_projection(
290                        command
291                            .query()
292                            .clone()
293                            .select_fields([authority.primary_key_name()]),
294                        authority,
295                    )?
296                    .into_parts();
297
298                Value::Uint(u64::from(row_count))
299            }
300            PreparedSqlScalarAggregateRuntimeDescriptor::CountField
301            | PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
302            | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
303                let Some(field) = strategy.projected_field() else {
304                    return Err(QueryError::invariant(
305                        "field-target SQL aggregate strategy requires projected field label",
306                    ));
307                };
308                let values = self.execute_structural_sql_aggregate_field_projection(
309                    command.query().clone().select_fields([field]),
310                    authority,
311                )?;
312
313                Self::reduce_structural_sql_aggregate_field_values(values, &strategy)?
314            }
315        };
316
317        Ok(SqlDispatchResult::Projection {
318            columns: vec![label],
319            rows: vec![vec![value]],
320            row_count: 1,
321        })
322    }
323
324    // Compile one already-parsed SQL aggregate statement into the shared
325    // generic-free aggregate command used by unified dispatch/query surfaces.
326    pub(in crate::db::session::sql) fn compile_sql_aggregate_command_core_for_authority(
327        parsed: &SqlParsedStatement,
328        authority: crate::db::executor::EntityAuthority,
329    ) -> Result<SqlGlobalAggregateCommandCore, QueryError> {
330        compile_sql_global_aggregate_command_core_from_prepared(
331            parsed.prepare(authority.model().name())?,
332            authority.model(),
333            MissingRowPolicy::Ignore,
334        )
335        .map_err(QueryError::from_sql_lowering_error)
336    }
337
338    // Require one resolved target slot from a prepared field-target SQL
339    // aggregate strategy before dispatching into execution families.
340    fn prepared_sql_scalar_target_slot_required(
341        strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
342        message: &'static str,
343    ) -> Result<crate::db::query::plan::FieldSlot, QueryError> {
344        strategy
345            .target_slot()
346            .cloned()
347            .ok_or_else(|| QueryError::invariant(message))
348    }
349
350    // Execute prepared COUNT(*) through the shared existing-rows scalar
351    // terminal boundary.
352    fn execute_prepared_sql_scalar_count_rows<E>(
353        &self,
354        command: &SqlGlobalAggregateCommand<E>,
355    ) -> Result<Value, QueryError>
356    where
357        E: PersistedRow<Canister = C> + EntityValue,
358    {
359        self.execute_load_query_with(command.query(), |load, plan| {
360            load.execute_scalar_terminal_request(
361                plan,
362                crate::db::executor::ScalarTerminalBoundaryRequest::Count,
363            )?
364            .into_count()
365        })
366        .map(|count| Value::Uint(u64::from(count)))
367    }
368
369    // Execute prepared COUNT(field) through the shared scalar projection
370    // boundary.
371    fn execute_prepared_sql_scalar_count_field<E>(
372        &self,
373        command: &SqlGlobalAggregateCommand<E>,
374        strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
375    ) -> Result<Value, QueryError>
376    where
377        E: PersistedRow<Canister = C> + EntityValue,
378    {
379        let target_slot = Self::prepared_sql_scalar_target_slot_required(
380            strategy,
381            "prepared COUNT(field) SQL aggregate strategy requires target slot",
382        )?;
383
384        self.execute_load_query_with(command.query(), |load, plan| {
385            load.execute_scalar_projection_boundary(
386                plan,
387                target_slot.clone(),
388                ScalarProjectionBoundaryRequest::CountNonNull,
389            )?
390            .into_count()
391        })
392        .map(|count| Value::Uint(u64::from(count)))
393    }
394
395    // Execute prepared SUM/AVG(field) through the shared numeric field
396    // boundary.
397    fn execute_prepared_sql_scalar_numeric_field<E>(
398        &self,
399        command: &SqlGlobalAggregateCommand<E>,
400        strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
401        request: ScalarNumericFieldBoundaryRequest,
402        message: &'static str,
403    ) -> Result<Value, QueryError>
404    where
405        E: PersistedRow<Canister = C> + EntityValue,
406    {
407        let target_slot = Self::prepared_sql_scalar_target_slot_required(strategy, message)?;
408
409        self.execute_load_query_with(command.query(), |load, plan| {
410            load.execute_numeric_field_boundary(plan, target_slot.clone(), request)
411        })
412        .map(|value| value.map_or(Value::Null, Value::Decimal))
413    }
414
415    // Execute prepared MIN/MAX(field) through the shared extrema-value
416    // boundary.
417    fn execute_prepared_sql_scalar_extrema_field<E>(
418        &self,
419        command: &SqlGlobalAggregateCommand<E>,
420        strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
421        kind: crate::db::query::plan::AggregateKind,
422    ) -> Result<Value, QueryError>
423    where
424        E: PersistedRow<Canister = C> + EntityValue,
425    {
426        let target_slot = Self::prepared_sql_scalar_target_slot_required(
427            strategy,
428            "prepared extrema SQL aggregate strategy requires target slot",
429        )?;
430
431        self.execute_load_query_with(command.query(), |load, plan| {
432            load.execute_scalar_extrema_value_boundary(plan, target_slot.clone(), kind)
433        })
434        .map(|value| value.unwrap_or(Value::Null))
435    }
436
437    // Execute one prepared typed SQL scalar aggregate strategy through the
438    // existing aggregate boundary families without rediscovering behavior from
439    // raw SQL terminal variants at the session layer.
440    fn execute_prepared_sql_scalar_aggregate<E>(
441        &self,
442        command: &SqlGlobalAggregateCommand<E>,
443    ) -> Result<Value, QueryError>
444    where
445        E: PersistedRow<Canister = C> + EntityValue,
446    {
447        let strategy = command.prepared_scalar_strategy();
448
449        match strategy.runtime_descriptor() {
450            PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => {
451                self.execute_prepared_sql_scalar_count_rows(command)
452            }
453            PreparedSqlScalarAggregateRuntimeDescriptor::CountField => {
454                self.execute_prepared_sql_scalar_count_field(command, &strategy)
455            }
456            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
457                kind: crate::db::query::plan::AggregateKind::Sum,
458            } => self.execute_prepared_sql_scalar_numeric_field(
459                command,
460                &strategy,
461                ScalarNumericFieldBoundaryRequest::Sum,
462                "prepared SUM(field) SQL aggregate strategy requires target slot",
463            ),
464            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
465                kind: crate::db::query::plan::AggregateKind::Avg,
466            } => self.execute_prepared_sql_scalar_numeric_field(
467                command,
468                &strategy,
469                ScalarNumericFieldBoundaryRequest::Avg,
470                "prepared AVG(field) SQL aggregate strategy requires target slot",
471            ),
472            PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { kind } => {
473                self.execute_prepared_sql_scalar_extrema_field(command, &strategy, kind)
474            }
475            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. } => {
476                Err(QueryError::invariant(
477                    "prepared SQL scalar aggregate numeric runtime descriptor drift",
478                ))
479            }
480        }
481    }
482
483    /// Execute one reduced SQL global aggregate `SELECT` statement.
484    ///
485    /// This entrypoint is intentionally constrained to one aggregate terminal
486    /// shape per statement and preserves existing terminal semantics.
487    pub fn execute_sql_aggregate<E>(&self, sql: &str) -> Result<Value, QueryError>
488    where
489        E: PersistedRow<Canister = C> + EntityValue,
490    {
491        // Parse once into one owned statement so the aggregate lane can keep
492        // its surface checks and lowering on the same statement instance.
493        let statement = parse_sql(sql).map_err(QueryError::from_sql_parse_error)?;
494
495        // First keep wrong-lane traffic on an explicit aggregate-surface
496        // contract instead of relying on generic lowering failures.
497        match &statement {
498            SqlStatement::Select(_) if is_sql_global_aggregate_statement(&statement) => {}
499            SqlStatement::Select(statement) if !statement.group_by.is_empty() => {
500                return Err(QueryError::unsupported_query(
501                    unsupported_sql_aggregate_grouped_message(),
502                ));
503            }
504            SqlStatement::Delete(_) => {
505                return Err(QueryError::unsupported_query(
506                    "execute_sql_aggregate rejects DELETE; use execute_sql_dispatch",
507                ));
508            }
509            _ => {
510                let route = sql_statement_route_from_statement(&statement);
511
512                return Err(QueryError::unsupported_query(
513                    unsupported_sql_aggregate_surface_lane_message(&route),
514                ));
515            }
516        }
517
518        // First lower the SQL surface onto the existing single-terminal
519        // aggregate command authority so execution never has to rediscover the
520        // accepted aggregate shape family.
521        let prepared = prepare_sql_statement(statement, E::MODEL.name())
522            .map_err(QueryError::from_sql_lowering_error)?;
523        let command = compile_sql_global_aggregate_command_from_prepared::<E>(
524            prepared.clone(),
525            MissingRowPolicy::Ignore,
526        )
527        .map_err(QueryError::from_sql_lowering_error)?;
528        let strategy = command.prepared_scalar_strategy();
529
530        // DISTINCT field aggregates reuse the existing structural projection +
531        // reduction lane so SQL deduplicates aggregate inputs before folding.
532        if strategy.is_distinct() {
533            let dispatch = compile_sql_global_aggregate_command_core_from_prepared(
534                prepared,
535                E::MODEL,
536                MissingRowPolicy::Ignore,
537            )
538            .map_err(QueryError::from_sql_lowering_error)?;
539            let authority = crate::db::executor::EntityAuthority::for_type::<E>();
540            let SqlDispatchResult::Projection { rows, .. } =
541                self.execute_sql_aggregate_dispatch_for_authority(dispatch, authority, None)?
542            else {
543                return Err(QueryError::invariant(
544                    "DISTINCT SQL aggregate dispatch must finalize as one projection row",
545                ));
546            };
547            let Some(mut row) = rows.into_iter().next() else {
548                return Err(QueryError::invariant(
549                    "DISTINCT SQL aggregate dispatch must emit one projection row",
550                ));
551            };
552            if row.len() != 1 {
553                return Err(QueryError::invariant(
554                    "DISTINCT SQL aggregate dispatch must emit exactly one projected value",
555                ));
556            }
557            let value = row.pop().ok_or_else(|| {
558                QueryError::invariant(
559                    "DISTINCT SQL aggregate dispatch must emit exactly one projected value",
560                )
561            })?;
562
563            return Ok(value);
564        }
565
566        // Then dispatch through one prepared typed-scalar aggregate strategy so
567        // SQL aggregate execution and SQL aggregate explain consume the same
568        // behavioral source instead of matching raw terminal variants twice.
569        self.execute_prepared_sql_scalar_aggregate(&command)
570    }
571}