Skip to main content

icydb_core/db/session/
sql.rs

1use crate::{
2    db::{
3        DbSession, EntityResponse, EntitySchemaDescription, MissingRowPolicy,
4        PagedGroupedExecutionWithTrace, ProjectionResponse, Query, QueryError,
5        query::{
6            builder::aggregate::{AggregateExpr, avg, count, count_by, max_by, min_by, sum},
7            intent::IntentError,
8            plan::{
9                AggregateKind, FieldSlot, QueryMode,
10                expr::{Expr, ProjectionField},
11            },
12        },
13        sql::lowering::{
14            SqlCommand, SqlGlobalAggregateCommand, SqlGlobalAggregateTerminal, SqlLoweringError,
15            compile_sql_command, compile_sql_global_aggregate_command,
16        },
17        sql::parser::{SqlExplainMode, SqlExplainTarget, SqlStatement, parse_sql},
18    },
19    error::{ErrorClass, ErrorOrigin, InternalError},
20    traits::{CanisterKind, EntityKind, EntityValue},
21    value::Value,
22};
23
24///
25/// SqlStatementRoute
26///
27/// Canonical SQL statement routing metadata derived from reduced SQL parser output.
28/// Carries surface kind (`Query` / `Explain` / `Describe` / `ShowIndexes`)
29/// and canonical parsed entity identifier.
30///
31#[derive(Clone, Debug, Eq, PartialEq)]
32pub enum SqlStatementRoute {
33    Query { entity: String },
34    Explain { entity: String },
35    Describe { entity: String },
36    ShowIndexes { entity: String },
37}
38
39impl SqlStatementRoute {
40    /// Borrow the parsed SQL entity identifier for this statement.
41    #[must_use]
42    pub const fn entity(&self) -> &str {
43        match self {
44            Self::Query { entity }
45            | Self::Explain { entity }
46            | Self::Describe { entity }
47            | Self::ShowIndexes { entity } => entity.as_str(),
48        }
49    }
50
51    /// Return whether this route targets the EXPLAIN surface.
52    #[must_use]
53    pub const fn is_explain(&self) -> bool {
54        matches!(self, Self::Explain { .. })
55    }
56
57    /// Return whether this route targets the DESCRIBE surface.
58    #[must_use]
59    pub const fn is_describe(&self) -> bool {
60        matches!(self, Self::Describe { .. })
61    }
62
63    /// Return whether this route targets the `SHOW INDEXES` surface.
64    #[must_use]
65    pub const fn is_show_indexes(&self) -> bool {
66        matches!(self, Self::ShowIndexes { .. })
67    }
68}
69
70// Map SQL frontend parse/lowering failures into query-facing execution errors.
71fn map_sql_lowering_error(err: SqlLoweringError) -> QueryError {
72    match err {
73        SqlLoweringError::Query(err) => err,
74        SqlLoweringError::Parse(crate::db::sql::parser::SqlParseError::UnsupportedFeature {
75            feature,
76        }) => QueryError::execute(InternalError::query_unsupported_sql_feature(feature)),
77        other => QueryError::execute(InternalError::classified(
78            ErrorClass::Unsupported,
79            ErrorOrigin::Query,
80            format!("SQL query is not executable in this release: {other}"),
81        )),
82    }
83}
84
85// Map reduced SQL parse failures through the same query-facing classification
86// policy used by SQL lowering entrypoints.
87fn map_sql_parse_error(err: crate::db::sql::parser::SqlParseError) -> QueryError {
88    map_sql_lowering_error(SqlLoweringError::Parse(err))
89}
90
91// Resolve one aggregate target field through planner slot contracts before
92// aggregate terminal execution.
93fn resolve_sql_aggregate_target_slot<E: EntityKind>(field: &str) -> Result<FieldSlot, QueryError> {
94    FieldSlot::resolve(E::MODEL, field).ok_or_else(|| {
95        QueryError::execute(crate::db::error::executor_unsupported(format!(
96            "unknown aggregate target field: {field}",
97        )))
98    })
99}
100
101// Convert one lowered global SQL aggregate terminal into aggregate expression
102// contracts used by aggregate explain execution descriptors.
103fn sql_global_aggregate_terminal_to_expr<E: EntityKind>(
104    terminal: &SqlGlobalAggregateTerminal,
105) -> Result<AggregateExpr, QueryError> {
106    match terminal {
107        SqlGlobalAggregateTerminal::CountRows => Ok(count()),
108        SqlGlobalAggregateTerminal::CountField(field) => {
109            let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
110
111            Ok(count_by(field.as_str()))
112        }
113        SqlGlobalAggregateTerminal::SumField(field) => {
114            let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
115
116            Ok(sum(field.as_str()))
117        }
118        SqlGlobalAggregateTerminal::AvgField(field) => {
119            let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
120
121            Ok(avg(field.as_str()))
122        }
123        SqlGlobalAggregateTerminal::MinField(field) => {
124            let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
125
126            Ok(min_by(field.as_str()))
127        }
128        SqlGlobalAggregateTerminal::MaxField(field) => {
129            let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
130
131            Ok(max_by(field.as_str()))
132        }
133    }
134}
135
136// Render one aggregate expression into a canonical projection column label.
137fn projection_label_from_aggregate(aggregate: &AggregateExpr) -> String {
138    let kind = match aggregate.kind() {
139        AggregateKind::Count => "COUNT",
140        AggregateKind::Sum => "SUM",
141        AggregateKind::Avg => "AVG",
142        AggregateKind::Exists => "EXISTS",
143        AggregateKind::First => "FIRST",
144        AggregateKind::Last => "LAST",
145        AggregateKind::Min => "MIN",
146        AggregateKind::Max => "MAX",
147    };
148    let distinct = if aggregate.is_distinct() {
149        "DISTINCT "
150    } else {
151        ""
152    };
153
154    if let Some(field) = aggregate.target_field() {
155        return format!("{kind}({distinct}{field})");
156    }
157
158    format!("{kind}({distinct}*)")
159}
160
161// Render one projection expression into a canonical output label.
162fn projection_label_from_expr(expr: &Expr, ordinal: usize) -> String {
163    match expr {
164        Expr::Field(field) => field.as_str().to_string(),
165        Expr::Aggregate(aggregate) => projection_label_from_aggregate(aggregate),
166        Expr::Alias { name, .. } => name.as_str().to_string(),
167        Expr::Literal(_) | Expr::Unary { .. } | Expr::Binary { .. } => {
168            format!("expr_{ordinal}")
169        }
170    }
171}
172
173// Derive canonical projection column labels from one planned query projection spec.
174fn projection_labels_from_query<E: EntityKind>(
175    query: &Query<E>,
176) -> Result<Vec<String>, QueryError> {
177    let projection = query.plan()?.projection_spec();
178    let mut labels = Vec::with_capacity(projection.len());
179
180    for (ordinal, field) in projection.fields().enumerate() {
181        match field {
182            ProjectionField::Scalar {
183                expr: _,
184                alias: Some(alias),
185            } => labels.push(alias.as_str().to_string()),
186            ProjectionField::Scalar { expr, alias: None } => {
187                labels.push(projection_label_from_expr(expr, ordinal));
188            }
189        }
190    }
191
192    Ok(labels)
193}
194
195impl<C: CanisterKind> DbSession<C> {
196    /// Parse one reduced SQL statement into canonical routing metadata.
197    ///
198    /// This method is the SQL dispatch authority for entity/surface routing
199    /// outside typed-entity lowering paths.
200    pub fn sql_statement_route(&self, sql: &str) -> Result<SqlStatementRoute, QueryError> {
201        let statement = parse_sql(sql).map_err(map_sql_parse_error)?;
202        match statement {
203            SqlStatement::Select(select) => Ok(SqlStatementRoute::Query {
204                entity: select.entity,
205            }),
206            SqlStatement::Delete(delete) => Ok(SqlStatementRoute::Query {
207                entity: delete.entity,
208            }),
209            SqlStatement::Explain(explain) => match explain.statement {
210                SqlExplainTarget::Select(select) => Ok(SqlStatementRoute::Explain {
211                    entity: select.entity,
212                }),
213                SqlExplainTarget::Delete(delete) => Ok(SqlStatementRoute::Explain {
214                    entity: delete.entity,
215                }),
216            },
217            SqlStatement::Describe(describe) => Ok(SqlStatementRoute::Describe {
218                entity: describe.entity,
219            }),
220            SqlStatement::ShowIndexes(show_indexes) => Ok(SqlStatementRoute::ShowIndexes {
221                entity: show_indexes.entity,
222            }),
223        }
224    }
225
226    /// Execute one reduced SQL `DESCRIBE` statement for entity `E`.
227    pub fn describe_sql<E>(&self, sql: &str) -> Result<EntitySchemaDescription, QueryError>
228    where
229        E: EntityKind<Canister = C>,
230    {
231        let command = compile_sql_command::<E>(sql, MissingRowPolicy::Ignore)
232            .map_err(map_sql_lowering_error)?;
233
234        match command {
235            SqlCommand::DescribeEntity => Ok(self.describe_entity::<E>()),
236            SqlCommand::Query(_)
237            | SqlCommand::Explain { .. }
238            | SqlCommand::ExplainGlobalAggregate { .. }
239            | SqlCommand::ShowIndexesEntity => Err(QueryError::execute(InternalError::classified(
240                ErrorClass::Unsupported,
241                ErrorOrigin::Query,
242                "describe_sql requires a DESCRIBE statement",
243            ))),
244        }
245    }
246
247    /// Execute one reduced SQL `SHOW INDEXES` statement for entity `E`.
248    pub fn show_indexes_sql<E>(&self, sql: &str) -> Result<Vec<String>, QueryError>
249    where
250        E: EntityKind<Canister = C>,
251    {
252        let command = compile_sql_command::<E>(sql, MissingRowPolicy::Ignore)
253            .map_err(map_sql_lowering_error)?;
254
255        match command {
256            SqlCommand::ShowIndexesEntity => Ok(self.show_indexes::<E>()),
257            SqlCommand::Query(_)
258            | SqlCommand::Explain { .. }
259            | SqlCommand::ExplainGlobalAggregate { .. }
260            | SqlCommand::DescribeEntity => Err(QueryError::execute(InternalError::classified(
261                ErrorClass::Unsupported,
262                ErrorOrigin::Query,
263                "show_indexes_sql requires a SHOW INDEXES statement",
264            ))),
265        }
266    }
267
268    /// Build one typed query intent from one reduced SQL statement.
269    ///
270    /// This parser/lowering entrypoint is intentionally constrained to the
271    /// executable subset wired in the current release.
272    pub fn query_from_sql<E>(&self, sql: &str) -> Result<Query<E>, QueryError>
273    where
274        E: EntityKind<Canister = C>,
275    {
276        let command = compile_sql_command::<E>(sql, MissingRowPolicy::Ignore)
277            .map_err(map_sql_lowering_error)?;
278
279        match command {
280            SqlCommand::Query(query) => Ok(query),
281            SqlCommand::Explain { .. } | SqlCommand::ExplainGlobalAggregate { .. } => {
282                Err(QueryError::execute(InternalError::classified(
283                    ErrorClass::Unsupported,
284                    ErrorOrigin::Query,
285                    "query_from_sql does not accept EXPLAIN statements; use explain_sql(...)",
286                )))
287            }
288            SqlCommand::DescribeEntity => Err(QueryError::execute(InternalError::classified(
289                ErrorClass::Unsupported,
290                ErrorOrigin::Query,
291                "query_from_sql does not accept DESCRIBE statements; use describe_sql(...)",
292            ))),
293            SqlCommand::ShowIndexesEntity => Err(QueryError::execute(InternalError::classified(
294                ErrorClass::Unsupported,
295                ErrorOrigin::Query,
296                "query_from_sql does not accept SHOW INDEXES statements; use show_indexes_sql(...)",
297            ))),
298        }
299    }
300
301    /// Derive canonical projection column labels for one reduced SQL `SELECT` statement.
302    pub fn sql_projection_columns<E>(&self, sql: &str) -> Result<Vec<String>, QueryError>
303    where
304        E: EntityKind<Canister = C>,
305    {
306        let query = self.query_from_sql::<E>(sql)?;
307        if query.has_grouping() {
308            return Err(QueryError::Intent(
309                IntentError::GroupedRequiresExecuteGrouped,
310            ));
311        }
312
313        match query.mode() {
314            QueryMode::Load(_) => projection_labels_from_query(&query),
315            QueryMode::Delete(_) => Err(QueryError::execute(InternalError::classified(
316                ErrorClass::Unsupported,
317                ErrorOrigin::Query,
318                "sql_projection_columns only supports SELECT statements",
319            ))),
320        }
321    }
322
323    /// Execute one reduced SQL `SELECT`/`DELETE` statement for entity `E`.
324    pub fn execute_sql<E>(&self, sql: &str) -> Result<EntityResponse<E>, QueryError>
325    where
326        E: EntityKind<Canister = C> + EntityValue,
327    {
328        let query = self.query_from_sql::<E>(sql)?;
329        if query.has_grouping() {
330            return Err(QueryError::Intent(
331                IntentError::GroupedRequiresExecuteGrouped,
332            ));
333        }
334
335        self.execute_query(&query)
336    }
337
338    /// Execute one reduced SQL `SELECT` statement and return projection-shaped rows.
339    ///
340    /// This surface keeps `execute_sql(...)` stable for callers
341    /// that consume full entity rows.
342    pub fn execute_sql_projection<E>(&self, sql: &str) -> Result<ProjectionResponse<E>, QueryError>
343    where
344        E: EntityKind<Canister = C> + EntityValue,
345    {
346        let query = self.query_from_sql::<E>(sql)?;
347        if query.has_grouping() {
348            return Err(QueryError::Intent(
349                IntentError::GroupedRequiresExecuteGrouped,
350            ));
351        }
352
353        match query.mode() {
354            QueryMode::Load(_) => {
355                self.execute_load_query_with(&query, |load, plan| load.execute_projection(plan))
356            }
357            QueryMode::Delete(_) => Err(QueryError::execute(InternalError::classified(
358                ErrorClass::Unsupported,
359                ErrorOrigin::Query,
360                "execute_sql_projection only supports SELECT statements",
361            ))),
362        }
363    }
364
365    /// Execute one reduced SQL global aggregate `SELECT` statement.
366    ///
367    /// This entrypoint is intentionally constrained to one aggregate terminal
368    /// shape per statement and preserves existing terminal semantics.
369    pub fn execute_sql_aggregate<E>(&self, sql: &str) -> Result<Value, QueryError>
370    where
371        E: EntityKind<Canister = C> + EntityValue,
372    {
373        let command = compile_sql_global_aggregate_command::<E>(sql, MissingRowPolicy::Ignore)
374            .map_err(map_sql_lowering_error)?;
375
376        match command.terminal() {
377            SqlGlobalAggregateTerminal::CountRows => self
378                .execute_load_query_with(command.query(), |load, plan| load.aggregate_count(plan))
379                .map(|count| Value::Uint(u64::from(count))),
380            SqlGlobalAggregateTerminal::CountField(field) => {
381                let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
382                self.execute_load_query_with(command.query(), |load, plan| {
383                    load.values_by_slot(plan, target_slot)
384                })
385                .map(|values| {
386                    let count = values
387                        .into_iter()
388                        .filter(|value| !matches!(value, Value::Null))
389                        .count();
390                    Value::Uint(u64::try_from(count).unwrap_or(u64::MAX))
391                })
392            }
393            SqlGlobalAggregateTerminal::SumField(field) => {
394                let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
395                self.execute_load_query_with(command.query(), |load, plan| {
396                    load.aggregate_sum_by_slot(plan, target_slot)
397                })
398                .map(|value| value.map_or(Value::Null, Value::Decimal))
399            }
400            SqlGlobalAggregateTerminal::AvgField(field) => {
401                let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
402                self.execute_load_query_with(command.query(), |load, plan| {
403                    load.aggregate_avg_by_slot(plan, target_slot)
404                })
405                .map(|value| value.map_or(Value::Null, Value::Decimal))
406            }
407            SqlGlobalAggregateTerminal::MinField(field) => {
408                let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
409                let min_id = self.execute_load_query_with(command.query(), |load, plan| {
410                    load.aggregate_min_by_slot(plan, target_slot)
411                })?;
412
413                match min_id {
414                    Some(id) => self
415                        .load::<E>()
416                        .by_id(id)
417                        .first_value_by(field)
418                        .map(|value| value.unwrap_or(Value::Null)),
419                    None => Ok(Value::Null),
420                }
421            }
422            SqlGlobalAggregateTerminal::MaxField(field) => {
423                let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
424                let max_id = self.execute_load_query_with(command.query(), |load, plan| {
425                    load.aggregate_max_by_slot(plan, target_slot)
426                })?;
427
428                match max_id {
429                    Some(id) => self
430                        .load::<E>()
431                        .by_id(id)
432                        .first_value_by(field)
433                        .map(|value| value.unwrap_or(Value::Null)),
434                    None => Ok(Value::Null),
435                }
436            }
437        }
438    }
439
440    /// Execute one reduced SQL grouped `SELECT` statement and return grouped rows.
441    pub fn execute_sql_grouped<E>(
442        &self,
443        sql: &str,
444        cursor_token: Option<&str>,
445    ) -> Result<PagedGroupedExecutionWithTrace, QueryError>
446    where
447        E: EntityKind<Canister = C> + EntityValue,
448    {
449        let query = self.query_from_sql::<E>(sql)?;
450        if !query.has_grouping() {
451            return Err(QueryError::execute(InternalError::classified(
452                ErrorClass::Unsupported,
453                ErrorOrigin::Query,
454                "execute_sql_grouped requires grouped SQL query intent",
455            )));
456        }
457
458        self.execute_grouped(&query, cursor_token)
459    }
460
461    /// Explain one reduced SQL statement for entity `E`.
462    ///
463    /// Supported modes:
464    /// - `EXPLAIN ...` -> logical plan text
465    /// - `EXPLAIN EXECUTION ...` -> execution descriptor text
466    /// - `EXPLAIN JSON ...` -> logical plan canonical JSON
467    pub fn explain_sql<E>(&self, sql: &str) -> Result<String, QueryError>
468    where
469        E: EntityKind<Canister = C> + EntityValue,
470    {
471        let command = compile_sql_command::<E>(sql, MissingRowPolicy::Ignore)
472            .map_err(map_sql_lowering_error)?;
473
474        match command {
475            SqlCommand::Query(_) => Err(QueryError::execute(InternalError::classified(
476                ErrorClass::Unsupported,
477                ErrorOrigin::Query,
478                "explain_sql requires an EXPLAIN statement",
479            ))),
480            SqlCommand::DescribeEntity => Err(QueryError::execute(InternalError::classified(
481                ErrorClass::Unsupported,
482                ErrorOrigin::Query,
483                "explain_sql does not accept DESCRIBE statements; use describe_sql(...)",
484            ))),
485            SqlCommand::ShowIndexesEntity => Err(QueryError::execute(InternalError::classified(
486                ErrorClass::Unsupported,
487                ErrorOrigin::Query,
488                "explain_sql does not accept SHOW INDEXES statements; use show_indexes_sql(...)",
489            ))),
490            SqlCommand::Explain { mode, query } => match mode {
491                SqlExplainMode::Plan => Ok(query.explain()?.render_text_canonical()),
492                SqlExplainMode::Execution => query.explain_execution_text(),
493                SqlExplainMode::Json => Ok(query.explain()?.render_json_canonical()),
494            },
495            SqlCommand::ExplainGlobalAggregate { mode, command } => {
496                Self::explain_sql_global_aggregate::<E>(mode, command)
497            }
498        }
499    }
500
501    // Render one EXPLAIN payload for constrained global aggregate SQL command.
502    fn explain_sql_global_aggregate<E>(
503        mode: SqlExplainMode,
504        command: SqlGlobalAggregateCommand<E>,
505    ) -> Result<String, QueryError>
506    where
507        E: EntityKind<Canister = C> + EntityValue,
508    {
509        match mode {
510            SqlExplainMode::Plan => {
511                // Keep explain validation parity with execution by requiring the
512                // target field to resolve before returning explain output.
513                let _ = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
514
515                Ok(command.query().explain()?.render_text_canonical())
516            }
517            SqlExplainMode::Execution => {
518                let aggregate = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
519                let plan = Self::explain_load_query_terminal_with(command.query(), aggregate)?;
520
521                Ok(plan.execution_node_descriptor().render_text_tree())
522            }
523            SqlExplainMode::Json => {
524                // Keep explain validation parity with execution by requiring the
525                // target field to resolve before returning explain output.
526                let _ = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
527
528                Ok(command.query().explain()?.render_json_canonical())
529            }
530        }
531    }
532}