datafusion_functions_nested/
position.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_position and array_positions functions.
19
20use arrow::datatypes::DataType;
21use arrow::datatypes::{
22    DataType::{LargeList, List, UInt64},
23    Field,
24};
25use datafusion_expr::{
26    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
27};
28use datafusion_macros::user_doc;
29
30use std::any::Any;
31use std::sync::Arc;
32
33use arrow::array::{
34    types::UInt64Type, Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait,
35    UInt64Array,
36};
37use datafusion_common::cast::{
38    as_generic_list_array, as_int64_array, as_large_list_array, as_list_array,
39};
40use datafusion_common::{exec_err, internal_err, utils::take_function_args, Result};
41use itertools::Itertools;
42
43use crate::utils::{compare_element_to_list, make_scalar_function};
44
45make_udf_expr_and_func!(
46    ArrayPosition,
47    array_position,
48    array element index,
49    "searches for an element in the array, returns first occurrence.",
50    array_position_udf
51);
52
53#[user_doc(
54    doc_section(label = "Array Functions"),
55    description = "Returns the position of the first occurrence of the specified element in the array.",
56    syntax_example = "array_position(array, element)\narray_position(array, element, index)",
57    sql_example = r#"```sql
58> select array_position([1, 2, 2, 3, 1, 4], 2);
59+----------------------------------------------+
60| array_position(List([1,2,2,3,1,4]),Int64(2)) |
61+----------------------------------------------+
62| 2                                            |
63+----------------------------------------------+
64> select array_position([1, 2, 2, 3, 1, 4], 2, 3);
65+----------------------------------------------------+
66| array_position(List([1,2,2,3,1,4]),Int64(2), Int64(3)) |
67+----------------------------------------------------+
68| 3                                                  |
69+----------------------------------------------------+
70```"#,
71    argument(
72        name = "array",
73        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
74    ),
75    argument(
76        name = "element",
77        description = "Element to search for position in the array."
78    ),
79    argument(name = "index", description = "Index at which to start searching.")
80)]
81#[derive(Debug)]
82pub struct ArrayPosition {
83    signature: Signature,
84    aliases: Vec<String>,
85}
86
87impl Default for ArrayPosition {
88    fn default() -> Self {
89        Self::new()
90    }
91}
92impl ArrayPosition {
93    pub fn new() -> Self {
94        Self {
95            signature: Signature::array_and_element_and_optional_index(
96                Volatility::Immutable,
97            ),
98            aliases: vec![
99                String::from("list_position"),
100                String::from("array_indexof"),
101                String::from("list_indexof"),
102            ],
103        }
104    }
105}
106
107impl ScalarUDFImpl for ArrayPosition {
108    fn as_any(&self) -> &dyn Any {
109        self
110    }
111    fn name(&self) -> &str {
112        "array_position"
113    }
114
115    fn signature(&self) -> &Signature {
116        &self.signature
117    }
118
119    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
120        Ok(UInt64)
121    }
122
123    fn invoke_with_args(
124        &self,
125        args: datafusion_expr::ScalarFunctionArgs,
126    ) -> Result<ColumnarValue> {
127        make_scalar_function(array_position_inner)(&args.args)
128    }
129
130    fn aliases(&self) -> &[String] {
131        &self.aliases
132    }
133
134    fn documentation(&self) -> Option<&Documentation> {
135        self.doc()
136    }
137}
138
139/// Array_position SQL function
140pub fn array_position_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
141    if args.len() < 2 || args.len() > 3 {
142        return exec_err!("array_position expects two or three arguments");
143    }
144    match &args[0].data_type() {
145        List(_) => general_position_dispatch::<i32>(args),
146        LargeList(_) => general_position_dispatch::<i64>(args),
147        array_type => exec_err!("array_position does not support type '{array_type:?}'."),
148    }
149}
150fn general_position_dispatch<O: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
151    let list_array = as_generic_list_array::<O>(&args[0])?;
152    let element_array = &args[1];
153
154    crate::utils::check_datatypes(
155        "array_position",
156        &[list_array.values(), element_array],
157    )?;
158
159    let arr_from = if args.len() == 3 {
160        as_int64_array(&args[2])?
161            .values()
162            .to_vec()
163            .iter()
164            .map(|&x| x - 1)
165            .collect::<Vec<_>>()
166    } else {
167        vec![0; list_array.len()]
168    };
169
170    // if `start_from` index is out of bounds, return error
171    for (arr, &from) in list_array.iter().zip(arr_from.iter()) {
172        if let Some(arr) = arr {
173            if from < 0 || from as usize >= arr.len() {
174                return internal_err!("start_from index out of bounds");
175            }
176        } else {
177            // We will get null if we got null in the array, so we don't need to check
178        }
179    }
180
181    generic_position::<O>(list_array, element_array, arr_from)
182}
183
184fn generic_position<OffsetSize: OffsetSizeTrait>(
185    list_array: &GenericListArray<OffsetSize>,
186    element_array: &ArrayRef,
187    arr_from: Vec<i64>, // 0-indexed
188) -> Result<ArrayRef> {
189    let mut data = Vec::with_capacity(list_array.len());
190
191    for (row_index, (list_array_row, &from)) in
192        list_array.iter().zip(arr_from.iter()).enumerate()
193    {
194        let from = from as usize;
195
196        if let Some(list_array_row) = list_array_row {
197            let eq_array =
198                compare_element_to_list(&list_array_row, element_array, row_index, true)?;
199
200            // Collect `true`s in 1-indexed positions
201            let index = eq_array
202                .iter()
203                .skip(from)
204                .position(|e| e == Some(true))
205                .map(|index| (from + index + 1) as u64);
206
207            data.push(index);
208        } else {
209            data.push(None);
210        }
211    }
212
213    Ok(Arc::new(UInt64Array::from(data)))
214}
215
216make_udf_expr_and_func!(
217    ArrayPositions,
218    array_positions,
219    array element, // arg name
220    "searches for an element in the array, returns all occurrences.", // doc
221    array_positions_udf // internal function name
222);
223
224#[user_doc(
225    doc_section(label = "Array Functions"),
226    description = "Searches for an element in the array, returns all occurrences.",
227    syntax_example = "array_positions(array, element)",
228    sql_example = r#"```sql
229> select array_positions([1, 2, 2, 3, 1, 4], 2);
230+-----------------------------------------------+
231| array_positions(List([1,2,2,3,1,4]),Int64(2)) |
232+-----------------------------------------------+
233| [2, 3]                                        |
234+-----------------------------------------------+
235```"#,
236    argument(
237        name = "array",
238        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
239    ),
240    argument(
241        name = "element",
242        description = "Element to search for position in the array."
243    )
244)]
245#[derive(Debug)]
246pub(super) struct ArrayPositions {
247    signature: Signature,
248    aliases: Vec<String>,
249}
250
251impl ArrayPositions {
252    pub fn new() -> Self {
253        Self {
254            signature: Signature::array_and_element(Volatility::Immutable),
255            aliases: vec![String::from("list_positions")],
256        }
257    }
258}
259
260impl ScalarUDFImpl for ArrayPositions {
261    fn as_any(&self) -> &dyn Any {
262        self
263    }
264    fn name(&self) -> &str {
265        "array_positions"
266    }
267
268    fn signature(&self) -> &Signature {
269        &self.signature
270    }
271
272    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
273        Ok(List(Arc::new(Field::new_list_field(UInt64, true))))
274    }
275
276    fn invoke_with_args(
277        &self,
278        args: datafusion_expr::ScalarFunctionArgs,
279    ) -> Result<ColumnarValue> {
280        make_scalar_function(array_positions_inner)(&args.args)
281    }
282
283    fn aliases(&self) -> &[String] {
284        &self.aliases
285    }
286
287    fn documentation(&self) -> Option<&Documentation> {
288        self.doc()
289    }
290}
291
292/// Array_positions SQL function
293pub fn array_positions_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
294    let [array, element] = take_function_args("array_positions", args)?;
295
296    match &array.data_type() {
297        List(_) => {
298            let arr = as_list_array(&array)?;
299            crate::utils::check_datatypes("array_positions", &[arr.values(), element])?;
300            general_positions::<i32>(arr, element)
301        }
302        LargeList(_) => {
303            let arr = as_large_list_array(&array)?;
304            crate::utils::check_datatypes("array_positions", &[arr.values(), element])?;
305            general_positions::<i64>(arr, element)
306        }
307        array_type => {
308            exec_err!("array_positions does not support type '{array_type:?}'.")
309        }
310    }
311}
312
313fn general_positions<OffsetSize: OffsetSizeTrait>(
314    list_array: &GenericListArray<OffsetSize>,
315    element_array: &ArrayRef,
316) -> Result<ArrayRef> {
317    let mut data = Vec::with_capacity(list_array.len());
318
319    for (row_index, list_array_row) in list_array.iter().enumerate() {
320        if let Some(list_array_row) = list_array_row {
321            let eq_array =
322                compare_element_to_list(&list_array_row, element_array, row_index, true)?;
323
324            // Collect `true`s in 1-indexed positions
325            let indexes = eq_array
326                .iter()
327                .positions(|e| e == Some(true))
328                .map(|index| Some(index as u64 + 1))
329                .collect::<Vec<_>>();
330
331            data.push(Some(indexes));
332        } else {
333            data.push(None);
334        }
335    }
336
337    Ok(Arc::new(
338        ListArray::from_iter_primitive::<UInt64Type, _, _>(data),
339    ))
340}