1use std::borrow::Cow;
2use std::sync::Arc;
3
4use arrow::array::PrimitiveArray;
5use arrow::bitmap::Bitmap;
6use arrow::bitmap::bitmask::BitMask;
7use arrow::trusted_len::TrustMyLength;
8use polars_compute::rolling::QuantileMethod;
9use polars_compute::unique::{AmortizedUnique, amortized_unique_from_dtype};
10use polars_core::error::{PolarsResult, polars_bail, polars_ensure};
11use polars_core::frame::DataFrame;
12use polars_core::prelude::row_encode::encode_rows_unordered;
13use polars_core::prelude::{
14 AnyValue, BooleanChunked, ChunkCast, Column, CompatLevel, Float64Chunked, GroupPositions,
15 GroupsType, IDX_DTYPE, IntoColumn,
16};
17use polars_core::runtime::RAYON;
18use polars_core::scalar::Scalar;
19use polars_core::series::{ChunkCompareEq, Series};
20use polars_utils::itertools::Itertools;
21use polars_utils::pl_str::PlSmallStr;
22use polars_utils::{IdxSize, UnitVec};
23use rayon::iter::{IntoParallelIterator, ParallelIterator};
24
25use crate::expressions::evaluate_count_on_ac;
26use crate::prelude::{AggState, AggregationContext, PhysicalExpr, UpdateGroups};
27use crate::state::ExecutionState;
28
29pub fn reverse<'a>(
30 inputs: &[Arc<dyn PhysicalExpr>],
31 df: &DataFrame,
32 groups: &'a GroupPositions,
33 state: &ExecutionState,
34) -> PolarsResult<AggregationContext<'a>> {
35 assert_eq!(inputs.len(), 1);
36
37 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
38
39 if let AggState::AggregatedScalar(_) | AggState::LiteralScalar(_) = &ac.agg_state() {
41 return Ok(ac);
42 }
43
44 RAYON.install(|| {
45 let positions = GroupsType::Idx(match &**ac.groups().as_ref() {
46 GroupsType::Idx(idx) => idx
47 .into_par_iter()
48 .map(|(first, idx)| {
49 (
50 idx.last().copied().unwrap_or(first),
51 idx.iter().copied().rev().collect(),
52 )
53 })
54 .collect(),
55 GroupsType::Slice {
56 groups,
57 overlapping: _,
58 monotonic: _,
59 } => groups
60 .into_par_iter()
61 .map(|[start, len]| {
62 (
63 start + len.saturating_sub(1),
64 (*start..*start + *len).rev().collect(),
65 )
66 })
67 .collect(),
68 })
69 .into_sliceable();
70 ac.with_groups(positions);
71 });
72
73 Ok(ac)
74}
75
76pub fn null_count<'a>(
77 inputs: &[Arc<dyn PhysicalExpr>],
78 df: &DataFrame,
79 groups: &'a GroupPositions,
80 state: &ExecutionState,
81) -> PolarsResult<AggregationContext<'a>> {
82 assert_eq!(inputs.len(), 1);
83
84 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
85
86 if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
87 *s = s.is_null().cast(&IDX_DTYPE).unwrap().into_column();
88 return Ok(ac);
89 }
90
91 ac.groups();
92 let values = ac.flat_naive();
93 let name = values.name().clone();
94 let Some(validity) = values.rechunk_validity() else {
95 ac.state = AggState::AggregatedScalar(Column::new_scalar(
96 name,
97 (0 as IdxSize).into(),
98 groups.len(),
99 ));
100 return Ok(ac);
101 };
102
103 RAYON.install(|| {
104 let validity = BitMask::from_bitmap(&validity);
105 let null_count: Vec<IdxSize> = match &**ac.groups.as_ref() {
106 GroupsType::Idx(idx) => idx
107 .into_par_iter()
108 .map(|(_, idx)| {
109 idx.iter()
110 .map(|i| IdxSize::from(!unsafe { validity.get_bit_unchecked(*i as usize) }))
111 .sum::<IdxSize>()
112 })
113 .collect(),
114 GroupsType::Slice {
115 groups,
116 overlapping: _,
117 monotonic: _,
118 } => groups
119 .into_par_iter()
120 .map(|[start, length]| {
121 unsafe { validity.sliced_unchecked(*start as usize, *length as usize) }
122 .unset_bits() as IdxSize
123 })
124 .collect(),
125 };
126
127 ac.state = AggState::AggregatedScalar(Column::new(name, null_count));
128 });
129
130 Ok(ac)
131}
132
133pub fn has_nulls<'a>(
134 inputs: &[Arc<dyn PhysicalExpr>],
135 df: &DataFrame,
136 groups: &'a GroupPositions,
137 state: &ExecutionState,
138) -> PolarsResult<AggregationContext<'a>> {
139 assert_eq!(inputs.len(), 1);
140
141 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
142
143 if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
144 *s = s.is_null().into_column();
145 return Ok(ac);
146 }
147
148 ac.groups();
149 let values = ac.flat_naive();
150 let name = values.name().clone();
151 let Some(validity) = values.rechunk_validity() else {
152 ac.state = AggState::AggregatedScalar(Column::new_scalar(name, false.into(), groups.len()));
153 return Ok(ac);
154 };
155
156 RAYON.install(|| {
157 let validity = BitMask::from_bitmap(&validity);
158 let has_nulls: BooleanChunked = match &**ac.groups.as_ref() {
159 GroupsType::Idx(idx) => idx
160 .into_par_iter()
161 .map(|(_, idx)| {
162 idx.iter()
163 .any(|i| !unsafe { validity.get_bit_unchecked(*i as usize) })
164 })
165 .collect(),
166 GroupsType::Slice {
167 groups,
168 overlapping: _,
169 monotonic: _,
170 } => groups
171 .into_par_iter()
172 .map(|[start, length]| {
173 unsafe { validity.sliced_unchecked(*start as usize, *length as usize) }
174 .unset_bits()
175 > 0
176 })
177 .collect(),
178 };
179
180 ac.state = AggState::AggregatedScalar(has_nulls.with_name(name).into_column());
181 });
182
183 Ok(ac)
184}
185
186pub fn any<'a>(
187 inputs: &[Arc<dyn PhysicalExpr>],
188 df: &DataFrame,
189 groups: &'a GroupPositions,
190 state: &ExecutionState,
191 ignore_nulls: bool,
192) -> PolarsResult<AggregationContext<'a>> {
193 assert_eq!(inputs.len(), 1);
194
195 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
196
197 if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
198 if ignore_nulls {
199 *s = s
200 .equal_missing(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
201 .unwrap()
202 .into_column();
203 } else {
204 *s = s
205 .equal(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
206 .unwrap()
207 .into_column();
208 }
209 return Ok(ac);
210 }
211
212 ac.groups();
213 let values = ac.flat_naive();
214 let values = values.bool()?;
215 let out = unsafe { values.agg_any(ac.groups.as_ref(), ignore_nulls) };
216 ac.state = AggState::AggregatedScalar(out.into_column());
217
218 Ok(ac)
219}
220
221pub fn all<'a>(
222 inputs: &[Arc<dyn PhysicalExpr>],
223 df: &DataFrame,
224 groups: &'a GroupPositions,
225 state: &ExecutionState,
226 ignore_nulls: bool,
227) -> PolarsResult<AggregationContext<'a>> {
228 assert_eq!(inputs.len(), 1);
229
230 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
231
232 if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
233 if ignore_nulls {
234 *s = s
235 .equal_missing(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
236 .unwrap()
237 .into_column();
238 } else {
239 *s = s
240 .equal(&Column::new_scalar(PlSmallStr::EMPTY, true.into(), 1))
241 .unwrap()
242 .into_column();
243 }
244 return Ok(ac);
245 }
246
247 ac.groups();
248 let values = ac.flat_naive();
249 let values = values.bool()?;
250 let out = unsafe { values.agg_all(ac.groups.as_ref(), ignore_nulls) };
251 ac.state = AggState::AggregatedScalar(out.into_column());
252
253 Ok(ac)
254}
255
256pub fn is_empty<'a>(
257 inputs: &[Arc<dyn PhysicalExpr>],
258 df: &DataFrame,
259 groups: &'a GroupPositions,
260 state: &ExecutionState,
261 ignore_nulls: bool,
262) -> PolarsResult<AggregationContext<'a>> {
263 assert_eq!(inputs.len(), 1);
264
265 let ac = inputs[0].evaluate_on_groups(df, groups, state)?;
267 let counts = evaluate_count_on_ac(ac, !ignore_nulls)?;
268 let is_empty = counts.equal(&Column::new_scalar(
269 PlSmallStr::EMPTY,
270 Scalar::new_idxsize(0),
271 1,
272 ))?;
273 Ok(AggregationContext::from_agg_state(
274 AggState::AggregatedScalar(is_empty.into_column()),
275 Cow::Borrowed(groups),
276 ))
277}
278
279#[cfg(feature = "bitwise")]
280pub fn bitwise_agg<'a>(
281 inputs: &[Arc<dyn PhysicalExpr>],
282 df: &DataFrame,
283 groups: &'a GroupPositions,
284 state: &ExecutionState,
285 op: &'static str,
286 f: impl Fn(&Column, &GroupsType) -> Column,
287) -> PolarsResult<AggregationContext<'a>> {
288 assert_eq!(inputs.len(), 1);
289
290 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
291
292 if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &ac.state {
293 let dtype = s.dtype();
294 polars_ensure!(
295 dtype.is_bool() | dtype.is_primitive_numeric(),
296 op = op,
297 dtype
298 );
299 return Ok(ac);
300 }
301
302 ac.groups();
303 let values = ac.flat_naive();
304 let out = f(values.as_ref(), ac.groups.as_ref());
305 ac.state = AggState::AggregatedScalar(out.into_column());
306
307 Ok(ac)
308}
309
310#[cfg(feature = "bitwise")]
311pub fn bitwise_and<'a>(
312 inputs: &[Arc<dyn PhysicalExpr>],
313 df: &DataFrame,
314 groups: &'a GroupPositions,
315 state: &ExecutionState,
316) -> PolarsResult<AggregationContext<'a>> {
317 bitwise_agg(
318 inputs,
319 df,
320 groups,
321 state,
322 "and_reduce",
323 |v, groups| unsafe { v.agg_and(groups) },
324 )
325}
326
327#[cfg(feature = "bitwise")]
328pub fn bitwise_or<'a>(
329 inputs: &[Arc<dyn PhysicalExpr>],
330 df: &DataFrame,
331 groups: &'a GroupPositions,
332 state: &ExecutionState,
333) -> PolarsResult<AggregationContext<'a>> {
334 bitwise_agg(inputs, df, groups, state, "or_reduce", |v, groups| unsafe {
335 v.agg_or(groups)
336 })
337}
338
339#[cfg(feature = "bitwise")]
340pub fn bitwise_xor<'a>(
341 inputs: &[Arc<dyn PhysicalExpr>],
342 df: &DataFrame,
343 groups: &'a GroupPositions,
344 state: &ExecutionState,
345) -> PolarsResult<AggregationContext<'a>> {
346 bitwise_agg(
347 inputs,
348 df,
349 groups,
350 state,
351 "xor_reduce",
352 |v, groups| unsafe { v.agg_xor(groups) },
353 )
354}
355
356pub fn drop_items<'a>(
357 mut ac: AggregationContext<'a>,
358 predicate: &Bitmap,
359) -> PolarsResult<AggregationContext<'a>> {
360 if predicate.unset_bits() == 0 {
362 if let AggState::AggregatedScalar(c) | AggState::LiteralScalar(c) = &mut ac.state {
363 *c = c.as_list().into_column();
364 if c.len() == 1 && ac.groups.len() != 1 {
365 *c = c.new_from_index(0, ac.groups.len());
366 }
367 ac.state = AggState::AggregatedList(std::mem::take(c));
368 ac.update_groups = UpdateGroups::WithSeriesLen;
369 }
370 return Ok(ac);
371 }
372
373 ac.set_original_len(false);
374
375 if predicate.set_bits() == 0 {
377 let name = ac.agg_state().name();
378 let dtype = ac.agg_state().flat_dtype();
379
380 ac.state = AggState::AggregatedList(Column::new_scalar(
381 name.clone(),
382 Scalar::new(
383 dtype.clone().implode(),
384 AnyValue::List(Series::new_empty(PlSmallStr::EMPTY, dtype)),
385 ),
386 ac.groups.len(),
387 ));
388 ac.with_update_groups(UpdateGroups::WithSeriesLen);
389 return Ok(ac);
390 }
391
392 if let AggState::AggregatedScalar(c) = &mut ac.state {
393 ac.state = AggState::NotAggregated(std::mem::take(c));
394 ac.groups = Cow::Owned(
395 {
396 let groups = predicate
397 .iter()
398 .enumerate_idx()
399 .map(|(i, p)| [i, IdxSize::from(p)])
400 .collect();
401 GroupsType::new_slice(groups, false, true)
402 }
403 .into_sliceable(),
404 );
405 ac.update_groups = UpdateGroups::No;
406 return Ok(ac);
407 }
408
409 ac.groups();
410 let predicate = BitMask::from_bitmap(predicate);
411 RAYON.install(|| {
412 let positions = GroupsType::Idx(match &**ac.groups.as_ref() {
413 GroupsType::Idx(idxs) => idxs
414 .into_par_iter()
415 .map(|(fst, idxs)| {
416 let out = idxs
417 .iter()
418 .copied()
419 .filter(|i| unsafe { predicate.get_bit_unchecked(*i as usize) })
420 .collect::<UnitVec<IdxSize>>();
421 (out.first().copied().unwrap_or(fst), out)
422 })
423 .collect(),
424 GroupsType::Slice {
425 groups,
426 overlapping: _,
427 monotonic: _,
428 } => groups
429 .into_par_iter()
430 .map(|[start, length]| {
431 let predicate =
432 unsafe { predicate.sliced_unchecked(*start as usize, *length as usize) };
433 let num_values = predicate.set_bits();
434
435 if num_values == 0 {
436 (*start, UnitVec::new())
437 } else if num_values == 1 {
438 let item = *start + predicate.leading_zeros() as IdxSize;
439 let mut out = UnitVec::with_capacity(1);
440 out.push(item);
441 (item, out)
442 } else if num_values == *length as usize {
443 (*start, (*start..*start + *length).collect())
444 } else {
445 let out = unsafe {
446 TrustMyLength::new(
447 (0..*length)
448 .filter(|i| predicate.get_bit_unchecked(*i as usize))
449 .map(|i| i + *start),
450 num_values,
451 )
452 };
453 let out = out.collect::<UnitVec<IdxSize>>();
454
455 (out.first().copied().unwrap(), out)
456 }
457 })
458 .collect(),
459 })
460 .into_sliceable();
461 ac.with_groups(positions);
462 });
463
464 Ok(ac)
465}
466
467pub fn drop_nans<'a>(
468 inputs: &[Arc<dyn PhysicalExpr>],
469 df: &DataFrame,
470 groups: &'a GroupPositions,
471 state: &ExecutionState,
472) -> PolarsResult<AggregationContext<'a>> {
473 assert_eq!(inputs.len(), 1);
474 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
475 ac.groups();
476 let predicate = if ac.agg_state().flat_dtype().is_float() {
477 let values = ac.flat_naive();
478 let mut values = values.is_nan().unwrap();
479 values.rechunk_mut();
480 values.downcast_as_array().values().clone()
481 } else {
482 Bitmap::new_with_value(false, 1)
483 };
484 let predicate = !&predicate;
485 drop_items(ac, &predicate)
486}
487
488pub fn drop_nulls<'a>(
489 inputs: &[Arc<dyn PhysicalExpr>],
490 df: &DataFrame,
491 groups: &'a GroupPositions,
492 state: &ExecutionState,
493) -> PolarsResult<AggregationContext<'a>> {
494 assert_eq!(inputs.len(), 1);
495 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
496 ac.groups();
497 let predicate = ac.flat_naive().as_ref().clone();
498 let predicate = predicate.rechunk_to_arrow(CompatLevel::newest());
499 let predicate = predicate
500 .validity()
501 .cloned()
502 .unwrap_or(Bitmap::new_with_value(true, 1));
503 drop_items(ac, &predicate)
504}
505
506pub fn quantile<'a>(
507 inputs: &[Arc<dyn PhysicalExpr>],
508 df: &DataFrame,
509 groups: &'a GroupPositions,
510 state: &ExecutionState,
511 method: QuantileMethod,
512) -> PolarsResult<AggregationContext<'a>> {
513 assert!(inputs.len() == 2);
514
515 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
518 ac.set_groups_for_undefined_agg_states();
519
520 let keep_name = ac.get_values().name().clone();
522
523 let quantile_column = inputs[1].evaluate(df, state)?;
524 polars_ensure!(
525 quantile_column.len() <= 1,
526 ComputeError:
527 "polars only supports computing a single quantile in a groupby aggregation context"
528 );
529 polars_ensure!(
530 quantile_column.dtype().is_numeric(),
531 SchemaMismatch:
532 "expected expression of dtype 'numeric' for quantile, got '{}'",
533 quantile_column.dtype()
534 );
535 let quantile: f64 = quantile_column.get(0).unwrap().try_extract()?;
536
537 if let AggState::LiteralScalar(c) = &mut ac.state {
538 *c = c.quantile_reduce(quantile, method)?.into_column(keep_name);
539 return Ok(ac);
540 }
541
542 let mut agg = unsafe {
544 ac.flat_naive()
545 .into_owned()
546 .agg_quantile(ac.groups(), quantile, method)
547 };
548 agg.rename(keep_name);
549 Ok(AggregationContext::from_agg_state(
550 AggState::AggregatedScalar(agg),
551 Cow::Borrowed(groups),
552 ))
553}
554
555#[cfg(feature = "moment")]
556pub fn moment_agg<'a, S: Default>(
557 inputs: &[Arc<dyn PhysicalExpr>],
558 df: &DataFrame,
559 groups: &'a GroupPositions,
560 state: &ExecutionState,
561
562 insert_one: impl Fn(&mut S, f64) + Send + Sync,
563 new_from_slice: impl Fn(&PrimitiveArray<f64>, usize, usize) -> S + Send + Sync,
564 finalize: impl Fn(S) -> Option<f64> + Send + Sync,
565) -> PolarsResult<AggregationContext<'a>> {
566 assert_eq!(inputs.len(), 1);
567 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
568
569 if let AggState::AggregatedScalar(s) | AggState::LiteralScalar(s) = &mut ac.state {
570 let ca = s.f64()?;
571 *s = ca
572 .iter()
573 .map(|v| {
574 v.and_then(|v| {
575 let mut state = S::default();
576 insert_one(&mut state, v);
577 finalize(state)
578 })
579 })
580 .collect::<Float64Chunked>()
581 .with_name(ca.name().clone())
582 .into_column();
583 return Ok(ac);
584 }
585
586 ac.groups();
587
588 let name = ac.get_values().name().clone();
589 let ca = ac.flat_naive();
590 let ca = ca.f64()?;
591 let ca = ca.rechunk();
592 let arr = ca.downcast_as_array();
593
594 let ca = RAYON.install(|| match &**ac.groups.as_ref() {
595 GroupsType::Idx(idx) => {
596 if let Some(validity) = arr.validity().filter(|v| v.unset_bits() > 0) {
597 idx.into_par_iter()
598 .map(|(_, idx)| {
599 let mut state = S::default();
600 for &i in idx.iter() {
601 if unsafe { validity.get_bit_unchecked(i as usize) } {
602 insert_one(&mut state, arr.values()[i as usize]);
603 }
604 }
605 finalize(state)
606 })
607 .collect::<Float64Chunked>()
608 } else {
609 idx.into_par_iter()
610 .map(|(_, idx)| {
611 let mut state = S::default();
612 for &i in idx.iter() {
613 insert_one(&mut state, arr.values()[i as usize]);
614 }
615 finalize(state)
616 })
617 .collect::<Float64Chunked>()
618 }
619 },
620 GroupsType::Slice {
621 groups,
622 overlapping: _,
623 monotonic: _,
624 } => groups
625 .into_par_iter()
626 .map(|[start, length]| finalize(new_from_slice(arr, *start as usize, *length as usize)))
627 .collect::<Float64Chunked>(),
628 });
629
630 ac.state = AggState::AggregatedScalar(ca.with_name(name).into_column());
631 Ok(ac)
632}
633
634#[cfg(feature = "moment")]
635pub fn skew<'a>(
636 inputs: &[Arc<dyn PhysicalExpr>],
637 df: &DataFrame,
638 groups: &'a GroupPositions,
639 state: &ExecutionState,
640 bias: bool,
641) -> PolarsResult<AggregationContext<'a>> {
642 use polars_compute::moment::SkewState;
643 moment_agg::<SkewState>(
644 inputs,
645 df,
646 groups,
647 state,
648 SkewState::insert_one,
649 SkewState::from_array,
650 |s| s.finalize(bias),
651 )
652}
653
654#[cfg(feature = "moment")]
655pub fn kurtosis<'a>(
656 inputs: &[Arc<dyn PhysicalExpr>],
657 df: &DataFrame,
658 groups: &'a GroupPositions,
659 state: &ExecutionState,
660 fisher: bool,
661 bias: bool,
662) -> PolarsResult<AggregationContext<'a>> {
663 use polars_compute::moment::KurtosisState;
664 moment_agg::<KurtosisState>(
665 inputs,
666 df,
667 groups,
668 state,
669 KurtosisState::insert_one,
670 KurtosisState::from_array,
671 |s| s.finalize(fisher, bias),
672 )
673}
674
675pub fn unique<'a>(
676 inputs: &[Arc<dyn PhysicalExpr>],
677 df: &DataFrame,
678 groups: &'a GroupPositions,
679 state: &ExecutionState,
680 stable: bool,
681) -> PolarsResult<AggregationContext<'a>> {
682 _ = stable;
683
684 assert_eq!(inputs.len(), 1);
685 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
686 ac.groups();
687
688 if let AggState::AggregatedScalar(c) | AggState::LiteralScalar(c) = &mut ac.state {
689 *c = c.as_list().into_column();
690 if c.len() == 1 && ac.groups.len() != 1 {
691 *c = c.new_from_index(0, ac.groups.len());
692 }
693 ac.state = AggState::AggregatedList(std::mem::take(c));
694 ac.update_groups = UpdateGroups::WithSeriesLen;
695 return Ok(ac);
696 }
697
698 let values = ac.flat_naive().to_physical_repr();
699 let dtype = values.dtype();
700 let values = if dtype.contains_objects() {
701 polars_bail!(opq = unique, dtype);
702 } else if let Some(ca) = values.try_str() {
703 ca.as_binary().into_column()
704 } else if dtype.is_nested() {
705 encode_rows_unordered(&[values])?.into_column()
706 } else {
707 values
708 };
709
710 let values = values.rechunk_to_arrow(CompatLevel::newest());
711 let values = values.as_ref();
712 let state = amortized_unique_from_dtype(values.dtype());
713
714 struct CloneWrapper(Box<dyn AmortizedUnique>);
715 impl Clone for CloneWrapper {
716 fn clone(&self) -> Self {
717 Self(self.0.new_empty())
718 }
719 }
720
721 RAYON.install(|| {
722 let positions = GroupsType::Idx(match &**ac.groups().as_ref() {
723 GroupsType::Idx(idx) => idx
724 .into_par_iter()
725 .map_with(CloneWrapper(state), |state, (first, idx)| {
726 let mut idx = idx.clone();
727 unsafe { state.0.retain_unique(values, &mut idx) };
728 (idx.first().copied().unwrap_or(first), idx)
729 })
730 .collect(),
731 GroupsType::Slice {
732 groups,
733 overlapping: _,
734 monotonic: _,
735 } => groups
736 .into_par_iter()
737 .map_with(CloneWrapper(state), |state, [start, len]| {
738 let mut idx = UnitVec::new();
739 state.0.arg_unique(values, &mut idx, *start, *len);
740 (idx.first().copied().unwrap_or(*start), idx)
741 })
742 .collect(),
743 })
744 .into_sliceable();
745 ac.with_groups(positions);
746 });
747
748 Ok(ac)
749}
750
751fn fw_bw_fill_null<'a>(
752 inputs: &[Arc<dyn PhysicalExpr>],
753 df: &DataFrame,
754 groups: &'a GroupPositions,
755 state: &ExecutionState,
756 f_idx: impl Fn(
757 std::iter::Copied<std::slice::Iter<'_, IdxSize>>,
758 BitMask<'_>,
759 usize,
760 ) -> UnitVec<IdxSize>
761 + Send
762 + Sync,
763 f_range: impl Fn(std::ops::Range<IdxSize>, BitMask<'_>, usize) -> UnitVec<IdxSize> + Send + Sync,
764) -> PolarsResult<AggregationContext<'a>> {
765 assert_eq!(inputs.len(), 1);
766 let mut ac = inputs[0].evaluate_on_groups(df, groups, state)?;
767 ac.groups();
768
769 if let AggState::AggregatedScalar(_) | AggState::LiteralScalar(_) = &mut ac.state {
770 return Ok(ac);
771 }
772
773 let values = ac.flat_naive();
774 let Some(validity) = values.rechunk_validity() else {
775 return Ok(ac);
776 };
777
778 let validity = BitMask::from_bitmap(&validity);
779 RAYON.install(|| {
780 let positions = GroupsType::Idx(match &**ac.groups().as_ref() {
781 GroupsType::Idx(idx) => idx
782 .into_par_iter()
783 .map(|(first, idx)| {
784 let idx = f_idx(idx.iter().copied(), validity, idx.len());
785 (idx.first().copied().unwrap_or(first), idx)
786 })
787 .collect(),
788 GroupsType::Slice {
789 groups,
790 overlapping: _,
791 monotonic: _,
792 } => groups
793 .into_par_iter()
794 .map(|[start, len]| {
795 let idx = f_range(*start..*start + *len, validity, *len as usize);
796 (idx.first().copied().unwrap_or(*start), idx)
797 })
798 .collect(),
799 })
800 .into_sliceable();
801 ac.with_groups(positions);
802 });
803
804 Ok(ac)
805}
806
807pub fn forward_fill_null<'a>(
808 inputs: &[Arc<dyn PhysicalExpr>],
809 df: &DataFrame,
810 groups: &'a GroupPositions,
811 state: &ExecutionState,
812 limit: Option<IdxSize>,
813) -> PolarsResult<AggregationContext<'a>> {
814 let limit = limit.unwrap_or(IdxSize::MAX);
815 macro_rules! arg_forward_fill {
816 (
817 $iter:ident,
818 $validity:ident,
819 $length:ident
820 ) => {{
821 |$iter, $validity, $length| {
822 let Some(start) = $iter
823 .clone()
824 .position(|i| unsafe { $validity.get_bit_unchecked(i as usize) })
825 else {
826 return $iter.collect();
827 };
828
829 let mut idx = UnitVec::with_capacity($length);
830 let mut iter = $iter;
831 idx.extend((&mut iter).take(start));
832
833 let mut current_limit = limit;
834 let mut value = iter.next().unwrap();
835 idx.push(value);
836
837 idx.extend(iter.map(|i| {
838 if unsafe { $validity.get_bit_unchecked(i as usize) } {
839 current_limit = limit;
840 value = i;
841 i
842 } else if current_limit == 0 {
843 i
844 } else {
845 current_limit -= 1;
846 value
847 }
848 }));
849 idx
850 }
851 }};
852 }
853
854 fw_bw_fill_null(
855 inputs,
856 df,
857 groups,
858 state,
859 arg_forward_fill!(iter, validity, length),
860 arg_forward_fill!(iter, validity, length),
861 )
862}
863
864pub fn backward_fill_null<'a>(
865 inputs: &[Arc<dyn PhysicalExpr>],
866 df: &DataFrame,
867 groups: &'a GroupPositions,
868 state: &ExecutionState,
869 limit: Option<IdxSize>,
870) -> PolarsResult<AggregationContext<'a>> {
871 let limit = limit.unwrap_or(IdxSize::MAX);
872 macro_rules! arg_backward_fill {
873 (
874 $iter:ident,
875 $validity:ident,
876 $length:ident
877 ) => {{
878 |$iter, $validity, $length| {
879 let Some(start) = $iter
880 .clone()
881 .rev()
882 .position(|i| unsafe { $validity.get_bit_unchecked(i as usize) })
883 else {
884 return $iter.collect();
885 };
886
887 let mut idx = UnitVec::from_iter($iter);
888 let mut current_limit = limit;
889 let mut value = idx[$length - start - 1];
890 for i in idx[..$length - start].iter_mut().rev() {
891 if unsafe { $validity.get_bit_unchecked(*i as usize) } {
892 current_limit = limit;
893 value = *i;
894 } else if current_limit != 0 {
895 current_limit -= 1;
896 *i = value;
897 }
898 }
899
900 idx
901 }
902 }};
903 }
904
905 fw_bw_fill_null(
906 inputs,
907 df,
908 groups,
909 state,
910 arg_backward_fill!(iter, validity, length),
911 arg_backward_fill!(iter, validity, length),
912 )
913}