datafusion_functions_nested/
position.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_position and array_positions functions.
19
20use arrow::datatypes::DataType;
21use arrow::datatypes::{
22    DataType::{LargeList, List, UInt64},
23    Field,
24};
25use datafusion_expr::{
26    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
27};
28use datafusion_macros::user_doc;
29
30use std::any::Any;
31use std::sync::Arc;
32
33use arrow::array::{
34    types::UInt64Type, Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait,
35    UInt64Array,
36};
37use datafusion_common::cast::{
38    as_generic_list_array, as_int64_array, as_large_list_array, as_list_array,
39};
40use datafusion_common::{exec_err, internal_err, utils::take_function_args, Result};
41use itertools::Itertools;
42
43use crate::utils::{compare_element_to_list, make_scalar_function};
44
45make_udf_expr_and_func!(
46    ArrayPosition,
47    array_position,
48    array element index,
49    "searches for an element in the array, returns first occurrence.",
50    array_position_udf
51);
52
53#[user_doc(
54    doc_section(label = "Array Functions"),
55    description = "Returns the position of the first occurrence of the specified element in the array, or NULL if not found.",
56    syntax_example = "array_position(array, element)\narray_position(array, element, index)",
57    sql_example = r#"```sql
58> select array_position([1, 2, 2, 3, 1, 4], 2);
59+----------------------------------------------+
60| array_position(List([1,2,2,3,1,4]),Int64(2)) |
61+----------------------------------------------+
62| 2                                            |
63+----------------------------------------------+
64> select array_position([1, 2, 2, 3, 1, 4], 2, 3);
65+----------------------------------------------------+
66| array_position(List([1,2,2,3,1,4]),Int64(2), Int64(3)) |
67+----------------------------------------------------+
68| 3                                                  |
69+----------------------------------------------------+
70```"#,
71    argument(
72        name = "array",
73        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
74    ),
75    argument(
76        name = "element",
77        description = "Element to search for position in the array."
78    ),
79    argument(
80        name = "index",
81        description = "Index at which to start searching (1-indexed)."
82    )
83)]
84#[derive(Debug)]
85pub struct ArrayPosition {
86    signature: Signature,
87    aliases: Vec<String>,
88}
89
90impl Default for ArrayPosition {
91    fn default() -> Self {
92        Self::new()
93    }
94}
95impl ArrayPosition {
96    pub fn new() -> Self {
97        Self {
98            signature: Signature::array_and_element_and_optional_index(
99                Volatility::Immutable,
100            ),
101            aliases: vec![
102                String::from("list_position"),
103                String::from("array_indexof"),
104                String::from("list_indexof"),
105            ],
106        }
107    }
108}
109
110impl ScalarUDFImpl for ArrayPosition {
111    fn as_any(&self) -> &dyn Any {
112        self
113    }
114    fn name(&self) -> &str {
115        "array_position"
116    }
117
118    fn signature(&self) -> &Signature {
119        &self.signature
120    }
121
122    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
123        Ok(UInt64)
124    }
125
126    fn invoke_with_args(
127        &self,
128        args: datafusion_expr::ScalarFunctionArgs,
129    ) -> Result<ColumnarValue> {
130        make_scalar_function(array_position_inner)(&args.args)
131    }
132
133    fn aliases(&self) -> &[String] {
134        &self.aliases
135    }
136
137    fn documentation(&self) -> Option<&Documentation> {
138        self.doc()
139    }
140}
141
142/// Array_position SQL function
143pub fn array_position_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
144    if args.len() < 2 || args.len() > 3 {
145        return exec_err!("array_position expects two or three arguments");
146    }
147    match &args[0].data_type() {
148        List(_) => general_position_dispatch::<i32>(args),
149        LargeList(_) => general_position_dispatch::<i64>(args),
150        array_type => exec_err!("array_position does not support type '{array_type:?}'."),
151    }
152}
153fn general_position_dispatch<O: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
154    let list_array = as_generic_list_array::<O>(&args[0])?;
155    let element_array = &args[1];
156
157    crate::utils::check_datatypes(
158        "array_position",
159        &[list_array.values(), element_array],
160    )?;
161
162    let arr_from = if args.len() == 3 {
163        as_int64_array(&args[2])?
164            .values()
165            .to_vec()
166            .iter()
167            .map(|&x| x - 1)
168            .collect::<Vec<_>>()
169    } else {
170        vec![0; list_array.len()]
171    };
172
173    // if `start_from` index is out of bounds, return error
174    for (arr, &from) in list_array.iter().zip(arr_from.iter()) {
175        if let Some(arr) = arr {
176            if from < 0 || from as usize > arr.len() {
177                return internal_err!("start_from index out of bounds");
178            }
179        } else {
180            // We will get null if we got null in the array, so we don't need to check
181        }
182    }
183
184    generic_position::<O>(list_array, element_array, arr_from)
185}
186
187fn generic_position<OffsetSize: OffsetSizeTrait>(
188    list_array: &GenericListArray<OffsetSize>,
189    element_array: &ArrayRef,
190    arr_from: Vec<i64>, // 0-indexed
191) -> Result<ArrayRef> {
192    let mut data = Vec::with_capacity(list_array.len());
193
194    for (row_index, (list_array_row, &from)) in
195        list_array.iter().zip(arr_from.iter()).enumerate()
196    {
197        let from = from as usize;
198
199        if let Some(list_array_row) = list_array_row {
200            let eq_array =
201                compare_element_to_list(&list_array_row, element_array, row_index, true)?;
202
203            // Collect `true`s in 1-indexed positions
204            let index = eq_array
205                .iter()
206                .skip(from)
207                .position(|e| e == Some(true))
208                .map(|index| (from + index + 1) as u64);
209
210            data.push(index);
211        } else {
212            data.push(None);
213        }
214    }
215
216    Ok(Arc::new(UInt64Array::from(data)))
217}
218
219make_udf_expr_and_func!(
220    ArrayPositions,
221    array_positions,
222    array element, // arg name
223    "searches for an element in the array, returns all occurrences.", // doc
224    array_positions_udf // internal function name
225);
226
227#[user_doc(
228    doc_section(label = "Array Functions"),
229    description = "Searches for an element in the array, returns all occurrences.",
230    syntax_example = "array_positions(array, element)",
231    sql_example = r#"```sql
232> select array_positions([1, 2, 2, 3, 1, 4], 2);
233+-----------------------------------------------+
234| array_positions(List([1,2,2,3,1,4]),Int64(2)) |
235+-----------------------------------------------+
236| [2, 3]                                        |
237+-----------------------------------------------+
238```"#,
239    argument(
240        name = "array",
241        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
242    ),
243    argument(
244        name = "element",
245        description = "Element to search for position in the array."
246    )
247)]
248#[derive(Debug)]
249pub(super) struct ArrayPositions {
250    signature: Signature,
251    aliases: Vec<String>,
252}
253
254impl ArrayPositions {
255    pub fn new() -> Self {
256        Self {
257            signature: Signature::array_and_element(Volatility::Immutable),
258            aliases: vec![String::from("list_positions")],
259        }
260    }
261}
262
263impl ScalarUDFImpl for ArrayPositions {
264    fn as_any(&self) -> &dyn Any {
265        self
266    }
267    fn name(&self) -> &str {
268        "array_positions"
269    }
270
271    fn signature(&self) -> &Signature {
272        &self.signature
273    }
274
275    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
276        Ok(List(Arc::new(Field::new_list_field(UInt64, true))))
277    }
278
279    fn invoke_with_args(
280        &self,
281        args: datafusion_expr::ScalarFunctionArgs,
282    ) -> Result<ColumnarValue> {
283        make_scalar_function(array_positions_inner)(&args.args)
284    }
285
286    fn aliases(&self) -> &[String] {
287        &self.aliases
288    }
289
290    fn documentation(&self) -> Option<&Documentation> {
291        self.doc()
292    }
293}
294
295/// Array_positions SQL function
296pub fn array_positions_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
297    let [array, element] = take_function_args("array_positions", args)?;
298
299    match &array.data_type() {
300        List(_) => {
301            let arr = as_list_array(&array)?;
302            crate::utils::check_datatypes("array_positions", &[arr.values(), element])?;
303            general_positions::<i32>(arr, element)
304        }
305        LargeList(_) => {
306            let arr = as_large_list_array(&array)?;
307            crate::utils::check_datatypes("array_positions", &[arr.values(), element])?;
308            general_positions::<i64>(arr, element)
309        }
310        array_type => {
311            exec_err!("array_positions does not support type '{array_type:?}'.")
312        }
313    }
314}
315
316fn general_positions<OffsetSize: OffsetSizeTrait>(
317    list_array: &GenericListArray<OffsetSize>,
318    element_array: &ArrayRef,
319) -> Result<ArrayRef> {
320    let mut data = Vec::with_capacity(list_array.len());
321
322    for (row_index, list_array_row) in list_array.iter().enumerate() {
323        if let Some(list_array_row) = list_array_row {
324            let eq_array =
325                compare_element_to_list(&list_array_row, element_array, row_index, true)?;
326
327            // Collect `true`s in 1-indexed positions
328            let indexes = eq_array
329                .iter()
330                .positions(|e| e == Some(true))
331                .map(|index| Some(index as u64 + 1))
332                .collect::<Vec<_>>();
333
334            data.push(Some(indexes));
335        } else {
336            data.push(None);
337        }
338    }
339
340    Ok(Arc::new(
341        ListArray::from_iter_primitive::<UInt64Type, _, _>(data),
342    ))
343}