spacetimedb_expr/
check.rs

1use std::collections::HashMap;
2use std::ops::{Deref, DerefMut};
3use std::sync::Arc;
4
5use crate::expr::{Expr, ProjectList, ProjectName, Relvar};
6use crate::{expr::LeftDeepJoin, statement::Statement};
7use spacetimedb_lib::AlgebraicType;
8use spacetimedb_primitives::TableId;
9use spacetimedb_schema::schema::TableSchema;
10use spacetimedb_sql_parser::ast::BinOp;
11use spacetimedb_sql_parser::{
12    ast::{sub::SqlSelect, SqlFrom, SqlIdent, SqlJoin},
13    parser::sub::parse_subscription,
14};
15
16use super::{
17    errors::{DuplicateName, TypingError, Unresolved, Unsupported},
18    expr::RelExpr,
19    type_expr, type_proj, type_select, StatementCtx, StatementSource,
20};
21
22/// The result of type checking and name resolution
23pub type TypingResult<T> = core::result::Result<T, TypingError>;
24
25/// A view of the database schema
26pub trait SchemaView {
27    fn table_id(&self, name: &str) -> Option<TableId>;
28    fn schema_for_table(&self, table_id: TableId) -> Option<Arc<TableSchema>>;
29
30    fn schema(&self, name: &str) -> Option<Arc<TableSchema>> {
31        self.table_id(name).and_then(|table_id| self.schema_for_table(table_id))
32    }
33}
34
35#[derive(Default)]
36pub struct Relvars(HashMap<Box<str>, Arc<TableSchema>>);
37
38impl Deref for Relvars {
39    type Target = HashMap<Box<str>, Arc<TableSchema>>;
40    fn deref(&self) -> &Self::Target {
41        &self.0
42    }
43}
44
45impl DerefMut for Relvars {
46    fn deref_mut(&mut self) -> &mut Self::Target {
47        &mut self.0
48    }
49}
50
51pub trait TypeChecker {
52    type Ast;
53    type Set;
54
55    fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<ProjectList>;
56
57    fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<ProjectList>;
58
59    fn type_from(from: SqlFrom, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<RelExpr> {
60        match from {
61            SqlFrom::Expr(SqlIdent(name), SqlIdent(alias)) => {
62                let schema = Self::type_relvar(tx, &name)?;
63                vars.insert(alias.clone(), schema.clone());
64                Ok(RelExpr::RelVar(Relvar {
65                    schema,
66                    alias,
67                    delta: None,
68                }))
69            }
70            SqlFrom::Join(SqlIdent(name), SqlIdent(alias), joins) => {
71                let schema = Self::type_relvar(tx, &name)?;
72                vars.insert(alias.clone(), schema.clone());
73                let mut join = RelExpr::RelVar(Relvar {
74                    schema,
75                    alias,
76                    delta: None,
77                });
78
79                for SqlJoin {
80                    var: SqlIdent(name),
81                    alias: SqlIdent(alias),
82                    on,
83                } in joins
84                {
85                    // Check for duplicate aliases
86                    if vars.contains_key(&alias) {
87                        return Err(DuplicateName(alias.into_string()).into());
88                    }
89
90                    let lhs = Box::new(join);
91                    let rhs = Relvar {
92                        schema: Self::type_relvar(tx, &name)?,
93                        alias,
94                        delta: None,
95                    };
96
97                    vars.insert(rhs.alias.clone(), rhs.schema.clone());
98
99                    if let Some(on) = on {
100                        if let Expr::BinOp(BinOp::Eq, a, b) = type_expr(vars, on, Some(&AlgebraicType::Bool))? {
101                            if let (Expr::Field(a), Expr::Field(b)) = (*a, *b) {
102                                join = RelExpr::EqJoin(LeftDeepJoin { lhs, rhs }, a, b);
103                                continue;
104                            }
105                        }
106                        unreachable!("Unreachability guaranteed by parser")
107                    }
108
109                    join = RelExpr::LeftDeepJoin(LeftDeepJoin { lhs, rhs });
110                }
111
112                Ok(join)
113            }
114        }
115    }
116
117    fn type_relvar(tx: &impl SchemaView, name: &str) -> TypingResult<Arc<TableSchema>> {
118        tx.schema(name)
119            .ok_or_else(|| Unresolved::table(name))
120            .map_err(TypingError::from)
121    }
122}
123
124/// Type checker for subscriptions
125struct SubChecker;
126
127impl TypeChecker for SubChecker {
128    type Ast = SqlSelect;
129    type Set = SqlSelect;
130
131    fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<ProjectList> {
132        Self::type_set(ast, &mut Relvars::default(), tx)
133    }
134
135    fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<ProjectList> {
136        match ast {
137            SqlSelect {
138                project,
139                from,
140                filter: None,
141            } => {
142                let input = Self::type_from(from, vars, tx)?;
143                type_proj(input, project, vars)
144            }
145            SqlSelect {
146                project,
147                from,
148                filter: Some(expr),
149            } => {
150                let input = Self::type_from(from, vars, tx)?;
151                type_proj(type_select(input, expr, vars)?, project, vars)
152            }
153        }
154    }
155}
156
157/// Parse and type check a subscription query
158pub fn parse_and_type_sub(sql: &str, tx: &impl SchemaView) -> TypingResult<ProjectName> {
159    expect_table_type(SubChecker::type_ast(parse_subscription(sql)?, tx)?)
160}
161
162/// Type check a subscription query
163pub fn type_subscription(ast: SqlSelect, tx: &impl SchemaView) -> TypingResult<ProjectName> {
164    expect_table_type(SubChecker::type_ast(ast, tx)?)
165}
166
167/// Parse and type check a *subscription* query into a `StatementCtx`
168pub fn compile_sql_sub<'a>(sql: &'a str, tx: &impl SchemaView) -> TypingResult<StatementCtx<'a>> {
169    Ok(StatementCtx {
170        statement: Statement::Select(ProjectList::Name(parse_and_type_sub(sql, tx)?)),
171        sql,
172        source: StatementSource::Subscription,
173    })
174}
175
176/// Returns an error if the input type is not a table type or relvar
177fn expect_table_type(expr: ProjectList) -> TypingResult<ProjectName> {
178    match expr {
179        ProjectList::Name(proj) => Ok(proj),
180        ProjectList::Limit(input, _) => expect_table_type(*input),
181        ProjectList::List(..) | ProjectList::Agg(..) => Err(Unsupported::ReturnType.into()),
182    }
183}
184
185pub mod test_utils {
186    use spacetimedb_lib::{db::raw_def::v9::RawModuleDefV9Builder, ProductType};
187    use spacetimedb_primitives::TableId;
188    use spacetimedb_schema::{
189        def::ModuleDef,
190        schema::{Schema, TableSchema},
191    };
192    use std::sync::Arc;
193
194    use super::SchemaView;
195
196    pub fn build_module_def(types: Vec<(&str, ProductType)>) -> ModuleDef {
197        let mut builder = RawModuleDefV9Builder::new();
198        for (name, ty) in types {
199            builder.build_table_with_new_type(name, ty, true);
200        }
201        builder.finish().try_into().expect("failed to generate module def")
202    }
203
204    pub struct SchemaViewer(pub ModuleDef);
205
206    impl SchemaView for SchemaViewer {
207        fn table_id(&self, name: &str) -> Option<TableId> {
208            match name {
209                "t" => Some(TableId(0)),
210                "s" => Some(TableId(1)),
211                _ => None,
212            }
213        }
214
215        fn schema_for_table(&self, table_id: TableId) -> Option<Arc<TableSchema>> {
216            match table_id.idx() {
217                0 => Some((TableId(0), "t")),
218                1 => Some((TableId(1), "s")),
219                _ => None,
220            }
221            .and_then(|(table_id, name)| {
222                self.0
223                    .table(name)
224                    .map(|def| Arc::new(TableSchema::from_module_def(&self.0, def, (), table_id)))
225            })
226        }
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use crate::check::test_utils::{build_module_def, SchemaViewer};
233    use spacetimedb_lib::{AlgebraicType, ProductType};
234    use spacetimedb_schema::def::ModuleDef;
235
236    use super::parse_and_type_sub;
237
238    fn module_def() -> ModuleDef {
239        build_module_def(vec![
240            (
241                "t",
242                ProductType::from([
243                    ("ts", AlgebraicType::timestamp()),
244                    ("i8", AlgebraicType::I8),
245                    ("u8", AlgebraicType::U8),
246                    ("i16", AlgebraicType::I16),
247                    ("u16", AlgebraicType::U16),
248                    ("i32", AlgebraicType::I32),
249                    ("u32", AlgebraicType::U32),
250                    ("i64", AlgebraicType::I64),
251                    ("u64", AlgebraicType::U64),
252                    ("int", AlgebraicType::U32),
253                    ("f32", AlgebraicType::F32),
254                    ("f64", AlgebraicType::F64),
255                    ("i128", AlgebraicType::I128),
256                    ("u128", AlgebraicType::U128),
257                    ("i256", AlgebraicType::I256),
258                    ("u256", AlgebraicType::U256),
259                    ("str", AlgebraicType::String),
260                    ("arr", AlgebraicType::array(AlgebraicType::String)),
261                ]),
262            ),
263            (
264                "s",
265                ProductType::from([
266                    ("id", AlgebraicType::identity()),
267                    ("u32", AlgebraicType::U32),
268                    ("arr", AlgebraicType::array(AlgebraicType::String)),
269                    ("bytes", AlgebraicType::bytes()),
270                ]),
271            ),
272        ])
273    }
274
275    #[test]
276    fn valid_literals() {
277        let tx = SchemaViewer(module_def());
278
279        struct TestCase {
280            sql: &'static str,
281            msg: &'static str,
282        }
283
284        for TestCase { sql, msg } in [
285            TestCase {
286                sql: "select * from t where i32 = -1",
287                msg: "Leading `-`",
288            },
289            TestCase {
290                sql: "select * from t where u32 = +1",
291                msg: "Leading `+`",
292            },
293            TestCase {
294                sql: "select * from t where u32 = 1e3",
295                msg: "Scientific notation",
296            },
297            TestCase {
298                sql: "select * from t where u32 = 1E3",
299                msg: "Case insensitive scientific notation",
300            },
301            TestCase {
302                sql: "select * from t where f32 = 1e3",
303                msg: "Integers can parse as floats",
304            },
305            TestCase {
306                sql: "select * from t where f32 = 1e-3",
307                msg: "Negative exponent",
308            },
309            TestCase {
310                sql: "select * from t where f32 = 0.1",
311                msg: "Standard decimal notation",
312            },
313            TestCase {
314                sql: "select * from t where f32 = .1",
315                msg: "Leading `.`",
316            },
317            TestCase {
318                sql: "select * from t where f32 = 1e40",
319                msg: "Infinity",
320            },
321            TestCase {
322                sql: "select * from t where u256 = 1e40",
323                msg: "u256",
324            },
325            TestCase {
326                sql: "select * from t where ts = '2025-02-10T15:45:30Z'",
327                msg: "timestamp",
328            },
329            TestCase {
330                sql: "select * from t where ts = '2025-02-10T15:45:30.123Z'",
331                msg: "timestamp ms",
332            },
333            TestCase {
334                sql: "select * from t where ts = '2025-02-10T15:45:30.123456789Z'",
335                msg: "timestamp ns",
336            },
337            TestCase {
338                sql: "select * from t where ts = '2025-02-10 15:45:30+02:00'",
339                msg: "timestamp with timezone",
340            },
341            TestCase {
342                sql: "select * from t where ts = '2025-02-10 15:45:30.123+02:00'",
343                msg: "timestamp ms with timezone",
344            },
345        ] {
346            let result = parse_and_type_sub(sql, &tx);
347            assert!(result.is_ok(), "name: {}, error: {}", msg, result.unwrap_err());
348        }
349    }
350
351    #[test]
352    fn valid_literals_for_type() {
353        let tx = SchemaViewer(module_def());
354
355        for ty in [
356            "i8", "u8", "i16", "u16", "i32", "u32", "i64", "u64", "f32", "f64", "i128", "u128", "i256", "u256",
357        ] {
358            let sql = format!("select * from t where {ty} = 127");
359            let result = parse_and_type_sub(&sql, &tx);
360            assert!(result.is_ok(), "Faild to parse {ty}: {}", result.unwrap_err());
361        }
362    }
363
364    #[test]
365    fn invalid_literals() {
366        let tx = SchemaViewer(module_def());
367
368        struct TestCase {
369            sql: &'static str,
370            msg: &'static str,
371        }
372
373        for TestCase { sql, msg } in [
374            TestCase {
375                sql: "select * from t where u8 = -1",
376                msg: "Negative integer for unsigned column",
377            },
378            TestCase {
379                sql: "select * from t where u8 = 1e3",
380                msg: "Out of bounds",
381            },
382            TestCase {
383                sql: "select * from t where u8 = 0.1",
384                msg: "Float as integer",
385            },
386            TestCase {
387                sql: "select * from t where u32 = 1e-3",
388                msg: "Float as integer",
389            },
390            TestCase {
391                sql: "select * from t where i32 = 1e-3",
392                msg: "Float as integer",
393            },
394        ] {
395            let result = parse_and_type_sub(sql, &tx);
396            assert!(result.is_err(), "{msg}");
397        }
398    }
399
400    #[test]
401    fn valid() {
402        let tx = SchemaViewer(module_def());
403
404        struct TestCase {
405            sql: &'static str,
406            msg: &'static str,
407        }
408
409        for TestCase { sql, msg } in [
410            TestCase {
411                sql: "select * from t",
412                msg: "Can select * on any table",
413            },
414            TestCase {
415                sql: "select * from t where true",
416                msg: "Boolean literals are valid in WHERE clause",
417            },
418            TestCase {
419                sql: "select * from t where t.u32 = 1",
420                msg: "Can qualify column references with table name",
421            },
422            TestCase {
423                sql: "select * from t where u32 = 1",
424                msg: "Can leave columns unqualified when unambiguous",
425            },
426            TestCase {
427                sql: "select * from t where t.u32 = 1 or t.str = ''",
428                msg: "Type OR with qualified column references",
429            },
430            TestCase {
431                sql: "select * from s where s.bytes = 0xABCD or bytes = X'ABCD'",
432                msg: "Type OR with mixed qualified and unqualified column references",
433            },
434            TestCase {
435                sql: "select * from s as r where r.bytes = 0xABCD or bytes = X'ABCD'",
436                msg: "Type OR with table alias",
437            },
438            TestCase {
439                sql: "select t.* from t join s",
440                msg: "Type cross join + projection",
441            },
442            TestCase {
443                sql: "select t.* from t join s join s as r where t.u32 = s.u32 and s.u32 = r.u32",
444                msg: "Type self join + projection",
445            },
446            TestCase {
447                sql: "select t.* from t join s on t.u32 = s.u32 where t.f32 = 0.1",
448                msg: "Type inner join + projection",
449            },
450        ] {
451            let result = parse_and_type_sub(sql, &tx);
452            assert!(result.is_ok(), "{msg}");
453        }
454    }
455
456    #[test]
457    fn invalid() {
458        let tx = SchemaViewer(module_def());
459
460        struct TestCase {
461            sql: &'static str,
462            msg: &'static str,
463        }
464
465        for TestCase { sql, msg } in [
466            TestCase {
467                sql: "select * from r",
468                msg: "Table r does not exist",
469            },
470            TestCase {
471                sql: "select * from t where t.a = 1",
472                msg: "Field a does not exist on table t",
473            },
474            TestCase {
475                sql: "select * from t as r where r.a = 1",
476                msg: "Field a does not exist on table t",
477            },
478            TestCase {
479                sql: "select * from t where u32 = 'str'",
480                msg: "Field u32 is not a string",
481            },
482            TestCase {
483                sql: "select * from t where t.u32 = 1.3",
484                msg: "Field u32 is not a float",
485            },
486            TestCase {
487                sql: "select * from t as r where t.u32 = 5",
488                msg: "t is not in scope after alias",
489            },
490            TestCase {
491                sql: "select u32 from t",
492                msg: "Subscriptions must be typed to a single table",
493            },
494            TestCase {
495                sql: "select * from t join s",
496                msg: "Subscriptions must be typed to a single table",
497            },
498            TestCase {
499                sql: "select t.* from t join t",
500                msg: "Self join requires aliases",
501            },
502            TestCase {
503                sql: "select t.* from t join s on t.arr = s.arr",
504                msg: "Product values are not comparable",
505            },
506            TestCase {
507                sql: "select t.* from t join s on t.u32 = r.u32 join s as r",
508                msg: "Alias r is not in scope when it is referenced",
509            },
510            TestCase {
511                sql: "select * from t limit 5",
512                msg: "Subscriptions do not support limit",
513            },
514        ] {
515            let result = parse_and_type_sub(sql, &tx);
516            assert!(result.is_err(), "{msg}");
517        }
518    }
519}