ndatafusion 0.1.1

Extensions and support for linear algebra in DataFusion
Documentation
use std::fmt::Display;

use datafusion::arrow::datatypes::DataType;
use datafusion::common::DataFusionError;

pub(crate) fn plan_error(function_name: &str, message: impl Display) -> DataFusionError {
    DataFusionError::Plan(format!("{function_name}: {message}"))
}

pub(crate) fn exec_error(function_name: &str, message: impl Display) -> DataFusionError {
    DataFusionError::Execution(format!("{function_name}: {message}"))
}

pub(crate) fn array_argument_required(function_name: &str, position: usize) -> DataFusionError {
    exec_error(function_name, format!("argument {position} must be an array column"))
}

pub(crate) fn scalar_argument_required(function_name: &str, position: usize) -> DataFusionError {
    plan_error(function_name, format!("argument {position} must be a non-null scalar"))
}

pub(crate) fn type_mismatch(
    function_name: &str,
    position: usize,
    expected: &str,
    actual: &DataType,
) -> DataFusionError {
    plan_error(function_name, format!("argument {position} expected {expected}, found {actual}"))
}

#[cfg(test)]
mod tests {
    use datafusion::arrow::datatypes::DataType;
    use datafusion::common::DataFusionError;

    use super::{
        array_argument_required, exec_error, plan_error, scalar_argument_required, type_mismatch,
    };

    #[test]
    fn plan_and_exec_error_helpers_prefix_function_name() {
        let plan = plan_error("vector_dot", "bad plan");
        let exec = exec_error("vector_dot", "bad exec");

        assert!(
            matches!(plan, DataFusionError::Plan(message) if message == "vector_dot: bad plan")
        );
        assert!(
            matches!(exec, DataFusionError::Execution(message) if message == "vector_dot: bad exec")
        );
    }

    #[test]
    fn argument_error_helpers_produce_stable_messages() {
        let array = array_argument_required("vector_dot", 1);
        let scalar = scalar_argument_required("linear_regression", 3);

        assert!(matches!(
            array,
            DataFusionError::Execution(message)
                if message == "vector_dot: argument 1 must be an array column"
        ));
        assert!(matches!(
            scalar,
            DataFusionError::Plan(message)
                if message == "linear_regression: argument 3 must be a non-null scalar"
        ));
    }

    #[test]
    fn type_mismatch_reports_expected_and_actual_types() {
        let error = type_mismatch(
            "tensor_sum_last_axis",
            1,
            "arrow.fixed_shape_tensor<Float64>",
            &DataType::Float64,
        );

        assert!(matches!(
            error,
            DataFusionError::Plan(message)
                if message == "tensor_sum_last_axis: argument 1 expected arrow.fixed_shape_tensor<Float64>, found Float64"
        ));
    }
}