1use std::{collections::HashSet, ops::Deref, str::FromStr};
2
3use crate::statement::Statement;
4use anyhow::anyhow;
5use anyhow::bail;
6use anyhow::Context;
7use bigdecimal::BigDecimal;
8use bigdecimal::ToPrimitive;
9use check::{Relvars, TypingResult};
10use errors::{DuplicateName, InvalidLiteral, InvalidOp, InvalidWildcard, UnexpectedType, Unresolved};
11use ethnum::i256;
12use ethnum::u256;
13use expr::AggType;
14use expr::{Expr, FieldProject, ProjectList, ProjectName, RelExpr};
15use spacetimedb_lib::ser::Serialize;
16use spacetimedb_lib::Timestamp;
17use spacetimedb_lib::{from_hex_pad, AlgebraicType, AlgebraicValue, ConnectionId, Identity};
18use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type;
19use spacetimedb_sats::algebraic_value::ser::ValueSerializer;
20use spacetimedb_schema::schema::ColumnSchema;
21use spacetimedb_sql_parser::ast::{self, BinOp, ProjectElem, SqlExpr, SqlIdent, SqlLiteral};
22use spacetimedb_sql_parser::parser::recursion;
23
24pub mod check;
25pub mod errors;
26pub mod expr;
27pub mod rls;
28pub mod statement;
29
30pub(crate) fn type_select(input: RelExpr, expr: SqlExpr, vars: &Relvars) -> TypingResult<RelExpr> {
32 Ok(RelExpr::Select(
33 Box::new(input),
34 type_expr(vars, expr, Some(&AlgebraicType::Bool))?,
35 ))
36}
37
38pub(crate) fn type_limit(input: ProjectList, limit: &str) -> TypingResult<ProjectList> {
40 Ok(
41 parse_int(limit, AlgebraicType::U64, BigDecimal::to_u64, AlgebraicValue::U64)
42 .map_err(|_| InvalidLiteral::new(limit.to_owned(), &AlgebraicType::U64))
43 .and_then(|n| {
44 n.into_u64()
45 .map_err(|_| InvalidLiteral::new(limit.to_owned(), &AlgebraicType::U64))
46 })
47 .map(|n| ProjectList::Limit(Box::new(input), n))?,
48 )
49}
50
51pub(crate) fn type_proj(input: RelExpr, proj: ast::Project, vars: &Relvars) -> TypingResult<ProjectList> {
53 match proj {
54 ast::Project::Star(None) if input.nfields() > 1 => Err(InvalidWildcard::Join.into()),
55 ast::Project::Star(None) => Ok(ProjectList::Name(vec![ProjectName::None(input)])),
56 ast::Project::Star(Some(SqlIdent(var))) if input.has_field(&var) => {
57 Ok(ProjectList::Name(vec![ProjectName::Some(input, var)]))
58 }
59 ast::Project::Star(Some(SqlIdent(var))) => Err(Unresolved::var(&var).into()),
60 ast::Project::Count(SqlIdent(alias)) => {
61 Ok(ProjectList::Agg(vec![input], AggType::Count, alias, AlgebraicType::U64))
62 }
63 ast::Project::Exprs(elems) => {
64 let mut projections = vec![];
65 let mut names = HashSet::new();
66
67 for ProjectElem(expr, SqlIdent(alias)) in elems {
68 if !names.insert(alias.clone()) {
69 return Err(DuplicateName(alias.into_string()).into());
70 }
71
72 if let Expr::Field(p) = type_expr(vars, expr.into(), None)? {
73 projections.push((alias, p));
74 }
75 }
76
77 Ok(ProjectList::List(vec![input], projections))
78 }
79 }
80}
81
82const _: () = assert!(size_of::<TypingResult<Expr>>() == 64);
85const _: () = assert!(size_of::<SqlExpr>() == 40);
86
87fn _type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>, depth: usize) -> TypingResult<Expr> {
88 recursion::guard(depth, recursion::MAX_RECURSION_TYP_EXPR, "expr::type_expr")?;
89
90 match (expr, expected) {
91 (SqlExpr::Lit(SqlLiteral::Bool(v)), None | Some(AlgebraicType::Bool)) => Ok(Expr::bool(v)),
92 (SqlExpr::Lit(SqlLiteral::Bool(_)), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()),
93 (SqlExpr::Lit(SqlLiteral::Str(_) | SqlLiteral::Num(_) | SqlLiteral::Hex(_)), None) => {
94 Err(Unresolved::Literal.into())
95 }
96 (SqlExpr::Lit(SqlLiteral::Str(v) | SqlLiteral::Num(v) | SqlLiteral::Hex(v)), Some(ty)) => Ok(Expr::Value(
97 parse(&v, ty).map_err(|_| InvalidLiteral::new(v.into_string(), ty))?,
98 ty.clone(),
99 )),
100 (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), None) => {
101 let table_type = vars.deref().get(&table).ok_or_else(|| Unresolved::var(&table))?;
102 let ColumnSchema { col_pos, col_type, .. } = table_type
103 .get_column_by_name(&field)
104 .ok_or_else(|| Unresolved::var(&field))?;
105 Ok(Expr::Field(FieldProject {
106 table,
107 field: col_pos.idx(),
108 ty: col_type.clone(),
109 }))
110 }
111 (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), Some(ty)) => {
112 let table_type = vars.deref().get(&table).ok_or_else(|| Unresolved::var(&table))?;
113 let ColumnSchema { col_pos, col_type, .. } = table_type
114 .as_ref()
115 .get_column_by_name(&field)
116 .ok_or_else(|| Unresolved::var(&field))?;
117 if col_type != ty {
118 return Err(UnexpectedType::new(col_type, ty).into());
119 }
120 Ok(Expr::Field(FieldProject {
121 table,
122 field: col_pos.idx(),
123 ty: col_type.clone(),
124 }))
125 }
126 (SqlExpr::Log(a, b, op), None | Some(AlgebraicType::Bool)) => {
127 let a = _type_expr(vars, *a, Some(&AlgebraicType::Bool), depth + 1)?;
128 let b = _type_expr(vars, *b, Some(&AlgebraicType::Bool), depth + 1)?;
129 Ok(Expr::LogOp(op, Box::new(a), Box::new(b)))
130 }
131 (SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) if matches!(&*a, SqlExpr::Lit(_)) => {
132 let b = _type_expr(vars, *b, None, depth + 1)?;
133 let a = _type_expr(vars, *a, Some(b.ty()), depth + 1)?;
134 if !op_supports_type(op, a.ty()) {
135 return Err(InvalidOp::new(op, a.ty()).into());
136 }
137 Ok(Expr::BinOp(op, Box::new(a), Box::new(b)))
138 }
139 (SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) => {
140 let a = _type_expr(vars, *a, None, depth + 1)?;
141 let b = _type_expr(vars, *b, Some(a.ty()), depth + 1)?;
142 if !op_supports_type(op, a.ty()) {
143 return Err(InvalidOp::new(op, a.ty()).into());
144 }
145 Ok(Expr::BinOp(op, Box::new(a), Box::new(b)))
146 }
147 (SqlExpr::Bin(..) | SqlExpr::Log(..), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()),
148 (SqlExpr::Var(_) | SqlExpr::Param(_), _) => unreachable!(),
151 }
152}
153
154pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>) -> TypingResult<Expr> {
156 _type_expr(vars, expr, expected, 0)
157}
158
159fn op_supports_type(_op: BinOp, t: &AlgebraicType) -> bool {
161 t.is_bool()
162 || t.is_integer()
163 || t.is_float()
164 || t.is_string()
165 || t.is_bytes()
166 || t.is_identity()
167 || t.is_connection_id()
168 || t.is_timestamp()
169}
170
171fn parse_int<Int, Val, ToInt, ToVal>(
173 literal: &str,
174 ty: AlgebraicType,
175 to_int: ToInt,
176 to_val: ToVal,
177) -> anyhow::Result<AlgebraicValue>
178where
179 Int: Into<Val>,
180 ToInt: FnOnce(&BigDecimal) -> Option<Int>,
181 ToVal: FnOnce(Val) -> AlgebraicValue,
182{
183 BigDecimal::from_str(literal)
186 .ok()
187 .filter(|decimal| decimal.is_integer())
188 .ok_or_else(|| anyhow!("{literal} is not an integer"))
189 .map(|decimal| to_int(&decimal).map(|val| val.into()).map(to_val))
190 .transpose()
191 .ok_or_else(|| anyhow!("{literal} is out of bounds for type {}", fmt_algebraic_type(&ty)))?
192}
193
194fn parse_float<Float, Value, ToFloat, ToValue>(
196 literal: &str,
197 ty: AlgebraicType,
198 to_float: ToFloat,
199 to_value: ToValue,
200) -> anyhow::Result<AlgebraicValue>
201where
202 Float: Into<Value>,
203 ToFloat: FnOnce(&BigDecimal) -> Option<Float>,
204 ToValue: FnOnce(Value) -> AlgebraicValue,
205{
206 BigDecimal::from_str(literal)
207 .ok()
208 .and_then(|decimal| to_float(&decimal))
209 .map(|value| value.into())
210 .map(to_value)
211 .ok_or_else(|| anyhow!("{literal} is not a valid {}", fmt_algebraic_type(&ty)))
212}
213
214pub(crate) fn parse(value: &str, ty: &AlgebraicType) -> anyhow::Result<AlgebraicValue> {
216 let to_timestamp = || {
217 Timestamp::parse_from_rfc3339(value)?
218 .serialize(ValueSerializer)
219 .with_context(|| "Could not parse timestamp")
220 };
221 let to_bytes = || {
222 from_hex_pad::<Vec<u8>, _>(value)
223 .map(|v| v.into_boxed_slice())
224 .map(AlgebraicValue::Bytes)
225 .with_context(|| "Could not parse hex value")
226 };
227 let to_identity = || {
228 Identity::from_hex(value)
229 .map(AlgebraicValue::from)
230 .with_context(|| "Could not parse identity")
231 };
232 let to_connection_id = || {
233 ConnectionId::from_hex(value)
234 .map(AlgebraicValue::from)
235 .with_context(|| "Could not parse connection id")
236 };
237 let to_i256 = |decimal: &BigDecimal| {
238 i256::from_str_radix(
239 &decimal.to_plain_string(),
241 10,
242 )
243 .ok()
244 };
245 let to_u256 = |decimal: &BigDecimal| {
246 u256::from_str_radix(
247 &decimal.to_plain_string(),
249 10,
250 )
251 .ok()
252 };
253 match ty {
254 AlgebraicType::I8 => parse_int(
255 value,
257 AlgebraicType::I8,
258 BigDecimal::to_i8,
259 AlgebraicValue::I8,
260 ),
261 AlgebraicType::U8 => parse_int(
262 value,
264 AlgebraicType::U8,
265 BigDecimal::to_u8,
266 AlgebraicValue::U8,
267 ),
268 AlgebraicType::I16 => parse_int(
269 value,
271 AlgebraicType::I16,
272 BigDecimal::to_i16,
273 AlgebraicValue::I16,
274 ),
275 AlgebraicType::U16 => parse_int(
276 value,
278 AlgebraicType::U16,
279 BigDecimal::to_u16,
280 AlgebraicValue::U16,
281 ),
282 AlgebraicType::I32 => parse_int(
283 value,
285 AlgebraicType::I32,
286 BigDecimal::to_i32,
287 AlgebraicValue::I32,
288 ),
289 AlgebraicType::U32 => parse_int(
290 value,
292 AlgebraicType::U32,
293 BigDecimal::to_u32,
294 AlgebraicValue::U32,
295 ),
296 AlgebraicType::I64 => parse_int(
297 value,
299 AlgebraicType::I64,
300 BigDecimal::to_i64,
301 AlgebraicValue::I64,
302 ),
303 AlgebraicType::U64 => parse_int(
304 value,
306 AlgebraicType::U64,
307 BigDecimal::to_u64,
308 AlgebraicValue::U64,
309 ),
310 AlgebraicType::F32 => parse_float(
311 value,
313 AlgebraicType::F32,
314 BigDecimal::to_f32,
315 AlgebraicValue::F32,
316 ),
317 AlgebraicType::F64 => parse_float(
318 value,
320 AlgebraicType::F64,
321 BigDecimal::to_f64,
322 AlgebraicValue::F64,
323 ),
324 AlgebraicType::I128 => parse_int(
325 value,
327 AlgebraicType::I128,
328 BigDecimal::to_i128,
329 AlgebraicValue::I128,
330 ),
331 AlgebraicType::U128 => parse_int(
332 value,
334 AlgebraicType::U128,
335 BigDecimal::to_u128,
336 AlgebraicValue::U128,
337 ),
338 AlgebraicType::I256 => parse_int(
339 value,
341 AlgebraicType::I256,
342 to_i256,
343 AlgebraicValue::I256,
344 ),
345 AlgebraicType::U256 => parse_int(
346 value,
348 AlgebraicType::U256,
349 to_u256,
350 AlgebraicValue::U256,
351 ),
352 AlgebraicType::String => Ok(AlgebraicValue::String(value.into())),
353 t if t.is_timestamp() => to_timestamp(),
354 t if t.is_bytes() => to_bytes(),
355 t if t.is_identity() => to_identity(),
356 t if t.is_connection_id() => to_connection_id(),
357 t => bail!("Literal values for type {} are not supported", fmt_algebraic_type(t)),
358 }
359}
360
361pub enum StatementSource {
363 Subscription,
364 Query,
365}
366
367pub struct StatementCtx<'a> {
371 pub statement: Statement,
372 pub sql: &'a str,
373 pub source: StatementSource,
374}