1use arrow::array::Scalar;
21use arrow::buffer::OffsetBuffer;
22use arrow::datatypes::DataType;
23use arrow::datatypes::{
24 DataType::{LargeList, List, UInt64},
25 Field,
26};
27use datafusion_common::ScalarValue;
28use datafusion_expr::{
29 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
30 Volatility,
31};
32use datafusion_macros::user_doc;
33
34use std::sync::Arc;
35
36use arrow::array::{
37 Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait, UInt64Array,
38 types::UInt64Type,
39};
40use datafusion_common::cast::{
41 as_generic_list_array, as_int64_array, as_large_list_array, as_list_array,
42};
43use datafusion_common::{Result, exec_err, utils::take_function_args};
44use itertools::Itertools;
45
46use crate::utils::{compare_element_to_list, make_scalar_function};
47
48make_udf_expr_and_func!(
49 ArrayPosition,
50 array_position,
51 array element index,
52 "searches for an element in the array, returns first occurrence.",
53 array_position_udf
54);
55
56#[user_doc(
57 doc_section(label = "Array Functions"),
58 description = "Returns the position of the first occurrence of the specified element in the array, or NULL if not found. Comparisons are done using `IS DISTINCT FROM` semantics, so NULL is considered to match NULL.",
59 syntax_example = "array_position(array, element)\narray_position(array, element, index)",
60 sql_example = r#"```sql
61> select array_position([1, 2, 2, 3, 1, 4], 2);
62+----------------------------------------------+
63| array_position(List([1,2,2,3,1,4]),Int64(2)) |
64+----------------------------------------------+
65| 2 |
66+----------------------------------------------+
67> select array_position([1, 2, 2, 3, 1, 4], 2, 3);
68+----------------------------------------------------+
69| array_position(List([1,2,2,3,1,4]),Int64(2), Int64(3)) |
70+----------------------------------------------------+
71| 3 |
72+----------------------------------------------------+
73```"#,
74 argument(
75 name = "array",
76 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
77 ),
78 argument(name = "element", description = "Element to search for in the array."),
79 argument(
80 name = "index",
81 description = "Index at which to start searching (1-indexed)."
82 )
83)]
84#[derive(Debug, PartialEq, Eq, Hash)]
85pub struct ArrayPosition {
86 signature: Signature,
87 aliases: Vec<String>,
88}
89
90impl Default for ArrayPosition {
91 fn default() -> Self {
92 Self::new()
93 }
94}
95impl ArrayPosition {
96 pub fn new() -> Self {
97 Self {
98 signature: Signature::array_and_element_and_optional_index(
99 Volatility::Immutable,
100 ),
101 aliases: vec![
102 String::from("list_position"),
103 String::from("array_indexof"),
104 String::from("list_indexof"),
105 ],
106 }
107 }
108}
109
110impl ScalarUDFImpl for ArrayPosition {
111 fn name(&self) -> &str {
112 "array_position"
113 }
114
115 fn signature(&self) -> &Signature {
116 &self.signature
117 }
118
119 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
120 Ok(UInt64)
121 }
122
123 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
124 match try_array_position_scalar(&args.args)? {
125 Some(result) => Ok(result),
126 None => make_scalar_function(array_position_inner)(&args.args),
127 }
128 }
129
130 fn aliases(&self) -> &[String] {
131 &self.aliases
132 }
133
134 fn documentation(&self) -> Option<&Documentation> {
135 self.doc()
136 }
137}
138
139fn try_array_position_scalar(args: &[ColumnarValue]) -> Result<Option<ColumnarValue>> {
141 if args.len() < 2 || args.len() > 3 {
142 return exec_err!("array_position expects two or three arguments");
143 }
144
145 let scalar_needle = match &args[1] {
147 ColumnarValue::Scalar(s) => s,
148 ColumnarValue::Array(_) => return Ok(None),
149 };
150
151 if scalar_needle.data_type().is_nested() {
154 return Ok(None);
155 }
156
157 let (num_rows, all_inputs_scalar) = match (&args[0], args.get(2)) {
160 (ColumnarValue::Array(a), _) => (a.len(), false),
161 (_, Some(ColumnarValue::Array(a))) => (a.len(), false),
162 _ => (1, true),
163 };
164
165 let needle = scalar_needle.to_array_of_size(1)?;
166 let haystack = args[0].to_array(num_rows)?;
167 let arr_from = resolve_start_from(args.get(2), num_rows)?;
168
169 let result = match haystack.data_type() {
170 List(_) => {
171 let list = as_list_array(&haystack)?;
172 array_position_scalar::<i32>(list, &needle, &arr_from)
173 }
174 LargeList(_) => {
175 let list = as_large_list_array(&haystack)?;
176 array_position_scalar::<i64>(list, &needle, &arr_from)
177 }
178 t => exec_err!("array_position does not support type '{t}'"),
179 }?;
180
181 if all_inputs_scalar {
182 Ok(Some(ColumnarValue::Scalar(ScalarValue::try_from_array(
183 &result, 0,
184 )?)))
185 } else {
186 Ok(Some(ColumnarValue::Array(result)))
187 }
188}
189
190fn array_position_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
191 if args.len() < 2 || args.len() > 3 {
192 return exec_err!("array_position expects two or three arguments");
193 }
194 match &args[0].data_type() {
195 List(_) => general_position_dispatch::<i32>(args),
196 LargeList(_) => general_position_dispatch::<i64>(args),
197 dt => exec_err!("array_position does not support type '{dt}'"),
198 }
199}
200
201fn resolve_start_from(
204 third_arg: Option<&ColumnarValue>,
205 num_rows: usize,
206) -> Result<Vec<i64>> {
207 match third_arg {
208 None => Ok(vec![0i64; num_rows]),
209 Some(ColumnarValue::Scalar(ScalarValue::Int64(Some(v)))) => {
210 Ok(vec![v - 1; num_rows])
211 }
212 Some(ColumnarValue::Scalar(s)) => {
213 exec_err!("array_position expected Int64 for start_from, got {s}")
214 }
215 Some(ColumnarValue::Array(a)) => {
216 Ok(as_int64_array(a)?.values().iter().map(|&x| x - 1).collect())
217 }
218 }
219}
220
221fn array_position_scalar<O: OffsetSizeTrait>(
227 haystack: &GenericListArray<O>,
228 needle: &ArrayRef,
229 arr_from: &[i64], ) -> Result<ArrayRef> {
231 crate::utils::check_datatypes("array_position", &[haystack.values(), needle])?;
232
233 if haystack.len() == 0 {
234 return Ok(Arc::new(UInt64Array::new_null(0)));
235 }
236
237 let needle_datum = Scalar::new(Arc::clone(needle));
238 let validity = haystack.nulls();
239
240 let offsets = haystack.offsets();
244 let first_offset = offsets[0].as_usize();
245 let last_offset = offsets[haystack.len()].as_usize();
246 let visible_values = haystack
247 .values()
248 .slice(first_offset, last_offset - first_offset);
249
250 let eq_array = arrow_ord::cmp::not_distinct(&visible_values, &needle_datum)?;
253 let eq_bits = eq_array.values();
254
255 let mut result: Vec<Option<u64>> = Vec::with_capacity(haystack.len());
256 let mut matches = eq_bits.set_indices().peekable();
257
258 for i in 0..haystack.len() {
261 let start = offsets[i].as_usize() - first_offset;
262 let end = offsets[i + 1].as_usize() - first_offset;
263
264 if validity.is_some_and(|v| v.is_null(i)) {
265 while matches.peek().is_some_and(|&p| p < end) {
267 matches.next();
268 }
269 result.push(None);
270 continue;
271 }
272
273 let from = arr_from[i];
274 let row_len = end - start;
275 if !(from >= 0 && (from as usize) <= row_len) {
276 return exec_err!("start_from out of bounds: {}", from + 1);
277 }
278 let search_start = start + from as usize;
279
280 while matches.peek().is_some_and(|&p| p < search_start) {
282 matches.next();
283 }
284
285 if matches.peek().is_some_and(|&p| p < end) {
287 let pos = *matches.peek().unwrap();
288 result.push(Some((pos - start + 1) as u64));
289 while matches.peek().is_some_and(|&p| p < end) {
291 matches.next();
292 }
293 } else {
294 result.push(None);
295 }
296 }
297
298 debug_assert_eq!(result.len(), haystack.len());
299 Ok(Arc::new(UInt64Array::from(result)))
300}
301
302fn general_position_dispatch<O: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
303 let haystack = as_generic_list_array::<O>(&args[0])?;
304 let needle = &args[1];
305
306 crate::utils::check_datatypes("array_position", &[haystack.values(), needle])?;
307
308 let arr_from = if args.len() == 3 {
309 as_int64_array(&args[2])?
310 .values()
311 .iter()
312 .map(|&x| x - 1)
313 .collect::<Vec<_>>()
314 } else {
315 vec![0; haystack.len()]
316 };
317
318 for (row, &from) in haystack.iter().zip(arr_from.iter()) {
319 if !row.is_none_or(|row| from >= 0 && (from as usize) <= row.len()) {
320 return exec_err!("start_from out of bounds: {}", from + 1);
321 }
322 }
323
324 generic_position::<O>(haystack, needle, &arr_from)
325}
326
327fn generic_position<O: OffsetSizeTrait>(
328 haystack: &GenericListArray<O>,
329 needle: &ArrayRef,
330 arr_from: &[i64], ) -> Result<ArrayRef> {
332 let mut data = Vec::with_capacity(haystack.len());
333
334 for (row_index, (row, &from)) in haystack.iter().zip(arr_from.iter()).enumerate() {
335 let from = from as usize;
336
337 if let Some(row) = row {
338 let eq_array = compare_element_to_list(&row, needle, row_index, true)?;
339
340 let index = eq_array
342 .iter()
343 .skip(from)
344 .position(|e| e == Some(true))
345 .map(|index| (from + index + 1) as u64);
346
347 data.push(index);
348 } else {
349 data.push(None);
350 }
351 }
352
353 Ok(Arc::new(UInt64Array::from(data)))
354}
355
356make_udf_expr_and_func!(
357 ArrayPositions,
358 array_positions,
359 array element, "searches for an element in the array, returns all occurrences.", array_positions_udf );
363
364#[user_doc(
365 doc_section(label = "Array Functions"),
366 description = "Searches for an element in the array, returns all occurrences.",
367 syntax_example = "array_positions(array, element)",
368 sql_example = r#"```sql
369> select array_positions([1, 2, 2, 3, 1, 4], 2);
370+-----------------------------------------------+
371| array_positions(List([1,2,2,3,1,4]),Int64(2)) |
372+-----------------------------------------------+
373| [2, 3] |
374+-----------------------------------------------+
375```"#,
376 argument(
377 name = "array",
378 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
379 ),
380 argument(name = "element", description = "Element to search for in the array.")
381)]
382#[derive(Debug, PartialEq, Eq, Hash)]
383pub struct ArrayPositions {
384 signature: Signature,
385 aliases: Vec<String>,
386}
387
388impl Default for ArrayPositions {
389 fn default() -> Self {
390 Self::new()
391 }
392}
393
394impl ArrayPositions {
395 pub fn new() -> Self {
396 Self {
397 signature: Signature::array_and_element(Volatility::Immutable),
398 aliases: vec![String::from("list_positions")],
399 }
400 }
401}
402
403impl ScalarUDFImpl for ArrayPositions {
404 fn name(&self) -> &str {
405 "array_positions"
406 }
407
408 fn signature(&self) -> &Signature {
409 &self.signature
410 }
411
412 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
413 Ok(List(Arc::new(Field::new_list_field(UInt64, true))))
414 }
415
416 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
417 match try_array_positions_scalar(&args.args)? {
418 Some(result) => Ok(result),
419 None => make_scalar_function(array_positions_inner)(&args.args),
420 }
421 }
422
423 fn aliases(&self) -> &[String] {
424 &self.aliases
425 }
426
427 fn documentation(&self) -> Option<&Documentation> {
428 self.doc()
429 }
430}
431
432fn try_array_positions_scalar(args: &[ColumnarValue]) -> Result<Option<ColumnarValue>> {
434 let [haystack_arg, needle_arg] = take_function_args("array_positions", args)?;
435
436 let scalar_needle = match needle_arg {
437 ColumnarValue::Scalar(s) => s,
438 ColumnarValue::Array(_) => return Ok(None),
439 };
440
441 if scalar_needle.data_type().is_nested() {
444 return Ok(None);
445 }
446
447 let (num_rows, all_inputs_scalar) = match haystack_arg {
448 ColumnarValue::Array(a) => (a.len(), false),
449 ColumnarValue::Scalar(_) => (1, true),
450 };
451
452 let needle = scalar_needle.to_array_of_size(1)?;
453 let haystack = haystack_arg.to_array(num_rows)?;
454
455 let result = match haystack.data_type() {
456 List(_) => {
457 let list = as_list_array(&haystack)?;
458 array_positions_scalar::<i32>(list, &needle)
459 }
460 LargeList(_) => {
461 let list = as_large_list_array(&haystack)?;
462 array_positions_scalar::<i64>(list, &needle)
463 }
464 t => exec_err!("array_positions does not support type '{t}'"),
465 }?;
466
467 if all_inputs_scalar {
468 Ok(Some(ColumnarValue::Scalar(ScalarValue::try_from_array(
469 &result, 0,
470 )?)))
471 } else {
472 Ok(Some(ColumnarValue::Array(result)))
473 }
474}
475
476fn array_positions_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
477 let [haystack, needle] = take_function_args("array_positions", args)?;
478
479 match &haystack.data_type() {
480 List(_) => general_positions::<i32>(as_list_array(&haystack)?, needle),
481 LargeList(_) => general_positions::<i64>(as_large_list_array(&haystack)?, needle),
482 dt => exec_err!("array_positions does not support type '{dt}'"),
483 }
484}
485
486fn general_positions<O: OffsetSizeTrait>(
487 haystack: &GenericListArray<O>,
488 needle: &ArrayRef,
489) -> Result<ArrayRef> {
490 crate::utils::check_datatypes("array_positions", &[haystack.values(), needle])?;
491 let mut data = Vec::with_capacity(haystack.len());
492
493 for (row_index, row) in haystack.iter().enumerate() {
494 if let Some(row) = row {
495 let eq_array = compare_element_to_list(&row, needle, row_index, true)?;
496
497 let indexes = eq_array
499 .iter()
500 .positions(|e| e == Some(true))
501 .map(|index| Some(index as u64 + 1))
502 .collect::<Vec<_>>();
503
504 data.push(Some(indexes));
505 } else {
506 data.push(None);
507 }
508 }
509
510 Ok(Arc::new(
511 ListArray::from_iter_primitive::<UInt64Type, _, _>(data),
512 ))
513}
514
515fn array_positions_scalar<O: OffsetSizeTrait>(
521 haystack: &GenericListArray<O>,
522 needle: &ArrayRef,
523) -> Result<ArrayRef> {
524 crate::utils::check_datatypes("array_positions", &[haystack.values(), needle])?;
525
526 let num_rows = haystack.len();
527 if num_rows == 0 {
528 return Ok(Arc::new(ListArray::try_new(
529 Arc::new(Field::new_list_field(UInt64, true)),
530 OffsetBuffer::new_zeroed(1),
531 Arc::new(UInt64Array::from(Vec::<u64>::new())),
532 None,
533 )?));
534 }
535
536 let needle_datum = Scalar::new(Arc::clone(needle));
537 let validity = haystack.nulls();
538
539 let offsets = haystack.offsets();
543 let first_offset = offsets[0].as_usize();
544 let last_offset = offsets[num_rows].as_usize();
545 let visible_values = haystack
546 .values()
547 .slice(first_offset, last_offset - first_offset);
548
549 let eq_array = arrow_ord::cmp::not_distinct(&visible_values, &needle_datum)?;
552 let eq_bits = eq_array.values();
553
554 let num_matches = eq_bits.count_set_bits();
555 let mut positions: Vec<u64> = Vec::with_capacity(num_matches);
556 let mut result_offsets: Vec<i32> = Vec::with_capacity(num_rows + 1);
557 result_offsets.push(0);
558 let mut matches = eq_bits.set_indices().peekable();
559
560 for i in 0..num_rows {
563 let start = offsets[i].as_usize() - first_offset;
564 let end = offsets[i + 1].as_usize() - first_offset;
565
566 if validity.is_some_and(|v| v.is_null(i)) {
567 while matches.peek().is_some_and(|&p| p < end) {
569 matches.next();
570 }
571 result_offsets.push(positions.len() as i32);
572 continue;
573 }
574
575 while let Some(pos) = matches.next_if(|&p| p < end) {
577 positions.push((pos - start + 1) as u64);
578 }
579 result_offsets.push(positions.len() as i32);
580 }
581
582 debug_assert_eq!(result_offsets.len(), num_rows + 1);
583 Ok(Arc::new(ListArray::try_new(
584 Arc::new(Field::new_list_field(UInt64, true)),
585 OffsetBuffer::new(result_offsets.into()),
586 Arc::new(UInt64Array::from(positions)),
587 validity.cloned(),
588 )?))
589}
590
591#[cfg(test)]
592mod tests {
593 use super::*;
594 use arrow::array::AsArray;
595 use arrow::datatypes::Int32Type;
596 use datafusion_common::config::ConfigOptions;
597
598 #[test]
599 fn test_array_position_sliced_list() -> Result<()> {
600 let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
602 Some(vec![Some(10), Some(20)]),
603 Some(vec![Some(30), Some(40)]),
604 Some(vec![Some(50), Some(60)]),
605 Some(vec![Some(70), Some(80)]),
606 ]);
607 let sliced = list.slice(1, 2);
608 let haystack_field =
609 Arc::new(Field::new("haystack", sliced.data_type().clone(), true));
610 let needle_field = Arc::new(Field::new("needle", DataType::Int32, true));
611 let return_field = Arc::new(Field::new("return", UInt64, true));
612
613 let invoke = |needle: i32| -> Result<ArrayRef> {
616 ArrayPosition::new()
617 .invoke_with_args(ScalarFunctionArgs {
618 args: vec![
619 ColumnarValue::Array(Arc::new(sliced.clone())),
620 ColumnarValue::Scalar(ScalarValue::Int32(Some(needle))),
621 ],
622 arg_fields: vec![
623 Arc::clone(&haystack_field),
624 Arc::clone(&needle_field),
625 ],
626 number_rows: 2,
627 return_field: Arc::clone(&return_field),
628 config_options: Arc::new(ConfigOptions::default()),
629 })?
630 .into_array(2)
631 };
632
633 let output = invoke(10)?;
634 let output = output.as_primitive::<UInt64Type>();
635 assert!(output.is_null(0));
636 assert!(output.is_null(1));
637
638 let output = invoke(70)?;
639 let output = output.as_primitive::<UInt64Type>();
640 assert!(output.is_null(0));
641 assert!(output.is_null(1));
642
643 Ok(())
644 }
645
646 #[test]
647 fn test_array_positions_sliced_list() -> Result<()> {
648 let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
651 Some(vec![Some(10), Some(20), Some(30)]),
652 Some(vec![Some(30), Some(40), Some(30)]),
653 Some(vec![Some(50), Some(60), Some(30)]),
654 Some(vec![Some(70), Some(80), Some(30)]),
655 ]);
656 let sliced = list.slice(1, 2);
657 let haystack_field =
658 Arc::new(Field::new("haystack", sliced.data_type().clone(), true));
659 let needle_field = Arc::new(Field::new("needle", DataType::Int32, true));
660 let return_field = Arc::new(Field::new(
661 "return",
662 List(Arc::new(Field::new_list_field(UInt64, true))),
663 true,
664 ));
665
666 let invoke = |needle: i32| -> Result<ArrayRef> {
667 ArrayPositions::new()
668 .invoke_with_args(ScalarFunctionArgs {
669 args: vec![
670 ColumnarValue::Array(Arc::new(sliced.clone())),
671 ColumnarValue::Scalar(ScalarValue::Int32(Some(needle))),
672 ],
673 arg_fields: vec![
674 Arc::clone(&haystack_field),
675 Arc::clone(&needle_field),
676 ],
677 number_rows: 2,
678 return_field: Arc::clone(&return_field),
679 config_options: Arc::new(ConfigOptions::default()),
680 })?
681 .into_array(2)
682 };
683
684 let output = invoke(30)?;
687 let output = output.as_list::<i32>();
688 let row0 = output.value(0);
689 let row0 = row0.as_primitive::<UInt64Type>();
690 assert_eq!(row0.values().as_ref(), &[1, 3]);
691 let row1 = output.value(1);
692 let row1 = row1.as_primitive::<UInt64Type>();
693 assert_eq!(row1.values().as_ref(), &[3]);
694
695 let output = invoke(10)?;
697 let output = output.as_list::<i32>();
698 assert!(output.value(0).is_empty());
699 assert!(output.value(1).is_empty());
700
701 let output = invoke(70)?;
703 let output = output.as_list::<i32>();
704 assert!(output.value(0).is_empty());
705 assert!(output.value(1).is_empty());
706
707 Ok(())
708 }
709
710 #[test]
711 fn test_array_positions_sliced_list_with_nulls() -> Result<()> {
712 let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
714 Some(vec![Some(1), Some(2)]),
715 None,
716 Some(vec![Some(3), Some(1)]),
717 Some(vec![Some(4), Some(5)]),
718 ]);
719 let sliced = list.slice(1, 2);
720 let haystack_field =
721 Arc::new(Field::new("haystack", sliced.data_type().clone(), true));
722 let needle_field = Arc::new(Field::new("needle", DataType::Int32, true));
723 let return_field = Arc::new(Field::new(
724 "return",
725 List(Arc::new(Field::new_list_field(UInt64, true))),
726 true,
727 ));
728
729 let output = ArrayPositions::new()
730 .invoke_with_args(ScalarFunctionArgs {
731 args: vec![
732 ColumnarValue::Array(Arc::new(sliced)),
733 ColumnarValue::Scalar(ScalarValue::Int32(Some(1))),
734 ],
735 arg_fields: vec![Arc::clone(&haystack_field), Arc::clone(&needle_field)],
736 number_rows: 2,
737 return_field: Arc::clone(&return_field),
738 config_options: Arc::new(ConfigOptions::default()),
739 })?
740 .into_array(2)?;
741
742 let output = output.as_list::<i32>();
743 assert!(output.is_null(0));
745 assert!(!output.is_null(1));
747 let row1 = output.value(1);
748 let row1 = row1.as_primitive::<UInt64Type>();
749 assert_eq!(row1.values().as_ref(), &[2]);
750
751 Ok(())
752 }
753}