polars_core/chunked_array/ops/
gather.rs1#![allow(unsafe_op_in_unsafe_fn)]
2use std::sync::OnceLock;
3
4use arrow::bitmap::Bitmap;
5use arrow::bitmap::bitmask::BitMask;
6use polars_compute::gather::take_unchecked;
7use polars_error::polars_ensure;
8use polars_utils::index::check_bounds;
9
10use crate::prelude::*;
11use crate::series::IsSorted;
12use crate::utils::Container;
13
14pub fn check_bounds_nulls(idx: &PrimitiveArray<IdxSize>, len: IdxSize) -> PolarsResult<()> {
15 let mask = BitMask::from_bitmap(idx.validity().unwrap());
16
17 for (block_idx, block) in idx.values().chunks(32).enumerate() {
19 let mut in_bounds = 0;
20 for (i, x) in block.iter().enumerate() {
21 in_bounds |= ((*x < len) as u32) << i;
22 }
23 let m = mask.get_u32(32 * block_idx);
24 polars_ensure!(m == m & in_bounds, ComputeError: "gather indices are out of bounds");
25 }
26 Ok(())
27}
28
29pub fn check_bounds_ca(indices: &IdxCa, len: IdxSize) -> PolarsResult<()> {
30 let all_valid = indices.downcast_iter().all(|a| {
31 if a.null_count() == 0 {
32 check_bounds(a.values(), len).is_ok()
33 } else {
34 check_bounds_nulls(a, len).is_ok()
35 }
36 });
37 polars_ensure!(all_valid, OutOfBounds: "gather indices are out of bounds");
38 Ok(())
39}
40
41impl<T: PolarsDataType, I: AsRef<[IdxSize]> + ?Sized> ChunkTake<I> for ChunkedArray<T>
42where
43 ChunkedArray<T>: ChunkTakeUnchecked<I>,
44{
45 fn take(&self, indices: &I) -> PolarsResult<Self> {
47 check_bounds(indices.as_ref(), self.len() as IdxSize)?;
48
49 Ok(unsafe { self.take_unchecked(indices) })
51 }
52}
53
54impl<T: PolarsDataType> ChunkTake<IdxCa> for ChunkedArray<T>
55where
56 ChunkedArray<T>: ChunkTakeUnchecked<IdxCa>,
57{
58 fn take(&self, indices: &IdxCa) -> PolarsResult<Self> {
60 check_bounds_ca(indices, self.len() as IdxSize)?;
61
62 Ok(unsafe { self.take_unchecked(indices) })
64 }
65}
66
67fn cumulative_lengths<A: StaticArray>(arrs: &[&A]) -> Vec<IdxSize> {
72 let mut ret = Vec::with_capacity(arrs.len());
73 let mut cumsum: IdxSize = 0;
74 for arr in arrs {
75 ret.push(cumsum);
76 cumsum = cumsum.checked_add(arr.len().try_into().unwrap()).unwrap();
77 }
78 ret
79}
80
81#[rustfmt::skip]
82#[inline]
83fn resolve_chunked_idx(idx: IdxSize, cumlens: &[IdxSize]) -> (usize, usize) {
84 let chunk_idx = cumlens.partition_point(|cl| idx >= *cl) - 1;
85 (chunk_idx, (idx - cumlens[chunk_idx]) as usize)
86}
87
88#[inline]
89unsafe fn target_value_unchecked<'a, A: StaticArray>(
90 targets: &[&'a A],
91 cumlens: &[IdxSize],
92 idx: IdxSize,
93) -> A::ValueT<'a> {
94 let (chunk_idx, arr_idx) = resolve_chunked_idx(idx, cumlens);
95 let arr = targets.get_unchecked(chunk_idx);
96 arr.value_unchecked(arr_idx)
97}
98
99#[inline]
100unsafe fn target_get_unchecked<'a, A: StaticArray>(
101 targets: &[&'a A],
102 cumlens: &[IdxSize],
103 idx: IdxSize,
104) -> Option<A::ValueT<'a>> {
105 let (chunk_idx, arr_idx) = resolve_chunked_idx(idx, cumlens);
106 let arr = targets.get_unchecked(chunk_idx);
107 arr.get_unchecked(arr_idx)
108}
109
110unsafe fn gather_idx_array_unchecked<A: StaticArray>(
111 dtype: ArrowDataType,
112 targets: &[&A],
113 has_nulls: bool,
114 indices: &[IdxSize],
115) -> A {
116 let it = indices.iter().copied();
117 if targets.len() == 1 {
118 let target = targets.first().unwrap();
119 if has_nulls {
120 it.map(|i| target.get_unchecked(i as usize))
121 .collect_arr_trusted_with_dtype(dtype)
122 } else if let Some(sl) = target.as_slice() {
123 it.map(|i| sl.get_unchecked(i as usize).clone())
125 .collect_arr_trusted_with_dtype(dtype)
126 } else {
127 it.map(|i| target.value_unchecked(i as usize))
128 .collect_arr_trusted_with_dtype(dtype)
129 }
130 } else {
131 let cumlens = cumulative_lengths(targets);
132 if has_nulls {
133 it.map(|i| target_get_unchecked(targets, &cumlens, i))
134 .collect_arr_trusted_with_dtype(dtype)
135 } else {
136 it.map(|i| target_value_unchecked(targets, &cumlens, i))
137 .collect_arr_trusted_with_dtype(dtype)
138 }
139 }
140}
141
142impl<T: PolarsDataType, I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ChunkedArray<T>
143where
144 T: PolarsDataType<HasViews = FalseT, IsStruct = FalseT, IsNested = FalseT>,
145{
146 unsafe fn take_unchecked(&self, indices: &I) -> Self {
148 let ca = self;
149 let targets: Vec<_> = ca.downcast_iter().collect();
150 let arr = gather_idx_array_unchecked(
151 ca.dtype().to_arrow(CompatLevel::newest()),
152 &targets,
153 ca.null_count() > 0,
154 indices.as_ref(),
155 );
156 ChunkedArray::from_chunk_iter_like(ca, [arr])
157 }
158}
159
160pub fn _update_gather_sorted_flag(sorted_arr: IsSorted, sorted_idx: IsSorted) -> IsSorted {
161 use crate::series::IsSorted::*;
162 match (sorted_arr, sorted_idx) {
163 (_, Not) => Not,
164 (Not, _) => Not,
165 (Ascending, Ascending) => Ascending,
166 (Ascending, Descending) => Descending,
167 (Descending, Ascending) => Descending,
168 (Descending, Descending) => Ascending,
169 }
170}
171
172impl<T: PolarsDataType> ChunkTakeUnchecked<IdxCa> for ChunkedArray<T>
173where
174 T: PolarsDataType<HasViews = FalseT, IsStruct = FalseT, IsNested = FalseT>,
175{
176 unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
178 let ca = self;
179 let targets_have_nulls = ca.null_count() > 0;
180 let targets: Vec<_> = ca.downcast_iter().collect();
181
182 let chunks = indices.downcast_iter().map(|idx_arr| {
183 let dtype = ca.dtype().to_arrow(CompatLevel::newest());
184 if idx_arr.null_count() == 0 {
185 gather_idx_array_unchecked(dtype, &targets, targets_have_nulls, idx_arr.values())
186 } else if targets.len() == 1 {
187 let target = targets.first().unwrap();
188 if targets_have_nulls {
189 idx_arr
190 .iter()
191 .map(|i| target.get_unchecked(*i? as usize))
192 .collect_arr_trusted_with_dtype(dtype)
193 } else {
194 idx_arr
195 .iter()
196 .map(|i| Some(target.value_unchecked(*i? as usize)))
197 .collect_arr_trusted_with_dtype(dtype)
198 }
199 } else {
200 let cumlens = cumulative_lengths(&targets);
201 if targets_have_nulls {
202 idx_arr
203 .iter()
204 .map(|i| target_get_unchecked(&targets, &cumlens, *i?))
205 .collect_arr_trusted_with_dtype(dtype)
206 } else {
207 idx_arr
208 .iter()
209 .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?)))
210 .collect_arr_trusted_with_dtype(dtype)
211 }
212 }
213 });
214
215 let mut out = ChunkedArray::from_chunk_iter_like(ca, chunks);
216 let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag());
217
218 out.set_sorted_flag(sorted_flag);
219 out
220 }
221}
222
223impl ChunkTakeUnchecked<IdxCa> for BinaryChunked {
224 unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
226 let ca = self;
227 let targets_have_nulls = ca.null_count() > 0;
228 let targets: Vec<_> = ca.downcast_iter().collect();
229
230 let chunks = indices.downcast_iter().map(|idx_arr| {
231 let dtype = ca.dtype().to_arrow(CompatLevel::newest());
232 if targets.len() == 1 {
233 let target = targets.first().unwrap();
234 take_unchecked(&**target, idx_arr)
235 } else {
236 let cumlens = cumulative_lengths(&targets);
237 if targets_have_nulls {
238 let arr: BinaryViewArray = idx_arr
239 .iter()
240 .map(|i| target_get_unchecked(&targets, &cumlens, *i?))
241 .collect_arr_trusted_with_dtype(dtype);
242 arr.to_boxed()
243 } else {
244 let arr: BinaryViewArray = idx_arr
245 .iter()
246 .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?)))
247 .collect_arr_trusted_with_dtype(dtype);
248 arr.to_boxed()
249 }
250 }
251 });
252
253 let mut out = ChunkedArray::from_chunks(ca.name().clone(), chunks.collect());
254 let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag());
255 out.set_sorted_flag(sorted_flag);
256 out
257 }
258}
259
260impl ChunkTakeUnchecked<IdxCa> for StringChunked {
261 unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
262 let ca = self;
263 let targets_have_nulls = ca.null_count() > 0;
264 let targets: Vec<_> = ca.downcast_iter().collect();
265
266 let chunks = indices.downcast_iter().map(|idx_arr| {
267 let dtype = ca.dtype().to_arrow(CompatLevel::newest());
268 if targets.len() == 1 {
269 let target = targets.first().unwrap();
270 take_unchecked(&**target, idx_arr)
271 } else {
272 let cumlens = cumulative_lengths(&targets);
273 if targets_have_nulls {
274 let arr: Utf8ViewArray = idx_arr
275 .iter()
276 .map(|i| target_get_unchecked(&targets, &cumlens, *i?))
277 .collect_arr_trusted_with_dtype(dtype);
278 arr.to_boxed()
279 } else {
280 let arr: Utf8ViewArray = idx_arr
281 .iter()
282 .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?)))
283 .collect_arr_trusted_with_dtype(dtype);
284 arr.to_boxed()
285 }
286 }
287 });
288
289 let mut out = ChunkedArray::from_chunks(ca.name().clone(), chunks.collect());
290 let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag());
291 out.set_sorted_flag(sorted_flag);
292 out
293 }
294}
295
296impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for BinaryChunked {
297 unsafe fn take_unchecked(&self, indices: &I) -> Self {
299 let indices = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
300 self.take_unchecked(&indices)
301 }
302}
303
304impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for StringChunked {
305 unsafe fn take_unchecked(&self, indices: &I) -> Self {
307 let indices = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
308 self.take_unchecked(&indices)
309 }
310}
311
312#[cfg(feature = "dtype-struct")]
313impl ChunkTakeUnchecked<IdxCa> for StructChunked {
314 unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
315 let a = self.rechunk();
316 let index = indices.rechunk();
317
318 let chunks = a
319 .downcast_iter()
320 .zip(index.downcast_iter())
321 .map(|(arr, idx)| take_unchecked(arr, idx))
322 .collect::<Vec<_>>();
323 self.copy_with_chunks(chunks)
324 }
325}
326
327#[cfg(feature = "dtype-struct")]
328impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for StructChunked {
329 unsafe fn take_unchecked(&self, indices: &I) -> Self {
330 let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
331 self.take_unchecked(&idx)
332 }
333}
334
335impl IdxCa {
336 pub fn with_nullable_idx<T, F: FnOnce(&IdxCa) -> T>(idx: &[NullableIdxSize], f: F) -> T {
337 let validity: Bitmap = idx.iter().map(|idx| !idx.is_null_idx()).collect_trusted();
338 let idx = bytemuck::cast_slice::<_, IdxSize>(idx);
339 let arr = unsafe { arrow::ffi::mmap::slice(idx) };
340 let arr = arr.with_validity_typed(Some(validity));
341 let ca = IdxCa::with_chunk(PlSmallStr::EMPTY, arr);
342
343 f(&ca)
344 }
345}
346
347#[cfg(feature = "dtype-array")]
348impl ChunkTakeUnchecked<IdxCa> for ArrayChunked {
349 unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
350 if self.n_chunks() > 1 && should_rechunk(self.len(), indices.len()) {
353 let chunks = vec![take_unchecked(
354 self.rechunk().downcast_as_array(),
355 indices.rechunk().downcast_as_array(),
356 )];
357 return self.copy_with_chunks(chunks);
358 }
359 let ca = self;
360 let targets_have_nulls = ca.null_count() > 0;
361 let targets: Vec<_> = ca.downcast_iter().collect();
362
363 let chunks = indices.downcast_iter().map(|idx_arr| {
364 let dtype = ca.dtype().to_arrow(CompatLevel::newest());
365 if targets.len() == 1 {
366 let target = targets.first().unwrap();
367 take_unchecked(&**target, idx_arr)
368 } else {
369 let cumlens = cumulative_lengths(&targets);
370 if targets_have_nulls {
371 let arr: FixedSizeListArray = idx_arr
372 .iter()
373 .map(|i| target_get_unchecked(&targets, &cumlens, *i?))
374 .collect_arr_trusted_with_dtype(dtype);
375 arr.to_boxed()
376 } else {
377 let arr: FixedSizeListArray = idx_arr
378 .iter()
379 .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?)))
380 .collect_arr_trusted_with_dtype(dtype);
381 arr.to_boxed()
382 }
383 }
384 });
385
386 let mut out = ca.with_chunks(chunks.collect());
387 let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag());
388 out.set_sorted_flag(sorted_flag);
389 out
390 }
391}
392
393#[cfg(feature = "dtype-array")]
394impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ArrayChunked {
395 unsafe fn take_unchecked(&self, indices: &I) -> Self {
396 let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
397 self.take_unchecked(&idx)
398 }
399}
400
401impl ChunkTakeUnchecked<IdxCa> for ListChunked {
402 unsafe fn take_unchecked(&self, indices: &IdxCa) -> Self {
403 if self.n_chunks() > 1 && should_rechunk(self.len(), indices.len()) {
406 let chunks = vec![take_unchecked(
407 self.rechunk().downcast_as_array(),
408 indices.rechunk().downcast_as_array(),
409 )];
410 return self.copy_with_chunks(chunks);
411 }
412 let ca = self;
413 let targets_have_nulls = ca.null_count() > 0;
414 let targets: Vec<_> = ca.downcast_iter().collect();
415
416 let chunks = indices.downcast_iter().map(|idx_arr| {
417 let dtype = ca.dtype().to_arrow(CompatLevel::newest());
418 if targets.len() == 1 {
419 let target = targets.first().unwrap();
420 take_unchecked(&**target, idx_arr)
421 } else {
422 let cumlens = cumulative_lengths(&targets);
423 if targets_have_nulls {
424 let arr: ListArray<i64> = idx_arr
425 .iter()
426 .map(|i| target_get_unchecked(&targets, &cumlens, *i?))
427 .collect_arr_trusted_with_dtype(dtype);
428 arr.to_boxed()
429 } else {
430 let arr: ListArray<i64> = idx_arr
431 .iter()
432 .map(|i| Some(target_value_unchecked(&targets, &cumlens, *i?)))
433 .collect_arr_trusted_with_dtype(dtype);
434 arr.to_boxed()
435 }
436 }
437 });
438
439 let mut out = ca.with_chunks(chunks.collect());
440 let sorted_flag = _update_gather_sorted_flag(ca.is_sorted_flag(), indices.is_sorted_flag());
441 out.set_sorted_flag(sorted_flag);
442 out
443 }
444}
445
446impl<I: AsRef<[IdxSize]> + ?Sized> ChunkTakeUnchecked<I> for ListChunked {
447 unsafe fn take_unchecked(&self, indices: &I) -> Self {
448 let idx = IdxCa::mmap_slice(PlSmallStr::EMPTY, indices.as_ref());
449 self.take_unchecked(&idx)
450 }
451}
452
453fn should_rechunk(n_values: usize, n_indices: usize) -> bool {
454 n_indices > 0 && { (n_values / n_indices) > gather_ratio() }
455}
456
457fn gather_ratio() -> usize {
458 return *GATHER_RECHUNK_RATIO.get_or_init(|| {
459 const NAME: &str = "POLARS_GATHER_RECHUNK_RATIO";
460 std::env::var(NAME)
461 .map(|x| {
462 x.parse::<usize>()
463 .unwrap_or_else(|_| panic!("invalid value for {NAME}: {x}"))
464 })
465 .unwrap_or(const { 64 })
466 });
467
468 static GATHER_RECHUNK_RATIO: OnceLock<usize> = OnceLock::new();
469}