Skip to main content

icydb_core/db/session/sql/
aggregate.rs

1//! Module: db::session::sql::aggregate
2//! Responsibility: session-owned execution and shaping helpers for lowered SQL
3//! scalar aggregate commands.
4//! Does not own: aggregate lowering or aggregate executor route selection.
5//! Boundary: binds lowered SQL aggregate commands onto authority-aware planning and result shaping.
6
7use crate::{
8    db::{
9        DbSession, MissingRowPolicy, PersistedRow, QueryError,
10        executor::{ScalarNumericFieldBoundaryRequest, ScalarProjectionBoundaryRequest},
11        numeric::{
12            add_decimal_terms, average_decimal_terms, coerce_numeric_decimal,
13            compare_numeric_or_strict_order,
14        },
15        session::sql::surface::sql_statement_route_from_statement,
16        session::sql::{SqlDispatchResult, SqlParsedStatement, SqlStatementRoute},
17        sql::lowering::{
18            PreparedSqlScalarAggregateRuntimeDescriptor, PreparedSqlScalarAggregateStrategy,
19            SqlGlobalAggregateCommand, SqlGlobalAggregateCommandCore,
20            compile_sql_global_aggregate_command_core_from_prepared,
21            compile_sql_global_aggregate_command_from_prepared, is_sql_global_aggregate_statement,
22            prepare_sql_statement,
23        },
24        sql::parser::{SqlStatement, parse_sql},
25    },
26    traits::{CanisterKind, EntityValue},
27    value::Value,
28};
29
30#[derive(Clone, Copy, Debug, Eq, PartialEq)]
31pub(in crate::db::session::sql) enum SqlAggregateSurface {
32    QueryFrom,
33    ExecuteSql,
34    ExecuteSqlGrouped,
35}
36
37pub(in crate::db::session::sql) fn parsed_requires_dedicated_sql_aggregate_lane(
38    parsed: &SqlParsedStatement,
39) -> bool {
40    is_sql_global_aggregate_statement(&parsed.statement)
41}
42
43pub(in crate::db::session::sql) const fn unsupported_sql_aggregate_lane_message(
44    surface: SqlAggregateSurface,
45) -> &'static str {
46    match surface {
47        SqlAggregateSurface::QueryFrom => {
48            "query_from_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
49        }
50        SqlAggregateSurface::ExecuteSql => {
51            "execute_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
52        }
53        SqlAggregateSurface::ExecuteSqlGrouped => {
54            "execute_sql_grouped rejects global aggregate SELECT; use execute_sql_aggregate(...)"
55        }
56    }
57}
58
59const fn unsupported_sql_aggregate_surface_lane_message(route: &SqlStatementRoute) -> &'static str {
60    match route {
61        SqlStatementRoute::Query { .. } => {
62            "execute_sql_aggregate requires constrained global aggregate SELECT"
63        }
64        SqlStatementRoute::Insert { .. } => {
65            "execute_sql_aggregate rejects INSERT; use execute_sql_dispatch"
66        }
67        SqlStatementRoute::Update { .. } => {
68            "execute_sql_aggregate rejects UPDATE; use execute_sql_dispatch"
69        }
70        SqlStatementRoute::Explain { .. } => {
71            "execute_sql_aggregate rejects EXPLAIN; use execute_sql_dispatch"
72        }
73        SqlStatementRoute::Describe { .. } => {
74            "execute_sql_aggregate rejects DESCRIBE; use execute_sql_dispatch"
75        }
76        SqlStatementRoute::ShowIndexes { .. } => {
77            "execute_sql_aggregate rejects SHOW INDEXES; use execute_sql_dispatch"
78        }
79        SqlStatementRoute::ShowColumns { .. } => {
80            "execute_sql_aggregate rejects SHOW COLUMNS; use execute_sql_dispatch"
81        }
82        SqlStatementRoute::ShowEntities => {
83            "execute_sql_aggregate rejects SHOW ENTITIES; use execute_sql_dispatch"
84        }
85    }
86}
87
88const fn unsupported_sql_aggregate_grouped_message() -> &'static str {
89    "execute_sql_aggregate rejects grouped SELECT; use execute_sql_grouped(...)"
90}
91
92impl<C: CanisterKind> DbSession<C> {
93    // Build the canonical SQL aggregate label projected by the prepared
94    // aggregate strategy so unified dispatch rows stay parser-stable.
95    pub(in crate::db::session::sql) fn sql_scalar_aggregate_label(
96        strategy: &PreparedSqlScalarAggregateStrategy,
97    ) -> String {
98        let kind = match strategy.runtime_descriptor() {
99            PreparedSqlScalarAggregateRuntimeDescriptor::CountRows
100            | PreparedSqlScalarAggregateRuntimeDescriptor::CountField => "COUNT",
101            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
102                kind: crate::db::query::plan::AggregateKind::Sum,
103            } => "SUM",
104            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
105                kind: crate::db::query::plan::AggregateKind::Avg,
106            } => "AVG",
107            PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
108                kind: crate::db::query::plan::AggregateKind::Min,
109            } => "MIN",
110            PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
111                kind: crate::db::query::plan::AggregateKind::Max,
112            } => "MAX",
113            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
114            | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
115                unreachable!("prepared SQL scalar aggregate strategy drifted outside SQL support")
116            }
117        };
118
119        match strategy.projected_field() {
120            Some(field) if strategy.is_distinct() => format!("{kind}(DISTINCT {field})"),
121            Some(field) => format!("{kind}({field})"),
122            None => format!("{kind}(*)"),
123        }
124    }
125
126    // Deduplicate one projected aggregate input stream while preserving the
127    // first-observed value order used by SQL aggregate reduction.
128    fn dedup_structural_sql_aggregate_input_values(values: Vec<Value>) -> Vec<Value> {
129        let mut deduped = Vec::with_capacity(values.len());
130
131        for value in values {
132            if deduped.iter().any(|current| current == &value) {
133                continue;
134            }
135            deduped.push(value);
136        }
137
138        deduped
139    }
140
141    // Reduce one structural aggregate field projection into canonical aggregate
142    // value semantics for the unified SQL dispatch/query surface.
143    fn reduce_structural_sql_aggregate_field_values(
144        values: Vec<Value>,
145        strategy: &PreparedSqlScalarAggregateStrategy,
146    ) -> Result<Value, QueryError> {
147        let values = if strategy.is_distinct() {
148            Self::dedup_structural_sql_aggregate_input_values(values)
149        } else {
150            values
151        };
152
153        match strategy.runtime_descriptor() {
154            PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => Err(QueryError::invariant(
155                "COUNT(*) structural reduction does not consume projected field values",
156            )),
157            PreparedSqlScalarAggregateRuntimeDescriptor::CountField => {
158                let count = values
159                    .into_iter()
160                    .filter(|value| !matches!(value, Value::Null))
161                    .count();
162
163                Ok(Value::Uint(u64::try_from(count).unwrap_or(u64::MAX)))
164            }
165            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
166                kind:
167                    crate::db::query::plan::AggregateKind::Sum
168                    | crate::db::query::plan::AggregateKind::Avg,
169            } => {
170                let mut sum = None;
171                let mut row_count = 0_u64;
172
173                for value in values {
174                    if matches!(value, Value::Null) {
175                        continue;
176                    }
177
178                    let decimal = coerce_numeric_decimal(&value).ok_or_else(|| {
179                        QueryError::invariant(
180                            "numeric SQL aggregate dispatch encountered non-numeric projected value",
181                        )
182                    })?;
183                    sum = Some(sum.map_or(decimal, |current| add_decimal_terms(current, decimal)));
184                    row_count = row_count.saturating_add(1);
185                }
186
187                match strategy.runtime_descriptor() {
188                    PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
189                        kind: crate::db::query::plan::AggregateKind::Sum,
190                    } => Ok(sum.map_or(Value::Null, Value::Decimal)),
191                    PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
192                        kind: crate::db::query::plan::AggregateKind::Avg,
193                    } => Ok(sum
194                        .and_then(|sum| average_decimal_terms(sum, row_count))
195                        .map_or(Value::Null, Value::Decimal)),
196                    _ => unreachable!("numeric SQL aggregate strategy drifted during reduction"),
197                }
198            }
199            PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
200                kind:
201                    crate::db::query::plan::AggregateKind::Min
202                    | crate::db::query::plan::AggregateKind::Max,
203            } => {
204                let mut selected = None::<Value>;
205
206                for value in values {
207                    if matches!(value, Value::Null) {
208                        continue;
209                    }
210
211                    let replace = match selected.as_ref() {
212                        None => true,
213                        Some(current) => {
214                            let ordering =
215                                compare_numeric_or_strict_order(&value, current).ok_or_else(
216                                    || {
217                                        QueryError::invariant(
218                                            "extrema SQL aggregate dispatch encountered incomparable projected values",
219                                        )
220                                    },
221                                )?;
222
223                            match strategy.runtime_descriptor() {
224                                PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
225                                    kind: crate::db::query::plan::AggregateKind::Min,
226                                } => ordering.is_lt(),
227                                PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
228                                    kind: crate::db::query::plan::AggregateKind::Max,
229                                } => ordering.is_gt(),
230                                _ => unreachable!(
231                                    "extrema SQL aggregate strategy drifted during reduction"
232                                ),
233                            }
234                        }
235                    };
236
237                    if replace {
238                        selected = Some(value);
239                    }
240                }
241
242                Ok(selected.unwrap_or(Value::Null))
243            }
244            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
245            | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
246                Err(QueryError::invariant(
247                    "prepared SQL scalar aggregate strategy drifted outside SQL support",
248                ))
249            }
250        }
251    }
252
253    // Project one single-field structural query and return its canonical field
254    // values for aggregate reduction.
255    fn execute_structural_sql_aggregate_field_projection(
256        &self,
257        query: crate::db::query::intent::StructuralQuery,
258        authority: crate::db::executor::EntityAuthority,
259    ) -> Result<Vec<Value>, QueryError> {
260        let (_, rows, _) = self
261            .execute_structural_sql_projection(query, authority)?
262            .into_parts();
263        let mut projected = Vec::with_capacity(rows.len());
264
265        for row in rows {
266            let [value] = row.as_slice() else {
267                return Err(QueryError::invariant(
268                    "structural SQL aggregate projection must emit exactly one field",
269                ));
270            };
271
272            projected.push(value.clone());
273        }
274
275        Ok(projected)
276    }
277
278    // Execute one generic-free prepared SQL aggregate command through the
279    // structural SQL projection path and package the result as one row-shaped
280    // dispatch payload for unified SQL loops.
281    pub(in crate::db::session::sql) fn execute_sql_aggregate_dispatch_for_authority(
282        &self,
283        command: SqlGlobalAggregateCommandCore,
284        authority: crate::db::executor::EntityAuthority,
285        label_override: Option<String>,
286    ) -> Result<SqlDispatchResult, QueryError> {
287        let model = authority.model();
288        let strategy = command
289            .prepared_scalar_strategy_with_model(model)
290            .map_err(QueryError::from_sql_lowering_error)?;
291        let label = label_override.unwrap_or_else(|| Self::sql_scalar_aggregate_label(&strategy));
292        let value = match strategy.runtime_descriptor() {
293            PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => {
294                let (_, _, row_count) = self
295                    .execute_structural_sql_projection(
296                        command
297                            .query()
298                            .clone()
299                            .select_fields([authority.primary_key_name()]),
300                        authority,
301                    )?
302                    .into_parts();
303
304                Value::Uint(u64::from(row_count))
305            }
306            PreparedSqlScalarAggregateRuntimeDescriptor::CountField
307            | PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
308            | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
309                let Some(field) = strategy.projected_field() else {
310                    return Err(QueryError::invariant(
311                        "field-target SQL aggregate strategy requires projected field label",
312                    ));
313                };
314                let values = self.execute_structural_sql_aggregate_field_projection(
315                    command.query().clone().select_fields([field]),
316                    authority,
317                )?;
318
319                Self::reduce_structural_sql_aggregate_field_values(values, &strategy)?
320            }
321        };
322
323        Ok(SqlDispatchResult::Projection {
324            columns: vec![label],
325            rows: vec![vec![value]],
326            row_count: 1,
327        })
328    }
329
330    // Compile one already-parsed SQL aggregate statement into the shared
331    // generic-free aggregate command used by unified dispatch/query surfaces.
332    pub(in crate::db::session::sql) fn compile_sql_aggregate_command_core_for_authority(
333        parsed: &SqlParsedStatement,
334        authority: crate::db::executor::EntityAuthority,
335    ) -> Result<SqlGlobalAggregateCommandCore, QueryError> {
336        compile_sql_global_aggregate_command_core_from_prepared(
337            parsed.prepare(authority.model().name())?,
338            authority.model(),
339            MissingRowPolicy::Ignore,
340        )
341        .map_err(QueryError::from_sql_lowering_error)
342    }
343
344    // Require one resolved target slot from a prepared field-target SQL
345    // aggregate strategy before dispatching into execution families.
346    fn prepared_sql_scalar_target_slot_required(
347        strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
348        message: &'static str,
349    ) -> Result<crate::db::query::plan::FieldSlot, QueryError> {
350        strategy
351            .target_slot()
352            .cloned()
353            .ok_or_else(|| QueryError::invariant(message))
354    }
355
356    // Execute prepared COUNT(*) through the shared existing-rows scalar
357    // terminal boundary.
358    fn execute_prepared_sql_scalar_count_rows<E>(
359        &self,
360        command: &SqlGlobalAggregateCommand<E>,
361    ) -> Result<Value, QueryError>
362    where
363        E: PersistedRow<Canister = C> + EntityValue,
364    {
365        self.execute_load_query_with(command.query(), |load, plan| {
366            load.execute_scalar_terminal_request(
367                plan,
368                crate::db::executor::ScalarTerminalBoundaryRequest::Count,
369            )?
370            .into_count()
371        })
372        .map(|count| Value::Uint(u64::from(count)))
373    }
374
375    // Execute prepared COUNT(field) through the shared scalar projection
376    // boundary.
377    fn execute_prepared_sql_scalar_count_field<E>(
378        &self,
379        command: &SqlGlobalAggregateCommand<E>,
380        strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
381    ) -> Result<Value, QueryError>
382    where
383        E: PersistedRow<Canister = C> + EntityValue,
384    {
385        let target_slot = Self::prepared_sql_scalar_target_slot_required(
386            strategy,
387            "prepared COUNT(field) SQL aggregate strategy requires target slot",
388        )?;
389
390        self.execute_load_query_with(command.query(), |load, plan| {
391            load.execute_scalar_projection_boundary(
392                plan,
393                target_slot.clone(),
394                ScalarProjectionBoundaryRequest::CountNonNull,
395            )?
396            .into_count()
397        })
398        .map(|count| Value::Uint(u64::from(count)))
399    }
400
401    // Execute prepared SUM/AVG(field) through the shared numeric field
402    // boundary.
403    fn execute_prepared_sql_scalar_numeric_field<E>(
404        &self,
405        command: &SqlGlobalAggregateCommand<E>,
406        strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
407        request: ScalarNumericFieldBoundaryRequest,
408        message: &'static str,
409    ) -> Result<Value, QueryError>
410    where
411        E: PersistedRow<Canister = C> + EntityValue,
412    {
413        let target_slot = Self::prepared_sql_scalar_target_slot_required(strategy, message)?;
414
415        self.execute_load_query_with(command.query(), |load, plan| {
416            load.execute_numeric_field_boundary(plan, target_slot.clone(), request)
417        })
418        .map(|value| value.map_or(Value::Null, Value::Decimal))
419    }
420
421    // Execute prepared MIN/MAX(field) through the shared extrema-value
422    // boundary.
423    fn execute_prepared_sql_scalar_extrema_field<E>(
424        &self,
425        command: &SqlGlobalAggregateCommand<E>,
426        strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
427        kind: crate::db::query::plan::AggregateKind,
428    ) -> Result<Value, QueryError>
429    where
430        E: PersistedRow<Canister = C> + EntityValue,
431    {
432        let target_slot = Self::prepared_sql_scalar_target_slot_required(
433            strategy,
434            "prepared extrema SQL aggregate strategy requires target slot",
435        )?;
436
437        self.execute_load_query_with(command.query(), |load, plan| {
438            load.execute_scalar_extrema_value_boundary(plan, target_slot.clone(), kind)
439        })
440        .map(|value| value.unwrap_or(Value::Null))
441    }
442
443    // Execute one prepared typed SQL scalar aggregate strategy through the
444    // existing aggregate boundary families without rediscovering behavior from
445    // raw SQL terminal variants at the session layer.
446    fn execute_prepared_sql_scalar_aggregate<E>(
447        &self,
448        command: &SqlGlobalAggregateCommand<E>,
449    ) -> Result<Value, QueryError>
450    where
451        E: PersistedRow<Canister = C> + EntityValue,
452    {
453        let strategy = command.prepared_scalar_strategy();
454
455        match strategy.runtime_descriptor() {
456            PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => {
457                self.execute_prepared_sql_scalar_count_rows(command)
458            }
459            PreparedSqlScalarAggregateRuntimeDescriptor::CountField => {
460                self.execute_prepared_sql_scalar_count_field(command, &strategy)
461            }
462            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
463                kind: crate::db::query::plan::AggregateKind::Sum,
464            } => self.execute_prepared_sql_scalar_numeric_field(
465                command,
466                &strategy,
467                ScalarNumericFieldBoundaryRequest::Sum,
468                "prepared SUM(field) SQL aggregate strategy requires target slot",
469            ),
470            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
471                kind: crate::db::query::plan::AggregateKind::Avg,
472            } => self.execute_prepared_sql_scalar_numeric_field(
473                command,
474                &strategy,
475                ScalarNumericFieldBoundaryRequest::Avg,
476                "prepared AVG(field) SQL aggregate strategy requires target slot",
477            ),
478            PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { kind } => {
479                self.execute_prepared_sql_scalar_extrema_field(command, &strategy, kind)
480            }
481            PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. } => {
482                Err(QueryError::invariant(
483                    "prepared SQL scalar aggregate numeric runtime descriptor drift",
484                ))
485            }
486        }
487    }
488
489    /// Execute one reduced SQL global aggregate `SELECT` statement.
490    ///
491    /// This entrypoint is intentionally constrained to one aggregate terminal
492    /// shape per statement and preserves existing terminal semantics.
493    pub fn execute_sql_aggregate<E>(&self, sql: &str) -> Result<Value, QueryError>
494    where
495        E: PersistedRow<Canister = C> + EntityValue,
496    {
497        // Parse once into one owned statement so the aggregate lane can keep
498        // its surface checks and lowering on the same statement instance.
499        let statement = parse_sql(sql).map_err(QueryError::from_sql_parse_error)?;
500
501        // First keep wrong-lane traffic on an explicit aggregate-surface
502        // contract instead of relying on generic lowering failures.
503        match &statement {
504            SqlStatement::Select(_) if is_sql_global_aggregate_statement(&statement) => {}
505            SqlStatement::Select(statement) if !statement.group_by.is_empty() => {
506                return Err(QueryError::unsupported_query(
507                    unsupported_sql_aggregate_grouped_message(),
508                ));
509            }
510            SqlStatement::Delete(_) => {
511                return Err(QueryError::unsupported_query(
512                    "execute_sql_aggregate rejects DELETE; use execute_sql_dispatch",
513                ));
514            }
515            _ => {
516                let route = sql_statement_route_from_statement(&statement);
517
518                return Err(QueryError::unsupported_query(
519                    unsupported_sql_aggregate_surface_lane_message(&route),
520                ));
521            }
522        }
523
524        // First lower the SQL surface onto the existing single-terminal
525        // aggregate command authority so execution never has to rediscover the
526        // accepted aggregate shape family.
527        let prepared = prepare_sql_statement(statement, E::MODEL.name())
528            .map_err(QueryError::from_sql_lowering_error)?;
529        let command = compile_sql_global_aggregate_command_from_prepared::<E>(
530            prepared.clone(),
531            MissingRowPolicy::Ignore,
532        )
533        .map_err(QueryError::from_sql_lowering_error)?;
534        let strategy = command.prepared_scalar_strategy();
535
536        // DISTINCT field aggregates reuse the existing structural projection +
537        // reduction lane so SQL deduplicates aggregate inputs before folding.
538        if strategy.is_distinct() {
539            let dispatch = compile_sql_global_aggregate_command_core_from_prepared(
540                prepared,
541                E::MODEL,
542                MissingRowPolicy::Ignore,
543            )
544            .map_err(QueryError::from_sql_lowering_error)?;
545            let authority = crate::db::executor::EntityAuthority::for_type::<E>();
546            let SqlDispatchResult::Projection { rows, .. } =
547                self.execute_sql_aggregate_dispatch_for_authority(dispatch, authority, None)?
548            else {
549                return Err(QueryError::invariant(
550                    "DISTINCT SQL aggregate dispatch must finalize as one projection row",
551                ));
552            };
553            let Some(mut row) = rows.into_iter().next() else {
554                return Err(QueryError::invariant(
555                    "DISTINCT SQL aggregate dispatch must emit one projection row",
556                ));
557            };
558            if row.len() != 1 {
559                return Err(QueryError::invariant(
560                    "DISTINCT SQL aggregate dispatch must emit exactly one projected value",
561                ));
562            }
563            let value = row.pop().ok_or_else(|| {
564                QueryError::invariant(
565                    "DISTINCT SQL aggregate dispatch must emit exactly one projected value",
566                )
567            })?;
568
569            return Ok(value);
570        }
571
572        // Then dispatch through one prepared typed-scalar aggregate strategy so
573        // SQL aggregate execution and SQL aggregate explain consume the same
574        // behavioral source instead of matching raw terminal variants twice.
575        self.execute_prepared_sql_scalar_aggregate(&command)
576    }
577}