Skip to main content

polars_core/chunked_array/ops/
gather.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use std::sync::OnceLock;
3
4use arrow::bitmap::Bitmap;
5use arrow::bitmap::bitmask::BitMask;
6use polars_compute::gather::take_unchecked;
7use polars_error::polars_ensure;
8use polars_utils::index::check_bounds;
9
10use crate::prelude::*;
11use crate::series::IsSorted;
12use crate::utils::Container;
13
14pub fn check_bounds_nulls(idx: &PrimitiveArray<IdxSize>, len: IdxSize) -> PolarsResult<()> {
15    let mask = BitMask::from_bitmap(idx.validity().unwrap());
16
17    // We iterate in chunks to make the inner loop branch-free.
18    for (block_idx, block) in idx.values().chunks(32).enumerate() {
19        let mut in_bounds = 0;
20        for (i, x) in block.iter().enumerate() {
21            in_bounds |= ((*x < len) as u32) << i;
22        }
23        let m = mask.get_u32(32 * block_idx);
24        polars_ensure!(m == m & in_bounds, ComputeError: "gather indices are out of bounds");
25    }
26    Ok(())
27}
28
29pub fn check_bounds_ca(indices: &IdxCa, len: IdxSize) -> PolarsResult<()> {
30    let all_valid = indices.downcast_iter().all(|a| {
31        if a.null_count() == 0 {
32            check_bounds(a.values(), len).is_ok()
33        } else {
34            check_bounds_nulls(a, len).is_ok()
35        }
36    });
37    polars_ensure!(all_valid, OutOfBounds: "gather indices are out of bounds");
38    Ok(())
39}
40
41impl<T: PolarsDataType, I: AsRef<[IdxSize]> + ?Sized> ChunkTake<I> for ChunkedArray<T>
42where
43    ChunkedArray<T>: ChunkTakeUnchecked<I>,
44{
45    /// Gather values from ChunkedArray by index.
46    fn take(&self, indices: &I) -> PolarsResult<Self> {
47        check_bounds(indices.as_ref(), self.len() as IdxSize)?;
48
49        // SAFETY: we just checked the indices are valid.
50        Ok(unsafe { self.take_unchecked(indices) })
51    }
52}
53
54impl<T: PolarsDataType> ChunkTake<IdxCa> for ChunkedArray<T>
55where
56    ChunkedArray<T>: ChunkTakeUnchecked<IdxCa>,
57{
58    /// Gather values from ChunkedArray by index.
59    fn take(&self, indices: &IdxCa) -> PolarsResult<Self> {
60        check_bounds_ca(indices, self.len() as IdxSize)?;
61
62        // SAFETY: we just checked the indices are valid.
63        Ok(unsafe { self.take_unchecked(indices) })
64    }
65}
66
67/// Computes cumulative lengths for efficient branchless binary search
68/// lookup. The first element is always 0, and the last length of arrs
69/// is always ignored (as we already checked that all indices are
70/// in-bounds we don't need to check against the last length).
71fn cumulative_lengths<A: StaticArray>(arrs: &[&A]) -> Vec<IdxSize> {
72    let mut ret = Vec::with_capacity(arrs.len());
73    let mut cumsum: IdxSize = 0;
74    for arr in arrs {
75        ret.push(cumsum);
76        cumsum = cumsum.checked_add(arr.len().try_into().unwrap()).unwrap();
77    }
78    ret
79}
80
81#[rustfmt::skip]
82#[inline]
83fn resolve_chunked_idx(idx: IdxSize, cumlens: &[IdxSize]) -> (usize, usize) {
84    let chunk_idx = cumlens.partition_point(|cl| idx >= *cl) - 1;
85    (chunk_idx, (idx - cumlens[chunk_idx]) as usize)
86}
87
88#[inline]
89unsafe fn target_value_unchecked<'a, A: StaticArray>(
90    targets: &[&'a A],
91    cumlens: &[IdxSize],
92    idx: IdxSize,
93) -> A::ValueT<'a> {
94    let (chunk_idx, arr_idx) = resolve_chunked_idx(idx, cumlens);
95    let arr = targets.get_unchecked(chunk_idx);
96    arr.value_unchecked(arr_idx)
97}
98
99#[inline]
100unsafe fn target_get_unchecked<'a, A: StaticArray>(
101    targets: &[&'a A],
102    cumlens: &[IdxSize],
103    idx: IdxSize,
104) -> Option<A::ValueT<'a>> {
105    let (chunk_idx, arr_idx) = resolve_chunked_idx(idx, cumlens);
106    let arr = targets.get_unchecked(chunk_idx);
107    arr.get_unchecked(arr_idx)
108}
109
110unsafe fn gather_idx_array_unchecked<A: StaticArray>(
111    dtype: ArrowDataType,
112    targets: &[&A],
113    has_nulls: bool,
114    indices: &[IdxSize],
115) -> A {
116    let it = indices.iter().copied();
117    if targets.len() == 1 {
118        let target = targets.first().unwrap();
119        if has_nulls {
120            it.map(|i| target.get_unchecked(i as usize))
121                .collect_arr_trusted_with_dtype(dtype)
122        } else if let Some(sl) = target.as_slice() {
123            // Avoid the Arc overhead from value_unchecked.
124            it.map(|i| sl.get_unchecked(i as usize).clone())
125                .collect_arr_trusted_with_dtype(dtype)
126        } else {
127            it.map(|i| target.value_unchecked(i as usize))
128                .collect_arr_trusted_with_dtype(dtype)
129        }
130    } else {
131        let cumlens = cumulative_lengths(targets);
132        if has_nulls {
133            it.map(|i| target_get_unchecked(targets, &cumlens, i))
134                .collect_arr_trusted_with_dtype(dtype)
135        } else {
136            it.map(|i| target_value_unchecked(targets, &cumlens, i))
137                .collect_arr_trusted_with_dtype(dtype)
138        }
139    }
140}
141
142impl<T: PolarsDataType, I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ChunkedArray<T>
143where
144    T: PolarsDataType<HasViews = FalseT, IsStruct = FalseT, IsNested = FalseT>,
145{
146    /// Gather values from ChunkedArray by index.
147    unsafe fn take_unchecked(&self, indices: &I) -> Self {
148        let ca = self;
149        let targets: Vec<_> = ca.downcast_iter().collect();
150        let arr = gather_idx_array_unchecked(
151            ca.dtype().to_arrow(CompatLevel::newest()),
152            &targets,
153            ca.null_count() > 0,
154            indices.as_ref(),
155        );
156        ChunkedArray::from_chunk_iter_like(ca, [arr])
157    }
158}
159
160pub fn _update_gather_sorted_flag(sorted_arr: IsSorted, sorted_idx: IsSorted) -> IsSorted {
161    use crate::series::IsSorted::*;
162    match (sorted_arr, sorted_idx) {
163        (_, Not) => Not,
164        (Not, _) => Not,
165        (Ascending, Ascending) => Ascending,
166        (Ascending, Descending) => Descending,
167        (Descending, Ascending) => Descending,
168        (Descending, Descending) => Ascending,
169    }
170}
171
172impl<T: PolarsDataType> ChunkTakeUnchecked<IdxCa> for ChunkedArray<T>
173where
174    T: PolarsDataType<HasViews = FalseT, IsStruct = FalseT, IsNested = FalseT>,
175{
176    /// Gather values from ChunkedArray by index.
177    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
178        let ca = self;
179        let targets_have_nulls = ca.null_count() > 0;
180        let targets: Vec<_> = ca.downcast_iter().collect();
181
182        let chunks = indices.downcast_iter().map(|idx_arr| {
183            let dtype = ca.dtype().to_arrow(CompatLevel::newest());
184            if idx_arr.null_count() == 0 {
185                gather_idx_array_unchecked(dtype, &targets, targets_have_nulls, idx_arr.values())
186            } else if targets.len() == 1 {
187                let target = targets.first().unwrap();
188                if targets_have_nulls {
189                    idx_arr
190                        .iter()
191                        .map(|i| target.get_unchecked(*i? as usize))
192                        .collect_arr_trusted_with_dtype(dtype)
193                } else {
194                    idx_arr
195                        .iter()
196                        .map(|i| Some(target.value_unchecked(*i? as usize)))
197                        .collect_arr_trusted_with_dtype(dtype)
198                }
199            } else {
200                let cumlens = cumulative_lengths(&targets);
201                if targets_have_nulls {
202                    idx_arr
203                        .iter()
204                        .map(|i| target_get_unchecked(&targets, &cumlens, *i?))
205                        .collect_arr_trusted_with_dtype(dtype)
206                } else {
207                    idx_arr
208                        .iter()
209                        .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?)))
210                        .collect_arr_trusted_with_dtype(dtype)
211                }
212            }
213        });
214
215        let mut out = ChunkedArray::from_chunk_iter_like(ca, chunks);
216        let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag());
217
218        out.set_sorted_flag(sorted_flag);
219        out
220    }
221}
222
223impl ChunkTakeUnchecked<IdxCa> for BinaryChunked {
224    /// Gather values from ChunkedArray by index.
225    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
226        let ca = self;
227        let targets_have_nulls = ca.null_count() > 0;
228        let targets: Vec<_> = ca.downcast_iter().collect();
229
230        let chunks = indices.downcast_iter().map(|idx_arr| {
231            let dtype = ca.dtype().to_arrow(CompatLevel::newest());
232            if targets.len() == 1 {
233                let target = targets.first().unwrap();
234                take_unchecked(&**target, idx_arr)
235            } else {
236                let cumlens = cumulative_lengths(&targets);
237                if targets_have_nulls {
238                    let arr: BinaryViewArray = idx_arr
239                        .iter()
240                        .map(|i| target_get_unchecked(&targets, &cumlens, *i?))
241                        .collect_arr_trusted_with_dtype(dtype);
242                    arr.to_boxed()
243                } else {
244                    let arr: BinaryViewArray = idx_arr
245                        .iter()
246                        .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?)))
247                        .collect_arr_trusted_with_dtype(dtype);
248                    arr.to_boxed()
249                }
250            }
251        });
252
253        let mut out = ChunkedArray::from_chunks(ca.name().clone(), chunks.collect());
254        let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag());
255        out.set_sorted_flag(sorted_flag);
256        out
257    }
258}
259
260impl ChunkTakeUnchecked<IdxCa> for StringChunked {
261    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
262        let ca = self;
263        let targets_have_nulls = ca.null_count() > 0;
264        let targets: Vec<_> = ca.downcast_iter().collect();
265
266        let chunks = indices.downcast_iter().map(|idx_arr| {
267            let dtype = ca.dtype().to_arrow(CompatLevel::newest());
268            if targets.len() == 1 {
269                let target = targets.first().unwrap();
270                take_unchecked(&**target, idx_arr)
271            } else {
272                let cumlens = cumulative_lengths(&targets);
273                if targets_have_nulls {
274                    let arr: Utf8ViewArray = idx_arr
275                        .iter()
276                        .map(|i| target_get_unchecked(&targets, &cumlens, *i?))
277                        .collect_arr_trusted_with_dtype(dtype);
278                    arr.to_boxed()
279                } else {
280                    let arr: Utf8ViewArray = idx_arr
281                        .iter()
282                        .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?)))
283                        .collect_arr_trusted_with_dtype(dtype);
284                    arr.to_boxed()
285                }
286            }
287        });
288
289        let mut out = ChunkedArray::from_chunks(ca.name().clone(), chunks.collect());
290        let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag());
291        out.set_sorted_flag(sorted_flag);
292        out
293    }
294}
295
296impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for BinaryChunked {
297    /// Gather values from ChunkedArray by index.
298    unsafe fn take_unchecked(&self, indices: &I) -> Self {
299        let indices = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
300        self.take_unchecked(&indices)
301    }
302}
303
304impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for StringChunked {
305    /// Gather values from ChunkedArray by index.
306    unsafe fn take_unchecked(&self, indices: &I) -> Self {
307        let indices = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
308        self.take_unchecked(&indices)
309    }
310}
311
312#[cfg(feature = "dtype-struct")]
313impl ChunkTakeUnchecked<IdxCa> for StructChunked {
314    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
315        let a = self.rechunk();
316        let index = indices.rechunk();
317
318        let chunks = a
319            .downcast_iter()
320            .zip(index.downcast_iter())
321            .map(|(arr, idx)| take_unchecked(arr, idx))
322            .collect::<Vec<_>>();
323        self.copy_with_chunks(chunks)
324    }
325}
326
327#[cfg(feature = "dtype-struct")]
328impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for StructChunked {
329    unsafe fn take_unchecked(&self, indices: &I) -> Self {
330        let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
331        self.take_unchecked(&idx)
332    }
333}
334
335impl IdxCa {
336    pub fn with_nullable_idx<T, F: FnOnce(&IdxCa) -> T>(idx: &[NullableIdxSize], f: F) -> T {
337        let validity: Bitmap = idx.iter().map(|idx| !idx.is_null_idx()).collect_trusted();
338        let idx = bytemuck::cast_slice::<_, IdxSize>(idx);
339        let arr = unsafe { arrow::ffi::mmap::slice(idx) };
340        let arr = arr.with_validity_typed(Some(validity));
341        let ca = IdxCa::with_chunk(PlSmallStr::EMPTY, arr);
342
343        f(&ca)
344    }
345}
346
347#[cfg(feature = "dtype-array")]
348impl ChunkTakeUnchecked<IdxCa> for ArrayChunked {
349    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
350        // Taking nested types by value is expensive, so at a certain len[n] ratio
351        // we rechunk first, so that we can memcopy internally
352        if self.n_chunks() > 1 && should_rechunk(self.len(), indices.len()) {
353            let chunks = vec![take_unchecked(
354                self.rechunk().downcast_as_array(),
355                indices.rechunk().downcast_as_array(),
356            )];
357            return self.copy_with_chunks(chunks);
358        }
359        let ca = self;
360        let targets_have_nulls = ca.null_count() > 0;
361        let targets: Vec<_> = ca.downcast_iter().collect();
362
363        let chunks = indices.downcast_iter().map(|idx_arr| {
364            let dtype = ca.dtype().to_arrow(CompatLevel::newest());
365            if targets.len() == 1 {
366                let target = targets.first().unwrap();
367                take_unchecked(&**target, idx_arr)
368            } else {
369                let cumlens = cumulative_lengths(&targets);
370                if targets_have_nulls {
371                    let arr: FixedSizeListArray = idx_arr
372                        .iter()
373                        .map(|i| target_get_unchecked(&targets, &cumlens, *i?))
374                        .collect_arr_trusted_with_dtype(dtype);
375                    arr.to_boxed()
376                } else {
377                    let arr: FixedSizeListArray = idx_arr
378                        .iter()
379                        .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?)))
380                        .collect_arr_trusted_with_dtype(dtype);
381                    arr.to_boxed()
382                }
383            }
384        });
385
386        let mut out = ca.with_chunks(chunks.collect());
387        let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag());
388        out.set_sorted_flag(sorted_flag);
389        out
390    }
391}
392
393#[cfg(feature = "dtype-array")]
394impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ArrayChunked {
395    unsafe fn take_unchecked(&self, indices: &I) -> Self {
396        let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
397        self.take_unchecked(&idx)
398    }
399}
400
401impl ChunkTakeUnchecked<IdxCa> for ListChunked {
402    unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
403        // Taking nested types by value is expensive, so at a certain len[n] ratio
404        // we rechunk first, so that we can memcopy internally
405        if self.n_chunks() > 1 && should_rechunk(self.len(), indices.len()) {
406            let chunks = vec![take_unchecked(
407                self.rechunk().downcast_as_array(),
408                indices.rechunk().downcast_as_array(),
409            )];
410            return self.copy_with_chunks(chunks);
411        }
412        let ca = self;
413        let targets_have_nulls = ca.null_count() > 0;
414        let targets: Vec<_> = ca.downcast_iter().collect();
415
416        let chunks = indices.downcast_iter().map(|idx_arr| {
417            let dtype = ca.dtype().to_arrow(CompatLevel::newest());
418            if targets.len() == 1 {
419                let target = targets.first().unwrap();
420                take_unchecked(&**target, idx_arr)
421            } else {
422                let cumlens = cumulative_lengths(&targets);
423                if targets_have_nulls {
424                    let arr: ListArray<i64> = idx_arr
425                        .iter()
426                        .map(|i| target_get_unchecked(&targets, &cumlens, *i?))
427                        .collect_arr_trusted_with_dtype(dtype);
428                    arr.to_boxed()
429                } else {
430                    let arr: ListArray<i64> = idx_arr
431                        .iter()
432                        .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?)))
433                        .collect_arr_trusted_with_dtype(dtype);
434                    arr.to_boxed()
435                }
436            }
437        });
438
439        let mut out = ca.with_chunks(chunks.collect());
440        let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag());
441        out.set_sorted_flag(sorted_flag);
442        out
443    }
444}
445
446impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ListChunked {
447    unsafe fn take_unchecked(&self, indices: &I) -> Self {
448        let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
449        self.take_unchecked(&idx)
450    }
451}
452
453fn should_rechunk(n_values: usize, n_indices: usize) -> bool {
454    n_indices > 0 && { (n_values / n_indices) > gather_ratio() }
455}
456
457fn gather_ratio() -> usize {
458    return *GATHER_RECHUNK_RATIO.get_or_init(|| {
459        const NAME: &str = "POLARS_GATHER_RECHUNK_RATIO";
460        std::env::var(NAME)
461            .map(|x| {
462                x.parse::<usize>()
463                    .unwrap_or_else(|_| panic!("invalid value for {NAME}: {x}"))
464            })
465            .unwrap_or(const { 64 })
466    });
467
468    static GATHER_RECHUNK_RATIO: OnceLock<usize> = OnceLock::new();
469}