use crate::{
db::{
query::plan::expr::{BinaryOp, Expr, Function, UnaryOp},
sql::lowering::SqlLoweringError,
},
value::Value,
};
pub(super) fn validate_where_bool_expr(expr: &Expr) -> Result<(), SqlLoweringError> {
match expr {
Expr::Field(_) => Ok(()),
Expr::Literal(Value::Bool(_) | Value::Null) => Ok(()),
Expr::Unary {
op: UnaryOp::Not,
expr,
} => validate_where_bool_expr(expr.as_ref()),
Expr::Binary {
op: BinaryOp::And | BinaryOp::Or,
left,
right,
} => {
validate_where_bool_expr(left.as_ref())?;
validate_where_bool_expr(right.as_ref())
}
Expr::Binary { op, left, right } => validate_where_bool_compare_expr(*op, left, right),
Expr::FunctionCall { function, args } => validate_where_bool_function_call(*function, args),
Expr::Case {
when_then_arms,
else_expr,
} => {
for arm in when_then_arms {
validate_where_bool_expr(arm.condition())?;
validate_where_bool_expr(arm.result())?;
}
validate_where_bool_expr(else_expr.as_ref())
}
#[cfg(test)]
Expr::Alias { .. } => Err(SqlLoweringError::unsupported_where_expression()),
Expr::Aggregate(_) | Expr::Literal(_) => {
Err(SqlLoweringError::unsupported_where_expression())
}
}
}
fn validate_where_bool_compare_expr(
op: BinaryOp,
left: &Expr,
right: &Expr,
) -> Result<(), SqlLoweringError> {
match op {
BinaryOp::Eq
| BinaryOp::Ne
| BinaryOp::Lt
| BinaryOp::Lte
| BinaryOp::Gt
| BinaryOp::Gte
if where_compare_operand_is_admitted(left)
&& where_compare_operand_is_admitted(right) =>
{
Ok(())
}
BinaryOp::Eq
| BinaryOp::Ne
| BinaryOp::Lt
| BinaryOp::Lte
| BinaryOp::Gt
| BinaryOp::Gte
| BinaryOp::Or
| BinaryOp::And
| BinaryOp::Add
| BinaryOp::Sub
| BinaryOp::Mul
| BinaryOp::Div => Err(SqlLoweringError::unsupported_where_expression()),
}
}
fn validate_where_bool_function_call(
function: Function,
args: &[Expr],
) -> Result<(), SqlLoweringError> {
match function {
Function::IsNull | Function::IsNotNull => match args {
[arg] if where_null_test_operand_is_admitted(arg) => Ok(()),
_ => Err(SqlLoweringError::unsupported_where_expression()),
},
Function::StartsWith | Function::EndsWith | Function::Contains => match args {
[left, right]
if where_compare_operand_is_admitted(left)
&& where_compare_operand_is_admitted(right) =>
{
Ok(())
}
_ => Err(SqlLoweringError::unsupported_where_expression()),
},
_ => Err(SqlLoweringError::unsupported_where_expression()),
}
}
fn where_compare_operand_is_admitted(expr: &Expr) -> bool {
match expr {
Expr::Field(_) | Expr::Literal(_) => true,
Expr::FunctionCall {
function: Function::Lower | Function::Upper,
args,
} => args.iter().all(where_compare_operand_is_admitted),
Expr::FunctionCall {
function:
Function::Coalesce
| Function::NullIf
| Function::Trim
| Function::Ltrim
| Function::Rtrim
| Function::Abs
| Function::Ceil
| Function::Ceiling
| Function::Floor
| Function::Length
| Function::Left
| Function::Right
| Function::Position
| Function::Replace
| Function::Substring
| Function::Round,
args,
} => args.iter().all(where_compare_operand_is_admitted),
Expr::Binary { op, left, right }
if matches!(
op,
BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div
) =>
{
where_compare_operand_is_admitted(left.as_ref())
&& where_compare_operand_is_admitted(right.as_ref())
}
Expr::Case {
when_then_arms,
else_expr,
} => {
when_then_arms.iter().all(|arm| {
validate_where_bool_expr(arm.condition()).is_ok()
&& where_compare_operand_is_admitted(arm.result())
}) && where_compare_operand_is_admitted(else_expr.as_ref())
}
Expr::Aggregate(_)
| Expr::Unary { .. }
| Expr::FunctionCall { .. }
| Expr::Binary { .. } => false,
#[cfg(test)]
Expr::Alias { .. } => false,
}
}
fn where_null_test_operand_is_admitted(expr: &Expr) -> bool {
where_compare_operand_is_admitted(expr)
}