polars_expr/dispatch/
groups_dispatch.rs

1use std::borrow::Cow;
2use std::sync::Arc;
3
4use arrow::array::PrimitiveArray;
5use arrow::bitmap::Bitmap;
6use arrow::bitmap::bitmask::BitMask;
7use arrow::trusted_len::TrustMyLength;
8use polars_compute::unique::{AmortizedUnique, amortized_unique_from_dtype};
9use polars_core::POOL;
10use polars_core::error::{PolarsResult, polars_bail, polars_ensure};
11use polars_core::frame::DataFrame;
12use polars_core::prelude::row_encode::encode_rows_unordered;
13use polars_core::prelude::{
14    AnyValue, ChunkCast, Column, CompatLevel, Float64Chunked, GroupPositions, GroupsType,
15    IDX_DTYPE, IntoColumn,
16};
17use polars_core::scalar::Scalar;
18use polars_core::series::{ChunkCompareEq, Series};
19use polars_utils::itertools::Itertools;
20use polars_utils::pl_str::PlSmallStr;
21use polars_utils::{IdxSize, UnitVec};
22use rayon::iter::{IntoParallelIterator, ParallelIterator};
23
24use crate::prelude::{AggState, AggregationContext, PhysicalExpr, UpdateGroups};
25use crate::state::ExecutionState;
26
27pub fn reverse<'a>(
28    inputs: &[Arc<dyn PhysicalExpr>],
29    df: &DataFrame,
30    groups: &'a GroupPositions,
31    state: &ExecutionState,
32) -> PolarsResult<AggregationContext<'a>> {
33    assert_eq!(inputs.len(), 1);
34
35    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
36
37    // Length preserving operation on scalars keeps scalar.
38    if let AggState::AggregatedScalar(_) | AggState::LiteralScalar(_) = &ac.agg_state() {
39        return Ok(ac);
40    }
41
42    POOL.install(|| {
43        let positions = GroupsType::Idx(match &**ac.groups().as_ref() {
44            GroupsType::Idx(idx) => idx
45                .into_par_iter()
46                .map(|(first, idx)| {
47                    (
48                        idx.last().copied().unwrap_or(first),
49                        idx.iter().copied().rev().collect(),
50                    )
51                })
52                .collect(),
53            GroupsType::Slice {
54                groups,
55                overlapping: _,
56            } => groups
57                .into_par_iter()
58                .map(|[start, len]| {
59                    (
60                        start + len.saturating_sub(1),
61                        (*start..*start + *len).rev().collect(),
62                    )
63                })
64                .collect(),
65        })
66        .into_sliceable();
67        ac.with_groups(positions);
68    });
69
70    Ok(ac)
71}
72
73pub fn null_count<'a>(
74    inputs: &[Arc<dyn PhysicalExpr>],
75    df: &DataFrame,
76    groups: &'a GroupPositions,
77    state: &ExecutionState,
78) -> PolarsResult<AggregationContext<'a>> {
79    assert_eq!(inputs.len(), 1);
80
81    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
82
83    if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
84        *s = s.is_null().cast(&IDX_DTYPE).unwrap().into_column();
85        return Ok(ac);
86    }
87
88    ac.groups();
89    let values = ac.flat_naive();
90    let name = values.name().clone();
91    let Some(validity) = values.rechunk_validity() else {
92        ac.state = AggState::AggregatedScalar(Column::new_scalar(
93            name,
94            (0 as IdxSize).into(),
95            groups.len(),
96        ));
97        return Ok(ac);
98    };
99
100    POOL.install(|| {
101        let validity = BitMask::from_bitmap(&validity);
102        let null_count: Vec<IdxSize> = match &**ac.groups.as_ref() {
103            GroupsType::Idx(idx) => idx
104                .into_par_iter()
105                .map(|(_, idx)| {
106                    idx.iter()
107                        .map(|i| IdxSize::from(!unsafe { validity.get_bit_unchecked(*i as usize) }))
108                        .sum::<IdxSize>()
109                })
110                .collect(),
111            GroupsType::Slice {
112                groups,
113                overlapping: _,
114            } => groups
115                .into_par_iter()
116                .map(|[start, length]| {
117                    unsafe { validity.sliced_unchecked(*start as usize, *length as usize) }
118                        .unset_bits() as IdxSize
119                })
120                .collect(),
121        };
122
123        ac.state = AggState::AggregatedScalar(Column::new(name, null_count));
124    });
125
126    Ok(ac)
127}
128
129pub fn any<'a>(
130    inputs: &[Arc<dyn PhysicalExpr>],
131    df: &DataFrame,
132    groups: &'a GroupPositions,
133    state: &ExecutionState,
134    ignore_nulls: bool,
135) -> PolarsResult<AggregationContext<'a>> {
136    assert_eq!(inputs.len(), 1);
137
138    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
139
140    if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
141        if ignore_nulls {
142            *s = s
143                .equal_missing(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
144                .unwrap()
145                .into_column();
146        } else {
147            *s = s
148                .equal(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
149                .unwrap()
150                .into_column();
151        }
152        return Ok(ac);
153    }
154
155    ac.groups();
156    let values = ac.flat_naive();
157    let values = values.bool()?;
158    let out = unsafe { values.agg_any(ac.groups.as_ref(), ignore_nulls) };
159    ac.state = AggState::AggregatedScalar(out.into_column());
160
161    Ok(ac)
162}
163
164pub fn all<'a>(
165    inputs: &[Arc<dyn PhysicalExpr>],
166    df: &DataFrame,
167    groups: &'a GroupPositions,
168    state: &ExecutionState,
169    ignore_nulls: bool,
170) -> PolarsResult<AggregationContext<'a>> {
171    assert_eq!(inputs.len(), 1);
172
173    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
174
175    if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
176        if ignore_nulls {
177            *s = s
178                .equal_missing(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
179                .unwrap()
180                .into_column();
181        } else {
182            *s = s
183                .equal(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
184                .unwrap()
185                .into_column();
186        }
187        return Ok(ac);
188    }
189
190    ac.groups();
191    let values = ac.flat_naive();
192    let values = values.bool()?;
193    let out = unsafe { values.agg_all(ac.groups.as_ref(), ignore_nulls) };
194    ac.state = AggState::AggregatedScalar(out.into_column());
195
196    Ok(ac)
197}
198
199#[cfg(feature = "bitwise")]
200pub fn bitwise_agg<'a>(
201    inputs: &[Arc<dyn PhysicalExpr>],
202    df: &DataFrame,
203    groups: &'a GroupPositions,
204    state: &ExecutionState,
205    op: &'static str,
206    f: impl Fn(&Column, &GroupsType) -> Column,
207) -> PolarsResult<AggregationContext<'a>> {
208    assert_eq!(inputs.len(), 1);
209
210    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
211
212    if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &ac.state {
213        let dtype = s.dtype();
214        polars_ensure!(
215            dtype.is_bool() | dtype.is_primitive_numeric(),
216            op = op,
217            dtype
218        );
219        return Ok(ac);
220    }
221
222    ac.groups();
223    let values = ac.flat_naive();
224    let out = f(values.as_ref(), ac.groups.as_ref());
225    ac.state = AggState::AggregatedScalar(out.into_column());
226
227    Ok(ac)
228}
229
230#[cfg(feature = "bitwise")]
231pub fn bitwise_and<'a>(
232    inputs: &[Arc<dyn PhysicalExpr>],
233    df: &DataFrame,
234    groups: &'a GroupPositions,
235    state: &ExecutionState,
236) -> PolarsResult<AggregationContext<'a>> {
237    bitwise_agg(
238        inputs,
239        df,
240        groups,
241        state,
242        "and_reduce",
243        |v, groups| unsafe { v.agg_and(groups) },
244    )
245}
246
247#[cfg(feature = "bitwise")]
248pub fn bitwise_or<'a>(
249    inputs: &[Arc<dyn PhysicalExpr>],
250    df: &DataFrame,
251    groups: &'a GroupPositions,
252    state: &ExecutionState,
253) -> PolarsResult<AggregationContext<'a>> {
254    bitwise_agg(inputs, df, groups, state, "or_reduce", |v, groups| unsafe {
255        v.agg_or(groups)
256    })
257}
258
259#[cfg(feature = "bitwise")]
260pub fn bitwise_xor<'a>(
261    inputs: &[Arc<dyn PhysicalExpr>],
262    df: &DataFrame,
263    groups: &'a GroupPositions,
264    state: &ExecutionState,
265) -> PolarsResult<AggregationContext<'a>> {
266    bitwise_agg(
267        inputs,
268        df,
269        groups,
270        state,
271        "xor_reduce",
272        |v, groups| unsafe { v.agg_xor(groups) },
273    )
274}
275
276pub fn drop_items<'a>(
277    mut ac: AggregationContext<'a>,
278    predicate: &Bitmap,
279) -> PolarsResult<AggregationContext<'a>> {
280    // No elements are filtered out.
281    if predicate.unset_bits() == 0 {
282        if let AggState::AggregatedScalar(c) | AggState::LiteralScalar(c) = &mut ac.state {
283            *c = c.as_list().into_column();
284        }
285        return Ok(ac);
286    }
287
288    ac.set_original_len(false);
289
290    // All elements are filtered out.
291    if predicate.set_bits() == 0 {
292        let name = ac.agg_state().name();
293        let dtype = ac.agg_state().flat_dtype();
294
295        ac.state = AggState::AggregatedList(Column::new_scalar(
296            name.clone(),
297            Scalar::new(
298                dtype.clone().implode(),
299                AnyValue::List(Series::new_empty(PlSmallStr::EMPTY, dtype)),
300            ),
301            ac.groups.len(),
302        ));
303        ac.with_update_groups(UpdateGroups::WithSeriesLen);
304        return Ok(ac);
305    }
306
307    if let AggState::LiteralScalar(c) = &ac.state {
308        ac.state =
309            AggState::AggregatedList(c.as_list().into_column().new_from_index(0, predicate.len()));
310        ac.groups = Cow::Owned(
311            GroupsType::Slice {
312                groups: predicate.iter().map(|p| [0, IdxSize::from(p)]).collect(),
313                overlapping: true,
314            }
315            .into_sliceable(),
316        );
317        return Ok(ac);
318    }
319
320    if let AggState::AggregatedScalar(c) = &mut ac.state {
321        ac.state = AggState::AggregatedList(c.as_list().into_column());
322        ac.groups = Cow::Owned(
323            GroupsType::Slice {
324                groups: predicate
325                    .iter()
326                    .enumerate_idx()
327                    .map(|(i, p)| [i, IdxSize::from(p)])
328                    .collect(),
329                overlapping: false,
330            }
331            .into_sliceable(),
332        );
333        return Ok(ac);
334    }
335
336    ac.groups();
337    let predicate = BitMask::from_bitmap(predicate);
338    POOL.install(|| {
339        let positions = GroupsType::Idx(match &**ac.groups.as_ref() {
340            GroupsType::Idx(idxs) => idxs
341                .into_par_iter()
342                .map(|(fst, idxs)| {
343                    let out = idxs
344                        .iter()
345                        .copied()
346                        .filter(|i| unsafe { predicate.get_bit_unchecked(*i as usize) })
347                        .collect::<UnitVec<IdxSize>>();
348                    (out.first().copied().unwrap_or(fst), out)
349                })
350                .collect(),
351            GroupsType::Slice {
352                groups,
353                overlapping: _,
354            } => groups
355                .into_par_iter()
356                .map(|[start, length]| {
357                    let predicate =
358                        unsafe { predicate.sliced_unchecked(*start as usize, *length as usize) };
359                    let num_values = predicate.set_bits();
360
361                    if num_values == 0 {
362                        (*start, UnitVec::new())
363                    } else if num_values == 1 {
364                        let item = *start + predicate.leading_zeros() as IdxSize;
365                        let mut out = UnitVec::with_capacity(1);
366                        out.push(item);
367                        (item, out)
368                    } else if num_values == *length as usize {
369                        (*start, (*start..*start + *length).collect())
370                    } else {
371                        let out = unsafe {
372                            TrustMyLength::new(
373                                (0..*length)
374                                    .filter(|i| predicate.get_bit_unchecked(*i as usize))
375                                    .map(|i| i + *start),
376                                num_values,
377                            )
378                        };
379                        let out = out.collect::<UnitVec<IdxSize>>();
380
381                        (out.first().copied().unwrap(), out)
382                    }
383                })
384                .collect(),
385        })
386        .into_sliceable();
387        ac.with_groups(positions);
388    });
389
390    Ok(ac)
391}
392
393pub fn drop_nans<'a>(
394    inputs: &[Arc<dyn PhysicalExpr>],
395    df: &DataFrame,
396    groups: &'a GroupPositions,
397    state: &ExecutionState,
398) -> PolarsResult<AggregationContext<'a>> {
399    assert_eq!(inputs.len(), 1);
400    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
401    ac.groups();
402    let predicate = if ac.agg_state().flat_dtype().is_float() {
403        let values = ac.flat_naive();
404        let mut values = values.is_nan().unwrap();
405        values.rechunk_mut();
406        values.downcast_as_array().values().clone()
407    } else {
408        Bitmap::new_with_value(true, 1)
409    };
410    drop_items(ac, &predicate)
411}
412
413pub fn drop_nulls<'a>(
414    inputs: &[Arc<dyn PhysicalExpr>],
415    df: &DataFrame,
416    groups: &'a GroupPositions,
417    state: &ExecutionState,
418) -> PolarsResult<AggregationContext<'a>> {
419    assert_eq!(inputs.len(), 1);
420    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
421    ac.groups();
422    let predicate = ac.flat_naive().as_ref().clone();
423    let predicate = predicate.rechunk_to_arrow(CompatLevel::newest());
424    let predicate = predicate
425        .validity()
426        .cloned()
427        .unwrap_or(Bitmap::new_with_value(true, 1));
428    drop_items(ac, &predicate)
429}
430
431#[cfg(feature = "moment")]
432pub fn moment_agg<'a, S: Default>(
433    inputs: &[Arc<dyn PhysicalExpr>],
434    df: &DataFrame,
435    groups: &'a GroupPositions,
436    state: &ExecutionState,
437
438    insert_one: impl Fn(&mut S, f64) + Send + Sync,
439    new_from_slice: impl Fn(&PrimitiveArray<f64>, usize, usize) -> S + Send + Sync,
440    finalize: impl Fn(S) -> Option<f64> + Send + Sync,
441) -> PolarsResult<AggregationContext<'a>> {
442    assert_eq!(inputs.len(), 1);
443    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
444
445    if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
446        let ca = s.f64()?;
447        *s = ca
448            .iter()
449            .map(|v| {
450                v.and_then(|v| {
451                    let mut state = S::default();
452                    insert_one(&mut state, v);
453                    finalize(state)
454                })
455            })
456            .collect::<Float64Chunked>()
457            .with_name(ca.name().clone())
458            .into_column();
459        return Ok(ac);
460    }
461
462    ac.groups();
463
464    let name = ac.get_values().name().clone();
465    let ca = ac.flat_naive();
466    let ca = ca.f64()?;
467    let ca = ca.rechunk();
468    let arr = ca.downcast_as_array();
469
470    let ca = POOL.install(|| match &**ac.groups.as_ref() {
471        GroupsType::Idx(idx) => {
472            if let Some(validity) = arr.validity().filter(|v| v.unset_bits() > 0) {
473                idx.into_par_iter()
474                    .map(|(_, idx)| {
475                        let mut state = S::default();
476                        for &i in idx.iter() {
477                            if unsafe { validity.get_bit_unchecked(i as usize) } {
478                                insert_one(&mut state, arr.values()[i as usize]);
479                            }
480                        }
481                        finalize(state)
482                    })
483                    .collect::<Float64Chunked>()
484            } else {
485                idx.into_par_iter()
486                    .map(|(_, idx)| {
487                        let mut state = S::default();
488                        for &i in idx.iter() {
489                            insert_one(&mut state, arr.values()[i as usize]);
490                        }
491                        finalize(state)
492                    })
493                    .collect::<Float64Chunked>()
494            }
495        },
496        GroupsType::Slice {
497            groups,
498            overlapping: _,
499        } => groups
500            .into_par_iter()
501            .map(|[start, length]| finalize(new_from_slice(arr, *start as usize, *length as usize)))
502            .collect::<Float64Chunked>(),
503    });
504
505    ac.state = AggState::AggregatedScalar(ca.with_name(name).into_column());
506    Ok(ac)
507}
508
509#[cfg(feature = "moment")]
510pub fn skew<'a>(
511    inputs: &[Arc<dyn PhysicalExpr>],
512    df: &DataFrame,
513    groups: &'a GroupPositions,
514    state: &ExecutionState,
515    bias: bool,
516) -> PolarsResult<AggregationContext<'a>> {
517    use polars_compute::moment::SkewState;
518    moment_agg::<SkewState>(
519        inputs,
520        df,
521        groups,
522        state,
523        SkewState::insert_one,
524        SkewState::from_array,
525        |s| s.finalize(bias),
526    )
527}
528
529#[cfg(feature = "moment")]
530pub fn kurtosis<'a>(
531    inputs: &[Arc<dyn PhysicalExpr>],
532    df: &DataFrame,
533    groups: &'a GroupPositions,
534    state: &ExecutionState,
535    fisher: bool,
536    bias: bool,
537) -> PolarsResult<AggregationContext<'a>> {
538    use polars_compute::moment::KurtosisState;
539    moment_agg::<KurtosisState>(
540        inputs,
541        df,
542        groups,
543        state,
544        KurtosisState::insert_one,
545        KurtosisState::from_array,
546        |s| s.finalize(fisher, bias),
547    )
548}
549
550pub fn unique<'a>(
551    inputs: &[Arc<dyn PhysicalExpr>],
552    df: &DataFrame,
553    groups: &'a GroupPositions,
554    state: &ExecutionState,
555    stable: bool,
556) -> PolarsResult<AggregationContext<'a>> {
557    _ = stable;
558
559    assert_eq!(inputs.len(), 1);
560    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
561    ac.groups();
562
563    if let AggState::AggregatedScalar(c) | AggState::LiteralScalar(c) = &mut ac.state {
564        *c = c.as_list().into_column();
565        return Ok(ac);
566    }
567
568    let values = ac.flat_naive().to_physical_repr();
569    let dtype = values.dtype();
570    let values = if dtype.contains_objects() {
571        polars_bail!(opq = unique, dtype);
572    } else if let Some(ca) = values.try_str() {
573        ca.as_binary().into_column()
574    } else if dtype.is_nested() {
575        encode_rows_unordered(&[values])?.into_column()
576    } else {
577        values
578    };
579
580    let values = values.rechunk_to_arrow(CompatLevel::newest());
581    let values = values.as_ref();
582    let state = amortized_unique_from_dtype(values.dtype());
583
584    struct CloneWrapper(Box<dyn AmortizedUnique>);
585    impl Clone for CloneWrapper {
586        fn clone(&self) -> Self {
587            Self(self.0.new_empty())
588        }
589    }
590
591    POOL.install(|| {
592        let positions = GroupsType::Idx(match &**ac.groups().as_ref() {
593            GroupsType::Idx(idx) => idx
594                .into_par_iter()
595                .map_with(CloneWrapper(state), |state, (first, idx)| {
596                    let mut idx = idx.clone();
597                    unsafe { state.0.retain_unique(values, &mut idx) };
598                    (idx.first().copied().unwrap_or(first), idx)
599                })
600                .collect(),
601            GroupsType::Slice {
602                groups,
603                overlapping: _,
604            } => groups
605                .into_par_iter()
606                .map_with(CloneWrapper(state), |state, [start, len]| {
607                    let mut idx = UnitVec::new();
608                    state.0.arg_unique(values, &mut idx, *start, *len);
609                    (idx.first().copied().unwrap_or(*start), idx)
610                })
611                .collect(),
612        })
613        .into_sliceable();
614        ac.with_groups(positions);
615    });
616
617    Ok(ac)
618}
619
620fn fw_bw_fill_null<'a>(
621    inputs: &[Arc<dyn PhysicalExpr>],
622    df: &DataFrame,
623    groups: &'a GroupPositions,
624    state: &ExecutionState,
625    f_idx: impl Fn(
626        std::iter::Copied<std::slice::Iter<'_, IdxSize>>,
627        BitMask<'_>,
628        usize,
629    ) -> UnitVec<IdxSize>
630    + Send
631    + Sync,
632    f_range: impl Fn(std::ops::Range<IdxSize>, BitMask<'_>, usize) -> UnitVec<IdxSize> + Send + Sync,
633) -> PolarsResult<AggregationContext<'a>> {
634    assert_eq!(inputs.len(), 1);
635    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
636    ac.groups();
637
638    if let AggState::AggregatedScalar(_) | AggState::LiteralScalar(_) = &mut ac.state {
639        return Ok(ac);
640    }
641
642    let values = ac.flat_naive();
643    let Some(validity) = values.rechunk_validity() else {
644        return Ok(ac);
645    };
646
647    let validity = BitMask::from_bitmap(&validity);
648    POOL.install(|| {
649        let positions = GroupsType::Idx(match &**ac.groups().as_ref() {
650            GroupsType::Idx(idx) => idx
651                .into_par_iter()
652                .map(|(first, idx)| {
653                    let idx = f_idx(idx.iter().copied(), validity, idx.len());
654                    (idx.first().copied().unwrap_or(first), idx)
655                })
656                .collect(),
657            GroupsType::Slice {
658                groups,
659                overlapping: _,
660            } => groups
661                .into_par_iter()
662                .map(|[start, len]| {
663                    let idx = f_range(*start..*start + *len, validity, *len as usize);
664                    (idx.first().copied().unwrap_or(*start), idx)
665                })
666                .collect(),
667        })
668        .into_sliceable();
669        ac.with_groups(positions);
670    });
671
672    Ok(ac)
673}
674
675pub fn forward_fill_null<'a>(
676    inputs: &[Arc<dyn PhysicalExpr>],
677    df: &DataFrame,
678    groups: &'a GroupPositions,
679    state: &ExecutionState,
680    limit: Option<IdxSize>,
681) -> PolarsResult<AggregationContext<'a>> {
682    let limit = limit.unwrap_or(IdxSize::MAX);
683    macro_rules! arg_forward_fill {
684        (
685            $iter:ident,
686            $validity:ident,
687            $length:ident
688        ) => {{
689            |$iter, $validity, $length| {
690                let Some(start) = $iter
691                    .clone()
692                    .position(|i| unsafe { $validity.get_bit_unchecked(i as usize) })
693                else {
694                    return $iter.collect();
695                };
696
697                let mut idx = UnitVec::with_capacity($length);
698                let mut iter = $iter;
699                idx.extend((&mut iter).take(start));
700
701                let mut current_limit = limit;
702                let mut value = iter.next().unwrap();
703                idx.push(value);
704
705                idx.extend(iter.map(|i| {
706                    if unsafe { $validity.get_bit_unchecked(i as usize) } {
707                        current_limit = limit;
708                        value = i;
709                        i
710                    } else if current_limit == 0 {
711                        i
712                    } else {
713                        current_limit -= 1;
714                        value
715                    }
716                }));
717                idx
718            }
719        }};
720    }
721
722    fw_bw_fill_null(
723        inputs,
724        df,
725        groups,
726        state,
727        arg_forward_fill!(iter, validity, length),
728        arg_forward_fill!(iter, validity, length),
729    )
730}
731
732pub fn backward_fill_null<'a>(
733    inputs: &[Arc<dyn PhysicalExpr>],
734    df: &DataFrame,
735    groups: &'a GroupPositions,
736    state: &ExecutionState,
737    limit: Option<IdxSize>,
738) -> PolarsResult<AggregationContext<'a>> {
739    let limit = limit.unwrap_or(IdxSize::MAX);
740    macro_rules! arg_backward_fill {
741        (
742            $iter:ident,
743            $validity:ident,
744            $length:ident
745        ) => {{
746            |$iter, $validity, $length| {
747                let Some(start) = $iter
748                    .clone()
749                    .rev()
750                    .position(|i| unsafe { $validity.get_bit_unchecked(i as usize) })
751                else {
752                    return $iter.collect();
753                };
754
755                let mut idx = UnitVec::from_iter($iter);
756                let mut current_limit = limit;
757                let mut value = idx[$length - start - 1];
758                for i in idx[..$length - start].iter_mut().rev() {
759                    if unsafe { $validity.get_bit_unchecked(*i as usize) } {
760                        current_limit = limit;
761                        value = *i;
762                    } else if current_limit != 0 {
763                        current_limit -= 1;
764                        *i = value;
765                    }
766                }
767
768                idx
769            }
770        }};
771    }
772
773    fw_bw_fill_null(
774        inputs,
775        df,
776        groups,
777        state,
778        arg_backward_fill!(iter, validity, length),
779        arg_backward_fill!(iter, validity, length),
780    )
781}