polars-lazy 0.26.1

Lazy query engine for the Polars DataFrame library
Documentation
mod aggregation;
mod alias;
mod apply;
mod binary;
mod cast;
mod column;
mod count;
mod filter;
mod group_iter;
mod literal;
mod slice;
mod sort;
mod sortby;
mod take;
mod ternary;
mod window;

use std::borrow::Cow;
use std::fmt::{Display, Formatter};

pub(crate) use aggregation::*;
pub(crate) use alias::*;
pub(crate) use apply::*;
pub(crate) use binary::*;
pub(crate) use cast::*;
pub(crate) use column::*;
pub(crate) use count::*;
pub(crate) use filter::*;
pub(crate) use literal::*;
use polars_arrow::export::arrow::array::ListArray;
use polars_arrow::export::arrow::offset::Offsets;
use polars_arrow::trusted_len::PushUnchecked;
use polars_arrow::utils::CustomIterTools;
use polars_core::frame::groupby::GroupsProxy;
use polars_core::prelude::*;
use polars_io::predicates::PhysicalIoExpr;
pub(crate) use slice::*;
pub(crate) use sort::*;
pub(crate) use sortby::*;
pub(crate) use take::*;
pub(crate) use ternary::*;
pub(crate) use window::*;

use crate::physical_plan::state::ExecutionState;
use crate::prelude::*;

#[derive(Clone, Debug)]
pub(crate) enum AggState {
    /// Already aggregated: `.agg_list(group_tuples`) is called
    /// and produced a `Series` of dtype `List`
    AggregatedList(Series),
    /// Already aggregated: `.agg_list(group_tuples`) is called
    /// and produced a `Series` of any dtype that is not nested.
    /// think of `sum`, `mean`, `variance` like aggregations.
    AggregatedFlat(Series),
    /// Not yet aggregated: `agg_list` still has to be called.
    NotAggregated(Series),
    Literal(Series),
}

impl AggState {
    // Literal series are not safe to aggregate
    fn safe_to_agg(&self, groups: &GroupsProxy) -> bool {
        match self {
            AggState::NotAggregated(s) => {
                !(s.len() == 1
                    // or more then one group
                    && (groups.len() > 1
                    // or single groups with more than one index
                    || !groups.is_empty()
                    && groups.get(0).len() > 1))
            }
            _ => true,
        }
    }
}

// lazy update strategy
#[cfg_attr(debug_assertions, derive(Debug))]
#[derive(PartialEq)]
pub(crate) enum UpdateGroups {
    /// don't update groups
    No,
    /// use the length of the current groups to determine new sorted indexes, preferred
    /// for performance
    WithGroupsLen,
    /// use the series list offsets to determine the new group lengths
    /// this one should be used when the length has changed. Note that
    /// the series should be aggregated state or else it will panic.
    WithSeriesLen,
    // Same as WithSeriesLen, but now take a series given by the caller
    WithSeriesLenOwned(Series),
}

#[cfg_attr(debug_assertions, derive(Debug))]
pub struct AggregationContext<'a> {
    /// Can be in one of two states
    /// 1. already aggregated as list
    /// 2. flat (still needs the grouptuples to aggregate)
    state: AggState,
    /// group tuples for AggState
    groups: Cow<'a, GroupsProxy>,
    /// if the group tuples are already used in a level above
    /// and the series is exploded, the group tuples are sorted
    /// e.g. the exploded Series is grouped per group.
    sorted: bool,
    /// This is used to determined if we need to update the groups
    /// into a sorted groups. We do this lazily, so that this work only is
    /// done when the groups are needed
    update_groups: UpdateGroups,
    /// This is true when the Series and GroupsProxy still have all
    /// their original values. Not the case when filtered
    original_len: bool,
}

impl<'a> AggregationContext<'a> {
    pub(crate) fn groups(&mut self) -> &Cow<'a, GroupsProxy> {
        match self.update_groups {
            UpdateGroups::No => {}
            UpdateGroups::WithGroupsLen => {
                // the groups are unordered
                // and the series is aggregated with this groups
                // so we need to recreate new grouptuples that
                // match the exploded Series
                let mut offset = 0 as IdxSize;

                match self.groups.as_ref() {
                    GroupsProxy::Idx(groups) => {
                        let groups = groups
                            .iter()
                            .map(|g| {
                                let len = g.1.len() as IdxSize;
                                let new_offset = offset + len;
                                let out = [offset, len];
                                offset = new_offset;
                                out
                            })
                            .collect();
                        self.groups = Cow::Owned(GroupsProxy::Slice {
                            groups,
                            rolling: false,
                        })
                    }
                    // sliced groups are already in correct order
                    GroupsProxy::Slice { .. } => {}
                }
                self.update_groups = UpdateGroups::No;
            }
            UpdateGroups::WithSeriesLen => {
                let s = self.series().clone();
                self.det_groups_from_list(&s);
            }
            UpdateGroups::WithSeriesLenOwned(ref s) => {
                let s = s.clone();
                self.det_groups_from_list(&s);
            }
        }
        &self.groups
    }

    pub(crate) fn series(&self) -> &Series {
        match &self.state {
            AggState::NotAggregated(s)
            | AggState::AggregatedFlat(s)
            | AggState::AggregatedList(s) => s,
            AggState::Literal(s) => s,
        }
    }

    pub(crate) fn agg_state(&self) -> &AggState {
        &self.state
    }

    pub(crate) fn is_not_aggregated(&self) -> bool {
        matches!(
            &self.state,
            AggState::NotAggregated(_) | AggState::Literal(_)
        )
    }

    pub(crate) fn is_aggregated(&self) -> bool {
        !self.is_not_aggregated()
    }
    pub(crate) fn is_literal(&self) -> bool {
        matches!(self.state, AggState::Literal(_))
    }

    pub(crate) fn combine_groups(&mut self, other: AggregationContext) -> &mut Self {
        if let (Cow::Borrowed(_), Cow::Owned(a)) = (&self.groups, other.groups) {
            self.groups = Cow::Owned(a);
        };
        self
    }

    /// # Arguments
    /// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its
    /// the columns dtype)
    fn new(
        series: Series,
        groups: Cow<'a, GroupsProxy>,
        aggregated: bool,
    ) -> AggregationContext<'a> {
        let series = match (aggregated, series.dtype()) {
            (true, &DataType::List(_)) => {
                assert_eq!(series.len(), groups.len());
                AggState::AggregatedList(series)
            }
            (true, _) => {
                assert_eq!(series.len(), groups.len());
                AggState::AggregatedFlat(series)
            }
            _ => AggState::NotAggregated(series),
        };

        Self {
            state: series,
            groups,
            sorted: false,
            update_groups: UpdateGroups::No,
            original_len: true,
        }
    }

    fn from_literal(lit: Series, groups: Cow<'a, GroupsProxy>) -> AggregationContext<'a> {
        Self {
            state: AggState::Literal(lit),
            groups,
            sorted: false,
            update_groups: UpdateGroups::No,
            original_len: true,
        }
    }

    pub(crate) fn set_original_len(&mut self, original_len: bool) -> &mut Self {
        self.original_len = original_len;
        self
    }

    pub(crate) fn with_update_groups(&mut self, update: UpdateGroups) -> &mut Self {
        self.update_groups = update;
        self
    }

    pub(crate) fn det_groups_from_list(&mut self, s: &Series) {
        let mut offset = 0 as IdxSize;
        let list = s
            .list()
            .expect("impl error, should be a list at this point");

        match list.chunks().len() {
            1 => {
                let arr = list.downcast_iter().next().unwrap();
                let offsets = arr.offsets().as_slice();

                let mut previous = 0i64;
                let groups = offsets[1..]
                    .iter()
                    .map(|&o| {
                        let len = (o - previous) as IdxSize;
                        // explode will fill empty rows with null, so we must increment the group
                        // offset accordingly
                        let new_offset = offset + len + (len == 0) as IdxSize;

                        previous = o;
                        let out = [offset, len];
                        offset = new_offset;
                        out
                    })
                    .collect_trusted();
                self.groups = Cow::Owned(GroupsProxy::Slice {
                    groups,
                    rolling: false,
                });
            }
            _ => {
                let groups = self
                    .series()
                    .list()
                    .expect("impl error, should be a list at this point")
                    .amortized_iter()
                    .map(|s| {
                        if let Some(s) = s {
                            let len = s.as_ref().len() as IdxSize;
                            let new_offset = offset + len;
                            let out = [offset, len];
                            offset = new_offset;
                            out
                        } else {
                            [offset, 0]
                        }
                    })
                    .collect_trusted();
                self.groups = Cow::Owned(GroupsProxy::Slice {
                    groups,
                    rolling: false,
                });
            }
        }
        self.update_groups = UpdateGroups::No;
    }

    /// In a binary expression one state can be aggregated and the other not.
    /// If both would be flattened naively one would be sorted and the other not.
    /// Calling this function will ensure both are sorted. This will be a no-op
    /// if already aggregated.
    pub(crate) fn sort_by_groups(&mut self) {
        // make sure that the groups are updated before we use them to sort.
        self.groups();
        match &self.state {
            AggState::NotAggregated(s) => {
                // We should not aggregate literals!!
                if self.state.safe_to_agg(&self.groups) {
                    // safety:
                    // groups are in bounds
                    let agg = unsafe { s.agg_list(&self.groups) };
                    self.update_groups = UpdateGroups::WithGroupsLen;
                    self.state = AggState::AggregatedList(agg);
                }
            }
            AggState::AggregatedFlat(_) => {}
            AggState::AggregatedList(_) => {}
            AggState::Literal(_) => {}
        }
    }

    /// # Arguments
    /// - `aggregated` sets if the Series is a list due to aggregation (could also be a list because its
    /// the columns dtype)
    pub(crate) fn with_series(&mut self, series: Series, aggregated: bool) -> &mut Self {
        self.state = match (aggregated, series.dtype()) {
            (true, &DataType::List(_)) => {
                assert_eq!(series.len(), self.groups.len());
                AggState::AggregatedList(series)
            }
            (true, _) => AggState::AggregatedFlat(series),
            _ => {
                // already aggregated to sum, min even this series was flattened it never could
                // retrieve the length before grouping, so it stays  in this state.
                if let AggState::AggregatedFlat(_) = self.state {
                    AggState::AggregatedFlat(series)
                } else {
                    AggState::NotAggregated(series)
                }
            }
        };
        self
    }

    pub(crate) fn with_literal(&mut self, series: Series) -> &mut Self {
        self.state = AggState::Literal(series);
        self
    }

    /// Update the group tuples
    pub(crate) fn with_groups(&mut self, groups: GroupsProxy) -> &mut Self {
        // In case of new groups, a series always needs to be flattened
        self.with_series(self.flat_naive().into_owned(), false);
        self.groups = Cow::Owned(groups);
        // make sure that previous setting is not used
        self.update_groups = UpdateGroups::No;
        self
    }

    /// Get the aggregated version of the series.
    pub(crate) fn aggregated(&mut self) -> Series {
        // we clone, because we only want to call `self.groups()` if needed.
        // self groups may instantiate new groups and thus can be expensive.
        match self.state.clone() {
            AggState::NotAggregated(s) => {
                // The groups are determined lazily and in case of a flat/non-aggregated
                // series we use the groups to aggregate the list
                // because this is lazy, we first must to update the groups
                // by calling .groups()
                self.groups();
                #[cfg(debug_assertions)]
                {
                    if self.groups.len() > s.len() {
                        eprintln!("groups may be out of bounds; more groups than elements in a series is only possible in dynamic groupby")
                    }
                }

                // safety:
                // groups are in bounds
                let out = unsafe { s.agg_list(&self.groups) };
                self.state = AggState::AggregatedList(out.clone());

                self.sorted = true;
                self.update_groups = UpdateGroups::WithGroupsLen;
                out
            }
            AggState::AggregatedList(s) | AggState::AggregatedFlat(s) => s,
            AggState::Literal(s) => {
                self.groups();
                let rows = self.groups.len();
                let s = s.new_from_index(0, rows);
                s.reshape(&[rows as i64, -1]).unwrap()
            }
        }
    }

    /// Get the final aggregated version of the series.
    pub(crate) fn finalize(&mut self) -> Series {
        // we clone, because we only want to call `self.groups()` if needed.
        // self groups may instantiate new groups and thus can be expensive.
        match &self.state {
            AggState::Literal(s) => {
                let s = s.clone();
                self.groups();
                let rows = self.groups.len();
                s.new_from_index(0, rows)
            }
            _ => self.aggregated(),
        }
    }

    /// Different from aggregated, in arity operations we expect literals to expand to the size of the
    /// group
    /// eg:
    ///
    /// lit(9) in groups [[1, 1], [2, 2, 2]]
    /// becomes: [[9, 9], [9, 9, 9]]
    ///
    /// where in [`Self::aggregated`] this becomes [9, 9]
    ///
    /// this is because comparisons need to create mask that have a correct length.
    fn aggregated_arity_operation(&mut self) -> Series {
        if let AggState::Literal(s) = self.agg_state() {
            // stop borrow;
            let s = s.clone();
            let groups = self.groups();

            let mut offsets = Vec::with_capacity(groups.len() + 1);

            let mut last_offset = 0i64;
            offsets.push(last_offset);
            for g in groups.iter() {
                last_offset += g.len() as i64;
                // safety:
                // we allocated enough
                unsafe { offsets.push_unchecked(last_offset) };
            }
            let values = s.new_from_index(0, last_offset as usize);
            let values = values.array_ref(0).clone();
            // Safety:
            // offsets are monotonically increasing
            let arr = unsafe {
                ListArray::<i64>::new(
                    DataType::List(Box::new(s.dtype().clone())).to_arrow(),
                    Offsets::new_unchecked(offsets).into(),
                    values,
                    None,
                )
            };
            Series::try_from((s.name(), Box::new(arr) as ArrayRef)).unwrap()
        } else {
            self.aggregated()
        }
    }

    // If a binary or ternary function has both of these branches true, it should
    // flatten the list
    fn arity_should_explode(&self) -> bool {
        use AggState::*;
        match self.agg_state() {
            Literal(s) => s.len() == 1,
            AggregatedFlat(_) => true,
            _ => false,
        }
    }

    /// Get the not-aggregated version of the series.
    /// Note that we call it naive, because if a previous expr
    /// has filtered or sorted this, this information is in the
    /// group tuples not the flattened series.
    pub(crate) fn flat_naive(&self) -> Cow<'_, Series> {
        match &self.state {
            AggState::NotAggregated(s) => Cow::Borrowed(s),
            AggState::AggregatedList(s) => {
                #[cfg(debug_assertions)]
                {
                    // panic so we find cases where we accidentally explode overlapping groups
                    // we don't want this as this can create a lot of data
                    if let GroupsProxy::Slice { rolling: true, .. } = self.groups.as_ref() {
                        panic!("implementation error, polars should not hit this branch for overlapping groups")
                    }
                }

                Cow::Owned(s.explode().unwrap())
            }
            AggState::AggregatedFlat(s) => Cow::Borrowed(s),
            AggState::Literal(s) => Cow::Borrowed(s),
        }
    }

    /// Take the series.
    pub(crate) fn take(&mut self) -> Series {
        let s = match &mut self.state {
            AggState::NotAggregated(s)
            | AggState::AggregatedFlat(s)
            | AggState::AggregatedList(s) => s,
            AggState::Literal(s) => s,
        };
        std::mem::take(s)
    }
}

/// Take a DataFrame and evaluate the expressions.
/// Implement this for Column, lt, eq, etc
pub trait PhysicalExpr: Send + Sync {
    fn as_expression(&self) -> Option<&Expr> {
        None
    }

    /// Take a DataFrame and evaluate the expression.
    fn evaluate(&self, df: &DataFrame, _state: &ExecutionState) -> PolarsResult<Series>;

    /// Some expression that are not aggregations can be done per group
    /// Think of sort, slice, filter, shift, etc.
    /// defaults to ignoring the group
    ///
    /// This method is called by an aggregation function.
    ///
    /// In case of a simple expr, like 'column', the groups are ignored and the column is returned.
    /// In case of an expr where group behavior makes sense, this method is called.
    /// For a filter operation for instance, a Series is created per groups and filtered.
    ///
    /// An implementation of this method may apply an aggregation on the groups only. For instance
    /// on a shift, the groups are first aggregated to a `ListChunked` and the shift is applied per
    /// group. The implementation then has to return the `Series` exploded (because a later aggregation
    /// will use the group tuples to aggregate). The group tuples also have to be updated, because
    /// aggregation to a list sorts the exploded `Series` by group.
    ///
    /// This has some gotcha's. An implementation may also change the group tuples instead of
    /// the `Series`.
    ///
    // we allow this because we pass the vec to the Cow
    // Note to self: Don't be smart and dispatch to evaluate as default implementation
    // this means filters will be incorrect and lead to invalid results down the line
    #[allow(clippy::ptr_arg)]
    fn evaluate_on_groups<'a>(
        &self,
        df: &DataFrame,
        groups: &'a GroupsProxy,
        state: &ExecutionState,
    ) -> PolarsResult<AggregationContext<'a>>;

    /// Get the output field of this expr
    fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field>;

    /// Convert to a partitioned aggregator.
    fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
        None
    }

    /// Can take &dyn Statistics and determine of a file should be
    /// read -> `true`
    /// or not -> `false`
    #[cfg(feature = "parquet")]
    fn as_stats_evaluator(&self) -> Option<&dyn polars_io::predicates::StatsEvaluator> {
        None
    }

    //
    fn is_valid_aggregation(&self) -> bool;

    fn is_literal(&self) -> bool {
        false
    }
}

impl Display for &dyn PhysicalExpr {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match self.as_expression() {
            None => Ok(()),
            Some(e) => write!(f, "{e}"),
        }
    }
}

/// Wrapper struct that allow us to use a PhysicalExpr in polars-io.
///
/// This is used to filter rows during the scan of file.
pub struct PhysicalIoHelper {
    pub expr: Arc<dyn PhysicalExpr>,
}

impl PhysicalIoExpr for PhysicalIoHelper {
    fn evaluate(&self, df: &DataFrame) -> PolarsResult<Series> {
        self.expr.evaluate(df, &Default::default())
    }

    #[cfg(feature = "parquet")]
    fn as_stats_evaluator(&self) -> Option<&dyn polars_io::predicates::StatsEvaluator> {
        self.expr.as_stats_evaluator()
    }
}

pub trait PartitionedAggregation: Send + Sync + PhysicalExpr {
    /// This is called in partitioned aggregation.
    /// Partitioned results may differ from aggregation results.
    /// For instance, for a `mean` operation a partitioned result
    /// needs to return the `sum` and the `valid_count` (length - null count).
    ///
    /// A final aggregation can then take the sum of sums and sum of valid_counts
    /// to produce a final mean.
    #[allow(clippy::ptr_arg)]
    fn evaluate_partitioned(
        &self,
        df: &DataFrame,
        groups: &GroupsProxy,
        state: &ExecutionState,
    ) -> PolarsResult<Series>;

    /// Called to merge all the partitioned results in a final aggregate.
    #[allow(clippy::ptr_arg)]
    fn finalize(
        &self,
        partitioned: Series,
        groups: &GroupsProxy,
        state: &ExecutionState,
    ) -> PolarsResult<Series>;
}