use vortex_error::VortexResult;
use crate::aggregate_fn::AggregateFnRef;
use crate::dtype::DType;
use crate::expr::Expression;
use crate::expr::lit;
use crate::expr::traversal::NodeExt;
use crate::expr::traversal::Transformed;
use crate::scalar::Scalar;
use crate::scalar_fn::fns::stat::StatFn;
pub trait StatBinder {
fn scope(&self) -> &DType;
fn bind_aggregate(
&self,
input: &Expression,
aggregate_fn: &AggregateFnRef,
stat_dtype: &DType,
) -> VortexResult<Option<Expression>>;
fn missing_stat(&self, dtype: DType) -> VortexResult<Expression> {
Ok(null_expr(dtype))
}
}
pub fn bind_stats<B: StatBinder + ?Sized>(
predicate: Expression,
binder: &B,
) -> VortexResult<Expression> {
let scope = binder.scope().clone();
Ok(predicate
.transform_down(|expr| {
if !expr.is::<StatFn>() {
return Ok(Transformed::no(expr));
}
match bind_stat_fn(&expr, &scope, binder)? {
Some(bound) => Ok(Transformed::yes(bound)),
None => {
let dtype = expr.return_dtype(&scope)?;
Ok(Transformed::yes(binder.missing_stat(dtype)?))
}
}
})?
.into_inner())
}
fn bind_stat_fn(
expr: &Expression,
scope: &DType,
binder: &(impl StatBinder + ?Sized),
) -> VortexResult<Option<Expression>> {
let options = expr.as_::<StatFn>();
let aggregate_fn = options.aggregate_fn();
let input = expr.child(0);
let stat_dtype = expr.return_dtype(scope)?;
binder.bind_aggregate(input, aggregate_fn, &stat_dtype)
}
fn null_expr(dtype: DType) -> Expression {
lit(Scalar::null(dtype.as_nullable()))
}
#[cfg(test)]
mod tests {
use vortex_error::VortexResult;
use super::*;
use crate::dtype::Nullability;
use crate::dtype::PType;
use crate::dtype::StructFields;
use crate::expr::and;
use crate::expr::col;
use crate::expr::get_item;
use crate::expr::is_null;
use crate::expr::or;
use crate::expr::root;
use crate::expr::stats::Stat;
use crate::stats::all_non_nan;
use crate::stats::nan_count;
struct TestBinder {
input_scope: DType,
bind_nan_count: bool,
}
impl TestBinder {
fn new(bind_nan_count: bool) -> Self {
Self {
input_scope: DType::Struct(
StructFields::from_iter([(
"f",
DType::Primitive(PType::F32, Nullability::NonNullable),
)]),
Nullability::NonNullable,
),
bind_nan_count,
}
}
}
impl StatBinder for TestBinder {
fn scope(&self) -> &DType {
&self.input_scope
}
fn bind_aggregate(
&self,
_input: &Expression,
aggregate_fn: &AggregateFnRef,
_stat_dtype: &DType,
) -> VortexResult<Option<Expression>> {
let Some(stat) = Stat::from_aggregate_fn(aggregate_fn) else {
return Ok(None);
};
if stat == Stat::NaNCount && self.bind_nan_count {
Ok(Some(get_item("f_nan_count", root())))
} else {
Ok(None)
}
}
}
#[test]
fn nan_count_binds_to_direct_stat_slot() -> VortexResult<()> {
let binder = TestBinder::new(true);
let bound = bind_stats(nan_count(col("f")), &binder)?;
assert_eq!(bound, col("f_nan_count"));
Ok(())
}
#[test]
fn all_non_nan_does_not_derive_from_nan_count() -> VortexResult<()> {
let binder = TestBinder::new(true);
let bound = bind_stats(all_non_nan(col("f")), &binder)?;
assert_eq!(bound, lit(Scalar::null(DType::Bool(Nullability::Nullable))));
Ok(())
}
#[test]
fn missing_stats_bind_to_null_without_reducing() -> VortexResult<()> {
let binder = TestBinder::new(false);
let null_bool = lit(Scalar::null(DType::Bool(Nullability::Nullable)));
let bound = bind_stats(and(lit(false), all_non_nan(col("f"))), &binder)?;
assert_eq!(bound, and(lit(false), null_bool.clone()));
let bound = bind_stats(or(lit(true), all_non_nan(col("f"))), &binder)?;
assert_eq!(bound, or(lit(true), null_bool));
Ok(())
}
#[test]
fn unrelated_expressions_do_not_request_nan_count() -> VortexResult<()> {
let binder = TestBinder::new(false);
let bound = bind_stats(is_null(col("f")), &binder)?;
assert_eq!(bound, is_null(col("f")));
Ok(())
}
}