spacetimedb_expr/
lib.rs

1use std::{collections::HashSet, ops::Deref, str::FromStr};
2
3use crate::statement::Statement;
4use anyhow::anyhow;
5use anyhow::bail;
6use anyhow::Context;
7use bigdecimal::BigDecimal;
8use bigdecimal::ToPrimitive;
9use check::{Relvars, TypingResult};
10use errors::{DuplicateName, InvalidLiteral, InvalidOp, InvalidWildcard, UnexpectedType, Unresolved};
11use ethnum::i256;
12use ethnum::u256;
13use expr::AggType;
14use expr::{Expr, FieldProject, ProjectList, ProjectName, RelExpr};
15use spacetimedb_lib::ser::Serialize;
16use spacetimedb_lib::Timestamp;
17use spacetimedb_lib::{from_hex_pad, AlgebraicType, AlgebraicValue, ConnectionId, Identity};
18use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type;
19use spacetimedb_sats::algebraic_value::ser::ValueSerializer;
20use spacetimedb_schema::schema::ColumnSchema;
21use spacetimedb_sql_parser::ast::{self, BinOp, ProjectElem, SqlExpr, SqlIdent, SqlLiteral};
22
23pub mod check;
24pub mod errors;
25pub mod expr;
26pub mod rls;
27pub mod statement;
28
29/// Type check and lower a [SqlExpr]
30pub(crate) fn type_select(input: RelExpr, expr: SqlExpr, vars: &Relvars) -> TypingResult<RelExpr> {
31    Ok(RelExpr::Select(
32        Box::new(input),
33        type_expr(vars, expr, Some(&AlgebraicType::Bool))?,
34    ))
35}
36
37/// Type check a LIMIT clause
38pub(crate) fn type_limit(input: ProjectList, limit: &str) -> TypingResult<ProjectList> {
39    Ok(
40        parse_int(limit, AlgebraicType::U64, BigDecimal::to_u64, AlgebraicValue::U64)
41            .map_err(|_| InvalidLiteral::new(limit.to_owned(), &AlgebraicType::U64))
42            .and_then(|n| {
43                n.into_u64()
44                    .map_err(|_| InvalidLiteral::new(limit.to_owned(), &AlgebraicType::U64))
45            })
46            .map(|n| ProjectList::Limit(Box::new(input), n))?,
47    )
48}
49
50/// Type check and lower a [ast::Project]
51pub(crate) fn type_proj(input: RelExpr, proj: ast::Project, vars: &Relvars) -> TypingResult<ProjectList> {
52    match proj {
53        ast::Project::Star(None) if input.nfields() > 1 => Err(InvalidWildcard::Join.into()),
54        ast::Project::Star(None) => Ok(ProjectList::Name(vec![ProjectName::None(input)])),
55        ast::Project::Star(Some(SqlIdent(var))) if input.has_field(&var) => {
56            Ok(ProjectList::Name(vec![ProjectName::Some(input, var)]))
57        }
58        ast::Project::Star(Some(SqlIdent(var))) => Err(Unresolved::var(&var).into()),
59        ast::Project::Count(SqlIdent(alias)) => {
60            Ok(ProjectList::Agg(vec![input], AggType::Count, alias, AlgebraicType::U64))
61        }
62        ast::Project::Exprs(elems) => {
63            let mut projections = vec![];
64            let mut names = HashSet::new();
65
66            for ProjectElem(expr, SqlIdent(alias)) in elems {
67                if !names.insert(alias.clone()) {
68                    return Err(DuplicateName(alias.into_string()).into());
69                }
70
71                if let Expr::Field(p) = type_expr(vars, expr.into(), None)? {
72                    projections.push((alias, p));
73                }
74            }
75
76            Ok(ProjectList::List(vec![input], projections))
77        }
78    }
79}
80
81/// Type check and lower a [SqlExpr] into a logical [Expr].
82pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>) -> TypingResult<Expr> {
83    match (expr, expected) {
84        (SqlExpr::Lit(SqlLiteral::Bool(v)), None | Some(AlgebraicType::Bool)) => Ok(Expr::bool(v)),
85        (SqlExpr::Lit(SqlLiteral::Bool(_)), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()),
86        (SqlExpr::Lit(SqlLiteral::Str(_) | SqlLiteral::Num(_) | SqlLiteral::Hex(_)), None) => {
87            Err(Unresolved::Literal.into())
88        }
89        (SqlExpr::Lit(SqlLiteral::Str(v) | SqlLiteral::Num(v) | SqlLiteral::Hex(v)), Some(ty)) => Ok(Expr::Value(
90            parse(&v, ty).map_err(|_| InvalidLiteral::new(v.into_string(), ty))?,
91            ty.clone(),
92        )),
93        (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), None) => {
94            let table_type = vars.deref().get(&table).ok_or_else(|| Unresolved::var(&table))?;
95            let ColumnSchema { col_pos, col_type, .. } = table_type
96                .get_column_by_name(&field)
97                .ok_or_else(|| Unresolved::var(&field))?;
98            Ok(Expr::Field(FieldProject {
99                table,
100                field: col_pos.idx(),
101                ty: col_type.clone(),
102            }))
103        }
104        (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), Some(ty)) => {
105            let table_type = vars.deref().get(&table).ok_or_else(|| Unresolved::var(&table))?;
106            let ColumnSchema { col_pos, col_type, .. } = table_type
107                .as_ref()
108                .get_column_by_name(&field)
109                .ok_or_else(|| Unresolved::var(&field))?;
110            if col_type != ty {
111                return Err(UnexpectedType::new(col_type, ty).into());
112            }
113            Ok(Expr::Field(FieldProject {
114                table,
115                field: col_pos.idx(),
116                ty: col_type.clone(),
117            }))
118        }
119        (SqlExpr::Log(a, b, op), None | Some(AlgebraicType::Bool)) => {
120            let a = type_expr(vars, *a, Some(&AlgebraicType::Bool))?;
121            let b = type_expr(vars, *b, Some(&AlgebraicType::Bool))?;
122            Ok(Expr::LogOp(op, Box::new(a), Box::new(b)))
123        }
124        (SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) if matches!(&*a, SqlExpr::Lit(_)) => {
125            let b = type_expr(vars, *b, None)?;
126            let a = type_expr(vars, *a, Some(b.ty()))?;
127            if !op_supports_type(op, a.ty()) {
128                return Err(InvalidOp::new(op, a.ty()).into());
129            }
130            Ok(Expr::BinOp(op, Box::new(a), Box::new(b)))
131        }
132        (SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) => {
133            let a = type_expr(vars, *a, None)?;
134            let b = type_expr(vars, *b, Some(a.ty()))?;
135            if !op_supports_type(op, a.ty()) {
136                return Err(InvalidOp::new(op, a.ty()).into());
137            }
138            Ok(Expr::BinOp(op, Box::new(a), Box::new(b)))
139        }
140        (SqlExpr::Bin(..) | SqlExpr::Log(..), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()),
141        // Both unqualified names as well as parameters are syntactic constructs.
142        // Unqualified names are qualified and parameters are resolved before type checking.
143        (SqlExpr::Var(_) | SqlExpr::Param(_), _) => unreachable!(),
144    }
145}
146
147/// Is this type compatible with this binary operator?
148fn op_supports_type(_op: BinOp, t: &AlgebraicType) -> bool {
149    t.is_bool()
150        || t.is_integer()
151        || t.is_float()
152        || t.is_string()
153        || t.is_bytes()
154        || t.is_identity()
155        || t.is_connection_id()
156        || t.is_timestamp()
157}
158
159/// Parse an integer literal into an [AlgebraicValue]
160fn parse_int<Int, Val, ToInt, ToVal>(
161    literal: &str,
162    ty: AlgebraicType,
163    to_int: ToInt,
164    to_val: ToVal,
165) -> anyhow::Result<AlgebraicValue>
166where
167    Int: Into<Val>,
168    ToInt: FnOnce(&BigDecimal) -> Option<Int>,
169    ToVal: FnOnce(Val) -> AlgebraicValue,
170{
171    // Why are we using an arbitrary precision type?
172    // For scientific notation as well as i256 and u256.
173    BigDecimal::from_str(literal)
174        .ok()
175        .filter(|decimal| decimal.is_integer())
176        .ok_or_else(|| anyhow!("{literal} is not an integer"))
177        .map(|decimal| to_int(&decimal).map(|val| val.into()).map(to_val))
178        .transpose()
179        .ok_or_else(|| anyhow!("{literal} is out of bounds for type {}", fmt_algebraic_type(&ty)))?
180}
181
182/// Parse a floating point literal into an [AlgebraicValue]
183fn parse_float<Float, Value, ToFloat, ToValue>(
184    literal: &str,
185    ty: AlgebraicType,
186    to_float: ToFloat,
187    to_value: ToValue,
188) -> anyhow::Result<AlgebraicValue>
189where
190    Float: Into<Value>,
191    ToFloat: FnOnce(&BigDecimal) -> Option<Float>,
192    ToValue: FnOnce(Value) -> AlgebraicValue,
193{
194    BigDecimal::from_str(literal)
195        .ok()
196        .and_then(|decimal| to_float(&decimal))
197        .map(|value| value.into())
198        .map(to_value)
199        .ok_or_else(|| anyhow!("{literal} is not a valid {}", fmt_algebraic_type(&ty)))
200}
201
202/// Parses a source text literal as a particular type
203pub(crate) fn parse(value: &str, ty: &AlgebraicType) -> anyhow::Result<AlgebraicValue> {
204    let to_timestamp = || {
205        Timestamp::parse_from_rfc3339(value)?
206            .serialize(ValueSerializer)
207            .with_context(|| "Could not parse timestamp")
208    };
209    let to_bytes = || {
210        from_hex_pad::<Vec<u8>, _>(value)
211            .map(|v| v.into_boxed_slice())
212            .map(AlgebraicValue::Bytes)
213            .with_context(|| "Could not parse hex value")
214    };
215    let to_identity = || {
216        Identity::from_hex(value)
217            .map(AlgebraicValue::from)
218            .with_context(|| "Could not parse identity")
219    };
220    let to_connection_id = || {
221        ConnectionId::from_hex(value)
222            .map(AlgebraicValue::from)
223            .with_context(|| "Could not parse connection id")
224    };
225    let to_i256 = |decimal: &BigDecimal| {
226        i256::from_str_radix(
227            // Convert to decimal notation
228            &decimal.to_plain_string(),
229            10,
230        )
231        .ok()
232    };
233    let to_u256 = |decimal: &BigDecimal| {
234        u256::from_str_radix(
235            // Convert to decimal notation
236            &decimal.to_plain_string(),
237            10,
238        )
239        .ok()
240    };
241    match ty {
242        AlgebraicType::I8 => parse_int(
243            // Parse literal as I8
244            value,
245            AlgebraicType::I8,
246            BigDecimal::to_i8,
247            AlgebraicValue::I8,
248        ),
249        AlgebraicType::U8 => parse_int(
250            // Parse literal as U8
251            value,
252            AlgebraicType::U8,
253            BigDecimal::to_u8,
254            AlgebraicValue::U8,
255        ),
256        AlgebraicType::I16 => parse_int(
257            // Parse literal as I16
258            value,
259            AlgebraicType::I16,
260            BigDecimal::to_i16,
261            AlgebraicValue::I16,
262        ),
263        AlgebraicType::U16 => parse_int(
264            // Parse literal as U16
265            value,
266            AlgebraicType::U16,
267            BigDecimal::to_u16,
268            AlgebraicValue::U16,
269        ),
270        AlgebraicType::I32 => parse_int(
271            // Parse literal as I32
272            value,
273            AlgebraicType::I32,
274            BigDecimal::to_i32,
275            AlgebraicValue::I32,
276        ),
277        AlgebraicType::U32 => parse_int(
278            // Parse literal as U32
279            value,
280            AlgebraicType::U32,
281            BigDecimal::to_u32,
282            AlgebraicValue::U32,
283        ),
284        AlgebraicType::I64 => parse_int(
285            // Parse literal as I64
286            value,
287            AlgebraicType::I64,
288            BigDecimal::to_i64,
289            AlgebraicValue::I64,
290        ),
291        AlgebraicType::U64 => parse_int(
292            // Parse literal as U64
293            value,
294            AlgebraicType::U64,
295            BigDecimal::to_u64,
296            AlgebraicValue::U64,
297        ),
298        AlgebraicType::F32 => parse_float(
299            // Parse literal as F32
300            value,
301            AlgebraicType::F32,
302            BigDecimal::to_f32,
303            AlgebraicValue::F32,
304        ),
305        AlgebraicType::F64 => parse_float(
306            // Parse literal as F64
307            value,
308            AlgebraicType::F64,
309            BigDecimal::to_f64,
310            AlgebraicValue::F64,
311        ),
312        AlgebraicType::I128 => parse_int(
313            // Parse literal as I128
314            value,
315            AlgebraicType::I128,
316            BigDecimal::to_i128,
317            AlgebraicValue::I128,
318        ),
319        AlgebraicType::U128 => parse_int(
320            // Parse literal as U128
321            value,
322            AlgebraicType::U128,
323            BigDecimal::to_u128,
324            AlgebraicValue::U128,
325        ),
326        AlgebraicType::I256 => parse_int(
327            // Parse literal as I256
328            value,
329            AlgebraicType::I256,
330            to_i256,
331            AlgebraicValue::I256,
332        ),
333        AlgebraicType::U256 => parse_int(
334            // Parse literal as U256
335            value,
336            AlgebraicType::U256,
337            to_u256,
338            AlgebraicValue::U256,
339        ),
340        AlgebraicType::String => Ok(AlgebraicValue::String(value.into())),
341        t if t.is_timestamp() => to_timestamp(),
342        t if t.is_bytes() => to_bytes(),
343        t if t.is_identity() => to_identity(),
344        t if t.is_connection_id() => to_connection_id(),
345        t => bail!("Literal values for type {} are not supported", fmt_algebraic_type(t)),
346    }
347}
348
349/// The source of a statement
350pub enum StatementSource {
351    Subscription,
352    Query,
353}
354
355/// A statement context.
356///
357/// This is a wrapper around a statement, its source, and the original SQL text.
358pub struct StatementCtx<'a> {
359    pub statement: Statement,
360    pub sql: &'a str,
361    pub source: StatementSource,
362}