1use std::collections::BTreeMap;
21use std::fmt;
22
23use palimpsest_sql::catalog::ColumnType;
24use palimpsest_wal::Datum;
25use sqlparser::ast::{BinaryOperator, Expr, UnaryOperator, Value as SqlValue};
26use sqlparser::dialect::PostgreSqlDialect;
27use sqlparser::parser::Parser;
28use thiserror::Error;
29
30use crate::palimpsest::wal::Row;
31
32pub type ScalarFn = Box<dyn Fn(&Row) -> Datum + Send + Sync>;
34
35pub type PredicateFn = Box<dyn Fn(&Row) -> bool + Send + Sync>;
37
38pub type IntExtractor = Box<dyn Fn(&Row) -> i64 + Send + Sync>;
42
43#[derive(Debug, Clone, Default)]
47pub struct ScalarSchema {
48 columns: Vec<(String, ColumnType)>,
49 index: BTreeMap<String, usize>,
50}
51
52impl ScalarSchema {
53 #[must_use]
57 pub fn from_pairs(columns: impl IntoIterator<Item = (String, ColumnType)>) -> Self {
58 let columns: Vec<_> = columns.into_iter().collect();
59 let mut index = BTreeMap::new();
60 for (i, (name, _)) in columns.iter().enumerate() {
61 index.insert(name.clone(), i);
62 }
63 Self { columns, index }
64 }
65
66 #[must_use]
68 pub fn index_of(&self, name: &str) -> Option<usize> {
69 self.index.get(name).copied()
70 }
71
72 #[must_use]
74 pub fn column_type(&self, name: &str) -> Option<ColumnType> {
75 self.index.get(name).map(|&i| self.columns[i].1)
76 }
77
78 #[must_use]
80 pub fn columns(&self) -> &[(String, ColumnType)] {
81 &self.columns
82 }
83
84 #[must_use]
86 pub fn len(&self) -> usize {
87 self.columns.len()
88 }
89
90 #[must_use]
92 pub fn is_empty(&self) -> bool {
93 self.columns.is_empty()
94 }
95}
96
97#[derive(Debug, Error)]
102pub enum EvalError {
103 #[error("parse error: {0}")]
105 Parse(String),
106 #[error("unsupported expression: {0}")]
108 Unsupported(String),
109 #[error("unknown column: {0}")]
111 UnknownColumn(String),
112}
113
114pub fn compile_predicate(expr_sql: &str, schema: &ScalarSchema) -> Result<PredicateFn, EvalError> {
121 let scalar = compile_scalar(expr_sql, schema)?;
122 Ok(Box::new(move |row| {
123 matches!(scalar(row), Datum::Bool(true))
124 }))
125}
126
127pub fn compile_scalar(expr_sql: &str, schema: &ScalarSchema) -> Result<ScalarFn, EvalError> {
132 let expr = parse_expr(expr_sql)?;
133 compile_inner(&expr, schema)
134}
135
136pub fn compile_int_extractor(
146 arg_sql: &str,
147 schema: &ScalarSchema,
148) -> Result<IntExtractor, EvalError> {
149 let trimmed = arg_sql.trim();
150 if trimmed == "*" {
151 return Ok(Box::new(|_| 0));
152 }
153 let scalar = compile_scalar(trimmed, schema)?;
154 Ok(Box::new(move |row| match scalar(row) {
155 Datum::I64(v) => v,
156 Datum::I32(v) => i64::from(v),
157 Datum::I16(v) => i64::from(v),
158 _ => 0,
159 }))
160}
161
162fn parse_expr(sql: &str) -> Result<Expr, EvalError> {
163 let dialect = PostgreSqlDialect {};
164 let mut parser = Parser::new(&dialect)
165 .try_with_sql(sql)
166 .map_err(|err| EvalError::Parse(err.to_string()))?;
167 parser
168 .parse_expr()
169 .map_err(|err| EvalError::Parse(err.to_string()))
170}
171
172fn compile_inner(expr: &Expr, schema: &ScalarSchema) -> Result<ScalarFn, EvalError> {
173 match expr {
174 Expr::Nested(inner) => compile_inner(inner, schema),
175 Expr::Identifier(ident) => identifier_scalar(&ident.value, schema),
176 Expr::CompoundIdentifier(parts) => {
177 let last = parts
181 .last()
182 .ok_or_else(|| EvalError::Unsupported("empty compound identifier".to_owned()))?;
183 identifier_scalar(&last.value, schema)
184 }
185 Expr::Value(value) => value_scalar(value),
186 Expr::BinaryOp { left, op, right } => binary_scalar(left, op.clone(), right, schema),
187 Expr::UnaryOp { op, expr: inner } => unary_scalar(op.clone(), inner, schema),
188 Expr::IsNull(inner) => {
189 let target = compile_inner(inner, schema)?;
190 Ok(Box::new(move |row| {
191 Datum::Bool(matches!(target(row), Datum::Null))
192 }))
193 }
194 Expr::IsNotNull(inner) => {
195 let target = compile_inner(inner, schema)?;
196 Ok(Box::new(move |row| {
197 Datum::Bool(!matches!(target(row), Datum::Null))
198 }))
199 }
200 Expr::IsTrue(inner) => {
201 let target = compile_inner(inner, schema)?;
202 Ok(Box::new(move |row| {
203 Datum::Bool(matches!(target(row), Datum::Bool(true)))
204 }))
205 }
206 Expr::IsFalse(inner) => {
207 let target = compile_inner(inner, schema)?;
208 Ok(Box::new(move |row| {
209 Datum::Bool(matches!(target(row), Datum::Bool(false)))
210 }))
211 }
212 other => Err(EvalError::Unsupported(format!("{other:?}"))),
213 }
214}
215
216fn identifier_scalar(name: &str, schema: &ScalarSchema) -> Result<ScalarFn, EvalError> {
217 let idx = schema
218 .index_of(name)
219 .ok_or_else(|| EvalError::UnknownColumn(name.to_owned()))?;
220 Ok(Box::new(move |row| {
221 row.get(idx).cloned().unwrap_or(Datum::Null)
222 }))
223}
224
225fn value_scalar(value: &SqlValue) -> Result<ScalarFn, EvalError> {
226 match value {
227 SqlValue::Boolean(b) => {
228 let b = *b;
229 Ok(Box::new(move |_| Datum::Bool(b)))
230 }
231 SqlValue::Number(n, _) => {
232 if let Ok(v) = n.parse::<i64>() {
233 Ok(Box::new(move |_| Datum::I64(v)))
234 } else if let Ok(v) = n.parse::<f64>() {
235 let bits = v.to_bits();
236 Ok(Box::new(move |_| Datum::F64(bits)))
237 } else {
238 Err(EvalError::Parse(format!("number literal '{n}'")))
239 }
240 }
241 SqlValue::SingleQuotedString(s) | SqlValue::DoubleQuotedString(s) => {
242 let bytes: bytes::Bytes = s.clone().into_bytes().into();
243 Ok(Box::new(move |_| Datum::Text(bytes.clone())))
244 }
245 SqlValue::Null => Ok(Box::new(|_| Datum::Null)),
246 other => Err(EvalError::Unsupported(format!("literal {other:?}"))),
247 }
248}
249
250fn binary_scalar(
251 left: &Expr,
252 op: BinaryOperator,
253 right: &Expr,
254 schema: &ScalarSchema,
255) -> Result<ScalarFn, EvalError> {
256 let l = compile_inner(left, schema)?;
257 let r = compile_inner(right, schema)?;
258 match op {
259 BinaryOperator::Eq => Ok(Box::new(move |row| Datum::Bool(datum_eq(&l(row), &r(row))))),
260 BinaryOperator::NotEq => Ok(Box::new(move |row| {
261 Datum::Bool(!datum_eq(&l(row), &r(row)))
262 })),
263 BinaryOperator::Lt => Ok(Box::new(move |row| {
264 datum_cmp_bool(&l(row), &r(row), |o| o.is_lt())
265 })),
266 BinaryOperator::LtEq => Ok(Box::new(move |row| {
267 datum_cmp_bool(&l(row), &r(row), |o| o.is_le())
268 })),
269 BinaryOperator::Gt => Ok(Box::new(move |row| {
270 datum_cmp_bool(&l(row), &r(row), |o| o.is_gt())
271 })),
272 BinaryOperator::GtEq => Ok(Box::new(move |row| {
273 datum_cmp_bool(&l(row), &r(row), |o| o.is_ge())
274 })),
275 BinaryOperator::And => Ok(Box::new(move |row| {
276 let lv = matches!(l(row), Datum::Bool(true));
277 if !lv {
278 return Datum::Bool(false);
279 }
280 Datum::Bool(matches!(r(row), Datum::Bool(true)))
281 })),
282 BinaryOperator::Or => Ok(Box::new(move |row| {
283 let lv = matches!(l(row), Datum::Bool(true));
284 if lv {
285 return Datum::Bool(true);
286 }
287 Datum::Bool(matches!(r(row), Datum::Bool(true)))
288 })),
289 other => Err(EvalError::Unsupported(format!("binary op {other:?}"))),
290 }
291}
292
293fn unary_scalar(
294 op: UnaryOperator,
295 inner: &Expr,
296 schema: &ScalarSchema,
297) -> Result<ScalarFn, EvalError> {
298 let e = compile_inner(inner, schema)?;
299 match op {
300 UnaryOperator::Not => Ok(Box::new(move |row| match e(row) {
301 Datum::Bool(b) => Datum::Bool(!b),
302 _ => Datum::Bool(false),
303 })),
304 UnaryOperator::Minus => Ok(Box::new(move |row| match e(row) {
305 Datum::I64(v) => Datum::I64(-v),
306 Datum::I32(v) => Datum::I32(-v),
307 Datum::I16(v) => Datum::I16(-v),
308 Datum::F64(v) => Datum::F64((-f64::from_bits(v)).to_bits()),
312 Datum::F32(v) => Datum::F32((-f32::from_bits(v)).to_bits()),
313 other => other,
314 })),
315 UnaryOperator::Plus => Ok(e),
316 other => Err(EvalError::Unsupported(format!("unary op {other:?}"))),
317 }
318}
319
320fn datum_eq(a: &Datum, b: &Datum) -> bool {
322 use Datum::{Bool, Null, Text, F32, F64, I16, I32, I64};
323 match (a, b) {
324 (Null, _) | (_, Null) => false,
325 (Bool(x), Bool(y)) => x == y,
326 (I64(x), I64(y)) => x == y,
327 (I32(x), I32(y)) => x == y,
328 (I16(x), I16(y)) => x == y,
329 (F64(x), F64(y)) => f64::from_bits(*x) == f64::from_bits(*y),
333 (F32(x), F32(y)) => f32::from_bits(*x) == f32::from_bits(*y),
334 (I64(x), I32(y)) => *x == i64::from(*y),
335 (I32(x), I64(y)) => i64::from(*x) == *y,
336 (I64(x), I16(y)) => *x == i64::from(*y),
337 (I16(x), I64(y)) => i64::from(*x) == *y,
338 (I32(x), I16(y)) => *x == i32::from(*y),
339 (I16(x), I32(y)) => i32::from(*x) == *y,
340 (Text(x), Text(y)) => x == y,
341 _ => false,
342 }
343}
344
345fn datum_cmp_bool<F>(a: &Datum, b: &Datum, pick: F) -> Datum
347where
348 F: Fn(std::cmp::Ordering) -> bool,
349{
350 use std::cmp::Ordering;
351 use Datum::{Null, Text, F64, I16, I32, I64};
352 let ord = match (a, b) {
353 (Null, _) | (_, Null) => return Datum::Bool(false),
354 (I64(x), I64(y)) => x.cmp(y),
355 (I32(x), I32(y)) => x.cmp(y),
356 (I16(x), I16(y)) => x.cmp(y),
357 (F64(x), F64(y)) => f64::from_bits(*x)
358 .partial_cmp(&f64::from_bits(*y))
359 .unwrap_or(Ordering::Equal),
360 (I64(x), I32(y)) => x.cmp(&i64::from(*y)),
361 (I32(x), I64(y)) => i64::from(*x).cmp(y),
362 (Text(x), Text(y)) => x.cmp(y),
363 _ => return Datum::Bool(false),
364 };
365 Datum::Bool(pick(ord))
366}
367
368impl fmt::Display for ScalarSchema {
369 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
370 f.write_str("(")?;
371 for (i, (name, ty)) in self.columns.iter().enumerate() {
372 if i > 0 {
373 f.write_str(", ")?;
374 }
375 write!(f, "{name}: {ty:?}")?;
376 }
377 f.write_str(")")
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384 use smallvec::smallvec;
385
386 fn posts_schema() -> ScalarSchema {
387 ScalarSchema::from_pairs([
388 ("id".to_owned(), ColumnType::Int),
389 ("title".to_owned(), ColumnType::Text),
390 ("published".to_owned(), ColumnType::Bool),
391 ])
392 }
393
394 fn text(s: &str) -> Datum {
395 Datum::Text(s.as_bytes().to_vec().into())
396 }
397
398 #[test]
399 fn column_ref_extracts_value() {
400 let schema = posts_schema();
401 let f = compile_scalar("published", &schema).unwrap();
402 let r: Row = smallvec![Datum::I64(1), text("hi"), Datum::Bool(true)];
403 assert_eq!(f(&r), Datum::Bool(true));
404 }
405
406 #[test]
407 fn predicate_equality_against_literal() {
408 let schema = posts_schema();
409 let p = compile_predicate("published = true", &schema).unwrap();
410 let r_pub: Row = smallvec![Datum::I64(1), text("a"), Datum::Bool(true)];
411 let r_draft: Row = smallvec![Datum::I64(2), text("b"), Datum::Bool(false)];
412 assert!(p(&r_pub));
413 assert!(!p(&r_draft));
414 }
415
416 #[test]
417 fn predicate_or_short_circuits() {
418 let schema = posts_schema();
419 let p = compile_predicate("published = true OR id = 99", &schema).unwrap();
420 let draft_99: Row = smallvec![Datum::I64(99), text("c"), Datum::Bool(false)];
421 assert!(p(&draft_99));
422 }
423
424 #[test]
425 fn predicate_with_inlined_admin_literal() {
426 let schema = posts_schema();
429 let p = compile_predicate("published = true OR true = true", &schema).unwrap();
430 let r: Row = smallvec![Datum::I64(1), text("x"), Datum::Bool(false)];
431 assert!(p(&r));
432 }
433
434 #[test]
435 fn predicate_ordering() {
436 let schema = posts_schema();
437 let p = compile_predicate("id < 5", &schema).unwrap();
438 let small: Row = smallvec![Datum::I64(3), text(""), Datum::Bool(true)];
439 let large: Row = smallvec![Datum::I64(7), text(""), Datum::Bool(true)];
440 assert!(p(&small));
441 assert!(!p(&large));
442 }
443
444 #[test]
445 fn unknown_column_rejected_at_compile_time() {
446 let schema = posts_schema();
447 let Err(err) = compile_predicate("ghost = 1", &schema) else {
449 panic!("expected compile failure on unknown column");
450 };
451 assert!(matches!(err, EvalError::UnknownColumn(_)));
452 }
453
454 #[test]
455 fn int_extractor_handles_star() {
456 let schema = posts_schema();
457 let f = compile_int_extractor("*", &schema).unwrap();
458 let r: Row = smallvec![Datum::I64(42), text(""), Datum::Bool(true)];
459 assert_eq!(f(&r), 0);
460 }
461
462 #[test]
463 fn int_extractor_reads_named_column() {
464 let schema = posts_schema();
465 let f = compile_int_extractor("id", &schema).unwrap();
466 let r: Row = smallvec![Datum::I64(42), text(""), Datum::Bool(true)];
467 assert_eq!(f(&r), 42);
468 }
469}