Skip to main content

polars_expr/dispatch/
groups_dispatch.rs

1use std::borrow::Cow;
2use std::sync::Arc;
3
4use arrow::array::PrimitiveArray;
5use arrow::bitmap::Bitmap;
6use arrow::bitmap::bitmask::BitMask;
7use arrow::trusted_len::TrustMyLength;
8use polars_compute::unique::{AmortizedUnique, amortized_unique_from_dtype};
9use polars_core::POOL;
10use polars_core::error::{PolarsResult, polars_bail, polars_ensure};
11use polars_core::frame::DataFrame;
12use polars_core::prelude::row_encode::encode_rows_unordered;
13use polars_core::prelude::{
14    AnyValue, ChunkCast, Column, CompatLevel, Float64Chunked, GroupPositions, GroupsType,
15    IDX_DTYPE, IntoColumn,
16};
17use polars_core::scalar::Scalar;
18use polars_core::series::{ChunkCompareEq, Series};
19use polars_utils::itertools::Itertools;
20use polars_utils::pl_str::PlSmallStr;
21use polars_utils::{IdxSize, UnitVec};
22use rayon::iter::{IntoParallelIterator, ParallelIterator};
23
24use crate::prelude::{AggState, AggregationContext, PhysicalExpr, UpdateGroups};
25use crate::state::ExecutionState;
26
27pub fn reverse<'a>(
28    inputs: &[Arc<dyn PhysicalExpr>],
29    df: &DataFrame,
30    groups: &'a GroupPositions,
31    state: &ExecutionState,
32) -> PolarsResult<AggregationContext<'a>> {
33    assert_eq!(inputs.len(), 1);
34
35    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
36
37    // Length preserving operation on scalars keeps scalar.
38    if let AggState::AggregatedScalar(_) | AggState::LiteralScalar(_) = &ac.agg_state() {
39        return Ok(ac);
40    }
41
42    POOL.install(|| {
43        let positions = GroupsType::Idx(match &**ac.groups().as_ref() {
44            GroupsType::Idx(idx) => idx
45                .into_par_iter()
46                .map(|(first, idx)| {
47                    (
48                        idx.last().copied().unwrap_or(first),
49                        idx.iter().copied().rev().collect(),
50                    )
51                })
52                .collect(),
53            GroupsType::Slice {
54                groups,
55                overlapping: _,
56                monotonic: _,
57            } => groups
58                .into_par_iter()
59                .map(|[start, len]| {
60                    (
61                        start + len.saturating_sub(1),
62                        (*start..*start + *len).rev().collect(),
63                    )
64                })
65                .collect(),
66        })
67        .into_sliceable();
68        ac.with_groups(positions);
69    });
70
71    Ok(ac)
72}
73
74pub fn null_count<'a>(
75    inputs: &[Arc<dyn PhysicalExpr>],
76    df: &DataFrame,
77    groups: &'a GroupPositions,
78    state: &ExecutionState,
79) -> PolarsResult<AggregationContext<'a>> {
80    assert_eq!(inputs.len(), 1);
81
82    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
83
84    if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
85        *s = s.is_null().cast(&IDX_DTYPE).unwrap().into_column();
86        return Ok(ac);
87    }
88
89    ac.groups();
90    let values = ac.flat_naive();
91    let name = values.name().clone();
92    let Some(validity) = values.rechunk_validity() else {
93        ac.state = AggState::AggregatedScalar(Column::new_scalar(
94            name,
95            (0 as IdxSize).into(),
96            groups.len(),
97        ));
98        return Ok(ac);
99    };
100
101    POOL.install(|| {
102        let validity = BitMask::from_bitmap(&validity);
103        let null_count: Vec<IdxSize> = match &**ac.groups.as_ref() {
104            GroupsType::Idx(idx) => idx
105                .into_par_iter()
106                .map(|(_, idx)| {
107                    idx.iter()
108                        .map(|i| IdxSize::from(!unsafe { validity.get_bit_unchecked(*i as usize) }))
109                        .sum::<IdxSize>()
110                })
111                .collect(),
112            GroupsType::Slice {
113                groups,
114                overlapping: _,
115                monotonic: _,
116            } => groups
117                .into_par_iter()
118                .map(|[start, length]| {
119                    unsafe { validity.sliced_unchecked(*start as usize, *length as usize) }
120                        .unset_bits() as IdxSize
121                })
122                .collect(),
123        };
124
125        ac.state = AggState::AggregatedScalar(Column::new(name, null_count));
126    });
127
128    Ok(ac)
129}
130
131pub fn any<'a>(
132    inputs: &[Arc<dyn PhysicalExpr>],
133    df: &DataFrame,
134    groups: &'a GroupPositions,
135    state: &ExecutionState,
136    ignore_nulls: bool,
137) -> PolarsResult<AggregationContext<'a>> {
138    assert_eq!(inputs.len(), 1);
139
140    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
141
142    if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
143        if ignore_nulls {
144            *s = s
145                .equal_missing(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
146                .unwrap()
147                .into_column();
148        } else {
149            *s = s
150                .equal(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
151                .unwrap()
152                .into_column();
153        }
154        return Ok(ac);
155    }
156
157    ac.groups();
158    let values = ac.flat_naive();
159    let values = values.bool()?;
160    let out = unsafe { values.agg_any(ac.groups.as_ref(), ignore_nulls) };
161    ac.state = AggState::AggregatedScalar(out.into_column());
162
163    Ok(ac)
164}
165
166pub fn all<'a>(
167    inputs: &[Arc<dyn PhysicalExpr>],
168    df: &DataFrame,
169    groups: &'a GroupPositions,
170    state: &ExecutionState,
171    ignore_nulls: bool,
172) -> PolarsResult<AggregationContext<'a>> {
173    assert_eq!(inputs.len(), 1);
174
175    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
176
177    if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
178        if ignore_nulls {
179            *s = s
180                .equal_missing(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
181                .unwrap()
182                .into_column();
183        } else {
184            *s = s
185                .equal(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
186                .unwrap()
187                .into_column();
188        }
189        return Ok(ac);
190    }
191
192    ac.groups();
193    let values = ac.flat_naive();
194    let values = values.bool()?;
195    let out = unsafe { values.agg_all(ac.groups.as_ref(), ignore_nulls) };
196    ac.state = AggState::AggregatedScalar(out.into_column());
197
198    Ok(ac)
199}
200
201#[cfg(feature = "bitwise")]
202pub fn bitwise_agg<'a>(
203    inputs: &[Arc<dyn PhysicalExpr>],
204    df: &DataFrame,
205    groups: &'a GroupPositions,
206    state: &ExecutionState,
207    op: &'static str,
208    f: impl Fn(&Column, &GroupsType) -> Column,
209) -> PolarsResult<AggregationContext<'a>> {
210    assert_eq!(inputs.len(), 1);
211
212    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
213
214    if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &ac.state {
215        let dtype = s.dtype();
216        polars_ensure!(
217            dtype.is_bool() | dtype.is_primitive_numeric(),
218            op = op,
219            dtype
220        );
221        return Ok(ac);
222    }
223
224    ac.groups();
225    let values = ac.flat_naive();
226    let out = f(values.as_ref(), ac.groups.as_ref());
227    ac.state = AggState::AggregatedScalar(out.into_column());
228
229    Ok(ac)
230}
231
232#[cfg(feature = "bitwise")]
233pub fn bitwise_and<'a>(
234    inputs: &[Arc<dyn PhysicalExpr>],
235    df: &DataFrame,
236    groups: &'a GroupPositions,
237    state: &ExecutionState,
238) -> PolarsResult<AggregationContext<'a>> {
239    bitwise_agg(
240        inputs,
241        df,
242        groups,
243        state,
244        "and_reduce",
245        |v, groups| unsafe { v.agg_and(groups) },
246    )
247}
248
249#[cfg(feature = "bitwise")]
250pub fn bitwise_or<'a>(
251    inputs: &[Arc<dyn PhysicalExpr>],
252    df: &DataFrame,
253    groups: &'a GroupPositions,
254    state: &ExecutionState,
255) -> PolarsResult<AggregationContext<'a>> {
256    bitwise_agg(inputs, df, groups, state, "or_reduce", |v, groups| unsafe {
257        v.agg_or(groups)
258    })
259}
260
261#[cfg(feature = "bitwise")]
262pub fn bitwise_xor<'a>(
263    inputs: &[Arc<dyn PhysicalExpr>],
264    df: &DataFrame,
265    groups: &'a GroupPositions,
266    state: &ExecutionState,
267) -> PolarsResult<AggregationContext<'a>> {
268    bitwise_agg(
269        inputs,
270        df,
271        groups,
272        state,
273        "xor_reduce",
274        |v, groups| unsafe { v.agg_xor(groups) },
275    )
276}
277
278pub fn drop_items<'a>(
279    mut ac: AggregationContext<'a>,
280    predicate: &Bitmap,
281) -> PolarsResult<AggregationContext<'a>> {
282    // No elements are filtered out.
283    if predicate.unset_bits() == 0 {
284        if let AggState::AggregatedScalar(c) | AggState::LiteralScalar(c) = &mut ac.state {
285            *c = c.as_list().into_column();
286            if c.len() == 1 && ac.groups.len() != 1 {
287                *c = c.new_from_index(0, ac.groups.len());
288            }
289            ac.state = AggState::AggregatedList(std::mem::take(c));
290            ac.update_groups = UpdateGroups::WithSeriesLen;
291        }
292        return Ok(ac);
293    }
294
295    ac.set_original_len(false);
296
297    // All elements are filtered out.
298    if predicate.set_bits() == 0 {
299        let name = ac.agg_state().name();
300        let dtype = ac.agg_state().flat_dtype();
301
302        ac.state = AggState::AggregatedList(Column::new_scalar(
303            name.clone(),
304            Scalar::new(
305                dtype.clone().implode(),
306                AnyValue::List(Series::new_empty(PlSmallStr::EMPTY, dtype)),
307            ),
308            ac.groups.len(),
309        ));
310        ac.with_update_groups(UpdateGroups::WithSeriesLen);
311        return Ok(ac);
312    }
313
314    if let AggState::AggregatedScalar(c) = &mut ac.state {
315        ac.state = AggState::NotAggregated(std::mem::take(c));
316        ac.groups = Cow::Owned(
317            {
318                let groups = predicate
319                    .iter()
320                    .enumerate_idx()
321                    .map(|(i, p)| [i, IdxSize::from(p)])
322                    .collect();
323                GroupsType::new_slice(groups, false, true)
324            }
325            .into_sliceable(),
326        );
327        ac.update_groups = UpdateGroups::No;
328        return Ok(ac);
329    }
330
331    ac.groups();
332    let predicate = BitMask::from_bitmap(predicate);
333    POOL.install(|| {
334        let positions = GroupsType::Idx(match &**ac.groups.as_ref() {
335            GroupsType::Idx(idxs) => idxs
336                .into_par_iter()
337                .map(|(fst, idxs)| {
338                    let out = idxs
339                        .iter()
340                        .copied()
341                        .filter(|i| unsafe { predicate.get_bit_unchecked(*i as usize) })
342                        .collect::<UnitVec<IdxSize>>();
343                    (out.first().copied().unwrap_or(fst), out)
344                })
345                .collect(),
346            GroupsType::Slice {
347                groups,
348                overlapping: _,
349                monotonic: _,
350            } => groups
351                .into_par_iter()
352                .map(|[start, length]| {
353                    let predicate =
354                        unsafe { predicate.sliced_unchecked(*start as usize, *length as usize) };
355                    let num_values = predicate.set_bits();
356
357                    if num_values == 0 {
358                        (*start, UnitVec::new())
359                    } else if num_values == 1 {
360                        let item = *start + predicate.leading_zeros() as IdxSize;
361                        let mut out = UnitVec::with_capacity(1);
362                        out.push(item);
363                        (item, out)
364                    } else if num_values == *length as usize {
365                        (*start, (*start..*start + *length).collect())
366                    } else {
367                        let out = unsafe {
368                            TrustMyLength::new(
369                                (0..*length)
370                                    .filter(|i| predicate.get_bit_unchecked(*i as usize))
371                                    .map(|i| i + *start),
372                                num_values,
373                            )
374                        };
375                        let out = out.collect::<UnitVec<IdxSize>>();
376
377                        (out.first().copied().unwrap(), out)
378                    }
379                })
380                .collect(),
381        })
382        .into_sliceable();
383        ac.with_groups(positions);
384    });
385
386    Ok(ac)
387}
388
389pub fn drop_nans<'a>(
390    inputs: &[Arc<dyn PhysicalExpr>],
391    df: &DataFrame,
392    groups: &'a GroupPositions,
393    state: &ExecutionState,
394) -> PolarsResult<AggregationContext<'a>> {
395    assert_eq!(inputs.len(), 1);
396    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
397    ac.groups();
398    let predicate = if ac.agg_state().flat_dtype().is_float() {
399        let values = ac.flat_naive();
400        let mut values = values.is_nan().unwrap();
401        values.rechunk_mut();
402        values.downcast_as_array().values().clone()
403    } else {
404        Bitmap::new_with_value(false, 1)
405    };
406    let predicate = !&predicate;
407    drop_items(ac, &predicate)
408}
409
410pub fn drop_nulls<'a>(
411    inputs: &[Arc<dyn PhysicalExpr>],
412    df: &DataFrame,
413    groups: &'a GroupPositions,
414    state: &ExecutionState,
415) -> PolarsResult<AggregationContext<'a>> {
416    assert_eq!(inputs.len(), 1);
417    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
418    ac.groups();
419    let predicate = ac.flat_naive().as_ref().clone();
420    let predicate = predicate.rechunk_to_arrow(CompatLevel::newest());
421    let predicate = predicate
422        .validity()
423        .cloned()
424        .unwrap_or(Bitmap::new_with_value(true, 1));
425    drop_items(ac, &predicate)
426}
427
428#[cfg(feature = "moment")]
429pub fn moment_agg<'a, S: Default>(
430    inputs: &[Arc<dyn PhysicalExpr>],
431    df: &DataFrame,
432    groups: &'a GroupPositions,
433    state: &ExecutionState,
434
435    insert_one: impl Fn(&mut S, f64) + Send + Sync,
436    new_from_slice: impl Fn(&PrimitiveArray<f64>, usize, usize) -> S + Send + Sync,
437    finalize: impl Fn(S) -> Option<f64> + Send + Sync,
438) -> PolarsResult<AggregationContext<'a>> {
439    assert_eq!(inputs.len(), 1);
440    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
441
442    if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
443        let ca = s.f64()?;
444        *s = ca
445            .iter()
446            .map(|v| {
447                v.and_then(|v| {
448                    let mut state = S::default();
449                    insert_one(&mut state, v);
450                    finalize(state)
451                })
452            })
453            .collect::<Float64Chunked>()
454            .with_name(ca.name().clone())
455            .into_column();
456        return Ok(ac);
457    }
458
459    ac.groups();
460
461    let name = ac.get_values().name().clone();
462    let ca = ac.flat_naive();
463    let ca = ca.f64()?;
464    let ca = ca.rechunk();
465    let arr = ca.downcast_as_array();
466
467    let ca = POOL.install(|| match &**ac.groups.as_ref() {
468        GroupsType::Idx(idx) => {
469            if let Some(validity) = arr.validity().filter(|v| v.unset_bits() > 0) {
470                idx.into_par_iter()
471                    .map(|(_, idx)| {
472                        let mut state = S::default();
473                        for &i in idx.iter() {
474                            if unsafe { validity.get_bit_unchecked(i as usize) } {
475                                insert_one(&mut state, arr.values()[i as usize]);
476                            }
477                        }
478                        finalize(state)
479                    })
480                    .collect::<Float64Chunked>()
481            } else {
482                idx.into_par_iter()
483                    .map(|(_, idx)| {
484                        let mut state = S::default();
485                        for &i in idx.iter() {
486                            insert_one(&mut state, arr.values()[i as usize]);
487                        }
488                        finalize(state)
489                    })
490                    .collect::<Float64Chunked>()
491            }
492        },
493        GroupsType::Slice {
494            groups,
495            overlapping: _,
496            monotonic: _,
497        } => groups
498            .into_par_iter()
499            .map(|[start, length]| finalize(new_from_slice(arr, *start as usize, *length as usize)))
500            .collect::<Float64Chunked>(),
501    });
502
503    ac.state = AggState::AggregatedScalar(ca.with_name(name).into_column());
504    Ok(ac)
505}
506
507#[cfg(feature = "moment")]
508pub fn skew<'a>(
509    inputs: &[Arc<dyn PhysicalExpr>],
510    df: &DataFrame,
511    groups: &'a GroupPositions,
512    state: &ExecutionState,
513    bias: bool,
514) -> PolarsResult<AggregationContext<'a>> {
515    use polars_compute::moment::SkewState;
516    moment_agg::<SkewState>(
517        inputs,
518        df,
519        groups,
520        state,
521        SkewState::insert_one,
522        SkewState::from_array,
523        |s| s.finalize(bias),
524    )
525}
526
527#[cfg(feature = "moment")]
528pub fn kurtosis<'a>(
529    inputs: &[Arc<dyn PhysicalExpr>],
530    df: &DataFrame,
531    groups: &'a GroupPositions,
532    state: &ExecutionState,
533    fisher: bool,
534    bias: bool,
535) -> PolarsResult<AggregationContext<'a>> {
536    use polars_compute::moment::KurtosisState;
537    moment_agg::<KurtosisState>(
538        inputs,
539        df,
540        groups,
541        state,
542        KurtosisState::insert_one,
543        KurtosisState::from_array,
544        |s| s.finalize(fisher, bias),
545    )
546}
547
548pub fn unique<'a>(
549    inputs: &[Arc<dyn PhysicalExpr>],
550    df: &DataFrame,
551    groups: &'a GroupPositions,
552    state: &ExecutionState,
553    stable: bool,
554) -> PolarsResult<AggregationContext<'a>> {
555    _ = stable;
556
557    assert_eq!(inputs.len(), 1);
558    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
559    ac.groups();
560
561    if let AggState::AggregatedScalar(c) | AggState::LiteralScalar(c) = &mut ac.state {
562        *c = c.as_list().into_column();
563        if c.len() == 1 && ac.groups.len() != 1 {
564            *c = c.new_from_index(0, ac.groups.len());
565        }
566        ac.state = AggState::AggregatedList(std::mem::take(c));
567        ac.update_groups = UpdateGroups::WithSeriesLen;
568        return Ok(ac);
569    }
570
571    let values = ac.flat_naive().to_physical_repr();
572    let dtype = values.dtype();
573    let values = if dtype.contains_objects() {
574        polars_bail!(opq = unique, dtype);
575    } else if let Some(ca) = values.try_str() {
576        ca.as_binary().into_column()
577    } else if dtype.is_nested() {
578        encode_rows_unordered(&[values])?.into_column()
579    } else {
580        values
581    };
582
583    let values = values.rechunk_to_arrow(CompatLevel::newest());
584    let values = values.as_ref();
585    let state = amortized_unique_from_dtype(values.dtype());
586
587    struct CloneWrapper(Box<dyn AmortizedUnique>);
588    impl Clone for CloneWrapper {
589        fn clone(&self) -> Self {
590            Self(self.0.new_empty())
591        }
592    }
593
594    POOL.install(|| {
595        let positions = GroupsType::Idx(match &**ac.groups().as_ref() {
596            GroupsType::Idx(idx) => idx
597                .into_par_iter()
598                .map_with(CloneWrapper(state), |state, (first, idx)| {
599                    let mut idx = idx.clone();
600                    unsafe { state.0.retain_unique(values, &mut idx) };
601                    (idx.first().copied().unwrap_or(first), idx)
602                })
603                .collect(),
604            GroupsType::Slice {
605                groups,
606                overlapping: _,
607                monotonic: _,
608            } => groups
609                .into_par_iter()
610                .map_with(CloneWrapper(state), |state, [start, len]| {
611                    let mut idx = UnitVec::new();
612                    state.0.arg_unique(values, &mut idx, *start, *len);
613                    (idx.first().copied().unwrap_or(*start), idx)
614                })
615                .collect(),
616        })
617        .into_sliceable();
618        ac.with_groups(positions);
619    });
620
621    Ok(ac)
622}
623
624fn fw_bw_fill_null<'a>(
625    inputs: &[Arc<dyn PhysicalExpr>],
626    df: &DataFrame,
627    groups: &'a GroupPositions,
628    state: &ExecutionState,
629    f_idx: impl Fn(
630        std::iter::Copied<std::slice::Iter<'_, IdxSize>>,
631        BitMask<'_>,
632        usize,
633    ) -> UnitVec<IdxSize>
634    + Send
635    + Sync,
636    f_range: impl Fn(std::ops::Range<IdxSize>, BitMask<'_>, usize) -> UnitVec<IdxSize> + Send + Sync,
637) -> PolarsResult<AggregationContext<'a>> {
638    assert_eq!(inputs.len(), 1);
639    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
640    ac.groups();
641
642    if let AggState::AggregatedScalar(_) | AggState::LiteralScalar(_) = &mut ac.state {
643        return Ok(ac);
644    }
645
646    let values = ac.flat_naive();
647    let Some(validity) = values.rechunk_validity() else {
648        return Ok(ac);
649    };
650
651    let validity = BitMask::from_bitmap(&validity);
652    POOL.install(|| {
653        let positions = GroupsType::Idx(match &**ac.groups().as_ref() {
654            GroupsType::Idx(idx) => idx
655                .into_par_iter()
656                .map(|(first, idx)| {
657                    let idx = f_idx(idx.iter().copied(), validity, idx.len());
658                    (idx.first().copied().unwrap_or(first), idx)
659                })
660                .collect(),
661            GroupsType::Slice {
662                groups,
663                overlapping: _,
664                monotonic: _,
665            } => groups
666                .into_par_iter()
667                .map(|[start, len]| {
668                    let idx = f_range(*start..*start + *len, validity, *len as usize);
669                    (idx.first().copied().unwrap_or(*start), idx)
670                })
671                .collect(),
672        })
673        .into_sliceable();
674        ac.with_groups(positions);
675    });
676
677    Ok(ac)
678}
679
680pub fn forward_fill_null<'a>(
681    inputs: &[Arc<dyn PhysicalExpr>],
682    df: &DataFrame,
683    groups: &'a GroupPositions,
684    state: &ExecutionState,
685    limit: Option<IdxSize>,
686) -> PolarsResult<AggregationContext<'a>> {
687    let limit = limit.unwrap_or(IdxSize::MAX);
688    macro_rules! arg_forward_fill {
689        (
690            $iter:ident,
691            $validity:ident,
692            $length:ident
693        ) => {{
694            |$iter, $validity, $length| {
695                let Some(start) = $iter
696                    .clone()
697                    .position(|i| unsafe { $validity.get_bit_unchecked(i as usize) })
698                else {
699                    return $iter.collect();
700                };
701
702                let mut idx = UnitVec::with_capacity($length);
703                let mut iter = $iter;
704                idx.extend((&mut iter).take(start));
705
706                let mut current_limit = limit;
707                let mut value = iter.next().unwrap();
708                idx.push(value);
709
710                idx.extend(iter.map(|i| {
711                    if unsafe { $validity.get_bit_unchecked(i as usize) } {
712                        current_limit = limit;
713                        value = i;
714                        i
715                    } else if current_limit == 0 {
716                        i
717                    } else {
718                        current_limit -= 1;
719                        value
720                    }
721                }));
722                idx
723            }
724        }};
725    }
726
727    fw_bw_fill_null(
728        inputs,
729        df,
730        groups,
731        state,
732        arg_forward_fill!(iter, validity, length),
733        arg_forward_fill!(iter, validity, length),
734    )
735}
736
737pub fn backward_fill_null<'a>(
738    inputs: &[Arc<dyn PhysicalExpr>],
739    df: &DataFrame,
740    groups: &'a GroupPositions,
741    state: &ExecutionState,
742    limit: Option<IdxSize>,
743) -> PolarsResult<AggregationContext<'a>> {
744    let limit = limit.unwrap_or(IdxSize::MAX);
745    macro_rules! arg_backward_fill {
746        (
747            $iter:ident,
748            $validity:ident,
749            $length:ident
750        ) => {{
751            |$iter, $validity, $length| {
752                let Some(start) = $iter
753                    .clone()
754                    .rev()
755                    .position(|i| unsafe { $validity.get_bit_unchecked(i as usize) })
756                else {
757                    return $iter.collect();
758                };
759
760                let mut idx = UnitVec::from_iter($iter);
761                let mut current_limit = limit;
762                let mut value = idx[$length - start - 1];
763                for i in idx[..$length - start].iter_mut().rev() {
764                    if unsafe { $validity.get_bit_unchecked(*i as usize) } {
765                        current_limit = limit;
766                        value = *i;
767                    } else if current_limit != 0 {
768                        current_limit -= 1;
769                        *i = value;
770                    }
771                }
772
773                idx
774            }
775        }};
776    }
777
778    fw_bw_fill_null(
779        inputs,
780        df,
781        groups,
782        state,
783        arg_backward_fill!(iter, validity, length),
784        arg_backward_fill!(iter, validity, length),
785    )
786}