1use std::sync::Arc;
2
3use spacetimedb_lib::{identity::AuthCtx, st_var::StVarValue, AlgebraicType, AlgebraicValue, ProductValue};
4use spacetimedb_primitives::{ColId, TableId};
5use spacetimedb_schema::schema::{ColumnSchema, TableSchema};
6use spacetimedb_sql_parser::{
7 ast::{
8 sql::{SqlAst, SqlDelete, SqlInsert, SqlSelect, SqlSet, SqlShow, SqlUpdate},
9 BinOp, SqlIdent, SqlLiteral,
10 },
11 parser::sql::parse_sql,
12};
13use thiserror::Error;
14
15use crate::{
16 check::Relvars,
17 errors::InvalidLiteral,
18 expr::{FieldProject, ProjectList, RelExpr, Relvar},
19 type_limit,
20};
21
22use super::{
23 check::{SchemaView, TypeChecker, TypingResult},
24 errors::{InsertFieldsError, InsertValuesError, TypingError, UnexpectedType, Unresolved},
25 expr::Expr,
26 parse, type_expr, type_proj, type_select, StatementCtx, StatementSource,
27};
28
29pub enum Statement {
30 Select(ProjectList),
31 DML(DML),
32}
33
34pub enum DML {
35 Insert(TableInsert),
36 Update(TableUpdate),
37 Delete(TableDelete),
38}
39
40impl DML {
41 pub fn table_schema(&self) -> &TableSchema {
43 match self {
44 Self::Insert(insert) => &insert.table,
45 Self::Delete(delete) => &delete.table,
46 Self::Update(update) => &update.table,
47 }
48 }
49
50 pub fn table_id(&self) -> TableId {
52 self.table_schema().table_id
53 }
54
55 pub fn table_name(&self) -> Box<str> {
57 self.table_schema().table_name.clone()
58 }
59}
60
61pub struct TableInsert {
62 pub table: Arc<TableSchema>,
63 pub rows: Box<[ProductValue]>,
64}
65
66pub struct TableDelete {
67 pub table: Arc<TableSchema>,
68 pub filter: Option<Expr>,
69}
70
71pub struct TableUpdate {
72 pub table: Arc<TableSchema>,
73 pub columns: Box<[(ColId, AlgebraicValue)]>,
74 pub filter: Option<Expr>,
75}
76
77pub struct SetVar {
78 pub name: String,
79 pub value: AlgebraicValue,
80}
81
82pub struct ShowVar {
83 pub name: String,
84}
85
86pub fn type_insert(insert: SqlInsert, tx: &impl SchemaView) -> TypingResult<TableInsert> {
88 let SqlInsert {
89 table: SqlIdent(table_name),
90 fields,
91 values,
92 } = insert;
93
94 let schema = tx
95 .schema(&table_name)
96 .ok_or_else(|| Unresolved::table(&table_name))
97 .map_err(TypingError::from)?;
98
99 let n = schema.columns().len();
101 if fields.len() != schema.columns().len() {
102 return Err(TypingError::from(InsertFieldsError {
103 table: table_name.into_string(),
104 nfields: fields.len(),
105 ncols: schema.columns().len(),
106 }));
107 }
108
109 let mut rows = Vec::new();
110 for row in values.0 {
111 if row.len() != n {
113 return Err(TypingError::from(InsertValuesError {
114 table: table_name.into_string(),
115 values: row.len(),
116 fields: n,
117 }));
118 }
119 let mut values = Vec::new();
120 for (value, ty) in row.into_iter().zip(
121 schema
122 .as_ref()
123 .columns()
124 .iter()
125 .map(|ColumnSchema { col_type, .. }| col_type),
126 ) {
127 match (value, ty) {
128 (SqlLiteral::Bool(v), AlgebraicType::Bool) => {
129 values.push(AlgebraicValue::Bool(v));
130 }
131 (SqlLiteral::Str(v), AlgebraicType::String) => {
132 values.push(AlgebraicValue::String(v));
133 }
134 (SqlLiteral::Bool(_), _) => {
135 return Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into());
136 }
137 (SqlLiteral::Str(_), _) => {
138 return Err(UnexpectedType::new(&AlgebraicType::String, ty).into());
139 }
140 (SqlLiteral::Hex(v), ty) | (SqlLiteral::Num(v), ty) => {
141 values.push(parse(&v, ty).map_err(|_| InvalidLiteral::new(v.into_string(), ty))?);
142 }
143 }
144 }
145 rows.push(ProductValue::from(values));
146 }
147 let into = schema;
148 let rows = rows.into_boxed_slice();
149 Ok(TableInsert { table: into, rows })
150}
151
152pub fn type_delete(delete: SqlDelete, tx: &impl SchemaView) -> TypingResult<TableDelete> {
154 let SqlDelete {
155 table: SqlIdent(table_name),
156 filter,
157 } = delete;
158 let from = tx
159 .schema(&table_name)
160 .ok_or_else(|| Unresolved::table(&table_name))
161 .map_err(TypingError::from)?;
162 let mut vars = Relvars::default();
163 vars.insert(table_name.clone(), from.clone());
164 let expr = filter
165 .map(|expr| type_expr(&vars, expr, Some(&AlgebraicType::Bool)))
166 .transpose()?;
167 Ok(TableDelete {
168 table: from,
169 filter: expr,
170 })
171}
172
173pub fn type_update(update: SqlUpdate, tx: &impl SchemaView) -> TypingResult<TableUpdate> {
175 let SqlUpdate {
176 table: SqlIdent(table_name),
177 assignments,
178 filter,
179 } = update;
180 let schema = tx
181 .schema(&table_name)
182 .ok_or_else(|| Unresolved::table(&table_name))
183 .map_err(TypingError::from)?;
184 let mut values = Vec::new();
185 for SqlSet(SqlIdent(field), lit) in assignments {
186 let ColumnSchema {
187 col_pos: col_id,
188 col_type: ty,
189 ..
190 } = schema
191 .as_ref()
192 .get_column_by_name(&field)
193 .ok_or_else(|| Unresolved::field(&table_name, &field))?;
194 match (lit, ty) {
195 (SqlLiteral::Bool(v), AlgebraicType::Bool) => {
196 values.push((*col_id, AlgebraicValue::Bool(v)));
197 }
198 (SqlLiteral::Str(v), AlgebraicType::String) => {
199 values.push((*col_id, AlgebraicValue::String(v)));
200 }
201 (SqlLiteral::Bool(_), _) => {
202 return Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into());
203 }
204 (SqlLiteral::Str(_), _) => {
205 return Err(UnexpectedType::new(&AlgebraicType::String, ty).into());
206 }
207 (SqlLiteral::Hex(v), ty) | (SqlLiteral::Num(v), ty) => {
208 values.push((
209 *col_id,
210 parse(&v, ty).map_err(|_| InvalidLiteral::new(v.into_string(), ty))?,
211 ));
212 }
213 }
214 }
215 let mut vars = Relvars::default();
216 vars.insert(table_name.clone(), schema.clone());
217 let values = values.into_boxed_slice();
218 let filter = filter
219 .map(|expr| type_expr(&vars, expr, Some(&AlgebraicType::Bool)))
220 .transpose()?;
221 Ok(TableUpdate {
222 table: schema,
223 columns: values,
224 filter,
225 })
226}
227
228#[derive(Error, Debug)]
229#[error("{name} is not a valid system variable")]
230pub struct InvalidVar {
231 pub name: String,
232}
233
234const VAR_ROW_LIMIT: &str = "row_limit";
235const VAR_SLOW_QUERY: &str = "slow_ad_hoc_query_ms";
236const VAR_SLOW_UPDATE: &str = "slow_tx_update_ms";
237const VAR_SLOW_SUB: &str = "slow_subscription_query_ms";
238
239fn is_var_valid(var: &str) -> bool {
240 var == VAR_ROW_LIMIT || var == VAR_SLOW_QUERY || var == VAR_SLOW_UPDATE || var == VAR_SLOW_SUB
241}
242
243const ST_VAR_NAME: &str = "st_var";
244const VALUE_COLUMN: &str = "value";
245
246pub fn type_and_rewrite_set(set: SqlSet, tx: &impl SchemaView) -> TypingResult<TableInsert> {
260 let SqlSet(SqlIdent(var_name), lit) = set;
261 if !is_var_valid(&var_name) {
262 return Err(InvalidVar {
263 name: var_name.into_string(),
264 }
265 .into());
266 }
267
268 match lit {
269 SqlLiteral::Bool(_) => Err(UnexpectedType::new(&AlgebraicType::U64, &AlgebraicType::Bool).into()),
270 SqlLiteral::Str(_) => Err(UnexpectedType::new(&AlgebraicType::U64, &AlgebraicType::String).into()),
271 SqlLiteral::Hex(_) => Err(UnexpectedType::new(&AlgebraicType::U64, &AlgebraicType::bytes()).into()),
272 SqlLiteral::Num(n) => {
273 let table = tx.schema(ST_VAR_NAME).ok_or_else(|| Unresolved::table(ST_VAR_NAME))?;
274 let var_name = AlgebraicValue::String(var_name);
275 let sum_value = StVarValue::try_from_primitive(
276 parse(&n, &AlgebraicType::U64)
277 .map_err(|_| InvalidLiteral::new(n.clone().into_string(), &AlgebraicType::U64))?,
278 )
279 .map_err(|_| InvalidLiteral::new(n.into_string(), &AlgebraicType::U64))?
280 .into();
281 Ok(TableInsert {
282 table,
283 rows: Box::new([ProductValue::from_iter([var_name, sum_value])]),
284 })
285 }
286 }
287}
288
289pub fn type_and_rewrite_show(show: SqlShow, tx: &impl SchemaView) -> TypingResult<ProjectList> {
303 let SqlShow(SqlIdent(var_name)) = show;
304 if !is_var_valid(&var_name) {
305 return Err(InvalidVar {
306 name: var_name.into_string(),
307 }
308 .into());
309 }
310
311 let table_schema = tx.schema(ST_VAR_NAME).ok_or_else(|| Unresolved::table(ST_VAR_NAME))?;
312
313 let value_col_ty = table_schema
314 .as_ref()
315 .get_column(1)
316 .map(|ColumnSchema { col_type, .. }| col_type)
317 .ok_or_else(|| Unresolved::field(ST_VAR_NAME, VALUE_COLUMN))?;
318
319 let var_name_field = Expr::Field(FieldProject {
324 table: ST_VAR_NAME.into(),
325 field: 0,
328 ty: AlgebraicType::String,
329 });
330
331 let var_name_value = Expr::Value(AlgebraicValue::String(var_name), AlgebraicType::String);
336
337 let column_list = vec![(
342 VALUE_COLUMN.into(),
343 FieldProject {
344 table: ST_VAR_NAME.into(),
345 field: 1,
348 ty: value_col_ty.clone(),
349 },
350 )];
351
352 let relvar = RelExpr::RelVar(Relvar {
357 schema: table_schema,
358 alias: ST_VAR_NAME.into(),
359 delta: None,
360 });
361
362 let filter = Expr::BinOp(
363 BinOp::Eq,
368 Box::new(var_name_field),
369 Box::new(var_name_value),
370 );
371
372 Ok(ProjectList::List(
373 vec![RelExpr::Select(Box::new(relvar), filter)],
374 column_list,
375 ))
376}
377
378struct SqlChecker;
380
381impl TypeChecker for SqlChecker {
382 type Ast = SqlSelect;
383 type Set = SqlSelect;
384
385 fn type_ast(ast: Self::Ast, tx: &impl SchemaView) -> TypingResult<ProjectList> {
386 Self::type_set(ast, &mut Relvars::default(), tx)
387 }
388
389 fn type_set(ast: Self::Set, vars: &mut Relvars, tx: &impl SchemaView) -> TypingResult<ProjectList> {
390 match ast {
391 SqlSelect {
392 project,
393 from,
394 filter: None,
395 limit: None,
396 } => type_proj(Self::type_from(from, vars, tx)?, project, vars),
397 SqlSelect {
398 project,
399 from,
400 filter: None,
401 limit: Some(n),
402 } => type_limit(type_proj(Self::type_from(from, vars, tx)?, project, vars)?, &n),
403 SqlSelect {
404 project,
405 from,
406 filter: Some(expr),
407 limit: None,
408 } => type_proj(
409 type_select(Self::type_from(from, vars, tx)?, expr, vars)?,
410 project,
411 vars,
412 ),
413 SqlSelect {
414 project,
415 from,
416 filter: Some(expr),
417 limit: Some(n),
418 } => type_limit(
419 type_proj(
420 type_select(Self::type_from(from, vars, tx)?, expr, vars)?,
421 project,
422 vars,
423 )?,
424 &n,
425 ),
426 }
427 }
428}
429
430pub fn parse_and_type_sql(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult<Statement> {
431 match parse_sql(sql)?.resolve_sender(auth.caller) {
432 SqlAst::Select(ast) => Ok(Statement::Select(SqlChecker::type_ast(ast, tx)?)),
433 SqlAst::Insert(insert) => Ok(Statement::DML(DML::Insert(type_insert(insert, tx)?))),
434 SqlAst::Delete(delete) => Ok(Statement::DML(DML::Delete(type_delete(delete, tx)?))),
435 SqlAst::Update(update) => Ok(Statement::DML(DML::Update(type_update(update, tx)?))),
436 SqlAst::Set(set) => Ok(Statement::DML(DML::Insert(type_and_rewrite_set(set, tx)?))),
437 SqlAst::Show(show) => Ok(Statement::Select(type_and_rewrite_show(show, tx)?)),
438 }
439}
440
441pub fn compile_sql_stmt<'a>(sql: &'a str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult<StatementCtx<'a>> {
443 let statement = parse_and_type_sql(sql, tx, auth)?;
444 Ok(StatementCtx {
445 statement,
446 sql,
447 source: StatementSource::Query,
448 })
449}
450
451#[cfg(test)]
452mod tests {
453 use super::Statement;
454 use crate::ast::LogOp;
455 use crate::check::{
456 test_utils::{build_module_def, SchemaViewer},
457 Relvars, SchemaView, TypingResult,
458 };
459 use crate::type_expr;
460 use spacetimedb_lib::{identity::AuthCtx, AlgebraicType, ProductType};
461 use spacetimedb_schema::def::ModuleDef;
462 use spacetimedb_sql_parser::ast::{SqlExpr, SqlLiteral};
463
464 fn module_def() -> ModuleDef {
465 build_module_def(vec![
466 (
467 "t",
468 ProductType::from([
469 ("u32", AlgebraicType::U32),
470 ("f32", AlgebraicType::F32),
471 ("str", AlgebraicType::String),
472 ("arr", AlgebraicType::array(AlgebraicType::String)),
473 ]),
474 ),
475 (
476 "s",
477 ProductType::from([
478 ("id", AlgebraicType::identity()),
479 ("u32", AlgebraicType::U32),
480 ("arr", AlgebraicType::array(AlgebraicType::String)),
481 ("bytes", AlgebraicType::bytes()),
482 ]),
483 ),
484 ])
485 }
486
487 fn parse_and_type_sql(sql: &str, tx: &impl SchemaView) -> TypingResult<Statement> {
489 super::parse_and_type_sql(sql, tx, &AuthCtx::for_testing())
490 }
491
492 #[test]
493 fn valid() {
494 let tx = SchemaViewer(module_def());
495
496 for sql in [
497 "select str from t",
498 "select str, arr from t",
499 "select t.str, arr from t",
500 "select * from t limit 5",
501 ] {
502 let result = parse_and_type_sql(sql, &tx);
503 assert!(result.is_ok());
504 }
505 }
506
507 #[test]
508 fn invalid() {
509 let tx = SchemaViewer(module_def());
510
511 for sql in [
512 "select id, str from s join t",
514 "select * from t limit '5'",
516 "select t.* from t join s on t.u32 = s.u32 where bytes = 0xABCD",
518 ] {
519 let result = parse_and_type_sql(sql, &tx);
520 assert!(result.is_err());
521 }
522 }
523
524 #[test]
528 fn typing_recursion() {
529 let build_query = |total, sep: char| {
530 let mut expr = SqlExpr::Lit(SqlLiteral::Bool(true));
531 for _ in 1..total {
532 let next = SqlExpr::Log(
533 Box::new(SqlExpr::Lit(SqlLiteral::Bool(true))),
534 Box::new(SqlExpr::Lit(SqlLiteral::Bool(false))),
535 LogOp::And,
536 );
537 expr = SqlExpr::Log(Box::new(expr), Box::new(next), LogOp::And);
538 }
539 type_expr(&Relvars::default(), expr, Some(&AlgebraicType::Bool))
540 .map_err(|e| e.to_string().split(sep).next().unwrap_or_default().to_string())
541 };
542 assert_eq!(build_query(2_501, ','), Err("Recursion limit exceeded".to_string()));
543
544 assert!(build_query(2_500, ',').is_ok());
545 }
546}