Skip to main content

polars_expr/expressions/
window.rs

1use std::cmp::Ordering;
2use std::fmt::Write;
3
4use arrow::array::PrimitiveArray;
5use arrow::bitmap::Bitmap;
6use arrow::trusted_len::TrustMyLength;
7use polars_core::error::feature_gated;
8use polars_core::prelude::row_encode::encode_rows_unordered;
9use polars_core::prelude::sort::perfect_sort;
10use polars_core::prelude::*;
11use polars_core::series::IsSorted;
12use polars_core::utils::_split_offsets;
13use polars_core::{POOL, downcast_as_macro_arg_physical};
14use polars_ops::frame::SeriesJoin;
15use polars_ops::frame::join::{ChunkJoinOptIds, private_left_join_multiple_keys};
16use polars_ops::prelude::*;
17use polars_plan::prelude::*;
18use polars_utils::UnitVec;
19use polars_utils::sync::SyncPtr;
20use polars_utils::vec::PushUnchecked;
21use rayon::prelude::*;
22
23use super::*;
24
25pub struct WindowExpr {
26    /// the root column that the Function will be applied on.
27    /// This will be used to create a smaller DataFrame to prevent taking unneeded columns by index
28    pub(crate) group_by: Vec<Arc<dyn PhysicalExpr>>,
29    pub(crate) order_by: Option<(Arc<dyn PhysicalExpr>, SortOptions)>,
30    pub(crate) apply_columns: Vec<PlSmallStr>,
31    pub(crate) phys_function: Arc<dyn PhysicalExpr>,
32    pub(crate) mapping: WindowMapping,
33    pub(crate) expr: Expr,
34    pub(crate) has_different_group_sources: bool,
35    pub(crate) output_field: Field,
36
37    pub(crate) all_group_by_are_elementwise: bool,
38    pub(crate) order_by_is_elementwise: bool,
39}
40
41#[cfg_attr(debug_assertions, derive(Debug))]
42enum MapStrategy {
43    // Join by key, this the most expensive
44    // for reduced aggregations
45    Join,
46    // explode now
47    Explode,
48    // Use an arg_sort to map the values back
49    Map,
50    Nothing,
51}
52
53impl WindowExpr {
54    fn map_list_agg_by_arg_sort(
55        &self,
56        out_column: Column,
57        flattened: &Column,
58        mut ac: AggregationContext,
59        gb: GroupBy,
60    ) -> PolarsResult<IdxCa> {
61        // idx (new-idx, original-idx)
62        let mut idx_mapping = Vec::with_capacity(out_column.len());
63
64        // we already set this buffer so we can reuse the `original_idx` buffer
65        // that saves an allocation
66        let mut take_idx = vec![];
67
68        // groups are not changed, we can map by doing a standard arg_sort.
69        if std::ptr::eq(ac.groups().as_ref(), gb.get_groups()) {
70            let mut iter = 0..flattened.len() as IdxSize;
71            match ac.groups().as_ref().as_ref() {
72                GroupsType::Idx(groups) => {
73                    for g in groups.all() {
74                        idx_mapping.extend(g.iter().copied().zip(&mut iter));
75                    }
76                },
77                GroupsType::Slice { groups, .. } => {
78                    for &[first, len] in groups {
79                        idx_mapping.extend((first..first + len).zip(&mut iter));
80                    }
81                },
82            }
83        }
84        // groups are changed, we use the new group indexes as arguments of the arg_sort
85        // and sort by the old indexes
86        else {
87            let mut original_idx = Vec::with_capacity(out_column.len());
88            match gb.get_groups().as_ref() {
89                GroupsType::Idx(groups) => {
90                    for g in groups.all() {
91                        original_idx.extend_from_slice(g)
92                    }
93                },
94                GroupsType::Slice { groups, .. } => {
95                    for &[first, len] in groups {
96                        original_idx.extend(first..first + len)
97                    }
98                },
99            };
100
101            let mut original_idx_iter = original_idx.iter().copied();
102
103            match ac.groups().as_ref().as_ref() {
104                GroupsType::Idx(groups) => {
105                    for g in groups.all() {
106                        idx_mapping.extend(g.iter().copied().zip(&mut original_idx_iter));
107                    }
108                },
109                GroupsType::Slice { groups, .. } => {
110                    for &[first, len] in groups {
111                        idx_mapping.extend((first..first + len).zip(&mut original_idx_iter));
112                    }
113                },
114            }
115            original_idx.clear();
116            take_idx = original_idx;
117        }
118        // SAFETY:
119        // we only have unique indices ranging from 0..len
120        unsafe { perfect_sort(&idx_mapping, &mut take_idx) };
121        Ok(IdxCa::from_vec(PlSmallStr::EMPTY, take_idx))
122    }
123
124    #[allow(clippy::too_many_arguments)]
125    fn map_by_arg_sort(
126        &self,
127        df: &DataFrame,
128        out_column: Column,
129        flattened: &Column,
130        mut ac: AggregationContext,
131        group_by_columns: &[Column],
132        gb: GroupBy,
133        cache_key: String,
134        state: &ExecutionState,
135    ) -> PolarsResult<Column> {
136        // we use an arg_sort to map the values back
137
138        // This is a bit more complicated because the final group tuples may differ from the original
139        // so we use the original indices as idx values to arg_sort the original column
140        //
141        // The example below shows the naive version without group tuple mapping
142
143        // columns
144        // a b a a
145        //
146        // agg list
147        // [0, 2, 3]
148        // [1]
149        //
150        // flatten
151        //
152        // [0, 2, 3, 1]
153        //
154        // arg_sort
155        //
156        // [0, 3, 1, 2]
157        //
158        // take by arg_sorted indexes and voila groups mapped
159        // [0, 1, 2, 3]
160
161        if flattened.len() != df.height() {
162            let ca = out_column.list().unwrap();
163            let non_matching_group =
164                ca.into_iter()
165                    .zip(ac.groups().iter())
166                    .find(|(output, group)| {
167                        if let Some(output) = output {
168                            output.as_ref().len() != group.len()
169                        } else {
170                            false
171                        }
172                    });
173
174            if let Some((output, group)) = non_matching_group {
175                let first = group.first();
176                let group = group_by_columns
177                    .iter()
178                    .map(|s| format!("{}", s.get(first as usize).unwrap()))
179                    .collect::<Vec<_>>();
180                polars_bail!(
181                    expr = self.expr, ShapeMismatch:
182                    "the length of the window expression did not match that of the group\
183                    \n> group: {}\n> group length: {}\n> output: '{:?}'",
184                    comma_delimited(String::new(), &group), group.len(), output.unwrap()
185                );
186            } else {
187                polars_bail!(
188                    expr = self.expr, ShapeMismatch:
189                    "the length of the window expression did not match that of the group"
190                );
191            };
192        }
193
194        let idx = if state.cache_window() {
195            if let Some(idx) = state.window_cache.get_map(&cache_key) {
196                idx
197            } else {
198                let idx = Arc::new(self.map_list_agg_by_arg_sort(out_column, flattened, ac, gb)?);
199                state.window_cache.insert_map(cache_key, idx.clone());
200                idx
201            }
202        } else {
203            Arc::new(self.map_list_agg_by_arg_sort(out_column, flattened, ac, gb)?)
204        };
205
206        // SAFETY:
207        // groups should always be in bounds.
208        unsafe { Ok(flattened.take_unchecked(&idx)) }
209    }
210
211    fn run_aggregation<'a>(
212        &self,
213        df: &DataFrame,
214        state: &ExecutionState,
215        gb: &'a GroupBy,
216    ) -> PolarsResult<AggregationContext<'a>> {
217        let ac = self
218            .phys_function
219            .evaluate_on_groups(df, gb.get_groups(), state)?;
220        Ok(ac)
221    }
222
223    fn is_explicit_list_agg(&self) -> bool {
224        // col("foo").implode()
225        // col("foo").implode().alias()
226        // ..
227        // col("foo").implode().alias().alias()
228        //
229        // but not:
230        // col("foo").implode().sum().alias()
231        // ..
232        // col("foo").min()
233        let mut explicit_list = false;
234        for e in &self.expr {
235            if let Expr::Over { function, .. } = e {
236                // or list().alias
237                let mut finishes_list = false;
238                for e in &**function {
239                    match e {
240                        Expr::Agg(AggExpr::Implode(_)) => {
241                            finishes_list = true;
242                        },
243                        Expr::Alias(_, _) => {},
244                        _ => break,
245                    }
246                }
247                explicit_list = finishes_list;
248            }
249        }
250
251        explicit_list
252    }
253
254    fn is_simple_column_expr(&self) -> bool {
255        // col()
256        // or col().alias()
257        let mut simple_col = false;
258        for e in &self.expr {
259            if let Expr::Over { function, .. } = e {
260                // or list().alias
261                for e in &**function {
262                    match e {
263                        Expr::Column(_) => {
264                            simple_col = true;
265                        },
266                        Expr::Alias(_, _) => {},
267                        _ => break,
268                    }
269                }
270            }
271        }
272        simple_col
273    }
274
275    fn is_aggregation(&self) -> bool {
276        // col()
277        // or col().agg()
278        let mut agg_col = false;
279        for e in &self.expr {
280            if let Expr::Over { function, .. } = e {
281                // or list().alias
282                for e in &**function {
283                    match e {
284                        Expr::Agg(_) => {
285                            agg_col = true;
286                        },
287                        Expr::Alias(_, _) => {},
288                        _ => break,
289                    }
290                }
291            }
292        }
293        agg_col
294    }
295
296    fn determine_map_strategy(
297        &self,
298        ac: &mut AggregationContext,
299        gb: &GroupBy,
300    ) -> PolarsResult<MapStrategy> {
301        match (self.mapping, ac.agg_state()) {
302            // Explode
303            // `(col("x").sum() * col("y")).list().over("groups").flatten()`
304            (WindowMapping::Explode, _) => Ok(MapStrategy::Explode),
305            // // explicit list
306            // // `(col("x").sum() * col("y")).list().over("groups")`
307            // (false, false, _) => Ok(MapStrategy::Join),
308            // aggregations
309            //`sum("foo").over("groups")`
310            (_, AggState::AggregatedScalar(_)) => Ok(MapStrategy::Join),
311            // no explicit aggregations, map over the groups
312            //`(col("x").sum() * col("y")).over("groups")`
313            (WindowMapping::Join, AggState::AggregatedList(_)) => Ok(MapStrategy::Join),
314            // no explicit aggregations, map over the groups
315            //`(col("x").sum() * col("y")).over("groups")`
316            (WindowMapping::GroupsToRows, AggState::AggregatedList(_)) => {
317                if let GroupsType::Slice { .. } = gb.get_groups().as_ref() {
318                    // Result can be directly exploded if the input was sorted.
319                    ac.groups().as_ref().check_lengths(gb.get_groups())?;
320                    Ok(MapStrategy::Explode)
321                } else {
322                    Ok(MapStrategy::Map)
323                }
324            },
325            // no aggregations, just return column
326            // or an aggregation that has been flattened
327            // we have to check which one
328            //`col("foo").over("groups")`
329            (WindowMapping::GroupsToRows, AggState::NotAggregated(_)) => {
330                // col()
331                // or col().alias()
332                if self.is_simple_column_expr() {
333                    Ok(MapStrategy::Nothing)
334                } else {
335                    Ok(MapStrategy::Map)
336                }
337            },
338            (WindowMapping::Join, AggState::NotAggregated(_)) => Ok(MapStrategy::Join),
339            // literals, do nothing and let broadcast
340            (_, AggState::LiteralScalar(_)) => Ok(MapStrategy::Nothing),
341        }
342    }
343}
344
345// Utility to create partitions and cache keys
346pub fn window_function_format_order_by(to: &mut String, e: &Expr, k: &SortOptions) {
347    write!(to, "_PL_{:?}{}_{}", e, k.descending, k.nulls_last).unwrap();
348}
349
350impl PhysicalExpr for WindowExpr {
351    // Note: this was first implemented with expression evaluation but this performed really bad.
352    // Therefore we choose the group_by -> apply -> self join approach
353
354    // This first cached the group_by and the join tuples, but rayon under a mutex leads to deadlocks:
355    // https://github.com/rayon-rs/rayon/issues/592
356    fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
357        // This method does the following:
358        // 1. determine group_by tuples based on the group_column
359        // 2. apply an aggregation function
360        // 3. join the results back to the original dataframe
361        //    this stores all group values on the original df size
362        //
363        //      we have several strategies for this
364        //      - 3.1 JOIN
365        //          Use a join for aggregations like
366        //              `sum("foo").over("groups")`
367        //          and explicit `list` aggregations
368        //              `(col("x").sum() * col("y")).list().over("groups")`
369        //
370        //      - 3.2 EXPLODE
371        //          Explicit list aggregations that are followed by `over().flatten()`
372        //          # the fastest method to do things over groups when the groups are sorted.
373        //          # note that it will require an explicit `list()` call from now on.
374        //              `(col("x").sum() * col("y")).list().over("groups").flatten()`
375        //
376        //      - 3.3. MAP to original locations
377        //          This will be done for list aggregations that are not explicitly aggregated as list
378        //              `(col("x").sum() * col("y")).over("groups")
379        //          This can be used to reverse, sort, shuffle etc. the values in a group
380
381        // 4. select the final column and return
382
383        if df.height() == 0 {
384            let field = self.phys_function.to_field(df.schema())?;
385            match self.mapping {
386                WindowMapping::Join => {
387                    return Ok(Column::full_null(
388                        field.name().clone(),
389                        0,
390                        &DataType::List(Box::new(field.dtype().clone())),
391                    ));
392                },
393                _ => {
394                    return Ok(Column::full_null(field.name().clone(), 0, field.dtype()));
395                },
396            }
397        }
398
399        let mut group_by_columns = self
400            .group_by
401            .iter()
402            .map(|e| e.evaluate(df, state))
403            .collect::<PolarsResult<Vec<_>>>()?;
404
405        // if the keys are sorted
406        let sorted_keys = group_by_columns.iter().all(|s| {
407            matches!(
408                s.is_sorted_flag(),
409                IsSorted::Ascending | IsSorted::Descending
410            )
411        });
412        let explicit_list_agg = self.is_explicit_list_agg();
413
414        // if we flatten this column we need to make sure the groups are sorted.
415        let mut sort_groups = matches!(self.mapping, WindowMapping::Explode) ||
416            // if not
417            //      `col().over()`
418            // and not
419            //      `col().list().over`
420            // and not
421            //      `col().sum()`
422            // and keys are sorted
423            //  we may optimize with explode call
424            (!self.is_simple_column_expr() && !explicit_list_agg && sorted_keys && !self.is_aggregation());
425
426        // overwrite sort_groups for some expressions
427        // TODO: fully understand the rationale is here.
428        if self.has_different_group_sources {
429            sort_groups = true
430        }
431
432        let create_groups = || {
433            let gb = df.group_by_with_series(group_by_columns.clone(), true, sort_groups)?;
434            let mut groups = gb.into_groups();
435
436            if let Some((order_by, options)) = &self.order_by {
437                let order_by = order_by.evaluate(df, state)?;
438                polars_ensure!(order_by.len() == df.height(), ShapeMismatch: "the order by expression evaluated to a length: {} that doesn't match the input DataFrame: {}", order_by.len(), df.height());
439                groups = update_groups_sort_by(&groups, order_by.as_materialized_series(), options)?
440                    .into_sliceable()
441            }
442
443            let out: PolarsResult<GroupPositions> = Ok(groups);
444            out
445        };
446
447        // Try to get cached grouptuples
448        let (mut groups, cache_key) = if state.cache_window() {
449            let mut cache_key = String::with_capacity(32 * group_by_columns.len());
450            write!(&mut cache_key, "{}", state.branch_idx).unwrap();
451            for s in &group_by_columns {
452                cache_key.push_str(s.name());
453            }
454            if let Some((e, options)) = &self.order_by {
455                let e = match e.as_expression() {
456                    Some(e) => e,
457                    None => {
458                        polars_bail!(InvalidOperation: "cannot order by this expression in window function")
459                    },
460                };
461                window_function_format_order_by(&mut cache_key, e, options)
462            }
463
464            let groups = match state.window_cache.get_groups(&cache_key) {
465                Some(groups) => groups,
466                None => create_groups()?,
467            };
468            (groups, cache_key)
469        } else {
470            (create_groups()?, "".to_string())
471        };
472
473        // 2. create GroupBy object and apply aggregation
474        let apply_columns = self.apply_columns.clone();
475
476        // some window expressions need sorted groups
477        // to make sure that the caches align we sort
478        // the groups, so that the cached groups and join keys
479        // are consistent among all windows
480        if sort_groups || state.cache_window() {
481            groups.sort();
482            state
483                .window_cache
484                .insert_groups(cache_key.clone(), groups.clone());
485        }
486
487        // broadcast if required
488        for col in group_by_columns.iter_mut() {
489            if col.len() != df.height() {
490                polars_ensure!(
491                    col.len() == 1,
492                    ShapeMismatch: "columns used as `partition_by` must have the same length as the DataFrame"
493                );
494                *col = col.new_from_index(0, df.height())
495            }
496        }
497
498        let gb = GroupBy::new(df, group_by_columns.clone(), groups, Some(apply_columns));
499
500        let mut ac = self.run_aggregation(df, state, &gb)?;
501
502        use MapStrategy::*;
503
504        match self.determine_map_strategy(&mut ac, &gb)? {
505            Nothing => {
506                let mut out = ac.flat_naive().into_owned();
507
508                if ac.is_literal() {
509                    out = out.new_from_index(0, df.height())
510                }
511                Ok(out.into_column())
512            },
513            Explode => {
514                let out = if self.phys_function.is_scalar() {
515                    ac.get_values().clone()
516                } else {
517                    ac.aggregated().explode(ExplodeOptions {
518                        empty_as_null: true,
519                        keep_nulls: true,
520                    })?
521                };
522                Ok(out.into_column())
523            },
524            Map => {
525                // TODO!
526                // investigate if sorted arrays can be return directly
527                let out_column = ac.aggregated();
528                let flattened = out_column.explode(ExplodeOptions {
529                    empty_as_null: true,
530                    keep_nulls: true,
531                })?;
532                // we extend the lifetime as we must convince the compiler that ac lives
533                // long enough. We drop `GrouBy` when we are done with `ac`.
534                let ac = unsafe {
535                    std::mem::transmute::<AggregationContext<'_>, AggregationContext<'static>>(ac)
536                };
537                self.map_by_arg_sort(
538                    df,
539                    out_column,
540                    &flattened,
541                    ac,
542                    &group_by_columns,
543                    gb,
544                    cache_key,
545                    state,
546                )
547            },
548            Join => {
549                let out_column = ac.aggregated();
550                // we try to flatten/extend the array by repeating the aggregated value n times
551                // where n is the number of members in that group. That way we can try to reuse
552                // the same map by arg_sort logic as done for listed aggregations
553                let update_groups = !matches!(&ac.update_groups, UpdateGroups::No);
554                match (
555                    &ac.update_groups,
556                    set_by_groups(&out_column, &ac, df.height(), update_groups),
557                ) {
558                    // for aggregations that reduce like sum, mean, first and are numeric
559                    // we take the group locations to directly map them to the right place
560                    (UpdateGroups::No, Some(out)) => Ok(out.into_column()),
561                    (_, _) => {
562                        let keys = gb.keys();
563
564                        let get_join_tuples = || {
565                            if group_by_columns.len() == 1 {
566                                let mut left = group_by_columns[0].clone();
567                                // group key from right column
568                                let mut right = keys[0].clone();
569
570                                let (left, right) = if left.dtype().is_nested() {
571                                    (
572                                        ChunkedArray::<BinaryOffsetType>::with_chunk(
573                                            "".into(),
574                                            row_encode::_get_rows_encoded_unordered(&[
575                                                left.clone()
576                                            ])?
577                                            .into_array(),
578                                        )
579                                        .into_series(),
580                                        ChunkedArray::<BinaryOffsetType>::with_chunk(
581                                            "".into(),
582                                            row_encode::_get_rows_encoded_unordered(&[
583                                                right.clone()
584                                            ])?
585                                            .into_array(),
586                                        )
587                                        .into_series(),
588                                    )
589                                } else {
590                                    (
591                                        left.into_materialized_series().clone(),
592                                        right.into_materialized_series().clone(),
593                                    )
594                                };
595
596                                PolarsResult::Ok(Arc::new(
597                                    left.hash_join_left(&right, JoinValidation::ManyToMany, true)
598                                        .unwrap()
599                                        .1,
600                                ))
601                            } else {
602                                let df_right =
603                                    unsafe { DataFrame::new_unchecked_infer_height(keys) };
604                                let df_left = unsafe {
605                                    DataFrame::new_unchecked_infer_height(group_by_columns)
606                                };
607                                Ok(Arc::new(
608                                    private_left_join_multiple_keys(&df_left, &df_right, true)?.1,
609                                ))
610                            }
611                        };
612
613                        // try to get cached join_tuples
614                        let join_opt_ids = if state.cache_window() {
615                            if let Some(jt) = state.window_cache.get_join(&cache_key) {
616                                jt
617                            } else {
618                                let jt = get_join_tuples()?;
619                                state.window_cache.insert_join(cache_key, jt.clone());
620                                jt
621                            }
622                        } else {
623                            get_join_tuples()?
624                        };
625
626                        let out = materialize_column(&join_opt_ids, &out_column);
627                        Ok(out.into_column())
628                    },
629                }
630            },
631        }
632    }
633
634    fn to_field(&self, _input_schema: &Schema) -> PolarsResult<Field> {
635        Ok(self.output_field.clone())
636    }
637
638    fn is_scalar(&self) -> bool {
639        false
640    }
641
642    #[allow(clippy::ptr_arg)]
643    fn evaluate_on_groups<'a>(
644        &self,
645        df: &DataFrame,
646        groups: &'a GroupPositions,
647        state: &ExecutionState,
648    ) -> PolarsResult<AggregationContext<'a>> {
649        if self.group_by.is_empty()
650            || !self.all_group_by_are_elementwise
651            || (self.order_by.is_some() && !self.order_by_is_elementwise)
652        {
653            polars_bail!(
654                InvalidOperation:
655                "window expression with non-elementwise `partition_by` or `order_by` not allowed in aggregation context"
656            );
657        }
658
659        let length_preserving_height = if let Some((c, _)) = state.element.as_ref() {
660            c.len()
661        } else {
662            df.height()
663        };
664
665        let function_is_scalar = self.phys_function.is_scalar();
666        let needs_remap_to_rows =
667            matches!(self.mapping, WindowMapping::GroupsToRows) && !function_is_scalar;
668
669        let partition_by_columns = self
670            .group_by
671            .iter()
672            .map(|e| {
673                let mut e = e.evaluate(df, state)?;
674                if e.len() == 1 {
675                    e = e.new_from_index(0, length_preserving_height);
676                }
677                // Sanity check: Length Preserving.
678                assert_eq!(e.len(), length_preserving_height,);
679                Ok(e)
680            })
681            .collect::<PolarsResult<Vec<_>>>()?;
682        let order_by = match &self.order_by {
683            None => None,
684            Some((e, options)) => {
685                let mut e = e.evaluate(df, state)?;
686                if e.len() == 1 {
687                    e = e.new_from_index(0, length_preserving_height);
688                }
689                // Sanity check: Length Preserving.
690                assert_eq!(e.len(), length_preserving_height);
691                let arr: Option<PrimitiveArray<IdxSize>> = if needs_remap_to_rows {
692                    feature_gated!("rank", {
693                        // Performance: precompute the rank here, so we can avoid dispatching per group
694                        // later.
695                        use polars_ops::series::SeriesRank;
696                        let arr = e.as_materialized_series().rank(
697                            RankOptions {
698                                method: RankMethod::Ordinal,
699                                descending: false,
700                            },
701                            None,
702                        );
703                        let arr = arr.idx()?;
704                        let arr = arr.rechunk();
705                        Some(arr.downcast_as_array().clone())
706                    })
707                } else {
708                    None
709                };
710
711                Some((e.clone(), arr, *options))
712            },
713        };
714
715        let (num_unique_ids, unique_ids) = if partition_by_columns.len() == 1 {
716            partition_by_columns[0].unique_id()?
717        } else {
718            ChunkUnique::unique_id(&encode_rows_unordered(&partition_by_columns)?)?
719        };
720
721        // All the groups within the existing groups.
722        let subgroups_approx_capacity = groups.len();
723        let mut subgroups: Vec<(IdxSize, UnitVec<IdxSize>)> =
724            Vec::with_capacity(subgroups_approx_capacity);
725
726        // Indices for the output groups. Not used with `WindowMapping::Explode`.
727        let mut gather_indices_offset = 0;
728        let mut gather_indices: Vec<(IdxSize, UnitVec<IdxSize>)> =
729            Vec::with_capacity(if matches!(self.mapping, WindowMapping::Explode) {
730                0
731            } else {
732                groups.len()
733            });
734        // Slices for the output groups. Only used with `WindowMapping::Explode`.
735        let mut strategy_explode_groups: Vec<[IdxSize; 2]> =
736            Vec::with_capacity(if matches!(self.mapping, WindowMapping::Explode) {
737                groups.len()
738            } else {
739                0
740            });
741
742        // Amortized vectors to reorder based on `order_by`.
743        let mut amort_arg_sort = Vec::new();
744        let mut amort_offsets = Vec::new();
745
746        // Amortized vectors to gather per group data.
747        let mut amort_subgroups_order = Vec::with_capacity(num_unique_ids as usize);
748        let mut amort_subgroups_sizes = Vec::with_capacity(num_unique_ids as usize);
749        let mut amort_subgroups_indices = (0..num_unique_ids)
750            .map(|_| (0, UnitVec::new()))
751            .collect::<Vec<(IdxSize, UnitVec<IdxSize>)>>();
752
753        macro_rules! map_window_groups {
754            ($iter:expr, $get:expr) => {
755                let mut subgroup_gather_indices =
756                    UnitVec::with_capacity(if matches!(self.mapping, WindowMapping::Explode) {
757                        0
758                    } else {
759                        $iter.len()
760                    });
761
762                amort_subgroups_order.clear();
763                amort_subgroups_sizes.clear();
764                amort_subgroups_sizes.resize(num_unique_ids as usize, 0);
765
766                // Determine sizes per subgroup.
767                for i in $iter.clone() {
768                    let id = *unsafe { unique_ids.get_unchecked(i as usize) };
769                    let size = unsafe { amort_subgroups_sizes.get_unchecked_mut(id as usize) };
770                    if *size == 0 {
771                        unsafe { amort_subgroups_order.push_unchecked(id) };
772                    }
773                    *size += 1;
774                }
775
776                if matches!(self.mapping, WindowMapping::Explode) {
777                    strategy_explode_groups.push([
778                        subgroups.len() as IdxSize,
779                        amort_subgroups_order.len() as IdxSize,
780                    ]);
781                }
782
783                // Set starting gather indices and reserve capacity per subgroup.
784                let mut offset = if needs_remap_to_rows {
785                    gather_indices_offset
786                } else {
787                    subgroups.len() as IdxSize
788                };
789                for &id in &amort_subgroups_order {
790                    let size = *unsafe { amort_subgroups_sizes.get_unchecked(id as usize) };
791                    let (next_gather_idx, indices) =
792                        unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
793                    indices.reserve(size as usize);
794                    *next_gather_idx = offset;
795                    offset += if needs_remap_to_rows { size } else { 1 };
796                }
797
798                // Collect gather indices.
799                if matches!(self.mapping, WindowMapping::Explode) {
800                    for i in $iter {
801                        let id = *unsafe { unique_ids.get_unchecked(i as usize) };
802                        let (_, indices) =
803                            unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
804                        unsafe { indices.push_unchecked(i) };
805                    }
806                } else {
807                    // If we are remapping exploded rows back to rows and are reordering, we need
808                    // to ensure we reorder the gather indices as well. Reordering the `subgroup`
809                    // indices is done later.
810                    //
811                    // We having precalculated both the `unique_ids` and `order_by_ranks` in
812                    // efficient kernels, we can now relatively efficient arg_sort per group. This
813                    // is still horrendously slow, but at least not as bad as it would be if you
814                    // did this naively.
815                    if needs_remap_to_rows && let Some((_, arr, options)) = &order_by {
816                        let arr = arr.as_ref().unwrap();
817                        amort_arg_sort.clear();
818                        amort_arg_sort.extend(0..$iter.len() as IdxSize);
819                        match arr.validity() {
820                            None => {
821                                let arr = arr.values().as_slice();
822                                amort_arg_sort.sort_by(|a, b| {
823                                    let in_group_idx_a = $get(*a as usize) as usize;
824                                    let in_group_idx_b = $get(*b as usize) as usize;
825
826                                    let order_a = unsafe { arr.get_unchecked(in_group_idx_a) };
827                                    let order_b = unsafe { arr.get_unchecked(in_group_idx_b) };
828
829                                    let mut cmp = order_a.cmp(&order_b);
830                                    // Performance: This can generally be handled branchlessly.
831                                    if options.descending {
832                                        cmp = cmp.reverse();
833                                    }
834                                    cmp
835                                });
836                            },
837                            Some(validity) => {
838                                let arr = arr.values().as_slice();
839                                amort_arg_sort.sort_by(|a, b| {
840                                    let in_group_idx_a = $get(*a as usize) as usize;
841                                    let in_group_idx_b = $get(*b as usize) as usize;
842
843                                    let is_valid_a =
844                                        unsafe { validity.get_bit_unchecked(in_group_idx_a) };
845                                    let is_valid_b =
846                                        unsafe { validity.get_bit_unchecked(in_group_idx_b) };
847                                    let order_a = unsafe { arr.get_unchecked(in_group_idx_a) };
848                                    let order_b = unsafe { arr.get_unchecked(in_group_idx_b) };
849
850                                    if !is_valid_a & !is_valid_b {
851                                        return Ordering::Equal;
852                                    }
853
854                                    let mut cmp = order_a.cmp(&order_b);
855                                    if !is_valid_a {
856                                        cmp = Ordering::Less;
857                                    }
858                                    if !is_valid_b {
859                                        cmp = Ordering::Greater;
860                                    }
861                                    if options.descending
862                                        | ((!is_valid_a | !is_valid_b) & options.nulls_last)
863                                    {
864                                        cmp = cmp.reverse();
865                                    }
866                                    cmp
867                                });
868                            },
869                        }
870
871                        amort_offsets.clear();
872                        amort_offsets.resize($iter.len(), 0);
873                        for &id in &amort_subgroups_order {
874                            amort_subgroups_sizes[id as usize] = 0;
875                        }
876
877                        for &idx in &amort_arg_sort {
878                            let in_group_idx = $get(idx as usize);
879                            let id = *unsafe { unique_ids.get_unchecked(in_group_idx as usize) };
880                            amort_offsets[idx as usize] = amort_subgroups_sizes[id as usize];
881                            amort_subgroups_sizes[id as usize] += 1;
882                        }
883
884                        for (i, offset) in $iter.zip(&amort_offsets) {
885                            let id = *unsafe { unique_ids.get_unchecked(i as usize) };
886                            let (next_gather_idx, indices) =
887                                unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
888                            unsafe {
889                                subgroup_gather_indices.push_unchecked(*next_gather_idx + *offset)
890                            };
891                            unsafe { indices.push_unchecked(i) };
892                        }
893                    } else {
894                        for i in $iter {
895                            let id = *unsafe { unique_ids.get_unchecked(i as usize) };
896                            let (next_gather_idx, indices) =
897                                unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
898                            unsafe { subgroup_gather_indices.push_unchecked(*next_gather_idx) };
899                            *next_gather_idx += IdxSize::from(needs_remap_to_rows);
900                            unsafe { indices.push_unchecked(i) };
901                        }
902                    }
903                }
904
905                // Push groups into nested_groups.
906                subgroups.extend(amort_subgroups_order.iter().map(|&id| {
907                    let (_, indices) =
908                        unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
909                    let indices = std::mem::take(indices);
910                    (*unsafe { indices.get_unchecked(0) }, indices)
911                }));
912
913                if !matches!(self.mapping, WindowMapping::Explode) {
914                    gather_indices_offset += subgroup_gather_indices.len() as IdxSize;
915                    gather_indices.push((
916                        subgroup_gather_indices.first().copied().unwrap_or(0),
917                        subgroup_gather_indices,
918                    ));
919                }
920            };
921        }
922        match groups.as_ref() {
923            GroupsType::Idx(idxs) => {
924                for g in idxs.all() {
925                    map_window_groups!(g.iter().copied(), (|i: usize| g[i]));
926                }
927            },
928            GroupsType::Slice {
929                groups,
930                overlapping: _,
931                monotonic: _,
932            } => {
933                for [s, l] in groups.iter() {
934                    let s = *s;
935                    let l = *l;
936                    let iter = unsafe { TrustMyLength::new(s..s + l, l as usize) };
937                    map_window_groups!(iter, (|i: usize| s + i as IdxSize));
938                }
939            },
940        }
941
942        let mut subgroups = GroupsType::Idx(subgroups.into());
943        if let Some((order_by, _, options)) = order_by {
944            subgroups =
945                update_groups_sort_by(&subgroups, order_by.as_materialized_series(), &options)?;
946        }
947        let subgroups = subgroups.into_sliceable();
948        let mut data = self
949            .phys_function
950            .evaluate_on_groups(df, &subgroups, state)?
951            .finalize();
952
953        let final_groups = if matches!(self.mapping, WindowMapping::Explode) {
954            if !function_is_scalar {
955                let (data_s, offsets) = data.list()?.explode_and_offsets(ExplodeOptions {
956                    empty_as_null: false,
957                    keep_nulls: false,
958                })?;
959                data = data_s.into_column();
960
961                let mut exploded_offset = 0;
962                for [start, length] in strategy_explode_groups.iter_mut() {
963                    let exploded_start = exploded_offset;
964                    let exploded_length = offsets
965                        .lengths()
966                        .skip(*start as usize)
967                        .take(*length as usize)
968                        .sum::<usize>() as IdxSize;
969                    exploded_offset += exploded_length;
970                    *start = exploded_start;
971                    *length = exploded_length;
972                }
973            }
974            GroupsType::new_slice(strategy_explode_groups, false, true)
975        } else {
976            if needs_remap_to_rows {
977                let data_l = data.list()?;
978                assert_eq!(data_l.len(), subgroups.len());
979                let lengths = data_l.lst_lengths();
980                let length_mismatch = match subgroups.as_ref() {
981                    GroupsType::Idx(idx) => idx
982                        .all()
983                        .iter()
984                        .zip(&lengths)
985                        .any(|(i, l)| i.len() as IdxSize != l.unwrap()),
986                    GroupsType::Slice {
987                        groups,
988                        overlapping: _,
989                        monotonic: _,
990                    } => groups
991                        .iter()
992                        .zip(&lengths)
993                        .any(|([_, i], l)| *i != l.unwrap()),
994                };
995
996                polars_ensure!(
997                    !length_mismatch,
998                    expr = self.expr, ShapeMismatch:
999                    "the length of the window expression did not match that of the group"
1000                );
1001
1002                data = data_l
1003                    .explode(ExplodeOptions {
1004                        empty_as_null: false,
1005                        keep_nulls: true,
1006                    })?
1007                    .into_column();
1008            }
1009            GroupsType::Idx(gather_indices.into())
1010        }
1011        .into_sliceable();
1012
1013        Ok(AggregationContext {
1014            state: AggState::NotAggregated(data),
1015            groups: Cow::Owned(final_groups),
1016            update_groups: UpdateGroups::No,
1017            original_len: false,
1018        })
1019    }
1020
1021    fn as_expression(&self) -> Option<&Expr> {
1022        Some(&self.expr)
1023    }
1024}
1025
1026fn materialize_column(join_opt_ids: &ChunkJoinOptIds, out_column: &Column) -> Column {
1027    {
1028        use arrow::Either;
1029        use polars_ops::chunked_array::TakeChunked;
1030
1031        match join_opt_ids {
1032            Either::Left(ids) => unsafe {
1033                IdxCa::with_nullable_idx(ids, |idx| out_column.take_unchecked(idx))
1034            },
1035            Either::Right(ids) => unsafe { out_column.take_opt_chunked_unchecked(ids, false) },
1036        }
1037    }
1038}
1039
1040/// Simple reducing aggregation can be set by the groups
1041fn set_by_groups(
1042    s: &Column,
1043    ac: &AggregationContext,
1044    len: usize,
1045    update_groups: bool,
1046) -> Option<Column> {
1047    if update_groups || !ac.original_len {
1048        return None;
1049    }
1050    if s.dtype().to_physical().is_primitive_numeric() {
1051        let dtype = s.dtype();
1052        let s = s.to_physical_repr();
1053
1054        macro_rules! dispatch {
1055            ($ca:expr) => {{ Some(set_numeric($ca, &ac.groups, len)) }};
1056        }
1057        downcast_as_macro_arg_physical!(&s, dispatch)
1058            .map(|s| unsafe { s.from_physical_unchecked(dtype) }.unwrap())
1059            .map(Column::from)
1060    } else {
1061        None
1062    }
1063}
1064
1065fn set_numeric<T: PolarsNumericType>(
1066    ca: &ChunkedArray<T>,
1067    groups: &GroupsType,
1068    len: usize,
1069) -> Series {
1070    let mut values = Vec::with_capacity(len);
1071    let ptr: *mut T::Native = values.as_mut_ptr();
1072    // SAFETY:
1073    // we will write from different threads but we will never alias.
1074    let sync_ptr_values = unsafe { SyncPtr::new(ptr) };
1075
1076    if ca.null_count() == 0 {
1077        let ca = ca.rechunk();
1078        match groups {
1079            GroupsType::Idx(groups) => {
1080                let agg_vals = ca.cont_slice().expect("rechunked");
1081                POOL.install(|| {
1082                    agg_vals
1083                        .par_iter()
1084                        .zip(groups.all().par_iter())
1085                        .for_each(|(v, g)| {
1086                            let ptr = sync_ptr_values.get();
1087                            for idx in g.as_slice() {
1088                                debug_assert!((*idx as usize) < len);
1089                                unsafe { *ptr.add(*idx as usize) = *v }
1090                            }
1091                        })
1092                })
1093            },
1094            GroupsType::Slice { groups, .. } => {
1095                let agg_vals = ca.cont_slice().expect("rechunked");
1096                POOL.install(|| {
1097                    agg_vals
1098                        .par_iter()
1099                        .zip(groups.par_iter())
1100                        .for_each(|(v, [start, g_len])| {
1101                            let ptr = sync_ptr_values.get();
1102                            let start = *start as usize;
1103                            let end = start + *g_len as usize;
1104                            for idx in start..end {
1105                                debug_assert!(idx < len);
1106                                unsafe { *ptr.add(idx) = *v }
1107                            }
1108                        })
1109                });
1110            },
1111        }
1112
1113        // SAFETY: we have written all slots
1114        unsafe { values.set_len(len) }
1115        ChunkedArray::<T>::new_vec(ca.name().clone(), values).into_series()
1116    } else {
1117        // We don't use a mutable bitmap as bits will have race conditions!
1118        // A single byte might alias if we write from single threads.
1119        let mut validity: Vec<bool> = vec![false; len];
1120        let validity_ptr = validity.as_mut_ptr();
1121        let sync_ptr_validity = unsafe { SyncPtr::new(validity_ptr) };
1122
1123        let n_threads = POOL.current_num_threads();
1124        let offsets = _split_offsets(ca.len(), n_threads);
1125
1126        match groups {
1127            GroupsType::Idx(groups) => offsets.par_iter().for_each(|(offset, offset_len)| {
1128                let offset = *offset;
1129                let offset_len = *offset_len;
1130                let ca = ca.slice(offset as i64, offset_len);
1131                let groups = &groups.all()[offset..offset + offset_len];
1132                let values_ptr = sync_ptr_values.get();
1133                let validity_ptr = sync_ptr_validity.get();
1134
1135                ca.iter().zip(groups.iter()).for_each(|(opt_v, g)| {
1136                    for idx in g.as_slice() {
1137                        let idx = *idx as usize;
1138                        debug_assert!(idx < len);
1139                        unsafe {
1140                            match opt_v {
1141                                Some(v) => {
1142                                    *values_ptr.add(idx) = v;
1143                                    *validity_ptr.add(idx) = true;
1144                                },
1145                                None => {
1146                                    *values_ptr.add(idx) = T::Native::default();
1147                                    *validity_ptr.add(idx) = false;
1148                                },
1149                            };
1150                        }
1151                    }
1152                })
1153            }),
1154            GroupsType::Slice { groups, .. } => {
1155                offsets.par_iter().for_each(|(offset, offset_len)| {
1156                    let offset = *offset;
1157                    let offset_len = *offset_len;
1158                    let ca = ca.slice(offset as i64, offset_len);
1159                    let groups = &groups[offset..offset + offset_len];
1160                    let values_ptr = sync_ptr_values.get();
1161                    let validity_ptr = sync_ptr_validity.get();
1162
1163                    for (opt_v, [start, g_len]) in ca.iter().zip(groups.iter()) {
1164                        let start = *start as usize;
1165                        let end = start + *g_len as usize;
1166                        for idx in start..end {
1167                            debug_assert!(idx < len);
1168                            unsafe {
1169                                match opt_v {
1170                                    Some(v) => {
1171                                        *values_ptr.add(idx) = v;
1172                                        *validity_ptr.add(idx) = true;
1173                                    },
1174                                    None => {
1175                                        *values_ptr.add(idx) = T::Native::default();
1176                                        *validity_ptr.add(idx) = false;
1177                                    },
1178                                };
1179                            }
1180                        }
1181                    }
1182                })
1183            },
1184        }
1185        // SAFETY: we have written all slots
1186        unsafe { values.set_len(len) }
1187        let validity = Bitmap::from(validity);
1188        let arr = PrimitiveArray::new(
1189            T::get_static_dtype()
1190                .to_physical()
1191                .to_arrow(CompatLevel::newest()),
1192            values.into(),
1193            Some(validity),
1194        );
1195        Series::try_from((ca.name().clone(), arr.boxed())).unwrap()
1196    }
1197}