Skip to main content

datafusion_spark/function/array/
array_contains.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    Array, AsArray, BooleanArray, BooleanBufferBuilder, GenericListArray, OffsetSizeTrait,
20};
21use arrow::buffer::{BooleanBuffer, NullBuffer};
22use arrow::datatypes::DataType;
23use datafusion_common::{Result, exec_err};
24use datafusion_expr::{
25    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
26};
27use datafusion_functions_nested::array_has::array_has_udf;
28use std::any::Any;
29use std::sync::Arc;
30
31/// Spark-compatible `array_contains` function.
32///
33/// Calls DataFusion's `array_has` and then applies Spark's null semantics:
34/// - If the result from `array_has` is `true`, return `true`.
35/// - If the result is `false` and the input array row contains any null elements,
36///   return `null` (because the element might have been the null).
37/// - If the result is `false` and the input array row has no null elements,
38///   return `false`.
39#[derive(Debug, PartialEq, Eq, Hash)]
40pub struct SparkArrayContains {
41    signature: Signature,
42}
43
44impl Default for SparkArrayContains {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50impl SparkArrayContains {
51    pub fn new() -> Self {
52        Self {
53            signature: Signature::array_and_element(Volatility::Immutable),
54        }
55    }
56}
57
58impl ScalarUDFImpl for SparkArrayContains {
59    fn as_any(&self) -> &dyn Any {
60        self
61    }
62
63    fn name(&self) -> &str {
64        "array_contains"
65    }
66
67    fn signature(&self) -> &Signature {
68        &self.signature
69    }
70
71    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
72        Ok(DataType::Boolean)
73    }
74
75    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
76        let haystack = args.args[0].clone();
77        let array_has_result = array_has_udf().invoke_with_args(args)?;
78
79        let result_array = array_has_result.to_array(1)?;
80        let patched = apply_spark_null_semantics(result_array.as_boolean(), &haystack)?;
81        Ok(ColumnarValue::Array(Arc::new(patched)))
82    }
83}
84
85/// For each row where `array_has` returned `false`, set the output to null
86/// if that row's input array contains any null elements.
87fn apply_spark_null_semantics(
88    result: &BooleanArray,
89    haystack_arg: &ColumnarValue,
90) -> Result<BooleanArray> {
91    // happy path
92    if result.false_count() == 0 || haystack_arg.data_type() == DataType::Null {
93        return Ok(result.clone());
94    }
95
96    let haystack = haystack_arg.to_array_of_size(result.len())?;
97
98    let row_has_nulls = compute_row_has_nulls(&haystack)?;
99
100    // A row keeps its validity when result is true OR the row has no nulls.
101    let keep_mask = result.values() | &!&row_has_nulls;
102    let new_validity = match result.nulls() {
103        Some(n) => n.inner() & &keep_mask,
104        None => keep_mask,
105    };
106
107    Ok(BooleanArray::new(
108        result.values().clone(),
109        Some(NullBuffer::new(new_validity)),
110    ))
111}
112
113/// Returns a per-row bitmap where bit i is set if row i's list contains any null element.
114fn compute_row_has_nulls(haystack: &dyn Array) -> Result<BooleanBuffer> {
115    match haystack.data_type() {
116        DataType::List(_) => generic_list_row_has_nulls(haystack.as_list::<i32>()),
117        DataType::LargeList(_) => generic_list_row_has_nulls(haystack.as_list::<i64>()),
118        DataType::FixedSizeList(_, _) => {
119            let list = haystack.as_fixed_size_list();
120            let buf = match list.values().nulls() {
121                Some(nulls) => {
122                    let validity = nulls.inner();
123                    let vl = list.value_length() as usize;
124                    let mut builder = BooleanBufferBuilder::new(list.len());
125                    for i in 0..list.len() {
126                        builder.append(validity.slice(i * vl, vl).count_set_bits() < vl);
127                    }
128                    builder.finish()
129                }
130                None => BooleanBuffer::new_unset(list.len()),
131            };
132            Ok(mask_with_list_nulls(buf, list.nulls()))
133        }
134        dt => exec_err!("compute_row_has_nulls: unsupported data type {dt}"),
135    }
136}
137
138/// Computes per-row null presence for `List` and `LargeList` arrays.
139fn generic_list_row_has_nulls<O: OffsetSizeTrait>(
140    list: &GenericListArray<O>,
141) -> Result<BooleanBuffer> {
142    let buf = match list.values().nulls() {
143        Some(nulls) => {
144            let validity = nulls.inner();
145            let offsets = list.offsets();
146            let mut builder = BooleanBufferBuilder::new(list.len());
147            for i in 0..list.len() {
148                let s = offsets[i].as_usize();
149                let len = offsets[i + 1].as_usize() - s;
150                builder.append(validity.slice(s, len).count_set_bits() < len);
151            }
152            builder.finish()
153        }
154        None => BooleanBuffer::new_unset(list.len()),
155    };
156    Ok(mask_with_list_nulls(buf, list.nulls()))
157}
158
159/// Rows where the list itself is null should not be marked as "has nulls".
160fn mask_with_list_nulls(
161    buf: BooleanBuffer,
162    list_nulls: Option<&NullBuffer>,
163) -> BooleanBuffer {
164    match list_nulls {
165        Some(n) => &buf & n.inner(),
166        None => buf,
167    }
168}