ndatafusion 0.1.1

Extensions and support for linear algebra in DataFusion
Documentation
use datafusion::arrow::datatypes::DataType;
use datafusion::common::Result;
use datafusion::logical_expr::{Signature, Volatility};

use crate::error::plan_error;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum ScalarCoercion {
    Boolean,
    Integer,
    Real,
}

pub(crate) fn any_signature(arg_count: usize) -> Signature {
    Signature::any(arg_count, Volatility::Immutable)
}

pub(crate) fn named_any_signature(arg_count: usize, parameter_names: &[&str]) -> Signature {
    Signature::any(arg_count, Volatility::Immutable)
        .with_parameter_names(parameter_names.to_vec())
        .expect("parameter names must match the fixed-arity signature")
}

pub(crate) fn user_defined_signature() -> Signature {
    Signature::user_defined(Volatility::Immutable)
}

pub(crate) fn named_user_defined_signature(parameter_names: &[&str]) -> Signature {
    Signature::user_defined(Volatility::Immutable)
        .with_parameter_names(parameter_names.to_vec())
        .expect("user-defined signatures accept named parameters")
}

fn coerce_scalar_type(
    function_name: &str,
    position: usize,
    data_type: &DataType,
    coercion: ScalarCoercion,
) -> Result<DataType> {
    match coercion {
        ScalarCoercion::Boolean => match data_type {
            DataType::Boolean | DataType::Null => Ok(DataType::Boolean),
            actual => Err(plan_error(
                function_name,
                format!("argument {position} must be a Boolean scalar, found {actual}"),
            )),
        },
        ScalarCoercion::Integer => match data_type {
            DataType::Int8
            | DataType::Int16
            | DataType::Int32
            | DataType::Int64
            | DataType::Null => Ok(DataType::Int64),
            DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
                Ok(DataType::UInt64)
            }
            actual => Err(plan_error(
                function_name,
                format!("argument {position} must be an integer scalar, found {actual}"),
            )),
        },
        ScalarCoercion::Real => match data_type {
            DataType::Float32
            | DataType::Float64
            | DataType::Int8
            | DataType::Int16
            | DataType::Int32
            | DataType::Int64
            | DataType::UInt8
            | DataType::UInt16
            | DataType::UInt32
            | DataType::UInt64
            | DataType::Null => Ok(DataType::Float64),
            actual => Err(plan_error(
                function_name,
                format!("argument {position} must be a numeric scalar, found {actual}"),
            )),
        },
    }
}

pub(crate) fn coerce_scalar_arguments(
    function_name: &str,
    arg_types: &[DataType],
    scalar_positions: &[(usize, ScalarCoercion)],
) -> Result<Vec<DataType>> {
    let mut coerced = arg_types.to_vec();
    for (position, coercion) in scalar_positions {
        let index =
            position.checked_sub(1).expect("argument positions are 1-based and must be positive");
        let Some(data_type) = arg_types.get(index) else {
            return Err(plan_error(function_name, format!("argument {position} is missing")));
        };
        coerced[index] = coerce_scalar_type(function_name, *position, data_type, *coercion)?;
    }
    Ok(coerced)
}

pub(crate) fn coerce_trailing_scalar_arguments(
    function_name: &str,
    arg_types: &[DataType],
    start_position: usize,
    coercion: ScalarCoercion,
) -> Result<Vec<DataType>> {
    let mut coerced = arg_types.to_vec();
    for (index, data_type) in arg_types.iter().enumerate().skip(start_position - 1) {
        let position = index + 1;
        coerced[index] = coerce_scalar_type(function_name, position, data_type, coercion)?;
    }
    Ok(coerced)
}

#[cfg(test)]
mod tests {
    use datafusion::arrow::datatypes::DataType;

    use super::{
        ScalarCoercion, coerce_scalar_arguments, coerce_trailing_scalar_arguments,
        named_user_defined_signature,
    };

    #[test]
    fn named_user_defined_signature_preserves_parameter_names() {
        let signature = named_user_defined_signature(&["matrix", "k"]);

        assert_eq!(
            signature.parameter_names.as_deref(),
            Some(["matrix".to_string(), "k".to_string()].as_slice())
        );
    }

    #[test]
    fn scalar_argument_coercion_normalizes_numeric_controls() {
        let arg_types = vec![DataType::Float32, DataType::UInt16, DataType::Int32];
        let coerced = coerce_scalar_arguments("matrix_exp", &arg_types, &[
            (2, ScalarCoercion::Integer),
            (3, ScalarCoercion::Real),
        ])
        .expect("valid scalar controls should coerce");

        assert_eq!(coerced, vec![DataType::Float32, DataType::UInt64, DataType::Float64]);
    }

    #[test]
    fn trailing_scalar_coercion_normalizes_variadic_integer_axes() {
        let arg_types = vec![DataType::Float64, DataType::Int16, DataType::UInt32];
        let coerced = coerce_trailing_scalar_arguments(
            "tensor_permute_axes",
            &arg_types,
            2,
            ScalarCoercion::Integer,
        )
        .expect("variadic integer axes should coerce");

        assert_eq!(coerced, vec![DataType::Float64, DataType::Int64, DataType::UInt64]);
    }

    #[test]
    fn scalar_argument_coercion_rejects_invalid_types() {
        let error = coerce_scalar_arguments(
            "matrix_power",
            &[DataType::Float64, DataType::Utf8],
            &[(2, ScalarCoercion::Real)],
        )
        .expect_err("non-numeric scalar controls must be rejected");

        let message = error.to_string();
        assert!(message.contains("matrix_power"));
        assert!(message.contains("argument 2 must be a numeric scalar"));
    }

    #[test]
    fn boolean_scalar_coercion_accepts_boolean_and_null() {
        let coerced = coerce_scalar_arguments(
            "linear_regression",
            &[DataType::Float32, DataType::Float32, DataType::Null],
            &[(3, ScalarCoercion::Boolean)],
        )
        .expect("null booleans should be castable for later validation");

        assert_eq!(coerced, vec![DataType::Float32, DataType::Float32, DataType::Boolean]);
    }
}