datafusion_spark/function/array/
array_contains.rs1use arrow::array::{
19 Array, AsArray, BooleanArray, BooleanBufferBuilder, GenericListArray, OffsetSizeTrait,
20};
21use arrow::buffer::{BooleanBuffer, NullBuffer};
22use arrow::datatypes::DataType;
23use datafusion_common::{Result, exec_err};
24use datafusion_expr::{
25 ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
26};
27use datafusion_functions_nested::array_has::array_has_udf;
28use std::any::Any;
29use std::sync::Arc;
30
31#[derive(Debug, PartialEq, Eq, Hash)]
40pub struct SparkArrayContains {
41 signature: Signature,
42}
43
44impl Default for SparkArrayContains {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50impl SparkArrayContains {
51 pub fn new() -> Self {
52 Self {
53 signature: Signature::array_and_element(Volatility::Immutable),
54 }
55 }
56}
57
58impl ScalarUDFImpl for SparkArrayContains {
59 fn as_any(&self) -> &dyn Any {
60 self
61 }
62
63 fn name(&self) -> &str {
64 "array_contains"
65 }
66
67 fn signature(&self) -> &Signature {
68 &self.signature
69 }
70
71 fn return_type(&self, _: &[DataType]) -> Result<DataType> {
72 Ok(DataType::Boolean)
73 }
74
75 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
76 let haystack = args.args[0].clone();
77 let array_has_result = array_has_udf().invoke_with_args(args)?;
78
79 let result_array = array_has_result.to_array(1)?;
80 let patched = apply_spark_null_semantics(result_array.as_boolean(), &haystack)?;
81 Ok(ColumnarValue::Array(Arc::new(patched)))
82 }
83}
84
85fn apply_spark_null_semantics(
88 result: &BooleanArray,
89 haystack_arg: &ColumnarValue,
90) -> Result<BooleanArray> {
91 if result.false_count() == 0 || haystack_arg.data_type() == DataType::Null {
93 return Ok(result.clone());
94 }
95
96 let haystack = haystack_arg.to_array_of_size(result.len())?;
97
98 let row_has_nulls = compute_row_has_nulls(&haystack)?;
99
100 let keep_mask = result.values() | &!&row_has_nulls;
102 let new_validity = match result.nulls() {
103 Some(n) => n.inner() & &keep_mask,
104 None => keep_mask,
105 };
106
107 Ok(BooleanArray::new(
108 result.values().clone(),
109 Some(NullBuffer::new(new_validity)),
110 ))
111}
112
113fn compute_row_has_nulls(haystack: &dyn Array) -> Result<BooleanBuffer> {
115 match haystack.data_type() {
116 DataType::List(_) => generic_list_row_has_nulls(haystack.as_list::<i32>()),
117 DataType::LargeList(_) => generic_list_row_has_nulls(haystack.as_list::<i64>()),
118 DataType::FixedSizeList(_, _) => {
119 let list = haystack.as_fixed_size_list();
120 let buf = match list.values().nulls() {
121 Some(nulls) => {
122 let validity = nulls.inner();
123 let vl = list.value_length() as usize;
124 let mut builder = BooleanBufferBuilder::new(list.len());
125 for i in 0..list.len() {
126 builder.append(validity.slice(i * vl, vl).count_set_bits() < vl);
127 }
128 builder.finish()
129 }
130 None => BooleanBuffer::new_unset(list.len()),
131 };
132 Ok(mask_with_list_nulls(buf, list.nulls()))
133 }
134 dt => exec_err!("compute_row_has_nulls: unsupported data type {dt}"),
135 }
136}
137
138fn generic_list_row_has_nulls<O: OffsetSizeTrait>(
140 list: &GenericListArray<O>,
141) -> Result<BooleanBuffer> {
142 let buf = match list.values().nulls() {
143 Some(nulls) => {
144 let validity = nulls.inner();
145 let offsets = list.offsets();
146 let mut builder = BooleanBufferBuilder::new(list.len());
147 for i in 0..list.len() {
148 let s = offsets[i].as_usize();
149 let len = offsets[i + 1].as_usize() - s;
150 builder.append(validity.slice(s, len).count_set_bits() < len);
151 }
152 builder.finish()
153 }
154 None => BooleanBuffer::new_unset(list.len()),
155 };
156 Ok(mask_with_list_nulls(buf, list.nulls()))
157}
158
159fn mask_with_list_nulls(
161 buf: BooleanBuffer,
162 list_nulls: Option<&NullBuffer>,
163) -> BooleanBuffer {
164 match list_nulls {
165 Some(n) => &buf & n.inner(),
166 None => buf,
167 }
168}