polars_expr/expressions/
mod.rs

1mod aggregation;
2mod alias;
3mod apply;
4mod binary;
5mod cast;
6mod column;
7mod count;
8mod eval;
9mod filter;
10mod gather;
11mod group_iter;
12mod literal;
13#[cfg(feature = "dynamic_group_by")]
14mod rolling;
15mod slice;
16mod sort;
17mod sortby;
18mod ternary;
19mod window;
20
21use std::borrow::Cow;
22use std::fmt::{Display, Formatter};
23
24pub(crate) use aggregation::*;
25pub(crate) use alias::*;
26pub(crate) use apply::*;
27use arrow::array::ArrayRef;
28use arrow::legacy::utils::CustomIterTools;
29pub(crate) use binary::*;
30pub(crate) use cast::*;
31pub(crate) use column::*;
32pub(crate) use count::*;
33pub(crate) use eval::*;
34pub(crate) use filter::*;
35pub(crate) use gather::*;
36pub(crate) use literal::*;
37use polars_core::prelude::*;
38use polars_io::predicates::PhysicalIoExpr;
39use polars_plan::prelude::*;
40#[cfg(feature = "dynamic_group_by")]
41pub(crate) use rolling::RollingExpr;
42pub(crate) use slice::*;
43pub(crate) use sort::*;
44pub(crate) use sortby::*;
45pub(crate) use ternary::*;
46pub use window::window_function_format_order_by;
47pub(crate) use window::*;
48
49use crate::state::ExecutionState;
50
51#[derive(Clone, Debug)]
52pub enum AggState {
53    /// Already aggregated: `.agg_list(group_tuples)` is called
54    /// and produced a `Series` of dtype `List`
55    AggregatedList(Column),
56    /// Already aggregated: `.agg` is called on an aggregation
57    /// that produces a scalar.
58    /// think of `sum`, `mean`, `variance` like aggregations.
59    AggregatedScalar(Column),
60    /// Not yet aggregated: `agg_list` still has to be called.
61    NotAggregated(Column),
62    Literal(Column),
63}
64
65impl AggState {
66    fn try_map<F>(&self, func: F) -> PolarsResult<Self>
67    where
68        F: FnOnce(&Column) -> PolarsResult<Column>,
69    {
70        Ok(match self {
71            AggState::AggregatedList(c) => AggState::AggregatedList(func(c)?),
72            AggState::AggregatedScalar(c) => AggState::AggregatedScalar(func(c)?),
73            AggState::Literal(c) => AggState::Literal(func(c)?),
74            AggState::NotAggregated(c) => AggState::NotAggregated(func(c)?),
75        })
76    }
77
78    fn is_scalar(&self) -> bool {
79        matches!(self, Self::AggregatedScalar(_))
80    }
81}
82
83// lazy update strategy
84#[cfg_attr(debug_assertions, derive(Debug))]
85#[derive(PartialEq, Clone, Copy)]
86pub(crate) enum UpdateGroups {
87    /// don't update groups
88    No,
89    /// use the length of the current groups to determine new sorted indexes, preferred
90    /// for performance
91    WithGroupsLen,
92    /// use the series list offsets to determine the new group lengths
93    /// this one should be used when the length has changed. Note that
94    /// the series should be aggregated state or else it will panic.
95    WithSeriesLen,
96}
97
98#[cfg_attr(debug_assertions, derive(Debug))]
99pub struct AggregationContext<'a> {
100    /// Can be in one of two states
101    /// 1. already aggregated as list
102    /// 2. flat (still needs the grouptuples to aggregate)
103    state: AggState,
104    /// group tuples for AggState
105    groups: Cow<'a, GroupPositions>,
106    /// if the group tuples are already used in a level above
107    /// and the series is exploded, the group tuples are sorted
108    /// e.g. the exploded Series is grouped per group.
109    sorted: bool,
110    /// This is used to determined if we need to update the groups
111    /// into a sorted groups. We do this lazily, so that this work only is
112    /// done when the groups are needed
113    update_groups: UpdateGroups,
114    /// This is true when the Series and Groups still have all
115    /// their original values. Not the case when filtered
116    original_len: bool,
117}
118
119impl<'a> AggregationContext<'a> {
120    pub(crate) fn dtype(&self) -> DataType {
121        match &self.state {
122            AggState::Literal(s) => s.dtype().clone(),
123            AggState::AggregatedList(s) => s.list().unwrap().inner_dtype().clone(),
124            AggState::AggregatedScalar(s) => s.dtype().clone(),
125            AggState::NotAggregated(s) => s.dtype().clone(),
126        }
127    }
128    pub(crate) fn groups(&mut self) -> &Cow<'a, GroupPositions> {
129        match self.update_groups {
130            UpdateGroups::No => {},
131            UpdateGroups::WithGroupsLen => {
132                // the groups are unordered
133                // and the series is aggregated with this groups
134                // so we need to recreate new grouptuples that
135                // match the exploded Series
136                let mut offset = 0 as IdxSize;
137
138                match self.groups.as_ref().as_ref() {
139                    GroupsType::Idx(groups) => {
140                        let groups = groups
141                            .iter()
142                            .map(|g| {
143                                let len = g.1.len() as IdxSize;
144                                let new_offset = offset + len;
145                                let out = [offset, len];
146                                offset = new_offset;
147                                out
148                            })
149                            .collect();
150                        self.groups = Cow::Owned(
151                            GroupsType::Slice {
152                                groups,
153                                rolling: false,
154                            }
155                            .into_sliceable(),
156                        )
157                    },
158                    // sliced groups are already in correct order
159                    GroupsType::Slice { .. } => {},
160                }
161                self.update_groups = UpdateGroups::No;
162            },
163            UpdateGroups::WithSeriesLen => {
164                let s = self.get_values().clone();
165                self.det_groups_from_list(s.as_materialized_series());
166            },
167        }
168        &self.groups
169    }
170
171    pub(crate) fn get_values(&self) -> &Column {
172        match &self.state {
173            AggState::NotAggregated(s)
174            | AggState::AggregatedScalar(s)
175            | AggState::AggregatedList(s) => s,
176            AggState::Literal(s) => s,
177        }
178    }
179
180    pub fn agg_state(&self) -> &AggState {
181        &self.state
182    }
183
184    pub(crate) fn is_not_aggregated(&self) -> bool {
185        matches!(
186            &self.state,
187            AggState::NotAggregated(_) | AggState::Literal(_)
188        )
189    }
190
191    pub(crate) fn is_aggregated(&self) -> bool {
192        !self.is_not_aggregated()
193    }
194
195    pub(crate) fn is_literal(&self) -> bool {
196        matches!(self.state, AggState::Literal(_))
197    }
198
199    /// # Arguments
200    /// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its
201    ///   the columns dtype)
202    fn new(
203        column: Column,
204        groups: Cow<'a, GroupPositions>,
205        aggregated: bool,
206    ) -> AggregationContext<'a> {
207        let series = match (aggregated, column.dtype()) {
208            (true, &DataType::List(_)) => {
209                assert_eq!(column.len(), groups.len());
210                AggState::AggregatedList(column)
211            },
212            (true, _) => {
213                assert_eq!(column.len(), groups.len());
214                AggState::AggregatedScalar(column)
215            },
216            _ => AggState::NotAggregated(column),
217        };
218
219        Self {
220            state: series,
221            groups,
222            sorted: false,
223            update_groups: UpdateGroups::No,
224            original_len: true,
225        }
226    }
227
228    fn with_agg_state(&mut self, agg_state: AggState) {
229        self.state = agg_state;
230    }
231
232    fn from_agg_state(
233        agg_state: AggState,
234        groups: Cow<'a, GroupPositions>,
235    ) -> AggregationContext<'a> {
236        Self {
237            state: agg_state,
238            groups,
239            sorted: false,
240            update_groups: UpdateGroups::No,
241            original_len: true,
242        }
243    }
244
245    fn from_literal(lit: Column, groups: Cow<'a, GroupPositions>) -> AggregationContext<'a> {
246        Self {
247            state: AggState::Literal(lit),
248            groups,
249            sorted: false,
250            update_groups: UpdateGroups::No,
251            original_len: true,
252        }
253    }
254
255    pub(crate) fn set_original_len(&mut self, original_len: bool) -> &mut Self {
256        self.original_len = original_len;
257        self
258    }
259
260    pub(crate) fn with_update_groups(&mut self, update: UpdateGroups) -> &mut Self {
261        self.update_groups = update;
262        self
263    }
264
265    fn det_groups_from_list(&mut self, s: &Series) {
266        let mut offset = 0 as IdxSize;
267        let list = s
268            .list()
269            .expect("impl error, should be a list at this point");
270
271        match list.chunks().len() {
272            1 => {
273                let arr = list.downcast_iter().next().unwrap();
274                let offsets = arr.offsets().as_slice();
275
276                let mut previous = 0i64;
277                let groups = offsets[1..]
278                    .iter()
279                    .map(|&o| {
280                        let len = (o - previous) as IdxSize;
281                        // explode will fill empty rows with null, so we must increment the group
282                        // offset accordingly
283                        let new_offset = offset + len + (len == 0) as IdxSize;
284
285                        previous = o;
286                        let out = [offset, len];
287                        offset = new_offset;
288                        out
289                    })
290                    .collect_trusted();
291                self.groups = Cow::Owned(
292                    GroupsType::Slice {
293                        groups,
294                        rolling: false,
295                    }
296                    .into_sliceable(),
297                );
298            },
299            _ => {
300                let groups = {
301                    self.get_values()
302                        .list()
303                        .expect("impl error, should be a list at this point")
304                        .amortized_iter()
305                        .map(|s| {
306                            if let Some(s) = s {
307                                let len = s.as_ref().len() as IdxSize;
308                                let new_offset = offset + len;
309                                let out = [offset, len];
310                                offset = new_offset;
311                                out
312                            } else {
313                                [offset, 0]
314                            }
315                        })
316                        .collect_trusted()
317                };
318                self.groups = Cow::Owned(
319                    GroupsType::Slice {
320                        groups,
321                        rolling: false,
322                    }
323                    .into_sliceable(),
324                );
325            },
326        }
327        self.update_groups = UpdateGroups::No;
328    }
329
330    /// # Arguments
331    /// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its
332    ///   the columns dtype)
333    pub(crate) fn with_values(
334        &mut self,
335        column: Column,
336        aggregated: bool,
337        expr: Option<&Expr>,
338    ) -> PolarsResult<&mut Self> {
339        self.with_values_and_args(
340            column,
341            aggregated,
342            expr,
343            false,
344            self.agg_state().is_scalar(),
345        )
346    }
347
348    pub(crate) fn with_values_and_args(
349        &mut self,
350        column: Column,
351        aggregated: bool,
352        expr: Option<&Expr>,
353        // if the applied function was a `map` instead of an `apply`
354        // this will keep functions applied over literals as literals: F(lit) = lit
355        mapped: bool,
356        returns_scalar: bool,
357    ) -> PolarsResult<&mut Self> {
358        self.state = match (aggregated, column.dtype()) {
359            (true, &DataType::List(_)) if !returns_scalar => {
360                if column.len() != self.groups.len() {
361                    let fmt_expr = if let Some(e) = expr {
362                        format!("'{e:?}' ")
363                    } else {
364                        String::new()
365                    };
366                    polars_bail!(
367                        ComputeError:
368                        "aggregation expression '{}' produced a different number of elements: {} \
369                        than the number of groups: {} (this is likely invalid)",
370                        fmt_expr, column.len(), self.groups.len(),
371                    );
372                }
373                AggState::AggregatedList(column)
374            },
375            (true, _) => AggState::AggregatedScalar(column),
376            _ => {
377                match self.state {
378                    // already aggregated to sum, min even this series was flattened it never could
379                    // retrieve the length before grouping, so it stays  in this state.
380                    AggState::AggregatedScalar(_) => AggState::AggregatedScalar(column),
381                    // applying a function on a literal, keeps the literal state
382                    AggState::Literal(_) if column.len() == 1 && mapped => {
383                        AggState::Literal(column)
384                    },
385                    _ => AggState::NotAggregated(column.into_column()),
386                }
387            },
388        };
389        Ok(self)
390    }
391
392    pub(crate) fn with_literal(&mut self, column: Column) -> &mut Self {
393        self.state = AggState::Literal(column);
394        self
395    }
396
397    /// Update the group tuples
398    pub(crate) fn with_groups(&mut self, groups: GroupPositions) -> &mut Self {
399        if let AggState::AggregatedList(_) = self.agg_state() {
400            // In case of new groups, a series always needs to be flattened
401            self.with_values(self.flat_naive().into_owned(), false, None)
402                .unwrap();
403        }
404        self.groups = Cow::Owned(groups);
405        // make sure that previous setting is not used
406        self.update_groups = UpdateGroups::No;
407        self
408    }
409
410    pub(crate) fn _implode_no_agg(&mut self) {
411        match self.state.clone() {
412            AggState::NotAggregated(_) => {
413                let _ = self.aggregated();
414                let AggState::AggregatedList(s) = self.state.clone() else {
415                    unreachable!()
416                };
417                self.state = AggState::AggregatedScalar(s);
418            },
419            AggState::AggregatedList(s) => {
420                self.state = AggState::AggregatedScalar(s);
421            },
422            _ => unreachable!("should only be called in non-agg/list-agg state by aggregation.rs"),
423        }
424    }
425
426    /// Get the aggregated version of the series.
427    pub fn aggregated(&mut self) -> Column {
428        // we clone, because we only want to call `self.groups()` if needed.
429        // self groups may instantiate new groups and thus can be expensive.
430        match self.state.clone() {
431            AggState::NotAggregated(s) => {
432                // The groups are determined lazily and in case of a flat/non-aggregated
433                // series we use the groups to aggregate the list
434                // because this is lazy, we first must to update the groups
435                // by calling .groups()
436                self.groups();
437                #[cfg(debug_assertions)]
438                {
439                    if self.groups.len() > s.len() {
440                        polars_warn!(
441                            "groups may be out of bounds; more groups than elements in a series is only possible in dynamic group_by"
442                        )
443                    }
444                }
445
446                // SAFETY:
447                // groups are in bounds
448                let out = unsafe { s.agg_list(&self.groups) };
449                self.state = AggState::AggregatedList(out.clone());
450
451                self.sorted = true;
452                self.update_groups = UpdateGroups::WithGroupsLen;
453                out
454            },
455            AggState::AggregatedList(s) | AggState::AggregatedScalar(s) => s.into_column(),
456            AggState::Literal(s) => {
457                self.groups();
458                let rows = self.groups.len();
459                let s = s.new_from_index(0, rows);
460                let out = s
461                    .reshape_list(&[
462                        ReshapeDimension::new_dimension(rows as u64),
463                        ReshapeDimension::Infer,
464                    ])
465                    .unwrap();
466                self.state = AggState::AggregatedList(out.clone());
467                out.into_column()
468            },
469        }
470    }
471
472    /// Get the final aggregated version of the series.
473    pub fn finalize(&mut self) -> Column {
474        // we clone, because we only want to call `self.groups()` if needed.
475        // self groups may instantiate new groups and thus can be expensive.
476        match &self.state {
477            AggState::Literal(c) => {
478                let c = c.clone();
479                self.groups();
480                let rows = self.groups.len();
481                c.new_from_index(0, rows)
482            },
483            _ => self.aggregated(),
484        }
485    }
486
487    // If a binary or ternary function has both of these branches true, it should
488    // flatten the list
489    fn arity_should_explode(&self) -> bool {
490        use AggState::*;
491        match self.agg_state() {
492            Literal(s) => s.len() == 1,
493            AggregatedScalar(_) => true,
494            _ => false,
495        }
496    }
497
498    pub fn get_final_aggregation(mut self) -> (Column, Cow<'a, GroupPositions>) {
499        let _ = self.groups();
500        let groups = self.groups;
501        match self.state {
502            AggState::NotAggregated(c) => (c, groups),
503            AggState::AggregatedScalar(c) => (c, groups),
504            AggState::Literal(c) => (c, groups),
505            AggState::AggregatedList(c) => {
506                let flattened = c.explode(false).unwrap();
507                let groups = groups.into_owned();
508                // unroll the possible flattened state
509                // say we have groups with overlapping windows:
510                //
511                // offset, len
512                // 0, 1
513                // 0, 2
514                // 0, 4
515                //
516                // gets aggregation
517                //
518                // [0]
519                // [0, 1],
520                // [0, 1, 2, 3]
521                //
522                // before aggregation the column was
523                // [0, 1, 2, 3]
524                // but explode on this list yields
525                // [0, 0, 1, 0, 1, 2, 3]
526                //
527                // so we unroll the groups as
528                //
529                // [0, 1]
530                // [1, 2]
531                // [3, 4]
532                let groups = groups.unroll();
533                (flattened, Cow::Owned(groups))
534            },
535        }
536    }
537
538    /// Get the not-aggregated version of the series.
539    /// Note that we call it naive, because if a previous expr
540    /// has filtered or sorted this, this information is in the
541    /// group tuples not the flattened series.
542    pub(crate) fn flat_naive(&self) -> Cow<'_, Column> {
543        match &self.state {
544            AggState::NotAggregated(c) => Cow::Borrowed(c),
545            AggState::AggregatedList(c) => {
546                #[cfg(debug_assertions)]
547                {
548                    // panic so we find cases where we accidentally explode overlapping groups
549                    // we don't want this as this can create a lot of data
550                    if let GroupsType::Slice { rolling: true, .. } = self.groups.as_ref().as_ref() {
551                        panic!(
552                            "implementation error, polars should not hit this branch for overlapping groups"
553                        )
554                    }
555                }
556
557                Cow::Owned(c.explode(false).unwrap())
558            },
559            AggState::AggregatedScalar(c) => Cow::Borrowed(c),
560            AggState::Literal(c) => Cow::Borrowed(c),
561        }
562    }
563
564    /// Take the series.
565    pub(crate) fn take(&mut self) -> Column {
566        let c = match &mut self.state {
567            AggState::NotAggregated(c)
568            | AggState::AggregatedScalar(c)
569            | AggState::AggregatedList(c) => c,
570            AggState::Literal(c) => c,
571        };
572        std::mem::take(c)
573    }
574}
575
576/// Take a DataFrame and evaluate the expressions.
577/// Implement this for Column, lt, eq, etc
578pub trait PhysicalExpr: Send + Sync {
579    fn as_expression(&self) -> Option<&Expr> {
580        None
581    }
582
583    /// Take a DataFrame and evaluate the expression.
584    fn evaluate(&self, df: &DataFrame, _state: &ExecutionState) -> PolarsResult<Column>;
585
586    /// Attempt to cheaply evaluate this expression in-line without a DataFrame context.
587    /// This is used by StatsEvaluator when skipping files / row groups using a predicate.
588    /// TODO: Maybe in the future we can do this evaluation in-line at the optimizer stage?
589    ///
590    /// Do not implement this directly - instead implement `evaluate_inline_impl`
591    fn evaluate_inline(&self) -> Option<Column> {
592        self.evaluate_inline_impl(4)
593    }
594
595    /// Implementation of `evaluate_inline`
596    fn evaluate_inline_impl(&self, _depth_limit: u8) -> Option<Column> {
597        None
598    }
599
600    /// Some expression that are not aggregations can be done per group
601    /// Think of sort, slice, filter, shift, etc.
602    /// defaults to ignoring the group
603    ///
604    /// This method is called by an aggregation function.
605    ///
606    /// In case of a simple expr, like 'column', the groups are ignored and the column is returned.
607    /// In case of an expr where group behavior makes sense, this method is called.
608    /// For a filter operation for instance, a Series is created per groups and filtered.
609    ///
610    /// An implementation of this method may apply an aggregation on the groups only. For instance
611    /// on a shift, the groups are first aggregated to a `ListChunked` and the shift is applied per
612    /// group. The implementation then has to return the `Series` exploded (because a later aggregation
613    /// will use the group tuples to aggregate). The group tuples also have to be updated, because
614    /// aggregation to a list sorts the exploded `Series` by group.
615    ///
616    /// This has some gotcha's. An implementation may also change the group tuples instead of
617    /// the `Series`.
618    ///
619    // we allow this because we pass the vec to the Cow
620    // Note to self: Don't be smart and dispatch to evaluate as default implementation
621    // this means filters will be incorrect and lead to invalid results down the line
622    #[allow(clippy::ptr_arg)]
623    fn evaluate_on_groups<'a>(
624        &self,
625        df: &DataFrame,
626        groups: &'a GroupPositions,
627        state: &ExecutionState,
628    ) -> PolarsResult<AggregationContext<'a>>;
629
630    /// Get the output field of this expr
631    fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field>;
632
633    /// Convert to a partitioned aggregator.
634    fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
635        None
636    }
637
638    fn is_literal(&self) -> bool {
639        false
640    }
641    fn is_scalar(&self) -> bool;
642}
643
644impl Display for &dyn PhysicalExpr {
645    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
646        match self.as_expression() {
647            None => Ok(()),
648            Some(e) => write!(f, "{e:?}"),
649        }
650    }
651}
652
653/// Wrapper struct that allow us to use a PhysicalExpr in polars-io.
654///
655/// This is used to filter rows during the scan of file.
656pub struct PhysicalIoHelper {
657    pub expr: Arc<dyn PhysicalExpr>,
658    pub has_window_function: bool,
659}
660
661impl PhysicalIoExpr for PhysicalIoHelper {
662    fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
663        let mut state: ExecutionState = Default::default();
664        if self.has_window_function {
665            state.insert_has_window_function_flag();
666        }
667        self.expr
668            .evaluate(df, &state)
669            .map(|c| c.take_materialized_series())
670    }
671}
672
673pub fn phys_expr_to_io_expr(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalIoExpr> {
674    let has_window_function = if let Some(expr) = expr.as_expression() {
675        expr.into_iter()
676            .any(|expr| matches!(expr, Expr::Window { .. }))
677    } else {
678        false
679    };
680    Arc::new(PhysicalIoHelper {
681        expr,
682        has_window_function,
683    }) as Arc<dyn PhysicalIoExpr>
684}
685
686pub trait PartitionedAggregation: Send + Sync + PhysicalExpr {
687    /// This is called in partitioned aggregation.
688    /// Partitioned results may differ from aggregation results.
689    /// For instance, for a `mean` operation a partitioned result
690    /// needs to return the `sum` and the `valid_count` (length - null count).
691    ///
692    /// A final aggregation can then take the sum of sums and sum of valid_counts
693    /// to produce a final mean.
694    #[allow(clippy::ptr_arg)]
695    fn evaluate_partitioned(
696        &self,
697        df: &DataFrame,
698        groups: &GroupPositions,
699        state: &ExecutionState,
700    ) -> PolarsResult<Column>;
701
702    /// Called to merge all the partitioned results in a final aggregate.
703    #[allow(clippy::ptr_arg)]
704    fn finalize(
705        &self,
706        partitioned: Column,
707        groups: &GroupPositions,
708        state: &ExecutionState,
709    ) -> PolarsResult<Column>;
710}