use super::{Column, Literal};
use crate::expressions::case::ResultState::{Complete, Empty, Partial};
use crate::expressions::try_cast;
use crate::PhysicalExpr;
use arrow::array::*;
use arrow::compute::kernels::zip::zip;
use arrow::compute::{
is_not_null, not, nullif, prep_null_mask_filter, FilterBuilder, FilterPredicate,
SlicesIterator,
};
use arrow::datatypes::{DataType, Schema, UInt32Type, UnionMode};
use arrow::error::ArrowError;
use datafusion_common::cast::as_boolean_array;
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, DataFusionError, HashMap, HashSet,
Result, ScalarValue,
};
use datafusion_expr::ColumnarValue;
use datafusion_physical_expr_common::datum::compare_with_eq;
use itertools::Itertools;
use std::borrow::Cow;
use std::fmt::{Debug, Formatter};
use std::hash::Hash;
use std::{any::Any, sync::Arc};
type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
#[derive(Debug, Hash, PartialEq, Eq)]
enum EvalMethod {
NoExpression(ProjectedCaseBody),
WithExpression(ProjectedCaseBody),
InfallibleExprOrNull,
ScalarOrScalar,
ExpressionOrExpression(ProjectedCaseBody),
}
#[derive(Debug, Hash, PartialEq, Eq)]
struct CaseBody {
expr: Option<Arc<dyn PhysicalExpr>>,
when_then_expr: Vec<WhenThen>,
else_expr: Option<Arc<dyn PhysicalExpr>>,
}
impl CaseBody {
fn project(&self) -> Result<ProjectedCaseBody> {
let mut used_column_indices = HashSet::<usize>::new();
let mut collect_column_indices = |expr: &Arc<dyn PhysicalExpr>| {
expr.apply(|expr| {
if let Some(column) = expr.as_any().downcast_ref::<Column>() {
used_column_indices.insert(column.index());
}
Ok(TreeNodeRecursion::Continue)
})
.expect("Closure cannot fail");
};
if let Some(e) = &self.expr {
collect_column_indices(e);
}
self.when_then_expr.iter().for_each(|(w, t)| {
collect_column_indices(w);
collect_column_indices(t);
});
if let Some(e) = &self.else_expr {
collect_column_indices(e);
}
let column_index_map = used_column_indices
.iter()
.enumerate()
.map(|(projected, original)| (*original, projected))
.collect::<HashMap<usize, usize>>();
let project = |expr: &Arc<dyn PhysicalExpr>| -> Result<Arc<dyn PhysicalExpr>> {
Arc::clone(expr)
.transform_down(|e| {
if let Some(column) = e.as_any().downcast_ref::<Column>() {
let original = column.index();
let projected = *column_index_map.get(&original).unwrap();
if projected != original {
return Ok(Transformed::yes(Arc::new(Column::new(
column.name(),
projected,
))));
}
}
Ok(Transformed::no(e))
})
.map(|t| t.data)
};
let projected_body = CaseBody {
expr: self.expr.as_ref().map(project).transpose()?,
when_then_expr: self
.when_then_expr
.iter()
.map(|(e, t)| Ok((project(e)?, project(t)?)))
.collect::<Result<Vec<_>>>()?,
else_expr: self.else_expr.as_ref().map(project).transpose()?,
};
let projection = column_index_map
.iter()
.sorted_by_key(|(_, v)| **v)
.map(|(k, _)| *k)
.collect::<Vec<_>>();
Ok(ProjectedCaseBody {
projection,
body: projected_body,
})
}
}
#[derive(Debug, Hash, PartialEq, Eq)]
struct ProjectedCaseBody {
projection: Vec<usize>,
body: CaseBody,
}
#[derive(Debug, Hash, PartialEq, Eq)]
pub struct CaseExpr {
body: CaseBody,
eval_method: EvalMethod,
}
impl std::fmt::Display for CaseExpr {
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "CASE ")?;
if let Some(e) = &self.body.expr {
write!(f, "{e} ")?;
}
for (w, t) in &self.body.when_then_expr {
write!(f, "WHEN {w} THEN {t} ")?;
}
if let Some(e) = &self.body.else_expr {
write!(f, "ELSE {e} ")?;
}
write!(f, "END")
}
}
fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
expr.as_any().is::<Column>()
}
fn create_filter(predicate: &BooleanArray, optimize: bool) -> FilterPredicate {
let mut filter_builder = FilterBuilder::new(predicate);
if optimize {
filter_builder = filter_builder.optimize();
}
filter_builder.build()
}
fn multiple_arrays(data_type: &DataType) -> bool {
match data_type {
DataType::Struct(fields) => {
fields.len() > 1
|| fields.len() == 1 && multiple_arrays(fields[0].data_type())
}
DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
_ => false,
}
}
fn filter_record_batch(
record_batch: &RecordBatch,
filter: &FilterPredicate,
) -> std::result::Result<RecordBatch, ArrowError> {
let filtered_columns = record_batch
.columns()
.iter()
.map(|a| filter_array(a, filter))
.collect::<std::result::Result<Vec<_>, _>>()?;
unsafe {
Ok(RecordBatch::new_unchecked(
record_batch.schema(),
filtered_columns,
filter.count(),
))
}
}
#[inline(always)]
fn filter_array(
array: &dyn Array,
filter: &FilterPredicate,
) -> std::result::Result<ArrayRef, ArrowError> {
filter.filter(array)
}
fn merge(
mask: &BooleanArray,
truthy: ColumnarValue,
falsy: ColumnarValue,
) -> std::result::Result<ArrayRef, ArrowError> {
let (truthy, truthy_is_scalar) = match truthy {
ColumnarValue::Array(a) => (a, false),
ColumnarValue::Scalar(s) => (s.to_array()?, true),
};
let (falsy, falsy_is_scalar) = match falsy {
ColumnarValue::Array(a) => (a, false),
ColumnarValue::Scalar(s) => (s.to_array()?, true),
};
if truthy_is_scalar && falsy_is_scalar {
return zip(mask, &Scalar::new(truthy), &Scalar::new(falsy));
}
let falsy = falsy.to_data();
let truthy = truthy.to_data();
let mut mutable = MutableArrayData::new(vec![&truthy, &falsy], false, truthy.len());
let mut filled = 0;
let mut falsy_offset = 0;
let mut truthy_offset = 0;
SlicesIterator::new(mask).for_each(|(start, end)| {
if start > filled {
if falsy_is_scalar {
for _ in filled..start {
mutable.extend(1, 0, 1);
}
} else {
let falsy_length = start - filled;
let falsy_end = falsy_offset + falsy_length;
mutable.extend(1, falsy_offset, falsy_end);
falsy_offset = falsy_end;
}
}
if truthy_is_scalar {
for _ in start..end {
mutable.extend(0, 0, 1);
}
} else {
let truthy_length = end - start;
let truthy_end = truthy_offset + truthy_length;
mutable.extend(0, truthy_offset, truthy_end);
truthy_offset = truthy_end;
}
filled = end;
});
if filled < mask.len() {
if falsy_is_scalar {
for _ in filled..mask.len() {
mutable.extend(1, 0, 1);
}
} else {
let falsy_length = mask.len() - filled;
let falsy_end = falsy_offset + falsy_length;
mutable.extend(1, falsy_offset, falsy_end);
}
}
let data = mutable.freeze();
Ok(make_array(data))
}
fn merge_n(values: &[ArrayData], indices: &[PartialResultIndex]) -> Result<ArrayRef> {
#[cfg(debug_assertions)]
for ix in indices {
if let Some(index) = ix.index() {
assert!(
index < values.len(),
"Index out of bounds: {} >= {}",
index,
values.len()
);
}
}
let data_refs = values.iter().collect();
let mut mutable = MutableArrayData::new(data_refs, true, indices.len());
let mut take_offsets = vec![0; values.len() + 1];
let mut start_row_ix = 0;
loop {
let array_ix = indices[start_row_ix];
let mut end_row_ix = start_row_ix + 1;
while end_row_ix < indices.len() && indices[end_row_ix] == array_ix {
end_row_ix += 1;
}
let slice_length = end_row_ix - start_row_ix;
match array_ix.index() {
None => mutable.extend_nulls(slice_length),
Some(index) => {
let start_offset = take_offsets[index];
let end_offset = start_offset + slice_length;
mutable.extend(index, start_offset, end_offset);
take_offsets[index] = end_offset;
}
}
if end_row_ix == indices.len() {
break;
} else {
start_row_ix = end_row_ix;
}
}
Ok(make_array(mutable.freeze()))
}
#[derive(Copy, Clone, PartialEq, Eq)]
struct PartialResultIndex {
index: u32,
}
const NONE_VALUE: u32 = u32::MAX;
impl PartialResultIndex {
fn none() -> Self {
Self { index: NONE_VALUE }
}
fn zero() -> Self {
Self { index: 0 }
}
fn try_new(index: usize) -> Result<Self> {
let Ok(index) = u32::try_from(index) else {
return internal_err!("Partial result index exceeds limit");
};
if index == NONE_VALUE {
return internal_err!("Partial result index exceeds limit");
}
Ok(Self { index })
}
fn is_none(&self) -> bool {
self.index == NONE_VALUE
}
fn index(&self) -> Option<usize> {
if self.is_none() {
None
} else {
Some(self.index as usize)
}
}
}
impl Debug for PartialResultIndex {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
if self.is_none() {
write!(f, "null")
} else {
write!(f, "{}", self.index)
}
}
}
enum ResultState {
Empty,
Partial {
arrays: Vec<ArrayData>,
indices: Vec<PartialResultIndex>,
},
Complete(ColumnarValue),
}
struct ResultBuilder {
data_type: DataType,
row_count: usize,
state: ResultState,
}
impl ResultBuilder {
fn new(data_type: &DataType, row_count: usize) -> Self {
Self {
data_type: data_type.clone(),
row_count,
state: Empty,
}
}
fn add_branch_result(
&mut self,
row_indices: &ArrayRef,
value: ColumnarValue,
) -> Result<()> {
match value {
ColumnarValue::Array(a) => {
if a.len() != row_indices.len() {
internal_err!("Array length must match row indices length")
} else if row_indices.len() == self.row_count {
self.set_complete_result(ColumnarValue::Array(a))
} else {
self.add_partial_result(row_indices, a.to_data())
}
}
ColumnarValue::Scalar(s) => {
if row_indices.len() == self.row_count {
self.set_complete_result(ColumnarValue::Scalar(s))
} else {
self.add_partial_result(
row_indices,
s.to_array_of_size(row_indices.len())?.to_data(),
)
}
}
}
}
fn add_partial_result(
&mut self,
row_indices: &ArrayRef,
row_values: ArrayData,
) -> Result<()> {
if row_indices.null_count() != 0 {
return internal_err!("Row indices must not contain nulls");
}
match &mut self.state {
Empty => {
let array_index = PartialResultIndex::zero();
let mut indices = vec![PartialResultIndex::none(); self.row_count];
for row_ix in row_indices.as_primitive::<UInt32Type>().values().iter() {
indices[*row_ix as usize] = array_index;
}
self.state = Partial {
arrays: vec![row_values],
indices,
};
Ok(())
}
Partial { arrays, indices } => {
let array_index = PartialResultIndex::try_new(arrays.len())?;
arrays.push(row_values);
for row_ix in row_indices.as_primitive::<UInt32Type>().values().iter() {
#[cfg(debug_assertions)]
if !indices[*row_ix as usize].is_none() {
return internal_err!("Duplicate value for row {}", *row_ix);
}
indices[*row_ix as usize] = array_index;
}
Ok(())
}
Complete(_) => internal_err!(
"Cannot add a partial result when complete result is already set"
),
}
}
fn set_complete_result(&mut self, value: ColumnarValue) -> Result<()> {
match &self.state {
Empty => {
self.state = Complete(value);
Ok(())
}
Partial { .. } => {
internal_err!(
"Cannot set a complete result when there are already partial results"
)
}
Complete(_) => internal_err!("Complete result already set"),
}
}
fn finish(self) -> Result<ColumnarValue> {
match self.state {
Empty => {
Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
&self.data_type,
)?))
}
Partial { arrays, indices } => {
Ok(ColumnarValue::Array(merge_n(&arrays, &indices)?))
}
Complete(v) => {
Ok(v)
}
}
}
}
impl CaseExpr {
pub fn try_new(
expr: Option<Arc<dyn PhysicalExpr>>,
when_then_expr: Vec<WhenThen>,
else_expr: Option<Arc<dyn PhysicalExpr>>,
) -> Result<Self> {
let else_expr = match &else_expr {
Some(e) => match e.as_any().downcast_ref::<Literal>() {
Some(lit) if lit.value().is_null() => None,
_ => else_expr,
},
_ => else_expr,
};
if when_then_expr.is_empty() {
return exec_err!("There must be at least one WHEN clause");
}
let body = CaseBody {
expr,
when_then_expr,
else_expr,
};
let eval_method = if body.expr.is_some() {
EvalMethod::WithExpression(body.project()?)
} else if body.when_then_expr.len() == 1
&& is_cheap_and_infallible(&(body.when_then_expr[0].1))
&& body.else_expr.is_none()
{
EvalMethod::InfallibleExprOrNull
} else if body.when_then_expr.len() == 1
&& body.when_then_expr[0].1.as_any().is::<Literal>()
&& body.else_expr.is_some()
&& body.else_expr.as_ref().unwrap().as_any().is::<Literal>()
{
EvalMethod::ScalarOrScalar
} else if body.when_then_expr.len() == 1 && body.else_expr.is_some() {
EvalMethod::ExpressionOrExpression(body.project()?)
} else {
EvalMethod::NoExpression(body.project()?)
};
Ok(Self { body, eval_method })
}
pub fn expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
self.body.expr.as_ref()
}
pub fn when_then_expr(&self) -> &[WhenThen] {
&self.body.when_then_expr
}
pub fn else_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
self.body.else_expr.as_ref()
}
}
impl CaseBody {
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
let mut data_type = DataType::Null;
for i in 0..self.when_then_expr.len() {
data_type = self.when_then_expr[i].1.data_type(input_schema)?;
if !data_type.equals_datatype(&DataType::Null) {
break;
}
}
if data_type.equals_datatype(&DataType::Null) {
if let Some(e) = &self.else_expr {
data_type = e.data_type(input_schema)?;
}
}
Ok(data_type)
}
fn case_when_with_expr(
&self,
batch: &RecordBatch,
return_type: &DataType,
) -> Result<ColumnarValue> {
let mut result_builder = ResultBuilder::new(return_type, batch.num_rows());
let mut remainder_rows: ArrayRef =
Arc::new(UInt32Array::from_iter_values(0..batch.num_rows() as u32));
let mut remainder_batch = Cow::Borrowed(batch);
let mut base_values = self
.expr
.as_ref()
.unwrap()
.evaluate(batch)?
.into_array(batch.num_rows())?;
if base_values.null_count() > 0 {
let base_not_nulls = is_not_null(base_values.as_ref())?;
let base_all_null = base_values.null_count() == remainder_batch.num_rows();
if let Some(e) = &self.else_expr {
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
if base_all_null {
let nulls_value = expr.evaluate(&remainder_batch)?;
result_builder.add_branch_result(&remainder_rows, nulls_value)?;
} else {
let nulls_filter = create_filter(¬(&base_not_nulls)?, true);
let nulls_batch =
filter_record_batch(&remainder_batch, &nulls_filter)?;
let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?;
let nulls_value = expr.evaluate(&nulls_batch)?;
result_builder.add_branch_result(&nulls_rows, nulls_value)?;
}
}
if base_all_null {
return result_builder.finish();
}
let not_null_filter = create_filter(&base_not_nulls, true);
remainder_batch =
Cow::Owned(filter_record_batch(&remainder_batch, ¬_null_filter)?);
remainder_rows = filter_array(&remainder_rows, ¬_null_filter)?;
base_values = filter_array(&base_values, ¬_null_filter)?;
}
let base_value_is_nested = base_values.data_type().is_nested();
for i in 0..self.when_then_expr.len() {
let when_expr = &self.when_then_expr[i].0;
let when_value = match when_expr.evaluate(&remainder_batch)? {
ColumnarValue::Array(a) => {
compare_with_eq(&a, &base_values, base_value_is_nested)
}
ColumnarValue::Scalar(s) => {
compare_with_eq(&s.to_scalar()?, &base_values, base_value_is_nested)
}
}?;
let when_true_count = when_value.true_count();
if when_true_count == 0 {
continue;
}
if when_true_count == remainder_batch.num_rows() {
let then_expression = &self.when_then_expr[i].1;
let then_value = then_expression.evaluate(&remainder_batch)?;
result_builder.add_branch_result(&remainder_rows, then_value)?;
return result_builder.finish();
}
let then_filter = create_filter(&when_value, true);
let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
let then_rows = filter_array(&remainder_rows, &then_filter)?;
let then_expression = &self.when_then_expr[i].1;
let then_value = then_expression.evaluate(&then_batch)?;
result_builder.add_branch_result(&then_rows, then_value)?;
if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 {
return result_builder.finish();
}
let next_selection = match when_value.null_count() {
0 => not(&when_value),
_ => {
not(&prep_null_mask_filter(&when_value))
}
}?;
let next_filter = create_filter(&next_selection, true);
remainder_batch =
Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
remainder_rows = filter_array(&remainder_rows, &next_filter)?;
base_values = filter_array(&base_values, &next_filter)?;
}
if let Some(e) = &self.else_expr {
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
let else_value = expr.evaluate(&remainder_batch)?;
result_builder.add_branch_result(&remainder_rows, else_value)?;
}
result_builder.finish()
}
fn case_when_no_expr(
&self,
batch: &RecordBatch,
return_type: &DataType,
) -> Result<ColumnarValue> {
let mut result_builder = ResultBuilder::new(return_type, batch.num_rows());
let mut remainder_rows: ArrayRef =
Arc::new(UInt32Array::from_iter(0..batch.num_rows() as u32));
let mut remainder_batch = Cow::Borrowed(batch);
for i in 0..self.when_then_expr.len() {
let when_predicate = &self.when_then_expr[i].0;
let when_value = when_predicate
.evaluate(&remainder_batch)?
.into_array(remainder_batch.num_rows())?;
let when_value = as_boolean_array(&when_value).map_err(|_| {
internal_datafusion_err!("WHEN expression did not return a BooleanArray")
})?;
let when_true_count = when_value.true_count();
if when_true_count == 0 {
continue;
}
if when_true_count == remainder_batch.num_rows() {
let then_expression = &self.when_then_expr[i].1;
let then_value = then_expression.evaluate(&remainder_batch)?;
result_builder.add_branch_result(&remainder_rows, then_value)?;
return result_builder.finish();
}
let then_filter = create_filter(when_value, true);
let then_batch = filter_record_batch(&remainder_batch, &then_filter)?;
let then_rows = filter_array(&remainder_rows, &then_filter)?;
let then_expression = &self.when_then_expr[i].1;
let then_value = then_expression.evaluate(&then_batch)?;
result_builder.add_branch_result(&then_rows, then_value)?;
if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 {
return result_builder.finish();
}
let next_selection = match when_value.null_count() {
0 => not(when_value),
_ => {
not(&prep_null_mask_filter(when_value))
}
}?;
let next_filter = create_filter(&next_selection, true);
remainder_batch =
Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?);
remainder_rows = filter_array(&remainder_rows, &next_filter)?;
}
if let Some(e) = &self.else_expr {
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
let else_value = expr.evaluate(&remainder_batch)?;
result_builder.add_branch_result(&remainder_rows, else_value)?;
}
result_builder.finish()
}
fn expr_or_expr(
&self,
batch: &RecordBatch,
when_value: &BooleanArray,
) -> Result<ColumnarValue> {
let when_value = match when_value.null_count() {
0 => Cow::Borrowed(when_value),
_ => {
Cow::Owned(prep_null_mask_filter(when_value))
}
};
let optimize_filter = batch.num_columns() > 1
|| (batch.num_columns() == 1 && multiple_arrays(batch.column(0).data_type()));
let when_filter = create_filter(&when_value, optimize_filter);
let then_batch = filter_record_batch(batch, &when_filter)?;
let then_value = self.when_then_expr[0].1.evaluate(&then_batch)?;
let else_selection = not(&when_value)?;
let else_filter = create_filter(&else_selection, optimize_filter);
let else_batch = filter_record_batch(batch, &else_filter)?;
let e = self.else_expr.as_ref().unwrap();
let return_type = self.data_type(&batch.schema())?;
let else_expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())
.unwrap_or_else(|_| Arc::clone(e));
let else_value = else_expr.evaluate(&else_batch)?;
Ok(ColumnarValue::Array(merge(
&when_value,
then_value,
else_value,
)?))
}
}
impl CaseExpr {
fn case_when_with_expr(
&self,
batch: &RecordBatch,
projected: &ProjectedCaseBody,
) -> Result<ColumnarValue> {
let return_type = self.data_type(&batch.schema())?;
if projected.projection.len() < batch.num_columns() {
let projected_batch = batch.project(&projected.projection)?;
projected
.body
.case_when_with_expr(&projected_batch, &return_type)
} else {
self.body.case_when_with_expr(batch, &return_type)
}
}
fn case_when_no_expr(
&self,
batch: &RecordBatch,
projected: &ProjectedCaseBody,
) -> Result<ColumnarValue> {
let return_type = self.data_type(&batch.schema())?;
if projected.projection.len() < batch.num_columns() {
let projected_batch = batch.project(&projected.projection)?;
projected
.body
.case_when_no_expr(&projected_batch, &return_type)
} else {
self.body.case_when_no_expr(batch, &return_type)
}
}
fn case_column_or_null(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let when_expr = &self.body.when_then_expr[0].0;
let then_expr = &self.body.when_then_expr[0].1;
match when_expr.evaluate(batch)? {
ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) => {
then_expr.evaluate(batch)
}
ColumnarValue::Scalar(_) => {
ScalarValue::try_from(self.data_type(&batch.schema())?)
.map(ColumnarValue::Scalar)
}
ColumnarValue::Array(bit_mask) => {
let bit_mask = bit_mask
.as_any()
.downcast_ref::<BooleanArray>()
.expect("predicate should evaluate to a boolean array");
let bit_mask = match bit_mask.null_count() {
0 => not(bit_mask)?,
_ => not(&prep_null_mask_filter(bit_mask))?,
};
match then_expr.evaluate(batch)? {
ColumnarValue::Array(array) => {
Ok(ColumnarValue::Array(nullif(&array, &bit_mask)?))
}
ColumnarValue::Scalar(_) => {
internal_err!("expression did not evaluate to an array")
}
}
}
}
}
fn scalar_or_scalar(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let return_type = self.data_type(&batch.schema())?;
let when_value = self.body.when_then_expr[0].0.evaluate(batch)?;
let when_value = when_value.into_array(batch.num_rows())?;
let when_value = as_boolean_array(&when_value).map_err(|_| {
internal_datafusion_err!("WHEN expression did not return a BooleanArray")
})?;
let when_value = match when_value.null_count() {
0 => Cow::Borrowed(when_value),
_ => Cow::Owned(prep_null_mask_filter(when_value)),
};
let then_value = self.body.when_then_expr[0].1.evaluate(batch)?;
let then_value = Scalar::new(then_value.into_array(1)?);
let Some(e) = &self.body.else_expr else {
return internal_err!("expression did not evaluate to an array");
};
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type)?;
let else_ = Scalar::new(expr.evaluate(batch)?.into_array(1)?);
Ok(ColumnarValue::Array(zip(&when_value, &then_value, &else_)?))
}
fn expr_or_expr(
&self,
batch: &RecordBatch,
projected: &ProjectedCaseBody,
) -> Result<ColumnarValue> {
let when_value = self.body.when_then_expr[0].0.evaluate(batch)?;
let when_value = when_value.into_array(1)?;
let when_value = as_boolean_array(&when_value).map_err(|e| {
DataFusionError::Context(
"WHEN expression did not return a BooleanArray".to_string(),
Box::new(e),
)
})?;
let true_count = when_value.true_count();
if true_count == when_value.len() {
self.body.when_then_expr[0].1.evaluate(batch)
} else if true_count == 0 {
self.body.else_expr.as_ref().unwrap().evaluate(batch)
} else if projected.projection.len() < batch.num_columns() {
let projected_batch = batch.project(&projected.projection)?;
projected.body.expr_or_expr(&projected_batch, when_value)
} else {
self.body.expr_or_expr(batch, when_value)
}
}
}
impl PhysicalExpr for CaseExpr {
fn as_any(&self) -> &dyn Any {
self
}
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
self.body.data_type(input_schema)
}
fn nullable(&self, input_schema: &Schema) -> Result<bool> {
let then_nullable = self
.body
.when_then_expr
.iter()
.map(|(_, t)| t.nullable(input_schema))
.collect::<Result<Vec<_>>>()?;
if then_nullable.contains(&true) {
Ok(true)
} else if let Some(e) = &self.body.else_expr {
e.nullable(input_schema)
} else {
Ok(true)
}
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
match &self.eval_method {
EvalMethod::WithExpression(p) => {
self.case_when_with_expr(batch, p)
}
EvalMethod::NoExpression(p) => {
self.case_when_no_expr(batch, p)
}
EvalMethod::InfallibleExprOrNull => {
self.case_column_or_null(batch)
}
EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch),
EvalMethod::ExpressionOrExpression(p) => self.expr_or_expr(batch, p),
}
}
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
let mut children = vec![];
if let Some(expr) = &self.body.expr {
children.push(expr)
}
self.body.when_then_expr.iter().for_each(|(cond, value)| {
children.push(cond);
children.push(value);
});
if let Some(else_expr) = &self.body.else_expr {
children.push(else_expr)
}
children
}
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
if children.len() != self.children().len() {
internal_err!("CaseExpr: Wrong number of children")
} else {
let (expr, when_then_expr, else_expr) =
match (self.expr().is_some(), self.body.else_expr.is_some()) {
(true, true) => (
Some(&children[0]),
&children[1..children.len() - 1],
Some(&children[children.len() - 1]),
),
(true, false) => {
(Some(&children[0]), &children[1..children.len()], None)
}
(false, true) => (
None,
&children[0..children.len() - 1],
Some(&children[children.len() - 1]),
),
(false, false) => (None, &children[0..children.len()], None),
};
Ok(Arc::new(CaseExpr::try_new(
expr.cloned(),
when_then_expr.iter().cloned().tuples().collect(),
else_expr.cloned(),
)?))
}
}
fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "CASE ")?;
if let Some(e) = &self.body.expr {
e.fmt_sql(f)?;
write!(f, " ")?;
}
for (w, t) in &self.body.when_then_expr {
write!(f, "WHEN ")?;
w.fmt_sql(f)?;
write!(f, " THEN ")?;
t.fmt_sql(f)?;
write!(f, " ")?;
}
if let Some(e) = &self.body.else_expr {
write!(f, "ELSE ")?;
e.fmt_sql(f)?;
write!(f, " ")?;
}
write!(f, "END")
}
}
pub fn case(
expr: Option<Arc<dyn PhysicalExpr>>,
when_thens: Vec<WhenThen>,
else_expr: Option<Arc<dyn PhysicalExpr>>,
) -> Result<Arc<dyn PhysicalExpr>> {
Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::expressions::{binary, cast, col, lit, BinaryExpr};
use arrow::buffer::Buffer;
use arrow::datatypes::DataType::Float64;
use arrow::datatypes::Field;
use datafusion_common::cast::{as_float64_array, as_int32_array};
use datafusion_common::plan_err;
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_expr::type_coercion::binary::comparison_coercion;
use datafusion_expr::Operator;
use datafusion_physical_expr_common::physical_expr::fmt_sql;
#[test]
fn case_with_expr() -> Result<()> {
let batch = case_test_batch()?;
let schema = batch.schema();
let when1 = lit("foo");
let then1 = lit(123i32);
let when2 = lit("bar");
let then2 = lit(456i32);
let expr = generate_case_when_with_type_coercion(
Some(col("a", &schema)?),
vec![(when1, then1), (when2, then2)],
None,
schema.as_ref(),
)?;
let result = expr
.evaluate(&batch)?
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let result = as_int32_array(&result)?;
let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
assert_eq!(expected, result);
Ok(())
}
#[test]
fn case_with_expr_else() -> Result<()> {
let batch = case_test_batch()?;
let schema = batch.schema();
let when1 = lit("foo");
let then1 = lit(123i32);
let when2 = lit("bar");
let then2 = lit(456i32);
let else_value = lit(999i32);
let expr = generate_case_when_with_type_coercion(
Some(col("a", &schema)?),
vec![(when1, then1), (when2, then2)],
Some(else_value),
schema.as_ref(),
)?;
let result = expr
.evaluate(&batch)?
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let result = as_int32_array(&result)?;
let expected =
&Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
assert_eq!(expected, result);
Ok(())
}
#[test]
fn case_with_expr_divide_by_zero() -> Result<()> {
let batch = case_test_batch1()?;
let schema = batch.schema();
let when1 = lit(0i32);
let then1 = lit(ScalarValue::Float64(None));
let else_value = binary(
lit(25.0f64),
Operator::Divide,
cast(col("a", &schema)?, &batch.schema(), Float64)?,
&batch.schema(),
)?;
let expr = generate_case_when_with_type_coercion(
Some(col("a", &schema)?),
vec![(when1, then1)],
Some(else_value),
schema.as_ref(),
)?;
let result = expr
.evaluate(&batch)?
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let result =
as_float64_array(&result).expect("failed to downcast to Float64Array");
let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
assert_eq!(expected, result);
Ok(())
}
#[test]
fn case_without_expr() -> Result<()> {
let batch = case_test_batch()?;
let schema = batch.schema();
let when1 = binary(
col("a", &schema)?,
Operator::Eq,
lit("foo"),
&batch.schema(),
)?;
let then1 = lit(123i32);
let when2 = binary(
col("a", &schema)?,
Operator::Eq,
lit("bar"),
&batch.schema(),
)?;
let then2 = lit(456i32);
let expr = generate_case_when_with_type_coercion(
None,
vec![(when1, then1), (when2, then2)],
None,
schema.as_ref(),
)?;
let result = expr
.evaluate(&batch)?
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let result = as_int32_array(&result)?;
let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]);
assert_eq!(expected, result);
Ok(())
}
#[test]
fn case_with_expr_when_null() -> Result<()> {
let batch = case_test_batch()?;
let schema = batch.schema();
let when1 = lit(ScalarValue::Utf8(None));
let then1 = lit(0i32);
let when2 = col("a", &schema)?;
let then2 = lit(123i32);
let else_value = lit(999i32);
let expr = generate_case_when_with_type_coercion(
Some(col("a", &schema)?),
vec![(when1, then1), (when2, then2)],
Some(else_value),
schema.as_ref(),
)?;
let result = expr
.evaluate(&batch)?
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let result = as_int32_array(&result)?;
let expected =
&Int32Array::from(vec![Some(123), Some(123), Some(999), Some(123)]);
assert_eq!(expected, result);
Ok(())
}
#[test]
fn case_without_expr_divide_by_zero() -> Result<()> {
let batch = case_test_batch1()?;
let schema = batch.schema();
let when1 = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &batch.schema())?;
let then1 = binary(
lit(25.0f64),
Operator::Divide,
cast(col("a", &schema)?, &batch.schema(), Float64)?,
&batch.schema(),
)?;
let x = lit(ScalarValue::Float64(None));
let expr = generate_case_when_with_type_coercion(
None,
vec![(when1, then1)],
Some(x),
schema.as_ref(),
)?;
let result = expr
.evaluate(&batch)?
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let result =
as_float64_array(&result).expect("failed to downcast to Float64Array");
let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
assert_eq!(expected, result);
Ok(())
}
fn case_test_batch1() -> Result<RecordBatch> {
let schema = Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
Field::new("c", DataType::Int32, true),
]);
let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]);
let b = Int32Array::from(vec![Some(3), None, Some(14), Some(7)]);
let c = Int32Array::from(vec![Some(0), Some(-3), Some(777), None]);
let batch = RecordBatch::try_new(
Arc::new(schema),
vec![Arc::new(a), Arc::new(b), Arc::new(c)],
)?;
Ok(batch)
}
#[test]
fn case_without_expr_else() -> Result<()> {
let batch = case_test_batch()?;
let schema = batch.schema();
let when1 = binary(
col("a", &schema)?,
Operator::Eq,
lit("foo"),
&batch.schema(),
)?;
let then1 = lit(123i32);
let when2 = binary(
col("a", &schema)?,
Operator::Eq,
lit("bar"),
&batch.schema(),
)?;
let then2 = lit(456i32);
let else_value = lit(999i32);
let expr = generate_case_when_with_type_coercion(
None,
vec![(when1, then1), (when2, then2)],
Some(else_value),
schema.as_ref(),
)?;
let result = expr
.evaluate(&batch)?
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let result = as_int32_array(&result)?;
let expected =
&Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
assert_eq!(expected, result);
Ok(())
}
#[test]
fn case_with_type_cast() -> Result<()> {
let batch = case_test_batch()?;
let schema = batch.schema();
let when = binary(
col("a", &schema)?,
Operator::Eq,
lit("foo"),
&batch.schema(),
)?;
let then = lit(123.3f64);
let else_value = lit(999i32);
let expr = generate_case_when_with_type_coercion(
None,
vec![(when, then)],
Some(else_value),
schema.as_ref(),
)?;
let result = expr
.evaluate(&batch)?
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let result =
as_float64_array(&result).expect("failed to downcast to Float64Array");
let expected =
&Float64Array::from(vec![Some(123.3), Some(999.0), Some(999.0), Some(999.0)]);
assert_eq!(expected, result);
Ok(())
}
#[test]
fn case_with_matches_and_nulls() -> Result<()> {
let batch = case_test_batch_nulls()?;
let schema = batch.schema();
let when = binary(
col("load4", &schema)?,
Operator::Eq,
lit(1.77f64),
&batch.schema(),
)?;
let then = col("load4", &schema)?;
let expr = generate_case_when_with_type_coercion(
None,
vec![(when, then)],
None,
schema.as_ref(),
)?;
let result = expr
.evaluate(&batch)?
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let result =
as_float64_array(&result).expect("failed to downcast to Float64Array");
let expected =
&Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
assert_eq!(expected, result);
Ok(())
}
#[test]
fn case_with_scalar_predicate() -> Result<()> {
let batch = case_test_batch_nulls()?;
let schema = batch.schema();
let when = lit(true);
let then = col("load4", &schema)?;
let expr = generate_case_when_with_type_coercion(
None,
vec![(when, then)],
None,
schema.as_ref(),
)?;
let result = expr
.evaluate(&batch)?
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let result =
as_float64_array(&result).expect("failed to downcast to Float64Array");
let expected = &Float64Array::from(vec![
Some(1.77),
None,
None,
Some(1.78),
None,
Some(1.77),
]);
assert_eq!(expected, result);
let expected = Float64Array::from(vec![Some(1.1)]);
let batch =
RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(expected.clone())])?;
let result = expr
.evaluate(&batch)?
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let result =
as_float64_array(&result).expect("failed to downcast to Float64Array");
assert_eq!(&expected, result);
Ok(())
}
#[test]
fn case_expr_matches_and_nulls() -> Result<()> {
let batch = case_test_batch_nulls()?;
let schema = batch.schema();
let expr = col("load4", &schema)?;
let when = lit(1.77f64);
let then = col("load4", &schema)?;
let expr = generate_case_when_with_type_coercion(
Some(expr),
vec![(when, then)],
None,
schema.as_ref(),
)?;
let result = expr
.evaluate(&batch)?
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let result =
as_float64_array(&result).expect("failed to downcast to Float64Array");
let expected =
&Float64Array::from(vec![Some(1.77), None, None, None, None, Some(1.77)]);
assert_eq!(expected, result);
Ok(())
}
#[test]
fn test_when_null_and_some_cond_else_null() -> Result<()> {
let batch = case_test_batch()?;
let schema = batch.schema();
let when = binary(
Arc::new(Literal::new(ScalarValue::Boolean(None))),
Operator::And,
binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?,
&schema,
)?;
let then = col("a", &schema)?;
let expr = Arc::new(CaseExpr::try_new(None, vec![(when, then)], None)?);
let result = expr
.evaluate(&batch)?
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let result = as_string_array(&result);
assert_eq!(result.logical_null_count(), batch.num_rows());
Ok(())
}
fn case_test_batch() -> Result<RecordBatch> {
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
Ok(batch)
}
fn case_test_batch_nulls() -> Result<RecordBatch> {
let load4: Float64Array = vec![
Some(1.77), Some(1.77), Some(1.77), Some(1.78), None, Some(1.77), ]
.into_iter()
.collect();
let null_buffer = Buffer::from([0b00101001u8]);
let load4 = load4
.into_data()
.into_builder()
.null_bit_buffer(Some(null_buffer))
.build()
.unwrap();
let load4: Float64Array = load4.into();
let batch =
RecordBatch::try_from_iter(vec![("load4", Arc::new(load4) as ArrayRef)])?;
Ok(batch)
}
#[test]
fn case_test_incompatible() -> Result<()> {
let batch = case_test_batch()?;
let schema = batch.schema();
let when1 = binary(
col("a", &schema)?,
Operator::Eq,
lit("foo"),
&batch.schema(),
)?;
let then1 = lit(123i32);
let when2 = binary(
col("a", &schema)?,
Operator::Eq,
lit("bar"),
&batch.schema(),
)?;
let then2 = lit(true);
let expr = generate_case_when_with_type_coercion(
None,
vec![(when1, then1), (when2, then2)],
None,
schema.as_ref(),
);
assert!(expr.is_err());
let when1 = binary(
col("a", &schema)?,
Operator::Eq,
lit("foo"),
&batch.schema(),
)?;
let then1 = lit(123i32);
let when2 = binary(
col("a", &schema)?,
Operator::Eq,
lit("bar"),
&batch.schema(),
)?;
let then2 = lit(456i64);
let else_expr = lit(1.23f64);
let expr = generate_case_when_with_type_coercion(
None,
vec![(when1, then1), (when2, then2)],
Some(else_expr),
schema.as_ref(),
);
assert!(expr.is_ok());
let result_type = expr.unwrap().data_type(schema.as_ref())?;
assert_eq!(Float64, result_type);
Ok(())
}
#[test]
fn case_eq() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
let when1 = lit("foo");
let then1 = lit(123i32);
let when2 = lit("bar");
let then2 = lit(456i32);
let else_value = lit(999i32);
let expr1 = generate_case_when_with_type_coercion(
Some(col("a", &schema)?),
vec![
(Arc::clone(&when1), Arc::clone(&then1)),
(Arc::clone(&when2), Arc::clone(&then2)),
],
Some(Arc::clone(&else_value)),
&schema,
)?;
let expr2 = generate_case_when_with_type_coercion(
Some(col("a", &schema)?),
vec![
(Arc::clone(&when1), Arc::clone(&then1)),
(Arc::clone(&when2), Arc::clone(&then2)),
],
Some(Arc::clone(&else_value)),
&schema,
)?;
let expr3 = generate_case_when_with_type_coercion(
Some(col("a", &schema)?),
vec![(Arc::clone(&when1), Arc::clone(&then1)), (when2, then2)],
None,
&schema,
)?;
let expr4 = generate_case_when_with_type_coercion(
Some(col("a", &schema)?),
vec![(when1, then1)],
Some(else_value),
&schema,
)?;
assert!(expr1.eq(&expr2));
assert!(expr2.eq(&expr1));
assert!(expr2.ne(&expr3));
assert!(expr3.ne(&expr2));
assert!(expr1.ne(&expr4));
assert!(expr4.ne(&expr1));
Ok(())
}
#[test]
fn case_transform() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
let when1 = lit("foo");
let then1 = lit(123i32);
let when2 = lit("bar");
let then2 = lit(456i32);
let else_value = lit(999i32);
let expr = generate_case_when_with_type_coercion(
Some(col("a", &schema)?),
vec![
(Arc::clone(&when1), Arc::clone(&then1)),
(Arc::clone(&when2), Arc::clone(&then2)),
],
Some(Arc::clone(&else_value)),
&schema,
)?;
let expr2 = Arc::clone(&expr)
.transform(|e| {
let transformed = match e.as_any().downcast_ref::<Literal>() {
Some(lit_value) => match lit_value.value() {
ScalarValue::Utf8(Some(str_value)) => {
Some(lit(str_value.to_uppercase()))
}
_ => None,
},
_ => None,
};
Ok(if let Some(transformed) = transformed {
Transformed::yes(transformed)
} else {
Transformed::no(e)
})
})
.data()
.unwrap();
let expr3 = Arc::clone(&expr)
.transform_down(|e| {
let transformed = match e.as_any().downcast_ref::<Literal>() {
Some(lit_value) => match lit_value.value() {
ScalarValue::Utf8(Some(str_value)) => {
Some(lit(str_value.to_uppercase()))
}
_ => None,
},
_ => None,
};
Ok(if let Some(transformed) = transformed {
Transformed::yes(transformed)
} else {
Transformed::no(e)
})
})
.data()
.unwrap();
assert!(expr.ne(&expr2));
assert!(expr2.eq(&expr3));
Ok(())
}
#[test]
fn test_column_or_null_specialization() -> Result<()> {
let mut c1 = Int32Builder::new();
let mut c2 = StringBuilder::new();
for i in 0..1000 {
c1.append_value(i);
if i % 7 == 0 {
c2.append_null();
} else {
c2.append_value(format!("string {i}"));
}
}
let c1 = Arc::new(c1.finish());
let c2 = Arc::new(c2.finish());
let schema = Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Utf8, true),
]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2]).unwrap();
let predicate = Arc::new(BinaryExpr::new(
make_col("c1", 0),
Operator::LtEq,
make_lit_i32(250),
));
let expr = CaseExpr::try_new(None, vec![(predicate, make_col("c2", 1))], None)?;
assert!(matches!(expr.eval_method, EvalMethod::InfallibleExprOrNull));
match expr.evaluate(&batch)? {
ColumnarValue::Array(array) => {
assert_eq!(1000, array.len());
assert_eq!(785, array.null_count());
}
_ => unreachable!(),
}
Ok(())
}
#[test]
fn test_expr_or_expr_specialization() -> Result<()> {
let batch = case_test_batch1()?;
let schema = batch.schema();
let when = binary(
col("a", &schema)?,
Operator::LtEq,
lit(2i32),
&batch.schema(),
)?;
let then = col("b", &schema)?;
let else_expr = col("c", &schema)?;
let expr = CaseExpr::try_new(None, vec![(when, then)], Some(else_expr))?;
assert!(matches!(
expr.eval_method,
EvalMethod::ExpressionOrExpression(_)
));
let result = expr
.evaluate(&batch)?
.into_array(batch.num_rows())
.expect("Failed to convert to array");
let result = as_int32_array(&result).expect("failed to downcast to Int32Array");
let expected = &Int32Array::from(vec![Some(3), None, Some(777), None]);
assert_eq!(expected, result);
Ok(())
}
fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
Arc::new(Column::new(name, index))
}
fn make_lit_i32(n: i32) -> Arc<dyn PhysicalExpr> {
Arc::new(Literal::new(ScalarValue::Int32(Some(n))))
}
fn generate_case_when_with_type_coercion(
expr: Option<Arc<dyn PhysicalExpr>>,
when_thens: Vec<WhenThen>,
else_expr: Option<Arc<dyn PhysicalExpr>>,
input_schema: &Schema,
) -> Result<Arc<dyn PhysicalExpr>> {
let coerce_type =
get_case_common_type(&when_thens, else_expr.clone(), input_schema);
let (when_thens, else_expr) = match coerce_type {
None => plan_err!(
"Can't get a common type for then {when_thens:?} and else {else_expr:?} expression"
),
Some(data_type) => {
let left = when_thens
.into_iter()
.map(|(when, then)| {
let then = try_cast(then, input_schema, data_type.clone())?;
Ok((when, then))
})
.collect::<Result<Vec<_>>>()?;
let right = match else_expr {
None => None,
Some(expr) => Some(try_cast(expr, input_schema, data_type.clone())?),
};
Ok((left, right))
}
}?;
case(expr, when_thens, else_expr)
}
fn get_case_common_type(
when_thens: &[WhenThen],
else_expr: Option<Arc<dyn PhysicalExpr>>,
input_schema: &Schema,
) -> Option<DataType> {
let thens_type = when_thens
.iter()
.map(|when_then| {
let data_type = &when_then.1.data_type(input_schema).unwrap();
data_type.clone()
})
.collect::<Vec<_>>();
let else_type = match else_expr {
None => {
thens_type[0].clone()
}
Some(else_phy_expr) => else_phy_expr.data_type(input_schema).unwrap(),
};
thens_type
.iter()
.try_fold(else_type, |left_type, right_type| {
comparison_coercion(&left_type, right_type)
})
}
#[test]
fn test_fmt_sql() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
let when = binary(col("a", &schema)?, Operator::Eq, lit("foo"), &schema)?;
let then = lit(123.3f64);
let else_value = lit(999i32);
let expr = generate_case_when_with_type_coercion(
None,
vec![(when, then)],
Some(else_value),
&schema,
)?;
let display_string = expr.to_string();
assert_eq!(
display_string,
"CASE WHEN a@0 = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
);
let sql_string = fmt_sql(expr.as_ref()).to_string();
assert_eq!(
sql_string,
"CASE WHEN a = foo THEN 123.3 ELSE TRY_CAST(999 AS Float64) END"
);
Ok(())
}
#[test]
fn test_merge_n() {
let a1 = StringArray::from(vec![Some("A")]).to_data();
let a2 = StringArray::from(vec![Some("B")]).to_data();
let a3 = StringArray::from(vec![Some("C"), Some("D")]).to_data();
let indices = vec![
PartialResultIndex::none(),
PartialResultIndex::try_new(1).unwrap(),
PartialResultIndex::try_new(0).unwrap(),
PartialResultIndex::none(),
PartialResultIndex::try_new(2).unwrap(),
PartialResultIndex::try_new(2).unwrap(),
];
let merged = merge_n(&[a1, a2, a3], &indices).unwrap();
let merged = merged.as_string::<i32>();
assert_eq!(merged.len(), indices.len());
assert!(!merged.is_valid(0));
assert!(merged.is_valid(1));
assert_eq!(merged.value(1), "B");
assert!(merged.is_valid(2));
assert_eq!(merged.value(2), "A");
assert!(!merged.is_valid(3));
assert!(merged.is_valid(4));
assert_eq!(merged.value(4), "C");
assert!(merged.is_valid(5));
assert_eq!(merged.value(5), "D");
}
#[test]
fn test_merge() {
let a1 = Arc::new(StringArray::from(vec![Some("A"), Some("C")]));
let a2 = Arc::new(StringArray::from(vec![Some("B")]));
let mask = BooleanArray::from(vec![true, false, true]);
let merged =
merge(&mask, ColumnarValue::Array(a1), ColumnarValue::Array(a2)).unwrap();
let merged = merged.as_string::<i32>();
assert_eq!(merged.len(), mask.len());
assert!(merged.is_valid(0));
assert_eq!(merged.value(0), "A");
assert!(merged.is_valid(1));
assert_eq!(merged.value(1), "B");
assert!(merged.is_valid(2));
assert_eq!(merged.value(2), "C");
}
}