1use std::fmt::Write;
2
3use arrow::array::PrimitiveArray;
4use arrow::bitmap::Bitmap;
5use arrow::trusted_len::TrustMyLength;
6use polars_core::downcast_as_macro_arg_physical;
7use polars_core::error::feature_gated;
8use polars_core::prelude::row_encode::encode_rows_unordered;
9use polars_core::prelude::sort::perfect_sort;
10use polars_core::prelude::*;
11use polars_core::runtime::RAYON;
12use polars_core::series::IsSorted;
13use polars_core::utils::_split_offsets;
14use polars_ops::frame::SeriesJoin;
15use polars_ops::frame::join::{ChunkJoinOptIds, private_left_join_multiple_keys};
16use polars_ops::prelude::*;
17use polars_plan::prelude::*;
18use polars_utils::UnitVec;
19use polars_utils::sync::SyncPtr;
20use polars_utils::vec::PushUnchecked;
21use rayon::prelude::*;
22
23use super::*;
24
25pub struct WindowExpr {
26 pub(crate) group_by: Vec<Arc<dyn PhysicalExpr>>,
29 pub(crate) order_by: Option<(Arc<dyn PhysicalExpr>, SortOptions)>,
30 pub(crate) apply_columns: Vec<PlSmallStr>,
31 pub(crate) phys_function: Arc<dyn PhysicalExpr>,
32 pub(crate) mapping: WindowMapping,
33 pub(crate) expr: Expr,
34 pub(crate) has_different_group_sources: bool,
35 pub(crate) output_field: Field,
36
37 pub(crate) all_group_by_are_elementwise: bool,
38 pub(crate) order_by_is_elementwise: bool,
39}
40
41#[cfg_attr(debug_assertions, derive(Debug))]
42enum MapStrategy {
43 Join,
46 Explode,
48 Map,
50 Nothing,
51}
52
53impl WindowExpr {
54 fn map_list_agg_by_arg_sort(
55 &self,
56 out_column: Column,
57 flattened: &Column,
58 mut ac: AggregationContext,
59 gb: GroupBy,
60 ) -> PolarsResult<IdxCa> {
61 let mut idx_mapping = Vec::with_capacity(out_column.len());
63
64 let mut take_idx = vec![];
67
68 if std::ptr::eq(ac.groups().as_ref(), gb.get_groups()) {
70 let mut iter = 0..flattened.len() as IdxSize;
71 match ac.groups().as_ref().as_ref() {
72 GroupsType::Idx(groups) => {
73 for g in groups.all() {
74 idx_mapping.extend(g.iter().copied().zip(&mut iter));
75 }
76 },
77 GroupsType::Slice { groups, .. } => {
78 for &[first, len] in groups {
79 idx_mapping.extend((first..first + len).zip(&mut iter));
80 }
81 },
82 }
83 }
84 else {
87 let mut original_idx = Vec::with_capacity(out_column.len());
88 match gb.get_groups().as_ref() {
89 GroupsType::Idx(groups) => {
90 for g in groups.all() {
91 original_idx.extend_from_slice(g)
92 }
93 },
94 GroupsType::Slice { groups, .. } => {
95 for &[first, len] in groups {
96 original_idx.extend(first..first + len)
97 }
98 },
99 };
100
101 let mut original_idx_iter = original_idx.iter().copied();
102
103 match ac.groups().as_ref().as_ref() {
104 GroupsType::Idx(groups) => {
105 for g in groups.all() {
106 idx_mapping.extend(g.iter().copied().zip(&mut original_idx_iter));
107 }
108 },
109 GroupsType::Slice { groups, .. } => {
110 for &[first, len] in groups {
111 idx_mapping.extend((first..first + len).zip(&mut original_idx_iter));
112 }
113 },
114 }
115 original_idx.clear();
116 take_idx = original_idx;
117 }
118 unsafe { perfect_sort(&idx_mapping, &mut take_idx) };
121 Ok(IdxCa::from_vec(PlSmallStr::EMPTY, take_idx))
122 }
123
124 #[allow(clippy::too_many_arguments)]
125 fn map_by_arg_sort(
126 &self,
127 df: &DataFrame,
128 out_column: Column,
129 flattened: &Column,
130 mut ac: AggregationContext,
131 group_by_columns: &[Column],
132 gb: GroupBy,
133 cache_key: String,
134 state: &ExecutionState,
135 ) -> PolarsResult<Column> {
136 if flattened.len() != df.height() {
162 let ca = out_column.list().unwrap();
163 let non_matching_group =
164 ca.series_iter()
165 .zip(ac.groups().iter())
166 .find(|(output, group)| {
167 if let Some(output) = output {
168 output.as_ref().len() != group.len()
169 } else {
170 false
171 }
172 });
173
174 if let Some((output, group)) = non_matching_group {
175 let first = group.first();
176 let group = group_by_columns
177 .iter()
178 .map(|s| format!("{}", s.get(first as usize).unwrap()))
179 .collect::<Vec<_>>();
180 polars_bail!(
181 expr = self.expr, ShapeMismatch:
182 "the length of the window expression did not match that of the group\
183 \n> group: {}\n> group length: {}\n> output: '{:?}'",
184 comma_delimited(String::new(), &group), group.len(), output.unwrap()
185 );
186 } else {
187 polars_bail!(
188 expr = self.expr, ShapeMismatch:
189 "the length of the window expression did not match that of the group"
190 );
191 };
192 }
193
194 let idx = if state.cache_window() {
195 if let Some(idx) = state.window_cache.get_map(&cache_key) {
196 idx
197 } else {
198 let idx = Arc::new(self.map_list_agg_by_arg_sort(out_column, flattened, ac, gb)?);
199 state.window_cache.insert_map(cache_key, idx.clone());
200 idx
201 }
202 } else {
203 Arc::new(self.map_list_agg_by_arg_sort(out_column, flattened, ac, gb)?)
204 };
205
206 unsafe { Ok(flattened.take_unchecked(&idx)) }
209 }
210
211 fn run_aggregation<'a>(
212 &self,
213 df: &DataFrame,
214 state: &ExecutionState,
215 gb: &'a GroupBy,
216 ) -> PolarsResult<AggregationContext<'a>> {
217 let ac = self
218 .phys_function
219 .evaluate_on_groups(df, gb.get_groups(), state)?;
220 Ok(ac)
221 }
222
223 fn is_explicit_list_agg(&self) -> bool {
224 let mut explicit_list = false;
234 for e in &self.expr {
235 if let Expr::Over { function, .. } = e {
236 let mut finishes_list = false;
238 for e in &**function {
239 match e {
240 Expr::Agg(AggExpr::Implode { .. }) => {
241 finishes_list = true;
242 },
243 Expr::Alias(_, _) => {},
244 _ => break,
245 }
246 }
247 explicit_list = finishes_list;
248 }
249 }
250
251 explicit_list
252 }
253
254 fn is_simple_column_expr(&self) -> bool {
255 let mut simple_col = false;
258 for e in &self.expr {
259 if let Expr::Over { function, .. } = e {
260 for e in &**function {
262 match e {
263 Expr::Column(_) => {
264 simple_col = true;
265 },
266 Expr::Alias(_, _) => {},
267 _ => break,
268 }
269 }
270 }
271 }
272 simple_col
273 }
274
275 fn is_aggregation(&self) -> bool {
276 let mut agg_col = false;
279 for e in &self.expr {
280 if let Expr::Over { function, .. } = e {
281 for e in &**function {
283 match e {
284 Expr::Agg(_) => {
285 agg_col = true;
286 },
287 Expr::Alias(_, _) => {},
288 _ => break,
289 }
290 }
291 }
292 }
293 agg_col
294 }
295
296 fn determine_map_strategy(
297 &self,
298 ac: &mut AggregationContext,
299 gb: &GroupBy,
300 ) -> PolarsResult<MapStrategy> {
301 match (self.mapping, ac.agg_state()) {
302 (WindowMapping::Explode, _) => Ok(MapStrategy::Explode),
305 (_, AggState::AggregatedScalar(_)) => Ok(MapStrategy::Join),
311 (WindowMapping::Join, AggState::AggregatedList(_)) => Ok(MapStrategy::Join),
314 (WindowMapping::GroupsToRows, AggState::AggregatedList(_)) => {
317 if let GroupsType::Slice { .. } = gb.get_groups().as_ref() {
318 ac.groups().as_ref().check_lengths(gb.get_groups())?;
320 Ok(MapStrategy::Explode)
321 } else {
322 Ok(MapStrategy::Map)
323 }
324 },
325 (WindowMapping::GroupsToRows, AggState::NotAggregated(_)) => {
330 if self.is_simple_column_expr() {
333 Ok(MapStrategy::Nothing)
334 } else {
335 Ok(MapStrategy::Map)
336 }
337 },
338 (WindowMapping::Join, AggState::NotAggregated(_)) => Ok(MapStrategy::Join),
339 (_, AggState::LiteralScalar(_)) => Ok(MapStrategy::Nothing),
341 }
342 }
343}
344
345pub fn window_function_format_order_by(to: &mut String, e: &Expr, k: &SortOptions) {
347 write!(to, "_PL_{:?}{}_{}", e, k.descending, k.nulls_last).unwrap();
348}
349
350impl PhysicalExpr for WindowExpr {
351 fn evaluate_impl(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
357 if df.height() == 0 {
384 let field = self.phys_function.to_field(df.schema())?;
385 match self.mapping {
386 WindowMapping::Join => {
387 return Ok(Column::full_null(
388 field.name().clone(),
389 0,
390 &DataType::List(Box::new(field.dtype().clone())),
391 ));
392 },
393 _ => {
394 return Ok(Column::full_null(field.name().clone(), 0, field.dtype()));
395 },
396 }
397 }
398
399 let mut group_by_columns = self
400 .group_by
401 .iter()
402 .map(|e| e.evaluate(df, state))
403 .collect::<PolarsResult<Vec<_>>>()?;
404
405 let sorted_keys = group_by_columns.iter().all(|s| {
407 matches!(
408 s.is_sorted_flag(),
409 IsSorted::Ascending | IsSorted::Descending
410 )
411 });
412 let explicit_list_agg = self.is_explicit_list_agg();
413
414 let mut sort_groups = matches!(self.mapping, WindowMapping::Explode) ||
416 (!self.is_simple_column_expr() && !explicit_list_agg && sorted_keys && !self.is_aggregation());
425
426 if self.has_different_group_sources {
429 sort_groups = true
430 }
431
432 let create_groups = || {
433 let gb = df.group_by_with_series(group_by_columns.clone(), true, sort_groups)?;
434 let mut groups = gb.into_groups();
435
436 if let Some((order_by, options)) = &self.order_by {
437 let order_by = order_by.evaluate(df, state)?;
438 polars_ensure!(order_by.len() == df.height(), ShapeMismatch: "the order by expression evaluated to a length: {} that doesn't match the input DataFrame: {}", order_by.len(), df.height());
439 groups = update_groups_sort_by(&groups, order_by.as_materialized_series(), options)?
440 .into_sliceable()
441 }
442
443 let out: PolarsResult<GroupPositions> = Ok(groups);
444 out
445 };
446
447 let (mut groups, cache_key) = if state.cache_window() {
449 let mut cache_key = String::with_capacity(32 * group_by_columns.len());
450 write!(&mut cache_key, "{}", state.branch_idx).unwrap();
451 for s in &group_by_columns {
452 cache_key.push_str(s.name());
453 }
454 if let Some((e, options)) = &self.order_by {
455 let e = match e.as_expression() {
456 Some(e) => e,
457 None => {
458 polars_bail!(InvalidOperation: "cannot order by this expression in window function")
459 },
460 };
461 window_function_format_order_by(&mut cache_key, e, options)
462 }
463
464 let groups = match state.window_cache.get_groups(&cache_key) {
465 Some(groups) => groups,
466 None => create_groups()?,
467 };
468 (groups, cache_key)
469 } else {
470 (create_groups()?, "".to_string())
471 };
472
473 let apply_columns = self.apply_columns.clone();
475
476 if sort_groups || state.cache_window() {
481 groups.sort_by_first_idx();
482 state
483 .window_cache
484 .insert_groups(cache_key.clone(), groups.clone());
485 }
486
487 for col in group_by_columns.iter_mut() {
489 if col.len() != df.height() {
490 polars_ensure!(
491 col.len() == 1,
492 ShapeMismatch: "columns used as `partition_by` must have the same length as the DataFrame"
493 );
494 *col = col.new_from_index(0, df.height())
495 }
496 }
497
498 let gb = GroupBy::new(df, group_by_columns.clone(), groups, Some(apply_columns));
499
500 let mut ac = self.run_aggregation(df, state, &gb)?;
501
502 use MapStrategy::*;
503
504 match self.determine_map_strategy(&mut ac, &gb)? {
505 Nothing => {
506 let mut out = ac.flat_naive().into_owned();
507
508 if ac.is_literal() {
509 out = out.new_from_index(0, df.height())
510 }
511 Ok(out.into_column())
512 },
513 Explode => {
514 let out = if self.phys_function.is_scalar() {
515 ac.get_values().clone()
516 } else {
517 ac.aggregated().explode(ExplodeOptions {
518 empty_as_null: true,
519 keep_nulls: true,
520 })?
521 };
522 Ok(out.into_column())
523 },
524 Map => {
525 let out_column = ac.aggregated();
528 let flattened = out_column.explode(ExplodeOptions {
529 empty_as_null: true,
530 keep_nulls: true,
531 })?;
532 let ac = unsafe {
535 std::mem::transmute::<AggregationContext<'_>, AggregationContext<'static>>(ac)
536 };
537 self.map_by_arg_sort(
538 df,
539 out_column,
540 &flattened,
541 ac,
542 &group_by_columns,
543 gb,
544 cache_key,
545 state,
546 )
547 },
548 Join => {
549 let out_column = ac.aggregated();
550 let update_groups = !matches!(&ac.update_groups, UpdateGroups::No);
554 match (
555 &ac.update_groups,
556 set_by_groups(
557 &out_column,
558 &ac,
559 gb.get_groups(),
560 df.height(),
561 update_groups,
562 ),
563 ) {
564 (UpdateGroups::No, Some(out)) => Ok(out.into_column()),
567 (_, _) => {
568 let keys = gb.keys();
569
570 let get_join_tuples = || {
571 if group_by_columns.len() == 1 {
572 let mut left = group_by_columns[0].clone();
573 let mut right = keys[0].clone();
575
576 let (left, right) = if left.dtype().is_nested() {
577 (
578 ChunkedArray::<BinaryOffsetType>::with_chunk(
579 "".into(),
580 row_encode::_get_rows_encoded_unordered(&[
581 left.clone()
582 ])?
583 .into_array(),
584 )
585 .into_series(),
586 ChunkedArray::<BinaryOffsetType>::with_chunk(
587 "".into(),
588 row_encode::_get_rows_encoded_unordered(&[
589 right.clone()
590 ])?
591 .into_array(),
592 )
593 .into_series(),
594 )
595 } else {
596 (
597 left.into_materialized_series().clone(),
598 right.into_materialized_series().clone(),
599 )
600 };
601
602 PolarsResult::Ok(Arc::new(
603 left.hash_join_left(&right, JoinValidation::ManyToMany, true)
604 .unwrap()
605 .1,
606 ))
607 } else {
608 Ok(Arc::new(
609 private_left_join_multiple_keys(
610 &group_by_columns,
611 &keys,
612 true,
613 )?
614 .1,
615 ))
616 }
617 };
618
619 let join_opt_ids = if state.cache_window() {
621 if let Some(jt) = state.window_cache.get_join(&cache_key) {
622 jt
623 } else {
624 let jt = get_join_tuples()?;
625 state.window_cache.insert_join(cache_key, jt.clone());
626 jt
627 }
628 } else {
629 get_join_tuples()?
630 };
631
632 let out = materialize_column(&join_opt_ids, &out_column);
633 Ok(out.into_column())
634 },
635 }
636 },
637 }
638 }
639
640 fn to_field(&self, _input_schema: &Schema) -> PolarsResult<Field> {
641 Ok(self.output_field.clone())
642 }
643
644 fn is_scalar(&self) -> bool {
645 false
646 }
647
648 #[allow(clippy::ptr_arg)]
649 fn evaluate_on_groups_impl<'a>(
650 &self,
651 df: &DataFrame,
652 groups: &'a GroupPositions,
653 state: &ExecutionState,
654 ) -> PolarsResult<AggregationContext<'a>> {
655 if self.group_by.is_empty()
656 || !self.all_group_by_are_elementwise
657 || (self.order_by.is_some() && !self.order_by_is_elementwise)
658 {
659 polars_bail!(
660 InvalidOperation:
661 "window expression with non-elementwise `partition_by` or `order_by` not allowed in aggregation context"
662 );
663 }
664
665 let length_preserving_height = if let Some((c, _)) = state.element.as_ref() {
666 c.len()
667 } else {
668 df.height()
669 };
670
671 let function_is_scalar = self.phys_function.is_scalar();
672 let needs_remap_to_rows =
673 matches!(self.mapping, WindowMapping::GroupsToRows) && !function_is_scalar;
674
675 let partition_by_columns = self
676 .group_by
677 .iter()
678 .map(|e| {
679 let mut e = e.evaluate(df, state)?;
680 if e.len() == 1 {
681 e = e.new_from_index(0, length_preserving_height);
682 }
683 assert_eq!(e.len(), length_preserving_height,);
685 Ok(e)
686 })
687 .collect::<PolarsResult<Vec<_>>>()?;
688 let order_by = match &self.order_by {
689 None => None,
690 Some((e, options)) => {
691 let mut e = e.evaluate(df, state)?;
692 if e.len() == 1 {
693 e = e.new_from_index(0, length_preserving_height);
694 }
695 assert_eq!(e.len(), length_preserving_height);
697 let arr: Option<PrimitiveArray<IdxSize>> = if needs_remap_to_rows {
698 feature_gated!("rank", {
699 use polars_ops::series::SeriesRank;
702 let arr = e.as_materialized_series().rank(
703 RankOptions {
704 method: RankMethod::Ordinal,
705 descending: false,
706 },
707 None,
708 );
709 let arr = arr.idx()?;
710 let arr = arr.rechunk();
711 Some(arr.downcast_as_array().clone())
712 })
713 } else {
714 None
715 };
716
717 Some((e.clone(), arr, *options))
718 },
719 };
720
721 let (num_unique_ids, unique_ids) = if partition_by_columns.len() == 1 {
722 partition_by_columns[0].unique_id()?
723 } else {
724 ChunkUnique::unique_id(&encode_rows_unordered(&partition_by_columns)?)?
725 };
726
727 let subgroups_approx_capacity = groups.len();
729 let mut subgroups: Vec<(IdxSize, UnitVec<IdxSize>)> =
730 Vec::with_capacity(subgroups_approx_capacity);
731
732 let mut gather_indices_offset = 0;
734 let mut gather_indices: Vec<(IdxSize, UnitVec<IdxSize>)> =
735 Vec::with_capacity(if matches!(self.mapping, WindowMapping::Explode) {
736 0
737 } else {
738 groups.len()
739 });
740 let mut strategy_explode_groups: Vec<[IdxSize; 2]> =
742 Vec::with_capacity(if matches!(self.mapping, WindowMapping::Explode) {
743 groups.len()
744 } else {
745 0
746 });
747
748 let mut amort_arg_sort = Vec::new();
750 let mut amort_offsets = Vec::new();
751
752 let mut amort_subgroups_order = Vec::with_capacity(num_unique_ids as usize);
754 let mut amort_subgroups_sizes = Vec::with_capacity(num_unique_ids as usize);
755 let mut amort_subgroups_indices = (0..num_unique_ids)
756 .map(|_| (0, UnitVec::new()))
757 .collect::<Vec<(IdxSize, UnitVec<IdxSize>)>>();
758
759 macro_rules! map_window_groups {
760 ($iter:expr, $get:expr) => {
761 let mut subgroup_gather_indices =
762 UnitVec::with_capacity(if matches!(self.mapping, WindowMapping::Explode) {
763 0
764 } else {
765 $iter.len()
766 });
767
768 amort_subgroups_order.clear();
769 amort_subgroups_sizes.clear();
770 amort_subgroups_sizes.resize(num_unique_ids as usize, 0);
771
772 for i in $iter.clone() {
774 let id = *unsafe { unique_ids.get_unchecked(i as usize) };
775 let size = unsafe { amort_subgroups_sizes.get_unchecked_mut(id as usize) };
776 if *size == 0 {
777 unsafe { amort_subgroups_order.push_unchecked(id) };
778 }
779 *size += 1;
780 }
781
782 if matches!(self.mapping, WindowMapping::Explode) {
783 strategy_explode_groups.push([
784 subgroups.len() as IdxSize,
785 amort_subgroups_order.len() as IdxSize,
786 ]);
787 }
788
789 let mut offset = if needs_remap_to_rows {
791 gather_indices_offset
792 } else {
793 subgroups.len() as IdxSize
794 };
795 for &id in &amort_subgroups_order {
796 let size = *unsafe { amort_subgroups_sizes.get_unchecked(id as usize) };
797 let (next_gather_idx, indices) =
798 unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
799 indices.reserve(size as usize);
800 *next_gather_idx = offset;
801 offset += if needs_remap_to_rows { size } else { 1 };
802 }
803
804 if matches!(self.mapping, WindowMapping::Explode) {
806 for i in $iter {
807 let id = *unsafe { unique_ids.get_unchecked(i as usize) };
808 let (_, indices) =
809 unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
810 unsafe { indices.push_unchecked(i) };
811 }
812 } else {
813 if needs_remap_to_rows && let Some((_, arr, options)) = &order_by {
822 let arr = arr.as_ref().unwrap();
823 amort_arg_sort.clear();
824 amort_arg_sort.extend(0..$iter.len() as IdxSize);
825 match arr.validity() {
826 None => {
827 let arr = arr.values().as_slice();
828 amort_arg_sort.sort_by(|a, b| {
829 let in_group_idx_a = $get(*a as usize) as usize;
830 let in_group_idx_b = $get(*b as usize) as usize;
831
832 let order_a = unsafe { arr.get_unchecked(in_group_idx_a) };
833 let order_b = unsafe { arr.get_unchecked(in_group_idx_b) };
834
835 let mut cmp = order_a.cmp(&order_b);
836 if options.descending {
838 cmp = cmp.reverse();
839 }
840 cmp
841 });
842 },
843 Some(validity) => {
844 let arr = arr.values().as_slice();
845 amort_arg_sort.sort_by(|a, b| {
846 let in_group_idx_a = $get(*a as usize) as usize;
847 let in_group_idx_b = $get(*b as usize) as usize;
848
849 let is_valid_a =
850 unsafe { validity.get_bit_unchecked(in_group_idx_a) };
851 let is_valid_b =
852 unsafe { validity.get_bit_unchecked(in_group_idx_b) };
853
854 if !(is_valid_a & is_valid_b) {
855 let mut cmp = is_valid_a.cmp(&is_valid_b);
856 if options.nulls_last {
857 cmp = cmp.reverse();
858 }
859 return cmp;
860 }
861
862 let order_a = unsafe { arr.get_unchecked(in_group_idx_a) };
863 let order_b = unsafe { arr.get_unchecked(in_group_idx_b) };
864
865 let mut cmp = order_a.cmp(&order_b);
866 if options.descending {
867 cmp = cmp.reverse();
868 }
869 cmp
870 });
871 },
872 }
873
874 amort_offsets.clear();
875 amort_offsets.resize($iter.len(), 0);
876 for &id in &amort_subgroups_order {
877 amort_subgroups_sizes[id as usize] = 0;
878 }
879
880 for &idx in &amort_arg_sort {
881 let in_group_idx = $get(idx as usize);
882 let id = *unsafe { unique_ids.get_unchecked(in_group_idx as usize) };
883 amort_offsets[idx as usize] = amort_subgroups_sizes[id as usize];
884 amort_subgroups_sizes[id as usize] += 1;
885 }
886
887 for (i, offset) in $iter.zip(&amort_offsets) {
888 let id = *unsafe { unique_ids.get_unchecked(i as usize) };
889 let (next_gather_idx, indices) =
890 unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
891 unsafe {
892 subgroup_gather_indices.push_unchecked(*next_gather_idx + *offset)
893 };
894 unsafe { indices.push_unchecked(i) };
895 }
896 } else {
897 for i in $iter {
898 let id = *unsafe { unique_ids.get_unchecked(i as usize) };
899 let (next_gather_idx, indices) =
900 unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
901 unsafe { subgroup_gather_indices.push_unchecked(*next_gather_idx) };
902 *next_gather_idx += IdxSize::from(needs_remap_to_rows);
903 unsafe { indices.push_unchecked(i) };
904 }
905 }
906 }
907
908 subgroups.extend(amort_subgroups_order.iter().map(|&id| {
910 let (_, indices) =
911 unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
912 let indices = std::mem::take(indices);
913 (*unsafe { indices.get_unchecked(0) }, indices)
914 }));
915
916 if !matches!(self.mapping, WindowMapping::Explode) {
917 gather_indices_offset += subgroup_gather_indices.len() as IdxSize;
918 gather_indices.push((
919 subgroup_gather_indices.first().copied().unwrap_or(0),
920 subgroup_gather_indices,
921 ));
922 }
923 };
924 }
925 match groups.as_ref() {
926 GroupsType::Idx(idxs) => {
927 for g in idxs.all() {
928 map_window_groups!(g.iter().copied(), (|i: usize| g[i]));
929 }
930 },
931 GroupsType::Slice {
932 groups,
933 overlapping: _,
934 monotonic: _,
935 } => {
936 for [s, l] in groups.iter() {
937 let s = *s;
938 let l = *l;
939 let iter = unsafe { TrustMyLength::new(s..s + l, l as usize) };
940 map_window_groups!(iter, (|i: usize| s + i as IdxSize));
941 }
942 },
943 }
944
945 let mut subgroups = GroupsType::Idx(subgroups.into());
946 if let Some((order_by, _, options)) = order_by {
947 subgroups =
948 update_groups_sort_by(&subgroups, order_by.as_materialized_series(), &options)?;
949 }
950 let subgroups = subgroups.into_sliceable();
951 let mut data = self
952 .phys_function
953 .evaluate_on_groups(df, &subgroups, state)?
954 .finalize();
955
956 let final_groups = if matches!(self.mapping, WindowMapping::Explode) {
957 if !function_is_scalar {
958 let (data_s, offsets) = data.list()?.explode_and_offsets(ExplodeOptions {
959 empty_as_null: false,
960 keep_nulls: false,
961 })?;
962 data = data_s.into_column();
963
964 let mut exploded_offset = 0;
965 for [start, length] in strategy_explode_groups.iter_mut() {
966 let exploded_start = exploded_offset;
967 let exploded_length = offsets
968 .lengths()
969 .skip(*start as usize)
970 .take(*length as usize)
971 .sum::<usize>() as IdxSize;
972 exploded_offset += exploded_length;
973 *start = exploded_start;
974 *length = exploded_length;
975 }
976 }
977 GroupsType::new_slice(strategy_explode_groups, false, true)
978 } else {
979 if needs_remap_to_rows {
980 let data_l = data.list()?;
981 assert_eq!(data_l.len(), subgroups.len());
982 let lengths = data_l.lst_lengths();
983 let length_mismatch = match subgroups.as_ref() {
984 GroupsType::Idx(idx) => idx
985 .all()
986 .iter()
987 .zip(lengths.iter())
988 .any(|(i, l)| i.len() as IdxSize != l.unwrap()),
989 GroupsType::Slice {
990 groups,
991 overlapping: _,
992 monotonic: _,
993 } => groups
994 .iter()
995 .zip(lengths.iter())
996 .any(|([_, i], l)| *i != l.unwrap()),
997 };
998
999 polars_ensure!(
1000 !length_mismatch,
1001 expr = self.expr, ShapeMismatch:
1002 "the length of the window expression did not match that of the group"
1003 );
1004
1005 data = data_l
1006 .explode(ExplodeOptions {
1007 empty_as_null: false,
1008 keep_nulls: true,
1009 })?
1010 .into_column();
1011 }
1012 GroupsType::Idx(gather_indices.into())
1013 }
1014 .into_sliceable();
1015
1016 Ok(AggregationContext {
1017 state: AggState::NotAggregated(data),
1018 groups: Cow::Owned(final_groups),
1019 update_groups: UpdateGroups::No,
1020 original_len: false,
1021 })
1022 }
1023
1024 fn as_expression(&self) -> Option<&Expr> {
1025 Some(&self.expr)
1026 }
1027}
1028
1029fn materialize_column(join_opt_ids: &ChunkJoinOptIds, out_column: &Column) -> Column {
1030 {
1031 use arrow::Either;
1032 use polars_ops::chunked_array::TakeChunked;
1033
1034 match join_opt_ids {
1035 Either::Left(ids) => unsafe {
1036 IdxCa::with_nullable_idx(ids, |idx| out_column.take_unchecked(idx))
1037 },
1038 Either::Right(ids) => unsafe { out_column.take_opt_chunked_unchecked(ids, false) },
1039 }
1040 }
1041}
1042
1043fn set_by_groups(
1045 s: &Column,
1046 ac: &AggregationContext,
1047 gb_groups: &GroupPositions,
1048 len: usize,
1049 update_groups: bool,
1050) -> Option<Column> {
1051 let groups = match ac.agg_state() {
1052 AggState::AggregatedScalar(_) | AggState::LiteralScalar(_) => gb_groups,
1053 AggState::NotAggregated(_) | AggState::AggregatedList(_) => {
1054 if update_groups || !ac.original_len {
1055 return None;
1056 } else {
1057 &ac.groups
1058 }
1059 },
1060 };
1061
1062 if s.dtype().to_physical().is_primitive_numeric() {
1063 let dtype = s.dtype();
1064 let s = s.to_physical_repr();
1065
1066 macro_rules! dispatch {
1067 ($ca:expr) => {{ Some(set_numeric($ca, groups, len)) }};
1068 }
1069
1070 downcast_as_macro_arg_physical!(&s, dispatch)
1071 .map(|s| unsafe { s.from_physical_unchecked(dtype) }.unwrap())
1072 .map(Column::from)
1073 } else {
1074 None
1075 }
1076}
1077
1078fn set_numeric<T: PolarsNumericType>(
1079 ca: &ChunkedArray<T>,
1080 groups: &GroupsType,
1081 len: usize,
1082) -> Series {
1083 let mut values = Vec::with_capacity(len);
1084 let ptr: *mut T::Native = values.as_mut_ptr();
1085 let sync_ptr_values = unsafe { SyncPtr::new(ptr) };
1088
1089 if ca.null_count() == 0 {
1090 let ca = ca.rechunk();
1091 match groups {
1092 GroupsType::Idx(groups) => {
1093 let agg_vals = ca.cont_slice().expect("rechunked");
1094 RAYON.install(|| {
1095 agg_vals
1096 .par_iter()
1097 .zip(groups.all().par_iter())
1098 .for_each(|(v, g)| {
1099 let ptr = sync_ptr_values.get();
1100 for idx in g.as_slice() {
1101 debug_assert!((*idx as usize) < len);
1102 unsafe { *ptr.add(*idx as usize) = *v }
1103 }
1104 })
1105 })
1106 },
1107 GroupsType::Slice { groups, .. } => {
1108 let agg_vals = ca.cont_slice().expect("rechunked");
1109 RAYON.install(|| {
1110 agg_vals
1111 .par_iter()
1112 .zip(groups.par_iter())
1113 .for_each(|(v, [start, g_len])| {
1114 let ptr = sync_ptr_values.get();
1115 let start = *start as usize;
1116 let end = start + *g_len as usize;
1117 for idx in start..end {
1118 debug_assert!(idx < len);
1119 unsafe { *ptr.add(idx) = *v }
1120 }
1121 })
1122 });
1123 },
1124 }
1125
1126 unsafe { values.set_len(len) }
1128 ChunkedArray::<T>::new_vec(ca.name().clone(), values).into_series()
1129 } else {
1130 let mut validity: Vec<bool> = vec![false; len];
1133 let validity_ptr = validity.as_mut_ptr();
1134 let sync_ptr_validity = unsafe { SyncPtr::new(validity_ptr) };
1135
1136 let n_threads = RAYON.current_num_threads();
1137 let offsets = _split_offsets(ca.len(), n_threads);
1138
1139 match groups {
1140 GroupsType::Idx(groups) => offsets.par_iter().for_each(|(offset, offset_len)| {
1141 let offset = *offset;
1142 let offset_len = *offset_len;
1143 let ca = ca.slice(offset as i64, offset_len);
1144 let groups = &groups.all()[offset..offset + offset_len];
1145 let values_ptr = sync_ptr_values.get();
1146 let validity_ptr = sync_ptr_validity.get();
1147
1148 ca.iter().zip(groups.iter()).for_each(|(opt_v, g)| {
1149 for idx in g.as_slice() {
1150 let idx = *idx as usize;
1151 debug_assert!(idx < len);
1152 unsafe {
1153 match opt_v {
1154 Some(v) => {
1155 *values_ptr.add(idx) = v;
1156 *validity_ptr.add(idx) = true;
1157 },
1158 None => {
1159 *values_ptr.add(idx) = T::Native::default();
1160 *validity_ptr.add(idx) = false;
1161 },
1162 };
1163 }
1164 }
1165 })
1166 }),
1167 GroupsType::Slice { groups, .. } => {
1168 offsets.par_iter().for_each(|(offset, offset_len)| {
1169 let offset = *offset;
1170 let offset_len = *offset_len;
1171 let ca = ca.slice(offset as i64, offset_len);
1172 let groups = &groups[offset..offset + offset_len];
1173 let values_ptr = sync_ptr_values.get();
1174 let validity_ptr = sync_ptr_validity.get();
1175
1176 for (opt_v, [start, g_len]) in ca.iter().zip(groups.iter()) {
1177 let start = *start as usize;
1178 let end = start + *g_len as usize;
1179 for idx in start..end {
1180 debug_assert!(idx < len);
1181 unsafe {
1182 match opt_v {
1183 Some(v) => {
1184 *values_ptr.add(idx) = v;
1185 *validity_ptr.add(idx) = true;
1186 },
1187 None => {
1188 *values_ptr.add(idx) = T::Native::default();
1189 *validity_ptr.add(idx) = false;
1190 },
1191 };
1192 }
1193 }
1194 }
1195 })
1196 },
1197 }
1198 unsafe { values.set_len(len) }
1200 let validity = Bitmap::from(validity);
1201 let arr = PrimitiveArray::new(
1202 T::get_static_dtype()
1203 .to_physical()
1204 .to_arrow(CompatLevel::newest()),
1205 values.into(),
1206 Some(validity),
1207 );
1208 Series::try_from((ca.name().clone(), arr.boxed())).unwrap()
1209 }
1210}