Skip to main content

icydb_core/db/session/
sql.rs

1use crate::{
2    db::{
3        DbSession, EntityResponse, MissingRowPolicy, PagedGroupedExecutionWithTrace,
4        ProjectionResponse, Query, QueryError,
5        query::{
6            builder::aggregate::{AggregateExpr, avg, count, count_by, max_by, min_by, sum},
7            intent::IntentError,
8            plan::{
9                AggregateKind, FieldSlot, QueryMode,
10                expr::{Expr, ProjectionField},
11            },
12        },
13        sql::lowering::{
14            SqlCommand, SqlGlobalAggregateCommand, SqlGlobalAggregateTerminal, SqlLoweringError,
15            compile_sql_command, compile_sql_global_aggregate_command,
16        },
17        sql::parser::{SqlExplainMode, SqlExplainTarget, SqlStatement, parse_sql},
18    },
19    error::{ErrorClass, ErrorOrigin, InternalError},
20    traits::{CanisterKind, EntityKind, EntityValue},
21    value::Value,
22};
23
24///
25/// SqlStatementRoute
26///
27/// Canonical SQL statement routing metadata derived from reduced SQL parser output.
28/// Carries surface kind (`Query` / `Explain`) and canonical parsed entity identifier.
29///
30#[derive(Clone, Debug, Eq, PartialEq)]
31pub enum SqlStatementRoute {
32    Query { entity: String },
33    Explain { entity: String },
34}
35
36impl SqlStatementRoute {
37    /// Borrow the parsed SQL entity identifier for this statement.
38    #[must_use]
39    pub const fn entity(&self) -> &str {
40        match self {
41            Self::Query { entity } | Self::Explain { entity } => entity.as_str(),
42        }
43    }
44
45    /// Return whether this route targets the EXPLAIN surface.
46    #[must_use]
47    pub const fn is_explain(&self) -> bool {
48        matches!(self, Self::Explain { .. })
49    }
50}
51
52// Map SQL frontend parse/lowering failures into query-facing execution errors.
53fn map_sql_lowering_error(err: SqlLoweringError) -> QueryError {
54    match err {
55        SqlLoweringError::Query(err) => err,
56        SqlLoweringError::Parse(crate::db::sql::parser::SqlParseError::UnsupportedFeature {
57            feature,
58        }) => QueryError::execute(InternalError::query_unsupported_sql_feature(feature)),
59        other => QueryError::execute(InternalError::classified(
60            ErrorClass::Unsupported,
61            ErrorOrigin::Query,
62            format!("SQL query is not executable in this release: {other}"),
63        )),
64    }
65}
66
67// Map reduced SQL parse failures through the same query-facing classification
68// policy used by SQL lowering entrypoints.
69fn map_sql_parse_error(err: crate::db::sql::parser::SqlParseError) -> QueryError {
70    map_sql_lowering_error(SqlLoweringError::Parse(err))
71}
72
73// Resolve one aggregate target field through planner slot contracts before
74// aggregate terminal execution.
75fn resolve_sql_aggregate_target_slot<E: EntityKind>(field: &str) -> Result<FieldSlot, QueryError> {
76    FieldSlot::resolve(E::MODEL, field).ok_or_else(|| {
77        QueryError::execute(crate::db::error::executor_unsupported(format!(
78            "unknown aggregate target field: {field}",
79        )))
80    })
81}
82
83// Convert one lowered global SQL aggregate terminal into aggregate expression
84// contracts used by aggregate explain execution descriptors.
85fn sql_global_aggregate_terminal_to_expr<E: EntityKind>(
86    terminal: &SqlGlobalAggregateTerminal,
87) -> Result<AggregateExpr, QueryError> {
88    match terminal {
89        SqlGlobalAggregateTerminal::CountRows => Ok(count()),
90        SqlGlobalAggregateTerminal::CountField(field) => {
91            let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
92
93            Ok(count_by(field.as_str()))
94        }
95        SqlGlobalAggregateTerminal::SumField(field) => {
96            let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
97
98            Ok(sum(field.as_str()))
99        }
100        SqlGlobalAggregateTerminal::AvgField(field) => {
101            let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
102
103            Ok(avg(field.as_str()))
104        }
105        SqlGlobalAggregateTerminal::MinField(field) => {
106            let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
107
108            Ok(min_by(field.as_str()))
109        }
110        SqlGlobalAggregateTerminal::MaxField(field) => {
111            let _ = resolve_sql_aggregate_target_slot::<E>(field)?;
112
113            Ok(max_by(field.as_str()))
114        }
115    }
116}
117
118// Render one aggregate expression into a canonical projection column label.
119fn projection_label_from_aggregate(aggregate: &AggregateExpr) -> String {
120    let kind = match aggregate.kind() {
121        AggregateKind::Count => "COUNT",
122        AggregateKind::Sum => "SUM",
123        AggregateKind::Avg => "AVG",
124        AggregateKind::Exists => "EXISTS",
125        AggregateKind::First => "FIRST",
126        AggregateKind::Last => "LAST",
127        AggregateKind::Min => "MIN",
128        AggregateKind::Max => "MAX",
129    };
130    let distinct = if aggregate.is_distinct() {
131        "DISTINCT "
132    } else {
133        ""
134    };
135
136    if let Some(field) = aggregate.target_field() {
137        return format!("{kind}({distinct}{field})");
138    }
139
140    format!("{kind}({distinct}*)")
141}
142
143// Render one projection expression into a canonical output label.
144fn projection_label_from_expr(expr: &Expr, ordinal: usize) -> String {
145    match expr {
146        Expr::Field(field) => field.as_str().to_string(),
147        Expr::Aggregate(aggregate) => projection_label_from_aggregate(aggregate),
148        Expr::Alias { name, .. } => name.as_str().to_string(),
149        Expr::Literal(_) | Expr::Unary { .. } | Expr::Binary { .. } => {
150            format!("expr_{ordinal}")
151        }
152    }
153}
154
155// Derive canonical projection column labels from one planned query projection spec.
156fn projection_labels_from_query<E: EntityKind>(
157    query: &Query<E>,
158) -> Result<Vec<String>, QueryError> {
159    let projection = query.plan()?.projection_spec();
160    let mut labels = Vec::with_capacity(projection.len());
161
162    for (ordinal, field) in projection.fields().enumerate() {
163        match field {
164            ProjectionField::Scalar {
165                expr: _,
166                alias: Some(alias),
167            } => labels.push(alias.as_str().to_string()),
168            ProjectionField::Scalar { expr, alias: None } => {
169                labels.push(projection_label_from_expr(expr, ordinal));
170            }
171        }
172    }
173
174    Ok(labels)
175}
176
177impl<C: CanisterKind> DbSession<C> {
178    /// Parse one reduced SQL statement into canonical routing metadata.
179    ///
180    /// This method is the SQL dispatch authority for entity/surface routing
181    /// outside typed-entity lowering paths.
182    pub fn sql_statement_route(&self, sql: &str) -> Result<SqlStatementRoute, QueryError> {
183        let statement = parse_sql(sql).map_err(map_sql_parse_error)?;
184        match statement {
185            SqlStatement::Select(select) => Ok(SqlStatementRoute::Query {
186                entity: select.entity,
187            }),
188            SqlStatement::Delete(delete) => Ok(SqlStatementRoute::Query {
189                entity: delete.entity,
190            }),
191            SqlStatement::Explain(explain) => match explain.statement {
192                SqlExplainTarget::Select(select) => Ok(SqlStatementRoute::Explain {
193                    entity: select.entity,
194                }),
195                SqlExplainTarget::Delete(delete) => Ok(SqlStatementRoute::Explain {
196                    entity: delete.entity,
197                }),
198            },
199        }
200    }
201
202    /// Build one typed query intent from one reduced SQL statement.
203    ///
204    /// This parser/lowering entrypoint is intentionally constrained to the
205    /// executable subset wired in the current release.
206    pub fn query_from_sql<E>(&self, sql: &str) -> Result<Query<E>, QueryError>
207    where
208        E: EntityKind<Canister = C>,
209    {
210        let command = compile_sql_command::<E>(sql, MissingRowPolicy::Ignore)
211            .map_err(map_sql_lowering_error)?;
212
213        match command {
214            SqlCommand::Query(query) => Ok(query),
215            SqlCommand::Explain { .. } | SqlCommand::ExplainGlobalAggregate { .. } => {
216                Err(QueryError::execute(InternalError::classified(
217                    ErrorClass::Unsupported,
218                    ErrorOrigin::Query,
219                    "query_from_sql does not accept EXPLAIN statements; use explain_sql(...)",
220                )))
221            }
222        }
223    }
224
225    /// Derive canonical projection column labels for one reduced SQL `SELECT` statement.
226    pub fn sql_projection_columns<E>(&self, sql: &str) -> Result<Vec<String>, QueryError>
227    where
228        E: EntityKind<Canister = C>,
229    {
230        let query = self.query_from_sql::<E>(sql)?;
231        if query.has_grouping() {
232            return Err(QueryError::Intent(
233                IntentError::GroupedRequiresExecuteGrouped,
234            ));
235        }
236
237        match query.mode() {
238            QueryMode::Load(_) => projection_labels_from_query(&query),
239            QueryMode::Delete(_) => Err(QueryError::execute(InternalError::classified(
240                ErrorClass::Unsupported,
241                ErrorOrigin::Query,
242                "sql_projection_columns only supports SELECT statements",
243            ))),
244        }
245    }
246
247    /// Execute one reduced SQL `SELECT`/`DELETE` statement for entity `E`.
248    pub fn execute_sql<E>(&self, sql: &str) -> Result<EntityResponse<E>, QueryError>
249    where
250        E: EntityKind<Canister = C> + EntityValue,
251    {
252        let query = self.query_from_sql::<E>(sql)?;
253        if query.has_grouping() {
254            return Err(QueryError::Intent(
255                IntentError::GroupedRequiresExecuteGrouped,
256            ));
257        }
258
259        self.execute_query(&query)
260    }
261
262    /// Execute one reduced SQL `SELECT` statement and return projection-shaped rows.
263    ///
264    /// This surface keeps `execute_sql(...)` backwards-compatible for callers
265    /// that currently consume full entity rows.
266    pub fn execute_sql_projection<E>(&self, sql: &str) -> Result<ProjectionResponse<E>, QueryError>
267    where
268        E: EntityKind<Canister = C> + EntityValue,
269    {
270        let query = self.query_from_sql::<E>(sql)?;
271        if query.has_grouping() {
272            return Err(QueryError::Intent(
273                IntentError::GroupedRequiresExecuteGrouped,
274            ));
275        }
276
277        match query.mode() {
278            QueryMode::Load(_) => {
279                self.execute_load_query_with(&query, |load, plan| load.execute_projection(plan))
280            }
281            QueryMode::Delete(_) => Err(QueryError::execute(InternalError::classified(
282                ErrorClass::Unsupported,
283                ErrorOrigin::Query,
284                "execute_sql_projection only supports SELECT statements",
285            ))),
286        }
287    }
288
289    /// Execute one reduced SQL global aggregate `SELECT` statement.
290    ///
291    /// This entrypoint is intentionally constrained to one aggregate terminal
292    /// shape per statement and preserves existing terminal semantics.
293    pub fn execute_sql_aggregate<E>(&self, sql: &str) -> Result<Value, QueryError>
294    where
295        E: EntityKind<Canister = C> + EntityValue,
296    {
297        let command = compile_sql_global_aggregate_command::<E>(sql, MissingRowPolicy::Ignore)
298            .map_err(map_sql_lowering_error)?;
299
300        match command.terminal() {
301            SqlGlobalAggregateTerminal::CountRows => self
302                .execute_load_query_with(command.query(), |load, plan| load.aggregate_count(plan))
303                .map(|count| Value::Uint(u64::from(count))),
304            SqlGlobalAggregateTerminal::CountField(field) => {
305                let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
306                self.execute_load_query_with(command.query(), |load, plan| {
307                    load.values_by_slot(plan, target_slot)
308                })
309                .map(|values| {
310                    let count = values
311                        .into_iter()
312                        .filter(|value| !matches!(value, Value::Null))
313                        .count();
314                    Value::Uint(u64::try_from(count).unwrap_or(u64::MAX))
315                })
316            }
317            SqlGlobalAggregateTerminal::SumField(field) => {
318                let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
319                self.execute_load_query_with(command.query(), |load, plan| {
320                    load.aggregate_sum_by_slot(plan, target_slot)
321                })
322                .map(|value| value.map_or(Value::Null, Value::Decimal))
323            }
324            SqlGlobalAggregateTerminal::AvgField(field) => {
325                let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
326                self.execute_load_query_with(command.query(), |load, plan| {
327                    load.aggregate_avg_by_slot(plan, target_slot)
328                })
329                .map(|value| value.map_or(Value::Null, Value::Decimal))
330            }
331            SqlGlobalAggregateTerminal::MinField(field) => {
332                let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
333                let min_id = self.execute_load_query_with(command.query(), |load, plan| {
334                    load.aggregate_min_by_slot(plan, target_slot)
335                })?;
336
337                match min_id {
338                    Some(id) => self
339                        .load::<E>()
340                        .by_id(id)
341                        .first_value_by(field)
342                        .map(|value| value.unwrap_or(Value::Null)),
343                    None => Ok(Value::Null),
344                }
345            }
346            SqlGlobalAggregateTerminal::MaxField(field) => {
347                let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
348                let max_id = self.execute_load_query_with(command.query(), |load, plan| {
349                    load.aggregate_max_by_slot(plan, target_slot)
350                })?;
351
352                match max_id {
353                    Some(id) => self
354                        .load::<E>()
355                        .by_id(id)
356                        .first_value_by(field)
357                        .map(|value| value.unwrap_or(Value::Null)),
358                    None => Ok(Value::Null),
359                }
360            }
361        }
362    }
363
364    /// Execute one reduced SQL grouped `SELECT` statement and return grouped rows.
365    pub fn execute_sql_grouped<E>(
366        &self,
367        sql: &str,
368        cursor_token: Option<&str>,
369    ) -> Result<PagedGroupedExecutionWithTrace, QueryError>
370    where
371        E: EntityKind<Canister = C> + EntityValue,
372    {
373        let query = self.query_from_sql::<E>(sql)?;
374        if !query.has_grouping() {
375            return Err(QueryError::execute(InternalError::classified(
376                ErrorClass::Unsupported,
377                ErrorOrigin::Query,
378                "execute_sql_grouped requires grouped SQL query intent",
379            )));
380        }
381
382        self.execute_grouped(&query, cursor_token)
383    }
384
385    /// Explain one reduced SQL statement for entity `E`.
386    ///
387    /// Supported modes:
388    /// - `EXPLAIN ...` -> logical plan text
389    /// - `EXPLAIN EXECUTION ...` -> execution descriptor text
390    /// - `EXPLAIN JSON ...` -> logical plan canonical JSON
391    pub fn explain_sql<E>(&self, sql: &str) -> Result<String, QueryError>
392    where
393        E: EntityKind<Canister = C> + EntityValue,
394    {
395        let command = compile_sql_command::<E>(sql, MissingRowPolicy::Ignore)
396            .map_err(map_sql_lowering_error)?;
397
398        match command {
399            SqlCommand::Query(_) => Err(QueryError::execute(InternalError::classified(
400                ErrorClass::Unsupported,
401                ErrorOrigin::Query,
402                "explain_sql requires an EXPLAIN statement",
403            ))),
404            SqlCommand::Explain { mode, query } => match mode {
405                SqlExplainMode::Plan => Ok(query.explain()?.render_text_canonical()),
406                SqlExplainMode::Execution => query.explain_execution_text(),
407                SqlExplainMode::Json => Ok(query.explain()?.render_json_canonical()),
408            },
409            SqlCommand::ExplainGlobalAggregate { mode, command } => {
410                Self::explain_sql_global_aggregate::<E>(mode, command)
411            }
412        }
413    }
414
415    // Render one EXPLAIN payload for constrained global aggregate SQL command.
416    fn explain_sql_global_aggregate<E>(
417        mode: SqlExplainMode,
418        command: SqlGlobalAggregateCommand<E>,
419    ) -> Result<String, QueryError>
420    where
421        E: EntityKind<Canister = C> + EntityValue,
422    {
423        match mode {
424            SqlExplainMode::Plan => {
425                // Keep explain validation parity with execution by requiring the
426                // target field to resolve before returning explain output.
427                let _ = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
428
429                Ok(command.query().explain()?.render_text_canonical())
430            }
431            SqlExplainMode::Execution => {
432                let aggregate = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
433                let plan = Self::explain_load_query_terminal_with(command.query(), aggregate)?;
434
435                Ok(plan.execution_node_descriptor().render_text_tree())
436            }
437            SqlExplainMode::Json => {
438                // Keep explain validation parity with execution by requiring the
439                // target field to resolve before returning explain output.
440                let _ = sql_global_aggregate_terminal_to_expr::<E>(command.terminal())?;
441
442                Ok(command.query().explain()?.render_json_canonical())
443            }
444        }
445    }
446}