polars_plan/dsl/
expr.rs

1use std::fmt::{Debug, Display, Formatter};
2use std::hash::{Hash, Hasher};
3
4use bytes::Bytes;
5use polars_compute::rolling::QuantileMethod;
6use polars_core::chunked_array::cast::CastOptions;
7use polars_core::error::feature_gated;
8use polars_core::prelude::*;
9#[cfg(feature = "serde")]
10use serde::{Deserialize, Serialize};
11
12pub use super::expr_dyn_fn::*;
13use crate::prelude::*;
14
15#[derive(PartialEq, Clone, Hash)]
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17pub enum AggExpr {
18    Min {
19        input: Arc<Expr>,
20        propagate_nans: bool,
21    },
22    Max {
23        input: Arc<Expr>,
24        propagate_nans: bool,
25    },
26    Median(Arc<Expr>),
27    NUnique(Arc<Expr>),
28    First(Arc<Expr>),
29    Last(Arc<Expr>),
30    Mean(Arc<Expr>),
31    Implode(Arc<Expr>),
32    // include_nulls
33    Count(Arc<Expr>, bool),
34    Quantile {
35        expr: Arc<Expr>,
36        quantile: Arc<Expr>,
37        method: QuantileMethod,
38    },
39    Sum(Arc<Expr>),
40    AggGroups(Arc<Expr>),
41    Std(Arc<Expr>, u8),
42    Var(Arc<Expr>, u8),
43}
44
45impl AsRef<Expr> for AggExpr {
46    fn as_ref(&self) -> &Expr {
47        use AggExpr::*;
48        match self {
49            Min { input, .. } => input,
50            Max { input, .. } => input,
51            Median(e) => e,
52            NUnique(e) => e,
53            First(e) => e,
54            Last(e) => e,
55            Mean(e) => e,
56            Implode(e) => e,
57            Count(e, _) => e,
58            Quantile { expr, .. } => expr,
59            Sum(e) => e,
60            AggGroups(e) => e,
61            Std(e, _) => e,
62            Var(e, _) => e,
63        }
64    }
65}
66
67/// Expressions that can be used in various contexts.
68///
69/// Queries consist of multiple expressions.
70/// When using the polars lazy API, don't construct an `Expr` directly; instead, create one using
71/// the functions in the `polars_lazy::dsl` module. See that module's docs for more info.
72#[derive(Clone, PartialEq)]
73#[must_use]
74#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
75pub enum Expr {
76    Alias(Arc<Expr>, PlSmallStr),
77    Column(PlSmallStr),
78    Columns(Arc<[PlSmallStr]>),
79    DtypeColumn(Vec<DataType>),
80    IndexColumn(Arc<[i64]>),
81    Literal(LiteralValue),
82    BinaryExpr {
83        left: Arc<Expr>,
84        op: Operator,
85        right: Arc<Expr>,
86    },
87    Cast {
88        expr: Arc<Expr>,
89        dtype: DataType,
90        options: CastOptions,
91    },
92    Sort {
93        expr: Arc<Expr>,
94        options: SortOptions,
95    },
96    Gather {
97        expr: Arc<Expr>,
98        idx: Arc<Expr>,
99        returns_scalar: bool,
100    },
101    SortBy {
102        expr: Arc<Expr>,
103        by: Vec<Expr>,
104        sort_options: SortMultipleOptions,
105    },
106    Agg(AggExpr),
107    /// A ternary operation
108    /// if true then "foo" else "bar"
109    Ternary {
110        predicate: Arc<Expr>,
111        truthy: Arc<Expr>,
112        falsy: Arc<Expr>,
113    },
114    Function {
115        /// function arguments
116        input: Vec<Expr>,
117        /// function to apply
118        function: FunctionExpr,
119        options: FunctionOptions,
120    },
121    Explode {
122        input: Arc<Expr>,
123        skip_empty: bool,
124    },
125    Filter {
126        input: Arc<Expr>,
127        by: Arc<Expr>,
128    },
129    /// Polars flavored window functions.
130    Window {
131        /// Also has the input. i.e. avg("foo")
132        function: Arc<Expr>,
133        partition_by: Vec<Expr>,
134        order_by: Option<(Arc<Expr>, SortOptions)>,
135        options: WindowType,
136    },
137    Wildcard,
138    Slice {
139        input: Arc<Expr>,
140        /// length is not yet known so we accept negative offsets
141        offset: Arc<Expr>,
142        length: Arc<Expr>,
143    },
144    /// Can be used in a select statement to exclude a column from selection
145    /// TODO: See if we can replace `Vec<Excluded>` with `Arc<Excluded>`
146    Exclude(Arc<Expr>, Vec<Excluded>),
147    /// Set root name as Alias
148    KeepName(Arc<Expr>),
149    Len,
150    /// Take the nth column in the `DataFrame`
151    Nth(i64),
152    RenameAlias {
153        function: SpecialEq<Arc<dyn RenameAliasFn>>,
154        expr: Arc<Expr>,
155    },
156    #[cfg(feature = "dtype-struct")]
157    Field(Arc<[PlSmallStr]>),
158    AnonymousFunction {
159        /// function arguments
160        input: Vec<Expr>,
161        /// function to apply
162        function: OpaqueColumnUdf,
163        /// output dtype of the function
164        output_type: GetOutput,
165        options: FunctionOptions,
166    },
167    SubPlan(SpecialEq<Arc<DslPlan>>, Vec<String>),
168    /// Expressions in this node should only be expanding
169    /// e.g.
170    /// `Expr::Columns`
171    /// `Expr::Dtypes`
172    /// `Expr::Wildcard`
173    /// `Expr::Exclude`
174    Selector(super::selector::Selector),
175}
176
177pub type OpaqueColumnUdf = LazySerde<SpecialEq<Arc<dyn ColumnsUdf>>>;
178pub(crate) fn new_column_udf<F: ColumnsUdf + 'static>(func: F) -> OpaqueColumnUdf {
179    LazySerde::Deserialized(SpecialEq::new(Arc::new(func)))
180}
181
182#[derive(Clone)]
183pub enum LazySerde<T: Clone> {
184    Deserialized(T),
185    Bytes(Bytes),
186}
187
188impl<T: PartialEq + Clone> PartialEq for LazySerde<T> {
189    fn eq(&self, other: &Self) -> bool {
190        use LazySerde as L;
191        match (self, other) {
192            (L::Deserialized(a), L::Deserialized(b)) => a == b,
193            (L::Bytes(a), L::Bytes(b)) => {
194                std::ptr::eq(a.as_ptr(), b.as_ptr()) && a.len() == b.len()
195            },
196            _ => false,
197        }
198    }
199}
200
201impl<T: Clone> Debug for LazySerde<T> {
202    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
203        match self {
204            Self::Bytes(_) => write!(f, "lazy-serde<Bytes>"),
205            Self::Deserialized(_) => write!(f, "lazy-serde<T>"),
206        }
207    }
208}
209
210impl OpaqueColumnUdf {
211    pub fn materialize(self) -> PolarsResult<SpecialEq<Arc<dyn ColumnsUdf>>> {
212        match self {
213            Self::Deserialized(t) => Ok(t),
214            Self::Bytes(_b) => {
215                feature_gated!("serde";"python", {
216                    crate::dsl::python_dsl::PythonUdfExpression::try_deserialize(_b.as_ref()).map(SpecialEq::new)
217                })
218            },
219        }
220    }
221}
222
223#[allow(clippy::derived_hash_with_manual_eq)]
224impl Hash for Expr {
225    fn hash<H: Hasher>(&self, state: &mut H) {
226        let d = std::mem::discriminant(self);
227        d.hash(state);
228        match self {
229            Expr::Column(name) => name.hash(state),
230            Expr::Columns(names) => names.hash(state),
231            Expr::DtypeColumn(dtypes) => dtypes.hash(state),
232            Expr::IndexColumn(indices) => indices.hash(state),
233            Expr::Literal(lv) => std::mem::discriminant(lv).hash(state),
234            Expr::Selector(s) => s.hash(state),
235            Expr::Nth(v) => v.hash(state),
236            Expr::Filter { input, by } => {
237                input.hash(state);
238                by.hash(state);
239            },
240            Expr::BinaryExpr { left, op, right } => {
241                left.hash(state);
242                right.hash(state);
243                std::mem::discriminant(op).hash(state)
244            },
245            Expr::Cast {
246                expr,
247                dtype,
248                options: strict,
249            } => {
250                expr.hash(state);
251                dtype.hash(state);
252                strict.hash(state)
253            },
254            Expr::Sort { expr, options } => {
255                expr.hash(state);
256                options.hash(state);
257            },
258            Expr::Alias(input, name) => {
259                input.hash(state);
260                name.hash(state)
261            },
262            Expr::KeepName(input) => input.hash(state),
263            Expr::Ternary {
264                predicate,
265                truthy,
266                falsy,
267            } => {
268                predicate.hash(state);
269                truthy.hash(state);
270                falsy.hash(state);
271            },
272            Expr::Function {
273                input,
274                function,
275                options,
276            } => {
277                input.hash(state);
278                std::mem::discriminant(function).hash(state);
279                options.hash(state);
280            },
281            Expr::Gather {
282                expr,
283                idx,
284                returns_scalar,
285            } => {
286                expr.hash(state);
287                idx.hash(state);
288                returns_scalar.hash(state);
289            },
290            // already hashed by discriminant
291            Expr::Wildcard | Expr::Len => {},
292            Expr::SortBy {
293                expr,
294                by,
295                sort_options,
296            } => {
297                expr.hash(state);
298                by.hash(state);
299                sort_options.hash(state);
300            },
301            Expr::Agg(input) => input.hash(state),
302            Expr::Explode { input, skip_empty } => {
303                skip_empty.hash(state);
304                input.hash(state)
305            },
306            Expr::Window {
307                function,
308                partition_by,
309                order_by,
310                options,
311            } => {
312                function.hash(state);
313                partition_by.hash(state);
314                order_by.hash(state);
315                options.hash(state);
316            },
317            Expr::Slice {
318                input,
319                offset,
320                length,
321            } => {
322                input.hash(state);
323                offset.hash(state);
324                length.hash(state);
325            },
326            Expr::Exclude(input, excl) => {
327                input.hash(state);
328                excl.hash(state);
329            },
330            Expr::RenameAlias { function: _, expr } => expr.hash(state),
331            Expr::AnonymousFunction {
332                input,
333                function: _,
334                output_type: _,
335                options,
336            } => {
337                input.hash(state);
338                options.hash(state);
339            },
340            Expr::SubPlan(_, names) => names.hash(state),
341            #[cfg(feature = "dtype-struct")]
342            Expr::Field(names) => names.hash(state),
343        }
344    }
345}
346
347impl Eq for Expr {}
348
349impl Default for Expr {
350    fn default() -> Self {
351        Expr::Literal(LiteralValue::Scalar(Scalar::default()))
352    }
353}
354
355#[derive(Debug, Clone, PartialEq, Eq, Hash)]
356#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
357pub enum Excluded {
358    Name(PlSmallStr),
359    Dtype(DataType),
360}
361
362impl Expr {
363    /// Get Field result of the expression. The schema is the input data.
364    pub fn to_field(&self, schema: &Schema, ctxt: Context) -> PolarsResult<Field> {
365        // this is not called much and the expression depth is typically shallow
366        let mut arena = Arena::with_capacity(5);
367        self.to_field_amortized(schema, ctxt, &mut arena)
368    }
369    pub(crate) fn to_field_amortized(
370        &self,
371        schema: &Schema,
372        ctxt: Context,
373        expr_arena: &mut Arena<AExpr>,
374    ) -> PolarsResult<Field> {
375        let root = to_aexpr(self.clone(), expr_arena)?;
376        expr_arena
377            .get(root)
378            .to_field_and_validate(schema, ctxt, expr_arena)
379    }
380
381    /// Extract a constant usize from an expression.
382    pub fn extract_usize(&self) -> PolarsResult<usize> {
383        match self {
384            Expr::Literal(n) => n.extract_usize(),
385            Expr::Cast { expr, dtype, .. } => {
386                // lit(x, dtype=...) are Cast expressions. We verify the inner expression is literal.
387                if dtype.is_integer() {
388                    expr.extract_usize()
389                } else {
390                    polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
391                }
392            },
393            _ => {
394                polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
395            },
396        }
397    }
398
399    #[inline]
400    pub fn map_unary(self, function: impl Into<FunctionExpr>) -> Self {
401        Expr::n_ary(function, vec![self])
402    }
403    #[inline]
404    pub fn map_binary(self, function: impl Into<FunctionExpr>, rhs: Self) -> Self {
405        Expr::n_ary(function, vec![self, rhs])
406    }
407
408    #[inline]
409    pub fn map_ternary(self, function: impl Into<FunctionExpr>, arg1: Expr, arg2: Expr) -> Expr {
410        Expr::n_ary(function, vec![self, arg1, arg2])
411    }
412
413    #[inline]
414    pub fn try_map_n_ary(
415        self,
416        function: impl Into<FunctionExpr>,
417        exprs: impl IntoIterator<Item = PolarsResult<Expr>>,
418    ) -> PolarsResult<Expr> {
419        let exprs = exprs.into_iter();
420        let mut input = Vec::with_capacity(exprs.size_hint().0 + 1);
421        input.push(self);
422        for e in exprs {
423            input.push(e?);
424        }
425        Ok(Expr::n_ary(function, input))
426    }
427
428    #[inline]
429    pub fn map_n_ary(
430        self,
431        function: impl Into<FunctionExpr>,
432        exprs: impl IntoIterator<Item = Expr>,
433    ) -> Expr {
434        let exprs = exprs.into_iter();
435        let mut input = Vec::with_capacity(exprs.size_hint().0 + 1);
436        input.push(self);
437        input.extend(exprs);
438        Expr::n_ary(function, input)
439    }
440
441    #[inline]
442    pub fn n_ary(function: impl Into<FunctionExpr>, input: Vec<Expr>) -> Expr {
443        let function = function.into();
444        let options = function.function_options();
445        Expr::Function {
446            input,
447            function,
448            options,
449        }
450    }
451}
452
453#[derive(Copy, Clone, PartialEq, Eq, Hash)]
454#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
455pub enum Operator {
456    Eq,
457    EqValidity,
458    NotEq,
459    NotEqValidity,
460    Lt,
461    LtEq,
462    Gt,
463    GtEq,
464    Plus,
465    Minus,
466    Multiply,
467    Divide,
468    TrueDivide,
469    FloorDivide,
470    Modulus,
471    And,
472    Or,
473    Xor,
474    LogicalAnd,
475    LogicalOr,
476}
477
478impl Display for Operator {
479    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
480        use Operator::*;
481        let tkn = match self {
482            Eq => "==",
483            EqValidity => "==v",
484            NotEq => "!=",
485            NotEqValidity => "!=v",
486            Lt => "<",
487            LtEq => "<=",
488            Gt => ">",
489            GtEq => ">=",
490            Plus => "+",
491            Minus => "-",
492            Multiply => "*",
493            Divide => "//",
494            TrueDivide => "/",
495            FloorDivide => "floor_div",
496            Modulus => "%",
497            And | LogicalAnd => "&",
498            Or | LogicalOr => "|",
499            Xor => "^",
500        };
501        write!(f, "{tkn}")
502    }
503}
504
505impl Operator {
506    pub fn is_comparison(&self) -> bool {
507        matches!(
508            self,
509            Self::Eq
510                | Self::NotEq
511                | Self::Lt
512                | Self::LtEq
513                | Self::Gt
514                | Self::GtEq
515                | Self::EqValidity
516                | Self::NotEqValidity
517        )
518    }
519
520    pub fn is_bitwise(&self) -> bool {
521        matches!(self, Self::And | Self::Or | Self::Xor)
522    }
523
524    pub fn is_comparison_or_bitwise(&self) -> bool {
525        self.is_comparison() || self.is_bitwise()
526    }
527
528    pub fn swap_operands(self) -> Self {
529        match self {
530            Operator::Eq => Operator::Eq,
531            Operator::Gt => Operator::Lt,
532            Operator::GtEq => Operator::LtEq,
533            Operator::LtEq => Operator::GtEq,
534            Operator::Or => Operator::Or,
535            Operator::LogicalAnd => Operator::LogicalAnd,
536            Operator::LogicalOr => Operator::LogicalOr,
537            Operator::Xor => Operator::Xor,
538            Operator::NotEq => Operator::NotEq,
539            Operator::EqValidity => Operator::EqValidity,
540            Operator::NotEqValidity => Operator::NotEqValidity,
541            Operator::Divide => Operator::Multiply,
542            Operator::Multiply => Operator::Divide,
543            Operator::And => Operator::And,
544            Operator::Plus => Operator::Minus,
545            Operator::Minus => Operator::Plus,
546            Operator::Lt => Operator::Gt,
547            _ => unimplemented!(),
548        }
549    }
550
551    pub fn is_arithmetic(&self) -> bool {
552        !(self.is_comparison_or_bitwise())
553    }
554}