datafusion-python 53.0.0

Apache DataFusion DataFrame and SQL Query Engine
Documentation
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

use std::ptr::NonNull;
use std::sync::Arc;

use datafusion::arrow::array::ArrayRef;
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
use datafusion::common::ScalarValue;
use datafusion::error::{DataFusionError, Result};
use datafusion::logical_expr::{
    Accumulator, AccumulatorFactoryFunction, AggregateUDF, AggregateUDFImpl, create_udaf,
};
use datafusion_ffi::udaf::FFI_AggregateUDF;
use datafusion_python_util::parse_volatility;
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyTuple};

use crate::common::data_type::PyScalarValue;
use crate::errors::{PyDataFusionResult, py_datafusion_err, to_datafusion_err};
use crate::expr::PyExpr;

#[derive(Debug)]
struct RustAccumulator {
    accum: Py<PyAny>,
}

impl RustAccumulator {
    fn new(accum: Py<PyAny>) -> Self {
        Self { accum }
    }
}

impl Accumulator for RustAccumulator {
    fn state(&mut self) -> Result<Vec<ScalarValue>> {
        Python::attach(|py| -> PyResult<Vec<ScalarValue>> {
            let values = self.accum.bind(py).call_method0("state")?;
            let mut scalars = Vec::new();
            for item in values.try_iter()? {
                let item: Bound<'_, PyAny> = item?;
                let scalar = item.extract::<PyScalarValue>()?.0;
                scalars.push(scalar);
            }
            Ok(scalars)
        })
        .map_err(|e| DataFusionError::Execution(format!("{e}")))
    }

    fn evaluate(&mut self) -> Result<ScalarValue> {
        Python::attach(|py| -> PyResult<ScalarValue> {
            let value = self.accum.bind(py).call_method0("evaluate")?;
            value.extract::<PyScalarValue>().map(|v| v.0)
        })
        .map_err(|e| DataFusionError::Execution(format!("{e}")))
    }

    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
        Python::attach(|py| {
            // 1. cast args to Pyarrow array
            let py_args = values
                .iter()
                .map(|arg| arg.to_data().to_pyarrow(py).unwrap())
                .collect::<Vec<_>>();
            let py_args = PyTuple::new(py, py_args).map_err(to_datafusion_err)?;

            // 2. call function
            self.accum
                .bind(py)
                .call_method1("update", py_args)
                .map_err(|e| DataFusionError::Execution(format!("{e}")))?;

            Ok(())
        })
    }

    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
        Python::attach(|py| {
            // // 1. cast states to Pyarrow arrays
            let py_states: Result<Vec<Bound<'_, PyAny>>> = states
                .iter()
                .map(|state| {
                    state
                        .to_data()
                        .to_pyarrow(py)
                        .map_err(|e| DataFusionError::Execution(format!("{e}")))
                })
                .collect();

            // 2. call merge
            self.accum
                .bind(py)
                .call_method1("merge", (py_states?,))
                .map_err(|e| DataFusionError::Execution(format!("{e}")))?;

            Ok(())
        })
    }

    fn size(&self) -> usize {
        std::mem::size_of_val(self)
    }

    fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
        Python::attach(|py| {
            // 1. cast args to Pyarrow array
            let py_args = values
                .iter()
                .map(|arg| arg.to_data().to_pyarrow(py).unwrap())
                .collect::<Vec<_>>();
            let py_args = PyTuple::new(py, py_args).map_err(to_datafusion_err)?;

            // 2. call function
            self.accum
                .bind(py)
                .call_method1("retract_batch", py_args)
                .map_err(|e| DataFusionError::Execution(format!("{e}")))?;

            Ok(())
        })
    }

    fn supports_retract_batch(&self) -> bool {
        Python::attach(
            |py| match self.accum.bind(py).call_method0("supports_retract_batch") {
                Ok(x) => x.extract().unwrap_or(false),
                Err(_) => false,
            },
        )
    }
}

pub fn to_rust_accumulator(accum: Py<PyAny>) -> AccumulatorFactoryFunction {
    Arc::new(move |_args| -> Result<Box<dyn Accumulator>> {
        let accum = Python::attach(|py| {
            accum
                .call0(py)
                .map_err(|e| DataFusionError::Execution(format!("{e}")))
        })?;
        Ok(Box::new(RustAccumulator::new(accum)))
    })
}

fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult<AggregateUDF> {
    let data: NonNull<FFI_AggregateUDF> = capsule
        .pointer_checked(Some(c"datafusion_aggregate_udf"))?
        .cast();
    let udaf = unsafe { data.as_ref() };
    let udaf: Arc<dyn AggregateUDFImpl> = udaf.into();

    Ok(AggregateUDF::new_from_shared_impl(udaf))
}

/// Represents an AggregateUDF
#[pyclass(
    from_py_object,
    frozen,
    name = "AggregateUDF",
    module = "datafusion",
    subclass
)]
#[derive(Debug, Clone)]
pub struct PyAggregateUDF {
    pub(crate) function: AggregateUDF,
}

#[pymethods]
impl PyAggregateUDF {
    #[new]
    #[pyo3(signature=(name, accumulator, input_type, return_type, state_type, volatility))]
    fn new(
        name: &str,
        accumulator: Py<PyAny>,
        input_type: PyArrowType<Vec<DataType>>,
        return_type: PyArrowType<DataType>,
        state_type: PyArrowType<Vec<DataType>>,
        volatility: &str,
    ) -> PyResult<Self> {
        let function = create_udaf(
            name,
            input_type.0,
            Arc::new(return_type.0),
            parse_volatility(volatility)?,
            to_rust_accumulator(accumulator),
            Arc::new(state_type.0),
        );
        Ok(Self { function })
    }

    #[staticmethod]
    pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
        if func.is_instance_of::<PyCapsule>() {
            let capsule = func.cast::<PyCapsule>().map_err(py_datafusion_err)?;
            let function = aggregate_udf_from_capsule(capsule)?;
            return Ok(Self { function });
        }

        if func.hasattr("__datafusion_aggregate_udf__")? {
            let capsule = func.getattr("__datafusion_aggregate_udf__")?.call0()?;
            let capsule = capsule.cast::<PyCapsule>().map_err(py_datafusion_err)?;
            let function = aggregate_udf_from_capsule(capsule)?;
            return Ok(Self { function });
        }

        Err(crate::errors::PyDataFusionError::Common(
            "__datafusion_aggregate_udf__ does not exist on AggregateUDF object.".to_string(),
        ))
    }

    /// creates a new PyExpr with the call of the udf
    #[pyo3(signature = (*args))]
    fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyExpr> {
        let args = args.iter().map(|e| e.expr.clone()).collect();
        Ok(self.function.call(args).into())
    }

    fn __repr__(&self) -> PyResult<String> {
        Ok(format!("AggregateUDF({})", self.function.name()))
    }
}