1use crate::{
7 db::{
8 DbSession, MissingRowPolicy, PersistedRow, QueryError,
9 executor::{ScalarNumericFieldBoundaryRequest, ScalarProjectionBoundaryRequest},
10 numeric::{
11 add_decimal_terms, average_decimal_terms, coerce_numeric_decimal,
12 compare_numeric_or_strict_order,
13 },
14 session::sql::surface::sql_statement_route_from_statement,
15 session::sql::{SqlDispatchResult, SqlParsedStatement, SqlStatementRoute},
16 sql::lowering::{
17 PreparedSqlScalarAggregateRuntimeDescriptor, PreparedSqlScalarAggregateStrategy,
18 SqlGlobalAggregateCommand, SqlGlobalAggregateCommandCore,
19 compile_sql_global_aggregate_command_core_from_prepared,
20 compile_sql_global_aggregate_command_from_prepared, is_sql_global_aggregate_statement,
21 prepare_sql_statement,
22 },
23 sql::parser::{SqlStatement, parse_sql},
24 },
25 traits::{CanisterKind, EntityValue},
26 value::Value,
27};
28
29#[derive(Clone, Copy, Debug, Eq, PartialEq)]
30pub(in crate::db::session::sql) enum SqlAggregateSurface {
31 QueryFrom,
32 ExecuteSql,
33 ExecuteSqlGrouped,
34}
35
36pub(in crate::db::session::sql) fn parsed_requires_dedicated_sql_aggregate_lane(
37 parsed: &SqlParsedStatement,
38) -> bool {
39 is_sql_global_aggregate_statement(&parsed.statement)
40}
41
42pub(in crate::db::session::sql) const fn unsupported_sql_aggregate_lane_message(
43 surface: SqlAggregateSurface,
44) -> &'static str {
45 match surface {
46 SqlAggregateSurface::QueryFrom => {
47 "query_from_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
48 }
49 SqlAggregateSurface::ExecuteSql => {
50 "execute_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
51 }
52 SqlAggregateSurface::ExecuteSqlGrouped => {
53 "execute_sql_grouped rejects global aggregate SELECT; use execute_sql_aggregate(...)"
54 }
55 }
56}
57
58const fn unsupported_sql_aggregate_surface_lane_message(route: &SqlStatementRoute) -> &'static str {
59 match route {
60 SqlStatementRoute::Query { .. } => {
61 "execute_sql_aggregate requires constrained global aggregate SELECT"
62 }
63 SqlStatementRoute::Explain { .. } => {
64 "execute_sql_aggregate rejects EXPLAIN; use execute_sql_dispatch"
65 }
66 SqlStatementRoute::Describe { .. } => {
67 "execute_sql_aggregate rejects DESCRIBE; use execute_sql_dispatch"
68 }
69 SqlStatementRoute::ShowIndexes { .. } => {
70 "execute_sql_aggregate rejects SHOW INDEXES; use execute_sql_dispatch"
71 }
72 SqlStatementRoute::ShowColumns { .. } => {
73 "execute_sql_aggregate rejects SHOW COLUMNS; use execute_sql_dispatch"
74 }
75 SqlStatementRoute::ShowEntities => {
76 "execute_sql_aggregate rejects SHOW ENTITIES; use execute_sql_dispatch"
77 }
78 }
79}
80
81const fn unsupported_sql_aggregate_grouped_message() -> &'static str {
82 "execute_sql_aggregate rejects grouped SELECT; use execute_sql_grouped(...)"
83}
84
85impl<C: CanisterKind> DbSession<C> {
86 pub(in crate::db::session::sql) fn sql_scalar_aggregate_label(
89 strategy: &PreparedSqlScalarAggregateStrategy,
90 ) -> String {
91 let kind = match strategy.runtime_descriptor() {
92 PreparedSqlScalarAggregateRuntimeDescriptor::CountRows
93 | PreparedSqlScalarAggregateRuntimeDescriptor::CountField => "COUNT",
94 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
95 kind: crate::db::query::plan::AggregateKind::Sum,
96 } => "SUM",
97 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
98 kind: crate::db::query::plan::AggregateKind::Avg,
99 } => "AVG",
100 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
101 kind: crate::db::query::plan::AggregateKind::Min,
102 } => "MIN",
103 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
104 kind: crate::db::query::plan::AggregateKind::Max,
105 } => "MAX",
106 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
107 | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
108 unreachable!("prepared SQL scalar aggregate strategy drifted outside SQL support")
109 }
110 };
111
112 match strategy.projected_field() {
113 Some(field) => format!("{kind}({field})"),
114 None => format!("{kind}(*)"),
115 }
116 }
117
118 fn reduce_structural_sql_aggregate_field_values(
121 values: Vec<Value>,
122 strategy: &PreparedSqlScalarAggregateStrategy,
123 ) -> Result<Value, QueryError> {
124 match strategy.runtime_descriptor() {
125 PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => Err(QueryError::invariant(
126 "COUNT(*) structural reduction does not consume projected field values",
127 )),
128 PreparedSqlScalarAggregateRuntimeDescriptor::CountField => {
129 let count = values
130 .into_iter()
131 .filter(|value| !matches!(value, Value::Null))
132 .count();
133
134 Ok(Value::Uint(u64::try_from(count).unwrap_or(u64::MAX)))
135 }
136 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
137 kind:
138 crate::db::query::plan::AggregateKind::Sum
139 | crate::db::query::plan::AggregateKind::Avg,
140 } => {
141 let mut sum = None;
142 let mut row_count = 0_u64;
143
144 for value in values {
145 if matches!(value, Value::Null) {
146 continue;
147 }
148
149 let decimal = coerce_numeric_decimal(&value).ok_or_else(|| {
150 QueryError::invariant(
151 "numeric SQL aggregate dispatch encountered non-numeric projected value",
152 )
153 })?;
154 sum = Some(sum.map_or(decimal, |current| add_decimal_terms(current, decimal)));
155 row_count = row_count.saturating_add(1);
156 }
157
158 match strategy.runtime_descriptor() {
159 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
160 kind: crate::db::query::plan::AggregateKind::Sum,
161 } => Ok(sum.map_or(Value::Null, Value::Decimal)),
162 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
163 kind: crate::db::query::plan::AggregateKind::Avg,
164 } => Ok(sum
165 .and_then(|sum| average_decimal_terms(sum, row_count))
166 .map_or(Value::Null, Value::Decimal)),
167 _ => unreachable!("numeric SQL aggregate strategy drifted during reduction"),
168 }
169 }
170 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
171 kind:
172 crate::db::query::plan::AggregateKind::Min
173 | crate::db::query::plan::AggregateKind::Max,
174 } => {
175 let mut selected = None::<Value>;
176
177 for value in values {
178 if matches!(value, Value::Null) {
179 continue;
180 }
181
182 let replace = match selected.as_ref() {
183 None => true,
184 Some(current) => {
185 let ordering =
186 compare_numeric_or_strict_order(&value, current).ok_or_else(
187 || {
188 QueryError::invariant(
189 "extrema SQL aggregate dispatch encountered incomparable projected values",
190 )
191 },
192 )?;
193
194 match strategy.runtime_descriptor() {
195 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
196 kind: crate::db::query::plan::AggregateKind::Min,
197 } => ordering.is_lt(),
198 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField {
199 kind: crate::db::query::plan::AggregateKind::Max,
200 } => ordering.is_gt(),
201 _ => unreachable!(
202 "extrema SQL aggregate strategy drifted during reduction"
203 ),
204 }
205 }
206 };
207
208 if replace {
209 selected = Some(value);
210 }
211 }
212
213 Ok(selected.unwrap_or(Value::Null))
214 }
215 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
216 | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
217 Err(QueryError::invariant(
218 "prepared SQL scalar aggregate strategy drifted outside SQL support",
219 ))
220 }
221 }
222 }
223
224 fn execute_structural_sql_aggregate_field_projection(
227 &self,
228 query: crate::db::query::intent::StructuralQuery,
229 authority: crate::db::executor::EntityAuthority,
230 ) -> Result<Vec<Value>, QueryError> {
231 let (_, rows, _) = self
232 .execute_structural_sql_projection(query, authority)?
233 .into_parts();
234 let mut projected = Vec::with_capacity(rows.len());
235
236 for row in rows {
237 let [value] = row.as_slice() else {
238 return Err(QueryError::invariant(
239 "structural SQL aggregate projection must emit exactly one field",
240 ));
241 };
242
243 projected.push(value.clone());
244 }
245
246 Ok(projected)
247 }
248
249 pub(in crate::db::session::sql) fn execute_sql_aggregate_dispatch_for_authority(
253 &self,
254 command: SqlGlobalAggregateCommandCore,
255 authority: crate::db::executor::EntityAuthority,
256 ) -> Result<SqlDispatchResult, QueryError> {
257 let model = authority.model();
258 let strategy = command
259 .prepared_scalar_strategy_with_model(model)
260 .map_err(QueryError::from_sql_lowering_error)?;
261 let label = Self::sql_scalar_aggregate_label(&strategy);
262 let value = match strategy.runtime_descriptor() {
263 PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => {
264 let (_, _, row_count) = self
265 .execute_structural_sql_projection(
266 command
267 .query()
268 .clone()
269 .select_fields([authority.primary_key_name()]),
270 authority,
271 )?
272 .into_parts();
273
274 Value::Uint(u64::from(row_count))
275 }
276 PreparedSqlScalarAggregateRuntimeDescriptor::CountField
277 | PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. }
278 | PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { .. } => {
279 let Some(field) = strategy.projected_field() else {
280 return Err(QueryError::invariant(
281 "field-target SQL aggregate strategy requires projected field label",
282 ));
283 };
284 let values = self.execute_structural_sql_aggregate_field_projection(
285 command.query().clone().select_fields([field]),
286 authority,
287 )?;
288
289 Self::reduce_structural_sql_aggregate_field_values(values, &strategy)?
290 }
291 };
292
293 Ok(SqlDispatchResult::Projection {
294 columns: vec![label],
295 rows: vec![vec![value]],
296 row_count: 1,
297 })
298 }
299
300 pub(in crate::db::session::sql) fn compile_sql_aggregate_command_core_for_authority(
303 parsed: &SqlParsedStatement,
304 authority: crate::db::executor::EntityAuthority,
305 ) -> Result<SqlGlobalAggregateCommandCore, QueryError> {
306 compile_sql_global_aggregate_command_core_from_prepared(
307 parsed.prepare(authority.model().name())?,
308 authority.model(),
309 MissingRowPolicy::Ignore,
310 )
311 .map_err(QueryError::from_sql_lowering_error)
312 }
313
314 fn prepared_sql_scalar_target_slot_required(
317 strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
318 message: &'static str,
319 ) -> Result<crate::db::query::plan::FieldSlot, QueryError> {
320 strategy
321 .target_slot()
322 .cloned()
323 .ok_or_else(|| QueryError::invariant(message))
324 }
325
326 fn execute_prepared_sql_scalar_count_rows<E>(
329 &self,
330 command: &SqlGlobalAggregateCommand<E>,
331 ) -> Result<Value, QueryError>
332 where
333 E: PersistedRow<Canister = C> + EntityValue,
334 {
335 self.execute_load_query_with(command.query(), |load, plan| {
336 load.execute_scalar_terminal_request(
337 plan,
338 crate::db::executor::ScalarTerminalBoundaryRequest::Count,
339 )?
340 .into_count()
341 })
342 .map(|count| Value::Uint(u64::from(count)))
343 }
344
345 fn execute_prepared_sql_scalar_count_field<E>(
348 &self,
349 command: &SqlGlobalAggregateCommand<E>,
350 strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
351 ) -> Result<Value, QueryError>
352 where
353 E: PersistedRow<Canister = C> + EntityValue,
354 {
355 let target_slot = Self::prepared_sql_scalar_target_slot_required(
356 strategy,
357 "prepared COUNT(field) SQL aggregate strategy requires target slot",
358 )?;
359
360 self.execute_load_query_with(command.query(), |load, plan| {
361 load.execute_scalar_projection_boundary(
362 plan,
363 target_slot.clone(),
364 ScalarProjectionBoundaryRequest::CountNonNull,
365 )?
366 .into_count()
367 })
368 .map(|count| Value::Uint(u64::from(count)))
369 }
370
371 fn execute_prepared_sql_scalar_numeric_field<E>(
374 &self,
375 command: &SqlGlobalAggregateCommand<E>,
376 strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
377 request: ScalarNumericFieldBoundaryRequest,
378 message: &'static str,
379 ) -> Result<Value, QueryError>
380 where
381 E: PersistedRow<Canister = C> + EntityValue,
382 {
383 let target_slot = Self::prepared_sql_scalar_target_slot_required(strategy, message)?;
384
385 self.execute_load_query_with(command.query(), |load, plan| {
386 load.execute_numeric_field_boundary(plan, target_slot.clone(), request)
387 })
388 .map(|value| value.map_or(Value::Null, Value::Decimal))
389 }
390
391 fn execute_prepared_sql_scalar_extrema_field<E>(
394 &self,
395 command: &SqlGlobalAggregateCommand<E>,
396 strategy: &crate::db::sql::lowering::PreparedSqlScalarAggregateStrategy,
397 kind: crate::db::query::plan::AggregateKind,
398 ) -> Result<Value, QueryError>
399 where
400 E: PersistedRow<Canister = C> + EntityValue,
401 {
402 let target_slot = Self::prepared_sql_scalar_target_slot_required(
403 strategy,
404 "prepared extrema SQL aggregate strategy requires target slot",
405 )?;
406
407 self.execute_load_query_with(command.query(), |load, plan| {
408 load.execute_scalar_extrema_value_boundary(plan, target_slot.clone(), kind)
409 })
410 .map(|value| value.unwrap_or(Value::Null))
411 }
412
413 fn execute_prepared_sql_scalar_aggregate<E>(
417 &self,
418 command: &SqlGlobalAggregateCommand<E>,
419 ) -> Result<Value, QueryError>
420 where
421 E: PersistedRow<Canister = C> + EntityValue,
422 {
423 let strategy = command.prepared_scalar_strategy();
424
425 match strategy.runtime_descriptor() {
426 PreparedSqlScalarAggregateRuntimeDescriptor::CountRows => {
427 self.execute_prepared_sql_scalar_count_rows(command)
428 }
429 PreparedSqlScalarAggregateRuntimeDescriptor::CountField => {
430 self.execute_prepared_sql_scalar_count_field(command, &strategy)
431 }
432 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
433 kind: crate::db::query::plan::AggregateKind::Sum,
434 } => self.execute_prepared_sql_scalar_numeric_field(
435 command,
436 &strategy,
437 ScalarNumericFieldBoundaryRequest::Sum,
438 "prepared SUM(field) SQL aggregate strategy requires target slot",
439 ),
440 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField {
441 kind: crate::db::query::plan::AggregateKind::Avg,
442 } => self.execute_prepared_sql_scalar_numeric_field(
443 command,
444 &strategy,
445 ScalarNumericFieldBoundaryRequest::Avg,
446 "prepared AVG(field) SQL aggregate strategy requires target slot",
447 ),
448 PreparedSqlScalarAggregateRuntimeDescriptor::ExtremalWinnerField { kind } => {
449 self.execute_prepared_sql_scalar_extrema_field(command, &strategy, kind)
450 }
451 PreparedSqlScalarAggregateRuntimeDescriptor::NumericField { .. } => {
452 Err(QueryError::invariant(
453 "prepared SQL scalar aggregate numeric runtime descriptor drift",
454 ))
455 }
456 }
457 }
458
459 pub fn execute_sql_aggregate<E>(&self, sql: &str) -> Result<Value, QueryError>
464 where
465 E: PersistedRow<Canister = C> + EntityValue,
466 {
467 let statement = parse_sql(sql).map_err(QueryError::from_sql_parse_error)?;
470
471 match &statement {
474 SqlStatement::Select(_) if is_sql_global_aggregate_statement(&statement) => {}
475 SqlStatement::Select(statement) if !statement.group_by.is_empty() => {
476 return Err(QueryError::unsupported_query(
477 unsupported_sql_aggregate_grouped_message(),
478 ));
479 }
480 SqlStatement::Delete(_) => {
481 return Err(QueryError::unsupported_query(
482 "execute_sql_aggregate rejects DELETE; use execute_sql_dispatch",
483 ));
484 }
485 _ => {
486 let route = sql_statement_route_from_statement(&statement);
487
488 return Err(QueryError::unsupported_query(
489 unsupported_sql_aggregate_surface_lane_message(&route),
490 ));
491 }
492 }
493
494 let command = compile_sql_global_aggregate_command_from_prepared::<E>(
498 prepare_sql_statement(statement, E::MODEL.name())
499 .map_err(QueryError::from_sql_lowering_error)?,
500 MissingRowPolicy::Ignore,
501 )
502 .map_err(QueryError::from_sql_lowering_error)?;
503
504 self.execute_prepared_sql_scalar_aggregate(&command)
508 }
509}