1use crate::{
2 db::{
3 DbSession, EntityResponse, EntitySchemaDescription, MissingRowPolicy,
4 PagedGroupedExecutionWithTrace, ProjectionResponse, Query, QueryError,
5 query::{
6 builder::aggregate::{AggregateExpr, avg, count, count_by, max_by, min_by, sum},
7 intent::IntentError,
8 plan::{
9 AggregateKind, FieldSlot, QueryMode,
10 expr::{Expr, ProjectionField},
11 },
12 },
13 sql::lowering::{
14 SqlCommand, SqlGlobalAggregateCommand, SqlGlobalAggregateTerminal, SqlLoweringError,
15 compile_sql_command, compile_sql_global_aggregate_command,
16 },
17 sql::parser::{SqlExplainMode, SqlExplainTarget, SqlStatement, parse_sql},
18 },
19 error::{ErrorClass, ErrorOrigin, InternalError},
20 traits::{CanisterKind, EntityKind, EntityValue},
21 value::Value,
22};
23
24#[derive(Clone, Debug, Eq, PartialEq)]
32pub enum SqlStatementRoute {
33 Query { entity: String },
34 Explain { entity: String },
35 Describe { entity: String },
36 ShowIndexes { entity: String },
37 ShowEntities,
38}
39
40impl SqlStatementRoute {
41 #[must_use]
46 pub const fn entity(&self) -> &str {
47 match self {
48 Self::Query { entity }
49 | Self::Explain { entity }
50 | Self::Describe { entity }
51 | Self::ShowIndexes { entity } => entity.as_str(),
52 Self::ShowEntities => "",
53 }
54 }
55
56 #[must_use]
58 pub const fn is_explain(&self) -> bool {
59 matches!(self, Self::Explain { .. })
60 }
61
62 #[must_use]
64 pub const fn is_describe(&self) -> bool {
65 matches!(self, Self::Describe { .. })
66 }
67
68 #[must_use]
70 pub const fn is_show_indexes(&self) -> bool {
71 matches!(self, Self::ShowIndexes { .. })
72 }
73
74 #[must_use]
76 pub const fn is_show_entities(&self) -> bool {
77 matches!(self, Self::ShowEntities)
78 }
79}
80
81#[derive(Clone, Copy, Debug, Eq, PartialEq)]
83enum SqlLaneKind {
84 Query,
85 Explain,
86 Describe,
87 ShowIndexes,
88 ShowEntities,
89}
90
91#[derive(Clone, Copy, Debug, Eq, PartialEq)]
93enum SqlSurface {
94 QueryFrom,
95 Explain,
96 Describe,
97 ShowIndexes,
98 ShowEntities,
99}
100
101const fn sql_command_lane<E: EntityKind>(command: &SqlCommand<E>) -> SqlLaneKind {
103 match command {
104 SqlCommand::Query(_) => SqlLaneKind::Query,
105 SqlCommand::Explain { .. } | SqlCommand::ExplainGlobalAggregate { .. } => {
106 SqlLaneKind::Explain
107 }
108 SqlCommand::DescribeEntity => SqlLaneKind::Describe,
109 SqlCommand::ShowIndexesEntity => SqlLaneKind::ShowIndexes,
110 SqlCommand::ShowEntities => SqlLaneKind::ShowEntities,
111 }
112}
113
114const fn sql_statement_route_lane(route: &SqlStatementRoute) -> SqlLaneKind {
116 match route {
117 SqlStatementRoute::Query { .. } => SqlLaneKind::Query,
118 SqlStatementRoute::Explain { .. } => SqlLaneKind::Explain,
119 SqlStatementRoute::Describe { .. } => SqlLaneKind::Describe,
120 SqlStatementRoute::ShowIndexes { .. } => SqlLaneKind::ShowIndexes,
121 SqlStatementRoute::ShowEntities => SqlLaneKind::ShowEntities,
122 }
123}
124
125const fn unsupported_sql_lane_message(surface: SqlSurface, lane: SqlLaneKind) -> &'static str {
127 match (surface, lane) {
128 (SqlSurface::QueryFrom, SqlLaneKind::Explain) => {
129 "query_from_sql does not accept EXPLAIN statements; use explain_sql(...)"
130 }
131 (SqlSurface::QueryFrom, SqlLaneKind::Describe) => {
132 "query_from_sql does not accept DESCRIBE statements; use describe_sql(...)"
133 }
134 (SqlSurface::QueryFrom, SqlLaneKind::ShowIndexes) => {
135 "query_from_sql does not accept SHOW INDEXES statements; use show_indexes_sql(...)"
136 }
137 (SqlSurface::QueryFrom, SqlLaneKind::ShowEntities) => {
138 "query_from_sql does not accept SHOW ENTITIES statements; use show_entities_sql(...)"
139 }
140 (SqlSurface::QueryFrom, SqlLaneKind::Query) => {
141 "query_from_sql requires one executable SELECT or DELETE statement"
142 }
143 (SqlSurface::Explain, SqlLaneKind::Describe) => {
144 "explain_sql does not accept DESCRIBE statements; use describe_sql(...)"
145 }
146 (SqlSurface::Explain, SqlLaneKind::ShowIndexes) => {
147 "explain_sql does not accept SHOW INDEXES statements; use show_indexes_sql(...)"
148 }
149 (SqlSurface::Explain, SqlLaneKind::ShowEntities) => {
150 "explain_sql does not accept SHOW ENTITIES statements; use show_entities_sql(...)"
151 }
152 (SqlSurface::Explain, SqlLaneKind::Query | SqlLaneKind::Explain) => {
153 "explain_sql requires an EXPLAIN statement"
154 }
155 (SqlSurface::Describe, _) => "describe_sql requires a DESCRIBE statement",
156 (SqlSurface::ShowIndexes, _) => "show_indexes_sql requires a SHOW INDEXES statement",
157 (SqlSurface::ShowEntities, _) => "show_entities_sql requires a SHOW ENTITIES statement",
158 }
159}
160
161fn unsupported_sql_lane_error(surface: SqlSurface, lane: SqlLaneKind) -> QueryError {
163 QueryError::execute(InternalError::classified(
164 ErrorClass::Unsupported,
165 ErrorOrigin::Query,
166 unsupported_sql_lane_message(surface, lane),
167 ))
168}
169
170fn compile_sql_command_ignore<E: EntityKind>(sql: &str) -> Result<SqlCommand<E>, QueryError> {
172 compile_sql_command::<E>(sql, MissingRowPolicy::Ignore).map_err(map_sql_lowering_error)
173}
174
175fn map_sql_lowering_error(err: SqlLoweringError) -> QueryError {
177 match err {
178 SqlLoweringError::Query(err) => err,
179 SqlLoweringError::Parse(crate::db::sql::parser::SqlParseError::UnsupportedFeature {
180 feature,
181 }) => QueryError::execute(InternalError::query_unsupported_sql_feature(feature)),
182 other => QueryError::execute(InternalError::classified(
183 ErrorClass::Unsupported,
184 ErrorOrigin::Query,
185 format!("SQL query is not executable in this release: {other}"),
186 )),
187 }
188}
189
190fn map_sql_parse_error(err: crate::db::sql::parser::SqlParseError) -> QueryError {
193 map_sql_lowering_error(SqlLoweringError::Parse(err))
194}
195
196fn resolve_sql_aggregate_target_slot<E: EntityKind>(field: &str) -> Result<FieldSlot, QueryError> {
199 FieldSlot::resolve(E::MODEL, field).ok_or_else(|| {
200 QueryError::execute(crate::db::error::executor_unsupported(format!(
201 "unknown aggregate target field: {field}",
202 )))
203 })
204}
205
206fn sql_global_aggregate_terminal_to_expr<E: EntityKind>(
209 terminal: &SqlGlobalAggregateTerminal,
210) -> Result<AggregateExpr, QueryError> {
211 match terminal {
212 SqlGlobalAggregateTerminal::CountRows => Ok(count()),
213 SqlGlobalAggregateTerminal::CountField(field) => {
214 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
215
216 Ok(count_by(field.as_str()))
217 }
218 SqlGlobalAggregateTerminal::SumField(field) => {
219 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
220
221 Ok(sum(field.as_str()))
222 }
223 SqlGlobalAggregateTerminal::AvgField(field) => {
224 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
225
226 Ok(avg(field.as_str()))
227 }
228 SqlGlobalAggregateTerminal::MinField(field) => {
229 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
230
231 Ok(min_by(field.as_str()))
232 }
233 SqlGlobalAggregateTerminal::MaxField(field) => {
234 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
235
236 Ok(max_by(field.as_str()))
237 }
238 }
239}
240
241fn projection_label_from_aggregate(aggregate: &AggregateExpr) -> String {
243 let kind = match aggregate.kind() {
244 AggregateKind::Count => "COUNT",
245 AggregateKind::Sum => "SUM",
246 AggregateKind::Avg => "AVG",
247 AggregateKind::Exists => "EXISTS",
248 AggregateKind::First => "FIRST",
249 AggregateKind::Last => "LAST",
250 AggregateKind::Min => "MIN",
251 AggregateKind::Max => "MAX",
252 };
253 let distinct = if aggregate.is_distinct() {
254 "DISTINCT "
255 } else {
256 ""
257 };
258
259 if let Some(field) = aggregate.target_field() {
260 return format!("{kind}({distinct}{field})");
261 }
262
263 format!("{kind}({distinct}*)")
264}
265
266fn projection_label_from_expr(expr: &Expr, ordinal: usize) -> String {
268 match expr {
269 Expr::Field(field) => field.as_str().to_string(),
270 Expr::Aggregate(aggregate) => projection_label_from_aggregate(aggregate),
271 Expr::Alias { name, .. } => name.as_str().to_string(),
272 Expr::Literal(_) | Expr::Unary { .. } | Expr::Binary { .. } => {
273 format!("expr_{ordinal}")
274 }
275 }
276}
277
278fn projection_labels_from_query<E: EntityKind>(
280 query: &Query<E>,
281) -> Result<Vec<String>, QueryError> {
282 let projection = query.plan()?.projection_spec();
283 let mut labels = Vec::with_capacity(projection.len());
284
285 for (ordinal, field) in projection.fields().enumerate() {
286 match field {
287 ProjectionField::Scalar {
288 expr: _,
289 alias: Some(alias),
290 } => labels.push(alias.as_str().to_string()),
291 ProjectionField::Scalar { expr, alias: None } => {
292 labels.push(projection_label_from_expr(expr, ordinal));
293 }
294 }
295 }
296
297 Ok(labels)
298}
299
300impl<C: CanisterKind> DbSession<C> {
301 pub fn sql_statement_route(&self, sql: &str) -> Result<SqlStatementRoute, QueryError> {
306 let statement = parse_sql(sql).map_err(map_sql_parse_error)?;
307 match statement {
308 SqlStatement::Select(select) => Ok(SqlStatementRoute::Query {
309 entity: select.entity,
310 }),
311 SqlStatement::Delete(delete) => Ok(SqlStatementRoute::Query {
312 entity: delete.entity,
313 }),
314 SqlStatement::Explain(explain) => match explain.statement {
315 SqlExplainTarget::Select(select) => Ok(SqlStatementRoute::Explain {
316 entity: select.entity,
317 }),
318 SqlExplainTarget::Delete(delete) => Ok(SqlStatementRoute::Explain {
319 entity: delete.entity,
320 }),
321 },
322 SqlStatement::Describe(describe) => Ok(SqlStatementRoute::Describe {
323 entity: describe.entity,
324 }),
325 SqlStatement::ShowIndexes(show_indexes) => Ok(SqlStatementRoute::ShowIndexes {
326 entity: show_indexes.entity,
327 }),
328 SqlStatement::ShowEntities(_) => Ok(SqlStatementRoute::ShowEntities),
329 }
330 }
331
332 pub fn describe_sql<E>(&self, sql: &str) -> Result<EntitySchemaDescription, QueryError>
334 where
335 E: EntityKind<Canister = C>,
336 {
337 let command = compile_sql_command_ignore::<E>(sql)?;
338 let lane = sql_command_lane(&command);
339
340 match command {
341 SqlCommand::DescribeEntity => Ok(self.describe_entity::<E>()),
342 SqlCommand::Query(_)
343 | SqlCommand::Explain { .. }
344 | SqlCommand::ExplainGlobalAggregate { .. }
345 | SqlCommand::ShowIndexesEntity
346 | SqlCommand::ShowEntities => {
347 Err(unsupported_sql_lane_error(SqlSurface::Describe, lane))
348 }
349 }
350 }
351
352 pub fn show_indexes_sql<E>(&self, sql: &str) -> Result<Vec<String>, QueryError>
354 where
355 E: EntityKind<Canister = C>,
356 {
357 let command = compile_sql_command_ignore::<E>(sql)?;
358 let lane = sql_command_lane(&command);
359
360 match command {
361 SqlCommand::ShowIndexesEntity => Ok(self.show_indexes::<E>()),
362 SqlCommand::Query(_)
363 | SqlCommand::Explain { .. }
364 | SqlCommand::ExplainGlobalAggregate { .. }
365 | SqlCommand::DescribeEntity
366 | SqlCommand::ShowEntities => {
367 Err(unsupported_sql_lane_error(SqlSurface::ShowIndexes, lane))
368 }
369 }
370 }
371
372 pub fn show_entities_sql(&self, sql: &str) -> Result<Vec<String>, QueryError> {
374 let statement = self.sql_statement_route(sql)?;
375 let lane = sql_statement_route_lane(&statement);
376 if lane != SqlLaneKind::ShowEntities {
377 return Err(unsupported_sql_lane_error(SqlSurface::ShowEntities, lane));
378 }
379
380 Ok(self.show_entities())
381 }
382
383 pub fn query_from_sql<E>(&self, sql: &str) -> Result<Query<E>, QueryError>
388 where
389 E: EntityKind<Canister = C>,
390 {
391 let command = compile_sql_command_ignore::<E>(sql)?;
392 let lane = sql_command_lane(&command);
393
394 match command {
395 SqlCommand::Query(query) => Ok(query),
396 SqlCommand::Explain { .. }
397 | SqlCommand::ExplainGlobalAggregate { .. }
398 | SqlCommand::DescribeEntity
399 | SqlCommand::ShowIndexesEntity
400 | SqlCommand::ShowEntities => {
401 Err(unsupported_sql_lane_error(SqlSurface::QueryFrom, lane))
402 }
403 }
404 }
405
406 pub fn sql_projection_columns<E>(&self, sql: &str) -> Result<Vec<String>, QueryError>
408 where
409 E: EntityKind<Canister = C>,
410 {
411 let query = self.query_from_sql::<E>(sql)?;
412 if query.has_grouping() {
413 return Err(QueryError::Intent(
414 IntentError::GroupedRequiresExecuteGrouped,
415 ));
416 }
417
418 match query.mode() {
419 QueryMode::Load(_) => projection_labels_from_query(&query),
420 QueryMode::Delete(_) => Err(QueryError::execute(InternalError::classified(
421 ErrorClass::Unsupported,
422 ErrorOrigin::Query,
423 "sql_projection_columns only supports SELECT statements",
424 ))),
425 }
426 }
427
428 pub fn execute_sql<E>(&self, sql: &str) -> Result<EntityResponse<E>, QueryError>
430 where
431 E: EntityKind<Canister = C> + EntityValue,
432 {
433 let query = self.query_from_sql::<E>(sql)?;
434 if query.has_grouping() {
435 return Err(QueryError::Intent(
436 IntentError::GroupedRequiresExecuteGrouped,
437 ));
438 }
439
440 self.execute_query(&query)
441 }
442
443 pub fn execute_sql_projection<E>(&self, sql: &str) -> Result<ProjectionResponse<E>, QueryError>
448 where
449 E: EntityKind<Canister = C> + EntityValue,
450 {
451 let query = self.query_from_sql::<E>(sql)?;
452 if query.has_grouping() {
453 return Err(QueryError::Intent(
454 IntentError::GroupedRequiresExecuteGrouped,
455 ));
456 }
457
458 match query.mode() {
459 QueryMode::Load(_) => {
460 self.execute_load_query_with(&query, |load, plan| load.execute_projection(plan))
461 }
462 QueryMode::Delete(_) => Err(QueryError::execute(InternalError::classified(
463 ErrorClass::Unsupported,
464 ErrorOrigin::Query,
465 "execute_sql_projection only supports SELECT statements",
466 ))),
467 }
468 }
469
470 pub fn execute_sql_aggregate<E>(&self, sql: &str) -> Result<Value, QueryError>
475 where
476 E: EntityKind<Canister = C> + EntityValue,
477 {
478 let command = compile_sql_global_aggregate_command::<E>(sql, MissingRowPolicy::Ignore)
479 .map_err(map_sql_lowering_error)?;
480
481 match command.terminal() {
482 SqlGlobalAggregateTerminal::CountRows => self
483 .execute_load_query_with(command.query(), |load, plan| load.aggregate_count(plan))
484 .map(|count| Value::Uint(u64::from(count))),
485 SqlGlobalAggregateTerminal::CountField(field) => {
486 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
487 self.execute_load_query_with(command.query(), |load, plan| {
488 load.values_by_slot(plan, target_slot)
489 })
490 .map(|values| {
491 let count = values
492 .into_iter()
493 .filter(|value| !matches!(value, Value::Null))
494 .count();
495 Value::Uint(u64::try_from(count).unwrap_or(u64::MAX))
496 })
497 }
498 SqlGlobalAggregateTerminal::SumField(field) => {
499 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
500 self.execute_load_query_with(command.query(), |load, plan| {
501 load.aggregate_sum_by_slot(plan, target_slot)
502 })
503 .map(|value| value.map_or(Value::Null, Value::Decimal))
504 }
505 SqlGlobalAggregateTerminal::AvgField(field) => {
506 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
507 self.execute_load_query_with(command.query(), |load, plan| {
508 load.aggregate_avg_by_slot(plan, target_slot)
509 })
510 .map(|value| value.map_or(Value::Null, Value::Decimal))
511 }
512 SqlGlobalAggregateTerminal::MinField(field) => {
513 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
514 let min_id = self.execute_load_query_with(command.query(), |load, plan| {
515 load.aggregate_min_by_slot(plan, target_slot)
516 })?;
517
518 match min_id {
519 Some(id) => self
520 .load::<E>()
521 .by_id(id)
522 .first_value_by(field)
523 .map(|value| value.unwrap_or(Value::Null)),
524 None => Ok(Value::Null),
525 }
526 }
527 SqlGlobalAggregateTerminal::MaxField(field) => {
528 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
529 let max_id = self.execute_load_query_with(command.query(), |load, plan| {
530 load.aggregate_max_by_slot(plan, target_slot)
531 })?;
532
533 match max_id {
534 Some(id) => self
535 .load::<E>()
536 .by_id(id)
537 .first_value_by(field)
538 .map(|value| value.unwrap_or(Value::Null)),
539 None => Ok(Value::Null),
540 }
541 }
542 }
543 }
544
545 pub fn execute_sql_grouped<E>(
547 &self,
548 sql: &str,
549 cursor_token: Option<&str>,
550 ) -> Result<PagedGroupedExecutionWithTrace, QueryError>
551 where
552 E: EntityKind<Canister = C> + EntityValue,
553 {
554 let query = self.query_from_sql::<E>(sql)?;
555 if !query.has_grouping() {
556 return Err(QueryError::execute(InternalError::classified(
557 ErrorClass::Unsupported,
558 ErrorOrigin::Query,
559 "execute_sql_grouped requires grouped SQL query intent",
560 )));
561 }
562
563 self.execute_grouped(&query, cursor_token)
564 }
565
566 pub fn explain_sql<E>(&self, sql: &str) -> Result<String, QueryError>
573 where
574 E: EntityKind<Canister = C> + EntityValue,
575 {
576 let command = compile_sql_command_ignore::<E>(sql)?;
577 let lane = sql_command_lane(&command);
578
579 match command {
580 SqlCommand::Query(_)
581 | SqlCommand::DescribeEntity
582 | SqlCommand::ShowIndexesEntity
583 | SqlCommand::ShowEntities => {
584 Err(unsupported_sql_lane_error(SqlSurface::Explain, lane))
585 }
586 SqlCommand::Explain { mode, query } => match mode {
587 SqlExplainMode::Plan => Ok(query.explain()?.render_text_canonical()),
588 SqlExplainMode::Execution => query.explain_execution_text(),
589 SqlExplainMode::Json => Ok(query.explain()?.render_json_canonical()),
590 },
591 SqlCommand::ExplainGlobalAggregate { mode, command } => {
592 Self::explain_sql_global_aggregate::<E>(mode, command)
593 }
594 }
595 }
596
597 fn explain_sql_global_aggregate<E>(
599 mode: SqlExplainMode,
600 command: SqlGlobalAggregateCommand<E>,
601 ) -> Result<String, QueryError>
602 where
603 E: EntityKind<Canister = C> + EntityValue,
604 {
605 match mode {
606 SqlExplainMode::Plan => {
607 let _ = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
610
611 Ok(command.query().explain()?.render_text_canonical())
612 }
613 SqlExplainMode::Execution => {
614 let aggregate = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
615 let plan = Self::explain_load_query_terminal_with(command.query(), aggregate)?;
616
617 Ok(plan.execution_node_descriptor().render_text_tree())
618 }
619 SqlExplainMode::Json => {
620 let _ = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
623
624 Ok(command.query().explain()?.render_json_canonical())
625 }
626 }
627 }
628}