Skip to main content

datafusion_functions_nested/
position.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_position and array_positions functions.
19
20use arrow::array::Scalar;
21use arrow::datatypes::DataType;
22use arrow::datatypes::{
23    DataType::{LargeList, List, UInt64},
24    Field,
25};
26use datafusion_common::ScalarValue;
27use datafusion_expr::{
28    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
29};
30use datafusion_macros::user_doc;
31
32use std::any::Any;
33use std::sync::Arc;
34
35use arrow::array::{
36    Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait, UInt64Array,
37    types::UInt64Type,
38};
39use datafusion_common::cast::{
40    as_generic_list_array, as_int64_array, as_large_list_array, as_list_array,
41};
42use datafusion_common::{Result, exec_err, utils::take_function_args};
43use itertools::Itertools;
44
45use crate::utils::{compare_element_to_list, make_scalar_function};
46
47make_udf_expr_and_func!(
48    ArrayPosition,
49    array_position,
50    array element index,
51    "searches for an element in the array, returns first occurrence.",
52    array_position_udf
53);
54
55#[user_doc(
56    doc_section(label = "Array Functions"),
57    description = "Returns the position of the first occurrence of the specified element in the array, or NULL if not found. Comparisons are done using `IS DISTINCT FROM` semantics, so NULL is considered to match NULL.",
58    syntax_example = "array_position(array, element)\narray_position(array, element, index)",
59    sql_example = r#"```sql
60> select array_position([1, 2, 2, 3, 1, 4], 2);
61+----------------------------------------------+
62| array_position(List([1,2,2,3,1,4]),Int64(2)) |
63+----------------------------------------------+
64| 2                                            |
65+----------------------------------------------+
66> select array_position([1, 2, 2, 3, 1, 4], 2, 3);
67+----------------------------------------------------+
68| array_position(List([1,2,2,3,1,4]),Int64(2), Int64(3)) |
69+----------------------------------------------------+
70| 3                                                  |
71+----------------------------------------------------+
72```"#,
73    argument(
74        name = "array",
75        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
76    ),
77    argument(name = "element", description = "Element to search for in the array."),
78    argument(
79        name = "index",
80        description = "Index at which to start searching (1-indexed)."
81    )
82)]
83#[derive(Debug, PartialEq, Eq, Hash)]
84pub struct ArrayPosition {
85    signature: Signature,
86    aliases: Vec<String>,
87}
88
89impl Default for ArrayPosition {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94impl ArrayPosition {
95    pub fn new() -> Self {
96        Self {
97            signature: Signature::array_and_element_and_optional_index(
98                Volatility::Immutable,
99            ),
100            aliases: vec![
101                String::from("list_position"),
102                String::from("array_indexof"),
103                String::from("list_indexof"),
104            ],
105        }
106    }
107}
108
109impl ScalarUDFImpl for ArrayPosition {
110    fn as_any(&self) -> &dyn Any {
111        self
112    }
113    fn name(&self) -> &str {
114        "array_position"
115    }
116
117    fn signature(&self) -> &Signature {
118        &self.signature
119    }
120
121    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
122        Ok(UInt64)
123    }
124
125    fn invoke_with_args(
126        &self,
127        args: datafusion_expr::ScalarFunctionArgs,
128    ) -> Result<ColumnarValue> {
129        let [first_arg, second_arg, third_arg @ ..] = args.args.as_slice() else {
130            return exec_err!("array_position expects two or three arguments");
131        };
132
133        match second_arg {
134            ColumnarValue::Scalar(scalar_element) => {
135                // Nested element types (List, Struct) can't use the fast path
136                // (because Arrow's `non_distinct` does not support them).
137                if scalar_element.data_type().is_nested() {
138                    return make_scalar_function(array_position_inner)(&args.args);
139                }
140
141                // Determine batch length from whichever argument is columnar;
142                // if all inputs are scalar, batch length is 1.
143                let (num_rows, all_inputs_scalar) = match (first_arg, third_arg.first()) {
144                    (ColumnarValue::Array(a), _) => (a.len(), false),
145                    (_, Some(ColumnarValue::Array(a))) => (a.len(), false),
146                    _ => (1, true),
147                };
148
149                let element_arr = scalar_element.to_array_of_size(1)?;
150                let haystack = first_arg.to_array(num_rows)?;
151                let arr_from = resolve_start_from(third_arg.first(), num_rows)?;
152
153                let result = match haystack.data_type() {
154                    List(_) => {
155                        let list = as_generic_list_array::<i32>(&haystack)?;
156                        array_position_scalar::<i32>(list, &element_arr, &arr_from)
157                    }
158                    LargeList(_) => {
159                        let list = as_generic_list_array::<i64>(&haystack)?;
160                        array_position_scalar::<i64>(list, &element_arr, &arr_from)
161                    }
162                    t => exec_err!("array_position does not support type '{t}'."),
163                }?;
164
165                if all_inputs_scalar {
166                    Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
167                        &result, 0,
168                    )?))
169                } else {
170                    Ok(ColumnarValue::Array(result))
171                }
172            }
173            ColumnarValue::Array(_) => {
174                make_scalar_function(array_position_inner)(&args.args)
175            }
176        }
177    }
178
179    fn aliases(&self) -> &[String] {
180        &self.aliases
181    }
182
183    fn documentation(&self) -> Option<&Documentation> {
184        self.doc()
185    }
186}
187
188fn array_position_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
189    if args.len() < 2 || args.len() > 3 {
190        return exec_err!("array_position expects two or three arguments");
191    }
192    match &args[0].data_type() {
193        List(_) => general_position_dispatch::<i32>(args),
194        LargeList(_) => general_position_dispatch::<i64>(args),
195        array_type => exec_err!("array_position does not support type '{array_type}'."),
196    }
197}
198
199/// Resolves the optional `start_from` argument into a `Vec<i64>` of
200/// 0-indexed starting positions.
201fn resolve_start_from(
202    third_arg: Option<&ColumnarValue>,
203    num_rows: usize,
204) -> Result<Vec<i64>> {
205    match third_arg {
206        None => Ok(vec![0i64; num_rows]),
207        Some(ColumnarValue::Scalar(ScalarValue::Int64(Some(v)))) => {
208            Ok(vec![v - 1; num_rows])
209        }
210        Some(ColumnarValue::Scalar(s)) => {
211            exec_err!("array_position expected Int64 for start_from, got {s}")
212        }
213        Some(ColumnarValue::Array(a)) => {
214            Ok(as_int64_array(a)?.values().iter().map(|&x| x - 1).collect())
215        }
216    }
217}
218
219/// Fast path for `array_position` when the element is a scalar.
220///
221/// Performs a single bulk `not_distinct` comparison of the scalar element
222/// against the entire flattened values buffer, then walks the result bitmap
223/// using offsets to find per-row first-match positions.
224fn array_position_scalar<O: OffsetSizeTrait>(
225    list_array: &GenericListArray<O>,
226    element_array: &ArrayRef,
227    arr_from: &[i64], // 0-indexed
228) -> Result<ArrayRef> {
229    crate::utils::check_datatypes(
230        "array_position",
231        &[list_array.values(), element_array],
232    )?;
233
234    if list_array.len() == 0 {
235        return Ok(Arc::new(UInt64Array::new_null(0)));
236    }
237
238    let element_datum = Scalar::new(Arc::clone(element_array));
239    let validity = list_array.nulls();
240
241    // Only compare the visible portion of the values buffer, which avoids
242    // wasted work for sliced ListArrays.
243    let offsets = list_array.offsets();
244    let first_offset = offsets[0].as_usize();
245    let last_offset = offsets[list_array.len()].as_usize();
246    let visible_values = list_array
247        .values()
248        .slice(first_offset, last_offset - first_offset);
249
250    // `not_distinct` treats NULL=NULL as true, matching the semantics of
251    // `array_position`
252    let eq_array = arrow_ord::cmp::not_distinct(&visible_values, &element_datum)?;
253    let eq_bits = eq_array.values();
254
255    let mut result: Vec<Option<u64>> = Vec::with_capacity(list_array.len());
256    let mut matches = eq_bits.set_indices().peekable();
257
258    // Match positions are relative to visible_values (0-based), so
259    // subtract first_offset from each offset when comparing.
260    for i in 0..list_array.len() {
261        let start = offsets[i].as_usize() - first_offset;
262        let end = offsets[i + 1].as_usize() - first_offset;
263
264        if validity.is_some_and(|v| v.is_null(i)) {
265            // Null row -> null output; advance past matches in range
266            while matches.peek().is_some_and(|&p| p < end) {
267                matches.next();
268            }
269            result.push(None);
270            continue;
271        }
272
273        let from = arr_from[i];
274        let row_len = end - start;
275        if !(from >= 0 && (from as usize) <= row_len) {
276            return exec_err!("start_from out of bounds: {}", from + 1);
277        }
278        let search_start = start + from as usize;
279
280        // Advance past matches before search_start
281        while matches.peek().is_some_and(|&p| p < search_start) {
282            matches.next();
283        }
284
285        // First match in [search_start, end)?
286        if matches.peek().is_some_and(|&p| p < end) {
287            let pos = *matches.peek().unwrap();
288            result.push(Some((pos - start + 1) as u64));
289            // Advance past remaining matches in this row
290            while matches.peek().is_some_and(|&p| p < end) {
291                matches.next();
292            }
293        } else {
294            result.push(None);
295        }
296    }
297
298    debug_assert_eq!(result.len(), list_array.len());
299    Ok(Arc::new(UInt64Array::from(result)))
300}
301
302fn general_position_dispatch<O: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
303    let list_array = as_generic_list_array::<O>(&args[0])?;
304    let element_array = &args[1];
305
306    crate::utils::check_datatypes(
307        "array_position",
308        &[list_array.values(), element_array],
309    )?;
310
311    let arr_from = if args.len() == 3 {
312        as_int64_array(&args[2])?
313            .values()
314            .iter()
315            .map(|&x| x - 1)
316            .collect::<Vec<_>>()
317    } else {
318        vec![0; list_array.len()]
319    };
320
321    for (arr, &from) in list_array.iter().zip(arr_from.iter()) {
322        // If `arr` is `None`: we will get null if we got null in the array, so we don't need to check
323        if !arr.is_none_or(|arr| from >= 0 && (from as usize) <= arr.len()) {
324            return exec_err!("start_from out of bounds: {}", from + 1);
325        }
326    }
327
328    generic_position::<O>(list_array, element_array, &arr_from)
329}
330
331fn generic_position<OffsetSize: OffsetSizeTrait>(
332    list_array: &GenericListArray<OffsetSize>,
333    element_array: &ArrayRef,
334    arr_from: &[i64], // 0-indexed
335) -> Result<ArrayRef> {
336    let mut data = Vec::with_capacity(list_array.len());
337
338    for (row_index, (list_array_row, &from)) in
339        list_array.iter().zip(arr_from.iter()).enumerate()
340    {
341        let from = from as usize;
342
343        if let Some(list_array_row) = list_array_row {
344            let eq_array =
345                compare_element_to_list(&list_array_row, element_array, row_index, true)?;
346
347            // Collect `true`s in 1-indexed positions
348            let index = eq_array
349                .iter()
350                .skip(from)
351                .position(|e| e == Some(true))
352                .map(|index| (from + index + 1) as u64);
353
354            data.push(index);
355        } else {
356            data.push(None);
357        }
358    }
359
360    Ok(Arc::new(UInt64Array::from(data)))
361}
362
363make_udf_expr_and_func!(
364    ArrayPositions,
365    array_positions,
366    array element, // arg name
367    "searches for an element in the array, returns all occurrences.", // doc
368    array_positions_udf // internal function name
369);
370
371#[user_doc(
372    doc_section(label = "Array Functions"),
373    description = "Searches for an element in the array, returns all occurrences.",
374    syntax_example = "array_positions(array, element)",
375    sql_example = r#"```sql
376> select array_positions([1, 2, 2, 3, 1, 4], 2);
377+-----------------------------------------------+
378| array_positions(List([1,2,2,3,1,4]),Int64(2)) |
379+-----------------------------------------------+
380| [2, 3]                                        |
381+-----------------------------------------------+
382```"#,
383    argument(
384        name = "array",
385        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
386    ),
387    argument(
388        name = "element",
389        description = "Element to search for position in the array."
390    )
391)]
392#[derive(Debug, PartialEq, Eq, Hash)]
393pub(super) struct ArrayPositions {
394    signature: Signature,
395    aliases: Vec<String>,
396}
397
398impl ArrayPositions {
399    pub fn new() -> Self {
400        Self {
401            signature: Signature::array_and_element(Volatility::Immutable),
402            aliases: vec![String::from("list_positions")],
403        }
404    }
405}
406
407impl ScalarUDFImpl for ArrayPositions {
408    fn as_any(&self) -> &dyn Any {
409        self
410    }
411    fn name(&self) -> &str {
412        "array_positions"
413    }
414
415    fn signature(&self) -> &Signature {
416        &self.signature
417    }
418
419    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
420        Ok(List(Arc::new(Field::new_list_field(UInt64, true))))
421    }
422
423    fn invoke_with_args(
424        &self,
425        args: datafusion_expr::ScalarFunctionArgs,
426    ) -> Result<ColumnarValue> {
427        make_scalar_function(array_positions_inner)(&args.args)
428    }
429
430    fn aliases(&self) -> &[String] {
431        &self.aliases
432    }
433
434    fn documentation(&self) -> Option<&Documentation> {
435        self.doc()
436    }
437}
438
439fn array_positions_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
440    let [array, element] = take_function_args("array_positions", args)?;
441
442    match &array.data_type() {
443        List(_) => {
444            let arr = as_list_array(&array)?;
445            crate::utils::check_datatypes("array_positions", &[arr.values(), element])?;
446            general_positions::<i32>(arr, element)
447        }
448        LargeList(_) => {
449            let arr = as_large_list_array(&array)?;
450            crate::utils::check_datatypes("array_positions", &[arr.values(), element])?;
451            general_positions::<i64>(arr, element)
452        }
453        array_type => {
454            exec_err!("array_positions does not support type '{array_type}'.")
455        }
456    }
457}
458
459fn general_positions<OffsetSize: OffsetSizeTrait>(
460    list_array: &GenericListArray<OffsetSize>,
461    element_array: &ArrayRef,
462) -> Result<ArrayRef> {
463    let mut data = Vec::with_capacity(list_array.len());
464
465    for (row_index, list_array_row) in list_array.iter().enumerate() {
466        if let Some(list_array_row) = list_array_row {
467            let eq_array =
468                compare_element_to_list(&list_array_row, element_array, row_index, true)?;
469
470            // Collect `true`s in 1-indexed positions
471            let indexes = eq_array
472                .iter()
473                .positions(|e| e == Some(true))
474                .map(|index| Some(index as u64 + 1))
475                .collect::<Vec<_>>();
476
477            data.push(Some(indexes));
478        } else {
479            data.push(None);
480        }
481    }
482
483    Ok(Arc::new(
484        ListArray::from_iter_primitive::<UInt64Type, _, _>(data),
485    ))
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491    use arrow::array::AsArray;
492    use arrow::datatypes::Int32Type;
493    use datafusion_common::config::ConfigOptions;
494    use datafusion_expr::ScalarFunctionArgs;
495
496    #[test]
497    fn test_array_position_sliced_list() -> Result<()> {
498        // [[10, 20], [30, 40], [50, 60], [70, 80]]  →  slice(1,2)  →  [[30, 40], [50, 60]]
499        let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
500            Some(vec![Some(10), Some(20)]),
501            Some(vec![Some(30), Some(40)]),
502            Some(vec![Some(50), Some(60)]),
503            Some(vec![Some(70), Some(80)]),
504        ]);
505        let sliced = list.slice(1, 2);
506        let haystack_field =
507            Arc::new(Field::new("haystack", sliced.data_type().clone(), true));
508        let needle_field = Arc::new(Field::new("needle", DataType::Int32, true));
509        let return_field = Arc::new(Field::new("return", UInt64, true));
510
511        // Search for elements that exist only in sliced-away rows:
512        // 10 is in the prefix row, 70 is in the suffix row.
513        let invoke = |needle: i32| -> Result<ArrayRef> {
514            ArrayPosition::new()
515                .invoke_with_args(ScalarFunctionArgs {
516                    args: vec![
517                        ColumnarValue::Array(Arc::new(sliced.clone())),
518                        ColumnarValue::Scalar(ScalarValue::Int32(Some(needle))),
519                    ],
520                    arg_fields: vec![
521                        Arc::clone(&haystack_field),
522                        Arc::clone(&needle_field),
523                    ],
524                    number_rows: 2,
525                    return_field: Arc::clone(&return_field),
526                    config_options: Arc::new(ConfigOptions::default()),
527                })?
528                .into_array(2)
529        };
530
531        let output = invoke(10)?;
532        let output = output.as_primitive::<UInt64Type>();
533        assert!(output.is_null(0));
534        assert!(output.is_null(1));
535
536        let output = invoke(70)?;
537        let output = output.as_primitive::<UInt64Type>();
538        assert!(output.is_null(0));
539        assert!(output.is_null(1));
540
541        Ok(())
542    }
543}