Skip to main content

polars_expr/expressions/
window.rs

1use std::fmt::Write;
2
3use arrow::array::PrimitiveArray;
4use arrow::bitmap::Bitmap;
5use arrow::trusted_len::TrustMyLength;
6use polars_core::downcast_as_macro_arg_physical;
7use polars_core::error::feature_gated;
8use polars_core::prelude::row_encode::encode_rows_unordered;
9use polars_core::prelude::sort::perfect_sort;
10use polars_core::prelude::*;
11use polars_core::runtime::RAYON;
12use polars_core::series::IsSorted;
13use polars_core::utils::_split_offsets;
14use polars_ops::frame::SeriesJoin;
15use polars_ops::frame::join::{ChunkJoinOptIds, private_left_join_multiple_keys};
16use polars_ops::prelude::*;
17use polars_plan::prelude::*;
18use polars_utils::UnitVec;
19use polars_utils::sync::SyncPtr;
20use polars_utils::vec::PushUnchecked;
21use rayon::prelude::*;
22
23use super::*;
24
25pub struct WindowExpr {
26    /// the root column that the Function will be applied on.
27    /// This will be used to create a smaller DataFrame to prevent taking unneeded columns by index
28    pub(crate) group_by: Vec<Arc<dyn PhysicalExpr>>,
29    pub(crate) order_by: Option<(Arc<dyn PhysicalExpr>, SortOptions)>,
30    pub(crate) apply_columns: Vec<PlSmallStr>,
31    pub(crate) phys_function: Arc<dyn PhysicalExpr>,
32    pub(crate) mapping: WindowMapping,
33    pub(crate) expr: Expr,
34    pub(crate) has_different_group_sources: bool,
35    pub(crate) output_field: Field,
36
37    pub(crate) all_group_by_are_elementwise: bool,
38    pub(crate) order_by_is_elementwise: bool,
39}
40
41#[cfg_attr(debug_assertions, derive(Debug))]
42enum MapStrategy {
43    // Join by key, this the most expensive
44    // for reduced aggregations
45    Join,
46    // explode now
47    Explode,
48    // Use an arg_sort to map the values back
49    Map,
50    Nothing,
51}
52
53impl WindowExpr {
54    fn map_list_agg_by_arg_sort(
55        &self,
56        out_column: Column,
57        flattened: &Column,
58        mut ac: AggregationContext,
59        gb: GroupBy,
60    ) -> PolarsResult<IdxCa> {
61        // idx (new-idx, original-idx)
62        let mut idx_mapping = Vec::with_capacity(out_column.len());
63
64        // we already set this buffer so we can reuse the `original_idx` buffer
65        // that saves an allocation
66        let mut take_idx = vec![];
67
68        // groups are not changed, we can map by doing a standard arg_sort.
69        if std::ptr::eq(ac.groups().as_ref(), gb.get_groups()) {
70            let mut iter = 0..flattened.len() as IdxSize;
71            match ac.groups().as_ref().as_ref() {
72                GroupsType::Idx(groups) => {
73                    for g in groups.all() {
74                        idx_mapping.extend(g.iter().copied().zip(&mut iter));
75                    }
76                },
77                GroupsType::Slice { groups, .. } => {
78                    for &[first, len] in groups {
79                        idx_mapping.extend((first..first + len).zip(&mut iter));
80                    }
81                },
82            }
83        }
84        // groups are changed, we use the new group indexes as arguments of the arg_sort
85        // and sort by the old indexes
86        else {
87            let mut original_idx = Vec::with_capacity(out_column.len());
88            match gb.get_groups().as_ref() {
89                GroupsType::Idx(groups) => {
90                    for g in groups.all() {
91                        original_idx.extend_from_slice(g)
92                    }
93                },
94                GroupsType::Slice { groups, .. } => {
95                    for &[first, len] in groups {
96                        original_idx.extend(first..first + len)
97                    }
98                },
99            };
100
101            let mut original_idx_iter = original_idx.iter().copied();
102
103            match ac.groups().as_ref().as_ref() {
104                GroupsType::Idx(groups) => {
105                    for g in groups.all() {
106                        idx_mapping.extend(g.iter().copied().zip(&mut original_idx_iter));
107                    }
108                },
109                GroupsType::Slice { groups, .. } => {
110                    for &[first, len] in groups {
111                        idx_mapping.extend((first..first + len).zip(&mut original_idx_iter));
112                    }
113                },
114            }
115            original_idx.clear();
116            take_idx = original_idx;
117        }
118        // SAFETY:
119        // we only have unique indices ranging from 0..len
120        unsafe { perfect_sort(&idx_mapping, &mut take_idx) };
121        Ok(IdxCa::from_vec(PlSmallStr::EMPTY, take_idx))
122    }
123
124    #[allow(clippy::too_many_arguments)]
125    fn map_by_arg_sort(
126        &self,
127        df: &DataFrame,
128        out_column: Column,
129        flattened: &Column,
130        mut ac: AggregationContext,
131        group_by_columns: &[Column],
132        gb: GroupBy,
133        cache_key: String,
134        state: &ExecutionState,
135    ) -> PolarsResult<Column> {
136        // we use an arg_sort to map the values back
137
138        // This is a bit more complicated because the final group tuples may differ from the original
139        // so we use the original indices as idx values to arg_sort the original column
140        //
141        // The example below shows the naive version without group tuple mapping
142
143        // columns
144        // a b a a
145        //
146        // agg list
147        // [0, 2, 3]
148        // [1]
149        //
150        // flatten
151        //
152        // [0, 2, 3, 1]
153        //
154        // arg_sort
155        //
156        // [0, 3, 1, 2]
157        //
158        // take by arg_sorted indexes and voila groups mapped
159        // [0, 1, 2, 3]
160
161        if flattened.len() != df.height() {
162            let ca = out_column.list().unwrap();
163            let non_matching_group =
164                ca.series_iter()
165                    .zip(ac.groups().iter())
166                    .find(|(output, group)| {
167                        if let Some(output) = output {
168                            output.as_ref().len() != group.len()
169                        } else {
170                            false
171                        }
172                    });
173
174            if let Some((output, group)) = non_matching_group {
175                let first = group.first();
176                let group = group_by_columns
177                    .iter()
178                    .map(|s| format!("{}", s.get(first as usize).unwrap()))
179                    .collect::<Vec<_>>();
180                polars_bail!(
181                    expr = self.expr, ShapeMismatch:
182                    "the length of the window expression did not match that of the group\
183                    \n> group: {}\n> group length: {}\n> output: '{:?}'",
184                    comma_delimited(String::new(), &group), group.len(), output.unwrap()
185                );
186            } else {
187                polars_bail!(
188                    expr = self.expr, ShapeMismatch:
189                    "the length of the window expression did not match that of the group"
190                );
191            };
192        }
193
194        let idx = if state.cache_window() {
195            if let Some(idx) = state.window_cache.get_map(&cache_key) {
196                idx
197            } else {
198                let idx = Arc::new(self.map_list_agg_by_arg_sort(out_column, flattened, ac, gb)?);
199                state.window_cache.insert_map(cache_key, idx.clone());
200                idx
201            }
202        } else {
203            Arc::new(self.map_list_agg_by_arg_sort(out_column, flattened, ac, gb)?)
204        };
205
206        // SAFETY:
207        // groups should always be in bounds.
208        unsafe { Ok(flattened.take_unchecked(&idx)) }
209    }
210
211    fn run_aggregation<'a>(
212        &self,
213        df: &DataFrame,
214        state: &ExecutionState,
215        gb: &'a GroupBy,
216    ) -> PolarsResult<AggregationContext<'a>> {
217        let ac = self
218            .phys_function
219            .evaluate_on_groups(df, gb.get_groups(), state)?;
220        Ok(ac)
221    }
222
223    fn is_explicit_list_agg(&self) -> bool {
224        // col("foo").implode()
225        // col("foo").implode().alias()
226        // ..
227        // col("foo").implode().alias().alias()
228        //
229        // but not:
230        // col("foo").implode().sum().alias()
231        // ..
232        // col("foo").min()
233        let mut explicit_list = false;
234        for e in &self.expr {
235            if let Expr::Over { function, .. } = e {
236                // or list().alias
237                let mut finishes_list = false;
238                for e in &**function {
239                    match e {
240                        Expr::Agg(AggExpr::Implode { .. }) => {
241                            finishes_list = true;
242                        },
243                        Expr::Alias(_, _) => {},
244                        _ => break,
245                    }
246                }
247                explicit_list = finishes_list;
248            }
249        }
250
251        explicit_list
252    }
253
254    fn is_simple_column_expr(&self) -> bool {
255        // col()
256        // or col().alias()
257        let mut simple_col = false;
258        for e in &self.expr {
259            if let Expr::Over { function, .. } = e {
260                // or list().alias
261                for e in &**function {
262                    match e {
263                        Expr::Column(_) => {
264                            simple_col = true;
265                        },
266                        Expr::Alias(_, _) => {},
267                        _ => break,
268                    }
269                }
270            }
271        }
272        simple_col
273    }
274
275    fn is_aggregation(&self) -> bool {
276        // col()
277        // or col().agg()
278        let mut agg_col = false;
279        for e in &self.expr {
280            if let Expr::Over { function, .. } = e {
281                // or list().alias
282                for e in &**function {
283                    match e {
284                        Expr::Agg(_) => {
285                            agg_col = true;
286                        },
287                        Expr::Alias(_, _) => {},
288                        _ => break,
289                    }
290                }
291            }
292        }
293        agg_col
294    }
295
296    fn determine_map_strategy(
297        &self,
298        ac: &mut AggregationContext,
299        gb: &GroupBy,
300    ) -> PolarsResult<MapStrategy> {
301        match (self.mapping, ac.agg_state()) {
302            // Explode
303            // `(col("x").sum() * col("y")).list().over("groups").flatten()`
304            (WindowMapping::Explode, _) => Ok(MapStrategy::Explode),
305            // // explicit list
306            // // `(col("x").sum() * col("y")).list().over("groups")`
307            // (false, false, _) => Ok(MapStrategy::Join),
308            // aggregations
309            //`sum("foo").over("groups")`
310            (_, AggState::AggregatedScalar(_)) => Ok(MapStrategy::Join),
311            // no explicit aggregations, map over the groups
312            //`(col("x").sum() * col("y")).over("groups")`
313            (WindowMapping::Join, AggState::AggregatedList(_)) => Ok(MapStrategy::Join),
314            // no explicit aggregations, map over the groups
315            //`(col("x").sum() * col("y")).over("groups")`
316            (WindowMapping::GroupsToRows, AggState::AggregatedList(_)) => {
317                if let GroupsType::Slice { .. } = gb.get_groups().as_ref() {
318                    // Result can be directly exploded if the input was sorted.
319                    ac.groups().as_ref().check_lengths(gb.get_groups())?;
320                    Ok(MapStrategy::Explode)
321                } else {
322                    Ok(MapStrategy::Map)
323                }
324            },
325            // no aggregations, just return column
326            // or an aggregation that has been flattened
327            // we have to check which one
328            //`col("foo").over("groups")`
329            (WindowMapping::GroupsToRows, AggState::NotAggregated(_)) => {
330                // col()
331                // or col().alias()
332                if self.is_simple_column_expr() {
333                    Ok(MapStrategy::Nothing)
334                } else {
335                    Ok(MapStrategy::Map)
336                }
337            },
338            (WindowMapping::Join, AggState::NotAggregated(_)) => Ok(MapStrategy::Join),
339            // literals, do nothing and let broadcast
340            (_, AggState::LiteralScalar(_)) => Ok(MapStrategy::Nothing),
341        }
342    }
343}
344
345// Utility to create partitions and cache keys
346pub fn window_function_format_order_by(to: &mut String, e: &Expr, k: &SortOptions) {
347    write!(to, "_PL_{:?}{}_{}", e, k.descending, k.nulls_last).unwrap();
348}
349
350impl PhysicalExpr for WindowExpr {
351    // Note: this was first implemented with expression evaluation but this performed really bad.
352    // Therefore we choose the group_by -> apply -> self join approach
353
354    // This first cached the group_by and the join tuples, but rayon under a mutex leads to deadlocks:
355    // https://github.com/rayon-rs/rayon/issues/592
356    fn evaluate_impl(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
357        // This method does the following:
358        // 1. determine group_by tuples based on the group_column
359        // 2. apply an aggregation function
360        // 3. join the results back to the original dataframe
361        //    this stores all group values on the original df size
362        //
363        //      we have several strategies for this
364        //      - 3.1 JOIN
365        //          Use a join for aggregations like
366        //              `sum("foo").over("groups")`
367        //          and explicit `list` aggregations
368        //              `(col("x").sum() * col("y")).list().over("groups")`
369        //
370        //      - 3.2 EXPLODE
371        //          Explicit list aggregations that are followed by `over().flatten()`
372        //          # the fastest method to do things over groups when the groups are sorted.
373        //          # note that it will require an explicit `list()` call from now on.
374        //              `(col("x").sum() * col("y")).list().over("groups").flatten()`
375        //
376        //      - 3.3. MAP to original locations
377        //          This will be done for list aggregations that are not explicitly aggregated as list
378        //              `(col("x").sum() * col("y")).over("groups")
379        //          This can be used to reverse, sort, shuffle etc. the values in a group
380
381        // 4. select the final column and return
382
383        if df.height() == 0 {
384            let field = self.phys_function.to_field(df.schema())?;
385            match self.mapping {
386                WindowMapping::Join => {
387                    return Ok(Column::full_null(
388                        field.name().clone(),
389                        0,
390                        &DataType::List(Box::new(field.dtype().clone())),
391                    ));
392                },
393                _ => {
394                    return Ok(Column::full_null(field.name().clone(), 0, field.dtype()));
395                },
396            }
397        }
398
399        let mut group_by_columns = self
400            .group_by
401            .iter()
402            .map(|e| e.evaluate(df, state))
403            .collect::<PolarsResult<Vec<_>>>()?;
404
405        // if the keys are sorted
406        let sorted_keys = group_by_columns.iter().all(|s| {
407            matches!(
408                s.is_sorted_flag(),
409                IsSorted::Ascending | IsSorted::Descending
410            )
411        });
412        let explicit_list_agg = self.is_explicit_list_agg();
413
414        // if we flatten this column we need to make sure the groups are sorted.
415        let mut sort_groups = matches!(self.mapping, WindowMapping::Explode) ||
416            // if not
417            //      `col().over()`
418            // and not
419            //      `col().list().over`
420            // and not
421            //      `col().sum()`
422            // and keys are sorted
423            //  we may optimize with explode call
424            (!self.is_simple_column_expr() && !explicit_list_agg && sorted_keys && !self.is_aggregation());
425
426        // overwrite sort_groups for some expressions
427        // TODO: fully understand the rationale is here.
428        if self.has_different_group_sources {
429            sort_groups = true
430        }
431
432        let create_groups = || {
433            let gb = df.group_by_with_series(group_by_columns.clone(), true, sort_groups)?;
434            let mut groups = gb.into_groups();
435
436            if let Some((order_by, options)) = &self.order_by {
437                let order_by = order_by.evaluate(df, state)?;
438                polars_ensure!(order_by.len() == df.height(), ShapeMismatch: "the order by expression evaluated to a length: {} that doesn't match the input DataFrame: {}", order_by.len(), df.height());
439                groups = update_groups_sort_by(&groups, order_by.as_materialized_series(), options)?
440                    .into_sliceable()
441            }
442
443            let out: PolarsResult<GroupPositions> = Ok(groups);
444            out
445        };
446
447        // Try to get cached grouptuples
448        let (mut groups, cache_key) = if state.cache_window() {
449            let mut cache_key = String::with_capacity(32 * group_by_columns.len());
450            write!(&mut cache_key, "{}", state.branch_idx).unwrap();
451            for s in &group_by_columns {
452                cache_key.push_str(s.name());
453            }
454            if let Some((e, options)) = &self.order_by {
455                let e = match e.as_expression() {
456                    Some(e) => e,
457                    None => {
458                        polars_bail!(InvalidOperation: "cannot order by this expression in window function")
459                    },
460                };
461                window_function_format_order_by(&mut cache_key, e, options)
462            }
463
464            let groups = match state.window_cache.get_groups(&cache_key) {
465                Some(groups) => groups,
466                None => create_groups()?,
467            };
468            (groups, cache_key)
469        } else {
470            (create_groups()?, "".to_string())
471        };
472
473        // 2. create GroupBy object and apply aggregation
474        let apply_columns = self.apply_columns.clone();
475
476        // some window expressions need sorted groups
477        // to make sure that the caches align we sort
478        // the groups, so that the cached groups and join keys
479        // are consistent among all windows
480        if sort_groups || state.cache_window() {
481            groups.sort_by_first_idx();
482            state
483                .window_cache
484                .insert_groups(cache_key.clone(), groups.clone());
485        }
486
487        // broadcast if required
488        for col in group_by_columns.iter_mut() {
489            if col.len() != df.height() {
490                polars_ensure!(
491                    col.len() == 1,
492                    ShapeMismatch: "columns used as `partition_by` must have the same length as the DataFrame"
493                );
494                *col = col.new_from_index(0, df.height())
495            }
496        }
497
498        let gb = GroupBy::new(df, group_by_columns.clone(), groups, Some(apply_columns));
499
500        let mut ac = self.run_aggregation(df, state, &gb)?;
501
502        use MapStrategy::*;
503
504        match self.determine_map_strategy(&mut ac, &gb)? {
505            Nothing => {
506                let mut out = ac.flat_naive().into_owned();
507
508                if ac.is_literal() {
509                    out = out.new_from_index(0, df.height())
510                }
511                Ok(out.into_column())
512            },
513            Explode => {
514                let out = if self.phys_function.is_scalar() {
515                    ac.get_values().clone()
516                } else {
517                    ac.aggregated().explode(ExplodeOptions {
518                        empty_as_null: true,
519                        keep_nulls: true,
520                    })?
521                };
522                Ok(out.into_column())
523            },
524            Map => {
525                // TODO!
526                // investigate if sorted arrays can be return directly
527                let out_column = ac.aggregated();
528                let flattened = out_column.explode(ExplodeOptions {
529                    empty_as_null: true,
530                    keep_nulls: true,
531                })?;
532                // we extend the lifetime as we must convince the compiler that ac lives
533                // long enough. We drop `GrouBy` when we are done with `ac`.
534                let ac = unsafe {
535                    std::mem::transmute::<AggregationContext<'_>, AggregationContext<'static>>(ac)
536                };
537                self.map_by_arg_sort(
538                    df,
539                    out_column,
540                    &flattened,
541                    ac,
542                    &group_by_columns,
543                    gb,
544                    cache_key,
545                    state,
546                )
547            },
548            Join => {
549                let out_column = ac.aggregated();
550                // we try to flatten/extend the array by repeating the aggregated value n times
551                // where n is the number of members in that group. That way we can try to reuse
552                // the same map by arg_sort logic as done for listed aggregations
553                let update_groups = !matches!(&ac.update_groups, UpdateGroups::No);
554                match (
555                    &ac.update_groups,
556                    set_by_groups(
557                        &out_column,
558                        &ac,
559                        gb.get_groups(),
560                        df.height(),
561                        update_groups,
562                    ),
563                ) {
564                    // for aggregations that reduce like sum, mean, first and are numeric
565                    // we take the group locations to directly map them to the right place
566                    (UpdateGroups::No, Some(out)) => Ok(out.into_column()),
567                    (_, _) => {
568                        let keys = gb.keys();
569
570                        let get_join_tuples = || {
571                            if group_by_columns.len() == 1 {
572                                let mut left = group_by_columns[0].clone();
573                                // group key from right column
574                                let mut right = keys[0].clone();
575
576                                let (left, right) = if left.dtype().is_nested() {
577                                    (
578                                        ChunkedArray::<BinaryOffsetType>::with_chunk(
579                                            "".into(),
580                                            row_encode::_get_rows_encoded_unordered(&[
581                                                left.clone()
582                                            ])?
583                                            .into_array(),
584                                        )
585                                        .into_series(),
586                                        ChunkedArray::<BinaryOffsetType>::with_chunk(
587                                            "".into(),
588                                            row_encode::_get_rows_encoded_unordered(&[
589                                                right.clone()
590                                            ])?
591                                            .into_array(),
592                                        )
593                                        .into_series(),
594                                    )
595                                } else {
596                                    (
597                                        left.into_materialized_series().clone(),
598                                        right.into_materialized_series().clone(),
599                                    )
600                                };
601
602                                PolarsResult::Ok(Arc::new(
603                                    left.hash_join_left(&right, JoinValidation::ManyToMany, true)
604                                        .unwrap()
605                                        .1,
606                                ))
607                            } else {
608                                Ok(Arc::new(
609                                    private_left_join_multiple_keys(
610                                        &group_by_columns,
611                                        &keys,
612                                        true,
613                                    )?
614                                    .1,
615                                ))
616                            }
617                        };
618
619                        // try to get cached join_tuples
620                        let join_opt_ids = if state.cache_window() {
621                            if let Some(jt) = state.window_cache.get_join(&cache_key) {
622                                jt
623                            } else {
624                                let jt = get_join_tuples()?;
625                                state.window_cache.insert_join(cache_key, jt.clone());
626                                jt
627                            }
628                        } else {
629                            get_join_tuples()?
630                        };
631
632                        let out = materialize_column(&join_opt_ids, &out_column);
633                        Ok(out.into_column())
634                    },
635                }
636            },
637        }
638    }
639
640    fn to_field(&self, _input_schema: &Schema) -> PolarsResult<Field> {
641        Ok(self.output_field.clone())
642    }
643
644    fn is_scalar(&self) -> bool {
645        false
646    }
647
648    #[allow(clippy::ptr_arg)]
649    fn evaluate_on_groups_impl<'a>(
650        &self,
651        df: &DataFrame,
652        groups: &'a GroupPositions,
653        state: &ExecutionState,
654    ) -> PolarsResult<AggregationContext<'a>> {
655        if self.group_by.is_empty()
656            || !self.all_group_by_are_elementwise
657            || (self.order_by.is_some() && !self.order_by_is_elementwise)
658        {
659            polars_bail!(
660                InvalidOperation:
661                "window expression with non-elementwise `partition_by` or `order_by` not allowed in aggregation context"
662            );
663        }
664
665        let length_preserving_height = if let Some((c, _)) = state.element.as_ref() {
666            c.len()
667        } else {
668            df.height()
669        };
670
671        let function_is_scalar = self.phys_function.is_scalar();
672        let needs_remap_to_rows =
673            matches!(self.mapping, WindowMapping::GroupsToRows) && !function_is_scalar;
674
675        let partition_by_columns = self
676            .group_by
677            .iter()
678            .map(|e| {
679                let mut e = e.evaluate(df, state)?;
680                if e.len() == 1 {
681                    e = e.new_from_index(0, length_preserving_height);
682                }
683                // Sanity check: Length Preserving.
684                assert_eq!(e.len(), length_preserving_height,);
685                Ok(e)
686            })
687            .collect::<PolarsResult<Vec<_>>>()?;
688        let order_by = match &self.order_by {
689            None => None,
690            Some((e, options)) => {
691                let mut e = e.evaluate(df, state)?;
692                if e.len() == 1 {
693                    e = e.new_from_index(0, length_preserving_height);
694                }
695                // Sanity check: Length Preserving.
696                assert_eq!(e.len(), length_preserving_height);
697                let arr: Option<PrimitiveArray<IdxSize>> = if needs_remap_to_rows {
698                    feature_gated!("rank", {
699                        // Performance: precompute the rank here, so we can avoid dispatching per group
700                        // later.
701                        use polars_ops::series::SeriesRank;
702                        let arr = e.as_materialized_series().rank(
703                            RankOptions {
704                                method: RankMethod::Ordinal,
705                                descending: false,
706                            },
707                            None,
708                        );
709                        let arr = arr.idx()?;
710                        let arr = arr.rechunk();
711                        Some(arr.downcast_as_array().clone())
712                    })
713                } else {
714                    None
715                };
716
717                Some((e.clone(), arr, *options))
718            },
719        };
720
721        let (num_unique_ids, unique_ids) = if partition_by_columns.len() == 1 {
722            partition_by_columns[0].unique_id()?
723        } else {
724            ChunkUnique::unique_id(&encode_rows_unordered(&partition_by_columns)?)?
725        };
726
727        // All the groups within the existing groups.
728        let subgroups_approx_capacity = groups.len();
729        let mut subgroups: Vec<(IdxSize, UnitVec<IdxSize>)> =
730            Vec::with_capacity(subgroups_approx_capacity);
731
732        // Indices for the output groups. Not used with `WindowMapping::Explode`.
733        let mut gather_indices_offset = 0;
734        let mut gather_indices: Vec<(IdxSize, UnitVec<IdxSize>)> =
735            Vec::with_capacity(if matches!(self.mapping, WindowMapping::Explode) {
736                0
737            } else {
738                groups.len()
739            });
740        // Slices for the output groups. Only used with `WindowMapping::Explode`.
741        let mut strategy_explode_groups: Vec<[IdxSize; 2]> =
742            Vec::with_capacity(if matches!(self.mapping, WindowMapping::Explode) {
743                groups.len()
744            } else {
745                0
746            });
747
748        // Amortized vectors to reorder based on `order_by`.
749        let mut amort_arg_sort = Vec::new();
750        let mut amort_offsets = Vec::new();
751
752        // Amortized vectors to gather per group data.
753        let mut amort_subgroups_order = Vec::with_capacity(num_unique_ids as usize);
754        let mut amort_subgroups_sizes = Vec::with_capacity(num_unique_ids as usize);
755        let mut amort_subgroups_indices = (0..num_unique_ids)
756            .map(|_| (0, UnitVec::new()))
757            .collect::<Vec<(IdxSize, UnitVec<IdxSize>)>>();
758
759        macro_rules! map_window_groups {
760            ($iter:expr, $get:expr) => {
761                let mut subgroup_gather_indices =
762                    UnitVec::with_capacity(if matches!(self.mapping, WindowMapping::Explode) {
763                        0
764                    } else {
765                        $iter.len()
766                    });
767
768                amort_subgroups_order.clear();
769                amort_subgroups_sizes.clear();
770                amort_subgroups_sizes.resize(num_unique_ids as usize, 0);
771
772                // Determine sizes per subgroup.
773                for i in $iter.clone() {
774                    let id = *unsafe { unique_ids.get_unchecked(i as usize) };
775                    let size = unsafe { amort_subgroups_sizes.get_unchecked_mut(id as usize) };
776                    if *size == 0 {
777                        unsafe { amort_subgroups_order.push_unchecked(id) };
778                    }
779                    *size += 1;
780                }
781
782                if matches!(self.mapping, WindowMapping::Explode) {
783                    strategy_explode_groups.push([
784                        subgroups.len() as IdxSize,
785                        amort_subgroups_order.len() as IdxSize,
786                    ]);
787                }
788
789                // Set starting gather indices and reserve capacity per subgroup.
790                let mut offset = if needs_remap_to_rows {
791                    gather_indices_offset
792                } else {
793                    subgroups.len() as IdxSize
794                };
795                for &id in &amort_subgroups_order {
796                    let size = *unsafe { amort_subgroups_sizes.get_unchecked(id as usize) };
797                    let (next_gather_idx, indices) =
798                        unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
799                    indices.reserve(size as usize);
800                    *next_gather_idx = offset;
801                    offset += if needs_remap_to_rows { size } else { 1 };
802                }
803
804                // Collect gather indices.
805                if matches!(self.mapping, WindowMapping::Explode) {
806                    for i in $iter {
807                        let id = *unsafe { unique_ids.get_unchecked(i as usize) };
808                        let (_, indices) =
809                            unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
810                        unsafe { indices.push_unchecked(i) };
811                    }
812                } else {
813                    // If we are remapping exploded rows back to rows and are reordering, we need
814                    // to ensure we reorder the gather indices as well. Reordering the `subgroup`
815                    // indices is done later.
816                    //
817                    // We having precalculated both the `unique_ids` and `order_by_ranks` in
818                    // efficient kernels, we can now relatively efficient arg_sort per group. This
819                    // is still horrendously slow, but at least not as bad as it would be if you
820                    // did this naively.
821                    if needs_remap_to_rows && let Some((_, arr, options)) = &order_by {
822                        let arr = arr.as_ref().unwrap();
823                        amort_arg_sort.clear();
824                        amort_arg_sort.extend(0..$iter.len() as IdxSize);
825                        match arr.validity() {
826                            None => {
827                                let arr = arr.values().as_slice();
828                                amort_arg_sort.sort_by(|a, b| {
829                                    let in_group_idx_a = $get(*a as usize) as usize;
830                                    let in_group_idx_b = $get(*b as usize) as usize;
831
832                                    let order_a = unsafe { arr.get_unchecked(in_group_idx_a) };
833                                    let order_b = unsafe { arr.get_unchecked(in_group_idx_b) };
834
835                                    let mut cmp = order_a.cmp(&order_b);
836                                    // Performance: This can generally be handled branchlessly.
837                                    if options.descending {
838                                        cmp = cmp.reverse();
839                                    }
840                                    cmp
841                                });
842                            },
843                            Some(validity) => {
844                                let arr = arr.values().as_slice();
845                                amort_arg_sort.sort_by(|a, b| {
846                                    let in_group_idx_a = $get(*a as usize) as usize;
847                                    let in_group_idx_b = $get(*b as usize) as usize;
848
849                                    let is_valid_a =
850                                        unsafe { validity.get_bit_unchecked(in_group_idx_a) };
851                                    let is_valid_b =
852                                        unsafe { validity.get_bit_unchecked(in_group_idx_b) };
853
854                                    if !(is_valid_a & is_valid_b) {
855                                        let mut cmp = is_valid_a.cmp(&is_valid_b);
856                                        if options.nulls_last {
857                                            cmp = cmp.reverse();
858                                        }
859                                        return cmp;
860                                    }
861
862                                    let order_a = unsafe { arr.get_unchecked(in_group_idx_a) };
863                                    let order_b = unsafe { arr.get_unchecked(in_group_idx_b) };
864
865                                    let mut cmp = order_a.cmp(&order_b);
866                                    if options.descending {
867                                        cmp = cmp.reverse();
868                                    }
869                                    cmp
870                                });
871                            },
872                        }
873
874                        amort_offsets.clear();
875                        amort_offsets.resize($iter.len(), 0);
876                        for &id in &amort_subgroups_order {
877                            amort_subgroups_sizes[id as usize] = 0;
878                        }
879
880                        for &idx in &amort_arg_sort {
881                            let in_group_idx = $get(idx as usize);
882                            let id = *unsafe { unique_ids.get_unchecked(in_group_idx as usize) };
883                            amort_offsets[idx as usize] = amort_subgroups_sizes[id as usize];
884                            amort_subgroups_sizes[id as usize] += 1;
885                        }
886
887                        for (i, offset) in $iter.zip(&amort_offsets) {
888                            let id = *unsafe { unique_ids.get_unchecked(i as usize) };
889                            let (next_gather_idx, indices) =
890                                unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
891                            unsafe {
892                                subgroup_gather_indices.push_unchecked(*next_gather_idx + *offset)
893                            };
894                            unsafe { indices.push_unchecked(i) };
895                        }
896                    } else {
897                        for i in $iter {
898                            let id = *unsafe { unique_ids.get_unchecked(i as usize) };
899                            let (next_gather_idx, indices) =
900                                unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
901                            unsafe { subgroup_gather_indices.push_unchecked(*next_gather_idx) };
902                            *next_gather_idx += IdxSize::from(needs_remap_to_rows);
903                            unsafe { indices.push_unchecked(i) };
904                        }
905                    }
906                }
907
908                // Push groups into nested_groups.
909                subgroups.extend(amort_subgroups_order.iter().map(|&id| {
910                    let (_, indices) =
911                        unsafe { amort_subgroups_indices.get_unchecked_mut(id as usize) };
912                    let indices = std::mem::take(indices);
913                    (*unsafe { indices.get_unchecked(0) }, indices)
914                }));
915
916                if !matches!(self.mapping, WindowMapping::Explode) {
917                    gather_indices_offset += subgroup_gather_indices.len() as IdxSize;
918                    gather_indices.push((
919                        subgroup_gather_indices.first().copied().unwrap_or(0),
920                        subgroup_gather_indices,
921                    ));
922                }
923            };
924        }
925        match groups.as_ref() {
926            GroupsType::Idx(idxs) => {
927                for g in idxs.all() {
928                    map_window_groups!(g.iter().copied(), (|i: usize| g[i]));
929                }
930            },
931            GroupsType::Slice {
932                groups,
933                overlapping: _,
934                monotonic: _,
935            } => {
936                for [s, l] in groups.iter() {
937                    let s = *s;
938                    let l = *l;
939                    let iter = unsafe { TrustMyLength::new(s..s + l, l as usize) };
940                    map_window_groups!(iter, (|i: usize| s + i as IdxSize));
941                }
942            },
943        }
944
945        let mut subgroups = GroupsType::Idx(subgroups.into());
946        if let Some((order_by, _, options)) = order_by {
947            subgroups =
948                update_groups_sort_by(&subgroups, order_by.as_materialized_series(), &options)?;
949        }
950        let subgroups = subgroups.into_sliceable();
951        let mut data = self
952            .phys_function
953            .evaluate_on_groups(df, &subgroups, state)?
954            .finalize();
955
956        let final_groups = if matches!(self.mapping, WindowMapping::Explode) {
957            if !function_is_scalar {
958                let (data_s, offsets) = data.list()?.explode_and_offsets(ExplodeOptions {
959                    empty_as_null: false,
960                    keep_nulls: false,
961                })?;
962                data = data_s.into_column();
963
964                let mut exploded_offset = 0;
965                for [start, length] in strategy_explode_groups.iter_mut() {
966                    let exploded_start = exploded_offset;
967                    let exploded_length = offsets
968                        .lengths()
969                        .skip(*start as usize)
970                        .take(*length as usize)
971                        .sum::<usize>() as IdxSize;
972                    exploded_offset += exploded_length;
973                    *start = exploded_start;
974                    *length = exploded_length;
975                }
976            }
977            GroupsType::new_slice(strategy_explode_groups, false, true)
978        } else {
979            if needs_remap_to_rows {
980                let data_l = data.list()?;
981                assert_eq!(data_l.len(), subgroups.len());
982                let lengths = data_l.lst_lengths();
983                let length_mismatch = match subgroups.as_ref() {
984                    GroupsType::Idx(idx) => idx
985                        .all()
986                        .iter()
987                        .zip(lengths.iter())
988                        .any(|(i, l)| i.len() as IdxSize != l.unwrap()),
989                    GroupsType::Slice {
990                        groups,
991                        overlapping: _,
992                        monotonic: _,
993                    } => groups
994                        .iter()
995                        .zip(lengths.iter())
996                        .any(|([_, i], l)| *i != l.unwrap()),
997                };
998
999                polars_ensure!(
1000                    !length_mismatch,
1001                    expr = self.expr, ShapeMismatch:
1002                    "the length of the window expression did not match that of the group"
1003                );
1004
1005                data = data_l
1006                    .explode(ExplodeOptions {
1007                        empty_as_null: false,
1008                        keep_nulls: true,
1009                    })?
1010                    .into_column();
1011            }
1012            GroupsType::Idx(gather_indices.into())
1013        }
1014        .into_sliceable();
1015
1016        Ok(AggregationContext {
1017            state: AggState::NotAggregated(data),
1018            groups: Cow::Owned(final_groups),
1019            update_groups: UpdateGroups::No,
1020            original_len: false,
1021        })
1022    }
1023
1024    fn as_expression(&self) -> Option<&Expr> {
1025        Some(&self.expr)
1026    }
1027}
1028
1029fn materialize_column(join_opt_ids: &ChunkJoinOptIds, out_column: &Column) -> Column {
1030    {
1031        use arrow::Either;
1032        use polars_ops::chunked_array::TakeChunked;
1033
1034        match join_opt_ids {
1035            Either::Left(ids) => unsafe {
1036                IdxCa::with_nullable_idx(ids, |idx| out_column.take_unchecked(idx))
1037            },
1038            Either::Right(ids) => unsafe { out_column.take_opt_chunked_unchecked(ids, false) },
1039        }
1040    }
1041}
1042
1043/// Simple reducing aggregation can be set by the groups
1044fn set_by_groups(
1045    s: &Column,
1046    ac: &AggregationContext,
1047    gb_groups: &GroupPositions,
1048    len: usize,
1049    update_groups: bool,
1050) -> Option<Column> {
1051    let groups = match ac.agg_state() {
1052        AggState::AggregatedScalar(_) | AggState::LiteralScalar(_) => gb_groups,
1053        AggState::NotAggregated(_) | AggState::AggregatedList(_) => {
1054            if update_groups || !ac.original_len {
1055                return None;
1056            } else {
1057                &ac.groups
1058            }
1059        },
1060    };
1061
1062    if s.dtype().to_physical().is_primitive_numeric() {
1063        let dtype = s.dtype();
1064        let s = s.to_physical_repr();
1065
1066        macro_rules! dispatch {
1067            ($ca:expr) => {{ Some(set_numeric($ca, groups, len)) }};
1068        }
1069
1070        downcast_as_macro_arg_physical!(&s, dispatch)
1071            .map(|s| unsafe { s.from_physical_unchecked(dtype) }.unwrap())
1072            .map(Column::from)
1073    } else {
1074        None
1075    }
1076}
1077
1078fn set_numeric<T: PolarsNumericType>(
1079    ca: &ChunkedArray<T>,
1080    groups: &GroupsType,
1081    len: usize,
1082) -> Series {
1083    let mut values = Vec::with_capacity(len);
1084    let ptr: *mut T::Native = values.as_mut_ptr();
1085    // SAFETY:
1086    // we will write from different threads but we will never alias.
1087    let sync_ptr_values = unsafe { SyncPtr::new(ptr) };
1088
1089    if ca.null_count() == 0 {
1090        let ca = ca.rechunk();
1091        match groups {
1092            GroupsType::Idx(groups) => {
1093                let agg_vals = ca.cont_slice().expect("rechunked");
1094                RAYON.install(|| {
1095                    agg_vals
1096                        .par_iter()
1097                        .zip(groups.all().par_iter())
1098                        .for_each(|(v, g)| {
1099                            let ptr = sync_ptr_values.get();
1100                            for idx in g.as_slice() {
1101                                debug_assert!((*idx as usize) < len);
1102                                unsafe { *ptr.add(*idx as usize) = *v }
1103                            }
1104                        })
1105                })
1106            },
1107            GroupsType::Slice { groups, .. } => {
1108                let agg_vals = ca.cont_slice().expect("rechunked");
1109                RAYON.install(|| {
1110                    agg_vals
1111                        .par_iter()
1112                        .zip(groups.par_iter())
1113                        .for_each(|(v, [start, g_len])| {
1114                            let ptr = sync_ptr_values.get();
1115                            let start = *start as usize;
1116                            let end = start + *g_len as usize;
1117                            for idx in start..end {
1118                                debug_assert!(idx < len);
1119                                unsafe { *ptr.add(idx) = *v }
1120                            }
1121                        })
1122                });
1123            },
1124        }
1125
1126        // SAFETY: we have written all slots
1127        unsafe { values.set_len(len) }
1128        ChunkedArray::<T>::new_vec(ca.name().clone(), values).into_series()
1129    } else {
1130        // We don't use a mutable bitmap as bits will have race conditions!
1131        // A single byte might alias if we write from single threads.
1132        let mut validity: Vec<bool> = vec![false; len];
1133        let validity_ptr = validity.as_mut_ptr();
1134        let sync_ptr_validity = unsafe { SyncPtr::new(validity_ptr) };
1135
1136        let n_threads = RAYON.current_num_threads();
1137        let offsets = _split_offsets(ca.len(), n_threads);
1138
1139        match groups {
1140            GroupsType::Idx(groups) => offsets.par_iter().for_each(|(offset, offset_len)| {
1141                let offset = *offset;
1142                let offset_len = *offset_len;
1143                let ca = ca.slice(offset as i64, offset_len);
1144                let groups = &groups.all()[offset..offset + offset_len];
1145                let values_ptr = sync_ptr_values.get();
1146                let validity_ptr = sync_ptr_validity.get();
1147
1148                ca.iter().zip(groups.iter()).for_each(|(opt_v, g)| {
1149                    for idx in g.as_slice() {
1150                        let idx = *idx as usize;
1151                        debug_assert!(idx < len);
1152                        unsafe {
1153                            match opt_v {
1154                                Some(v) => {
1155                                    *values_ptr.add(idx) = v;
1156                                    *validity_ptr.add(idx) = true;
1157                                },
1158                                None => {
1159                                    *values_ptr.add(idx) = T::Native::default();
1160                                    *validity_ptr.add(idx) = false;
1161                                },
1162                            };
1163                        }
1164                    }
1165                })
1166            }),
1167            GroupsType::Slice { groups, .. } => {
1168                offsets.par_iter().for_each(|(offset, offset_len)| {
1169                    let offset = *offset;
1170                    let offset_len = *offset_len;
1171                    let ca = ca.slice(offset as i64, offset_len);
1172                    let groups = &groups[offset..offset + offset_len];
1173                    let values_ptr = sync_ptr_values.get();
1174                    let validity_ptr = sync_ptr_validity.get();
1175
1176                    for (opt_v, [start, g_len]) in ca.iter().zip(groups.iter()) {
1177                        let start = *start as usize;
1178                        let end = start + *g_len as usize;
1179                        for idx in start..end {
1180                            debug_assert!(idx < len);
1181                            unsafe {
1182                                match opt_v {
1183                                    Some(v) => {
1184                                        *values_ptr.add(idx) = v;
1185                                        *validity_ptr.add(idx) = true;
1186                                    },
1187                                    None => {
1188                                        *values_ptr.add(idx) = T::Native::default();
1189                                        *validity_ptr.add(idx) = false;
1190                                    },
1191                                };
1192                            }
1193                        }
1194                    }
1195                })
1196            },
1197        }
1198        // SAFETY: we have written all slots
1199        unsafe { values.set_len(len) }
1200        let validity = Bitmap::from(validity);
1201        let arr = PrimitiveArray::new(
1202            T::get_static_dtype()
1203                .to_physical()
1204                .to_arrow(CompatLevel::newest()),
1205            values.into(),
1206            Some(validity),
1207        );
1208        Series::try_from((ca.name().clone(), arr.boxed())).unwrap()
1209    }
1210}