1use std::borrow::Cow;
2use std::sync::Arc;
3
4use arrow::array::PrimitiveArray;
5use arrow::bitmap::Bitmap;
6use arrow::bitmap::bitmask::BitMask;
7use arrow::trusted_len::TrustMyLength;
8use polars_compute::unique::{AmortizedUnique, amortized_unique_from_dtype};
9use polars_core::POOL;
10use polars_core::error::{PolarsResult, polars_bail, polars_ensure};
11use polars_core::frame::DataFrame;
12use polars_core::prelude::row_encode::encode_rows_unordered;
13use polars_core::prelude::{
14 AnyValue, ChunkCast, Column, CompatLevel, Float64Chunked, GroupPositions, GroupsType,
15 IDX_DTYPE, IntoColumn,
16};
17use polars_core::scalar::Scalar;
18use polars_core::series::{ChunkCompareEq, Series};
19use polars_utils::itertools::Itertools;
20use polars_utils::pl_str::PlSmallStr;
21use polars_utils::{IdxSize, UnitVec};
22use rayon::iter::{IntoParallelIterator, ParallelIterator};
23
24use crate::prelude::{AggState, AggregationContext, PhysicalExpr, UpdateGroups};
25use crate::state::ExecutionState;
26
27pub fn reverse<'a>(
28 inputs: &[Arc<dyn PhysicalExpr>],
29 df: &DataFrame,
30 groups: &'a GroupPositions,
31 state: &ExecutionState,
32) -> PolarsResult<AggregationContext<'a>> {
33 assert_eq!(inputs.len(), 1);
34
35 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
36
37 if let AggState::AggregatedScalar(_) | AggState::LiteralScalar(_) = &ac.agg_state() {
39 return Ok(ac);
40 }
41
42 POOL.install(|| {
43 let positions = GroupsType::Idx(match &**ac.groups().as_ref() {
44 GroupsType::Idx(idx) => idx
45 .into_par_iter()
46 .map(|(first, idx)| {
47 (
48 idx.last().copied().unwrap_or(first),
49 idx.iter().copied().rev().collect(),
50 )
51 })
52 .collect(),
53 GroupsType::Slice {
54 groups,
55 overlapping: _,
56 monotonic: _,
57 } => groups
58 .into_par_iter()
59 .map(|[start, len]| {
60 (
61 start + len.saturating_sub(1),
62 (*start..*start + *len).rev().collect(),
63 )
64 })
65 .collect(),
66 })
67 .into_sliceable();
68 ac.with_groups(positions);
69 });
70
71 Ok(ac)
72}
73
74pub fn null_count<'a>(
75 inputs: &[Arc<dyn PhysicalExpr>],
76 df: &DataFrame,
77 groups: &'a GroupPositions,
78 state: &ExecutionState,
79) -> PolarsResult<AggregationContext<'a>> {
80 assert_eq!(inputs.len(), 1);
81
82 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
83
84 if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
85 *s = s.is_null().cast(&IDX_DTYPE).unwrap().into_column();
86 return Ok(ac);
87 }
88
89 ac.groups();
90 let values = ac.flat_naive();
91 let name = values.name().clone();
92 let Some(validity) = values.rechunk_validity() else {
93 ac.state = AggState::AggregatedScalar(Column::new_scalar(
94 name,
95 (0 as IdxSize).into(),
96 groups.len(),
97 ));
98 return Ok(ac);
99 };
100
101 POOL.install(|| {
102 let validity = BitMask::from_bitmap(&validity);
103 let null_count: Vec<IdxSize> = match &**ac.groups.as_ref() {
104 GroupsType::Idx(idx) => idx
105 .into_par_iter()
106 .map(|(_, idx)| {
107 idx.iter()
108 .map(|i| IdxSize::from(!unsafe { validity.get_bit_unchecked(*i as usize) }))
109 .sum::<IdxSize>()
110 })
111 .collect(),
112 GroupsType::Slice {
113 groups,
114 overlapping: _,
115 monotonic: _,
116 } => groups
117 .into_par_iter()
118 .map(|[start, length]| {
119 unsafe { validity.sliced_unchecked(*start as usize, *length as usize) }
120 .unset_bits() as IdxSize
121 })
122 .collect(),
123 };
124
125 ac.state = AggState::AggregatedScalar(Column::new(name, null_count));
126 });
127
128 Ok(ac)
129}
130
131pub fn any<'a>(
132 inputs: &[Arc<dyn PhysicalExpr>],
133 df: &DataFrame,
134 groups: &'a GroupPositions,
135 state: &ExecutionState,
136 ignore_nulls: bool,
137) -> PolarsResult<AggregationContext<'a>> {
138 assert_eq!(inputs.len(), 1);
139
140 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
141
142 if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
143 if ignore_nulls {
144 *s = s
145 .equal_missing(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
146 .unwrap()
147 .into_column();
148 } else {
149 *s = s
150 .equal(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
151 .unwrap()
152 .into_column();
153 }
154 return Ok(ac);
155 }
156
157 ac.groups();
158 let values = ac.flat_naive();
159 let values = values.bool()?;
160 let out = unsafe { values.agg_any(ac.groups.as_ref(), ignore_nulls) };
161 ac.state = AggState::AggregatedScalar(out.into_column());
162
163 Ok(ac)
164}
165
166pub fn all<'a>(
167 inputs: &[Arc<dyn PhysicalExpr>],
168 df: &DataFrame,
169 groups: &'a GroupPositions,
170 state: &ExecutionState,
171 ignore_nulls: bool,
172) -> PolarsResult<AggregationContext<'a>> {
173 assert_eq!(inputs.len(), 1);
174
175 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
176
177 if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
178 if ignore_nulls {
179 *s = s
180 .equal_missing(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
181 .unwrap()
182 .into_column();
183 } else {
184 *s = s
185 .equal(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
186 .unwrap()
187 .into_column();
188 }
189 return Ok(ac);
190 }
191
192 ac.groups();
193 let values = ac.flat_naive();
194 let values = values.bool()?;
195 let out = unsafe { values.agg_all(ac.groups.as_ref(), ignore_nulls) };
196 ac.state = AggState::AggregatedScalar(out.into_column());
197
198 Ok(ac)
199}
200
201#[cfg(feature = "bitwise")]
202pub fn bitwise_agg<'a>(
203 inputs: &[Arc<dyn PhysicalExpr>],
204 df: &DataFrame,
205 groups: &'a GroupPositions,
206 state: &ExecutionState,
207 op: &'static str,
208 f: impl Fn(&Column, &GroupsType) -> Column,
209) -> PolarsResult<AggregationContext<'a>> {
210 assert_eq!(inputs.len(), 1);
211
212 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
213
214 if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &ac.state {
215 let dtype = s.dtype();
216 polars_ensure!(
217 dtype.is_bool() | dtype.is_primitive_numeric(),
218 op = op,
219 dtype
220 );
221 return Ok(ac);
222 }
223
224 ac.groups();
225 let values = ac.flat_naive();
226 let out = f(values.as_ref(), ac.groups.as_ref());
227 ac.state = AggState::AggregatedScalar(out.into_column());
228
229 Ok(ac)
230}
231
232#[cfg(feature = "bitwise")]
233pub fn bitwise_and<'a>(
234 inputs: &[Arc<dyn PhysicalExpr>],
235 df: &DataFrame,
236 groups: &'a GroupPositions,
237 state: &ExecutionState,
238) -> PolarsResult<AggregationContext<'a>> {
239 bitwise_agg(
240 inputs,
241 df,
242 groups,
243 state,
244 "and_reduce",
245 |v, groups| unsafe { v.agg_and(groups) },
246 )
247}
248
249#[cfg(feature = "bitwise")]
250pub fn bitwise_or<'a>(
251 inputs: &[Arc<dyn PhysicalExpr>],
252 df: &DataFrame,
253 groups: &'a GroupPositions,
254 state: &ExecutionState,
255) -> PolarsResult<AggregationContext<'a>> {
256 bitwise_agg(inputs, df, groups, state, "or_reduce", |v, groups| unsafe {
257 v.agg_or(groups)
258 })
259}
260
261#[cfg(feature = "bitwise")]
262pub fn bitwise_xor<'a>(
263 inputs: &[Arc<dyn PhysicalExpr>],
264 df: &DataFrame,
265 groups: &'a GroupPositions,
266 state: &ExecutionState,
267) -> PolarsResult<AggregationContext<'a>> {
268 bitwise_agg(
269 inputs,
270 df,
271 groups,
272 state,
273 "xor_reduce",
274 |v, groups| unsafe { v.agg_xor(groups) },
275 )
276}
277
278pub fn drop_items<'a>(
279 mut ac: AggregationContext<'a>,
280 predicate: &Bitmap,
281) -> PolarsResult<AggregationContext<'a>> {
282 if predicate.unset_bits() == 0 {
284 if let AggState::AggregatedScalar(c) | AggState::LiteralScalar(c) = &mut ac.state {
285 *c = c.as_list().into_column();
286 if c.len() == 1 && ac.groups.len() != 1 {
287 *c = c.new_from_index(0, ac.groups.len());
288 }
289 ac.state = AggState::AggregatedList(std::mem::take(c));
290 ac.update_groups = UpdateGroups::WithSeriesLen;
291 }
292 return Ok(ac);
293 }
294
295 ac.set_original_len(false);
296
297 if predicate.set_bits() == 0 {
299 let name = ac.agg_state().name();
300 let dtype = ac.agg_state().flat_dtype();
301
302 ac.state = AggState::AggregatedList(Column::new_scalar(
303 name.clone(),
304 Scalar::new(
305 dtype.clone().implode(),
306 AnyValue::List(Series::new_empty(PlSmallStr::EMPTY, dtype)),
307 ),
308 ac.groups.len(),
309 ));
310 ac.with_update_groups(UpdateGroups::WithSeriesLen);
311 return Ok(ac);
312 }
313
314 if let AggState::AggregatedScalar(c) = &mut ac.state {
315 ac.state = AggState::NotAggregated(std::mem::take(c));
316 ac.groups = Cow::Owned(
317 {
318 let groups = predicate
319 .iter()
320 .enumerate_idx()
321 .map(|(i, p)| [i, IdxSize::from(p)])
322 .collect();
323 GroupsType::new_slice(groups, false, true)
324 }
325 .into_sliceable(),
326 );
327 ac.update_groups = UpdateGroups::No;
328 return Ok(ac);
329 }
330
331 ac.groups();
332 let predicate = BitMask::from_bitmap(predicate);
333 POOL.install(|| {
334 let positions = GroupsType::Idx(match &**ac.groups.as_ref() {
335 GroupsType::Idx(idxs) => idxs
336 .into_par_iter()
337 .map(|(fst, idxs)| {
338 let out = idxs
339 .iter()
340 .copied()
341 .filter(|i| unsafe { predicate.get_bit_unchecked(*i as usize) })
342 .collect::<UnitVec<IdxSize>>();
343 (out.first().copied().unwrap_or(fst), out)
344 })
345 .collect(),
346 GroupsType::Slice {
347 groups,
348 overlapping: _,
349 monotonic: _,
350 } => groups
351 .into_par_iter()
352 .map(|[start, length]| {
353 let predicate =
354 unsafe { predicate.sliced_unchecked(*start as usize, *length as usize) };
355 let num_values = predicate.set_bits();
356
357 if num_values == 0 {
358 (*start, UnitVec::new())
359 } else if num_values == 1 {
360 let item = *start + predicate.leading_zeros() as IdxSize;
361 let mut out = UnitVec::with_capacity(1);
362 out.push(item);
363 (item, out)
364 } else if num_values == *length as usize {
365 (*start, (*start..*start + *length).collect())
366 } else {
367 let out = unsafe {
368 TrustMyLength::new(
369 (0..*length)
370 .filter(|i| predicate.get_bit_unchecked(*i as usize))
371 .map(|i| i + *start),
372 num_values,
373 )
374 };
375 let out = out.collect::<UnitVec<IdxSize>>();
376
377 (out.first().copied().unwrap(), out)
378 }
379 })
380 .collect(),
381 })
382 .into_sliceable();
383 ac.with_groups(positions);
384 });
385
386 Ok(ac)
387}
388
389pub fn drop_nans<'a>(
390 inputs: &[Arc<dyn PhysicalExpr>],
391 df: &DataFrame,
392 groups: &'a GroupPositions,
393 state: &ExecutionState,
394) -> PolarsResult<AggregationContext<'a>> {
395 assert_eq!(inputs.len(), 1);
396 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
397 ac.groups();
398 let predicate = if ac.agg_state().flat_dtype().is_float() {
399 let values = ac.flat_naive();
400 let mut values = values.is_nan().unwrap();
401 values.rechunk_mut();
402 values.downcast_as_array().values().clone()
403 } else {
404 Bitmap::new_with_value(false, 1)
405 };
406 let predicate = !&predicate;
407 drop_items(ac, &predicate)
408}
409
410pub fn drop_nulls<'a>(
411 inputs: &[Arc<dyn PhysicalExpr>],
412 df: &DataFrame,
413 groups: &'a GroupPositions,
414 state: &ExecutionState,
415) -> PolarsResult<AggregationContext<'a>> {
416 assert_eq!(inputs.len(), 1);
417 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
418 ac.groups();
419 let predicate = ac.flat_naive().as_ref().clone();
420 let predicate = predicate.rechunk_to_arrow(CompatLevel::newest());
421 let predicate = predicate
422 .validity()
423 .cloned()
424 .unwrap_or(Bitmap::new_with_value(true, 1));
425 drop_items(ac, &predicate)
426}
427
428#[cfg(feature = "moment")]
429pub fn moment_agg<'a, S: Default>(
430 inputs: &[Arc<dyn PhysicalExpr>],
431 df: &DataFrame,
432 groups: &'a GroupPositions,
433 state: &ExecutionState,
434
435 insert_one: impl Fn(&mut S, f64) + Send + Sync,
436 new_from_slice: impl Fn(&PrimitiveArray<f64>, usize, usize) -> S + Send + Sync,
437 finalize: impl Fn(S) -> Option<f64> + Send + Sync,
438) -> PolarsResult<AggregationContext<'a>> {
439 assert_eq!(inputs.len(), 1);
440 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
441
442 if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
443 let ca = s.f64()?;
444 *s = ca
445 .iter()
446 .map(|v| {
447 v.and_then(|v| {
448 let mut state = S::default();
449 insert_one(&mut state, v);
450 finalize(state)
451 })
452 })
453 .collect::<Float64Chunked>()
454 .with_name(ca.name().clone())
455 .into_column();
456 return Ok(ac);
457 }
458
459 ac.groups();
460
461 let name = ac.get_values().name().clone();
462 let ca = ac.flat_naive();
463 let ca = ca.f64()?;
464 let ca = ca.rechunk();
465 let arr = ca.downcast_as_array();
466
467 let ca = POOL.install(|| match &**ac.groups.as_ref() {
468 GroupsType::Idx(idx) => {
469 if let Some(validity) = arr.validity().filter(|v| v.unset_bits() > 0) {
470 idx.into_par_iter()
471 .map(|(_, idx)| {
472 let mut state = S::default();
473 for &i in idx.iter() {
474 if unsafe { validity.get_bit_unchecked(i as usize) } {
475 insert_one(&mut state, arr.values()[i as usize]);
476 }
477 }
478 finalize(state)
479 })
480 .collect::<Float64Chunked>()
481 } else {
482 idx.into_par_iter()
483 .map(|(_, idx)| {
484 let mut state = S::default();
485 for &i in idx.iter() {
486 insert_one(&mut state, arr.values()[i as usize]);
487 }
488 finalize(state)
489 })
490 .collect::<Float64Chunked>()
491 }
492 },
493 GroupsType::Slice {
494 groups,
495 overlapping: _,
496 monotonic: _,
497 } => groups
498 .into_par_iter()
499 .map(|[start, length]| finalize(new_from_slice(arr, *start as usize, *length as usize)))
500 .collect::<Float64Chunked>(),
501 });
502
503 ac.state = AggState::AggregatedScalar(ca.with_name(name).into_column());
504 Ok(ac)
505}
506
507#[cfg(feature = "moment")]
508pub fn skew<'a>(
509 inputs: &[Arc<dyn PhysicalExpr>],
510 df: &DataFrame,
511 groups: &'a GroupPositions,
512 state: &ExecutionState,
513 bias: bool,
514) -> PolarsResult<AggregationContext<'a>> {
515 use polars_compute::moment::SkewState;
516 moment_agg::<SkewState>(
517 inputs,
518 df,
519 groups,
520 state,
521 SkewState::insert_one,
522 SkewState::from_array,
523 |s| s.finalize(bias),
524 )
525}
526
527#[cfg(feature = "moment")]
528pub fn kurtosis<'a>(
529 inputs: &[Arc<dyn PhysicalExpr>],
530 df: &DataFrame,
531 groups: &'a GroupPositions,
532 state: &ExecutionState,
533 fisher: bool,
534 bias: bool,
535) -> PolarsResult<AggregationContext<'a>> {
536 use polars_compute::moment::KurtosisState;
537 moment_agg::<KurtosisState>(
538 inputs,
539 df,
540 groups,
541 state,
542 KurtosisState::insert_one,
543 KurtosisState::from_array,
544 |s| s.finalize(fisher, bias),
545 )
546}
547
548pub fn unique<'a>(
549 inputs: &[Arc<dyn PhysicalExpr>],
550 df: &DataFrame,
551 groups: &'a GroupPositions,
552 state: &ExecutionState,
553 stable: bool,
554) -> PolarsResult<AggregationContext<'a>> {
555 _ = stable;
556
557 assert_eq!(inputs.len(), 1);
558 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
559 ac.groups();
560
561 if let AggState::AggregatedScalar(c) | AggState::LiteralScalar(c) = &mut ac.state {
562 *c = c.as_list().into_column();
563 if c.len() == 1 && ac.groups.len() != 1 {
564 *c = c.new_from_index(0, ac.groups.len());
565 }
566 ac.state = AggState::AggregatedList(std::mem::take(c));
567 ac.update_groups = UpdateGroups::WithSeriesLen;
568 return Ok(ac);
569 }
570
571 let values = ac.flat_naive().to_physical_repr();
572 let dtype = values.dtype();
573 let values = if dtype.contains_objects() {
574 polars_bail!(opq = unique, dtype);
575 } else if let Some(ca) = values.try_str() {
576 ca.as_binary().into_column()
577 } else if dtype.is_nested() {
578 encode_rows_unordered(&[values])?.into_column()
579 } else {
580 values
581 };
582
583 let values = values.rechunk_to_arrow(CompatLevel::newest());
584 let values = values.as_ref();
585 let state = amortized_unique_from_dtype(values.dtype());
586
587 struct CloneWrapper(Box<dyn AmortizedUnique>);
588 impl Clone for CloneWrapper {
589 fn clone(&self) -> Self {
590 Self(self.0.new_empty())
591 }
592 }
593
594 POOL.install(|| {
595 let positions = GroupsType::Idx(match &**ac.groups().as_ref() {
596 GroupsType::Idx(idx) => idx
597 .into_par_iter()
598 .map_with(CloneWrapper(state), |state, (first, idx)| {
599 let mut idx = idx.clone();
600 unsafe { state.0.retain_unique(values, &mut idx) };
601 (idx.first().copied().unwrap_or(first), idx)
602 })
603 .collect(),
604 GroupsType::Slice {
605 groups,
606 overlapping: _,
607 monotonic: _,
608 } => groups
609 .into_par_iter()
610 .map_with(CloneWrapper(state), |state, [start, len]| {
611 let mut idx = UnitVec::new();
612 state.0.arg_unique(values, &mut idx, *start, *len);
613 (idx.first().copied().unwrap_or(*start), idx)
614 })
615 .collect(),
616 })
617 .into_sliceable();
618 ac.with_groups(positions);
619 });
620
621 Ok(ac)
622}
623
624fn fw_bw_fill_null<'a>(
625 inputs: &[Arc<dyn PhysicalExpr>],
626 df: &DataFrame,
627 groups: &'a GroupPositions,
628 state: &ExecutionState,
629 f_idx: impl Fn(
630 std::iter::Copied<std::slice::Iter<'_, IdxSize>>,
631 BitMask<'_>,
632 usize,
633 ) -> UnitVec<IdxSize>
634 + Send
635 + Sync,
636 f_range: impl Fn(std::ops::Range<IdxSize>, BitMask<'_>, usize) -> UnitVec<IdxSize> + Send + Sync,
637) -> PolarsResult<AggregationContext<'a>> {
638 assert_eq!(inputs.len(), 1);
639 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
640 ac.groups();
641
642 if let AggState::AggregatedScalar(_) | AggState::LiteralScalar(_) = &mut ac.state {
643 return Ok(ac);
644 }
645
646 let values = ac.flat_naive();
647 let Some(validity) = values.rechunk_validity() else {
648 return Ok(ac);
649 };
650
651 let validity = BitMask::from_bitmap(&validity);
652 POOL.install(|| {
653 let positions = GroupsType::Idx(match &**ac.groups().as_ref() {
654 GroupsType::Idx(idx) => idx
655 .into_par_iter()
656 .map(|(first, idx)| {
657 let idx = f_idx(idx.iter().copied(), validity, idx.len());
658 (idx.first().copied().unwrap_or(first), idx)
659 })
660 .collect(),
661 GroupsType::Slice {
662 groups,
663 overlapping: _,
664 monotonic: _,
665 } => groups
666 .into_par_iter()
667 .map(|[start, len]| {
668 let idx = f_range(*start..*start + *len, validity, *len as usize);
669 (idx.first().copied().unwrap_or(*start), idx)
670 })
671 .collect(),
672 })
673 .into_sliceable();
674 ac.with_groups(positions);
675 });
676
677 Ok(ac)
678}
679
680pub fn forward_fill_null<'a>(
681 inputs: &[Arc<dyn PhysicalExpr>],
682 df: &DataFrame,
683 groups: &'a GroupPositions,
684 state: &ExecutionState,
685 limit: Option<IdxSize>,
686) -> PolarsResult<AggregationContext<'a>> {
687 let limit = limit.unwrap_or(IdxSize::MAX);
688 macro_rules! arg_forward_fill {
689 (
690 $iter:ident,
691 $validity:ident,
692 $length:ident
693 ) => {{
694 |$iter, $validity, $length| {
695 let Some(start) = $iter
696 .clone()
697 .position(|i| unsafe { $validity.get_bit_unchecked(i as usize) })
698 else {
699 return $iter.collect();
700 };
701
702 let mut idx = UnitVec::with_capacity($length);
703 let mut iter = $iter;
704 idx.extend((&mut iter).take(start));
705
706 let mut current_limit = limit;
707 let mut value = iter.next().unwrap();
708 idx.push(value);
709
710 idx.extend(iter.map(|i| {
711 if unsafe { $validity.get_bit_unchecked(i as usize) } {
712 current_limit = limit;
713 value = i;
714 i
715 } else if current_limit == 0 {
716 i
717 } else {
718 current_limit -= 1;
719 value
720 }
721 }));
722 idx
723 }
724 }};
725 }
726
727 fw_bw_fill_null(
728 inputs,
729 df,
730 groups,
731 state,
732 arg_forward_fill!(iter, validity, length),
733 arg_forward_fill!(iter, validity, length),
734 )
735}
736
737pub fn backward_fill_null<'a>(
738 inputs: &[Arc<dyn PhysicalExpr>],
739 df: &DataFrame,
740 groups: &'a GroupPositions,
741 state: &ExecutionState,
742 limit: Option<IdxSize>,
743) -> PolarsResult<AggregationContext<'a>> {
744 let limit = limit.unwrap_or(IdxSize::MAX);
745 macro_rules! arg_backward_fill {
746 (
747 $iter:ident,
748 $validity:ident,
749 $length:ident
750 ) => {{
751 |$iter, $validity, $length| {
752 let Some(start) = $iter
753 .clone()
754 .rev()
755 .position(|i| unsafe { $validity.get_bit_unchecked(i as usize) })
756 else {
757 return $iter.collect();
758 };
759
760 let mut idx = UnitVec::from_iter($iter);
761 let mut current_limit = limit;
762 let mut value = idx[$length - start - 1];
763 for i in idx[..$length - start].iter_mut().rev() {
764 if unsafe { $validity.get_bit_unchecked(*i as usize) } {
765 current_limit = limit;
766 value = *i;
767 } else if current_limit != 0 {
768 current_limit -= 1;
769 *i = value;
770 }
771 }
772
773 idx
774 }
775 }};
776 }
777
778 fw_bw_fill_null(
779 inputs,
780 df,
781 groups,
782 state,
783 arg_backward_fill!(iter, validity, length),
784 arg_backward_fill!(iter, validity, length),
785 )
786}