mod normalize;
#[cfg(test)]
mod tests;
mod validate;
use crate::{
db::{
predicate::{
CoercionId, CompareFieldsPredicate, CompareOp, ComparePredicate, MembershipCompareLeaf,
Predicate, collapse_membership_compare_leaves, normalize as normalize_predicate,
},
query::plan::expr::{
Expr, derive_normalized_bool_expr_predicate_subset, is_normalized_bool_expr,
},
sql::{
lowering::{
SqlLoweringError,
expr::{SqlExprPhase, lower_sql_expr},
},
parser::{SqlExpr, SqlExprBinaryOp, SqlScalarFunction},
},
},
value::Value,
};
pub(in crate::db::sql::lowering) fn lower_sql_where_expr(
expr: &SqlExpr,
) -> Result<Predicate, SqlLoweringError> {
let lowered_expr = lower_sql_where_bool_expr(expr)?;
derive_sql_where_expr_predicate_subset(expr, &lowered_expr)
.ok_or_else(SqlLoweringError::unsupported_where_expression)
}
pub(in crate::db::sql::lowering) fn derive_sql_where_expr_predicate_subset(
sql_expr: &SqlExpr,
lowered_expr: &Expr,
) -> Option<Predicate> {
derive_top_level_sql_membership_predicate_subset(sql_expr)
.or_else(|| derive_normalized_bool_expr_predicate_subset(lowered_expr))
}
pub(in crate::db::sql::lowering) fn derive_sql_where_expr_predicate_only_subset(
sql_expr: &SqlExpr,
) -> Option<Predicate> {
if !predicate_only_sql_expr_contains_membership(sql_expr) {
return None;
}
derive_sql_where_expr_predicate_only_subset_impl(sql_expr)
.map(|predicate| normalize_predicate(&predicate))
}
fn derive_sql_where_expr_predicate_only_subset_impl(sql_expr: &SqlExpr) -> Option<Predicate> {
match sql_expr {
SqlExpr::Membership { .. } => derive_top_level_sql_membership_predicate_subset(sql_expr),
SqlExpr::Binary {
op: SqlExprBinaryOp::And,
left,
right,
} => Some(Predicate::And(vec![
derive_sql_where_expr_predicate_only_subset_impl(left.as_ref())?,
derive_sql_where_expr_predicate_only_subset_impl(right.as_ref())?,
])),
SqlExpr::Binary { op, left, right } => {
derive_sql_binary_compare_predicate(*op, left.as_ref(), right.as_ref())
}
SqlExpr::Field(_)
| SqlExpr::FieldPath { .. }
| SqlExpr::Aggregate(_)
| SqlExpr::Literal(_)
| SqlExpr::Param { .. }
| SqlExpr::NullTest { .. }
| SqlExpr::Like { .. }
| SqlExpr::FunctionCall { .. }
| SqlExpr::Unary { .. }
| SqlExpr::Case { .. } => None,
}
}
fn predicate_only_sql_expr_contains_membership(sql_expr: &SqlExpr) -> bool {
match sql_expr {
SqlExpr::Membership { .. } => true,
SqlExpr::Binary {
op: SqlExprBinaryOp::And,
left,
right,
} => {
predicate_only_sql_expr_contains_membership(left.as_ref())
|| predicate_only_sql_expr_contains_membership(right.as_ref())
}
SqlExpr::Binary { .. }
| SqlExpr::Field(_)
| SqlExpr::FieldPath { .. }
| SqlExpr::Aggregate(_)
| SqlExpr::Literal(_)
| SqlExpr::Param { .. }
| SqlExpr::NullTest { .. }
| SqlExpr::Like { .. }
| SqlExpr::FunctionCall { .. }
| SqlExpr::Unary { .. }
| SqlExpr::Case { .. } => false,
}
}
fn derive_sql_binary_compare_predicate(
op: SqlExprBinaryOp,
left: &SqlExpr,
right: &SqlExpr,
) -> Option<Predicate> {
let op = sql_compare_op(op)?;
if matches!(left, SqlExpr::Literal(Value::Null))
|| matches!(right, SqlExpr::Literal(Value::Null))
{
return None;
}
match (left, right) {
(SqlExpr::Field(field), SqlExpr::Literal(value)) => {
Some(Predicate::Compare(ComparePredicate::with_coercion(
field,
op,
value.clone(),
sql_compare_literal_coercion(op, value),
)))
}
(SqlExpr::Literal(value), SqlExpr::Field(field)) => {
let op = op.flipped();
Some(Predicate::Compare(ComparePredicate::with_coercion(
field,
op,
value.clone(),
sql_compare_literal_coercion(op, value),
)))
}
(SqlExpr::Field(left_field), SqlExpr::Field(right_field)) => Some(
Predicate::CompareFields(CompareFieldsPredicate::with_coercion(
left_field,
op,
right_field,
sql_compare_field_coercion(op),
)),
),
(
SqlExpr::FunctionCall {
function: SqlScalarFunction::Lower,
args,
},
SqlExpr::Literal(Value::Text(value)),
) => {
let [SqlExpr::Field(field)] = args.as_slice() else {
return None;
};
Some(Predicate::Compare(ComparePredicate::with_coercion(
field,
op,
Value::Text(value.clone()),
CoercionId::TextCasefold,
)))
}
_ => None,
}
}
const fn sql_compare_op(op: SqlExprBinaryOp) -> Option<CompareOp> {
match op {
SqlExprBinaryOp::Eq => Some(CompareOp::Eq),
SqlExprBinaryOp::Ne => Some(CompareOp::Ne),
SqlExprBinaryOp::Lt => Some(CompareOp::Lt),
SqlExprBinaryOp::Lte => Some(CompareOp::Lte),
SqlExprBinaryOp::Gt => Some(CompareOp::Gt),
SqlExprBinaryOp::Gte => Some(CompareOp::Gte),
SqlExprBinaryOp::Or
| SqlExprBinaryOp::And
| SqlExprBinaryOp::Add
| SqlExprBinaryOp::Sub
| SqlExprBinaryOp::Mul
| SqlExprBinaryOp::Div => None,
}
}
const fn sql_compare_literal_coercion(op: CompareOp, value: &Value) -> CoercionId {
match value {
Value::Text(_) | Value::Nat64(_) | Value::Nat128(_) | Value::NatBig(_) => {
CoercionId::Strict
}
Value::Float32(_) | Value::Float64(_) | Value::Decimal(_) => {
if op.is_ordering_family() {
CoercionId::NumericWiden
} else {
CoercionId::Strict
}
}
_ if value.supports_numeric_coercion() => CoercionId::NumericWiden,
_ => CoercionId::Strict,
}
}
fn sql_compare_field_coercion(op: CompareOp) -> CoercionId {
if !op.supports_field_compare() {
unreachable!("sql predicate lowering invariant");
}
if op.is_ordering_family() {
CoercionId::NumericWiden
} else {
CoercionId::Strict
}
}
pub(in crate::db::sql::lowering) fn lower_sql_where_bool_expr(
expr: &SqlExpr,
) -> Result<Expr, SqlLoweringError> {
lower_sql_bool_expr_internal(expr, false, SqlExprPhase::Where)
}
pub(in crate::db::sql::lowering) fn lower_sql_pre_aggregate_bool_expr(
expr: &SqlExpr,
) -> Result<Expr, SqlLoweringError> {
lower_sql_bool_expr_internal(expr, false, SqlExprPhase::PreAggregate)
}
pub(in crate::db::sql::lowering) fn lower_sql_scalar_where_bool_expr(
expr: &SqlExpr,
) -> Result<Expr, SqlLoweringError> {
lower_sql_bool_expr_internal(expr, true, SqlExprPhase::Where)
}
fn lower_sql_bool_expr_internal(
expr: &SqlExpr,
scalar_case_canonicalization: bool,
phase: SqlExprPhase,
) -> Result<Expr, SqlLoweringError> {
let expr = lower_sql_expr(expr, phase)?;
validate::validate_where_bool_expr(&expr)?;
let expr = if scalar_case_canonicalization {
normalize::normalize_scalar_where_bool_expr(expr)
} else {
normalize::normalize_where_bool_expr(expr)
};
debug_assert!(
validate::validate_where_bool_expr(&expr).is_ok(),
"WHERE normalization must not widen or narrow clause admissibility",
);
debug_assert!(is_normalized_bool_expr(&expr));
Ok(expr)
}
fn derive_top_level_sql_membership_predicate_subset(expr: &SqlExpr) -> Option<Predicate> {
let SqlExpr::Membership {
expr,
values,
negated,
} = expr
else {
return None;
};
let target_op = if *negated {
CompareOp::NotIn
} else {
CompareOp::In
};
let (field, fixed_coercion) = sql_membership_target(expr.as_ref())?;
let leaves = values
.iter()
.map(|value| {
let coercion = sql_membership_value_coercion(value, fixed_coercion)?;
Some(MembershipCompareLeaf::new(field, value.clone(), coercion))
})
.collect::<Option<Vec<_>>>()?;
collapse_membership_compare_leaves(leaves, target_op).map(Predicate::Compare)
}
fn sql_membership_target(expr: &SqlExpr) -> Option<(&str, Option<CoercionId>)> {
match expr {
SqlExpr::Field(field) => Some((field.as_str(), None)),
SqlExpr::FunctionCall {
function: SqlScalarFunction::Lower,
args,
} => match args.as_slice() {
[SqlExpr::Field(field)] => Some((field.as_str(), Some(CoercionId::TextCasefold))),
_ => None,
},
_ => None,
}
}
const fn sql_membership_value_coercion(
value: &Value,
fixed: Option<CoercionId>,
) -> Option<CoercionId> {
match fixed {
Some(CoercionId::TextCasefold) if matches!(value, Value::Text(_)) => {
Some(CoercionId::TextCasefold)
}
Some(_) => None,
None if matches!(value, Value::List(_) | Value::Map(_)) => None,
None if value.supports_numeric_coercion() => Some(CoercionId::NumericWiden),
None => Some(CoercionId::Strict),
}
}