spacetimedb_expr/
statement.rs

1use std::sync::Arc;
2
3use spacetimedb_lib::{identity::AuthCtx, st_var::StVarValue, AlgebraicType, AlgebraicValue, ProductValue};
4use spacetimedb_primitives::{ColId, TableId};
5use spacetimedb_schema::schema::{ColumnSchema, TableSchema};
6use spacetimedb_sql_parser::{
7    ast::{
8        sql::{SqlAst, SqlDelete, SqlInsert, SqlSelect, SqlSet, SqlShow, SqlUpdate},
9        BinOp, SqlIdent, SqlLiteral,
10    },
11    parser::sql::parse_sql,
12};
13use thiserror::Error;
14
15use crate::{
16    check::Relvars,
17    errors::InvalidLiteral,
18    expr::{FieldProject, ProjectList, RelExpr, Relvar},
19    type_limit,
20};
21
22use super::{
23    check::{SchemaView, TypeChecker, TypingResult},
24    errors::{InsertFieldsError, InsertValuesError, TypingError, UnexpectedType, Unresolved},
25    expr::Expr,
26    parse, type_expr, type_proj, type_select, StatementCtx, StatementSource,
27};
28
29pub enum Statement {
30    Select(ProjectList),
31    DML(DML),
32}
33
34pub enum DML {
35    Insert(TableInsert),
36    Update(TableUpdate),
37    Delete(TableDelete),
38}
39
40impl DML {
41    /// Returns the schema of the table on which this mutation applies
42    pub fn table_schema(&self) -> &TableSchema {
43        match self {
44            Self::Insert(insert) => &insert.table,
45            Self::Delete(delete) => &delete.table,
46            Self::Update(update) => &update.table,
47        }
48    }
49
50    /// Returns the id of the table on which this mutation applies
51    pub fn table_id(&self) -> TableId {
52        self.table_schema().table_id
53    }
54
55    /// Returns the name of the table on which this mutation applies
56    pub fn table_name(&self) -> Box<str> {
57        self.table_schema().table_name.clone()
58    }
59}
60
61pub struct TableInsert {
62    pub table: Arc<TableSchema>,
63    pub rows: Box<[ProductValue]>,
64}
65
66pub struct TableDelete {
67    pub table: Arc<TableSchema>,
68    pub filter: Option<Expr>,
69}
70
71pub struct TableUpdate {
72    pub table: Arc<TableSchema>,
73    pub columns: Box<[(ColId, AlgebraicValue)]>,
74    pub filter: Option<Expr>,
75}
76
77pub struct SetVar {
78    pub name: String,
79    pub value: AlgebraicValue,
80}
81
82pub struct ShowVar {
83    pub name: String,
84}
85
86/// Type check an INSERT statement
87pub fn type_insert(insert: SqlInsert, tx: &impl SchemaView) -> TypingResult<TableInsert> {
88    let SqlInsert {
89        table: SqlIdent(table_name),
90        fields,
91        values,
92    } = insert;
93
94    let schema = tx
95        .schema(&table_name)
96        .ok_or_else(|| Unresolved::table(&table_name))
97        .map_err(TypingError::from)?;
98
99    // Expect n fields
100    let n = schema.columns().len();
101    if fields.len() != schema.columns().len() {
102        return Err(TypingError::from(InsertFieldsError {
103            table: table_name.into_string(),
104            nfields: fields.len(),
105            ncols: schema.columns().len(),
106        }));
107    }
108
109    let mut rows = Vec::new();
110    for row in values.0 {
111        // Expect each row to have n values
112        if row.len() != n {
113            return Err(TypingError::from(InsertValuesError {
114                table: table_name.into_string(),
115                values: row.len(),
116                fields: n,
117            }));
118        }
119        let mut values = Vec::new();
120        for (value, ty) in row.into_iter().zip(
121            schema
122                .as_ref()
123                .columns()
124                .iter()
125                .map(|ColumnSchema { col_type, .. }| col_type),
126        ) {
127            match (value, ty) {
128                (SqlLiteral::Bool(v), AlgebraicType::Bool) => {
129                    values.push(AlgebraicValue::Bool(v));
130                }
131                (SqlLiteral::Str(v), AlgebraicType::String) => {
132                    values.push(AlgebraicValue::String(v));
133                }
134                (SqlLiteral::Bool(_), _) => {
135                    return Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into());
136                }
137                (SqlLiteral::Str(_), _) => {
138                    return Err(UnexpectedType::new(&AlgebraicType::String, ty).into());
139                }
140                (SqlLiteral::Hex(v), ty) | (SqlLiteral::Num(v), ty) => {
141                    values.push(parse(&v, ty).map_err(|_| InvalidLiteral::new(v.into_string(), ty))?);
142                }
143            }
144        }
145        rows.push(ProductValue::from(values));
146    }
147    let into = schema;
148    let rows = rows.into_boxed_slice();
149    Ok(TableInsert { table: into, rows })
150}
151
152/// Type check a DELETE statement
153pub fn type_delete(delete: SqlDelete, tx: &impl SchemaView) -> TypingResult<TableDelete> {
154    let SqlDelete {
155        table: SqlIdent(table_name),
156        filter,
157    } = delete;
158    let from = tx
159        .schema(&table_name)
160        .ok_or_else(|| Unresolved::table(&table_name))
161        .map_err(TypingError::from)?;
162    let mut vars = Relvars::default();
163    vars.insert(table_name.clone(), from.clone());
164    let expr = filter
165        .map(|expr| type_expr(&vars, expr, Some(&AlgebraicType::Bool)))
166        .transpose()?;
167    Ok(TableDelete {
168        table: from,
169        filter: expr,
170    })
171}
172
173/// Type check an UPDATE statement
174pub fn type_update(update: SqlUpdate, tx: &impl SchemaView) -> TypingResult<TableUpdate> {
175    let SqlUpdate {
176        table: SqlIdent(table_name),
177        assignments,
178        filter,
179    } = update;
180    let schema = tx
181        .schema(&table_name)
182        .ok_or_else(|| Unresolved::table(&table_name))
183        .map_err(TypingError::from)?;
184    let mut values = Vec::new();
185    for SqlSet(SqlIdent(field), lit) in assignments {
186        let ColumnSchema {
187            col_pos: col_id,
188            col_type: ty,
189            ..
190        } = schema
191            .as_ref()
192            .get_column_by_name(&field)
193            .ok_or_else(|| Unresolved::field(&table_name, &field))?;
194        match (lit, ty) {
195            (SqlLiteral::Bool(v), AlgebraicType::Bool) => {
196                values.push((*col_id, AlgebraicValue::Bool(v)));
197            }
198            (SqlLiteral::Str(v), AlgebraicType::String) => {
199                values.push((*col_id, AlgebraicValue::String(v)));
200            }
201            (SqlLiteral::Bool(_), _) => {
202                return Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into());
203            }
204            (SqlLiteral::Str(_), _) => {
205                return Err(UnexpectedType::new(&AlgebraicType::String, ty).into());
206            }
207            (SqlLiteral::Hex(v), ty) | (SqlLiteral::Num(v), ty) => {
208                values.push((
209                    *col_id,
210                    parse(&v, ty).map_err(|_| InvalidLiteral::new(v.into_string(), ty))?,
211                ));
212            }
213        }
214    }
215    let mut vars = Relvars::default();
216    vars.insert(table_name.clone(), schema.clone());
217    let values = values.into_boxed_slice();
218    let filter = filter
219        .map(|expr| type_expr(&vars, expr, Some(&AlgebraicType::Bool)))
220        .transpose()?;
221    Ok(TableUpdate {
222        table: schema,
223        columns: values,
224        filter,
225    })
226}
227
228#[derive(Error, Debug)]
229#[error("{name} is not a valid system variable")]
230pub struct InvalidVar {
231    pub name: String,
232}
233
234const VAR_ROW_LIMIT: &str = "row_limit";
235const VAR_SLOW_QUERY: &str = "slow_ad_hoc_query_ms";
236const VAR_SLOW_UPDATE: &str = "slow_tx_update_ms";
237const VAR_SLOW_SUB: &str = "slow_subscription_query_ms";
238
239fn is_var_valid(var: &str) -> bool {
240    var == VAR_ROW_LIMIT || var == VAR_SLOW_QUERY || var == VAR_SLOW_UPDATE || var == VAR_SLOW_SUB
241}
242
243const ST_VAR_NAME: &str = "st_var";
244const VALUE_COLUMN: &str = "value";
245
246/// The concept of `SET` only exists in the ast.
247/// We translate it here to an `INSERT` on the `st_var` system table.
248/// That is:
249///
250/// ```sql
251/// SET var TO ...
252/// ```
253///
254/// is rewritten as
255///
256/// ```sql
257/// INSERT INTO st_var (name, value) VALUES ('var', ...)
258/// ```
259pub fn type_and_rewrite_set(set: SqlSet, tx: &impl SchemaView) -> TypingResult<TableInsert> {
260    let SqlSet(SqlIdent(var_name), lit) = set;
261    if !is_var_valid(&var_name) {
262        return Err(InvalidVar {
263            name: var_name.into_string(),
264        }
265        .into());
266    }
267
268    match lit {
269        SqlLiteral::Bool(_) => Err(UnexpectedType::new(&AlgebraicType::U64, &AlgebraicType::Bool).into()),
270        SqlLiteral::Str(_) => Err(UnexpectedType::new(&AlgebraicType::U64, &AlgebraicType::String).into()),
271        SqlLiteral::Hex(_) => Err(UnexpectedType::new(&AlgebraicType::U64, &AlgebraicType::bytes()).into()),
272        SqlLiteral::Num(n) => {
273            let table = tx.schema(ST_VAR_NAME).ok_or_else(|| Unresolved::table(ST_VAR_NAME))?;
274            let var_name = AlgebraicValue::String(var_name);
275            let sum_value = StVarValue::try_from_primitive(
276                parse(&n, &AlgebraicType::U64)
277                    .map_err(|_| InvalidLiteral::new(n.clone().into_string(), &AlgebraicType::U64))?,
278            )
279            .map_err(|_| InvalidLiteral::new(n.into_string(), &AlgebraicType::U64))?
280            .into();
281            Ok(TableInsert {
282                table,
283                rows: Box::new([ProductValue::from_iter([var_name, sum_value])]),
284            })
285        }
286    }
287}
288
289/// The concept of `SHOW` only exists in the ast.
290/// We translate it here to a `SELECT` on the `st_var` system table.
291/// That is:
292///
293/// ```sql
294/// SHOW var
295/// ```
296///
297/// is rewritten as
298///
299/// ```sql
300/// SELECT value FROM st_var WHERE name = 'var'
301/// ```
302pub fn type_and_rewrite_show(show: SqlShow, tx: &impl SchemaView) -> TypingResult<ProjectList> {
303    let SqlShow(SqlIdent(var_name)) = show;
304    if !is_var_valid(&var_name) {
305        return Err(InvalidVar {
306            name: var_name.into_string(),
307        }
308        .into());
309    }
310
311    let table_schema = tx.schema(ST_VAR_NAME).ok_or_else(|| Unresolved::table(ST_VAR_NAME))?;
312
313    let value_col_ty = table_schema
314        .as_ref()
315        .get_column(1)
316        .map(|ColumnSchema { col_type, .. }| col_type)
317        .ok_or_else(|| Unresolved::field(ST_VAR_NAME, VALUE_COLUMN))?;
318
319    // -------------------------------------------
320    // SELECT value FROM st_var WHERE name = 'var'
321    //                                ^^^^
322    // -------------------------------------------
323    let var_name_field = Expr::Field(FieldProject {
324        table: ST_VAR_NAME.into(),
325        // TODO: Avoid hard coding the field position.
326        // See `StVarFields` for the schema of `st_var`.
327        field: 0,
328        ty: AlgebraicType::String,
329    });
330
331    // -------------------------------------------
332    // SELECT value FROM st_var WHERE name = 'var'
333    //                                        ^^^
334    // -------------------------------------------
335    let var_name_value = Expr::Value(AlgebraicValue::String(var_name), AlgebraicType::String);
336
337    // -------------------------------------------
338    // SELECT value FROM st_var WHERE name = 'var'
339    //        ^^^^^
340    // -------------------------------------------
341    let column_list = vec![(
342        VALUE_COLUMN.into(),
343        FieldProject {
344            table: ST_VAR_NAME.into(),
345            // TODO: Avoid hard coding the field position.
346            // See `StVarFields` for the schema of `st_var`.
347            field: 1,
348            ty: value_col_ty.clone(),
349        },
350    )];
351
352    // -------------------------------------------
353    // SELECT value FROM st_var WHERE name = 'var'
354    //                   ^^^^^^
355    // -------------------------------------------
356    let relvar = RelExpr::RelVar(Relvar {
357        schema: table_schema,
358        alias: ST_VAR_NAME.into(),
359        delta: None,
360    });
361
362    let filter = Expr::BinOp(
363        // -------------------------------------------
364        // SELECT value FROM st_var WHERE name = 'var'
365        //                                    ^^^
366        // -------------------------------------------
367        BinOp::Eq,
368        Box::new(var_name_field),
369        Box::new(var_name_value),
370    );
371
372    Ok(ProjectList::List(
373        vec![RelExpr::Select(Box::new(relvar), filter)],
374        column_list,
375    ))
376}
377
378/// Type-checker for regular `SQL` queries
379struct SqlChecker;
380
381impl TypeChecker for SqlChecker {
382    type Ast = SqlSelect;
383    type Set = SqlSelect;
384
385    fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<ProjectList> {
386        Self::type_set(ast, &mut Relvars::default(), tx)
387    }
388
389    fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<ProjectList> {
390        match ast {
391            SqlSelect {
392                project,
393                from,
394                filter: None,
395                limit: None,
396            } => type_proj(Self::type_from(from, vars, tx)?, project, vars),
397            SqlSelect {
398                project,
399                from,
400                filter: None,
401                limit: Some(n),
402            } => type_limit(type_proj(Self::type_from(from, vars, tx)?, project, vars)?, &n),
403            SqlSelect {
404                project,
405                from,
406                filter: Some(expr),
407                limit: None,
408            } => type_proj(
409                type_select(Self::type_from(from, vars, tx)?, expr, vars)?,
410                project,
411                vars,
412            ),
413            SqlSelect {
414                project,
415                from,
416                filter: Some(expr),
417                limit: Some(n),
418            } => type_limit(
419                type_proj(
420                    type_select(Self::type_from(from, vars, tx)?, expr, vars)?,
421                    project,
422                    vars,
423                )?,
424                &n,
425            ),
426        }
427    }
428}
429
430pub fn parse_and_type_sql(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult<Statement> {
431    match parse_sql(sql)?.resolve_sender(auth.caller) {
432        SqlAst::Select(ast) => Ok(Statement::Select(SqlChecker::type_ast(ast, tx)?)),
433        SqlAst::Insert(insert) => Ok(Statement::DML(DML::Insert(type_insert(insert, tx)?))),
434        SqlAst::Delete(delete) => Ok(Statement::DML(DML::Delete(type_delete(delete, tx)?))),
435        SqlAst::Update(update) => Ok(Statement::DML(DML::Update(type_update(update, tx)?))),
436        SqlAst::Set(set) => Ok(Statement::DML(DML::Insert(type_and_rewrite_set(set, tx)?))),
437        SqlAst::Show(show) => Ok(Statement::Select(type_and_rewrite_show(show, tx)?)),
438    }
439}
440
441/// Parse and type check a *general* query into a [StatementCtx].
442pub fn compile_sql_stmt<'a>(sql: &'a str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult<StatementCtx<'a>> {
443    let statement = parse_and_type_sql(sql, tx, auth)?;
444    Ok(StatementCtx {
445        statement,
446        sql,
447        source: StatementSource::Query,
448    })
449}
450
451#[cfg(test)]
452mod tests {
453    use spacetimedb_lib::{identity::AuthCtx, AlgebraicType, ProductType};
454    use spacetimedb_schema::def::ModuleDef;
455
456    use crate::check::{
457        test_utils::{build_module_def, SchemaViewer},
458        SchemaView, TypingResult,
459    };
460
461    use super::Statement;
462
463    fn module_def() -> ModuleDef {
464        build_module_def(vec![
465            (
466                "t",
467                ProductType::from([
468                    ("u32", AlgebraicType::U32),
469                    ("f32", AlgebraicType::F32),
470                    ("str", AlgebraicType::String),
471                    ("arr", AlgebraicType::array(AlgebraicType::String)),
472                ]),
473            ),
474            (
475                "s",
476                ProductType::from([
477                    ("id", AlgebraicType::identity()),
478                    ("u32", AlgebraicType::U32),
479                    ("arr", AlgebraicType::array(AlgebraicType::String)),
480                    ("bytes", AlgebraicType::bytes()),
481                ]),
482            ),
483        ])
484    }
485
486    /// A wrapper around [super::parse_and_type_sql] that takes a dummy [AuthCtx]
487    fn parse_and_type_sql(sql: &str, tx: &impl SchemaView) -> TypingResult<Statement> {
488        super::parse_and_type_sql(sql, tx, &AuthCtx::for_testing())
489    }
490
491    #[test]
492    fn valid() {
493        let tx = SchemaViewer(module_def());
494
495        for sql in [
496            "select str from t",
497            "select str, arr from t",
498            "select t.str, arr from t",
499            "select * from t limit 5",
500        ] {
501            let result = parse_and_type_sql(sql, &tx);
502            assert!(result.is_ok());
503        }
504    }
505
506    #[test]
507    fn invalid() {
508        let tx = SchemaViewer(module_def());
509
510        for sql in [
511            // Unqualified columns in a join
512            "select id, str from s join t",
513            // Wrong type for limit
514            "select * from t limit '5'",
515            // Unqualified name in join expression
516            "select t.* from t join s on t.u32 = s.u32 where bytes = 0xABCD",
517        ] {
518            let result = parse_and_type_sql(sql, &tx);
519            assert!(result.is_err());
520        }
521    }
522}