Skip to main content

datafusion_spark/function/array/
array_contains.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    Array, AsArray, BooleanArray, BooleanBufferBuilder, GenericListArray, OffsetSizeTrait,
20};
21use arrow::buffer::{BooleanBuffer, NullBuffer};
22use arrow::datatypes::DataType;
23use datafusion_common::{Result, exec_err};
24use datafusion_expr::{
25    ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
26};
27use datafusion_functions_nested::array_has::array_has_udf;
28use std::sync::Arc;
29
30/// Spark-compatible `array_contains` function.
31///
32/// Calls DataFusion's `array_has` and then applies Spark's null semantics:
33/// - If the result from `array_has` is `true`, return `true`.
34/// - If the result is `false` and the input array row contains any null elements,
35///   return `null` (because the element might have been the null).
36/// - If the result is `false` and the input array row has no null elements,
37///   return `false`.
38#[derive(Debug, PartialEq, Eq, Hash)]
39pub struct SparkArrayContains {
40    signature: Signature,
41}
42
43impl Default for SparkArrayContains {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl SparkArrayContains {
50    pub fn new() -> Self {
51        Self {
52            signature: Signature::array_and_element(Volatility::Immutable),
53        }
54    }
55}
56
57impl ScalarUDFImpl for SparkArrayContains {
58    fn name(&self) -> &str {
59        "array_contains"
60    }
61
62    fn signature(&self) -> &Signature {
63        &self.signature
64    }
65
66    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
67        Ok(DataType::Boolean)
68    }
69
70    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
71        let haystack = args.args[0].clone();
72        let array_has_result = array_has_udf().invoke_with_args(args)?;
73
74        let result_array = array_has_result.to_array(1)?;
75        let patched = apply_spark_null_semantics(result_array.as_boolean(), &haystack)?;
76        Ok(ColumnarValue::Array(Arc::new(patched)))
77    }
78}
79
80/// For each row where `array_has` returned `false`, set the output to null
81/// if that row's input array contains any null elements.
82fn apply_spark_null_semantics(
83    result: &BooleanArray,
84    haystack_arg: &ColumnarValue,
85) -> Result<BooleanArray> {
86    // happy path
87    if haystack_arg.data_type() == DataType::Null || !result.has_false() {
88        return Ok(result.clone());
89    }
90
91    let haystack = haystack_arg.to_array_of_size(result.len())?;
92
93    let row_has_nulls = compute_row_has_nulls(&haystack)?;
94
95    // A row keeps its validity when result is true OR the row has no nulls.
96    let keep_mask = result.values() | &!&row_has_nulls;
97    let new_validity = match result.nulls() {
98        Some(n) => n.inner() & &keep_mask,
99        None => keep_mask,
100    };
101
102    Ok(BooleanArray::new(
103        result.values().clone(),
104        Some(NullBuffer::new(new_validity)),
105    ))
106}
107
108/// Returns a per-row bitmap where bit i is set if row i's list contains any null element.
109fn compute_row_has_nulls(haystack: &dyn Array) -> Result<BooleanBuffer> {
110    match haystack.data_type() {
111        DataType::List(_) => generic_list_row_has_nulls(haystack.as_list::<i32>()),
112        DataType::LargeList(_) => generic_list_row_has_nulls(haystack.as_list::<i64>()),
113        DataType::FixedSizeList(_, _) => {
114            let list = haystack.as_fixed_size_list();
115            let buf = match list.values().nulls() {
116                Some(nulls) => {
117                    let validity = nulls.inner();
118                    let vl = list.value_length() as usize;
119                    let mut builder = BooleanBufferBuilder::new(list.len());
120                    for i in 0..list.len() {
121                        builder.append(validity.slice(i * vl, vl).count_set_bits() < vl);
122                    }
123                    builder.finish()
124                }
125                None => BooleanBuffer::new_unset(list.len()),
126            };
127            Ok(mask_with_list_nulls(buf, list.nulls()))
128        }
129        dt => exec_err!("compute_row_has_nulls: unsupported data type {dt}"),
130    }
131}
132
133/// Computes per-row null presence for `List` and `LargeList` arrays.
134fn generic_list_row_has_nulls<O: OffsetSizeTrait>(
135    list: &GenericListArray<O>,
136) -> Result<BooleanBuffer> {
137    let buf = match list.values().nulls() {
138        Some(nulls) => {
139            let validity = nulls.inner();
140            let offsets = list.offsets();
141            let mut builder = BooleanBufferBuilder::new(list.len());
142            for i in 0..list.len() {
143                let s = offsets[i].as_usize();
144                let len = offsets[i + 1].as_usize() - s;
145                builder.append(validity.slice(s, len).count_set_bits() < len);
146            }
147            builder.finish()
148        }
149        None => BooleanBuffer::new_unset(list.len()),
150    };
151    Ok(mask_with_list_nulls(buf, list.nulls()))
152}
153
154/// Rows where the list itself is null should not be marked as "has nulls".
155fn mask_with_list_nulls(
156    buf: BooleanBuffer,
157    list_nulls: Option<&NullBuffer>,
158) -> BooleanBuffer {
159    match list_nulls {
160        Some(n) => &buf & n.inner(),
161        None => buf,
162    }
163}