1use crate::{
8 db::{
9 DbSession, MissingRowPolicy, PersistedRow, QueryError,
10 executor::{ScalarNumericFieldBoundaryRequest, ScalarProjectionBoundaryRequest},
11 numeric::{
12 add_decimal_terms, average_decimal_terms, coerce_numeric_decimal,
13 compare_numeric_or_strict_order,
14 },
15 session::sql::surface::sql_statement_route_from_statement,
16 session::sql::{SqlDispatchResult, SqlParsedStatement, SqlStatementRoute},
17 sql::lowering::{
18 PreparedSqlScalarAggregateRuntimeDescriptor, PreparedSqlScalarAggregateStrategy,
19 SqlGlobalAggregateCommand, SqlGlobalAggregateCommandCore,
20 compile_sql_global_aggregate_command_core_from_prepared,
21 compile_sql_global_aggregate_command_from_prepared, is_sql_global_aggregate_statement,
22 prepare_sql_statement,
23 },
24 sql::parser::{SqlStatement, parse_sql},
25 },
26 traits::{CanisterKind, EntityValue},
27 value::Value,
28};
29
30#[derive(Clone, Copy, Debug, Eq, PartialEq)]
31pub(in crate::db::session::sql) enum SqlAggregateSurface {
32 QueryFrom,
33 ExecuteSql,
34 ExecuteSqlGrouped,
35}
36
37pub(in crate::db::session::sql) fn parsed_requires_dedicated_sql_aggregate_lane(
38 parsed: &SqlParsedStatement,
39) -> bool {
40 is_sql_global_aggregate_statement(&parsed.statement)
41}
42
43pub(in crate::db::session::sql) const fn unsupported_sql_aggregate_lane_message(
44 surface: SqlAggregateSurface,
45) -> &'static str {
46 match surface {
47 SqlAggregateSurface::QueryFrom => {
48 "query_from_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
49 }
50 SqlAggregateSurface::ExecuteSql => {
51 "execute_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
52 }
53 SqlAggregateSurface::ExecuteSqlGrouped => {
54 "execute_sql_grouped rejects global aggregate SELECT; use execute_sql_aggregate(...)"
55 }
56 }
57}
58
59const fn unsupported_sql_aggregate_surface_lane_message(route: &SqlStatementRoute) -> &'static str {
60 match route {
61 SqlStatementRoute::Query { .. } => {
62 "execute_sql_aggregate requires constrained global aggregate SELECT"
63 }
64 SqlStatementRoute::Insert { .. } => {
65 "execute_sql_aggregate rejects INSERT; use create(...) or insert(...)"
66 }
67 SqlStatementRoute::Update { .. } => "execute_sql_aggregate rejects UPDATE; use update(...)",
68 SqlStatementRoute::Explain { .. } => "execute_sql_aggregate rejects EXPLAIN",
69 SqlStatementRoute::Describe { .. } => "execute_sql_aggregate rejects DESCRIBE",
70 SqlStatementRoute::ShowIndexes { .. } => "execute_sql_aggregate rejects SHOW INDEXES",
71 SqlStatementRoute::ShowColumns { .. } => "execute_sql_aggregate rejects SHOW COLUMNS",
72 SqlStatementRoute::ShowEntities => "execute_sql_aggregate rejects SHOW ENTITIES",
73 }
74}
75
76const fn unsupported_sql_aggregate_grouped_message() -> &'static str {
77 "execute_sql_aggregate rejects grouped SELECT; use execute_sql_grouped(...)"
78}
79
80impl<C: CanisterKind> DbSession<C> {
81 pub(in crate::db::session::sql) fn sql_scalar_aggregate_label(
84 strategy: &PreparedSqlScalarAggregateStrategy,
85 ) -> String {
86 let kind = match strategy.runtime_descriptor() {
87 PreparedSqlScalarAggregateRuntimeDescriptor::CountRows
88 | PreparedSqlScalarAggregateRuntimeDescriptor::CountField => "COUNT",
89 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
90 kind: crate::db::query::plan::AggregateKind::Sum,
91 } => "SUM",
92 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
93 kind: crate::db::query::plan::AggregateKind::Avg,
94 } => "AVG",
95 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
96 kind: crate::db::query::plan::AggregateKind::Min,
97 } => "MIN",
98 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
99 kind: crate::db::query::plan::AggregateKind::Max,
100 } => "MAX",
101 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
102 | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
103 unreachable!("prepared SQL scalar aggregate strategy drifted outside SQL support")
104 }
105 };
106
107 match strategy.projected_field() {
108 Some(field) if strategy.is_distinct() => format!("{kind}(DISTINCT {field})"),
109 Some(field) => format!("{kind}({field})"),
110 None => format!("{kind}(*)"),
111 }
112 }
113
114 fn dedup_structural_sql_aggregate_input_values(values: Vec<Value>) -> Vec<Value> {
117 let mut deduped = Vec::with_capacity(values.len());
118
119 for value in values {
120 if deduped.iter().any(|current| current == &value) {
121 continue;
122 }
123 deduped.push(value);
124 }
125
126 deduped
127 }
128
129 fn reduce_structural_sql_aggregate_field_values(
132 values: Vec<Value>,
133 strategy: &PreparedSqlScalarAggregateStrategy,
134 ) -> Result<Value, QueryError> {
135 let values = if strategy.is_distinct() {
136 Self::dedup_structural_sql_aggregate_input_values(values)
137 } else {
138 values
139 };
140
141 match strategy.runtime_descriptor() {
142 PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => Err(QueryError::invariant(
143 "COUNT(*) structural reduction does not consume projected field values",
144 )),
145 PreparedSqlScalarAggregateRuntimeDescriptor::CountField => {
146 let count = values
147 .into_iter()
148 .filter(|value| !matches!(value, Value::Null))
149 .count();
150
151 Ok(Value::Uint(u64::try_from(count).unwrap_or(u64::MAX)))
152 }
153 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
154 kind:
155 crate::db::query::plan::AggregateKind::Sum
156 | crate::db::query::plan::AggregateKind::Avg,
157 } => {
158 let mut sum = None;
159 let mut row_count = 0_u64;
160
161 for value in values {
162 if matches!(value, Value::Null) {
163 continue;
164 }
165
166 let decimal = coerce_numeric_decimal(&value).ok_or_else(|| {
167 QueryError::invariant(
168 "numeric SQL aggregate dispatch encountered non-numeric projected value",
169 )
170 })?;
171 sum = Some(sum.map_or(decimal, |current| add_decimal_terms(current, decimal)));
172 row_count = row_count.saturating_add(1);
173 }
174
175 match strategy.runtime_descriptor() {
176 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
177 kind: crate::db::query::plan::AggregateKind::Sum,
178 } => Ok(sum.map_or(Value::Null, Value::Decimal)),
179 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
180 kind: crate::db::query::plan::AggregateKind::Avg,
181 } => Ok(sum
182 .and_then(|sum| average_decimal_terms(sum, row_count))
183 .map_or(Value::Null, Value::Decimal)),
184 _ => unreachable!("numeric SQL aggregate strategy drifted during reduction"),
185 }
186 }
187 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
188 kind:
189 crate::db::query::plan::AggregateKind::Min
190 | crate::db::query::plan::AggregateKind::Max,
191 } => {
192 let mut selected = None::<Value>;
193
194 for value in values {
195 if matches!(value, Value::Null) {
196 continue;
197 }
198
199 let replace = match selected.as_ref() {
200 None => true,
201 Some(current) => {
202 let ordering =
203 compare_numeric_or_strict_order(&value, current).ok_or_else(
204 || {
205 QueryError::invariant(
206 "extrema SQL aggregate dispatch encountered incomparable projected values",
207 )
208 },
209 )?;
210
211 match strategy.runtime_descriptor() {
212 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
213 kind: crate::db::query::plan::AggregateKind::Min,
214 } => ordering.is_lt(),
215 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
216 kind: crate::db::query::plan::AggregateKind::Max,
217 } => ordering.is_gt(),
218 _ => unreachable!(
219 "extrema SQL aggregate strategy drifted during reduction"
220 ),
221 }
222 }
223 };
224
225 if replace {
226 selected = Some(value);
227 }
228 }
229
230 Ok(selected.unwrap_or(Value::Null))
231 }
232 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
233 | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
234 Err(QueryError::invariant(
235 "prepared SQL scalar aggregate strategy drifted outside SQL support",
236 ))
237 }
238 }
239 }
240
241 fn execute_structural_sql_aggregate_field_projection(
244 &self,
245 query: crate::db::query::intent::StructuralQuery,
246 authority: crate::db::executor::EntityAuthority,
247 ) -> Result<Vec<Value>, QueryError> {
248 let (_, rows, _) = self
249 .execute_structural_sql_projection(query, authority)?
250 .into_parts();
251 let mut projected = Vec::with_capacity(rows.len());
252
253 for row in rows {
254 let [value] = row.as_slice() else {
255 return Err(QueryError::invariant(
256 "structural SQL aggregate projection must emit exactly one field",
257 ));
258 };
259
260 projected.push(value.clone());
261 }
262
263 Ok(projected)
264 }
265
266 pub(in crate::db::session::sql) fn execute_sql_aggregate_dispatch_for_authority(
270 &self,
271 command: SqlGlobalAggregateCommandCore,
272 authority: crate::db::executor::EntityAuthority,
273 label_override: Option<String>,
274 ) -> Result<SqlDispatchResult, QueryError> {
275 let model = authority.model();
276 let strategy = command
277 .prepared_scalar_strategy_with_model(model)
278 .map_err(QueryError::from_sql_lowering_error)?;
279 let label = label_override.unwrap_or_else(|| Self::sql_scalar_aggregate_label(&strategy));
280 let value = match strategy.runtime_descriptor() {
281 PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => {
282 let (_, _, row_count) = self
283 .execute_structural_sql_projection(
284 command
285 .query()
286 .clone()
287 .select_fields([authority.primary_key_name()]),
288 authority,
289 )?
290 .into_parts();
291
292 Value::Uint(u64::from(row_count))
293 }
294 PreparedSqlScalarAggregateRuntimeDescriptor::CountField
295 | PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
296 | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
297 let Some(field) = strategy.projected_field() else {
298 return Err(QueryError::invariant(
299 "field-target SQL aggregate strategy requires projected field label",
300 ));
301 };
302 let values = self.execute_structural_sql_aggregate_field_projection(
303 command.query().clone().select_fields([field]),
304 authority,
305 )?;
306
307 Self::reduce_structural_sql_aggregate_field_values(values, &strategy)?
308 }
309 };
310
311 Ok(SqlDispatchResult::Projection {
312 columns: vec![label],
313 rows: vec![vec![value]],
314 row_count: 1,
315 })
316 }
317
318 pub(in crate::db::session::sql) fn compile_sql_aggregate_command_core_for_authority(
321 parsed: &SqlParsedStatement,
322 authority: crate::db::executor::EntityAuthority,
323 ) -> Result<SqlGlobalAggregateCommandCore, QueryError> {
324 compile_sql_global_aggregate_command_core_from_prepared(
325 parsed.prepare(authority.model().name())?,
326 authority.model(),
327 MissingRowPolicy::Ignore,
328 )
329 .map_err(QueryError::from_sql_lowering_error)
330 }
331
332 fn prepared_sql_scalar_target_slot_required(
335 strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
336 message: &'static str,
337 ) -> Result<crate::db::query::plan::FieldSlot, QueryError> {
338 strategy
339 .target_slot()
340 .cloned()
341 .ok_or_else(|| QueryError::invariant(message))
342 }
343
344 fn execute_prepared_sql_scalar_count_rows<E>(
347 &self,
348 command: &SqlGlobalAggregateCommand<E>,
349 ) -> Result<Value, QueryError>
350 where
351 E: PersistedRow<Canister = C> + EntityValue,
352 {
353 self.execute_load_query_with(command.query(), |load, plan| {
354 load.execute_scalar_terminal_request(
355 plan,
356 crate::db::executor::ScalarTerminalBoundaryRequest::Count,
357 )?
358 .into_count()
359 })
360 .map(|count| Value::Uint(u64::from(count)))
361 }
362
363 fn execute_prepared_sql_scalar_count_field<E>(
366 &self,
367 command: &SqlGlobalAggregateCommand<E>,
368 strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
369 ) -> Result<Value, QueryError>
370 where
371 E: PersistedRow<Canister = C> + EntityValue,
372 {
373 let target_slot = Self::prepared_sql_scalar_target_slot_required(
374 strategy,
375 "prepared COUNT(field) SQL aggregate strategy requires target slot",
376 )?;
377
378 self.execute_load_query_with(command.query(), |load, plan| {
379 load.execute_scalar_projection_boundary(
380 plan,
381 target_slot.clone(),
382 ScalarProjectionBoundaryRequest::CountNonNull,
383 )?
384 .into_count()
385 })
386 .map(|count| Value::Uint(u64::from(count)))
387 }
388
389 fn execute_prepared_sql_scalar_numeric_field<E>(
392 &self,
393 command: &SqlGlobalAggregateCommand<E>,
394 strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
395 request: ScalarNumericFieldBoundaryRequest,
396 message: &'static str,
397 ) -> Result<Value, QueryError>
398 where
399 E: PersistedRow<Canister = C> + EntityValue,
400 {
401 let target_slot = Self::prepared_sql_scalar_target_slot_required(strategy, message)?;
402
403 self.execute_load_query_with(command.query(), |load, plan| {
404 load.execute_numeric_field_boundary(plan, target_slot.clone(), request)
405 })
406 .map(|value| value.map_or(Value::Null, Value::Decimal))
407 }
408
409 fn execute_prepared_sql_scalar_extrema_field<E>(
412 &self,
413 command: &SqlGlobalAggregateCommand<E>,
414 strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
415 kind: crate::db::query::plan::AggregateKind,
416 ) -> Result<Value, QueryError>
417 where
418 E: PersistedRow<Canister = C> + EntityValue,
419 {
420 let target_slot = Self::prepared_sql_scalar_target_slot_required(
421 strategy,
422 "prepared extrema SQL aggregate strategy requires target slot",
423 )?;
424
425 self.execute_load_query_with(command.query(), |load, plan| {
426 load.execute_scalar_extrema_value_boundary(plan, target_slot.clone(), kind)
427 })
428 .map(|value| value.unwrap_or(Value::Null))
429 }
430
431 fn execute_prepared_sql_scalar_aggregate<E>(
435 &self,
436 command: &SqlGlobalAggregateCommand<E>,
437 ) -> Result<Value, QueryError>
438 where
439 E: PersistedRow<Canister = C> + EntityValue,
440 {
441 let strategy = command.prepared_scalar_strategy();
442
443 match strategy.runtime_descriptor() {
444 PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => {
445 self.execute_prepared_sql_scalar_count_rows(command)
446 }
447 PreparedSqlScalarAggregateRuntimeDescriptor::CountField => {
448 self.execute_prepared_sql_scalar_count_field(command, &strategy)
449 }
450 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
451 kind: crate::db::query::plan::AggregateKind::Sum,
452 } => self.execute_prepared_sql_scalar_numeric_field(
453 command,
454 &strategy,
455 ScalarNumericFieldBoundaryRequest::Sum,
456 "prepared SUM(field) SQL aggregate strategy requires target slot",
457 ),
458 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
459 kind: crate::db::query::plan::AggregateKind::Avg,
460 } => self.execute_prepared_sql_scalar_numeric_field(
461 command,
462 &strategy,
463 ScalarNumericFieldBoundaryRequest::Avg,
464 "prepared AVG(field) SQL aggregate strategy requires target slot",
465 ),
466 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { kind } => {
467 self.execute_prepared_sql_scalar_extrema_field(command, &strategy, kind)
468 }
469 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. } => {
470 Err(QueryError::invariant(
471 "prepared SQL scalar aggregate numeric runtime descriptor drift",
472 ))
473 }
474 }
475 }
476
477 pub fn execute_sql_aggregate<E>(&self, sql: &str) -> Result<Value, QueryError>
482 where
483 E: PersistedRow<Canister = C> + EntityValue,
484 {
485 let statement = parse_sql(sql).map_err(QueryError::from_sql_parse_error)?;
488
489 match &statement {
492 SqlStatement::Select(_) if is_sql_global_aggregate_statement(&statement) => {}
493 SqlStatement::Select(statement) if !statement.group_by.is_empty() => {
494 return Err(QueryError::unsupported_query(
495 unsupported_sql_aggregate_grouped_message(),
496 ));
497 }
498 SqlStatement::Delete(_) => {
499 return Err(QueryError::unsupported_query(
500 "execute_sql_aggregate rejects DELETE; use delete::<E>()",
501 ));
502 }
503 _ => {
504 let route = sql_statement_route_from_statement(&statement);
505
506 return Err(QueryError::unsupported_query(
507 unsupported_sql_aggregate_surface_lane_message(&route),
508 ));
509 }
510 }
511
512 let prepared = prepare_sql_statement(statement, E::MODEL.name())
516 .map_err(QueryError::from_sql_lowering_error)?;
517 let command = compile_sql_global_aggregate_command_from_prepared::<E>(
518 prepared.clone(),
519 MissingRowPolicy::Ignore,
520 )
521 .map_err(QueryError::from_sql_lowering_error)?;
522 let strategy = command.prepared_scalar_strategy();
523
524 if strategy.is_distinct() {
527 let dispatch = compile_sql_global_aggregate_command_core_from_prepared(
528 prepared,
529 E::MODEL,
530 MissingRowPolicy::Ignore,
531 )
532 .map_err(QueryError::from_sql_lowering_error)?;
533 let authority = crate::db::executor::EntityAuthority::for_type::<E>();
534 let SqlDispatchResult::Projection { rows, .. } =
535 self.execute_sql_aggregate_dispatch_for_authority(dispatch, authority, None)?
536 else {
537 return Err(QueryError::invariant(
538 "DISTINCT SQL aggregate dispatch must finalize as one projection row",
539 ));
540 };
541 let Some(mut row) = rows.into_iter().next() else {
542 return Err(QueryError::invariant(
543 "DISTINCT SQL aggregate dispatch must emit one projection row",
544 ));
545 };
546 if row.len() != 1 {
547 return Err(QueryError::invariant(
548 "DISTINCT SQL aggregate dispatch must emit exactly one projected value",
549 ));
550 }
551 let value = row.pop().ok_or_else(|| {
552 QueryError::invariant(
553 "DISTINCT SQL aggregate dispatch must emit exactly one projected value",
554 )
555 })?;
556
557 return Ok(value);
558 }
559
560 self.execute_prepared_sql_scalar_aggregate(&command)
564 }
565}