use std::cell::RefCell;
use std::iter;
use itertools::Itertools;
use vortex_utils::aliases::hash_map::HashMap;
use super::relation::Relation;
use crate::dtype::Field;
use crate::dtype::FieldName;
use crate::dtype::FieldPath;
use crate::dtype::FieldPathSet;
use crate::expr::Expression;
use crate::expr::StatsCatalog;
use crate::expr::get_item;
use crate::expr::root;
use crate::expr::stats::Stat;
pub type RequiredStats = Relation<FieldPath, Stat>;
#[derive(Default)]
pub(crate) struct TrackingStatsCatalog {
usage: RefCell<HashMap<(FieldPath, Stat), Expression>>,
}
impl TrackingStatsCatalog {
fn into_usages(self) -> HashMap<(FieldPath, Stat), Expression> {
self.usage.into_inner()
}
}
struct ScopeStatsCatalog<'a> {
inner: TrackingStatsCatalog,
available_stats: &'a FieldPathSet,
}
impl StatsCatalog for ScopeStatsCatalog<'_> {
fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option<Expression> {
let stat_path = field_path.clone().push(stat.name());
if self.available_stats.contains(&stat_path) {
self.inner.stats_ref(field_path, stat)
} else {
None
}
}
}
impl StatsCatalog for TrackingStatsCatalog {
fn stats_ref(&self, field_path: &FieldPath, stat: Stat) -> Option<Expression> {
let mut expr = root();
let name = field_path_stat_field_name(field_path, stat);
expr = get_item(name, expr);
self.usage
.borrow_mut()
.insert((field_path.clone(), stat), expr.clone());
Some(expr)
}
}
#[doc(hidden)]
pub fn field_path_stat_field_name(field_path: &FieldPath, stat: Stat) -> FieldName {
field_path
.parts()
.iter()
.map(|f| match f {
Field::Name(n) => n.as_ref(),
Field::ElementType => todo!("element type not currently handled"),
})
.chain(iter::once(stat.name()))
.join("_")
.into()
}
pub fn checked_pruning_expr(
expr: &Expression,
available_stats: &FieldPathSet,
) -> Option<(Expression, RequiredStats)> {
let catalog = ScopeStatsCatalog {
inner: Default::default(),
available_stats,
};
let expr = expr.stat_falsification(&catalog)?;
let mut relation: Relation<FieldPath, Stat> = Relation::new();
for ((field_path, stat), _) in catalog.inner.into_usages() {
relation.insert(field_path, stat)
}
Some((expr, relation))
}
#[cfg(test)]
mod tests {
use rstest::fixture;
use rstest::rstest;
use vortex_utils::aliases::hash_set::HashSet;
use super::HashMap;
use crate::dtype::DType;
use crate::dtype::FieldName;
use crate::dtype::FieldNames;
use crate::dtype::FieldPath;
use crate::dtype::FieldPathSet;
use crate::dtype::Nullability;
use crate::dtype::StructFields;
use crate::expr::and;
use crate::expr::between;
use crate::expr::cast;
use crate::expr::col;
use crate::expr::eq;
use crate::expr::get_item;
use crate::expr::gt;
use crate::expr::gt_eq;
use crate::expr::lit;
use crate::expr::lt;
use crate::expr::lt_eq;
use crate::expr::not_eq;
use crate::expr::or;
use crate::expr::pruning::checked_pruning_expr;
use crate::expr::pruning::field_path_stat_field_name;
use crate::expr::root;
use crate::expr::stats::Stat;
use crate::scalar_fn::fns::between::BetweenOptions;
use crate::scalar_fn::fns::between::StrictComparison;
#[fixture]
fn available_stats() -> FieldPathSet {
let field_a = FieldPath::from_name("a");
let field_b = FieldPath::from_name("b");
FieldPathSet::from_iter([
field_a.clone().push(Stat::Min.name()),
field_a.push(Stat::Max.name()),
field_b.clone().push(Stat::Min.name()),
field_b.push(Stat::Max.name()),
])
}
#[rstest]
pub fn pruning_equals(available_stats: FieldPathSet) {
let name = FieldName::from("a");
let literal_eq = lit(42);
let eq_expr = eq(get_item("a", root()), literal_eq.clone());
let (converted, _refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
let expected_expr = or(
gt(
get_item(
field_path_stat_field_name(&FieldPath::from_name(name.clone()), Stat::Min),
root(),
),
literal_eq.clone(),
),
gt(
literal_eq,
col(field_path_stat_field_name(
&FieldPath::from_name(name),
Stat::Max,
)),
),
);
assert_eq!(&converted, &expected_expr);
}
#[rstest]
pub fn pruning_equals_column(available_stats: FieldPathSet) {
let column = FieldName::from("a");
let other_col = FieldName::from("b");
let eq_expr = eq(col(column.clone()), col(other_col.clone()));
let (converted, refs) = checked_pruning_expr(&eq_expr, &available_stats).unwrap();
assert_eq!(
refs.map(),
&HashMap::from_iter([
(
FieldPath::from_name(column.clone()),
HashSet::from_iter([Stat::Min, Stat::Max])
),
(
FieldPath::from_name(other_col.clone()),
HashSet::from_iter([Stat::Max, Stat::Min])
)
])
);
let expected_expr = or(
gt(
col(field_path_stat_field_name(
&FieldPath::from_name(column.clone()),
Stat::Min,
)),
col(field_path_stat_field_name(
&FieldPath::from_name(other_col.clone()),
Stat::Max,
)),
),
gt(
col(field_path_stat_field_name(
&FieldPath::from_name(other_col),
Stat::Min,
)),
col(field_path_stat_field_name(
&FieldPath::from_name(column),
Stat::Max,
)),
),
);
assert_eq!(&converted, &expected_expr);
}
#[rstest]
pub fn pruning_not_equals_column(available_stats: FieldPathSet) {
let column = FieldName::from("a");
let other_col = FieldName::from("b");
let not_eq_expr = not_eq(col(column.clone()), col(other_col.clone()));
let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
assert_eq!(
refs.map(),
&HashMap::from_iter([
(
FieldPath::from_name(column.clone()),
HashSet::from_iter([Stat::Min, Stat::Max])
),
(
FieldPath::from_name(other_col.clone()),
HashSet::from_iter([Stat::Max, Stat::Min])
)
])
);
let expected_expr = and(
eq(
col(field_path_stat_field_name(
&FieldPath::from_name(column.clone()),
Stat::Min,
)),
col(field_path_stat_field_name(
&FieldPath::from_name(other_col.clone()),
Stat::Max,
)),
),
eq(
col(field_path_stat_field_name(
&FieldPath::from_name(column),
Stat::Max,
)),
col(field_path_stat_field_name(
&FieldPath::from_name(other_col),
Stat::Min,
)),
),
);
assert_eq!(&converted, &expected_expr);
}
#[rstest]
pub fn pruning_gt_column(available_stats: FieldPathSet) {
let column = FieldName::from("a");
let other_col = FieldName::from("b");
let other_expr = col(other_col.clone());
let not_eq_expr = gt(col(column.clone()), other_expr);
let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
assert_eq!(
refs.map(),
&HashMap::from_iter([
(
FieldPath::from_name(column.clone()),
HashSet::from_iter([Stat::Max])
),
(
FieldPath::from_name(other_col.clone()),
HashSet::from_iter([Stat::Min])
)
])
);
let expected_expr = lt_eq(
col(field_path_stat_field_name(
&FieldPath::from_name(column),
Stat::Max,
)),
col(field_path_stat_field_name(
&FieldPath::from_name(other_col),
Stat::Min,
)),
);
assert_eq!(&converted, &expected_expr);
}
#[rstest]
pub fn pruning_gt_value(available_stats: FieldPathSet) {
let column = FieldName::from("a");
let other_col = lit(42);
let not_eq_expr = gt(col(column.clone()), other_col.clone());
let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
assert_eq!(
refs.map(),
&HashMap::from_iter([(
FieldPath::from_name(column.clone()),
HashSet::from_iter([Stat::Max])
),])
);
let expected_expr = lt_eq(
col(field_path_stat_field_name(
&FieldPath::from_name(column),
Stat::Max,
)),
other_col,
);
assert_eq!(&converted, &(expected_expr));
}
#[rstest]
pub fn pruning_lt_column(available_stats: FieldPathSet) {
let column = FieldName::from("a");
let other_col = FieldName::from("b");
let other_expr = col(other_col.clone());
let not_eq_expr = lt(col(column.clone()), other_expr);
let (converted, refs) = checked_pruning_expr(¬_eq_expr, &available_stats).unwrap();
assert_eq!(
refs.map(),
&HashMap::from_iter([
(
FieldPath::from_name(column.clone()),
HashSet::from_iter([Stat::Min])
),
(
FieldPath::from_name(other_col.clone()),
HashSet::from_iter([Stat::Max])
)
])
);
let expected_expr = gt_eq(
col(field_path_stat_field_name(
&FieldPath::from_name(column),
Stat::Min,
)),
col(field_path_stat_field_name(
&FieldPath::from_name(other_col),
Stat::Max,
)),
);
assert_eq!(&converted, &expected_expr);
}
#[rstest]
pub fn pruning_lt_value(available_stats: FieldPathSet) {
let expr = lt(col("a"), lit(42));
let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
assert_eq!(
refs.map(),
&HashMap::from_iter([(FieldPath::from_name("a"), HashSet::from_iter([Stat::Min]))])
);
assert_eq!(&converted, >_eq(col("a_min"), lit(42)));
}
#[rstest]
fn pruning_identity(available_stats: FieldPathSet) {
let expr = or(lt(col("a"), lit(10)), gt(col("a"), lit(50)));
let (predicate, _) = checked_pruning_expr(&expr, &available_stats).unwrap();
let expected_expr = and(gt_eq(col("a_min"), lit(10)), lt_eq(col("a_max"), lit(50)));
assert_eq!(&predicate.to_string(), &expected_expr.to_string());
}
#[rstest]
pub fn pruning_and_or_operators(available_stats: FieldPathSet) {
let column = FieldName::from("a");
let and_expr = and(gt(col(column.clone()), lit(10)), lt(col(column), lit(50)));
let (predicate, _) = checked_pruning_expr(&and_expr, &available_stats).unwrap();
assert_eq!(
&predicate,
&or(
lt_eq(col(FieldName::from("a_max")), lit(10)),
gt_eq(col(FieldName::from("a_min")), lit(50)),
),
);
}
#[rstest]
fn test_gt_eq_with_booleans(available_stats: FieldPathSet) {
let expr = gt_eq(col("x"), gt(col("y"), col("z")));
assert!(checked_pruning_expr(&expr, &available_stats).is_none());
}
#[fixture]
fn available_stats_with_nans() -> FieldPathSet {
let float_col = FieldPath::from_name("float_col");
let int_col = FieldPath::from_name("int_col");
FieldPathSet::from_iter([
float_col.clone().push(Stat::Min.name()),
float_col.clone().push(Stat::Max.name()),
float_col.push(Stat::NaNCount.name()),
int_col.clone().push(Stat::Min.name()),
int_col.push(Stat::Max.name()),
])
}
#[rstest]
fn pruning_checks_nans(available_stats_with_nans: FieldPathSet) {
let expr = gt_eq(col("float_col"), lit(f32::NAN));
let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
assert_eq!(
&converted,
&and(
and(
eq(col("float_col_nan_count"), lit(0u64)),
eq(lit(1u64), lit(0u64)),
),
lt(col("float_col_max"), lit(f32::NAN)),
)
);
let expr = and(
gt(col("float_col"), lit(10f32)),
lt(col("int_col"), lit(10)),
);
let (converted, _) = checked_pruning_expr(&expr, &available_stats_with_nans).unwrap();
assert_eq!(
&converted,
&or(
and(
and(
eq(col("float_col_nan_count"), lit(0u64)),
eq(lit(0u64), lit(0u64)),
),
lt_eq(col("float_col_max"), lit(10f32)),
),
gt_eq(col("int_col_min"), lit(10)),
)
)
}
#[rstest]
fn pruning_between(available_stats: FieldPathSet) {
let expr = between(
col("a"),
lit(10),
lit(50),
BetweenOptions {
lower_strict: StrictComparison::NonStrict,
upper_strict: StrictComparison::NonStrict,
},
);
let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
assert_eq!(
refs.map(),
&HashMap::from_iter([(
FieldPath::from_name("a"),
HashSet::from_iter([Stat::Min, Stat::Max])
)])
);
assert_eq!(
&converted,
&or(gt(lit(10), col("a_max")), gt(col("a_min"), lit(50)))
);
}
#[rstest]
fn pruning_cast_get_item_eq(available_stats: FieldPathSet) {
let struct_dtype = DType::Struct(
StructFields::new(
FieldNames::from([FieldName::from("a"), FieldName::from("b")]),
vec![
DType::Utf8(Nullability::Nullable),
DType::Utf8(Nullability::Nullable),
],
),
Nullability::NonNullable,
);
let expr = eq(get_item("a", cast(root(), struct_dtype)), lit("value"));
let (converted, refs) = checked_pruning_expr(&expr, &available_stats).unwrap();
assert_eq!(
refs.map(),
&HashMap::from_iter([(
FieldPath::from_name("a"),
HashSet::from_iter([Stat::Min, Stat::Max])
)])
);
assert_eq!(
&converted,
&or(
gt(col("a_min"), lit("value")),
gt(lit("value"), col("a_max"))
)
);
}
}