datafusion_functions_nested/
position.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_position and array_positions functions.
19
20use arrow::datatypes::DataType;
21use arrow::datatypes::{
22    DataType::{LargeList, List, UInt64},
23    Field,
24};
25use datafusion_expr::{
26    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
27};
28use datafusion_macros::user_doc;
29
30use std::any::Any;
31use std::sync::Arc;
32
33use arrow::array::{
34    Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait, UInt64Array,
35    types::UInt64Type,
36};
37use datafusion_common::cast::{
38    as_generic_list_array, as_int64_array, as_large_list_array, as_list_array,
39};
40use datafusion_common::{
41    Result, assert_or_internal_err, exec_err, utils::take_function_args,
42};
43use itertools::Itertools;
44
45use crate::utils::{compare_element_to_list, make_scalar_function};
46
47make_udf_expr_and_func!(
48    ArrayPosition,
49    array_position,
50    array element index,
51    "searches for an element in the array, returns first occurrence.",
52    array_position_udf
53);
54
55#[user_doc(
56    doc_section(label = "Array Functions"),
57    description = "Returns the position of the first occurrence of the specified element in the array, or NULL if not found.",
58    syntax_example = "array_position(array, element)\narray_position(array, element, index)",
59    sql_example = r#"```sql
60> select array_position([1, 2, 2, 3, 1, 4], 2);
61+----------------------------------------------+
62| array_position(List([1,2,2,3,1,4]),Int64(2)) |
63+----------------------------------------------+
64| 2                                            |
65+----------------------------------------------+
66> select array_position([1, 2, 2, 3, 1, 4], 2, 3);
67+----------------------------------------------------+
68| array_position(List([1,2,2,3,1,4]),Int64(2), Int64(3)) |
69+----------------------------------------------------+
70| 3                                                  |
71+----------------------------------------------------+
72```"#,
73    argument(
74        name = "array",
75        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
76    ),
77    argument(
78        name = "element",
79        description = "Element to search for position in the array."
80    ),
81    argument(
82        name = "index",
83        description = "Index at which to start searching (1-indexed)."
84    )
85)]
86#[derive(Debug, PartialEq, Eq, Hash)]
87pub struct ArrayPosition {
88    signature: Signature,
89    aliases: Vec<String>,
90}
91
92impl Default for ArrayPosition {
93    fn default() -> Self {
94        Self::new()
95    }
96}
97impl ArrayPosition {
98    pub fn new() -> Self {
99        Self {
100            signature: Signature::array_and_element_and_optional_index(
101                Volatility::Immutable,
102            ),
103            aliases: vec![
104                String::from("list_position"),
105                String::from("array_indexof"),
106                String::from("list_indexof"),
107            ],
108        }
109    }
110}
111
112impl ScalarUDFImpl for ArrayPosition {
113    fn as_any(&self) -> &dyn Any {
114        self
115    }
116    fn name(&self) -> &str {
117        "array_position"
118    }
119
120    fn signature(&self) -> &Signature {
121        &self.signature
122    }
123
124    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
125        Ok(UInt64)
126    }
127
128    fn invoke_with_args(
129        &self,
130        args: datafusion_expr::ScalarFunctionArgs,
131    ) -> Result<ColumnarValue> {
132        make_scalar_function(array_position_inner)(&args.args)
133    }
134
135    fn aliases(&self) -> &[String] {
136        &self.aliases
137    }
138
139    fn documentation(&self) -> Option<&Documentation> {
140        self.doc()
141    }
142}
143
144fn array_position_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
145    if args.len() < 2 || args.len() > 3 {
146        return exec_err!("array_position expects two or three arguments");
147    }
148    match &args[0].data_type() {
149        List(_) => general_position_dispatch::<i32>(args),
150        LargeList(_) => general_position_dispatch::<i64>(args),
151        array_type => exec_err!("array_position does not support type '{array_type}'."),
152    }
153}
154
155fn general_position_dispatch<O: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
156    let list_array = as_generic_list_array::<O>(&args[0])?;
157    let element_array = &args[1];
158
159    crate::utils::check_datatypes(
160        "array_position",
161        &[list_array.values(), element_array],
162    )?;
163
164    let arr_from = if args.len() == 3 {
165        as_int64_array(&args[2])?
166            .values()
167            .to_vec()
168            .iter()
169            .map(|&x| x - 1)
170            .collect::<Vec<_>>()
171    } else {
172        vec![0; list_array.len()]
173    };
174
175    // if `start_from` index is out of bounds, return error
176    for (arr, &from) in list_array.iter().zip(arr_from.iter()) {
177        // If `arr` is `None`: we will get null if we got null in the array, so we don't need to check
178        assert_or_internal_err!(
179            arr.is_none_or(|arr| from >= 0 && (from as usize) <= arr.len()),
180            "start_from index out of bounds"
181        );
182    }
183
184    generic_position::<O>(list_array, element_array, &arr_from)
185}
186
187fn generic_position<OffsetSize: OffsetSizeTrait>(
188    list_array: &GenericListArray<OffsetSize>,
189    element_array: &ArrayRef,
190    arr_from: &[i64], // 0-indexed
191) -> Result<ArrayRef> {
192    let mut data = Vec::with_capacity(list_array.len());
193
194    for (row_index, (list_array_row, &from)) in
195        list_array.iter().zip(arr_from.iter()).enumerate()
196    {
197        let from = from as usize;
198
199        if let Some(list_array_row) = list_array_row {
200            let eq_array =
201                compare_element_to_list(&list_array_row, element_array, row_index, true)?;
202
203            // Collect `true`s in 1-indexed positions
204            let index = eq_array
205                .iter()
206                .skip(from)
207                .position(|e| e == Some(true))
208                .map(|index| (from + index + 1) as u64);
209
210            data.push(index);
211        } else {
212            data.push(None);
213        }
214    }
215
216    Ok(Arc::new(UInt64Array::from(data)))
217}
218
219make_udf_expr_and_func!(
220    ArrayPositions,
221    array_positions,
222    array element, // arg name
223    "searches for an element in the array, returns all occurrences.", // doc
224    array_positions_udf // internal function name
225);
226
227#[user_doc(
228    doc_section(label = "Array Functions"),
229    description = "Searches for an element in the array, returns all occurrences.",
230    syntax_example = "array_positions(array, element)",
231    sql_example = r#"```sql
232> select array_positions([1, 2, 2, 3, 1, 4], 2);
233+-----------------------------------------------+
234| array_positions(List([1,2,2,3,1,4]),Int64(2)) |
235+-----------------------------------------------+
236| [2, 3]                                        |
237+-----------------------------------------------+
238```"#,
239    argument(
240        name = "array",
241        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
242    ),
243    argument(
244        name = "element",
245        description = "Element to search for position in the array."
246    )
247)]
248#[derive(Debug, PartialEq, Eq, Hash)]
249pub(super) struct ArrayPositions {
250    signature: Signature,
251    aliases: Vec<String>,
252}
253
254impl ArrayPositions {
255    pub fn new() -> Self {
256        Self {
257            signature: Signature::array_and_element(Volatility::Immutable),
258            aliases: vec![String::from("list_positions")],
259        }
260    }
261}
262
263impl ScalarUDFImpl for ArrayPositions {
264    fn as_any(&self) -> &dyn Any {
265        self
266    }
267    fn name(&self) -> &str {
268        "array_positions"
269    }
270
271    fn signature(&self) -> &Signature {
272        &self.signature
273    }
274
275    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
276        Ok(List(Arc::new(Field::new_list_field(UInt64, true))))
277    }
278
279    fn invoke_with_args(
280        &self,
281        args: datafusion_expr::ScalarFunctionArgs,
282    ) -> Result<ColumnarValue> {
283        make_scalar_function(array_positions_inner)(&args.args)
284    }
285
286    fn aliases(&self) -> &[String] {
287        &self.aliases
288    }
289
290    fn documentation(&self) -> Option<&Documentation> {
291        self.doc()
292    }
293}
294
295fn array_positions_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
296    let [array, element] = take_function_args("array_positions", args)?;
297
298    match &array.data_type() {
299        List(_) => {
300            let arr = as_list_array(&array)?;
301            crate::utils::check_datatypes("array_positions", &[arr.values(), element])?;
302            general_positions::<i32>(arr, element)
303        }
304        LargeList(_) => {
305            let arr = as_large_list_array(&array)?;
306            crate::utils::check_datatypes("array_positions", &[arr.values(), element])?;
307            general_positions::<i64>(arr, element)
308        }
309        array_type => {
310            exec_err!("array_positions does not support type '{array_type}'.")
311        }
312    }
313}
314
315fn general_positions<OffsetSize: OffsetSizeTrait>(
316    list_array: &GenericListArray<OffsetSize>,
317    element_array: &ArrayRef,
318) -> Result<ArrayRef> {
319    let mut data = Vec::with_capacity(list_array.len());
320
321    for (row_index, list_array_row) in list_array.iter().enumerate() {
322        if let Some(list_array_row) = list_array_row {
323            let eq_array =
324                compare_element_to_list(&list_array_row, element_array, row_index, true)?;
325
326            // Collect `true`s in 1-indexed positions
327            let indexes = eq_array
328                .iter()
329                .positions(|e| e == Some(true))
330                .map(|index| Some(index as u64 + 1))
331                .collect::<Vec<_>>();
332
333            data.push(Some(indexes));
334        } else {
335            data.push(None);
336        }
337    }
338
339    Ok(Arc::new(
340        ListArray::from_iter_primitive::<UInt64Type, _, _>(data),
341    ))
342}