Skip to main content

datafusion_python/
expr.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19use std::convert::{From, Into};
20use std::sync::Arc;
21
22use datafusion::arrow::datatypes::{DataType, Field};
23use datafusion::arrow::pyarrow::PyArrowType;
24use datafusion::functions::core::expr_ext::FieldAccessor;
25use datafusion::logical_expr::expr::{
26    AggregateFunction, AggregateFunctionParams, FieldMetadata, InList, InSubquery, ScalarFunction,
27    SetComparison, WindowFunction,
28};
29use datafusion::logical_expr::utils::exprlist_to_fields;
30use datafusion::logical_expr::{
31    Between, BinaryExpr, Case, Cast, Expr, ExprFuncBuilder, ExprFunctionExt, Like, LogicalPlan,
32    Operator, TryCast, WindowFunctionDefinition, col, lit, lit_with_metadata,
33};
34use pyo3::IntoPyObjectExt;
35use pyo3::basic::CompareOp;
36use pyo3::prelude::*;
37use window::PyWindowFrame;
38
39use self::alias::PyAlias;
40use self::bool_expr::{
41    PyIsFalse, PyIsNotFalse, PyIsNotNull, PyIsNotTrue, PyIsNotUnknown, PyIsNull, PyIsTrue,
42    PyIsUnknown, PyNegative, PyNot,
43};
44use self::like::{PyILike, PyLike, PySimilarTo};
45use self::scalar_variable::PyScalarVariable;
46use crate::common::data_type::{DataTypeMap, NullTreatment, PyScalarValue, RexType};
47use crate::errors::{PyDataFusionResult, py_runtime_err, py_type_err, py_unsupported_variant_err};
48use crate::expr::aggregate_expr::PyAggregateFunction;
49use crate::expr::binary_expr::PyBinaryExpr;
50use crate::expr::column::PyColumn;
51use crate::expr::literal::PyLiteral;
52use crate::functions::add_builder_fns_to_window;
53use crate::pyarrow_util::scalar_to_pyarrow;
54use crate::sql::logical::PyLogicalPlan;
55
56pub mod aggregate;
57pub mod aggregate_expr;
58pub mod alias;
59pub mod analyze;
60pub mod between;
61pub mod binary_expr;
62pub mod bool_expr;
63pub mod case;
64pub mod cast;
65pub mod column;
66pub mod conditional_expr;
67pub mod copy_to;
68pub mod create_catalog;
69pub mod create_catalog_schema;
70pub mod create_external_table;
71pub mod create_function;
72pub mod create_index;
73pub mod create_memory_table;
74pub mod create_view;
75pub mod describe_table;
76pub mod distinct;
77pub mod dml;
78pub mod drop_catalog_schema;
79pub mod drop_function;
80pub mod drop_table;
81pub mod drop_view;
82pub mod empty_relation;
83pub mod exists;
84pub mod explain;
85pub mod extension;
86pub mod filter;
87pub mod grouping_set;
88pub mod in_list;
89pub mod in_subquery;
90pub mod join;
91pub mod like;
92pub mod limit;
93pub mod literal;
94pub mod logical_node;
95pub mod placeholder;
96pub mod projection;
97pub mod recursive_query;
98pub mod repartition;
99pub mod scalar_subquery;
100pub mod scalar_variable;
101pub mod set_comparison;
102pub mod signature;
103pub mod sort;
104pub mod sort_expr;
105pub mod statement;
106pub mod subquery;
107pub mod subquery_alias;
108pub mod table_scan;
109pub mod union;
110pub mod unnest;
111pub mod unnest_expr;
112pub mod values;
113pub mod window;
114
115use sort_expr::{PySortExpr, to_sort_expressions};
116
117/// A PyExpr that can be used on a DataFrame
118#[pyclass(
119    from_py_object,
120    frozen,
121    name = "RawExpr",
122    module = "datafusion.expr",
123    subclass
124)]
125#[derive(Debug, Clone)]
126pub struct PyExpr {
127    pub expr: Expr,
128}
129
130impl From<PyExpr> for Expr {
131    fn from(expr: PyExpr) -> Expr {
132        expr.expr
133    }
134}
135
136impl From<Expr> for PyExpr {
137    fn from(expr: Expr) -> PyExpr {
138        PyExpr { expr }
139    }
140}
141
142/// Convert a list of DataFusion Expr to PyExpr
143pub fn py_expr_list(expr: &[Expr]) -> PyResult<Vec<PyExpr>> {
144    Ok(expr.iter().map(|e| PyExpr::from(e.clone())).collect())
145}
146
147#[pymethods]
148impl PyExpr {
149    /// Return the specific expression
150    fn to_variant<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
151        Python::attach(|_| match &self.expr {
152            Expr::Alias(alias) => Ok(PyAlias::from(alias.clone()).into_bound_py_any(py)?),
153            Expr::Column(col) => Ok(PyColumn::from(col.clone()).into_bound_py_any(py)?),
154            Expr::ScalarVariable(field, variables) => {
155                Ok(PyScalarVariable::new(field, variables).into_bound_py_any(py)?)
156            }
157            Expr::Like(value) => Ok(PyLike::from(value.clone()).into_bound_py_any(py)?),
158            Expr::Literal(value, metadata) => Ok(PyLiteral::new_with_metadata(
159                value.clone(),
160                metadata.clone(),
161            )
162            .into_bound_py_any(py)?),
163            Expr::BinaryExpr(expr) => Ok(PyBinaryExpr::from(expr.clone()).into_bound_py_any(py)?),
164            Expr::Not(expr) => Ok(PyNot::new(*expr.clone()).into_bound_py_any(py)?),
165            Expr::IsNotNull(expr) => Ok(PyIsNotNull::new(*expr.clone()).into_bound_py_any(py)?),
166            Expr::IsNull(expr) => Ok(PyIsNull::new(*expr.clone()).into_bound_py_any(py)?),
167            Expr::IsTrue(expr) => Ok(PyIsTrue::new(*expr.clone()).into_bound_py_any(py)?),
168            Expr::IsFalse(expr) => Ok(PyIsFalse::new(*expr.clone()).into_bound_py_any(py)?),
169            Expr::IsUnknown(expr) => Ok(PyIsUnknown::new(*expr.clone()).into_bound_py_any(py)?),
170            Expr::IsNotTrue(expr) => Ok(PyIsNotTrue::new(*expr.clone()).into_bound_py_any(py)?),
171            Expr::IsNotFalse(expr) => Ok(PyIsNotFalse::new(*expr.clone()).into_bound_py_any(py)?),
172            Expr::IsNotUnknown(expr) => {
173                Ok(PyIsNotUnknown::new(*expr.clone()).into_bound_py_any(py)?)
174            }
175            Expr::Negative(expr) => Ok(PyNegative::new(*expr.clone()).into_bound_py_any(py)?),
176            Expr::AggregateFunction(expr) => {
177                Ok(PyAggregateFunction::from(expr.clone()).into_bound_py_any(py)?)
178            }
179            Expr::SimilarTo(value) => Ok(PySimilarTo::from(value.clone()).into_bound_py_any(py)?),
180            Expr::Between(value) => {
181                Ok(between::PyBetween::from(value.clone()).into_bound_py_any(py)?)
182            }
183            Expr::Case(value) => Ok(case::PyCase::from(value.clone()).into_bound_py_any(py)?),
184            Expr::Cast(value) => Ok(cast::PyCast::from(value.clone()).into_bound_py_any(py)?),
185            Expr::TryCast(value) => Ok(cast::PyTryCast::from(value.clone()).into_bound_py_any(py)?),
186            Expr::ScalarFunction(value) => Err(py_unsupported_variant_err(format!(
187                "Converting Expr::ScalarFunction to a Python object is not implemented: {value:?}"
188            ))),
189            Expr::WindowFunction(value) => Err(py_unsupported_variant_err(format!(
190                "Converting Expr::WindowFunction to a Python object is not implemented: {value:?}"
191            ))),
192            Expr::InList(value) => {
193                Ok(in_list::PyInList::from(value.clone()).into_bound_py_any(py)?)
194            }
195            Expr::Exists(value) => Ok(exists::PyExists::from(value.clone()).into_bound_py_any(py)?),
196            Expr::InSubquery(value) => {
197                Ok(in_subquery::PyInSubquery::from(value.clone()).into_bound_py_any(py)?)
198            }
199            Expr::ScalarSubquery(value) => {
200                Ok(scalar_subquery::PyScalarSubquery::from(value.clone()).into_bound_py_any(py)?)
201            }
202            #[allow(deprecated)]
203            Expr::Wildcard { qualifier, options } => Err(py_unsupported_variant_err(format!(
204                "Converting Expr::Wildcard to a Python object is not implemented : {qualifier:?} {options:?}"
205            ))),
206            Expr::GroupingSet(value) => {
207                Ok(grouping_set::PyGroupingSet::from(value.clone()).into_bound_py_any(py)?)
208            }
209            Expr::Placeholder(value) => {
210                Ok(placeholder::PyPlaceholder::from(value.clone()).into_bound_py_any(py)?)
211            }
212            Expr::OuterReferenceColumn(data_type, column) => {
213                Err(py_unsupported_variant_err(format!(
214                    "Converting Expr::OuterReferenceColumn to a Python object is not implemented: {data_type:?} - {column:?}"
215                )))
216            }
217            Expr::Unnest(value) => {
218                Ok(unnest_expr::PyUnnestExpr::from(value.clone()).into_bound_py_any(py)?)
219            }
220            Expr::SetComparison(value) => {
221                Ok(set_comparison::PySetComparison::from(value.clone()).into_bound_py_any(py)?)
222            }
223        })
224    }
225
226    /// Returns the name of this expression as it should appear in a schema. This name
227    /// will not include any CAST expressions.
228    fn schema_name(&self) -> PyResult<String> {
229        Ok(format!("{}", self.expr.schema_name()))
230    }
231
232    /// Returns a full and complete string representation of this expression.
233    fn canonical_name(&self) -> PyResult<String> {
234        Ok(format!("{}", self.expr))
235    }
236
237    /// Returns the name of the Expr variant.
238    /// Ex: 'IsNotNull', 'Literal', 'BinaryExpr', etc
239    fn variant_name(&self) -> PyResult<&str> {
240        Ok(self.expr.variant_name())
241    }
242
243    fn __richcmp__(&self, other: PyExpr, op: CompareOp) -> PyExpr {
244        let expr = match op {
245            CompareOp::Lt => self.expr.clone().lt(other.expr),
246            CompareOp::Le => self.expr.clone().lt_eq(other.expr),
247            CompareOp::Eq => self.expr.clone().eq(other.expr),
248            CompareOp::Ne => self.expr.clone().not_eq(other.expr),
249            CompareOp::Gt => self.expr.clone().gt(other.expr),
250            CompareOp::Ge => self.expr.clone().gt_eq(other.expr),
251        };
252        expr.into()
253    }
254
255    fn __repr__(&self) -> PyResult<String> {
256        Ok(format!("Expr({})", self.expr))
257    }
258
259    fn __add__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
260        Ok((self.expr.clone() + rhs.expr).into())
261    }
262
263    fn __sub__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
264        Ok((self.expr.clone() - rhs.expr).into())
265    }
266
267    fn __truediv__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
268        Ok((self.expr.clone() / rhs.expr).into())
269    }
270
271    fn __mul__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
272        Ok((self.expr.clone() * rhs.expr).into())
273    }
274
275    fn __mod__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
276        let expr = self.expr.clone() % rhs.expr;
277        Ok(expr.into())
278    }
279
280    fn __and__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
281        Ok(self.expr.clone().and(rhs.expr).into())
282    }
283
284    fn __or__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
285        Ok(self.expr.clone().or(rhs.expr).into())
286    }
287
288    fn __invert__(&self) -> PyResult<PyExpr> {
289        let expr = !self.expr.clone();
290        Ok(expr.into())
291    }
292
293    fn __getitem__(&self, key: &str) -> PyResult<PyExpr> {
294        Ok(self.expr.clone().field(key).into())
295    }
296
297    #[staticmethod]
298    pub fn literal(value: PyScalarValue) -> PyExpr {
299        lit(value.0).into()
300    }
301
302    #[staticmethod]
303    pub fn literal_with_metadata(
304        value: PyScalarValue,
305        metadata: HashMap<String, String>,
306    ) -> PyExpr {
307        let metadata = FieldMetadata::new(metadata.into_iter().collect());
308        lit_with_metadata(value.0, Some(metadata)).into()
309    }
310
311    #[staticmethod]
312    pub fn column(value: &str) -> PyExpr {
313        col(value).into()
314    }
315
316    /// assign a name to the PyExpr
317    #[pyo3(signature = (name, metadata=None))]
318    pub fn alias(&self, name: &str, metadata: Option<HashMap<String, String>>) -> PyExpr {
319        let metadata = metadata.map(|m| FieldMetadata::new(m.into_iter().collect()));
320        self.expr.clone().alias_with_metadata(name, metadata).into()
321    }
322
323    /// Create a sort PyExpr from an existing PyExpr.
324    #[pyo3(signature = (ascending=true, nulls_first=true))]
325    pub fn sort(&self, ascending: bool, nulls_first: bool) -> PySortExpr {
326        self.expr.clone().sort(ascending, nulls_first).into()
327    }
328
329    pub fn is_null(&self) -> PyExpr {
330        self.expr.clone().is_null().into()
331    }
332
333    pub fn is_not_null(&self) -> PyExpr {
334        self.expr.clone().is_not_null().into()
335    }
336
337    pub fn cast(&self, to: PyArrowType<DataType>) -> PyExpr {
338        // self.expr.cast_to() requires DFSchema to validate that the cast
339        // is supported, omit that for now
340        let expr = Expr::Cast(Cast::new(Box::new(self.expr.clone()), to.0));
341        expr.into()
342    }
343
344    #[pyo3(signature = (low, high, negated=false))]
345    pub fn between(&self, low: PyExpr, high: PyExpr, negated: bool) -> PyExpr {
346        let expr = Expr::Between(Between::new(
347            Box::new(self.expr.clone()),
348            negated,
349            Box::new(low.into()),
350            Box::new(high.into()),
351        ));
352        expr.into()
353    }
354
355    /// A Rex (Row Expression) specifies a single row of data. That specification
356    /// could include user defined functions or types. RexType identifies the row
357    /// as one of the possible valid `RexTypes`.
358    pub fn rex_type(&self) -> PyResult<RexType> {
359        Ok(match self.expr {
360            Expr::Alias(..) => RexType::Alias,
361            Expr::Column(..) => RexType::Reference,
362            Expr::ScalarVariable(..) | Expr::Literal(..) => RexType::Literal,
363            Expr::BinaryExpr { .. }
364            | Expr::Not(..)
365            | Expr::IsNotNull(..)
366            | Expr::Negative(..)
367            | Expr::IsNull(..)
368            | Expr::Like { .. }
369            | Expr::SimilarTo { .. }
370            | Expr::Between { .. }
371            | Expr::Case { .. }
372            | Expr::Cast { .. }
373            | Expr::TryCast { .. }
374            | Expr::ScalarFunction { .. }
375            | Expr::AggregateFunction { .. }
376            | Expr::WindowFunction { .. }
377            | Expr::InList { .. }
378            | Expr::Exists { .. }
379            | Expr::InSubquery { .. }
380            | Expr::GroupingSet(..)
381            | Expr::IsTrue(..)
382            | Expr::IsFalse(..)
383            | Expr::IsUnknown(_)
384            | Expr::IsNotTrue(..)
385            | Expr::IsNotFalse(..)
386            | Expr::Placeholder { .. }
387            | Expr::OuterReferenceColumn(_, _)
388            | Expr::Unnest(_)
389            | Expr::IsNotUnknown(_)
390            | Expr::SetComparison(_) => RexType::Call,
391            Expr::ScalarSubquery(..) => RexType::ScalarSubquery,
392            #[allow(deprecated)]
393            Expr::Wildcard { .. } => {
394                return Err(py_unsupported_variant_err("Expr::Wildcard is unsupported"));
395            }
396        })
397    }
398
399    /// Given the current `Expr` return the DataTypeMap which represents the
400    /// PythonType, Arrow DataType, and SqlType Enum which represents
401    pub fn types(&self) -> PyResult<DataTypeMap> {
402        Self::_types(&self.expr)
403    }
404
405    /// Extracts the Expr value into a Py<PyAny> that can be shared with Python
406    pub fn python_value<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
407        match &self.expr {
408            Expr::Literal(scalar_value, _) => scalar_to_pyarrow(scalar_value, py),
409            _ => Err(py_type_err(format!(
410                "Non Expr::Literal encountered in types: {:?}",
411                &self.expr
412            ))),
413        }
414    }
415
416    /// Row expressions, Rex(s), operate on the concept of operands. Different variants of Expressions, Expr(s),
417    /// store those operands in different datastructures. This function examines the Expr variant and returns
418    /// the operands to the calling logic as a Vec of PyExpr instances.
419    pub fn rex_call_operands(&self) -> PyResult<Vec<PyExpr>> {
420        match &self.expr {
421            // Expr variants that are themselves the operand to return
422            Expr::Column(..) | Expr::ScalarVariable(..) | Expr::Literal(..) => {
423                Ok(vec![PyExpr::from(self.expr.clone())])
424            }
425
426            Expr::Alias(alias) => Ok(vec![PyExpr::from(*alias.expr.clone())]),
427
428            // Expr(s) that house the Expr instance to return in their bounded params
429            Expr::Not(expr)
430            | Expr::IsNull(expr)
431            | Expr::IsNotNull(expr)
432            | Expr::IsTrue(expr)
433            | Expr::IsFalse(expr)
434            | Expr::IsUnknown(expr)
435            | Expr::IsNotTrue(expr)
436            | Expr::IsNotFalse(expr)
437            | Expr::IsNotUnknown(expr)
438            | Expr::Negative(expr)
439            | Expr::Cast(Cast { expr, .. })
440            | Expr::TryCast(TryCast { expr, .. })
441            | Expr::InSubquery(InSubquery { expr, .. })
442            | Expr::SetComparison(SetComparison { expr, .. }) => {
443                Ok(vec![PyExpr::from(*expr.clone())])
444            }
445
446            // Expr variants containing a collection of Expr(s) for operands
447            Expr::AggregateFunction(AggregateFunction {
448                params: AggregateFunctionParams { args, .. },
449                ..
450            })
451            | Expr::ScalarFunction(ScalarFunction { args, .. }) => {
452                Ok(args.iter().map(|arg| PyExpr::from(arg.clone())).collect())
453            }
454            Expr::WindowFunction(boxed_window_fn) => {
455                let args = &boxed_window_fn.params.args;
456                Ok(args.iter().map(|arg| PyExpr::from(arg.clone())).collect())
457            }
458
459            // Expr(s) that require more specific processing
460            Expr::Case(Case {
461                expr,
462                when_then_expr,
463                else_expr,
464            }) => {
465                let mut operands: Vec<PyExpr> = Vec::new();
466
467                if let Some(e) = expr {
468                    for (when, then) in when_then_expr {
469                        operands.push(PyExpr::from(Expr::BinaryExpr(BinaryExpr::new(
470                            Box::new(*e.clone()),
471                            Operator::Eq,
472                            Box::new(*when.clone()),
473                        ))));
474                        operands.push(PyExpr::from(*then.clone()));
475                    }
476                } else {
477                    for (when, then) in when_then_expr {
478                        operands.push(PyExpr::from(*when.clone()));
479                        operands.push(PyExpr::from(*then.clone()));
480                    }
481                };
482
483                if let Some(e) = else_expr {
484                    operands.push(PyExpr::from(*e.clone()));
485                };
486
487                Ok(operands)
488            }
489            Expr::InList(InList { expr, list, .. }) => {
490                let mut operands: Vec<PyExpr> = vec![PyExpr::from(*expr.clone())];
491                for list_elem in list {
492                    operands.push(PyExpr::from(list_elem.clone()));
493                }
494
495                Ok(operands)
496            }
497            Expr::BinaryExpr(BinaryExpr { left, right, .. }) => Ok(vec![
498                PyExpr::from(*left.clone()),
499                PyExpr::from(*right.clone()),
500            ]),
501            Expr::Like(Like { expr, pattern, .. }) => Ok(vec![
502                PyExpr::from(*expr.clone()),
503                PyExpr::from(*pattern.clone()),
504            ]),
505            Expr::SimilarTo(Like { expr, pattern, .. }) => Ok(vec![
506                PyExpr::from(*expr.clone()),
507                PyExpr::from(*pattern.clone()),
508            ]),
509            Expr::Between(Between {
510                expr,
511                negated: _,
512                low,
513                high,
514            }) => Ok(vec![
515                PyExpr::from(*expr.clone()),
516                PyExpr::from(*low.clone()),
517                PyExpr::from(*high.clone()),
518            ]),
519
520            // Currently un-support/implemented Expr types for Rex Call operations
521            Expr::GroupingSet(..)
522            | Expr::Unnest(_)
523            | Expr::OuterReferenceColumn(_, _)
524            | Expr::ScalarSubquery(..)
525            | Expr::Placeholder { .. }
526            | Expr::Exists { .. } => Err(py_runtime_err(format!(
527                "Unimplemented Expr type: {}",
528                self.expr
529            ))),
530
531            #[allow(deprecated)]
532            Expr::Wildcard { .. } => {
533                Err(py_unsupported_variant_err("Expr::Wildcard is unsupported"))
534            }
535        }
536    }
537
538    /// Extracts the operator associated with a RexType::Call
539    pub fn rex_call_operator(&self) -> PyResult<String> {
540        Ok(match &self.expr {
541            Expr::BinaryExpr(BinaryExpr {
542                left: _,
543                op,
544                right: _,
545            }) => format!("{op}"),
546            Expr::ScalarFunction(ScalarFunction { func, args: _ }) => func.name().to_string(),
547            Expr::Cast { .. } => "cast".to_string(),
548            Expr::Between { .. } => "between".to_string(),
549            Expr::Case { .. } => "case".to_string(),
550            Expr::IsNull(..) => "is null".to_string(),
551            Expr::IsNotNull(..) => "is not null".to_string(),
552            Expr::IsTrue(_) => "is true".to_string(),
553            Expr::IsFalse(_) => "is false".to_string(),
554            Expr::IsUnknown(_) => "is unknown".to_string(),
555            Expr::IsNotTrue(_) => "is not true".to_string(),
556            Expr::IsNotFalse(_) => "is not false".to_string(),
557            Expr::IsNotUnknown(_) => "is not unknown".to_string(),
558            Expr::InList { .. } => "in list".to_string(),
559            Expr::Negative(..) => "negative".to_string(),
560            Expr::Not(..) => "not".to_string(),
561            Expr::Like(Like {
562                negated,
563                case_insensitive,
564                ..
565            }) => {
566                let name = if *case_insensitive { "ilike" } else { "like" };
567                if *negated {
568                    format!("not {name}")
569                } else {
570                    name.to_string()
571                }
572            }
573            Expr::SimilarTo(Like { negated, .. }) => {
574                if *negated {
575                    "not similar to".to_string()
576                } else {
577                    "similar to".to_string()
578                }
579            }
580            _ => {
581                return Err(py_type_err(format!(
582                    "Catch all triggered in get_operator_name: {:?}",
583                    &self.expr
584                )));
585            }
586        })
587    }
588
589    pub fn column_name(&self, plan: PyLogicalPlan) -> PyResult<String> {
590        self._column_name(&plan.plan()).map_err(py_runtime_err)
591    }
592
593    // Expression Function Builder functions
594
595    pub fn order_by(&self, order_by: Vec<PySortExpr>) -> PyExprFuncBuilder {
596        self.expr
597            .clone()
598            .order_by(to_sort_expressions(order_by))
599            .into()
600    }
601
602    pub fn filter(&self, filter: PyExpr) -> PyExprFuncBuilder {
603        self.expr.clone().filter(filter.expr.clone()).into()
604    }
605
606    pub fn distinct(&self) -> PyExprFuncBuilder {
607        self.expr.clone().distinct().into()
608    }
609
610    pub fn null_treatment(&self, null_treatment: NullTreatment) -> PyExprFuncBuilder {
611        self.expr
612            .clone()
613            .null_treatment(Some(null_treatment.into()))
614            .into()
615    }
616
617    pub fn partition_by(&self, partition_by: Vec<PyExpr>) -> PyExprFuncBuilder {
618        let partition_by = partition_by.iter().map(|e| e.expr.clone()).collect();
619        self.expr.clone().partition_by(partition_by).into()
620    }
621
622    pub fn window_frame(&self, window_frame: PyWindowFrame) -> PyExprFuncBuilder {
623        self.expr.clone().window_frame(window_frame.into()).into()
624    }
625
626    #[pyo3(signature = (partition_by=None, window_frame=None, order_by=None, null_treatment=None))]
627    pub fn over(
628        &self,
629        partition_by: Option<Vec<PyExpr>>,
630        window_frame: Option<PyWindowFrame>,
631        order_by: Option<Vec<PySortExpr>>,
632        null_treatment: Option<NullTreatment>,
633    ) -> PyDataFusionResult<PyExpr> {
634        match &self.expr {
635            Expr::AggregateFunction(agg_fn) => {
636                let window_fn = Expr::WindowFunction(Box::new(WindowFunction::new(
637                    WindowFunctionDefinition::AggregateUDF(agg_fn.func.clone()),
638                    agg_fn.params.args.clone(),
639                )));
640
641                add_builder_fns_to_window(
642                    window_fn,
643                    partition_by,
644                    window_frame,
645                    order_by,
646                    null_treatment,
647                )
648            }
649            Expr::WindowFunction(_) => add_builder_fns_to_window(
650                self.expr.clone(),
651                partition_by,
652                window_frame,
653                order_by,
654                null_treatment,
655            ),
656            _ => Err(datafusion::error::DataFusionError::Plan(format!(
657                "Using {} with `over` is not allowed. Must use an aggregate or window function.",
658                self.expr.variant_name()
659            ))
660            .into()),
661        }
662    }
663}
664
665#[pyclass(
666    from_py_object,
667    frozen,
668    name = "ExprFuncBuilder",
669    module = "datafusion.expr",
670    subclass
671)]
672#[derive(Debug, Clone)]
673pub struct PyExprFuncBuilder {
674    pub builder: ExprFuncBuilder,
675}
676
677impl From<ExprFuncBuilder> for PyExprFuncBuilder {
678    fn from(builder: ExprFuncBuilder) -> Self {
679        Self { builder }
680    }
681}
682
683#[pymethods]
684impl PyExprFuncBuilder {
685    pub fn order_by(&self, order_by: Vec<PySortExpr>) -> PyExprFuncBuilder {
686        self.builder
687            .clone()
688            .order_by(to_sort_expressions(order_by))
689            .into()
690    }
691
692    pub fn filter(&self, filter: PyExpr) -> PyExprFuncBuilder {
693        self.builder.clone().filter(filter.expr.clone()).into()
694    }
695
696    pub fn distinct(&self) -> PyExprFuncBuilder {
697        self.builder.clone().distinct().into()
698    }
699
700    pub fn null_treatment(&self, null_treatment: NullTreatment) -> PyExprFuncBuilder {
701        self.builder
702            .clone()
703            .null_treatment(Some(null_treatment.into()))
704            .into()
705    }
706
707    pub fn partition_by(&self, partition_by: Vec<PyExpr>) -> PyExprFuncBuilder {
708        let partition_by = partition_by.iter().map(|e| e.expr.clone()).collect();
709        self.builder.clone().partition_by(partition_by).into()
710    }
711
712    pub fn window_frame(&self, window_frame: PyWindowFrame) -> PyExprFuncBuilder {
713        self.builder
714            .clone()
715            .window_frame(window_frame.into())
716            .into()
717    }
718
719    pub fn build(&self) -> PyDataFusionResult<PyExpr> {
720        Ok(self.builder.clone().build().map(|expr| expr.into())?)
721    }
722}
723
724impl PyExpr {
725    pub fn _column_name(&self, plan: &LogicalPlan) -> PyDataFusionResult<String> {
726        let field = Self::expr_to_field(&self.expr, plan)?;
727        Ok(field.name().to_owned())
728    }
729
730    /// Create a [Field] representing an [Expr], given an input [LogicalPlan] to resolve against
731    pub fn expr_to_field(expr: &Expr, input_plan: &LogicalPlan) -> PyDataFusionResult<Arc<Field>> {
732        let fields = exprlist_to_fields(std::slice::from_ref(expr), input_plan)?;
733        Ok(fields[0].1.clone())
734    }
735    fn _types(expr: &Expr) -> PyResult<DataTypeMap> {
736        match expr {
737            Expr::BinaryExpr(BinaryExpr {
738                left: _,
739                op,
740                right: _,
741            }) => match op {
742                Operator::Eq
743                | Operator::NotEq
744                | Operator::Lt
745                | Operator::LtEq
746                | Operator::Gt
747                | Operator::GtEq
748                | Operator::And
749                | Operator::Or
750                | Operator::IsDistinctFrom
751                | Operator::IsNotDistinctFrom
752                | Operator::RegexMatch
753                | Operator::RegexIMatch
754                | Operator::RegexNotMatch
755                | Operator::RegexNotIMatch
756                | Operator::LikeMatch
757                | Operator::ILikeMatch
758                | Operator::NotLikeMatch
759                | Operator::NotILikeMatch => DataTypeMap::map_from_arrow_type(&DataType::Boolean),
760                Operator::Plus | Operator::Minus | Operator::Multiply | Operator::Modulo => {
761                    DataTypeMap::map_from_arrow_type(&DataType::Int64)
762                }
763                Operator::Divide => DataTypeMap::map_from_arrow_type(&DataType::Float64),
764                Operator::StringConcat => DataTypeMap::map_from_arrow_type(&DataType::Utf8),
765                Operator::BitwiseShiftLeft
766                | Operator::BitwiseShiftRight
767                | Operator::BitwiseXor
768                | Operator::BitwiseAnd
769                | Operator::BitwiseOr => DataTypeMap::map_from_arrow_type(&DataType::Binary),
770                Operator::AtArrow
771                | Operator::ArrowAt
772                | Operator::Arrow
773                | Operator::LongArrow
774                | Operator::HashArrow
775                | Operator::HashLongArrow
776                | Operator::AtAt
777                | Operator::IntegerDivide
778                | Operator::HashMinus
779                | Operator::AtQuestion
780                | Operator::Question
781                | Operator::QuestionAnd
782                | Operator::QuestionPipe
783                | Operator::Colon => Err(py_type_err(format!("Unsupported expr: ${op}"))),
784            },
785            Expr::Cast(Cast { expr: _, data_type }) => DataTypeMap::map_from_arrow_type(data_type),
786            Expr::Literal(scalar_value, _) => DataTypeMap::map_from_scalar_value(scalar_value),
787            _ => Err(py_type_err(format!(
788                "Non Expr::Literal encountered in types: {expr:?}"
789            ))),
790        }
791    }
792}
793
794/// Initializes the `expr` module to match the pattern of `datafusion-expr` https://docs.rs/datafusion-expr/latest/datafusion_expr/
795pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
796    m.add_class::<PyExpr>()?;
797    m.add_class::<PyColumn>()?;
798    m.add_class::<PyLiteral>()?;
799    m.add_class::<PyBinaryExpr>()?;
800    m.add_class::<PyLiteral>()?;
801    m.add_class::<PyAggregateFunction>()?;
802    m.add_class::<PyNot>()?;
803    m.add_class::<PyIsNotNull>()?;
804    m.add_class::<PyIsNull>()?;
805    m.add_class::<PyIsTrue>()?;
806    m.add_class::<PyIsFalse>()?;
807    m.add_class::<PyIsUnknown>()?;
808    m.add_class::<PyIsNotTrue>()?;
809    m.add_class::<PyIsNotFalse>()?;
810    m.add_class::<PyIsNotUnknown>()?;
811    m.add_class::<PyNegative>()?;
812    m.add_class::<PyLike>()?;
813    m.add_class::<PyILike>()?;
814    m.add_class::<PySimilarTo>()?;
815    m.add_class::<PyScalarVariable>()?;
816    m.add_class::<alias::PyAlias>()?;
817    m.add_class::<in_list::PyInList>()?;
818    m.add_class::<exists::PyExists>()?;
819    m.add_class::<subquery::PySubquery>()?;
820    m.add_class::<in_subquery::PyInSubquery>()?;
821    m.add_class::<scalar_subquery::PyScalarSubquery>()?;
822    m.add_class::<placeholder::PyPlaceholder>()?;
823    m.add_class::<grouping_set::PyGroupingSet>()?;
824    m.add_class::<case::PyCase>()?;
825    m.add_class::<conditional_expr::PyCaseBuilder>()?;
826    m.add_class::<cast::PyCast>()?;
827    m.add_class::<cast::PyTryCast>()?;
828    m.add_class::<between::PyBetween>()?;
829    m.add_class::<explain::PyExplain>()?;
830    m.add_class::<limit::PyLimit>()?;
831    m.add_class::<aggregate::PyAggregate>()?;
832    m.add_class::<sort::PySort>()?;
833    m.add_class::<analyze::PyAnalyze>()?;
834    m.add_class::<empty_relation::PyEmptyRelation>()?;
835    m.add_class::<join::PyJoin>()?;
836    m.add_class::<join::PyJoinType>()?;
837    m.add_class::<join::PyJoinConstraint>()?;
838    m.add_class::<union::PyUnion>()?;
839    m.add_class::<unnest::PyUnnest>()?;
840    m.add_class::<unnest_expr::PyUnnestExpr>()?;
841    m.add_class::<extension::PyExtension>()?;
842    m.add_class::<filter::PyFilter>()?;
843    m.add_class::<projection::PyProjection>()?;
844    m.add_class::<table_scan::PyTableScan>()?;
845    m.add_class::<create_memory_table::PyCreateMemoryTable>()?;
846    m.add_class::<create_view::PyCreateView>()?;
847    m.add_class::<distinct::PyDistinct>()?;
848    m.add_class::<sort_expr::PySortExpr>()?;
849    m.add_class::<subquery_alias::PySubqueryAlias>()?;
850    m.add_class::<drop_table::PyDropTable>()?;
851    m.add_class::<repartition::PyPartitioning>()?;
852    m.add_class::<repartition::PyRepartition>()?;
853    m.add_class::<window::PyWindowExpr>()?;
854    m.add_class::<window::PyWindowFrame>()?;
855    m.add_class::<window::PyWindowFrameBound>()?;
856    m.add_class::<copy_to::PyCopyTo>()?;
857    m.add_class::<copy_to::PyFileType>()?;
858    m.add_class::<create_catalog::PyCreateCatalog>()?;
859    m.add_class::<create_catalog_schema::PyCreateCatalogSchema>()?;
860    m.add_class::<create_external_table::PyCreateExternalTable>()?;
861    m.add_class::<create_function::PyCreateFunction>()?;
862    m.add_class::<create_function::PyOperateFunctionArg>()?;
863    m.add_class::<create_function::PyCreateFunctionBody>()?;
864    m.add_class::<create_index::PyCreateIndex>()?;
865    m.add_class::<describe_table::PyDescribeTable>()?;
866    m.add_class::<dml::PyDmlStatement>()?;
867    m.add_class::<drop_catalog_schema::PyDropCatalogSchema>()?;
868    m.add_class::<drop_function::PyDropFunction>()?;
869    m.add_class::<drop_view::PyDropView>()?;
870    m.add_class::<recursive_query::PyRecursiveQuery>()?;
871
872    m.add_class::<statement::PyTransactionStart>()?;
873    m.add_class::<statement::PyTransactionEnd>()?;
874    m.add_class::<statement::PySetVariable>()?;
875    m.add_class::<statement::PyPrepare>()?;
876    m.add_class::<statement::PyExecute>()?;
877    m.add_class::<statement::PyDeallocate>()?;
878    m.add_class::<values::PyValues>()?;
879    m.add_class::<statement::PyTransactionAccessMode>()?;
880    m.add_class::<statement::PyTransactionConclusion>()?;
881    m.add_class::<statement::PyTransactionIsolationLevel>()?;
882
883    Ok(())
884}