1use std::{collections::HashSet, ops::Deref, str::FromStr};
2
3use crate::statement::Statement;
4use anyhow::anyhow;
5use anyhow::bail;
6use anyhow::Context;
7use bigdecimal::BigDecimal;
8use bigdecimal::ToPrimitive;
9use check::{Relvars, TypingResult};
10use errors::{DuplicateName, InvalidLiteral, InvalidOp, InvalidWildcard, UnexpectedType, Unresolved};
11use ethnum::i256;
12use ethnum::u256;
13use expr::AggType;
14use expr::{Expr, FieldProject, ProjectList, ProjectName, RelExpr};
15use spacetimedb_lib::ser::Serialize;
16use spacetimedb_lib::Timestamp;
17use spacetimedb_lib::{from_hex_pad, AlgebraicType, AlgebraicValue, ConnectionId, Identity};
18use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type;
19use spacetimedb_sats::algebraic_value::ser::ValueSerializer;
20use spacetimedb_schema::schema::ColumnSchema;
21use spacetimedb_sql_parser::ast::{self, BinOp, ProjectElem, SqlExpr, SqlIdent, SqlLiteral};
22
23pub mod check;
24pub mod errors;
25pub mod expr;
26pub mod rls;
27pub mod statement;
28
29pub(crate) fn type_select(input: RelExpr, expr: SqlExpr, vars: &Relvars) -> TypingResult<RelExpr> {
31 Ok(RelExpr::Select(
32 Box::new(input),
33 type_expr(vars, expr, Some(&AlgebraicType::Bool))?,
34 ))
35}
36
37pub(crate) fn type_limit(input: ProjectList, limit: &str) -> TypingResult<ProjectList> {
39 Ok(
40 parse_int(limit, AlgebraicType::U64, BigDecimal::to_u64, AlgebraicValue::U64)
41 .map_err(|_| InvalidLiteral::new(limit.to_owned(), &AlgebraicType::U64))
42 .and_then(|n| {
43 n.into_u64()
44 .map_err(|_| InvalidLiteral::new(limit.to_owned(), &AlgebraicType::U64))
45 })
46 .map(|n| ProjectList::Limit(Box::new(input), n))?,
47 )
48}
49
50pub(crate) fn type_proj(input: RelExpr, proj: ast::Project, vars: &Relvars) -> TypingResult<ProjectList> {
52 match proj {
53 ast::Project::Star(None) if input.nfields() > 1 => Err(InvalidWildcard::Join.into()),
54 ast::Project::Star(None) => Ok(ProjectList::Name(vec![ProjectName::None(input)])),
55 ast::Project::Star(Some(SqlIdent(var))) if input.has_field(&var) => {
56 Ok(ProjectList::Name(vec![ProjectName::Some(input, var)]))
57 }
58 ast::Project::Star(Some(SqlIdent(var))) => Err(Unresolved::var(&var).into()),
59 ast::Project::Count(SqlIdent(alias)) => {
60 Ok(ProjectList::Agg(vec![input], AggType::Count, alias, AlgebraicType::U64))
61 }
62 ast::Project::Exprs(elems) => {
63 let mut projections = vec![];
64 let mut names = HashSet::new();
65
66 for ProjectElem(expr, SqlIdent(alias)) in elems {
67 if !names.insert(alias.clone()) {
68 return Err(DuplicateName(alias.into_string()).into());
69 }
70
71 if let Expr::Field(p) = type_expr(vars, expr.into(), None)? {
72 projections.push((alias, p));
73 }
74 }
75
76 Ok(ProjectList::List(vec![input], projections))
77 }
78 }
79}
80
81pub(crate) fn type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>) -> TypingResult<Expr> {
83 match (expr, expected) {
84 (SqlExpr::Lit(SqlLiteral::Bool(v)), None | Some(AlgebraicType::Bool)) => Ok(Expr::bool(v)),
85 (SqlExpr::Lit(SqlLiteral::Bool(_)), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()),
86 (SqlExpr::Lit(SqlLiteral::Str(_) | SqlLiteral::Num(_) | SqlLiteral::Hex(_)), None) => {
87 Err(Unresolved::Literal.into())
88 }
89 (SqlExpr::Lit(SqlLiteral::Str(v) | SqlLiteral::Num(v) | SqlLiteral::Hex(v)), Some(ty)) => Ok(Expr::Value(
90 parse(&v, ty).map_err(|_| InvalidLiteral::new(v.into_string(), ty))?,
91 ty.clone(),
92 )),
93 (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), None) => {
94 let table_type = vars.deref().get(&table).ok_or_else(|| Unresolved::var(&table))?;
95 let ColumnSchema { col_pos, col_type, .. } = table_type
96 .get_column_by_name(&field)
97 .ok_or_else(|| Unresolved::var(&field))?;
98 Ok(Expr::Field(FieldProject {
99 table,
100 field: col_pos.idx(),
101 ty: col_type.clone(),
102 }))
103 }
104 (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), Some(ty)) => {
105 let table_type = vars.deref().get(&table).ok_or_else(|| Unresolved::var(&table))?;
106 let ColumnSchema { col_pos, col_type, .. } = table_type
107 .as_ref()
108 .get_column_by_name(&field)
109 .ok_or_else(|| Unresolved::var(&field))?;
110 if col_type != ty {
111 return Err(UnexpectedType::new(col_type, ty).into());
112 }
113 Ok(Expr::Field(FieldProject {
114 table,
115 field: col_pos.idx(),
116 ty: col_type.clone(),
117 }))
118 }
119 (SqlExpr::Log(a, b, op), None | Some(AlgebraicType::Bool)) => {
120 let a = type_expr(vars, *a, Some(&AlgebraicType::Bool))?;
121 let b = type_expr(vars, *b, Some(&AlgebraicType::Bool))?;
122 Ok(Expr::LogOp(op, Box::new(a), Box::new(b)))
123 }
124 (SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) if matches!(&*a, SqlExpr::Lit(_)) => {
125 let b = type_expr(vars, *b, None)?;
126 let a = type_expr(vars, *a, Some(b.ty()))?;
127 if !op_supports_type(op, a.ty()) {
128 return Err(InvalidOp::new(op, a.ty()).into());
129 }
130 Ok(Expr::BinOp(op, Box::new(a), Box::new(b)))
131 }
132 (SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) => {
133 let a = type_expr(vars, *a, None)?;
134 let b = type_expr(vars, *b, Some(a.ty()))?;
135 if !op_supports_type(op, a.ty()) {
136 return Err(InvalidOp::new(op, a.ty()).into());
137 }
138 Ok(Expr::BinOp(op, Box::new(a), Box::new(b)))
139 }
140 (SqlExpr::Bin(..) | SqlExpr::Log(..), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()),
141 (SqlExpr::Var(_) | SqlExpr::Param(_), _) => unreachable!(),
144 }
145}
146
147fn op_supports_type(_op: BinOp, t: &AlgebraicType) -> bool {
149 t.is_bool()
150 || t.is_integer()
151 || t.is_float()
152 || t.is_string()
153 || t.is_bytes()
154 || t.is_identity()
155 || t.is_connection_id()
156 || t.is_timestamp()
157}
158
159fn parse_int<Int, Val, ToInt, ToVal>(
161 literal: &str,
162 ty: AlgebraicType,
163 to_int: ToInt,
164 to_val: ToVal,
165) -> anyhow::Result<AlgebraicValue>
166where
167 Int: Into<Val>,
168 ToInt: FnOnce(&BigDecimal) -> Option<Int>,
169 ToVal: FnOnce(Val) -> AlgebraicValue,
170{
171 BigDecimal::from_str(literal)
174 .ok()
175 .filter(|decimal| decimal.is_integer())
176 .ok_or_else(|| anyhow!("{literal} is not an integer"))
177 .map(|decimal| to_int(&decimal).map(|val| val.into()).map(to_val))
178 .transpose()
179 .ok_or_else(|| anyhow!("{literal} is out of bounds for type {}", fmt_algebraic_type(&ty)))?
180}
181
182fn parse_float<Float, Value, ToFloat, ToValue>(
184 literal: &str,
185 ty: AlgebraicType,
186 to_float: ToFloat,
187 to_value: ToValue,
188) -> anyhow::Result<AlgebraicValue>
189where
190 Float: Into<Value>,
191 ToFloat: FnOnce(&BigDecimal) -> Option<Float>,
192 ToValue: FnOnce(Value) -> AlgebraicValue,
193{
194 BigDecimal::from_str(literal)
195 .ok()
196 .and_then(|decimal| to_float(&decimal))
197 .map(|value| value.into())
198 .map(to_value)
199 .ok_or_else(|| anyhow!("{literal} is not a valid {}", fmt_algebraic_type(&ty)))
200}
201
202pub(crate) fn parse(value: &str, ty: &AlgebraicType) -> anyhow::Result<AlgebraicValue> {
204 let to_timestamp = || {
205 Timestamp::parse_from_rfc3339(value)?
206 .serialize(ValueSerializer)
207 .with_context(|| "Could not parse timestamp")
208 };
209 let to_bytes = || {
210 from_hex_pad::<Vec<u8>, _>(value)
211 .map(|v| v.into_boxed_slice())
212 .map(AlgebraicValue::Bytes)
213 .with_context(|| "Could not parse hex value")
214 };
215 let to_identity = || {
216 Identity::from_hex(value)
217 .map(AlgebraicValue::from)
218 .with_context(|| "Could not parse identity")
219 };
220 let to_connection_id = || {
221 ConnectionId::from_hex(value)
222 .map(AlgebraicValue::from)
223 .with_context(|| "Could not parse connection id")
224 };
225 let to_i256 = |decimal: &BigDecimal| {
226 i256::from_str_radix(
227 &decimal.to_plain_string(),
229 10,
230 )
231 .ok()
232 };
233 let to_u256 = |decimal: &BigDecimal| {
234 u256::from_str_radix(
235 &decimal.to_plain_string(),
237 10,
238 )
239 .ok()
240 };
241 match ty {
242 AlgebraicType::I8 => parse_int(
243 value,
245 AlgebraicType::I8,
246 BigDecimal::to_i8,
247 AlgebraicValue::I8,
248 ),
249 AlgebraicType::U8 => parse_int(
250 value,
252 AlgebraicType::U8,
253 BigDecimal::to_u8,
254 AlgebraicValue::U8,
255 ),
256 AlgebraicType::I16 => parse_int(
257 value,
259 AlgebraicType::I16,
260 BigDecimal::to_i16,
261 AlgebraicValue::I16,
262 ),
263 AlgebraicType::U16 => parse_int(
264 value,
266 AlgebraicType::U16,
267 BigDecimal::to_u16,
268 AlgebraicValue::U16,
269 ),
270 AlgebraicType::I32 => parse_int(
271 value,
273 AlgebraicType::I32,
274 BigDecimal::to_i32,
275 AlgebraicValue::I32,
276 ),
277 AlgebraicType::U32 => parse_int(
278 value,
280 AlgebraicType::U32,
281 BigDecimal::to_u32,
282 AlgebraicValue::U32,
283 ),
284 AlgebraicType::I64 => parse_int(
285 value,
287 AlgebraicType::I64,
288 BigDecimal::to_i64,
289 AlgebraicValue::I64,
290 ),
291 AlgebraicType::U64 => parse_int(
292 value,
294 AlgebraicType::U64,
295 BigDecimal::to_u64,
296 AlgebraicValue::U64,
297 ),
298 AlgebraicType::F32 => parse_float(
299 value,
301 AlgebraicType::F32,
302 BigDecimal::to_f32,
303 AlgebraicValue::F32,
304 ),
305 AlgebraicType::F64 => parse_float(
306 value,
308 AlgebraicType::F64,
309 BigDecimal::to_f64,
310 AlgebraicValue::F64,
311 ),
312 AlgebraicType::I128 => parse_int(
313 value,
315 AlgebraicType::I128,
316 BigDecimal::to_i128,
317 AlgebraicValue::I128,
318 ),
319 AlgebraicType::U128 => parse_int(
320 value,
322 AlgebraicType::U128,
323 BigDecimal::to_u128,
324 AlgebraicValue::U128,
325 ),
326 AlgebraicType::I256 => parse_int(
327 value,
329 AlgebraicType::I256,
330 to_i256,
331 AlgebraicValue::I256,
332 ),
333 AlgebraicType::U256 => parse_int(
334 value,
336 AlgebraicType::U256,
337 to_u256,
338 AlgebraicValue::U256,
339 ),
340 AlgebraicType::String => Ok(AlgebraicValue::String(value.into())),
341 t if t.is_timestamp() => to_timestamp(),
342 t if t.is_bytes() => to_bytes(),
343 t if t.is_identity() => to_identity(),
344 t if t.is_connection_id() => to_connection_id(),
345 t => bail!("Literal values for type {} are not supported", fmt_algebraic_type(t)),
346 }
347}
348
349pub enum StatementSource {
351 Subscription,
352 Query,
353}
354
355pub struct StatementCtx<'a> {
359 pub statement: Statement,
360 pub sql: &'a str,
361 pub source: StatementSource,
362}