use arrow::array::{
Array, AsArray, BooleanArray, BooleanBufferBuilder, GenericListArray, OffsetSizeTrait,
};
use arrow::buffer::{BooleanBuffer, NullBuffer};
use arrow::datatypes::DataType;
use datafusion_common::{Result, exec_err};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_functions_nested::array_has::array_has_udf;
use std::any::Any;
use std::sync::Arc;
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkArrayContains {
signature: Signature,
}
impl Default for SparkArrayContains {
fn default() -> Self {
Self::new()
}
}
impl SparkArrayContains {
pub fn new() -> Self {
Self {
signature: Signature::array_and_element(Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for SparkArrayContains {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"array_contains"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
Ok(DataType::Boolean)
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let haystack = args.args[0].clone();
let array_has_result = array_has_udf().invoke_with_args(args)?;
let result_array = array_has_result.to_array(1)?;
let patched = apply_spark_null_semantics(result_array.as_boolean(), &haystack)?;
Ok(ColumnarValue::Array(Arc::new(patched)))
}
}
fn apply_spark_null_semantics(
result: &BooleanArray,
haystack_arg: &ColumnarValue,
) -> Result<BooleanArray> {
if result.false_count() == 0 || haystack_arg.data_type() == DataType::Null {
return Ok(result.clone());
}
let haystack = haystack_arg.to_array_of_size(result.len())?;
let row_has_nulls = compute_row_has_nulls(&haystack)?;
let keep_mask = result.values() | &!&row_has_nulls;
let new_validity = match result.nulls() {
Some(n) => n.inner() & &keep_mask,
None => keep_mask,
};
Ok(BooleanArray::new(
result.values().clone(),
Some(NullBuffer::new(new_validity)),
))
}
fn compute_row_has_nulls(haystack: &dyn Array) -> Result<BooleanBuffer> {
match haystack.data_type() {
DataType::List(_) => generic_list_row_has_nulls(haystack.as_list::<i32>()),
DataType::LargeList(_) => generic_list_row_has_nulls(haystack.as_list::<i64>()),
DataType::FixedSizeList(_, _) => {
let list = haystack.as_fixed_size_list();
let buf = match list.values().nulls() {
Some(nulls) => {
let validity = nulls.inner();
let vl = list.value_length() as usize;
let mut builder = BooleanBufferBuilder::new(list.len());
for i in 0..list.len() {
builder.append(validity.slice(i * vl, vl).count_set_bits() < vl);
}
builder.finish()
}
None => BooleanBuffer::new_unset(list.len()),
};
Ok(mask_with_list_nulls(buf, list.nulls()))
}
dt => exec_err!("compute_row_has_nulls: unsupported data type {dt}"),
}
}
fn generic_list_row_has_nulls<O: OffsetSizeTrait>(
list: &GenericListArray<O>,
) -> Result<BooleanBuffer> {
let buf = match list.values().nulls() {
Some(nulls) => {
let validity = nulls.inner();
let offsets = list.offsets();
let mut builder = BooleanBufferBuilder::new(list.len());
for i in 0..list.len() {
let s = offsets[i].as_usize();
let len = offsets[i + 1].as_usize() - s;
builder.append(validity.slice(s, len).count_set_bits() < len);
}
builder.finish()
}
None => BooleanBuffer::new_unset(list.len()),
};
Ok(mask_with_list_nulls(buf, list.nulls()))
}
fn mask_with_list_nulls(
buf: BooleanBuffer,
list_nulls: Option<&NullBuffer>,
) -> BooleanBuffer {
match list_nulls {
Some(n) => &buf & n.inner(),
None => buf,
}
}