use std::fmt;
use std::fmt::Formatter;
use std::hash::Hash;
use std::sync::Arc;
use prost::Message;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_mask::AllOr;
use vortex_mask::Mask;
use vortex_proto::expr as pb;
use vortex_session::VortexSession;
use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::arrays::BoolArray;
use crate::arrays::ConstantArray;
use crate::arrays::bool::BoolArrayExt;
use crate::builders::ArrayBuilder;
use crate::builders::builder_with_capacity;
use crate::builtins::ArrayBuiltins;
use crate::dtype::DType;
use crate::expr::Expression;
use crate::scalar::Scalar;
use crate::scalar_fn::Arity;
use crate::scalar_fn::ChildName;
use crate::scalar_fn::ExecutionArgs;
use crate::scalar_fn::ScalarFnId;
use crate::scalar_fn::ScalarFnVTable;
use crate::scalar_fn::fns::zip::zip_impl;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct CaseWhenOptions {
pub num_when_then_pairs: u32,
pub has_else: bool,
}
impl CaseWhenOptions {
pub fn num_children(&self) -> usize {
self.num_when_then_pairs as usize * 2 + usize::from(self.has_else)
}
}
impl fmt::Display for CaseWhenOptions {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(
f,
"case_when(pairs={}, else={})",
self.num_when_then_pairs, self.has_else
)
}
}
#[derive(Clone)]
pub struct CaseWhen;
impl ScalarFnVTable for CaseWhen {
type Options = CaseWhenOptions;
fn id(&self) -> ScalarFnId {
ScalarFnId::from("vortex.case_when")
}
fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> {
vortex_bail!("cannot serialize")
}
fn deserialize(
&self,
metadata: &[u8],
_session: &VortexSession,
) -> VortexResult<Self::Options> {
let opts = pb::CaseWhenOpts::decode(metadata)?;
if opts.num_children < 2 {
vortex_bail!(
"CaseWhen expects at least 2 children, got {}",
opts.num_children
);
}
Ok(CaseWhenOptions {
num_when_then_pairs: opts.num_children / 2,
has_else: opts.num_children % 2 == 1,
})
}
fn arity(&self, options: &Self::Options) -> Arity {
Arity::Exact(options.num_children())
}
fn child_name(&self, options: &Self::Options, child_idx: usize) -> ChildName {
let num_pair_children = options.num_when_then_pairs as usize * 2;
if child_idx < num_pair_children {
let pair_idx = child_idx / 2;
if child_idx.is_multiple_of(2) {
ChildName::from(Arc::from(format!("when_{pair_idx}")))
} else {
ChildName::from(Arc::from(format!("then_{pair_idx}")))
}
} else if options.has_else && child_idx == num_pair_children {
ChildName::from("else")
} else {
unreachable!("Invalid child index {} for CaseWhen", child_idx)
}
}
fn fmt_sql(
&self,
options: &Self::Options,
expr: &Expression,
f: &mut Formatter<'_>,
) -> fmt::Result {
write!(f, "CASE")?;
for i in 0..options.num_when_then_pairs as usize {
write!(
f,
" WHEN {} THEN {}",
expr.child(i * 2),
expr.child(i * 2 + 1)
)?;
}
if options.has_else {
let else_idx = options.num_when_then_pairs as usize * 2;
write!(f, " ELSE {}", expr.child(else_idx))?;
}
write!(f, " END")
}
fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult<DType> {
if options.num_when_then_pairs == 0 {
vortex_bail!("CaseWhen must have at least one WHEN/THEN pair");
}
let expected_len = options.num_children();
if arg_dtypes.len() != expected_len {
vortex_bail!(
"CaseWhen expects {expected_len} argument dtypes, got {}",
arg_dtypes.len()
);
}
let first_then = &arg_dtypes[1];
let mut result_dtype = first_then.clone();
for i in 1..options.num_when_then_pairs as usize {
let then_i = &arg_dtypes[i * 2 + 1];
if !first_then.eq_ignore_nullability(then_i) {
vortex_bail!(
"CaseWhen THEN dtypes must match (ignoring nullability), got {} and {}",
first_then,
then_i
);
}
result_dtype = result_dtype.union_nullability(then_i.nullability());
}
if options.has_else {
let else_dtype = &arg_dtypes[options.num_when_then_pairs as usize * 2];
if !result_dtype.eq_ignore_nullability(else_dtype) {
vortex_bail!(
"CaseWhen THEN and ELSE dtypes must match (ignoring nullability), got {} and {}",
first_then,
else_dtype
);
}
result_dtype = result_dtype.union_nullability(else_dtype.nullability());
} else {
result_dtype = result_dtype.as_nullable();
}
Ok(result_dtype)
}
fn execute(
&self,
options: &Self::Options,
args: &dyn ExecutionArgs,
ctx: &mut ExecutionCtx,
) -> VortexResult<ArrayRef> {
let row_count = args.row_count();
let num_pairs = options.num_when_then_pairs as usize;
let mut remaining = Mask::new_true(row_count);
let mut branches: Vec<(Mask, ArrayRef)> = Vec::with_capacity(num_pairs);
for i in 0..num_pairs {
if remaining.all_false() {
break;
}
let condition = args.get(i * 2)?;
let cond_bool = condition.execute::<BoolArray>(ctx)?;
let cond_mask = cond_bool.to_mask_fill_null_false();
let effective_mask = &remaining & &cond_mask;
if effective_mask.all_false() {
continue;
}
let then_value = args.get(i * 2 + 1)?;
remaining = remaining.bitand_not(&cond_mask);
branches.push((effective_mask, then_value));
}
let else_value: ArrayRef = if options.has_else {
args.get(num_pairs * 2)?
} else {
let then_dtype = args.get(1)?.dtype().as_nullable();
ConstantArray::new(Scalar::null(then_dtype), row_count).into_array()
};
if branches.is_empty() {
return Ok(else_value);
}
merge_case_branches(branches, else_value)
}
fn is_null_sensitive(&self, _options: &Self::Options) -> bool {
true
}
fn is_fallible(&self, _options: &Self::Options) -> bool {
false
}
}
const SLICE_CROSSOVER_RUN_LEN: usize = 4;
fn merge_case_branches(
branches: Vec<(Mask, ArrayRef)>,
else_value: ArrayRef,
) -> VortexResult<ArrayRef> {
if branches.len() == 1 {
let (mask, then_value) = &branches[0];
return zip_impl(then_value, &else_value, mask);
}
let output_nullability = branches
.iter()
.fold(else_value.dtype().nullability(), |acc, (_, arr)| {
acc | arr.dtype().nullability()
});
let output_dtype = else_value.dtype().with_nullability(output_nullability);
let branch_arrays: Vec<&ArrayRef> = branches.iter().map(|(_, arr)| arr).collect();
let mut spans: Vec<(usize, usize, usize)> = Vec::new();
for (branch_idx, (mask, _)) in branches.iter().enumerate() {
match mask.slices() {
AllOr::All => return branch_arrays[branch_idx].cast(output_dtype),
AllOr::None => {}
AllOr::Some(slices) => {
for &(start, end) in slices {
spans.push((start, end, branch_idx));
}
}
}
}
spans.sort_unstable_by_key(|&(start, ..)| start);
if spans.is_empty() {
return else_value.cast(output_dtype);
}
let builder = builder_with_capacity(&output_dtype, else_value.len());
let fragmented = spans.len() > else_value.len() / SLICE_CROSSOVER_RUN_LEN;
if fragmented {
merge_row_by_row(&branch_arrays, &else_value, &spans, &output_dtype, builder)
} else {
merge_run_by_run(&branch_arrays, &else_value, &spans, &output_dtype, builder)
}
}
fn merge_row_by_row(
branch_arrays: &[&ArrayRef],
else_value: &ArrayRef,
spans: &[(usize, usize, usize)],
output_dtype: &DType,
mut builder: Box<dyn ArrayBuilder>,
) -> VortexResult<ArrayRef> {
let mut pos = 0;
for &(start, end, branch_idx) in spans {
for row in pos..start {
let scalar = else_value.scalar_at(row)?;
builder.append_scalar(&scalar.cast(output_dtype)?)?;
}
for row in start..end {
let scalar = branch_arrays[branch_idx].scalar_at(row)?;
builder.append_scalar(&scalar.cast(output_dtype)?)?;
}
pos = end;
}
for row in pos..else_value.len() {
let scalar = else_value.scalar_at(row)?;
builder.append_scalar(&scalar.cast(output_dtype)?)?;
}
Ok(builder.finish())
}
fn merge_run_by_run(
branch_arrays: &[&ArrayRef],
else_value: &ArrayRef,
spans: &[(usize, usize, usize)],
output_dtype: &DType,
mut builder: Box<dyn ArrayBuilder>,
) -> VortexResult<ArrayRef> {
let else_value = else_value.cast(output_dtype.clone())?;
let len = else_value.len();
for (start, end, branch_idx) in spans {
if builder.len() < *start {
builder.extend_from_array(&else_value.slice(builder.len()..*start)?);
}
builder.extend_from_array(
&branch_arrays[*branch_idx]
.cast(output_dtype.clone())?
.slice(*start..*end)?,
);
}
if builder.len() < len {
builder.extend_from_array(&else_value.slice(builder.len()..len)?);
}
Ok(builder.finish())
}
#[cfg(test)]
mod tests {
use std::sync::LazyLock;
use vortex_buffer::buffer;
use vortex_error::VortexExpect as _;
use vortex_session::VortexSession;
use super::*;
use crate::Canonical;
use crate::IntoArray;
use crate::VortexSessionExecute as _;
use crate::arrays::BoolArray;
use crate::arrays::PrimitiveArray;
use crate::arrays::StructArray;
use crate::assert_arrays_eq;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::dtype::PType;
use crate::expr::case_when;
use crate::expr::case_when_no_else;
use crate::expr::col;
use crate::expr::eq;
use crate::expr::get_item;
use crate::expr::gt;
use crate::expr::lit;
use crate::expr::nested_case_when;
use crate::expr::root;
use crate::expr::test_harness;
use crate::scalar::Scalar;
use crate::session::ArraySession;
static SESSION: LazyLock<VortexSession> =
LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
fn evaluate_expr(expr: &Expression, array: &ArrayRef) -> ArrayRef {
let mut ctx = SESSION.create_execution_ctx();
array
.clone()
.apply(expr)
.unwrap()
.execute::<Canonical>(&mut ctx)
.unwrap()
.into_array()
}
#[test]
#[should_panic(expected = "cannot serialize")]
fn test_serialization_roundtrip() {
let options = CaseWhenOptions {
num_when_then_pairs: 1,
has_else: true,
};
let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
let deserialized = CaseWhen
.deserialize(&serialized, &VortexSession::empty())
.unwrap();
assert_eq!(options, deserialized);
}
#[test]
#[should_panic(expected = "cannot serialize")]
fn test_serialization_no_else() {
let options = CaseWhenOptions {
num_when_then_pairs: 1,
has_else: false,
};
let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
let deserialized = CaseWhen
.deserialize(&serialized, &VortexSession::empty())
.unwrap();
assert_eq!(options, deserialized);
}
#[test]
fn test_display_with_else() {
let expr = case_when(gt(col("value"), lit(0i32)), lit(100i32), lit(0i32));
let display = format!("{}", expr);
assert!(display.contains("CASE"));
assert!(display.contains("WHEN"));
assert!(display.contains("THEN"));
assert!(display.contains("ELSE"));
assert!(display.contains("END"));
}
#[test]
fn test_display_no_else() {
let expr = case_when_no_else(gt(col("value"), lit(0i32)), lit(100i32));
let display = format!("{}", expr);
assert!(display.contains("CASE"));
assert!(display.contains("WHEN"));
assert!(display.contains("THEN"));
assert!(!display.contains("ELSE"));
assert!(display.contains("END"));
}
#[test]
fn test_display_nested_nary() {
let expr = nested_case_when(
vec![
(gt(col("x"), lit(10i32)), lit("high")),
(gt(col("x"), lit(5i32)), lit("medium")),
],
Some(lit("low")),
);
let display = format!("{}", expr);
assert_eq!(display.matches("CASE").count(), 1);
assert_eq!(display.matches("WHEN").count(), 2);
assert_eq!(display.matches("THEN").count(), 2);
}
#[test]
fn test_return_dtype_with_else() {
let expr = case_when(lit(true), lit(100i32), lit(0i32));
let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
let result_dtype = expr.return_dtype(&input_dtype).unwrap();
assert_eq!(
result_dtype,
DType::Primitive(PType::I32, Nullability::NonNullable)
);
}
#[test]
fn test_return_dtype_with_nullable_else() {
let expr = case_when(
lit(true),
lit(100i32),
lit(Scalar::null(DType::Primitive(
PType::I32,
Nullability::Nullable,
))),
);
let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
let result_dtype = expr.return_dtype(&input_dtype).unwrap();
assert_eq!(
result_dtype,
DType::Primitive(PType::I32, Nullability::Nullable)
);
}
#[test]
fn test_return_dtype_without_else_is_nullable() {
let expr = case_when_no_else(lit(true), lit(100i32));
let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
let result_dtype = expr.return_dtype(&input_dtype).unwrap();
assert_eq!(
result_dtype,
DType::Primitive(PType::I32, Nullability::Nullable)
);
}
#[test]
fn test_return_dtype_with_struct_input() {
let dtype = test_harness::struct_dtype();
let expr = case_when(
gt(get_item("col1", root()), lit(10u16)),
lit(100i32),
lit(0i32),
);
let result_dtype = expr.return_dtype(&dtype).unwrap();
assert_eq!(
result_dtype,
DType::Primitive(PType::I32, Nullability::NonNullable)
);
}
#[test]
fn test_return_dtype_mismatched_then_else_errors() {
let expr = case_when(lit(true), lit(100i32), lit("zero"));
let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
let err = expr.return_dtype(&input_dtype).unwrap_err();
assert!(
err.to_string()
.contains("THEN and ELSE dtypes must match (ignoring nullability)")
);
}
#[test]
fn test_arity_with_else() {
let options = CaseWhenOptions {
num_when_then_pairs: 1,
has_else: true,
};
assert_eq!(CaseWhen.arity(&options), Arity::Exact(3));
}
#[test]
fn test_arity_without_else() {
let options = CaseWhenOptions {
num_when_then_pairs: 1,
has_else: false,
};
assert_eq!(CaseWhen.arity(&options), Arity::Exact(2));
}
#[test]
fn test_child_names() {
let options = CaseWhenOptions {
num_when_then_pairs: 1,
has_else: true,
};
assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "else");
}
#[test]
#[should_panic(expected = "cannot serialize")]
fn test_serialization_roundtrip_nary() {
let options = CaseWhenOptions {
num_when_then_pairs: 3,
has_else: true,
};
let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
let deserialized = CaseWhen
.deserialize(&serialized, &VortexSession::empty())
.unwrap();
assert_eq!(options, deserialized);
}
#[test]
#[should_panic(expected = "cannot serialize")]
fn test_serialization_roundtrip_nary_no_else() {
let options = CaseWhenOptions {
num_when_then_pairs: 4,
has_else: false,
};
let serialized = CaseWhen.serialize(&options).unwrap().unwrap();
let deserialized = CaseWhen
.deserialize(&serialized, &VortexSession::empty())
.unwrap();
assert_eq!(options, deserialized);
}
#[test]
fn test_arity_nary_with_else() {
let options = CaseWhenOptions {
num_when_then_pairs: 3,
has_else: true,
};
assert_eq!(CaseWhen.arity(&options), Arity::Exact(7));
}
#[test]
fn test_arity_nary_without_else() {
let options = CaseWhenOptions {
num_when_then_pairs: 3,
has_else: false,
};
assert_eq!(CaseWhen.arity(&options), Arity::Exact(6));
}
#[test]
fn test_child_names_nary() {
let options = CaseWhenOptions {
num_when_then_pairs: 3,
has_else: true,
};
assert_eq!(CaseWhen.child_name(&options, 0).to_string(), "when_0");
assert_eq!(CaseWhen.child_name(&options, 1).to_string(), "then_0");
assert_eq!(CaseWhen.child_name(&options, 2).to_string(), "when_1");
assert_eq!(CaseWhen.child_name(&options, 3).to_string(), "then_1");
assert_eq!(CaseWhen.child_name(&options, 4).to_string(), "when_2");
assert_eq!(CaseWhen.child_name(&options, 5).to_string(), "then_2");
assert_eq!(CaseWhen.child_name(&options, 6).to_string(), "else");
}
#[test]
fn test_return_dtype_nary_mismatched_then_types_errors() {
let expr = nested_case_when(
vec![(lit(true), lit(100i32)), (lit(false), lit("oops"))],
Some(lit(0i32)),
);
let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
let err = expr.return_dtype(&input_dtype).unwrap_err();
assert!(err.to_string().contains("THEN dtypes must match"));
}
#[test]
fn test_return_dtype_nary_mixed_nullability() {
let non_null_then = lit(100i32);
let nullable_then = lit(Scalar::null(DType::Primitive(
PType::I32,
Nullability::Nullable,
)));
let expr = nested_case_when(
vec![(lit(true), non_null_then), (lit(false), nullable_then)],
Some(lit(0i32)),
);
let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
let result = expr.return_dtype(&input_dtype).unwrap();
assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
}
#[test]
fn test_return_dtype_nary_no_else_is_nullable() {
let expr = nested_case_when(
vec![(lit(true), lit(10i32)), (lit(false), lit(20i32))],
None,
);
let input_dtype = DType::Primitive(PType::I32, Nullability::NonNullable);
let result = expr.return_dtype(&input_dtype).unwrap();
assert_eq!(result, DType::Primitive(PType::I32, Nullability::Nullable));
}
#[test]
fn test_replace_children() {
let expr = case_when(lit(true), lit(1i32), lit(0i32));
expr.with_children([lit(false), lit(2i32), lit(3i32)])
.vortex_expect("operation should succeed in test");
}
#[test]
fn test_evaluate_simple_condition() {
let test_array =
StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
.unwrap()
.into_array();
let expr = case_when(
gt(get_item("value", root()), lit(2i32)),
lit(100i32),
lit(0i32),
);
let result = evaluate_expr(&expr, &test_array);
assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
}
#[test]
fn test_evaluate_nary_multiple_conditions() {
let test_array =
StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
.unwrap()
.into_array();
let expr = nested_case_when(
vec![
(eq(get_item("value", root()), lit(1i32)), lit(10i32)),
(eq(get_item("value", root()), lit(3i32)), lit(30i32)),
],
Some(lit(0i32)),
);
let result = evaluate_expr(&expr, &test_array);
assert_arrays_eq!(result, buffer![10i32, 0, 30, 0, 0].into_array());
}
#[test]
fn test_evaluate_nary_first_match_wins() {
let test_array =
StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
.unwrap()
.into_array();
let expr = nested_case_when(
vec![
(gt(get_item("value", root()), lit(2i32)), lit(100i32)),
(gt(get_item("value", root()), lit(3i32)), lit(200i32)),
],
Some(lit(0i32)),
);
let result = evaluate_expr(&expr, &test_array);
assert_arrays_eq!(result, buffer![0i32, 0, 100, 100, 100].into_array());
}
#[test]
fn test_evaluate_no_else_returns_null() {
let test_array =
StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
.unwrap()
.into_array();
let expr = case_when_no_else(gt(get_item("value", root()), lit(3i32)), lit(100i32));
let result = evaluate_expr(&expr, &test_array);
assert!(result.dtype().is_nullable());
assert_arrays_eq!(
result,
PrimitiveArray::from_option_iter([None::<i32>, None, None, Some(100), Some(100)])
.into_array()
);
}
#[test]
fn test_evaluate_all_conditions_false() {
let test_array =
StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
.unwrap()
.into_array();
let expr = case_when(
gt(get_item("value", root()), lit(100i32)),
lit(1i32),
lit(0i32),
);
let result = evaluate_expr(&expr, &test_array);
assert_arrays_eq!(result, buffer![0i32, 0, 0, 0, 0].into_array());
}
#[test]
fn test_evaluate_all_conditions_true() {
let test_array =
StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
.unwrap()
.into_array();
let expr = case_when(
gt(get_item("value", root()), lit(0i32)),
lit(100i32),
lit(0i32),
);
let result = evaluate_expr(&expr, &test_array);
assert_arrays_eq!(result, buffer![100i32, 100, 100, 100, 100].into_array());
}
#[test]
fn test_evaluate_all_true_no_else_returns_correct_dtype() {
let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
.unwrap()
.into_array();
let expr = case_when_no_else(gt(get_item("value", root()), lit(0i32)), lit(100i32));
let result = evaluate_expr(&expr, &test_array);
assert!(
result.dtype().is_nullable(),
"result dtype must be Nullable, got {:?}",
result.dtype()
);
assert_arrays_eq!(
result,
PrimitiveArray::from_option_iter([Some(100i32), Some(100), Some(100)]).into_array()
);
}
#[test]
fn test_merge_case_branches_widens_nullability_of_later_branch() -> VortexResult<()> {
let test_array = StructArray::from_fields(&[("value", buffer![0i32, 1, 2].into_array())])
.unwrap()
.into_array();
let nullable_20 =
Scalar::from(20i32).cast(&DType::Primitive(PType::I32, Nullability::Nullable))?;
let expr = nested_case_when(
vec![
(eq(get_item("value", root()), lit(0i32)), lit(10i32)),
(eq(get_item("value", root()), lit(1i32)), lit(nullable_20)),
],
Some(lit(0i32)),
);
let result = evaluate_expr(&expr, &test_array);
assert!(
result.dtype().is_nullable(),
"result dtype must be Nullable, got {:?}",
result.dtype()
);
assert_arrays_eq!(
result,
PrimitiveArray::from_option_iter([Some(10), Some(20), Some(0)]).into_array()
);
Ok(())
}
#[test]
fn test_evaluate_with_literal_condition() {
let test_array = buffer![1i32, 2, 3].into_array();
let expr = case_when(lit(true), lit(100i32), lit(0i32));
let result = evaluate_expr(&expr, &test_array);
assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
}
#[test]
fn test_evaluate_with_bool_column_result() {
let test_array =
StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
.unwrap()
.into_array();
let expr = case_when(
gt(get_item("value", root()), lit(2i32)),
lit(true),
lit(false),
);
let result = evaluate_expr(&expr, &test_array);
assert_arrays_eq!(
result,
BoolArray::from_iter([false, false, true, true, true]).into_array()
);
}
#[test]
fn test_evaluate_with_nullable_condition() {
let test_array = StructArray::from_fields(&[(
"cond",
BoolArray::from_iter([Some(true), None, Some(false), None, Some(true)]).into_array(),
)])
.unwrap()
.into_array();
let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
let result = evaluate_expr(&expr, &test_array);
assert_arrays_eq!(result, buffer![100i32, 0, 0, 0, 100].into_array());
}
#[test]
fn test_evaluate_with_nullable_result_values() {
let test_array = StructArray::from_fields(&[
("value", buffer![1i32, 2, 3, 4, 5].into_array()),
(
"result",
PrimitiveArray::from_option_iter([Some(10), None, Some(30), Some(40), Some(50)])
.into_array(),
),
])
.unwrap()
.into_array();
let expr = case_when(
gt(get_item("value", root()), lit(2i32)),
get_item("result", root()),
lit(0i32),
);
let result = evaluate_expr(&expr, &test_array);
assert_arrays_eq!(
result,
PrimitiveArray::from_option_iter([Some(0i32), Some(0), Some(30), Some(40), Some(50)])
.into_array()
);
}
#[test]
fn test_evaluate_with_all_null_condition() {
let test_array = StructArray::from_fields(&[(
"cond",
BoolArray::from_iter([None, None, None]).into_array(),
)])
.unwrap()
.into_array();
let expr = case_when(get_item("cond", root()), lit(100i32), lit(0i32));
let result = evaluate_expr(&expr, &test_array);
assert_arrays_eq!(result, buffer![0i32, 0, 0].into_array());
}
#[test]
fn test_evaluate_nary_no_else_returns_null() {
let test_array =
StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
.unwrap()
.into_array();
let expr = nested_case_when(
vec![
(eq(get_item("value", root()), lit(1i32)), lit(10i32)),
(eq(get_item("value", root()), lit(3i32)), lit(30i32)),
],
None,
);
let result = evaluate_expr(&expr, &test_array);
assert!(result.dtype().is_nullable());
assert_arrays_eq!(
result,
PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None, None])
.into_array()
);
}
#[test]
fn test_evaluate_nary_many_conditions() {
let test_array =
StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4, 5].into_array())])
.unwrap()
.into_array();
let expr = nested_case_when(
vec![
(eq(get_item("value", root()), lit(1i32)), lit(10i32)),
(eq(get_item("value", root()), lit(2i32)), lit(20i32)),
(eq(get_item("value", root()), lit(3i32)), lit(30i32)),
(eq(get_item("value", root()), lit(4i32)), lit(40i32)),
(eq(get_item("value", root()), lit(5i32)), lit(50i32)),
],
Some(lit(0i32)),
);
let result = evaluate_expr(&expr, &test_array);
assert_arrays_eq!(result, buffer![10i32, 20, 30, 40, 50].into_array());
}
#[test]
fn test_evaluate_nary_all_false_no_else() {
let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
.unwrap()
.into_array();
let expr = nested_case_when(
vec![
(gt(get_item("value", root()), lit(100i32)), lit(10i32)),
(gt(get_item("value", root()), lit(200i32)), lit(20i32)),
],
None,
);
let result = evaluate_expr(&expr, &test_array);
assert!(result.dtype().is_nullable());
assert_arrays_eq!(
result,
PrimitiveArray::from_option_iter([None::<i32>, None, None]).into_array()
);
}
#[test]
fn test_evaluate_nary_overlapping_conditions_first_wins() {
let test_array =
StructArray::from_fields(&[("value", buffer![10i32, 20, 30].into_array())])
.unwrap()
.into_array();
let expr = nested_case_when(
vec![
(gt(get_item("value", root()), lit(5i32)), lit(1i32)),
(gt(get_item("value", root()), lit(0i32)), lit(2i32)),
(gt(get_item("value", root()), lit(15i32)), lit(3i32)),
],
Some(lit(0i32)),
);
let result = evaluate_expr(&expr, &test_array);
assert_arrays_eq!(result, buffer![1i32, 1, 1].into_array());
}
#[test]
fn test_evaluate_nary_early_exit_when_remaining_empty() {
let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
.unwrap()
.into_array();
let expr = nested_case_when(
vec![
(gt(get_item("value", root()), lit(0i32)), lit(100i32)),
(gt(get_item("value", root()), lit(0i32)), lit(999i32)),
],
Some(lit(0i32)),
);
let result = evaluate_expr(&expr, &test_array);
assert_arrays_eq!(result, buffer![100i32, 100, 100].into_array());
}
#[test]
fn test_evaluate_nary_skips_branch_with_empty_effective_mask() {
let test_array = StructArray::from_fields(&[("value", buffer![1i32, 2, 3].into_array())])
.unwrap()
.into_array();
let expr = nested_case_when(
vec![
(eq(get_item("value", root()), lit(1i32)), lit(10i32)),
(eq(get_item("value", root()), lit(1i32)), lit(999i32)),
(eq(get_item("value", root()), lit(2i32)), lit(20i32)),
],
Some(lit(0i32)),
);
let result = evaluate_expr(&expr, &test_array);
assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
}
#[test]
fn test_evaluate_nary_string_output() -> VortexResult<()> {
let test_array =
StructArray::from_fields(&[("value", buffer![1i32, 2, 3, 4].into_array())])
.unwrap()
.into_array();
let expr = nested_case_when(
vec![
(gt(get_item("value", root()), lit(2i32)), lit("high")),
(gt(get_item("value", root()), lit(0i32)), lit("low")),
],
Some(lit("none")),
);
let result = evaluate_expr(&expr, &test_array);
assert_eq!(
result.scalar_at(0)?,
Scalar::utf8("low", Nullability::NonNullable)
);
assert_eq!(
result.scalar_at(1)?,
Scalar::utf8("low", Nullability::NonNullable)
);
assert_eq!(
result.scalar_at(2)?,
Scalar::utf8("high", Nullability::NonNullable)
);
assert_eq!(
result.scalar_at(3)?,
Scalar::utf8("high", Nullability::NonNullable)
);
Ok(())
}
#[test]
fn test_evaluate_nary_with_nullable_conditions() {
let test_array = StructArray::from_fields(&[
(
"cond1",
BoolArray::from_iter([Some(true), None, Some(false)]).into_array(),
),
(
"cond2",
BoolArray::from_iter([Some(false), Some(true), None]).into_array(),
),
])
.unwrap()
.into_array();
let expr = nested_case_when(
vec![
(get_item("cond1", root()), lit(10i32)),
(get_item("cond2", root()), lit(20i32)),
],
Some(lit(0i32)),
);
let result = evaluate_expr(&expr, &test_array);
assert_arrays_eq!(result, buffer![10i32, 20, 0].into_array());
}
#[test]
fn test_merge_case_branches_alternating_mask() -> VortexResult<()> {
let n = 100usize;
let branch0_mask = Mask::from_indices(n, (0..n).step_by(2).collect());
let branch1_mask = Mask::from_indices(n, (1..n).step_by(2).collect());
let result = merge_case_branches(
vec![
(
branch0_mask,
PrimitiveArray::from_option_iter(vec![Some(0i32); n]).into_array(),
),
(
branch1_mask,
PrimitiveArray::from_option_iter(vec![Some(1i32); n]).into_array(),
),
],
PrimitiveArray::from_option_iter(vec![Some(99i32); n]).into_array(),
)?;
let expected: Vec<Option<i32>> = (0..n)
.map(|v| if v % 2 == 0 { Some(0) } else { Some(1) })
.collect();
assert_arrays_eq!(
result,
PrimitiveArray::from_option_iter(expected).into_array()
);
Ok(())
}
}