use crate::{
db::{
query::plan::expr::{BinaryOp, Expr, Function, UnaryOp},
sql::lowering::SqlLoweringError,
},
value::Value,
};
pub(super) fn validate_where_bool_expr(expr: &Expr) -> Result<(), SqlLoweringError> {
match expr {
Expr::Field(_) => Ok(()),
Expr::Literal(Value::Bool(_) | Value::Null) => Ok(()),
Expr::Unary {
op: UnaryOp::Not,
expr,
} => validate_where_bool_expr(expr.as_ref()),
Expr::Binary {
op: BinaryOp::And | BinaryOp::Or,
left,
right,
} => {
validate_where_bool_expr(left.as_ref())?;
validate_where_bool_expr(right.as_ref())
}
Expr::Binary { op, left, right } => validate_where_bool_compare_expr(*op, left, right),
Expr::FunctionCall { function, args } => validate_where_bool_function_call(*function, args),
Expr::Case {
when_then_arms,
else_expr,
} => {
for arm in when_then_arms {
validate_where_bool_expr(arm.condition())?;
validate_where_bool_expr(arm.result())?;
}
validate_where_bool_expr(else_expr.as_ref())
}
#[cfg(test)]
Expr::Alias { .. } => Err(SqlLoweringError::unsupported_where_expression()),
Expr::Aggregate(_) | Expr::Literal(_) => {
Err(SqlLoweringError::unsupported_where_expression())
}
}
}
fn validate_where_bool_compare_expr(
op: BinaryOp,
left: &Expr,
right: &Expr,
) -> Result<(), SqlLoweringError> {
match op {
BinaryOp::Eq
| BinaryOp::Ne
| BinaryOp::Lt
| BinaryOp::Lte
| BinaryOp::Gt
| BinaryOp::Gte
if where_compare_operand_is_admitted(left)
&& where_compare_operand_is_admitted(right) =>
{
Ok(())
}
BinaryOp::Eq
| BinaryOp::Ne
| BinaryOp::Lt
| BinaryOp::Lte
| BinaryOp::Gt
| BinaryOp::Gte
| BinaryOp::Or
| BinaryOp::And
| BinaryOp::Add
| BinaryOp::Sub
| BinaryOp::Mul
| BinaryOp::Div => Err(SqlLoweringError::unsupported_where_expression()),
}
}
fn validate_where_bool_function_call(
function: Function,
args: &[Expr],
) -> Result<(), SqlLoweringError> {
match function {
Function::IsNull | Function::IsNotNull => match args {
[arg] if where_null_test_operand_is_admitted(arg) => Ok(()),
_ => Err(SqlLoweringError::unsupported_where_expression()),
},
Function::StartsWith | Function::EndsWith | Function::Contains => match args {
[left, Expr::Literal(Value::Text(_))] if where_text_target_is_admitted(left) => Ok(()),
_ => Err(SqlLoweringError::unsupported_where_expression()),
},
_ => Err(SqlLoweringError::unsupported_where_expression()),
}
}
fn where_compare_operand_is_admitted(expr: &Expr) -> bool {
match expr {
Expr::Field(_) | Expr::Literal(_) => true,
Expr::FunctionCall {
function: Function::Lower | Function::Upper,
args,
} => matches!(args.as_slice(), [Expr::Field(_)]),
Expr::Aggregate(_)
| Expr::Unary { .. }
| Expr::Binary { .. }
| Expr::Case { .. }
| Expr::FunctionCall { .. } => false,
#[cfg(test)]
Expr::Alias { .. } => false,
}
}
const fn where_null_test_operand_is_admitted(expr: &Expr) -> bool {
matches!(expr, Expr::Field(_) | Expr::Literal(_))
}
fn where_text_target_is_admitted(expr: &Expr) -> bool {
match expr {
Expr::Field(_) => true,
Expr::FunctionCall {
function: Function::Lower | Function::Upper,
args,
} => matches!(args.as_slice(), [Expr::Field(_)]),
Expr::Aggregate(_)
| Expr::Literal(_)
| Expr::Unary { .. }
| Expr::Binary { .. }
| Expr::Case { .. }
| Expr::FunctionCall { .. } => false,
#[cfg(test)]
Expr::Alias { .. } => false,
}
}