1use crate::{
2 db::{
3 DbSession, EntityResponse, MissingRowPolicy, PagedGroupedExecutionWithTrace,
4 ProjectionResponse, Query, QueryError,
5 query::{
6 builder::aggregate::{AggregateExpr, avg, count, count_by, max_by, min_by, sum},
7 intent::IntentError,
8 plan::{
9 AggregateKind, FieldSlot, QueryMode,
10 expr::{Expr, ProjectionField},
11 },
12 },
13 sql::lowering::{
14 SqlCommand, SqlGlobalAggregateCommand, SqlGlobalAggregateTerminal, SqlLoweringError,
15 compile_sql_command, compile_sql_global_aggregate_command,
16 },
17 sql::parser::{SqlExplainMode, SqlExplainTarget, SqlStatement, parse_sql},
18 },
19 error::{ErrorClass, ErrorOrigin, InternalError},
20 traits::{CanisterKind, EntityKind, EntityValue},
21 value::Value,
22};
23
24#[derive(Clone, Debug, Eq, PartialEq)]
31pub enum SqlStatementRoute {
32 Query { entity: String },
33 Explain { entity: String },
34}
35
36impl SqlStatementRoute {
37 #[must_use]
39 pub const fn entity(&self) -> &str {
40 match self {
41 Self::Query { entity } | Self::Explain { entity } => entity.as_str(),
42 }
43 }
44
45 #[must_use]
47 pub const fn is_explain(&self) -> bool {
48 matches!(self, Self::Explain { .. })
49 }
50}
51
52fn map_sql_lowering_error(err: SqlLoweringError) -> QueryError {
54 match err {
55 SqlLoweringError::Query(err) => err,
56 SqlLoweringError::Parse(crate::db::sql::parser::SqlParseError::UnsupportedFeature {
57 feature,
58 }) => QueryError::execute(InternalError::query_unsupported_sql_feature(feature)),
59 other => QueryError::execute(InternalError::classified(
60 ErrorClass::Unsupported,
61 ErrorOrigin::Query,
62 format!("SQL query is not executable in this release: {other}"),
63 )),
64 }
65}
66
67fn map_sql_parse_error(err: crate::db::sql::parser::SqlParseError) -> QueryError {
70 map_sql_lowering_error(SqlLoweringError::Parse(err))
71}
72
73fn resolve_sql_aggregate_target_slot<E: EntityKind>(field: &str) -> Result<FieldSlot, QueryError> {
76 FieldSlot::resolve(E::MODEL, field).ok_or_else(|| {
77 QueryError::execute(crate::db::error::executor_unsupported(format!(
78 "unknown aggregate target field: {field}",
79 )))
80 })
81}
82
83fn sql_global_aggregate_terminal_to_expr<E: EntityKind>(
86 terminal: &SqlGlobalAggregateTerminal,
87) -> Result<AggregateExpr, QueryError> {
88 match terminal {
89 SqlGlobalAggregateTerminal::CountRows => Ok(count()),
90 SqlGlobalAggregateTerminal::CountField(field) => {
91 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
92
93 Ok(count_by(field.as_str()))
94 }
95 SqlGlobalAggregateTerminal::SumField(field) => {
96 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
97
98 Ok(sum(field.as_str()))
99 }
100 SqlGlobalAggregateTerminal::AvgField(field) => {
101 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
102
103 Ok(avg(field.as_str()))
104 }
105 SqlGlobalAggregateTerminal::MinField(field) => {
106 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
107
108 Ok(min_by(field.as_str()))
109 }
110 SqlGlobalAggregateTerminal::MaxField(field) => {
111 let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
112
113 Ok(max_by(field.as_str()))
114 }
115 }
116}
117
118fn projection_label_from_aggregate(aggregate: &AggregateExpr) -> String {
120 let kind = match aggregate.kind() {
121 AggregateKind::Count => "COUNT",
122 AggregateKind::Sum => "SUM",
123 AggregateKind::Avg => "AVG",
124 AggregateKind::Exists => "EXISTS",
125 AggregateKind::First => "FIRST",
126 AggregateKind::Last => "LAST",
127 AggregateKind::Min => "MIN",
128 AggregateKind::Max => "MAX",
129 };
130 let distinct = if aggregate.is_distinct() {
131 "DISTINCT "
132 } else {
133 ""
134 };
135
136 if let Some(field) = aggregate.target_field() {
137 return format!("{kind}({distinct}{field})");
138 }
139
140 format!("{kind}({distinct}*)")
141}
142
143fn projection_label_from_expr(expr: &Expr, ordinal: usize) -> String {
145 match expr {
146 Expr::Field(field) => field.as_str().to_string(),
147 Expr::Aggregate(aggregate) => projection_label_from_aggregate(aggregate),
148 Expr::Alias { name, .. } => name.as_str().to_string(),
149 Expr::Literal(_) | Expr::Unary { .. } | Expr::Binary { .. } => {
150 format!("expr_{ordinal}")
151 }
152 }
153}
154
155fn projection_labels_from_query<E: EntityKind>(
157 query: &Query<E>,
158) -> Result<Vec<String>, QueryError> {
159 let projection = query.plan()?.projection_spec();
160 let mut labels = Vec::with_capacity(projection.len());
161
162 for (ordinal, field) in projection.fields().enumerate() {
163 match field {
164 ProjectionField::Scalar {
165 expr: _,
166 alias: Some(alias),
167 } => labels.push(alias.as_str().to_string()),
168 ProjectionField::Scalar { expr, alias: None } => {
169 labels.push(projection_label_from_expr(expr, ordinal));
170 }
171 }
172 }
173
174 Ok(labels)
175}
176
177impl<C: CanisterKind> DbSession<C> {
178 pub fn sql_statement_route(&self, sql: &str) -> Result<SqlStatementRoute, QueryError> {
183 let statement = parse_sql(sql).map_err(map_sql_parse_error)?;
184 match statement {
185 SqlStatement::Select(select) => Ok(SqlStatementRoute::Query {
186 entity: select.entity,
187 }),
188 SqlStatement::Delete(delete) => Ok(SqlStatementRoute::Query {
189 entity: delete.entity,
190 }),
191 SqlStatement::Explain(explain) => match explain.statement {
192 SqlExplainTarget::Select(select) => Ok(SqlStatementRoute::Explain {
193 entity: select.entity,
194 }),
195 SqlExplainTarget::Delete(delete) => Ok(SqlStatementRoute::Explain {
196 entity: delete.entity,
197 }),
198 },
199 }
200 }
201
202 pub fn query_from_sql<E>(&self, sql: &str) -> Result<Query<E>, QueryError>
207 where
208 E: EntityKind<Canister = C>,
209 {
210 let command = compile_sql_command::<E>(sql, MissingRowPolicy::Ignore)
211 .map_err(map_sql_lowering_error)?;
212
213 match command {
214 SqlCommand::Query(query) => Ok(query),
215 SqlCommand::Explain { .. } | SqlCommand::ExplainGlobalAggregate { .. } => {
216 Err(QueryError::execute(InternalError::classified(
217 ErrorClass::Unsupported,
218 ErrorOrigin::Query,
219 "query_from_sql does not accept EXPLAIN statements; use explain_sql(...)",
220 )))
221 }
222 }
223 }
224
225 pub fn sql_projection_columns<E>(&self, sql: &str) -> Result<Vec<String>, QueryError>
227 where
228 E: EntityKind<Canister = C>,
229 {
230 let query = self.query_from_sql::<E>(sql)?;
231 if query.has_grouping() {
232 return Err(QueryError::Intent(
233 IntentError::GroupedRequiresExecuteGrouped,
234 ));
235 }
236
237 match query.mode() {
238 QueryMode::Load(_) => projection_labels_from_query(&query),
239 QueryMode::Delete(_) => Err(QueryError::execute(InternalError::classified(
240 ErrorClass::Unsupported,
241 ErrorOrigin::Query,
242 "sql_projection_columns only supports SELECT statements",
243 ))),
244 }
245 }
246
247 pub fn execute_sql<E>(&self, sql: &str) -> Result<EntityResponse<E>, QueryError>
249 where
250 E: EntityKind<Canister = C> + EntityValue,
251 {
252 let query = self.query_from_sql::<E>(sql)?;
253 if query.has_grouping() {
254 return Err(QueryError::Intent(
255 IntentError::GroupedRequiresExecuteGrouped,
256 ));
257 }
258
259 self.execute_query(&query)
260 }
261
262 pub fn execute_sql_projection<E>(&self, sql: &str) -> Result<ProjectionResponse<E>, QueryError>
267 where
268 E: EntityKind<Canister = C> + EntityValue,
269 {
270 let query = self.query_from_sql::<E>(sql)?;
271 if query.has_grouping() {
272 return Err(QueryError::Intent(
273 IntentError::GroupedRequiresExecuteGrouped,
274 ));
275 }
276
277 match query.mode() {
278 QueryMode::Load(_) => {
279 self.execute_load_query_with(&query, |load, plan| load.execute_projection(plan))
280 }
281 QueryMode::Delete(_) => Err(QueryError::execute(InternalError::classified(
282 ErrorClass::Unsupported,
283 ErrorOrigin::Query,
284 "execute_sql_projection only supports SELECT statements",
285 ))),
286 }
287 }
288
289 pub fn execute_sql_aggregate<E>(&self, sql: &str) -> Result<Value, QueryError>
294 where
295 E: EntityKind<Canister = C> + EntityValue,
296 {
297 let command = compile_sql_global_aggregate_command::<E>(sql, MissingRowPolicy::Ignore)
298 .map_err(map_sql_lowering_error)?;
299
300 match command.terminal() {
301 SqlGlobalAggregateTerminal::CountRows => self
302 .execute_load_query_with(command.query(), |load, plan| load.aggregate_count(plan))
303 .map(|count| Value::Uint(u64::from(count))),
304 SqlGlobalAggregateTerminal::CountField(field) => {
305 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
306 self.execute_load_query_with(command.query(), |load, plan| {
307 load.values_by_slot(plan, target_slot)
308 })
309 .map(|values| {
310 let count = values
311 .into_iter()
312 .filter(|value| !matches!(value, Value::Null))
313 .count();
314 Value::Uint(u64::try_from(count).unwrap_or(u64::MAX))
315 })
316 }
317 SqlGlobalAggregateTerminal::SumField(field) => {
318 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
319 self.execute_load_query_with(command.query(), |load, plan| {
320 load.aggregate_sum_by_slot(plan, target_slot)
321 })
322 .map(|value| value.map_or(Value::Null, Value::Decimal))
323 }
324 SqlGlobalAggregateTerminal::AvgField(field) => {
325 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
326 self.execute_load_query_with(command.query(), |load, plan| {
327 load.aggregate_avg_by_slot(plan, target_slot)
328 })
329 .map(|value| value.map_or(Value::Null, Value::Decimal))
330 }
331 SqlGlobalAggregateTerminal::MinField(field) => {
332 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
333 let min_id = self.execute_load_query_with(command.query(), |load, plan| {
334 load.aggregate_min_by_slot(plan, target_slot)
335 })?;
336
337 match min_id {
338 Some(id) => self
339 .load::<E>()
340 .by_id(id)
341 .first_value_by(field)
342 .map(|value| value.unwrap_or(Value::Null)),
343 None => Ok(Value::Null),
344 }
345 }
346 SqlGlobalAggregateTerminal::MaxField(field) => {
347 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
348 let max_id = self.execute_load_query_with(command.query(), |load, plan| {
349 load.aggregate_max_by_slot(plan, target_slot)
350 })?;
351
352 match max_id {
353 Some(id) => self
354 .load::<E>()
355 .by_id(id)
356 .first_value_by(field)
357 .map(|value| value.unwrap_or(Value::Null)),
358 None => Ok(Value::Null),
359 }
360 }
361 }
362 }
363
364 pub fn execute_sql_grouped<E>(
366 &self,
367 sql: &str,
368 cursor_token: Option<&str>,
369 ) -> Result<PagedGroupedExecutionWithTrace, QueryError>
370 where
371 E: EntityKind<Canister = C> + EntityValue,
372 {
373 let query = self.query_from_sql::<E>(sql)?;
374 if !query.has_grouping() {
375 return Err(QueryError::execute(InternalError::classified(
376 ErrorClass::Unsupported,
377 ErrorOrigin::Query,
378 "execute_sql_grouped requires grouped SQL query intent",
379 )));
380 }
381
382 self.execute_grouped(&query, cursor_token)
383 }
384
385 pub fn explain_sql<E>(&self, sql: &str) -> Result<String, QueryError>
392 where
393 E: EntityKind<Canister = C> + EntityValue,
394 {
395 let command = compile_sql_command::<E>(sql, MissingRowPolicy::Ignore)
396 .map_err(map_sql_lowering_error)?;
397
398 match command {
399 SqlCommand::Query(_) => Err(QueryError::execute(InternalError::classified(
400 ErrorClass::Unsupported,
401 ErrorOrigin::Query,
402 "explain_sql requires an EXPLAIN statement",
403 ))),
404 SqlCommand::Explain { mode, query } => match mode {
405 SqlExplainMode::Plan => Ok(query.explain()?.render_text_canonical()),
406 SqlExplainMode::Execution => query.explain_execution_text(),
407 SqlExplainMode::Json => Ok(query.explain()?.render_json_canonical()),
408 },
409 SqlCommand::ExplainGlobalAggregate { mode, command } => {
410 Self::explain_sql_global_aggregate::<E>(mode, command)
411 }
412 }
413 }
414
415 fn explain_sql_global_aggregate<E>(
417 mode: SqlExplainMode,
418 command: SqlGlobalAggregateCommand<E>,
419 ) -> Result<String, QueryError>
420 where
421 E: EntityKind<Canister = C> + EntityValue,
422 {
423 match mode {
424 SqlExplainMode::Plan => {
425 let _ = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
428
429 Ok(command.query().explain()?.render_text_canonical())
430 }
431 SqlExplainMode::Execution => {
432 let aggregate = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
433 let plan = Self::explain_load_query_terminal_with(command.query(), aggregate)?;
434
435 Ok(plan.execution_node_descriptor().render_text_tree())
436 }
437 SqlExplainMode::Json => {
438 let _ = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
441
442 Ok(command.query().explain()?.render_json_canonical())
443 }
444 }
445 }
446}