Skip to main content

palimpsest_dataflow/palimpsest/
eval.rs

1// Copyright 2026 Thousand Birds Inc.
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Runtime expression evaluator.
5//!
6//! Compiles SQL expression strings (boolean predicates, projection
7//! expressions, aggregate input columns, order-by keys) into closures
8//! over [`Row`] values. Mirrors GlueSQL's `Evaluator` shape: the
9//! parser produces an `Expr` AST; this module walks the AST and
10//! returns a `Box<dyn Fn(&Row) -> Datum>` (or a typed wrapper) that
11//! reads concrete column values out of the row at evaluation time.
12//!
13//! Notably **no `$user.*` resolution lives here**. The permission
14//! rewriter materializes user-context references into literal values
15//! before the predicate reaches the dataflow
16//! (`palimpsest_permissions::compile::CompiledPredicate::materialize`),
17//! so the evaluator only deals with column refs + literals + boolean
18//! logic.
19
20use std::collections::BTreeMap;
21use std::fmt;
22
23use palimpsest_sql::catalog::ColumnType;
24use palimpsest_wal::Datum;
25use sqlparser::ast::{BinaryOperator, Expr, UnaryOperator, Value as SqlValue};
26use sqlparser::dialect::PostgreSqlDialect;
27use sqlparser::parser::Parser;
28use thiserror::Error;
29
30use crate::palimpsest::wal::Row;
31
32/// Closure that reads a single `Datum` out of a row.
33pub type ScalarFn = Box<dyn Fn(&Row) -> Datum + Send + Sync>;
34
35/// Closure that evaluates a boolean predicate over a row.
36pub type PredicateFn = Box<dyn Fn(&Row) -> bool + Send + Sync>;
37
38/// Closure that extracts an `i64`-coerced column value (for aggregate
39/// inputs). Returns 0 for `NULL` / non-numeric — same handling as
40/// SQL's implicit-coalesce-to-zero in `SUM` / `AVG`.
41pub type IntExtractor = Box<dyn Fn(&Row) -> i64 + Send + Sync>;
42
43/// Per-column metadata used during compilation to resolve identifiers
44/// to row indices. Built by the MIR walker from each node's output
45/// schema.
46#[derive(Debug, Clone, Default)]
47pub struct ScalarSchema {
48    columns: Vec<(String, ColumnType)>,
49    index: BTreeMap<String, usize>,
50}
51
52impl ScalarSchema {
53    /// Build a schema from a sequence of `(name, type)` pairs in
54    /// column order. The last column with a given name wins on
55    /// collision (mirroring SQL's "last alias wins" projection rule).
56    #[must_use]
57    pub fn from_pairs(columns: impl IntoIterator<Item = (String, ColumnType)>) -> Self {
58        let columns: Vec<_> = columns.into_iter().collect();
59        let mut index = BTreeMap::new();
60        for (i, (name, _)) in columns.iter().enumerate() {
61            index.insert(name.clone(), i);
62        }
63        Self { columns, index }
64    }
65
66    /// Row index of the column named `name`, if any.
67    #[must_use]
68    pub fn index_of(&self, name: &str) -> Option<usize> {
69        self.index.get(name).copied()
70    }
71
72    /// Declared type of the column named `name`, if any.
73    #[must_use]
74    pub fn column_type(&self, name: &str) -> Option<ColumnType> {
75        self.index.get(name).map(|&i| self.columns[i].1)
76    }
77
78    /// Ordered `(name, type)` pairs.
79    #[must_use]
80    pub fn columns(&self) -> &[(String, ColumnType)] {
81        &self.columns
82    }
83
84    /// Number of columns.
85    #[must_use]
86    pub fn len(&self) -> usize {
87        self.columns.len()
88    }
89
90    /// True when no columns are declared.
91    #[must_use]
92    pub fn is_empty(&self) -> bool {
93        self.columns.is_empty()
94    }
95}
96
97/// Errors raised during compile-time analysis. Runtime evaluation
98/// itself is total: every closure returns *some* `Datum` — invalid
99/// arithmetic / type mismatches surface as `Datum::Null`, matching
100/// SQL's three-valued semantics on most paths.
101#[derive(Debug, Error)]
102pub enum EvalError {
103    /// The SQL parser refused the expression.
104    #[error("parse error: {0}")]
105    Parse(String),
106    /// The expression uses a feature this evaluator doesn't implement.
107    #[error("unsupported expression: {0}")]
108    Unsupported(String),
109    /// An identifier didn't resolve against the input schema.
110    #[error("unknown column: {0}")]
111    UnknownColumn(String),
112}
113
114/// Compile `expr_sql` into a boolean predicate. Non-bool / null
115/// results count as `false`, matching `WHERE` semantics.
116///
117/// # Errors
118/// Returns [`EvalError`] on parse failure, unknown columns, or
119/// unsupported operator kinds.
120pub fn compile_predicate(expr_sql: &str, schema: &ScalarSchema) -> Result<PredicateFn, EvalError> {
121    let scalar = compile_scalar(expr_sql, schema)?;
122    Ok(Box::new(move |row| {
123        matches!(scalar(row), Datum::Bool(true))
124    }))
125}
126
127/// Compile `expr_sql` into a scalar closure.
128///
129/// # Errors
130/// See [`compile_predicate`].
131pub fn compile_scalar(expr_sql: &str, schema: &ScalarSchema) -> Result<ScalarFn, EvalError> {
132    let expr = parse_expr(expr_sql)?;
133    compile_inner(&expr, schema)
134}
135
136/// Convenience: compile a single column reference into an `i64`
137/// extractor. Used by aggregate input expressions like `SUM(value)`,
138/// where the argument is a simple identifier. Also accepts `*` as
139/// a sentinel for `COUNT(*)`, returning a constant `0` (the aggregate
140/// only inspects the diff multiplicity in that case).
141///
142/// # Errors
143/// Returns `EvalError::UnknownColumn` if the named column isn't in
144/// `schema`, or `EvalError::Unsupported` for non-identifier inputs.
145pub fn compile_int_extractor(
146    arg_sql: &str,
147    schema: &ScalarSchema,
148) -> Result<IntExtractor, EvalError> {
149    let trimmed = arg_sql.trim();
150    if trimmed == "*" {
151        return Ok(Box::new(|_| 0));
152    }
153    let scalar = compile_scalar(trimmed, schema)?;
154    Ok(Box::new(move |row| match scalar(row) {
155        Datum::I64(v) => v,
156        Datum::I32(v) => i64::from(v),
157        Datum::I16(v) => i64::from(v),
158        _ => 0,
159    }))
160}
161
162fn parse_expr(sql: &str) -> Result<Expr, EvalError> {
163    let dialect = PostgreSqlDialect {};
164    let mut parser = Parser::new(&dialect)
165        .try_with_sql(sql)
166        .map_err(|err| EvalError::Parse(err.to_string()))?;
167    parser
168        .parse_expr()
169        .map_err(|err| EvalError::Parse(err.to_string()))
170}
171
172fn compile_inner(expr: &Expr, schema: &ScalarSchema) -> Result<ScalarFn, EvalError> {
173    match expr {
174        Expr::Nested(inner) => compile_inner(inner, schema),
175        Expr::Identifier(ident) => identifier_scalar(&ident.value, schema),
176        Expr::CompoundIdentifier(parts) => {
177            // Treat `table.column` as just `column` for our flat row
178            // model. The MIR's BaseTable.project already pinned
179            // column ordering, so qualification is informational.
180            let last = parts
181                .last()
182                .ok_or_else(|| EvalError::Unsupported("empty compound identifier".to_owned()))?;
183            identifier_scalar(&last.value, schema)
184        }
185        Expr::Value(value) => value_scalar(value),
186        Expr::BinaryOp { left, op, right } => binary_scalar(left, op.clone(), right, schema),
187        Expr::UnaryOp { op, expr: inner } => unary_scalar(op.clone(), inner, schema),
188        Expr::IsNull(inner) => {
189            let target = compile_inner(inner, schema)?;
190            Ok(Box::new(move |row| {
191                Datum::Bool(matches!(target(row), Datum::Null))
192            }))
193        }
194        Expr::IsNotNull(inner) => {
195            let target = compile_inner(inner, schema)?;
196            Ok(Box::new(move |row| {
197                Datum::Bool(!matches!(target(row), Datum::Null))
198            }))
199        }
200        Expr::IsTrue(inner) => {
201            let target = compile_inner(inner, schema)?;
202            Ok(Box::new(move |row| {
203                Datum::Bool(matches!(target(row), Datum::Bool(true)))
204            }))
205        }
206        Expr::IsFalse(inner) => {
207            let target = compile_inner(inner, schema)?;
208            Ok(Box::new(move |row| {
209                Datum::Bool(matches!(target(row), Datum::Bool(false)))
210            }))
211        }
212        other => Err(EvalError::Unsupported(format!("{other:?}"))),
213    }
214}
215
216fn identifier_scalar(name: &str, schema: &ScalarSchema) -> Result<ScalarFn, EvalError> {
217    let idx = schema
218        .index_of(name)
219        .ok_or_else(|| EvalError::UnknownColumn(name.to_owned()))?;
220    Ok(Box::new(move |row| {
221        row.get(idx).cloned().unwrap_or(Datum::Null)
222    }))
223}
224
225fn value_scalar(value: &SqlValue) -> Result<ScalarFn, EvalError> {
226    match value {
227        SqlValue::Boolean(b) => {
228            let b = *b;
229            Ok(Box::new(move |_| Datum::Bool(b)))
230        }
231        SqlValue::Number(n, _) => {
232            if let Ok(v) = n.parse::<i64>() {
233                Ok(Box::new(move |_| Datum::I64(v)))
234            } else if let Ok(v) = n.parse::<f64>() {
235                let bits = v.to_bits();
236                Ok(Box::new(move |_| Datum::F64(bits)))
237            } else {
238                Err(EvalError::Parse(format!("number literal '{n}'")))
239            }
240        }
241        SqlValue::SingleQuotedString(s) | SqlValue::DoubleQuotedString(s) => {
242            let bytes: bytes::Bytes = s.clone().into_bytes().into();
243            Ok(Box::new(move |_| Datum::Text(bytes.clone())))
244        }
245        SqlValue::Null => Ok(Box::new(|_| Datum::Null)),
246        other => Err(EvalError::Unsupported(format!("literal {other:?}"))),
247    }
248}
249
250fn binary_scalar(
251    left: &Expr,
252    op: BinaryOperator,
253    right: &Expr,
254    schema: &ScalarSchema,
255) -> Result<ScalarFn, EvalError> {
256    let l = compile_inner(left, schema)?;
257    let r = compile_inner(right, schema)?;
258    match op {
259        BinaryOperator::Eq => Ok(Box::new(move |row| Datum::Bool(datum_eq(&l(row), &r(row))))),
260        BinaryOperator::NotEq => Ok(Box::new(move |row| {
261            Datum::Bool(!datum_eq(&l(row), &r(row)))
262        })),
263        BinaryOperator::Lt => Ok(Box::new(move |row| {
264            datum_cmp_bool(&l(row), &r(row), |o| o.is_lt())
265        })),
266        BinaryOperator::LtEq => Ok(Box::new(move |row| {
267            datum_cmp_bool(&l(row), &r(row), |o| o.is_le())
268        })),
269        BinaryOperator::Gt => Ok(Box::new(move |row| {
270            datum_cmp_bool(&l(row), &r(row), |o| o.is_gt())
271        })),
272        BinaryOperator::GtEq => Ok(Box::new(move |row| {
273            datum_cmp_bool(&l(row), &r(row), |o| o.is_ge())
274        })),
275        BinaryOperator::And => Ok(Box::new(move |row| {
276            let lv = matches!(l(row), Datum::Bool(true));
277            if !lv {
278                return Datum::Bool(false);
279            }
280            Datum::Bool(matches!(r(row), Datum::Bool(true)))
281        })),
282        BinaryOperator::Or => Ok(Box::new(move |row| {
283            let lv = matches!(l(row), Datum::Bool(true));
284            if lv {
285                return Datum::Bool(true);
286            }
287            Datum::Bool(matches!(r(row), Datum::Bool(true)))
288        })),
289        other => Err(EvalError::Unsupported(format!("binary op {other:?}"))),
290    }
291}
292
293fn unary_scalar(
294    op: UnaryOperator,
295    inner: &Expr,
296    schema: &ScalarSchema,
297) -> Result<ScalarFn, EvalError> {
298    let e = compile_inner(inner, schema)?;
299    match op {
300        UnaryOperator::Not => Ok(Box::new(move |row| match e(row) {
301            Datum::Bool(b) => Datum::Bool(!b),
302            _ => Datum::Bool(false),
303        })),
304        UnaryOperator::Minus => Ok(Box::new(move |row| match e(row) {
305            Datum::I64(v) => Datum::I64(-v),
306            Datum::I32(v) => Datum::I32(-v),
307            Datum::I16(v) => Datum::I16(-v),
308            // `Datum::F{32,64}` store the bit pattern of the float
309            // rather than the float itself, so negation has to round-
310            // trip through the IEEE representation.
311            Datum::F64(v) => Datum::F64((-f64::from_bits(v)).to_bits()),
312            Datum::F32(v) => Datum::F32((-f32::from_bits(v)).to_bits()),
313            other => other,
314        })),
315        UnaryOperator::Plus => Ok(e),
316        other => Err(EvalError::Unsupported(format!("unary op {other:?}"))),
317    }
318}
319
320/// SQL equality with three-valued logic: NULL on either side → false.
321fn datum_eq(a: &Datum, b: &Datum) -> bool {
322    use Datum::{Bool, Null, Text, F32, F64, I16, I32, I64};
323    match (a, b) {
324        (Null, _) | (_, Null) => false,
325        (Bool(x), Bool(y)) => x == y,
326        (I64(x), I64(y)) => x == y,
327        (I32(x), I32(y)) => x == y,
328        (I16(x), I16(y)) => x == y,
329        // `F32`/`F64` store IEEE bit patterns; comparing the integer
330        // backing types gives canonical-bit equality rather than the
331        // float equality SQL expects. Decode before comparing.
332        (F64(x), F64(y)) => f64::from_bits(*x) == f64::from_bits(*y),
333        (F32(x), F32(y)) => f32::from_bits(*x) == f32::from_bits(*y),
334        (I64(x), I32(y)) => *x == i64::from(*y),
335        (I32(x), I64(y)) => i64::from(*x) == *y,
336        (I64(x), I16(y)) => *x == i64::from(*y),
337        (I16(x), I64(y)) => i64::from(*x) == *y,
338        (I32(x), I16(y)) => *x == i32::from(*y),
339        (I16(x), I32(y)) => i32::from(*x) == *y,
340        (Text(x), Text(y)) => x == y,
341        _ => false,
342    }
343}
344
345/// SQL ordering with three-valued logic: NULL on either side → false.
346fn datum_cmp_bool<F>(a: &Datum, b: &Datum, pick: F) -> Datum
347where
348    F: Fn(std::cmp::Ordering) -> bool,
349{
350    use std::cmp::Ordering;
351    use Datum::{Null, Text, F64, I16, I32, I64};
352    let ord = match (a, b) {
353        (Null, _) | (_, Null) => return Datum::Bool(false),
354        (I64(x), I64(y)) => x.cmp(y),
355        (I32(x), I32(y)) => x.cmp(y),
356        (I16(x), I16(y)) => x.cmp(y),
357        (F64(x), F64(y)) => f64::from_bits(*x)
358            .partial_cmp(&f64::from_bits(*y))
359            .unwrap_or(Ordering::Equal),
360        (I64(x), I32(y)) => x.cmp(&i64::from(*y)),
361        (I32(x), I64(y)) => i64::from(*x).cmp(y),
362        (Text(x), Text(y)) => x.cmp(y),
363        _ => return Datum::Bool(false),
364    };
365    Datum::Bool(pick(ord))
366}
367
368impl fmt::Display for ScalarSchema {
369    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
370        f.write_str("(")?;
371        for (i, (name, ty)) in self.columns.iter().enumerate() {
372            if i > 0 {
373                f.write_str(", ")?;
374            }
375            write!(f, "{name}: {ty:?}")?;
376        }
377        f.write_str(")")
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use smallvec::smallvec;
385
386    fn posts_schema() -> ScalarSchema {
387        ScalarSchema::from_pairs([
388            ("id".to_owned(), ColumnType::Int),
389            ("title".to_owned(), ColumnType::Text),
390            ("published".to_owned(), ColumnType::Bool),
391        ])
392    }
393
394    fn text(s: &str) -> Datum {
395        Datum::Text(s.as_bytes().to_vec().into())
396    }
397
398    #[test]
399    fn column_ref_extracts_value() {
400        let schema = posts_schema();
401        let f = compile_scalar("published", &schema).unwrap();
402        let r: Row = smallvec![Datum::I64(1), text("hi"), Datum::Bool(true)];
403        assert_eq!(f(&r), Datum::Bool(true));
404    }
405
406    #[test]
407    fn predicate_equality_against_literal() {
408        let schema = posts_schema();
409        let p = compile_predicate("published = true", &schema).unwrap();
410        let r_pub: Row = smallvec![Datum::I64(1), text("a"), Datum::Bool(true)];
411        let r_draft: Row = smallvec![Datum::I64(2), text("b"), Datum::Bool(false)];
412        assert!(p(&r_pub));
413        assert!(!p(&r_draft));
414    }
415
416    #[test]
417    fn predicate_or_short_circuits() {
418        let schema = posts_schema();
419        let p = compile_predicate("published = true OR id = 99", &schema).unwrap();
420        let draft_99: Row = smallvec![Datum::I64(99), text("c"), Datum::Bool(false)];
421        assert!(p(&draft_99));
422    }
423
424    #[test]
425    fn predicate_with_inlined_admin_literal() {
426        // After permission rewriting, $user.is_admin becomes a literal —
427        // this is exactly the predicate dataflow operators see.
428        let schema = posts_schema();
429        let p = compile_predicate("published = true OR true = true", &schema).unwrap();
430        let r: Row = smallvec![Datum::I64(1), text("x"), Datum::Bool(false)];
431        assert!(p(&r));
432    }
433
434    #[test]
435    fn predicate_ordering() {
436        let schema = posts_schema();
437        let p = compile_predicate("id < 5", &schema).unwrap();
438        let small: Row = smallvec![Datum::I64(3), text(""), Datum::Bool(true)];
439        let large: Row = smallvec![Datum::I64(7), text(""), Datum::Bool(true)];
440        assert!(p(&small));
441        assert!(!p(&large));
442    }
443
444    #[test]
445    fn unknown_column_rejected_at_compile_time() {
446        let schema = posts_schema();
447        // PredicateFn isn't Debug, so we destructure the Err arm manually.
448        let Err(err) = compile_predicate("ghost = 1", &schema) else {
449            panic!("expected compile failure on unknown column");
450        };
451        assert!(matches!(err, EvalError::UnknownColumn(_)));
452    }
453
454    #[test]
455    fn int_extractor_handles_star() {
456        let schema = posts_schema();
457        let f = compile_int_extractor("*", &schema).unwrap();
458        let r: Row = smallvec![Datum::I64(42), text(""), Datum::Bool(true)];
459        assert_eq!(f(&r), 0);
460    }
461
462    #[test]
463    fn int_extractor_reads_named_column() {
464        let schema = posts_schema();
465        let f = compile_int_extractor("id", &schema).unwrap();
466        let r: Row = smallvec![Datum::I64(42), text(""), Datum::Bool(true)];
467        assert_eq!(f(&r), 42);
468    }
469}