mod kernel;
use std::fmt::Formatter;
pub use kernel::*;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use vortex_session::VortexSession;
use crate::ArrayRef;
use crate::Canonical;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::arrays::BoolArray;
use crate::arrays::Constant;
use crate::arrays::ConstantArray;
use crate::arrays::bool::BoolArrayExt;
use crate::arrays::masked::mask_validity_canonical;
use crate::builtins::ArrayBuiltins;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::expr::Expression;
use crate::expr::and;
use crate::expr::lit;
use crate::scalar::Scalar;
use crate::scalar_fn::Arity;
use crate::scalar_fn::ChildName;
use crate::scalar_fn::EmptyOptions;
use crate::scalar_fn::ExecutionArgs;
use crate::scalar_fn::ScalarFnId;
use crate::scalar_fn::ScalarFnVTable;
use crate::scalar_fn::SimplifyCtx;
use crate::scalar_fn::fns::literal::Literal;
#[derive(Clone)]
pub struct Mask;
impl ScalarFnVTable for Mask {
type Options = EmptyOptions;
fn id(&self) -> ScalarFnId {
ScalarFnId::from("vortex.mask")
}
fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
Ok(Some(vec![]))
}
fn deserialize(
&self,
_metadata: &[u8],
_session: &VortexSession,
) -> VortexResult<Self::Options> {
Ok(EmptyOptions)
}
fn arity(&self, _options: &Self::Options) -> Arity {
Arity::Exact(2)
}
fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName {
match child_idx {
0 => ChildName::from("input"),
1 => ChildName::from("mask"),
_ => unreachable!("Invalid child index {} for Mask expression", child_idx),
}
}
fn fmt_sql(
&self,
_options: &Self::Options,
expr: &Expression,
f: &mut Formatter<'_>,
) -> std::fmt::Result {
write!(f, "mask(")?;
expr.child(0).fmt_sql(f)?;
write!(f, ", ")?;
expr.child(1).fmt_sql(f)?;
write!(f, ")")
}
fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
vortex_ensure!(
arg_dtypes[1] == DType::Bool(Nullability::NonNullable),
"The mask argument to 'mask' must be a non-nullable boolean array, got {}",
arg_dtypes[1]
);
Ok(arg_dtypes[0].as_nullable())
}
fn execute(
&self,
_options: &Self::Options,
args: &dyn ExecutionArgs,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let input = args.get(0)?;
let mask_array = args.get(1)?;
if let Some(result) = execute_constant(&input, &mask_array)? {
return Ok(result);
}
execute_canonical(input, mask_array, ctx)
}
fn simplify(
&self,
_options: &Self::Options,
expr: &Expression,
ctx: &dyn SimplifyCtx,
) -> VortexResult<Option<Expression>> {
let Some(mask_lit) = expr.child(1).as_opt::<Literal>() else {
return Ok(None);
};
let mask_lit = mask_lit
.as_bool()
.value()
.vortex_expect("Mask must be non-nullable");
if mask_lit {
Ok(Some(expr.child(0).clone()))
} else {
let input_dtype = ctx.return_dtype(expr.child(0))?;
Ok(Some(lit(Scalar::null(input_dtype.as_nullable()))))
}
}
fn validity(
&self,
_options: &Self::Options,
expression: &Expression,
) -> VortexResult<Option<Expression>> {
Ok(Some(and(
expression.child(0).validity()?,
expression.child(1).clone(),
)))
}
}
fn execute_constant(input: &ArrayRef, mask_array: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
let len = input.len();
if let Some(constant_mask) = mask_array.as_opt::<Constant>() {
let mask_value = constant_mask.scalar().as_bool().value().unwrap_or(false);
return if mask_value {
input.cast(input.dtype().as_nullable()).map(Some)
} else {
Ok(Some(
ConstantArray::new(Scalar::null(input.dtype().as_nullable()), len).into_array(),
))
};
}
if let Some(constant_input) = input.as_opt::<Constant>()
&& constant_input.scalar().is_null()
{
return Ok(Some(
ConstantArray::new(Scalar::null(input.dtype().as_nullable()), len).into_array(),
));
}
Ok(None)
}
fn execute_canonical(
input: ArrayRef,
mask_array: ArrayRef,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let mask_bool = mask_array.execute::<BoolArray>(ctx)?;
let validity_mask = vortex_mask::Mask::from(mask_bool.to_bit_buffer());
let canonical = input.execute::<Canonical>(ctx)?;
Ok(mask_validity_canonical(canonical, &validity_mask, ctx)?.into_array())
}
#[cfg(test)]
mod test {
use vortex_error::VortexExpect;
use crate::dtype::DType;
use crate::dtype::Nullability::Nullable;
use crate::dtype::PType;
use crate::expr::lit;
use crate::expr::mask;
use crate::scalar::Scalar;
#[test]
fn test_simplify() {
let input_expr = lit(42u32);
let true_mask_expr = lit(true);
let false_mask_expr = lit(false);
let mask_true_expr = mask(input_expr.clone(), true_mask_expr);
let simplified_true = mask_true_expr
.optimize(&DType::Null)
.vortex_expect("Simplification");
assert_eq!(&simplified_true, &input_expr);
let mask_false_expr = mask(input_expr, false_mask_expr);
let simplified_false = mask_false_expr
.optimize(&DType::Null)
.vortex_expect("Simplification");
let expected_null_expr = lit(Scalar::null(DType::Primitive(PType::U32, Nullable)));
assert_eq!(&simplified_false, &expected_null_expr);
}
}