1use crate::{
2 db::{
3 DbSession, EntityResponse, EntitySchemaDescription, MissingRowPolicy,
4 PagedGroupedExecutionWithTrace, ProjectionResponse, Query, QueryError,
5 query::{
6 builder::aggregate::{AggregateExpr, avg, count, count_by, max_by, min_by, sum},
7 intent::IntentError,
8 plan::{
9 AggregateKind, FieldSlot, QueryMode,
10 expr::{Expr, ProjectionField},
11 },
12 },
13 sql::lowering::{
14 SqlCommand, SqlGlobalAggregateCommand, SqlGlobalAggregateTerminal, SqlLoweringError,
15 compile_sql_command, compile_sql_global_aggregate_command,
16 },
17 sql::parser::{SqlExplainMode, SqlExplainTarget, SqlStatement, parse_sql},
18 },
19 error::{ErrorClass, ErrorOrigin, InternalError},
20 traits::{CanisterKind, EntityKind, EntityValue},
21 value::Value,
22};
23
24#[derive(Clone, Debug, Eq, PartialEq)]
32pub enum SqlStatementRoute {
33 Query { entity: String },
34 Explain { entity: String },
35 Describe { entity: String },
36 ShowIndexes { entity: String },
37}
38
39impl SqlStatementRoute {
40 #[must_use]
42 pub const fn entity(&self) -> &str {
43 match self {
44 Self::Query { entity }
45 | Self::Explain { entity }
46 | Self::Describe { entity }
47 | Self::ShowIndexes { entity } => entity.as_str(),
48 }
49 }
50
51 #[must_use]
53 pub const fn is_explain(&self) -> bool {
54 matches!(self, Self::Explain { .. })
55 }
56
57 #[must_use]
59 pub const fn is_describe(&self) -> bool {
60 matches!(self, Self::Describe { .. })
61 }
62
63 #[must_use]
65 pub const fn is_show_indexes(&self) -> bool {
66 matches!(self, Self::ShowIndexes { .. })
67 }
68}
69
70fn map_sql_lowering_error(err: SqlLoweringError) -> QueryError {
72 match err {
73 SqlLoweringError::Query(err) => err,
74 SqlLoweringError::Parse(crate::db::sql::parser::SqlParseError::UnsupportedFeature {
75 feature,
76 }) => QueryError::execute(InternalError::query_unsupported_sql_feature(feature)),
77 other => QueryError::execute(InternalError::classified(
78 ErrorClass::Unsupported,
79 ErrorOrigin::Query,
80 format!("SQL query is not executable in this release: {other}"),
81 )),
82 }
83}
84
85fn map_sql_parse_error(err: crate::db::sql::parser::SqlParseError) -> QueryError {
88 map_sql_lowering_error(SqlLoweringError::Parse(err))
89}
90
91fn resolve_sql_aggregate_target_slot<E: EntityKind>(field: &str) -> Result<FieldSlot, QueryError> {
94 FieldSlot::resolve(E::MODEL, field).ok_or_else(|| {
95 QueryError::execute(crate::db::error::executor_unsupported(format!(
96 "unknown aggregate target field: {field}",
97 )))
98 })
99}
100
101fn sql_global_aggregate_terminal_to_expr<E: EntityKind>(
104 terminal: &SqlGlobalAggregateTerminal,
105) -> Result<AggregateExpr, QueryError> {
106 match terminal {
107 SqlGlobalAggregateTerminal::CountRows => Ok(count()),
108 SqlGlobalAggregateTerminal::CountField(field) => {
109 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
110
111 Ok(count_by(field.as_str()))
112 }
113 SqlGlobalAggregateTerminal::SumField(field) => {
114 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
115
116 Ok(sum(field.as_str()))
117 }
118 SqlGlobalAggregateTerminal::AvgField(field) => {
119 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
120
121 Ok(avg(field.as_str()))
122 }
123 SqlGlobalAggregateTerminal::MinField(field) => {
124 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
125
126 Ok(min_by(field.as_str()))
127 }
128 SqlGlobalAggregateTerminal::MaxField(field) => {
129 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
130
131 Ok(max_by(field.as_str()))
132 }
133 }
134}
135
136fn projection_label_from_aggregate(aggregate: &AggregateExpr) -> String {
138 let kind = match aggregate.kind() {
139 AggregateKind::Count => "COUNT",
140 AggregateKind::Sum => "SUM",
141 AggregateKind::Avg => "AVG",
142 AggregateKind::Exists => "EXISTS",
143 AggregateKind::First => "FIRST",
144 AggregateKind::Last => "LAST",
145 AggregateKind::Min => "MIN",
146 AggregateKind::Max => "MAX",
147 };
148 let distinct = if aggregate.is_distinct() {
149 "DISTINCT "
150 } else {
151 ""
152 };
153
154 if let Some(field) = aggregate.target_field() {
155 return format!("{kind}({distinct}{field})");
156 }
157
158 format!("{kind}({distinct}*)")
159}
160
161fn projection_label_from_expr(expr: &Expr, ordinal: usize) -> String {
163 match expr {
164 Expr::Field(field) => field.as_str().to_string(),
165 Expr::Aggregate(aggregate) => projection_label_from_aggregate(aggregate),
166 Expr::Alias { name, .. } => name.as_str().to_string(),
167 Expr::Literal(_) | Expr::Unary { .. } | Expr::Binary { .. } => {
168 format!("expr_{ordinal}")
169 }
170 }
171}
172
173fn projection_labels_from_query<E: EntityKind>(
175 query: &Query<E>,
176) -> Result<Vec<String>, QueryError> {
177 let projection = query.plan()?.projection_spec();
178 let mut labels = Vec::with_capacity(projection.len());
179
180 for (ordinal, field) in projection.fields().enumerate() {
181 match field {
182 ProjectionField::Scalar {
183 expr: _,
184 alias: Some(alias),
185 } => labels.push(alias.as_str().to_string()),
186 ProjectionField::Scalar { expr, alias: None } => {
187 labels.push(projection_label_from_expr(expr, ordinal));
188 }
189 }
190 }
191
192 Ok(labels)
193}
194
195impl<C: CanisterKind> DbSession<C> {
196 pub fn sql_statement_route(&self, sql: &str) -> Result<SqlStatementRoute, QueryError> {
201 let statement = parse_sql(sql).map_err(map_sql_parse_error)?;
202 match statement {
203 SqlStatement::Select(select) => Ok(SqlStatementRoute::Query {
204 entity: select.entity,
205 }),
206 SqlStatement::Delete(delete) => Ok(SqlStatementRoute::Query {
207 entity: delete.entity,
208 }),
209 SqlStatement::Explain(explain) => match explain.statement {
210 SqlExplainTarget::Select(select) => Ok(SqlStatementRoute::Explain {
211 entity: select.entity,
212 }),
213 SqlExplainTarget::Delete(delete) => Ok(SqlStatementRoute::Explain {
214 entity: delete.entity,
215 }),
216 },
217 SqlStatement::Describe(describe) => Ok(SqlStatementRoute::Describe {
218 entity: describe.entity,
219 }),
220 SqlStatement::ShowIndexes(show_indexes) => Ok(SqlStatementRoute::ShowIndexes {
221 entity: show_indexes.entity,
222 }),
223 }
224 }
225
226 pub fn describe_sql<E>(&self, sql: &str) -> Result<EntitySchemaDescription, QueryError>
228 where
229 E: EntityKind<Canister = C>,
230 {
231 let command = compile_sql_command::<E>(sql, MissingRowPolicy::Ignore)
232 .map_err(map_sql_lowering_error)?;
233
234 match command {
235 SqlCommand::DescribeEntity => Ok(self.describe_entity::<E>()),
236 SqlCommand::Query(_)
237 | SqlCommand::Explain { .. }
238 | SqlCommand::ExplainGlobalAggregate { .. }
239 | SqlCommand::ShowIndexesEntity => Err(QueryError::execute(InternalError::classified(
240 ErrorClass::Unsupported,
241 ErrorOrigin::Query,
242 "describe_sql requires a DESCRIBE statement",
243 ))),
244 }
245 }
246
247 pub fn show_indexes_sql<E>(&self, sql: &str) -> Result<Vec<String>, QueryError>
249 where
250 E: EntityKind<Canister = C>,
251 {
252 let command = compile_sql_command::<E>(sql, MissingRowPolicy::Ignore)
253 .map_err(map_sql_lowering_error)?;
254
255 match command {
256 SqlCommand::ShowIndexesEntity => Ok(self.show_indexes::<E>()),
257 SqlCommand::Query(_)
258 | SqlCommand::Explain { .. }
259 | SqlCommand::ExplainGlobalAggregate { .. }
260 | SqlCommand::DescribeEntity => Err(QueryError::execute(InternalError::classified(
261 ErrorClass::Unsupported,
262 ErrorOrigin::Query,
263 "show_indexes_sql requires a SHOW INDEXES statement",
264 ))),
265 }
266 }
267
268 pub fn query_from_sql<E>(&self, sql: &str) -> Result<Query<E>, QueryError>
273 where
274 E: EntityKind<Canister = C>,
275 {
276 let command = compile_sql_command::<E>(sql, MissingRowPolicy::Ignore)
277 .map_err(map_sql_lowering_error)?;
278
279 match command {
280 SqlCommand::Query(query) => Ok(query),
281 SqlCommand::Explain { .. } | SqlCommand::ExplainGlobalAggregate { .. } => {
282 Err(QueryError::execute(InternalError::classified(
283 ErrorClass::Unsupported,
284 ErrorOrigin::Query,
285 "query_from_sql does not accept EXPLAIN statements; use explain_sql(...)",
286 )))
287 }
288 SqlCommand::DescribeEntity => Err(QueryError::execute(InternalError::classified(
289 ErrorClass::Unsupported,
290 ErrorOrigin::Query,
291 "query_from_sql does not accept DESCRIBE statements; use describe_sql(...)",
292 ))),
293 SqlCommand::ShowIndexesEntity => Err(QueryError::execute(InternalError::classified(
294 ErrorClass::Unsupported,
295 ErrorOrigin::Query,
296 "query_from_sql does not accept SHOW INDEXES statements; use show_indexes_sql(...)",
297 ))),
298 }
299 }
300
301 pub fn sql_projection_columns<E>(&self, sql: &str) -> Result<Vec<String>, QueryError>
303 where
304 E: EntityKind<Canister = C>,
305 {
306 let query = self.query_from_sql::<E>(sql)?;
307 if query.has_grouping() {
308 return Err(QueryError::Intent(
309 IntentError::GroupedRequiresExecuteGrouped,
310 ));
311 }
312
313 match query.mode() {
314 QueryMode::Load(_) => projection_labels_from_query(&query),
315 QueryMode::Delete(_) => Err(QueryError::execute(InternalError::classified(
316 ErrorClass::Unsupported,
317 ErrorOrigin::Query,
318 "sql_projection_columns only supports SELECT statements",
319 ))),
320 }
321 }
322
323 pub fn execute_sql<E>(&self, sql: &str) -> Result<EntityResponse<E>, QueryError>
325 where
326 E: EntityKind<Canister = C> + EntityValue,
327 {
328 let query = self.query_from_sql::<E>(sql)?;
329 if query.has_grouping() {
330 return Err(QueryError::Intent(
331 IntentError::GroupedRequiresExecuteGrouped,
332 ));
333 }
334
335 self.execute_query(&query)
336 }
337
338 pub fn execute_sql_projection<E>(&self, sql: &str) -> Result<ProjectionResponse<E>, QueryError>
343 where
344 E: EntityKind<Canister = C> + EntityValue,
345 {
346 let query = self.query_from_sql::<E>(sql)?;
347 if query.has_grouping() {
348 return Err(QueryError::Intent(
349 IntentError::GroupedRequiresExecuteGrouped,
350 ));
351 }
352
353 match query.mode() {
354 QueryMode::Load(_) => {
355 self.execute_load_query_with(&query, |load, plan| load.execute_projection(plan))
356 }
357 QueryMode::Delete(_) => Err(QueryError::execute(InternalError::classified(
358 ErrorClass::Unsupported,
359 ErrorOrigin::Query,
360 "execute_sql_projection only supports SELECT statements",
361 ))),
362 }
363 }
364
365 pub fn execute_sql_aggregate<E>(&self, sql: &str) -> Result<Value, QueryError>
370 where
371 E: EntityKind<Canister = C> + EntityValue,
372 {
373 let command = compile_sql_global_aggregate_command::<E>(sql, MissingRowPolicy::Ignore)
374 .map_err(map_sql_lowering_error)?;
375
376 match command.terminal() {
377 SqlGlobalAggregateTerminal::CountRows => self
378 .execute_load_query_with(command.query(), |load, plan| load.aggregate_count(plan))
379 .map(|count| Value::Uint(u64::from(count))),
380 SqlGlobalAggregateTerminal::CountField(field) => {
381 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
382 self.execute_load_query_with(command.query(), |load, plan| {
383 load.values_by_slot(plan, target_slot)
384 })
385 .map(|values| {
386 let count = values
387 .into_iter()
388 .filter(|value| !matches!(value, Value::Null))
389 .count();
390 Value::Uint(u64::try_from(count).unwrap_or(u64::MAX))
391 })
392 }
393 SqlGlobalAggregateTerminal::SumField(field) => {
394 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
395 self.execute_load_query_with(command.query(), |load, plan| {
396 load.aggregate_sum_by_slot(plan, target_slot)
397 })
398 .map(|value| value.map_or(Value::Null, Value::Decimal))
399 }
400 SqlGlobalAggregateTerminal::AvgField(field) => {
401 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
402 self.execute_load_query_with(command.query(), |load, plan| {
403 load.aggregate_avg_by_slot(plan, target_slot)
404 })
405 .map(|value| value.map_or(Value::Null, Value::Decimal))
406 }
407 SqlGlobalAggregateTerminal::MinField(field) => {
408 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
409 let min_id = self.execute_load_query_with(command.query(), |load, plan| {
410 load.aggregate_min_by_slot(plan, target_slot)
411 })?;
412
413 match min_id {
414 Some(id) => self
415 .load::<E>()
416 .by_id(id)
417 .first_value_by(field)
418 .map(|value| value.unwrap_or(Value::Null)),
419 None => Ok(Value::Null),
420 }
421 }
422 SqlGlobalAggregateTerminal::MaxField(field) => {
423 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
424 let max_id = self.execute_load_query_with(command.query(), |load, plan| {
425 load.aggregate_max_by_slot(plan, target_slot)
426 })?;
427
428 match max_id {
429 Some(id) => self
430 .load::<E>()
431 .by_id(id)
432 .first_value_by(field)
433 .map(|value| value.unwrap_or(Value::Null)),
434 None => Ok(Value::Null),
435 }
436 }
437 }
438 }
439
440 pub fn execute_sql_grouped<E>(
442 &self,
443 sql: &str,
444 cursor_token: Option<&str>,
445 ) -> Result<PagedGroupedExecutionWithTrace, QueryError>
446 where
447 E: EntityKind<Canister = C> + EntityValue,
448 {
449 let query = self.query_from_sql::<E>(sql)?;
450 if !query.has_grouping() {
451 return Err(QueryError::execute(InternalError::classified(
452 ErrorClass::Unsupported,
453 ErrorOrigin::Query,
454 "execute_sql_grouped requires grouped SQL query intent",
455 )));
456 }
457
458 self.execute_grouped(&query, cursor_token)
459 }
460
461 pub fn explain_sql<E>(&self, sql: &str) -> Result<String, QueryError>
468 where
469 E: EntityKind<Canister = C> + EntityValue,
470 {
471 let command = compile_sql_command::<E>(sql, MissingRowPolicy::Ignore)
472 .map_err(map_sql_lowering_error)?;
473
474 match command {
475 SqlCommand::Query(_) => Err(QueryError::execute(InternalError::classified(
476 ErrorClass::Unsupported,
477 ErrorOrigin::Query,
478 "explain_sql requires an EXPLAIN statement",
479 ))),
480 SqlCommand::DescribeEntity => Err(QueryError::execute(InternalError::classified(
481 ErrorClass::Unsupported,
482 ErrorOrigin::Query,
483 "explain_sql does not accept DESCRIBE statements; use describe_sql(...)",
484 ))),
485 SqlCommand::ShowIndexesEntity => Err(QueryError::execute(InternalError::classified(
486 ErrorClass::Unsupported,
487 ErrorOrigin::Query,
488 "explain_sql does not accept SHOW INDEXES statements; use show_indexes_sql(...)",
489 ))),
490 SqlCommand::Explain { mode, query } => match mode {
491 SqlExplainMode::Plan => Ok(query.explain()?.render_text_canonical()),
492 SqlExplainMode::Execution => query.explain_execution_text(),
493 SqlExplainMode::Json => Ok(query.explain()?.render_json_canonical()),
494 },
495 SqlCommand::ExplainGlobalAggregate { mode, command } => {
496 Self::explain_sql_global_aggregate::<E>(mode, command)
497 }
498 }
499 }
500
501 fn explain_sql_global_aggregate<E>(
503 mode: SqlExplainMode,
504 command: SqlGlobalAggregateCommand<E>,
505 ) -> Result<String, QueryError>
506 where
507 E: EntityKind<Canister = C> + EntityValue,
508 {
509 match mode {
510 SqlExplainMode::Plan => {
511 let _ = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
514
515 Ok(command.query().explain()?.render_text_canonical())
516 }
517 SqlExplainMode::Execution => {
518 let aggregate = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
519 let plan = Self::explain_load_query_terminal_with(command.query(), aggregate)?;
520
521 Ok(plan.execution_node_descriptor().render_text_tree())
522 }
523 SqlExplainMode::Json => {
524 let _ = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
527
528 Ok(command.query().explain()?.render_json_canonical())
529 }
530 }
531 }
532}