datafusion-python 53.0.0

Apache DataFusion DataFrame and SQL Query Engine
Documentation
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

use std::any::Any;
use std::hash::{Hash, Hasher};
use std::ptr::NonNull;
use std::sync::Arc;

use arrow::datatypes::{Field, FieldRef};
use arrow::pyarrow::ToPyArrow;
use datafusion::arrow::array::{ArrayData, make_array};
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType};
use datafusion::common::internal_err;
use datafusion::error::DataFusionError;
use datafusion::logical_expr::{
    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
    Volatility,
};
use datafusion_ffi::udf::FFI_ScalarUDF;
use datafusion_python_util::parse_volatility;
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyTuple};

use crate::array::PyArrowArrayExportable;
use crate::errors::{PyDataFusionResult, to_datafusion_err};
use crate::expr::PyExpr;

/// This struct holds the Python written function that is a
/// ScalarUDF.
#[derive(Debug)]
struct PythonFunctionScalarUDF {
    name: String,
    func: Py<PyAny>,
    signature: Signature,
    return_field: FieldRef,
}

impl PythonFunctionScalarUDF {
    fn new(
        name: String,
        func: Py<PyAny>,
        input_fields: Vec<Field>,
        return_field: Field,
        volatility: Volatility,
    ) -> Self {
        let input_types = input_fields.iter().map(|f| f.data_type().clone()).collect();
        let signature = Signature::exact(input_types, volatility);
        Self {
            name,
            func,
            signature,
            return_field: Arc::new(return_field),
        }
    }
}

impl Eq for PythonFunctionScalarUDF {}
impl PartialEq for PythonFunctionScalarUDF {
    fn eq(&self, other: &Self) -> bool {
        self.name == other.name
            && self.signature == other.signature
            && self.return_field == other.return_field
            && Python::attach(|py| self.func.bind(py).eq(other.func.bind(py)).unwrap_or(false))
    }
}

impl Hash for PythonFunctionScalarUDF {
    fn hash<H: Hasher>(&self, state: &mut H) {
        self.name.hash(state);
        self.signature.hash(state);
        self.return_field.hash(state);

        Python::attach(|py| {
            let py_hash = self.func.bind(py).hash().unwrap_or(0); // Handle unhashable objects

            state.write_isize(py_hash);
        });
    }
}

impl ScalarUDFImpl for PythonFunctionScalarUDF {
    fn as_any(&self) -> &dyn Any {
        self
    }

    fn name(&self) -> &str {
        &self.name
    }

    fn signature(&self) -> &Signature {
        &self.signature
    }

    fn return_type(&self, _arg_types: &[DataType]) -> datafusion::common::Result<DataType> {
        internal_err!(
            "return_field should not be called when return_field_from_args is implemented."
        )
    }

    fn return_field_from_args(
        &self,
        _args: ReturnFieldArgs,
    ) -> datafusion::common::Result<FieldRef> {
        Ok(Arc::clone(&self.return_field))
    }

    fn invoke_with_args(
        &self,
        args: ScalarFunctionArgs,
    ) -> datafusion::common::Result<ColumnarValue> {
        let num_rows = args.number_rows;
        Python::attach(|py| {
            // 1. cast args to Pyarrow arrays
            let py_args = args
                .args
                .into_iter()
                .zip(args.arg_fields)
                .map(|(arg, field)| {
                    let array = arg.to_array(num_rows)?;
                    PyArrowArrayExportable::new(array, field)
                        .to_pyarrow(py)
                        .map_err(to_datafusion_err)
                })
                .collect::<Result<Vec<_>, _>>()?;
            let py_args = PyTuple::new(py, py_args).map_err(to_datafusion_err)?;

            // 2. call function
            let value = self
                .func
                .call(py, py_args, None)
                .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;

            // 3. cast to arrow::array::Array
            let array_data = ArrayData::from_pyarrow_bound(value.bind(py))
                .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
            Ok(ColumnarValue::Array(make_array(array_data)))
        })
    }
}

/// Represents a PyScalarUDF
#[pyclass(
    from_py_object,
    frozen,
    name = "ScalarUDF",
    module = "datafusion",
    subclass
)]
#[derive(Debug, Clone)]
pub struct PyScalarUDF {
    pub(crate) function: ScalarUDF,
}

#[pymethods]
impl PyScalarUDF {
    #[new]
    #[pyo3(signature=(name, func, input_types, return_type, volatility))]
    fn new(
        name: String,
        func: Py<PyAny>,
        input_types: PyArrowType<Vec<Field>>,
        return_type: PyArrowType<Field>,
        volatility: &str,
    ) -> PyResult<Self> {
        let py_function = PythonFunctionScalarUDF::new(
            name,
            func,
            input_types.0,
            return_type.0,
            parse_volatility(volatility)?,
        );
        let function = ScalarUDF::new_from_impl(py_function);

        Ok(Self { function })
    }

    #[staticmethod]
    pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
        if func.hasattr("__datafusion_scalar_udf__")? {
            let capsule = func.getattr("__datafusion_scalar_udf__")?.call0()?;
            let capsule = capsule.cast::<PyCapsule>().map_err(to_datafusion_err)?;
            let data: NonNull<FFI_ScalarUDF> = capsule
                .pointer_checked(Some(c"datafusion_scalar_udf"))?
                .cast();
            let udf = unsafe { data.as_ref() };
            let udf: Arc<dyn ScalarUDFImpl> = udf.into();

            Ok(Self {
                function: ScalarUDF::new_from_shared_impl(udf),
            })
        } else {
            Err(crate::errors::PyDataFusionError::Common(
                "__datafusion_scalar_udf__ does not exist on ScalarUDF object.".to_string(),
            ))
        }
    }

    /// creates a new PyExpr with the call of the udf
    #[pyo3(signature = (*args))]
    fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyExpr> {
        let args = args.iter().map(|e| e.expr.clone()).collect();
        Ok(self.function.call(args).into())
    }

    fn __repr__(&self) -> PyResult<String> {
        Ok(format!("ScalarUDF({})", self.function.name()))
    }
}