1use crate::{
8 db::{
9 DbSession, MissingRowPolicy, PersistedRow, QueryError,
10 executor::{ScalarNumericFieldBoundaryRequest, ScalarProjectionBoundaryRequest},
11 numeric::{
12 add_decimal_terms, average_decimal_terms, coerce_numeric_decimal,
13 compare_numeric_or_strict_order,
14 },
15 session::sql::surface::sql_statement_route_from_statement,
16 session::sql::{SqlDispatchResult, SqlParsedStatement, SqlStatementRoute},
17 sql::lowering::{
18 PreparedSqlScalarAggregateRuntimeDescriptor, PreparedSqlScalarAggregateStrategy,
19 SqlGlobalAggregateCommand, SqlGlobalAggregateCommandCore,
20 compile_sql_global_aggregate_command_core_from_prepared,
21 compile_sql_global_aggregate_command_from_prepared, is_sql_global_aggregate_statement,
22 prepare_sql_statement,
23 },
24 sql::parser::{SqlStatement, parse_sql},
25 },
26 traits::{CanisterKind, EntityValue},
27 value::Value,
28};
29
30#[derive(Clone, Copy, Debug, Eq, PartialEq)]
31pub(in crate::db::session::sql) enum SqlAggregateSurface {
32 QueryFrom,
33 ExecuteSql,
34 ExecuteSqlGrouped,
35}
36
37pub(in crate::db::session::sql) fn parsed_requires_dedicated_sql_aggregate_lane(
38 parsed: &SqlParsedStatement,
39) -> bool {
40 is_sql_global_aggregate_statement(&parsed.statement)
41}
42
43pub(in crate::db::session::sql) const fn unsupported_sql_aggregate_lane_message(
44 surface: SqlAggregateSurface,
45) -> &'static str {
46 match surface {
47 SqlAggregateSurface::QueryFrom => {
48 "query_from_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
49 }
50 SqlAggregateSurface::ExecuteSql => {
51 "execute_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
52 }
53 SqlAggregateSurface::ExecuteSqlGrouped => {
54 "execute_sql_grouped rejects global aggregate SELECT; use execute_sql_aggregate(...)"
55 }
56 }
57}
58
59const fn unsupported_sql_aggregate_surface_lane_message(route: &SqlStatementRoute) -> &'static str {
60 match route {
61 SqlStatementRoute::Query { .. } => {
62 "execute_sql_aggregate requires constrained global aggregate SELECT"
63 }
64 SqlStatementRoute::Insert { .. } => {
65 "execute_sql_aggregate rejects INSERT; use execute_sql_dispatch"
66 }
67 SqlStatementRoute::Update { .. } => {
68 "execute_sql_aggregate rejects UPDATE; use execute_sql_dispatch"
69 }
70 SqlStatementRoute::Explain { .. } => {
71 "execute_sql_aggregate rejects EXPLAIN; use execute_sql_dispatch"
72 }
73 SqlStatementRoute::Describe { .. } => {
74 "execute_sql_aggregate rejects DESCRIBE; use execute_sql_dispatch"
75 }
76 SqlStatementRoute::ShowIndexes { .. } => {
77 "execute_sql_aggregate rejects SHOW INDEXES; use execute_sql_dispatch"
78 }
79 SqlStatementRoute::ShowColumns { .. } => {
80 "execute_sql_aggregate rejects SHOW COLUMNS; use execute_sql_dispatch"
81 }
82 SqlStatementRoute::ShowEntities => {
83 "execute_sql_aggregate rejects SHOW ENTITIES; use execute_sql_dispatch"
84 }
85 }
86}
87
88const fn unsupported_sql_aggregate_grouped_message() -> &'static str {
89 "execute_sql_aggregate rejects grouped SELECT; use execute_sql_grouped(...)"
90}
91
92impl<C: CanisterKind> DbSession<C> {
93 pub(in crate::db::session::sql) fn sql_scalar_aggregate_label(
96 strategy: &PreparedSqlScalarAggregateStrategy,
97 ) -> String {
98 let kind = match strategy.runtime_descriptor() {
99 PreparedSqlScalarAggregateRuntimeDescriptor::CountRows
100 | PreparedSqlScalarAggregateRuntimeDescriptor::CountField => "COUNT",
101 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
102 kind: crate::db::query::plan::AggregateKind::Sum,
103 } => "SUM",
104 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
105 kind: crate::db::query::plan::AggregateKind::Avg,
106 } => "AVG",
107 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
108 kind: crate::db::query::plan::AggregateKind::Min,
109 } => "MIN",
110 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
111 kind: crate::db::query::plan::AggregateKind::Max,
112 } => "MAX",
113 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
114 | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
115 unreachable!("prepared SQL scalar aggregate strategy drifted outside SQL support")
116 }
117 };
118
119 match strategy.projected_field() {
120 Some(field) if strategy.is_distinct() => format!("{kind}(DISTINCT {field})"),
121 Some(field) => format!("{kind}({field})"),
122 None => format!("{kind}(*)"),
123 }
124 }
125
126 fn dedup_structural_sql_aggregate_input_values(values: Vec<Value>) -> Vec<Value> {
129 let mut deduped = Vec::with_capacity(values.len());
130
131 for value in values {
132 if deduped.iter().any(|current| current == &value) {
133 continue;
134 }
135 deduped.push(value);
136 }
137
138 deduped
139 }
140
141 fn reduce_structural_sql_aggregate_field_values(
144 values: Vec<Value>,
145 strategy: &PreparedSqlScalarAggregateStrategy,
146 ) -> Result<Value, QueryError> {
147 let values = if strategy.is_distinct() {
148 Self::dedup_structural_sql_aggregate_input_values(values)
149 } else {
150 values
151 };
152
153 match strategy.runtime_descriptor() {
154 PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => Err(QueryError::invariant(
155 "COUNT(*) structural reduction does not consume projected field values",
156 )),
157 PreparedSqlScalarAggregateRuntimeDescriptor::CountField => {
158 let count = values
159 .into_iter()
160 .filter(|value| !matches!(value, Value::Null))
161 .count();
162
163 Ok(Value::Uint(u64::try_from(count).unwrap_or(u64::MAX)))
164 }
165 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
166 kind:
167 crate::db::query::plan::AggregateKind::Sum
168 | crate::db::query::plan::AggregateKind::Avg,
169 } => {
170 let mut sum = None;
171 let mut row_count = 0_u64;
172
173 for value in values {
174 if matches!(value, Value::Null) {
175 continue;
176 }
177
178 let decimal = coerce_numeric_decimal(&value).ok_or_else(|| {
179 QueryError::invariant(
180 "numeric SQL aggregate dispatch encountered non-numeric projected value",
181 )
182 })?;
183 sum = Some(sum.map_or(decimal, |current| add_decimal_terms(current, decimal)));
184 row_count = row_count.saturating_add(1);
185 }
186
187 match strategy.runtime_descriptor() {
188 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
189 kind: crate::db::query::plan::AggregateKind::Sum,
190 } => Ok(sum.map_or(Value::Null, Value::Decimal)),
191 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
192 kind: crate::db::query::plan::AggregateKind::Avg,
193 } => Ok(sum
194 .and_then(|sum| average_decimal_terms(sum, row_count))
195 .map_or(Value::Null, Value::Decimal)),
196 _ => unreachable!("numeric SQL aggregate strategy drifted during reduction"),
197 }
198 }
199 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
200 kind:
201 crate::db::query::plan::AggregateKind::Min
202 | crate::db::query::plan::AggregateKind::Max,
203 } => {
204 let mut selected = None::<Value>;
205
206 for value in values {
207 if matches!(value, Value::Null) {
208 continue;
209 }
210
211 let replace = match selected.as_ref() {
212 None => true,
213 Some(current) => {
214 let ordering =
215 compare_numeric_or_strict_order(&value, current).ok_or_else(
216 || {
217 QueryError::invariant(
218 "extrema SQL aggregate dispatch encountered incomparable projected values",
219 )
220 },
221 )?;
222
223 match strategy.runtime_descriptor() {
224 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
225 kind: crate::db::query::plan::AggregateKind::Min,
226 } => ordering.is_lt(),
227 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
228 kind: crate::db::query::plan::AggregateKind::Max,
229 } => ordering.is_gt(),
230 _ => unreachable!(
231 "extrema SQL aggregate strategy drifted during reduction"
232 ),
233 }
234 }
235 };
236
237 if replace {
238 selected = Some(value);
239 }
240 }
241
242 Ok(selected.unwrap_or(Value::Null))
243 }
244 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
245 | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
246 Err(QueryError::invariant(
247 "prepared SQL scalar aggregate strategy drifted outside SQL support",
248 ))
249 }
250 }
251 }
252
253 fn execute_structural_sql_aggregate_field_projection(
256 &self,
257 query: crate::db::query::intent::StructuralQuery,
258 authority: crate::db::executor::EntityAuthority,
259 ) -> Result<Vec<Value>, QueryError> {
260 let (_, rows, _) = self
261 .execute_structural_sql_projection(query, authority)?
262 .into_parts();
263 let mut projected = Vec::with_capacity(rows.len());
264
265 for row in rows {
266 let [value] = row.as_slice() else {
267 return Err(QueryError::invariant(
268 "structural SQL aggregate projection must emit exactly one field",
269 ));
270 };
271
272 projected.push(value.clone());
273 }
274
275 Ok(projected)
276 }
277
278 pub(in crate::db::session::sql) fn execute_sql_aggregate_dispatch_for_authority(
282 &self,
283 command: SqlGlobalAggregateCommandCore,
284 authority: crate::db::executor::EntityAuthority,
285 label_override: Option<String>,
286 ) -> Result<SqlDispatchResult, QueryError> {
287 let model = authority.model();
288 let strategy = command
289 .prepared_scalar_strategy_with_model(model)
290 .map_err(QueryError::from_sql_lowering_error)?;
291 let label = label_override.unwrap_or_else(|| Self::sql_scalar_aggregate_label(&strategy));
292 let value = match strategy.runtime_descriptor() {
293 PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => {
294 let (_, _, row_count) = self
295 .execute_structural_sql_projection(
296 command
297 .query()
298 .clone()
299 .select_fields([authority.primary_key_name()]),
300 authority,
301 )?
302 .into_parts();
303
304 Value::Uint(u64::from(row_count))
305 }
306 PreparedSqlScalarAggregateRuntimeDescriptor::CountField
307 | PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
308 | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
309 let Some(field) = strategy.projected_field() else {
310 return Err(QueryError::invariant(
311 "field-target SQL aggregate strategy requires projected field label",
312 ));
313 };
314 let values = self.execute_structural_sql_aggregate_field_projection(
315 command.query().clone().select_fields([field]),
316 authority,
317 )?;
318
319 Self::reduce_structural_sql_aggregate_field_values(values, &strategy)?
320 }
321 };
322
323 Ok(SqlDispatchResult::Projection {
324 columns: vec![label],
325 rows: vec![vec![value]],
326 row_count: 1,
327 })
328 }
329
330 pub(in crate::db::session::sql) fn compile_sql_aggregate_command_core_for_authority(
333 parsed: &SqlParsedStatement,
334 authority: crate::db::executor::EntityAuthority,
335 ) -> Result<SqlGlobalAggregateCommandCore, QueryError> {
336 compile_sql_global_aggregate_command_core_from_prepared(
337 parsed.prepare(authority.model().name())?,
338 authority.model(),
339 MissingRowPolicy::Ignore,
340 )
341 .map_err(QueryError::from_sql_lowering_error)
342 }
343
344 fn prepared_sql_scalar_target_slot_required(
347 strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
348 message: &'static str,
349 ) -> Result<crate::db::query::plan::FieldSlot, QueryError> {
350 strategy
351 .target_slot()
352 .cloned()
353 .ok_or_else(|| QueryError::invariant(message))
354 }
355
356 fn execute_prepared_sql_scalar_count_rows<E>(
359 &self,
360 command: &SqlGlobalAggregateCommand<E>,
361 ) -> Result<Value, QueryError>
362 where
363 E: PersistedRow<Canister = C> + EntityValue,
364 {
365 self.execute_load_query_with(command.query(), |load, plan| {
366 load.execute_scalar_terminal_request(
367 plan,
368 crate::db::executor::ScalarTerminalBoundaryRequest::Count,
369 )?
370 .into_count()
371 })
372 .map(|count| Value::Uint(u64::from(count)))
373 }
374
375 fn execute_prepared_sql_scalar_count_field<E>(
378 &self,
379 command: &SqlGlobalAggregateCommand<E>,
380 strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
381 ) -> Result<Value, QueryError>
382 where
383 E: PersistedRow<Canister = C> + EntityValue,
384 {
385 let target_slot = Self::prepared_sql_scalar_target_slot_required(
386 strategy,
387 "prepared COUNT(field) SQL aggregate strategy requires target slot",
388 )?;
389
390 self.execute_load_query_with(command.query(), |load, plan| {
391 load.execute_scalar_projection_boundary(
392 plan,
393 target_slot.clone(),
394 ScalarProjectionBoundaryRequest::CountNonNull,
395 )?
396 .into_count()
397 })
398 .map(|count| Value::Uint(u64::from(count)))
399 }
400
401 fn execute_prepared_sql_scalar_numeric_field<E>(
404 &self,
405 command: &SqlGlobalAggregateCommand<E>,
406 strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
407 request: ScalarNumericFieldBoundaryRequest,
408 message: &'static str,
409 ) -> Result<Value, QueryError>
410 where
411 E: PersistedRow<Canister = C> + EntityValue,
412 {
413 let target_slot = Self::prepared_sql_scalar_target_slot_required(strategy, message)?;
414
415 self.execute_load_query_with(command.query(), |load, plan| {
416 load.execute_numeric_field_boundary(plan, target_slot.clone(), request)
417 })
418 .map(|value| value.map_or(Value::Null, Value::Decimal))
419 }
420
421 fn execute_prepared_sql_scalar_extrema_field<E>(
424 &self,
425 command: &SqlGlobalAggregateCommand<E>,
426 strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
427 kind: crate::db::query::plan::AggregateKind,
428 ) -> Result<Value, QueryError>
429 where
430 E: PersistedRow<Canister = C> + EntityValue,
431 {
432 let target_slot = Self::prepared_sql_scalar_target_slot_required(
433 strategy,
434 "prepared extrema SQL aggregate strategy requires target slot",
435 )?;
436
437 self.execute_load_query_with(command.query(), |load, plan| {
438 load.execute_scalar_extrema_value_boundary(plan, target_slot.clone(), kind)
439 })
440 .map(|value| value.unwrap_or(Value::Null))
441 }
442
443 fn execute_prepared_sql_scalar_aggregate<E>(
447 &self,
448 command: &SqlGlobalAggregateCommand<E>,
449 ) -> Result<Value, QueryError>
450 where
451 E: PersistedRow<Canister = C> + EntityValue,
452 {
453 let strategy = command.prepared_scalar_strategy();
454
455 match strategy.runtime_descriptor() {
456 PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => {
457 self.execute_prepared_sql_scalar_count_rows(command)
458 }
459 PreparedSqlScalarAggregateRuntimeDescriptor::CountField => {
460 self.execute_prepared_sql_scalar_count_field(command, &strategy)
461 }
462 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
463 kind: crate::db::query::plan::AggregateKind::Sum,
464 } => self.execute_prepared_sql_scalar_numeric_field(
465 command,
466 &strategy,
467 ScalarNumericFieldBoundaryRequest::Sum,
468 "prepared SUM(field) SQL aggregate strategy requires target slot",
469 ),
470 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
471 kind: crate::db::query::plan::AggregateKind::Avg,
472 } => self.execute_prepared_sql_scalar_numeric_field(
473 command,
474 &strategy,
475 ScalarNumericFieldBoundaryRequest::Avg,
476 "prepared AVG(field) SQL aggregate strategy requires target slot",
477 ),
478 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { kind } => {
479 self.execute_prepared_sql_scalar_extrema_field(command, &strategy, kind)
480 }
481 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. } => {
482 Err(QueryError::invariant(
483 "prepared SQL scalar aggregate numeric runtime descriptor drift",
484 ))
485 }
486 }
487 }
488
489 pub fn execute_sql_aggregate<E>(&self, sql: &str) -> Result<Value, QueryError>
494 where
495 E: PersistedRow<Canister = C> + EntityValue,
496 {
497 let statement = parse_sql(sql).map_err(QueryError::from_sql_parse_error)?;
500
501 match &statement {
504 SqlStatement::Select(_) if is_sql_global_aggregate_statement(&statement) => {}
505 SqlStatement::Select(statement) if !statement.group_by.is_empty() => {
506 return Err(QueryError::unsupported_query(
507 unsupported_sql_aggregate_grouped_message(),
508 ));
509 }
510 SqlStatement::Delete(_) => {
511 return Err(QueryError::unsupported_query(
512 "execute_sql_aggregate rejects DELETE; use execute_sql_dispatch",
513 ));
514 }
515 _ => {
516 let route = sql_statement_route_from_statement(&statement);
517
518 return Err(QueryError::unsupported_query(
519 unsupported_sql_aggregate_surface_lane_message(&route),
520 ));
521 }
522 }
523
524 let prepared = prepare_sql_statement(statement, E::MODEL.name())
528 .map_err(QueryError::from_sql_lowering_error)?;
529 let command = compile_sql_global_aggregate_command_from_prepared::<E>(
530 prepared.clone(),
531 MissingRowPolicy::Ignore,
532 )
533 .map_err(QueryError::from_sql_lowering_error)?;
534 let strategy = command.prepared_scalar_strategy();
535
536 if strategy.is_distinct() {
539 let dispatch = compile_sql_global_aggregate_command_core_from_prepared(
540 prepared,
541 E::MODEL,
542 MissingRowPolicy::Ignore,
543 )
544 .map_err(QueryError::from_sql_lowering_error)?;
545 let authority = crate::db::executor::EntityAuthority::for_type::<E>();
546 let SqlDispatchResult::Projection { rows, .. } =
547 self.execute_sql_aggregate_dispatch_for_authority(dispatch, authority, None)?
548 else {
549 return Err(QueryError::invariant(
550 "DISTINCT SQL aggregate dispatch must finalize as one projection row",
551 ));
552 };
553 let Some(mut row) = rows.into_iter().next() else {
554 return Err(QueryError::invariant(
555 "DISTINCT SQL aggregate dispatch must emit one projection row",
556 ));
557 };
558 if row.len() != 1 {
559 return Err(QueryError::invariant(
560 "DISTINCT SQL aggregate dispatch must emit exactly one projected value",
561 ));
562 }
563 let value = row.pop().ok_or_else(|| {
564 QueryError::invariant(
565 "DISTINCT SQL aggregate dispatch must emit exactly one projected value",
566 )
567 })?;
568
569 return Ok(value);
570 }
571
572 self.execute_prepared_sql_scalar_aggregate(&command)
576 }
577}