datafusion_python/
expr.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use datafusion::logical_expr::expr::{AggregateFunctionParams, WindowFunctionParams};
19use datafusion::logical_expr::utils::exprlist_to_fields;
20use datafusion::logical_expr::{
21    ExprFuncBuilder, ExprFunctionExt, LogicalPlan, WindowFunctionDefinition,
22};
23use pyo3::IntoPyObjectExt;
24use pyo3::{basic::CompareOp, prelude::*};
25use std::collections::HashMap;
26use std::convert::{From, Into};
27use std::sync::Arc;
28use window::PyWindowFrame;
29
30use datafusion::arrow::datatypes::{DataType, Field};
31use datafusion::arrow::pyarrow::PyArrowType;
32use datafusion::functions::core::expr_ext::FieldAccessor;
33use datafusion::logical_expr::{
34    col,
35    expr::{AggregateFunction, InList, InSubquery, ScalarFunction, WindowFunction},
36    lit, Between, BinaryExpr, Case, Cast, Expr, Like, Operator, TryCast,
37};
38
39use crate::common::data_type::{DataTypeMap, NullTreatment, PyScalarValue, RexType};
40use crate::errors::{py_runtime_err, py_type_err, py_unsupported_variant_err, PyDataFusionResult};
41use crate::expr::aggregate_expr::PyAggregateFunction;
42use crate::expr::binary_expr::PyBinaryExpr;
43use crate::expr::column::PyColumn;
44use crate::expr::literal::PyLiteral;
45use crate::functions::add_builder_fns_to_window;
46use crate::pyarrow_util::scalar_to_pyarrow;
47use crate::sql::logical::PyLogicalPlan;
48
49use self::alias::PyAlias;
50use self::bool_expr::{
51    PyIsFalse, PyIsNotFalse, PyIsNotNull, PyIsNotTrue, PyIsNotUnknown, PyIsNull, PyIsTrue,
52    PyIsUnknown, PyNegative, PyNot,
53};
54use self::like::{PyILike, PyLike, PySimilarTo};
55use self::scalar_variable::PyScalarVariable;
56
57pub mod aggregate;
58pub mod aggregate_expr;
59pub mod alias;
60pub mod analyze;
61pub mod between;
62pub mod binary_expr;
63pub mod bool_expr;
64pub mod case;
65pub mod cast;
66pub mod column;
67pub mod conditional_expr;
68pub mod copy_to;
69pub mod create_catalog;
70pub mod create_catalog_schema;
71pub mod create_external_table;
72pub mod create_function;
73pub mod create_index;
74pub mod create_memory_table;
75pub mod create_view;
76pub mod describe_table;
77pub mod distinct;
78pub mod dml;
79pub mod drop_catalog_schema;
80pub mod drop_function;
81pub mod drop_table;
82pub mod drop_view;
83pub mod empty_relation;
84pub mod exists;
85pub mod explain;
86pub mod extension;
87pub mod filter;
88pub mod grouping_set;
89pub mod in_list;
90pub mod in_subquery;
91pub mod join;
92pub mod like;
93pub mod limit;
94pub mod literal;
95pub mod logical_node;
96pub mod placeholder;
97pub mod projection;
98pub mod recursive_query;
99pub mod repartition;
100pub mod scalar_subquery;
101pub mod scalar_variable;
102pub mod signature;
103pub mod sort;
104pub mod sort_expr;
105pub mod statement;
106pub mod subquery;
107pub mod subquery_alias;
108pub mod table_scan;
109pub mod union;
110pub mod unnest;
111pub mod unnest_expr;
112pub mod values;
113pub mod window;
114
115use sort_expr::{to_sort_expressions, PySortExpr};
116
117/// A PyExpr that can be used on a DataFrame
118#[pyclass(name = "RawExpr", module = "datafusion.expr", subclass)]
119#[derive(Debug, Clone)]
120pub struct PyExpr {
121    pub expr: Expr,
122}
123
124impl From<PyExpr> for Expr {
125    fn from(expr: PyExpr) -> Expr {
126        expr.expr
127    }
128}
129
130impl From<Expr> for PyExpr {
131    fn from(expr: Expr) -> PyExpr {
132        PyExpr { expr }
133    }
134}
135
136/// Convert a list of DataFusion Expr to PyExpr
137pub fn py_expr_list(expr: &[Expr]) -> PyResult<Vec<PyExpr>> {
138    Ok(expr.iter().map(|e| PyExpr::from(e.clone())).collect())
139}
140
141#[pymethods]
142impl PyExpr {
143    /// Return the specific expression
144    fn to_variant<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
145        Python::with_gil(|_| {
146            match &self.expr {
147            Expr::Alias(alias) => Ok(PyAlias::from(alias.clone()).into_bound_py_any(py)?),
148            Expr::Column(col) => Ok(PyColumn::from(col.clone()).into_bound_py_any(py)?),
149            Expr::ScalarVariable(data_type, variables) => {
150                Ok(PyScalarVariable::new(data_type, variables).into_bound_py_any(py)?)
151            }
152            Expr::Like(value) => Ok(PyLike::from(value.clone()).into_bound_py_any(py)?),
153            Expr::Literal(value) => Ok(PyLiteral::from(value.clone()).into_bound_py_any(py)?),
154            Expr::BinaryExpr(expr) => Ok(PyBinaryExpr::from(expr.clone()).into_bound_py_any(py)?),
155            Expr::Not(expr) => Ok(PyNot::new(*expr.clone()).into_bound_py_any(py)?),
156            Expr::IsNotNull(expr) => Ok(PyIsNotNull::new(*expr.clone()).into_bound_py_any(py)?),
157            Expr::IsNull(expr) => Ok(PyIsNull::new(*expr.clone()).into_bound_py_any(py)?),
158            Expr::IsTrue(expr) => Ok(PyIsTrue::new(*expr.clone()).into_bound_py_any(py)?),
159            Expr::IsFalse(expr) => Ok(PyIsFalse::new(*expr.clone()).into_bound_py_any(py)?),
160            Expr::IsUnknown(expr) => Ok(PyIsUnknown::new(*expr.clone()).into_bound_py_any(py)?),
161            Expr::IsNotTrue(expr) => Ok(PyIsNotTrue::new(*expr.clone()).into_bound_py_any(py)?),
162            Expr::IsNotFalse(expr) => Ok(PyIsNotFalse::new(*expr.clone()).into_bound_py_any(py)?),
163            Expr::IsNotUnknown(expr) => Ok(PyIsNotUnknown::new(*expr.clone()).into_bound_py_any(py)?),
164            Expr::Negative(expr) => Ok(PyNegative::new(*expr.clone()).into_bound_py_any(py)?),
165            Expr::AggregateFunction(expr) => {
166                Ok(PyAggregateFunction::from(expr.clone()).into_bound_py_any(py)?)
167            }
168            Expr::SimilarTo(value) => Ok(PySimilarTo::from(value.clone()).into_bound_py_any(py)?),
169            Expr::Between(value) => Ok(between::PyBetween::from(value.clone()).into_bound_py_any(py)?),
170            Expr::Case(value) => Ok(case::PyCase::from(value.clone()).into_bound_py_any(py)?),
171            Expr::Cast(value) => Ok(cast::PyCast::from(value.clone()).into_bound_py_any(py)?),
172            Expr::TryCast(value) => Ok(cast::PyTryCast::from(value.clone()).into_bound_py_any(py)?),
173            Expr::ScalarFunction(value) => Err(py_unsupported_variant_err(format!(
174                "Converting Expr::ScalarFunction to a Python object is not implemented: {:?}",
175                value
176            ))),
177            Expr::WindowFunction(value) => Err(py_unsupported_variant_err(format!(
178                "Converting Expr::WindowFunction to a Python object is not implemented: {:?}",
179                value
180            ))),
181            Expr::InList(value) => Ok(in_list::PyInList::from(value.clone()).into_bound_py_any(py)?),
182            Expr::Exists(value) => Ok(exists::PyExists::from(value.clone()).into_bound_py_any(py)?),
183            Expr::InSubquery(value) => {
184                Ok(in_subquery::PyInSubquery::from(value.clone()).into_bound_py_any(py)?)
185            }
186            Expr::ScalarSubquery(value) => {
187                Ok(scalar_subquery::PyScalarSubquery::from(value.clone()).into_bound_py_any(py)?)
188            }
189            #[allow(deprecated)]
190            Expr::Wildcard { qualifier, options } => Err(py_unsupported_variant_err(format!(
191                "Converting Expr::Wildcard to a Python object is not implemented : {:?} {:?}",
192                qualifier, options
193            ))),
194            Expr::GroupingSet(value) => {
195                Ok(grouping_set::PyGroupingSet::from(value.clone()).into_bound_py_any(py)?)
196            }
197            Expr::Placeholder(value) => {
198                Ok(placeholder::PyPlaceholder::from(value.clone()).into_bound_py_any(py)?)
199            }
200            Expr::OuterReferenceColumn(data_type, column) => Err(py_unsupported_variant_err(format!(
201                "Converting Expr::OuterReferenceColumn to a Python object is not implemented: {:?} - {:?}",
202                data_type, column
203            ))),
204            Expr::Unnest(value) => Ok(unnest_expr::PyUnnestExpr::from(value.clone()).into_bound_py_any(py)?),
205        }
206        })
207    }
208
209    /// Returns the name of this expression as it should appear in a schema. This name
210    /// will not include any CAST expressions.
211    fn schema_name(&self) -> PyResult<String> {
212        Ok(format!("{}", self.expr.schema_name()))
213    }
214
215    /// Returns a full and complete string representation of this expression.
216    fn canonical_name(&self) -> PyResult<String> {
217        Ok(format!("{}", self.expr))
218    }
219
220    /// Returns the name of the Expr variant.
221    /// Ex: 'IsNotNull', 'Literal', 'BinaryExpr', etc
222    fn variant_name(&self) -> PyResult<&str> {
223        Ok(self.expr.variant_name())
224    }
225
226    fn __richcmp__(&self, other: PyExpr, op: CompareOp) -> PyExpr {
227        let expr = match op {
228            CompareOp::Lt => self.expr.clone().lt(other.expr),
229            CompareOp::Le => self.expr.clone().lt_eq(other.expr),
230            CompareOp::Eq => self.expr.clone().eq(other.expr),
231            CompareOp::Ne => self.expr.clone().not_eq(other.expr),
232            CompareOp::Gt => self.expr.clone().gt(other.expr),
233            CompareOp::Ge => self.expr.clone().gt_eq(other.expr),
234        };
235        expr.into()
236    }
237
238    fn __repr__(&self) -> PyResult<String> {
239        Ok(format!("Expr({})", self.expr))
240    }
241
242    fn __add__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
243        Ok((self.expr.clone() + rhs.expr).into())
244    }
245
246    fn __sub__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
247        Ok((self.expr.clone() - rhs.expr).into())
248    }
249
250    fn __truediv__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
251        Ok((self.expr.clone() / rhs.expr).into())
252    }
253
254    fn __mul__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
255        Ok((self.expr.clone() * rhs.expr).into())
256    }
257
258    fn __mod__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
259        let expr = self.expr.clone() % rhs.expr;
260        Ok(expr.into())
261    }
262
263    fn __and__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
264        Ok(self.expr.clone().and(rhs.expr).into())
265    }
266
267    fn __or__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
268        Ok(self.expr.clone().or(rhs.expr).into())
269    }
270
271    fn __invert__(&self) -> PyResult<PyExpr> {
272        let expr = !self.expr.clone();
273        Ok(expr.into())
274    }
275
276    fn __getitem__(&self, key: &str) -> PyResult<PyExpr> {
277        Ok(self.expr.clone().field(key).into())
278    }
279
280    #[staticmethod]
281    pub fn literal(value: PyScalarValue) -> PyExpr {
282        lit(value.0).into()
283    }
284
285    #[staticmethod]
286    pub fn column(value: &str) -> PyExpr {
287        col(value).into()
288    }
289
290    /// assign a name to the PyExpr
291    #[pyo3(signature = (name, metadata=None))]
292    pub fn alias(&self, name: &str, metadata: Option<HashMap<String, String>>) -> PyExpr {
293        self.expr.clone().alias_with_metadata(name, metadata).into()
294    }
295
296    /// Create a sort PyExpr from an existing PyExpr.
297    #[pyo3(signature = (ascending=true, nulls_first=true))]
298    pub fn sort(&self, ascending: bool, nulls_first: bool) -> PySortExpr {
299        self.expr.clone().sort(ascending, nulls_first).into()
300    }
301
302    pub fn is_null(&self) -> PyExpr {
303        self.expr.clone().is_null().into()
304    }
305
306    pub fn is_not_null(&self) -> PyExpr {
307        self.expr.clone().is_not_null().into()
308    }
309
310    pub fn cast(&self, to: PyArrowType<DataType>) -> PyExpr {
311        // self.expr.cast_to() requires DFSchema to validate that the cast
312        // is supported, omit that for now
313        let expr = Expr::Cast(Cast::new(Box::new(self.expr.clone()), to.0));
314        expr.into()
315    }
316
317    #[pyo3(signature = (low, high, negated=false))]
318    pub fn between(&self, low: PyExpr, high: PyExpr, negated: bool) -> PyExpr {
319        let expr = Expr::Between(Between::new(
320            Box::new(self.expr.clone()),
321            negated,
322            Box::new(low.into()),
323            Box::new(high.into()),
324        ));
325        expr.into()
326    }
327
328    /// A Rex (Row Expression) specifies a single row of data. That specification
329    /// could include user defined functions or types. RexType identifies the row
330    /// as one of the possible valid `RexTypes`.
331    pub fn rex_type(&self) -> PyResult<RexType> {
332        Ok(match self.expr {
333            Expr::Alias(..) => RexType::Alias,
334            Expr::Column(..) => RexType::Reference,
335            Expr::ScalarVariable(..) | Expr::Literal(..) => RexType::Literal,
336            Expr::BinaryExpr { .. }
337            | Expr::Not(..)
338            | Expr::IsNotNull(..)
339            | Expr::Negative(..)
340            | Expr::IsNull(..)
341            | Expr::Like { .. }
342            | Expr::SimilarTo { .. }
343            | Expr::Between { .. }
344            | Expr::Case { .. }
345            | Expr::Cast { .. }
346            | Expr::TryCast { .. }
347            | Expr::ScalarFunction { .. }
348            | Expr::AggregateFunction { .. }
349            | Expr::WindowFunction { .. }
350            | Expr::InList { .. }
351            | Expr::Exists { .. }
352            | Expr::InSubquery { .. }
353            | Expr::GroupingSet(..)
354            | Expr::IsTrue(..)
355            | Expr::IsFalse(..)
356            | Expr::IsUnknown(_)
357            | Expr::IsNotTrue(..)
358            | Expr::IsNotFalse(..)
359            | Expr::Placeholder { .. }
360            | Expr::OuterReferenceColumn(_, _)
361            | Expr::Unnest(_)
362            | Expr::IsNotUnknown(_) => RexType::Call,
363            Expr::ScalarSubquery(..) => RexType::ScalarSubquery,
364            #[allow(deprecated)]
365            Expr::Wildcard { .. } => {
366                return Err(py_unsupported_variant_err("Expr::Wildcard is unsupported"))
367            }
368        })
369    }
370
371    /// Given the current `Expr` return the DataTypeMap which represents the
372    /// PythonType, Arrow DataType, and SqlType Enum which represents
373    pub fn types(&self) -> PyResult<DataTypeMap> {
374        Self::_types(&self.expr)
375    }
376
377    /// Extracts the Expr value into a PyObject that can be shared with Python
378    pub fn python_value(&self, py: Python) -> PyResult<PyObject> {
379        match &self.expr {
380            Expr::Literal(scalar_value) => scalar_to_pyarrow(scalar_value, py),
381            _ => Err(py_type_err(format!(
382                "Non Expr::Literal encountered in types: {:?}",
383                &self.expr
384            ))),
385        }
386    }
387
388    /// Row expressions, Rex(s), operate on the concept of operands. Different variants of Expressions, Expr(s),
389    /// store those operands in different datastructures. This function examines the Expr variant and returns
390    /// the operands to the calling logic as a Vec of PyExpr instances.
391    pub fn rex_call_operands(&self) -> PyResult<Vec<PyExpr>> {
392        match &self.expr {
393            // Expr variants that are themselves the operand to return
394            Expr::Column(..) | Expr::ScalarVariable(..) | Expr::Literal(..) => {
395                Ok(vec![PyExpr::from(self.expr.clone())])
396            }
397
398            Expr::Alias(alias) => Ok(vec![PyExpr::from(*alias.expr.clone())]),
399
400            // Expr(s) that house the Expr instance to return in their bounded params
401            Expr::Not(expr)
402            | Expr::IsNull(expr)
403            | Expr::IsNotNull(expr)
404            | Expr::IsTrue(expr)
405            | Expr::IsFalse(expr)
406            | Expr::IsUnknown(expr)
407            | Expr::IsNotTrue(expr)
408            | Expr::IsNotFalse(expr)
409            | Expr::IsNotUnknown(expr)
410            | Expr::Negative(expr)
411            | Expr::Cast(Cast { expr, .. })
412            | Expr::TryCast(TryCast { expr, .. })
413            | Expr::InSubquery(InSubquery { expr, .. }) => Ok(vec![PyExpr::from(*expr.clone())]),
414
415            // Expr variants containing a collection of Expr(s) for operands
416            Expr::AggregateFunction(AggregateFunction {
417                params: AggregateFunctionParams { args, .. },
418                ..
419            })
420            | Expr::ScalarFunction(ScalarFunction { args, .. })
421            | Expr::WindowFunction(WindowFunction {
422                params: WindowFunctionParams { args, .. },
423                ..
424            }) => Ok(args.iter().map(|arg| PyExpr::from(arg.clone())).collect()),
425
426            // Expr(s) that require more specific processing
427            Expr::Case(Case {
428                expr,
429                when_then_expr,
430                else_expr,
431            }) => {
432                let mut operands: Vec<PyExpr> = Vec::new();
433
434                if let Some(e) = expr {
435                    for (when, then) in when_then_expr {
436                        operands.push(PyExpr::from(Expr::BinaryExpr(BinaryExpr::new(
437                            Box::new(*e.clone()),
438                            Operator::Eq,
439                            Box::new(*when.clone()),
440                        ))));
441                        operands.push(PyExpr::from(*then.clone()));
442                    }
443                } else {
444                    for (when, then) in when_then_expr {
445                        operands.push(PyExpr::from(*when.clone()));
446                        operands.push(PyExpr::from(*then.clone()));
447                    }
448                };
449
450                if let Some(e) = else_expr {
451                    operands.push(PyExpr::from(*e.clone()));
452                };
453
454                Ok(operands)
455            }
456            Expr::InList(InList { expr, list, .. }) => {
457                let mut operands: Vec<PyExpr> = vec![PyExpr::from(*expr.clone())];
458                for list_elem in list {
459                    operands.push(PyExpr::from(list_elem.clone()));
460                }
461
462                Ok(operands)
463            }
464            Expr::BinaryExpr(BinaryExpr { left, right, .. }) => Ok(vec![
465                PyExpr::from(*left.clone()),
466                PyExpr::from(*right.clone()),
467            ]),
468            Expr::Like(Like { expr, pattern, .. }) => Ok(vec![
469                PyExpr::from(*expr.clone()),
470                PyExpr::from(*pattern.clone()),
471            ]),
472            Expr::SimilarTo(Like { expr, pattern, .. }) => Ok(vec![
473                PyExpr::from(*expr.clone()),
474                PyExpr::from(*pattern.clone()),
475            ]),
476            Expr::Between(Between {
477                expr,
478                negated: _,
479                low,
480                high,
481            }) => Ok(vec![
482                PyExpr::from(*expr.clone()),
483                PyExpr::from(*low.clone()),
484                PyExpr::from(*high.clone()),
485            ]),
486
487            // Currently un-support/implemented Expr types for Rex Call operations
488            Expr::GroupingSet(..)
489            | Expr::Unnest(_)
490            | Expr::OuterReferenceColumn(_, _)
491            | Expr::ScalarSubquery(..)
492            | Expr::Placeholder { .. }
493            | Expr::Exists { .. } => Err(py_runtime_err(format!(
494                "Unimplemented Expr type: {}",
495                self.expr
496            ))),
497
498            #[allow(deprecated)]
499            Expr::Wildcard { .. } => {
500                Err(py_unsupported_variant_err("Expr::Wildcard is unsupported"))
501            }
502        }
503    }
504
505    /// Extracts the operator associated with a RexType::Call
506    pub fn rex_call_operator(&self) -> PyResult<String> {
507        Ok(match &self.expr {
508            Expr::BinaryExpr(BinaryExpr {
509                left: _,
510                op,
511                right: _,
512            }) => format!("{op}"),
513            Expr::ScalarFunction(ScalarFunction { func, args: _ }) => func.name().to_string(),
514            Expr::Cast { .. } => "cast".to_string(),
515            Expr::Between { .. } => "between".to_string(),
516            Expr::Case { .. } => "case".to_string(),
517            Expr::IsNull(..) => "is null".to_string(),
518            Expr::IsNotNull(..) => "is not null".to_string(),
519            Expr::IsTrue(_) => "is true".to_string(),
520            Expr::IsFalse(_) => "is false".to_string(),
521            Expr::IsUnknown(_) => "is unknown".to_string(),
522            Expr::IsNotTrue(_) => "is not true".to_string(),
523            Expr::IsNotFalse(_) => "is not false".to_string(),
524            Expr::IsNotUnknown(_) => "is not unknown".to_string(),
525            Expr::InList { .. } => "in list".to_string(),
526            Expr::Negative(..) => "negative".to_string(),
527            Expr::Not(..) => "not".to_string(),
528            Expr::Like(Like {
529                negated,
530                case_insensitive,
531                ..
532            }) => {
533                let name = if *case_insensitive { "ilike" } else { "like" };
534                if *negated {
535                    format!("not {name}")
536                } else {
537                    name.to_string()
538                }
539            }
540            Expr::SimilarTo(Like { negated, .. }) => {
541                if *negated {
542                    "not similar to".to_string()
543                } else {
544                    "similar to".to_string()
545                }
546            }
547            _ => {
548                return Err(py_type_err(format!(
549                    "Catch all triggered in get_operator_name: {:?}",
550                    &self.expr
551                )))
552            }
553        })
554    }
555
556    pub fn column_name(&self, plan: PyLogicalPlan) -> PyResult<String> {
557        self._column_name(&plan.plan()).map_err(py_runtime_err)
558    }
559
560    // Expression Function Builder functions
561
562    pub fn order_by(&self, order_by: Vec<PySortExpr>) -> PyExprFuncBuilder {
563        self.expr
564            .clone()
565            .order_by(to_sort_expressions(order_by))
566            .into()
567    }
568
569    pub fn filter(&self, filter: PyExpr) -> PyExprFuncBuilder {
570        self.expr.clone().filter(filter.expr.clone()).into()
571    }
572
573    pub fn distinct(&self) -> PyExprFuncBuilder {
574        self.expr.clone().distinct().into()
575    }
576
577    pub fn null_treatment(&self, null_treatment: NullTreatment) -> PyExprFuncBuilder {
578        self.expr
579            .clone()
580            .null_treatment(Some(null_treatment.into()))
581            .into()
582    }
583
584    pub fn partition_by(&self, partition_by: Vec<PyExpr>) -> PyExprFuncBuilder {
585        let partition_by = partition_by.iter().map(|e| e.expr.clone()).collect();
586        self.expr.clone().partition_by(partition_by).into()
587    }
588
589    pub fn window_frame(&self, window_frame: PyWindowFrame) -> PyExprFuncBuilder {
590        self.expr.clone().window_frame(window_frame.into()).into()
591    }
592
593    #[pyo3(signature = (partition_by=None, window_frame=None, order_by=None, null_treatment=None))]
594    pub fn over(
595        &self,
596        partition_by: Option<Vec<PyExpr>>,
597        window_frame: Option<PyWindowFrame>,
598        order_by: Option<Vec<PySortExpr>>,
599        null_treatment: Option<NullTreatment>,
600    ) -> PyDataFusionResult<PyExpr> {
601        match &self.expr {
602            Expr::AggregateFunction(agg_fn) => {
603                let window_fn = Expr::WindowFunction(WindowFunction::new(
604                    WindowFunctionDefinition::AggregateUDF(agg_fn.func.clone()),
605                    agg_fn.params.args.clone(),
606                ));
607
608                add_builder_fns_to_window(
609                    window_fn,
610                    partition_by,
611                    window_frame,
612                    order_by,
613                    null_treatment,
614                )
615            }
616            Expr::WindowFunction(_) => add_builder_fns_to_window(
617                self.expr.clone(),
618                partition_by,
619                window_frame,
620                order_by,
621                null_treatment,
622            ),
623            _ => Err(datafusion::error::DataFusionError::Plan(format!(
624                "Using {} with `over` is not allowed. Must use an aggregate or window function.",
625                self.expr.variant_name()
626            ))
627            .into()),
628        }
629    }
630}
631
632#[pyclass(name = "ExprFuncBuilder", module = "datafusion.expr", subclass)]
633#[derive(Debug, Clone)]
634pub struct PyExprFuncBuilder {
635    pub builder: ExprFuncBuilder,
636}
637
638impl From<ExprFuncBuilder> for PyExprFuncBuilder {
639    fn from(builder: ExprFuncBuilder) -> Self {
640        Self { builder }
641    }
642}
643
644#[pymethods]
645impl PyExprFuncBuilder {
646    pub fn order_by(&self, order_by: Vec<PySortExpr>) -> PyExprFuncBuilder {
647        self.builder
648            .clone()
649            .order_by(to_sort_expressions(order_by))
650            .into()
651    }
652
653    pub fn filter(&self, filter: PyExpr) -> PyExprFuncBuilder {
654        self.builder.clone().filter(filter.expr.clone()).into()
655    }
656
657    pub fn distinct(&self) -> PyExprFuncBuilder {
658        self.builder.clone().distinct().into()
659    }
660
661    pub fn null_treatment(&self, null_treatment: NullTreatment) -> PyExprFuncBuilder {
662        self.builder
663            .clone()
664            .null_treatment(Some(null_treatment.into()))
665            .into()
666    }
667
668    pub fn partition_by(&self, partition_by: Vec<PyExpr>) -> PyExprFuncBuilder {
669        let partition_by = partition_by.iter().map(|e| e.expr.clone()).collect();
670        self.builder.clone().partition_by(partition_by).into()
671    }
672
673    pub fn window_frame(&self, window_frame: PyWindowFrame) -> PyExprFuncBuilder {
674        self.builder
675            .clone()
676            .window_frame(window_frame.into())
677            .into()
678    }
679
680    pub fn build(&self) -> PyDataFusionResult<PyExpr> {
681        Ok(self.builder.clone().build().map(|expr| expr.into())?)
682    }
683}
684
685impl PyExpr {
686    pub fn _column_name(&self, plan: &LogicalPlan) -> PyDataFusionResult<String> {
687        let field = Self::expr_to_field(&self.expr, plan)?;
688        Ok(field.name().to_owned())
689    }
690
691    /// Create a [Field] representing an [Expr], given an input [LogicalPlan] to resolve against
692    pub fn expr_to_field(expr: &Expr, input_plan: &LogicalPlan) -> PyDataFusionResult<Arc<Field>> {
693        let fields = exprlist_to_fields(&[expr.clone()], input_plan)?;
694        Ok(fields[0].1.clone())
695    }
696    fn _types(expr: &Expr) -> PyResult<DataTypeMap> {
697        match expr {
698            Expr::BinaryExpr(BinaryExpr {
699                left: _,
700                op,
701                right: _,
702            }) => match op {
703                Operator::Eq
704                | Operator::NotEq
705                | Operator::Lt
706                | Operator::LtEq
707                | Operator::Gt
708                | Operator::GtEq
709                | Operator::And
710                | Operator::Or
711                | Operator::IsDistinctFrom
712                | Operator::IsNotDistinctFrom
713                | Operator::RegexMatch
714                | Operator::RegexIMatch
715                | Operator::RegexNotMatch
716                | Operator::RegexNotIMatch
717                | Operator::LikeMatch
718                | Operator::ILikeMatch
719                | Operator::NotLikeMatch
720                | Operator::NotILikeMatch => DataTypeMap::map_from_arrow_type(&DataType::Boolean),
721                Operator::Plus | Operator::Minus | Operator::Multiply | Operator::Modulo => {
722                    DataTypeMap::map_from_arrow_type(&DataType::Int64)
723                }
724                Operator::Divide => DataTypeMap::map_from_arrow_type(&DataType::Float64),
725                Operator::StringConcat => DataTypeMap::map_from_arrow_type(&DataType::Utf8),
726                Operator::BitwiseShiftLeft
727                | Operator::BitwiseShiftRight
728                | Operator::BitwiseXor
729                | Operator::BitwiseAnd
730                | Operator::BitwiseOr => DataTypeMap::map_from_arrow_type(&DataType::Binary),
731                Operator::AtArrow
732                | Operator::ArrowAt
733                | Operator::Arrow
734                | Operator::LongArrow
735                | Operator::HashArrow
736                | Operator::HashLongArrow
737                | Operator::AtAt
738                | Operator::IntegerDivide
739                | Operator::HashMinus
740                | Operator::AtQuestion
741                | Operator::Question
742                | Operator::QuestionAnd
743                | Operator::QuestionPipe => Err(py_type_err(format!("Unsupported expr: ${op}"))),
744            },
745            Expr::Cast(Cast { expr: _, data_type }) => DataTypeMap::map_from_arrow_type(data_type),
746            Expr::Literal(scalar_value) => DataTypeMap::map_from_scalar_value(scalar_value),
747            _ => Err(py_type_err(format!(
748                "Non Expr::Literal encountered in types: {:?}",
749                expr
750            ))),
751        }
752    }
753}
754
755/// Initializes the `expr` module to match the pattern of `datafusion-expr` https://docs.rs/datafusion-expr/latest/datafusion_expr/
756pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
757    m.add_class::<PyExpr>()?;
758    m.add_class::<PyColumn>()?;
759    m.add_class::<PyLiteral>()?;
760    m.add_class::<PyBinaryExpr>()?;
761    m.add_class::<PyLiteral>()?;
762    m.add_class::<PyAggregateFunction>()?;
763    m.add_class::<PyNot>()?;
764    m.add_class::<PyIsNotNull>()?;
765    m.add_class::<PyIsNull>()?;
766    m.add_class::<PyIsTrue>()?;
767    m.add_class::<PyIsFalse>()?;
768    m.add_class::<PyIsUnknown>()?;
769    m.add_class::<PyIsNotTrue>()?;
770    m.add_class::<PyIsNotFalse>()?;
771    m.add_class::<PyIsNotUnknown>()?;
772    m.add_class::<PyNegative>()?;
773    m.add_class::<PyLike>()?;
774    m.add_class::<PyILike>()?;
775    m.add_class::<PySimilarTo>()?;
776    m.add_class::<PyScalarVariable>()?;
777    m.add_class::<alias::PyAlias>()?;
778    m.add_class::<in_list::PyInList>()?;
779    m.add_class::<exists::PyExists>()?;
780    m.add_class::<subquery::PySubquery>()?;
781    m.add_class::<in_subquery::PyInSubquery>()?;
782    m.add_class::<scalar_subquery::PyScalarSubquery>()?;
783    m.add_class::<placeholder::PyPlaceholder>()?;
784    m.add_class::<grouping_set::PyGroupingSet>()?;
785    m.add_class::<case::PyCase>()?;
786    m.add_class::<conditional_expr::PyCaseBuilder>()?;
787    m.add_class::<cast::PyCast>()?;
788    m.add_class::<cast::PyTryCast>()?;
789    m.add_class::<between::PyBetween>()?;
790    m.add_class::<explain::PyExplain>()?;
791    m.add_class::<limit::PyLimit>()?;
792    m.add_class::<aggregate::PyAggregate>()?;
793    m.add_class::<sort::PySort>()?;
794    m.add_class::<analyze::PyAnalyze>()?;
795    m.add_class::<empty_relation::PyEmptyRelation>()?;
796    m.add_class::<join::PyJoin>()?;
797    m.add_class::<join::PyJoinType>()?;
798    m.add_class::<join::PyJoinConstraint>()?;
799    m.add_class::<union::PyUnion>()?;
800    m.add_class::<unnest::PyUnnest>()?;
801    m.add_class::<unnest_expr::PyUnnestExpr>()?;
802    m.add_class::<extension::PyExtension>()?;
803    m.add_class::<filter::PyFilter>()?;
804    m.add_class::<projection::PyProjection>()?;
805    m.add_class::<table_scan::PyTableScan>()?;
806    m.add_class::<create_memory_table::PyCreateMemoryTable>()?;
807    m.add_class::<create_view::PyCreateView>()?;
808    m.add_class::<distinct::PyDistinct>()?;
809    m.add_class::<sort_expr::PySortExpr>()?;
810    m.add_class::<subquery_alias::PySubqueryAlias>()?;
811    m.add_class::<drop_table::PyDropTable>()?;
812    m.add_class::<repartition::PyPartitioning>()?;
813    m.add_class::<repartition::PyRepartition>()?;
814    m.add_class::<window::PyWindowExpr>()?;
815    m.add_class::<window::PyWindowFrame>()?;
816    m.add_class::<window::PyWindowFrameBound>()?;
817    m.add_class::<copy_to::PyCopyTo>()?;
818    m.add_class::<copy_to::PyFileType>()?;
819    m.add_class::<create_catalog::PyCreateCatalog>()?;
820    m.add_class::<create_catalog_schema::PyCreateCatalogSchema>()?;
821    m.add_class::<create_external_table::PyCreateExternalTable>()?;
822    m.add_class::<create_function::PyCreateFunction>()?;
823    m.add_class::<create_function::PyOperateFunctionArg>()?;
824    m.add_class::<create_function::PyCreateFunctionBody>()?;
825    m.add_class::<create_index::PyCreateIndex>()?;
826    m.add_class::<describe_table::PyDescribeTable>()?;
827    m.add_class::<dml::PyDmlStatement>()?;
828    m.add_class::<drop_catalog_schema::PyDropCatalogSchema>()?;
829    m.add_class::<drop_function::PyDropFunction>()?;
830    m.add_class::<drop_view::PyDropView>()?;
831    m.add_class::<recursive_query::PyRecursiveQuery>()?;
832
833    m.add_class::<statement::PyTransactionStart>()?;
834    m.add_class::<statement::PyTransactionEnd>()?;
835    m.add_class::<statement::PySetVariable>()?;
836    m.add_class::<statement::PyPrepare>()?;
837    m.add_class::<statement::PyExecute>()?;
838    m.add_class::<statement::PyDeallocate>()?;
839    m.add_class::<values::PyValues>()?;
840    m.add_class::<statement::PyTransactionAccessMode>()?;
841    m.add_class::<statement::PyTransactionConclusion>()?;
842    m.add_class::<statement::PyTransactionIsolationLevel>()?;
843
844    Ok(())
845}