Skip to main content

arrow_select/
take.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines take kernel for [Array]
19
20use std::fmt::Display;
21use std::mem::ManuallyDrop;
22use std::sync::Arc;
23
24use arrow_array::builder::{BufferBuilder, UInt32Builder};
25use arrow_array::cast::AsArray;
26use arrow_array::types::*;
27use arrow_array::*;
28use arrow_buffer::{
29    ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, ScalarBuffer,
30    bit_util,
31};
32use arrow_data::{ArrayDataBuilder, transform::MutableArrayData};
33use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
34
35use num_traits::Zero;
36
37/// Take elements by index from [Array], creating a new [Array] from those indexes.
38///
39/// ```text
40/// ┌─────────────────┐      ┌─────────┐                              ┌─────────────────┐
41/// │        A        │      │    0    │                              │        A        │
42/// ├─────────────────┤      ├─────────┤                              ├─────────────────┤
43/// │        D        │      │    2    │                              │        B        │
44/// ├─────────────────┤      ├─────────┤   take(values, indices)      ├─────────────────┤
45/// │        B        │      │    3    │ ─────────────────────────▶   │        C        │
46/// ├─────────────────┤      ├─────────┤                              ├─────────────────┤
47/// │        C        │      │    1    │                              │        D        │
48/// ├─────────────────┤      └─────────┘                              └─────────────────┘
49/// │        E        │
50/// └─────────────────┘
51///    values array          indices array                              result
52/// ```
53///
54/// For selecting values by index from multiple arrays see [`crate::interleave`]
55///
56/// Note that this kernel, similar to other kernels in this crate,
57/// will avoid allocating where not necessary. Consequently
58/// the returned array may share buffers with the inputs
59///
60/// # Errors
61/// This function errors whenever:
62/// * An index cannot be casted to `usize` (typically 32 bit architectures)
63/// * An index is out of bounds and `options` is set to check bounds.
64///
65/// # Safety
66///
67/// When `options` is not set to check bounds, taking indexes after `len` will panic.
68///
69/// # See also
70/// * [`BatchCoalescer`]: to filter multiple [`RecordBatch`] and coalesce
71///   the results into a single array.
72///
73/// [`BatchCoalescer`]: crate::coalesce::BatchCoalescer
74///
75/// # Examples
76/// ```
77/// # use arrow_array::{StringArray, UInt32Array, cast::AsArray};
78/// # use arrow_select::take::take;
79/// let values = StringArray::from(vec!["zero", "one", "two"]);
80///
81/// // Take items at index 2, and 1:
82/// let indices = UInt32Array::from(vec![2, 1]);
83/// let taken = take(&values, &indices, None).unwrap();
84/// let taken = taken.as_string::<i32>();
85///
86/// assert_eq!(*taken, StringArray::from(vec!["two", "one"]));
87/// ```
88pub fn take(
89    values: &dyn Array,
90    indices: &dyn Array,
91    options: Option<TakeOptions>,
92) -> Result<ArrayRef, ArrowError> {
93    let options = options.unwrap_or_default();
94    downcast_integer_array!(
95        indices => {
96            if options.check_bounds {
97                check_bounds(values.len(), indices)?;
98            }
99            let indices = indices.to_indices();
100            take_impl(values, &indices)
101        },
102        d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}")))
103    )
104}
105
106/// For each [ArrayRef] in the [`Vec<ArrayRef>`], take elements by index and create a new
107/// [`Vec<ArrayRef>`] from those indices.
108///
109/// ```text
110/// ┌────────┬────────┐
111/// │        │        │           ┌────────┐                                ┌────────┬────────┐
112/// │   A    │   1    │           │        │                                │        │        │
113/// ├────────┼────────┤           │   0    │                                │   A    │   1    │
114/// │        │        │           ├────────┤                                ├────────┼────────┤
115/// │   D    │   4    │           │        │                                │        │        │
116/// ├────────┼────────┤           │   2    │  take_arrays(values,indices)   │   B    │   2    │
117/// │        │        │           ├────────┤                                ├────────┼────────┤
118/// │   B    │   2    │           │        │  ───────────────────────────►  │        │        │
119/// ├────────┼────────┤           │   3    │                                │   C    │   3    │
120/// │        │        │           ├────────┤                                ├────────┼────────┤
121/// │   C    │   3    │           │        │                                │        │        │
122/// ├────────┼────────┤           │   1    │                                │   D    │   4    │
123/// │        │        │           └────────┘                                └────────┼────────┘
124/// │   E    │   5    │
125/// └────────┴────────┘
126///    values arrays             indices array                                      result
127/// ```
128///
129/// # Errors
130/// This function errors whenever:
131/// * An index cannot be casted to `usize` (typically 32 bit architectures)
132/// * An index is out of bounds and `options` is set to check bounds.
133///
134/// # Safety
135///
136/// When `options` is not set to check bounds, taking indexes after `len` will panic.
137///
138/// # Examples
139/// ```
140/// # use std::sync::Arc;
141/// # use arrow_array::{StringArray, UInt32Array, cast::AsArray};
142/// # use arrow_select::take::{take, take_arrays};
143/// let string_values = Arc::new(StringArray::from(vec!["zero", "one", "two"]));
144/// let values = Arc::new(UInt32Array::from(vec![0, 1, 2]));
145///
146/// // Take items at index 2, and 1:
147/// let indices = UInt32Array::from(vec![2, 1]);
148/// let taken_arrays = take_arrays(&[string_values, values], &indices, None).unwrap();
149/// let taken_string = taken_arrays[0].as_string::<i32>();
150/// assert_eq!(*taken_string, StringArray::from(vec!["two", "one"]));
151/// let taken_values = taken_arrays[1].as_primitive();
152/// assert_eq!(*taken_values, UInt32Array::from(vec![2, 1]));
153/// ```
154pub fn take_arrays(
155    arrays: &[ArrayRef],
156    indices: &dyn Array,
157    options: Option<TakeOptions>,
158) -> Result<Vec<ArrayRef>, ArrowError> {
159    arrays
160        .iter()
161        .map(|array| take(array.as_ref(), indices, options.clone()))
162        .collect()
163}
164
165/// Verifies that the non-null values of `indices` are all `< len`
166fn check_bounds<T: ArrowPrimitiveType>(
167    len: usize,
168    indices: &PrimitiveArray<T>,
169) -> Result<(), ArrowError>
170where
171    T::Native: Display,
172{
173    let len = match T::Native::from_usize(len) {
174        Some(len) => len,
175        None => {
176            if T::DATA_TYPE.is_integer() {
177                // the biggest representable value for T::Native is lower than len, e.g: u8::MAX < 512, no need to check bounds
178                return Ok(());
179            } else {
180                return Err(ArrowError::ComputeError("Cast to usize failed".to_string()));
181            }
182        }
183    };
184
185    if indices.null_count() > 0 {
186        indices.iter().flatten().try_for_each(|index| {
187            if index >= len {
188                return Err(ArrowError::ComputeError(format!(
189                    "Array index out of bounds, cannot get item at index {index} from {len} entries"
190                )));
191            }
192            Ok(())
193        })
194    } else {
195        let in_bounds = indices.values().iter().fold(true, |in_bounds, &i| {
196            in_bounds & (i >= T::Native::ZERO) & (i < len)
197        });
198
199        if !in_bounds {
200            for &index in indices.values() {
201                if index < T::Native::ZERO || index >= len {
202                    return Err(ArrowError::ComputeError(format!(
203                        "Array index out of bounds, cannot get item at index {index} from {len} entries"
204                    )));
205                }
206            }
207        }
208
209        Ok(())
210    }
211}
212
213#[inline(never)]
214fn take_impl<IndexType: ArrowPrimitiveType>(
215    values: &dyn Array,
216    indices: &PrimitiveArray<IndexType>,
217) -> Result<ArrayRef, ArrowError> {
218    if indices.is_empty() {
219        return Ok(new_empty_array(values.data_type()));
220    }
221    downcast_primitive_array! {
222        values => Ok(Arc::new(take_primitive(values, indices)?)),
223        DataType::Boolean => {
224            let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
225            Ok(Arc::new(take_boolean(values, indices)))
226        }
227        DataType::Utf8 => {
228            Ok(Arc::new(take_bytes(values.as_string::<i32>(), indices)?))
229        }
230        DataType::LargeUtf8 => {
231            Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?))
232        }
233        DataType::Utf8View => {
234            Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?))
235        }
236        DataType::List(_) => {
237            Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?))
238        }
239        DataType::LargeList(_) => {
240            Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?))
241        }
242        DataType::ListView(_) => {
243            Ok(Arc::new(take_list_view::<_, Int32Type>(values.as_list_view(), indices)?))
244        }
245        DataType::LargeListView(_) => {
246            Ok(Arc::new(take_list_view::<_, Int64Type>(values.as_list_view(), indices)?))
247        }
248        DataType::FixedSizeList(_, length) => {
249            let values = values
250                .as_any()
251                .downcast_ref::<FixedSizeListArray>()
252                .unwrap();
253            Ok(Arc::new(take_fixed_size_list(
254                values,
255                indices,
256                *length as u32,
257            )?))
258        }
259        DataType::Map(_, _) => {
260            let list_arr = ListArray::from(values.as_map().clone());
261            let list_data = take_list::<_, Int32Type>(&list_arr, indices)?;
262            let builder = list_data.into_data().into_builder().data_type(values.data_type().clone());
263            Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() })))
264        }
265        DataType::Struct(fields) => {
266            let array: &StructArray = values.as_struct();
267            let arrays  = array
268                .columns()
269                .iter()
270                .map(|a| take_impl(a.as_ref(), indices))
271                .collect::<Result<Vec<ArrayRef>, _>>()?;
272            let fields: Vec<(FieldRef, ArrayRef)> =
273                fields.iter().cloned().zip(arrays).collect();
274
275            // Create the null bit buffer.
276            let is_valid: Buffer = indices
277                .iter()
278                .map(|index| {
279                    if let Some(index) = index {
280                        array.is_valid(index.to_usize().unwrap())
281                    } else {
282                        false
283                    }
284                })
285                .collect();
286
287            if fields.is_empty() {
288                let nulls = NullBuffer::new(BooleanBuffer::new(is_valid, 0, indices.len()));
289                Ok(Arc::new(StructArray::new_empty_fields(indices.len(), Some(nulls))))
290            } else {
291                Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
292            }
293        }
294        DataType::Dictionary(_, _) => downcast_dictionary_array! {
295            values => Ok(Arc::new(take_dict(values, indices)?)),
296            t => unimplemented!("Take not supported for dictionary type {:?}", t)
297        }
298        DataType::RunEndEncoded(_, _) => downcast_run_array! {
299            values => Ok(Arc::new(take_run(values, indices)?)),
300            t => unimplemented!("Take not supported for run type {:?}", t)
301        }
302        DataType::Binary => {
303            Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?))
304        }
305        DataType::LargeBinary => {
306            Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?))
307        }
308        DataType::BinaryView => {
309            Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?))
310        }
311        DataType::FixedSizeBinary(size) => {
312            let values = values
313                .as_any()
314                .downcast_ref::<FixedSizeBinaryArray>()
315                .unwrap();
316            Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?))
317        }
318        DataType::Null => {
319            // Take applied to a null array produces a null array.
320            if values.len() >= indices.len() {
321                // If the existing null array is as big as the indices, we can use a slice of it
322                // to avoid allocating a new null array.
323                Ok(values.slice(0, indices.len()))
324            } else {
325                // If the existing null array isn't big enough, create a new one.
326                Ok(new_null_array(&DataType::Null, indices.len()))
327            }
328        }
329        DataType::Union(fields, UnionMode::Sparse) => {
330            let mut children = Vec::with_capacity(fields.len());
331            let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
332            let type_ids = take_native(values.type_ids(), indices);
333            for (type_id, _field) in fields.iter() {
334                let values = values.child(type_id);
335                let values = take_impl(values, indices)?;
336                children.push(values);
337            }
338            let array = UnionArray::try_new(fields.clone(), type_ids, None, children)?;
339            Ok(Arc::new(array))
340        }
341        DataType::Union(fields, UnionMode::Dense) => {
342            let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
343
344            let type_ids = <PrimitiveArray<Int8Type>>::try_new(take_native(values.type_ids(), indices), None)?;
345            let offsets = <PrimitiveArray<Int32Type>>::try_new(take_native(values.offsets().unwrap(), indices), None)?;
346
347            let children = fields.iter()
348                .map(|(field_type_id, _)| {
349                    let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id);
350
351                    let indices = crate::filter::filter(&offsets, &mask)?;
352
353                    let values = values.child(field_type_id);
354
355                    take_impl(values, indices.as_primitive::<Int32Type>())
356                })
357                .collect::<Result<_, _>>()?;
358
359            let mut child_offsets = [0; 128];
360
361            let offsets = type_ids.values()
362                .iter()
363                .map(|&i| {
364                    let offset = child_offsets[i as usize];
365
366                    child_offsets[i as usize] += 1;
367
368                    offset
369                })
370                .collect();
371
372            let (_, type_ids, _) = type_ids.into_parts();
373
374            let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)?;
375
376            Ok(Arc::new(array))
377        }
378        t => unimplemented!("Take not supported for data type {:?}", t)
379    }
380}
381
382/// Options that define how `take` should behave
383#[derive(Clone, Debug, Default)]
384pub struct TakeOptions {
385    /// Perform bounds check before taking indices from values.
386    /// If enabled, an `ArrowError` is returned if the indices are out of bounds.
387    /// If not enabled, and indices exceed bounds, the kernel will panic.
388    pub check_bounds: bool,
389}
390
391/// `take` implementation for all primitive arrays
392///
393/// This checks if an `indices` slot is populated, and gets the value from `values`
394///  as the populated index.
395/// If the `indices` slot is null, a null value is returned.
396/// For example, given:
397///     values:  [1, 2, 3, null, 5]
398///     indices: [0, null, 4, 3]
399/// The result is: [1 (slot 0), null (null slot), 5 (slot 4), null (slot 3)]
400fn take_primitive<T, I>(
401    values: &PrimitiveArray<T>,
402    indices: &PrimitiveArray<I>,
403) -> Result<PrimitiveArray<T>, ArrowError>
404where
405    T: ArrowPrimitiveType,
406    I: ArrowPrimitiveType,
407{
408    let values_buf = take_native(values.values(), indices);
409    let nulls = take_nulls(values.nulls(), indices);
410    Ok(PrimitiveArray::try_new(values_buf, nulls)?.with_data_type(values.data_type().clone()))
411}
412
413#[inline(never)]
414fn take_nulls<I: ArrowPrimitiveType>(
415    values: Option<&NullBuffer>,
416    indices: &PrimitiveArray<I>,
417) -> Option<NullBuffer> {
418    match values.filter(|n| n.null_count() > 0) {
419        Some(n) => NullBuffer::from_unsliced_buffer(
420            take_bits(n.inner(), indices).into_inner(),
421            indices.len(),
422        ),
423        None => indices.nulls().cloned(),
424    }
425}
426
427#[inline(never)]
428fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>(
429    values: &[T],
430    indices: &PrimitiveArray<I>,
431) -> ScalarBuffer<T> {
432    match indices.nulls().filter(|n| n.null_count() > 0) {
433        Some(n) => indices
434            .values()
435            .iter()
436            .enumerate()
437            .map(|(idx, index)| match values.get(index.as_usize()) {
438                Some(v) => *v,
439                // SAFETY: idx<indices.len()
440                None => match unsafe { n.inner().value_unchecked(idx) } {
441                    false => T::default(),
442                    true => panic!("Out-of-bounds index {index:?}"),
443                },
444            })
445            .collect(),
446        None => indices
447            .values()
448            .iter()
449            .map(|index| values[index.as_usize()])
450            .collect(),
451    }
452}
453
454#[inline(never)]
455fn take_bits<I: ArrowPrimitiveType>(
456    values: &BooleanBuffer,
457    indices: &PrimitiveArray<I>,
458) -> BooleanBuffer {
459    let len = indices.len();
460
461    match indices.nulls().filter(|n| n.null_count() > 0) {
462        Some(nulls) => {
463            let mut output_buffer = MutableBuffer::new_null(len);
464            let output_slice = output_buffer.as_slice_mut();
465            nulls.valid_indices().for_each(|idx| {
466                // SAFETY: idx is a valid index in indices.nulls() --> idx<indices.len()
467                if values.value(unsafe { indices.value_unchecked(idx).as_usize() }) {
468                    // SAFETY: MutableBuffer was created with space for indices.len() bit, and idx < indices.len()
469                    unsafe { bit_util::set_bit_raw(output_slice.as_mut_ptr(), idx) };
470                }
471            });
472            BooleanBuffer::new(output_buffer.into(), 0, len)
473        }
474        None => {
475            BooleanBuffer::collect_bool(len, |idx: usize| {
476                // SAFETY: idx<indices.len()
477                values.value(unsafe { indices.value_unchecked(idx).as_usize() })
478            })
479        }
480    }
481}
482
483/// `take` implementation for boolean arrays
484fn take_boolean<IndexType: ArrowPrimitiveType>(
485    values: &BooleanArray,
486    indices: &PrimitiveArray<IndexType>,
487) -> BooleanArray {
488    let val_buf = take_bits(values.values(), indices);
489    let null_buf = take_nulls(values.nulls(), indices);
490    BooleanArray::new(val_buf, null_buf)
491}
492
493/// `take` implementation for string arrays
494fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
495    array: &GenericByteArray<T>,
496    indices: &PrimitiveArray<IndexType>,
497) -> Result<GenericByteArray<T>, ArrowError> {
498    let mut values: Vec<u8> = Vec::new();
499    let mut offsets = Vec::with_capacity(indices.len() + 1);
500    offsets.push(T::Offset::default());
501
502    let input_offsets = array.value_offsets();
503    let mut capacity = 0;
504    let nulls = take_nulls(array.nulls(), indices);
505
506    // Branch on output nulls — `None` means every output slot is valid.
507    match nulls.as_ref().filter(|n| n.null_count() > 0) {
508        // Fast path: no nulls in output, every index is valid.
509        None => {
510            for index in indices.values() {
511                let index = index.as_usize();
512                let start = input_offsets[index].as_usize();
513                let end = input_offsets[index + 1].as_usize();
514                capacity += end - start;
515                offsets.push(
516                    T::Offset::from_usize(capacity)
517                        .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
518                );
519            }
520
521            values.reserve(capacity);
522
523            let dst = values.spare_capacity_mut();
524            debug_assert!(dst.len() >= capacity);
525            let mut offset = 0;
526
527            for index in indices.values() {
528                // SAFETY: in-bounds proven by the first loop's bounds-checked offset access.
529                // dst asserted above to include the required capacity.
530                unsafe {
531                    let data: &[u8] = array.value_unchecked(index.as_usize()).as_ref();
532                    std::ptr::copy_nonoverlapping(
533                        data.as_ptr(),
534                        dst.get_unchecked_mut(offset..).as_mut_ptr().cast::<u8>(),
535                        data.len(),
536                    );
537                    offset += data.len();
538                }
539            }
540
541            // SAFETY: wrote exactly `capacity` bytes above; reserved on line above.
542            unsafe {
543                values.set_len(capacity);
544            }
545        }
546        // Nullable path: only process valid (non-null) output positions.
547        Some(output_nulls) => {
548            let mut source_ranges = Vec::with_capacity(indices.len() - output_nulls.null_count());
549            let mut last_filled = 0;
550
551            // Pre-fill offsets; we overwrite valid positions below.
552            offsets.resize(indices.len() + 1, T::Offset::default());
553
554            // Pass 1: find all valid ranges that need to be copied.
555            for i in output_nulls.valid_indices() {
556                let current_offset = T::Offset::from_usize(capacity)
557                    .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?;
558                // Fill offsets for skipped null slots so they get zero-length ranges.
559                if last_filled < i {
560                    offsets[last_filled + 1..=i].fill(current_offset);
561                }
562
563                // SAFETY: `i` comes from a validity bitmap over `indices`, so it is in-bounds.
564                let index = unsafe { indices.value_unchecked(i) }.as_usize();
565                let start = input_offsets[index].as_usize();
566                let end = input_offsets[index + 1].as_usize();
567                capacity += end - start;
568                offsets[i + 1] = T::Offset::from_usize(capacity)
569                    .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?;
570
571                source_ranges.push((start, end));
572                last_filled = i + 1;
573            }
574
575            // Fill trailing null offsets after the last valid position.
576            let final_offset = T::Offset::from_usize(capacity)
577                .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?;
578            offsets[last_filled + 1..].fill(final_offset);
579            // Pass 2: copy byte data for all collected ranges.
580            values.reserve(capacity);
581            debug_assert_eq!(
582                source_ranges.iter().map(|(s, e)| e - s).sum::<usize>(),
583                capacity,
584                "capacity must equal total bytes across all ranges"
585            );
586
587            let src = array.value_data();
588            let src = src.as_ptr();
589            let dst = values.spare_capacity_mut();
590            debug_assert!(dst.len() >= capacity);
591
592            let mut offset = 0;
593
594            for (start, end) in source_ranges.into_iter() {
595                let value_len = end - start;
596                // SAFETY: caller guarantees each (start, end) is in-bounds of `src`.
597                // `dst` asserted above to include the required capacity.
598                // The regions don't overlap (src is input, dst is a fresh allocation).
599                unsafe {
600                    std::ptr::copy_nonoverlapping(
601                        src.add(start),
602                        dst.get_unchecked_mut(offset..).as_mut_ptr().cast::<u8>(),
603                        value_len,
604                    );
605                    offset += value_len;
606                }
607            }
608            // SAFETY: caller guarantees `capacity` == total bytes across all ranges,
609            // so the loop above wrote exactly `capacity` bytes.
610            unsafe { values.set_len(capacity) };
611        }
612    };
613
614    // SAFETY: offsets are monotonically increasing and in-bounds of `values`,
615    // and `nulls` (if present) has length == `indices.len()`.
616    let array = unsafe {
617        let offsets = OffsetBuffer::new_unchecked(offsets.into());
618        GenericByteArray::<T>::new_unchecked(offsets, values.into(), nulls)
619    };
620
621    Ok(array)
622}
623
624/// `take` implementation for byte view arrays
625fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(
626    array: &GenericByteViewArray<T>,
627    indices: &PrimitiveArray<IndexType>,
628) -> Result<GenericByteViewArray<T>, ArrowError> {
629    let new_views = take_native(array.views(), indices);
630    let new_nulls = take_nulls(array.nulls(), indices);
631    // Safety:  array.views was valid, and take_native copies only valid values, and verifies bounds
632    Ok(unsafe {
633        GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls)
634    })
635}
636
637/// `take` implementation for list arrays
638///
639/// Copies the selected list entries' child slices into a new child array
640/// via `MutableArrayData`, then reconstructs a list array with new offsets
641fn take_list<IndexType, OffsetType>(
642    values: &GenericListArray<OffsetType::Native>,
643    indices: &PrimitiveArray<IndexType>,
644) -> Result<GenericListArray<OffsetType::Native>, ArrowError>
645where
646    IndexType: ArrowPrimitiveType,
647    OffsetType: ArrowPrimitiveType,
648    OffsetType::Native: OffsetSizeTrait,
649    PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
650{
651    let list_offsets = values.value_offsets();
652    let child_data = values.values().to_data();
653    let nulls = take_nulls(values.nulls(), indices);
654
655    let mut new_offsets = Vec::with_capacity(indices.len() + 1);
656    new_offsets.push(OffsetType::Native::zero());
657
658    let use_nulls = child_data.null_count() > 0;
659
660    let capacity = child_data
661        .len()
662        .checked_div(values.len())
663        .map(|v| v * indices.len())
664        .unwrap_or_default();
665
666    let mut array_data = MutableArrayData::new(vec![&child_data], use_nulls, capacity);
667
668    match nulls.as_ref().filter(|n| n.null_count() > 0) {
669        None => {
670            for index in indices.values() {
671                let ix = index.as_usize();
672                let start = list_offsets[ix].as_usize();
673                let end = list_offsets[ix + 1].as_usize();
674                array_data.extend(0, start, end);
675                new_offsets.push(OffsetType::Native::from_usize(array_data.len()).unwrap());
676            }
677        }
678        Some(output_nulls) => {
679            assert_eq!(output_nulls.len(), indices.len());
680
681            let mut last_filled = 0;
682            for i in output_nulls.valid_indices() {
683                let current = OffsetType::Native::from_usize(array_data.len()).unwrap();
684                // Filling offsets for the null values between the two valid indices
685                if last_filled < i {
686                    new_offsets.extend(std::iter::repeat_n(current, i - last_filled));
687                }
688
689                // SAFETY: `i` comes from validity bitmap over `indices`, so in-bounds.
690                let ix = unsafe { indices.value_unchecked(i) }.as_usize();
691                let start = list_offsets[ix].as_usize();
692                let end = list_offsets[ix + 1].as_usize();
693                array_data.extend(0, start, end);
694                new_offsets.push(OffsetType::Native::from_usize(array_data.len()).unwrap());
695                last_filled = i + 1;
696            }
697
698            // Filling offsets for null values at the end
699            let final_offset = OffsetType::Native::from_usize(array_data.len()).unwrap();
700            new_offsets.extend(std::iter::repeat_n(
701                final_offset,
702                indices.len() - last_filled,
703            ));
704        }
705    };
706
707    assert_eq!(
708        new_offsets.len(),
709        indices.len() + 1,
710        "New offsets was filled under/over the expected capacity"
711    );
712
713    let child_data = array_data.freeze();
714    let value_offsets = Buffer::from_vec(new_offsets);
715
716    let list_data = ArrayDataBuilder::new(values.data_type().clone())
717        .len(indices.len())
718        .nulls(nulls)
719        .offset(0)
720        .add_child_data(child_data)
721        .add_buffer(value_offsets);
722
723    let list_data = unsafe { list_data.build_unchecked() };
724    Ok(GenericListArray::<OffsetType::Native>::from(list_data))
725}
726
727fn take_list_view<IndexType, OffsetType>(
728    values: &GenericListViewArray<OffsetType::Native>,
729    indices: &PrimitiveArray<IndexType>,
730) -> Result<GenericListViewArray<OffsetType::Native>, ArrowError>
731where
732    IndexType: ArrowPrimitiveType,
733    OffsetType: ArrowPrimitiveType,
734    OffsetType::Native: OffsetSizeTrait,
735{
736    let taken_offsets = take_native(values.offsets(), indices);
737    let taken_sizes = take_native(values.sizes(), indices);
738    let nulls = take_nulls(values.nulls(), indices);
739
740    let list_view_data = ArrayDataBuilder::new(values.data_type().clone())
741        .len(indices.len())
742        .nulls(nulls)
743        .buffers(vec![taken_offsets.into(), taken_sizes.into()])
744        .child_data(vec![values.values().to_data()]);
745
746    // SAFETY: all buffers and child nodes for ListView added in constructor
747    let list_view_data = unsafe { list_view_data.build_unchecked() };
748
749    Ok(GenericListViewArray::<OffsetType::Native>::from(
750        list_view_data,
751    ))
752}
753
754/// `take` implementation for `FixedSizeListArray`
755///
756/// Calculates the index and indexed offset for the inner array,
757/// applying `take` on the inner array, then reconstructing a list array
758/// with the indexed offsets
759fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
760    values: &FixedSizeListArray,
761    indices: &PrimitiveArray<IndexType>,
762    length: <UInt32Type as ArrowPrimitiveType>::Native,
763) -> Result<FixedSizeListArray, ArrowError> {
764    let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?;
765    let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)?;
766
767    // determine null count and null buffer, which are a function of `values` and `indices`
768    let num_bytes = bit_util::ceil(indices.len(), 8);
769    let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
770    let null_slice = null_buf.as_slice_mut();
771
772    for i in 0..indices.len() {
773        let index = indices
774            .value(i)
775            .to_usize()
776            .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
777        if !indices.is_valid(i) || values.is_null(index) {
778            bit_util::unset_bit(null_slice, i);
779        }
780    }
781
782    let list_data = ArrayDataBuilder::new(values.data_type().clone())
783        .len(indices.len())
784        .null_bit_buffer(Some(null_buf.into()))
785        .offset(0)
786        .add_child_data(taken.into_data());
787
788    let list_data = unsafe { list_data.build_unchecked() };
789
790    Ok(FixedSizeListArray::from(list_data))
791}
792
793/// The take kernel implementation for `FixedSizeBinaryArray`.
794///
795/// The computation is done in two steps:
796/// - Compute the values buffer
797/// - Compute the null buffer
798fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>(
799    values: &FixedSizeBinaryArray,
800    indices: &PrimitiveArray<IndexType>,
801    size: i32,
802) -> Result<FixedSizeBinaryArray, ArrowError> {
803    let size_usize = usize::try_from(size).map_err(|_| {
804        ArrowError::InvalidArgumentError(format!("Cannot convert size '{}' to usize", size))
805    })?;
806
807    let result_buffer = match size_usize {
808        1 => take_fixed_size::<IndexType, 1>(values.values(), indices),
809        2 => take_fixed_size::<IndexType, 2>(values.values(), indices),
810        4 => take_fixed_size::<IndexType, 4>(values.values(), indices),
811        8 => take_fixed_size::<IndexType, 8>(values.values(), indices),
812        16 => take_fixed_size::<IndexType, 16>(values.values(), indices),
813        _ => take_fixed_size_binary_buffer_dynamic_length(values, indices, size_usize),
814    };
815
816    let value_nulls = take_nulls(values.nulls(), indices);
817    let final_nulls = NullBuffer::union(value_nulls.as_ref(), indices.nulls());
818    let array_data = ArrayDataBuilder::new(DataType::FixedSizeBinary(size))
819        .len(indices.len())
820        .nulls(final_nulls)
821        .offset(0)
822        .add_buffer(result_buffer)
823        .build()?;
824
825    return Ok(FixedSizeBinaryArray::from(array_data));
826
827    /// Implementation of the take kernel for fixed size binary arrays.
828    #[inline(never)]
829    fn take_fixed_size_binary_buffer_dynamic_length<IndexType: ArrowPrimitiveType>(
830        values: &FixedSizeBinaryArray,
831        indices: &PrimitiveArray<IndexType>,
832        size_usize: usize,
833    ) -> Buffer {
834        let values_buffer = values.values().as_slice();
835        let mut values_buffer_builder = BufferBuilder::new(indices.len() * size_usize);
836
837        if indices.null_count() == 0 {
838            let array_iter = indices.values().iter().map(|idx| {
839                let offset = idx.as_usize() * size_usize;
840                &values_buffer[offset..offset + size_usize]
841            });
842            for slice in array_iter {
843                values_buffer_builder.append_slice(slice);
844            }
845        } else {
846            // The indices nullability cannot be ignored here because the values buffer may contain
847            // nulls which should not cause a panic.
848            let array_iter = indices.iter().map(|idx| {
849                idx.map(|idx| {
850                    let offset = idx.as_usize() * size_usize;
851                    &values_buffer[offset..offset + size_usize]
852                })
853            });
854            for slice in array_iter {
855                match slice {
856                    None => values_buffer_builder.append_n(size_usize, 0),
857                    Some(slice) => values_buffer_builder.append_slice(slice),
858                }
859            }
860        }
861
862        values_buffer_builder.finish()
863    }
864}
865
866/// Implements the take kernel semantics over a flat [`Buffer`], interpreting it as a slice of
867/// `&[[u8; N]]`, where `N` is a compile-time constant. The usage of a flat [`Buffer`] allows using
868/// this kernel without an available [`ArrowPrimitiveType`] (e.g., for `[u8; 5]`).
869///
870/// # Using This Function in the Primitive Take Kernel
871///
872/// This function is basically the same as [`take_native`] but just on a flat [`Buffer`] instead of
873/// the primitive [`ScalarBuffer`]. Ideally, the [`take_primitive`] kernel should just use this
874/// more general function. However, the "idiomatic code" requires the
875/// [feature(generic_const_exprs)](https://github.com/rust-lang/rust/issues/76560) for calling
876/// `take_fixed_size<I, { size_of::<T::Native> () } >(...)`. Once this feature has been stabilized,
877/// we can use this function also in the primitive kernels.
878fn take_fixed_size<IndexType: ArrowPrimitiveType, const N: usize>(
879    buffer: &Buffer,
880    indices: &PrimitiveArray<IndexType>,
881) -> Buffer {
882    assert_eq!(
883        buffer.len() % N,
884        0,
885        "Invalid array length in take_fixed_size"
886    );
887
888    let ptr = buffer.as_ptr();
889    let chunk_ptr = ptr.cast::<[u8; N]>();
890    let chunk_len = buffer.len() / N;
891    let buffer: &[[u8; N]] = unsafe {
892        // SAFETY: interpret an already valid slice as a slice of N-byte chunks. N divides buffer
893        // length without remainder.
894        std::slice::from_raw_parts(chunk_ptr, chunk_len)
895    };
896
897    let result_buffer = match indices.nulls().filter(|n| n.null_count() > 0) {
898        Some(n) => indices
899            .values()
900            .iter()
901            .enumerate()
902            .map(|(idx, index)| match buffer.get(index.as_usize()) {
903                Some(v) => *v,
904                // SAFETY: idx<indices.len()
905                None => match unsafe { n.inner().value_unchecked(idx) } {
906                    false => [0u8; N],
907                    true => panic!("Out-of-bounds index {index:?}"),
908                },
909            })
910            .collect::<Vec<_>>(),
911        None => indices
912            .values()
913            .iter()
914            .map(|index| buffer[index.as_usize()])
915            .collect::<Vec<_>>(),
916    };
917
918    let mut vec = ManuallyDrop::new(result_buffer); // Prevent de-allocation
919    let ptr = vec.as_mut_ptr();
920    let len = vec.len();
921    let cap = vec.capacity();
922    let result_buffer = unsafe {
923        // SAFETY: flattening an already valid Vec.
924        Vec::from_raw_parts(ptr.cast::<u8>(), len * N, cap * N)
925    };
926
927    Buffer::from_vec(result_buffer)
928}
929
930/// `take` implementation for dictionary arrays
931///
932/// applies `take` to the keys of the dictionary array and returns a new dictionary array
933/// with the same dictionary values and reordered keys
934fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>(
935    values: &DictionaryArray<T>,
936    indices: &PrimitiveArray<I>,
937) -> Result<DictionaryArray<T>, ArrowError> {
938    let new_keys = take_primitive(values.keys(), indices)?;
939    Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) })
940}
941
942/// `take` implementation for run arrays
943///
944/// Finds physical indices for the given logical indices and builds output run array
945/// by taking values in the input run_array.values at the physical indices.
946/// The output run array will be run encoded on the physical indices and not on output values.
947/// For e.g. an input `RunArray{ run_ends = [2,4,6,8], values=[1,2,1,2] }` and `logical_indices=[2,3,6,7]`
948/// would be converted to `physical_indices=[1,1,3,3]` which will be used to build
949/// output `RunArray{ run_ends=[2,4], values=[2,2] }`.
950fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
951    run_array: &RunArray<T>,
952    logical_indices: &PrimitiveArray<I>,
953) -> Result<RunArray<T>, ArrowError> {
954    // get physical indices for the input logical indices
955    let physical_indices = run_array.get_physical_indices(logical_indices.values())?;
956
957    // Run encode the physical indices into new_run_ends_builder
958    // Keep track of the physical indices to take in take_value_indices
959    // `unwrap` is used in this function because the unwrapped values are bounded by the corresponding `::Native`.
960    let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
961    let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
962    let mut new_physical_len = 1;
963    for ix in 1..physical_indices.len() {
964        if physical_indices[ix] != physical_indices[ix - 1] {
965            take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
966            new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
967            new_physical_len += 1;
968        }
969    }
970    take_value_indices
971        .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap());
972    new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
973    let new_run_ends = unsafe {
974        // Safety:
975        // The function builds a valid run_ends array and hence need not be validated.
976        ArrayDataBuilder::new(T::DATA_TYPE)
977            .len(new_physical_len)
978            .null_count(0)
979            .add_buffer(new_run_ends_builder.finish())
980            .build_unchecked()
981    };
982
983    let take_value_indices: PrimitiveArray<I> = unsafe {
984        // Safety:
985        // The function builds a valid take_value_indices array and hence need not be validated.
986        ArrayDataBuilder::new(I::DATA_TYPE)
987            .len(new_physical_len)
988            .null_count(0)
989            .add_buffer(take_value_indices.finish())
990            .build_unchecked()
991            .into()
992    };
993
994    let new_values = take(run_array.values(), &take_value_indices, None)?;
995
996    let builder = ArrayDataBuilder::new(run_array.data_type().clone())
997        .len(physical_indices.len())
998        .add_child_data(new_run_ends)
999        .add_child_data(new_values.into_data());
1000    let array_data = unsafe {
1001        // Safety:
1002        //  This function builds a valid run array and hence can skip validation.
1003        builder.build_unchecked()
1004    };
1005    Ok(array_data.into())
1006}
1007
1008/// Takes/filters a fixed size list array's inner data using the offsets of the list array.
1009fn take_value_indices_from_fixed_size_list<IndexType>(
1010    list: &FixedSizeListArray,
1011    indices: &PrimitiveArray<IndexType>,
1012    length: <UInt32Type as ArrowPrimitiveType>::Native,
1013) -> Result<PrimitiveArray<UInt32Type>, ArrowError>
1014where
1015    IndexType: ArrowPrimitiveType,
1016{
1017    let mut values = UInt32Builder::with_capacity(length as usize * indices.len());
1018
1019    for i in 0..indices.len() {
1020        if indices.is_valid(i) {
1021            let index = indices
1022                .value(i)
1023                .to_usize()
1024                .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
1025            let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native;
1026
1027            // Safety: Range always has known length.
1028            unsafe {
1029                values.append_trusted_len_iter(start..start + length);
1030            }
1031        } else {
1032            values.append_nulls(length as usize);
1033        }
1034    }
1035
1036    Ok(values.finish())
1037}
1038
1039/// To avoid generating take implementations for every index type, instead we
1040/// only generate for UInt32 and UInt64 and coerce inputs to these types
1041trait ToIndices {
1042    type T: ArrowPrimitiveType;
1043
1044    fn to_indices(&self) -> PrimitiveArray<Self::T>;
1045}
1046
1047macro_rules! to_indices_reinterpret {
1048    ($t:ty, $o:ty) => {
1049        impl ToIndices for PrimitiveArray<$t> {
1050            type T = $o;
1051
1052            fn to_indices(&self) -> PrimitiveArray<$o> {
1053                let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len());
1054                PrimitiveArray::new(cast, self.nulls().cloned())
1055            }
1056        }
1057    };
1058}
1059
1060macro_rules! to_indices_identity {
1061    ($t:ty) => {
1062        impl ToIndices for PrimitiveArray<$t> {
1063            type T = $t;
1064
1065            fn to_indices(&self) -> PrimitiveArray<$t> {
1066                self.clone()
1067            }
1068        }
1069    };
1070}
1071
1072macro_rules! to_indices_widening {
1073    ($t:ty, $o:ty) => {
1074        impl ToIndices for PrimitiveArray<$t> {
1075            type T = UInt32Type;
1076
1077            fn to_indices(&self) -> PrimitiveArray<$o> {
1078                let cast = self.values().iter().copied().map(|x| x as _).collect();
1079                PrimitiveArray::new(cast, self.nulls().cloned())
1080            }
1081        }
1082    };
1083}
1084
1085to_indices_widening!(UInt8Type, UInt32Type);
1086to_indices_widening!(Int8Type, UInt32Type);
1087
1088to_indices_widening!(UInt16Type, UInt32Type);
1089to_indices_widening!(Int16Type, UInt32Type);
1090
1091to_indices_identity!(UInt32Type);
1092to_indices_reinterpret!(Int32Type, UInt32Type);
1093
1094to_indices_identity!(UInt64Type);
1095to_indices_reinterpret!(Int64Type, UInt64Type);
1096
1097/// Take rows by index from [`RecordBatch`] and returns a new [`RecordBatch`] from those indexes.
1098///
1099/// This function will call [`take`] on each array of the [`RecordBatch`] and assemble a new [`RecordBatch`].
1100///
1101/// # Example
1102/// ```
1103/// # use std::sync::Arc;
1104/// # use arrow_array::{StringArray, Int32Array, UInt32Array, RecordBatch};
1105/// # use arrow_schema::{DataType, Field, Schema};
1106/// # use arrow_select::take::take_record_batch;
1107/// let schema = Arc::new(Schema::new(vec![
1108///     Field::new("a", DataType::Int32, true),
1109///     Field::new("b", DataType::Utf8, true),
1110/// ]));
1111/// let batch = RecordBatch::try_new(
1112///     schema.clone(),
1113///     vec![
1114///         Arc::new(Int32Array::from_iter_values(0..20)),
1115///         Arc::new(StringArray::from_iter_values(
1116///             (0..20).map(|i| format!("str-{}", i)),
1117///         )),
1118///     ],
1119/// )
1120/// .unwrap();
1121///
1122/// let indices = UInt32Array::from(vec![1, 5, 10]);
1123/// let taken = take_record_batch(&batch, &indices).unwrap();
1124///
1125/// let expected = RecordBatch::try_new(
1126///     schema,
1127///     vec![
1128///         Arc::new(Int32Array::from(vec![1, 5, 10])),
1129///         Arc::new(StringArray::from(vec!["str-1", "str-5", "str-10"])),
1130///     ],
1131/// )
1132/// .unwrap();
1133/// assert_eq!(taken, expected);
1134/// ```
1135pub fn take_record_batch(
1136    record_batch: &RecordBatch,
1137    indices: &dyn Array,
1138) -> Result<RecordBatch, ArrowError> {
1139    let columns = record_batch
1140        .columns()
1141        .iter()
1142        .map(|c| take(c, indices, None))
1143        .collect::<Result<Vec<_>, _>>()?;
1144    RecordBatch::try_new(record_batch.schema(), columns)
1145}
1146
1147#[cfg(test)]
1148mod tests {
1149    use super::*;
1150    use arrow_array::builder::*;
1151    use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
1152    use arrow_data::ArrayData;
1153    use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
1154    use num_traits::ToPrimitive;
1155
1156    fn test_take_decimal_arrays(
1157        data: Vec<Option<i128>>,
1158        index: &UInt32Array,
1159        options: Option<TakeOptions>,
1160        expected_data: Vec<Option<i128>>,
1161        precision: &u8,
1162        scale: &i8,
1163    ) -> Result<(), ArrowError> {
1164        let output = data
1165            .into_iter()
1166            .collect::<Decimal128Array>()
1167            .with_precision_and_scale(*precision, *scale)
1168            .unwrap();
1169
1170        let expected = expected_data
1171            .into_iter()
1172            .collect::<Decimal128Array>()
1173            .with_precision_and_scale(*precision, *scale)
1174            .unwrap();
1175
1176        let expected = Arc::new(expected) as ArrayRef;
1177        let output = take(&output, index, options).unwrap();
1178        assert_eq!(&output, &expected);
1179        Ok(())
1180    }
1181
1182    fn test_take_boolean_arrays(
1183        data: Vec<Option<bool>>,
1184        index: &UInt32Array,
1185        options: Option<TakeOptions>,
1186        expected_data: Vec<Option<bool>>,
1187    ) {
1188        let output = BooleanArray::from(data);
1189        let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef;
1190        let output = take(&output, index, options).unwrap();
1191        assert_eq!(&output, &expected)
1192    }
1193
1194    fn test_take_primitive_arrays<T>(
1195        data: Vec<Option<T::Native>>,
1196        index: &UInt32Array,
1197        options: Option<TakeOptions>,
1198        expected_data: Vec<Option<T::Native>>,
1199    ) -> Result<(), ArrowError>
1200    where
1201        T: ArrowPrimitiveType,
1202        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1203    {
1204        let output = PrimitiveArray::<T>::from(data);
1205        let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1206        let output = take(&output, index, options)?;
1207        assert_eq!(&output, &expected);
1208        Ok(())
1209    }
1210
1211    fn test_take_primitive_arrays_non_null<T>(
1212        data: Vec<T::Native>,
1213        index: &UInt32Array,
1214        options: Option<TakeOptions>,
1215        expected_data: Vec<Option<T::Native>>,
1216    ) -> Result<(), ArrowError>
1217    where
1218        T: ArrowPrimitiveType,
1219        PrimitiveArray<T>: From<Vec<T::Native>>,
1220        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1221    {
1222        let output = PrimitiveArray::<T>::from(data);
1223        let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1224        let output = take(&output, index, options)?;
1225        assert_eq!(&output, &expected);
1226        Ok(())
1227    }
1228
1229    fn test_take_impl_primitive_arrays<T, I>(
1230        data: Vec<Option<T::Native>>,
1231        index: &PrimitiveArray<I>,
1232        options: Option<TakeOptions>,
1233        expected_data: Vec<Option<T::Native>>,
1234    ) where
1235        T: ArrowPrimitiveType,
1236        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1237        I: ArrowPrimitiveType,
1238    {
1239        let output = PrimitiveArray::<T>::from(data);
1240        let expected = PrimitiveArray::<T>::from(expected_data);
1241        let output = take(&output, index, options).unwrap();
1242        let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1243        assert_eq!(output, &expected)
1244    }
1245
1246    // create a simple struct for testing purposes
1247    fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray {
1248        let mut struct_builder = StructBuilder::new(
1249            Fields::from(vec![
1250                Field::new("a", DataType::Boolean, true),
1251                Field::new("b", DataType::Int32, true),
1252            ]),
1253            vec![
1254                Box::new(BooleanBuilder::with_capacity(values.len())),
1255                Box::new(Int32Builder::with_capacity(values.len())),
1256            ],
1257        );
1258
1259        for value in values {
1260            struct_builder
1261                .field_builder::<BooleanBuilder>(0)
1262                .unwrap()
1263                .append_option(value.and_then(|v| v.0));
1264            struct_builder
1265                .field_builder::<Int32Builder>(1)
1266                .unwrap()
1267                .append_option(value.and_then(|v| v.1));
1268            struct_builder.append(value.is_some());
1269        }
1270        struct_builder.finish()
1271    }
1272
1273    #[test]
1274    fn test_take_decimal128_non_null_indices() {
1275        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1276        let precision: u8 = 10;
1277        let scale: i8 = 5;
1278        test_take_decimal_arrays(
1279            vec![None, Some(3), Some(5), Some(2), Some(3), None],
1280            &index,
1281            None,
1282            vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1283            &precision,
1284            &scale,
1285        )
1286        .unwrap();
1287    }
1288
1289    #[test]
1290    fn test_take_decimal128() {
1291        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1292        let precision: u8 = 10;
1293        let scale: i8 = 5;
1294        test_take_decimal_arrays(
1295            vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1296            &index,
1297            None,
1298            vec![Some(3), None, Some(1), Some(3), Some(2)],
1299            &precision,
1300            &scale,
1301        )
1302        .unwrap();
1303    }
1304
1305    #[test]
1306    fn test_take_primitive_non_null_indices() {
1307        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1308        test_take_primitive_arrays::<Int8Type>(
1309            vec![None, Some(3), Some(5), Some(2), Some(3), None],
1310            &index,
1311            None,
1312            vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1313        )
1314        .unwrap();
1315    }
1316
1317    #[test]
1318    fn test_take_primitive_non_null_values() {
1319        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1320        test_take_primitive_arrays::<Int8Type>(
1321            vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1322            &index,
1323            None,
1324            vec![Some(3), None, Some(1), Some(3), Some(2)],
1325        )
1326        .unwrap();
1327    }
1328
1329    #[test]
1330    fn test_take_primitive_non_null() {
1331        let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1332        test_take_primitive_arrays::<Int8Type>(
1333            vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
1334            &index,
1335            None,
1336            vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
1337        )
1338        .unwrap();
1339    }
1340
1341    #[test]
1342    fn test_take_primitive_nullable_indices_non_null_values_with_offset() {
1343        let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1344        let index = index.slice(2, 4);
1345        let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1346
1347        assert_eq!(
1348            index,
1349            &UInt32Array::from(vec![Some(2), Some(3), None, None])
1350        );
1351
1352        test_take_primitive_arrays_non_null::<Int64Type>(
1353            vec![0, 10, 20, 30, 40, 50],
1354            index,
1355            None,
1356            vec![Some(20), Some(30), None, None],
1357        )
1358        .unwrap();
1359    }
1360
1361    #[test]
1362    fn test_take_primitive_nullable_indices_nullable_values_with_offset() {
1363        let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1364        let index = index.slice(2, 4);
1365        let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1366
1367        assert_eq!(
1368            index,
1369            &UInt32Array::from(vec![Some(2), Some(3), None, None])
1370        );
1371
1372        test_take_primitive_arrays::<Int64Type>(
1373            vec![None, None, Some(20), Some(30), Some(40), Some(50)],
1374            index,
1375            None,
1376            vec![Some(20), Some(30), None, None],
1377        )
1378        .unwrap();
1379    }
1380
1381    #[test]
1382    fn test_take_primitive() {
1383        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1384
1385        // int8
1386        test_take_primitive_arrays::<Int8Type>(
1387            vec![Some(0), None, Some(2), Some(3), None],
1388            &index,
1389            None,
1390            vec![Some(3), None, None, Some(3), Some(2)],
1391        )
1392        .unwrap();
1393
1394        // int16
1395        test_take_primitive_arrays::<Int16Type>(
1396            vec![Some(0), None, Some(2), Some(3), None],
1397            &index,
1398            None,
1399            vec![Some(3), None, None, Some(3), Some(2)],
1400        )
1401        .unwrap();
1402
1403        // int32
1404        test_take_primitive_arrays::<Int32Type>(
1405            vec![Some(0), None, Some(2), Some(3), None],
1406            &index,
1407            None,
1408            vec![Some(3), None, None, Some(3), Some(2)],
1409        )
1410        .unwrap();
1411
1412        // int64
1413        test_take_primitive_arrays::<Int64Type>(
1414            vec![Some(0), None, Some(2), Some(3), None],
1415            &index,
1416            None,
1417            vec![Some(3), None, None, Some(3), Some(2)],
1418        )
1419        .unwrap();
1420
1421        // uint8
1422        test_take_primitive_arrays::<UInt8Type>(
1423            vec![Some(0), None, Some(2), Some(3), None],
1424            &index,
1425            None,
1426            vec![Some(3), None, None, Some(3), Some(2)],
1427        )
1428        .unwrap();
1429
1430        // uint16
1431        test_take_primitive_arrays::<UInt16Type>(
1432            vec![Some(0), None, Some(2), Some(3), None],
1433            &index,
1434            None,
1435            vec![Some(3), None, None, Some(3), Some(2)],
1436        )
1437        .unwrap();
1438
1439        // uint32
1440        test_take_primitive_arrays::<UInt32Type>(
1441            vec![Some(0), None, Some(2), Some(3), None],
1442            &index,
1443            None,
1444            vec![Some(3), None, None, Some(3), Some(2)],
1445        )
1446        .unwrap();
1447
1448        // int64
1449        test_take_primitive_arrays::<Int64Type>(
1450            vec![Some(0), None, Some(2), Some(-15), None],
1451            &index,
1452            None,
1453            vec![Some(-15), None, None, Some(-15), Some(2)],
1454        )
1455        .unwrap();
1456
1457        // interval_year_month
1458        test_take_primitive_arrays::<IntervalYearMonthType>(
1459            vec![Some(0), None, Some(2), Some(-15), None],
1460            &index,
1461            None,
1462            vec![Some(-15), None, None, Some(-15), Some(2)],
1463        )
1464        .unwrap();
1465
1466        // interval_day_time
1467        let v1 = IntervalDayTime::new(0, 0);
1468        let v2 = IntervalDayTime::new(2, 0);
1469        let v3 = IntervalDayTime::new(-15, 0);
1470        test_take_primitive_arrays::<IntervalDayTimeType>(
1471            vec![Some(v1), None, Some(v2), Some(v3), None],
1472            &index,
1473            None,
1474            vec![Some(v3), None, None, Some(v3), Some(v2)],
1475        )
1476        .unwrap();
1477
1478        // interval_month_day_nano
1479        let v1 = IntervalMonthDayNano::new(0, 0, 0);
1480        let v2 = IntervalMonthDayNano::new(2, 0, 0);
1481        let v3 = IntervalMonthDayNano::new(-15, 0, 0);
1482        test_take_primitive_arrays::<IntervalMonthDayNanoType>(
1483            vec![Some(v1), None, Some(v2), Some(v3), None],
1484            &index,
1485            None,
1486            vec![Some(v3), None, None, Some(v3), Some(v2)],
1487        )
1488        .unwrap();
1489
1490        // duration_second
1491        test_take_primitive_arrays::<DurationSecondType>(
1492            vec![Some(0), None, Some(2), Some(-15), None],
1493            &index,
1494            None,
1495            vec![Some(-15), None, None, Some(-15), Some(2)],
1496        )
1497        .unwrap();
1498
1499        // duration_millisecond
1500        test_take_primitive_arrays::<DurationMillisecondType>(
1501            vec![Some(0), None, Some(2), Some(-15), None],
1502            &index,
1503            None,
1504            vec![Some(-15), None, None, Some(-15), Some(2)],
1505        )
1506        .unwrap();
1507
1508        // duration_microsecond
1509        test_take_primitive_arrays::<DurationMicrosecondType>(
1510            vec![Some(0), None, Some(2), Some(-15), None],
1511            &index,
1512            None,
1513            vec![Some(-15), None, None, Some(-15), Some(2)],
1514        )
1515        .unwrap();
1516
1517        // duration_nanosecond
1518        test_take_primitive_arrays::<DurationNanosecondType>(
1519            vec![Some(0), None, Some(2), Some(-15), None],
1520            &index,
1521            None,
1522            vec![Some(-15), None, None, Some(-15), Some(2)],
1523        )
1524        .unwrap();
1525
1526        // float32
1527        test_take_primitive_arrays::<Float32Type>(
1528            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1529            &index,
1530            None,
1531            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1532        )
1533        .unwrap();
1534
1535        // float64
1536        test_take_primitive_arrays::<Float64Type>(
1537            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1538            &index,
1539            None,
1540            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1541        )
1542        .unwrap();
1543    }
1544
1545    #[test]
1546    fn test_take_preserve_timezone() {
1547        let index = Int64Array::from(vec![Some(0), None]);
1548
1549        let input = TimestampNanosecondArray::from(vec![
1550            1_639_715_368_000_000_000,
1551            1_639_715_368_000_000_000,
1552        ])
1553        .with_timezone("UTC".to_string());
1554        let result = take(&input, &index, None).unwrap();
1555        match result.data_type() {
1556            DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
1557                assert_eq!(tz.clone(), Some("UTC".into()))
1558            }
1559            _ => panic!(),
1560        }
1561    }
1562
1563    #[test]
1564    fn test_take_impl_primitive_with_int64_indices() {
1565        let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1566
1567        // int16
1568        test_take_impl_primitive_arrays::<Int16Type, Int64Type>(
1569            vec![Some(0), None, Some(2), Some(3), None],
1570            &index,
1571            None,
1572            vec![Some(3), None, None, Some(3), Some(2)],
1573        );
1574
1575        // int64
1576        test_take_impl_primitive_arrays::<Int64Type, Int64Type>(
1577            vec![Some(0), None, Some(2), Some(-15), None],
1578            &index,
1579            None,
1580            vec![Some(-15), None, None, Some(-15), Some(2)],
1581        );
1582
1583        // uint64
1584        test_take_impl_primitive_arrays::<UInt64Type, Int64Type>(
1585            vec![Some(0), None, Some(2), Some(3), None],
1586            &index,
1587            None,
1588            vec![Some(3), None, None, Some(3), Some(2)],
1589        );
1590
1591        // duration_millisecond
1592        test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>(
1593            vec![Some(0), None, Some(2), Some(-15), None],
1594            &index,
1595            None,
1596            vec![Some(-15), None, None, Some(-15), Some(2)],
1597        );
1598
1599        // float32
1600        test_take_impl_primitive_arrays::<Float32Type, Int64Type>(
1601            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1602            &index,
1603            None,
1604            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1605        );
1606    }
1607
1608    #[test]
1609    fn test_take_impl_primitive_with_uint8_indices() {
1610        let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1611
1612        // int16
1613        test_take_impl_primitive_arrays::<Int16Type, UInt8Type>(
1614            vec![Some(0), None, Some(2), Some(3), None],
1615            &index,
1616            None,
1617            vec![Some(3), None, None, Some(3), Some(2)],
1618        );
1619
1620        // duration_millisecond
1621        test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>(
1622            vec![Some(0), None, Some(2), Some(-15), None],
1623            &index,
1624            None,
1625            vec![Some(-15), None, None, Some(-15), Some(2)],
1626        );
1627
1628        // float32
1629        test_take_impl_primitive_arrays::<Float32Type, UInt8Type>(
1630            vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1631            &index,
1632            None,
1633            vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1634        );
1635    }
1636
1637    #[test]
1638    fn test_take_bool() {
1639        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1640        // boolean
1641        test_take_boolean_arrays(
1642            vec![Some(false), None, Some(true), Some(false), None],
1643            &index,
1644            None,
1645            vec![Some(false), None, None, Some(false), Some(true)],
1646        );
1647    }
1648
1649    #[test]
1650    fn test_take_bool_nullable_index() {
1651        // indices where the masked invalid elements would be out of bounds
1652        let index_data = ArrayData::try_new(
1653            DataType::UInt32,
1654            6,
1655            Some(Buffer::from_iter(vec![
1656                false, true, false, true, false, true,
1657            ])),
1658            0,
1659            vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1660            vec![],
1661        )
1662        .unwrap();
1663        let index = UInt32Array::from(index_data);
1664        test_take_boolean_arrays(
1665            vec![Some(true), None, Some(false)],
1666            &index,
1667            None,
1668            vec![None, Some(true), None, None, None, Some(false)],
1669        );
1670    }
1671
1672    #[test]
1673    fn test_take_bool_nullable_index_nonnull_values() {
1674        // indices where the masked invalid elements would be out of bounds
1675        let index_data = ArrayData::try_new(
1676            DataType::UInt32,
1677            6,
1678            Some(Buffer::from_iter(vec![
1679                false, true, false, true, false, true,
1680            ])),
1681            0,
1682            vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1683            vec![],
1684        )
1685        .unwrap();
1686        let index = UInt32Array::from(index_data);
1687        test_take_boolean_arrays(
1688            vec![Some(true), Some(true), Some(false)],
1689            &index,
1690            None,
1691            vec![None, Some(true), None, Some(true), None, Some(false)],
1692        );
1693    }
1694
1695    #[test]
1696    fn test_take_bool_with_offset() {
1697        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]);
1698        let index = index.slice(2, 4);
1699        let index = index
1700            .as_any()
1701            .downcast_ref::<PrimitiveArray<UInt32Type>>()
1702            .unwrap();
1703
1704        // boolean
1705        test_take_boolean_arrays(
1706            vec![Some(false), None, Some(true), Some(false), None],
1707            index,
1708            None,
1709            vec![None, Some(false), Some(true), None],
1710        );
1711    }
1712
1713    fn _test_take_string<'a, K>()
1714    where
1715        K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static,
1716    {
1717        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]);
1718
1719        let array = K::from(vec![
1720            Some("one"),
1721            None,
1722            Some("three"),
1723            Some("four"),
1724            Some("five"),
1725        ]);
1726        let actual = take(&array, &index, None).unwrap();
1727        assert_eq!(actual.len(), index.len());
1728
1729        let actual = actual.as_any().downcast_ref::<K>().unwrap();
1730
1731        let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]);
1732
1733        assert_eq!(actual, &expected);
1734    }
1735
1736    #[test]
1737    fn test_take_string() {
1738        _test_take_string::<StringArray>()
1739    }
1740
1741    #[test]
1742    fn test_take_large_string() {
1743        _test_take_string::<LargeStringArray>()
1744    }
1745
1746    #[test]
1747    fn test_take_slice_string() {
1748        let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]);
1749        let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]);
1750        let indices_slice = indices.slice(1, 4);
1751        let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]);
1752        let result = take(&strings, &indices_slice, None).unwrap();
1753        assert_eq!(result.as_ref(), &expected);
1754    }
1755
1756    /// Take from a *sliced* byte array, i.e. one whose value offsets do not
1757    /// start at zero. This exercises copying byte data out of an array with a
1758    /// non-zero base offset for both the no-null fast path and the nullable
1759    /// path (null indices and selected null values).
1760    #[test]
1761    fn test_take_bytes_sliced_values() {
1762        let values = StringArray::from(vec![
1763            Some("aaa"),
1764            Some("bbb"),
1765            None,
1766            Some("ccccc"),
1767            Some("dd"),
1768            None,
1769            Some("eeee"),
1770        ]);
1771        // Slice so the underlying value offsets no longer start at 0:
1772        // sliced == [None, "ccccc", "dd", None, "eeee"]
1773        let sliced = values.slice(2, 5);
1774
1775        // Fast path: every output slot is valid (no null indices, no null
1776        // values selected).
1777        let indices = Int32Array::from(vec![1, 2, 4, 1]);
1778        let result = take(&sliced, &indices, None).unwrap();
1779        let expected =
1780            StringArray::from(vec![Some("ccccc"), Some("dd"), Some("eeee"), Some("ccccc")]);
1781        assert_eq!(result.as_string::<i32>(), &expected);
1782
1783        // Nullable path: a null index (position 1) and selected null values
1784        // (sliced indices 0 and 3 are null).
1785        let indices = Int32Array::from(vec![Some(1), None, Some(0), Some(4), Some(3)]);
1786        let result = take(&sliced, &indices, None).unwrap();
1787        let expected = StringArray::from(vec![Some("ccccc"), None, None, Some("eeee"), None]);
1788        assert_eq!(result.as_string::<i32>(), &expected);
1789    }
1790
1791    fn _test_byte_view<T>()
1792    where
1793        T: ByteViewType,
1794        str: AsRef<T::Native>,
1795        T::Native: PartialEq,
1796    {
1797        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]);
1798        let array = {
1799            // ["hello", "world", null, "large payload over 12 bytes", "lulu"]
1800            let mut builder = GenericByteViewBuilder::<T>::new();
1801            builder.append_value("hello");
1802            builder.append_value("world");
1803            builder.append_null();
1804            builder.append_value("large payload over 12 bytes");
1805            builder.append_value("lulu");
1806            builder.finish()
1807        };
1808
1809        let actual = take(&array, &index, None).unwrap();
1810
1811        assert_eq!(actual.len(), index.len());
1812
1813        let expected = {
1814            // ["large payload over 12 bytes", null, "world", "large payload over 12 bytes", "lulu", null]
1815            let mut builder = GenericByteViewBuilder::<T>::new();
1816            builder.append_value("large payload over 12 bytes");
1817            builder.append_null();
1818            builder.append_value("world");
1819            builder.append_value("large payload over 12 bytes");
1820            builder.append_value("lulu");
1821            builder.append_null();
1822            builder.finish()
1823        };
1824
1825        assert_eq!(actual.as_ref(), &expected);
1826    }
1827
1828    #[test]
1829    fn test_take_string_view() {
1830        _test_byte_view::<StringViewType>()
1831    }
1832
1833    #[test]
1834    fn test_take_binary_view() {
1835        _test_byte_view::<BinaryViewType>()
1836    }
1837
1838    macro_rules! test_take_list {
1839        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1840            // Construct a value array, [[0,0,0], [-1,-2,-1], [], [2,3]]
1841            let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1842            // Construct offsets
1843            let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1844            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1845            // Construct a list array from the above two
1846            let list_data_type =
1847                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false)));
1848            let list_data = ArrayData::builder(list_data_type.clone())
1849                .len(4)
1850                .add_buffer(value_offsets)
1851                .add_child_data(value_data)
1852                .build()
1853                .unwrap();
1854            let list_array = $list_array_type::from(list_data);
1855
1856            // index returns: [[2,3], null, [-1,-2,-1], [], [0,0,0]]
1857            let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]);
1858
1859            let a = take(&list_array, &index, None).unwrap();
1860            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1861
1862            // construct a value array with expected results:
1863            // [[2,3], null, [-1,-2,-1], [], [0,0,0]]
1864            let expected_data = Int32Array::from(vec![
1865                Some(2),
1866                Some(3),
1867                Some(-1),
1868                Some(-2),
1869                Some(-1),
1870                Some(0),
1871                Some(0),
1872                Some(0),
1873            ])
1874            .into_data();
1875            // construct offsets
1876            let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8];
1877            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1878            // construct list array from the two
1879            let expected_list_data = ArrayData::builder(list_data_type)
1880                .len(5)
1881                // null buffer remains the same as only the indices have nulls
1882                .nulls(index.nulls().cloned())
1883                .add_buffer(expected_offsets)
1884                .add_child_data(expected_data)
1885                .build()
1886                .unwrap();
1887            let expected_list_array = $list_array_type::from(expected_list_data);
1888
1889            assert_eq!(a, &expected_list_array);
1890        }};
1891    }
1892
1893    macro_rules! test_take_list_with_value_nulls {
1894        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1895            // Construct a value array, [[0,null,0], [-1,-2,3], [null], [5,null]]
1896            let value_data = Int32Array::from(vec![
1897                Some(0),
1898                None,
1899                Some(0),
1900                Some(-1),
1901                Some(-2),
1902                Some(3),
1903                None,
1904                Some(5),
1905                None,
1906            ])
1907            .into_data();
1908            // Construct offsets
1909            let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9];
1910            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1911            // Construct a list array from the above two
1912            let list_data_type =
1913                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1914            let list_data = ArrayData::builder(list_data_type.clone())
1915                .len(4)
1916                .add_buffer(value_offsets)
1917                .null_bit_buffer(Some(Buffer::from([0b11111111])))
1918                .add_child_data(value_data)
1919                .build()
1920                .unwrap();
1921            let list_array = $list_array_type::from(list_data);
1922
1923            // index returns: [[null], null, [-1,-2,3], [2,null], [0,null,0]]
1924            let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1925
1926            let a = take(&list_array, &index, None).unwrap();
1927            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1928
1929            // construct a value array with expected results:
1930            // [[null], null, [-1,-2,3], [5,null], [0,null,0]]
1931            let expected_data = Int32Array::from(vec![
1932                None,
1933                Some(-1),
1934                Some(-2),
1935                Some(3),
1936                Some(5),
1937                None,
1938                Some(0),
1939                None,
1940                Some(0),
1941            ])
1942            .into_data();
1943            // construct offsets
1944            let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9];
1945            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1946            // construct list array from the two
1947            let expected_list_data = ArrayData::builder(list_data_type)
1948                .len(5)
1949                // null buffer remains the same as only the indices have nulls
1950                .nulls(index.nulls().cloned())
1951                .add_buffer(expected_offsets)
1952                .add_child_data(expected_data)
1953                .build()
1954                .unwrap();
1955            let expected_list_array = $list_array_type::from(expected_list_data);
1956
1957            assert_eq!(a, &expected_list_array);
1958        }};
1959    }
1960
1961    macro_rules! test_take_list_with_nulls {
1962        ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1963            // Construct a value array, [[0,null,0], [-1,-2,3], null, [5,null]]
1964            let value_data = Int32Array::from(vec![
1965                Some(0),
1966                None,
1967                Some(0),
1968                Some(-1),
1969                Some(-2),
1970                Some(3),
1971                Some(5),
1972                None,
1973            ])
1974            .into_data();
1975            // Construct offsets
1976            let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1977            let value_offsets = Buffer::from_slice_ref(&value_offsets);
1978            // Construct a list array from the above two
1979            let list_data_type =
1980                DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1981            let list_data = ArrayData::builder(list_data_type.clone())
1982                .len(4)
1983                .add_buffer(value_offsets)
1984                .null_bit_buffer(Some(Buffer::from([0b11111011])))
1985                .add_child_data(value_data)
1986                .build()
1987                .unwrap();
1988            let list_array = $list_array_type::from(list_data);
1989
1990            // index returns: [null, null, [-1,-2,3], [5,null], [0,null,0]]
1991            let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1992
1993            let a = take(&list_array, &index, None).unwrap();
1994            let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1995
1996            // construct a value array with expected results:
1997            // [null, null, [-1,-2,3], [5,null], [0,null,0]]
1998            let expected_data = Int32Array::from(vec![
1999                Some(-1),
2000                Some(-2),
2001                Some(3),
2002                Some(5),
2003                None,
2004                Some(0),
2005                None,
2006                Some(0),
2007            ])
2008            .into_data();
2009            // construct offsets
2010            let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8];
2011            let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
2012            // construct list array from the two
2013            let mut null_bits: [u8; 1] = [0; 1];
2014            bit_util::set_bit(&mut null_bits, 2);
2015            bit_util::set_bit(&mut null_bits, 3);
2016            bit_util::set_bit(&mut null_bits, 4);
2017            let expected_list_data = ArrayData::builder(list_data_type)
2018                .len(5)
2019                // null buffer must be recalculated as both values and indices have nulls
2020                .null_bit_buffer(Some(Buffer::from(null_bits)))
2021                .add_buffer(expected_offsets)
2022                .add_child_data(expected_data)
2023                .build()
2024                .unwrap();
2025            let expected_list_array = $list_array_type::from(expected_list_data);
2026
2027            assert_eq!(a, &expected_list_array);
2028        }};
2029    }
2030
2031    fn test_take_list_view_generic<OffsetType: OffsetSizeTrait, ValuesType: ArrowPrimitiveType, F>(
2032        values: Vec<Option<Vec<Option<ValuesType::Native>>>>,
2033        take_indices: Vec<Option<usize>>,
2034        expected: Vec<Option<Vec<Option<ValuesType::Native>>>>,
2035        mapper: F,
2036    ) where
2037        F: Fn(GenericListViewArray<OffsetType>) -> GenericListViewArray<OffsetType>,
2038    {
2039        let mut list_view_array =
2040            GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new());
2041
2042        for value in values {
2043            list_view_array.append_option(value);
2044        }
2045        let list_view_array = list_view_array.finish();
2046        let list_view_array = mapper(list_view_array);
2047
2048        let mut indices = UInt64Builder::new();
2049        for idx in take_indices {
2050            indices.append_option(idx.map(|i| i.to_u64().unwrap()));
2051        }
2052        let indices = indices.finish();
2053
2054        let taken = take(&list_view_array, &indices, None)
2055            .unwrap()
2056            .as_list_view()
2057            .clone();
2058
2059        let mut expected_array =
2060            GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new());
2061        for value in expected {
2062            expected_array.append_option(value);
2063        }
2064        let expected_array = expected_array.finish();
2065
2066        assert_eq!(taken, expected_array);
2067    }
2068
2069    macro_rules! list_view_test_case {
2070        (values: $values:expr, indices: $indices:expr, expected: $expected: expr) => {{
2071            test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, |x| x);
2072            test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, |x| x);
2073        }};
2074        (values: $values:expr, transform: $fn:expr, indices: $indices:expr, expected: $expected: expr) => {{
2075            test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, $fn);
2076            test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, $fn);
2077        }};
2078    }
2079
2080    fn do_take_fixed_size_list_test<T>(
2081        length: <Int32Type as ArrowPrimitiveType>::Native,
2082        input_data: Vec<Option<Vec<Option<T::Native>>>>,
2083        indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>,
2084        expected_data: Vec<Option<Vec<Option<T::Native>>>>,
2085    ) where
2086        T: ArrowPrimitiveType,
2087        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2088    {
2089        let indices = UInt32Array::from(indices);
2090
2091        let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length);
2092
2093        let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap();
2094
2095        let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length);
2096
2097        assert_eq!(&output, &expected)
2098    }
2099
2100    #[test]
2101    fn test_take_list() {
2102        test_take_list!(i32, List, ListArray);
2103    }
2104
2105    #[test]
2106    fn test_take_large_list() {
2107        test_take_list!(i64, LargeList, LargeListArray);
2108    }
2109
2110    #[test]
2111    fn test_take_list_with_value_nulls() {
2112        test_take_list_with_value_nulls!(i32, List, ListArray);
2113    }
2114
2115    #[test]
2116    fn test_take_large_list_with_value_nulls() {
2117        test_take_list_with_value_nulls!(i64, LargeList, LargeListArray);
2118    }
2119
2120    #[test]
2121    fn test_test_take_list_with_nulls() {
2122        test_take_list_with_nulls!(i32, List, ListArray);
2123    }
2124
2125    #[test]
2126    fn test_test_take_large_list_with_nulls() {
2127        test_take_list_with_nulls!(i64, LargeList, LargeListArray);
2128    }
2129
2130    #[test]
2131    fn test_test_take_list_view_reversed() {
2132        // Take reversed indices
2133        list_view_test_case! {
2134            values: vec![
2135                Some(vec![Some(1), None, Some(3)]),
2136                None,
2137                Some(vec![Some(7), Some(8), None]),
2138            ],
2139            indices: vec![Some(2), Some(1), Some(0)],
2140            expected: vec![
2141                Some(vec![Some(7), Some(8), None]),
2142                None,
2143                Some(vec![Some(1), None, Some(3)]),
2144            ]
2145        }
2146    }
2147
2148    #[test]
2149    fn test_take_list_view_null_indices() {
2150        // Take with null indices
2151        list_view_test_case! {
2152            values: vec![
2153                Some(vec![Some(1), None, Some(3)]),
2154                None,
2155                Some(vec![Some(7), Some(8), None]),
2156            ],
2157            indices: vec![None, Some(0), None],
2158            expected: vec![None, Some(vec![Some(1), None, Some(3)]), None]
2159        }
2160    }
2161
2162    #[test]
2163    fn test_take_list_view_null_values() {
2164        // Take at null values
2165        list_view_test_case! {
2166            values: vec![
2167                Some(vec![Some(1), None, Some(3)]),
2168                None,
2169                Some(vec![Some(7), Some(8), None]),
2170            ],
2171            indices: vec![Some(1), Some(1), Some(1), None, None],
2172            expected: vec![None; 5]
2173        }
2174    }
2175
2176    #[test]
2177    fn test_take_list_view_sliced() {
2178        // Take null indices/values, with slicing.
2179        list_view_test_case! {
2180            values: vec![
2181                Some(vec![Some(1)]),
2182                None,
2183                None,
2184                Some(vec![Some(2), Some(3)]),
2185                Some(vec![Some(4), Some(5)]),
2186                None,
2187            ],
2188            transform: |l| l.slice(2, 4),
2189            indices: vec![Some(0), Some(3), None, Some(1), Some(2)],
2190            expected: vec![
2191                None, None, None, Some(vec![Some(2), Some(3)]), Some(vec![Some(4), Some(5)])
2192            ]
2193        }
2194    }
2195
2196    #[test]
2197    fn test_take_fixed_size_list() {
2198        do_take_fixed_size_list_test::<Int32Type>(
2199            3,
2200            vec![
2201                Some(vec![None, Some(1), Some(2)]),
2202                Some(vec![Some(3), Some(4), None]),
2203                Some(vec![Some(6), Some(7), Some(8)]),
2204            ],
2205            vec![2, 1, 0],
2206            vec![
2207                Some(vec![Some(6), Some(7), Some(8)]),
2208                Some(vec![Some(3), Some(4), None]),
2209                Some(vec![None, Some(1), Some(2)]),
2210            ],
2211        );
2212
2213        do_take_fixed_size_list_test::<UInt8Type>(
2214            1,
2215            vec![
2216                Some(vec![Some(1)]),
2217                Some(vec![Some(2)]),
2218                Some(vec![Some(3)]),
2219                Some(vec![Some(4)]),
2220                Some(vec![Some(5)]),
2221                Some(vec![Some(6)]),
2222                Some(vec![Some(7)]),
2223                Some(vec![Some(8)]),
2224            ],
2225            vec![2, 7, 0],
2226            vec![
2227                Some(vec![Some(3)]),
2228                Some(vec![Some(8)]),
2229                Some(vec![Some(1)]),
2230            ],
2231        );
2232
2233        do_take_fixed_size_list_test::<UInt64Type>(
2234            3,
2235            vec![
2236                Some(vec![Some(10), Some(11), Some(12)]),
2237                Some(vec![Some(13), Some(14), Some(15)]),
2238                None,
2239                Some(vec![Some(16), Some(17), Some(18)]),
2240            ],
2241            vec![3, 2, 1, 2, 0],
2242            vec![
2243                Some(vec![Some(16), Some(17), Some(18)]),
2244                None,
2245                Some(vec![Some(13), Some(14), Some(15)]),
2246                None,
2247                Some(vec![Some(10), Some(11), Some(12)]),
2248            ],
2249        );
2250    }
2251
2252    #[test]
2253    fn test_take_fixed_size_binary_with_nulls_indices() {
2254        let fsb = FixedSizeBinaryArray::try_from_sparse_iter_with_size(
2255            [
2256                Some(vec![0x01, 0x01, 0x01, 0x01]),
2257                Some(vec![0x02, 0x02, 0x02, 0x02]),
2258                Some(vec![0x03, 0x03, 0x03, 0x03]),
2259                Some(vec![0x04, 0x04, 0x04, 0x04]),
2260            ]
2261            .into_iter(),
2262            4,
2263        )
2264        .unwrap();
2265
2266        // The two middle indices are null -> Should be null in the output.
2267        let indices = UInt32Array::from(vec![Some(0), None, None, Some(3)]);
2268
2269        let result = take_fixed_size_binary(&fsb, &indices, 4).unwrap();
2270        assert_eq!(result.len(), 4);
2271        assert_eq!(result.null_count(), 2);
2272        assert_eq!(
2273            result.nulls().unwrap().iter().collect::<Vec<_>>(),
2274            vec![true, false, false, true]
2275        );
2276    }
2277
2278    /// The [`take_fixed_size_binary`] kernel contains optimizations that provide a faster
2279    /// implementation for commonly-used value lengths. This test uses a value length that is not
2280    /// optimized to test both code paths.
2281    #[test]
2282    fn test_take_fixed_size_binary_with_nulls_indices_not_optimized_length() {
2283        let fsb = FixedSizeBinaryArray::try_from_sparse_iter_with_size(
2284            [
2285                Some(vec![0x01, 0x01, 0x01, 0x01, 0x01]),
2286                Some(vec![0x02, 0x02, 0x02, 0x02, 0x01]),
2287                Some(vec![0x03, 0x03, 0x03, 0x03, 0x01]),
2288                Some(vec![0x04, 0x04, 0x04, 0x04, 0x01]),
2289            ]
2290            .into_iter(),
2291            5,
2292        )
2293        .unwrap();
2294
2295        // The two middle indices are null -> Should be null in the output.
2296        let indices = UInt32Array::from(vec![Some(0), None, None, Some(3)]);
2297
2298        let result = take_fixed_size_binary(&fsb, &indices, 5).unwrap();
2299        assert_eq!(result.len(), 4);
2300        assert_eq!(result.null_count(), 2);
2301        assert_eq!(
2302            result.nulls().unwrap().iter().collect::<Vec<_>>(),
2303            vec![true, false, false, true]
2304        );
2305    }
2306
2307    #[test]
2308    #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2309    fn test_take_list_out_of_bounds() {
2310        // Construct a value array, [[0,0,0], [-1,-2,-1], [2,3]]
2311        let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
2312        // Construct offsets
2313        let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
2314        // Construct a list array from the above two
2315        let list_data_type =
2316            DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false)));
2317        let list_data = ArrayData::builder(list_data_type)
2318            .len(3)
2319            .add_buffer(value_offsets)
2320            .add_child_data(value_data)
2321            .build()
2322            .unwrap();
2323        let list_array = ListArray::from(list_data);
2324
2325        let index = UInt32Array::from(vec![1000]);
2326
2327        // A panic is expected here since we have not supplied the check_bounds
2328        // option.
2329        take(&list_array, &index, None).unwrap();
2330    }
2331
2332    #[test]
2333    fn test_take_map() {
2334        let values = Int32Array::from(vec![1, 2, 3, 4]);
2335        let array =
2336            MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4])
2337                .unwrap();
2338
2339        let index = UInt32Array::from(vec![0]);
2340
2341        let result = take(&array, &index, None).unwrap();
2342        let expected: ArrayRef = Arc::new(
2343            MapArray::new_from_strings(
2344                vec!["a", "b", "c"].into_iter(),
2345                &values.slice(0, 3),
2346                &[0, 3],
2347            )
2348            .unwrap(),
2349        );
2350        assert_eq!(&expected, &result);
2351    }
2352
2353    #[test]
2354    fn test_take_struct() {
2355        let array = create_test_struct(vec![
2356            Some((Some(true), Some(42))),
2357            Some((Some(false), Some(28))),
2358            Some((Some(false), Some(19))),
2359            Some((Some(true), Some(31))),
2360            None,
2361        ]);
2362
2363        let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
2364        let actual = take(&array, &index, None).unwrap();
2365        let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2366        assert_eq!(index.len(), actual.len());
2367        assert_eq!(1, actual.null_count());
2368
2369        let expected = create_test_struct(vec![
2370            Some((Some(true), Some(42))),
2371            Some((Some(true), Some(31))),
2372            Some((Some(false), Some(28))),
2373            Some((Some(true), Some(42))),
2374            Some((Some(false), Some(19))),
2375            None,
2376        ]);
2377
2378        assert_eq!(&expected, actual);
2379
2380        let nulls = NullBuffer::from(&[false, true, false, true, false, true]);
2381        let empty_struct_arr = StructArray::new_empty_fields(6, Some(nulls));
2382        let index = UInt32Array::from(vec![0, 2, 1, 4]);
2383        let actual = take(&empty_struct_arr, &index, None).unwrap();
2384
2385        let expected_nulls = NullBuffer::from(&[false, false, true, false]);
2386        let expected_struct_arr = StructArray::new_empty_fields(4, Some(expected_nulls));
2387        assert_eq!(&expected_struct_arr, actual.as_struct());
2388    }
2389
2390    #[test]
2391    fn test_take_struct_with_null_indices() {
2392        let array = create_test_struct(vec![
2393            Some((Some(true), Some(42))),
2394            Some((Some(false), Some(28))),
2395            Some((Some(false), Some(19))),
2396            Some((Some(true), Some(31))),
2397            None,
2398        ]);
2399
2400        let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]);
2401        let actual = take(&array, &index, None).unwrap();
2402        let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2403        assert_eq!(index.len(), actual.len());
2404        assert_eq!(3, actual.null_count()); // 2 because of indices, 1 because of struct array
2405
2406        let expected = create_test_struct(vec![
2407            None,
2408            Some((Some(true), Some(31))),
2409            Some((Some(false), Some(28))),
2410            None,
2411            Some((Some(true), Some(42))),
2412            None,
2413        ]);
2414
2415        assert_eq!(&expected, actual);
2416    }
2417
2418    #[test]
2419    fn test_take_out_of_bounds() {
2420        let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]);
2421        let take_opt = TakeOptions { check_bounds: true };
2422
2423        // int64
2424        let result = test_take_primitive_arrays::<Int64Type>(
2425            vec![Some(0), None, Some(2), Some(3), None],
2426            &index,
2427            Some(take_opt),
2428            vec![None],
2429        );
2430        assert!(result.is_err());
2431    }
2432
2433    #[test]
2434    #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2435    fn test_take_out_of_bounds_panic() {
2436        let index = UInt32Array::from(vec![Some(1000)]);
2437
2438        test_take_primitive_arrays::<Int64Type>(
2439            vec![Some(0), Some(1), Some(2), Some(3)],
2440            &index,
2441            None,
2442            vec![None],
2443        )
2444        .unwrap();
2445    }
2446
2447    #[test]
2448    fn test_null_array_smaller_than_indices() {
2449        let values = NullArray::new(2);
2450        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2451
2452        let result = take(&values, &indices, None).unwrap();
2453        let expected: ArrayRef = Arc::new(NullArray::new(3));
2454        assert_eq!(&result, &expected);
2455    }
2456
2457    #[test]
2458    fn test_null_array_larger_than_indices() {
2459        let values = NullArray::new(5);
2460        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2461
2462        let result = take(&values, &indices, None).unwrap();
2463        let expected: ArrayRef = Arc::new(NullArray::new(3));
2464        assert_eq!(&result, &expected);
2465    }
2466
2467    #[test]
2468    fn test_null_array_indices_out_of_bounds() {
2469        let values = NullArray::new(5);
2470        let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2471
2472        let result = take(&values, &indices, Some(TakeOptions { check_bounds: true }));
2473        assert_eq!(
2474            result.unwrap_err().to_string(),
2475            "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries"
2476        );
2477    }
2478
2479    #[test]
2480    fn test_take_dict() {
2481        let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new();
2482
2483        dict_builder.append("foo").unwrap();
2484        dict_builder.append("bar").unwrap();
2485        dict_builder.append("").unwrap();
2486        dict_builder.append_null();
2487        dict_builder.append("foo").unwrap();
2488        dict_builder.append("bar").unwrap();
2489        dict_builder.append("bar").unwrap();
2490        dict_builder.append("foo").unwrap();
2491
2492        let array = dict_builder.finish();
2493        let dict_values = array.values().clone();
2494        let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap();
2495
2496        let indices = UInt32Array::from(vec![
2497            Some(0), // first "foo"
2498            Some(7), // last "foo"
2499            None,    // null index should return null
2500            Some(5), // second "bar"
2501            Some(6), // another "bar"
2502            Some(2), // empty string
2503            Some(3), // input is null at this index
2504        ]);
2505
2506        let result = take(&array, &indices, None).unwrap();
2507        let result = result
2508            .as_any()
2509            .downcast_ref::<DictionaryArray<Int16Type>>()
2510            .unwrap();
2511
2512        let result_values: StringArray = result.values().to_data().into();
2513
2514        // dictionary values should stay the same
2515        let expected_values = StringArray::from(vec!["foo", "bar", ""]);
2516        assert_eq!(&expected_values, dict_values);
2517        assert_eq!(&expected_values, &result_values);
2518
2519        let expected_keys = Int16Array::from(vec![
2520            Some(0),
2521            Some(0),
2522            None,
2523            Some(1),
2524            Some(1),
2525            Some(2),
2526            None,
2527        ]);
2528        assert_eq!(result.keys(), &expected_keys);
2529    }
2530
2531    fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S>
2532    where
2533        S: OffsetSizeTrait + 'static,
2534        T: ArrowPrimitiveType,
2535        PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2536    {
2537        GenericListArray::from_iter_primitive::<T, _, _>(
2538            data.iter()
2539                .map(|x| x.as_ref().map(|x| x.iter().map(|x| Some(*x)))),
2540        )
2541    }
2542
2543    fn test_take_sliced_list_generic<S: OffsetSizeTrait + 'static>() {
2544        let list = build_generic_list::<S, Int32Type>(vec![
2545            Some(vec![0, 1]),
2546            Some(vec![2, 3, 4]),
2547            None,
2548            Some(vec![]),
2549            Some(vec![5, 6]),
2550            Some(vec![7]),
2551        ]);
2552        let sliced = list.slice(1, 4);
2553        let indices = UInt32Array::from(vec![Some(3), Some(0), None, Some(2), Some(1)]);
2554
2555        let taken = take(&sliced, &indices, None).unwrap();
2556        let taken = taken.as_list::<S>();
2557
2558        let expected = build_generic_list::<S, Int32Type>(vec![
2559            Some(vec![5, 6]),
2560            Some(vec![2, 3, 4]),
2561            None,
2562            Some(vec![]),
2563            None,
2564        ]);
2565
2566        assert_eq!(taken, &expected);
2567    }
2568
2569    fn test_take_sliced_list_with_value_nulls_generic<S: OffsetSizeTrait + 'static>() {
2570        let list = GenericListArray::<S>::from_iter_primitive::<Int32Type, _, _>(vec![
2571            Some(vec![Some(10)]),
2572            Some(vec![None, Some(1)]),
2573            None,
2574            Some(vec![Some(2), None]),
2575            Some(vec![]),
2576            Some(vec![Some(3)]),
2577        ]);
2578        let sliced = list.slice(1, 4);
2579        let indices = UInt32Array::from(vec![Some(2), Some(0), None, Some(3), Some(1)]);
2580
2581        let taken = take(&sliced, &indices, None).unwrap();
2582        let taken = taken.as_list::<S>();
2583
2584        let expected = GenericListArray::<S>::from_iter_primitive::<Int32Type, _, _>(vec![
2585            Some(vec![Some(2), None]),
2586            Some(vec![None, Some(1)]),
2587            None,
2588            Some(vec![]),
2589            None,
2590        ]);
2591
2592        assert_eq!(taken, &expected);
2593    }
2594
2595    #[test]
2596    fn test_take_sliced_list() {
2597        test_take_sliced_list_generic::<i32>();
2598    }
2599
2600    #[test]
2601    fn test_take_sliced_large_list() {
2602        test_take_sliced_list_generic::<i64>();
2603    }
2604
2605    #[test]
2606    fn test_take_sliced_list_with_value_nulls() {
2607        test_take_sliced_list_with_value_nulls_generic::<i32>();
2608    }
2609
2610    #[test]
2611    fn test_take_sliced_large_list_with_value_nulls() {
2612        test_take_sliced_list_with_value_nulls_generic::<i64>();
2613    }
2614
2615    #[test]
2616    fn test_take_runs() {
2617        let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];
2618
2619        let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2620        builder.extend(logical_array.into_iter().map(Some));
2621        let run_array = builder.finish();
2622
2623        let take_indices: PrimitiveArray<Int32Type> =
2624            vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
2625
2626        let take_out = take_run(&run_array, &take_indices).unwrap();
2627
2628        assert_eq!(take_out.len(), 7);
2629        assert_eq!(take_out.run_ends().len(), 7);
2630        assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
2631
2632        let take_out_values = take_out.values().as_primitive::<Int32Type>();
2633        assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
2634    }
2635
2636    #[test]
2637    fn test_take_runs_sliced() {
2638        let logical_array: Vec<i32> = vec![1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6];
2639
2640        let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2641        builder.extend(logical_array.into_iter().map(Some));
2642        let run_array = builder.finish();
2643
2644        let run_array = run_array.slice(4, 6); // [3, 3, 3, 4, 4, 5]
2645
2646        let take_indices: PrimitiveArray<Int32Type> = vec![0, 5, 5, 1, 4].into_iter().collect();
2647
2648        let result = take_run(&run_array, &take_indices).unwrap();
2649        let result = result.downcast::<Int32Array>().unwrap();
2650
2651        let expected = vec![3, 5, 5, 3, 4];
2652        let actual = result.into_iter().flatten().collect::<Vec<_>>();
2653
2654        assert_eq!(expected, actual);
2655    }
2656
2657    #[test]
2658    fn test_take_value_index_from_fixed_list() {
2659        let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
2660            vec![
2661                Some(vec![Some(1), Some(2), None]),
2662                Some(vec![Some(4), None, Some(6)]),
2663                None,
2664                Some(vec![None, Some(8), Some(9)]),
2665            ],
2666            3,
2667        );
2668
2669        let indices = UInt32Array::from(vec![2, 1, 0]);
2670        let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2671
2672        assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2]));
2673
2674        let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]);
2675        let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2676
2677        assert_eq!(
2678            indexed,
2679            UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2])
2680        );
2681    }
2682
2683    #[test]
2684    fn test_take_null_indices() {
2685        // Build indices with values that are out of bounds, but masked by null mask
2686        let indices = Int32Array::new(
2687            vec![1, 2, 400, 400].into(),
2688            Some(NullBuffer::from(vec![true, true, false, false])),
2689        );
2690        let values = Int32Array::from(vec![1, 23, 4, 5]);
2691        let r = take(&values, &indices, None).unwrap();
2692        let values = r
2693            .as_primitive::<Int32Type>()
2694            .into_iter()
2695            .collect::<Vec<_>>();
2696        assert_eq!(&values, &[Some(23), Some(4), None, None])
2697    }
2698
2699    #[test]
2700    fn test_take_fixed_size_list_null_indices() {
2701        let indices = Int32Array::from_iter([Some(0), None]);
2702        let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3]));
2703        let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true));
2704        let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap();
2705
2706        let r = take(&values, &indices, None).unwrap();
2707        let values = r
2708            .as_fixed_size_list()
2709            .values()
2710            .as_primitive::<Int32Type>()
2711            .into_iter()
2712            .collect::<Vec<_>>();
2713        assert_eq!(values, &[Some(0), Some(1), None, None])
2714    }
2715
2716    #[test]
2717    fn test_take_bytes_null_indices() {
2718        let indices = Int32Array::new(
2719            vec![0, 1, 400, 400].into(),
2720            Some(NullBuffer::from_iter(vec![true, true, false, false])),
2721        );
2722        let values = StringArray::from(vec![Some("foo"), None]);
2723        let r = take(&values, &indices, None).unwrap();
2724        let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
2725        assert_eq!(&values, &[Some("foo"), None, None, None])
2726    }
2727
2728    #[test]
2729    fn test_take_union_sparse() {
2730        let structs = create_test_struct(vec![
2731            Some((Some(true), Some(42))),
2732            Some((Some(false), Some(28))),
2733            Some((Some(false), Some(19))),
2734            Some((Some(true), Some(31))),
2735            None,
2736        ]);
2737        let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
2738        let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>();
2739
2740        let union_fields = [
2741            (
2742                0,
2743                Arc::new(Field::new("f1", structs.data_type().clone(), true)),
2744            ),
2745            (
2746                1,
2747                Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2748            ),
2749        ]
2750        .into_iter()
2751        .collect();
2752        let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)];
2753        let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
2754
2755        let indices = vec![0, 3, 1, 0, 2, 4];
2756        let index = UInt32Array::from(indices.clone());
2757        let actual = take(&array, &index, None).unwrap();
2758        let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2759        let strings = actual.child(1);
2760        let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
2761
2762        let actual = strings.iter().collect::<Vec<_>>();
2763        let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
2764        assert_eq!(expected, actual);
2765    }
2766
2767    #[test]
2768    fn test_take_union_dense() {
2769        let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
2770        let offsets = vec![0, 0, 1, 1, 2, 2, 3];
2771        let ints = vec![10, 20, 30, 40];
2772        let strings = vec![Some("a"), None, Some("c"), Some("d")];
2773
2774        let indices = vec![0, 3, 1, 0, 2, 4];
2775
2776        let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
2777        let taken_offsets = vec![0, 1, 0, 2, 1, 3];
2778        let taken_ints = vec![10, 20, 10, 30];
2779        let taken_strings = vec![Some("a"), None];
2780
2781        let type_ids = <ScalarBuffer<i8>>::from(type_ids);
2782        let offsets = <ScalarBuffer<i32>>::from(offsets);
2783        let ints = UInt32Array::from(ints);
2784        let strings = StringArray::from(strings);
2785
2786        let union_fields = [
2787            (
2788                0,
2789                Arc::new(Field::new("f1", ints.data_type().clone(), true)),
2790            ),
2791            (
2792                1,
2793                Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2794            ),
2795        ]
2796        .into_iter()
2797        .collect();
2798
2799        let array = UnionArray::try_new(
2800            union_fields,
2801            type_ids,
2802            Some(offsets),
2803            vec![Arc::new(ints), Arc::new(strings)],
2804        )
2805        .unwrap();
2806
2807        let index = UInt32Array::from(indices);
2808
2809        let actual = take(&array, &index, None).unwrap();
2810        let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2811
2812        assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
2813        assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
2814        assert_eq!(
2815            UInt32Array::from(actual.child(0).to_data()),
2816            UInt32Array::from(taken_ints)
2817        );
2818        assert_eq!(
2819            StringArray::from(actual.child(1).to_data()),
2820            StringArray::from(taken_strings)
2821        );
2822    }
2823
2824    #[test]
2825    fn test_take_union_dense_using_builder() {
2826        let mut builder = UnionBuilder::new_dense();
2827
2828        builder.append::<Int32Type>("a", 1).unwrap();
2829        builder.append::<Float64Type>("b", 3.0).unwrap();
2830        builder.append::<Int32Type>("a", 4).unwrap();
2831        builder.append::<Int32Type>("a", 5).unwrap();
2832        builder.append::<Float64Type>("b", 2.0).unwrap();
2833
2834        let union = builder.build().unwrap();
2835
2836        let indices = UInt32Array::from(vec![2, 0, 1, 2]);
2837
2838        let mut builder = UnionBuilder::new_dense();
2839
2840        builder.append::<Int32Type>("a", 4).unwrap();
2841        builder.append::<Int32Type>("a", 1).unwrap();
2842        builder.append::<Float64Type>("b", 3.0).unwrap();
2843        builder.append::<Int32Type>("a", 4).unwrap();
2844
2845        let taken = builder.build().unwrap();
2846
2847        assert_eq!(
2848            taken.to_data(),
2849            take(&union, &indices, None).unwrap().to_data()
2850        );
2851    }
2852
2853    #[test]
2854    fn test_take_union_dense_all_match_issue_6206() {
2855        let fields = UnionFields::from_fields(vec![Field::new("a", DataType::Int64, false)]);
2856        let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
2857
2858        let array = UnionArray::try_new(
2859            fields,
2860            ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
2861            Some(ScalarBuffer::from_iter(0_i32..5)),
2862            vec![ints],
2863        )
2864        .unwrap();
2865
2866        let indicies = Int64Array::from(vec![0, 2, 4]);
2867        let array = take(&array, &indicies, None).unwrap();
2868        assert_eq!(array.len(), 3);
2869    }
2870
2871    /// Fixture for the offset-overflow tests: a single large value plus the
2872    /// number of times it must be selected so the cumulative offset exceeds
2873    /// `i32::MAX`. Using a large value keeps the index count (and the test
2874    /// runtime) small.
2875    fn offset_overflow_fixture() -> (StringArray, usize) {
2876        let value_len = 1_000_000;
2877        let values = StringArray::from(vec![Some("a".repeat(value_len))]);
2878        let n = i32::MAX as usize / value_len + 1;
2879        (values, n)
2880    }
2881
2882    #[test]
2883    fn test_take_bytes_offset_overflow() {
2884        let (values, n) = offset_overflow_fixture();
2885        let indices = Int32Array::from(vec![0; n]);
2886        assert!(matches!(
2887            take(&values, &indices, None),
2888            Err(ArrowError::OffsetOverflowError(_))
2889        ));
2890    }
2891
2892    /// The offset-overflow error must also be produced on the nullable code
2893    /// path (when the output contains nulls), not only on the no-null fast path.
2894    #[test]
2895    fn test_take_bytes_offset_overflow_nullable() {
2896        let (values, n) = offset_overflow_fixture();
2897        // A null index forces the output to contain nulls, exercising the
2898        // nullable code path.
2899        let validity =
2900            NullBuffer::from_iter(std::iter::once(false).chain(std::iter::repeat_n(true, n)));
2901        let indices = Int32Array::new(vec![0i32; n + 1].into(), Some(validity));
2902
2903        assert!(matches!(
2904            take(&values, &indices, None),
2905            Err(ArrowError::OffsetOverflowError(_))
2906        ));
2907    }
2908
2909    #[test]
2910    fn test_take_run_empty_indices() {
2911        let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2912        builder.extend([Some(1), Some(1), Some(2), Some(2)]);
2913        let run_array = builder.finish();
2914
2915        let logical_indices: PrimitiveArray<Int32Type> = PrimitiveArray::from(Vec::<i32>::new());
2916
2917        let result = take_impl(&run_array, &logical_indices).expect("take_run with empty indices");
2918
2919        // Verify the result is a valid empty RunArray
2920        assert_eq!(result.len(), 0);
2921        assert_eq!(result.null_count(), 0);
2922
2923        // Verify that the result can be downcast and used without validation errors
2924        // This specifically tests that "The values in run_ends array should be strictly positive" is not triggered
2925        let run_result = result
2926            .as_any()
2927            .downcast_ref::<RunArray<Int32Type>>()
2928            .expect("result should be a RunArray");
2929        assert_eq!(run_result.run_ends().len(), 0);
2930        assert_eq!(run_result.values().len(), 0);
2931    }
2932}