use std::collections::BTreeMap;
use std::fmt;
use palimpsest_sql::catalog::ColumnType;
use palimpsest_wal::Datum;
use sqlparser::ast::{BinaryOperator, Expr, UnaryOperator, Value as SqlValue};
use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;
use thiserror::Error;
use crate::palimpsest::wal::Row;
pub type ScalarFn = Box<dyn Fn(&Row) -> Datum + Send + Sync>;
pub type PredicateFn = Box<dyn Fn(&Row) -> bool + Send + Sync>;
pub type IntExtractor = Box<dyn Fn(&Row) -> i64 + Send + Sync>;
#[derive(Debug, Clone, Default)]
pub struct ScalarSchema {
columns: Vec<(String, ColumnType)>,
index: BTreeMap<String, usize>,
}
impl ScalarSchema {
#[must_use]
pub fn from_pairs(columns: impl IntoIterator<Item = (String, ColumnType)>) -> Self {
let columns: Vec<_> = columns.into_iter().collect();
let mut index = BTreeMap::new();
for (i, (name, _)) in columns.iter().enumerate() {
index.insert(name.clone(), i);
}
Self { columns, index }
}
#[must_use]
pub fn index_of(&self, name: &str) -> Option<usize> {
self.index.get(name).copied()
}
#[must_use]
pub fn column_type(&self, name: &str) -> Option<ColumnType> {
self.index.get(name).map(|&i| self.columns[i].1)
}
#[must_use]
pub fn columns(&self) -> &[(String, ColumnType)] {
&self.columns
}
#[must_use]
pub fn len(&self) -> usize {
self.columns.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.columns.is_empty()
}
}
#[derive(Debug, Error)]
pub enum EvalError {
#[error("parse error: {0}")]
Parse(String),
#[error("unsupported expression: {0}")]
Unsupported(String),
#[error("unknown column: {0}")]
UnknownColumn(String),
}
pub fn compile_predicate(expr_sql: &str, schema: &ScalarSchema) -> Result<PredicateFn, EvalError> {
let scalar = compile_scalar(expr_sql, schema)?;
Ok(Box::new(move |row| {
matches!(scalar(row), Datum::Bool(true))
}))
}
pub fn compile_scalar(expr_sql: &str, schema: &ScalarSchema) -> Result<ScalarFn, EvalError> {
let expr = parse_expr(expr_sql)?;
compile_inner(&expr, schema)
}
pub fn compile_int_extractor(
arg_sql: &str,
schema: &ScalarSchema,
) -> Result<IntExtractor, EvalError> {
let trimmed = arg_sql.trim();
if trimmed == "*" {
return Ok(Box::new(|_| 0));
}
let scalar = compile_scalar(trimmed, schema)?;
Ok(Box::new(move |row| match scalar(row) {
Datum::I64(v) => v,
Datum::I32(v) => i64::from(v),
Datum::I16(v) => i64::from(v),
_ => 0,
}))
}
fn parse_expr(sql: &str) -> Result<Expr, EvalError> {
let dialect = PostgreSqlDialect {};
let mut parser = Parser::new(&dialect)
.try_with_sql(sql)
.map_err(|err| EvalError::Parse(err.to_string()))?;
parser
.parse_expr()
.map_err(|err| EvalError::Parse(err.to_string()))
}
fn compile_inner(expr: &Expr, schema: &ScalarSchema) -> Result<ScalarFn, EvalError> {
match expr {
Expr::Nested(inner) => compile_inner(inner, schema),
Expr::Identifier(ident) => identifier_scalar(&ident.value, schema),
Expr::CompoundIdentifier(parts) => {
let last = parts
.last()
.ok_or_else(|| EvalError::Unsupported("empty compound identifier".to_owned()))?;
identifier_scalar(&last.value, schema)
}
Expr::Value(value) => value_scalar(value),
Expr::BinaryOp { left, op, right } => binary_scalar(left, op.clone(), right, schema),
Expr::UnaryOp { op, expr: inner } => unary_scalar(op.clone(), inner, schema),
Expr::IsNull(inner) => {
let target = compile_inner(inner, schema)?;
Ok(Box::new(move |row| {
Datum::Bool(matches!(target(row), Datum::Null))
}))
}
Expr::IsNotNull(inner) => {
let target = compile_inner(inner, schema)?;
Ok(Box::new(move |row| {
Datum::Bool(!matches!(target(row), Datum::Null))
}))
}
Expr::IsTrue(inner) => {
let target = compile_inner(inner, schema)?;
Ok(Box::new(move |row| {
Datum::Bool(matches!(target(row), Datum::Bool(true)))
}))
}
Expr::IsFalse(inner) => {
let target = compile_inner(inner, schema)?;
Ok(Box::new(move |row| {
Datum::Bool(matches!(target(row), Datum::Bool(false)))
}))
}
other => Err(EvalError::Unsupported(format!("{other:?}"))),
}
}
fn identifier_scalar(name: &str, schema: &ScalarSchema) -> Result<ScalarFn, EvalError> {
let idx = schema
.index_of(name)
.ok_or_else(|| EvalError::UnknownColumn(name.to_owned()))?;
Ok(Box::new(move |row| {
row.get(idx).cloned().unwrap_or(Datum::Null)
}))
}
fn value_scalar(value: &SqlValue) -> Result<ScalarFn, EvalError> {
match value {
SqlValue::Boolean(b) => {
let b = *b;
Ok(Box::new(move |_| Datum::Bool(b)))
}
SqlValue::Number(n, _) => {
if let Ok(v) = n.parse::<i64>() {
Ok(Box::new(move |_| Datum::I64(v)))
} else if let Ok(v) = n.parse::<f64>() {
let bits = v.to_bits();
Ok(Box::new(move |_| Datum::F64(bits)))
} else {
Err(EvalError::Parse(format!("number literal '{n}'")))
}
}
SqlValue::SingleQuotedString(s) | SqlValue::DoubleQuotedString(s) => {
let bytes: bytes::Bytes = s.clone().into_bytes().into();
Ok(Box::new(move |_| Datum::Text(bytes.clone())))
}
SqlValue::Null => Ok(Box::new(|_| Datum::Null)),
other => Err(EvalError::Unsupported(format!("literal {other:?}"))),
}
}
fn binary_scalar(
left: &Expr,
op: BinaryOperator,
right: &Expr,
schema: &ScalarSchema,
) -> Result<ScalarFn, EvalError> {
let l = compile_inner(left, schema)?;
let r = compile_inner(right, schema)?;
match op {
BinaryOperator::Eq => Ok(Box::new(move |row| Datum::Bool(datum_eq(&l(row), &r(row))))),
BinaryOperator::NotEq => Ok(Box::new(move |row| {
Datum::Bool(!datum_eq(&l(row), &r(row)))
})),
BinaryOperator::Lt => Ok(Box::new(move |row| {
datum_cmp_bool(&l(row), &r(row), |o| o.is_lt())
})),
BinaryOperator::LtEq => Ok(Box::new(move |row| {
datum_cmp_bool(&l(row), &r(row), |o| o.is_le())
})),
BinaryOperator::Gt => Ok(Box::new(move |row| {
datum_cmp_bool(&l(row), &r(row), |o| o.is_gt())
})),
BinaryOperator::GtEq => Ok(Box::new(move |row| {
datum_cmp_bool(&l(row), &r(row), |o| o.is_ge())
})),
BinaryOperator::And => Ok(Box::new(move |row| {
let lv = matches!(l(row), Datum::Bool(true));
if !lv {
return Datum::Bool(false);
}
Datum::Bool(matches!(r(row), Datum::Bool(true)))
})),
BinaryOperator::Or => Ok(Box::new(move |row| {
let lv = matches!(l(row), Datum::Bool(true));
if lv {
return Datum::Bool(true);
}
Datum::Bool(matches!(r(row), Datum::Bool(true)))
})),
other => Err(EvalError::Unsupported(format!("binary op {other:?}"))),
}
}
fn unary_scalar(
op: UnaryOperator,
inner: &Expr,
schema: &ScalarSchema,
) -> Result<ScalarFn, EvalError> {
let e = compile_inner(inner, schema)?;
match op {
UnaryOperator::Not => Ok(Box::new(move |row| match e(row) {
Datum::Bool(b) => Datum::Bool(!b),
_ => Datum::Bool(false),
})),
UnaryOperator::Minus => Ok(Box::new(move |row| match e(row) {
Datum::I64(v) => Datum::I64(-v),
Datum::I32(v) => Datum::I32(-v),
Datum::I16(v) => Datum::I16(-v),
Datum::F64(v) => Datum::F64((-f64::from_bits(v)).to_bits()),
Datum::F32(v) => Datum::F32((-f32::from_bits(v)).to_bits()),
other => other,
})),
UnaryOperator::Plus => Ok(e),
other => Err(EvalError::Unsupported(format!("unary op {other:?}"))),
}
}
fn datum_eq(a: &Datum, b: &Datum) -> bool {
use Datum::{Bool, Null, Text, F32, F64, I16, I32, I64};
match (a, b) {
(Null, _) | (_, Null) => false,
(Bool(x), Bool(y)) => x == y,
(I64(x), I64(y)) => x == y,
(I32(x), I32(y)) => x == y,
(I16(x), I16(y)) => x == y,
(F64(x), F64(y)) => f64::from_bits(*x) == f64::from_bits(*y),
(F32(x), F32(y)) => f32::from_bits(*x) == f32::from_bits(*y),
(I64(x), I32(y)) => *x == i64::from(*y),
(I32(x), I64(y)) => i64::from(*x) == *y,
(I64(x), I16(y)) => *x == i64::from(*y),
(I16(x), I64(y)) => i64::from(*x) == *y,
(I32(x), I16(y)) => *x == i32::from(*y),
(I16(x), I32(y)) => i32::from(*x) == *y,
(Text(x), Text(y)) => x == y,
_ => false,
}
}
fn datum_cmp_bool<F>(a: &Datum, b: &Datum, pick: F) -> Datum
where
F: Fn(std::cmp::Ordering) -> bool,
{
use std::cmp::Ordering;
use Datum::{Null, Text, F64, I16, I32, I64};
let ord = match (a, b) {
(Null, _) | (_, Null) => return Datum::Bool(false),
(I64(x), I64(y)) => x.cmp(y),
(I32(x), I32(y)) => x.cmp(y),
(I16(x), I16(y)) => x.cmp(y),
(F64(x), F64(y)) => f64::from_bits(*x)
.partial_cmp(&f64::from_bits(*y))
.unwrap_or(Ordering::Equal),
(I64(x), I32(y)) => x.cmp(&i64::from(*y)),
(I32(x), I64(y)) => i64::from(*x).cmp(y),
(Text(x), Text(y)) => x.cmp(y),
_ => return Datum::Bool(false),
};
Datum::Bool(pick(ord))
}
impl fmt::Display for ScalarSchema {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("(")?;
for (i, (name, ty)) in self.columns.iter().enumerate() {
if i > 0 {
f.write_str(", ")?;
}
write!(f, "{name}: {ty:?}")?;
}
f.write_str(")")
}
}
#[cfg(test)]
mod tests {
use super::*;
use smallvec::smallvec;
fn posts_schema() -> ScalarSchema {
ScalarSchema::from_pairs([
("id".to_owned(), ColumnType::Int),
("title".to_owned(), ColumnType::Text),
("published".to_owned(), ColumnType::Bool),
])
}
fn text(s: &str) -> Datum {
Datum::Text(s.as_bytes().to_vec().into())
}
#[test]
fn column_ref_extracts_value() {
let schema = posts_schema();
let f = compile_scalar("published", &schema).unwrap();
let r: Row = smallvec![Datum::I64(1), text("hi"), Datum::Bool(true)];
assert_eq!(f(&r), Datum::Bool(true));
}
#[test]
fn predicate_equality_against_literal() {
let schema = posts_schema();
let p = compile_predicate("published = true", &schema).unwrap();
let r_pub: Row = smallvec![Datum::I64(1), text("a"), Datum::Bool(true)];
let r_draft: Row = smallvec![Datum::I64(2), text("b"), Datum::Bool(false)];
assert!(p(&r_pub));
assert!(!p(&r_draft));
}
#[test]
fn predicate_or_short_circuits() {
let schema = posts_schema();
let p = compile_predicate("published = true OR id = 99", &schema).unwrap();
let draft_99: Row = smallvec![Datum::I64(99), text("c"), Datum::Bool(false)];
assert!(p(&draft_99));
}
#[test]
fn predicate_with_inlined_admin_literal() {
let schema = posts_schema();
let p = compile_predicate("published = true OR true = true", &schema).unwrap();
let r: Row = smallvec![Datum::I64(1), text("x"), Datum::Bool(false)];
assert!(p(&r));
}
#[test]
fn predicate_ordering() {
let schema = posts_schema();
let p = compile_predicate("id < 5", &schema).unwrap();
let small: Row = smallvec![Datum::I64(3), text(""), Datum::Bool(true)];
let large: Row = smallvec![Datum::I64(7), text(""), Datum::Bool(true)];
assert!(p(&small));
assert!(!p(&large));
}
#[test]
fn unknown_column_rejected_at_compile_time() {
let schema = posts_schema();
let Err(err) = compile_predicate("ghost = 1", &schema) else {
panic!("expected compile failure on unknown column");
};
assert!(matches!(err, EvalError::UnknownColumn(_)));
}
#[test]
fn int_extractor_handles_star() {
let schema = posts_schema();
let f = compile_int_extractor("*", &schema).unwrap();
let r: Row = smallvec![Datum::I64(42), text(""), Datum::Bool(true)];
assert_eq!(f(&r), 0);
}
#[test]
fn int_extractor_reads_named_column() {
let schema = posts_schema();
let f = compile_int_extractor("id", &schema).unwrap();
let r: Row = smallvec![Datum::I64(42), text(""), Datum::Bool(true)];
assert_eq!(f(&r), 42);
}
}