spacetimedb_expr/
lib.rs

1use std::{collections::HashSet, ops::Deref, str::FromStr};
2
3use crate::statement::Statement;
4use anyhow::anyhow;
5use anyhow::bail;
6use anyhow::Context;
7use bigdecimal::BigDecimal;
8use bigdecimal::ToPrimitive;
9use check::{Relvars, TypingResult};
10use errors::{DuplicateName, InvalidLiteral, InvalidOp, InvalidWildcard, UnexpectedType, Unresolved};
11use ethnum::i256;
12use ethnum::u256;
13use expr::AggType;
14use expr::{Expr, FieldProject, ProjectList, ProjectName, RelExpr};
15use spacetimedb_lib::ser::Serialize;
16use spacetimedb_lib::Timestamp;
17use spacetimedb_lib::{from_hex_pad, AlgebraicType, AlgebraicValue, ConnectionId, Identity};
18use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type;
19use spacetimedb_sats::algebraic_value::ser::ValueSerializer;
20use spacetimedb_schema::schema::ColumnSchema;
21use spacetimedb_sql_parser::ast::{self, BinOp, ProjectElem, SqlExpr, SqlIdent, SqlLiteral};
22use spacetimedb_sql_parser::parser::recursion;
23
24pub mod check;
25pub mod errors;
26pub mod expr;
27pub mod rls;
28pub mod statement;
29
30/// Type check and lower a [SqlExpr]
31pub(crate) fn type_select(input: RelExpr, expr: SqlExpr, vars: &Relvars) -> TypingResult<RelExpr> {
32    Ok(RelExpr::Select(
33        Box::new(input),
34        type_expr(vars, expr, Some(&AlgebraicType::Bool))?,
35    ))
36}
37
38/// Type check a LIMIT clause
39pub(crate) fn type_limit(input: ProjectList, limit: &str) -> TypingResult<ProjectList> {
40    Ok(
41        parse_int(limit, AlgebraicType::U64, BigDecimal::to_u64, AlgebraicValue::U64)
42            .map_err(|_| InvalidLiteral::new(limit.to_owned(), &AlgebraicType::U64))
43            .and_then(|n| {
44                n.into_u64()
45                    .map_err(|_| InvalidLiteral::new(limit.to_owned(), &AlgebraicType::U64))
46            })
47            .map(|n| ProjectList::Limit(Box::new(input), n))?,
48    )
49}
50
51/// Type check and lower a [ast::Project]
52pub(crate) fn type_proj(input: RelExpr, proj: ast::Project, vars: &Relvars) -> TypingResult<ProjectList> {
53    match proj {
54        ast::Project::Star(None) if input.nfields() > 1 => Err(InvalidWildcard::Join.into()),
55        ast::Project::Star(None) => Ok(ProjectList::Name(vec![ProjectName::None(input)])),
56        ast::Project::Star(Some(SqlIdent(var))) if input.has_field(&var) => {
57            Ok(ProjectList::Name(vec![ProjectName::Some(input, var)]))
58        }
59        ast::Project::Star(Some(SqlIdent(var))) => Err(Unresolved::var(&var).into()),
60        ast::Project::Count(SqlIdent(alias)) => {
61            Ok(ProjectList::Agg(vec![input], AggType::Count, alias, AlgebraicType::U64))
62        }
63        ast::Project::Exprs(elems) => {
64            let mut projections = vec![];
65            let mut names = HashSet::new();
66
67            for ProjectElem(expr, SqlIdent(alias)) in elems {
68                if !names.insert(alias.clone()) {
69                    return Err(DuplicateName(alias.into_string()).into());
70                }
71
72                if let Expr::Field(p) = type_expr(vars, expr.into(), None)? {
73                    projections.push((alias, p));
74                }
75            }
76
77            Ok(ProjectList::List(vec![input], projections))
78        }
79    }
80}
81
82// These types determine the size of each stack frame during type checking.
83// Changing their sizes will require updating the recursion limit to avoid stack overflows.
84const _: () = assert!(size_of::<TypingResult<Expr>>() == 64);
85const _: () = assert!(size_of::<SqlExpr>() == 40);
86
87fn _type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>, depth: usize) -> TypingResult<Expr> {
88    recursion::guard(depth, recursion::MAX_RECURSION_TYP_EXPR, "expr::type_expr")?;
89
90    match (expr, expected) {
91        (SqlExpr::Lit(SqlLiteral::Bool(v)), None | Some(AlgebraicType::Bool)) => Ok(Expr::bool(v)),
92        (SqlExpr::Lit(SqlLiteral::Bool(_)), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()),
93        (SqlExpr::Lit(SqlLiteral::Str(_) | SqlLiteral::Num(_) | SqlLiteral::Hex(_)), None) => {
94            Err(Unresolved::Literal.into())
95        }
96        (SqlExpr::Lit(SqlLiteral::Str(v) | SqlLiteral::Num(v) | SqlLiteral::Hex(v)), Some(ty)) => Ok(Expr::Value(
97            parse(&v, ty).map_err(|_| InvalidLiteral::new(v.into_string(), ty))?,
98            ty.clone(),
99        )),
100        (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), None) => {
101            let table_type = vars.deref().get(&table).ok_or_else(|| Unresolved::var(&table))?;
102            let ColumnSchema { col_pos, col_type, .. } = table_type
103                .get_column_by_name(&field)
104                .ok_or_else(|| Unresolved::var(&field))?;
105            Ok(Expr::Field(FieldProject {
106                table,
107                field: col_pos.idx(),
108                ty: col_type.clone(),
109            }))
110        }
111        (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), Some(ty)) => {
112            let table_type = vars.deref().get(&table).ok_or_else(|| Unresolved::var(&table))?;
113            let ColumnSchema { col_pos, col_type, .. } = table_type
114                .as_ref()
115                .get_column_by_name(&field)
116                .ok_or_else(|| Unresolved::var(&field))?;
117            if col_type != ty {
118                return Err(UnexpectedType::new(col_type, ty).into());
119            }
120            Ok(Expr::Field(FieldProject {
121                table,
122                field: col_pos.idx(),
123                ty: col_type.clone(),
124            }))
125        }
126        (SqlExpr::Log(a, b, op), None | Some(AlgebraicType::Bool)) => {
127            let a = _type_expr(vars, *a, Some(&AlgebraicType::Bool), depth + 1)?;
128            let b = _type_expr(vars, *b, Some(&AlgebraicType::Bool), depth + 1)?;
129            Ok(Expr::LogOp(op, Box::new(a), Box::new(b)))
130        }
131        (SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) if matches!(&*a, SqlExpr::Lit(_)) => {
132            let b = _type_expr(vars, *b, None, depth + 1)?;
133            let a = _type_expr(vars, *a, Some(b.ty()), depth + 1)?;
134            if !op_supports_type(op, a.ty()) {
135                return Err(InvalidOp::new(op, a.ty()).into());
136            }
137            Ok(Expr::BinOp(op, Box::new(a), Box::new(b)))
138        }
139        (SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) => {
140            let a = _type_expr(vars, *a, None, depth + 1)?;
141            let b = _type_expr(vars, *b, Some(a.ty()), depth + 1)?;
142            if !op_supports_type(op, a.ty()) {
143                return Err(InvalidOp::new(op, a.ty()).into());
144            }
145            Ok(Expr::BinOp(op, Box::new(a), Box::new(b)))
146        }
147        (SqlExpr::Bin(..) | SqlExpr::Log(..), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()),
148        // Both unqualified names as well as parameters are syntactic constructs.
149        // Unqualified names are qualified and parameters are resolved before type checking.
150        (SqlExpr::Var(_) | SqlExpr::Param(_), _) => unreachable!(),
151    }
152}
153
154/// Type check and lower a [SqlExpr] into a logical [Expr].
155pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>) -> TypingResult<Expr> {
156    _type_expr(vars, expr, expected, 0)
157}
158
159/// Is this type compatible with this binary operator?
160fn op_supports_type(_op: BinOp, t: &AlgebraicType) -> bool {
161    t.is_bool()
162        || t.is_integer()
163        || t.is_float()
164        || t.is_string()
165        || t.is_bytes()
166        || t.is_identity()
167        || t.is_connection_id()
168        || t.is_timestamp()
169}
170
171/// Parse an integer literal into an [AlgebraicValue]
172fn parse_int<Int, Val, ToInt, ToVal>(
173    literal: &str,
174    ty: AlgebraicType,
175    to_int: ToInt,
176    to_val: ToVal,
177) -> anyhow::Result<AlgebraicValue>
178where
179    Int: Into<Val>,
180    ToInt: FnOnce(&BigDecimal) -> Option<Int>,
181    ToVal: FnOnce(Val) -> AlgebraicValue,
182{
183    // Why are we using an arbitrary precision type?
184    // For scientific notation as well as i256 and u256.
185    BigDecimal::from_str(literal)
186        .ok()
187        .filter(|decimal| decimal.is_integer())
188        .ok_or_else(|| anyhow!("{literal} is not an integer"))
189        .map(|decimal| to_int(&decimal).map(|val| val.into()).map(to_val))
190        .transpose()
191        .ok_or_else(|| anyhow!("{literal} is out of bounds for type {}", fmt_algebraic_type(&ty)))?
192}
193
194/// Parse a floating point literal into an [AlgebraicValue]
195fn parse_float<Float, Value, ToFloat, ToValue>(
196    literal: &str,
197    ty: AlgebraicType,
198    to_float: ToFloat,
199    to_value: ToValue,
200) -> anyhow::Result<AlgebraicValue>
201where
202    Float: Into<Value>,
203    ToFloat: FnOnce(&BigDecimal) -> Option<Float>,
204    ToValue: FnOnce(Value) -> AlgebraicValue,
205{
206    BigDecimal::from_str(literal)
207        .ok()
208        .and_then(|decimal| to_float(&decimal))
209        .map(|value| value.into())
210        .map(to_value)
211        .ok_or_else(|| anyhow!("{literal} is not a valid {}", fmt_algebraic_type(&ty)))
212}
213
214/// Parses a source text literal as a particular type
215pub(crate) fn parse(value: &str, ty: &AlgebraicType) -> anyhow::Result<AlgebraicValue> {
216    let to_timestamp = || {
217        Timestamp::parse_from_rfc3339(value)?
218            .serialize(ValueSerializer)
219            .with_context(|| "Could not parse timestamp")
220    };
221    let to_bytes = || {
222        from_hex_pad::<Vec<u8>, _>(value)
223            .map(|v| v.into_boxed_slice())
224            .map(AlgebraicValue::Bytes)
225            .with_context(|| "Could not parse hex value")
226    };
227    let to_identity = || {
228        Identity::from_hex(value)
229            .map(AlgebraicValue::from)
230            .with_context(|| "Could not parse identity")
231    };
232    let to_connection_id = || {
233        ConnectionId::from_hex(value)
234            .map(AlgebraicValue::from)
235            .with_context(|| "Could not parse connection id")
236    };
237    let to_i256 = |decimal: &BigDecimal| {
238        i256::from_str_radix(
239            // Convert to decimal notation
240            &decimal.to_plain_string(),
241            10,
242        )
243        .ok()
244    };
245    let to_u256 = |decimal: &BigDecimal| {
246        u256::from_str_radix(
247            // Convert to decimal notation
248            &decimal.to_plain_string(),
249            10,
250        )
251        .ok()
252    };
253    match ty {
254        AlgebraicType::I8 => parse_int(
255            // Parse literal as I8
256            value,
257            AlgebraicType::I8,
258            BigDecimal::to_i8,
259            AlgebraicValue::I8,
260        ),
261        AlgebraicType::U8 => parse_int(
262            // Parse literal as U8
263            value,
264            AlgebraicType::U8,
265            BigDecimal::to_u8,
266            AlgebraicValue::U8,
267        ),
268        AlgebraicType::I16 => parse_int(
269            // Parse literal as I16
270            value,
271            AlgebraicType::I16,
272            BigDecimal::to_i16,
273            AlgebraicValue::I16,
274        ),
275        AlgebraicType::U16 => parse_int(
276            // Parse literal as U16
277            value,
278            AlgebraicType::U16,
279            BigDecimal::to_u16,
280            AlgebraicValue::U16,
281        ),
282        AlgebraicType::I32 => parse_int(
283            // Parse literal as I32
284            value,
285            AlgebraicType::I32,
286            BigDecimal::to_i32,
287            AlgebraicValue::I32,
288        ),
289        AlgebraicType::U32 => parse_int(
290            // Parse literal as U32
291            value,
292            AlgebraicType::U32,
293            BigDecimal::to_u32,
294            AlgebraicValue::U32,
295        ),
296        AlgebraicType::I64 => parse_int(
297            // Parse literal as I64
298            value,
299            AlgebraicType::I64,
300            BigDecimal::to_i64,
301            AlgebraicValue::I64,
302        ),
303        AlgebraicType::U64 => parse_int(
304            // Parse literal as U64
305            value,
306            AlgebraicType::U64,
307            BigDecimal::to_u64,
308            AlgebraicValue::U64,
309        ),
310        AlgebraicType::F32 => parse_float(
311            // Parse literal as F32
312            value,
313            AlgebraicType::F32,
314            BigDecimal::to_f32,
315            AlgebraicValue::F32,
316        ),
317        AlgebraicType::F64 => parse_float(
318            // Parse literal as F64
319            value,
320            AlgebraicType::F64,
321            BigDecimal::to_f64,
322            AlgebraicValue::F64,
323        ),
324        AlgebraicType::I128 => parse_int(
325            // Parse literal as I128
326            value,
327            AlgebraicType::I128,
328            BigDecimal::to_i128,
329            AlgebraicValue::I128,
330        ),
331        AlgebraicType::U128 => parse_int(
332            // Parse literal as U128
333            value,
334            AlgebraicType::U128,
335            BigDecimal::to_u128,
336            AlgebraicValue::U128,
337        ),
338        AlgebraicType::I256 => parse_int(
339            // Parse literal as I256
340            value,
341            AlgebraicType::I256,
342            to_i256,
343            AlgebraicValue::I256,
344        ),
345        AlgebraicType::U256 => parse_int(
346            // Parse literal as U256
347            value,
348            AlgebraicType::U256,
349            to_u256,
350            AlgebraicValue::U256,
351        ),
352        AlgebraicType::String => Ok(AlgebraicValue::String(value.into())),
353        t if t.is_timestamp() => to_timestamp(),
354        t if t.is_bytes() => to_bytes(),
355        t if t.is_identity() => to_identity(),
356        t if t.is_connection_id() => to_connection_id(),
357        t => bail!("Literal values for type {} are not supported", fmt_algebraic_type(t)),
358    }
359}
360
361/// The source of a statement
362pub enum StatementSource {
363    Subscription,
364    Query,
365}
366
367/// A statement context.
368///
369/// This is a wrapper around a statement, its source, and the original SQL text.
370pub struct StatementCtx<'a> {
371    pub statement: Statement,
372    pub sql: &'a str,
373    pub source: StatementSource,
374}