1use std::fmt::Write;
2
3use arrow::array::PrimitiveArray;
4use arrow::bitmap::Bitmap;
5use polars_core::prelude::sort::perfect_sort;
6use polars_core::prelude::*;
7use polars_core::series::IsSorted;
8use polars_core::utils::_split_offsets;
9use polars_core::{POOL, downcast_as_macro_arg_physical};
10use polars_ops::frame::SeriesJoin;
11use polars_ops::frame::join::{ChunkJoinOptIds, private_left_join_multiple_keys};
12use polars_ops::prelude::*;
13use polars_plan::prelude::*;
14use polars_utils::sync::SyncPtr;
15use rayon::prelude::*;
16
17use super::*;
18
19pub struct WindowExpr {
20 pub(crate) group_by: Vec<Arc<dyn PhysicalExpr>>,
23 pub(crate) order_by: Option<(Arc<dyn PhysicalExpr>, SortOptions)>,
24 pub(crate) apply_columns: Vec<PlSmallStr>,
25 pub(crate) phys_function: Arc<dyn PhysicalExpr>,
26 pub(crate) mapping: WindowMapping,
27 pub(crate) expr: Expr,
28 pub(crate) has_different_group_sources: bool,
29 pub(crate) output_field: Field,
30}
31
32#[cfg_attr(debug_assertions, derive(Debug))]
33enum MapStrategy {
34 Join,
37 Explode,
39 Map,
41 Nothing,
42}
43
44impl WindowExpr {
45 fn map_list_agg_by_arg_sort(
46 &self,
47 out_column: Column,
48 flattened: &Column,
49 mut ac: AggregationContext,
50 gb: GroupBy,
51 ) -> PolarsResult<IdxCa> {
52 let mut idx_mapping = Vec::with_capacity(out_column.len());
54
55 let mut take_idx = vec![];
58
59 if std::ptr::eq(ac.groups().as_ref(), gb.get_groups()) {
61 let mut iter = 0..flattened.len() as IdxSize;
62 match ac.groups().as_ref().as_ref() {
63 GroupsType::Idx(groups) => {
64 for g in groups.all() {
65 idx_mapping.extend(g.iter().copied().zip(&mut iter));
66 }
67 },
68 GroupsType::Slice { groups, .. } => {
69 for &[first, len] in groups {
70 idx_mapping.extend((first..first + len).zip(&mut iter));
71 }
72 },
73 }
74 }
75 else {
78 let mut original_idx = Vec::with_capacity(out_column.len());
79 match gb.get_groups().as_ref() {
80 GroupsType::Idx(groups) => {
81 for g in groups.all() {
82 original_idx.extend_from_slice(g)
83 }
84 },
85 GroupsType::Slice { groups, .. } => {
86 for &[first, len] in groups {
87 original_idx.extend(first..first + len)
88 }
89 },
90 };
91
92 let mut original_idx_iter = original_idx.iter().copied();
93
94 match ac.groups().as_ref().as_ref() {
95 GroupsType::Idx(groups) => {
96 for g in groups.all() {
97 idx_mapping.extend(g.iter().copied().zip(&mut original_idx_iter));
98 }
99 },
100 GroupsType::Slice { groups, .. } => {
101 for &[first, len] in groups {
102 idx_mapping.extend((first..first + len).zip(&mut original_idx_iter));
103 }
104 },
105 }
106 original_idx.clear();
107 take_idx = original_idx;
108 }
109 unsafe { perfect_sort(&idx_mapping, &mut take_idx) };
112 Ok(IdxCa::from_vec(PlSmallStr::EMPTY, take_idx))
113 }
114
115 #[allow(clippy::too_many_arguments)]
116 fn map_by_arg_sort(
117 &self,
118 df: &DataFrame,
119 out_column: Column,
120 flattened: &Column,
121 mut ac: AggregationContext,
122 group_by_columns: &[Column],
123 gb: GroupBy,
124 cache_key: String,
125 state: &ExecutionState,
126 ) -> PolarsResult<Column> {
127 if flattened.len() != df.height() {
153 let ca = out_column.list().unwrap();
154 let non_matching_group =
155 ca.into_iter()
156 .zip(ac.groups().iter())
157 .find(|(output, group)| {
158 if let Some(output) = output {
159 output.as_ref().len() != group.len()
160 } else {
161 false
162 }
163 });
164
165 if let Some((output, group)) = non_matching_group {
166 let first = group.first();
167 let group = group_by_columns
168 .iter()
169 .map(|s| format!("{}", s.get(first as usize).unwrap()))
170 .collect::<Vec<_>>();
171 polars_bail!(
172 expr = self.expr, ShapeMismatch:
173 "the length of the window expression did not match that of the group\
174 \n> group: {}\n> group length: {}\n> output: '{:?}'",
175 comma_delimited(String::new(), &group), group.len(), output.unwrap()
176 );
177 } else {
178 polars_bail!(
179 expr = self.expr, ShapeMismatch:
180 "the length of the window expression did not match that of the group"
181 );
182 };
183 }
184
185 let idx = if state.cache_window() {
186 if let Some(idx) = state.window_cache.get_map(&cache_key) {
187 idx
188 } else {
189 let idx = Arc::new(self.map_list_agg_by_arg_sort(out_column, flattened, ac, gb)?);
190 state.window_cache.insert_map(cache_key, idx.clone());
191 idx
192 }
193 } else {
194 Arc::new(self.map_list_agg_by_arg_sort(out_column, flattened, ac, gb)?)
195 };
196
197 unsafe { Ok(flattened.take_unchecked(&idx)) }
200 }
201
202 fn run_aggregation<'a>(
203 &self,
204 df: &DataFrame,
205 state: &ExecutionState,
206 gb: &'a GroupBy,
207 ) -> PolarsResult<AggregationContext<'a>> {
208 let ac = self
209 .phys_function
210 .evaluate_on_groups(df, gb.get_groups(), state)?;
211 Ok(ac)
212 }
213
214 fn is_explicit_list_agg(&self) -> bool {
215 let mut explicit_list = false;
225 for e in &self.expr {
226 if let Expr::Window { function, .. } = e {
227 let mut finishes_list = false;
229 for e in &**function {
230 match e {
231 Expr::Agg(AggExpr::Implode(_)) => {
232 finishes_list = true;
233 },
234 Expr::Alias(_, _) => {},
235 _ => break,
236 }
237 }
238 explicit_list = finishes_list;
239 }
240 }
241
242 explicit_list
243 }
244
245 fn is_simple_column_expr(&self) -> bool {
246 let mut simple_col = false;
249 for e in &self.expr {
250 if let Expr::Window { function, .. } = e {
251 for e in &**function {
253 match e {
254 Expr::Column(_) => {
255 simple_col = true;
256 },
257 Expr::Alias(_, _) => {},
258 _ => break,
259 }
260 }
261 }
262 }
263 simple_col
264 }
265
266 fn is_aggregation(&self) -> bool {
267 let mut agg_col = false;
270 for e in &self.expr {
271 if let Expr::Window { function, .. } = e {
272 for e in &**function {
274 match e {
275 Expr::Agg(_) => {
276 agg_col = true;
277 },
278 Expr::Alias(_, _) => {},
279 _ => break,
280 }
281 }
282 }
283 }
284 agg_col
285 }
286
287 fn determine_map_strategy(
288 &self,
289 ac: &mut AggregationContext,
290 gb: &GroupBy,
291 ) -> PolarsResult<MapStrategy> {
292 match (self.mapping, ac.agg_state()) {
293 (WindowMapping::Explode, _) => Ok(MapStrategy::Explode),
296 (_, AggState::AggregatedScalar(_)) => Ok(MapStrategy::Join),
302 (WindowMapping::Join, AggState::AggregatedList(_)) => Ok(MapStrategy::Join),
305 (WindowMapping::GroupsToRows, AggState::AggregatedList(_)) => {
308 if let GroupsType::Slice { .. } = gb.get_groups().as_ref() {
309 ac.groups().as_ref().check_lengths(gb.get_groups())?;
311 Ok(MapStrategy::Explode)
312 } else {
313 Ok(MapStrategy::Map)
314 }
315 },
316 (WindowMapping::GroupsToRows, AggState::NotAggregated(_)) => {
321 if self.is_simple_column_expr() {
324 Ok(MapStrategy::Nothing)
325 } else {
326 Ok(MapStrategy::Map)
327 }
328 },
329 (WindowMapping::Join, AggState::NotAggregated(_)) => Ok(MapStrategy::Join),
330 (_, AggState::LiteralScalar(_)) => Ok(MapStrategy::Nothing),
332 }
333 }
334}
335
336pub fn window_function_format_order_by(to: &mut String, e: &Expr, k: &SortOptions) {
338 write!(to, "_PL_{:?}{}_{}", e, k.descending, k.nulls_last).unwrap();
339}
340
341impl PhysicalExpr for WindowExpr {
342 fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
348 if df.is_empty() {
375 let field = self.phys_function.to_field(df.schema())?;
376 match self.mapping {
377 WindowMapping::Join => {
378 return Ok(Column::full_null(
379 field.name().clone(),
380 0,
381 &DataType::List(Box::new(field.dtype().clone())),
382 ));
383 },
384 _ => {
385 return Ok(Column::full_null(field.name().clone(), 0, field.dtype()));
386 },
387 }
388 }
389
390 let mut group_by_columns = self
391 .group_by
392 .iter()
393 .map(|e| e.evaluate(df, state))
394 .collect::<PolarsResult<Vec<_>>>()?;
395
396 let sorted_keys = group_by_columns.iter().all(|s| {
398 matches!(
399 s.is_sorted_flag(),
400 IsSorted::Ascending | IsSorted::Descending
401 )
402 });
403 let explicit_list_agg = self.is_explicit_list_agg();
404
405 let mut sort_groups = matches!(self.mapping, WindowMapping::Explode) ||
407 (!self.is_simple_column_expr() && !explicit_list_agg && sorted_keys && !self.is_aggregation());
416
417 if self.has_different_group_sources {
420 sort_groups = true
421 }
422
423 let create_groups = || {
424 let gb = df.group_by_with_series(group_by_columns.clone(), true, sort_groups)?;
425 let mut groups = gb.take_groups();
426
427 if let Some((order_by, options)) = &self.order_by {
428 let order_by = order_by.evaluate(df, state)?;
429 polars_ensure!(order_by.len() == df.height(), ShapeMismatch: "the order by expression evaluated to a length: {} that doesn't match the input DataFrame: {}", order_by.len(), df.height());
430 groups = update_groups_sort_by(&groups, order_by.as_materialized_series(), options)?
431 .into_sliceable()
432 }
433
434 let out: PolarsResult<GroupPositions> = Ok(groups);
435 out
436 };
437
438 let (mut groups, cache_key) = if state.cache_window() {
440 let mut cache_key = String::with_capacity(32 * group_by_columns.len());
441 write!(&mut cache_key, "{}", state.branch_idx).unwrap();
442 for s in &group_by_columns {
443 cache_key.push_str(s.name());
444 }
445 if let Some((e, options)) = &self.order_by {
446 let e = match e.as_expression() {
447 Some(e) => e,
448 None => {
449 polars_bail!(InvalidOperation: "cannot order by this expression in window function")
450 },
451 };
452 window_function_format_order_by(&mut cache_key, e, options)
453 }
454
455 let groups = match state.window_cache.get_groups(&cache_key) {
456 Some(groups) => groups,
457 None => create_groups()?,
458 };
459 (groups, cache_key)
460 } else {
461 (create_groups()?, "".to_string())
462 };
463
464 let apply_columns = self.apply_columns.clone();
466
467 if sort_groups || state.cache_window() {
472 groups.sort();
473 state
474 .window_cache
475 .insert_groups(cache_key.clone(), groups.clone());
476 }
477
478 for col in group_by_columns.iter_mut() {
480 if col.len() != df.height() {
481 polars_ensure!(
482 col.len() == 1,
483 ShapeMismatch: "columns used as `partition_by` must have the same length as the DataFrame"
484 );
485 *col = col.new_from_index(0, df.height())
486 }
487 }
488
489 let gb = GroupBy::new(df, group_by_columns.clone(), groups, Some(apply_columns));
490
491 let mut ac = self.run_aggregation(df, state, &gb)?;
492
493 use MapStrategy::*;
494
495 match self.determine_map_strategy(&mut ac, &gb)? {
496 Nothing => {
497 let mut out = ac.flat_naive().into_owned();
498
499 if ac.is_literal() {
500 out = out.new_from_index(0, df.height())
501 }
502 Ok(out.into_column())
503 },
504 Explode => {
505 let out = if self.phys_function.is_scalar() {
506 ac.get_values().clone()
507 } else {
508 ac.aggregated().explode(false)?
509 };
510 Ok(out.into_column())
511 },
512 Map => {
513 let out_column = ac.aggregated();
516 let flattened = out_column.explode(false)?;
517 let ac = unsafe {
520 std::mem::transmute::<AggregationContext<'_>, AggregationContext<'static>>(ac)
521 };
522 self.map_by_arg_sort(
523 df,
524 out_column,
525 &flattened,
526 ac,
527 &group_by_columns,
528 gb,
529 cache_key,
530 state,
531 )
532 },
533 Join => {
534 let out_column = ac.aggregated();
535 let update_groups = !matches!(&ac.update_groups, UpdateGroups::No);
539 match (
540 &ac.update_groups,
541 set_by_groups(&out_column, &ac, df.height(), update_groups),
542 ) {
543 (UpdateGroups::No, Some(out)) => Ok(out.into_column()),
546 (_, _) => {
547 let keys = gb.keys();
548
549 let get_join_tuples = || {
550 if group_by_columns.len() == 1 {
551 let mut left = group_by_columns[0].clone();
552 let mut right = keys[0].clone();
554
555 let (left, right) = if left.dtype().is_nested() {
556 (
557 ChunkedArray::<BinaryOffsetType>::with_chunk(
558 "".into(),
559 row_encode::_get_rows_encoded_unordered(&[
560 left.clone()
561 ])?
562 .into_array(),
563 )
564 .into_series(),
565 ChunkedArray::<BinaryOffsetType>::with_chunk(
566 "".into(),
567 row_encode::_get_rows_encoded_unordered(&[
568 right.clone()
569 ])?
570 .into_array(),
571 )
572 .into_series(),
573 )
574 } else {
575 (
576 left.into_materialized_series().clone(),
577 right.into_materialized_series().clone(),
578 )
579 };
580
581 PolarsResult::Ok(Arc::new(
582 left.hash_join_left(&right, JoinValidation::ManyToMany, true)
583 .unwrap()
584 .1,
585 ))
586 } else {
587 let df_right =
588 unsafe { DataFrame::new_no_checks_height_from_first(keys) };
589 let df_left = unsafe {
590 DataFrame::new_no_checks_height_from_first(group_by_columns)
591 };
592 Ok(Arc::new(
593 private_left_join_multiple_keys(&df_left, &df_right, true)?.1,
594 ))
595 }
596 };
597
598 let join_opt_ids = if state.cache_window() {
600 if let Some(jt) = state.window_cache.get_join(&cache_key) {
601 jt
602 } else {
603 let jt = get_join_tuples()?;
604 state.window_cache.insert_join(cache_key, jt.clone());
605 jt
606 }
607 } else {
608 get_join_tuples()?
609 };
610
611 let out = materialize_column(&join_opt_ids, &out_column);
612 Ok(out.into_column())
613 },
614 }
615 },
616 }
617 }
618
619 fn to_field(&self, _input_schema: &Schema) -> PolarsResult<Field> {
620 Ok(self.output_field.clone())
621 }
622
623 fn is_scalar(&self) -> bool {
624 false
625 }
626
627 #[allow(clippy::ptr_arg)]
628 fn evaluate_on_groups<'a>(
629 &self,
630 _df: &DataFrame,
631 _groups: &'a GroupPositions,
632 _state: &ExecutionState,
633 ) -> PolarsResult<AggregationContext<'a>> {
634 polars_bail!(InvalidOperation: "window expression not allowed in aggregation");
635 }
636
637 fn as_expression(&self) -> Option<&Expr> {
638 Some(&self.expr)
639 }
640}
641
642fn materialize_column(join_opt_ids: &ChunkJoinOptIds, out_column: &Column) -> Column {
643 {
644 use arrow::Either;
645 use polars_ops::chunked_array::TakeChunked;
646
647 match join_opt_ids {
648 Either::Left(ids) => unsafe {
649 IdxCa::with_nullable_idx(ids, |idx| out_column.take_unchecked(idx))
650 },
651 Either::Right(ids) => unsafe { out_column.take_opt_chunked_unchecked(ids, false) },
652 }
653 }
654}
655
656fn set_by_groups(
658 s: &Column,
659 ac: &AggregationContext,
660 len: usize,
661 update_groups: bool,
662) -> Option<Column> {
663 if update_groups || !ac.original_len {
664 return None;
665 }
666 if s.dtype().to_physical().is_primitive_numeric() {
667 let dtype = s.dtype();
668 let s = s.to_physical_repr();
669
670 macro_rules! dispatch {
671 ($ca:expr) => {{ Some(set_numeric($ca, &ac.groups, len)) }};
672 }
673 downcast_as_macro_arg_physical!(&s, dispatch)
674 .map(|s| unsafe { s.from_physical_unchecked(dtype) }.unwrap())
675 .map(Column::from)
676 } else {
677 None
678 }
679}
680
681fn set_numeric<T: PolarsNumericType>(
682 ca: &ChunkedArray<T>,
683 groups: &GroupsType,
684 len: usize,
685) -> Series {
686 let mut values = Vec::with_capacity(len);
687 let ptr: *mut T::Native = values.as_mut_ptr();
688 let sync_ptr_values = unsafe { SyncPtr::new(ptr) };
691
692 if ca.null_count() == 0 {
693 let ca = ca.rechunk();
694 match groups {
695 GroupsType::Idx(groups) => {
696 let agg_vals = ca.cont_slice().expect("rechunked");
697 POOL.install(|| {
698 agg_vals
699 .par_iter()
700 .zip(groups.all().par_iter())
701 .for_each(|(v, g)| {
702 let ptr = sync_ptr_values.get();
703 for idx in g.as_slice() {
704 debug_assert!((*idx as usize) < len);
705 unsafe { *ptr.add(*idx as usize) = *v }
706 }
707 })
708 })
709 },
710 GroupsType::Slice { groups, .. } => {
711 let agg_vals = ca.cont_slice().expect("rechunked");
712 POOL.install(|| {
713 agg_vals
714 .par_iter()
715 .zip(groups.par_iter())
716 .for_each(|(v, [start, g_len])| {
717 let ptr = sync_ptr_values.get();
718 let start = *start as usize;
719 let end = start + *g_len as usize;
720 for idx in start..end {
721 debug_assert!(idx < len);
722 unsafe { *ptr.add(idx) = *v }
723 }
724 })
725 });
726 },
727 }
728
729 unsafe { values.set_len(len) }
731 ChunkedArray::<T>::new_vec(ca.name().clone(), values).into_series()
732 } else {
733 let mut validity: Vec<bool> = vec![false; len];
736 let validity_ptr = validity.as_mut_ptr();
737 let sync_ptr_validity = unsafe { SyncPtr::new(validity_ptr) };
738
739 let n_threads = POOL.current_num_threads();
740 let offsets = _split_offsets(ca.len(), n_threads);
741
742 match groups {
743 GroupsType::Idx(groups) => offsets.par_iter().for_each(|(offset, offset_len)| {
744 let offset = *offset;
745 let offset_len = *offset_len;
746 let ca = ca.slice(offset as i64, offset_len);
747 let groups = &groups.all()[offset..offset + offset_len];
748 let values_ptr = sync_ptr_values.get();
749 let validity_ptr = sync_ptr_validity.get();
750
751 ca.iter().zip(groups.iter()).for_each(|(opt_v, g)| {
752 for idx in g.as_slice() {
753 let idx = *idx as usize;
754 debug_assert!(idx < len);
755 unsafe {
756 match opt_v {
757 Some(v) => {
758 *values_ptr.add(idx) = v;
759 *validity_ptr.add(idx) = true;
760 },
761 None => {
762 *values_ptr.add(idx) = T::Native::default();
763 *validity_ptr.add(idx) = false;
764 },
765 };
766 }
767 }
768 })
769 }),
770 GroupsType::Slice { groups, .. } => {
771 offsets.par_iter().for_each(|(offset, offset_len)| {
772 let offset = *offset;
773 let offset_len = *offset_len;
774 let ca = ca.slice(offset as i64, offset_len);
775 let groups = &groups[offset..offset + offset_len];
776 let values_ptr = sync_ptr_values.get();
777 let validity_ptr = sync_ptr_validity.get();
778
779 for (opt_v, [start, g_len]) in ca.iter().zip(groups.iter()) {
780 let start = *start as usize;
781 let end = start + *g_len as usize;
782 for idx in start..end {
783 debug_assert!(idx < len);
784 unsafe {
785 match opt_v {
786 Some(v) => {
787 *values_ptr.add(idx) = v;
788 *validity_ptr.add(idx) = true;
789 },
790 None => {
791 *values_ptr.add(idx) = T::Native::default();
792 *validity_ptr.add(idx) = false;
793 },
794 };
795 }
796 }
797 }
798 })
799 },
800 }
801 unsafe { values.set_len(len) }
803 let validity = Bitmap::from(validity);
804 let arr = PrimitiveArray::new(
805 T::get_static_dtype()
806 .to_physical()
807 .to_arrow(CompatLevel::newest()),
808 values.into(),
809 Some(validity),
810 );
811 Series::try_from((ca.name().clone(), arr.boxed())).unwrap()
812 }
813}