datafusion_python/
expr.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::collections::HashMap;
19use std::convert::{From, Into};
20use std::sync::Arc;
21
22use datafusion::arrow::datatypes::{DataType, Field};
23use datafusion::arrow::pyarrow::PyArrowType;
24use datafusion::functions::core::expr_ext::FieldAccessor;
25use datafusion::logical_expr::expr::{
26    AggregateFunction, AggregateFunctionParams, FieldMetadata, InList, InSubquery, ScalarFunction,
27    WindowFunction,
28};
29use datafusion::logical_expr::utils::exprlist_to_fields;
30use datafusion::logical_expr::{
31    col, lit, lit_with_metadata, Between, BinaryExpr, Case, Cast, Expr, ExprFuncBuilder,
32    ExprFunctionExt, Like, LogicalPlan, Operator, TryCast, WindowFunctionDefinition,
33};
34use pyo3::basic::CompareOp;
35use pyo3::prelude::*;
36use pyo3::IntoPyObjectExt;
37use window::PyWindowFrame;
38
39use self::alias::PyAlias;
40use self::bool_expr::{
41    PyIsFalse, PyIsNotFalse, PyIsNotNull, PyIsNotTrue, PyIsNotUnknown, PyIsNull, PyIsTrue,
42    PyIsUnknown, PyNegative, PyNot,
43};
44use self::like::{PyILike, PyLike, PySimilarTo};
45use self::scalar_variable::PyScalarVariable;
46use crate::common::data_type::{DataTypeMap, NullTreatment, PyScalarValue, RexType};
47use crate::errors::{py_runtime_err, py_type_err, py_unsupported_variant_err, PyDataFusionResult};
48use crate::expr::aggregate_expr::PyAggregateFunction;
49use crate::expr::binary_expr::PyBinaryExpr;
50use crate::expr::column::PyColumn;
51use crate::expr::literal::PyLiteral;
52use crate::functions::add_builder_fns_to_window;
53use crate::pyarrow_util::scalar_to_pyarrow;
54use crate::sql::logical::PyLogicalPlan;
55
56pub mod aggregate;
57pub mod aggregate_expr;
58pub mod alias;
59pub mod analyze;
60pub mod between;
61pub mod binary_expr;
62pub mod bool_expr;
63pub mod case;
64pub mod cast;
65pub mod column;
66pub mod conditional_expr;
67pub mod copy_to;
68pub mod create_catalog;
69pub mod create_catalog_schema;
70pub mod create_external_table;
71pub mod create_function;
72pub mod create_index;
73pub mod create_memory_table;
74pub mod create_view;
75pub mod describe_table;
76pub mod distinct;
77pub mod dml;
78pub mod drop_catalog_schema;
79pub mod drop_function;
80pub mod drop_table;
81pub mod drop_view;
82pub mod empty_relation;
83pub mod exists;
84pub mod explain;
85pub mod extension;
86pub mod filter;
87pub mod grouping_set;
88pub mod in_list;
89pub mod in_subquery;
90pub mod join;
91pub mod like;
92pub mod limit;
93pub mod literal;
94pub mod logical_node;
95pub mod placeholder;
96pub mod projection;
97pub mod recursive_query;
98pub mod repartition;
99pub mod scalar_subquery;
100pub mod scalar_variable;
101pub mod signature;
102pub mod sort;
103pub mod sort_expr;
104pub mod statement;
105pub mod subquery;
106pub mod subquery_alias;
107pub mod table_scan;
108pub mod union;
109pub mod unnest;
110pub mod unnest_expr;
111pub mod values;
112pub mod window;
113
114use sort_expr::{to_sort_expressions, PySortExpr};
115
116/// A PyExpr that can be used on a DataFrame
117#[pyclass(frozen, name = "RawExpr", module = "datafusion.expr", subclass)]
118#[derive(Debug, Clone)]
119pub struct PyExpr {
120    pub expr: Expr,
121}
122
123impl From<PyExpr> for Expr {
124    fn from(expr: PyExpr) -> Expr {
125        expr.expr
126    }
127}
128
129impl From<Expr> for PyExpr {
130    fn from(expr: Expr) -> PyExpr {
131        PyExpr { expr }
132    }
133}
134
135/// Convert a list of DataFusion Expr to PyExpr
136pub fn py_expr_list(expr: &[Expr]) -> PyResult<Vec<PyExpr>> {
137    Ok(expr.iter().map(|e| PyExpr::from(e.clone())).collect())
138}
139
140#[pymethods]
141impl PyExpr {
142    /// Return the specific expression
143    fn to_variant<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
144        Python::attach(|_| {
145            match &self.expr {
146            Expr::Alias(alias) => Ok(PyAlias::from(alias.clone()).into_bound_py_any(py)?),
147            Expr::Column(col) => Ok(PyColumn::from(col.clone()).into_bound_py_any(py)?),
148            Expr::ScalarVariable(data_type, variables) => {
149                Ok(PyScalarVariable::new(data_type, variables).into_bound_py_any(py)?)
150            }
151            Expr::Like(value) => Ok(PyLike::from(value.clone()).into_bound_py_any(py)?),
152            Expr::Literal(value, metadata) => Ok(PyLiteral::new_with_metadata(value.clone(), metadata.clone()).into_bound_py_any(py)?),
153            Expr::BinaryExpr(expr) => Ok(PyBinaryExpr::from(expr.clone()).into_bound_py_any(py)?),
154            Expr::Not(expr) => Ok(PyNot::new(*expr.clone()).into_bound_py_any(py)?),
155            Expr::IsNotNull(expr) => Ok(PyIsNotNull::new(*expr.clone()).into_bound_py_any(py)?),
156            Expr::IsNull(expr) => Ok(PyIsNull::new(*expr.clone()).into_bound_py_any(py)?),
157            Expr::IsTrue(expr) => Ok(PyIsTrue::new(*expr.clone()).into_bound_py_any(py)?),
158            Expr::IsFalse(expr) => Ok(PyIsFalse::new(*expr.clone()).into_bound_py_any(py)?),
159            Expr::IsUnknown(expr) => Ok(PyIsUnknown::new(*expr.clone()).into_bound_py_any(py)?),
160            Expr::IsNotTrue(expr) => Ok(PyIsNotTrue::new(*expr.clone()).into_bound_py_any(py)?),
161            Expr::IsNotFalse(expr) => Ok(PyIsNotFalse::new(*expr.clone()).into_bound_py_any(py)?),
162            Expr::IsNotUnknown(expr) => Ok(PyIsNotUnknown::new(*expr.clone()).into_bound_py_any(py)?),
163            Expr::Negative(expr) => Ok(PyNegative::new(*expr.clone()).into_bound_py_any(py)?),
164            Expr::AggregateFunction(expr) => {
165                Ok(PyAggregateFunction::from(expr.clone()).into_bound_py_any(py)?)
166            }
167            Expr::SimilarTo(value) => Ok(PySimilarTo::from(value.clone()).into_bound_py_any(py)?),
168            Expr::Between(value) => Ok(between::PyBetween::from(value.clone()).into_bound_py_any(py)?),
169            Expr::Case(value) => Ok(case::PyCase::from(value.clone()).into_bound_py_any(py)?),
170            Expr::Cast(value) => Ok(cast::PyCast::from(value.clone()).into_bound_py_any(py)?),
171            Expr::TryCast(value) => Ok(cast::PyTryCast::from(value.clone()).into_bound_py_any(py)?),
172            Expr::ScalarFunction(value) => Err(py_unsupported_variant_err(format!(
173                "Converting Expr::ScalarFunction to a Python object is not implemented: {value:?}"
174            ))),
175            Expr::WindowFunction(value) => Err(py_unsupported_variant_err(format!(
176                "Converting Expr::WindowFunction to a Python object is not implemented: {value:?}"
177            ))),
178            Expr::InList(value) => Ok(in_list::PyInList::from(value.clone()).into_bound_py_any(py)?),
179            Expr::Exists(value) => Ok(exists::PyExists::from(value.clone()).into_bound_py_any(py)?),
180            Expr::InSubquery(value) => {
181                Ok(in_subquery::PyInSubquery::from(value.clone()).into_bound_py_any(py)?)
182            }
183            Expr::ScalarSubquery(value) => {
184                Ok(scalar_subquery::PyScalarSubquery::from(value.clone()).into_bound_py_any(py)?)
185            }
186            #[allow(deprecated)]
187            Expr::Wildcard { qualifier, options } => Err(py_unsupported_variant_err(format!(
188                "Converting Expr::Wildcard to a Python object is not implemented : {qualifier:?} {options:?}"
189            ))),
190            Expr::GroupingSet(value) => {
191                Ok(grouping_set::PyGroupingSet::from(value.clone()).into_bound_py_any(py)?)
192            }
193            Expr::Placeholder(value) => {
194                Ok(placeholder::PyPlaceholder::from(value.clone()).into_bound_py_any(py)?)
195            }
196            Expr::OuterReferenceColumn(data_type, column) => Err(py_unsupported_variant_err(format!(
197                "Converting Expr::OuterReferenceColumn to a Python object is not implemented: {data_type:?} - {column:?}"
198            ))),
199            Expr::Unnest(value) => Ok(unnest_expr::PyUnnestExpr::from(value.clone()).into_bound_py_any(py)?),
200        }
201        })
202    }
203
204    /// Returns the name of this expression as it should appear in a schema. This name
205    /// will not include any CAST expressions.
206    fn schema_name(&self) -> PyResult<String> {
207        Ok(format!("{}", self.expr.schema_name()))
208    }
209
210    /// Returns a full and complete string representation of this expression.
211    fn canonical_name(&self) -> PyResult<String> {
212        Ok(format!("{}", self.expr))
213    }
214
215    /// Returns the name of the Expr variant.
216    /// Ex: 'IsNotNull', 'Literal', 'BinaryExpr', etc
217    fn variant_name(&self) -> PyResult<&str> {
218        Ok(self.expr.variant_name())
219    }
220
221    fn __richcmp__(&self, other: PyExpr, op: CompareOp) -> PyExpr {
222        let expr = match op {
223            CompareOp::Lt => self.expr.clone().lt(other.expr),
224            CompareOp::Le => self.expr.clone().lt_eq(other.expr),
225            CompareOp::Eq => self.expr.clone().eq(other.expr),
226            CompareOp::Ne => self.expr.clone().not_eq(other.expr),
227            CompareOp::Gt => self.expr.clone().gt(other.expr),
228            CompareOp::Ge => self.expr.clone().gt_eq(other.expr),
229        };
230        expr.into()
231    }
232
233    fn __repr__(&self) -> PyResult<String> {
234        Ok(format!("Expr({})", self.expr))
235    }
236
237    fn __add__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
238        Ok((self.expr.clone() + rhs.expr).into())
239    }
240
241    fn __sub__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
242        Ok((self.expr.clone() - rhs.expr).into())
243    }
244
245    fn __truediv__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
246        Ok((self.expr.clone() / rhs.expr).into())
247    }
248
249    fn __mul__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
250        Ok((self.expr.clone() * rhs.expr).into())
251    }
252
253    fn __mod__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
254        let expr = self.expr.clone() % rhs.expr;
255        Ok(expr.into())
256    }
257
258    fn __and__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
259        Ok(self.expr.clone().and(rhs.expr).into())
260    }
261
262    fn __or__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
263        Ok(self.expr.clone().or(rhs.expr).into())
264    }
265
266    fn __invert__(&self) -> PyResult<PyExpr> {
267        let expr = !self.expr.clone();
268        Ok(expr.into())
269    }
270
271    fn __getitem__(&self, key: &str) -> PyResult<PyExpr> {
272        Ok(self.expr.clone().field(key).into())
273    }
274
275    #[staticmethod]
276    pub fn literal(value: PyScalarValue) -> PyExpr {
277        lit(value.0).into()
278    }
279
280    #[staticmethod]
281    pub fn literal_with_metadata(
282        value: PyScalarValue,
283        metadata: HashMap<String, String>,
284    ) -> PyExpr {
285        let metadata = FieldMetadata::new(metadata.into_iter().collect());
286        lit_with_metadata(value.0, Some(metadata)).into()
287    }
288
289    #[staticmethod]
290    pub fn column(value: &str) -> PyExpr {
291        col(value).into()
292    }
293
294    /// assign a name to the PyExpr
295    #[pyo3(signature = (name, metadata=None))]
296    pub fn alias(&self, name: &str, metadata: Option<HashMap<String, String>>) -> PyExpr {
297        let metadata = metadata.map(|m| FieldMetadata::new(m.into_iter().collect()));
298        self.expr.clone().alias_with_metadata(name, metadata).into()
299    }
300
301    /// Create a sort PyExpr from an existing PyExpr.
302    #[pyo3(signature = (ascending=true, nulls_first=true))]
303    pub fn sort(&self, ascending: bool, nulls_first: bool) -> PySortExpr {
304        self.expr.clone().sort(ascending, nulls_first).into()
305    }
306
307    pub fn is_null(&self) -> PyExpr {
308        self.expr.clone().is_null().into()
309    }
310
311    pub fn is_not_null(&self) -> PyExpr {
312        self.expr.clone().is_not_null().into()
313    }
314
315    pub fn cast(&self, to: PyArrowType<DataType>) -> PyExpr {
316        // self.expr.cast_to() requires DFSchema to validate that the cast
317        // is supported, omit that for now
318        let expr = Expr::Cast(Cast::new(Box::new(self.expr.clone()), to.0));
319        expr.into()
320    }
321
322    #[pyo3(signature = (low, high, negated=false))]
323    pub fn between(&self, low: PyExpr, high: PyExpr, negated: bool) -> PyExpr {
324        let expr = Expr::Between(Between::new(
325            Box::new(self.expr.clone()),
326            negated,
327            Box::new(low.into()),
328            Box::new(high.into()),
329        ));
330        expr.into()
331    }
332
333    /// A Rex (Row Expression) specifies a single row of data. That specification
334    /// could include user defined functions or types. RexType identifies the row
335    /// as one of the possible valid `RexTypes`.
336    pub fn rex_type(&self) -> PyResult<RexType> {
337        Ok(match self.expr {
338            Expr::Alias(..) => RexType::Alias,
339            Expr::Column(..) => RexType::Reference,
340            Expr::ScalarVariable(..) | Expr::Literal(..) => RexType::Literal,
341            Expr::BinaryExpr { .. }
342            | Expr::Not(..)
343            | Expr::IsNotNull(..)
344            | Expr::Negative(..)
345            | Expr::IsNull(..)
346            | Expr::Like { .. }
347            | Expr::SimilarTo { .. }
348            | Expr::Between { .. }
349            | Expr::Case { .. }
350            | Expr::Cast { .. }
351            | Expr::TryCast { .. }
352            | Expr::ScalarFunction { .. }
353            | Expr::AggregateFunction { .. }
354            | Expr::WindowFunction { .. }
355            | Expr::InList { .. }
356            | Expr::Exists { .. }
357            | Expr::InSubquery { .. }
358            | Expr::GroupingSet(..)
359            | Expr::IsTrue(..)
360            | Expr::IsFalse(..)
361            | Expr::IsUnknown(_)
362            | Expr::IsNotTrue(..)
363            | Expr::IsNotFalse(..)
364            | Expr::Placeholder { .. }
365            | Expr::OuterReferenceColumn(_, _)
366            | Expr::Unnest(_)
367            | Expr::IsNotUnknown(_) => RexType::Call,
368            Expr::ScalarSubquery(..) => RexType::ScalarSubquery,
369            #[allow(deprecated)]
370            Expr::Wildcard { .. } => {
371                return Err(py_unsupported_variant_err("Expr::Wildcard is unsupported"))
372            }
373        })
374    }
375
376    /// Given the current `Expr` return the DataTypeMap which represents the
377    /// PythonType, Arrow DataType, and SqlType Enum which represents
378    pub fn types(&self) -> PyResult<DataTypeMap> {
379        Self::_types(&self.expr)
380    }
381
382    /// Extracts the Expr value into a Py<PyAny> that can be shared with Python
383    pub fn python_value<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
384        match &self.expr {
385            Expr::Literal(scalar_value, _) => scalar_to_pyarrow(scalar_value, py),
386            _ => Err(py_type_err(format!(
387                "Non Expr::Literal encountered in types: {:?}",
388                &self.expr
389            ))),
390        }
391    }
392
393    /// Row expressions, Rex(s), operate on the concept of operands. Different variants of Expressions, Expr(s),
394    /// store those operands in different datastructures. This function examines the Expr variant and returns
395    /// the operands to the calling logic as a Vec of PyExpr instances.
396    pub fn rex_call_operands(&self) -> PyResult<Vec<PyExpr>> {
397        match &self.expr {
398            // Expr variants that are themselves the operand to return
399            Expr::Column(..) | Expr::ScalarVariable(..) | Expr::Literal(..) => {
400                Ok(vec![PyExpr::from(self.expr.clone())])
401            }
402
403            Expr::Alias(alias) => Ok(vec![PyExpr::from(*alias.expr.clone())]),
404
405            // Expr(s) that house the Expr instance to return in their bounded params
406            Expr::Not(expr)
407            | Expr::IsNull(expr)
408            | Expr::IsNotNull(expr)
409            | Expr::IsTrue(expr)
410            | Expr::IsFalse(expr)
411            | Expr::IsUnknown(expr)
412            | Expr::IsNotTrue(expr)
413            | Expr::IsNotFalse(expr)
414            | Expr::IsNotUnknown(expr)
415            | Expr::Negative(expr)
416            | Expr::Cast(Cast { expr, .. })
417            | Expr::TryCast(TryCast { expr, .. })
418            | Expr::InSubquery(InSubquery { expr, .. }) => Ok(vec![PyExpr::from(*expr.clone())]),
419
420            // Expr variants containing a collection of Expr(s) for operands
421            Expr::AggregateFunction(AggregateFunction {
422                params: AggregateFunctionParams { args, .. },
423                ..
424            })
425            | Expr::ScalarFunction(ScalarFunction { args, .. }) => {
426                Ok(args.iter().map(|arg| PyExpr::from(arg.clone())).collect())
427            }
428            Expr::WindowFunction(boxed_window_fn) => {
429                let args = &boxed_window_fn.params.args;
430                Ok(args.iter().map(|arg| PyExpr::from(arg.clone())).collect())
431            }
432
433            // Expr(s) that require more specific processing
434            Expr::Case(Case {
435                expr,
436                when_then_expr,
437                else_expr,
438            }) => {
439                let mut operands: Vec<PyExpr> = Vec::new();
440
441                if let Some(e) = expr {
442                    for (when, then) in when_then_expr {
443                        operands.push(PyExpr::from(Expr::BinaryExpr(BinaryExpr::new(
444                            Box::new(*e.clone()),
445                            Operator::Eq,
446                            Box::new(*when.clone()),
447                        ))));
448                        operands.push(PyExpr::from(*then.clone()));
449                    }
450                } else {
451                    for (when, then) in when_then_expr {
452                        operands.push(PyExpr::from(*when.clone()));
453                        operands.push(PyExpr::from(*then.clone()));
454                    }
455                };
456
457                if let Some(e) = else_expr {
458                    operands.push(PyExpr::from(*e.clone()));
459                };
460
461                Ok(operands)
462            }
463            Expr::InList(InList { expr, list, .. }) => {
464                let mut operands: Vec<PyExpr> = vec![PyExpr::from(*expr.clone())];
465                for list_elem in list {
466                    operands.push(PyExpr::from(list_elem.clone()));
467                }
468
469                Ok(operands)
470            }
471            Expr::BinaryExpr(BinaryExpr { left, right, .. }) => Ok(vec![
472                PyExpr::from(*left.clone()),
473                PyExpr::from(*right.clone()),
474            ]),
475            Expr::Like(Like { expr, pattern, .. }) => Ok(vec![
476                PyExpr::from(*expr.clone()),
477                PyExpr::from(*pattern.clone()),
478            ]),
479            Expr::SimilarTo(Like { expr, pattern, .. }) => Ok(vec![
480                PyExpr::from(*expr.clone()),
481                PyExpr::from(*pattern.clone()),
482            ]),
483            Expr::Between(Between {
484                expr,
485                negated: _,
486                low,
487                high,
488            }) => Ok(vec![
489                PyExpr::from(*expr.clone()),
490                PyExpr::from(*low.clone()),
491                PyExpr::from(*high.clone()),
492            ]),
493
494            // Currently un-support/implemented Expr types for Rex Call operations
495            Expr::GroupingSet(..)
496            | Expr::Unnest(_)
497            | Expr::OuterReferenceColumn(_, _)
498            | Expr::ScalarSubquery(..)
499            | Expr::Placeholder { .. }
500            | Expr::Exists { .. } => Err(py_runtime_err(format!(
501                "Unimplemented Expr type: {}",
502                self.expr
503            ))),
504
505            #[allow(deprecated)]
506            Expr::Wildcard { .. } => {
507                Err(py_unsupported_variant_err("Expr::Wildcard is unsupported"))
508            }
509        }
510    }
511
512    /// Extracts the operator associated with a RexType::Call
513    pub fn rex_call_operator(&self) -> PyResult<String> {
514        Ok(match &self.expr {
515            Expr::BinaryExpr(BinaryExpr {
516                left: _,
517                op,
518                right: _,
519            }) => format!("{op}"),
520            Expr::ScalarFunction(ScalarFunction { func, args: _ }) => func.name().to_string(),
521            Expr::Cast { .. } => "cast".to_string(),
522            Expr::Between { .. } => "between".to_string(),
523            Expr::Case { .. } => "case".to_string(),
524            Expr::IsNull(..) => "is null".to_string(),
525            Expr::IsNotNull(..) => "is not null".to_string(),
526            Expr::IsTrue(_) => "is true".to_string(),
527            Expr::IsFalse(_) => "is false".to_string(),
528            Expr::IsUnknown(_) => "is unknown".to_string(),
529            Expr::IsNotTrue(_) => "is not true".to_string(),
530            Expr::IsNotFalse(_) => "is not false".to_string(),
531            Expr::IsNotUnknown(_) => "is not unknown".to_string(),
532            Expr::InList { .. } => "in list".to_string(),
533            Expr::Negative(..) => "negative".to_string(),
534            Expr::Not(..) => "not".to_string(),
535            Expr::Like(Like {
536                negated,
537                case_insensitive,
538                ..
539            }) => {
540                let name = if *case_insensitive { "ilike" } else { "like" };
541                if *negated {
542                    format!("not {name}")
543                } else {
544                    name.to_string()
545                }
546            }
547            Expr::SimilarTo(Like { negated, .. }) => {
548                if *negated {
549                    "not similar to".to_string()
550                } else {
551                    "similar to".to_string()
552                }
553            }
554            _ => {
555                return Err(py_type_err(format!(
556                    "Catch all triggered in get_operator_name: {:?}",
557                    &self.expr
558                )))
559            }
560        })
561    }
562
563    pub fn column_name(&self, plan: PyLogicalPlan) -> PyResult<String> {
564        self._column_name(&plan.plan()).map_err(py_runtime_err)
565    }
566
567    // Expression Function Builder functions
568
569    pub fn order_by(&self, order_by: Vec<PySortExpr>) -> PyExprFuncBuilder {
570        self.expr
571            .clone()
572            .order_by(to_sort_expressions(order_by))
573            .into()
574    }
575
576    pub fn filter(&self, filter: PyExpr) -> PyExprFuncBuilder {
577        self.expr.clone().filter(filter.expr.clone()).into()
578    }
579
580    pub fn distinct(&self) -> PyExprFuncBuilder {
581        self.expr.clone().distinct().into()
582    }
583
584    pub fn null_treatment(&self, null_treatment: NullTreatment) -> PyExprFuncBuilder {
585        self.expr
586            .clone()
587            .null_treatment(Some(null_treatment.into()))
588            .into()
589    }
590
591    pub fn partition_by(&self, partition_by: Vec<PyExpr>) -> PyExprFuncBuilder {
592        let partition_by = partition_by.iter().map(|e| e.expr.clone()).collect();
593        self.expr.clone().partition_by(partition_by).into()
594    }
595
596    pub fn window_frame(&self, window_frame: PyWindowFrame) -> PyExprFuncBuilder {
597        self.expr.clone().window_frame(window_frame.into()).into()
598    }
599
600    #[pyo3(signature = (partition_by=None, window_frame=None, order_by=None, null_treatment=None))]
601    pub fn over(
602        &self,
603        partition_by: Option<Vec<PyExpr>>,
604        window_frame: Option<PyWindowFrame>,
605        order_by: Option<Vec<PySortExpr>>,
606        null_treatment: Option<NullTreatment>,
607    ) -> PyDataFusionResult<PyExpr> {
608        match &self.expr {
609            Expr::AggregateFunction(agg_fn) => {
610                let window_fn = Expr::WindowFunction(Box::new(WindowFunction::new(
611                    WindowFunctionDefinition::AggregateUDF(agg_fn.func.clone()),
612                    agg_fn.params.args.clone(),
613                )));
614
615                add_builder_fns_to_window(
616                    window_fn,
617                    partition_by,
618                    window_frame,
619                    order_by,
620                    null_treatment,
621                )
622            }
623            Expr::WindowFunction(_) => add_builder_fns_to_window(
624                self.expr.clone(),
625                partition_by,
626                window_frame,
627                order_by,
628                null_treatment,
629            ),
630            _ => Err(datafusion::error::DataFusionError::Plan(format!(
631                "Using {} with `over` is not allowed. Must use an aggregate or window function.",
632                self.expr.variant_name()
633            ))
634            .into()),
635        }
636    }
637}
638
639#[pyclass(frozen, name = "ExprFuncBuilder", module = "datafusion.expr", subclass)]
640#[derive(Debug, Clone)]
641pub struct PyExprFuncBuilder {
642    pub builder: ExprFuncBuilder,
643}
644
645impl From<ExprFuncBuilder> for PyExprFuncBuilder {
646    fn from(builder: ExprFuncBuilder) -> Self {
647        Self { builder }
648    }
649}
650
651#[pymethods]
652impl PyExprFuncBuilder {
653    pub fn order_by(&self, order_by: Vec<PySortExpr>) -> PyExprFuncBuilder {
654        self.builder
655            .clone()
656            .order_by(to_sort_expressions(order_by))
657            .into()
658    }
659
660    pub fn filter(&self, filter: PyExpr) -> PyExprFuncBuilder {
661        self.builder.clone().filter(filter.expr.clone()).into()
662    }
663
664    pub fn distinct(&self) -> PyExprFuncBuilder {
665        self.builder.clone().distinct().into()
666    }
667
668    pub fn null_treatment(&self, null_treatment: NullTreatment) -> PyExprFuncBuilder {
669        self.builder
670            .clone()
671            .null_treatment(Some(null_treatment.into()))
672            .into()
673    }
674
675    pub fn partition_by(&self, partition_by: Vec<PyExpr>) -> PyExprFuncBuilder {
676        let partition_by = partition_by.iter().map(|e| e.expr.clone()).collect();
677        self.builder.clone().partition_by(partition_by).into()
678    }
679
680    pub fn window_frame(&self, window_frame: PyWindowFrame) -> PyExprFuncBuilder {
681        self.builder
682            .clone()
683            .window_frame(window_frame.into())
684            .into()
685    }
686
687    pub fn build(&self) -> PyDataFusionResult<PyExpr> {
688        Ok(self.builder.clone().build().map(|expr| expr.into())?)
689    }
690}
691
692impl PyExpr {
693    pub fn _column_name(&self, plan: &LogicalPlan) -> PyDataFusionResult<String> {
694        let field = Self::expr_to_field(&self.expr, plan)?;
695        Ok(field.name().to_owned())
696    }
697
698    /// Create a [Field] representing an [Expr], given an input [LogicalPlan] to resolve against
699    pub fn expr_to_field(expr: &Expr, input_plan: &LogicalPlan) -> PyDataFusionResult<Arc<Field>> {
700        let fields = exprlist_to_fields(std::slice::from_ref(expr), input_plan)?;
701        Ok(fields[0].1.clone())
702    }
703    fn _types(expr: &Expr) -> PyResult<DataTypeMap> {
704        match expr {
705            Expr::BinaryExpr(BinaryExpr {
706                left: _,
707                op,
708                right: _,
709            }) => match op {
710                Operator::Eq
711                | Operator::NotEq
712                | Operator::Lt
713                | Operator::LtEq
714                | Operator::Gt
715                | Operator::GtEq
716                | Operator::And
717                | Operator::Or
718                | Operator::IsDistinctFrom
719                | Operator::IsNotDistinctFrom
720                | Operator::RegexMatch
721                | Operator::RegexIMatch
722                | Operator::RegexNotMatch
723                | Operator::RegexNotIMatch
724                | Operator::LikeMatch
725                | Operator::ILikeMatch
726                | Operator::NotLikeMatch
727                | Operator::NotILikeMatch => DataTypeMap::map_from_arrow_type(&DataType::Boolean),
728                Operator::Plus | Operator::Minus | Operator::Multiply | Operator::Modulo => {
729                    DataTypeMap::map_from_arrow_type(&DataType::Int64)
730                }
731                Operator::Divide => DataTypeMap::map_from_arrow_type(&DataType::Float64),
732                Operator::StringConcat => DataTypeMap::map_from_arrow_type(&DataType::Utf8),
733                Operator::BitwiseShiftLeft
734                | Operator::BitwiseShiftRight
735                | Operator::BitwiseXor
736                | Operator::BitwiseAnd
737                | Operator::BitwiseOr => DataTypeMap::map_from_arrow_type(&DataType::Binary),
738                Operator::AtArrow
739                | Operator::ArrowAt
740                | Operator::Arrow
741                | Operator::LongArrow
742                | Operator::HashArrow
743                | Operator::HashLongArrow
744                | Operator::AtAt
745                | Operator::IntegerDivide
746                | Operator::HashMinus
747                | Operator::AtQuestion
748                | Operator::Question
749                | Operator::QuestionAnd
750                | Operator::QuestionPipe => Err(py_type_err(format!("Unsupported expr: ${op}"))),
751            },
752            Expr::Cast(Cast { expr: _, data_type }) => DataTypeMap::map_from_arrow_type(data_type),
753            Expr::Literal(scalar_value, _) => DataTypeMap::map_from_scalar_value(scalar_value),
754            _ => Err(py_type_err(format!(
755                "Non Expr::Literal encountered in types: {expr:?}"
756            ))),
757        }
758    }
759}
760
761/// Initializes the `expr` module to match the pattern of `datafusion-expr` https://docs.rs/datafusion-expr/latest/datafusion_expr/
762pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
763    m.add_class::<PyExpr>()?;
764    m.add_class::<PyColumn>()?;
765    m.add_class::<PyLiteral>()?;
766    m.add_class::<PyBinaryExpr>()?;
767    m.add_class::<PyLiteral>()?;
768    m.add_class::<PyAggregateFunction>()?;
769    m.add_class::<PyNot>()?;
770    m.add_class::<PyIsNotNull>()?;
771    m.add_class::<PyIsNull>()?;
772    m.add_class::<PyIsTrue>()?;
773    m.add_class::<PyIsFalse>()?;
774    m.add_class::<PyIsUnknown>()?;
775    m.add_class::<PyIsNotTrue>()?;
776    m.add_class::<PyIsNotFalse>()?;
777    m.add_class::<PyIsNotUnknown>()?;
778    m.add_class::<PyNegative>()?;
779    m.add_class::<PyLike>()?;
780    m.add_class::<PyILike>()?;
781    m.add_class::<PySimilarTo>()?;
782    m.add_class::<PyScalarVariable>()?;
783    m.add_class::<alias::PyAlias>()?;
784    m.add_class::<in_list::PyInList>()?;
785    m.add_class::<exists::PyExists>()?;
786    m.add_class::<subquery::PySubquery>()?;
787    m.add_class::<in_subquery::PyInSubquery>()?;
788    m.add_class::<scalar_subquery::PyScalarSubquery>()?;
789    m.add_class::<placeholder::PyPlaceholder>()?;
790    m.add_class::<grouping_set::PyGroupingSet>()?;
791    m.add_class::<case::PyCase>()?;
792    m.add_class::<conditional_expr::PyCaseBuilder>()?;
793    m.add_class::<cast::PyCast>()?;
794    m.add_class::<cast::PyTryCast>()?;
795    m.add_class::<between::PyBetween>()?;
796    m.add_class::<explain::PyExplain>()?;
797    m.add_class::<limit::PyLimit>()?;
798    m.add_class::<aggregate::PyAggregate>()?;
799    m.add_class::<sort::PySort>()?;
800    m.add_class::<analyze::PyAnalyze>()?;
801    m.add_class::<empty_relation::PyEmptyRelation>()?;
802    m.add_class::<join::PyJoin>()?;
803    m.add_class::<join::PyJoinType>()?;
804    m.add_class::<join::PyJoinConstraint>()?;
805    m.add_class::<union::PyUnion>()?;
806    m.add_class::<unnest::PyUnnest>()?;
807    m.add_class::<unnest_expr::PyUnnestExpr>()?;
808    m.add_class::<extension::PyExtension>()?;
809    m.add_class::<filter::PyFilter>()?;
810    m.add_class::<projection::PyProjection>()?;
811    m.add_class::<table_scan::PyTableScan>()?;
812    m.add_class::<create_memory_table::PyCreateMemoryTable>()?;
813    m.add_class::<create_view::PyCreateView>()?;
814    m.add_class::<distinct::PyDistinct>()?;
815    m.add_class::<sort_expr::PySortExpr>()?;
816    m.add_class::<subquery_alias::PySubqueryAlias>()?;
817    m.add_class::<drop_table::PyDropTable>()?;
818    m.add_class::<repartition::PyPartitioning>()?;
819    m.add_class::<repartition::PyRepartition>()?;
820    m.add_class::<window::PyWindowExpr>()?;
821    m.add_class::<window::PyWindowFrame>()?;
822    m.add_class::<window::PyWindowFrameBound>()?;
823    m.add_class::<copy_to::PyCopyTo>()?;
824    m.add_class::<copy_to::PyFileType>()?;
825    m.add_class::<create_catalog::PyCreateCatalog>()?;
826    m.add_class::<create_catalog_schema::PyCreateCatalogSchema>()?;
827    m.add_class::<create_external_table::PyCreateExternalTable>()?;
828    m.add_class::<create_function::PyCreateFunction>()?;
829    m.add_class::<create_function::PyOperateFunctionArg>()?;
830    m.add_class::<create_function::PyCreateFunctionBody>()?;
831    m.add_class::<create_index::PyCreateIndex>()?;
832    m.add_class::<describe_table::PyDescribeTable>()?;
833    m.add_class::<dml::PyDmlStatement>()?;
834    m.add_class::<drop_catalog_schema::PyDropCatalogSchema>()?;
835    m.add_class::<drop_function::PyDropFunction>()?;
836    m.add_class::<drop_view::PyDropView>()?;
837    m.add_class::<recursive_query::PyRecursiveQuery>()?;
838
839    m.add_class::<statement::PyTransactionStart>()?;
840    m.add_class::<statement::PyTransactionEnd>()?;
841    m.add_class::<statement::PySetVariable>()?;
842    m.add_class::<statement::PyPrepare>()?;
843    m.add_class::<statement::PyExecute>()?;
844    m.add_class::<statement::PyDeallocate>()?;
845    m.add_class::<values::PyValues>()?;
846    m.add_class::<statement::PyTransactionAccessMode>()?;
847    m.add_class::<statement::PyTransactionConclusion>()?;
848    m.add_class::<statement::PyTransactionIsolationLevel>()?;
849
850    Ok(())
851}