polars_expr/expressions/
mod.rs

1mod aggregation;
2mod alias;
3mod apply;
4mod binary;
5mod cast;
6mod column;
7mod count;
8mod filter;
9mod gather;
10mod group_iter;
11mod literal;
12#[cfg(feature = "dynamic_group_by")]
13mod rolling;
14mod slice;
15mod sort;
16mod sortby;
17mod ternary;
18mod window;
19
20use std::borrow::Cow;
21use std::fmt::{Display, Formatter};
22
23pub(crate) use aggregation::*;
24pub(crate) use alias::*;
25pub(crate) use apply::*;
26use arrow::array::ArrayRef;
27use arrow::legacy::utils::CustomIterTools;
28pub(crate) use binary::*;
29pub(crate) use cast::*;
30pub(crate) use column::*;
31pub(crate) use count::*;
32pub(crate) use filter::*;
33pub(crate) use gather::*;
34pub(crate) use literal::*;
35use polars_core::prelude::*;
36use polars_io::predicates::PhysicalIoExpr;
37use polars_plan::prelude::*;
38#[cfg(feature = "dynamic_group_by")]
39pub(crate) use rolling::RollingExpr;
40pub(crate) use slice::*;
41pub(crate) use sort::*;
42pub(crate) use sortby::*;
43pub(crate) use ternary::*;
44pub use window::window_function_format_order_by;
45pub(crate) use window::*;
46
47use crate::state::ExecutionState;
48
49#[derive(Clone, Debug)]
50pub enum AggState {
51    /// Already aggregated: `.agg_list(group_tuples)` is called
52    /// and produced a `Series` of dtype `List`
53    AggregatedList(Column),
54    /// Already aggregated: `.agg` is called on an aggregation
55    /// that produces a scalar.
56    /// think of `sum`, `mean`, `variance` like aggregations.
57    AggregatedScalar(Column),
58    /// Not yet aggregated: `agg_list` still has to be called.
59    NotAggregated(Column),
60    Literal(Column),
61}
62
63impl AggState {
64    fn try_map<F>(&self, func: F) -> PolarsResult<Self>
65    where
66        F: FnOnce(&Column) -> PolarsResult<Column>,
67    {
68        Ok(match self {
69            AggState::AggregatedList(c) => AggState::AggregatedList(func(c)?),
70            AggState::AggregatedScalar(c) => AggState::AggregatedScalar(func(c)?),
71            AggState::Literal(c) => AggState::Literal(func(c)?),
72            AggState::NotAggregated(c) => AggState::NotAggregated(func(c)?),
73        })
74    }
75}
76
77// lazy update strategy
78#[cfg_attr(debug_assertions, derive(Debug))]
79#[derive(PartialEq, Clone, Copy)]
80pub(crate) enum UpdateGroups {
81    /// don't update groups
82    No,
83    /// use the length of the current groups to determine new sorted indexes, preferred
84    /// for performance
85    WithGroupsLen,
86    /// use the series list offsets to determine the new group lengths
87    /// this one should be used when the length has changed. Note that
88    /// the series should be aggregated state or else it will panic.
89    WithSeriesLen,
90}
91
92#[cfg_attr(debug_assertions, derive(Debug))]
93pub struct AggregationContext<'a> {
94    /// Can be in one of two states
95    /// 1. already aggregated as list
96    /// 2. flat (still needs the grouptuples to aggregate)
97    state: AggState,
98    /// group tuples for AggState
99    groups: Cow<'a, GroupPositions>,
100    /// if the group tuples are already used in a level above
101    /// and the series is exploded, the group tuples are sorted
102    /// e.g. the exploded Series is grouped per group.
103    sorted: bool,
104    /// This is used to determined if we need to update the groups
105    /// into a sorted groups. We do this lazily, so that this work only is
106    /// done when the groups are needed
107    update_groups: UpdateGroups,
108    /// This is true when the Series and Groups still have all
109    /// their original values. Not the case when filtered
110    original_len: bool,
111}
112
113impl<'a> AggregationContext<'a> {
114    pub(crate) fn dtype(&self) -> DataType {
115        match &self.state {
116            AggState::Literal(s) => s.dtype().clone(),
117            AggState::AggregatedList(s) => s.list().unwrap().inner_dtype().clone(),
118            AggState::AggregatedScalar(s) => s.dtype().clone(),
119            AggState::NotAggregated(s) => s.dtype().clone(),
120        }
121    }
122    pub(crate) fn groups(&mut self) -> &Cow<'a, GroupPositions> {
123        match self.update_groups {
124            UpdateGroups::No => {},
125            UpdateGroups::WithGroupsLen => {
126                // the groups are unordered
127                // and the series is aggregated with this groups
128                // so we need to recreate new grouptuples that
129                // match the exploded Series
130                let mut offset = 0 as IdxSize;
131
132                match self.groups.as_ref().as_ref() {
133                    GroupsType::Idx(groups) => {
134                        let groups = groups
135                            .iter()
136                            .map(|g| {
137                                let len = g.1.len() as IdxSize;
138                                let new_offset = offset + len;
139                                let out = [offset, len];
140                                offset = new_offset;
141                                out
142                            })
143                            .collect();
144                        self.groups = Cow::Owned(
145                            GroupsType::Slice {
146                                groups,
147                                rolling: false,
148                            }
149                            .into_sliceable(),
150                        )
151                    },
152                    // sliced groups are already in correct order
153                    GroupsType::Slice { .. } => {},
154                }
155                self.update_groups = UpdateGroups::No;
156            },
157            UpdateGroups::WithSeriesLen => {
158                let s = self.get_values().clone();
159                self.det_groups_from_list(s.as_materialized_series());
160            },
161        }
162        &self.groups
163    }
164
165    pub(crate) fn get_values(&self) -> &Column {
166        match &self.state {
167            AggState::NotAggregated(s)
168            | AggState::AggregatedScalar(s)
169            | AggState::AggregatedList(s) => s,
170            AggState::Literal(s) => s,
171        }
172    }
173
174    pub fn agg_state(&self) -> &AggState {
175        &self.state
176    }
177
178    pub(crate) fn is_not_aggregated(&self) -> bool {
179        matches!(
180            &self.state,
181            AggState::NotAggregated(_) | AggState::Literal(_)
182        )
183    }
184
185    pub(crate) fn is_aggregated(&self) -> bool {
186        !self.is_not_aggregated()
187    }
188
189    pub(crate) fn is_literal(&self) -> bool {
190        matches!(self.state, AggState::Literal(_))
191    }
192
193    /// # Arguments
194    /// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its
195    ///   the columns dtype)
196    fn new(
197        column: Column,
198        groups: Cow<'a, GroupPositions>,
199        aggregated: bool,
200    ) -> AggregationContext<'a> {
201        let series = match (aggregated, column.dtype()) {
202            (true, &DataType::List(_)) => {
203                assert_eq!(column.len(), groups.len());
204                AggState::AggregatedList(column)
205            },
206            (true, _) => {
207                assert_eq!(column.len(), groups.len());
208                AggState::AggregatedScalar(column)
209            },
210            _ => AggState::NotAggregated(column),
211        };
212
213        Self {
214            state: series,
215            groups,
216            sorted: false,
217            update_groups: UpdateGroups::No,
218            original_len: true,
219        }
220    }
221
222    fn with_agg_state(&mut self, agg_state: AggState) {
223        self.state = agg_state;
224    }
225
226    fn from_agg_state(
227        agg_state: AggState,
228        groups: Cow<'a, GroupPositions>,
229    ) -> AggregationContext<'a> {
230        Self {
231            state: agg_state,
232            groups,
233            sorted: false,
234            update_groups: UpdateGroups::No,
235            original_len: true,
236        }
237    }
238
239    fn from_literal(lit: Column, groups: Cow<'a, GroupPositions>) -> AggregationContext<'a> {
240        Self {
241            state: AggState::Literal(lit),
242            groups,
243            sorted: false,
244            update_groups: UpdateGroups::No,
245            original_len: true,
246        }
247    }
248
249    pub(crate) fn set_original_len(&mut self, original_len: bool) -> &mut Self {
250        self.original_len = original_len;
251        self
252    }
253
254    pub(crate) fn with_update_groups(&mut self, update: UpdateGroups) -> &mut Self {
255        self.update_groups = update;
256        self
257    }
258
259    pub(crate) fn det_groups_from_list(&mut self, s: &Series) {
260        let mut offset = 0 as IdxSize;
261        let list = s
262            .list()
263            .expect("impl error, should be a list at this point");
264
265        match list.chunks().len() {
266            1 => {
267                let arr = list.downcast_iter().next().unwrap();
268                let offsets = arr.offsets().as_slice();
269
270                let mut previous = 0i64;
271                let groups = offsets[1..]
272                    .iter()
273                    .map(|&o| {
274                        let len = (o - previous) as IdxSize;
275                        // explode will fill empty rows with null, so we must increment the group
276                        // offset accordingly
277                        let new_offset = offset + len + (len == 0) as IdxSize;
278
279                        previous = o;
280                        let out = [offset, len];
281                        offset = new_offset;
282                        out
283                    })
284                    .collect_trusted();
285                self.groups = Cow::Owned(
286                    GroupsType::Slice {
287                        groups,
288                        rolling: false,
289                    }
290                    .into_sliceable(),
291                );
292            },
293            _ => {
294                let groups = {
295                    self.get_values()
296                        .list()
297                        .expect("impl error, should be a list at this point")
298                        .amortized_iter()
299                        .map(|s| {
300                            if let Some(s) = s {
301                                let len = s.as_ref().len() as IdxSize;
302                                let new_offset = offset + len;
303                                let out = [offset, len];
304                                offset = new_offset;
305                                out
306                            } else {
307                                [offset, 0]
308                            }
309                        })
310                        .collect_trusted()
311                };
312                self.groups = Cow::Owned(
313                    GroupsType::Slice {
314                        groups,
315                        rolling: false,
316                    }
317                    .into_sliceable(),
318                );
319            },
320        }
321        self.update_groups = UpdateGroups::No;
322    }
323
324    /// # Arguments
325    /// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its
326    ///   the columns dtype)
327    pub(crate) fn with_values(
328        &mut self,
329        column: Column,
330        aggregated: bool,
331        expr: Option<&Expr>,
332    ) -> PolarsResult<&mut Self> {
333        self.with_values_and_args(column, aggregated, expr, false)
334    }
335
336    pub(crate) fn with_values_and_args(
337        &mut self,
338        column: Column,
339        aggregated: bool,
340        expr: Option<&Expr>,
341        // if the applied function was a `map` instead of an `apply`
342        // this will keep functions applied over literals as literals: F(lit) = lit
343        mapped: bool,
344    ) -> PolarsResult<&mut Self> {
345        self.state = match (aggregated, column.dtype()) {
346            (true, &DataType::List(_)) => {
347                if column.len() != self.groups.len() {
348                    let fmt_expr = if let Some(e) = expr {
349                        format!("'{e:?}' ")
350                    } else {
351                        String::new()
352                    };
353                    polars_bail!(
354                        ComputeError:
355                        "aggregation expression '{}' produced a different number of elements: {} \
356                        than the number of groups: {} (this is likely invalid)",
357                        fmt_expr, column.len(), self.groups.len(),
358                    );
359                }
360                AggState::AggregatedList(column)
361            },
362            (true, _) => AggState::AggregatedScalar(column),
363            _ => {
364                match self.state {
365                    // already aggregated to sum, min even this series was flattened it never could
366                    // retrieve the length before grouping, so it stays  in this state.
367                    AggState::AggregatedScalar(_) => AggState::AggregatedScalar(column),
368                    // applying a function on a literal, keeps the literal state
369                    AggState::Literal(_) if column.len() == 1 && mapped => {
370                        AggState::Literal(column)
371                    },
372                    _ => AggState::NotAggregated(column.into_column()),
373                }
374            },
375        };
376        Ok(self)
377    }
378
379    pub(crate) fn with_literal(&mut self, column: Column) -> &mut Self {
380        self.state = AggState::Literal(column);
381        self
382    }
383
384    /// Update the group tuples
385    pub(crate) fn with_groups(&mut self, groups: GroupPositions) -> &mut Self {
386        if let AggState::AggregatedList(_) = self.agg_state() {
387            // In case of new groups, a series always needs to be flattened
388            self.with_values(self.flat_naive().into_owned(), false, None)
389                .unwrap();
390        }
391        self.groups = Cow::Owned(groups);
392        // make sure that previous setting is not used
393        self.update_groups = UpdateGroups::No;
394        self
395    }
396
397    /// Get the aggregated version of the series.
398    pub fn aggregated(&mut self) -> Column {
399        // we clone, because we only want to call `self.groups()` if needed.
400        // self groups may instantiate new groups and thus can be expensive.
401        match self.state.clone() {
402            AggState::NotAggregated(s) => {
403                // The groups are determined lazily and in case of a flat/non-aggregated
404                // series we use the groups to aggregate the list
405                // because this is lazy, we first must to update the groups
406                // by calling .groups()
407                self.groups();
408                #[cfg(debug_assertions)]
409                {
410                    if self.groups.len() > s.len() {
411                        polars_warn!("groups may be out of bounds; more groups than elements in a series is only possible in dynamic group_by")
412                    }
413                }
414
415                // SAFETY:
416                // groups are in bounds
417                let out = unsafe { s.agg_list(&self.groups) };
418                self.state = AggState::AggregatedList(out.clone());
419
420                self.sorted = true;
421                self.update_groups = UpdateGroups::WithGroupsLen;
422                out
423            },
424            AggState::AggregatedList(s) | AggState::AggregatedScalar(s) => s.into_column(),
425            AggState::Literal(s) => {
426                self.groups();
427                let rows = self.groups.len();
428                let s = s.new_from_index(0, rows);
429                let out = s
430                    .reshape_list(&[
431                        ReshapeDimension::new_dimension(rows as u64),
432                        ReshapeDimension::Infer,
433                    ])
434                    .unwrap();
435                self.state = AggState::AggregatedList(out.clone());
436                out.into_column()
437            },
438        }
439    }
440
441    /// Get the final aggregated version of the series.
442    pub fn finalize(&mut self) -> Column {
443        // we clone, because we only want to call `self.groups()` if needed.
444        // self groups may instantiate new groups and thus can be expensive.
445        match &self.state {
446            AggState::Literal(c) => {
447                let c = c.clone();
448                self.groups();
449                let rows = self.groups.len();
450                c.new_from_index(0, rows)
451            },
452            _ => self.aggregated(),
453        }
454    }
455
456    // If a binary or ternary function has both of these branches true, it should
457    // flatten the list
458    fn arity_should_explode(&self) -> bool {
459        use AggState::*;
460        match self.agg_state() {
461            Literal(s) => s.len() == 1,
462            AggregatedScalar(_) => true,
463            _ => false,
464        }
465    }
466
467    pub fn get_final_aggregation(mut self) -> (Column, Cow<'a, GroupPositions>) {
468        let _ = self.groups();
469        let groups = self.groups;
470        match self.state {
471            AggState::NotAggregated(c) => (c, groups),
472            AggState::AggregatedScalar(c) => (c, groups),
473            AggState::Literal(c) => (c, groups),
474            AggState::AggregatedList(c) => {
475                let flattened = c.explode().unwrap();
476                let groups = groups.into_owned();
477                // unroll the possible flattened state
478                // say we have groups with overlapping windows:
479                //
480                // offset, len
481                // 0, 1
482                // 0, 2
483                // 0, 4
484                //
485                // gets aggregation
486                //
487                // [0]
488                // [0, 1],
489                // [0, 1, 2, 3]
490                //
491                // before aggregation the column was
492                // [0, 1, 2, 3]
493                // but explode on this list yields
494                // [0, 0, 1, 0, 1, 2, 3]
495                //
496                // so we unroll the groups as
497                //
498                // [0, 1]
499                // [1, 2]
500                // [3, 4]
501                let groups = groups.unroll();
502                (flattened, Cow::Owned(groups))
503            },
504        }
505    }
506
507    /// Get the not-aggregated version of the series.
508    /// Note that we call it naive, because if a previous expr
509    /// has filtered or sorted this, this information is in the
510    /// group tuples not the flattened series.
511    pub(crate) fn flat_naive(&self) -> Cow<'_, Column> {
512        match &self.state {
513            AggState::NotAggregated(c) => Cow::Borrowed(c),
514            AggState::AggregatedList(c) => {
515                #[cfg(debug_assertions)]
516                {
517                    // panic so we find cases where we accidentally explode overlapping groups
518                    // we don't want this as this can create a lot of data
519                    if let GroupsType::Slice { rolling: true, .. } = self.groups.as_ref().as_ref() {
520                        panic!("implementation error, polars should not hit this branch for overlapping groups")
521                    }
522                }
523
524                Cow::Owned(c.explode().unwrap())
525            },
526            AggState::AggregatedScalar(c) => Cow::Borrowed(c),
527            AggState::Literal(c) => Cow::Borrowed(c),
528        }
529    }
530
531    /// Take the series.
532    pub(crate) fn take(&mut self) -> Column {
533        let c = match &mut self.state {
534            AggState::NotAggregated(c)
535            | AggState::AggregatedScalar(c)
536            | AggState::AggregatedList(c) => c,
537            AggState::Literal(c) => c,
538        };
539        std::mem::take(c)
540    }
541}
542
543/// Take a DataFrame and evaluate the expressions.
544/// Implement this for Column, lt, eq, etc
545pub trait PhysicalExpr: Send + Sync {
546    fn as_expression(&self) -> Option<&Expr> {
547        None
548    }
549
550    /// Take a DataFrame and evaluate the expression.
551    fn evaluate(&self, df: &DataFrame, _state: &ExecutionState) -> PolarsResult<Column>;
552
553    /// Attempt to cheaply evaluate this expression in-line without a DataFrame context.
554    /// This is used by StatsEvaluator when skipping files / row groups using a predicate.
555    /// TODO: Maybe in the future we can do this evaluation in-line at the optimizer stage?
556    ///
557    /// Do not implement this directly - instead implement `evaluate_inline_impl`
558    fn evaluate_inline(&self) -> Option<Column> {
559        self.evaluate_inline_impl(4)
560    }
561
562    /// Implementation of `evaluate_inline`
563    fn evaluate_inline_impl(&self, _depth_limit: u8) -> Option<Column> {
564        None
565    }
566
567    /// Some expression that are not aggregations can be done per group
568    /// Think of sort, slice, filter, shift, etc.
569    /// defaults to ignoring the group
570    ///
571    /// This method is called by an aggregation function.
572    ///
573    /// In case of a simple expr, like 'column', the groups are ignored and the column is returned.
574    /// In case of an expr where group behavior makes sense, this method is called.
575    /// For a filter operation for instance, a Series is created per groups and filtered.
576    ///
577    /// An implementation of this method may apply an aggregation on the groups only. For instance
578    /// on a shift, the groups are first aggregated to a `ListChunked` and the shift is applied per
579    /// group. The implementation then has to return the `Series` exploded (because a later aggregation
580    /// will use the group tuples to aggregate). The group tuples also have to be updated, because
581    /// aggregation to a list sorts the exploded `Series` by group.
582    ///
583    /// This has some gotcha's. An implementation may also change the group tuples instead of
584    /// the `Series`.
585    ///
586    // we allow this because we pass the vec to the Cow
587    // Note to self: Don't be smart and dispatch to evaluate as default implementation
588    // this means filters will be incorrect and lead to invalid results down the line
589    #[allow(clippy::ptr_arg)]
590    fn evaluate_on_groups<'a>(
591        &self,
592        df: &DataFrame,
593        groups: &'a GroupPositions,
594        state: &ExecutionState,
595    ) -> PolarsResult<AggregationContext<'a>>;
596
597    /// Get the output field of this expr
598    fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field>;
599
600    /// Convert to a partitioned aggregator.
601    fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
602        None
603    }
604
605    /// Get the variables that are used in the expression i.e. live variables.
606    /// This can contain duplicates.
607    fn collect_live_columns(&self, lv: &mut PlIndexSet<PlSmallStr>);
608
609    /// Can take &dyn Statistics and determine of a file should be
610    /// read -> `true`
611    /// or not -> `false`
612    fn as_stats_evaluator(&self) -> Option<&dyn polars_io::predicates::StatsEvaluator> {
613        None
614    }
615    fn is_literal(&self) -> bool {
616        false
617    }
618    fn is_scalar(&self) -> bool;
619}
620
621impl Display for &dyn PhysicalExpr {
622    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
623        match self.as_expression() {
624            None => Ok(()),
625            Some(e) => write!(f, "{e:?}"),
626        }
627    }
628}
629
630/// Wrapper struct that allow us to use a PhysicalExpr in polars-io.
631///
632/// This is used to filter rows during the scan of file.
633pub struct PhysicalIoHelper {
634    pub expr: Arc<dyn PhysicalExpr>,
635    pub has_window_function: bool,
636}
637
638impl PhysicalIoExpr for PhysicalIoHelper {
639    fn evaluate_io(&self, df: &DataFrame) -> PolarsResult<Series> {
640        let mut state: ExecutionState = Default::default();
641        if self.has_window_function {
642            state.insert_has_window_function_flag();
643        }
644        self.expr
645            .evaluate(df, &state)
646            .map(|c| c.take_materialized_series())
647    }
648
649    fn collect_live_columns(&self, live_columns: &mut PlIndexSet<PlSmallStr>) {
650        self.expr.collect_live_columns(live_columns);
651    }
652
653    #[cfg(feature = "parquet")]
654    fn as_stats_evaluator(&self) -> Option<&dyn polars_io::predicates::StatsEvaluator> {
655        self.expr.as_stats_evaluator()
656    }
657}
658
659pub fn phys_expr_to_io_expr(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalIoExpr> {
660    let has_window_function = if let Some(expr) = expr.as_expression() {
661        expr.into_iter()
662            .any(|expr| matches!(expr, Expr::Window { .. }))
663    } else {
664        false
665    };
666    Arc::new(PhysicalIoHelper {
667        expr,
668        has_window_function,
669    }) as Arc<dyn PhysicalIoExpr>
670}
671
672pub trait PartitionedAggregation: Send + Sync + PhysicalExpr {
673    /// This is called in partitioned aggregation.
674    /// Partitioned results may differ from aggregation results.
675    /// For instance, for a `mean` operation a partitioned result
676    /// needs to return the `sum` and the `valid_count` (length - null count).
677    ///
678    /// A final aggregation can then take the sum of sums and sum of valid_counts
679    /// to produce a final mean.
680    #[allow(clippy::ptr_arg)]
681    fn evaluate_partitioned(
682        &self,
683        df: &DataFrame,
684        groups: &GroupPositions,
685        state: &ExecutionState,
686    ) -> PolarsResult<Column>;
687
688    /// Called to merge all the partitioned results in a final aggregate.
689    #[allow(clippy::ptr_arg)]
690    fn finalize(
691        &self,
692        partitioned: Column,
693        groups: &GroupPositions,
694        state: &ExecutionState,
695    ) -> PolarsResult<Column>;
696}