datafusion-python 53.0.0

Apache DataFusion DataFrame and SQL Query Engine
Documentation
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

use std::convert::TryFrom;
use std::result::Result;

use datafusion::common::{Column, ScalarValue};
use datafusion::logical_expr::expr::InList;
use datafusion::logical_expr::{Between, BinaryExpr, Expr, Operator};
/// Converts a Datafusion logical plan expression (Expr) into a PyArrow compute expression
use pyo3::{IntoPyObjectExt, prelude::*};

use crate::errors::{PyDataFusionError, PyDataFusionResult};
use crate::pyarrow_util::scalar_to_pyarrow;

#[derive(Debug)]
#[repr(transparent)]
pub(crate) struct PyArrowFilterExpression(Py<PyAny>);

fn operator_to_py<'py>(
    operator: &Operator,
    op: &Bound<'py, PyModule>,
) -> PyDataFusionResult<Bound<'py, PyAny>> {
    let py_op: Bound<'_, PyAny> = match operator {
        Operator::Eq => op.getattr("eq")?,
        Operator::NotEq => op.getattr("ne")?,
        Operator::Lt => op.getattr("lt")?,
        Operator::LtEq => op.getattr("le")?,
        Operator::Gt => op.getattr("gt")?,
        Operator::GtEq => op.getattr("ge")?,
        Operator::And => op.getattr("and_")?,
        Operator::Or => op.getattr("or_")?,
        _ => {
            return Err(PyDataFusionError::Common(format!(
                "Unsupported operator {operator:?}"
            )));
        }
    };
    Ok(py_op)
}

fn extract_scalar_list<'py>(
    exprs: &[Expr],
    py: Python<'py>,
) -> PyDataFusionResult<Vec<Bound<'py, PyAny>>> {
    exprs
        .iter()
        .map(|expr| match expr {
            // TODO: should we also leverage `ScalarValue::to_pyarrow` here?
            Expr::Literal(v, _) => match v {
                // The unwraps here are for infallible conversions
                ScalarValue::Boolean(Some(b)) => Ok(b.into_bound_py_any(py)?),
                ScalarValue::Int8(Some(i)) => Ok(i.into_bound_py_any(py)?),
                ScalarValue::Int16(Some(i)) => Ok(i.into_bound_py_any(py)?),
                ScalarValue::Int32(Some(i)) => Ok(i.into_bound_py_any(py)?),
                ScalarValue::Int64(Some(i)) => Ok(i.into_bound_py_any(py)?),
                ScalarValue::UInt8(Some(i)) => Ok(i.into_bound_py_any(py)?),
                ScalarValue::UInt16(Some(i)) => Ok(i.into_bound_py_any(py)?),
                ScalarValue::UInt32(Some(i)) => Ok(i.into_bound_py_any(py)?),
                ScalarValue::UInt64(Some(i)) => Ok(i.into_bound_py_any(py)?),
                ScalarValue::Float32(Some(f)) => Ok(f.into_bound_py_any(py)?),
                ScalarValue::Float64(Some(f)) => Ok(f.into_bound_py_any(py)?),
                ScalarValue::Utf8(Some(s)) => Ok(s.into_bound_py_any(py)?),
                _ => Err(PyDataFusionError::Common(format!(
                    "PyArrow can't handle ScalarValue: {v:?}"
                ))),
            },
            _ => Err(PyDataFusionError::Common(format!(
                "Only a list of Literals are supported got {expr:?}"
            ))),
        })
        .collect()
}

impl PyArrowFilterExpression {
    pub fn inner(&self) -> &Py<PyAny> {
        &self.0
    }
}

impl TryFrom<&Expr> for PyArrowFilterExpression {
    type Error = PyDataFusionError;

    // Converts a Datafusion filter Expr into an expression string that can be evaluated by Python
    // Note that pyarrow.compute.{field,scalar} are put into Python globals() when evaluated
    // isin, is_null, and is_valid (~is_null) are methods of pyarrow.dataset.Expression
    // https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html#pyarrow-dataset-expression
    fn try_from(expr: &Expr) -> Result<Self, Self::Error> {
        Python::attach(|py| {
            let pc = Python::import(py, "pyarrow.compute")?;
            let op_module = Python::import(py, "operator")?;
            let pc_expr: PyDataFusionResult<Bound<'_, PyAny>> = match expr {
                Expr::Column(Column { name, .. }) => Ok(pc.getattr("field")?.call1((name,))?),
                Expr::Literal(scalar, _) => Ok(scalar_to_pyarrow(scalar, py)?),
                Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
                    let operator = operator_to_py(op, &op_module)?;
                    let left = PyArrowFilterExpression::try_from(left.as_ref())?.0;
                    let right = PyArrowFilterExpression::try_from(right.as_ref())?.0;
                    Ok(operator.call1((left, right))?)
                }
                Expr::Not(expr) => {
                    let operator = op_module.getattr("invert")?;
                    let py_expr = PyArrowFilterExpression::try_from(expr.as_ref())?.0;
                    Ok(operator.call1((py_expr,))?)
                }
                Expr::IsNotNull(expr) => {
                    let py_expr = PyArrowFilterExpression::try_from(expr.as_ref())?
                        .0
                        .into_bound(py);
                    Ok(py_expr.call_method0("is_valid")?)
                }
                Expr::IsNull(expr) => {
                    let expr = PyArrowFilterExpression::try_from(expr.as_ref())?
                        .0
                        .into_bound(py);

                    // https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html#pyarrow.dataset.Expression.is_null
                    // Whether floating-point NaNs are considered null.
                    let nan_is_null = false;

                    let res = expr.call_method1("is_null", (nan_is_null,))?;
                    Ok(res)
                }
                Expr::Between(Between {
                    expr,
                    negated,
                    low,
                    high,
                }) => {
                    let expr = PyArrowFilterExpression::try_from(expr.as_ref())?.0;
                    let low = PyArrowFilterExpression::try_from(low.as_ref())?.0;
                    let high = PyArrowFilterExpression::try_from(high.as_ref())?.0;
                    let and = op_module.getattr("and_")?;
                    let le = op_module.getattr("le")?;
                    let invert = op_module.getattr("invert")?;

                    // scalar <= field() returns a boolean expression so we need to use and to combine these
                    let ret = and.call1((
                        le.call1((low, expr.clone_ref(py)))?,
                        le.call1((expr, high))?,
                    ))?;

                    Ok(if *negated { invert.call1((ret,))? } else { ret })
                }
                Expr::InList(InList {
                    expr,
                    list,
                    negated,
                }) => {
                    let expr = PyArrowFilterExpression::try_from(expr.as_ref())?
                        .0
                        .into_bound(py);
                    let scalars = extract_scalar_list(list, py)?;
                    let ret = expr.call_method1("isin", (scalars,))?;
                    let invert = op_module.getattr("invert")?;

                    Ok(if *negated { invert.call1((ret,))? } else { ret })
                }
                _ => Err(PyDataFusionError::Common(format!(
                    "Unsupported Datafusion expression {expr:?}"
                ))),
            };
            Ok(PyArrowFilterExpression(pc_expr?.into()))
        })
    }
}