1mod aggregation;
2mod alias;
3mod apply;
4mod binary;
5mod cast;
6mod column;
7mod count;
8mod element;
9mod eval;
10#[cfg(feature = "dtype-struct")]
11mod field;
12mod filter;
13mod gather;
14mod group_iter;
15mod len;
16mod literal;
17#[cfg(feature = "dynamic_group_by")]
18mod rolling;
19mod slice;
20mod sort;
21mod sortby;
22#[cfg(feature = "dtype-struct")]
23mod structeval;
24mod ternary;
25mod window;
26
27use std::borrow::Cow;
28use std::fmt::{Display, Formatter};
29
30pub(crate) use aggregation::*;
31pub(crate) use alias::*;
32pub(crate) use apply::*;
33use arrow::array::ArrayRef;
34use arrow::bitmap::MutableBitmap;
35use arrow::legacy::utils::CustomIterTools;
36pub(crate) use binary::*;
37pub(crate) use cast::*;
38pub(crate) use column::*;
39pub(crate) use count::*;
40pub(crate) use element::*;
41pub(crate) use eval::*;
42#[cfg(feature = "dtype-struct")]
43pub(crate) use field::*;
44pub(crate) use filter::*;
45pub(crate) use gather::*;
46pub(crate) use len::*;
47pub(crate) use literal::*;
48use polars_core::prelude::*;
49use polars_io::predicates::PhysicalIoExpr;
50use polars_plan::prelude::*;
51#[cfg(feature = "dynamic_group_by")]
52pub(crate) use rolling::RollingExpr;
53pub(crate) use slice::*;
54pub(crate) use sort::*;
55pub(crate) use sortby::*;
56#[cfg(feature = "dtype-struct")]
57pub(crate) use structeval::*;
58pub(crate) use ternary::*;
59pub use window::window_function_format_order_by;
60pub(crate) use window::*;
61
62use crate::state::ExecutionState;
63
64#[derive(Clone, Debug)]
65pub enum AggState {
66 AggregatedList(Column),
69 AggregatedScalar(Column),
73 NotAggregated(Column),
75 LiteralScalar(Column),
77}
78
79impl AggState {
80 fn try_map<F>(&self, func: F) -> PolarsResult<Self>
81 where
82 F: FnOnce(&Column) -> PolarsResult<Column>,
83 {
84 Ok(match self {
85 AggState::AggregatedList(c) => AggState::AggregatedList(func(c)?),
86 AggState::AggregatedScalar(c) => AggState::AggregatedScalar(func(c)?),
87 AggState::LiteralScalar(c) => AggState::LiteralScalar(func(c)?),
88 AggState::NotAggregated(c) => AggState::NotAggregated(func(c)?),
89 })
90 }
91
92 fn is_scalar(&self) -> bool {
93 matches!(self, Self::AggregatedScalar(_))
94 }
95
96 pub fn name(&self) -> &PlSmallStr {
97 match self {
98 AggState::AggregatedList(s)
99 | AggState::NotAggregated(s)
100 | AggState::LiteralScalar(s)
101 | AggState::AggregatedScalar(s) => s.name(),
102 }
103 }
104
105 pub fn rename(&mut self, name: PlSmallStr) {
106 match self {
107 AggState::AggregatedList(s)
108 | AggState::NotAggregated(s)
109 | AggState::LiteralScalar(s)
110 | AggState::AggregatedScalar(s) => s.rename(name),
111 }
112 }
113
114 pub fn flat_dtype(&self) -> &DataType {
115 match self {
116 AggState::AggregatedList(s) => s.dtype().inner_dtype().unwrap(),
117 AggState::NotAggregated(s)
118 | AggState::LiteralScalar(s)
119 | AggState::AggregatedScalar(s) => s.dtype(),
120 }
121 }
122}
123
124#[derive(Debug, PartialEq, Clone, Copy)]
126pub(crate) enum UpdateGroups {
127 No,
129 WithGroupsLen,
132 WithSeriesLen,
136}
137
138#[cfg_attr(debug_assertions, derive(Debug))]
139pub struct AggregationContext<'a> {
140 pub(crate) state: AggState,
147 pub(crate) groups: Cow<'a, GroupPositions>,
149 pub(crate) update_groups: UpdateGroups,
153 pub(crate) original_len: bool,
156}
157
158impl<'a> AggregationContext<'a> {
159 pub(crate) fn groups(&mut self) -> &Cow<'a, GroupPositions> {
160 match self.update_groups {
161 UpdateGroups::No => {},
162 UpdateGroups::WithGroupsLen => {
163 let mut offset = 0 as IdxSize;
168
169 match self.groups.as_ref().as_ref() {
170 GroupsType::Idx(groups) => {
171 let groups = groups
172 .iter()
173 .map(|g| {
174 let len = g.1.len() as IdxSize;
175 let new_offset = offset + len;
176 let out = [offset, len];
177 offset = new_offset;
178 out
179 })
180 .collect();
181 self.groups =
182 Cow::Owned(GroupsType::new_slice(groups, false, true).into_sliceable())
183 },
184 GroupsType::Slice { groups, .. } => {
188 let groups = groups
190 .iter()
191 .map(|g| {
192 let len = g[1];
193 let new = [offset, g[1]];
194 offset += len;
195 new
196 })
197 .collect();
198 self.groups =
199 Cow::Owned(GroupsType::new_slice(groups, false, true).into_sliceable())
200 },
201 }
202 self.update_groups = UpdateGroups::No;
203 },
204 UpdateGroups::WithSeriesLen => {
205 let s = self.get_values().clone();
206 self.det_groups_from_list(s.as_materialized_series());
207 },
208 }
209 &self.groups
210 }
211
212 pub(crate) fn get_values(&self) -> &Column {
213 match &self.state {
214 AggState::NotAggregated(s)
215 | AggState::AggregatedScalar(s)
216 | AggState::AggregatedList(s) => s,
217 AggState::LiteralScalar(s) => s,
218 }
219 }
220
221 pub fn agg_state(&self) -> &AggState {
222 &self.state
223 }
224
225 pub(crate) fn is_not_aggregated(&self) -> bool {
226 matches!(
227 &self.state,
228 AggState::NotAggregated(_) | AggState::LiteralScalar(_)
229 )
230 }
231
232 pub(crate) fn is_aggregated(&self) -> bool {
233 !self.is_not_aggregated()
234 }
235
236 pub(crate) fn is_literal(&self) -> bool {
237 matches!(self.state, AggState::LiteralScalar(_))
238 }
239
240 fn new(
244 column: Column,
245 groups: Cow<'a, GroupPositions>,
246 aggregated: bool,
247 ) -> AggregationContext<'a> {
248 let series = if aggregated {
249 assert_eq!(column.len(), groups.len());
250 AggState::AggregatedScalar(column)
251 } else {
252 AggState::NotAggregated(column)
253 };
254
255 Self {
256 state: series,
257 groups,
258 update_groups: UpdateGroups::No,
259 original_len: true,
260 }
261 }
262
263 fn with_agg_state(&mut self, agg_state: AggState) {
264 self.state = agg_state;
265 }
266
267 fn rename(&mut self, name: PlSmallStr) {
268 self.state.rename(name);
269 }
270
271 pub(crate) fn from_agg_state(
272 agg_state: AggState,
273 groups: Cow<'a, GroupPositions>,
274 ) -> AggregationContext<'a> {
275 Self {
276 state: agg_state,
277 groups,
278 update_groups: UpdateGroups::No,
279 original_len: true,
280 }
281 }
282
283 pub(crate) fn set_original_len(&mut self, original_len: bool) -> &mut Self {
284 self.original_len = original_len;
285 self
286 }
287
288 pub(crate) fn with_update_groups(&mut self, update: UpdateGroups) -> &mut Self {
289 self.update_groups = update;
290 self
291 }
292
293 fn det_groups_from_list(&mut self, s: &Series) {
294 let mut offset = 0 as IdxSize;
295 let list = s
296 .list()
297 .expect("impl error, should be a list at this point");
298
299 match list.chunks().len() {
300 1 => {
301 let arr = list.downcast_iter().next().unwrap();
302 let offsets = arr.offsets().as_slice();
303
304 let mut previous = 0i64;
305 let groups = offsets[1..]
306 .iter()
307 .map(|&o| {
308 let len = (o - previous) as IdxSize;
309 let new_offset = offset + len;
310
311 previous = o;
312 let out = [offset, len];
313 offset = new_offset;
314 out
315 })
316 .collect_trusted();
317 self.groups =
318 Cow::Owned(GroupsType::new_slice(groups, false, true).into_sliceable());
319 },
320 _ => {
321 let groups = {
322 self.get_values()
323 .list()
324 .expect("impl error, should be a list at this point")
325 .amortized_iter()
326 .map(|s| {
327 if let Some(s) = s {
328 let len = s.as_ref().len() as IdxSize;
329 let new_offset = offset + len;
330 let out = [offset, len];
331 offset = new_offset;
332 out
333 } else {
334 [offset, 0]
335 }
336 })
337 .collect_trusted()
338 };
339 self.groups =
340 Cow::Owned(GroupsType::new_slice(groups, false, true).into_sliceable());
341 },
342 }
343 self.update_groups = UpdateGroups::No;
344 }
345
346 pub(crate) fn with_values(
350 &mut self,
351 column: Column,
352 aggregated: bool,
353 expr: Option<&Expr>,
354 ) -> PolarsResult<&mut Self> {
355 self.with_values_and_args(
356 column,
357 aggregated,
358 expr,
359 false,
360 self.agg_state().is_scalar(),
361 )
362 }
363
364 pub(crate) fn with_values_and_args(
365 &mut self,
366 column: Column,
367 aggregated: bool,
368 expr: Option<&Expr>,
369 preserve_literal: bool,
372 returns_scalar: bool,
373 ) -> PolarsResult<&mut Self> {
374 self.state = match (aggregated, column.dtype()) {
375 (true, &DataType::List(_)) if !returns_scalar => {
376 if column.len() != self.groups.len() {
377 let fmt_expr = if let Some(e) = expr {
378 format!("'{e:?}' ")
379 } else {
380 String::new()
381 };
382 polars_bail!(
383 ComputeError:
384 "aggregation expression '{}' produced a different number of elements: {} \
385 than the number of groups: {} (this is likely invalid)",
386 fmt_expr, column.len(), self.groups.len(),
387 );
388 }
389 AggState::AggregatedList(column)
390 },
391 (true, _) => AggState::AggregatedScalar(column),
392 _ => {
393 match self.state {
394 AggState::AggregatedScalar(_) => AggState::AggregatedScalar(column),
397 AggState::LiteralScalar(_) if column.len() == 1 && preserve_literal => {
399 AggState::LiteralScalar(column)
400 },
401 _ => AggState::NotAggregated(column.into_column()),
402 }
403 },
404 };
405 Ok(self)
406 }
407
408 pub(crate) fn with_literal(&mut self, column: Column) -> &mut Self {
409 self.state = AggState::LiteralScalar(column);
410 self
411 }
412
413 pub(crate) fn with_groups(&mut self, groups: GroupPositions) -> &mut Self {
415 if let AggState::AggregatedList(_) = self.agg_state() {
416 self.with_values(self.flat_naive().into_owned(), false, None)
418 .unwrap();
419 }
420 self.groups = Cow::Owned(groups);
421 self.update_groups = UpdateGroups::No;
423 self.original_len = false;
425 self
426 }
427
428 pub fn normalize_values(&mut self) {
430 self.set_original_len(false);
431 self.groups();
432 let values = self.flat_naive();
433 let values = unsafe { values.agg_list(&self.groups) };
434 self.state = AggState::AggregatedList(values);
435 self.with_update_groups(UpdateGroups::WithGroupsLen);
436 }
437
438 pub fn aggregated_as_list<'b>(&'b mut self) -> Cow<'b, ListChunked> {
440 self.aggregated();
441 let out = self.get_values();
442 match self.agg_state() {
443 AggState::AggregatedScalar(_) => Cow::Owned(out.as_list()),
444 _ => Cow::Borrowed(out.list().unwrap()),
445 }
446 }
447
448 pub fn aggregated(&mut self) -> Column {
450 match self.state.clone() {
453 AggState::NotAggregated(s) => {
454 self.groups();
459 #[cfg(debug_assertions)]
460 {
461 if self.groups.len() > s.len() {
462 polars_warn!(
463 "groups may be out of bounds; more groups than elements in a series is only possible in dynamic group_by"
464 )
465 }
466 }
467
468 let out = unsafe { s.agg_list(&self.groups) };
471 self.state = AggState::AggregatedList(out.clone());
472
473 self.update_groups = UpdateGroups::WithGroupsLen;
474 out
475 },
476 AggState::AggregatedList(s) | AggState::AggregatedScalar(s) => s.into_column(),
477 AggState::LiteralScalar(s) => {
478 let rows = self.groups.len();
479 let s = s.implode().unwrap();
480 let s = s.new_from_index(0, rows);
481 let s = s.into_column();
482 self.state = AggState::AggregatedList(s.clone());
483 self.with_update_groups(UpdateGroups::WithSeriesLen);
484 s.clone()
485 },
486 }
487 }
488
489 pub fn finalize(&mut self) -> Column {
491 match &self.state {
494 AggState::LiteralScalar(c) => {
495 let c = c.clone();
496 self.groups();
497 let rows = self.groups.len();
498 c.new_from_index(0, rows)
499 },
500 _ => self.aggregated(),
501 }
502 }
503
504 fn arity_should_explode(&self) -> bool {
507 use AggState::*;
508 match self.agg_state() {
509 LiteralScalar(s) => s.len() == 1,
510 AggregatedScalar(_) => true,
511 _ => false,
512 }
513 }
514
515 pub fn get_final_aggregation(mut self) -> (Column, Cow<'a, GroupPositions>) {
516 let _ = self.groups();
517 let groups = self.groups;
518 match self.state {
519 AggState::NotAggregated(c) => (c, groups),
520 AggState::AggregatedScalar(c) => (c, groups),
521 AggState::LiteralScalar(c) => (c, groups),
522 AggState::AggregatedList(c) => {
523 let flattened = c
524 .explode(ExplodeOptions {
525 empty_as_null: false,
526 keep_nulls: true,
527 })
528 .unwrap();
529 let groups = groups.into_owned();
530 let groups = groups.unroll();
555 (flattened, Cow::Owned(groups))
556 },
557 }
558 }
559
560 pub(crate) fn flat_naive(&self) -> Cow<'_, Column> {
565 match &self.state {
566 AggState::NotAggregated(c) => Cow::Borrowed(c),
567 AggState::AggregatedList(c) => {
568 if cfg!(debug_assertions) {
569 if self.groups.is_overlapping() {
572 polars_warn!(
573 "performance - an aggregated list with overlapping groups may consume excessive memory"
574 )
575 }
576 }
577
578 Cow::Owned(
580 c.explode(ExplodeOptions {
581 empty_as_null: false,
582 keep_nulls: true,
583 })
584 .unwrap(),
585 )
586 },
587 AggState::AggregatedScalar(c) => Cow::Borrowed(c),
588 AggState::LiteralScalar(c) => Cow::Borrowed(c),
589 }
590 }
591
592 fn flat_naive_length(&self) -> usize {
593 match &self.state {
594 AggState::NotAggregated(c) => c.len(),
595 AggState::AggregatedList(c) => c.list().unwrap().inner_length(),
596 AggState::AggregatedScalar(c) => c.len(),
597 AggState::LiteralScalar(_) => 1,
598 }
599 }
600
601 pub(crate) fn take(&mut self) -> Column {
603 let c = match &mut self.state {
604 AggState::NotAggregated(c)
605 | AggState::AggregatedScalar(c)
606 | AggState::AggregatedList(c) => c,
607 AggState::LiteralScalar(c) => c,
608 };
609 std::mem::take(c)
610 }
611
612 fn groups_cover_all_values(&mut self) -> bool {
614 if matches!(
615 self.state,
616 AggState::LiteralScalar(_) | AggState::AggregatedScalar(_)
617 ) {
618 return true;
619 }
620
621 let num_values = self.flat_naive_length();
622 match self.groups().as_ref().as_ref() {
623 GroupsType::Idx(groups) => {
624 let mut seen = MutableBitmap::from_len_zeroed(num_values);
625 for (_, g) in groups {
626 for i in g.iter() {
627 unsafe { seen.set_unchecked(*i as usize, true) };
628 }
629 }
630 seen.unset_bits() == 0
631 },
632 GroupsType::Slice {
633 groups,
634 overlapping: true,
635 monotonic: _,
636 } => {
637 let mut offset = 0;
639 let mut covers_all = true;
640 for [start, length] in groups {
641 covers_all &= *start <= offset;
642 offset = start + length;
643 }
644 covers_all && offset == num_values as IdxSize
645 },
646
647 GroupsType::Slice {
649 groups,
650 overlapping: false,
651 monotonic: _,
652 } => groups.iter().map(|[_, l]| *l as usize).sum::<usize>() == num_values,
653 }
654 }
655
656 pub(crate) fn set_groups_for_undefined_agg_states(&mut self) {
659 match &self.state {
660 AggState::AggregatedList(_) | AggState::NotAggregated(_) => {},
661 AggState::AggregatedScalar(c) => {
662 assert_eq!(self.update_groups, UpdateGroups::No);
663 self.groups = Cow::Owned({
664 let groups = (0..c.len() as IdxSize).map(|i| [i, 1]).collect();
665 GroupsType::new_slice(groups, false, true).into_sliceable()
666 });
667 self.set_original_len(false);
668 },
669 AggState::LiteralScalar(c) => {
670 assert_eq!(c.len(), 1);
671 assert_eq!(self.update_groups, UpdateGroups::No);
672 self.groups = Cow::Owned({
673 let groups = vec![[0, 1]; self.groups.len()];
674 GroupsType::new_slice(groups, true, true).into_sliceable()
675 });
676 self.set_original_len(false);
677 },
678 }
679 }
680
681 pub fn into_static(&self) -> AggregationContext<'static> {
682 let groups: GroupPositions = GroupPositions::to_owned(&self.groups);
683 let groups: Cow<'static, GroupPositions> = Cow::Owned(groups);
684 AggregationContext {
685 state: self.state.clone(),
686 groups,
687 update_groups: self.update_groups,
688 original_len: self.original_len,
689 }
690 }
691}
692
693pub trait PhysicalExpr: Send + Sync {
696 fn as_expression(&self) -> Option<&Expr> {
697 None
698 }
699
700 fn as_column(&self) -> Option<PlSmallStr> {
701 None
702 }
703
704 fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult<Column> {
709 self.evaluate_impl(df, state).map_err(|e| {
710 if let Some(expr) = self.as_expression() {
711 e.with_expr_context(expr.to_string().into())
712 } else {
713 e
714 }
715 })
716 }
717
718 fn evaluate_impl(&self, df: &DataFrame, _state: &ExecutionState) -> PolarsResult<Column>;
719
720 #[allow(clippy::ptr_arg)]
743 fn evaluate_on_groups<'a>(
744 &self,
745 df: &DataFrame,
746 groups: &'a GroupPositions,
747 state: &ExecutionState,
748 ) -> PolarsResult<AggregationContext<'a>> {
749 self.evaluate_on_groups_impl(df, groups, state)
750 .map_err(|e| {
751 if let Some(expr) = self.as_expression() {
752 e.with_expr_context(expr.to_string().into())
753 } else {
754 e
755 }
756 })
757 }
758
759 #[allow(clippy::ptr_arg)]
760 fn evaluate_on_groups_impl<'a>(
761 &self,
762 df: &DataFrame,
763 groups: &'a GroupPositions,
764 state: &ExecutionState,
765 ) -> PolarsResult<AggregationContext<'a>>;
766
767 fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field>;
769
770 fn is_literal(&self) -> bool {
771 false
772 }
773 fn is_scalar(&self) -> bool;
774}
775
776impl Display for &dyn PhysicalExpr {
777 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
778 match self.as_expression() {
779 None => Ok(()),
780 Some(e) => write!(f, "{e:?}"),
781 }
782 }
783}
784
785pub struct PhysicalIoHelper {
789 pub expr: Arc<dyn PhysicalExpr>,
790 pub has_window_function: bool,
791}
792
793impl PhysicalIoExpr for PhysicalIoHelper {
794 fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
795 let mut state: ExecutionState = Default::default();
796 if self.has_window_function {
797 state.insert_has_window_function_flag();
798 }
799 self.expr.evaluate(df, &state).map(|c| {
800 debug_assert_eq!(c.dtype(), &DataType::Boolean);
802 (if c.len() == 1 && df.height() != 1 {
803 c.new_from_index(0, df.height())
805 } else {
806 c
807 })
808 .take_materialized_series()
809 })
810 }
811}
812
813pub fn phys_expr_to_io_expr(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalIoExpr> {
814 let has_window_function = if let Some(expr) = expr.as_expression() {
815 expr.into_iter().any(|expr| {
816 #[cfg(feature = "dynamic_group_by")]
817 if matches!(expr, Expr::Rolling { .. }) {
818 return true;
819 }
820
821 matches!(expr, Expr::Over { .. })
822 })
823 } else {
824 false
825 };
826 Arc::new(PhysicalIoHelper {
827 expr,
828 has_window_function,
829 }) as Arc<dyn PhysicalIoExpr>
830}