use crate::expr::Expression;
use crate::expr::and_collect;
use crate::expr::forms::conjuncts;
use crate::expr::lit;
use crate::scalar_fn::ScalarFnVTableExt;
use crate::scalar_fn::fns::between::Between;
use crate::scalar_fn::fns::between::BetweenOptions;
use crate::scalar_fn::fns::between::StrictComparison;
use crate::scalar_fn::fns::binary::Binary;
use crate::scalar_fn::fns::get_item::GetItem;
use crate::scalar_fn::fns::literal::Literal;
use crate::scalar_fn::fns::operators::Operator;
pub fn find_between(expr: Expression) -> Expression {
let mut conjuncts = conjuncts(&expr);
let mut rest = vec![];
for idx in 0..conjuncts.len() {
let Some(c) = conjuncts.get(idx).cloned() else {
continue;
};
let mut matched = false;
for idx2 in (idx + 1)..conjuncts.len() {
let Some(c2) = conjuncts.get(idx2) else {
continue;
};
if let Some(expr) = maybe_match(&c, c2) {
rest.push(expr);
conjuncts.remove(idx2);
matched = true;
break;
}
}
if !matched {
rest.push(c.clone())
}
}
and_collect(rest).unwrap_or_else(|| lit(true))
}
fn maybe_match(lhs: &Expression, rhs: &Expression) -> Option<Expression> {
let (Some(lhs_op), Some(rhs_op)) = (lhs.as_opt::<Binary>(), rhs.as_opt::<Binary>()) else {
return None;
};
let lhs_lhs = lhs.child(0);
let lhs_rhs = lhs.child(1);
let rhs_lhs = rhs.child(0);
let rhs_rhs = rhs.child(1);
if lhs_lhs.eq(lhs_rhs) || rhs_lhs.eq(rhs_rhs) {
return None;
}
let lhs = match (lhs_lhs.is::<GetItem>(), lhs_rhs.is::<GetItem>()) {
(true, false) => lhs.clone(),
(false, true) => Binary.new_expr(lhs_op.swap()?, [lhs_rhs.clone(), lhs_lhs.clone()]),
_ => return None,
};
let lhs_op = lhs.as_::<Binary>();
let lhs_lhs = lhs.child(0);
let rhs = match (rhs_lhs.is::<GetItem>(), rhs_rhs.is::<GetItem>()) {
(true, false) => rhs.clone(),
(false, true) => Binary.new_expr(rhs_op.swap()?, [rhs_rhs.clone(), rhs_lhs.clone()]),
_ => return None,
};
let rhs_op = rhs.as_::<Binary>();
let rhs_lhs = rhs.child(0);
if !lhs_lhs.eq(rhs_lhs) {
return None;
}
let target = lhs_lhs.clone();
let (lower, upper) = match (lhs_op, rhs_op) {
(Operator::Lt | Operator::Lte, Operator::Gt | Operator::Gte) => (rhs, lhs),
(Operator::Gt | Operator::Gte, Operator::Lt | Operator::Lte) => (lhs, rhs),
_ => return None,
};
let lower_op = lower.as_::<Binary>();
let lower_rhs = lower.child(1);
let upper_op = upper.as_::<Binary>();
let upper_rhs = upper.child(1);
let _ = lower_rhs.as_opt::<Literal>()?;
let _ = upper_rhs.as_opt::<Literal>()?;
let lower_strict = is_strict_comparison(*lower_op)?;
let upper_strict = is_strict_comparison(*upper_op)?;
Some(Between.new_expr(
BetweenOptions {
lower_strict,
upper_strict,
},
[target, lower_rhs.clone(), upper_rhs.clone()],
))
}
fn is_strict_comparison(op: Operator) -> Option<StrictComparison> {
match op {
Operator::Lt | Operator::Gt => Some(StrictComparison::Strict),
Operator::Lte | Operator::Gte => Some(StrictComparison::NonStrict),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::find_between;
use crate::expr::and;
use crate::expr::between;
use crate::expr::col;
use crate::expr::gt;
use crate::expr::gt_eq;
use crate::expr::lit;
use crate::expr::lt;
use crate::expr::lt_eq;
use crate::scalar_fn::fns::between::BetweenOptions;
use crate::scalar_fn::fns::between::StrictComparison;
#[test]
fn test_bad_match() {
let expr = and(lt_eq(lit(100), col("x")), gt(lit(-100), col("x")));
let find = find_between(expr);
assert_eq!(
&find,
&between(
col("x"),
lit(100),
lit(-100),
BetweenOptions {
lower_strict: StrictComparison::NonStrict,
upper_strict: StrictComparison::Strict,
}
)
);
}
#[test]
fn test_match_between() {
let expr = and(lt(lit(2), col("x")), gt_eq(lit(5), col("x")));
let find = find_between(expr);
assert_eq!(
&between(
col("x"),
lit(2),
lit(5),
BetweenOptions {
lower_strict: StrictComparison::Strict,
upper_strict: StrictComparison::NonStrict,
}
),
&find
);
}
#[test]
fn test_match_2_between() {
let expr = and(gt_eq(col("x"), lit(2)), lt(col("x"), lit(5)));
let find = find_between(expr);
assert_eq!(
&between(
col("x"),
lit(2),
lit(5),
BetweenOptions {
lower_strict: StrictComparison::NonStrict,
upper_strict: StrictComparison::Strict,
}
),
&find
);
}
#[test]
fn test_match_3_between() {
let expr = and(gt_eq(col("x"), lit(2)), gt_eq(lit(5), col("x")));
let find = find_between(expr);
assert_eq!(
&between(
col("x"),
lit(2),
lit(5),
BetweenOptions {
lower_strict: StrictComparison::NonStrict,
upper_strict: StrictComparison::NonStrict,
}
),
&find
);
}
#[test]
fn test_match_4_between() {
let expr = and(gt_eq(lit(5), col("x")), lt(lit(2), col("x")));
let find = find_between(expr);
assert_eq!(
&between(
col("x"),
lit(2),
lit(5),
BetweenOptions {
lower_strict: StrictComparison::Strict,
upper_strict: StrictComparison::NonStrict,
}
),
&find
);
}
#[test]
fn test_match_5_between() {
let expr = and(
and(gt_eq(col("y"), lit(10)), gt_eq(lit(5), col("x"))),
lt(lit(2), col("x")),
);
let find = find_between(expr);
assert_eq!(
&and(
gt_eq(col("y"), lit(10)),
between(
col("x"),
lit(2),
lit(5),
BetweenOptions {
lower_strict: StrictComparison::Strict,
upper_strict: StrictComparison::NonStrict,
}
)
),
&find
);
}
#[test]
fn test_match_6_between() {
let expr = and(
and(gt_eq(lit(5), col("x")), gt_eq(col("y"), lit(10))),
lt(lit(2), col("x")),
);
let find = find_between(expr);
assert_eq!(
&and(
between(
col("x"),
lit(2),
lit(5),
BetweenOptions {
lower_strict: StrictComparison::Strict,
upper_strict: StrictComparison::NonStrict,
}
),
gt_eq(col("y"), lit(10)),
),
&find
);
}
}