use crate::utils::split_disjunction;
use crate::{split_conjunction, PhysicalExpr};
use datafusion_common::{Column, ScalarValue};
use datafusion_expr::Operator;
use std::collections::{HashMap, HashSet};
use std::fmt::{self, Display, Formatter};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq)]
pub struct LiteralGuarantee {
pub column: Column,
pub guarantee: Guarantee,
pub literals: HashSet<ScalarValue>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Guarantee {
In,
NotIn,
}
impl LiteralGuarantee {
fn try_new<'a>(
column_name: impl Into<String>,
guarantee: Guarantee,
literals: impl IntoIterator<Item = &'a ScalarValue>,
) -> Option<Self> {
let literals: HashSet<_> = literals.into_iter().cloned().collect();
Some(Self {
column: Column::from_name(column_name),
guarantee,
literals,
})
}
pub fn analyze(expr: &Arc<dyn PhysicalExpr>) -> Vec<LiteralGuarantee> {
split_conjunction(expr)
.into_iter()
.fold(GuaranteeBuilder::new(), |builder, expr| {
if let Some(cel) = ColOpLit::try_new(expr) {
return builder.aggregate_conjunct(cel);
} else if let Some(inlist) = expr
.as_any()
.downcast_ref::<crate::expressions::InListExpr>()
{
let col = inlist
.expr()
.as_any()
.downcast_ref::<crate::expressions::Column>();
let Some(col) = col else {
return builder;
};
let literals = inlist
.list()
.iter()
.map(|e| e.as_any().downcast_ref::<crate::expressions::Literal>())
.collect::<Option<Vec<_>>>();
let Some(literals) = literals else {
return builder;
};
let guarantee = if inlist.negated() {
Guarantee::NotIn
} else {
Guarantee::In
};
builder.aggregate_multi_conjunct(
col,
guarantee,
literals.iter().map(|e| e.value()),
)
} else {
let disjunctions = split_disjunction(expr);
let terms = disjunctions
.iter()
.filter_map(|expr| ColOpLit::try_new(expr))
.collect::<Vec<_>>();
if terms.is_empty() {
return builder;
}
if terms.len() != disjunctions.len() {
return builder;
}
let first_term = &terms[0];
if terms.iter().all(|term| {
term.col.name() == first_term.col.name()
&& term.guarantee == Guarantee::In
}) {
builder.aggregate_multi_conjunct(
first_term.col,
Guarantee::In,
terms.iter().map(|term| term.lit.value()),
)
} else {
builder
}
}
})
.build()
}
}
impl Display for LiteralGuarantee {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self.guarantee {
Guarantee::In => write!(
f,
"{} in ({})",
self.column.name,
self.literals
.iter()
.map(|lit| lit.to_string())
.collect::<Vec<_>>()
.join(", ")
),
Guarantee::NotIn => write!(
f,
"{} not in ({})",
self.column.name,
self.literals
.iter()
.map(|lit| lit.to_string())
.collect::<Vec<_>>()
.join(", ")
),
}
}
}
#[derive(Debug, Default)]
struct GuaranteeBuilder<'a> {
guarantees: Vec<Option<LiteralGuarantee>>,
map: HashMap<(&'a crate::expressions::Column, Guarantee), usize>,
}
impl<'a> GuaranteeBuilder<'a> {
fn new() -> Self {
Default::default()
}
fn aggregate_conjunct(self, col_op_lit: ColOpLit<'a>) -> Self {
self.aggregate_multi_conjunct(
col_op_lit.col,
col_op_lit.guarantee,
[col_op_lit.lit.value()],
)
}
fn aggregate_multi_conjunct(
mut self,
col: &'a crate::expressions::Column,
guarantee: Guarantee,
new_values: impl IntoIterator<Item = &'a ScalarValue>,
) -> Self {
let key = (col, guarantee);
if let Some(index) = self.map.get(&key) {
let entry = &mut self.guarantees[*index];
let Some(existing) = entry else {
return self;
};
match existing.guarantee {
Guarantee::NotIn => {
let new_values: HashSet<_> = new_values.into_iter().collect();
existing.literals.extend(new_values.into_iter().cloned());
}
Guarantee::In => {
let intersection = new_values
.into_iter()
.filter(|new_value| existing.literals.contains(*new_value))
.collect::<Vec<_>>();
if !intersection.is_empty() {
existing.literals = intersection.into_iter().cloned().collect();
} else {
*entry = None;
}
}
}
} else {
let new_values: HashSet<_> = new_values.into_iter().collect();
if let Some(guarantee) =
LiteralGuarantee::try_new(col.name(), guarantee, new_values)
{
self.guarantees.push(Some(guarantee));
self.map.insert(key, self.guarantees.len() - 1);
}
}
self
}
fn build(self) -> Vec<LiteralGuarantee> {
self.guarantees.into_iter().flatten().collect()
}
}
struct ColOpLit<'a> {
col: &'a crate::expressions::Column,
guarantee: Guarantee,
lit: &'a crate::expressions::Literal,
}
impl<'a> ColOpLit<'a> {
fn try_new(expr: &'a Arc<dyn PhysicalExpr>) -> Option<Self> {
let binary_expr = expr
.as_any()
.downcast_ref::<crate::expressions::BinaryExpr>()?;
let (left, op, right) = (
binary_expr.left().as_any(),
binary_expr.op(),
binary_expr.right().as_any(),
);
let guarantee = match op {
Operator::Eq => Guarantee::In,
Operator::NotEq => Guarantee::NotIn,
_ => return None,
};
if let (Some(col), Some(lit)) = (
left.downcast_ref::<crate::expressions::Column>(),
right.downcast_ref::<crate::expressions::Literal>(),
) {
Some(Self {
col,
guarantee,
lit,
})
}
else if let (Some(lit), Some(col)) = (
left.downcast_ref::<crate::expressions::Literal>(),
right.downcast_ref::<crate::expressions::Column>(),
) {
Some(Self {
col,
guarantee,
lit,
})
} else {
None
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::create_physical_expr;
use crate::execution_props::ExecutionProps;
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use datafusion_common::ToDFSchema;
use datafusion_expr::expr_fn::*;
use datafusion_expr::{lit, Expr};
use itertools::Itertools;
use std::sync::OnceLock;
#[test]
fn test_literal() {
test_analyze(lit(true), vec![])
}
#[test]
fn test_single() {
test_analyze(col("a").eq(lit("foo")), vec![in_guarantee("a", ["foo"])]);
test_analyze(lit("foo").eq(col("a")), vec![in_guarantee("a", ["foo"])]);
test_analyze(
col("a").not_eq(lit("foo")),
vec![not_in_guarantee("a", ["foo"])],
);
test_analyze(
lit("foo").not_eq(col("a")),
vec![not_in_guarantee("a", ["foo"])],
);
}
#[test]
fn test_conjunction_single_column() {
test_analyze(col("b").eq(lit(1)).and(col("b").eq(lit(2))), vec![]);
test_analyze(
col("b").eq(lit(1)).and(col("b").not_eq(lit(2))),
vec![
in_guarantee("b", [1]),
not_in_guarantee("b", [2]),
],
);
test_analyze(
col("b").not_eq(lit(1)).and(col("b").eq(lit(2))),
vec![
not_in_guarantee("b", [1]),
in_guarantee("b", [2]),
],
);
test_analyze(
col("b").not_eq(lit(1)).and(col("b").not_eq(lit(2))),
vec![not_in_guarantee("b", [1, 2])],
);
test_analyze(
col("b")
.not_eq(lit(1))
.and(col("b").not_eq(lit(2)))
.and(col("b").not_eq(lit(3))),
vec![not_in_guarantee("b", [1, 2, 3])],
);
test_analyze(
col("b")
.not_eq(lit(1))
.and(col("b").eq(lit(2)))
.and(col("b").not_eq(lit(3))),
vec![not_in_guarantee("b", [1, 3]), in_guarantee("b", [2])],
);
test_analyze(
col("b")
.not_eq(lit(1))
.and(col("b").not_eq(lit(2)))
.and(col("b").eq(lit(3))),
vec![not_in_guarantee("b", [1, 2]), in_guarantee("b", [3])],
);
test_analyze(
col("b")
.not_eq(lit(1))
.and(col("b").not_eq(lit(2)))
.and(col("b").gt(lit(3))),
vec![not_in_guarantee("b", [1, 2])],
);
}
#[test]
fn test_conjunction_multi_column() {
test_analyze(
col("a").eq(lit("foo")).and(col("b").eq(lit(1))),
vec![
in_guarantee("a", ["foo"]),
in_guarantee("b", [1]),
],
);
test_analyze(
col("a").not_eq(lit("foo")).and(col("b").not_eq(lit(1))),
vec![not_in_guarantee("a", ["foo"]), not_in_guarantee("b", [1])],
);
test_analyze(
col("a").eq(lit("foo")).and(col("a").eq(lit("bar"))),
vec![],
);
test_analyze(
col("a").eq(lit("foo")).and(col("a").not_eq(lit("bar"))),
vec![in_guarantee("a", ["foo"]), not_in_guarantee("a", ["bar"])],
);
test_analyze(
col("a").not_eq(lit("foo")).and(col("a").not_eq(lit("bar"))),
vec![not_in_guarantee("a", ["foo", "bar"])],
);
test_analyze(
col("a")
.not_eq(lit("foo"))
.and(col("a").not_eq(lit("bar")))
.and(col("a").not_eq(lit("baz"))),
vec![not_in_guarantee("a", ["foo", "bar", "baz"])],
);
let expr = col("a").eq(lit("foo"));
test_analyze(expr.clone().and(expr), vec![in_guarantee("a", ["foo"])]);
test_analyze(
col("b").gt(lit(5)).and(col("b").eq(lit(10))),
vec![in_guarantee("b", [10])],
);
test_analyze(
col("b").gt(lit(10)).and(col("b").eq(lit(10))),
vec![
in_guarantee("b", [10]),
],
);
test_analyze(
col("a")
.not_eq(lit("foo"))
.and(col("a").not_eq(lit("bar")).or(col("a").not_eq(lit("baz")))),
vec![not_in_guarantee("a", ["foo"])],
);
}
#[test]
fn test_conjunction_and_disjunction_single_column() {
test_analyze(
col("b").not_eq(lit(1)).and(col("b").gt(lit(2))),
vec![
not_in_guarantee("b", [1]),
],
);
test_analyze(
col("b")
.eq(lit(1))
.and(col("b").eq(lit(2)).or(col("b").eq(lit(3)))),
vec![
],
);
}
#[test]
fn test_disjunction_single_column() {
test_analyze(
col("b").eq(lit(1)).or(col("b").eq(lit(2))),
vec![in_guarantee("b", [1, 2])],
);
test_analyze(col("b").not_eq(lit(1)).or(col("b").eq(lit(2))), vec![]);
test_analyze(col("b").eq(lit(1)).or(col("b").not_eq(lit(2))), vec![]);
test_analyze(col("b").not_eq(lit(1)).or(col("b").not_eq(lit(2))), vec![]);
test_analyze(
col("b")
.not_eq(lit(1))
.or(col("b").not_eq(lit(2)))
.or(lit("b").eq(lit(3))),
vec![],
);
test_analyze(
col("b")
.eq(lit(1))
.or(col("b").eq(lit(2)))
.or(col("b").eq(lit(3))),
vec![in_guarantee("b", [1, 2, 3])],
);
test_analyze(
col("b")
.eq(lit(1))
.or(col("b").eq(lit(2)))
.or(lit("b").eq(lit(3))),
vec![],
);
}
#[test]
fn test_disjunction_multi_column() {
test_analyze(
col("a").eq(lit("foo")).or(col("b").eq(lit(1))),
vec![],
);
test_analyze(
col("a").not_eq(lit("foo")).or(col("b").not_eq(lit(1))),
vec![],
);
test_analyze(
col("a").eq(lit("foo")).or(col("a").eq(lit("bar"))),
vec![in_guarantee("a", ["foo", "bar"])],
);
test_analyze(
col("a").eq(lit("foo")).or(col("a").eq(lit("foo"))),
vec![in_guarantee("a", ["foo"])],
);
test_analyze(
col("a").not_eq(lit("foo")).or(col("a").not_eq(lit("bar"))),
vec![],
);
test_analyze(
col("a")
.eq(lit("foo"))
.or(col("a").eq(lit("bar")))
.or(col("a").eq(lit("baz"))),
vec![in_guarantee("a", ["foo", "bar", "baz"])],
);
test_analyze(
(col("a").eq(lit("foo")).or(col("a").eq(lit("bar"))))
.and(col("a").eq(lit("baz"))),
vec![],
);
test_analyze(
(col("a").eq(lit("foo")).or(col("a").eq(lit("bar"))))
.and(col("b").eq(lit(1))),
vec![in_guarantee("a", ["foo", "bar"]), in_guarantee("b", [1])],
);
test_analyze(
col("a")
.eq(lit("foo"))
.or(col("a").eq(lit("bar")))
.or(col("b").eq(lit(1))),
vec![],
);
}
#[test]
fn test_single_inlist() {
test_analyze(
col("b").in_list(vec![lit(1), lit(2), lit(3)], false),
vec![in_guarantee("b", [1, 2, 3])],
);
test_analyze(
col("b").in_list(vec![lit(1), lit(2), lit(3)], true),
vec![not_in_guarantee("b", [1, 2, 3])],
);
test_analyze(
col("b").in_list((1..25).map(lit).collect_vec(), false),
vec![in_guarantee("b", 1..25)],
);
}
#[test]
fn test_inlist_conjunction() {
test_analyze(
col("b")
.in_list(vec![lit(1), lit(2), lit(3)], false)
.and(col("b").in_list(vec![lit(2), lit(3), lit(4)], false)),
vec![in_guarantee("b", [2, 3])],
);
test_analyze(
col("b")
.in_list(vec![lit(1), lit(2), lit(3)], true)
.and(col("b").in_list(vec![lit(2), lit(3), lit(4)], false)),
vec![
not_in_guarantee("b", [1, 2, 3]),
in_guarantee("b", [2, 3, 4]),
],
);
test_analyze(
col("b")
.in_list(vec![lit(1), lit(2), lit(3)], true)
.and(col("b").in_list(vec![lit(2), lit(3), lit(4)], true)),
vec![not_in_guarantee("b", [1, 2, 3, 4])],
);
test_analyze(
col("b")
.in_list(vec![lit(1), lit(2), lit(3)], false)
.and(col("b").eq(lit(4))),
vec![],
);
test_analyze(
col("b")
.in_list(vec![lit(1), lit(2), lit(3)], false)
.and(col("b").eq(lit(2))),
vec![in_guarantee("b", [2])],
);
test_analyze(
col("b")
.in_list(vec![lit(1), lit(2), lit(3)], false)
.and(col("b").not_eq(lit(2))),
vec![in_guarantee("b", [1, 2, 3]), not_in_guarantee("b", [2])],
);
test_analyze(
col("b")
.in_list(vec![lit(1), lit(2), lit(3)], true)
.and(col("b").not_eq(lit(4))),
vec![not_in_guarantee("b", [1, 2, 3, 4])],
);
test_analyze(
col("b")
.in_list(vec![lit(1), lit(2), lit(3)], true)
.and(col("b").not_eq(lit(2))),
vec![not_in_guarantee("b", [1, 2, 3])],
);
}
#[test]
fn test_inlist_with_disjunction() {
test_analyze(
col("b")
.in_list(vec![lit(1), lit(2), lit(3)], false)
.and(col("b").eq(lit(3)).or(col("b").eq(lit(4)))),
vec![in_guarantee("b", [3])],
);
test_analyze(
col("b")
.in_list(vec![lit(1), lit(2), lit(3)], false)
.and(col("b").eq(lit(4)).or(col("b").eq(lit(5)))),
vec![],
);
test_analyze(
col("b")
.in_list(vec![lit(1), lit(2), lit(3)], true)
.and(col("b").eq(lit(3)).or(col("b").eq(lit(4)))),
vec![not_in_guarantee("b", [1, 2, 3]), in_guarantee("b", [3, 4])],
);
test_analyze(
col("b")
.in_list(vec![lit(1), lit(2), lit(3)], false)
.or(col("b").eq(lit(2))),
vec![],
);
test_analyze(
col("b")
.in_list(vec![lit(1), lit(2), lit(3)], false)
.or(col("b").not_eq(lit(3))),
vec![],
);
}
fn test_analyze(expr: Expr, expected: Vec<LiteralGuarantee>) {
println!("Begin analyze of {expr}");
let schema = schema();
let physical_expr = logical2physical(&expr, &schema);
let actual = LiteralGuarantee::analyze(&physical_expr);
assert_eq!(
expected, actual,
"expr: {expr}\
\n\nexpected: {expected:#?}\
\n\nactual: {actual:#?}\
\n\nexpr: {expr:#?}\
\n\nphysical_expr: {physical_expr:#?}"
);
}
fn in_guarantee<'a, I, S>(column: &str, literals: I) -> LiteralGuarantee
where
I: IntoIterator<Item = S>,
S: Into<ScalarValue> + 'a,
{
let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect();
LiteralGuarantee::try_new(column, Guarantee::In, literals.iter()).unwrap()
}
fn not_in_guarantee<'a, I, S>(column: &str, literals: I) -> LiteralGuarantee
where
I: IntoIterator<Item = S>,
S: Into<ScalarValue> + 'a,
{
let literals: Vec<_> = literals.into_iter().map(|s| s.into()).collect();
LiteralGuarantee::try_new(column, Guarantee::NotIn, literals.iter()).unwrap()
}
fn logical2physical(expr: &Expr, schema: &Schema) -> Arc<dyn PhysicalExpr> {
let df_schema = schema.clone().to_dfschema().unwrap();
let execution_props = ExecutionProps::new();
create_physical_expr(expr, &df_schema, &execution_props).unwrap()
}
fn schema() -> SchemaRef {
SCHEMA
.get_or_init(|| {
Arc::new(Schema::new(vec![
Field::new("a", DataType::Utf8, false),
Field::new("b", DataType::Int32, false),
]))
})
.clone()
}
static SCHEMA: OnceLock<SchemaRef> = OnceLock::new();
}