use arrow::array::Array;
use arrow::buffer::NullBuffer;
use arrow::datatypes::DataType;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::ColumnarValue;
use std::sync::Arc;
pub(crate) enum NullMaskResolution {
ReturnNull,
NoMask,
Apply(NullBuffer),
}
pub(crate) fn compute_null_mask(
args: &[ColumnarValue],
number_rows: usize,
) -> Result<NullMaskResolution> {
let all_scalars = args
.iter()
.all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
if all_scalars {
for arg in args {
if let ColumnarValue::Scalar(scalar) = arg
&& scalar.is_null()
{
return Ok(NullMaskResolution::ReturnNull);
}
}
Ok(NullMaskResolution::NoMask)
} else {
let array_len = args
.iter()
.find_map(|arg| match arg {
ColumnarValue::Array(array) => Some(array.len()),
_ => None,
})
.unwrap_or(number_rows);
let arrays: Result<Vec<_>> = args
.iter()
.map(|arg| match arg {
ColumnarValue::Array(array) => Ok(Arc::clone(array)),
ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len),
})
.collect();
let arrays = arrays?;
let combined_nulls = arrays
.iter()
.map(|arr| arr.nulls())
.fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls));
match combined_nulls {
Some(nulls) => Ok(NullMaskResolution::Apply(nulls)),
None => Ok(NullMaskResolution::NoMask),
}
}
}
pub(crate) fn apply_null_mask(
result: ColumnarValue,
null_mask: NullMaskResolution,
return_type: &DataType,
) -> Result<ColumnarValue> {
match (result, null_mask) {
(ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => {
Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?))
}
(scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => Ok(scalar),
(ColumnarValue::Array(array), NullMaskResolution::Apply(null_mask)) => {
let combined_nulls = NullBuffer::union(array.nulls(), Some(&null_mask));
let new_array = array
.into_data()
.into_builder()
.nulls(combined_nulls)
.build()?;
Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array(
new_array,
))))
}
(array @ ColumnarValue::Array(_), NullMaskResolution::NoMask) => Ok(array),
(scalar, _) => Ok(scalar),
}
}