1use std::fmt::Display;
21use std::mem::ManuallyDrop;
22use std::sync::Arc;
23
24use arrow_array::builder::{BufferBuilder, UInt32Builder};
25use arrow_array::cast::AsArray;
26use arrow_array::types::*;
27use arrow_array::*;
28use arrow_buffer::{
29 ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, ScalarBuffer,
30 bit_util,
31};
32use arrow_data::{ArrayDataBuilder, transform::MutableArrayData};
33use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
34
35use num_traits::Zero;
36
37pub fn take(
89 values: &dyn Array,
90 indices: &dyn Array,
91 options: Option<TakeOptions>,
92) -> Result<ArrayRef, ArrowError> {
93 let options = options.unwrap_or_default();
94 downcast_integer_array!(
95 indices => {
96 if options.check_bounds {
97 check_bounds(values.len(), indices)?;
98 }
99 let indices = indices.to_indices();
100 take_impl(values, &indices)
101 },
102 d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}")))
103 )
104}
105
106pub fn take_arrays(
155 arrays: &[ArrayRef],
156 indices: &dyn Array,
157 options: Option<TakeOptions>,
158) -> Result<Vec<ArrayRef>, ArrowError> {
159 arrays
160 .iter()
161 .map(|array| take(array.as_ref(), indices, options.clone()))
162 .collect()
163}
164
165fn check_bounds<T: ArrowPrimitiveType>(
167 len: usize,
168 indices: &PrimitiveArray<T>,
169) -> Result<(), ArrowError>
170where
171 T::Native: Display,
172{
173 let len = match T::Native::from_usize(len) {
174 Some(len) => len,
175 None => {
176 if T::DATA_TYPE.is_integer() {
177 return Ok(());
179 } else {
180 return Err(ArrowError::ComputeError("Cast to usize failed".to_string()));
181 }
182 }
183 };
184
185 if indices.null_count() > 0 {
186 indices.iter().flatten().try_for_each(|index| {
187 if index >= len {
188 return Err(ArrowError::ComputeError(format!(
189 "Array index out of bounds, cannot get item at index {index} from {len} entries"
190 )));
191 }
192 Ok(())
193 })
194 } else {
195 let in_bounds = indices.values().iter().fold(true, |in_bounds, &i| {
196 in_bounds & (i >= T::Native::ZERO) & (i < len)
197 });
198
199 if !in_bounds {
200 for &index in indices.values() {
201 if index < T::Native::ZERO || index >= len {
202 return Err(ArrowError::ComputeError(format!(
203 "Array index out of bounds, cannot get item at index {index} from {len} entries"
204 )));
205 }
206 }
207 }
208
209 Ok(())
210 }
211}
212
213#[inline(never)]
214fn take_impl<IndexType: ArrowPrimitiveType>(
215 values: &dyn Array,
216 indices: &PrimitiveArray<IndexType>,
217) -> Result<ArrayRef, ArrowError> {
218 if indices.is_empty() {
219 return Ok(new_empty_array(values.data_type()));
220 }
221 downcast_primitive_array! {
222 values => Ok(Arc::new(take_primitive(values, indices)?)),
223 DataType::Boolean => {
224 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
225 Ok(Arc::new(take_boolean(values, indices)))
226 }
227 DataType::Utf8 => {
228 Ok(Arc::new(take_bytes(values.as_string::<i32>(), indices)?))
229 }
230 DataType::LargeUtf8 => {
231 Ok(Arc::new(take_bytes(values.as_string::<i64>(), indices)?))
232 }
233 DataType::Utf8View => {
234 Ok(Arc::new(take_byte_view(values.as_string_view(), indices)?))
235 }
236 DataType::List(_) => {
237 Ok(Arc::new(take_list::<_, Int32Type>(values.as_list(), indices)?))
238 }
239 DataType::LargeList(_) => {
240 Ok(Arc::new(take_list::<_, Int64Type>(values.as_list(), indices)?))
241 }
242 DataType::ListView(_) => {
243 Ok(Arc::new(take_list_view::<_, Int32Type>(values.as_list_view(), indices)?))
244 }
245 DataType::LargeListView(_) => {
246 Ok(Arc::new(take_list_view::<_, Int64Type>(values.as_list_view(), indices)?))
247 }
248 DataType::FixedSizeList(_, length) => {
249 let values = values
250 .as_any()
251 .downcast_ref::<FixedSizeListArray>()
252 .unwrap();
253 Ok(Arc::new(take_fixed_size_list(
254 values,
255 indices,
256 *length as u32,
257 )?))
258 }
259 DataType::Map(_, _) => {
260 let list_arr = ListArray::from(values.as_map().clone());
261 let list_data = take_list::<_, Int32Type>(&list_arr, indices)?;
262 let builder = list_data.into_data().into_builder().data_type(values.data_type().clone());
263 Ok(Arc::new(MapArray::from(unsafe { builder.build_unchecked() })))
264 }
265 DataType::Struct(fields) => {
266 let array: &StructArray = values.as_struct();
267 let arrays = array
268 .columns()
269 .iter()
270 .map(|a| take_impl(a.as_ref(), indices))
271 .collect::<Result<Vec<ArrayRef>, _>>()?;
272 let fields: Vec<(FieldRef, ArrayRef)> =
273 fields.iter().cloned().zip(arrays).collect();
274
275 let is_valid: Buffer = indices
277 .iter()
278 .map(|index| {
279 if let Some(index) = index {
280 array.is_valid(index.to_usize().unwrap())
281 } else {
282 false
283 }
284 })
285 .collect();
286
287 if fields.is_empty() {
288 let nulls = NullBuffer::new(BooleanBuffer::new(is_valid, 0, indices.len()));
289 Ok(Arc::new(StructArray::new_empty_fields(indices.len(), Some(nulls))))
290 } else {
291 Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef)
292 }
293 }
294 DataType::Dictionary(_, _) => downcast_dictionary_array! {
295 values => Ok(Arc::new(take_dict(values, indices)?)),
296 t => unimplemented!("Take not supported for dictionary type {:?}", t)
297 }
298 DataType::RunEndEncoded(_, _) => downcast_run_array! {
299 values => Ok(Arc::new(take_run(values, indices)?)),
300 t => unimplemented!("Take not supported for run type {:?}", t)
301 }
302 DataType::Binary => {
303 Ok(Arc::new(take_bytes(values.as_binary::<i32>(), indices)?))
304 }
305 DataType::LargeBinary => {
306 Ok(Arc::new(take_bytes(values.as_binary::<i64>(), indices)?))
307 }
308 DataType::BinaryView => {
309 Ok(Arc::new(take_byte_view(values.as_binary_view(), indices)?))
310 }
311 DataType::FixedSizeBinary(size) => {
312 let values = values
313 .as_any()
314 .downcast_ref::<FixedSizeBinaryArray>()
315 .unwrap();
316 Ok(Arc::new(take_fixed_size_binary(values, indices, *size)?))
317 }
318 DataType::Null => {
319 if values.len() >= indices.len() {
321 Ok(values.slice(0, indices.len()))
324 } else {
325 Ok(new_null_array(&DataType::Null, indices.len()))
327 }
328 }
329 DataType::Union(fields, UnionMode::Sparse) => {
330 let mut children = Vec::with_capacity(fields.len());
331 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
332 let type_ids = take_native(values.type_ids(), indices);
333 for (type_id, _field) in fields.iter() {
334 let values = values.child(type_id);
335 let values = take_impl(values, indices)?;
336 children.push(values);
337 }
338 let array = UnionArray::try_new(fields.clone(), type_ids, None, children)?;
339 Ok(Arc::new(array))
340 }
341 DataType::Union(fields, UnionMode::Dense) => {
342 let values = values.as_any().downcast_ref::<UnionArray>().unwrap();
343
344 let type_ids = <PrimitiveArray<Int8Type>>::try_new(take_native(values.type_ids(), indices), None)?;
345 let offsets = <PrimitiveArray<Int32Type>>::try_new(take_native(values.offsets().unwrap(), indices), None)?;
346
347 let children = fields.iter()
348 .map(|(field_type_id, _)| {
349 let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id);
350
351 let indices = crate::filter::filter(&offsets, &mask)?;
352
353 let values = values.child(field_type_id);
354
355 take_impl(values, indices.as_primitive::<Int32Type>())
356 })
357 .collect::<Result<_, _>>()?;
358
359 let mut child_offsets = [0; 128];
360
361 let offsets = type_ids.values()
362 .iter()
363 .map(|&i| {
364 let offset = child_offsets[i as usize];
365
366 child_offsets[i as usize] += 1;
367
368 offset
369 })
370 .collect();
371
372 let (_, type_ids, _) = type_ids.into_parts();
373
374 let array = UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)?;
375
376 Ok(Arc::new(array))
377 }
378 t => unimplemented!("Take not supported for data type {:?}", t)
379 }
380}
381
382#[derive(Clone, Debug, Default)]
384pub struct TakeOptions {
385 pub check_bounds: bool,
389}
390
391fn take_primitive<T, I>(
401 values: &PrimitiveArray<T>,
402 indices: &PrimitiveArray<I>,
403) -> Result<PrimitiveArray<T>, ArrowError>
404where
405 T: ArrowPrimitiveType,
406 I: ArrowPrimitiveType,
407{
408 let values_buf = take_native(values.values(), indices);
409 let nulls = take_nulls(values.nulls(), indices);
410 Ok(PrimitiveArray::try_new(values_buf, nulls)?.with_data_type(values.data_type().clone()))
411}
412
413#[inline(never)]
414fn take_nulls<I: ArrowPrimitiveType>(
415 values: Option<&NullBuffer>,
416 indices: &PrimitiveArray<I>,
417) -> Option<NullBuffer> {
418 match values.filter(|n| n.null_count() > 0) {
419 Some(n) => NullBuffer::from_unsliced_buffer(
420 take_bits(n.inner(), indices).into_inner(),
421 indices.len(),
422 ),
423 None => indices.nulls().cloned(),
424 }
425}
426
427#[inline(never)]
428fn take_native<T: ArrowNativeType, I: ArrowPrimitiveType>(
429 values: &[T],
430 indices: &PrimitiveArray<I>,
431) -> ScalarBuffer<T> {
432 match indices.nulls().filter(|n| n.null_count() > 0) {
433 Some(n) => indices
434 .values()
435 .iter()
436 .enumerate()
437 .map(|(idx, index)| match values.get(index.as_usize()) {
438 Some(v) => *v,
439 None => match unsafe { n.inner().value_unchecked(idx) } {
441 false => T::default(),
442 true => panic!("Out-of-bounds index {index:?}"),
443 },
444 })
445 .collect(),
446 None => indices
447 .values()
448 .iter()
449 .map(|index| values[index.as_usize()])
450 .collect(),
451 }
452}
453
454#[inline(never)]
455fn take_bits<I: ArrowPrimitiveType>(
456 values: &BooleanBuffer,
457 indices: &PrimitiveArray<I>,
458) -> BooleanBuffer {
459 let len = indices.len();
460
461 match indices.nulls().filter(|n| n.null_count() > 0) {
462 Some(nulls) => {
463 let mut output_buffer = MutableBuffer::new_null(len);
464 let output_slice = output_buffer.as_slice_mut();
465 nulls.valid_indices().for_each(|idx| {
466 if values.value(unsafe { indices.value_unchecked(idx).as_usize() }) {
468 unsafe { bit_util::set_bit_raw(output_slice.as_mut_ptr(), idx) };
470 }
471 });
472 BooleanBuffer::new(output_buffer.into(), 0, len)
473 }
474 None => {
475 BooleanBuffer::collect_bool(len, |idx: usize| {
476 values.value(unsafe { indices.value_unchecked(idx).as_usize() })
478 })
479 }
480 }
481}
482
483fn take_boolean<IndexType: ArrowPrimitiveType>(
485 values: &BooleanArray,
486 indices: &PrimitiveArray<IndexType>,
487) -> BooleanArray {
488 let val_buf = take_bits(values.values(), indices);
489 let null_buf = take_nulls(values.nulls(), indices);
490 BooleanArray::new(val_buf, null_buf)
491}
492
493fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
495 array: &GenericByteArray<T>,
496 indices: &PrimitiveArray<IndexType>,
497) -> Result<GenericByteArray<T>, ArrowError> {
498 let mut values: Vec<u8> = Vec::new();
499 let mut offsets = Vec::with_capacity(indices.len() + 1);
500 offsets.push(T::Offset::default());
501
502 let input_offsets = array.value_offsets();
503 let mut capacity = 0;
504 let nulls = take_nulls(array.nulls(), indices);
505
506 match nulls.as_ref().filter(|n| n.null_count() > 0) {
508 None => {
510 for index in indices.values() {
511 let index = index.as_usize();
512 let start = input_offsets[index].as_usize();
513 let end = input_offsets[index + 1].as_usize();
514 capacity += end - start;
515 offsets.push(
516 T::Offset::from_usize(capacity)
517 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?,
518 );
519 }
520
521 values.reserve(capacity);
522
523 let dst = values.spare_capacity_mut();
524 debug_assert!(dst.len() >= capacity);
525 let mut offset = 0;
526
527 for index in indices.values() {
528 unsafe {
531 let data: &[u8] = array.value_unchecked(index.as_usize()).as_ref();
532 std::ptr::copy_nonoverlapping(
533 data.as_ptr(),
534 dst.get_unchecked_mut(offset..).as_mut_ptr().cast::<u8>(),
535 data.len(),
536 );
537 offset += data.len();
538 }
539 }
540
541 unsafe {
543 values.set_len(capacity);
544 }
545 }
546 Some(output_nulls) => {
548 let mut source_ranges = Vec::with_capacity(indices.len() - output_nulls.null_count());
549 let mut last_filled = 0;
550
551 offsets.resize(indices.len() + 1, T::Offset::default());
553
554 for i in output_nulls.valid_indices() {
556 let current_offset = T::Offset::from_usize(capacity)
557 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?;
558 if last_filled < i {
560 offsets[last_filled + 1..=i].fill(current_offset);
561 }
562
563 let index = unsafe { indices.value_unchecked(i) }.as_usize();
565 let start = input_offsets[index].as_usize();
566 let end = input_offsets[index + 1].as_usize();
567 capacity += end - start;
568 offsets[i + 1] = T::Offset::from_usize(capacity)
569 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?;
570
571 source_ranges.push((start, end));
572 last_filled = i + 1;
573 }
574
575 let final_offset = T::Offset::from_usize(capacity)
577 .ok_or_else(|| ArrowError::OffsetOverflowError(capacity))?;
578 offsets[last_filled + 1..].fill(final_offset);
579 values.reserve(capacity);
581 debug_assert_eq!(
582 source_ranges.iter().map(|(s, e)| e - s).sum::<usize>(),
583 capacity,
584 "capacity must equal total bytes across all ranges"
585 );
586
587 let src = array.value_data();
588 let src = src.as_ptr();
589 let dst = values.spare_capacity_mut();
590 debug_assert!(dst.len() >= capacity);
591
592 let mut offset = 0;
593
594 for (start, end) in source_ranges.into_iter() {
595 let value_len = end - start;
596 unsafe {
600 std::ptr::copy_nonoverlapping(
601 src.add(start),
602 dst.get_unchecked_mut(offset..).as_mut_ptr().cast::<u8>(),
603 value_len,
604 );
605 offset += value_len;
606 }
607 }
608 unsafe { values.set_len(capacity) };
611 }
612 };
613
614 let array = unsafe {
617 let offsets = OffsetBuffer::new_unchecked(offsets.into());
618 GenericByteArray::<T>::new_unchecked(offsets, values.into(), nulls)
619 };
620
621 Ok(array)
622}
623
624fn take_byte_view<T: ByteViewType, IndexType: ArrowPrimitiveType>(
626 array: &GenericByteViewArray<T>,
627 indices: &PrimitiveArray<IndexType>,
628) -> Result<GenericByteViewArray<T>, ArrowError> {
629 let new_views = take_native(array.views(), indices);
630 let new_nulls = take_nulls(array.nulls(), indices);
631 Ok(unsafe {
633 GenericByteViewArray::new_unchecked(new_views, array.data_buffers().to_vec(), new_nulls)
634 })
635}
636
637fn take_list<IndexType, OffsetType>(
642 values: &GenericListArray<OffsetType::Native>,
643 indices: &PrimitiveArray<IndexType>,
644) -> Result<GenericListArray<OffsetType::Native>, ArrowError>
645where
646 IndexType: ArrowPrimitiveType,
647 OffsetType: ArrowPrimitiveType,
648 OffsetType::Native: OffsetSizeTrait,
649 PrimitiveArray<OffsetType>: From<Vec<OffsetType::Native>>,
650{
651 let list_offsets = values.value_offsets();
652 let child_data = values.values().to_data();
653 let nulls = take_nulls(values.nulls(), indices);
654
655 let mut new_offsets = Vec::with_capacity(indices.len() + 1);
656 new_offsets.push(OffsetType::Native::zero());
657
658 let use_nulls = child_data.null_count() > 0;
659
660 let capacity = child_data
661 .len()
662 .checked_div(values.len())
663 .map(|v| v * indices.len())
664 .unwrap_or_default();
665
666 let mut array_data = MutableArrayData::new(vec![&child_data], use_nulls, capacity);
667
668 match nulls.as_ref().filter(|n| n.null_count() > 0) {
669 None => {
670 for index in indices.values() {
671 let ix = index.as_usize();
672 let start = list_offsets[ix].as_usize();
673 let end = list_offsets[ix + 1].as_usize();
674 array_data.extend(0, start, end);
675 new_offsets.push(OffsetType::Native::from_usize(array_data.len()).unwrap());
676 }
677 }
678 Some(output_nulls) => {
679 assert_eq!(output_nulls.len(), indices.len());
680
681 let mut last_filled = 0;
682 for i in output_nulls.valid_indices() {
683 let current = OffsetType::Native::from_usize(array_data.len()).unwrap();
684 if last_filled < i {
686 new_offsets.extend(std::iter::repeat_n(current, i - last_filled));
687 }
688
689 let ix = unsafe { indices.value_unchecked(i) }.as_usize();
691 let start = list_offsets[ix].as_usize();
692 let end = list_offsets[ix + 1].as_usize();
693 array_data.extend(0, start, end);
694 new_offsets.push(OffsetType::Native::from_usize(array_data.len()).unwrap());
695 last_filled = i + 1;
696 }
697
698 let final_offset = OffsetType::Native::from_usize(array_data.len()).unwrap();
700 new_offsets.extend(std::iter::repeat_n(
701 final_offset,
702 indices.len() - last_filled,
703 ));
704 }
705 };
706
707 assert_eq!(
708 new_offsets.len(),
709 indices.len() + 1,
710 "New offsets was filled under/over the expected capacity"
711 );
712
713 let child_data = array_data.freeze();
714 let value_offsets = Buffer::from_vec(new_offsets);
715
716 let list_data = ArrayDataBuilder::new(values.data_type().clone())
717 .len(indices.len())
718 .nulls(nulls)
719 .offset(0)
720 .add_child_data(child_data)
721 .add_buffer(value_offsets);
722
723 let list_data = unsafe { list_data.build_unchecked() };
724 Ok(GenericListArray::<OffsetType::Native>::from(list_data))
725}
726
727fn take_list_view<IndexType, OffsetType>(
728 values: &GenericListViewArray<OffsetType::Native>,
729 indices: &PrimitiveArray<IndexType>,
730) -> Result<GenericListViewArray<OffsetType::Native>, ArrowError>
731where
732 IndexType: ArrowPrimitiveType,
733 OffsetType: ArrowPrimitiveType,
734 OffsetType::Native: OffsetSizeTrait,
735{
736 let taken_offsets = take_native(values.offsets(), indices);
737 let taken_sizes = take_native(values.sizes(), indices);
738 let nulls = take_nulls(values.nulls(), indices);
739
740 let list_view_data = ArrayDataBuilder::new(values.data_type().clone())
741 .len(indices.len())
742 .nulls(nulls)
743 .buffers(vec![taken_offsets.into(), taken_sizes.into()])
744 .child_data(vec![values.values().to_data()]);
745
746 let list_view_data = unsafe { list_view_data.build_unchecked() };
748
749 Ok(GenericListViewArray::<OffsetType::Native>::from(
750 list_view_data,
751 ))
752}
753
754fn take_fixed_size_list<IndexType: ArrowPrimitiveType>(
760 values: &FixedSizeListArray,
761 indices: &PrimitiveArray<IndexType>,
762 length: <UInt32Type as ArrowPrimitiveType>::Native,
763) -> Result<FixedSizeListArray, ArrowError> {
764 let list_indices = take_value_indices_from_fixed_size_list(values, indices, length)?;
765 let taken = take_impl::<UInt32Type>(values.values().as_ref(), &list_indices)?;
766
767 let num_bytes = bit_util::ceil(indices.len(), 8);
769 let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
770 let null_slice = null_buf.as_slice_mut();
771
772 for i in 0..indices.len() {
773 let index = indices
774 .value(i)
775 .to_usize()
776 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
777 if !indices.is_valid(i) || values.is_null(index) {
778 bit_util::unset_bit(null_slice, i);
779 }
780 }
781
782 let list_data = ArrayDataBuilder::new(values.data_type().clone())
783 .len(indices.len())
784 .null_bit_buffer(Some(null_buf.into()))
785 .offset(0)
786 .add_child_data(taken.into_data());
787
788 let list_data = unsafe { list_data.build_unchecked() };
789
790 Ok(FixedSizeListArray::from(list_data))
791}
792
793fn take_fixed_size_binary<IndexType: ArrowPrimitiveType>(
799 values: &FixedSizeBinaryArray,
800 indices: &PrimitiveArray<IndexType>,
801 size: i32,
802) -> Result<FixedSizeBinaryArray, ArrowError> {
803 let size_usize = usize::try_from(size).map_err(|_| {
804 ArrowError::InvalidArgumentError(format!("Cannot convert size '{}' to usize", size))
805 })?;
806
807 let result_buffer = match size_usize {
808 1 => take_fixed_size::<IndexType, 1>(values.values(), indices),
809 2 => take_fixed_size::<IndexType, 2>(values.values(), indices),
810 4 => take_fixed_size::<IndexType, 4>(values.values(), indices),
811 8 => take_fixed_size::<IndexType, 8>(values.values(), indices),
812 16 => take_fixed_size::<IndexType, 16>(values.values(), indices),
813 _ => take_fixed_size_binary_buffer_dynamic_length(values, indices, size_usize),
814 };
815
816 let value_nulls = take_nulls(values.nulls(), indices);
817 let final_nulls = NullBuffer::union(value_nulls.as_ref(), indices.nulls());
818 let array_data = ArrayDataBuilder::new(DataType::FixedSizeBinary(size))
819 .len(indices.len())
820 .nulls(final_nulls)
821 .offset(0)
822 .add_buffer(result_buffer)
823 .build()?;
824
825 return Ok(FixedSizeBinaryArray::from(array_data));
826
827 #[inline(never)]
829 fn take_fixed_size_binary_buffer_dynamic_length<IndexType: ArrowPrimitiveType>(
830 values: &FixedSizeBinaryArray,
831 indices: &PrimitiveArray<IndexType>,
832 size_usize: usize,
833 ) -> Buffer {
834 let values_buffer = values.values().as_slice();
835 let mut values_buffer_builder = BufferBuilder::new(indices.len() * size_usize);
836
837 if indices.null_count() == 0 {
838 let array_iter = indices.values().iter().map(|idx| {
839 let offset = idx.as_usize() * size_usize;
840 &values_buffer[offset..offset + size_usize]
841 });
842 for slice in array_iter {
843 values_buffer_builder.append_slice(slice);
844 }
845 } else {
846 let array_iter = indices.iter().map(|idx| {
849 idx.map(|idx| {
850 let offset = idx.as_usize() * size_usize;
851 &values_buffer[offset..offset + size_usize]
852 })
853 });
854 for slice in array_iter {
855 match slice {
856 None => values_buffer_builder.append_n(size_usize, 0),
857 Some(slice) => values_buffer_builder.append_slice(slice),
858 }
859 }
860 }
861
862 values_buffer_builder.finish()
863 }
864}
865
866fn take_fixed_size<IndexType: ArrowPrimitiveType, const N: usize>(
879 buffer: &Buffer,
880 indices: &PrimitiveArray<IndexType>,
881) -> Buffer {
882 assert_eq!(
883 buffer.len() % N,
884 0,
885 "Invalid array length in take_fixed_size"
886 );
887
888 let ptr = buffer.as_ptr();
889 let chunk_ptr = ptr.cast::<[u8; N]>();
890 let chunk_len = buffer.len() / N;
891 let buffer: &[[u8; N]] = unsafe {
892 std::slice::from_raw_parts(chunk_ptr, chunk_len)
895 };
896
897 let result_buffer = match indices.nulls().filter(|n| n.null_count() > 0) {
898 Some(n) => indices
899 .values()
900 .iter()
901 .enumerate()
902 .map(|(idx, index)| match buffer.get(index.as_usize()) {
903 Some(v) => *v,
904 None => match unsafe { n.inner().value_unchecked(idx) } {
906 false => [0u8; N],
907 true => panic!("Out-of-bounds index {index:?}"),
908 },
909 })
910 .collect::<Vec<_>>(),
911 None => indices
912 .values()
913 .iter()
914 .map(|index| buffer[index.as_usize()])
915 .collect::<Vec<_>>(),
916 };
917
918 let mut vec = ManuallyDrop::new(result_buffer); let ptr = vec.as_mut_ptr();
920 let len = vec.len();
921 let cap = vec.capacity();
922 let result_buffer = unsafe {
923 Vec::from_raw_parts(ptr.cast::<u8>(), len * N, cap * N)
925 };
926
927 Buffer::from_vec(result_buffer)
928}
929
930fn take_dict<T: ArrowDictionaryKeyType, I: ArrowPrimitiveType>(
935 values: &DictionaryArray<T>,
936 indices: &PrimitiveArray<I>,
937) -> Result<DictionaryArray<T>, ArrowError> {
938 let new_keys = take_primitive(values.keys(), indices)?;
939 Ok(unsafe { DictionaryArray::new_unchecked(new_keys, values.values().clone()) })
940}
941
942fn take_run<T: RunEndIndexType, I: ArrowPrimitiveType>(
951 run_array: &RunArray<T>,
952 logical_indices: &PrimitiveArray<I>,
953) -> Result<RunArray<T>, ArrowError> {
954 let physical_indices = run_array.get_physical_indices(logical_indices.values())?;
956
957 let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
961 let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
962 let mut new_physical_len = 1;
963 for ix in 1..physical_indices.len() {
964 if physical_indices[ix] != physical_indices[ix - 1] {
965 take_value_indices.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
966 new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
967 new_physical_len += 1;
968 }
969 }
970 take_value_indices
971 .append(I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap());
972 new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
973 let new_run_ends = unsafe {
974 ArrayDataBuilder::new(T::DATA_TYPE)
977 .len(new_physical_len)
978 .null_count(0)
979 .add_buffer(new_run_ends_builder.finish())
980 .build_unchecked()
981 };
982
983 let take_value_indices: PrimitiveArray<I> = unsafe {
984 ArrayDataBuilder::new(I::DATA_TYPE)
987 .len(new_physical_len)
988 .null_count(0)
989 .add_buffer(take_value_indices.finish())
990 .build_unchecked()
991 .into()
992 };
993
994 let new_values = take(run_array.values(), &take_value_indices, None)?;
995
996 let builder = ArrayDataBuilder::new(run_array.data_type().clone())
997 .len(physical_indices.len())
998 .add_child_data(new_run_ends)
999 .add_child_data(new_values.into_data());
1000 let array_data = unsafe {
1001 builder.build_unchecked()
1004 };
1005 Ok(array_data.into())
1006}
1007
1008fn take_value_indices_from_fixed_size_list<IndexType>(
1010 list: &FixedSizeListArray,
1011 indices: &PrimitiveArray<IndexType>,
1012 length: <UInt32Type as ArrowPrimitiveType>::Native,
1013) -> Result<PrimitiveArray<UInt32Type>, ArrowError>
1014where
1015 IndexType: ArrowPrimitiveType,
1016{
1017 let mut values = UInt32Builder::with_capacity(length as usize * indices.len());
1018
1019 for i in 0..indices.len() {
1020 if indices.is_valid(i) {
1021 let index = indices
1022 .value(i)
1023 .to_usize()
1024 .ok_or_else(|| ArrowError::ComputeError("Cast to usize failed".to_string()))?;
1025 let start = list.value_offset(index) as <UInt32Type as ArrowPrimitiveType>::Native;
1026
1027 unsafe {
1029 values.append_trusted_len_iter(start..start + length);
1030 }
1031 } else {
1032 values.append_nulls(length as usize);
1033 }
1034 }
1035
1036 Ok(values.finish())
1037}
1038
1039trait ToIndices {
1042 type T: ArrowPrimitiveType;
1043
1044 fn to_indices(&self) -> PrimitiveArray<Self::T>;
1045}
1046
1047macro_rules! to_indices_reinterpret {
1048 ($t:ty, $o:ty) => {
1049 impl ToIndices for PrimitiveArray<$t> {
1050 type T = $o;
1051
1052 fn to_indices(&self) -> PrimitiveArray<$o> {
1053 let cast = ScalarBuffer::new(self.values().inner().clone(), 0, self.len());
1054 PrimitiveArray::new(cast, self.nulls().cloned())
1055 }
1056 }
1057 };
1058}
1059
1060macro_rules! to_indices_identity {
1061 ($t:ty) => {
1062 impl ToIndices for PrimitiveArray<$t> {
1063 type T = $t;
1064
1065 fn to_indices(&self) -> PrimitiveArray<$t> {
1066 self.clone()
1067 }
1068 }
1069 };
1070}
1071
1072macro_rules! to_indices_widening {
1073 ($t:ty, $o:ty) => {
1074 impl ToIndices for PrimitiveArray<$t> {
1075 type T = UInt32Type;
1076
1077 fn to_indices(&self) -> PrimitiveArray<$o> {
1078 let cast = self.values().iter().copied().map(|x| x as _).collect();
1079 PrimitiveArray::new(cast, self.nulls().cloned())
1080 }
1081 }
1082 };
1083}
1084
1085to_indices_widening!(UInt8Type, UInt32Type);
1086to_indices_widening!(Int8Type, UInt32Type);
1087
1088to_indices_widening!(UInt16Type, UInt32Type);
1089to_indices_widening!(Int16Type, UInt32Type);
1090
1091to_indices_identity!(UInt32Type);
1092to_indices_reinterpret!(Int32Type, UInt32Type);
1093
1094to_indices_identity!(UInt64Type);
1095to_indices_reinterpret!(Int64Type, UInt64Type);
1096
1097pub fn take_record_batch(
1136 record_batch: &RecordBatch,
1137 indices: &dyn Array,
1138) -> Result<RecordBatch, ArrowError> {
1139 let columns = record_batch
1140 .columns()
1141 .iter()
1142 .map(|c| take(c, indices, None))
1143 .collect::<Result<Vec<_>, _>>()?;
1144 RecordBatch::try_new(record_batch.schema(), columns)
1145}
1146
1147#[cfg(test)]
1148mod tests {
1149 use super::*;
1150 use arrow_array::builder::*;
1151 use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
1152 use arrow_data::ArrayData;
1153 use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
1154 use num_traits::ToPrimitive;
1155
1156 fn test_take_decimal_arrays(
1157 data: Vec<Option<i128>>,
1158 index: &UInt32Array,
1159 options: Option<TakeOptions>,
1160 expected_data: Vec<Option<i128>>,
1161 precision: &u8,
1162 scale: &i8,
1163 ) -> Result<(), ArrowError> {
1164 let output = data
1165 .into_iter()
1166 .collect::<Decimal128Array>()
1167 .with_precision_and_scale(*precision, *scale)
1168 .unwrap();
1169
1170 let expected = expected_data
1171 .into_iter()
1172 .collect::<Decimal128Array>()
1173 .with_precision_and_scale(*precision, *scale)
1174 .unwrap();
1175
1176 let expected = Arc::new(expected) as ArrayRef;
1177 let output = take(&output, index, options).unwrap();
1178 assert_eq!(&output, &expected);
1179 Ok(())
1180 }
1181
1182 fn test_take_boolean_arrays(
1183 data: Vec<Option<bool>>,
1184 index: &UInt32Array,
1185 options: Option<TakeOptions>,
1186 expected_data: Vec<Option<bool>>,
1187 ) {
1188 let output = BooleanArray::from(data);
1189 let expected = Arc::new(BooleanArray::from(expected_data)) as ArrayRef;
1190 let output = take(&output, index, options).unwrap();
1191 assert_eq!(&output, &expected)
1192 }
1193
1194 fn test_take_primitive_arrays<T>(
1195 data: Vec<Option<T::Native>>,
1196 index: &UInt32Array,
1197 options: Option<TakeOptions>,
1198 expected_data: Vec<Option<T::Native>>,
1199 ) -> Result<(), ArrowError>
1200 where
1201 T: ArrowPrimitiveType,
1202 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1203 {
1204 let output = PrimitiveArray::<T>::from(data);
1205 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1206 let output = take(&output, index, options)?;
1207 assert_eq!(&output, &expected);
1208 Ok(())
1209 }
1210
1211 fn test_take_primitive_arrays_non_null<T>(
1212 data: Vec<T::Native>,
1213 index: &UInt32Array,
1214 options: Option<TakeOptions>,
1215 expected_data: Vec<Option<T::Native>>,
1216 ) -> Result<(), ArrowError>
1217 where
1218 T: ArrowPrimitiveType,
1219 PrimitiveArray<T>: From<Vec<T::Native>>,
1220 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1221 {
1222 let output = PrimitiveArray::<T>::from(data);
1223 let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
1224 let output = take(&output, index, options)?;
1225 assert_eq!(&output, &expected);
1226 Ok(())
1227 }
1228
1229 fn test_take_impl_primitive_arrays<T, I>(
1230 data: Vec<Option<T::Native>>,
1231 index: &PrimitiveArray<I>,
1232 options: Option<TakeOptions>,
1233 expected_data: Vec<Option<T::Native>>,
1234 ) where
1235 T: ArrowPrimitiveType,
1236 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
1237 I: ArrowPrimitiveType,
1238 {
1239 let output = PrimitiveArray::<T>::from(data);
1240 let expected = PrimitiveArray::<T>::from(expected_data);
1241 let output = take(&output, index, options).unwrap();
1242 let output = output.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
1243 assert_eq!(output, &expected)
1244 }
1245
1246 fn create_test_struct(values: Vec<Option<(Option<bool>, Option<i32>)>>) -> StructArray {
1248 let mut struct_builder = StructBuilder::new(
1249 Fields::from(vec![
1250 Field::new("a", DataType::Boolean, true),
1251 Field::new("b", DataType::Int32, true),
1252 ]),
1253 vec![
1254 Box::new(BooleanBuilder::with_capacity(values.len())),
1255 Box::new(Int32Builder::with_capacity(values.len())),
1256 ],
1257 );
1258
1259 for value in values {
1260 struct_builder
1261 .field_builder::<BooleanBuilder>(0)
1262 .unwrap()
1263 .append_option(value.and_then(|v| v.0));
1264 struct_builder
1265 .field_builder::<Int32Builder>(1)
1266 .unwrap()
1267 .append_option(value.and_then(|v| v.1));
1268 struct_builder.append(value.is_some());
1269 }
1270 struct_builder.finish()
1271 }
1272
1273 #[test]
1274 fn test_take_decimal128_non_null_indices() {
1275 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1276 let precision: u8 = 10;
1277 let scale: i8 = 5;
1278 test_take_decimal_arrays(
1279 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1280 &index,
1281 None,
1282 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1283 &precision,
1284 &scale,
1285 )
1286 .unwrap();
1287 }
1288
1289 #[test]
1290 fn test_take_decimal128() {
1291 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1292 let precision: u8 = 10;
1293 let scale: i8 = 5;
1294 test_take_decimal_arrays(
1295 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1296 &index,
1297 None,
1298 vec![Some(3), None, Some(1), Some(3), Some(2)],
1299 &precision,
1300 &scale,
1301 )
1302 .unwrap();
1303 }
1304
1305 #[test]
1306 fn test_take_primitive_non_null_indices() {
1307 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1308 test_take_primitive_arrays::<Int8Type>(
1309 vec![None, Some(3), Some(5), Some(2), Some(3), None],
1310 &index,
1311 None,
1312 vec![None, None, Some(2), Some(3), Some(3), Some(5)],
1313 )
1314 .unwrap();
1315 }
1316
1317 #[test]
1318 fn test_take_primitive_non_null_values() {
1319 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1320 test_take_primitive_arrays::<Int8Type>(
1321 vec![Some(0), Some(1), Some(2), Some(3), Some(4)],
1322 &index,
1323 None,
1324 vec![Some(3), None, Some(1), Some(3), Some(2)],
1325 )
1326 .unwrap();
1327 }
1328
1329 #[test]
1330 fn test_take_primitive_non_null() {
1331 let index = UInt32Array::from(vec![0, 5, 3, 1, 4, 2]);
1332 test_take_primitive_arrays::<Int8Type>(
1333 vec![Some(0), Some(3), Some(5), Some(2), Some(3), Some(1)],
1334 &index,
1335 None,
1336 vec![Some(0), Some(1), Some(2), Some(3), Some(3), Some(5)],
1337 )
1338 .unwrap();
1339 }
1340
1341 #[test]
1342 fn test_take_primitive_nullable_indices_non_null_values_with_offset() {
1343 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1344 let index = index.slice(2, 4);
1345 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1346
1347 assert_eq!(
1348 index,
1349 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1350 );
1351
1352 test_take_primitive_arrays_non_null::<Int64Type>(
1353 vec![0, 10, 20, 30, 40, 50],
1354 index,
1355 None,
1356 vec![Some(20), Some(30), None, None],
1357 )
1358 .unwrap();
1359 }
1360
1361 #[test]
1362 fn test_take_primitive_nullable_indices_nullable_values_with_offset() {
1363 let index = UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
1364 let index = index.slice(2, 4);
1365 let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();
1366
1367 assert_eq!(
1368 index,
1369 &UInt32Array::from(vec![Some(2), Some(3), None, None])
1370 );
1371
1372 test_take_primitive_arrays::<Int64Type>(
1373 vec![None, None, Some(20), Some(30), Some(40), Some(50)],
1374 index,
1375 None,
1376 vec![Some(20), Some(30), None, None],
1377 )
1378 .unwrap();
1379 }
1380
1381 #[test]
1382 fn test_take_primitive() {
1383 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1384
1385 test_take_primitive_arrays::<Int8Type>(
1387 vec![Some(0), None, Some(2), Some(3), None],
1388 &index,
1389 None,
1390 vec![Some(3), None, None, Some(3), Some(2)],
1391 )
1392 .unwrap();
1393
1394 test_take_primitive_arrays::<Int16Type>(
1396 vec![Some(0), None, Some(2), Some(3), None],
1397 &index,
1398 None,
1399 vec![Some(3), None, None, Some(3), Some(2)],
1400 )
1401 .unwrap();
1402
1403 test_take_primitive_arrays::<Int32Type>(
1405 vec![Some(0), None, Some(2), Some(3), None],
1406 &index,
1407 None,
1408 vec![Some(3), None, None, Some(3), Some(2)],
1409 )
1410 .unwrap();
1411
1412 test_take_primitive_arrays::<Int64Type>(
1414 vec![Some(0), None, Some(2), Some(3), None],
1415 &index,
1416 None,
1417 vec![Some(3), None, None, Some(3), Some(2)],
1418 )
1419 .unwrap();
1420
1421 test_take_primitive_arrays::<UInt8Type>(
1423 vec![Some(0), None, Some(2), Some(3), None],
1424 &index,
1425 None,
1426 vec![Some(3), None, None, Some(3), Some(2)],
1427 )
1428 .unwrap();
1429
1430 test_take_primitive_arrays::<UInt16Type>(
1432 vec![Some(0), None, Some(2), Some(3), None],
1433 &index,
1434 None,
1435 vec![Some(3), None, None, Some(3), Some(2)],
1436 )
1437 .unwrap();
1438
1439 test_take_primitive_arrays::<UInt32Type>(
1441 vec![Some(0), None, Some(2), Some(3), None],
1442 &index,
1443 None,
1444 vec![Some(3), None, None, Some(3), Some(2)],
1445 )
1446 .unwrap();
1447
1448 test_take_primitive_arrays::<Int64Type>(
1450 vec![Some(0), None, Some(2), Some(-15), None],
1451 &index,
1452 None,
1453 vec![Some(-15), None, None, Some(-15), Some(2)],
1454 )
1455 .unwrap();
1456
1457 test_take_primitive_arrays::<IntervalYearMonthType>(
1459 vec![Some(0), None, Some(2), Some(-15), None],
1460 &index,
1461 None,
1462 vec![Some(-15), None, None, Some(-15), Some(2)],
1463 )
1464 .unwrap();
1465
1466 let v1 = IntervalDayTime::new(0, 0);
1468 let v2 = IntervalDayTime::new(2, 0);
1469 let v3 = IntervalDayTime::new(-15, 0);
1470 test_take_primitive_arrays::<IntervalDayTimeType>(
1471 vec![Some(v1), None, Some(v2), Some(v3), None],
1472 &index,
1473 None,
1474 vec![Some(v3), None, None, Some(v3), Some(v2)],
1475 )
1476 .unwrap();
1477
1478 let v1 = IntervalMonthDayNano::new(0, 0, 0);
1480 let v2 = IntervalMonthDayNano::new(2, 0, 0);
1481 let v3 = IntervalMonthDayNano::new(-15, 0, 0);
1482 test_take_primitive_arrays::<IntervalMonthDayNanoType>(
1483 vec![Some(v1), None, Some(v2), Some(v3), None],
1484 &index,
1485 None,
1486 vec![Some(v3), None, None, Some(v3), Some(v2)],
1487 )
1488 .unwrap();
1489
1490 test_take_primitive_arrays::<DurationSecondType>(
1492 vec![Some(0), None, Some(2), Some(-15), None],
1493 &index,
1494 None,
1495 vec![Some(-15), None, None, Some(-15), Some(2)],
1496 )
1497 .unwrap();
1498
1499 test_take_primitive_arrays::<DurationMillisecondType>(
1501 vec![Some(0), None, Some(2), Some(-15), None],
1502 &index,
1503 None,
1504 vec![Some(-15), None, None, Some(-15), Some(2)],
1505 )
1506 .unwrap();
1507
1508 test_take_primitive_arrays::<DurationMicrosecondType>(
1510 vec![Some(0), None, Some(2), Some(-15), None],
1511 &index,
1512 None,
1513 vec![Some(-15), None, None, Some(-15), Some(2)],
1514 )
1515 .unwrap();
1516
1517 test_take_primitive_arrays::<DurationNanosecondType>(
1519 vec![Some(0), None, Some(2), Some(-15), None],
1520 &index,
1521 None,
1522 vec![Some(-15), None, None, Some(-15), Some(2)],
1523 )
1524 .unwrap();
1525
1526 test_take_primitive_arrays::<Float32Type>(
1528 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1529 &index,
1530 None,
1531 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1532 )
1533 .unwrap();
1534
1535 test_take_primitive_arrays::<Float64Type>(
1537 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1538 &index,
1539 None,
1540 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1541 )
1542 .unwrap();
1543 }
1544
1545 #[test]
1546 fn test_take_preserve_timezone() {
1547 let index = Int64Array::from(vec![Some(0), None]);
1548
1549 let input = TimestampNanosecondArray::from(vec![
1550 1_639_715_368_000_000_000,
1551 1_639_715_368_000_000_000,
1552 ])
1553 .with_timezone("UTC".to_string());
1554 let result = take(&input, &index, None).unwrap();
1555 match result.data_type() {
1556 DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
1557 assert_eq!(tz.clone(), Some("UTC".into()))
1558 }
1559 _ => panic!(),
1560 }
1561 }
1562
1563 #[test]
1564 fn test_take_impl_primitive_with_int64_indices() {
1565 let index = Int64Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1566
1567 test_take_impl_primitive_arrays::<Int16Type, Int64Type>(
1569 vec![Some(0), None, Some(2), Some(3), None],
1570 &index,
1571 None,
1572 vec![Some(3), None, None, Some(3), Some(2)],
1573 );
1574
1575 test_take_impl_primitive_arrays::<Int64Type, Int64Type>(
1577 vec![Some(0), None, Some(2), Some(-15), None],
1578 &index,
1579 None,
1580 vec![Some(-15), None, None, Some(-15), Some(2)],
1581 );
1582
1583 test_take_impl_primitive_arrays::<UInt64Type, Int64Type>(
1585 vec![Some(0), None, Some(2), Some(3), None],
1586 &index,
1587 None,
1588 vec![Some(3), None, None, Some(3), Some(2)],
1589 );
1590
1591 test_take_impl_primitive_arrays::<DurationMillisecondType, Int64Type>(
1593 vec![Some(0), None, Some(2), Some(-15), None],
1594 &index,
1595 None,
1596 vec![Some(-15), None, None, Some(-15), Some(2)],
1597 );
1598
1599 test_take_impl_primitive_arrays::<Float32Type, Int64Type>(
1601 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1602 &index,
1603 None,
1604 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1605 );
1606 }
1607
1608 #[test]
1609 fn test_take_impl_primitive_with_uint8_indices() {
1610 let index = UInt8Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1611
1612 test_take_impl_primitive_arrays::<Int16Type, UInt8Type>(
1614 vec![Some(0), None, Some(2), Some(3), None],
1615 &index,
1616 None,
1617 vec![Some(3), None, None, Some(3), Some(2)],
1618 );
1619
1620 test_take_impl_primitive_arrays::<DurationMillisecondType, UInt8Type>(
1622 vec![Some(0), None, Some(2), Some(-15), None],
1623 &index,
1624 None,
1625 vec![Some(-15), None, None, Some(-15), Some(2)],
1626 );
1627
1628 test_take_impl_primitive_arrays::<Float32Type, UInt8Type>(
1630 vec![Some(0.0), None, Some(2.21), Some(-3.1), None],
1631 &index,
1632 None,
1633 vec![Some(-3.1), None, None, Some(-3.1), Some(2.21)],
1634 );
1635 }
1636
1637 #[test]
1638 fn test_take_bool() {
1639 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
1640 test_take_boolean_arrays(
1642 vec![Some(false), None, Some(true), Some(false), None],
1643 &index,
1644 None,
1645 vec![Some(false), None, None, Some(false), Some(true)],
1646 );
1647 }
1648
1649 #[test]
1650 fn test_take_bool_nullable_index() {
1651 let index_data = ArrayData::try_new(
1653 DataType::UInt32,
1654 6,
1655 Some(Buffer::from_iter(vec![
1656 false, true, false, true, false, true,
1657 ])),
1658 0,
1659 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1660 vec![],
1661 )
1662 .unwrap();
1663 let index = UInt32Array::from(index_data);
1664 test_take_boolean_arrays(
1665 vec![Some(true), None, Some(false)],
1666 &index,
1667 None,
1668 vec![None, Some(true), None, None, None, Some(false)],
1669 );
1670 }
1671
1672 #[test]
1673 fn test_take_bool_nullable_index_nonnull_values() {
1674 let index_data = ArrayData::try_new(
1676 DataType::UInt32,
1677 6,
1678 Some(Buffer::from_iter(vec![
1679 false, true, false, true, false, true,
1680 ])),
1681 0,
1682 vec![Buffer::from_iter(vec![99, 0, 999, 1, 9999, 2])],
1683 vec![],
1684 )
1685 .unwrap();
1686 let index = UInt32Array::from(index_data);
1687 test_take_boolean_arrays(
1688 vec![Some(true), Some(true), Some(false)],
1689 &index,
1690 None,
1691 vec![None, Some(true), None, Some(true), None, Some(false)],
1692 );
1693 }
1694
1695 #[test]
1696 fn test_take_bool_with_offset() {
1697 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]);
1698 let index = index.slice(2, 4);
1699 let index = index
1700 .as_any()
1701 .downcast_ref::<PrimitiveArray<UInt32Type>>()
1702 .unwrap();
1703
1704 test_take_boolean_arrays(
1706 vec![Some(false), None, Some(true), Some(false), None],
1707 index,
1708 None,
1709 vec![None, Some(false), Some(true), None],
1710 );
1711 }
1712
1713 fn _test_take_string<'a, K>()
1714 where
1715 K: Array + PartialEq + From<Vec<Option<&'a str>>> + 'static,
1716 {
1717 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4)]);
1718
1719 let array = K::from(vec![
1720 Some("one"),
1721 None,
1722 Some("three"),
1723 Some("four"),
1724 Some("five"),
1725 ]);
1726 let actual = take(&array, &index, None).unwrap();
1727 assert_eq!(actual.len(), index.len());
1728
1729 let actual = actual.as_any().downcast_ref::<K>().unwrap();
1730
1731 let expected = K::from(vec![Some("four"), None, None, Some("four"), Some("five")]);
1732
1733 assert_eq!(actual, &expected);
1734 }
1735
1736 #[test]
1737 fn test_take_string() {
1738 _test_take_string::<StringArray>()
1739 }
1740
1741 #[test]
1742 fn test_take_large_string() {
1743 _test_take_string::<LargeStringArray>()
1744 }
1745
1746 #[test]
1747 fn test_take_slice_string() {
1748 let strings = StringArray::from(vec![Some("hello"), None, Some("world"), None, Some("hi")]);
1749 let indices = Int32Array::from(vec![Some(0), Some(1), None, Some(0), Some(2)]);
1750 let indices_slice = indices.slice(1, 4);
1751 let expected = StringArray::from(vec![None, None, Some("hello"), Some("world")]);
1752 let result = take(&strings, &indices_slice, None).unwrap();
1753 assert_eq!(result.as_ref(), &expected);
1754 }
1755
1756 #[test]
1761 fn test_take_bytes_sliced_values() {
1762 let values = StringArray::from(vec![
1763 Some("aaa"),
1764 Some("bbb"),
1765 None,
1766 Some("ccccc"),
1767 Some("dd"),
1768 None,
1769 Some("eeee"),
1770 ]);
1771 let sliced = values.slice(2, 5);
1774
1775 let indices = Int32Array::from(vec![1, 2, 4, 1]);
1778 let result = take(&sliced, &indices, None).unwrap();
1779 let expected =
1780 StringArray::from(vec![Some("ccccc"), Some("dd"), Some("eeee"), Some("ccccc")]);
1781 assert_eq!(result.as_string::<i32>(), &expected);
1782
1783 let indices = Int32Array::from(vec![Some(1), None, Some(0), Some(4), Some(3)]);
1786 let result = take(&sliced, &indices, None).unwrap();
1787 let expected = StringArray::from(vec![Some("ccccc"), None, None, Some("eeee"), None]);
1788 assert_eq!(result.as_string::<i32>(), &expected);
1789 }
1790
1791 fn _test_byte_view<T>()
1792 where
1793 T: ByteViewType,
1794 str: AsRef<T::Native>,
1795 T::Native: PartialEq,
1796 {
1797 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(4), Some(2)]);
1798 let array = {
1799 let mut builder = GenericByteViewBuilder::<T>::new();
1801 builder.append_value("hello");
1802 builder.append_value("world");
1803 builder.append_null();
1804 builder.append_value("large payload over 12 bytes");
1805 builder.append_value("lulu");
1806 builder.finish()
1807 };
1808
1809 let actual = take(&array, &index, None).unwrap();
1810
1811 assert_eq!(actual.len(), index.len());
1812
1813 let expected = {
1814 let mut builder = GenericByteViewBuilder::<T>::new();
1816 builder.append_value("large payload over 12 bytes");
1817 builder.append_null();
1818 builder.append_value("world");
1819 builder.append_value("large payload over 12 bytes");
1820 builder.append_value("lulu");
1821 builder.append_null();
1822 builder.finish()
1823 };
1824
1825 assert_eq!(actual.as_ref(), &expected);
1826 }
1827
1828 #[test]
1829 fn test_take_string_view() {
1830 _test_byte_view::<StringViewType>()
1831 }
1832
1833 #[test]
1834 fn test_take_binary_view() {
1835 _test_byte_view::<BinaryViewType>()
1836 }
1837
1838 macro_rules! test_take_list {
1839 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1840 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
1842 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1844 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1845 let list_data_type =
1847 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, false)));
1848 let list_data = ArrayData::builder(list_data_type.clone())
1849 .len(4)
1850 .add_buffer(value_offsets)
1851 .add_child_data(value_data)
1852 .build()
1853 .unwrap();
1854 let list_array = $list_array_type::from(list_data);
1855
1856 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(2), Some(0)]);
1858
1859 let a = take(&list_array, &index, None).unwrap();
1860 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1861
1862 let expected_data = Int32Array::from(vec![
1865 Some(2),
1866 Some(3),
1867 Some(-1),
1868 Some(-2),
1869 Some(-1),
1870 Some(0),
1871 Some(0),
1872 Some(0),
1873 ])
1874 .into_data();
1875 let expected_offsets: [$offset_type; 6] = [0, 2, 2, 5, 5, 8];
1877 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1878 let expected_list_data = ArrayData::builder(list_data_type)
1880 .len(5)
1881 .nulls(index.nulls().cloned())
1883 .add_buffer(expected_offsets)
1884 .add_child_data(expected_data)
1885 .build()
1886 .unwrap();
1887 let expected_list_array = $list_array_type::from(expected_list_data);
1888
1889 assert_eq!(a, &expected_list_array);
1890 }};
1891 }
1892
1893 macro_rules! test_take_list_with_value_nulls {
1894 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1895 let value_data = Int32Array::from(vec![
1897 Some(0),
1898 None,
1899 Some(0),
1900 Some(-1),
1901 Some(-2),
1902 Some(3),
1903 None,
1904 Some(5),
1905 None,
1906 ])
1907 .into_data();
1908 let value_offsets: [$offset_type; 5] = [0, 3, 6, 7, 9];
1910 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1911 let list_data_type =
1913 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1914 let list_data = ArrayData::builder(list_data_type.clone())
1915 .len(4)
1916 .add_buffer(value_offsets)
1917 .null_bit_buffer(Some(Buffer::from([0b11111111])))
1918 .add_child_data(value_data)
1919 .build()
1920 .unwrap();
1921 let list_array = $list_array_type::from(list_data);
1922
1923 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1925
1926 let a = take(&list_array, &index, None).unwrap();
1927 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1928
1929 let expected_data = Int32Array::from(vec![
1932 None,
1933 Some(-1),
1934 Some(-2),
1935 Some(3),
1936 Some(5),
1937 None,
1938 Some(0),
1939 None,
1940 Some(0),
1941 ])
1942 .into_data();
1943 let expected_offsets: [$offset_type; 6] = [0, 1, 1, 4, 6, 9];
1945 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
1946 let expected_list_data = ArrayData::builder(list_data_type)
1948 .len(5)
1949 .nulls(index.nulls().cloned())
1951 .add_buffer(expected_offsets)
1952 .add_child_data(expected_data)
1953 .build()
1954 .unwrap();
1955 let expected_list_array = $list_array_type::from(expected_list_data);
1956
1957 assert_eq!(a, &expected_list_array);
1958 }};
1959 }
1960
1961 macro_rules! test_take_list_with_nulls {
1962 ($offset_type:ty, $list_data_type:ident, $list_array_type:ident) => {{
1963 let value_data = Int32Array::from(vec![
1965 Some(0),
1966 None,
1967 Some(0),
1968 Some(-1),
1969 Some(-2),
1970 Some(3),
1971 Some(5),
1972 None,
1973 ])
1974 .into_data();
1975 let value_offsets: [$offset_type; 5] = [0, 3, 6, 6, 8];
1977 let value_offsets = Buffer::from_slice_ref(&value_offsets);
1978 let list_data_type =
1980 DataType::$list_data_type(Arc::new(Field::new_list_field(DataType::Int32, true)));
1981 let list_data = ArrayData::builder(list_data_type.clone())
1982 .len(4)
1983 .add_buffer(value_offsets)
1984 .null_bit_buffer(Some(Buffer::from([0b11111011])))
1985 .add_child_data(value_data)
1986 .build()
1987 .unwrap();
1988 let list_array = $list_array_type::from(list_data);
1989
1990 let index = UInt32Array::from(vec![Some(2), None, Some(1), Some(3), Some(0)]);
1992
1993 let a = take(&list_array, &index, None).unwrap();
1994 let a: &$list_array_type = a.as_any().downcast_ref::<$list_array_type>().unwrap();
1995
1996 let expected_data = Int32Array::from(vec![
1999 Some(-1),
2000 Some(-2),
2001 Some(3),
2002 Some(5),
2003 None,
2004 Some(0),
2005 None,
2006 Some(0),
2007 ])
2008 .into_data();
2009 let expected_offsets: [$offset_type; 6] = [0, 0, 0, 3, 5, 8];
2011 let expected_offsets = Buffer::from_slice_ref(&expected_offsets);
2012 let mut null_bits: [u8; 1] = [0; 1];
2014 bit_util::set_bit(&mut null_bits, 2);
2015 bit_util::set_bit(&mut null_bits, 3);
2016 bit_util::set_bit(&mut null_bits, 4);
2017 let expected_list_data = ArrayData::builder(list_data_type)
2018 .len(5)
2019 .null_bit_buffer(Some(Buffer::from(null_bits)))
2021 .add_buffer(expected_offsets)
2022 .add_child_data(expected_data)
2023 .build()
2024 .unwrap();
2025 let expected_list_array = $list_array_type::from(expected_list_data);
2026
2027 assert_eq!(a, &expected_list_array);
2028 }};
2029 }
2030
2031 fn test_take_list_view_generic<OffsetType: OffsetSizeTrait, ValuesType: ArrowPrimitiveType, F>(
2032 values: Vec<Option<Vec<Option<ValuesType::Native>>>>,
2033 take_indices: Vec<Option<usize>>,
2034 expected: Vec<Option<Vec<Option<ValuesType::Native>>>>,
2035 mapper: F,
2036 ) where
2037 F: Fn(GenericListViewArray<OffsetType>) -> GenericListViewArray<OffsetType>,
2038 {
2039 let mut list_view_array =
2040 GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new());
2041
2042 for value in values {
2043 list_view_array.append_option(value);
2044 }
2045 let list_view_array = list_view_array.finish();
2046 let list_view_array = mapper(list_view_array);
2047
2048 let mut indices = UInt64Builder::new();
2049 for idx in take_indices {
2050 indices.append_option(idx.map(|i| i.to_u64().unwrap()));
2051 }
2052 let indices = indices.finish();
2053
2054 let taken = take(&list_view_array, &indices, None)
2055 .unwrap()
2056 .as_list_view()
2057 .clone();
2058
2059 let mut expected_array =
2060 GenericListViewBuilder::<OffsetType, _>::new(PrimitiveBuilder::<ValuesType>::new());
2061 for value in expected {
2062 expected_array.append_option(value);
2063 }
2064 let expected_array = expected_array.finish();
2065
2066 assert_eq!(taken, expected_array);
2067 }
2068
2069 macro_rules! list_view_test_case {
2070 (values: $values:expr, indices: $indices:expr, expected: $expected: expr) => {{
2071 test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, |x| x);
2072 test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, |x| x);
2073 }};
2074 (values: $values:expr, transform: $fn:expr, indices: $indices:expr, expected: $expected: expr) => {{
2075 test_take_list_view_generic::<i32, Int8Type, _>($values, $indices, $expected, $fn);
2076 test_take_list_view_generic::<i64, Int8Type, _>($values, $indices, $expected, $fn);
2077 }};
2078 }
2079
2080 fn do_take_fixed_size_list_test<T>(
2081 length: <Int32Type as ArrowPrimitiveType>::Native,
2082 input_data: Vec<Option<Vec<Option<T::Native>>>>,
2083 indices: Vec<<UInt32Type as ArrowPrimitiveType>::Native>,
2084 expected_data: Vec<Option<Vec<Option<T::Native>>>>,
2085 ) where
2086 T: ArrowPrimitiveType,
2087 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2088 {
2089 let indices = UInt32Array::from(indices);
2090
2091 let input_array = FixedSizeListArray::from_iter_primitive::<T, _, _>(input_data, length);
2092
2093 let output = take_fixed_size_list(&input_array, &indices, length as u32).unwrap();
2094
2095 let expected = FixedSizeListArray::from_iter_primitive::<T, _, _>(expected_data, length);
2096
2097 assert_eq!(&output, &expected)
2098 }
2099
2100 #[test]
2101 fn test_take_list() {
2102 test_take_list!(i32, List, ListArray);
2103 }
2104
2105 #[test]
2106 fn test_take_large_list() {
2107 test_take_list!(i64, LargeList, LargeListArray);
2108 }
2109
2110 #[test]
2111 fn test_take_list_with_value_nulls() {
2112 test_take_list_with_value_nulls!(i32, List, ListArray);
2113 }
2114
2115 #[test]
2116 fn test_take_large_list_with_value_nulls() {
2117 test_take_list_with_value_nulls!(i64, LargeList, LargeListArray);
2118 }
2119
2120 #[test]
2121 fn test_test_take_list_with_nulls() {
2122 test_take_list_with_nulls!(i32, List, ListArray);
2123 }
2124
2125 #[test]
2126 fn test_test_take_large_list_with_nulls() {
2127 test_take_list_with_nulls!(i64, LargeList, LargeListArray);
2128 }
2129
2130 #[test]
2131 fn test_test_take_list_view_reversed() {
2132 list_view_test_case! {
2134 values: vec![
2135 Some(vec![Some(1), None, Some(3)]),
2136 None,
2137 Some(vec![Some(7), Some(8), None]),
2138 ],
2139 indices: vec![Some(2), Some(1), Some(0)],
2140 expected: vec![
2141 Some(vec![Some(7), Some(8), None]),
2142 None,
2143 Some(vec![Some(1), None, Some(3)]),
2144 ]
2145 }
2146 }
2147
2148 #[test]
2149 fn test_take_list_view_null_indices() {
2150 list_view_test_case! {
2152 values: vec![
2153 Some(vec![Some(1), None, Some(3)]),
2154 None,
2155 Some(vec![Some(7), Some(8), None]),
2156 ],
2157 indices: vec![None, Some(0), None],
2158 expected: vec![None, Some(vec![Some(1), None, Some(3)]), None]
2159 }
2160 }
2161
2162 #[test]
2163 fn test_take_list_view_null_values() {
2164 list_view_test_case! {
2166 values: vec![
2167 Some(vec![Some(1), None, Some(3)]),
2168 None,
2169 Some(vec![Some(7), Some(8), None]),
2170 ],
2171 indices: vec![Some(1), Some(1), Some(1), None, None],
2172 expected: vec![None; 5]
2173 }
2174 }
2175
2176 #[test]
2177 fn test_take_list_view_sliced() {
2178 list_view_test_case! {
2180 values: vec![
2181 Some(vec![Some(1)]),
2182 None,
2183 None,
2184 Some(vec![Some(2), Some(3)]),
2185 Some(vec![Some(4), Some(5)]),
2186 None,
2187 ],
2188 transform: |l| l.slice(2, 4),
2189 indices: vec![Some(0), Some(3), None, Some(1), Some(2)],
2190 expected: vec![
2191 None, None, None, Some(vec![Some(2), Some(3)]), Some(vec![Some(4), Some(5)])
2192 ]
2193 }
2194 }
2195
2196 #[test]
2197 fn test_take_fixed_size_list() {
2198 do_take_fixed_size_list_test::<Int32Type>(
2199 3,
2200 vec![
2201 Some(vec![None, Some(1), Some(2)]),
2202 Some(vec![Some(3), Some(4), None]),
2203 Some(vec![Some(6), Some(7), Some(8)]),
2204 ],
2205 vec![2, 1, 0],
2206 vec![
2207 Some(vec![Some(6), Some(7), Some(8)]),
2208 Some(vec![Some(3), Some(4), None]),
2209 Some(vec![None, Some(1), Some(2)]),
2210 ],
2211 );
2212
2213 do_take_fixed_size_list_test::<UInt8Type>(
2214 1,
2215 vec![
2216 Some(vec![Some(1)]),
2217 Some(vec![Some(2)]),
2218 Some(vec![Some(3)]),
2219 Some(vec![Some(4)]),
2220 Some(vec![Some(5)]),
2221 Some(vec![Some(6)]),
2222 Some(vec![Some(7)]),
2223 Some(vec![Some(8)]),
2224 ],
2225 vec![2, 7, 0],
2226 vec![
2227 Some(vec![Some(3)]),
2228 Some(vec![Some(8)]),
2229 Some(vec![Some(1)]),
2230 ],
2231 );
2232
2233 do_take_fixed_size_list_test::<UInt64Type>(
2234 3,
2235 vec![
2236 Some(vec![Some(10), Some(11), Some(12)]),
2237 Some(vec![Some(13), Some(14), Some(15)]),
2238 None,
2239 Some(vec![Some(16), Some(17), Some(18)]),
2240 ],
2241 vec![3, 2, 1, 2, 0],
2242 vec![
2243 Some(vec![Some(16), Some(17), Some(18)]),
2244 None,
2245 Some(vec![Some(13), Some(14), Some(15)]),
2246 None,
2247 Some(vec![Some(10), Some(11), Some(12)]),
2248 ],
2249 );
2250 }
2251
2252 #[test]
2253 fn test_take_fixed_size_binary_with_nulls_indices() {
2254 let fsb = FixedSizeBinaryArray::try_from_sparse_iter_with_size(
2255 [
2256 Some(vec![0x01, 0x01, 0x01, 0x01]),
2257 Some(vec![0x02, 0x02, 0x02, 0x02]),
2258 Some(vec![0x03, 0x03, 0x03, 0x03]),
2259 Some(vec![0x04, 0x04, 0x04, 0x04]),
2260 ]
2261 .into_iter(),
2262 4,
2263 )
2264 .unwrap();
2265
2266 let indices = UInt32Array::from(vec![Some(0), None, None, Some(3)]);
2268
2269 let result = take_fixed_size_binary(&fsb, &indices, 4).unwrap();
2270 assert_eq!(result.len(), 4);
2271 assert_eq!(result.null_count(), 2);
2272 assert_eq!(
2273 result.nulls().unwrap().iter().collect::<Vec<_>>(),
2274 vec![true, false, false, true]
2275 );
2276 }
2277
2278 #[test]
2282 fn test_take_fixed_size_binary_with_nulls_indices_not_optimized_length() {
2283 let fsb = FixedSizeBinaryArray::try_from_sparse_iter_with_size(
2284 [
2285 Some(vec![0x01, 0x01, 0x01, 0x01, 0x01]),
2286 Some(vec![0x02, 0x02, 0x02, 0x02, 0x01]),
2287 Some(vec![0x03, 0x03, 0x03, 0x03, 0x01]),
2288 Some(vec![0x04, 0x04, 0x04, 0x04, 0x01]),
2289 ]
2290 .into_iter(),
2291 5,
2292 )
2293 .unwrap();
2294
2295 let indices = UInt32Array::from(vec![Some(0), None, None, Some(3)]);
2297
2298 let result = take_fixed_size_binary(&fsb, &indices, 5).unwrap();
2299 assert_eq!(result.len(), 4);
2300 assert_eq!(result.null_count(), 2);
2301 assert_eq!(
2302 result.nulls().unwrap().iter().collect::<Vec<_>>(),
2303 vec![true, false, false, true]
2304 );
2305 }
2306
2307 #[test]
2308 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2309 fn test_take_list_out_of_bounds() {
2310 let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 3]).into_data();
2312 let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]);
2314 let list_data_type =
2316 DataType::List(Arc::new(Field::new_list_field(DataType::Int32, false)));
2317 let list_data = ArrayData::builder(list_data_type)
2318 .len(3)
2319 .add_buffer(value_offsets)
2320 .add_child_data(value_data)
2321 .build()
2322 .unwrap();
2323 let list_array = ListArray::from(list_data);
2324
2325 let index = UInt32Array::from(vec![1000]);
2326
2327 take(&list_array, &index, None).unwrap();
2330 }
2331
2332 #[test]
2333 fn test_take_map() {
2334 let values = Int32Array::from(vec![1, 2, 3, 4]);
2335 let array =
2336 MapArray::new_from_strings(vec!["a", "b", "c", "a"].into_iter(), &values, &[0, 3, 4])
2337 .unwrap();
2338
2339 let index = UInt32Array::from(vec![0]);
2340
2341 let result = take(&array, &index, None).unwrap();
2342 let expected: ArrayRef = Arc::new(
2343 MapArray::new_from_strings(
2344 vec!["a", "b", "c"].into_iter(),
2345 &values.slice(0, 3),
2346 &[0, 3],
2347 )
2348 .unwrap(),
2349 );
2350 assert_eq!(&expected, &result);
2351 }
2352
2353 #[test]
2354 fn test_take_struct() {
2355 let array = create_test_struct(vec![
2356 Some((Some(true), Some(42))),
2357 Some((Some(false), Some(28))),
2358 Some((Some(false), Some(19))),
2359 Some((Some(true), Some(31))),
2360 None,
2361 ]);
2362
2363 let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]);
2364 let actual = take(&array, &index, None).unwrap();
2365 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2366 assert_eq!(index.len(), actual.len());
2367 assert_eq!(1, actual.null_count());
2368
2369 let expected = create_test_struct(vec![
2370 Some((Some(true), Some(42))),
2371 Some((Some(true), Some(31))),
2372 Some((Some(false), Some(28))),
2373 Some((Some(true), Some(42))),
2374 Some((Some(false), Some(19))),
2375 None,
2376 ]);
2377
2378 assert_eq!(&expected, actual);
2379
2380 let nulls = NullBuffer::from(&[false, true, false, true, false, true]);
2381 let empty_struct_arr = StructArray::new_empty_fields(6, Some(nulls));
2382 let index = UInt32Array::from(vec![0, 2, 1, 4]);
2383 let actual = take(&empty_struct_arr, &index, None).unwrap();
2384
2385 let expected_nulls = NullBuffer::from(&[false, false, true, false]);
2386 let expected_struct_arr = StructArray::new_empty_fields(4, Some(expected_nulls));
2387 assert_eq!(&expected_struct_arr, actual.as_struct());
2388 }
2389
2390 #[test]
2391 fn test_take_struct_with_null_indices() {
2392 let array = create_test_struct(vec![
2393 Some((Some(true), Some(42))),
2394 Some((Some(false), Some(28))),
2395 Some((Some(false), Some(19))),
2396 Some((Some(true), Some(31))),
2397 None,
2398 ]);
2399
2400 let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]);
2401 let actual = take(&array, &index, None).unwrap();
2402 let actual: &StructArray = actual.as_any().downcast_ref::<StructArray>().unwrap();
2403 assert_eq!(index.len(), actual.len());
2404 assert_eq!(3, actual.null_count()); let expected = create_test_struct(vec![
2407 None,
2408 Some((Some(true), Some(31))),
2409 Some((Some(false), Some(28))),
2410 None,
2411 Some((Some(true), Some(42))),
2412 None,
2413 ]);
2414
2415 assert_eq!(&expected, actual);
2416 }
2417
2418 #[test]
2419 fn test_take_out_of_bounds() {
2420 let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(6)]);
2421 let take_opt = TakeOptions { check_bounds: true };
2422
2423 let result = test_take_primitive_arrays::<Int64Type>(
2425 vec![Some(0), None, Some(2), Some(3), None],
2426 &index,
2427 Some(take_opt),
2428 vec![None],
2429 );
2430 assert!(result.is_err());
2431 }
2432
2433 #[test]
2434 #[should_panic(expected = "index out of bounds: the len is 4 but the index is 1000")]
2435 fn test_take_out_of_bounds_panic() {
2436 let index = UInt32Array::from(vec![Some(1000)]);
2437
2438 test_take_primitive_arrays::<Int64Type>(
2439 vec![Some(0), Some(1), Some(2), Some(3)],
2440 &index,
2441 None,
2442 vec![None],
2443 )
2444 .unwrap();
2445 }
2446
2447 #[test]
2448 fn test_null_array_smaller_than_indices() {
2449 let values = NullArray::new(2);
2450 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2451
2452 let result = take(&values, &indices, None).unwrap();
2453 let expected: ArrayRef = Arc::new(NullArray::new(3));
2454 assert_eq!(&result, &expected);
2455 }
2456
2457 #[test]
2458 fn test_null_array_larger_than_indices() {
2459 let values = NullArray::new(5);
2460 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2461
2462 let result = take(&values, &indices, None).unwrap();
2463 let expected: ArrayRef = Arc::new(NullArray::new(3));
2464 assert_eq!(&result, &expected);
2465 }
2466
2467 #[test]
2468 fn test_null_array_indices_out_of_bounds() {
2469 let values = NullArray::new(5);
2470 let indices = UInt32Array::from(vec![Some(0), None, Some(15)]);
2471
2472 let result = take(&values, &indices, Some(TakeOptions { check_bounds: true }));
2473 assert_eq!(
2474 result.unwrap_err().to_string(),
2475 "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries"
2476 );
2477 }
2478
2479 #[test]
2480 fn test_take_dict() {
2481 let mut dict_builder = StringDictionaryBuilder::<Int16Type>::new();
2482
2483 dict_builder.append("foo").unwrap();
2484 dict_builder.append("bar").unwrap();
2485 dict_builder.append("").unwrap();
2486 dict_builder.append_null();
2487 dict_builder.append("foo").unwrap();
2488 dict_builder.append("bar").unwrap();
2489 dict_builder.append("bar").unwrap();
2490 dict_builder.append("foo").unwrap();
2491
2492 let array = dict_builder.finish();
2493 let dict_values = array.values().clone();
2494 let dict_values = dict_values.as_any().downcast_ref::<StringArray>().unwrap();
2495
2496 let indices = UInt32Array::from(vec![
2497 Some(0), Some(7), None, Some(5), Some(6), Some(2), Some(3), ]);
2505
2506 let result = take(&array, &indices, None).unwrap();
2507 let result = result
2508 .as_any()
2509 .downcast_ref::<DictionaryArray<Int16Type>>()
2510 .unwrap();
2511
2512 let result_values: StringArray = result.values().to_data().into();
2513
2514 let expected_values = StringArray::from(vec!["foo", "bar", ""]);
2516 assert_eq!(&expected_values, dict_values);
2517 assert_eq!(&expected_values, &result_values);
2518
2519 let expected_keys = Int16Array::from(vec![
2520 Some(0),
2521 Some(0),
2522 None,
2523 Some(1),
2524 Some(1),
2525 Some(2),
2526 None,
2527 ]);
2528 assert_eq!(result.keys(), &expected_keys);
2529 }
2530
2531 fn build_generic_list<S, T>(data: Vec<Option<Vec<T::Native>>>) -> GenericListArray<S>
2532 where
2533 S: OffsetSizeTrait + 'static,
2534 T: ArrowPrimitiveType,
2535 PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
2536 {
2537 GenericListArray::from_iter_primitive::<T, _, _>(
2538 data.iter()
2539 .map(|x| x.as_ref().map(|x| x.iter().map(|x| Some(*x)))),
2540 )
2541 }
2542
2543 fn test_take_sliced_list_generic<S: OffsetSizeTrait + 'static>() {
2544 let list = build_generic_list::<S, Int32Type>(vec![
2545 Some(vec![0, 1]),
2546 Some(vec![2, 3, 4]),
2547 None,
2548 Some(vec![]),
2549 Some(vec![5, 6]),
2550 Some(vec![7]),
2551 ]);
2552 let sliced = list.slice(1, 4);
2553 let indices = UInt32Array::from(vec![Some(3), Some(0), None, Some(2), Some(1)]);
2554
2555 let taken = take(&sliced, &indices, None).unwrap();
2556 let taken = taken.as_list::<S>();
2557
2558 let expected = build_generic_list::<S, Int32Type>(vec![
2559 Some(vec![5, 6]),
2560 Some(vec![2, 3, 4]),
2561 None,
2562 Some(vec![]),
2563 None,
2564 ]);
2565
2566 assert_eq!(taken, &expected);
2567 }
2568
2569 fn test_take_sliced_list_with_value_nulls_generic<S: OffsetSizeTrait + 'static>() {
2570 let list = GenericListArray::<S>::from_iter_primitive::<Int32Type, _, _>(vec![
2571 Some(vec![Some(10)]),
2572 Some(vec![None, Some(1)]),
2573 None,
2574 Some(vec![Some(2), None]),
2575 Some(vec![]),
2576 Some(vec![Some(3)]),
2577 ]);
2578 let sliced = list.slice(1, 4);
2579 let indices = UInt32Array::from(vec![Some(2), Some(0), None, Some(3), Some(1)]);
2580
2581 let taken = take(&sliced, &indices, None).unwrap();
2582 let taken = taken.as_list::<S>();
2583
2584 let expected = GenericListArray::<S>::from_iter_primitive::<Int32Type, _, _>(vec![
2585 Some(vec![Some(2), None]),
2586 Some(vec![None, Some(1)]),
2587 None,
2588 Some(vec![]),
2589 None,
2590 ]);
2591
2592 assert_eq!(taken, &expected);
2593 }
2594
2595 #[test]
2596 fn test_take_sliced_list() {
2597 test_take_sliced_list_generic::<i32>();
2598 }
2599
2600 #[test]
2601 fn test_take_sliced_large_list() {
2602 test_take_sliced_list_generic::<i64>();
2603 }
2604
2605 #[test]
2606 fn test_take_sliced_list_with_value_nulls() {
2607 test_take_sliced_list_with_value_nulls_generic::<i32>();
2608 }
2609
2610 #[test]
2611 fn test_take_sliced_large_list_with_value_nulls() {
2612 test_take_sliced_list_with_value_nulls_generic::<i64>();
2613 }
2614
2615 #[test]
2616 fn test_take_runs() {
2617 let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];
2618
2619 let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2620 builder.extend(logical_array.into_iter().map(Some));
2621 let run_array = builder.finish();
2622
2623 let take_indices: PrimitiveArray<Int32Type> =
2624 vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
2625
2626 let take_out = take_run(&run_array, &take_indices).unwrap();
2627
2628 assert_eq!(take_out.len(), 7);
2629 assert_eq!(take_out.run_ends().len(), 7);
2630 assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
2631
2632 let take_out_values = take_out.values().as_primitive::<Int32Type>();
2633 assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
2634 }
2635
2636 #[test]
2637 fn test_take_runs_sliced() {
2638 let logical_array: Vec<i32> = vec![1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5, 6, 6];
2639
2640 let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2641 builder.extend(logical_array.into_iter().map(Some));
2642 let run_array = builder.finish();
2643
2644 let run_array = run_array.slice(4, 6); let take_indices: PrimitiveArray<Int32Type> = vec![0, 5, 5, 1, 4].into_iter().collect();
2647
2648 let result = take_run(&run_array, &take_indices).unwrap();
2649 let result = result.downcast::<Int32Array>().unwrap();
2650
2651 let expected = vec![3, 5, 5, 3, 4];
2652 let actual = result.into_iter().flatten().collect::<Vec<_>>();
2653
2654 assert_eq!(expected, actual);
2655 }
2656
2657 #[test]
2658 fn test_take_value_index_from_fixed_list() {
2659 let list = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
2660 vec![
2661 Some(vec![Some(1), Some(2), None]),
2662 Some(vec![Some(4), None, Some(6)]),
2663 None,
2664 Some(vec![None, Some(8), Some(9)]),
2665 ],
2666 3,
2667 );
2668
2669 let indices = UInt32Array::from(vec![2, 1, 0]);
2670 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2671
2672 assert_eq!(indexed, UInt32Array::from(vec![6, 7, 8, 3, 4, 5, 0, 1, 2]));
2673
2674 let indices = UInt32Array::from(vec![3, 2, 1, 2, 0]);
2675 let indexed = take_value_indices_from_fixed_size_list(&list, &indices, 3).unwrap();
2676
2677 assert_eq!(
2678 indexed,
2679 UInt32Array::from(vec![9, 10, 11, 6, 7, 8, 3, 4, 5, 6, 7, 8, 0, 1, 2])
2680 );
2681 }
2682
2683 #[test]
2684 fn test_take_null_indices() {
2685 let indices = Int32Array::new(
2687 vec![1, 2, 400, 400].into(),
2688 Some(NullBuffer::from(vec![true, true, false, false])),
2689 );
2690 let values = Int32Array::from(vec![1, 23, 4, 5]);
2691 let r = take(&values, &indices, None).unwrap();
2692 let values = r
2693 .as_primitive::<Int32Type>()
2694 .into_iter()
2695 .collect::<Vec<_>>();
2696 assert_eq!(&values, &[Some(23), Some(4), None, None])
2697 }
2698
2699 #[test]
2700 fn test_take_fixed_size_list_null_indices() {
2701 let indices = Int32Array::from_iter([Some(0), None]);
2702 let values = Arc::new(Int32Array::from(vec![0, 1, 2, 3]));
2703 let arr_field = Arc::new(Field::new_list_field(values.data_type().clone(), true));
2704 let values = FixedSizeListArray::try_new(arr_field, 2, values, None).unwrap();
2705
2706 let r = take(&values, &indices, None).unwrap();
2707 let values = r
2708 .as_fixed_size_list()
2709 .values()
2710 .as_primitive::<Int32Type>()
2711 .into_iter()
2712 .collect::<Vec<_>>();
2713 assert_eq!(values, &[Some(0), Some(1), None, None])
2714 }
2715
2716 #[test]
2717 fn test_take_bytes_null_indices() {
2718 let indices = Int32Array::new(
2719 vec![0, 1, 400, 400].into(),
2720 Some(NullBuffer::from_iter(vec![true, true, false, false])),
2721 );
2722 let values = StringArray::from(vec![Some("foo"), None]);
2723 let r = take(&values, &indices, None).unwrap();
2724 let values = r.as_string::<i32>().iter().collect::<Vec<_>>();
2725 assert_eq!(&values, &[Some("foo"), None, None, None])
2726 }
2727
2728 #[test]
2729 fn test_take_union_sparse() {
2730 let structs = create_test_struct(vec![
2731 Some((Some(true), Some(42))),
2732 Some((Some(false), Some(28))),
2733 Some((Some(false), Some(19))),
2734 Some((Some(true), Some(31))),
2735 None,
2736 ]);
2737 let strings = StringArray::from(vec![Some("a"), None, Some("c"), None, Some("d")]);
2738 let type_ids = [1; 5].into_iter().collect::<ScalarBuffer<i8>>();
2739
2740 let union_fields = [
2741 (
2742 0,
2743 Arc::new(Field::new("f1", structs.data_type().clone(), true)),
2744 ),
2745 (
2746 1,
2747 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2748 ),
2749 ]
2750 .into_iter()
2751 .collect();
2752 let children = vec![Arc::new(structs) as Arc<dyn Array>, Arc::new(strings)];
2753 let array = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();
2754
2755 let indices = vec![0, 3, 1, 0, 2, 4];
2756 let index = UInt32Array::from(indices.clone());
2757 let actual = take(&array, &index, None).unwrap();
2758 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2759 let strings = actual.child(1);
2760 let strings = strings.as_any().downcast_ref::<StringArray>().unwrap();
2761
2762 let actual = strings.iter().collect::<Vec<_>>();
2763 let expected = vec![Some("a"), None, None, Some("a"), Some("c"), Some("d")];
2764 assert_eq!(expected, actual);
2765 }
2766
2767 #[test]
2768 fn test_take_union_dense() {
2769 let type_ids = vec![0, 1, 1, 0, 0, 1, 0];
2770 let offsets = vec![0, 0, 1, 1, 2, 2, 3];
2771 let ints = vec![10, 20, 30, 40];
2772 let strings = vec![Some("a"), None, Some("c"), Some("d")];
2773
2774 let indices = vec![0, 3, 1, 0, 2, 4];
2775
2776 let taken_type_ids = vec![0, 0, 1, 0, 1, 0];
2777 let taken_offsets = vec![0, 1, 0, 2, 1, 3];
2778 let taken_ints = vec![10, 20, 10, 30];
2779 let taken_strings = vec![Some("a"), None];
2780
2781 let type_ids = <ScalarBuffer<i8>>::from(type_ids);
2782 let offsets = <ScalarBuffer<i32>>::from(offsets);
2783 let ints = UInt32Array::from(ints);
2784 let strings = StringArray::from(strings);
2785
2786 let union_fields = [
2787 (
2788 0,
2789 Arc::new(Field::new("f1", ints.data_type().clone(), true)),
2790 ),
2791 (
2792 1,
2793 Arc::new(Field::new("f2", strings.data_type().clone(), true)),
2794 ),
2795 ]
2796 .into_iter()
2797 .collect();
2798
2799 let array = UnionArray::try_new(
2800 union_fields,
2801 type_ids,
2802 Some(offsets),
2803 vec![Arc::new(ints), Arc::new(strings)],
2804 )
2805 .unwrap();
2806
2807 let index = UInt32Array::from(indices);
2808
2809 let actual = take(&array, &index, None).unwrap();
2810 let actual = actual.as_any().downcast_ref::<UnionArray>().unwrap();
2811
2812 assert_eq!(actual.offsets(), Some(&ScalarBuffer::from(taken_offsets)));
2813 assert_eq!(actual.type_ids(), &ScalarBuffer::from(taken_type_ids));
2814 assert_eq!(
2815 UInt32Array::from(actual.child(0).to_data()),
2816 UInt32Array::from(taken_ints)
2817 );
2818 assert_eq!(
2819 StringArray::from(actual.child(1).to_data()),
2820 StringArray::from(taken_strings)
2821 );
2822 }
2823
2824 #[test]
2825 fn test_take_union_dense_using_builder() {
2826 let mut builder = UnionBuilder::new_dense();
2827
2828 builder.append::<Int32Type>("a", 1).unwrap();
2829 builder.append::<Float64Type>("b", 3.0).unwrap();
2830 builder.append::<Int32Type>("a", 4).unwrap();
2831 builder.append::<Int32Type>("a", 5).unwrap();
2832 builder.append::<Float64Type>("b", 2.0).unwrap();
2833
2834 let union = builder.build().unwrap();
2835
2836 let indices = UInt32Array::from(vec![2, 0, 1, 2]);
2837
2838 let mut builder = UnionBuilder::new_dense();
2839
2840 builder.append::<Int32Type>("a", 4).unwrap();
2841 builder.append::<Int32Type>("a", 1).unwrap();
2842 builder.append::<Float64Type>("b", 3.0).unwrap();
2843 builder.append::<Int32Type>("a", 4).unwrap();
2844
2845 let taken = builder.build().unwrap();
2846
2847 assert_eq!(
2848 taken.to_data(),
2849 take(&union, &indices, None).unwrap().to_data()
2850 );
2851 }
2852
2853 #[test]
2854 fn test_take_union_dense_all_match_issue_6206() {
2855 let fields = UnionFields::from_fields(vec![Field::new("a", DataType::Int64, false)]);
2856 let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
2857
2858 let array = UnionArray::try_new(
2859 fields,
2860 ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
2861 Some(ScalarBuffer::from_iter(0_i32..5)),
2862 vec![ints],
2863 )
2864 .unwrap();
2865
2866 let indicies = Int64Array::from(vec![0, 2, 4]);
2867 let array = take(&array, &indicies, None).unwrap();
2868 assert_eq!(array.len(), 3);
2869 }
2870
2871 fn offset_overflow_fixture() -> (StringArray, usize) {
2876 let value_len = 1_000_000;
2877 let values = StringArray::from(vec![Some("a".repeat(value_len))]);
2878 let n = i32::MAX as usize / value_len + 1;
2879 (values, n)
2880 }
2881
2882 #[test]
2883 fn test_take_bytes_offset_overflow() {
2884 let (values, n) = offset_overflow_fixture();
2885 let indices = Int32Array::from(vec![0; n]);
2886 assert!(matches!(
2887 take(&values, &indices, None),
2888 Err(ArrowError::OffsetOverflowError(_))
2889 ));
2890 }
2891
2892 #[test]
2895 fn test_take_bytes_offset_overflow_nullable() {
2896 let (values, n) = offset_overflow_fixture();
2897 let validity =
2900 NullBuffer::from_iter(std::iter::once(false).chain(std::iter::repeat_n(true, n)));
2901 let indices = Int32Array::new(vec![0i32; n + 1].into(), Some(validity));
2902
2903 assert!(matches!(
2904 take(&values, &indices, None),
2905 Err(ArrowError::OffsetOverflowError(_))
2906 ));
2907 }
2908
2909 #[test]
2910 fn test_take_run_empty_indices() {
2911 let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
2912 builder.extend([Some(1), Some(1), Some(2), Some(2)]);
2913 let run_array = builder.finish();
2914
2915 let logical_indices: PrimitiveArray<Int32Type> = PrimitiveArray::from(Vec::<i32>::new());
2916
2917 let result = take_impl(&run_array, &logical_indices).expect("take_run with empty indices");
2918
2919 assert_eq!(result.len(), 0);
2921 assert_eq!(result.null_count(), 0);
2922
2923 let run_result = result
2926 .as_any()
2927 .downcast_ref::<RunArray<Int32Type>>()
2928 .expect("result should be a RunArray");
2929 assert_eq!(run_result.run_ends().len(), 0);
2930 assert_eq!(run_result.values().len(), 0);
2931 }
2932}