spacetimedb_expr/
lib.rs

1use std::{collections::HashSet, ops::Deref, str::FromStr};
2
3use crate::statement::Statement;
4use anyhow::anyhow;
5use anyhow::bail;
6use anyhow::Context;
7use bigdecimal::BigDecimal;
8use bigdecimal::ToPrimitive;
9use check::{Relvars, TypingResult};
10use errors::{DuplicateName, InvalidLiteral, InvalidOp, InvalidWildcard, UnexpectedType, Unresolved};
11use ethnum::i256;
12use ethnum::u256;
13use expr::AggType;
14use expr::{Expr, FieldProject, ProjectList, ProjectName, RelExpr};
15use spacetimedb_lib::ser::Serialize;
16use spacetimedb_lib::Timestamp;
17use spacetimedb_lib::{from_hex_pad, AlgebraicType, AlgebraicValue, ConnectionId, Identity};
18use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type;
19use spacetimedb_sats::algebraic_value::ser::ValueSerializer;
20use spacetimedb_schema::schema::ColumnSchema;
21use spacetimedb_sql_parser::ast::{self, BinOp, ProjectElem, SqlExpr, SqlIdent, SqlLiteral};
22
23pub mod check;
24pub mod errors;
25pub mod expr;
26pub mod statement;
27
28/// Type check and lower a [SqlExpr]
29pub(crate) fn type_select(input: RelExpr, expr: SqlExpr, vars: &Relvars) -> TypingResult<RelExpr> {
30    Ok(RelExpr::Select(
31        Box::new(input),
32        type_expr(vars, expr, Some(&AlgebraicType::Bool))?,
33    ))
34}
35
36/// Type check a LIMIT clause
37pub(crate) fn type_limit(input: ProjectList, limit: &str) -> TypingResult<ProjectList> {
38    Ok(
39        parse_int(limit, AlgebraicType::U64, BigDecimal::to_u64, AlgebraicValue::U64)
40            .map_err(|_| InvalidLiteral::new(limit.to_owned(), &AlgebraicType::U64))
41            .and_then(|n| {
42                n.into_u64()
43                    .map_err(|_| InvalidLiteral::new(limit.to_owned(), &AlgebraicType::U64))
44            })
45            .map(|n| ProjectList::Limit(Box::new(input), n))?,
46    )
47}
48
49/// Type check and lower a [ast::Project]
50pub(crate) fn type_proj(input: RelExpr, proj: ast::Project, vars: &Relvars) -> TypingResult<ProjectList> {
51    match proj {
52        ast::Project::Star(None) if input.nfields() > 1 => Err(InvalidWildcard::Join.into()),
53        ast::Project::Star(None) => Ok(ProjectList::Name(ProjectName::None(input))),
54        ast::Project::Star(Some(SqlIdent(var))) if input.has_field(&var) => {
55            Ok(ProjectList::Name(ProjectName::Some(input, var)))
56        }
57        ast::Project::Star(Some(SqlIdent(var))) => Err(Unresolved::var(&var).into()),
58        ast::Project::Count(SqlIdent(alias)) => Ok(ProjectList::Agg(input, AggType::Count, alias, AlgebraicType::U64)),
59        ast::Project::Exprs(elems) => {
60            let mut projections = vec![];
61            let mut names = HashSet::new();
62
63            for ProjectElem(expr, SqlIdent(alias)) in elems {
64                if !names.insert(alias.clone()) {
65                    return Err(DuplicateName(alias.into_string()).into());
66                }
67
68                if let Expr::Field(p) = type_expr(vars, expr.into(), None)? {
69                    projections.push((alias, p));
70                }
71            }
72
73            Ok(ProjectList::List(input, projections))
74        }
75    }
76}
77
78/// Type check and lower a [SqlExpr] into a logical [Expr].
79pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>) -> TypingResult<Expr> {
80    match (expr, expected) {
81        (SqlExpr::Lit(SqlLiteral::Bool(v)), None | Some(AlgebraicType::Bool)) => Ok(Expr::bool(v)),
82        (SqlExpr::Lit(SqlLiteral::Bool(_)), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()),
83        (SqlExpr::Lit(SqlLiteral::Str(_) | SqlLiteral::Num(_) | SqlLiteral::Hex(_)), None) => {
84            Err(Unresolved::Literal.into())
85        }
86        (SqlExpr::Lit(SqlLiteral::Str(v) | SqlLiteral::Num(v) | SqlLiteral::Hex(v)), Some(ty)) => Ok(Expr::Value(
87            parse(&v, ty).map_err(|_| InvalidLiteral::new(v.into_string(), ty))?,
88            ty.clone(),
89        )),
90        (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), None) => {
91            let table_type = vars.deref().get(&table).ok_or_else(|| Unresolved::var(&table))?;
92            let ColumnSchema { col_pos, col_type, .. } = table_type
93                .get_column_by_name(&field)
94                .ok_or_else(|| Unresolved::var(&field))?;
95            Ok(Expr::Field(FieldProject {
96                table,
97                field: col_pos.idx(),
98                ty: col_type.clone(),
99            }))
100        }
101        (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), Some(ty)) => {
102            let table_type = vars.deref().get(&table).ok_or_else(|| Unresolved::var(&table))?;
103            let ColumnSchema { col_pos, col_type, .. } = table_type
104                .as_ref()
105                .get_column_by_name(&field)
106                .ok_or_else(|| Unresolved::var(&field))?;
107            if col_type != ty {
108                return Err(UnexpectedType::new(col_type, ty).into());
109            }
110            Ok(Expr::Field(FieldProject {
111                table,
112                field: col_pos.idx(),
113                ty: col_type.clone(),
114            }))
115        }
116        (SqlExpr::Log(a, b, op), None | Some(AlgebraicType::Bool)) => {
117            let a = type_expr(vars, *a, Some(&AlgebraicType::Bool))?;
118            let b = type_expr(vars, *b, Some(&AlgebraicType::Bool))?;
119            Ok(Expr::LogOp(op, Box::new(a), Box::new(b)))
120        }
121        (SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) => match (*a, *b) {
122            (a, b @ SqlExpr::Lit(_)) | (b @ SqlExpr::Lit(_), a) | (a, b) => {
123                let a = type_expr(vars, a, None)?;
124                let b = type_expr(vars, b, Some(a.ty()))?;
125                if !op_supports_type(op, a.ty()) {
126                    return Err(InvalidOp::new(op, a.ty()).into());
127                }
128                Ok(Expr::BinOp(op, Box::new(a), Box::new(b)))
129            }
130        },
131        (SqlExpr::Bin(..) | SqlExpr::Log(..), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()),
132        (SqlExpr::Var(_), _) => unreachable!(),
133    }
134}
135
136/// Is this type compatible with this binary operator?
137fn op_supports_type(_op: BinOp, t: &AlgebraicType) -> bool {
138    t.is_bool()
139        || t.is_integer()
140        || t.is_float()
141        || t.is_string()
142        || t.is_bytes()
143        || t.is_identity()
144        || t.is_connection_id()
145        || t.is_timestamp()
146}
147
148/// Parse an integer literal into an [AlgebraicValue]
149fn parse_int<Int, Val, ToInt, ToVal>(
150    literal: &str,
151    ty: AlgebraicType,
152    to_int: ToInt,
153    to_val: ToVal,
154) -> anyhow::Result<AlgebraicValue>
155where
156    Int: Into<Val>,
157    ToInt: FnOnce(&BigDecimal) -> Option<Int>,
158    ToVal: FnOnce(Val) -> AlgebraicValue,
159{
160    // Why are we using an arbitrary precision type?
161    // For scientific notation as well as i256 and u256.
162    BigDecimal::from_str(literal)
163        .ok()
164        .filter(|decimal| decimal.is_integer())
165        .ok_or_else(|| anyhow!("{literal} is not an integer"))
166        .map(|decimal| to_int(&decimal).map(|val| val.into()).map(to_val))
167        .transpose()
168        .ok_or_else(|| anyhow!("{literal} is out of bounds for type {}", fmt_algebraic_type(&ty)))?
169}
170
171/// Parse a floating point literal into an [AlgebraicValue]
172fn parse_float<Float, Value, ToFloat, ToValue>(
173    literal: &str,
174    ty: AlgebraicType,
175    to_float: ToFloat,
176    to_value: ToValue,
177) -> anyhow::Result<AlgebraicValue>
178where
179    Float: Into<Value>,
180    ToFloat: FnOnce(&BigDecimal) -> Option<Float>,
181    ToValue: FnOnce(Value) -> AlgebraicValue,
182{
183    BigDecimal::from_str(literal)
184        .ok()
185        .and_then(|decimal| to_float(&decimal))
186        .map(|value| value.into())
187        .map(to_value)
188        .ok_or_else(|| anyhow!("{literal} is not a valid {}", fmt_algebraic_type(&ty)))
189}
190
191/// Parses a source text literal as a particular type
192pub(crate) fn parse(value: &str, ty: &AlgebraicType) -> anyhow::Result<AlgebraicValue> {
193    let to_timestamp = || {
194        Timestamp::parse_from_str(value)?
195            .serialize(ValueSerializer)
196            .with_context(|| "Could not parse timestamp")
197    };
198    let to_bytes = || {
199        from_hex_pad::<Vec<u8>, _>(value)
200            .map(|v| v.into_boxed_slice())
201            .map(AlgebraicValue::Bytes)
202            .with_context(|| "Could not parse hex value")
203    };
204    let to_identity = || {
205        Identity::from_hex(value)
206            .map(AlgebraicValue::from)
207            .with_context(|| "Could not parse identity")
208    };
209    let to_connection_id = || {
210        ConnectionId::from_hex(value)
211            .map(AlgebraicValue::from)
212            .with_context(|| "Could not parse connection id")
213    };
214    let to_i256 = |decimal: &BigDecimal| {
215        i256::from_str_radix(
216            // Convert to decimal notation
217            &decimal.to_plain_string(),
218            10,
219        )
220        .ok()
221    };
222    let to_u256 = |decimal: &BigDecimal| {
223        u256::from_str_radix(
224            // Convert to decimal notation
225            &decimal.to_plain_string(),
226            10,
227        )
228        .ok()
229    };
230    match ty {
231        AlgebraicType::I8 => parse_int(
232            // Parse literal as I8
233            value,
234            AlgebraicType::I8,
235            BigDecimal::to_i8,
236            AlgebraicValue::I8,
237        ),
238        AlgebraicType::U8 => parse_int(
239            // Parse literal as U8
240            value,
241            AlgebraicType::U8,
242            BigDecimal::to_u8,
243            AlgebraicValue::U8,
244        ),
245        AlgebraicType::I16 => parse_int(
246            // Parse literal as I16
247            value,
248            AlgebraicType::I16,
249            BigDecimal::to_i16,
250            AlgebraicValue::I16,
251        ),
252        AlgebraicType::U16 => parse_int(
253            // Parse literal as U16
254            value,
255            AlgebraicType::U16,
256            BigDecimal::to_u16,
257            AlgebraicValue::U16,
258        ),
259        AlgebraicType::I32 => parse_int(
260            // Parse literal as I32
261            value,
262            AlgebraicType::I32,
263            BigDecimal::to_i32,
264            AlgebraicValue::I32,
265        ),
266        AlgebraicType::U32 => parse_int(
267            // Parse literal as U32
268            value,
269            AlgebraicType::U32,
270            BigDecimal::to_u32,
271            AlgebraicValue::U32,
272        ),
273        AlgebraicType::I64 => parse_int(
274            // Parse literal as I64
275            value,
276            AlgebraicType::I64,
277            BigDecimal::to_i64,
278            AlgebraicValue::I64,
279        ),
280        AlgebraicType::U64 => parse_int(
281            // Parse literal as U64
282            value,
283            AlgebraicType::U64,
284            BigDecimal::to_u64,
285            AlgebraicValue::U64,
286        ),
287        AlgebraicType::F32 => parse_float(
288            // Parse literal as F32
289            value,
290            AlgebraicType::F32,
291            BigDecimal::to_f32,
292            AlgebraicValue::F32,
293        ),
294        AlgebraicType::F64 => parse_float(
295            // Parse literal as F64
296            value,
297            AlgebraicType::F64,
298            BigDecimal::to_f64,
299            AlgebraicValue::F64,
300        ),
301        AlgebraicType::I128 => parse_int(
302            // Parse literal as I128
303            value,
304            AlgebraicType::I128,
305            BigDecimal::to_i128,
306            AlgebraicValue::I128,
307        ),
308        AlgebraicType::U128 => parse_int(
309            // Parse literal as U128
310            value,
311            AlgebraicType::U128,
312            BigDecimal::to_u128,
313            AlgebraicValue::U128,
314        ),
315        AlgebraicType::I256 => parse_int(
316            // Parse literal as I256
317            value,
318            AlgebraicType::I256,
319            to_i256,
320            AlgebraicValue::I256,
321        ),
322        AlgebraicType::U256 => parse_int(
323            // Parse literal as U256
324            value,
325            AlgebraicType::U256,
326            to_u256,
327            AlgebraicValue::U256,
328        ),
329        AlgebraicType::String => Ok(AlgebraicValue::String(value.into())),
330        t if t.is_timestamp() => to_timestamp(),
331        t if t.is_bytes() => to_bytes(),
332        t if t.is_identity() => to_identity(),
333        t if t.is_connection_id() => to_connection_id(),
334        t => bail!("Literal values for type {} are not supported", fmt_algebraic_type(t)),
335    }
336}
337
338/// The source of a statement
339pub enum StatementSource {
340    Subscription,
341    Query,
342}
343
344/// A statement context.
345///
346/// This is a wrapper around a statement, its source, and the original SQL text.
347pub struct StatementCtx<'a> {
348    pub statement: Statement,
349    pub sql: &'a str,
350    pub source: StatementSource,
351}