Skip to main content

icydb_core/db/session/sql/
aggregate.rs

1//! Module: db::session::sql::aggregate
2//! Responsibility: module-local ownership and contracts for db::session::sql::aggregate.
3//! Does not own: cross-module orchestration outside this module.
4//! Boundary: exposes this module API while keeping implementation details internal.
5
6use crate::{
7    db::{
8        DbSession, MissingRowPolicy, PersistedRow, QueryError,
9        executor::{ScalarNumericFieldBoundaryRequest, ScalarProjectionBoundaryRequest},
10        numeric::{
11            add_decimal_terms, average_decimal_terms, coerce_numeric_decimal,
12            compare_numeric_or_strict_order,
13        },
14        session::sql::surface::sql_statement_route_from_statement,
15        session::sql::{SqlDispatchResult, SqlParsedStatement, SqlStatementRoute},
16        sql::lowering::{
17            PreparedSqlScalarAggregateRuntimeDescriptor, PreparedSqlScalarAggregateStrategy,
18            SqlGlobalAggregateCommand, SqlGlobalAggregateCommandCore,
19            compile_sql_global_aggregate_command_core_from_prepared,
20            compile_sql_global_aggregate_command_from_prepared, is_sql_global_aggregate_statement,
21            prepare_sql_statement,
22        },
23        sql::parser::{SqlStatement, parse_sql},
24    },
25    traits::{CanisterKind, EntityValue},
26    value::Value,
27};
28
29#[derive(Clone, Copy, Debug, Eq, PartialEq)]
30pub(in crate::db::session::sql) enum SqlAggregateSurface {
31    QueryFrom,
32    ExecuteSql,
33    ExecuteSqlGrouped,
34}
35
36pub(in crate::db::session::sql) fn parsed_requires_dedicated_sql_aggregate_lane(
37    parsed: &SqlParsedStatement,
38) -> bool {
39    is_sql_global_aggregate_statement(&parsed.statement)
40}
41
42pub(in crate::db::session::sql) const fn unsupported_sql_aggregate_lane_message(
43    surface: SqlAggregateSurface,
44) -> &'static str {
45    match surface {
46        SqlAggregateSurface::QueryFrom => {
47            "query_from_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
48        }
49        SqlAggregateSurface::ExecuteSql => {
50            "execute_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
51        }
52        SqlAggregateSurface::ExecuteSqlGrouped => {
53            "execute_sql_grouped rejects global aggregate SELECT; use execute_sql_aggregate(...)"
54        }
55    }
56}
57
58const fn unsupported_sql_aggregate_surface_lane_message(route: &SqlStatementRoute) -> &'static str {
59    match route {
60        SqlStatementRoute::Query { .. } => {
61            "execute_sql_aggregate requires constrained global aggregate SELECT"
62        }
63        SqlStatementRoute::Explain { .. } => {
64            "execute_sql_aggregate rejects EXPLAIN; use execute_sql_dispatch"
65        }
66        SqlStatementRoute::Describe { .. } => {
67            "execute_sql_aggregate rejects DESCRIBE; use execute_sql_dispatch"
68        }
69        SqlStatementRoute::ShowIndexes { .. } => {
70            "execute_sql_aggregate rejects SHOW INDEXES; use execute_sql_dispatch"
71        }
72        SqlStatementRoute::ShowColumns { .. } => {
73            "execute_sql_aggregate rejects SHOW COLUMNS; use execute_sql_dispatch"
74        }
75        SqlStatementRoute::ShowEntities => {
76            "execute_sql_aggregate rejects SHOW ENTITIES; use execute_sql_dispatch"
77        }
78    }
79}
80
81const fn unsupported_sql_aggregate_grouped_message() -> &'static str {
82    "execute_sql_aggregate rejects grouped SELECT; use execute_sql_grouped(...)"
83}
84
85impl<C: CanisterKind> DbSession<C> {
86    // Build the canonical SQL aggregate label projected by the prepared
87    // aggregate strategy so unified dispatch rows stay parser-stable.
88    pub(in crate::db::session::sql) fn sql_scalar_aggregate_label(
89        strategy: &PreparedSqlScalarAggregateStrategy,
90    ) -> String {
91        let kind = match strategy.runtime_descriptor() {
92            PreparedSqlScalarAggregateRuntimeDescriptor::CountRows
93            | PreparedSqlScalarAggregateRuntimeDescriptor::CountField => "COUNT",
94            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
95                kind: crate::db::query::plan::AggregateKind::Sum,
96            } => "SUM",
97            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
98                kind: crate::db::query::plan::AggregateKind::Avg,
99            } => "AVG",
100            PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
101                kind: crate::db::query::plan::AggregateKind::Min,
102            } => "MIN",
103            PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
104                kind: crate::db::query::plan::AggregateKind::Max,
105            } => "MAX",
106            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
107            | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
108                unreachable!("prepared SQL scalar aggregate strategy drifted outside SQL support")
109            }
110        };
111
112        match strategy.projected_field() {
113            Some(field) => format!("{kind}({field})"),
114            None => format!("{kind}(*)"),
115        }
116    }
117
118    // Reduce one structural aggregate field projection into canonical aggregate
119    // value semantics for the unified SQL dispatch/query surface.
120    fn reduce_structural_sql_aggregate_field_values(
121        values: Vec<Value>,
122        strategy: &PreparedSqlScalarAggregateStrategy,
123    ) -> Result<Value, QueryError> {
124        match strategy.runtime_descriptor() {
125            PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => Err(QueryError::invariant(
126                "COUNT(*) structural reduction does not consume projected field values",
127            )),
128            PreparedSqlScalarAggregateRuntimeDescriptor::CountField => {
129                let count = values
130                    .into_iter()
131                    .filter(|value| !matches!(value, Value::Null))
132                    .count();
133
134                Ok(Value::Uint(u64::try_from(count).unwrap_or(u64::MAX)))
135            }
136            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
137                kind:
138                    crate::db::query::plan::AggregateKind::Sum
139                    | crate::db::query::plan::AggregateKind::Avg,
140            } => {
141                let mut sum = None;
142                let mut row_count = 0_u64;
143
144                for value in values {
145                    if matches!(value, Value::Null) {
146                        continue;
147                    }
148
149                    let decimal = coerce_numeric_decimal(&value).ok_or_else(|| {
150                        QueryError::invariant(
151                            "numeric SQL aggregate dispatch encountered non-numeric projected value",
152                        )
153                    })?;
154                    sum = Some(sum.map_or(decimal, |current| add_decimal_terms(current, decimal)));
155                    row_count = row_count.saturating_add(1);
156                }
157
158                match strategy.runtime_descriptor() {
159                    PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
160                        kind: crate::db::query::plan::AggregateKind::Sum,
161                    } => Ok(sum.map_or(Value::Null, Value::Decimal)),
162                    PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
163                        kind: crate::db::query::plan::AggregateKind::Avg,
164                    } => Ok(sum
165                        .and_then(|sum| average_decimal_terms(sum, row_count))
166                        .map_or(Value::Null, Value::Decimal)),
167                    _ => unreachable!("numeric SQL aggregate strategy drifted during reduction"),
168                }
169            }
170            PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
171                kind:
172                    crate::db::query::plan::AggregateKind::Min
173                    | crate::db::query::plan::AggregateKind::Max,
174            } => {
175                let mut selected = None::<Value>;
176
177                for value in values {
178                    if matches!(value, Value::Null) {
179                        continue;
180                    }
181
182                    let replace = match selected.as_ref() {
183                        None => true,
184                        Some(current) => {
185                            let ordering =
186                                compare_numeric_or_strict_order(&value, current).ok_or_else(
187                                    || {
188                                        QueryError::invariant(
189                                            "extrema SQL aggregate dispatch encountered incomparable projected values",
190                                        )
191                                    },
192                                )?;
193
194                            match strategy.runtime_descriptor() {
195                                PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
196                                    kind: crate::db::query::plan::AggregateKind::Min,
197                                } => ordering.is_lt(),
198                                PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
199                                    kind: crate::db::query::plan::AggregateKind::Max,
200                                } => ordering.is_gt(),
201                                _ => unreachable!(
202                                    "extrema SQL aggregate strategy drifted during reduction"
203                                ),
204                            }
205                        }
206                    };
207
208                    if replace {
209                        selected = Some(value);
210                    }
211                }
212
213                Ok(selected.unwrap_or(Value::Null))
214            }
215            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
216            | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
217                Err(QueryError::invariant(
218                    "prepared SQL scalar aggregate strategy drifted outside SQL support",
219                ))
220            }
221        }
222    }
223
224    // Project one single-field structural query and return its canonical field
225    // values for aggregate reduction.
226    fn execute_structural_sql_aggregate_field_projection(
227        &self,
228        query: crate::db::query::intent::StructuralQuery,
229        authority: crate::db::executor::EntityAuthority,
230    ) -> Result<Vec<Value>, QueryError> {
231        let (_, rows, _) = self
232            .execute_structural_sql_projection(query, authority)?
233            .into_parts();
234        let mut projected = Vec::with_capacity(rows.len());
235
236        for row in rows {
237            let [value] = row.as_slice() else {
238                return Err(QueryError::invariant(
239                    "structural SQL aggregate projection must emit exactly one field",
240                ));
241            };
242
243            projected.push(value.clone());
244        }
245
246        Ok(projected)
247    }
248
249    // Execute one generic-free prepared SQL aggregate command through the
250    // structural SQL projection path and package the result as one row-shaped
251    // dispatch payload for unified SQL loops.
252    pub(in crate::db::session::sql) fn execute_sql_aggregate_dispatch_for_authority(
253        &self,
254        command: SqlGlobalAggregateCommandCore,
255        authority: crate::db::executor::EntityAuthority,
256    ) -> Result<SqlDispatchResult, QueryError> {
257        let model = authority.model();
258        let strategy = command
259            .prepared_scalar_strategy_with_model(model)
260            .map_err(QueryError::from_sql_lowering_error)?;
261        let label = Self::sql_scalar_aggregate_label(&strategy);
262        let value = match strategy.runtime_descriptor() {
263            PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => {
264                let (_, _, row_count) = self
265                    .execute_structural_sql_projection(
266                        command
267                            .query()
268                            .clone()
269                            .select_fields([authority.primary_key_name()]),
270                        authority,
271                    )?
272                    .into_parts();
273
274                Value::Uint(u64::from(row_count))
275            }
276            PreparedSqlScalarAggregateRuntimeDescriptor::CountField
277            | PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
278            | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
279                let Some(field) = strategy.projected_field() else {
280                    return Err(QueryError::invariant(
281                        "field-target SQL aggregate strategy requires projected field label",
282                    ));
283                };
284                let values = self.execute_structural_sql_aggregate_field_projection(
285                    command.query().clone().select_fields([field]),
286                    authority,
287                )?;
288
289                Self::reduce_structural_sql_aggregate_field_values(values, &strategy)?
290            }
291        };
292
293        Ok(SqlDispatchResult::Projection {
294            columns: vec![label],
295            rows: vec![vec![value]],
296            row_count: 1,
297        })
298    }
299
300    // Compile one already-parsed SQL aggregate statement into the shared
301    // generic-free aggregate command used by unified dispatch/query surfaces.
302    pub(in crate::db::session::sql) fn compile_sql_aggregate_command_core_for_authority(
303        parsed: &SqlParsedStatement,
304        authority: crate::db::executor::EntityAuthority,
305    ) -> Result<SqlGlobalAggregateCommandCore, QueryError> {
306        compile_sql_global_aggregate_command_core_from_prepared(
307            parsed.prepare(authority.model().name())?,
308            authority.model(),
309            MissingRowPolicy::Ignore,
310        )
311        .map_err(QueryError::from_sql_lowering_error)
312    }
313
314    // Require one resolved target slot from a prepared field-target SQL
315    // aggregate strategy before dispatching into execution families.
316    fn prepared_sql_scalar_target_slot_required(
317        strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
318        message: &'static str,
319    ) -> Result<crate::db::query::plan::FieldSlot, QueryError> {
320        strategy
321            .target_slot()
322            .cloned()
323            .ok_or_else(|| QueryError::invariant(message))
324    }
325
326    // Execute prepared COUNT(*) through the shared existing-rows scalar
327    // terminal boundary.
328    fn execute_prepared_sql_scalar_count_rows<E>(
329        &self,
330        command: &SqlGlobalAggregateCommand<E>,
331    ) -> Result<Value, QueryError>
332    where
333        E: PersistedRow<Canister = C> + EntityValue,
334    {
335        self.execute_load_query_with(command.query(), |load, plan| {
336            load.execute_scalar_terminal_request(
337                plan,
338                crate::db::executor::ScalarTerminalBoundaryRequest::Count,
339            )?
340            .into_count()
341        })
342        .map(|count| Value::Uint(u64::from(count)))
343    }
344
345    // Execute prepared COUNT(field) through the shared scalar projection
346    // boundary.
347    fn execute_prepared_sql_scalar_count_field<E>(
348        &self,
349        command: &SqlGlobalAggregateCommand<E>,
350        strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
351    ) -> Result<Value, QueryError>
352    where
353        E: PersistedRow<Canister = C> + EntityValue,
354    {
355        let target_slot = Self::prepared_sql_scalar_target_slot_required(
356            strategy,
357            "prepared COUNT(field) SQL aggregate strategy requires target slot",
358        )?;
359
360        self.execute_load_query_with(command.query(), |load, plan| {
361            load.execute_scalar_projection_boundary(
362                plan,
363                target_slot.clone(),
364                ScalarProjectionBoundaryRequest::CountNonNull,
365            )?
366            .into_count()
367        })
368        .map(|count| Value::Uint(u64::from(count)))
369    }
370
371    // Execute prepared SUM/AVG(field) through the shared numeric field
372    // boundary.
373    fn execute_prepared_sql_scalar_numeric_field<E>(
374        &self,
375        command: &SqlGlobalAggregateCommand<E>,
376        strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
377        request: ScalarNumericFieldBoundaryRequest,
378        message: &'static str,
379    ) -> Result<Value, QueryError>
380    where
381        E: PersistedRow<Canister = C> + EntityValue,
382    {
383        let target_slot = Self::prepared_sql_scalar_target_slot_required(strategy, message)?;
384
385        self.execute_load_query_with(command.query(), |load, plan| {
386            load.execute_numeric_field_boundary(plan, target_slot.clone(), request)
387        })
388        .map(|value| value.map_or(Value::Null, Value::Decimal))
389    }
390
391    // Execute prepared MIN/MAX(field) through the shared extrema-value
392    // boundary.
393    fn execute_prepared_sql_scalar_extrema_field<E>(
394        &self,
395        command: &SqlGlobalAggregateCommand<E>,
396        strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
397        kind: crate::db::query::plan::AggregateKind,
398    ) -> Result<Value, QueryError>
399    where
400        E: PersistedRow<Canister = C> + EntityValue,
401    {
402        let target_slot = Self::prepared_sql_scalar_target_slot_required(
403            strategy,
404            "prepared extrema SQL aggregate strategy requires target slot",
405        )?;
406
407        self.execute_load_query_with(command.query(), |load, plan| {
408            load.execute_scalar_extrema_value_boundary(plan, target_slot.clone(), kind)
409        })
410        .map(|value| value.unwrap_or(Value::Null))
411    }
412
413    // Execute one prepared typed SQL scalar aggregate strategy through the
414    // existing aggregate boundary families without rediscovering behavior from
415    // raw SQL terminal variants at the session layer.
416    fn execute_prepared_sql_scalar_aggregate<E>(
417        &self,
418        command: &SqlGlobalAggregateCommand<E>,
419    ) -> Result<Value, QueryError>
420    where
421        E: PersistedRow<Canister = C> + EntityValue,
422    {
423        let strategy = command.prepared_scalar_strategy();
424
425        match strategy.runtime_descriptor() {
426            PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => {
427                self.execute_prepared_sql_scalar_count_rows(command)
428            }
429            PreparedSqlScalarAggregateRuntimeDescriptor::CountField => {
430                self.execute_prepared_sql_scalar_count_field(command, &strategy)
431            }
432            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
433                kind: crate::db::query::plan::AggregateKind::Sum,
434            } => self.execute_prepared_sql_scalar_numeric_field(
435                command,
436                &strategy,
437                ScalarNumericFieldBoundaryRequest::Sum,
438                "prepared SUM(field) SQL aggregate strategy requires target slot",
439            ),
440            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
441                kind: crate::db::query::plan::AggregateKind::Avg,
442            } => self.execute_prepared_sql_scalar_numeric_field(
443                command,
444                &strategy,
445                ScalarNumericFieldBoundaryRequest::Avg,
446                "prepared AVG(field) SQL aggregate strategy requires target slot",
447            ),
448            PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { kind } => {
449                self.execute_prepared_sql_scalar_extrema_field(command, &strategy, kind)
450            }
451            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. } => {
452                Err(QueryError::invariant(
453                    "prepared SQL scalar aggregate numeric runtime descriptor drift",
454                ))
455            }
456        }
457    }
458
459    /// Execute one reduced SQL global aggregate `SELECT` statement.
460    ///
461    /// This entrypoint is intentionally constrained to one aggregate terminal
462    /// shape per statement and preserves existing terminal semantics.
463    pub fn execute_sql_aggregate<E>(&self, sql: &str) -> Result<Value, QueryError>
464    where
465        E: PersistedRow<Canister = C> + EntityValue,
466    {
467        // Parse once into one owned statement so the aggregate lane can keep
468        // its surface checks and lowering on the same statement instance.
469        let statement = parse_sql(sql).map_err(QueryError::from_sql_parse_error)?;
470
471        // First keep wrong-lane traffic on an explicit aggregate-surface
472        // contract instead of relying on generic lowering failures.
473        match &statement {
474            SqlStatement::Select(_) if is_sql_global_aggregate_statement(&statement) => {}
475            SqlStatement::Select(statement) if !statement.group_by.is_empty() => {
476                return Err(QueryError::unsupported_query(
477                    unsupported_sql_aggregate_grouped_message(),
478                ));
479            }
480            SqlStatement::Delete(_) => {
481                return Err(QueryError::unsupported_query(
482                    "execute_sql_aggregate rejects DELETE; use execute_sql_dispatch",
483                ));
484            }
485            _ => {
486                let route = sql_statement_route_from_statement(&statement);
487
488                return Err(QueryError::unsupported_query(
489                    unsupported_sql_aggregate_surface_lane_message(&route),
490                ));
491            }
492        }
493
494        // First lower the SQL surface onto the existing single-terminal
495        // aggregate command authority so execution never has to rediscover the
496        // accepted aggregate shape family.
497        let command = compile_sql_global_aggregate_command_from_prepared::<E>(
498            prepare_sql_statement(statement, E::MODEL.name())
499                .map_err(QueryError::from_sql_lowering_error)?,
500            MissingRowPolicy::Ignore,
501        )
502        .map_err(QueryError::from_sql_lowering_error)?;
503
504        // Then dispatch through one prepared typed-scalar aggregate strategy so
505        // SQL aggregate execution and SQL aggregate explain consume the same
506        // behavioral source instead of matching raw terminal variants twice.
507        self.execute_prepared_sql_scalar_aggregate(&command)
508    }
509}