Skip to main content

icydb_core/db/session/sql/
aggregate.rs

1//! Module: db::session::sql::aggregate
2//! Responsibility: module-local ownership and contracts for db::session::sql::aggregate.
3//! Does not own: cross-module orchestration outside this module.
4//! Boundary: exposes this module API while keeping implementation details internal.
5
6use crate::{
7    db::{
8        DbSession, MissingRowPolicy, PersistedRow, QueryError,
9        executor::{ScalarNumericFieldBoundaryRequest, ScalarProjectionBoundaryRequest},
10        query::plan::AggregateKind,
11        session::sql::explain::resolve_sql_aggregate_target_slot,
12        session::sql::surface::sql_statement_route_from_statement,
13        session::sql::{SqlParsedStatement, SqlStatementRoute},
14        sql::lowering::{
15            SqlGlobalAggregateTerminal, compile_sql_global_aggregate_command_from_prepared,
16            is_sql_global_aggregate_statement, prepare_sql_statement,
17        },
18        sql::parser::{SqlStatement, parse_sql},
19    },
20    traits::{CanisterKind, EntityValue},
21    value::Value,
22};
23
24#[derive(Clone, Copy, Debug, Eq, PartialEq)]
25pub(in crate::db::session::sql) enum SqlAggregateSurface {
26    QueryFrom,
27    ExecuteSql,
28    ExecuteSqlGrouped,
29    ExecuteSqlDispatch,
30    GeneratedQuerySurface,
31}
32
33pub(in crate::db::session::sql) fn parsed_requires_dedicated_sql_aggregate_lane(
34    parsed: &SqlParsedStatement,
35) -> bool {
36    is_sql_global_aggregate_statement(&parsed.statement)
37}
38
39pub(in crate::db::session::sql) const fn unsupported_sql_aggregate_lane_message(
40    surface: SqlAggregateSurface,
41) -> &'static str {
42    match surface {
43        SqlAggregateSurface::QueryFrom => {
44            "query_from_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
45        }
46        SqlAggregateSurface::ExecuteSql => {
47            "execute_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
48        }
49        SqlAggregateSurface::ExecuteSqlGrouped => {
50            "execute_sql_grouped rejects global aggregate SELECT; use execute_sql_aggregate(...)"
51        }
52        SqlAggregateSurface::ExecuteSqlDispatch => {
53            "execute_sql_dispatch rejects global aggregate SELECT; use execute_sql_aggregate(...)"
54        }
55        SqlAggregateSurface::GeneratedQuerySurface => {
56            "generated SQL query surface rejects global aggregate SELECT; use execute_sql_aggregate(...)"
57        }
58    }
59}
60
61const fn unsupported_sql_aggregate_surface_lane_message(route: &SqlStatementRoute) -> &'static str {
62    match route {
63        SqlStatementRoute::Query { .. } => {
64            "execute_sql_aggregate requires constrained global aggregate SELECT"
65        }
66        SqlStatementRoute::Explain { .. } => {
67            "execute_sql_aggregate rejects EXPLAIN; use execute_sql_dispatch"
68        }
69        SqlStatementRoute::Describe { .. } => {
70            "execute_sql_aggregate rejects DESCRIBE; use execute_sql_dispatch"
71        }
72        SqlStatementRoute::ShowIndexes { .. } => {
73            "execute_sql_aggregate rejects SHOW INDEXES; use execute_sql_dispatch"
74        }
75        SqlStatementRoute::ShowColumns { .. } => {
76            "execute_sql_aggregate rejects SHOW COLUMNS; use execute_sql_dispatch"
77        }
78        SqlStatementRoute::ShowEntities => {
79            "execute_sql_aggregate rejects SHOW ENTITIES; use execute_sql_dispatch"
80        }
81    }
82}
83
84const fn unsupported_sql_aggregate_grouped_message() -> &'static str {
85    "execute_sql_aggregate rejects grouped SELECT; use execute_sql_grouped(...)"
86}
87
88impl<C: CanisterKind> DbSession<C> {
89    /// Execute one reduced SQL global aggregate `SELECT` statement.
90    ///
91    /// This entrypoint is intentionally constrained to one aggregate terminal
92    /// shape per statement and preserves existing terminal semantics.
93    pub fn execute_sql_aggregate<E>(&self, sql: &str) -> Result<Value, QueryError>
94    where
95        E: PersistedRow<Canister = C> + EntityValue,
96    {
97        // Parse once into one owned statement so the aggregate lane can keep
98        // its surface checks and lowering on the same statement instance.
99        let statement = parse_sql(sql).map_err(QueryError::from_sql_parse_error)?;
100
101        // First keep wrong-lane traffic on an explicit aggregate-surface
102        // contract instead of relying on generic lowering failures.
103        match &statement {
104            SqlStatement::Select(_) if is_sql_global_aggregate_statement(&statement) => {}
105            SqlStatement::Select(statement) if !statement.group_by.is_empty() => {
106                return Err(QueryError::unsupported_query(
107                    unsupported_sql_aggregate_grouped_message(),
108                ));
109            }
110            SqlStatement::Delete(_) => {
111                return Err(QueryError::unsupported_query(
112                    "execute_sql_aggregate rejects DELETE; use execute_sql_dispatch",
113                ));
114            }
115            _ => {
116                let route = sql_statement_route_from_statement(&statement);
117
118                return Err(QueryError::unsupported_query(
119                    unsupported_sql_aggregate_surface_lane_message(&route),
120                ));
121            }
122        }
123
124        // First lower the SQL surface onto the existing single-terminal
125        // aggregate command authority so execution never has to rediscover the
126        // accepted aggregate shape family.
127        let command = compile_sql_global_aggregate_command_from_prepared::<E>(
128            prepare_sql_statement(statement, E::MODEL.name())
129                .map_err(QueryError::from_sql_lowering_error)?,
130            MissingRowPolicy::Ignore,
131        )
132        .map_err(QueryError::from_sql_lowering_error)?;
133
134        // Then dispatch each accepted terminal onto the existing load/query
135        // boundaries instead of reopening aggregate execution ownership here.
136        match command.terminal() {
137            SqlGlobalAggregateTerminal::CountRows => self
138                .execute_load_query_with(command.query(), |load, plan| {
139                    load.execute_scalar_terminal_request(
140                        plan,
141                        crate::db::executor::ScalarTerminalBoundaryRequest::Count,
142                    )?
143                    .into_count()
144                })
145                .map(|count| Value::Uint(u64::from(count))),
146            SqlGlobalAggregateTerminal::CountField(field) => {
147                let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
148
149                self.execute_load_query_with(command.query(), |load, plan| {
150                    load.execute_scalar_projection_boundary(
151                        plan,
152                        target_slot,
153                        ScalarProjectionBoundaryRequest::CountNonNull,
154                    )?
155                    .into_count()
156                })
157                .map(|count| Value::Uint(u64::from(count)))
158            }
159            SqlGlobalAggregateTerminal::SumField(field) => {
160                let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
161
162                self.execute_load_query_with(command.query(), |load, plan| {
163                    load.execute_numeric_field_boundary(
164                        plan,
165                        target_slot,
166                        ScalarNumericFieldBoundaryRequest::Sum,
167                    )
168                })
169                .map(|value| value.map_or(Value::Null, Value::Decimal))
170            }
171            SqlGlobalAggregateTerminal::AvgField(field) => {
172                let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
173
174                self.execute_load_query_with(command.query(), |load, plan| {
175                    load.execute_numeric_field_boundary(
176                        plan,
177                        target_slot,
178                        ScalarNumericFieldBoundaryRequest::Avg,
179                    )
180                })
181                .map(|value| value.map_or(Value::Null, Value::Decimal))
182            }
183            SqlGlobalAggregateTerminal::MinField(field) => {
184                let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
185
186                self.execute_load_query_with(command.query(), |load, plan| {
187                    load.execute_scalar_extrema_value_boundary(
188                        plan,
189                        target_slot,
190                        AggregateKind::Min,
191                    )
192                })
193                .map(|value| value.unwrap_or(Value::Null))
194            }
195            SqlGlobalAggregateTerminal::MaxField(field) => {
196                let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
197
198                self.execute_load_query_with(command.query(), |load, plan| {
199                    load.execute_scalar_extrema_value_boundary(
200                        plan,
201                        target_slot,
202                        AggregateKind::Max,
203                    )
204                })
205                .map(|value| value.unwrap_or(Value::Null))
206            }
207        }
208    }
209}