Skip to main content

icydb_core/db/session/
sql.rs

1use crate::{
2    db::{
3        DbSession, EntityResponse, EntitySchemaDescription, MissingRowPolicy,
4        PagedGroupedExecutionWithTrace, ProjectionResponse, Query, QueryError,
5        query::{
6            builder::aggregate::{AggregateExpr, avg, count, count_by, max_by, min_by, sum},
7            intent::IntentError,
8            plan::{
9                AggregateKind, FieldSlot, QueryMode,
10                expr::{Expr, ProjectionField},
11            },
12        },
13        sql::lowering::{
14            SqlCommand, SqlGlobalAggregateCommand, SqlGlobalAggregateTerminal, SqlLoweringError,
15            compile_sql_command, compile_sql_global_aggregate_command,
16        },
17        sql::parser::{SqlExplainMode, SqlExplainTarget, SqlStatement, parse_sql},
18    },
19    error::{ErrorClass, ErrorOrigin, InternalError},
20    traits::{CanisterKind, EntityKind, EntityValue},
21    value::Value,
22};
23
24///
25/// SqlStatementRoute
26///
27/// Canonical SQL statement routing metadata derived from reduced SQL parser output.
28/// Carries surface kind (`Query` / `Explain` / `Describe` / `ShowIndexes` / `ShowEntities`)
29/// and canonical parsed entity identifier.
30///
31#[derive(Clone, Debug, Eq, PartialEq)]
32pub enum SqlStatementRoute {
33    Query { entity: String },
34    Explain { entity: String },
35    Describe { entity: String },
36    ShowIndexes { entity: String },
37    ShowEntities,
38}
39
40impl SqlStatementRoute {
41    /// Borrow the parsed SQL entity identifier for this statement.
42    ///
43    /// `SHOW ENTITIES` does not carry an entity identifier and returns an
44    /// empty string for this accessor.
45    #[must_use]
46    pub const fn entity(&self) -> &str {
47        match self {
48            Self::Query { entity }
49            | Self::Explain { entity }
50            | Self::Describe { entity }
51            | Self::ShowIndexes { entity } => entity.as_str(),
52            Self::ShowEntities => "",
53        }
54    }
55
56    /// Return whether this route targets the EXPLAIN surface.
57    #[must_use]
58    pub const fn is_explain(&self) -> bool {
59        matches!(self, Self::Explain { .. })
60    }
61
62    /// Return whether this route targets the DESCRIBE surface.
63    #[must_use]
64    pub const fn is_describe(&self) -> bool {
65        matches!(self, Self::Describe { .. })
66    }
67
68    /// Return whether this route targets the `SHOW INDEXES` surface.
69    #[must_use]
70    pub const fn is_show_indexes(&self) -> bool {
71        matches!(self, Self::ShowIndexes { .. })
72    }
73
74    /// Return whether this route targets the `SHOW ENTITIES` surface.
75    #[must_use]
76    pub const fn is_show_entities(&self) -> bool {
77        matches!(self, Self::ShowEntities)
78    }
79}
80
81// Map SQL frontend parse/lowering failures into query-facing execution errors.
82fn map_sql_lowering_error(err: SqlLoweringError) -> QueryError {
83    match err {
84        SqlLoweringError::Query(err) => err,
85        SqlLoweringError::Parse(crate::db::sql::parser::SqlParseError::UnsupportedFeature {
86            feature,
87        }) => QueryError::execute(InternalError::query_unsupported_sql_feature(feature)),
88        other => QueryError::execute(InternalError::classified(
89            ErrorClass::Unsupported,
90            ErrorOrigin::Query,
91            format!("SQL query is not executable in this release: {other}"),
92        )),
93    }
94}
95
96// Map reduced SQL parse failures through the same query-facing classification
97// policy used by SQL lowering entrypoints.
98fn map_sql_parse_error(err: crate::db::sql::parser::SqlParseError) -> QueryError {
99    map_sql_lowering_error(SqlLoweringError::Parse(err))
100}
101
102// Resolve one aggregate target field through planner slot contracts before
103// aggregate terminal execution.
104fn resolve_sql_aggregate_target_slot<E: EntityKind>(field: &str) -> Result<FieldSlot, QueryError> {
105    FieldSlot::resolve(E::MODEL, field).ok_or_else(|| {
106        QueryError::execute(crate::db::error::executor_unsupported(format!(
107            "unknown aggregate target field: {field}",
108        )))
109    })
110}
111
112// Convert one lowered global SQL aggregate terminal into aggregate expression
113// contracts used by aggregate explain execution descriptors.
114fn sql_global_aggregate_terminal_to_expr<E: EntityKind>(
115    terminal: &SqlGlobalAggregateTerminal,
116) -> Result<AggregateExpr, QueryError> {
117    match terminal {
118        SqlGlobalAggregateTerminal::CountRows => Ok(count()),
119        SqlGlobalAggregateTerminal::CountField(field) => {
120            let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
121
122            Ok(count_by(field.as_str()))
123        }
124        SqlGlobalAggregateTerminal::SumField(field) => {
125            let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
126
127            Ok(sum(field.as_str()))
128        }
129        SqlGlobalAggregateTerminal::AvgField(field) => {
130            let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
131
132            Ok(avg(field.as_str()))
133        }
134        SqlGlobalAggregateTerminal::MinField(field) => {
135            let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
136
137            Ok(min_by(field.as_str()))
138        }
139        SqlGlobalAggregateTerminal::MaxField(field) => {
140            let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
141
142            Ok(max_by(field.as_str()))
143        }
144    }
145}
146
147// Render one aggregate expression into a canonical projection column label.
148fn projection_label_from_aggregate(aggregate: &AggregateExpr) -> String {
149    let kind = match aggregate.kind() {
150        AggregateKind::Count => "COUNT",
151        AggregateKind::Sum => "SUM",
152        AggregateKind::Avg => "AVG",
153        AggregateKind::Exists => "EXISTS",
154        AggregateKind::First => "FIRST",
155        AggregateKind::Last => "LAST",
156        AggregateKind::Min => "MIN",
157        AggregateKind::Max => "MAX",
158    };
159    let distinct = if aggregate.is_distinct() {
160        "DISTINCT "
161    } else {
162        ""
163    };
164
165    if let Some(field) = aggregate.target_field() {
166        return format!("{kind}({distinct}{field})");
167    }
168
169    format!("{kind}({distinct}*)")
170}
171
172// Render one projection expression into a canonical output label.
173fn projection_label_from_expr(expr: &Expr, ordinal: usize) -> String {
174    match expr {
175        Expr::Field(field) => field.as_str().to_string(),
176        Expr::Aggregate(aggregate) => projection_label_from_aggregate(aggregate),
177        Expr::Alias { name, .. } => name.as_str().to_string(),
178        Expr::Literal(_) | Expr::Unary { .. } | Expr::Binary { .. } => {
179            format!("expr_{ordinal}")
180        }
181    }
182}
183
184// Derive canonical projection column labels from one planned query projection spec.
185fn projection_labels_from_query<E: EntityKind>(
186    query: &Query<E>,
187) -> Result<Vec<String>, QueryError> {
188    let projection = query.plan()?.projection_spec();
189    let mut labels = Vec::with_capacity(projection.len());
190
191    for (ordinal, field) in projection.fields().enumerate() {
192        match field {
193            ProjectionField::Scalar {
194                expr: _,
195                alias: Some(alias),
196            } => labels.push(alias.as_str().to_string()),
197            ProjectionField::Scalar { expr, alias: None } => {
198                labels.push(projection_label_from_expr(expr, ordinal));
199            }
200        }
201    }
202
203    Ok(labels)
204}
205
206impl<C: CanisterKind> DbSession<C> {
207    /// Parse one reduced SQL statement into canonical routing metadata.
208    ///
209    /// This method is the SQL dispatch authority for entity/surface routing
210    /// outside typed-entity lowering paths.
211    pub fn sql_statement_route(&self, sql: &str) -> Result<SqlStatementRoute, QueryError> {
212        let statement = parse_sql(sql).map_err(map_sql_parse_error)?;
213        match statement {
214            SqlStatement::Select(select) => Ok(SqlStatementRoute::Query {
215                entity: select.entity,
216            }),
217            SqlStatement::Delete(delete) => Ok(SqlStatementRoute::Query {
218                entity: delete.entity,
219            }),
220            SqlStatement::Explain(explain) => match explain.statement {
221                SqlExplainTarget::Select(select) => Ok(SqlStatementRoute::Explain {
222                    entity: select.entity,
223                }),
224                SqlExplainTarget::Delete(delete) => Ok(SqlStatementRoute::Explain {
225                    entity: delete.entity,
226                }),
227            },
228            SqlStatement::Describe(describe) => Ok(SqlStatementRoute::Describe {
229                entity: describe.entity,
230            }),
231            SqlStatement::ShowIndexes(show_indexes) => Ok(SqlStatementRoute::ShowIndexes {
232                entity: show_indexes.entity,
233            }),
234            SqlStatement::ShowEntities(_) => Ok(SqlStatementRoute::ShowEntities),
235        }
236    }
237
238    /// Execute one reduced SQL `DESCRIBE` statement for entity `E`.
239    pub fn describe_sql<E>(&self, sql: &str) -> Result<EntitySchemaDescription, QueryError>
240    where
241        E: EntityKind<Canister = C>,
242    {
243        let command = compile_sql_command::<E>(sql, MissingRowPolicy::Ignore)
244            .map_err(map_sql_lowering_error)?;
245
246        match command {
247            SqlCommand::DescribeEntity => Ok(self.describe_entity::<E>()),
248            SqlCommand::Query(_)
249            | SqlCommand::Explain { .. }
250            | SqlCommand::ExplainGlobalAggregate { .. }
251            | SqlCommand::ShowIndexesEntity
252            | SqlCommand::ShowEntities => Err(QueryError::execute(InternalError::classified(
253                ErrorClass::Unsupported,
254                ErrorOrigin::Query,
255                "describe_sql requires a DESCRIBE statement",
256            ))),
257        }
258    }
259
260    /// Execute one reduced SQL `SHOW INDEXES` statement for entity `E`.
261    pub fn show_indexes_sql<E>(&self, sql: &str) -> Result<Vec<String>, QueryError>
262    where
263        E: EntityKind<Canister = C>,
264    {
265        let command = compile_sql_command::<E>(sql, MissingRowPolicy::Ignore)
266            .map_err(map_sql_lowering_error)?;
267
268        match command {
269            SqlCommand::ShowIndexesEntity => Ok(self.show_indexes::<E>()),
270            SqlCommand::Query(_)
271            | SqlCommand::Explain { .. }
272            | SqlCommand::ExplainGlobalAggregate { .. }
273            | SqlCommand::DescribeEntity
274            | SqlCommand::ShowEntities => Err(QueryError::execute(InternalError::classified(
275                ErrorClass::Unsupported,
276                ErrorOrigin::Query,
277                "show_indexes_sql requires a SHOW INDEXES statement",
278            ))),
279        }
280    }
281
282    /// Execute one reduced SQL `SHOW ENTITIES` statement.
283    pub fn show_entities_sql(&self, sql: &str) -> Result<Vec<String>, QueryError> {
284        let statement = self.sql_statement_route(sql)?;
285        if !statement.is_show_entities() {
286            return Err(QueryError::execute(InternalError::classified(
287                ErrorClass::Unsupported,
288                ErrorOrigin::Query,
289                "show_entities_sql requires a SHOW ENTITIES statement",
290            )));
291        }
292
293        Ok(self.show_entities())
294    }
295
296    /// Build one typed query intent from one reduced SQL statement.
297    ///
298    /// This parser/lowering entrypoint is intentionally constrained to the
299    /// executable subset wired in the current release.
300    pub fn query_from_sql<E>(&self, sql: &str) -> Result<Query<E>, QueryError>
301    where
302        E: EntityKind<Canister = C>,
303    {
304        let command = compile_sql_command::<E>(sql, MissingRowPolicy::Ignore)
305            .map_err(map_sql_lowering_error)?;
306
307        match command {
308            SqlCommand::Query(query) => Ok(query),
309            SqlCommand::Explain { .. } | SqlCommand::ExplainGlobalAggregate { .. } => {
310                Err(QueryError::execute(InternalError::classified(
311                    ErrorClass::Unsupported,
312                    ErrorOrigin::Query,
313                    "query_from_sql does not accept EXPLAIN statements; use explain_sql(...)",
314                )))
315            }
316            SqlCommand::DescribeEntity => Err(QueryError::execute(InternalError::classified(
317                ErrorClass::Unsupported,
318                ErrorOrigin::Query,
319                "query_from_sql does not accept DESCRIBE statements; use describe_sql(...)",
320            ))),
321            SqlCommand::ShowIndexesEntity => Err(QueryError::execute(InternalError::classified(
322                ErrorClass::Unsupported,
323                ErrorOrigin::Query,
324                "query_from_sql does not accept SHOW INDEXES statements; use show_indexes_sql(...)",
325            ))),
326            SqlCommand::ShowEntities => Err(QueryError::execute(InternalError::classified(
327                ErrorClass::Unsupported,
328                ErrorOrigin::Query,
329                "query_from_sql does not accept SHOW ENTITIES statements; use show_entities_sql(...)",
330            ))),
331        }
332    }
333
334    /// Derive canonical projection column labels for one reduced SQL `SELECT` statement.
335    pub fn sql_projection_columns<E>(&self, sql: &str) -> Result<Vec<String>, QueryError>
336    where
337        E: EntityKind<Canister = C>,
338    {
339        let query = self.query_from_sql::<E>(sql)?;
340        if query.has_grouping() {
341            return Err(QueryError::Intent(
342                IntentError::GroupedRequiresExecuteGrouped,
343            ));
344        }
345
346        match query.mode() {
347            QueryMode::Load(_) => projection_labels_from_query(&query),
348            QueryMode::Delete(_) => Err(QueryError::execute(InternalError::classified(
349                ErrorClass::Unsupported,
350                ErrorOrigin::Query,
351                "sql_projection_columns only supports SELECT statements",
352            ))),
353        }
354    }
355
356    /// Execute one reduced SQL `SELECT`/`DELETE` statement for entity `E`.
357    pub fn execute_sql<E>(&self, sql: &str) -> Result<EntityResponse<E>, QueryError>
358    where
359        E: EntityKind<Canister = C> + EntityValue,
360    {
361        let query = self.query_from_sql::<E>(sql)?;
362        if query.has_grouping() {
363            return Err(QueryError::Intent(
364                IntentError::GroupedRequiresExecuteGrouped,
365            ));
366        }
367
368        self.execute_query(&query)
369    }
370
371    /// Execute one reduced SQL `SELECT` statement and return projection-shaped rows.
372    ///
373    /// This surface keeps `execute_sql(...)` stable for callers
374    /// that consume full entity rows.
375    pub fn execute_sql_projection<E>(&self, sql: &str) -> Result<ProjectionResponse<E>, QueryError>
376    where
377        E: EntityKind<Canister = C> + EntityValue,
378    {
379        let query = self.query_from_sql::<E>(sql)?;
380        if query.has_grouping() {
381            return Err(QueryError::Intent(
382                IntentError::GroupedRequiresExecuteGrouped,
383            ));
384        }
385
386        match query.mode() {
387            QueryMode::Load(_) => {
388                self.execute_load_query_with(&query, |load, plan| load.execute_projection(plan))
389            }
390            QueryMode::Delete(_) => Err(QueryError::execute(InternalError::classified(
391                ErrorClass::Unsupported,
392                ErrorOrigin::Query,
393                "execute_sql_projection only supports SELECT statements",
394            ))),
395        }
396    }
397
398    /// Execute one reduced SQL global aggregate `SELECT` statement.
399    ///
400    /// This entrypoint is intentionally constrained to one aggregate terminal
401    /// shape per statement and preserves existing terminal semantics.
402    pub fn execute_sql_aggregate<E>(&self, sql: &str) -> Result<Value, QueryError>
403    where
404        E: EntityKind<Canister = C> + EntityValue,
405    {
406        let command = compile_sql_global_aggregate_command::<E>(sql, MissingRowPolicy::Ignore)
407            .map_err(map_sql_lowering_error)?;
408
409        match command.terminal() {
410            SqlGlobalAggregateTerminal::CountRows => self
411                .execute_load_query_with(command.query(), |load, plan| load.aggregate_count(plan))
412                .map(|count| Value::Uint(u64::from(count))),
413            SqlGlobalAggregateTerminal::CountField(field) => {
414                let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
415                self.execute_load_query_with(command.query(), |load, plan| {
416                    load.values_by_slot(plan, target_slot)
417                })
418                .map(|values| {
419                    let count = values
420                        .into_iter()
421                        .filter(|value| !matches!(value, Value::Null))
422                        .count();
423                    Value::Uint(u64::try_from(count).unwrap_or(u64::MAX))
424                })
425            }
426            SqlGlobalAggregateTerminal::SumField(field) => {
427                let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
428                self.execute_load_query_with(command.query(), |load, plan| {
429                    load.aggregate_sum_by_slot(plan, target_slot)
430                })
431                .map(|value| value.map_or(Value::Null, Value::Decimal))
432            }
433            SqlGlobalAggregateTerminal::AvgField(field) => {
434                let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
435                self.execute_load_query_with(command.query(), |load, plan| {
436                    load.aggregate_avg_by_slot(plan, target_slot)
437                })
438                .map(|value| value.map_or(Value::Null, Value::Decimal))
439            }
440            SqlGlobalAggregateTerminal::MinField(field) => {
441                let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
442                let min_id = self.execute_load_query_with(command.query(), |load, plan| {
443                    load.aggregate_min_by_slot(plan, target_slot)
444                })?;
445
446                match min_id {
447                    Some(id) => self
448                        .load::<E>()
449                        .by_id(id)
450                        .first_value_by(field)
451                        .map(|value| value.unwrap_or(Value::Null)),
452                    None => Ok(Value::Null),
453                }
454            }
455            SqlGlobalAggregateTerminal::MaxField(field) => {
456                let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
457                let max_id = self.execute_load_query_with(command.query(), |load, plan| {
458                    load.aggregate_max_by_slot(plan, target_slot)
459                })?;
460
461                match max_id {
462                    Some(id) => self
463                        .load::<E>()
464                        .by_id(id)
465                        .first_value_by(field)
466                        .map(|value| value.unwrap_or(Value::Null)),
467                    None => Ok(Value::Null),
468                }
469            }
470        }
471    }
472
473    /// Execute one reduced SQL grouped `SELECT` statement and return grouped rows.
474    pub fn execute_sql_grouped<E>(
475        &self,
476        sql: &str,
477        cursor_token: Option<&str>,
478    ) -> Result<PagedGroupedExecutionWithTrace, QueryError>
479    where
480        E: EntityKind<Canister = C> + EntityValue,
481    {
482        let query = self.query_from_sql::<E>(sql)?;
483        if !query.has_grouping() {
484            return Err(QueryError::execute(InternalError::classified(
485                ErrorClass::Unsupported,
486                ErrorOrigin::Query,
487                "execute_sql_grouped requires grouped SQL query intent",
488            )));
489        }
490
491        self.execute_grouped(&query, cursor_token)
492    }
493
494    /// Explain one reduced SQL statement for entity `E`.
495    ///
496    /// Supported modes:
497    /// - `EXPLAIN ...` -> logical plan text
498    /// - `EXPLAIN EXECUTION ...` -> execution descriptor text
499    /// - `EXPLAIN JSON ...` -> logical plan canonical JSON
500    pub fn explain_sql<E>(&self, sql: &str) -> Result<String, QueryError>
501    where
502        E: EntityKind<Canister = C> + EntityValue,
503    {
504        let command = compile_sql_command::<E>(sql, MissingRowPolicy::Ignore)
505            .map_err(map_sql_lowering_error)?;
506
507        match command {
508            SqlCommand::Query(_) => Err(QueryError::execute(InternalError::classified(
509                ErrorClass::Unsupported,
510                ErrorOrigin::Query,
511                "explain_sql requires an EXPLAIN statement",
512            ))),
513            SqlCommand::DescribeEntity => Err(QueryError::execute(InternalError::classified(
514                ErrorClass::Unsupported,
515                ErrorOrigin::Query,
516                "explain_sql does not accept DESCRIBE statements; use describe_sql(...)",
517            ))),
518            SqlCommand::ShowIndexesEntity => Err(QueryError::execute(InternalError::classified(
519                ErrorClass::Unsupported,
520                ErrorOrigin::Query,
521                "explain_sql does not accept SHOW INDEXES statements; use show_indexes_sql(...)",
522            ))),
523            SqlCommand::ShowEntities => Err(QueryError::execute(InternalError::classified(
524                ErrorClass::Unsupported,
525                ErrorOrigin::Query,
526                "explain_sql does not accept SHOW ENTITIES statements; use show_entities_sql(...)",
527            ))),
528            SqlCommand::Explain { mode, query } => match mode {
529                SqlExplainMode::Plan => Ok(query.explain()?.render_text_canonical()),
530                SqlExplainMode::Execution => query.explain_execution_text(),
531                SqlExplainMode::Json => Ok(query.explain()?.render_json_canonical()),
532            },
533            SqlCommand::ExplainGlobalAggregate { mode, command } => {
534                Self::explain_sql_global_aggregate::<E>(mode, command)
535            }
536        }
537    }
538
539    // Render one EXPLAIN payload for constrained global aggregate SQL command.
540    fn explain_sql_global_aggregate<E>(
541        mode: SqlExplainMode,
542        command: SqlGlobalAggregateCommand<E>,
543    ) -> Result<String, QueryError>
544    where
545        E: EntityKind<Canister = C> + EntityValue,
546    {
547        match mode {
548            SqlExplainMode::Plan => {
549                // Keep explain validation parity with execution by requiring the
550                // target field to resolve before returning explain output.
551                let _ = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
552
553                Ok(command.query().explain()?.render_text_canonical())
554            }
555            SqlExplainMode::Execution => {
556                let aggregate = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
557                let plan = Self::explain_load_query_terminal_with(command.query(), aggregate)?;
558
559                Ok(plan.execution_node_descriptor().render_text_tree())
560            }
561            SqlExplainMode::Json => {
562                // Keep explain validation parity with execution by requiring the
563                // target field to resolve before returning explain output.
564                let _ = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
565
566                Ok(command.query().explain()?.render_json_canonical())
567            }
568        }
569    }
570}