1use crate::{
2 db::{
3 DbSession, EntityFieldDescription, EntityResponse, EntitySchemaDescription,
4 MissingRowPolicy, PagedGroupedExecutionWithTrace, ProjectionResponse, Query, QueryError,
5 query::{
6 builder::aggregate::{AggregateExpr, avg, count, count_by, max_by, min_by, sum},
7 intent::IntentError,
8 plan::{
9 AggregateKind, FieldSlot, QueryMode,
10 expr::{Expr, ProjectionField},
11 },
12 },
13 sql::lowering::{
14 SqlCommand, SqlGlobalAggregateCommand, SqlGlobalAggregateTerminal, SqlLoweringError,
15 compile_sql_command, compile_sql_global_aggregate_command,
16 },
17 sql::parser::{SqlExplainMode, SqlExplainTarget, SqlStatement, parse_sql},
18 },
19 error::{ErrorClass, ErrorOrigin, InternalError},
20 traits::{CanisterKind, EntityKind, EntityValue},
21 value::Value,
22};
23
24#[derive(Clone, Debug, Eq, PartialEq)]
32pub enum SqlStatementRoute {
33 Query { entity: String },
34 Explain { entity: String },
35 Describe { entity: String },
36 ShowIndexes { entity: String },
37 ShowColumns { entity: String },
38 ShowEntities,
39}
40
41impl SqlStatementRoute {
42 #[must_use]
47 pub const fn entity(&self) -> &str {
48 match self {
49 Self::Query { entity }
50 | Self::Explain { entity }
51 | Self::Describe { entity }
52 | Self::ShowIndexes { entity }
53 | Self::ShowColumns { entity } => entity.as_str(),
54 Self::ShowEntities => "",
55 }
56 }
57
58 #[must_use]
60 pub const fn is_explain(&self) -> bool {
61 matches!(self, Self::Explain { .. })
62 }
63
64 #[must_use]
66 pub const fn is_describe(&self) -> bool {
67 matches!(self, Self::Describe { .. })
68 }
69
70 #[must_use]
72 pub const fn is_show_indexes(&self) -> bool {
73 matches!(self, Self::ShowIndexes { .. })
74 }
75
76 #[must_use]
78 pub const fn is_show_columns(&self) -> bool {
79 matches!(self, Self::ShowColumns { .. })
80 }
81
82 #[must_use]
84 pub const fn is_show_entities(&self) -> bool {
85 matches!(self, Self::ShowEntities)
86 }
87}
88
89#[derive(Clone, Copy, Debug, Eq, PartialEq)]
91enum SqlLaneKind {
92 Query,
93 Explain,
94 Describe,
95 ShowIndexes,
96 ShowColumns,
97 ShowEntities,
98}
99
100#[derive(Clone, Copy, Debug, Eq, PartialEq)]
102enum SqlSurface {
103 QueryFrom,
104 Explain,
105 Describe,
106 ShowIndexes,
107 ShowColumns,
108 ShowEntities,
109}
110
111const fn sql_command_lane<E: EntityKind>(command: &SqlCommand<E>) -> SqlLaneKind {
113 match command {
114 SqlCommand::Query(_) => SqlLaneKind::Query,
115 SqlCommand::Explain { .. } | SqlCommand::ExplainGlobalAggregate { .. } => {
116 SqlLaneKind::Explain
117 }
118 SqlCommand::DescribeEntity => SqlLaneKind::Describe,
119 SqlCommand::ShowIndexesEntity => SqlLaneKind::ShowIndexes,
120 SqlCommand::ShowColumnsEntity => SqlLaneKind::ShowColumns,
121 SqlCommand::ShowEntities => SqlLaneKind::ShowEntities,
122 }
123}
124
125const fn sql_statement_route_lane(route: &SqlStatementRoute) -> SqlLaneKind {
127 match route {
128 SqlStatementRoute::Query { .. } => SqlLaneKind::Query,
129 SqlStatementRoute::Explain { .. } => SqlLaneKind::Explain,
130 SqlStatementRoute::Describe { .. } => SqlLaneKind::Describe,
131 SqlStatementRoute::ShowIndexes { .. } => SqlLaneKind::ShowIndexes,
132 SqlStatementRoute::ShowColumns { .. } => SqlLaneKind::ShowColumns,
133 SqlStatementRoute::ShowEntities => SqlLaneKind::ShowEntities,
134 }
135}
136
137const fn unsupported_sql_lane_message(surface: SqlSurface, lane: SqlLaneKind) -> &'static str {
139 match (surface, lane) {
140 (SqlSurface::QueryFrom, SqlLaneKind::Explain) => {
141 "query_from_sql does not accept EXPLAIN statements; use explain_sql(...)"
142 }
143 (SqlSurface::QueryFrom, SqlLaneKind::Describe) => {
144 "query_from_sql does not accept DESCRIBE statements; use describe_sql(...)"
145 }
146 (SqlSurface::QueryFrom, SqlLaneKind::ShowIndexes) => {
147 "query_from_sql does not accept SHOW INDEXES statements; use show_indexes_sql(...)"
148 }
149 (SqlSurface::QueryFrom, SqlLaneKind::ShowColumns) => {
150 "query_from_sql does not accept SHOW COLUMNS statements; use show_columns_sql(...)"
151 }
152 (SqlSurface::QueryFrom, SqlLaneKind::ShowEntities) => {
153 "query_from_sql does not accept SHOW ENTITIES/SHOW TABLES statements; use show_entities_sql(...)"
154 }
155 (SqlSurface::QueryFrom, SqlLaneKind::Query) => {
156 "query_from_sql requires one executable SELECT or DELETE statement"
157 }
158 (SqlSurface::Explain, SqlLaneKind::Describe) => {
159 "explain_sql does not accept DESCRIBE statements; use describe_sql(...)"
160 }
161 (SqlSurface::Explain, SqlLaneKind::ShowIndexes) => {
162 "explain_sql does not accept SHOW INDEXES statements; use show_indexes_sql(...)"
163 }
164 (SqlSurface::Explain, SqlLaneKind::ShowColumns) => {
165 "explain_sql does not accept SHOW COLUMNS statements; use show_columns_sql(...)"
166 }
167 (SqlSurface::Explain, SqlLaneKind::ShowEntities) => {
168 "explain_sql does not accept SHOW ENTITIES/SHOW TABLES statements; use show_entities_sql(...)"
169 }
170 (SqlSurface::Explain, SqlLaneKind::Query | SqlLaneKind::Explain) => {
171 "explain_sql requires an EXPLAIN statement"
172 }
173 (SqlSurface::Describe, _) => "describe_sql requires a DESCRIBE statement",
174 (SqlSurface::ShowIndexes, _) => "show_indexes_sql requires a SHOW INDEXES statement",
175 (SqlSurface::ShowColumns, _) => "show_columns_sql requires a SHOW COLUMNS statement",
176 (SqlSurface::ShowEntities, _) => {
177 "show_entities_sql requires a SHOW ENTITIES or SHOW TABLES statement"
178 }
179 }
180}
181
182fn unsupported_sql_lane_error(surface: SqlSurface, lane: SqlLaneKind) -> QueryError {
184 QueryError::execute(InternalError::classified(
185 ErrorClass::Unsupported,
186 ErrorOrigin::Query,
187 unsupported_sql_lane_message(surface, lane),
188 ))
189}
190
191fn compile_sql_command_ignore<E: EntityKind>(sql: &str) -> Result<SqlCommand<E>, QueryError> {
193 compile_sql_command::<E>(sql, MissingRowPolicy::Ignore).map_err(map_sql_lowering_error)
194}
195
196fn map_sql_lowering_error(err: SqlLoweringError) -> QueryError {
198 match err {
199 SqlLoweringError::Query(err) => err,
200 SqlLoweringError::Parse(crate::db::sql::parser::SqlParseError::UnsupportedFeature {
201 feature,
202 }) => QueryError::execute(InternalError::query_unsupported_sql_feature(feature)),
203 other => QueryError::execute(InternalError::classified(
204 ErrorClass::Unsupported,
205 ErrorOrigin::Query,
206 format!("SQL query is not executable in this release: {other}"),
207 )),
208 }
209}
210
211fn map_sql_parse_error(err: crate::db::sql::parser::SqlParseError) -> QueryError {
214 map_sql_lowering_error(SqlLoweringError::Parse(err))
215}
216
217fn resolve_sql_aggregate_target_slot<E: EntityKind>(field: &str) -> Result<FieldSlot, QueryError> {
220 FieldSlot::resolve(E::MODEL, field).ok_or_else(|| {
221 QueryError::execute(crate::db::error::executor_unsupported(format!(
222 "unknown aggregate target field: {field}",
223 )))
224 })
225}
226
227fn sql_global_aggregate_terminal_to_expr<E: EntityKind>(
230 terminal: &SqlGlobalAggregateTerminal,
231) -> Result<AggregateExpr, QueryError> {
232 match terminal {
233 SqlGlobalAggregateTerminal::CountRows => Ok(count()),
234 SqlGlobalAggregateTerminal::CountField(field) => {
235 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
236
237 Ok(count_by(field.as_str()))
238 }
239 SqlGlobalAggregateTerminal::SumField(field) => {
240 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
241
242 Ok(sum(field.as_str()))
243 }
244 SqlGlobalAggregateTerminal::AvgField(field) => {
245 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
246
247 Ok(avg(field.as_str()))
248 }
249 SqlGlobalAggregateTerminal::MinField(field) => {
250 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
251
252 Ok(min_by(field.as_str()))
253 }
254 SqlGlobalAggregateTerminal::MaxField(field) => {
255 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
256
257 Ok(max_by(field.as_str()))
258 }
259 }
260}
261
262fn projection_label_from_aggregate(aggregate: &AggregateExpr) -> String {
264 let kind = match aggregate.kind() {
265 AggregateKind::Count => "COUNT",
266 AggregateKind::Sum => "SUM",
267 AggregateKind::Avg => "AVG",
268 AggregateKind::Exists => "EXISTS",
269 AggregateKind::First => "FIRST",
270 AggregateKind::Last => "LAST",
271 AggregateKind::Min => "MIN",
272 AggregateKind::Max => "MAX",
273 };
274 let distinct = if aggregate.is_distinct() {
275 "DISTINCT "
276 } else {
277 ""
278 };
279
280 if let Some(field) = aggregate.target_field() {
281 return format!("{kind}({distinct}{field})");
282 }
283
284 format!("{kind}({distinct}*)")
285}
286
287fn projection_label_from_expr(expr: &Expr, ordinal: usize) -> String {
289 match expr {
290 Expr::Field(field) => field.as_str().to_string(),
291 Expr::Aggregate(aggregate) => projection_label_from_aggregate(aggregate),
292 Expr::Alias { name, .. } => name.as_str().to_string(),
293 Expr::Literal(_) | Expr::Unary { .. } | Expr::Binary { .. } => {
294 format!("expr_{ordinal}")
295 }
296 }
297}
298
299fn projection_labels_from_query<E: EntityKind>(
301 query: &Query<E>,
302) -> Result<Vec<String>, QueryError> {
303 let projection = query.plan()?.projection_spec();
304 let mut labels = Vec::with_capacity(projection.len());
305
306 for (ordinal, field) in projection.fields().enumerate() {
307 match field {
308 ProjectionField::Scalar {
309 expr: _,
310 alias: Some(alias),
311 } => labels.push(alias.as_str().to_string()),
312 ProjectionField::Scalar { expr, alias: None } => {
313 labels.push(projection_label_from_expr(expr, ordinal));
314 }
315 }
316 }
317
318 Ok(labels)
319}
320
321impl<C: CanisterKind> DbSession<C> {
322 pub fn sql_statement_route(&self, sql: &str) -> Result<SqlStatementRoute, QueryError> {
327 let statement = parse_sql(sql).map_err(map_sql_parse_error)?;
328 match statement {
329 SqlStatement::Select(select) => Ok(SqlStatementRoute::Query {
330 entity: select.entity,
331 }),
332 SqlStatement::Delete(delete) => Ok(SqlStatementRoute::Query {
333 entity: delete.entity,
334 }),
335 SqlStatement::Explain(explain) => match explain.statement {
336 SqlExplainTarget::Select(select) => Ok(SqlStatementRoute::Explain {
337 entity: select.entity,
338 }),
339 SqlExplainTarget::Delete(delete) => Ok(SqlStatementRoute::Explain {
340 entity: delete.entity,
341 }),
342 },
343 SqlStatement::Describe(describe) => Ok(SqlStatementRoute::Describe {
344 entity: describe.entity,
345 }),
346 SqlStatement::ShowIndexes(show_indexes) => Ok(SqlStatementRoute::ShowIndexes {
347 entity: show_indexes.entity,
348 }),
349 SqlStatement::ShowColumns(show_columns) => Ok(SqlStatementRoute::ShowColumns {
350 entity: show_columns.entity,
351 }),
352 SqlStatement::ShowEntities(_) => Ok(SqlStatementRoute::ShowEntities),
353 }
354 }
355
356 pub fn describe_sql<E>(&self, sql: &str) -> Result<EntitySchemaDescription, QueryError>
358 where
359 E: EntityKind<Canister = C>,
360 {
361 let command = compile_sql_command_ignore::<E>(sql)?;
362 let lane = sql_command_lane(&command);
363
364 match command {
365 SqlCommand::DescribeEntity => Ok(self.describe_entity::<E>()),
366 SqlCommand::Query(_)
367 | SqlCommand::Explain { .. }
368 | SqlCommand::ExplainGlobalAggregate { .. }
369 | SqlCommand::ShowIndexesEntity
370 | SqlCommand::ShowColumnsEntity
371 | SqlCommand::ShowEntities => {
372 Err(unsupported_sql_lane_error(SqlSurface::Describe, lane))
373 }
374 }
375 }
376
377 pub fn show_indexes_sql<E>(&self, sql: &str) -> Result<Vec<String>, QueryError>
379 where
380 E: EntityKind<Canister = C>,
381 {
382 let command = compile_sql_command_ignore::<E>(sql)?;
383 let lane = sql_command_lane(&command);
384
385 match command {
386 SqlCommand::ShowIndexesEntity => Ok(self.show_indexes::<E>()),
387 SqlCommand::Query(_)
388 | SqlCommand::Explain { .. }
389 | SqlCommand::ExplainGlobalAggregate { .. }
390 | SqlCommand::DescribeEntity
391 | SqlCommand::ShowColumnsEntity
392 | SqlCommand::ShowEntities => {
393 Err(unsupported_sql_lane_error(SqlSurface::ShowIndexes, lane))
394 }
395 }
396 }
397
398 pub fn show_columns_sql<E>(&self, sql: &str) -> Result<Vec<EntityFieldDescription>, QueryError>
400 where
401 E: EntityKind<Canister = C>,
402 {
403 let command = compile_sql_command_ignore::<E>(sql)?;
404 let lane = sql_command_lane(&command);
405
406 match command {
407 SqlCommand::ShowColumnsEntity => Ok(self.show_columns::<E>()),
408 SqlCommand::Query(_)
409 | SqlCommand::Explain { .. }
410 | SqlCommand::ExplainGlobalAggregate { .. }
411 | SqlCommand::DescribeEntity
412 | SqlCommand::ShowIndexesEntity
413 | SqlCommand::ShowEntities => {
414 Err(unsupported_sql_lane_error(SqlSurface::ShowColumns, lane))
415 }
416 }
417 }
418
419 pub fn show_entities_sql(&self, sql: &str) -> Result<Vec<String>, QueryError> {
421 let statement = self.sql_statement_route(sql)?;
422 let lane = sql_statement_route_lane(&statement);
423 if lane != SqlLaneKind::ShowEntities {
424 return Err(unsupported_sql_lane_error(SqlSurface::ShowEntities, lane));
425 }
426
427 Ok(self.show_entities())
428 }
429
430 pub fn query_from_sql<E>(&self, sql: &str) -> Result<Query<E>, QueryError>
435 where
436 E: EntityKind<Canister = C>,
437 {
438 let command = compile_sql_command_ignore::<E>(sql)?;
439 let lane = sql_command_lane(&command);
440
441 match command {
442 SqlCommand::Query(query) => Ok(query),
443 SqlCommand::Explain { .. }
444 | SqlCommand::ExplainGlobalAggregate { .. }
445 | SqlCommand::DescribeEntity
446 | SqlCommand::ShowIndexesEntity
447 | SqlCommand::ShowColumnsEntity
448 | SqlCommand::ShowEntities => {
449 Err(unsupported_sql_lane_error(SqlSurface::QueryFrom, lane))
450 }
451 }
452 }
453
454 pub fn sql_projection_columns<E>(&self, sql: &str) -> Result<Vec<String>, QueryError>
456 where
457 E: EntityKind<Canister = C>,
458 {
459 let query = self.query_from_sql::<E>(sql)?;
460 if query.has_grouping() {
461 return Err(QueryError::Intent(
462 IntentError::GroupedRequiresExecuteGrouped,
463 ));
464 }
465
466 match query.mode() {
467 QueryMode::Load(_) => projection_labels_from_query(&query),
468 QueryMode::Delete(_) => Err(QueryError::execute(InternalError::classified(
469 ErrorClass::Unsupported,
470 ErrorOrigin::Query,
471 "sql_projection_columns only supports SELECT statements",
472 ))),
473 }
474 }
475
476 pub fn execute_sql<E>(&self, sql: &str) -> Result<EntityResponse<E>, QueryError>
478 where
479 E: EntityKind<Canister = C> + EntityValue,
480 {
481 let query = self.query_from_sql::<E>(sql)?;
482 if query.has_grouping() {
483 return Err(QueryError::Intent(
484 IntentError::GroupedRequiresExecuteGrouped,
485 ));
486 }
487
488 self.execute_query(&query)
489 }
490
491 pub fn execute_sql_projection<E>(&self, sql: &str) -> Result<ProjectionResponse<E>, QueryError>
496 where
497 E: EntityKind<Canister = C> + EntityValue,
498 {
499 let query = self.query_from_sql::<E>(sql)?;
500 if query.has_grouping() {
501 return Err(QueryError::Intent(
502 IntentError::GroupedRequiresExecuteGrouped,
503 ));
504 }
505
506 match query.mode() {
507 QueryMode::Load(_) => {
508 self.execute_load_query_with(&query, |load, plan| load.execute_projection(plan))
509 }
510 QueryMode::Delete(_) => Err(QueryError::execute(InternalError::classified(
511 ErrorClass::Unsupported,
512 ErrorOrigin::Query,
513 "execute_sql_projection only supports SELECT statements",
514 ))),
515 }
516 }
517
518 pub fn execute_sql_aggregate<E>(&self, sql: &str) -> Result<Value, QueryError>
523 where
524 E: EntityKind<Canister = C> + EntityValue,
525 {
526 let command = compile_sql_global_aggregate_command::<E>(sql, MissingRowPolicy::Ignore)
527 .map_err(map_sql_lowering_error)?;
528
529 match command.terminal() {
530 SqlGlobalAggregateTerminal::CountRows => self
531 .execute_load_query_with(command.query(), |load, plan| load.aggregate_count(plan))
532 .map(|count| Value::Uint(u64::from(count))),
533 SqlGlobalAggregateTerminal::CountField(field) => {
534 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
535 self.execute_load_query_with(command.query(), |load, plan| {
536 load.values_by_slot(plan, target_slot)
537 })
538 .map(|values| {
539 let count = values
540 .into_iter()
541 .filter(|value| !matches!(value, Value::Null))
542 .count();
543 Value::Uint(u64::try_from(count).unwrap_or(u64::MAX))
544 })
545 }
546 SqlGlobalAggregateTerminal::SumField(field) => {
547 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
548 self.execute_load_query_with(command.query(), |load, plan| {
549 load.aggregate_sum_by_slot(plan, target_slot)
550 })
551 .map(|value| value.map_or(Value::Null, Value::Decimal))
552 }
553 SqlGlobalAggregateTerminal::AvgField(field) => {
554 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
555 self.execute_load_query_with(command.query(), |load, plan| {
556 load.aggregate_avg_by_slot(plan, target_slot)
557 })
558 .map(|value| value.map_or(Value::Null, Value::Decimal))
559 }
560 SqlGlobalAggregateTerminal::MinField(field) => {
561 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
562 let min_id = self.execute_load_query_with(command.query(), |load, plan| {
563 load.aggregate_min_by_slot(plan, target_slot)
564 })?;
565
566 match min_id {
567 Some(id) => self
568 .load::<E>()
569 .by_id(id)
570 .first_value_by(field)
571 .map(|value| value.unwrap_or(Value::Null)),
572 None => Ok(Value::Null),
573 }
574 }
575 SqlGlobalAggregateTerminal::MaxField(field) => {
576 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
577 let max_id = self.execute_load_query_with(command.query(), |load, plan| {
578 load.aggregate_max_by_slot(plan, target_slot)
579 })?;
580
581 match max_id {
582 Some(id) => self
583 .load::<E>()
584 .by_id(id)
585 .first_value_by(field)
586 .map(|value| value.unwrap_or(Value::Null)),
587 None => Ok(Value::Null),
588 }
589 }
590 }
591 }
592
593 pub fn execute_sql_grouped<E>(
595 &self,
596 sql: &str,
597 cursor_token: Option<&str>,
598 ) -> Result<PagedGroupedExecutionWithTrace, QueryError>
599 where
600 E: EntityKind<Canister = C> + EntityValue,
601 {
602 let query = self.query_from_sql::<E>(sql)?;
603 if !query.has_grouping() {
604 return Err(QueryError::execute(InternalError::classified(
605 ErrorClass::Unsupported,
606 ErrorOrigin::Query,
607 "execute_sql_grouped requires grouped SQL query intent",
608 )));
609 }
610
611 self.execute_grouped(&query, cursor_token)
612 }
613
614 pub fn explain_sql<E>(&self, sql: &str) -> Result<String, QueryError>
621 where
622 E: EntityKind<Canister = C> + EntityValue,
623 {
624 let command = compile_sql_command_ignore::<E>(sql)?;
625 let lane = sql_command_lane(&command);
626
627 match command {
628 SqlCommand::Query(_)
629 | SqlCommand::DescribeEntity
630 | SqlCommand::ShowIndexesEntity
631 | SqlCommand::ShowColumnsEntity
632 | SqlCommand::ShowEntities => {
633 Err(unsupported_sql_lane_error(SqlSurface::Explain, lane))
634 }
635 SqlCommand::Explain { mode, query } => match mode {
636 SqlExplainMode::Plan => Ok(query.explain()?.render_text_canonical()),
637 SqlExplainMode::Execution => query.explain_execution_text(),
638 SqlExplainMode::Json => Ok(query.explain()?.render_json_canonical()),
639 },
640 SqlCommand::ExplainGlobalAggregate { mode, command } => {
641 Self::explain_sql_global_aggregate::<E>(mode, command)
642 }
643 }
644 }
645
646 fn explain_sql_global_aggregate<E>(
648 mode: SqlExplainMode,
649 command: SqlGlobalAggregateCommand<E>,
650 ) -> Result<String, QueryError>
651 where
652 E: EntityKind<Canister = C> + EntityValue,
653 {
654 match mode {
655 SqlExplainMode::Plan => {
656 let _ = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
659
660 Ok(command.query().explain()?.render_text_canonical())
661 }
662 SqlExplainMode::Execution => {
663 let aggregate = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
664 let plan = Self::explain_load_query_terminal_with(command.query(), aggregate)?;
665
666 Ok(plan.execution_node_descriptor().render_text_tree())
667 }
668 SqlExplainMode::Json => {
669 let _ = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
672
673 Ok(command.query().explain()?.render_json_canonical())
674 }
675 }
676 }
677}