Skip to main content

polars_expr/dispatch/
groups_dispatch.rs

1use std::borrow::Cow;
2use std::sync::Arc;
3
4use arrow::array::PrimitiveArray;
5use arrow::bitmap::Bitmap;
6use arrow::bitmap::bitmask::BitMask;
7use arrow::trusted_len::TrustMyLength;
8use polars_compute::rolling::QuantileMethod;
9use polars_compute::unique::{AmortizedUnique, amortized_unique_from_dtype};
10use polars_core::error::{PolarsResult, polars_bail, polars_ensure};
11use polars_core::frame::DataFrame;
12use polars_core::prelude::row_encode::encode_rows_unordered;
13use polars_core::prelude::{
14    AnyValue, BooleanChunked, ChunkCast, Column, CompatLevel, Float64Chunked, GroupPositions,
15    GroupsType, IDX_DTYPE, IntoColumn,
16};
17use polars_core::runtime::RAYON;
18use polars_core::scalar::Scalar;
19use polars_core::series::{ChunkCompareEq, Series};
20use polars_utils::itertools::Itertools;
21use polars_utils::pl_str::PlSmallStr;
22use polars_utils::{IdxSize, UnitVec};
23use rayon::iter::{IntoParallelIterator, ParallelIterator};
24
25use crate::expressions::evaluate_count_on_ac;
26use crate::prelude::{AggState, AggregationContext, PhysicalExpr, UpdateGroups};
27use crate::state::ExecutionState;
28
29pub fn reverse<'a>(
30    inputs: &[Arc<dyn PhysicalExpr>],
31    df: &DataFrame,
32    groups: &'a GroupPositions,
33    state: &ExecutionState,
34) -> PolarsResult<AggregationContext<'a>> {
35    assert_eq!(inputs.len(), 1);
36
37    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
38
39    // Length preserving operation on scalars keeps scalar.
40    if let AggState::AggregatedScalar(_) | AggState::LiteralScalar(_) = &ac.agg_state() {
41        return Ok(ac);
42    }
43
44    RAYON.install(|| {
45        let positions = GroupsType::Idx(match &**ac.groups().as_ref() {
46            GroupsType::Idx(idx) => idx
47                .into_par_iter()
48                .map(|(first, idx)| {
49                    (
50                        idx.last().copied().unwrap_or(first),
51                        idx.iter().copied().rev().collect(),
52                    )
53                })
54                .collect(),
55            GroupsType::Slice {
56                groups,
57                overlapping: _,
58                monotonic: _,
59            } => groups
60                .into_par_iter()
61                .map(|[start, len]| {
62                    (
63                        start + len.saturating_sub(1),
64                        (*start..*start + *len).rev().collect(),
65                    )
66                })
67                .collect(),
68        })
69        .into_sliceable();
70        ac.with_groups(positions);
71    });
72
73    Ok(ac)
74}
75
76pub fn null_count<'a>(
77    inputs: &[Arc<dyn PhysicalExpr>],
78    df: &DataFrame,
79    groups: &'a GroupPositions,
80    state: &ExecutionState,
81) -> PolarsResult<AggregationContext<'a>> {
82    assert_eq!(inputs.len(), 1);
83
84    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
85
86    if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
87        *s = s.is_null().cast(&IDX_DTYPE).unwrap().into_column();
88        return Ok(ac);
89    }
90
91    ac.groups();
92    let values = ac.flat_naive();
93    let name = values.name().clone();
94    let Some(validity) = values.rechunk_validity() else {
95        ac.state = AggState::AggregatedScalar(Column::new_scalar(
96            name,
97            (0 as IdxSize).into(),
98            groups.len(),
99        ));
100        return Ok(ac);
101    };
102
103    RAYON.install(|| {
104        let validity = BitMask::from_bitmap(&validity);
105        let null_count: Vec<IdxSize> = match &**ac.groups.as_ref() {
106            GroupsType::Idx(idx) => idx
107                .into_par_iter()
108                .map(|(_, idx)| {
109                    idx.iter()
110                        .map(|i| IdxSize::from(!unsafe { validity.get_bit_unchecked(*i as usize) }))
111                        .sum::<IdxSize>()
112                })
113                .collect(),
114            GroupsType::Slice {
115                groups,
116                overlapping: _,
117                monotonic: _,
118            } => groups
119                .into_par_iter()
120                .map(|[start, length]| {
121                    unsafe { validity.sliced_unchecked(*start as usize, *length as usize) }
122                        .unset_bits() as IdxSize
123                })
124                .collect(),
125        };
126
127        ac.state = AggState::AggregatedScalar(Column::new(name, null_count));
128    });
129
130    Ok(ac)
131}
132
133pub fn has_nulls<'a>(
134    inputs: &[Arc<dyn PhysicalExpr>],
135    df: &DataFrame,
136    groups: &'a GroupPositions,
137    state: &ExecutionState,
138) -> PolarsResult<AggregationContext<'a>> {
139    assert_eq!(inputs.len(), 1);
140
141    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
142
143    if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
144        *s = s.is_null().into_column();
145        return Ok(ac);
146    }
147
148    ac.groups();
149    let values = ac.flat_naive();
150    let name = values.name().clone();
151    let Some(validity) = values.rechunk_validity() else {
152        ac.state = AggState::AggregatedScalar(Column::new_scalar(name, false.into(), groups.len()));
153        return Ok(ac);
154    };
155
156    RAYON.install(|| {
157        let validity = BitMask::from_bitmap(&validity);
158        let has_nulls: BooleanChunked = match &**ac.groups.as_ref() {
159            GroupsType::Idx(idx) => idx
160                .into_par_iter()
161                .map(|(_, idx)| {
162                    idx.iter()
163                        .any(|i| !unsafe { validity.get_bit_unchecked(*i as usize) })
164                })
165                .collect(),
166            GroupsType::Slice {
167                groups,
168                overlapping: _,
169                monotonic: _,
170            } => groups
171                .into_par_iter()
172                .map(|[start, length]| {
173                    unsafe { validity.sliced_unchecked(*start as usize, *length as usize) }
174                        .unset_bits()
175                        > 0
176                })
177                .collect(),
178        };
179
180        ac.state = AggState::AggregatedScalar(has_nulls.with_name(name).into_column());
181    });
182
183    Ok(ac)
184}
185
186pub fn any<'a>(
187    inputs: &[Arc<dyn PhysicalExpr>],
188    df: &DataFrame,
189    groups: &'a GroupPositions,
190    state: &ExecutionState,
191    ignore_nulls: bool,
192) -> PolarsResult<AggregationContext<'a>> {
193    assert_eq!(inputs.len(), 1);
194
195    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
196
197    if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
198        if ignore_nulls {
199            *s = s
200                .equal_missing(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
201                .unwrap()
202                .into_column();
203        } else {
204            *s = s
205                .equal(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
206                .unwrap()
207                .into_column();
208        }
209        return Ok(ac);
210    }
211
212    ac.groups();
213    let values = ac.flat_naive();
214    let values = values.bool()?;
215    let out = unsafe { values.agg_any(ac.groups.as_ref(), ignore_nulls) };
216    ac.state = AggState::AggregatedScalar(out.into_column());
217
218    Ok(ac)
219}
220
221pub fn all<'a>(
222    inputs: &[Arc<dyn PhysicalExpr>],
223    df: &DataFrame,
224    groups: &'a GroupPositions,
225    state: &ExecutionState,
226    ignore_nulls: bool,
227) -> PolarsResult<AggregationContext<'a>> {
228    assert_eq!(inputs.len(), 1);
229
230    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
231
232    if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
233        if ignore_nulls {
234            *s = s
235                .equal_missing(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
236                .unwrap()
237                .into_column();
238        } else {
239            *s = s
240                .equal(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
241                .unwrap()
242                .into_column();
243        }
244        return Ok(ac);
245    }
246
247    ac.groups();
248    let values = ac.flat_naive();
249    let values = values.bool()?;
250    let out = unsafe { values.agg_all(ac.groups.as_ref(), ignore_nulls) };
251    ac.state = AggState::AggregatedScalar(out.into_column());
252
253    Ok(ac)
254}
255
256pub fn is_empty<'a>(
257    inputs: &[Arc<dyn PhysicalExpr>],
258    df: &DataFrame,
259    groups: &'a GroupPositions,
260    state: &ExecutionState,
261    ignore_nulls: bool,
262) -> PolarsResult<AggregationContext<'a>> {
263    assert_eq!(inputs.len(), 1);
264
265    // TODO: dedicated impl.
266    let ac = inputs[0].evaluate_on_groups(df, groups, state)?;
267    let counts = evaluate_count_on_ac(ac, !ignore_nulls)?;
268    let is_empty = counts.equal(&Column::new_scalar(
269        PlSmallStr::EMPTY,
270        Scalar::new_idxsize(0),
271        1,
272    ))?;
273    Ok(AggregationContext::from_agg_state(
274        AggState::AggregatedScalar(is_empty.into_column()),
275        Cow::Borrowed(groups),
276    ))
277}
278
279#[cfg(feature = "bitwise")]
280pub fn bitwise_agg<'a>(
281    inputs: &[Arc<dyn PhysicalExpr>],
282    df: &DataFrame,
283    groups: &'a GroupPositions,
284    state: &ExecutionState,
285    op: &'static str,
286    f: impl Fn(&Column, &GroupsType) -> Column,
287) -> PolarsResult<AggregationContext<'a>> {
288    assert_eq!(inputs.len(), 1);
289
290    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
291
292    if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &ac.state {
293        let dtype = s.dtype();
294        polars_ensure!(
295            dtype.is_bool() | dtype.is_primitive_numeric(),
296            op = op,
297            dtype
298        );
299        return Ok(ac);
300    }
301
302    ac.groups();
303    let values = ac.flat_naive();
304    let out = f(values.as_ref(), ac.groups.as_ref());
305    ac.state = AggState::AggregatedScalar(out.into_column());
306
307    Ok(ac)
308}
309
310#[cfg(feature = "bitwise")]
311pub fn bitwise_and<'a>(
312    inputs: &[Arc<dyn PhysicalExpr>],
313    df: &DataFrame,
314    groups: &'a GroupPositions,
315    state: &ExecutionState,
316) -> PolarsResult<AggregationContext<'a>> {
317    bitwise_agg(
318        inputs,
319        df,
320        groups,
321        state,
322        "and_reduce",
323        |v, groups| unsafe { v.agg_and(groups) },
324    )
325}
326
327#[cfg(feature = "bitwise")]
328pub fn bitwise_or<'a>(
329    inputs: &[Arc<dyn PhysicalExpr>],
330    df: &DataFrame,
331    groups: &'a GroupPositions,
332    state: &ExecutionState,
333) -> PolarsResult<AggregationContext<'a>> {
334    bitwise_agg(inputs, df, groups, state, "or_reduce", |v, groups| unsafe {
335        v.agg_or(groups)
336    })
337}
338
339#[cfg(feature = "bitwise")]
340pub fn bitwise_xor<'a>(
341    inputs: &[Arc<dyn PhysicalExpr>],
342    df: &DataFrame,
343    groups: &'a GroupPositions,
344    state: &ExecutionState,
345) -> PolarsResult<AggregationContext<'a>> {
346    bitwise_agg(
347        inputs,
348        df,
349        groups,
350        state,
351        "xor_reduce",
352        |v, groups| unsafe { v.agg_xor(groups) },
353    )
354}
355
356pub fn drop_items<'a>(
357    mut ac: AggregationContext<'a>,
358    predicate: &Bitmap,
359) -> PolarsResult<AggregationContext<'a>> {
360    // No elements are filtered out.
361    if predicate.unset_bits() == 0 {
362        if let AggState::AggregatedScalar(c) | AggState::LiteralScalar(c) = &mut ac.state {
363            *c = c.as_list().into_column();
364            if c.len() == 1 && ac.groups.len() != 1 {
365                *c = c.new_from_index(0, ac.groups.len());
366            }
367            ac.state = AggState::AggregatedList(std::mem::take(c));
368            ac.update_groups = UpdateGroups::WithSeriesLen;
369        }
370        return Ok(ac);
371    }
372
373    ac.set_original_len(false);
374
375    // All elements are filtered out.
376    if predicate.set_bits() == 0 {
377        let name = ac.agg_state().name();
378        let dtype = ac.agg_state().flat_dtype();
379
380        ac.state = AggState::AggregatedList(Column::new_scalar(
381            name.clone(),
382            Scalar::new(
383                dtype.clone().implode(),
384                AnyValue::List(Series::new_empty(PlSmallStr::EMPTY, dtype)),
385            ),
386            ac.groups.len(),
387        ));
388        ac.with_update_groups(UpdateGroups::WithSeriesLen);
389        return Ok(ac);
390    }
391
392    if let AggState::AggregatedScalar(c) = &mut ac.state {
393        ac.state = AggState::NotAggregated(std::mem::take(c));
394        ac.groups = Cow::Owned(
395            {
396                let groups = predicate
397                    .iter()
398                    .enumerate_idx()
399                    .map(|(i, p)| [i, IdxSize::from(p)])
400                    .collect();
401                GroupsType::new_slice(groups, false, true)
402            }
403            .into_sliceable(),
404        );
405        ac.update_groups = UpdateGroups::No;
406        return Ok(ac);
407    }
408
409    ac.groups();
410    let predicate = BitMask::from_bitmap(predicate);
411    RAYON.install(|| {
412        let positions = GroupsType::Idx(match &**ac.groups.as_ref() {
413            GroupsType::Idx(idxs) => idxs
414                .into_par_iter()
415                .map(|(fst, idxs)| {
416                    let out = idxs
417                        .iter()
418                        .copied()
419                        .filter(|i| unsafe { predicate.get_bit_unchecked(*i as usize) })
420                        .collect::<UnitVec<IdxSize>>();
421                    (out.first().copied().unwrap_or(fst), out)
422                })
423                .collect(),
424            GroupsType::Slice {
425                groups,
426                overlapping: _,
427                monotonic: _,
428            } => groups
429                .into_par_iter()
430                .map(|[start, length]| {
431                    let predicate =
432                        unsafe { predicate.sliced_unchecked(*start as usize, *length as usize) };
433                    let num_values = predicate.set_bits();
434
435                    if num_values == 0 {
436                        (*start, UnitVec::new())
437                    } else if num_values == 1 {
438                        let item = *start + predicate.leading_zeros() as IdxSize;
439                        let mut out = UnitVec::with_capacity(1);
440                        out.push(item);
441                        (item, out)
442                    } else if num_values == *length as usize {
443                        (*start, (*start..*start + *length).collect())
444                    } else {
445                        let out = unsafe {
446                            TrustMyLength::new(
447                                (0..*length)
448                                    .filter(|i| predicate.get_bit_unchecked(*i as usize))
449                                    .map(|i| i + *start),
450                                num_values,
451                            )
452                        };
453                        let out = out.collect::<UnitVec<IdxSize>>();
454
455                        (out.first().copied().unwrap(), out)
456                    }
457                })
458                .collect(),
459        })
460        .into_sliceable();
461        ac.with_groups(positions);
462    });
463
464    Ok(ac)
465}
466
467pub fn drop_nans<'a>(
468    inputs: &[Arc<dyn PhysicalExpr>],
469    df: &DataFrame,
470    groups: &'a GroupPositions,
471    state: &ExecutionState,
472) -> PolarsResult<AggregationContext<'a>> {
473    assert_eq!(inputs.len(), 1);
474    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
475    ac.groups();
476    let predicate = if ac.agg_state().flat_dtype().is_float() {
477        let values = ac.flat_naive();
478        let mut values = values.is_nan().unwrap();
479        values.rechunk_mut();
480        values.downcast_as_array().values().clone()
481    } else {
482        Bitmap::new_with_value(false, 1)
483    };
484    let predicate = !&predicate;
485    drop_items(ac, &predicate)
486}
487
488pub fn drop_nulls<'a>(
489    inputs: &[Arc<dyn PhysicalExpr>],
490    df: &DataFrame,
491    groups: &'a GroupPositions,
492    state: &ExecutionState,
493) -> PolarsResult<AggregationContext<'a>> {
494    assert_eq!(inputs.len(), 1);
495    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
496    ac.groups();
497    let predicate = ac.flat_naive().as_ref().clone();
498    let predicate = predicate.rechunk_to_arrow(CompatLevel::newest());
499    let predicate = predicate
500        .validity()
501        .cloned()
502        .unwrap_or(Bitmap::new_with_value(true, 1));
503    drop_items(ac, &predicate)
504}
505
506pub fn quantile<'a>(
507    inputs: &[Arc<dyn PhysicalExpr>],
508    df: &DataFrame,
509    groups: &'a GroupPositions,
510    state: &ExecutionState,
511    method: QuantileMethod,
512) -> PolarsResult<AggregationContext<'a>> {
513    assert!(inputs.len() == 2);
514
515    // AggregatedScalar has no defined group structure. We fix it up here, so that we can
516    // reliably call `agg_quantile` functions with the groups.
517    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
518    ac.set_groups_for_undefined_agg_states();
519
520    // Don't change names by aggregations as is done in polars-core.
521    let keep_name = ac.get_values().name().clone();
522
523    let quantile_column = inputs[1].evaluate(df, state)?;
524    polars_ensure!(
525        quantile_column.len() <= 1,
526        ComputeError:
527            "polars only supports computing a single quantile in a groupby aggregation context"
528    );
529    polars_ensure!(
530        quantile_column.dtype().is_numeric(),
531        SchemaMismatch:
532            "expected expression of dtype 'numeric' for quantile, got '{}'",
533        quantile_column.dtype()
534    );
535    let quantile: f64 = quantile_column.get(0).unwrap().try_extract()?;
536
537    if let AggState::LiteralScalar(c) = &mut ac.state {
538        *c = c.quantile_reduce(quantile, method)?.into_column(keep_name);
539        return Ok(ac);
540    }
541
542    // SAFETY: groups are in bounds.
543    let mut agg = unsafe {
544        ac.flat_naive()
545            .into_owned()
546            .agg_quantile(ac.groups(), quantile, method)
547    };
548    agg.rename(keep_name);
549    Ok(AggregationContext::from_agg_state(
550        AggState::AggregatedScalar(agg),
551        Cow::Borrowed(groups),
552    ))
553}
554
555#[cfg(feature = "moment")]
556pub fn moment_agg<'a, S: Default>(
557    inputs: &[Arc<dyn PhysicalExpr>],
558    df: &DataFrame,
559    groups: &'a GroupPositions,
560    state: &ExecutionState,
561
562    insert_one: impl Fn(&mut S, f64) + Send + Sync,
563    new_from_slice: impl Fn(&PrimitiveArray<f64>, usize, usize) -> S + Send + Sync,
564    finalize: impl Fn(S) -> Option<f64> + Send + Sync,
565) -> PolarsResult<AggregationContext<'a>> {
566    assert_eq!(inputs.len(), 1);
567    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
568
569    if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
570        let ca = s.f64()?;
571        *s = ca
572            .iter()
573            .map(|v| {
574                v.and_then(|v| {
575                    let mut state = S::default();
576                    insert_one(&mut state, v);
577                    finalize(state)
578                })
579            })
580            .collect::<Float64Chunked>()
581            .with_name(ca.name().clone())
582            .into_column();
583        return Ok(ac);
584    }
585
586    ac.groups();
587
588    let name = ac.get_values().name().clone();
589    let ca = ac.flat_naive();
590    let ca = ca.f64()?;
591    let ca = ca.rechunk();
592    let arr = ca.downcast_as_array();
593
594    let ca = RAYON.install(|| match &**ac.groups.as_ref() {
595        GroupsType::Idx(idx) => {
596            if let Some(validity) = arr.validity().filter(|v| v.unset_bits() > 0) {
597                idx.into_par_iter()
598                    .map(|(_, idx)| {
599                        let mut state = S::default();
600                        for &i in idx.iter() {
601                            if unsafe { validity.get_bit_unchecked(i as usize) } {
602                                insert_one(&mut state, arr.values()[i as usize]);
603                            }
604                        }
605                        finalize(state)
606                    })
607                    .collect::<Float64Chunked>()
608            } else {
609                idx.into_par_iter()
610                    .map(|(_, idx)| {
611                        let mut state = S::default();
612                        for &i in idx.iter() {
613                            insert_one(&mut state, arr.values()[i as usize]);
614                        }
615                        finalize(state)
616                    })
617                    .collect::<Float64Chunked>()
618            }
619        },
620        GroupsType::Slice {
621            groups,
622            overlapping: _,
623            monotonic: _,
624        } => groups
625            .into_par_iter()
626            .map(|[start, length]| finalize(new_from_slice(arr, *start as usize, *length as usize)))
627            .collect::<Float64Chunked>(),
628    });
629
630    ac.state = AggState::AggregatedScalar(ca.with_name(name).into_column());
631    Ok(ac)
632}
633
634#[cfg(feature = "moment")]
635pub fn skew<'a>(
636    inputs: &[Arc<dyn PhysicalExpr>],
637    df: &DataFrame,
638    groups: &'a GroupPositions,
639    state: &ExecutionState,
640    bias: bool,
641) -> PolarsResult<AggregationContext<'a>> {
642    use polars_compute::moment::SkewState;
643    moment_agg::<SkewState>(
644        inputs,
645        df,
646        groups,
647        state,
648        SkewState::insert_one,
649        SkewState::from_array,
650        |s| s.finalize(bias),
651    )
652}
653
654#[cfg(feature = "moment")]
655pub fn kurtosis<'a>(
656    inputs: &[Arc<dyn PhysicalExpr>],
657    df: &DataFrame,
658    groups: &'a GroupPositions,
659    state: &ExecutionState,
660    fisher: bool,
661    bias: bool,
662) -> PolarsResult<AggregationContext<'a>> {
663    use polars_compute::moment::KurtosisState;
664    moment_agg::<KurtosisState>(
665        inputs,
666        df,
667        groups,
668        state,
669        KurtosisState::insert_one,
670        KurtosisState::from_array,
671        |s| s.finalize(fisher, bias),
672    )
673}
674
675pub fn unique<'a>(
676    inputs: &[Arc<dyn PhysicalExpr>],
677    df: &DataFrame,
678    groups: &'a GroupPositions,
679    state: &ExecutionState,
680    stable: bool,
681) -> PolarsResult<AggregationContext<'a>> {
682    _ = stable;
683
684    assert_eq!(inputs.len(), 1);
685    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
686    ac.groups();
687
688    if let AggState::AggregatedScalar(c) | AggState::LiteralScalar(c) = &mut ac.state {
689        *c = c.as_list().into_column();
690        if c.len() == 1 && ac.groups.len() != 1 {
691            *c = c.new_from_index(0, ac.groups.len());
692        }
693        ac.state = AggState::AggregatedList(std::mem::take(c));
694        ac.update_groups = UpdateGroups::WithSeriesLen;
695        return Ok(ac);
696    }
697
698    let values = ac.flat_naive().to_physical_repr();
699    let dtype = values.dtype();
700    let values = if dtype.contains_objects() {
701        polars_bail!(opq = unique, dtype);
702    } else if let Some(ca) = values.try_str() {
703        ca.as_binary().into_column()
704    } else if dtype.is_nested() {
705        encode_rows_unordered(&[values])?.into_column()
706    } else {
707        values
708    };
709
710    let values = values.rechunk_to_arrow(CompatLevel::newest());
711    let values = values.as_ref();
712    let state = amortized_unique_from_dtype(values.dtype());
713
714    struct CloneWrapper(Box<dyn AmortizedUnique>);
715    impl Clone for CloneWrapper {
716        fn clone(&self) -> Self {
717            Self(self.0.new_empty())
718        }
719    }
720
721    RAYON.install(|| {
722        let positions = GroupsType::Idx(match &**ac.groups().as_ref() {
723            GroupsType::Idx(idx) => idx
724                .into_par_iter()
725                .map_with(CloneWrapper(state), |state, (first, idx)| {
726                    let mut idx = idx.clone();
727                    unsafe { state.0.retain_unique(values, &mut idx) };
728                    (idx.first().copied().unwrap_or(first), idx)
729                })
730                .collect(),
731            GroupsType::Slice {
732                groups,
733                overlapping: _,
734                monotonic: _,
735            } => groups
736                .into_par_iter()
737                .map_with(CloneWrapper(state), |state, [start, len]| {
738                    let mut idx = UnitVec::new();
739                    state.0.arg_unique(values, &mut idx, *start, *len);
740                    (idx.first().copied().unwrap_or(*start), idx)
741                })
742                .collect(),
743        })
744        .into_sliceable();
745        ac.with_groups(positions);
746    });
747
748    Ok(ac)
749}
750
751fn fw_bw_fill_null<'a>(
752    inputs: &[Arc<dyn PhysicalExpr>],
753    df: &DataFrame,
754    groups: &'a GroupPositions,
755    state: &ExecutionState,
756    f_idx: impl Fn(
757        std::iter::Copied<std::slice::Iter<'_, IdxSize>>,
758        BitMask<'_>,
759        usize,
760    ) -> UnitVec<IdxSize>
761    + Send
762    + Sync,
763    f_range: impl Fn(std::ops::Range<IdxSize>, BitMask<'_>, usize) -> UnitVec<IdxSize> + Send + Sync,
764) -> PolarsResult<AggregationContext<'a>> {
765    assert_eq!(inputs.len(), 1);
766    let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
767    ac.groups();
768
769    if let AggState::AggregatedScalar(_) | AggState::LiteralScalar(_) = &mut ac.state {
770        return Ok(ac);
771    }
772
773    let values = ac.flat_naive();
774    let Some(validity) = values.rechunk_validity() else {
775        return Ok(ac);
776    };
777
778    let validity = BitMask::from_bitmap(&validity);
779    RAYON.install(|| {
780        let positions = GroupsType::Idx(match &**ac.groups().as_ref() {
781            GroupsType::Idx(idx) => idx
782                .into_par_iter()
783                .map(|(first, idx)| {
784                    let idx = f_idx(idx.iter().copied(), validity, idx.len());
785                    (idx.first().copied().unwrap_or(first), idx)
786                })
787                .collect(),
788            GroupsType::Slice {
789                groups,
790                overlapping: _,
791                monotonic: _,
792            } => groups
793                .into_par_iter()
794                .map(|[start, len]| {
795                    let idx = f_range(*start..*start + *len, validity, *len as usize);
796                    (idx.first().copied().unwrap_or(*start), idx)
797                })
798                .collect(),
799        })
800        .into_sliceable();
801        ac.with_groups(positions);
802    });
803
804    Ok(ac)
805}
806
807pub fn forward_fill_null<'a>(
808    inputs: &[Arc<dyn PhysicalExpr>],
809    df: &DataFrame,
810    groups: &'a GroupPositions,
811    state: &ExecutionState,
812    limit: Option<IdxSize>,
813) -> PolarsResult<AggregationContext<'a>> {
814    let limit = limit.unwrap_or(IdxSize::MAX);
815    macro_rules! arg_forward_fill {
816        (
817            $iter:ident,
818            $validity:ident,
819            $length:ident
820        ) => {{
821            |$iter, $validity, $length| {
822                let Some(start) = $iter
823                    .clone()
824                    .position(|i| unsafe { $validity.get_bit_unchecked(i as usize) })
825                else {
826                    return $iter.collect();
827                };
828
829                let mut idx = UnitVec::with_capacity($length);
830                let mut iter = $iter;
831                idx.extend((&mut iter).take(start));
832
833                let mut current_limit = limit;
834                let mut value = iter.next().unwrap();
835                idx.push(value);
836
837                idx.extend(iter.map(|i| {
838                    if unsafe { $validity.get_bit_unchecked(i as usize) } {
839                        current_limit = limit;
840                        value = i;
841                        i
842                    } else if current_limit == 0 {
843                        i
844                    } else {
845                        current_limit -= 1;
846                        value
847                    }
848                }));
849                idx
850            }
851        }};
852    }
853
854    fw_bw_fill_null(
855        inputs,
856        df,
857        groups,
858        state,
859        arg_forward_fill!(iter, validity, length),
860        arg_forward_fill!(iter, validity, length),
861    )
862}
863
864pub fn backward_fill_null<'a>(
865    inputs: &[Arc<dyn PhysicalExpr>],
866    df: &DataFrame,
867    groups: &'a GroupPositions,
868    state: &ExecutionState,
869    limit: Option<IdxSize>,
870) -> PolarsResult<AggregationContext<'a>> {
871    let limit = limit.unwrap_or(IdxSize::MAX);
872    macro_rules! arg_backward_fill {
873        (
874            $iter:ident,
875            $validity:ident,
876            $length:ident
877        ) => {{
878            |$iter, $validity, $length| {
879                let Some(start) = $iter
880                    .clone()
881                    .rev()
882                    .position(|i| unsafe { $validity.get_bit_unchecked(i as usize) })
883                else {
884                    return $iter.collect();
885                };
886
887                let mut idx = UnitVec::from_iter($iter);
888                let mut current_limit = limit;
889                let mut value = idx[$length - start - 1];
890                for i in idx[..$length - start].iter_mut().rev() {
891                    if unsafe { $validity.get_bit_unchecked(*i as usize) } {
892                        current_limit = limit;
893                        value = *i;
894                    } else if current_limit != 0 {
895                        current_limit -= 1;
896                        *i = value;
897                    }
898                }
899
900                idx
901            }
902        }};
903    }
904
905    fw_bw_fill_null(
906        inputs,
907        df,
908        groups,
909        state,
910        arg_backward_fill!(iter, validity, length),
911        arg_backward_fill!(iter, validity, length),
912    )
913}