datafusion_python/
expr.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use datafusion::logical_expr::utils::exprlist_to_fields;
19use datafusion::logical_expr::{
20    ExprFuncBuilder, ExprFunctionExt, LogicalPlan, WindowFunctionDefinition,
21};
22use pyo3::{basic::CompareOp, prelude::*};
23use std::convert::{From, Into};
24use std::sync::Arc;
25use window::PyWindowFrame;
26
27use datafusion::arrow::datatypes::{DataType, Field};
28use datafusion::arrow::pyarrow::PyArrowType;
29use datafusion::functions::core::expr_ext::FieldAccessor;
30use datafusion::logical_expr::{
31    col,
32    expr::{AggregateFunction, InList, InSubquery, ScalarFunction, WindowFunction},
33    lit, Between, BinaryExpr, Case, Cast, Expr, Like, Operator, TryCast,
34};
35
36use crate::common::data_type::{DataTypeMap, NullTreatment, PyScalarValue, RexType};
37use crate::errors::{
38    py_runtime_err, py_type_err, py_unsupported_variant_err, PyDataFusionError, PyDataFusionResult,
39};
40use crate::expr::aggregate_expr::PyAggregateFunction;
41use crate::expr::binary_expr::PyBinaryExpr;
42use crate::expr::column::PyColumn;
43use crate::expr::literal::PyLiteral;
44use crate::functions::add_builder_fns_to_window;
45use crate::pyarrow_util::scalar_to_pyarrow;
46use crate::sql::logical::PyLogicalPlan;
47
48use self::alias::PyAlias;
49use self::bool_expr::{
50    PyIsFalse, PyIsNotFalse, PyIsNotNull, PyIsNotTrue, PyIsNotUnknown, PyIsNull, PyIsTrue,
51    PyIsUnknown, PyNegative, PyNot,
52};
53use self::like::{PyILike, PyLike, PySimilarTo};
54use self::scalar_variable::PyScalarVariable;
55
56pub mod aggregate;
57pub mod aggregate_expr;
58pub mod alias;
59pub mod analyze;
60pub mod between;
61pub mod binary_expr;
62pub mod bool_expr;
63pub mod case;
64pub mod cast;
65pub mod column;
66pub mod conditional_expr;
67pub mod create_memory_table;
68pub mod create_view;
69pub mod distinct;
70pub mod drop_table;
71pub mod empty_relation;
72pub mod exists;
73pub mod explain;
74pub mod extension;
75pub mod filter;
76pub mod grouping_set;
77pub mod in_list;
78pub mod in_subquery;
79pub mod join;
80pub mod like;
81pub mod limit;
82pub mod literal;
83pub mod logical_node;
84pub mod placeholder;
85pub mod projection;
86pub mod repartition;
87pub mod scalar_subquery;
88pub mod scalar_variable;
89pub mod signature;
90pub mod sort;
91pub mod sort_expr;
92pub mod subquery;
93pub mod subquery_alias;
94pub mod table_scan;
95pub mod union;
96pub mod unnest;
97pub mod unnest_expr;
98pub mod window;
99
100use sort_expr::{to_sort_expressions, PySortExpr};
101
102/// A PyExpr that can be used on a DataFrame
103#[pyclass(name = "Expr", module = "datafusion.expr", subclass)]
104#[derive(Debug, Clone)]
105pub struct PyExpr {
106    pub expr: Expr,
107}
108
109impl From<PyExpr> for Expr {
110    fn from(expr: PyExpr) -> Expr {
111        expr.expr
112    }
113}
114
115impl From<Expr> for PyExpr {
116    fn from(expr: Expr) -> PyExpr {
117        PyExpr { expr }
118    }
119}
120
121/// Convert a list of DataFusion Expr to PyExpr
122pub fn py_expr_list(expr: &[Expr]) -> PyResult<Vec<PyExpr>> {
123    Ok(expr.iter().map(|e| PyExpr::from(e.clone())).collect())
124}
125
126#[pymethods]
127impl PyExpr {
128    /// Return the specific expression
129    fn to_variant(&self, py: Python) -> PyResult<PyObject> {
130        Python::with_gil(|_| {
131            match &self.expr {
132            Expr::Alias(alias) => Ok(PyAlias::from(alias.clone()).into_py(py)),
133            Expr::Column(col) => Ok(PyColumn::from(col.clone()).into_py(py)),
134            Expr::ScalarVariable(data_type, variables) => {
135                Ok(PyScalarVariable::new(data_type, variables).into_py(py))
136            }
137            Expr::Like(value) => Ok(PyLike::from(value.clone()).into_py(py)),
138            Expr::Literal(value) => Ok(PyLiteral::from(value.clone()).into_py(py)),
139            Expr::BinaryExpr(expr) => Ok(PyBinaryExpr::from(expr.clone()).into_py(py)),
140            Expr::Not(expr) => Ok(PyNot::new(*expr.clone()).into_py(py)),
141            Expr::IsNotNull(expr) => Ok(PyIsNotNull::new(*expr.clone()).into_py(py)),
142            Expr::IsNull(expr) => Ok(PyIsNull::new(*expr.clone()).into_py(py)),
143            Expr::IsTrue(expr) => Ok(PyIsTrue::new(*expr.clone()).into_py(py)),
144            Expr::IsFalse(expr) => Ok(PyIsFalse::new(*expr.clone()).into_py(py)),
145            Expr::IsUnknown(expr) => Ok(PyIsUnknown::new(*expr.clone()).into_py(py)),
146            Expr::IsNotTrue(expr) => Ok(PyIsNotTrue::new(*expr.clone()).into_py(py)),
147            Expr::IsNotFalse(expr) => Ok(PyIsNotFalse::new(*expr.clone()).into_py(py)),
148            Expr::IsNotUnknown(expr) => Ok(PyIsNotUnknown::new(*expr.clone()).into_py(py)),
149            Expr::Negative(expr) => Ok(PyNegative::new(*expr.clone()).into_py(py)),
150            Expr::AggregateFunction(expr) => {
151                Ok(PyAggregateFunction::from(expr.clone()).into_py(py))
152            }
153            Expr::SimilarTo(value) => Ok(PySimilarTo::from(value.clone()).into_py(py)),
154            Expr::Between(value) => Ok(between::PyBetween::from(value.clone()).into_py(py)),
155            Expr::Case(value) => Ok(case::PyCase::from(value.clone()).into_py(py)),
156            Expr::Cast(value) => Ok(cast::PyCast::from(value.clone()).into_py(py)),
157            Expr::TryCast(value) => Ok(cast::PyTryCast::from(value.clone()).into_py(py)),
158            Expr::ScalarFunction(value) => Err(py_unsupported_variant_err(format!(
159                "Converting Expr::ScalarFunction to a Python object is not implemented: {:?}",
160                value
161            ))),
162            Expr::WindowFunction(value) => Err(py_unsupported_variant_err(format!(
163                "Converting Expr::WindowFunction to a Python object is not implemented: {:?}",
164                value
165            ))),
166            Expr::InList(value) => Ok(in_list::PyInList::from(value.clone()).into_py(py)),
167            Expr::Exists(value) => Ok(exists::PyExists::from(value.clone()).into_py(py)),
168            Expr::InSubquery(value) => {
169                Ok(in_subquery::PyInSubquery::from(value.clone()).into_py(py))
170            }
171            Expr::ScalarSubquery(value) => {
172                Ok(scalar_subquery::PyScalarSubquery::from(value.clone()).into_py(py))
173            }
174            Expr::Wildcard { qualifier, options } => Err(py_unsupported_variant_err(format!(
175                "Converting Expr::Wildcard to a Python object is not implemented : {:?} {:?}",
176                qualifier, options
177            ))),
178            Expr::GroupingSet(value) => {
179                Ok(grouping_set::PyGroupingSet::from(value.clone()).into_py(py))
180            }
181            Expr::Placeholder(value) => {
182                Ok(placeholder::PyPlaceholder::from(value.clone()).into_py(py))
183            }
184            Expr::OuterReferenceColumn(data_type, column) => Err(py_unsupported_variant_err(format!(
185                "Converting Expr::OuterReferenceColumn to a Python object is not implemented: {:?} - {:?}",
186                data_type, column
187            ))),
188            Expr::Unnest(value) => Ok(unnest_expr::PyUnnestExpr::from(value.clone()).into_py(py)),
189        }
190        })
191    }
192
193    /// Returns the name of this expression as it should appear in a schema. This name
194    /// will not include any CAST expressions.
195    fn schema_name(&self) -> PyResult<String> {
196        Ok(format!("{}", self.expr.schema_name()))
197    }
198
199    /// Returns a full and complete string representation of this expression.
200    fn canonical_name(&self) -> PyResult<String> {
201        Ok(format!("{}", self.expr))
202    }
203
204    /// Returns the name of the Expr variant.
205    /// Ex: 'IsNotNull', 'Literal', 'BinaryExpr', etc
206    fn variant_name(&self) -> PyResult<&str> {
207        Ok(self.expr.variant_name())
208    }
209
210    fn __richcmp__(&self, other: PyExpr, op: CompareOp) -> PyExpr {
211        let expr = match op {
212            CompareOp::Lt => self.expr.clone().lt(other.expr),
213            CompareOp::Le => self.expr.clone().lt_eq(other.expr),
214            CompareOp::Eq => self.expr.clone().eq(other.expr),
215            CompareOp::Ne => self.expr.clone().not_eq(other.expr),
216            CompareOp::Gt => self.expr.clone().gt(other.expr),
217            CompareOp::Ge => self.expr.clone().gt_eq(other.expr),
218        };
219        expr.into()
220    }
221
222    fn __repr__(&self) -> PyResult<String> {
223        Ok(format!("Expr({})", self.expr))
224    }
225
226    fn __add__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
227        Ok((self.expr.clone() + rhs.expr).into())
228    }
229
230    fn __sub__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
231        Ok((self.expr.clone() - rhs.expr).into())
232    }
233
234    fn __truediv__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
235        Ok((self.expr.clone() / rhs.expr).into())
236    }
237
238    fn __mul__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
239        Ok((self.expr.clone() * rhs.expr).into())
240    }
241
242    fn __mod__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
243        let expr = self.expr.clone() % rhs.expr;
244        Ok(expr.into())
245    }
246
247    fn __and__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
248        Ok(self.expr.clone().and(rhs.expr).into())
249    }
250
251    fn __or__(&self, rhs: PyExpr) -> PyResult<PyExpr> {
252        Ok(self.expr.clone().or(rhs.expr).into())
253    }
254
255    fn __invert__(&self) -> PyResult<PyExpr> {
256        let expr = !self.expr.clone();
257        Ok(expr.into())
258    }
259
260    fn __getitem__(&self, key: &str) -> PyResult<PyExpr> {
261        Ok(self.expr.clone().field(key).into())
262    }
263
264    #[staticmethod]
265    pub fn literal(value: PyScalarValue) -> PyExpr {
266        lit(value.0).into()
267    }
268
269    #[staticmethod]
270    pub fn column(value: &str) -> PyExpr {
271        col(value).into()
272    }
273
274    /// assign a name to the PyExpr
275    pub fn alias(&self, name: &str) -> PyExpr {
276        self.expr.clone().alias(name).into()
277    }
278
279    /// Create a sort PyExpr from an existing PyExpr.
280    #[pyo3(signature = (ascending=true, nulls_first=true))]
281    pub fn sort(&self, ascending: bool, nulls_first: bool) -> PySortExpr {
282        self.expr.clone().sort(ascending, nulls_first).into()
283    }
284
285    pub fn is_null(&self) -> PyExpr {
286        self.expr.clone().is_null().into()
287    }
288
289    pub fn is_not_null(&self) -> PyExpr {
290        self.expr.clone().is_not_null().into()
291    }
292
293    pub fn cast(&self, to: PyArrowType<DataType>) -> PyExpr {
294        // self.expr.cast_to() requires DFSchema to validate that the cast
295        // is supported, omit that for now
296        let expr = Expr::Cast(Cast::new(Box::new(self.expr.clone()), to.0));
297        expr.into()
298    }
299
300    #[pyo3(signature = (low, high, negated=false))]
301    pub fn between(&self, low: PyExpr, high: PyExpr, negated: bool) -> PyExpr {
302        let expr = Expr::Between(Between::new(
303            Box::new(self.expr.clone()),
304            negated,
305            Box::new(low.into()),
306            Box::new(high.into()),
307        ));
308        expr.into()
309    }
310
311    /// A Rex (Row Expression) specifies a single row of data. That specification
312    /// could include user defined functions or types. RexType identifies the row
313    /// as one of the possible valid `RexTypes`.
314    pub fn rex_type(&self) -> PyResult<RexType> {
315        Ok(match self.expr {
316            Expr::Alias(..) => RexType::Alias,
317            Expr::Column(..) => RexType::Reference,
318            Expr::ScalarVariable(..) | Expr::Literal(..) => RexType::Literal,
319            Expr::BinaryExpr { .. }
320            | Expr::Not(..)
321            | Expr::IsNotNull(..)
322            | Expr::Negative(..)
323            | Expr::IsNull(..)
324            | Expr::Like { .. }
325            | Expr::SimilarTo { .. }
326            | Expr::Between { .. }
327            | Expr::Case { .. }
328            | Expr::Cast { .. }
329            | Expr::TryCast { .. }
330            | Expr::ScalarFunction { .. }
331            | Expr::AggregateFunction { .. }
332            | Expr::WindowFunction { .. }
333            | Expr::InList { .. }
334            | Expr::Wildcard { .. }
335            | Expr::Exists { .. }
336            | Expr::InSubquery { .. }
337            | Expr::GroupingSet(..)
338            | Expr::IsTrue(..)
339            | Expr::IsFalse(..)
340            | Expr::IsUnknown(_)
341            | Expr::IsNotTrue(..)
342            | Expr::IsNotFalse(..)
343            | Expr::Placeholder { .. }
344            | Expr::OuterReferenceColumn(_, _)
345            | Expr::Unnest(_)
346            | Expr::IsNotUnknown(_) => RexType::Call,
347            Expr::ScalarSubquery(..) => RexType::ScalarSubquery,
348        })
349    }
350
351    /// Given the current `Expr` return the DataTypeMap which represents the
352    /// PythonType, Arrow DataType, and SqlType Enum which represents
353    pub fn types(&self) -> PyResult<DataTypeMap> {
354        Self::_types(&self.expr)
355    }
356
357    /// Extracts the Expr value into a PyObject that can be shared with Python
358    pub fn python_value(&self, py: Python) -> PyResult<PyObject> {
359        match &self.expr {
360            Expr::Literal(scalar_value) => scalar_to_pyarrow(scalar_value, py),
361            _ => Err(py_type_err(format!(
362                "Non Expr::Literal encountered in types: {:?}",
363                &self.expr
364            ))),
365        }
366    }
367
368    /// Row expressions, Rex(s), operate on the concept of operands. Different variants of Expressions, Expr(s),
369    /// store those operands in different datastructures. This function examines the Expr variant and returns
370    /// the operands to the calling logic as a Vec of PyExpr instances.
371    pub fn rex_call_operands(&self) -> PyResult<Vec<PyExpr>> {
372        match &self.expr {
373            // Expr variants that are themselves the operand to return
374            Expr::Column(..) | Expr::ScalarVariable(..) | Expr::Literal(..) => {
375                Ok(vec![PyExpr::from(self.expr.clone())])
376            }
377
378            Expr::Alias(alias) => Ok(vec![PyExpr::from(*alias.expr.clone())]),
379
380            // Expr(s) that house the Expr instance to return in their bounded params
381            Expr::Not(expr)
382            | Expr::IsNull(expr)
383            | Expr::IsNotNull(expr)
384            | Expr::IsTrue(expr)
385            | Expr::IsFalse(expr)
386            | Expr::IsUnknown(expr)
387            | Expr::IsNotTrue(expr)
388            | Expr::IsNotFalse(expr)
389            | Expr::IsNotUnknown(expr)
390            | Expr::Negative(expr)
391            | Expr::Cast(Cast { expr, .. })
392            | Expr::TryCast(TryCast { expr, .. })
393            | Expr::InSubquery(InSubquery { expr, .. }) => Ok(vec![PyExpr::from(*expr.clone())]),
394
395            // Expr variants containing a collection of Expr(s) for operands
396            Expr::AggregateFunction(AggregateFunction { args, .. })
397            | Expr::ScalarFunction(ScalarFunction { args, .. })
398            | Expr::WindowFunction(WindowFunction { args, .. }) => {
399                Ok(args.iter().map(|arg| PyExpr::from(arg.clone())).collect())
400            }
401
402            // Expr(s) that require more specific processing
403            Expr::Case(Case {
404                expr,
405                when_then_expr,
406                else_expr,
407            }) => {
408                let mut operands: Vec<PyExpr> = Vec::new();
409
410                if let Some(e) = expr {
411                    for (when, then) in when_then_expr {
412                        operands.push(PyExpr::from(Expr::BinaryExpr(BinaryExpr::new(
413                            Box::new(*e.clone()),
414                            Operator::Eq,
415                            Box::new(*when.clone()),
416                        ))));
417                        operands.push(PyExpr::from(*then.clone()));
418                    }
419                } else {
420                    for (when, then) in when_then_expr {
421                        operands.push(PyExpr::from(*when.clone()));
422                        operands.push(PyExpr::from(*then.clone()));
423                    }
424                };
425
426                if let Some(e) = else_expr {
427                    operands.push(PyExpr::from(*e.clone()));
428                };
429
430                Ok(operands)
431            }
432            Expr::InList(InList { expr, list, .. }) => {
433                let mut operands: Vec<PyExpr> = vec![PyExpr::from(*expr.clone())];
434                for list_elem in list {
435                    operands.push(PyExpr::from(list_elem.clone()));
436                }
437
438                Ok(operands)
439            }
440            Expr::BinaryExpr(BinaryExpr { left, right, .. }) => Ok(vec![
441                PyExpr::from(*left.clone()),
442                PyExpr::from(*right.clone()),
443            ]),
444            Expr::Like(Like { expr, pattern, .. }) => Ok(vec![
445                PyExpr::from(*expr.clone()),
446                PyExpr::from(*pattern.clone()),
447            ]),
448            Expr::SimilarTo(Like { expr, pattern, .. }) => Ok(vec![
449                PyExpr::from(*expr.clone()),
450                PyExpr::from(*pattern.clone()),
451            ]),
452            Expr::Between(Between {
453                expr,
454                negated: _,
455                low,
456                high,
457            }) => Ok(vec![
458                PyExpr::from(*expr.clone()),
459                PyExpr::from(*low.clone()),
460                PyExpr::from(*high.clone()),
461            ]),
462
463            // Currently un-support/implemented Expr types for Rex Call operations
464            Expr::GroupingSet(..)
465            | Expr::Unnest(_)
466            | Expr::OuterReferenceColumn(_, _)
467            | Expr::Wildcard { .. }
468            | Expr::ScalarSubquery(..)
469            | Expr::Placeholder { .. }
470            | Expr::Exists { .. } => Err(py_runtime_err(format!(
471                "Unimplemented Expr type: {}",
472                self.expr
473            ))),
474        }
475    }
476
477    /// Extracts the operator associated with a RexType::Call
478    pub fn rex_call_operator(&self) -> PyResult<String> {
479        Ok(match &self.expr {
480            Expr::BinaryExpr(BinaryExpr {
481                left: _,
482                op,
483                right: _,
484            }) => format!("{op}"),
485            Expr::ScalarFunction(ScalarFunction { func, args: _ }) => func.name().to_string(),
486            Expr::Cast { .. } => "cast".to_string(),
487            Expr::Between { .. } => "between".to_string(),
488            Expr::Case { .. } => "case".to_string(),
489            Expr::IsNull(..) => "is null".to_string(),
490            Expr::IsNotNull(..) => "is not null".to_string(),
491            Expr::IsTrue(_) => "is true".to_string(),
492            Expr::IsFalse(_) => "is false".to_string(),
493            Expr::IsUnknown(_) => "is unknown".to_string(),
494            Expr::IsNotTrue(_) => "is not true".to_string(),
495            Expr::IsNotFalse(_) => "is not false".to_string(),
496            Expr::IsNotUnknown(_) => "is not unknown".to_string(),
497            Expr::InList { .. } => "in list".to_string(),
498            Expr::Negative(..) => "negative".to_string(),
499            Expr::Not(..) => "not".to_string(),
500            Expr::Like(Like {
501                negated,
502                case_insensitive,
503                ..
504            }) => {
505                let name = if *case_insensitive { "ilike" } else { "like" };
506                if *negated {
507                    format!("not {name}")
508                } else {
509                    name.to_string()
510                }
511            }
512            Expr::SimilarTo(Like { negated, .. }) => {
513                if *negated {
514                    "not similar to".to_string()
515                } else {
516                    "similar to".to_string()
517                }
518            }
519            _ => {
520                return Err(py_type_err(format!(
521                    "Catch all triggered in get_operator_name: {:?}",
522                    &self.expr
523                )))
524            }
525        })
526    }
527
528    pub fn column_name(&self, plan: PyLogicalPlan) -> PyResult<String> {
529        self._column_name(&plan.plan()).map_err(py_runtime_err)
530    }
531
532    // Expression Function Builder functions
533
534    pub fn order_by(&self, order_by: Vec<PySortExpr>) -> PyExprFuncBuilder {
535        self.expr
536            .clone()
537            .order_by(to_sort_expressions(order_by))
538            .into()
539    }
540
541    pub fn filter(&self, filter: PyExpr) -> PyExprFuncBuilder {
542        self.expr.clone().filter(filter.expr.clone()).into()
543    }
544
545    pub fn distinct(&self) -> PyExprFuncBuilder {
546        self.expr.clone().distinct().into()
547    }
548
549    pub fn null_treatment(&self, null_treatment: NullTreatment) -> PyExprFuncBuilder {
550        self.expr
551            .clone()
552            .null_treatment(Some(null_treatment.into()))
553            .into()
554    }
555
556    pub fn partition_by(&self, partition_by: Vec<PyExpr>) -> PyExprFuncBuilder {
557        let partition_by = partition_by.iter().map(|e| e.expr.clone()).collect();
558        self.expr.clone().partition_by(partition_by).into()
559    }
560
561    pub fn window_frame(&self, window_frame: PyWindowFrame) -> PyExprFuncBuilder {
562        self.expr.clone().window_frame(window_frame.into()).into()
563    }
564
565    #[pyo3(signature = (partition_by=None, window_frame=None, order_by=None, null_treatment=None))]
566    pub fn over(
567        &self,
568        partition_by: Option<Vec<PyExpr>>,
569        window_frame: Option<PyWindowFrame>,
570        order_by: Option<Vec<PySortExpr>>,
571        null_treatment: Option<NullTreatment>,
572    ) -> PyDataFusionResult<PyExpr> {
573        match &self.expr {
574            Expr::AggregateFunction(agg_fn) => {
575                let window_fn = Expr::WindowFunction(WindowFunction::new(
576                    WindowFunctionDefinition::AggregateUDF(agg_fn.func.clone()),
577                    agg_fn.args.clone(),
578                ));
579
580                add_builder_fns_to_window(
581                    window_fn,
582                    partition_by,
583                    window_frame,
584                    order_by,
585                    null_treatment,
586                )
587            }
588            Expr::WindowFunction(_) => add_builder_fns_to_window(
589                self.expr.clone(),
590                partition_by,
591                window_frame,
592                order_by,
593                null_treatment,
594            ),
595            _ => Err(
596                PyDataFusionError::ExecutionError(datafusion::error::DataFusionError::Plan(
597                    format!("Using {} with `over` is not allowed. Must use an aggregate or window function.", self.expr.variant_name()),
598                ))
599            ),
600        }
601    }
602}
603
604#[pyclass(name = "ExprFuncBuilder", module = "datafusion.expr", subclass)]
605#[derive(Debug, Clone)]
606pub struct PyExprFuncBuilder {
607    pub builder: ExprFuncBuilder,
608}
609
610impl From<ExprFuncBuilder> for PyExprFuncBuilder {
611    fn from(builder: ExprFuncBuilder) -> Self {
612        Self { builder }
613    }
614}
615
616#[pymethods]
617impl PyExprFuncBuilder {
618    pub fn order_by(&self, order_by: Vec<PySortExpr>) -> PyExprFuncBuilder {
619        self.builder
620            .clone()
621            .order_by(to_sort_expressions(order_by))
622            .into()
623    }
624
625    pub fn filter(&self, filter: PyExpr) -> PyExprFuncBuilder {
626        self.builder.clone().filter(filter.expr.clone()).into()
627    }
628
629    pub fn distinct(&self) -> PyExprFuncBuilder {
630        self.builder.clone().distinct().into()
631    }
632
633    pub fn null_treatment(&self, null_treatment: NullTreatment) -> PyExprFuncBuilder {
634        self.builder
635            .clone()
636            .null_treatment(Some(null_treatment.into()))
637            .into()
638    }
639
640    pub fn partition_by(&self, partition_by: Vec<PyExpr>) -> PyExprFuncBuilder {
641        let partition_by = partition_by.iter().map(|e| e.expr.clone()).collect();
642        self.builder.clone().partition_by(partition_by).into()
643    }
644
645    pub fn window_frame(&self, window_frame: PyWindowFrame) -> PyExprFuncBuilder {
646        self.builder
647            .clone()
648            .window_frame(window_frame.into())
649            .into()
650    }
651
652    pub fn build(&self) -> PyDataFusionResult<PyExpr> {
653        Ok(self.builder.clone().build().map(|expr| expr.into())?)
654    }
655}
656
657impl PyExpr {
658    pub fn _column_name(&self, plan: &LogicalPlan) -> PyDataFusionResult<String> {
659        let field = Self::expr_to_field(&self.expr, plan)?;
660        Ok(field.name().to_owned())
661    }
662
663    /// Create a [Field] representing an [Expr], given an input [LogicalPlan] to resolve against
664    pub fn expr_to_field(expr: &Expr, input_plan: &LogicalPlan) -> PyDataFusionResult<Arc<Field>> {
665        match expr {
666            Expr::Wildcard { .. } => {
667                // Since * could be any of the valid column names just return the first one
668                Ok(Arc::new(input_plan.schema().field(0).clone()))
669            }
670            _ => {
671                let fields = exprlist_to_fields(&[expr.clone()], input_plan)?;
672                Ok(fields[0].1.clone())
673            }
674        }
675    }
676    fn _types(expr: &Expr) -> PyResult<DataTypeMap> {
677        match expr {
678            Expr::BinaryExpr(BinaryExpr {
679                left: _,
680                op,
681                right: _,
682            }) => match op {
683                Operator::Eq
684                | Operator::NotEq
685                | Operator::Lt
686                | Operator::LtEq
687                | Operator::Gt
688                | Operator::GtEq
689                | Operator::And
690                | Operator::Or
691                | Operator::IsDistinctFrom
692                | Operator::IsNotDistinctFrom
693                | Operator::RegexMatch
694                | Operator::RegexIMatch
695                | Operator::RegexNotMatch
696                | Operator::RegexNotIMatch
697                | Operator::LikeMatch
698                | Operator::ILikeMatch
699                | Operator::NotLikeMatch
700                | Operator::NotILikeMatch => DataTypeMap::map_from_arrow_type(&DataType::Boolean),
701                Operator::Plus | Operator::Minus | Operator::Multiply | Operator::Modulo => {
702                    DataTypeMap::map_from_arrow_type(&DataType::Int64)
703                }
704                Operator::Divide => DataTypeMap::map_from_arrow_type(&DataType::Float64),
705                Operator::StringConcat => DataTypeMap::map_from_arrow_type(&DataType::Utf8),
706                Operator::BitwiseShiftLeft
707                | Operator::BitwiseShiftRight
708                | Operator::BitwiseXor
709                | Operator::BitwiseAnd
710                | Operator::BitwiseOr => DataTypeMap::map_from_arrow_type(&DataType::Binary),
711                Operator::AtArrow | Operator::ArrowAt => {
712                    Err(py_type_err(format!("Unsupported expr: ${op}")))
713                }
714            },
715            Expr::Cast(Cast { expr: _, data_type }) => DataTypeMap::map_from_arrow_type(data_type),
716            Expr::Literal(scalar_value) => DataTypeMap::map_from_scalar_value(scalar_value),
717            _ => Err(py_type_err(format!(
718                "Non Expr::Literal encountered in types: {:?}",
719                expr
720            ))),
721        }
722    }
723}
724
725/// Initializes the `expr` module to match the pattern of `datafusion-expr` https://docs.rs/datafusion-expr/latest/datafusion_expr/
726pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
727    m.add_class::<PyExpr>()?;
728    m.add_class::<PyColumn>()?;
729    m.add_class::<PyLiteral>()?;
730    m.add_class::<PyBinaryExpr>()?;
731    m.add_class::<PyLiteral>()?;
732    m.add_class::<PyAggregateFunction>()?;
733    m.add_class::<PyNot>()?;
734    m.add_class::<PyIsNotNull>()?;
735    m.add_class::<PyIsNull>()?;
736    m.add_class::<PyIsTrue>()?;
737    m.add_class::<PyIsFalse>()?;
738    m.add_class::<PyIsUnknown>()?;
739    m.add_class::<PyIsNotTrue>()?;
740    m.add_class::<PyIsNotFalse>()?;
741    m.add_class::<PyIsNotUnknown>()?;
742    m.add_class::<PyNegative>()?;
743    m.add_class::<PyLike>()?;
744    m.add_class::<PyILike>()?;
745    m.add_class::<PySimilarTo>()?;
746    m.add_class::<PyScalarVariable>()?;
747    m.add_class::<alias::PyAlias>()?;
748    m.add_class::<in_list::PyInList>()?;
749    m.add_class::<exists::PyExists>()?;
750    m.add_class::<subquery::PySubquery>()?;
751    m.add_class::<in_subquery::PyInSubquery>()?;
752    m.add_class::<scalar_subquery::PyScalarSubquery>()?;
753    m.add_class::<placeholder::PyPlaceholder>()?;
754    m.add_class::<grouping_set::PyGroupingSet>()?;
755    m.add_class::<case::PyCase>()?;
756    m.add_class::<conditional_expr::PyCaseBuilder>()?;
757    m.add_class::<cast::PyCast>()?;
758    m.add_class::<cast::PyTryCast>()?;
759    m.add_class::<between::PyBetween>()?;
760    m.add_class::<explain::PyExplain>()?;
761    m.add_class::<limit::PyLimit>()?;
762    m.add_class::<aggregate::PyAggregate>()?;
763    m.add_class::<sort::PySort>()?;
764    m.add_class::<analyze::PyAnalyze>()?;
765    m.add_class::<empty_relation::PyEmptyRelation>()?;
766    m.add_class::<join::PyJoin>()?;
767    m.add_class::<join::PyJoinType>()?;
768    m.add_class::<join::PyJoinConstraint>()?;
769    m.add_class::<union::PyUnion>()?;
770    m.add_class::<unnest::PyUnnest>()?;
771    m.add_class::<unnest_expr::PyUnnestExpr>()?;
772    m.add_class::<extension::PyExtension>()?;
773    m.add_class::<filter::PyFilter>()?;
774    m.add_class::<projection::PyProjection>()?;
775    m.add_class::<table_scan::PyTableScan>()?;
776    m.add_class::<create_memory_table::PyCreateMemoryTable>()?;
777    m.add_class::<create_view::PyCreateView>()?;
778    m.add_class::<distinct::PyDistinct>()?;
779    m.add_class::<sort_expr::PySortExpr>()?;
780    m.add_class::<subquery_alias::PySubqueryAlias>()?;
781    m.add_class::<drop_table::PyDropTable>()?;
782    m.add_class::<repartition::PyPartitioning>()?;
783    m.add_class::<repartition::PyRepartition>()?;
784    m.add_class::<window::PyWindowExpr>()?;
785    m.add_class::<window::PyWindowFrame>()?;
786    m.add_class::<window::PyWindowFrameBound>()?;
787    Ok(())
788}