1use crate::{
8 db::{
9 DbSession, MissingRowPolicy, PersistedRow, QueryError,
10 executor::{ScalarNumericFieldBoundaryRequest, ScalarProjectionBoundaryRequest},
11 numeric::{
12 add_decimal_terms, average_decimal_terms, coerce_numeric_decimal,
13 compare_numeric_or_strict_order,
14 },
15 session::sql::surface::sql_statement_route_from_statement,
16 session::sql::{SqlDispatchResult, SqlParsedStatement, SqlStatementRoute},
17 sql::lowering::{
18 PreparedSqlScalarAggregateRuntimeDescriptor, PreparedSqlScalarAggregateStrategy,
19 SqlGlobalAggregateCommand, SqlGlobalAggregateCommandCore,
20 compile_sql_global_aggregate_command_core_from_prepared,
21 compile_sql_global_aggregate_command_from_prepared, is_sql_global_aggregate_statement,
22 prepare_sql_statement,
23 },
24 sql::parser::{SqlStatement, parse_sql},
25 },
26 traits::{CanisterKind, EntityValue},
27 value::Value,
28};
29
30#[derive(Clone, Copy, Debug, Eq, PartialEq)]
31pub(in crate::db::session::sql) enum SqlAggregateSurface {
32 QueryFrom,
33 ExecuteSql,
34 ExecuteSqlGrouped,
35}
36
37pub(in crate::db::session::sql) fn parsed_requires_dedicated_sql_aggregate_lane(
38 parsed: &SqlParsedStatement,
39) -> bool {
40 is_sql_global_aggregate_statement(&parsed.statement)
41}
42
43pub(in crate::db::session::sql) const fn unsupported_sql_aggregate_lane_message(
44 surface: SqlAggregateSurface,
45) -> &'static str {
46 match surface {
47 SqlAggregateSurface::QueryFrom => {
48 "query_from_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
49 }
50 SqlAggregateSurface::ExecuteSql => {
51 "execute_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
52 }
53 SqlAggregateSurface::ExecuteSqlGrouped => {
54 "execute_sql_grouped rejects global aggregate SELECT; use execute_sql_aggregate(...)"
55 }
56 }
57}
58
59const fn unsupported_sql_aggregate_surface_lane_message(route: &SqlStatementRoute) -> &'static str {
60 match route {
61 SqlStatementRoute::Query { .. } => {
62 "execute_sql_aggregate requires constrained global aggregate SELECT"
63 }
64 SqlStatementRoute::Explain { .. } => {
65 "execute_sql_aggregate rejects EXPLAIN; use execute_sql_dispatch"
66 }
67 SqlStatementRoute::Describe { .. } => {
68 "execute_sql_aggregate rejects DESCRIBE; use execute_sql_dispatch"
69 }
70 SqlStatementRoute::ShowIndexes { .. } => {
71 "execute_sql_aggregate rejects SHOW INDEXES; use execute_sql_dispatch"
72 }
73 SqlStatementRoute::ShowColumns { .. } => {
74 "execute_sql_aggregate rejects SHOW COLUMNS; use execute_sql_dispatch"
75 }
76 SqlStatementRoute::ShowEntities => {
77 "execute_sql_aggregate rejects SHOW ENTITIES; use execute_sql_dispatch"
78 }
79 }
80}
81
82const fn unsupported_sql_aggregate_grouped_message() -> &'static str {
83 "execute_sql_aggregate rejects grouped SELECT; use execute_sql_grouped(...)"
84}
85
86impl<C: CanisterKind> DbSession<C> {
87 pub(in crate::db::session::sql) fn sql_scalar_aggregate_label(
90 strategy: &PreparedSqlScalarAggregateStrategy,
91 ) -> String {
92 let kind = match strategy.runtime_descriptor() {
93 PreparedSqlScalarAggregateRuntimeDescriptor::CountRows
94 | PreparedSqlScalarAggregateRuntimeDescriptor::CountField => "COUNT",
95 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
96 kind: crate::db::query::plan::AggregateKind::Sum,
97 } => "SUM",
98 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
99 kind: crate::db::query::plan::AggregateKind::Avg,
100 } => "AVG",
101 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
102 kind: crate::db::query::plan::AggregateKind::Min,
103 } => "MIN",
104 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
105 kind: crate::db::query::plan::AggregateKind::Max,
106 } => "MAX",
107 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
108 | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
109 unreachable!("prepared SQL scalar aggregate strategy drifted outside SQL support")
110 }
111 };
112
113 match strategy.projected_field() {
114 Some(field) => format!("{kind}({field})"),
115 None => format!("{kind}(*)"),
116 }
117 }
118
119 fn reduce_structural_sql_aggregate_field_values(
122 values: Vec<Value>,
123 strategy: &PreparedSqlScalarAggregateStrategy,
124 ) -> Result<Value, QueryError> {
125 match strategy.runtime_descriptor() {
126 PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => Err(QueryError::invariant(
127 "COUNT(*) structural reduction does not consume projected field values",
128 )),
129 PreparedSqlScalarAggregateRuntimeDescriptor::CountField => {
130 let count = values
131 .into_iter()
132 .filter(|value| !matches!(value, Value::Null))
133 .count();
134
135 Ok(Value::Uint(u64::try_from(count).unwrap_or(u64::MAX)))
136 }
137 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
138 kind:
139 crate::db::query::plan::AggregateKind::Sum
140 | crate::db::query::plan::AggregateKind::Avg,
141 } => {
142 let mut sum = None;
143 let mut row_count = 0_u64;
144
145 for value in values {
146 if matches!(value, Value::Null) {
147 continue;
148 }
149
150 let decimal = coerce_numeric_decimal(&value).ok_or_else(|| {
151 QueryError::invariant(
152 "numeric SQL aggregate dispatch encountered non-numeric projected value",
153 )
154 })?;
155 sum = Some(sum.map_or(decimal, |current| add_decimal_terms(current, decimal)));
156 row_count = row_count.saturating_add(1);
157 }
158
159 match strategy.runtime_descriptor() {
160 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
161 kind: crate::db::query::plan::AggregateKind::Sum,
162 } => Ok(sum.map_or(Value::Null, Value::Decimal)),
163 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
164 kind: crate::db::query::plan::AggregateKind::Avg,
165 } => Ok(sum
166 .and_then(|sum| average_decimal_terms(sum, row_count))
167 .map_or(Value::Null, Value::Decimal)),
168 _ => unreachable!("numeric SQL aggregate strategy drifted during reduction"),
169 }
170 }
171 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
172 kind:
173 crate::db::query::plan::AggregateKind::Min
174 | crate::db::query::plan::AggregateKind::Max,
175 } => {
176 let mut selected = None::<Value>;
177
178 for value in values {
179 if matches!(value, Value::Null) {
180 continue;
181 }
182
183 let replace = match selected.as_ref() {
184 None => true,
185 Some(current) => {
186 let ordering =
187 compare_numeric_or_strict_order(&value, current).ok_or_else(
188 || {
189 QueryError::invariant(
190 "extrema SQL aggregate dispatch encountered incomparable projected values",
191 )
192 },
193 )?;
194
195 match strategy.runtime_descriptor() {
196 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
197 kind: crate::db::query::plan::AggregateKind::Min,
198 } => ordering.is_lt(),
199 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
200 kind: crate::db::query::plan::AggregateKind::Max,
201 } => ordering.is_gt(),
202 _ => unreachable!(
203 "extrema SQL aggregate strategy drifted during reduction"
204 ),
205 }
206 }
207 };
208
209 if replace {
210 selected = Some(value);
211 }
212 }
213
214 Ok(selected.unwrap_or(Value::Null))
215 }
216 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
217 | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
218 Err(QueryError::invariant(
219 "prepared SQL scalar aggregate strategy drifted outside SQL support",
220 ))
221 }
222 }
223 }
224
225 fn execute_structural_sql_aggregate_field_projection(
228 &self,
229 query: crate::db::query::intent::StructuralQuery,
230 authority: crate::db::executor::EntityAuthority,
231 ) -> Result<Vec<Value>, QueryError> {
232 let (_, rows, _) = self
233 .execute_structural_sql_projection(query, authority)?
234 .into_parts();
235 let mut projected = Vec::with_capacity(rows.len());
236
237 for row in rows {
238 let [value] = row.as_slice() else {
239 return Err(QueryError::invariant(
240 "structural SQL aggregate projection must emit exactly one field",
241 ));
242 };
243
244 projected.push(value.clone());
245 }
246
247 Ok(projected)
248 }
249
250 pub(in crate::db::session::sql) fn execute_sql_aggregate_dispatch_for_authority(
254 &self,
255 command: SqlGlobalAggregateCommandCore,
256 authority: crate::db::executor::EntityAuthority,
257 ) -> Result<SqlDispatchResult, QueryError> {
258 let model = authority.model();
259 let strategy = command
260 .prepared_scalar_strategy_with_model(model)
261 .map_err(QueryError::from_sql_lowering_error)?;
262 let label = Self::sql_scalar_aggregate_label(&strategy);
263 let value = match strategy.runtime_descriptor() {
264 PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => {
265 let (_, _, row_count) = self
266 .execute_structural_sql_projection(
267 command
268 .query()
269 .clone()
270 .select_fields([authority.primary_key_name()]),
271 authority,
272 )?
273 .into_parts();
274
275 Value::Uint(u64::from(row_count))
276 }
277 PreparedSqlScalarAggregateRuntimeDescriptor::CountField
278 | PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
279 | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
280 let Some(field) = strategy.projected_field() else {
281 return Err(QueryError::invariant(
282 "field-target SQL aggregate strategy requires projected field label",
283 ));
284 };
285 let values = self.execute_structural_sql_aggregate_field_projection(
286 command.query().clone().select_fields([field]),
287 authority,
288 )?;
289
290 Self::reduce_structural_sql_aggregate_field_values(values, &strategy)?
291 }
292 };
293
294 Ok(SqlDispatchResult::Projection {
295 columns: vec![label],
296 rows: vec![vec![value]],
297 row_count: 1,
298 })
299 }
300
301 pub(in crate::db::session::sql) fn compile_sql_aggregate_command_core_for_authority(
304 parsed: &SqlParsedStatement,
305 authority: crate::db::executor::EntityAuthority,
306 ) -> Result<SqlGlobalAggregateCommandCore, QueryError> {
307 compile_sql_global_aggregate_command_core_from_prepared(
308 parsed.prepare(authority.model().name())?,
309 authority.model(),
310 MissingRowPolicy::Ignore,
311 )
312 .map_err(QueryError::from_sql_lowering_error)
313 }
314
315 fn prepared_sql_scalar_target_slot_required(
318 strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
319 message: &'static str,
320 ) -> Result<crate::db::query::plan::FieldSlot, QueryError> {
321 strategy
322 .target_slot()
323 .cloned()
324 .ok_or_else(|| QueryError::invariant(message))
325 }
326
327 fn execute_prepared_sql_scalar_count_rows<E>(
330 &self,
331 command: &SqlGlobalAggregateCommand<E>,
332 ) -> Result<Value, QueryError>
333 where
334 E: PersistedRow<Canister = C> + EntityValue,
335 {
336 self.execute_load_query_with(command.query(), |load, plan| {
337 load.execute_scalar_terminal_request(
338 plan,
339 crate::db::executor::ScalarTerminalBoundaryRequest::Count,
340 )?
341 .into_count()
342 })
343 .map(|count| Value::Uint(u64::from(count)))
344 }
345
346 fn execute_prepared_sql_scalar_count_field<E>(
349 &self,
350 command: &SqlGlobalAggregateCommand<E>,
351 strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
352 ) -> Result<Value, QueryError>
353 where
354 E: PersistedRow<Canister = C> + EntityValue,
355 {
356 let target_slot = Self::prepared_sql_scalar_target_slot_required(
357 strategy,
358 "prepared COUNT(field) SQL aggregate strategy requires target slot",
359 )?;
360
361 self.execute_load_query_with(command.query(), |load, plan| {
362 load.execute_scalar_projection_boundary(
363 plan,
364 target_slot.clone(),
365 ScalarProjectionBoundaryRequest::CountNonNull,
366 )?
367 .into_count()
368 })
369 .map(|count| Value::Uint(u64::from(count)))
370 }
371
372 fn execute_prepared_sql_scalar_numeric_field<E>(
375 &self,
376 command: &SqlGlobalAggregateCommand<E>,
377 strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
378 request: ScalarNumericFieldBoundaryRequest,
379 message: &'static str,
380 ) -> Result<Value, QueryError>
381 where
382 E: PersistedRow<Canister = C> + EntityValue,
383 {
384 let target_slot = Self::prepared_sql_scalar_target_slot_required(strategy, message)?;
385
386 self.execute_load_query_with(command.query(), |load, plan| {
387 load.execute_numeric_field_boundary(plan, target_slot.clone(), request)
388 })
389 .map(|value| value.map_or(Value::Null, Value::Decimal))
390 }
391
392 fn execute_prepared_sql_scalar_extrema_field<E>(
395 &self,
396 command: &SqlGlobalAggregateCommand<E>,
397 strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
398 kind: crate::db::query::plan::AggregateKind,
399 ) -> Result<Value, QueryError>
400 where
401 E: PersistedRow<Canister = C> + EntityValue,
402 {
403 let target_slot = Self::prepared_sql_scalar_target_slot_required(
404 strategy,
405 "prepared extrema SQL aggregate strategy requires target slot",
406 )?;
407
408 self.execute_load_query_with(command.query(), |load, plan| {
409 load.execute_scalar_extrema_value_boundary(plan, target_slot.clone(), kind)
410 })
411 .map(|value| value.unwrap_or(Value::Null))
412 }
413
414 fn execute_prepared_sql_scalar_aggregate<E>(
418 &self,
419 command: &SqlGlobalAggregateCommand<E>,
420 ) -> Result<Value, QueryError>
421 where
422 E: PersistedRow<Canister = C> + EntityValue,
423 {
424 let strategy = command.prepared_scalar_strategy();
425
426 match strategy.runtime_descriptor() {
427 PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => {
428 self.execute_prepared_sql_scalar_count_rows(command)
429 }
430 PreparedSqlScalarAggregateRuntimeDescriptor::CountField => {
431 self.execute_prepared_sql_scalar_count_field(command, &strategy)
432 }
433 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
434 kind: crate::db::query::plan::AggregateKind::Sum,
435 } => self.execute_prepared_sql_scalar_numeric_field(
436 command,
437 &strategy,
438 ScalarNumericFieldBoundaryRequest::Sum,
439 "prepared SUM(field) SQL aggregate strategy requires target slot",
440 ),
441 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
442 kind: crate::db::query::plan::AggregateKind::Avg,
443 } => self.execute_prepared_sql_scalar_numeric_field(
444 command,
445 &strategy,
446 ScalarNumericFieldBoundaryRequest::Avg,
447 "prepared AVG(field) SQL aggregate strategy requires target slot",
448 ),
449 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { kind } => {
450 self.execute_prepared_sql_scalar_extrema_field(command, &strategy, kind)
451 }
452 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. } => {
453 Err(QueryError::invariant(
454 "prepared SQL scalar aggregate numeric runtime descriptor drift",
455 ))
456 }
457 }
458 }
459
460 pub fn execute_sql_aggregate<E>(&self, sql: &str) -> Result<Value, QueryError>
465 where
466 E: PersistedRow<Canister = C> + EntityValue,
467 {
468 let statement = parse_sql(sql).map_err(QueryError::from_sql_parse_error)?;
471
472 match &statement {
475 SqlStatement::Select(_) if is_sql_global_aggregate_statement(&statement) => {}
476 SqlStatement::Select(statement) if !statement.group_by.is_empty() => {
477 return Err(QueryError::unsupported_query(
478 unsupported_sql_aggregate_grouped_message(),
479 ));
480 }
481 SqlStatement::Delete(_) => {
482 return Err(QueryError::unsupported_query(
483 "execute_sql_aggregate rejects DELETE; use execute_sql_dispatch",
484 ));
485 }
486 _ => {
487 let route = sql_statement_route_from_statement(&statement);
488
489 return Err(QueryError::unsupported_query(
490 unsupported_sql_aggregate_surface_lane_message(&route),
491 ));
492 }
493 }
494
495 let command = compile_sql_global_aggregate_command_from_prepared::<E>(
499 prepare_sql_statement(statement, E::MODEL.name())
500 .map_err(QueryError::from_sql_lowering_error)?,
501 MissingRowPolicy::Ignore,
502 )
503 .map_err(QueryError::from_sql_lowering_error)?;
504
505 self.execute_prepared_sql_scalar_aggregate(&command)
509 }
510}