1use crate::{
2 db::{
3 DbSession, EntityResponse, EntitySchemaDescription, MissingRowPolicy,
4 PagedGroupedExecutionWithTrace, ProjectionResponse, Query, QueryError,
5 query::{
6 builder::aggregate::{AggregateExpr, avg, count, count_by, max_by, min_by, sum},
7 intent::IntentError,
8 plan::{
9 AggregateKind, FieldSlot, QueryMode,
10 expr::{Expr, ProjectionField},
11 },
12 },
13 sql::lowering::{
14 SqlCommand, SqlGlobalAggregateCommand, SqlGlobalAggregateTerminal, SqlLoweringError,
15 compile_sql_command, compile_sql_global_aggregate_command,
16 },
17 sql::parser::{SqlExplainMode, SqlExplainTarget, SqlStatement, parse_sql},
18 },
19 error::{ErrorClass, ErrorOrigin, InternalError},
20 traits::{CanisterKind, EntityKind, EntityValue},
21 value::Value,
22};
23
24#[derive(Clone, Debug, Eq, PartialEq)]
32pub enum SqlStatementRoute {
33 Query { entity: String },
34 Explain { entity: String },
35 Describe { entity: String },
36 ShowIndexes { entity: String },
37 ShowEntities,
38}
39
40impl SqlStatementRoute {
41 #[must_use]
46 pub const fn entity(&self) -> &str {
47 match self {
48 Self::Query { entity }
49 | Self::Explain { entity }
50 | Self::Describe { entity }
51 | Self::ShowIndexes { entity } => entity.as_str(),
52 Self::ShowEntities => "",
53 }
54 }
55
56 #[must_use]
58 pub const fn is_explain(&self) -> bool {
59 matches!(self, Self::Explain { .. })
60 }
61
62 #[must_use]
64 pub const fn is_describe(&self) -> bool {
65 matches!(self, Self::Describe { .. })
66 }
67
68 #[must_use]
70 pub const fn is_show_indexes(&self) -> bool {
71 matches!(self, Self::ShowIndexes { .. })
72 }
73
74 #[must_use]
76 pub const fn is_show_entities(&self) -> bool {
77 matches!(self, Self::ShowEntities)
78 }
79}
80
81fn map_sql_lowering_error(err: SqlLoweringError) -> QueryError {
83 match err {
84 SqlLoweringError::Query(err) => err,
85 SqlLoweringError::Parse(crate::db::sql::parser::SqlParseError::UnsupportedFeature {
86 feature,
87 }) => QueryError::execute(InternalError::query_unsupported_sql_feature(feature)),
88 other => QueryError::execute(InternalError::classified(
89 ErrorClass::Unsupported,
90 ErrorOrigin::Query,
91 format!("SQL query is not executable in this release: {other}"),
92 )),
93 }
94}
95
96fn map_sql_parse_error(err: crate::db::sql::parser::SqlParseError) -> QueryError {
99 map_sql_lowering_error(SqlLoweringError::Parse(err))
100}
101
102fn resolve_sql_aggregate_target_slot<E: EntityKind>(field: &str) -> Result<FieldSlot, QueryError> {
105 FieldSlot::resolve(E::MODEL, field).ok_or_else(|| {
106 QueryError::execute(crate::db::error::executor_unsupported(format!(
107 "unknown aggregate target field: {field}",
108 )))
109 })
110}
111
112fn sql_global_aggregate_terminal_to_expr<E: EntityKind>(
115 terminal: &SqlGlobalAggregateTerminal,
116) -> Result<AggregateExpr, QueryError> {
117 match terminal {
118 SqlGlobalAggregateTerminal::CountRows => Ok(count()),
119 SqlGlobalAggregateTerminal::CountField(field) => {
120 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
121
122 Ok(count_by(field.as_str()))
123 }
124 SqlGlobalAggregateTerminal::SumField(field) => {
125 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
126
127 Ok(sum(field.as_str()))
128 }
129 SqlGlobalAggregateTerminal::AvgField(field) => {
130 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
131
132 Ok(avg(field.as_str()))
133 }
134 SqlGlobalAggregateTerminal::MinField(field) => {
135 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
136
137 Ok(min_by(field.as_str()))
138 }
139 SqlGlobalAggregateTerminal::MaxField(field) => {
140 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
141
142 Ok(max_by(field.as_str()))
143 }
144 }
145}
146
147fn projection_label_from_aggregate(aggregate: &AggregateExpr) -> String {
149 let kind = match aggregate.kind() {
150 AggregateKind::Count => "COUNT",
151 AggregateKind::Sum => "SUM",
152 AggregateKind::Avg => "AVG",
153 AggregateKind::Exists => "EXISTS",
154 AggregateKind::First => "FIRST",
155 AggregateKind::Last => "LAST",
156 AggregateKind::Min => "MIN",
157 AggregateKind::Max => "MAX",
158 };
159 let distinct = if aggregate.is_distinct() {
160 "DISTINCT "
161 } else {
162 ""
163 };
164
165 if let Some(field) = aggregate.target_field() {
166 return format!("{kind}({distinct}{field})");
167 }
168
169 format!("{kind}({distinct}*)")
170}
171
172fn projection_label_from_expr(expr: &Expr, ordinal: usize) -> String {
174 match expr {
175 Expr::Field(field) => field.as_str().to_string(),
176 Expr::Aggregate(aggregate) => projection_label_from_aggregate(aggregate),
177 Expr::Alias { name, .. } => name.as_str().to_string(),
178 Expr::Literal(_) | Expr::Unary { .. } | Expr::Binary { .. } => {
179 format!("expr_{ordinal}")
180 }
181 }
182}
183
184fn projection_labels_from_query<E: EntityKind>(
186 query: &Query<E>,
187) -> Result<Vec<String>, QueryError> {
188 let projection = query.plan()?.projection_spec();
189 let mut labels = Vec::with_capacity(projection.len());
190
191 for (ordinal, field) in projection.fields().enumerate() {
192 match field {
193 ProjectionField::Scalar {
194 expr: _,
195 alias: Some(alias),
196 } => labels.push(alias.as_str().to_string()),
197 ProjectionField::Scalar { expr, alias: None } => {
198 labels.push(projection_label_from_expr(expr, ordinal));
199 }
200 }
201 }
202
203 Ok(labels)
204}
205
206impl<C: CanisterKind> DbSession<C> {
207 pub fn sql_statement_route(&self, sql: &str) -> Result<SqlStatementRoute, QueryError> {
212 let statement = parse_sql(sql).map_err(map_sql_parse_error)?;
213 match statement {
214 SqlStatement::Select(select) => Ok(SqlStatementRoute::Query {
215 entity: select.entity,
216 }),
217 SqlStatement::Delete(delete) => Ok(SqlStatementRoute::Query {
218 entity: delete.entity,
219 }),
220 SqlStatement::Explain(explain) => match explain.statement {
221 SqlExplainTarget::Select(select) => Ok(SqlStatementRoute::Explain {
222 entity: select.entity,
223 }),
224 SqlExplainTarget::Delete(delete) => Ok(SqlStatementRoute::Explain {
225 entity: delete.entity,
226 }),
227 },
228 SqlStatement::Describe(describe) => Ok(SqlStatementRoute::Describe {
229 entity: describe.entity,
230 }),
231 SqlStatement::ShowIndexes(show_indexes) => Ok(SqlStatementRoute::ShowIndexes {
232 entity: show_indexes.entity,
233 }),
234 SqlStatement::ShowEntities(_) => Ok(SqlStatementRoute::ShowEntities),
235 }
236 }
237
238 pub fn describe_sql<E>(&self, sql: &str) -> Result<EntitySchemaDescription, QueryError>
240 where
241 E: EntityKind<Canister = C>,
242 {
243 let command = compile_sql_command::<E>(sql, MissingRowPolicy::Ignore)
244 .map_err(map_sql_lowering_error)?;
245
246 match command {
247 SqlCommand::DescribeEntity => Ok(self.describe_entity::<E>()),
248 SqlCommand::Query(_)
249 | SqlCommand::Explain { .. }
250 | SqlCommand::ExplainGlobalAggregate { .. }
251 | SqlCommand::ShowIndexesEntity
252 | SqlCommand::ShowEntities => Err(QueryError::execute(InternalError::classified(
253 ErrorClass::Unsupported,
254 ErrorOrigin::Query,
255 "describe_sql requires a DESCRIBE statement",
256 ))),
257 }
258 }
259
260 pub fn show_indexes_sql<E>(&self, sql: &str) -> Result<Vec<String>, QueryError>
262 where
263 E: EntityKind<Canister = C>,
264 {
265 let command = compile_sql_command::<E>(sql, MissingRowPolicy::Ignore)
266 .map_err(map_sql_lowering_error)?;
267
268 match command {
269 SqlCommand::ShowIndexesEntity => Ok(self.show_indexes::<E>()),
270 SqlCommand::Query(_)
271 | SqlCommand::Explain { .. }
272 | SqlCommand::ExplainGlobalAggregate { .. }
273 | SqlCommand::DescribeEntity
274 | SqlCommand::ShowEntities => Err(QueryError::execute(InternalError::classified(
275 ErrorClass::Unsupported,
276 ErrorOrigin::Query,
277 "show_indexes_sql requires a SHOW INDEXES statement",
278 ))),
279 }
280 }
281
282 pub fn show_entities_sql(&self, sql: &str) -> Result<Vec<String>, QueryError> {
284 let statement = self.sql_statement_route(sql)?;
285 if !statement.is_show_entities() {
286 return Err(QueryError::execute(InternalError::classified(
287 ErrorClass::Unsupported,
288 ErrorOrigin::Query,
289 "show_entities_sql requires a SHOW ENTITIES statement",
290 )));
291 }
292
293 Ok(self.show_entities())
294 }
295
296 pub fn query_from_sql<E>(&self, sql: &str) -> Result<Query<E>, QueryError>
301 where
302 E: EntityKind<Canister = C>,
303 {
304 let command = compile_sql_command::<E>(sql, MissingRowPolicy::Ignore)
305 .map_err(map_sql_lowering_error)?;
306
307 match command {
308 SqlCommand::Query(query) => Ok(query),
309 SqlCommand::Explain { .. } | SqlCommand::ExplainGlobalAggregate { .. } => {
310 Err(QueryError::execute(InternalError::classified(
311 ErrorClass::Unsupported,
312 ErrorOrigin::Query,
313 "query_from_sql does not accept EXPLAIN statements; use explain_sql(...)",
314 )))
315 }
316 SqlCommand::DescribeEntity => Err(QueryError::execute(InternalError::classified(
317 ErrorClass::Unsupported,
318 ErrorOrigin::Query,
319 "query_from_sql does not accept DESCRIBE statements; use describe_sql(...)",
320 ))),
321 SqlCommand::ShowIndexesEntity => Err(QueryError::execute(InternalError::classified(
322 ErrorClass::Unsupported,
323 ErrorOrigin::Query,
324 "query_from_sql does not accept SHOW INDEXES statements; use show_indexes_sql(...)",
325 ))),
326 SqlCommand::ShowEntities => Err(QueryError::execute(InternalError::classified(
327 ErrorClass::Unsupported,
328 ErrorOrigin::Query,
329 "query_from_sql does not accept SHOW ENTITIES statements; use show_entities_sql(...)",
330 ))),
331 }
332 }
333
334 pub fn sql_projection_columns<E>(&self, sql: &str) -> Result<Vec<String>, QueryError>
336 where
337 E: EntityKind<Canister = C>,
338 {
339 let query = self.query_from_sql::<E>(sql)?;
340 if query.has_grouping() {
341 return Err(QueryError::Intent(
342 IntentError::GroupedRequiresExecuteGrouped,
343 ));
344 }
345
346 match query.mode() {
347 QueryMode::Load(_) => projection_labels_from_query(&query),
348 QueryMode::Delete(_) => Err(QueryError::execute(InternalError::classified(
349 ErrorClass::Unsupported,
350 ErrorOrigin::Query,
351 "sql_projection_columns only supports SELECT statements",
352 ))),
353 }
354 }
355
356 pub fn execute_sql<E>(&self, sql: &str) -> Result<EntityResponse<E>, QueryError>
358 where
359 E: EntityKind<Canister = C> + EntityValue,
360 {
361 let query = self.query_from_sql::<E>(sql)?;
362 if query.has_grouping() {
363 return Err(QueryError::Intent(
364 IntentError::GroupedRequiresExecuteGrouped,
365 ));
366 }
367
368 self.execute_query(&query)
369 }
370
371 pub fn execute_sql_projection<E>(&self, sql: &str) -> Result<ProjectionResponse<E>, QueryError>
376 where
377 E: EntityKind<Canister = C> + EntityValue,
378 {
379 let query = self.query_from_sql::<E>(sql)?;
380 if query.has_grouping() {
381 return Err(QueryError::Intent(
382 IntentError::GroupedRequiresExecuteGrouped,
383 ));
384 }
385
386 match query.mode() {
387 QueryMode::Load(_) => {
388 self.execute_load_query_with(&query, |load, plan| load.execute_projection(plan))
389 }
390 QueryMode::Delete(_) => Err(QueryError::execute(InternalError::classified(
391 ErrorClass::Unsupported,
392 ErrorOrigin::Query,
393 "execute_sql_projection only supports SELECT statements",
394 ))),
395 }
396 }
397
398 pub fn execute_sql_aggregate<E>(&self, sql: &str) -> Result<Value, QueryError>
403 where
404 E: EntityKind<Canister = C> + EntityValue,
405 {
406 let command = compile_sql_global_aggregate_command::<E>(sql, MissingRowPolicy::Ignore)
407 .map_err(map_sql_lowering_error)?;
408
409 match command.terminal() {
410 SqlGlobalAggregateTerminal::CountRows => self
411 .execute_load_query_with(command.query(), |load, plan| load.aggregate_count(plan))
412 .map(|count| Value::Uint(u64::from(count))),
413 SqlGlobalAggregateTerminal::CountField(field) => {
414 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
415 self.execute_load_query_with(command.query(), |load, plan| {
416 load.values_by_slot(plan, target_slot)
417 })
418 .map(|values| {
419 let count = values
420 .into_iter()
421 .filter(|value| !matches!(value, Value::Null))
422 .count();
423 Value::Uint(u64::try_from(count).unwrap_or(u64::MAX))
424 })
425 }
426 SqlGlobalAggregateTerminal::SumField(field) => {
427 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
428 self.execute_load_query_with(command.query(), |load, plan| {
429 load.aggregate_sum_by_slot(plan, target_slot)
430 })
431 .map(|value| value.map_or(Value::Null, Value::Decimal))
432 }
433 SqlGlobalAggregateTerminal::AvgField(field) => {
434 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
435 self.execute_load_query_with(command.query(), |load, plan| {
436 load.aggregate_avg_by_slot(plan, target_slot)
437 })
438 .map(|value| value.map_or(Value::Null, Value::Decimal))
439 }
440 SqlGlobalAggregateTerminal::MinField(field) => {
441 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
442 let min_id = self.execute_load_query_with(command.query(), |load, plan| {
443 load.aggregate_min_by_slot(plan, target_slot)
444 })?;
445
446 match min_id {
447 Some(id) => self
448 .load::<E>()
449 .by_id(id)
450 .first_value_by(field)
451 .map(|value| value.unwrap_or(Value::Null)),
452 None => Ok(Value::Null),
453 }
454 }
455 SqlGlobalAggregateTerminal::MaxField(field) => {
456 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
457 let max_id = self.execute_load_query_with(command.query(), |load, plan| {
458 load.aggregate_max_by_slot(plan, target_slot)
459 })?;
460
461 match max_id {
462 Some(id) => self
463 .load::<E>()
464 .by_id(id)
465 .first_value_by(field)
466 .map(|value| value.unwrap_or(Value::Null)),
467 None => Ok(Value::Null),
468 }
469 }
470 }
471 }
472
473 pub fn execute_sql_grouped<E>(
475 &self,
476 sql: &str,
477 cursor_token: Option<&str>,
478 ) -> Result<PagedGroupedExecutionWithTrace, QueryError>
479 where
480 E: EntityKind<Canister = C> + EntityValue,
481 {
482 let query = self.query_from_sql::<E>(sql)?;
483 if !query.has_grouping() {
484 return Err(QueryError::execute(InternalError::classified(
485 ErrorClass::Unsupported,
486 ErrorOrigin::Query,
487 "execute_sql_grouped requires grouped SQL query intent",
488 )));
489 }
490
491 self.execute_grouped(&query, cursor_token)
492 }
493
494 pub fn explain_sql<E>(&self, sql: &str) -> Result<String, QueryError>
501 where
502 E: EntityKind<Canister = C> + EntityValue,
503 {
504 let command = compile_sql_command::<E>(sql, MissingRowPolicy::Ignore)
505 .map_err(map_sql_lowering_error)?;
506
507 match command {
508 SqlCommand::Query(_) => Err(QueryError::execute(InternalError::classified(
509 ErrorClass::Unsupported,
510 ErrorOrigin::Query,
511 "explain_sql requires an EXPLAIN statement",
512 ))),
513 SqlCommand::DescribeEntity => Err(QueryError::execute(InternalError::classified(
514 ErrorClass::Unsupported,
515 ErrorOrigin::Query,
516 "explain_sql does not accept DESCRIBE statements; use describe_sql(...)",
517 ))),
518 SqlCommand::ShowIndexesEntity => Err(QueryError::execute(InternalError::classified(
519 ErrorClass::Unsupported,
520 ErrorOrigin::Query,
521 "explain_sql does not accept SHOW INDEXES statements; use show_indexes_sql(...)",
522 ))),
523 SqlCommand::ShowEntities => Err(QueryError::execute(InternalError::classified(
524 ErrorClass::Unsupported,
525 ErrorOrigin::Query,
526 "explain_sql does not accept SHOW ENTITIES statements; use show_entities_sql(...)",
527 ))),
528 SqlCommand::Explain { mode, query } => match mode {
529 SqlExplainMode::Plan => Ok(query.explain()?.render_text_canonical()),
530 SqlExplainMode::Execution => query.explain_execution_text(),
531 SqlExplainMode::Json => Ok(query.explain()?.render_json_canonical()),
532 },
533 SqlCommand::ExplainGlobalAggregate { mode, command } => {
534 Self::explain_sql_global_aggregate::<E>(mode, command)
535 }
536 }
537 }
538
539 fn explain_sql_global_aggregate<E>(
541 mode: SqlExplainMode,
542 command: SqlGlobalAggregateCommand<E>,
543 ) -> Result<String, QueryError>
544 where
545 E: EntityKind<Canister = C> + EntityValue,
546 {
547 match mode {
548 SqlExplainMode::Plan => {
549 let _ = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
552
553 Ok(command.query().explain()?.render_text_canonical())
554 }
555 SqlExplainMode::Execution => {
556 let aggregate = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
557 let plan = Self::explain_load_query_terminal_with(command.query(), aggregate)?;
558
559 Ok(plan.execution_node_descriptor().render_text_tree())
560 }
561 SqlExplainMode::Json => {
562 let _ = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
565
566 Ok(command.query().explain()?.render_json_canonical())
567 }
568 }
569 }
570}