1use crate::{
8 db::{
9 DbSession, MissingRowPolicy, PersistedRow, QueryError,
10 executor::{ScalarNumericFieldBoundaryRequest, ScalarProjectionBoundaryRequest},
11 numeric::{
12 add_decimal_terms, average_decimal_terms, coerce_numeric_decimal,
13 compare_numeric_or_strict_order,
14 },
15 session::sql::surface::sql_statement_route_from_statement,
16 session::sql::{SqlDispatchResult, SqlParsedStatement, SqlStatementRoute},
17 sql::lowering::{
18 PreparedSqlScalarAggregateRuntimeDescriptor, PreparedSqlScalarAggregateStrategy,
19 SqlGlobalAggregateCommand, SqlGlobalAggregateCommandCore,
20 compile_sql_global_aggregate_command_core_from_prepared,
21 compile_sql_global_aggregate_command_from_prepared, is_sql_global_aggregate_statement,
22 prepare_sql_statement,
23 },
24 sql::parser::{SqlStatement, parse_sql},
25 },
26 traits::{CanisterKind, EntityValue},
27 value::Value,
28};
29
30#[derive(Clone, Copy, Debug, Eq, PartialEq)]
31pub(in crate::db::session::sql) enum SqlAggregateSurface {
32 QueryFrom,
33 ExecuteSql,
34 ExecuteSqlGrouped,
35}
36
37pub(in crate::db::session::sql) fn parsed_requires_dedicated_sql_aggregate_lane(
38 parsed: &SqlParsedStatement,
39) -> bool {
40 is_sql_global_aggregate_statement(&parsed.statement)
41}
42
43pub(in crate::db::session::sql) const fn unsupported_sql_aggregate_lane_message(
44 surface: SqlAggregateSurface,
45) -> &'static str {
46 match surface {
47 SqlAggregateSurface::QueryFrom => {
48 "query_from_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
49 }
50 SqlAggregateSurface::ExecuteSql => {
51 "execute_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
52 }
53 SqlAggregateSurface::ExecuteSqlGrouped => {
54 "execute_sql_grouped rejects global aggregate SELECT; use execute_sql_aggregate(...)"
55 }
56 }
57}
58
59const fn unsupported_sql_aggregate_surface_lane_message(route: &SqlStatementRoute) -> &'static str {
60 match route {
61 SqlStatementRoute::Query { .. } => {
62 "execute_sql_aggregate requires constrained global aggregate SELECT"
63 }
64 SqlStatementRoute::Explain { .. } => {
65 "execute_sql_aggregate rejects EXPLAIN; use execute_sql_dispatch"
66 }
67 SqlStatementRoute::Describe { .. } => {
68 "execute_sql_aggregate rejects DESCRIBE; use execute_sql_dispatch"
69 }
70 SqlStatementRoute::ShowIndexes { .. } => {
71 "execute_sql_aggregate rejects SHOW INDEXES; use execute_sql_dispatch"
72 }
73 SqlStatementRoute::ShowColumns { .. } => {
74 "execute_sql_aggregate rejects SHOW COLUMNS; use execute_sql_dispatch"
75 }
76 SqlStatementRoute::ShowEntities => {
77 "execute_sql_aggregate rejects SHOW ENTITIES; use execute_sql_dispatch"
78 }
79 }
80}
81
82const fn unsupported_sql_aggregate_grouped_message() -> &'static str {
83 "execute_sql_aggregate rejects grouped SELECT; use execute_sql_grouped(...)"
84}
85
86impl<C: CanisterKind> DbSession<C> {
87 pub(in crate::db::session::sql) fn sql_scalar_aggregate_label(
90 strategy: &PreparedSqlScalarAggregateStrategy,
91 ) -> String {
92 let kind = match strategy.runtime_descriptor() {
93 PreparedSqlScalarAggregateRuntimeDescriptor::CountRows
94 | PreparedSqlScalarAggregateRuntimeDescriptor::CountField => "COUNT",
95 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
96 kind: crate::db::query::plan::AggregateKind::Sum,
97 } => "SUM",
98 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
99 kind: crate::db::query::plan::AggregateKind::Avg,
100 } => "AVG",
101 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
102 kind: crate::db::query::plan::AggregateKind::Min,
103 } => "MIN",
104 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
105 kind: crate::db::query::plan::AggregateKind::Max,
106 } => "MAX",
107 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
108 | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
109 unreachable!("prepared SQL scalar aggregate strategy drifted outside SQL support")
110 }
111 };
112
113 match strategy.projected_field() {
114 Some(field) if strategy.is_distinct() => format!("{kind}(DISTINCT {field})"),
115 Some(field) => format!("{kind}({field})"),
116 None => format!("{kind}(*)"),
117 }
118 }
119
120 fn dedup_structural_sql_aggregate_input_values(values: Vec<Value>) -> Vec<Value> {
123 let mut deduped = Vec::with_capacity(values.len());
124
125 for value in values {
126 if deduped.iter().any(|current| current == &value) {
127 continue;
128 }
129 deduped.push(value);
130 }
131
132 deduped
133 }
134
135 fn reduce_structural_sql_aggregate_field_values(
138 values: Vec<Value>,
139 strategy: &PreparedSqlScalarAggregateStrategy,
140 ) -> Result<Value, QueryError> {
141 let values = if strategy.is_distinct() {
142 Self::dedup_structural_sql_aggregate_input_values(values)
143 } else {
144 values
145 };
146
147 match strategy.runtime_descriptor() {
148 PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => Err(QueryError::invariant(
149 "COUNT(*) structural reduction does not consume projected field values",
150 )),
151 PreparedSqlScalarAggregateRuntimeDescriptor::CountField => {
152 let count = values
153 .into_iter()
154 .filter(|value| !matches!(value, Value::Null))
155 .count();
156
157 Ok(Value::Uint(u64::try_from(count).unwrap_or(u64::MAX)))
158 }
159 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
160 kind:
161 crate::db::query::plan::AggregateKind::Sum
162 | crate::db::query::plan::AggregateKind::Avg,
163 } => {
164 let mut sum = None;
165 let mut row_count = 0_u64;
166
167 for value in values {
168 if matches!(value, Value::Null) {
169 continue;
170 }
171
172 let decimal = coerce_numeric_decimal(&value).ok_or_else(|| {
173 QueryError::invariant(
174 "numeric SQL aggregate dispatch encountered non-numeric projected value",
175 )
176 })?;
177 sum = Some(sum.map_or(decimal, |current| add_decimal_terms(current, decimal)));
178 row_count = row_count.saturating_add(1);
179 }
180
181 match strategy.runtime_descriptor() {
182 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
183 kind: crate::db::query::plan::AggregateKind::Sum,
184 } => Ok(sum.map_or(Value::Null, Value::Decimal)),
185 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
186 kind: crate::db::query::plan::AggregateKind::Avg,
187 } => Ok(sum
188 .and_then(|sum| average_decimal_terms(sum, row_count))
189 .map_or(Value::Null, Value::Decimal)),
190 _ => unreachable!("numeric SQL aggregate strategy drifted during reduction"),
191 }
192 }
193 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
194 kind:
195 crate::db::query::plan::AggregateKind::Min
196 | crate::db::query::plan::AggregateKind::Max,
197 } => {
198 let mut selected = None::<Value>;
199
200 for value in values {
201 if matches!(value, Value::Null) {
202 continue;
203 }
204
205 let replace = match selected.as_ref() {
206 None => true,
207 Some(current) => {
208 let ordering =
209 compare_numeric_or_strict_order(&value, current).ok_or_else(
210 || {
211 QueryError::invariant(
212 "extrema SQL aggregate dispatch encountered incomparable projected values",
213 )
214 },
215 )?;
216
217 match strategy.runtime_descriptor() {
218 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
219 kind: crate::db::query::plan::AggregateKind::Min,
220 } => ordering.is_lt(),
221 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
222 kind: crate::db::query::plan::AggregateKind::Max,
223 } => ordering.is_gt(),
224 _ => unreachable!(
225 "extrema SQL aggregate strategy drifted during reduction"
226 ),
227 }
228 }
229 };
230
231 if replace {
232 selected = Some(value);
233 }
234 }
235
236 Ok(selected.unwrap_or(Value::Null))
237 }
238 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
239 | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
240 Err(QueryError::invariant(
241 "prepared SQL scalar aggregate strategy drifted outside SQL support",
242 ))
243 }
244 }
245 }
246
247 fn execute_structural_sql_aggregate_field_projection(
250 &self,
251 query: crate::db::query::intent::StructuralQuery,
252 authority: crate::db::executor::EntityAuthority,
253 ) -> Result<Vec<Value>, QueryError> {
254 let (_, rows, _) = self
255 .execute_structural_sql_projection(query, authority)?
256 .into_parts();
257 let mut projected = Vec::with_capacity(rows.len());
258
259 for row in rows {
260 let [value] = row.as_slice() else {
261 return Err(QueryError::invariant(
262 "structural SQL aggregate projection must emit exactly one field",
263 ));
264 };
265
266 projected.push(value.clone());
267 }
268
269 Ok(projected)
270 }
271
272 pub(in crate::db::session::sql) fn execute_sql_aggregate_dispatch_for_authority(
276 &self,
277 command: SqlGlobalAggregateCommandCore,
278 authority: crate::db::executor::EntityAuthority,
279 label_override: Option<String>,
280 ) -> Result<SqlDispatchResult, QueryError> {
281 let model = authority.model();
282 let strategy = command
283 .prepared_scalar_strategy_with_model(model)
284 .map_err(QueryError::from_sql_lowering_error)?;
285 let label = label_override.unwrap_or_else(|| Self::sql_scalar_aggregate_label(&strategy));
286 let value = match strategy.runtime_descriptor() {
287 PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => {
288 let (_, _, row_count) = self
289 .execute_structural_sql_projection(
290 command
291 .query()
292 .clone()
293 .select_fields([authority.primary_key_name()]),
294 authority,
295 )?
296 .into_parts();
297
298 Value::Uint(u64::from(row_count))
299 }
300 PreparedSqlScalarAggregateRuntimeDescriptor::CountField
301 | PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
302 | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
303 let Some(field) = strategy.projected_field() else {
304 return Err(QueryError::invariant(
305 "field-target SQL aggregate strategy requires projected field label",
306 ));
307 };
308 let values = self.execute_structural_sql_aggregate_field_projection(
309 command.query().clone().select_fields([field]),
310 authority,
311 )?;
312
313 Self::reduce_structural_sql_aggregate_field_values(values, &strategy)?
314 }
315 };
316
317 Ok(SqlDispatchResult::Projection {
318 columns: vec![label],
319 rows: vec![vec![value]],
320 row_count: 1,
321 })
322 }
323
324 pub(in crate::db::session::sql) fn compile_sql_aggregate_command_core_for_authority(
327 parsed: &SqlParsedStatement,
328 authority: crate::db::executor::EntityAuthority,
329 ) -> Result<SqlGlobalAggregateCommandCore, QueryError> {
330 compile_sql_global_aggregate_command_core_from_prepared(
331 parsed.prepare(authority.model().name())?,
332 authority.model(),
333 MissingRowPolicy::Ignore,
334 )
335 .map_err(QueryError::from_sql_lowering_error)
336 }
337
338 fn prepared_sql_scalar_target_slot_required(
341 strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
342 message: &'static str,
343 ) -> Result<crate::db::query::plan::FieldSlot, QueryError> {
344 strategy
345 .target_slot()
346 .cloned()
347 .ok_or_else(|| QueryError::invariant(message))
348 }
349
350 fn execute_prepared_sql_scalar_count_rows<E>(
353 &self,
354 command: &SqlGlobalAggregateCommand<E>,
355 ) -> Result<Value, QueryError>
356 where
357 E: PersistedRow<Canister = C> + EntityValue,
358 {
359 self.execute_load_query_with(command.query(), |load, plan| {
360 load.execute_scalar_terminal_request(
361 plan,
362 crate::db::executor::ScalarTerminalBoundaryRequest::Count,
363 )?
364 .into_count()
365 })
366 .map(|count| Value::Uint(u64::from(count)))
367 }
368
369 fn execute_prepared_sql_scalar_count_field<E>(
372 &self,
373 command: &SqlGlobalAggregateCommand<E>,
374 strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
375 ) -> Result<Value, QueryError>
376 where
377 E: PersistedRow<Canister = C> + EntityValue,
378 {
379 let target_slot = Self::prepared_sql_scalar_target_slot_required(
380 strategy,
381 "prepared COUNT(field) SQL aggregate strategy requires target slot",
382 )?;
383
384 self.execute_load_query_with(command.query(), |load, plan| {
385 load.execute_scalar_projection_boundary(
386 plan,
387 target_slot.clone(),
388 ScalarProjectionBoundaryRequest::CountNonNull,
389 )?
390 .into_count()
391 })
392 .map(|count| Value::Uint(u64::from(count)))
393 }
394
395 fn execute_prepared_sql_scalar_numeric_field<E>(
398 &self,
399 command: &SqlGlobalAggregateCommand<E>,
400 strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
401 request: ScalarNumericFieldBoundaryRequest,
402 message: &'static str,
403 ) -> Result<Value, QueryError>
404 where
405 E: PersistedRow<Canister = C> + EntityValue,
406 {
407 let target_slot = Self::prepared_sql_scalar_target_slot_required(strategy, message)?;
408
409 self.execute_load_query_with(command.query(), |load, plan| {
410 load.execute_numeric_field_boundary(plan, target_slot.clone(), request)
411 })
412 .map(|value| value.map_or(Value::Null, Value::Decimal))
413 }
414
415 fn execute_prepared_sql_scalar_extrema_field<E>(
418 &self,
419 command: &SqlGlobalAggregateCommand<E>,
420 strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
421 kind: crate::db::query::plan::AggregateKind,
422 ) -> Result<Value, QueryError>
423 where
424 E: PersistedRow<Canister = C> + EntityValue,
425 {
426 let target_slot = Self::prepared_sql_scalar_target_slot_required(
427 strategy,
428 "prepared extrema SQL aggregate strategy requires target slot",
429 )?;
430
431 self.execute_load_query_with(command.query(), |load, plan| {
432 load.execute_scalar_extrema_value_boundary(plan, target_slot.clone(), kind)
433 })
434 .map(|value| value.unwrap_or(Value::Null))
435 }
436
437 fn execute_prepared_sql_scalar_aggregate<E>(
441 &self,
442 command: &SqlGlobalAggregateCommand<E>,
443 ) -> Result<Value, QueryError>
444 where
445 E: PersistedRow<Canister = C> + EntityValue,
446 {
447 let strategy = command.prepared_scalar_strategy();
448
449 match strategy.runtime_descriptor() {
450 PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => {
451 self.execute_prepared_sql_scalar_count_rows(command)
452 }
453 PreparedSqlScalarAggregateRuntimeDescriptor::CountField => {
454 self.execute_prepared_sql_scalar_count_field(command, &strategy)
455 }
456 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
457 kind: crate::db::query::plan::AggregateKind::Sum,
458 } => self.execute_prepared_sql_scalar_numeric_field(
459 command,
460 &strategy,
461 ScalarNumericFieldBoundaryRequest::Sum,
462 "prepared SUM(field) SQL aggregate strategy requires target slot",
463 ),
464 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
465 kind: crate::db::query::plan::AggregateKind::Avg,
466 } => self.execute_prepared_sql_scalar_numeric_field(
467 command,
468 &strategy,
469 ScalarNumericFieldBoundaryRequest::Avg,
470 "prepared AVG(field) SQL aggregate strategy requires target slot",
471 ),
472 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { kind } => {
473 self.execute_prepared_sql_scalar_extrema_field(command, &strategy, kind)
474 }
475 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. } => {
476 Err(QueryError::invariant(
477 "prepared SQL scalar aggregate numeric runtime descriptor drift",
478 ))
479 }
480 }
481 }
482
483 pub fn execute_sql_aggregate<E>(&self, sql: &str) -> Result<Value, QueryError>
488 where
489 E: PersistedRow<Canister = C> + EntityValue,
490 {
491 let statement = parse_sql(sql).map_err(QueryError::from_sql_parse_error)?;
494
495 match &statement {
498 SqlStatement::Select(_) if is_sql_global_aggregate_statement(&statement) => {}
499 SqlStatement::Select(statement) if !statement.group_by.is_empty() => {
500 return Err(QueryError::unsupported_query(
501 unsupported_sql_aggregate_grouped_message(),
502 ));
503 }
504 SqlStatement::Delete(_) => {
505 return Err(QueryError::unsupported_query(
506 "execute_sql_aggregate rejects DELETE; use execute_sql_dispatch",
507 ));
508 }
509 _ => {
510 let route = sql_statement_route_from_statement(&statement);
511
512 return Err(QueryError::unsupported_query(
513 unsupported_sql_aggregate_surface_lane_message(&route),
514 ));
515 }
516 }
517
518 let prepared = prepare_sql_statement(statement, E::MODEL.name())
522 .map_err(QueryError::from_sql_lowering_error)?;
523 let command = compile_sql_global_aggregate_command_from_prepared::<E>(
524 prepared.clone(),
525 MissingRowPolicy::Ignore,
526 )
527 .map_err(QueryError::from_sql_lowering_error)?;
528 let strategy = command.prepared_scalar_strategy();
529
530 if strategy.is_distinct() {
533 let dispatch = compile_sql_global_aggregate_command_core_from_prepared(
534 prepared,
535 E::MODEL,
536 MissingRowPolicy::Ignore,
537 )
538 .map_err(QueryError::from_sql_lowering_error)?;
539 let authority = crate::db::executor::EntityAuthority::for_type::<E>();
540 let SqlDispatchResult::Projection { rows, .. } =
541 self.execute_sql_aggregate_dispatch_for_authority(dispatch, authority, None)?
542 else {
543 return Err(QueryError::invariant(
544 "DISTINCT SQL aggregate dispatch must finalize as one projection row",
545 ));
546 };
547 let Some(mut row) = rows.into_iter().next() else {
548 return Err(QueryError::invariant(
549 "DISTINCT SQL aggregate dispatch must emit one projection row",
550 ));
551 };
552 if row.len() != 1 {
553 return Err(QueryError::invariant(
554 "DISTINCT SQL aggregate dispatch must emit exactly one projected value",
555 ));
556 }
557 let value = row.pop().ok_or_else(|| {
558 QueryError::invariant(
559 "DISTINCT SQL aggregate dispatch must emit exactly one projected value",
560 )
561 })?;
562
563 return Ok(value);
564 }
565
566 self.execute_prepared_sql_scalar_aggregate(&command)
570 }
571}