1use std::cmp::Ordering;
2use std::fmt::Write;
3
4use arrow::array::PrimitiveArray;
5use arrow::bitmap::Bitmap;
6use arrow::trusted_len::TrustMyLength;
7use polars_core::error::feature_gated;
8use polars_core::prelude::row_encode::encode_rows_unordered;
9use polars_core::prelude::sort::perfect_sort;
10use polars_core::prelude::*;
11use polars_core::series::IsSorted;
12use polars_core::utils::_split_offsets;
13use polars_core::{POOL, downcast_as_macro_arg_physical};
14use polars_ops::frame::SeriesJoin;
15use polars_ops::frame::join::{ChunkJoinOptIds, private_left_join_multiple_keys};
16use polars_ops::prelude::*;
17use polars_plan::prelude::*;
18use polars_utils::UnitVec;
19use polars_utils::sync::SyncPtr;
20use polars_utils::vec::PushUnchecked;
21use rayon::prelude::*;
22
23use super::*;
24
25pub struct WindowExpr {
26 pub(crate) group_by: Vec<Arc<dyn PhysicalExpr>>,
29 pub(crate) order_by: Option<(Arc<dyn PhysicalExpr>, SortOptions)>,
30 pub(crate) apply_columns: Vec<PlSmallStr>,
31 pub(crate) phys_function: Arc<dyn PhysicalExpr>,
32 pub(crate) mapping: WindowMapping,
33 pub(crate) expr: Expr,
34 pub(crate) has_different_group_sources: bool,
35 pub(crate) output_field: Field,
36
37 pub(crate) all_group_by_are_elementwise: bool,
38 pub(crate) order_by_is_elementwise: bool,
39}
40
41#[cfg_attr(debug_assertions, derive(Debug))]
42enum MapStrategy {
43 Join,
46 Explode,
48 Map,
50 Nothing,
51}
52
53impl WindowExpr {
54 fn map_list_agg_by_arg_sort(
55 &self,
56 out_column: Column,
57 flattened: &Column,
58 mut ac: AggregationContext,
59 gb: GroupBy,
60 ) -> PolarsResult<IdxCa> {
61 let mut idx_mapping = Vec::with_capacity(out_column.len());
63
64 let mut take_idx = vec![];
67
68 if std::ptr::eq(ac.groups().as_ref(), gb.get_groups()) {
70 let mut iter = 0..flattened.len() as IdxSize;
71 match ac.groups().as_ref().as_ref() {
72 GroupsType::Idx(groups) => {
73 for g in groups.all() {
74 idx_mapping.extend(g.iter().copied().zip(&mut iter));
75 }
76 },
77 GroupsType::Slice { groups, .. } => {
78 for &[first, len] in groups {
79 idx_mapping.extend((first..first + len).zip(&mut iter));
80 }
81 },
82 }
83 }
84 else {
87 let mut original_idx = Vec::with_capacity(out_column.len());
88 match gb.get_groups().as_ref() {
89 GroupsType::Idx(groups) => {
90 for g in groups.all() {
91 original_idx.extend_from_slice(g)
92 }
93 },
94 GroupsType::Slice { groups, .. } => {
95 for &[first, len] in groups {
96 original_idx.extend(first..first + len)
97 }
98 },
99 };
100
101 let mut original_idx_iter = original_idx.iter().copied();
102
103 match ac.groups().as_ref().as_ref() {
104 GroupsType::Idx(groups) => {
105 for g in groups.all() {
106 idx_mapping.extend(g.iter().copied().zip(&mut original_idx_iter));
107 }
108 },
109 GroupsType::Slice { groups, .. } => {
110 for &[first, len] in groups {
111 idx_mapping.extend((first..first + len).zip(&mut original_idx_iter));
112 }
113 },
114 }
115 original_idx.clear();
116 take_idx = original_idx;
117 }
118 unsafe { perfect_sort(&idx_mapping, &mut take_idx) };
121 Ok(IdxCa::from_vec(PlSmallStr::EMPTY, take_idx))
122 }
123
124 #[allow(clippy::too_many_arguments)]
125 fn map_by_arg_sort(
126 &self,
127 df: &DataFrame,
128 out_column: Column,
129 flattened: &Column,
130 mut ac: AggregationContext,
131 group_by_columns: &[Column],
132 gb: GroupBy,
133 cache_key: String,
134 state: &ExecutionState,
135 ) -> PolarsResult<Column> {
136 if flattened.len() != df.height() {
162 let ca = out_column.list().unwrap();
163 let non_matching_group =
164 ca.into_iter()
165 .zip(ac.groups().iter())
166 .find(|(output, group)| {
167 if let Some(output) = output {
168 output.as_ref().len() != group.len()
169 } else {
170 false
171 }
172 });
173
174 if let Some((output, group)) = non_matching_group {
175 let first = group.first();
176 let group = group_by_columns
177 .iter()
178 .map(|s| format!("{}", s.get(first as usize).unwrap()))
179 .collect::<Vec<_>>();
180 polars_bail!(
181 expr = self.expr, ShapeMismatch:
182 "the length of the window expression did not match that of the group\
183 \n> group: {}\n> group length: {}\n> output: '{:?}'",
184 comma_delimited(String::new(), &group), group.len(), output.unwrap()
185 );
186 } else {
187 polars_bail!(
188 expr = self.expr, ShapeMismatch:
189 "the length of the window expression did not match that of the group"
190 );
191 };
192 }
193
194 let idx = if state.cache_window() {
195 if let Some(idx) = state.window_cache.get_map(&cache_key) {
196 idx
197 } else {
198 let idx = Arc::new(self.map_list_agg_by_arg_sort(out_column, flattened, ac, gb)?);
199 state.window_cache.insert_map(cache_key, idx.clone());
200 idx
201 }
202 } else {
203 Arc::new(self.map_list_agg_by_arg_sort(out_column, flattened, ac, gb)?)
204 };
205
206 unsafe { Ok(flattened.take_unchecked(&idx)) }
209 }
210
211 fn run_aggregation<'a>(
212 &self,
213 df: &DataFrame,
214 state: &ExecutionState,
215 gb: &'a GroupBy,
216 ) -> PolarsResult<AggregationContext<'a>> {
217 let ac = self
218 .phys_function
219 .evaluate_on_groups(df, gb.get_groups(), state)?;
220 Ok(ac)
221 }
222
223 fn is_explicit_list_agg(&self) -> bool {
224 let mut explicit_list = false;
234 for e in &self.expr {
235 if let Expr::Over { function, .. } = e {
236 let mut finishes_list = false;
238 for e in &**function {
239 match e {
240 Expr::Agg(AggExpr::Implode(_)) => {
241 finishes_list = true;
242 },
243 Expr::Alias(_, _) => {},
244 _ => break,
245 }
246 }
247 explicit_list = finishes_list;
248 }
249 }
250
251 explicit_list
252 }
253
254 fn is_simple_column_expr(&self) -> bool {
255 let mut simple_col = false;
258 for e in &self.expr {
259 if let Expr::Over { function, .. } = e {
260 for e in &**function {
262 match e {
263 Expr::Column(_) => {
264 simple_col = true;
265 },
266 Expr::Alias(_, _) => {},
267 _ => break,
268 }
269 }
270 }
271 }
272 simple_col
273 }
274
275 fn is_aggregation(&self) -> bool {
276 let mut agg_col = false;
279 for e in &self.expr {
280 if let Expr::Over { function, .. } = e {
281 for e in &**function {
283 match e {
284 Expr::Agg(_) => {
285 agg_col = true;
286 },
287 Expr::Alias(_, _) => {},
288 _ => break,
289 }
290 }
291 }
292 }
293 agg_col
294 }
295
296 fn determine_map_strategy(
297 &self,
298 ac: &mut AggregationContext,
299 gb: &GroupBy,
300 ) -> PolarsResult<MapStrategy> {
301 match (self.mapping, ac.agg_state()) {
302 (WindowMapping::Explode, _) => Ok(MapStrategy::Explode),
305 (_, AggState::AggregatedScalar(_)) => Ok(MapStrategy::Join),
311 (WindowMapping::Join, AggState::AggregatedList(_)) => Ok(MapStrategy::Join),
314 (WindowMapping::GroupsToRows, AggState::AggregatedList(_)) => {
317 if let GroupsType::Slice { .. } = gb.get_groups().as_ref() {
318 ac.groups().as_ref().check_lengths(gb.get_groups())?;
320 Ok(MapStrategy::Explode)
321 } else {
322 Ok(MapStrategy::Map)
323 }
324 },
325 (WindowMapping::GroupsToRows, AggState::NotAggregated(_)) => {
330 if self.is_simple_column_expr() {
333 Ok(MapStrategy::Nothing)
334 } else {
335 Ok(MapStrategy::Map)
336 }
337 },
338 (WindowMapping::Join, AggState::NotAggregated(_)) => Ok(MapStrategy::Join),
339 (_, AggState::LiteralScalar(_)) => Ok(MapStrategy::Nothing),
341 }
342 }
343}
344
345pub fn window_function_format_order_by(to: &mut String, e: &Expr, k: &SortOptions) {
347 write!(to, "_PL_{:?}{}_{}", e, k.descending, k.nulls_last).unwrap();
348}
349
350impl PhysicalExpr for WindowExpr {
351 fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
357 if df.height() == 0 {
384 let field = self.phys_function.to_field(df.schema())?;
385 match self.mapping {
386 WindowMapping::Join => {
387 return Ok(Column::full_null(
388 field.name().clone(),
389 0,
390 &DataType::List(Box::new(field.dtype().clone())),
391 ));
392 },
393 _ => {
394 return Ok(Column::full_null(field.name().clone(), 0, field.dtype()));
395 },
396 }
397 }
398
399 let mut group_by_columns = self
400 .group_by
401 .iter()
402 .map(|e| e.evaluate(df, state))
403 .collect::<PolarsResult<Vec<_>>>()?;
404
405 let sorted_keys = group_by_columns.iter().all(|s| {
407 matches!(
408 s.is_sorted_flag(),
409 IsSorted::Ascending | IsSorted::Descending
410 )
411 });
412 let explicit_list_agg = self.is_explicit_list_agg();
413
414 let mut sort_groups = matches!(self.mapping, WindowMapping::Explode) ||
416 (!self.is_simple_column_expr() && !explicit_list_agg && sorted_keys && !self.is_aggregation());
425
426 if self.has_different_group_sources {
429 sort_groups = true
430 }
431
432 let create_groups = || {
433 let gb = df.group_by_with_series(group_by_columns.clone(), true, sort_groups)?;
434 let mut groups = gb.into_groups();
435
436 if let Some((order_by, options)) = &self.order_by {
437 let order_by = order_by.evaluate(df, state)?;
438 polars_ensure!(order_by.len() == df.height(), ShapeMismatch: "the order by expression evaluated to a length: {} that doesn't match the input DataFrame: {}", order_by.len(), df.height());
439 groups = update_groups_sort_by(&groups, order_by.as_materialized_series(), options)?
440 .into_sliceable()
441 }
442
443 let out: PolarsResult<GroupPositions> = Ok(groups);
444 out
445 };
446
447 let (mut groups, cache_key) = if state.cache_window() {
449 let mut cache_key = String::with_capacity(32 * group_by_columns.len());
450 write!(&mut cache_key, "{}", state.branch_idx).unwrap();
451 for s in &group_by_columns {
452 cache_key.push_str(s.name());
453 }
454 if let Some((e, options)) = &self.order_by {
455 let e = match e.as_expression() {
456 Some(e) => e,
457 None => {
458 polars_bail!(InvalidOperation: "cannot order by this expression in window function")
459 },
460 };
461 window_function_format_order_by(&mut cache_key, e, options)
462 }
463
464 let groups = match state.window_cache.get_groups(&cache_key) {
465 Some(groups) => groups,
466 None => create_groups()?,
467 };
468 (groups, cache_key)
469 } else {
470 (create_groups()?, "".to_string())
471 };
472
473 let apply_columns = self.apply_columns.clone();
475
476 if sort_groups || state.cache_window() {
481 groups.sort();
482 state
483 .window_cache
484 .insert_groups(cache_key.clone(), groups.clone());
485 }
486
487 for col in group_by_columns.iter_mut() {
489 if col.len() != df.height() {
490 polars_ensure!(
491 col.len() == 1,
492 ShapeMismatch: "columns used as `partition_by` must have the same length as the DataFrame"
493 );
494 *col = col.new_from_index(0, df.height())
495 }
496 }
497
498 let gb = GroupBy::new(df, group_by_columns.clone(), groups, Some(apply_columns));
499
500 let mut ac = self.run_aggregation(df, state, &gb)?;
501
502 use MapStrategy::*;
503
504 match self.determine_map_strategy(&mut ac, &gb)? {
505 Nothing => {
506 let mut out = ac.flat_naive().into_owned();
507
508 if ac.is_literal() {
509 out = out.new_from_index(0, df.height())
510 }
511 Ok(out.into_column())
512 },
513 Explode => {
514 let out = if self.phys_function.is_scalar() {
515 ac.get_values().clone()
516 } else {
517 ac.aggregated().explode(ExplodeOptions {
518 empty_as_null: true,
519 keep_nulls: true,
520 })?
521 };
522 Ok(out.into_column())
523 },
524 Map => {
525 let out_column = ac.aggregated();
528 let flattened = out_column.explode(ExplodeOptions {
529 empty_as_null: true,
530 keep_nulls: true,
531 })?;
532 let ac = unsafe {
535 std::mem::transmute::<AggregationContext<'_>, AggregationContext<'static>>(ac)
536 };
537 self.map_by_arg_sort(
538 df,
539 out_column,
540 &flattened,
541 ac,
542 &group_by_columns,
543 gb,
544 cache_key,
545 state,
546 )
547 },
548 Join => {
549 let out_column = ac.aggregated();
550 let update_groups = !matches!(&ac.update_groups, UpdateGroups::No);
554 match (
555 &ac.update_groups,
556 set_by_groups(&out_column, &ac, df.height(), update_groups),
557 ) {
558 (UpdateGroups::No, Some(out)) => Ok(out.into_column()),
561 (_, _) => {
562 let keys = gb.keys();
563
564 let get_join_tuples = || {
565 if group_by_columns.len() == 1 {
566 let mut left = group_by_columns[0].clone();
567 let mut right = keys[0].clone();
569
570 let (left, right) = if left.dtype().is_nested() {
571 (
572 ChunkedArray::<BinaryOffsetType>::with_chunk(
573 "".into(),
574 row_encode::_get_rows_encoded_unordered(&[
575 left.clone()
576 ])?
577 .into_array(),
578 )
579 .into_series(),
580 ChunkedArray::<BinaryOffsetType>::with_chunk(
581 "".into(),
582 row_encode::_get_rows_encoded_unordered(&[
583 right.clone()
584 ])?
585 .into_array(),
586 )
587 .into_series(),
588 )
589 } else {
590 (
591 left.into_materialized_series().clone(),
592 right.into_materialized_series().clone(),
593 )
594 };
595
596 PolarsResult::Ok(Arc::new(
597 left.hash_join_left(&right, JoinValidation::ManyToMany, true)
598 .unwrap()
599 .1,
600 ))
601 } else {
602 let df_right =
603 unsafe { DataFrame::new_unchecked_infer_height(keys) };
604 let df_left = unsafe {
605 DataFrame::new_unchecked_infer_height(group_by_columns)
606 };
607 Ok(Arc::new(
608 private_left_join_multiple_keys(&df_left, &df_right, true)?.1,
609 ))
610 }
611 };
612
613 let join_opt_ids = if state.cache_window() {
615 if let Some(jt) = state.window_cache.get_join(&cache_key) {
616 jt
617 } else {
618 let jt = get_join_tuples()?;
619 state.window_cache.insert_join(cache_key, jt.clone());
620 jt
621 }
622 } else {
623 get_join_tuples()?
624 };
625
626 let out = materialize_column(&join_opt_ids, &out_column);
627 Ok(out.into_column())
628 },
629 }
630 },
631 }
632 }
633
634 fn to_field(&self, _input_schema: &Schema) -> PolarsResult<Field> {
635 Ok(self.output_field.clone())
636 }
637
638 fn is_scalar(&self) -> bool {
639 false
640 }
641
642 #[allow(clippy::ptr_arg)]
643 fn evaluate_on_groups<'a>(
644 &self,
645 df: &DataFrame,
646 groups: &'a GroupPositions,
647 state: &ExecutionState,
648 ) -> PolarsResult<AggregationContext<'a>> {
649 if self.group_by.is_empty()
650 || !self.all_group_by_are_elementwise
651 || (self.order_by.is_some() && !self.order_by_is_elementwise)
652 {
653 polars_bail!(
654 InvalidOperation:
655 "window expression with non-elementwise `partition_by` or `order_by` not allowed in aggregation context"
656 );
657 }
658
659 let length_preserving_height = if let Some((c, _)) = state.element.as_ref() {
660 c.len()
661 } else {
662 df.height()
663 };
664
665 let function_is_scalar = self.phys_function.is_scalar();
666 let needs_remap_to_rows =
667 matches!(self.mapping, WindowMapping::GroupsToRows) && !function_is_scalar;
668
669 let partition_by_columns = self
670 .group_by
671 .iter()
672 .map(|e| {
673 let mut e = e.evaluate(df, state)?;
674 if e.len() == 1 {
675 e = e.new_from_index(0, length_preserving_height);
676 }
677 assert_eq!(e.len(), length_preserving_height,);
679 Ok(e)
680 })
681 .collect::<PolarsResult<Vec<_>>>()?;
682 let order_by = match &self.order_by {
683 None => None,
684 Some((e, options)) => {
685 let mut e = e.evaluate(df, state)?;
686 if e.len() == 1 {
687 e = e.new_from_index(0, length_preserving_height);
688 }
689 assert_eq!(e.len(), length_preserving_height);
691 let arr: Option<PrimitiveArray<IdxSize>> = if needs_remap_to_rows {
692 feature_gated!("rank", {
693 use polars_ops::series::SeriesRank;
696 let arr = e.as_materialized_series().rank(
697 RankOptions {
698 method: RankMethod::Ordinal,
699 descending: false,
700 },
701 None,
702 );
703 let arr = arr.idx()?;
704 let arr = arr.rechunk();
705 Some(arr.downcast_as_array().clone())
706 })
707 } else {
708 None
709 };
710
711 Some((e.clone(), arr, *options))
712 },
713 };
714
715 let (num_unique_ids, unique_ids) = if partition_by_columns.len() == 1 {
716 partition_by_columns[0].unique_id()?
717 } else {
718 ChunkUnique::unique_id(&encode_rows_unordered(&partition_by_columns)?)?
719 };
720
721 let subgroups_approx_capacity = groups.len();
723 let mut subgroups: Vec<(IdxSize, UnitVec<IdxSize>)> =
724 Vec::with_capacity(subgroups_approx_capacity);
725
726 let mut gather_indices_offset = 0;
728 let mut gather_indices: Vec<(IdxSize, UnitVec<IdxSize>)> =
729 Vec::with_capacity(if matches!(self.mapping, WindowMapping::Explode) {
730 0
731 } else {
732 groups.len()
733 });
734 let mut strategy_explode_groups: Vec<[IdxSize; 2]> =
736 Vec::with_capacity(if matches!(self.mapping, WindowMapping::Explode) {
737 groups.len()
738 } else {
739 0
740 });
741
742 let mut amort_arg_sort = Vec::new();
744 let mut amort_offsets = Vec::new();
745
746 let mut amort_subgroups_order = Vec::with_capacity(num_unique_ids as usize);
748 let mut amort_subgroups_sizes = Vec::with_capacity(num_unique_ids as usize);
749 let mut amort_subgroups_indices = (0..num_unique_ids)
750 .map(|_| (0, UnitVec::new()))
751 .collect::<Vec<(IdxSize, UnitVec<IdxSize>)>>();
752
753 macro_rules! map_window_groups {
754 ($iter:expr, $get:expr) => {
755 let mut subgroup_gather_indices =
756 UnitVec::with_capacity(if matches!(self.mapping, WindowMapping::Explode) {
757 0
758 } else {
759 $iter.len()
760 });
761
762 amort_subgroups_order.clear();
763 amort_subgroups_sizes.clear();
764 amort_subgroups_sizes.resize(num_unique_ids as usize, 0);
765
766 for i in $iter.clone() {
768 let id = *unsafe { unique_ids.get_unchecked(i as usize) };
769 let size = unsafe { amort_subgroups_sizes.get_unchecked_mut(id as usize) };
770 if *size == 0 {
771 unsafe { amort_subgroups_order.push_unchecked(id) };
772 }
773 *size += 1;
774 }
775
776 if matches!(self.mapping, WindowMapping::Explode) {
777 strategy_explode_groups.push([
778 subgroups.len() as IdxSize,
779 amort_subgroups_order.len() as IdxSize,
780 ]);
781 }
782
783 let mut offset = if needs_remap_to_rows {
785 gather_indices_offset
786 } else {
787 subgroups.len() as IdxSize
788 };
789 for &id in &amort_subgroups_order {
790 let size = *unsafe { amort_subgroups_sizes.get_unchecked(id as usize) };
791 let (next_gather_idx, indices) =
792 unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
793 indices.reserve(size as usize);
794 *next_gather_idx = offset;
795 offset += if needs_remap_to_rows { size } else { 1 };
796 }
797
798 if matches!(self.mapping, WindowMapping::Explode) {
800 for i in $iter {
801 let id = *unsafe { unique_ids.get_unchecked(i as usize) };
802 let (_, indices) =
803 unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
804 unsafe { indices.push_unchecked(i) };
805 }
806 } else {
807 if needs_remap_to_rows && let Some((_, arr, options)) = &order_by {
816 let arr = arr.as_ref().unwrap();
817 amort_arg_sort.clear();
818 amort_arg_sort.extend(0..$iter.len() as IdxSize);
819 match arr.validity() {
820 None => {
821 let arr = arr.values().as_slice();
822 amort_arg_sort.sort_by(|a, b| {
823 let in_group_idx_a = $get(*a as usize) as usize;
824 let in_group_idx_b = $get(*b as usize) as usize;
825
826 let order_a = unsafe { arr.get_unchecked(in_group_idx_a) };
827 let order_b = unsafe { arr.get_unchecked(in_group_idx_b) };
828
829 let mut cmp = order_a.cmp(&order_b);
830 if options.descending {
832 cmp = cmp.reverse();
833 }
834 cmp
835 });
836 },
837 Some(validity) => {
838 let arr = arr.values().as_slice();
839 amort_arg_sort.sort_by(|a, b| {
840 let in_group_idx_a = $get(*a as usize) as usize;
841 let in_group_idx_b = $get(*b as usize) as usize;
842
843 let is_valid_a =
844 unsafe { validity.get_bit_unchecked(in_group_idx_a) };
845 let is_valid_b =
846 unsafe { validity.get_bit_unchecked(in_group_idx_b) };
847 let order_a = unsafe { arr.get_unchecked(in_group_idx_a) };
848 let order_b = unsafe { arr.get_unchecked(in_group_idx_b) };
849
850 if !is_valid_a & !is_valid_b {
851 return Ordering::Equal;
852 }
853
854 let mut cmp = order_a.cmp(&order_b);
855 if !is_valid_a {
856 cmp = Ordering::Less;
857 }
858 if !is_valid_b {
859 cmp = Ordering::Greater;
860 }
861 if options.descending
862 | ((!is_valid_a | !is_valid_b) & options.nulls_last)
863 {
864 cmp = cmp.reverse();
865 }
866 cmp
867 });
868 },
869 }
870
871 amort_offsets.clear();
872 amort_offsets.resize($iter.len(), 0);
873 for &id in &amort_subgroups_order {
874 amort_subgroups_sizes[id as usize] = 0;
875 }
876
877 for &idx in &amort_arg_sort {
878 let in_group_idx = $get(idx as usize);
879 let id = *unsafe { unique_ids.get_unchecked(in_group_idx as usize) };
880 amort_offsets[idx as usize] = amort_subgroups_sizes[id as usize];
881 amort_subgroups_sizes[id as usize] += 1;
882 }
883
884 for (i, offset) in $iter.zip(&amort_offsets) {
885 let id = *unsafe { unique_ids.get_unchecked(i as usize) };
886 let (next_gather_idx, indices) =
887 unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
888 unsafe {
889 subgroup_gather_indices.push_unchecked(*next_gather_idx + *offset)
890 };
891 unsafe { indices.push_unchecked(i) };
892 }
893 } else {
894 for i in $iter {
895 let id = *unsafe { unique_ids.get_unchecked(i as usize) };
896 let (next_gather_idx, indices) =
897 unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
898 unsafe { subgroup_gather_indices.push_unchecked(*next_gather_idx) };
899 *next_gather_idx += IdxSize::from(needs_remap_to_rows);
900 unsafe { indices.push_unchecked(i) };
901 }
902 }
903 }
904
905 subgroups.extend(amort_subgroups_order.iter().map(|&id| {
907 let (_, indices) =
908 unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
909 let indices = std::mem::take(indices);
910 (*unsafe { indices.get_unchecked(0) }, indices)
911 }));
912
913 if !matches!(self.mapping, WindowMapping::Explode) {
914 gather_indices_offset += subgroup_gather_indices.len() as IdxSize;
915 gather_indices.push((
916 subgroup_gather_indices.first().copied().unwrap_or(0),
917 subgroup_gather_indices,
918 ));
919 }
920 };
921 }
922 match groups.as_ref() {
923 GroupsType::Idx(idxs) => {
924 for g in idxs.all() {
925 map_window_groups!(g.iter().copied(), (|i: usize| g[i]));
926 }
927 },
928 GroupsType::Slice {
929 groups,
930 overlapping: _,
931 monotonic: _,
932 } => {
933 for [s, l] in groups.iter() {
934 let s = *s;
935 let l = *l;
936 let iter = unsafe { TrustMyLength::new(s..s + l, l as usize) };
937 map_window_groups!(iter, (|i: usize| s + i as IdxSize));
938 }
939 },
940 }
941
942 let mut subgroups = GroupsType::Idx(subgroups.into());
943 if let Some((order_by, _, options)) = order_by {
944 subgroups =
945 update_groups_sort_by(&subgroups, order_by.as_materialized_series(), &options)?;
946 }
947 let subgroups = subgroups.into_sliceable();
948 let mut data = self
949 .phys_function
950 .evaluate_on_groups(df, &subgroups, state)?
951 .finalize();
952
953 let final_groups = if matches!(self.mapping, WindowMapping::Explode) {
954 if !function_is_scalar {
955 let (data_s, offsets) = data.list()?.explode_and_offsets(ExplodeOptions {
956 empty_as_null: false,
957 keep_nulls: false,
958 })?;
959 data = data_s.into_column();
960
961 let mut exploded_offset = 0;
962 for [start, length] in strategy_explode_groups.iter_mut() {
963 let exploded_start = exploded_offset;
964 let exploded_length = offsets
965 .lengths()
966 .skip(*start as usize)
967 .take(*length as usize)
968 .sum::<usize>() as IdxSize;
969 exploded_offset += exploded_length;
970 *start = exploded_start;
971 *length = exploded_length;
972 }
973 }
974 GroupsType::new_slice(strategy_explode_groups, false, true)
975 } else {
976 if needs_remap_to_rows {
977 let data_l = data.list()?;
978 assert_eq!(data_l.len(), subgroups.len());
979 let lengths = data_l.lst_lengths();
980 let length_mismatch = match subgroups.as_ref() {
981 GroupsType::Idx(idx) => idx
982 .all()
983 .iter()
984 .zip(&lengths)
985 .any(|(i, l)| i.len() as IdxSize != l.unwrap()),
986 GroupsType::Slice {
987 groups,
988 overlapping: _,
989 monotonic: _,
990 } => groups
991 .iter()
992 .zip(&lengths)
993 .any(|([_, i], l)| *i != l.unwrap()),
994 };
995
996 polars_ensure!(
997 !length_mismatch,
998 expr = self.expr, ShapeMismatch:
999 "the length of the window expression did not match that of the group"
1000 );
1001
1002 data = data_l
1003 .explode(ExplodeOptions {
1004 empty_as_null: false,
1005 keep_nulls: true,
1006 })?
1007 .into_column();
1008 }
1009 GroupsType::Idx(gather_indices.into())
1010 }
1011 .into_sliceable();
1012
1013 Ok(AggregationContext {
1014 state: AggState::NotAggregated(data),
1015 groups: Cow::Owned(final_groups),
1016 update_groups: UpdateGroups::No,
1017 original_len: false,
1018 })
1019 }
1020
1021 fn as_expression(&self) -> Option<&Expr> {
1022 Some(&self.expr)
1023 }
1024}
1025
1026fn materialize_column(join_opt_ids: &ChunkJoinOptIds, out_column: &Column) -> Column {
1027 {
1028 use arrow::Either;
1029 use polars_ops::chunked_array::TakeChunked;
1030
1031 match join_opt_ids {
1032 Either::Left(ids) => unsafe {
1033 IdxCa::with_nullable_idx(ids, |idx| out_column.take_unchecked(idx))
1034 },
1035 Either::Right(ids) => unsafe { out_column.take_opt_chunked_unchecked(ids, false) },
1036 }
1037 }
1038}
1039
1040fn set_by_groups(
1042 s: &Column,
1043 ac: &AggregationContext,
1044 len: usize,
1045 update_groups: bool,
1046) -> Option<Column> {
1047 if update_groups || !ac.original_len {
1048 return None;
1049 }
1050 if s.dtype().to_physical().is_primitive_numeric() {
1051 let dtype = s.dtype();
1052 let s = s.to_physical_repr();
1053
1054 macro_rules! dispatch {
1055 ($ca:expr) => {{ Some(set_numeric($ca, &ac.groups, len)) }};
1056 }
1057 downcast_as_macro_arg_physical!(&s, dispatch)
1058 .map(|s| unsafe { s.from_physical_unchecked(dtype) }.unwrap())
1059 .map(Column::from)
1060 } else {
1061 None
1062 }
1063}
1064
1065fn set_numeric<T: PolarsNumericType>(
1066 ca: &ChunkedArray<T>,
1067 groups: &GroupsType,
1068 len: usize,
1069) -> Series {
1070 let mut values = Vec::with_capacity(len);
1071 let ptr: *mut T::Native = values.as_mut_ptr();
1072 let sync_ptr_values = unsafe { SyncPtr::new(ptr) };
1075
1076 if ca.null_count() == 0 {
1077 let ca = ca.rechunk();
1078 match groups {
1079 GroupsType::Idx(groups) => {
1080 let agg_vals = ca.cont_slice().expect("rechunked");
1081 POOL.install(|| {
1082 agg_vals
1083 .par_iter()
1084 .zip(groups.all().par_iter())
1085 .for_each(|(v, g)| {
1086 let ptr = sync_ptr_values.get();
1087 for idx in g.as_slice() {
1088 debug_assert!((*idx as usize) < len);
1089 unsafe { *ptr.add(*idx as usize) = *v }
1090 }
1091 })
1092 })
1093 },
1094 GroupsType::Slice { groups, .. } => {
1095 let agg_vals = ca.cont_slice().expect("rechunked");
1096 POOL.install(|| {
1097 agg_vals
1098 .par_iter()
1099 .zip(groups.par_iter())
1100 .for_each(|(v, [start, g_len])| {
1101 let ptr = sync_ptr_values.get();
1102 let start = *start as usize;
1103 let end = start + *g_len as usize;
1104 for idx in start..end {
1105 debug_assert!(idx < len);
1106 unsafe { *ptr.add(idx) = *v }
1107 }
1108 })
1109 });
1110 },
1111 }
1112
1113 unsafe { values.set_len(len) }
1115 ChunkedArray::<T>::new_vec(ca.name().clone(), values).into_series()
1116 } else {
1117 let mut validity: Vec<bool> = vec![false; len];
1120 let validity_ptr = validity.as_mut_ptr();
1121 let sync_ptr_validity = unsafe { SyncPtr::new(validity_ptr) };
1122
1123 let n_threads = POOL.current_num_threads();
1124 let offsets = _split_offsets(ca.len(), n_threads);
1125
1126 match groups {
1127 GroupsType::Idx(groups) => offsets.par_iter().for_each(|(offset, offset_len)| {
1128 let offset = *offset;
1129 let offset_len = *offset_len;
1130 let ca = ca.slice(offset as i64, offset_len);
1131 let groups = &groups.all()[offset..offset + offset_len];
1132 let values_ptr = sync_ptr_values.get();
1133 let validity_ptr = sync_ptr_validity.get();
1134
1135 ca.iter().zip(groups.iter()).for_each(|(opt_v, g)| {
1136 for idx in g.as_slice() {
1137 let idx = *idx as usize;
1138 debug_assert!(idx < len);
1139 unsafe {
1140 match opt_v {
1141 Some(v) => {
1142 *values_ptr.add(idx) = v;
1143 *validity_ptr.add(idx) = true;
1144 },
1145 None => {
1146 *values_ptr.add(idx) = T::Native::default();
1147 *validity_ptr.add(idx) = false;
1148 },
1149 };
1150 }
1151 }
1152 })
1153 }),
1154 GroupsType::Slice { groups, .. } => {
1155 offsets.par_iter().for_each(|(offset, offset_len)| {
1156 let offset = *offset;
1157 let offset_len = *offset_len;
1158 let ca = ca.slice(offset as i64, offset_len);
1159 let groups = &groups[offset..offset + offset_len];
1160 let values_ptr = sync_ptr_values.get();
1161 let validity_ptr = sync_ptr_validity.get();
1162
1163 for (opt_v, [start, g_len]) in ca.iter().zip(groups.iter()) {
1164 let start = *start as usize;
1165 let end = start + *g_len as usize;
1166 for idx in start..end {
1167 debug_assert!(idx < len);
1168 unsafe {
1169 match opt_v {
1170 Some(v) => {
1171 *values_ptr.add(idx) = v;
1172 *validity_ptr.add(idx) = true;
1173 },
1174 None => {
1175 *values_ptr.add(idx) = T::Native::default();
1176 *validity_ptr.add(idx) = false;
1177 },
1178 };
1179 }
1180 }
1181 }
1182 })
1183 },
1184 }
1185 unsafe { values.set_len(len) }
1187 let validity = Bitmap::from(validity);
1188 let arr = PrimitiveArray::new(
1189 T::get_static_dtype()
1190 .to_physical()
1191 .to_arrow(CompatLevel::newest()),
1192 values.into(),
1193 Some(validity),
1194 );
1195 Series::try_from((ca.name().clone(), arr.boxed())).unwrap()
1196 }
1197}