polars_plan/dsl/expr/
mod.rs

1mod expr_dyn_fn;
2use std::fmt::{Debug, Display, Formatter};
3use std::hash::{Hash, Hasher};
4
5use bytes::Bytes;
6pub use expr_dyn_fn::*;
7use polars_compute::rolling::QuantileMethod;
8use polars_core::chunked_array::cast::CastOptions;
9use polars_core::error::feature_gated;
10use polars_core::prelude::*;
11#[cfg(feature = "serde")]
12use serde::{Deserialize, Serialize};
13#[cfg(feature = "serde")]
14pub mod named_serde;
15#[cfg(feature = "serde")]
16mod serde_expr;
17
18use crate::prelude::*;
19
20#[derive(PartialEq, Clone, Hash)]
21#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
22pub enum AggExpr {
23    Min {
24        input: Arc<Expr>,
25        propagate_nans: bool,
26    },
27    Max {
28        input: Arc<Expr>,
29        propagate_nans: bool,
30    },
31    Median(Arc<Expr>),
32    NUnique(Arc<Expr>),
33    First(Arc<Expr>),
34    Last(Arc<Expr>),
35    Mean(Arc<Expr>),
36    Implode(Arc<Expr>),
37    // include_nulls
38    Count(Arc<Expr>, bool),
39    Quantile {
40        expr: Arc<Expr>,
41        quantile: Arc<Expr>,
42        method: QuantileMethod,
43    },
44    Sum(Arc<Expr>),
45    AggGroups(Arc<Expr>),
46    Std(Arc<Expr>, u8),
47    Var(Arc<Expr>, u8),
48}
49
50impl AsRef<Expr> for AggExpr {
51    fn as_ref(&self) -> &Expr {
52        use AggExpr::*;
53        match self {
54            Min { input, .. } => input,
55            Max { input, .. } => input,
56            Median(e) => e,
57            NUnique(e) => e,
58            First(e) => e,
59            Last(e) => e,
60            Mean(e) => e,
61            Implode(e) => e,
62            Count(e, _) => e,
63            Quantile { expr, .. } => expr,
64            Sum(e) => e,
65            AggGroups(e) => e,
66            Std(e, _) => e,
67            Var(e, _) => e,
68        }
69    }
70}
71
72/// Expressions that can be used in various contexts.
73///
74/// Queries consist of multiple expressions.
75/// When using the polars lazy API, don't construct an `Expr` directly; instead, create one using
76/// the functions in the `polars_lazy::dsl` module. See that module's docs for more info.
77#[derive(Clone, PartialEq)]
78#[must_use]
79#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
80pub enum Expr {
81    Alias(Arc<Expr>, PlSmallStr),
82    Column(PlSmallStr),
83    Columns(Arc<[PlSmallStr]>),
84    DtypeColumn(Vec<DataType>),
85    IndexColumn(Arc<[i64]>),
86    Literal(LiteralValue),
87    BinaryExpr {
88        left: Arc<Expr>,
89        op: Operator,
90        right: Arc<Expr>,
91    },
92    Cast {
93        expr: Arc<Expr>,
94        dtype: DataType,
95        options: CastOptions,
96    },
97    Sort {
98        expr: Arc<Expr>,
99        options: SortOptions,
100    },
101    Gather {
102        expr: Arc<Expr>,
103        idx: Arc<Expr>,
104        returns_scalar: bool,
105    },
106    SortBy {
107        expr: Arc<Expr>,
108        by: Vec<Expr>,
109        sort_options: SortMultipleOptions,
110    },
111    Agg(AggExpr),
112    /// A ternary operation
113    /// if true then "foo" else "bar"
114    Ternary {
115        predicate: Arc<Expr>,
116        truthy: Arc<Expr>,
117        falsy: Arc<Expr>,
118    },
119    Function {
120        /// function arguments
121        input: Vec<Expr>,
122        /// function to apply
123        function: FunctionExpr,
124        options: FunctionOptions,
125    },
126    Explode {
127        input: Arc<Expr>,
128        skip_empty: bool,
129    },
130    Filter {
131        input: Arc<Expr>,
132        by: Arc<Expr>,
133    },
134    /// Polars flavored window functions.
135    Window {
136        /// Also has the input. i.e. avg("foo")
137        function: Arc<Expr>,
138        partition_by: Vec<Expr>,
139        order_by: Option<(Arc<Expr>, SortOptions)>,
140        options: WindowType,
141    },
142    Wildcard,
143    Slice {
144        input: Arc<Expr>,
145        /// length is not yet known so we accept negative offsets
146        offset: Arc<Expr>,
147        length: Arc<Expr>,
148    },
149    /// Can be used in a select statement to exclude a column from selection
150    /// TODO: See if we can replace `Vec<Excluded>` with `Arc<Excluded>`
151    Exclude(Arc<Expr>, Vec<Excluded>),
152    /// Set root name as Alias
153    KeepName(Arc<Expr>),
154    Len,
155    /// Take the nth column in the `DataFrame`
156    Nth(i64),
157    RenameAlias {
158        function: SpecialEq<Arc<dyn RenameAliasFn>>,
159        expr: Arc<Expr>,
160    },
161    #[cfg(feature = "dtype-struct")]
162    Field(Arc<[PlSmallStr]>),
163    AnonymousFunction {
164        /// function arguments
165        input: Vec<Expr>,
166        /// function to apply
167        function: OpaqueColumnUdf,
168        /// output dtype of the function
169        output_type: GetOutput,
170        options: FunctionOptions,
171    },
172    SubPlan(SpecialEq<Arc<DslPlan>>, Vec<String>),
173    /// Expressions in this node should only be expanding
174    /// e.g.
175    /// `Expr::Columns`
176    /// `Expr::Dtypes`
177    /// `Expr::Wildcard`
178    /// `Expr::Exclude`
179    Selector(super::selector::Selector),
180}
181
182#[derive(Clone)]
183pub enum LazySerde<T: Clone> {
184    Deserialized(T),
185    Bytes(Bytes),
186    // Used by cloud
187    Named {
188        // Name and payload are used by the NamedRegistry
189        // To load the function `T` at runtime.
190        name: String,
191        payload: Option<Bytes>,
192        // Sometimes we need the function `T` before sending
193        // to a different machine, so optionally set it as well.
194        value: Option<T>,
195    },
196}
197
198impl<T: PartialEq + Clone> PartialEq for LazySerde<T> {
199    fn eq(&self, other: &Self) -> bool {
200        use LazySerde as L;
201        match (self, other) {
202            (L::Deserialized(a), L::Deserialized(b)) => a == b,
203            (L::Bytes(a), L::Bytes(b)) => {
204                std::ptr::eq(a.as_ptr(), b.as_ptr()) && a.len() == b.len()
205            },
206            (
207                L::Named {
208                    name: l,
209                    payload: pl,
210                    value: _,
211                },
212                L::Named {
213                    name: r,
214                    payload: pr,
215                    value: _,
216                },
217            ) => {
218                #[cfg(debug_assertions)]
219                {
220                    if l == r {
221                        assert_eq!(pl, pr, "name should point to unique payload")
222                    }
223                }
224                _ = pl;
225                _ = pr;
226                l == r
227            },
228            _ => false,
229        }
230    }
231}
232
233impl<T: Clone> Debug for LazySerde<T> {
234    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
235        match self {
236            Self::Bytes(_) => write!(f, "lazy-serde<Bytes>"),
237            Self::Deserialized(_) => write!(f, "lazy-serde<T>"),
238            Self::Named {
239                name,
240                payload: _,
241                value: _,
242            } => write!(f, "lazy-serde<Named>: {}", name),
243        }
244    }
245}
246
247#[allow(clippy::derived_hash_with_manual_eq)]
248impl Hash for Expr {
249    fn hash<H: Hasher>(&self, state: &mut H) {
250        let d = std::mem::discriminant(self);
251        d.hash(state);
252        match self {
253            Expr::Column(name) => name.hash(state),
254            Expr::Columns(names) => names.hash(state),
255            Expr::DtypeColumn(dtypes) => dtypes.hash(state),
256            Expr::IndexColumn(indices) => indices.hash(state),
257            Expr::Literal(lv) => std::mem::discriminant(lv).hash(state),
258            Expr::Selector(s) => s.hash(state),
259            Expr::Nth(v) => v.hash(state),
260            Expr::Filter { input, by } => {
261                input.hash(state);
262                by.hash(state);
263            },
264            Expr::BinaryExpr { left, op, right } => {
265                left.hash(state);
266                right.hash(state);
267                std::mem::discriminant(op).hash(state)
268            },
269            Expr::Cast {
270                expr,
271                dtype,
272                options: strict,
273            } => {
274                expr.hash(state);
275                dtype.hash(state);
276                strict.hash(state)
277            },
278            Expr::Sort { expr, options } => {
279                expr.hash(state);
280                options.hash(state);
281            },
282            Expr::Alias(input, name) => {
283                input.hash(state);
284                name.hash(state)
285            },
286            Expr::KeepName(input) => input.hash(state),
287            Expr::Ternary {
288                predicate,
289                truthy,
290                falsy,
291            } => {
292                predicate.hash(state);
293                truthy.hash(state);
294                falsy.hash(state);
295            },
296            Expr::Function {
297                input,
298                function,
299                options,
300            } => {
301                input.hash(state);
302                std::mem::discriminant(function).hash(state);
303                options.hash(state);
304            },
305            Expr::Gather {
306                expr,
307                idx,
308                returns_scalar,
309            } => {
310                expr.hash(state);
311                idx.hash(state);
312                returns_scalar.hash(state);
313            },
314            // already hashed by discriminant
315            Expr::Wildcard | Expr::Len => {},
316            Expr::SortBy {
317                expr,
318                by,
319                sort_options,
320            } => {
321                expr.hash(state);
322                by.hash(state);
323                sort_options.hash(state);
324            },
325            Expr::Agg(input) => input.hash(state),
326            Expr::Explode { input, skip_empty } => {
327                skip_empty.hash(state);
328                input.hash(state)
329            },
330            Expr::Window {
331                function,
332                partition_by,
333                order_by,
334                options,
335            } => {
336                function.hash(state);
337                partition_by.hash(state);
338                order_by.hash(state);
339                options.hash(state);
340            },
341            Expr::Slice {
342                input,
343                offset,
344                length,
345            } => {
346                input.hash(state);
347                offset.hash(state);
348                length.hash(state);
349            },
350            Expr::Exclude(input, excl) => {
351                input.hash(state);
352                excl.hash(state);
353            },
354            Expr::RenameAlias { function: _, expr } => expr.hash(state),
355            Expr::AnonymousFunction {
356                input,
357                function: _,
358                output_type: _,
359                options,
360            } => {
361                input.hash(state);
362                options.hash(state);
363            },
364            Expr::SubPlan(_, names) => names.hash(state),
365            #[cfg(feature = "dtype-struct")]
366            Expr::Field(names) => names.hash(state),
367        }
368    }
369}
370
371impl Eq for Expr {}
372
373impl Default for Expr {
374    fn default() -> Self {
375        Expr::Literal(LiteralValue::Scalar(Scalar::default()))
376    }
377}
378
379#[derive(Debug, Clone, PartialEq, Eq, Hash)]
380#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
381pub enum Excluded {
382    Name(PlSmallStr),
383    Dtype(DataType),
384}
385
386impl Expr {
387    /// Get Field result of the expression. The schema is the input data.
388    pub fn to_field(&self, schema: &Schema, ctxt: Context) -> PolarsResult<Field> {
389        // this is not called much and the expression depth is typically shallow
390        let mut arena = Arena::with_capacity(5);
391        self.to_field_amortized(schema, ctxt, &mut arena)
392    }
393    pub(crate) fn to_field_amortized(
394        &self,
395        schema: &Schema,
396        ctxt: Context,
397        expr_arena: &mut Arena<AExpr>,
398    ) -> PolarsResult<Field> {
399        let root = to_aexpr(self.clone(), expr_arena)?;
400        expr_arena
401            .get(root)
402            .to_field_and_validate(schema, ctxt, expr_arena)
403    }
404
405    /// Extract a constant usize from an expression.
406    pub fn extract_usize(&self) -> PolarsResult<usize> {
407        match self {
408            Expr::Literal(n) => n.extract_usize(),
409            Expr::Cast { expr, dtype, .. } => {
410                // lit(x, dtype=...) are Cast expressions. We verify the inner expression is literal.
411                if dtype.is_integer() {
412                    expr.extract_usize()
413                } else {
414                    polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
415                }
416            },
417            _ => {
418                polars_bail!(InvalidOperation: "expression must be constant literal to extract integer")
419            },
420        }
421    }
422
423    #[inline]
424    pub fn map_unary(self, function: impl Into<FunctionExpr>) -> Self {
425        Expr::n_ary(function, vec![self])
426    }
427    #[inline]
428    pub fn map_binary(self, function: impl Into<FunctionExpr>, rhs: Self) -> Self {
429        Expr::n_ary(function, vec![self, rhs])
430    }
431
432    #[inline]
433    pub fn map_ternary(self, function: impl Into<FunctionExpr>, arg1: Expr, arg2: Expr) -> Expr {
434        Expr::n_ary(function, vec![self, arg1, arg2])
435    }
436
437    #[inline]
438    pub fn try_map_n_ary(
439        self,
440        function: impl Into<FunctionExpr>,
441        exprs: impl IntoIterator<Item = PolarsResult<Expr>>,
442    ) -> PolarsResult<Expr> {
443        let exprs = exprs.into_iter();
444        let mut input = Vec::with_capacity(exprs.size_hint().0 + 1);
445        input.push(self);
446        for e in exprs {
447            input.push(e?);
448        }
449        Ok(Expr::n_ary(function, input))
450    }
451
452    #[inline]
453    pub fn map_n_ary(
454        self,
455        function: impl Into<FunctionExpr>,
456        exprs: impl IntoIterator<Item = Expr>,
457    ) -> Expr {
458        let exprs = exprs.into_iter();
459        let mut input = Vec::with_capacity(exprs.size_hint().0 + 1);
460        input.push(self);
461        input.extend(exprs);
462        Expr::n_ary(function, input)
463    }
464
465    #[inline]
466    pub fn n_ary(function: impl Into<FunctionExpr>, input: Vec<Expr>) -> Expr {
467        let function = function.into();
468        let options = function.function_options();
469        Expr::Function {
470            input,
471            function,
472            options,
473        }
474    }
475}
476
477#[derive(Copy, Clone, PartialEq, Eq, Hash)]
478#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
479pub enum Operator {
480    Eq,
481    EqValidity,
482    NotEq,
483    NotEqValidity,
484    Lt,
485    LtEq,
486    Gt,
487    GtEq,
488    Plus,
489    Minus,
490    Multiply,
491    Divide,
492    TrueDivide,
493    FloorDivide,
494    Modulus,
495    And,
496    Or,
497    Xor,
498    LogicalAnd,
499    LogicalOr,
500}
501
502impl Display for Operator {
503    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
504        use Operator::*;
505        let tkn = match self {
506            Eq => "==",
507            EqValidity => "==v",
508            NotEq => "!=",
509            NotEqValidity => "!=v",
510            Lt => "<",
511            LtEq => "<=",
512            Gt => ">",
513            GtEq => ">=",
514            Plus => "+",
515            Minus => "-",
516            Multiply => "*",
517            Divide => "//",
518            TrueDivide => "/",
519            FloorDivide => "floor_div",
520            Modulus => "%",
521            And | LogicalAnd => "&",
522            Or | LogicalOr => "|",
523            Xor => "^",
524        };
525        write!(f, "{tkn}")
526    }
527}
528
529impl Operator {
530    pub fn is_comparison(&self) -> bool {
531        matches!(
532            self,
533            Self::Eq
534                | Self::NotEq
535                | Self::Lt
536                | Self::LtEq
537                | Self::Gt
538                | Self::GtEq
539                | Self::EqValidity
540                | Self::NotEqValidity
541        )
542    }
543
544    pub fn is_bitwise(&self) -> bool {
545        matches!(self, Self::And | Self::Or | Self::Xor)
546    }
547
548    pub fn is_comparison_or_bitwise(&self) -> bool {
549        self.is_comparison() || self.is_bitwise()
550    }
551
552    pub fn swap_operands(self) -> Self {
553        match self {
554            Operator::Eq => Operator::Eq,
555            Operator::Gt => Operator::Lt,
556            Operator::GtEq => Operator::LtEq,
557            Operator::LtEq => Operator::GtEq,
558            Operator::Or => Operator::Or,
559            Operator::LogicalAnd => Operator::LogicalAnd,
560            Operator::LogicalOr => Operator::LogicalOr,
561            Operator::Xor => Operator::Xor,
562            Operator::NotEq => Operator::NotEq,
563            Operator::EqValidity => Operator::EqValidity,
564            Operator::NotEqValidity => Operator::NotEqValidity,
565            Operator::Divide => Operator::Multiply,
566            Operator::Multiply => Operator::Divide,
567            Operator::And => Operator::And,
568            Operator::Plus => Operator::Minus,
569            Operator::Minus => Operator::Plus,
570            Operator::Lt => Operator::Gt,
571            _ => unimplemented!(),
572        }
573    }
574
575    pub fn is_arithmetic(&self) -> bool {
576        !(self.is_comparison_or_bitwise())
577    }
578}