1use crate::{
2 db::{
3 DbSession, EntityFieldDescription, EntityResponse, EntitySchemaDescription,
4 MissingRowPolicy, PagedGroupedExecutionWithTrace, ProjectionResponse, Query, QueryError,
5 query::{
6 builder::aggregate::{AggregateExpr, avg, count, count_by, max_by, min_by, sum},
7 intent::IntentError,
8 plan::{
9 AggregateKind, FieldSlot, QueryMode,
10 expr::{Expr, ProjectionField},
11 },
12 },
13 sql::lowering::{
14 SqlCommand, SqlGlobalAggregateCommand, SqlGlobalAggregateTerminal, SqlLoweringError,
15 compile_sql_command, compile_sql_global_aggregate_command,
16 },
17 sql::parser::{SqlExplainMode, SqlExplainTarget, SqlStatement, parse_sql},
18 },
19 error::{ErrorClass, ErrorOrigin, InternalError},
20 traits::{CanisterKind, EntityKind, EntityValue},
21 value::Value,
22};
23
24#[derive(Clone, Debug, Eq, PartialEq)]
32pub enum SqlStatementRoute {
33 Query { entity: String },
34 Explain { entity: String },
35 Describe { entity: String },
36 ShowIndexes { entity: String },
37 ShowColumns { entity: String },
38 ShowEntities,
39}
40
41impl SqlStatementRoute {
42 #[must_use]
47 pub const fn entity(&self) -> &str {
48 match self {
49 Self::Query { entity }
50 | Self::Explain { entity }
51 | Self::Describe { entity }
52 | Self::ShowIndexes { entity }
53 | Self::ShowColumns { entity } => entity.as_str(),
54 Self::ShowEntities => "",
55 }
56 }
57
58 #[must_use]
60 pub const fn is_explain(&self) -> bool {
61 matches!(self, Self::Explain { .. })
62 }
63
64 #[must_use]
66 pub const fn is_describe(&self) -> bool {
67 matches!(self, Self::Describe { .. })
68 }
69
70 #[must_use]
72 pub const fn is_show_indexes(&self) -> bool {
73 matches!(self, Self::ShowIndexes { .. })
74 }
75
76 #[must_use]
78 pub const fn is_show_columns(&self) -> bool {
79 matches!(self, Self::ShowColumns { .. })
80 }
81
82 #[must_use]
84 pub const fn is_show_entities(&self) -> bool {
85 matches!(self, Self::ShowEntities)
86 }
87}
88
89#[derive(Clone, Copy, Debug, Eq, PartialEq)]
91enum SqlLaneKind {
92 Query,
93 Explain,
94 Describe,
95 ShowIndexes,
96 ShowColumns,
97 ShowEntities,
98}
99
100#[derive(Clone, Copy, Debug, Eq, PartialEq)]
102enum SqlSurface {
103 QueryFrom,
104 Explain,
105 Describe,
106 ShowIndexes,
107 ShowColumns,
108 ShowEntities,
109}
110
111const fn sql_command_lane<E: EntityKind>(command: &SqlCommand<E>) -> SqlLaneKind {
113 match command {
114 SqlCommand::Query(_) => SqlLaneKind::Query,
115 SqlCommand::Explain { .. } | SqlCommand::ExplainGlobalAggregate { .. } => {
116 SqlLaneKind::Explain
117 }
118 SqlCommand::DescribeEntity => SqlLaneKind::Describe,
119 SqlCommand::ShowIndexesEntity => SqlLaneKind::ShowIndexes,
120 SqlCommand::ShowColumnsEntity => SqlLaneKind::ShowColumns,
121 SqlCommand::ShowEntities => SqlLaneKind::ShowEntities,
122 }
123}
124
125const fn sql_statement_route_lane(route: &SqlStatementRoute) -> SqlLaneKind {
127 match route {
128 SqlStatementRoute::Query { .. } => SqlLaneKind::Query,
129 SqlStatementRoute::Explain { .. } => SqlLaneKind::Explain,
130 SqlStatementRoute::Describe { .. } => SqlLaneKind::Describe,
131 SqlStatementRoute::ShowIndexes { .. } => SqlLaneKind::ShowIndexes,
132 SqlStatementRoute::ShowColumns { .. } => SqlLaneKind::ShowColumns,
133 SqlStatementRoute::ShowEntities => SqlLaneKind::ShowEntities,
134 }
135}
136
137const fn unsupported_sql_lane_message(surface: SqlSurface, lane: SqlLaneKind) -> &'static str {
139 match (surface, lane) {
140 (SqlSurface::QueryFrom, SqlLaneKind::Explain) => {
141 "query_from_sql does not accept EXPLAIN statements; use explain_sql(...)"
142 }
143 (SqlSurface::QueryFrom, SqlLaneKind::Describe) => {
144 "query_from_sql does not accept DESCRIBE statements; use describe_sql(...)"
145 }
146 (SqlSurface::QueryFrom, SqlLaneKind::ShowIndexes) => {
147 "query_from_sql does not accept SHOW INDEXES statements; use show_indexes_sql(...)"
148 }
149 (SqlSurface::QueryFrom, SqlLaneKind::ShowColumns) => {
150 "query_from_sql does not accept SHOW COLUMNS statements; use show_columns_sql(...)"
151 }
152 (SqlSurface::QueryFrom, SqlLaneKind::ShowEntities) => {
153 "query_from_sql does not accept SHOW ENTITIES statements; use show_entities_sql(...)"
154 }
155 (SqlSurface::QueryFrom, SqlLaneKind::Query) => {
156 "query_from_sql requires one executable SELECT or DELETE statement"
157 }
158 (SqlSurface::Explain, SqlLaneKind::Describe) => {
159 "explain_sql does not accept DESCRIBE statements; use describe_sql(...)"
160 }
161 (SqlSurface::Explain, SqlLaneKind::ShowIndexes) => {
162 "explain_sql does not accept SHOW INDEXES statements; use show_indexes_sql(...)"
163 }
164 (SqlSurface::Explain, SqlLaneKind::ShowColumns) => {
165 "explain_sql does not accept SHOW COLUMNS statements; use show_columns_sql(...)"
166 }
167 (SqlSurface::Explain, SqlLaneKind::ShowEntities) => {
168 "explain_sql does not accept SHOW ENTITIES statements; use show_entities_sql(...)"
169 }
170 (SqlSurface::Explain, SqlLaneKind::Query | SqlLaneKind::Explain) => {
171 "explain_sql requires an EXPLAIN statement"
172 }
173 (SqlSurface::Describe, _) => "describe_sql requires a DESCRIBE statement",
174 (SqlSurface::ShowIndexes, _) => "show_indexes_sql requires a SHOW INDEXES statement",
175 (SqlSurface::ShowColumns, _) => "show_columns_sql requires a SHOW COLUMNS statement",
176 (SqlSurface::ShowEntities, _) => "show_entities_sql requires a SHOW ENTITIES statement",
177 }
178}
179
180fn unsupported_sql_lane_error(surface: SqlSurface, lane: SqlLaneKind) -> QueryError {
182 QueryError::execute(InternalError::classified(
183 ErrorClass::Unsupported,
184 ErrorOrigin::Query,
185 unsupported_sql_lane_message(surface, lane),
186 ))
187}
188
189fn compile_sql_command_ignore<E: EntityKind>(sql: &str) -> Result<SqlCommand<E>, QueryError> {
191 compile_sql_command::<E>(sql, MissingRowPolicy::Ignore).map_err(map_sql_lowering_error)
192}
193
194fn map_sql_lowering_error(err: SqlLoweringError) -> QueryError {
196 match err {
197 SqlLoweringError::Query(err) => err,
198 SqlLoweringError::Parse(crate::db::sql::parser::SqlParseError::UnsupportedFeature {
199 feature,
200 }) => QueryError::execute(InternalError::query_unsupported_sql_feature(feature)),
201 other => QueryError::execute(InternalError::classified(
202 ErrorClass::Unsupported,
203 ErrorOrigin::Query,
204 format!("SQL query is not executable in this release: {other}"),
205 )),
206 }
207}
208
209fn map_sql_parse_error(err: crate::db::sql::parser::SqlParseError) -> QueryError {
212 map_sql_lowering_error(SqlLoweringError::Parse(err))
213}
214
215fn resolve_sql_aggregate_target_slot<E: EntityKind>(field: &str) -> Result<FieldSlot, QueryError> {
218 FieldSlot::resolve(E::MODEL, field).ok_or_else(|| {
219 QueryError::execute(crate::db::error::executor_unsupported(format!(
220 "unknown aggregate target field: {field}",
221 )))
222 })
223}
224
225fn sql_global_aggregate_terminal_to_expr<E: EntityKind>(
228 terminal: &SqlGlobalAggregateTerminal,
229) -> Result<AggregateExpr, QueryError> {
230 match terminal {
231 SqlGlobalAggregateTerminal::CountRows => Ok(count()),
232 SqlGlobalAggregateTerminal::CountField(field) => {
233 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
234
235 Ok(count_by(field.as_str()))
236 }
237 SqlGlobalAggregateTerminal::SumField(field) => {
238 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
239
240 Ok(sum(field.as_str()))
241 }
242 SqlGlobalAggregateTerminal::AvgField(field) => {
243 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
244
245 Ok(avg(field.as_str()))
246 }
247 SqlGlobalAggregateTerminal::MinField(field) => {
248 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
249
250 Ok(min_by(field.as_str()))
251 }
252 SqlGlobalAggregateTerminal::MaxField(field) => {
253 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
254
255 Ok(max_by(field.as_str()))
256 }
257 }
258}
259
260fn projection_label_from_aggregate(aggregate: &AggregateExpr) -> String {
262 let kind = match aggregate.kind() {
263 AggregateKind::Count => "COUNT",
264 AggregateKind::Sum => "SUM",
265 AggregateKind::Avg => "AVG",
266 AggregateKind::Exists => "EXISTS",
267 AggregateKind::First => "FIRST",
268 AggregateKind::Last => "LAST",
269 AggregateKind::Min => "MIN",
270 AggregateKind::Max => "MAX",
271 };
272 let distinct = if aggregate.is_distinct() {
273 "DISTINCT "
274 } else {
275 ""
276 };
277
278 if let Some(field) = aggregate.target_field() {
279 return format!("{kind}({distinct}{field})");
280 }
281
282 format!("{kind}({distinct}*)")
283}
284
285fn projection_label_from_expr(expr: &Expr, ordinal: usize) -> String {
287 match expr {
288 Expr::Field(field) => field.as_str().to_string(),
289 Expr::Aggregate(aggregate) => projection_label_from_aggregate(aggregate),
290 Expr::Alias { name, .. } => name.as_str().to_string(),
291 Expr::Literal(_) | Expr::Unary { .. } | Expr::Binary { .. } => {
292 format!("expr_{ordinal}")
293 }
294 }
295}
296
297fn projection_labels_from_query<E: EntityKind>(
299 query: &Query<E>,
300) -> Result<Vec<String>, QueryError> {
301 let projection = query.plan()?.projection_spec();
302 let mut labels = Vec::with_capacity(projection.len());
303
304 for (ordinal, field) in projection.fields().enumerate() {
305 match field {
306 ProjectionField::Scalar {
307 expr: _,
308 alias: Some(alias),
309 } => labels.push(alias.as_str().to_string()),
310 ProjectionField::Scalar { expr, alias: None } => {
311 labels.push(projection_label_from_expr(expr, ordinal));
312 }
313 }
314 }
315
316 Ok(labels)
317}
318
319impl<C: CanisterKind> DbSession<C> {
320 pub fn sql_statement_route(&self, sql: &str) -> Result<SqlStatementRoute, QueryError> {
325 let statement = parse_sql(sql).map_err(map_sql_parse_error)?;
326 match statement {
327 SqlStatement::Select(select) => Ok(SqlStatementRoute::Query {
328 entity: select.entity,
329 }),
330 SqlStatement::Delete(delete) => Ok(SqlStatementRoute::Query {
331 entity: delete.entity,
332 }),
333 SqlStatement::Explain(explain) => match explain.statement {
334 SqlExplainTarget::Select(select) => Ok(SqlStatementRoute::Explain {
335 entity: select.entity,
336 }),
337 SqlExplainTarget::Delete(delete) => Ok(SqlStatementRoute::Explain {
338 entity: delete.entity,
339 }),
340 },
341 SqlStatement::Describe(describe) => Ok(SqlStatementRoute::Describe {
342 entity: describe.entity,
343 }),
344 SqlStatement::ShowIndexes(show_indexes) => Ok(SqlStatementRoute::ShowIndexes {
345 entity: show_indexes.entity,
346 }),
347 SqlStatement::ShowColumns(show_columns) => Ok(SqlStatementRoute::ShowColumns {
348 entity: show_columns.entity,
349 }),
350 SqlStatement::ShowEntities(_) => Ok(SqlStatementRoute::ShowEntities),
351 }
352 }
353
354 pub fn describe_sql<E>(&self, sql: &str) -> Result<EntitySchemaDescription, QueryError>
356 where
357 E: EntityKind<Canister = C>,
358 {
359 let command = compile_sql_command_ignore::<E>(sql)?;
360 let lane = sql_command_lane(&command);
361
362 match command {
363 SqlCommand::DescribeEntity => Ok(self.describe_entity::<E>()),
364 SqlCommand::Query(_)
365 | SqlCommand::Explain { .. }
366 | SqlCommand::ExplainGlobalAggregate { .. }
367 | SqlCommand::ShowIndexesEntity
368 | SqlCommand::ShowColumnsEntity
369 | SqlCommand::ShowEntities => {
370 Err(unsupported_sql_lane_error(SqlSurface::Describe, lane))
371 }
372 }
373 }
374
375 pub fn show_indexes_sql<E>(&self, sql: &str) -> Result<Vec<String>, QueryError>
377 where
378 E: EntityKind<Canister = C>,
379 {
380 let command = compile_sql_command_ignore::<E>(sql)?;
381 let lane = sql_command_lane(&command);
382
383 match command {
384 SqlCommand::ShowIndexesEntity => Ok(self.show_indexes::<E>()),
385 SqlCommand::Query(_)
386 | SqlCommand::Explain { .. }
387 | SqlCommand::ExplainGlobalAggregate { .. }
388 | SqlCommand::DescribeEntity
389 | SqlCommand::ShowColumnsEntity
390 | SqlCommand::ShowEntities => {
391 Err(unsupported_sql_lane_error(SqlSurface::ShowIndexes, lane))
392 }
393 }
394 }
395
396 pub fn show_columns_sql<E>(&self, sql: &str) -> Result<Vec<EntityFieldDescription>, QueryError>
398 where
399 E: EntityKind<Canister = C>,
400 {
401 let command = compile_sql_command_ignore::<E>(sql)?;
402 let lane = sql_command_lane(&command);
403
404 match command {
405 SqlCommand::ShowColumnsEntity => Ok(self.show_columns::<E>()),
406 SqlCommand::Query(_)
407 | SqlCommand::Explain { .. }
408 | SqlCommand::ExplainGlobalAggregate { .. }
409 | SqlCommand::DescribeEntity
410 | SqlCommand::ShowIndexesEntity
411 | SqlCommand::ShowEntities => {
412 Err(unsupported_sql_lane_error(SqlSurface::ShowColumns, lane))
413 }
414 }
415 }
416
417 pub fn show_entities_sql(&self, sql: &str) -> Result<Vec<String>, QueryError> {
419 let statement = self.sql_statement_route(sql)?;
420 let lane = sql_statement_route_lane(&statement);
421 if lane != SqlLaneKind::ShowEntities {
422 return Err(unsupported_sql_lane_error(SqlSurface::ShowEntities, lane));
423 }
424
425 Ok(self.show_entities())
426 }
427
428 pub fn query_from_sql<E>(&self, sql: &str) -> Result<Query<E>, QueryError>
433 where
434 E: EntityKind<Canister = C>,
435 {
436 let command = compile_sql_command_ignore::<E>(sql)?;
437 let lane = sql_command_lane(&command);
438
439 match command {
440 SqlCommand::Query(query) => Ok(query),
441 SqlCommand::Explain { .. }
442 | SqlCommand::ExplainGlobalAggregate { .. }
443 | SqlCommand::DescribeEntity
444 | SqlCommand::ShowIndexesEntity
445 | SqlCommand::ShowColumnsEntity
446 | SqlCommand::ShowEntities => {
447 Err(unsupported_sql_lane_error(SqlSurface::QueryFrom, lane))
448 }
449 }
450 }
451
452 pub fn sql_projection_columns<E>(&self, sql: &str) -> Result<Vec<String>, QueryError>
454 where
455 E: EntityKind<Canister = C>,
456 {
457 let query = self.query_from_sql::<E>(sql)?;
458 if query.has_grouping() {
459 return Err(QueryError::Intent(
460 IntentError::GroupedRequiresExecuteGrouped,
461 ));
462 }
463
464 match query.mode() {
465 QueryMode::Load(_) => projection_labels_from_query(&query),
466 QueryMode::Delete(_) => Err(QueryError::execute(InternalError::classified(
467 ErrorClass::Unsupported,
468 ErrorOrigin::Query,
469 "sql_projection_columns only supports SELECT statements",
470 ))),
471 }
472 }
473
474 pub fn execute_sql<E>(&self, sql: &str) -> Result<EntityResponse<E>, QueryError>
476 where
477 E: EntityKind<Canister = C> + EntityValue,
478 {
479 let query = self.query_from_sql::<E>(sql)?;
480 if query.has_grouping() {
481 return Err(QueryError::Intent(
482 IntentError::GroupedRequiresExecuteGrouped,
483 ));
484 }
485
486 self.execute_query(&query)
487 }
488
489 pub fn execute_sql_projection<E>(&self, sql: &str) -> Result<ProjectionResponse<E>, QueryError>
494 where
495 E: EntityKind<Canister = C> + EntityValue,
496 {
497 let query = self.query_from_sql::<E>(sql)?;
498 if query.has_grouping() {
499 return Err(QueryError::Intent(
500 IntentError::GroupedRequiresExecuteGrouped,
501 ));
502 }
503
504 match query.mode() {
505 QueryMode::Load(_) => {
506 self.execute_load_query_with(&query, |load, plan| load.execute_projection(plan))
507 }
508 QueryMode::Delete(_) => Err(QueryError::execute(InternalError::classified(
509 ErrorClass::Unsupported,
510 ErrorOrigin::Query,
511 "execute_sql_projection only supports SELECT statements",
512 ))),
513 }
514 }
515
516 pub fn execute_sql_aggregate<E>(&self, sql: &str) -> Result<Value, QueryError>
521 where
522 E: EntityKind<Canister = C> + EntityValue,
523 {
524 let command = compile_sql_global_aggregate_command::<E>(sql, MissingRowPolicy::Ignore)
525 .map_err(map_sql_lowering_error)?;
526
527 match command.terminal() {
528 SqlGlobalAggregateTerminal::CountRows => self
529 .execute_load_query_with(command.query(), |load, plan| load.aggregate_count(plan))
530 .map(|count| Value::Uint(u64::from(count))),
531 SqlGlobalAggregateTerminal::CountField(field) => {
532 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
533 self.execute_load_query_with(command.query(), |load, plan| {
534 load.values_by_slot(plan, target_slot)
535 })
536 .map(|values| {
537 let count = values
538 .into_iter()
539 .filter(|value| !matches!(value, Value::Null))
540 .count();
541 Value::Uint(u64::try_from(count).unwrap_or(u64::MAX))
542 })
543 }
544 SqlGlobalAggregateTerminal::SumField(field) => {
545 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
546 self.execute_load_query_with(command.query(), |load, plan| {
547 load.aggregate_sum_by_slot(plan, target_slot)
548 })
549 .map(|value| value.map_or(Value::Null, Value::Decimal))
550 }
551 SqlGlobalAggregateTerminal::AvgField(field) => {
552 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
553 self.execute_load_query_with(command.query(), |load, plan| {
554 load.aggregate_avg_by_slot(plan, target_slot)
555 })
556 .map(|value| value.map_or(Value::Null, Value::Decimal))
557 }
558 SqlGlobalAggregateTerminal::MinField(field) => {
559 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
560 let min_id = self.execute_load_query_with(command.query(), |load, plan| {
561 load.aggregate_min_by_slot(plan, target_slot)
562 })?;
563
564 match min_id {
565 Some(id) => self
566 .load::<E>()
567 .by_id(id)
568 .first_value_by(field)
569 .map(|value| value.unwrap_or(Value::Null)),
570 None => Ok(Value::Null),
571 }
572 }
573 SqlGlobalAggregateTerminal::MaxField(field) => {
574 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
575 let max_id = self.execute_load_query_with(command.query(), |load, plan| {
576 load.aggregate_max_by_slot(plan, target_slot)
577 })?;
578
579 match max_id {
580 Some(id) => self
581 .load::<E>()
582 .by_id(id)
583 .first_value_by(field)
584 .map(|value| value.unwrap_or(Value::Null)),
585 None => Ok(Value::Null),
586 }
587 }
588 }
589 }
590
591 pub fn execute_sql_grouped<E>(
593 &self,
594 sql: &str,
595 cursor_token: Option<&str>,
596 ) -> Result<PagedGroupedExecutionWithTrace, QueryError>
597 where
598 E: EntityKind<Canister = C> + EntityValue,
599 {
600 let query = self.query_from_sql::<E>(sql)?;
601 if !query.has_grouping() {
602 return Err(QueryError::execute(InternalError::classified(
603 ErrorClass::Unsupported,
604 ErrorOrigin::Query,
605 "execute_sql_grouped requires grouped SQL query intent",
606 )));
607 }
608
609 self.execute_grouped(&query, cursor_token)
610 }
611
612 pub fn explain_sql<E>(&self, sql: &str) -> Result<String, QueryError>
619 where
620 E: EntityKind<Canister = C> + EntityValue,
621 {
622 let command = compile_sql_command_ignore::<E>(sql)?;
623 let lane = sql_command_lane(&command);
624
625 match command {
626 SqlCommand::Query(_)
627 | SqlCommand::DescribeEntity
628 | SqlCommand::ShowIndexesEntity
629 | SqlCommand::ShowColumnsEntity
630 | SqlCommand::ShowEntities => {
631 Err(unsupported_sql_lane_error(SqlSurface::Explain, lane))
632 }
633 SqlCommand::Explain { mode, query } => match mode {
634 SqlExplainMode::Plan => Ok(query.explain()?.render_text_canonical()),
635 SqlExplainMode::Execution => query.explain_execution_text(),
636 SqlExplainMode::Json => Ok(query.explain()?.render_json_canonical()),
637 },
638 SqlCommand::ExplainGlobalAggregate { mode, command } => {
639 Self::explain_sql_global_aggregate::<E>(mode, command)
640 }
641 }
642 }
643
644 fn explain_sql_global_aggregate<E>(
646 mode: SqlExplainMode,
647 command: SqlGlobalAggregateCommand<E>,
648 ) -> Result<String, QueryError>
649 where
650 E: EntityKind<Canister = C> + EntityValue,
651 {
652 match mode {
653 SqlExplainMode::Plan => {
654 let _ = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
657
658 Ok(command.query().explain()?.render_text_canonical())
659 }
660 SqlExplainMode::Execution => {
661 let aggregate = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
662 let plan = Self::explain_load_query_terminal_with(command.query(), aggregate)?;
663
664 Ok(plan.execution_node_descriptor().render_text_tree())
665 }
666 SqlExplainMode::Json => {
667 let _ = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
670
671 Ok(command.query().explain()?.render_json_canonical())
672 }
673 }
674 }
675}