1use arrow::array::Scalar;
21use arrow::datatypes::DataType;
22use arrow::datatypes::{
23 DataType::{LargeList, List, UInt64},
24 Field,
25};
26use datafusion_common::ScalarValue;
27use datafusion_expr::{
28 ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
29};
30use datafusion_macros::user_doc;
31
32use std::any::Any;
33use std::sync::Arc;
34
35use arrow::array::{
36 Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait, UInt64Array,
37 types::UInt64Type,
38};
39use datafusion_common::cast::{
40 as_generic_list_array, as_int64_array, as_large_list_array, as_list_array,
41};
42use datafusion_common::{Result, exec_err, utils::take_function_args};
43use itertools::Itertools;
44
45use crate::utils::{compare_element_to_list, make_scalar_function};
46
47make_udf_expr_and_func!(
48 ArrayPosition,
49 array_position,
50 array element index,
51 "searches for an element in the array, returns first occurrence.",
52 array_position_udf
53);
54
55#[user_doc(
56 doc_section(label = "Array Functions"),
57 description = "Returns the position of the first occurrence of the specified element in the array, or NULL if not found. Comparisons are done using `IS DISTINCT FROM` semantics, so NULL is considered to match NULL.",
58 syntax_example = "array_position(array, element)\narray_position(array, element, index)",
59 sql_example = r#"```sql
60> select array_position([1, 2, 2, 3, 1, 4], 2);
61+----------------------------------------------+
62| array_position(List([1,2,2,3,1,4]),Int64(2)) |
63+----------------------------------------------+
64| 2 |
65+----------------------------------------------+
66> select array_position([1, 2, 2, 3, 1, 4], 2, 3);
67+----------------------------------------------------+
68| array_position(List([1,2,2,3,1,4]),Int64(2), Int64(3)) |
69+----------------------------------------------------+
70| 3 |
71+----------------------------------------------------+
72```"#,
73 argument(
74 name = "array",
75 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
76 ),
77 argument(name = "element", description = "Element to search for in the array."),
78 argument(
79 name = "index",
80 description = "Index at which to start searching (1-indexed)."
81 )
82)]
83#[derive(Debug, PartialEq, Eq, Hash)]
84pub struct ArrayPosition {
85 signature: Signature,
86 aliases: Vec<String>,
87}
88
89impl Default for ArrayPosition {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94impl ArrayPosition {
95 pub fn new() -> Self {
96 Self {
97 signature: Signature::array_and_element_and_optional_index(
98 Volatility::Immutable,
99 ),
100 aliases: vec![
101 String::from("list_position"),
102 String::from("array_indexof"),
103 String::from("list_indexof"),
104 ],
105 }
106 }
107}
108
109impl ScalarUDFImpl for ArrayPosition {
110 fn as_any(&self) -> &dyn Any {
111 self
112 }
113 fn name(&self) -> &str {
114 "array_position"
115 }
116
117 fn signature(&self) -> &Signature {
118 &self.signature
119 }
120
121 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
122 Ok(UInt64)
123 }
124
125 fn invoke_with_args(
126 &self,
127 args: datafusion_expr::ScalarFunctionArgs,
128 ) -> Result<ColumnarValue> {
129 let [first_arg, second_arg, third_arg @ ..] = args.args.as_slice() else {
130 return exec_err!("array_position expects two or three arguments");
131 };
132
133 match second_arg {
134 ColumnarValue::Scalar(scalar_element) => {
135 if scalar_element.data_type().is_nested() {
138 return make_scalar_function(array_position_inner)(&args.args);
139 }
140
141 let (num_rows, all_inputs_scalar) = match (first_arg, third_arg.first()) {
144 (ColumnarValue::Array(a), _) => (a.len(), false),
145 (_, Some(ColumnarValue::Array(a))) => (a.len(), false),
146 _ => (1, true),
147 };
148
149 let element_arr = scalar_element.to_array_of_size(1)?;
150 let haystack = first_arg.to_array(num_rows)?;
151 let arr_from = resolve_start_from(third_arg.first(), num_rows)?;
152
153 let result = match haystack.data_type() {
154 List(_) => {
155 let list = as_generic_list_array::<i32>(&haystack)?;
156 array_position_scalar::<i32>(list, &element_arr, &arr_from)
157 }
158 LargeList(_) => {
159 let list = as_generic_list_array::<i64>(&haystack)?;
160 array_position_scalar::<i64>(list, &element_arr, &arr_from)
161 }
162 t => exec_err!("array_position does not support type '{t}'."),
163 }?;
164
165 if all_inputs_scalar {
166 Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
167 &result, 0,
168 )?))
169 } else {
170 Ok(ColumnarValue::Array(result))
171 }
172 }
173 ColumnarValue::Array(_) => {
174 make_scalar_function(array_position_inner)(&args.args)
175 }
176 }
177 }
178
179 fn aliases(&self) -> &[String] {
180 &self.aliases
181 }
182
183 fn documentation(&self) -> Option<&Documentation> {
184 self.doc()
185 }
186}
187
188fn array_position_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
189 if args.len() < 2 || args.len() > 3 {
190 return exec_err!("array_position expects two or three arguments");
191 }
192 match &args[0].data_type() {
193 List(_) => general_position_dispatch::<i32>(args),
194 LargeList(_) => general_position_dispatch::<i64>(args),
195 array_type => exec_err!("array_position does not support type '{array_type}'."),
196 }
197}
198
199fn resolve_start_from(
202 third_arg: Option<&ColumnarValue>,
203 num_rows: usize,
204) -> Result<Vec<i64>> {
205 match third_arg {
206 None => Ok(vec![0i64; num_rows]),
207 Some(ColumnarValue::Scalar(ScalarValue::Int64(Some(v)))) => {
208 Ok(vec![v - 1; num_rows])
209 }
210 Some(ColumnarValue::Scalar(s)) => {
211 exec_err!("array_position expected Int64 for start_from, got {s}")
212 }
213 Some(ColumnarValue::Array(a)) => {
214 Ok(as_int64_array(a)?.values().iter().map(|&x| x - 1).collect())
215 }
216 }
217}
218
219fn array_position_scalar<O: OffsetSizeTrait>(
225 list_array: &GenericListArray<O>,
226 element_array: &ArrayRef,
227 arr_from: &[i64], ) -> Result<ArrayRef> {
229 crate::utils::check_datatypes(
230 "array_position",
231 &[list_array.values(), element_array],
232 )?;
233
234 if list_array.len() == 0 {
235 return Ok(Arc::new(UInt64Array::new_null(0)));
236 }
237
238 let element_datum = Scalar::new(Arc::clone(element_array));
239 let validity = list_array.nulls();
240
241 let offsets = list_array.offsets();
244 let first_offset = offsets[0].as_usize();
245 let last_offset = offsets[list_array.len()].as_usize();
246 let visible_values = list_array
247 .values()
248 .slice(first_offset, last_offset - first_offset);
249
250 let eq_array = arrow_ord::cmp::not_distinct(&visible_values, &element_datum)?;
253 let eq_bits = eq_array.values();
254
255 let mut result: Vec<Option<u64>> = Vec::with_capacity(list_array.len());
256 let mut matches = eq_bits.set_indices().peekable();
257
258 for i in 0..list_array.len() {
261 let start = offsets[i].as_usize() - first_offset;
262 let end = offsets[i + 1].as_usize() - first_offset;
263
264 if validity.is_some_and(|v| v.is_null(i)) {
265 while matches.peek().is_some_and(|&p| p < end) {
267 matches.next();
268 }
269 result.push(None);
270 continue;
271 }
272
273 let from = arr_from[i];
274 let row_len = end - start;
275 if !(from >= 0 && (from as usize) <= row_len) {
276 return exec_err!("start_from out of bounds: {}", from + 1);
277 }
278 let search_start = start + from as usize;
279
280 while matches.peek().is_some_and(|&p| p < search_start) {
282 matches.next();
283 }
284
285 if matches.peek().is_some_and(|&p| p < end) {
287 let pos = *matches.peek().unwrap();
288 result.push(Some((pos - start + 1) as u64));
289 while matches.peek().is_some_and(|&p| p < end) {
291 matches.next();
292 }
293 } else {
294 result.push(None);
295 }
296 }
297
298 debug_assert_eq!(result.len(), list_array.len());
299 Ok(Arc::new(UInt64Array::from(result)))
300}
301
302fn general_position_dispatch<O: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
303 let list_array = as_generic_list_array::<O>(&args[0])?;
304 let element_array = &args[1];
305
306 crate::utils::check_datatypes(
307 "array_position",
308 &[list_array.values(), element_array],
309 )?;
310
311 let arr_from = if args.len() == 3 {
312 as_int64_array(&args[2])?
313 .values()
314 .iter()
315 .map(|&x| x - 1)
316 .collect::<Vec<_>>()
317 } else {
318 vec![0; list_array.len()]
319 };
320
321 for (arr, &from) in list_array.iter().zip(arr_from.iter()) {
322 if !arr.is_none_or(|arr| from >= 0 && (from as usize) <= arr.len()) {
324 return exec_err!("start_from out of bounds: {}", from + 1);
325 }
326 }
327
328 generic_position::<O>(list_array, element_array, &arr_from)
329}
330
331fn generic_position<OffsetSize: OffsetSizeTrait>(
332 list_array: &GenericListArray<OffsetSize>,
333 element_array: &ArrayRef,
334 arr_from: &[i64], ) -> Result<ArrayRef> {
336 let mut data = Vec::with_capacity(list_array.len());
337
338 for (row_index, (list_array_row, &from)) in
339 list_array.iter().zip(arr_from.iter()).enumerate()
340 {
341 let from = from as usize;
342
343 if let Some(list_array_row) = list_array_row {
344 let eq_array =
345 compare_element_to_list(&list_array_row, element_array, row_index, true)?;
346
347 let index = eq_array
349 .iter()
350 .skip(from)
351 .position(|e| e == Some(true))
352 .map(|index| (from + index + 1) as u64);
353
354 data.push(index);
355 } else {
356 data.push(None);
357 }
358 }
359
360 Ok(Arc::new(UInt64Array::from(data)))
361}
362
363make_udf_expr_and_func!(
364 ArrayPositions,
365 array_positions,
366 array element, "searches for an element in the array, returns all occurrences.", array_positions_udf );
370
371#[user_doc(
372 doc_section(label = "Array Functions"),
373 description = "Searches for an element in the array, returns all occurrences.",
374 syntax_example = "array_positions(array, element)",
375 sql_example = r#"```sql
376> select array_positions([1, 2, 2, 3, 1, 4], 2);
377+-----------------------------------------------+
378| array_positions(List([1,2,2,3,1,4]),Int64(2)) |
379+-----------------------------------------------+
380| [2, 3] |
381+-----------------------------------------------+
382```"#,
383 argument(
384 name = "array",
385 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
386 ),
387 argument(
388 name = "element",
389 description = "Element to search for position in the array."
390 )
391)]
392#[derive(Debug, PartialEq, Eq, Hash)]
393pub(super) struct ArrayPositions {
394 signature: Signature,
395 aliases: Vec<String>,
396}
397
398impl ArrayPositions {
399 pub fn new() -> Self {
400 Self {
401 signature: Signature::array_and_element(Volatility::Immutable),
402 aliases: vec![String::from("list_positions")],
403 }
404 }
405}
406
407impl ScalarUDFImpl for ArrayPositions {
408 fn as_any(&self) -> &dyn Any {
409 self
410 }
411 fn name(&self) -> &str {
412 "array_positions"
413 }
414
415 fn signature(&self) -> &Signature {
416 &self.signature
417 }
418
419 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
420 Ok(List(Arc::new(Field::new_list_field(UInt64, true))))
421 }
422
423 fn invoke_with_args(
424 &self,
425 args: datafusion_expr::ScalarFunctionArgs,
426 ) -> Result<ColumnarValue> {
427 make_scalar_function(array_positions_inner)(&args.args)
428 }
429
430 fn aliases(&self) -> &[String] {
431 &self.aliases
432 }
433
434 fn documentation(&self) -> Option<&Documentation> {
435 self.doc()
436 }
437}
438
439fn array_positions_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
440 let [array, element] = take_function_args("array_positions", args)?;
441
442 match &array.data_type() {
443 List(_) => {
444 let arr = as_list_array(&array)?;
445 crate::utils::check_datatypes("array_positions", &[arr.values(), element])?;
446 general_positions::<i32>(arr, element)
447 }
448 LargeList(_) => {
449 let arr = as_large_list_array(&array)?;
450 crate::utils::check_datatypes("array_positions", &[arr.values(), element])?;
451 general_positions::<i64>(arr, element)
452 }
453 array_type => {
454 exec_err!("array_positions does not support type '{array_type}'.")
455 }
456 }
457}
458
459fn general_positions<OffsetSize: OffsetSizeTrait>(
460 list_array: &GenericListArray<OffsetSize>,
461 element_array: &ArrayRef,
462) -> Result<ArrayRef> {
463 let mut data = Vec::with_capacity(list_array.len());
464
465 for (row_index, list_array_row) in list_array.iter().enumerate() {
466 if let Some(list_array_row) = list_array_row {
467 let eq_array =
468 compare_element_to_list(&list_array_row, element_array, row_index, true)?;
469
470 let indexes = eq_array
472 .iter()
473 .positions(|e| e == Some(true))
474 .map(|index| Some(index as u64 + 1))
475 .collect::<Vec<_>>();
476
477 data.push(Some(indexes));
478 } else {
479 data.push(None);
480 }
481 }
482
483 Ok(Arc::new(
484 ListArray::from_iter_primitive::<UInt64Type, _, _>(data),
485 ))
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491 use arrow::array::AsArray;
492 use arrow::datatypes::Int32Type;
493 use datafusion_common::config::ConfigOptions;
494 use datafusion_expr::ScalarFunctionArgs;
495
496 #[test]
497 fn test_array_position_sliced_list() -> Result<()> {
498 let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
500 Some(vec![Some(10), Some(20)]),
501 Some(vec![Some(30), Some(40)]),
502 Some(vec![Some(50), Some(60)]),
503 Some(vec![Some(70), Some(80)]),
504 ]);
505 let sliced = list.slice(1, 2);
506 let haystack_field =
507 Arc::new(Field::new("haystack", sliced.data_type().clone(), true));
508 let needle_field = Arc::new(Field::new("needle", DataType::Int32, true));
509 let return_field = Arc::new(Field::new("return", UInt64, true));
510
511 let invoke = |needle: i32| -> Result<ArrayRef> {
514 ArrayPosition::new()
515 .invoke_with_args(ScalarFunctionArgs {
516 args: vec![
517 ColumnarValue::Array(Arc::new(sliced.clone())),
518 ColumnarValue::Scalar(ScalarValue::Int32(Some(needle))),
519 ],
520 arg_fields: vec![
521 Arc::clone(&haystack_field),
522 Arc::clone(&needle_field),
523 ],
524 number_rows: 2,
525 return_field: Arc::clone(&return_field),
526 config_options: Arc::new(ConfigOptions::default()),
527 })?
528 .into_array(2)
529 };
530
531 let output = invoke(10)?;
532 let output = output.as_primitive::<UInt64Type>();
533 assert!(output.is_null(0));
534 assert!(output.is_null(1));
535
536 let output = invoke(70)?;
537 let output = output.as_primitive::<UInt64Type>();
538 assert!(output.is_null(0));
539 assert!(output.is_null(1));
540
541 Ok(())
542 }
543}