use std::any::Any;
use std::sync::{Arc, LazyLock};
use datafusion::arrow::array::types::{ArrowPrimitiveType, Float32Type, Float64Type};
use datafusion::arrow::datatypes::{DataType, FieldRef};
use datafusion::common::Result;
use datafusion::logical_expr::{
ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
Signature,
};
use nabled::core::prelude::NabledReal;
use ndarray::{Array1, ArrayView1, ArrayView2, Axis};
use ndarrow::NdarrowElement;
use super::common::{
complex_fixed_shape_tensor_view3, complex_fixed_size_list_array_from_flat_rows,
expect_fixed_size_list_arg, expect_real_scalar_arg, expect_real_scalar_argument,
expect_usize_scalar_arg, expect_usize_scalar_argument, fixed_shape_tensor_view3,
fixed_size_list_array_from_flat_rows, fixed_size_list_view2, nullable_or,
};
use super::docs::iterative_doc;
use crate::error::exec_error;
use crate::metadata::{
complex_vector_field, parse_complex_matrix_batch_field, parse_complex_vector_field,
parse_matrix_batch_field, parse_vector_field, vector_field,
};
use crate::signatures::{ScalarCoercion, coerce_scalar_arguments, named_user_defined_signature};
fn return_square_system(
args: &ReturnFieldArgs<'_>,
function_name: &str,
) -> Result<(DataType, usize)> {
let matrix = parse_matrix_batch_field(&args.arg_fields[0], function_name, 1)?;
let vector = parse_vector_field(&args.arg_fields[1], function_name, 2)?;
if matrix.value_type != vector.value_type {
return Err(exec_error(
function_name,
format!("value type mismatch: matrix {}, rhs {}", matrix.value_type, vector.value_type),
));
}
if matrix.rows != matrix.cols {
return Err(exec_error(
function_name,
format!(
"{function_name} requires square matrices, found ({}, {})",
matrix.rows, matrix.cols
),
));
}
if vector.len != matrix.cols {
return Err(exec_error(
function_name,
format!("rhs vector length mismatch: expected {}, found {}", matrix.cols, vector.len),
));
}
Ok((matrix.value_type, matrix.rows))
}
fn return_square_complex_system(args: &ReturnFieldArgs<'_>, function_name: &str) -> Result<usize> {
let matrix = parse_complex_matrix_batch_field(&args.arg_fields[0], function_name, 1)?;
let (_vector_field, vector) =
parse_complex_vector_field(&args.arg_fields[1], function_name, 2)?;
if matrix.rows != matrix.cols {
return Err(exec_error(
function_name,
format!(
"{function_name} requires square matrices, found ({}, {})",
matrix.rows, matrix.cols
),
));
}
if vector.len != matrix.cols {
return Err(exec_error(
function_name,
format!("rhs vector length mismatch: expected {}, found {}", matrix.cols, vector.len),
));
}
Ok(matrix.rows)
}
fn validate_tolerance(function_name: &str, tolerance: f64) -> Result<f64> {
if !tolerance.is_finite() {
return Err(exec_error(function_name, "tolerance must be finite"));
}
if tolerance <= 0.0 {
return Err(exec_error(function_name, "tolerance must be positive"));
}
Ok(tolerance)
}
fn validate_max_iterations(function_name: &str, max_iterations: usize) -> Result<usize> {
if max_iterations == 0 {
return Err(exec_error(function_name, "max_iterations must be greater than 0"));
}
Ok(max_iterations)
}
fn iterative_config_f32(
function_name: &str,
tolerance: f64,
max_iterations: usize,
) -> Result<nabled::ml::iterative::IterativeConfig<f32>> {
let tolerance = validate_tolerance(function_name, tolerance)?;
let max_iterations = validate_max_iterations(function_name, max_iterations)?;
let tolerance = tolerance.to_string().parse::<f32>().map_err(|error| {
exec_error(
function_name,
format!("tolerance could not be represented in matrix value type: {error}"),
)
})?;
Ok(nabled::ml::iterative::IterativeConfig { tolerance, max_iterations })
}
fn iterative_config_f64(
function_name: &str,
tolerance: f64,
max_iterations: usize,
) -> Result<nabled::ml::iterative::IterativeConfig<f64>> {
let tolerance = validate_tolerance(function_name, tolerance)?;
let max_iterations = validate_max_iterations(function_name, max_iterations)?;
Ok(nabled::ml::iterative::IterativeConfig { tolerance, max_iterations })
}
fn invoke_iterative_solver<T, E>(
args: &ScalarFunctionArgs,
function_name: &str,
config: &nabled::ml::iterative::IterativeConfig<T::Native>,
op: impl Fn(
&ArrayView2<'_, T::Native>,
&ArrayView1<'_, T::Native>,
&nabled::ml::iterative::IterativeConfig<T::Native>,
) -> std::result::Result<Array1<T::Native>, E>,
) -> Result<ColumnarValue>
where
T: ArrowPrimitiveType,
T::Native:
NdarrowElement + NabledReal + std::ops::SubAssign + nabled::linalg::lu::LuProviderScalar,
E: std::fmt::Display,
{
let matrices = expect_fixed_size_list_arg(args, 1, function_name)?;
let rhs = expect_fixed_size_list_arg(args, 2, function_name)?;
let matrix_view = fixed_shape_tensor_view3::<T>(&args.arg_fields[0], matrices, function_name)?;
let rhs_view = fixed_size_list_view2::<T>(rhs, function_name)?;
if matrix_view.len_of(Axis(0)) != rhs_view.nrows() {
return Err(exec_error(
function_name,
format!(
"batch length mismatch: {} matrices vs {} rhs vectors",
matrix_view.len_of(Axis(0)),
rhs_view.nrows()
),
));
}
if matrix_view.len_of(Axis(1)) != matrix_view.len_of(Axis(2)) {
return Err(exec_error(
function_name,
format!(
"{function_name} requires square matrices, found ({}, {})",
matrix_view.len_of(Axis(1)),
matrix_view.len_of(Axis(2))
),
));
}
if rhs_view.ncols() != matrix_view.len_of(Axis(2)) {
return Err(exec_error(
function_name,
format!(
"rhs vector length mismatch: expected {}, found {}",
matrix_view.len_of(Axis(2)),
rhs_view.ncols()
),
));
}
let mut output = Vec::with_capacity(rhs_view.len());
for row in 0..matrix_view.len_of(Axis(0)) {
let solution =
op(&matrix_view.index_axis(Axis(0), row), &rhs_view.index_axis(Axis(0), row), config)
.map_err(|error| exec_error(function_name, error))?;
output.extend(solution.iter().copied());
}
let output = fixed_size_list_array_from_flat_rows::<T>(
function_name,
rhs_view.nrows(),
rhs_view.ncols(),
&output,
)?;
Ok(ColumnarValue::Array(Arc::new(output)))
}
fn invoke_complex_iterative_solver<E>(
args: &ScalarFunctionArgs,
function_name: &str,
config: &nabled::ml::iterative::IterativeConfig<f64>,
op: impl Fn(
&ArrayView2<'_, num_complex::Complex64>,
&ArrayView1<'_, num_complex::Complex64>,
&nabled::ml::iterative::IterativeConfig<f64>,
) -> std::result::Result<Array1<num_complex::Complex64>, E>,
) -> Result<ColumnarValue>
where
E: std::fmt::Display,
{
let matrices = expect_fixed_size_list_arg(args, 1, function_name)?;
let rhs = expect_fixed_size_list_arg(args, 2, function_name)?;
let matrix_view =
complex_fixed_shape_tensor_view3(&args.arg_fields[0], matrices, function_name)?;
let rhs_view =
ndarrow::complex64_as_array_view2(rhs).map_err(|error| exec_error(function_name, error))?;
if matrix_view.len_of(Axis(0)) != rhs_view.nrows() {
return Err(exec_error(
function_name,
format!(
"batch length mismatch: {} matrices vs {} rhs vectors",
matrix_view.len_of(Axis(0)),
rhs_view.nrows()
),
));
}
if matrix_view.len_of(Axis(1)) != matrix_view.len_of(Axis(2)) {
return Err(exec_error(
function_name,
format!(
"{function_name} requires square matrices, found ({}, {})",
matrix_view.len_of(Axis(1)),
matrix_view.len_of(Axis(2))
),
));
}
if rhs_view.ncols() != matrix_view.len_of(Axis(2)) {
return Err(exec_error(
function_name,
format!(
"rhs vector length mismatch: expected {}, found {}",
matrix_view.len_of(Axis(2)),
rhs_view.ncols()
),
));
}
let mut output = Vec::with_capacity(rhs_view.len());
for row in 0..matrix_view.len_of(Axis(0)) {
let solution =
op(&matrix_view.index_axis(Axis(0), row), &rhs_view.index_axis(Axis(0), row), config)
.map_err(|error| exec_error(function_name, error))?;
output.extend(solution.iter().copied());
}
let output = complex_fixed_size_list_array_from_flat_rows(
function_name,
rhs_view.nrows(),
rhs_view.ncols(),
output,
)?;
Ok(ColumnarValue::Array(Arc::new(output)))
}
#[derive(Debug, PartialEq, Eq, Hash)]
struct MatrixConjugateGradient {
signature: Signature,
}
impl MatrixConjugateGradient {
fn new() -> Self {
Self {
signature: named_user_defined_signature(&[
"matrix",
"rhs",
"tolerance",
"max_iterations",
]),
}
}
}
impl ScalarUDFImpl for MatrixConjugateGradient {
fn as_any(&self) -> &dyn Any { self }
fn name(&self) -> &'static str { "matrix_conjugate_gradient" }
fn signature(&self) -> &Signature { &self.signature }
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
coerce_scalar_arguments(self.name(), arg_types, &[
(3, ScalarCoercion::Real),
(4, ScalarCoercion::Integer),
])
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
datafusion::common::internal_err!("return_field_from_args should be used instead")
}
fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
let (value_type, len) = return_square_system(&args, self.name())?;
let tolerance = expect_real_scalar_argument(&args, 3, self.name())?;
let max_iterations = expect_usize_scalar_argument(&args, 4, self.name())?;
let _ = validate_tolerance(self.name(), tolerance)?;
let _ = validate_max_iterations(self.name(), max_iterations)?;
vector_field(self.name(), &value_type, len, nullable_or(args.arg_fields))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let matrix = parse_matrix_batch_field(&args.arg_fields[0], self.name(), 1)?;
let vector = parse_vector_field(&args.arg_fields[1], self.name(), 2)?;
if matrix.value_type != vector.value_type {
return Err(exec_error(
self.name(),
format!(
"value type mismatch: matrix {}, rhs {}",
matrix.value_type, vector.value_type
),
));
}
if matrix.rows != matrix.cols {
return Err(exec_error(
self.name(),
format!(
"{} requires square matrices, found ({}, {})",
self.name(),
matrix.rows,
matrix.cols
),
));
}
if vector.len != matrix.cols {
return Err(exec_error(
self.name(),
format!(
"rhs vector length mismatch: expected {}, found {}",
matrix.cols, vector.len
),
));
}
let tolerance = expect_real_scalar_arg(&args, 3, self.name())?;
let max_iterations = expect_usize_scalar_arg(&args, 4, self.name())?;
match matrix.value_type {
DataType::Float32 => {
let config = iterative_config_f32(self.name(), tolerance, max_iterations)?;
invoke_iterative_solver::<Float32Type, _>(
&args,
self.name(),
&config,
nabled::ml::iterative::conjugate_gradient_view,
)
}
DataType::Float64 => {
let config = iterative_config_f64(self.name(), tolerance, max_iterations)?;
invoke_iterative_solver::<Float64Type, _>(
&args,
self.name(),
&config,
nabled::ml::iterative::conjugate_gradient_view,
)
}
actual => {
Err(exec_error(self.name(), format!("unsupported matrix value type {actual}")))
}
}
}
fn documentation(&self) -> Option<&Documentation> {
static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
iterative_doc(
"Solve each square linear system in the batch with conjugate gradient.",
"matrix_conjugate_gradient(matrix_batch, rhs_batch, tolerance => 1e-6, \
max_iterations => 64)",
)
.with_argument(
"matrix",
"Square dense matrix batch in canonical fixed-shape tensor form.",
)
.with_argument(
"rhs",
"Dense vector batch containing one right-hand side per matrix row.",
)
.with_argument("tolerance", "Positive finite convergence tolerance.")
.with_argument("max_iterations", "Positive integer iteration cap.")
.with_alternative_syntax(
"matrix_conjugate_gradient(matrix => matrix_batch, rhs => rhs_batch, tolerance => \
1e-6, max_iterations => 64)",
)
.build()
});
Some(&DOCUMENTATION)
}
}
#[derive(Debug, PartialEq, Eq, Hash)]
struct MatrixGmres {
signature: Signature,
}
impl MatrixGmres {
fn new() -> Self {
Self {
signature: named_user_defined_signature(&[
"matrix",
"rhs",
"tolerance",
"max_iterations",
]),
}
}
}
impl ScalarUDFImpl for MatrixGmres {
fn as_any(&self) -> &dyn Any { self }
fn name(&self) -> &'static str { "matrix_gmres" }
fn signature(&self) -> &Signature { &self.signature }
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
coerce_scalar_arguments(self.name(), arg_types, &[
(3, ScalarCoercion::Real),
(4, ScalarCoercion::Integer),
])
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
datafusion::common::internal_err!("return_field_from_args should be used instead")
}
fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
let (value_type, len) = return_square_system(&args, self.name())?;
let tolerance = expect_real_scalar_argument(&args, 3, self.name())?;
let max_iterations = expect_usize_scalar_argument(&args, 4, self.name())?;
let _ = validate_tolerance(self.name(), tolerance)?;
let _ = validate_max_iterations(self.name(), max_iterations)?;
vector_field(self.name(), &value_type, len, nullable_or(args.arg_fields))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let matrix = parse_matrix_batch_field(&args.arg_fields[0], self.name(), 1)?;
let vector = parse_vector_field(&args.arg_fields[1], self.name(), 2)?;
if matrix.value_type != vector.value_type {
return Err(exec_error(
self.name(),
format!(
"value type mismatch: matrix {}, rhs {}",
matrix.value_type, vector.value_type
),
));
}
if matrix.rows != matrix.cols {
return Err(exec_error(
self.name(),
format!(
"{} requires square matrices, found ({}, {})",
self.name(),
matrix.rows,
matrix.cols
),
));
}
if vector.len != matrix.cols {
return Err(exec_error(
self.name(),
format!(
"rhs vector length mismatch: expected {}, found {}",
matrix.cols, vector.len
),
));
}
let tolerance = expect_real_scalar_arg(&args, 3, self.name())?;
let max_iterations = expect_usize_scalar_arg(&args, 4, self.name())?;
match matrix.value_type {
DataType::Float32 => {
let config = iterative_config_f32(self.name(), tolerance, max_iterations)?;
invoke_iterative_solver::<Float32Type, _>(
&args,
self.name(),
&config,
nabled::ml::iterative::gmres_view,
)
}
DataType::Float64 => {
let config = iterative_config_f64(self.name(), tolerance, max_iterations)?;
invoke_iterative_solver::<Float64Type, _>(
&args,
self.name(),
&config,
nabled::ml::iterative::gmres_view,
)
}
actual => {
Err(exec_error(self.name(), format!("unsupported matrix value type {actual}")))
}
}
}
fn documentation(&self) -> Option<&Documentation> {
static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
iterative_doc(
"Solve each square linear system in the batch with GMRES.",
"matrix_gmres(matrix_batch, rhs_batch, tolerance => 1e-6, max_iterations => 64)",
)
.with_argument(
"matrix",
"Square dense matrix batch in canonical fixed-shape tensor form.",
)
.with_argument(
"rhs",
"Dense vector batch containing one right-hand side per matrix row.",
)
.with_argument("tolerance", "Positive finite convergence tolerance.")
.with_argument("max_iterations", "Positive integer iteration cap.")
.with_alternative_syntax(
"matrix_gmres(matrix => matrix_batch, rhs => rhs_batch, tolerance => 1e-6, \
max_iterations => 64)",
)
.build()
});
Some(&DOCUMENTATION)
}
}
#[derive(Debug, PartialEq, Eq, Hash)]
struct MatrixConjugateGradientComplex {
signature: Signature,
}
impl MatrixConjugateGradientComplex {
fn new() -> Self {
Self {
signature: named_user_defined_signature(&[
"matrix",
"rhs",
"tolerance",
"max_iterations",
]),
}
}
}
impl ScalarUDFImpl for MatrixConjugateGradientComplex {
fn as_any(&self) -> &dyn Any { self }
fn name(&self) -> &'static str { "matrix_conjugate_gradient_complex" }
fn signature(&self) -> &Signature { &self.signature }
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
coerce_scalar_arguments(self.name(), arg_types, &[
(3, ScalarCoercion::Real),
(4, ScalarCoercion::Integer),
])
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
datafusion::common::internal_err!("return_field_from_args should be used instead")
}
fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
let len = return_square_complex_system(&args, self.name())?;
let tolerance = expect_real_scalar_argument(&args, 3, self.name())?;
let max_iterations = expect_usize_scalar_argument(&args, 4, self.name())?;
let _ = validate_tolerance(self.name(), tolerance)?;
let _ = validate_max_iterations(self.name(), max_iterations)?;
complex_vector_field(self.name(), len, nullable_or(args.arg_fields))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let matrix = parse_complex_matrix_batch_field(&args.arg_fields[0], self.name(), 1)?;
let (_rhs_field, vector) = parse_complex_vector_field(&args.arg_fields[1], self.name(), 2)?;
if matrix.rows != matrix.cols {
return Err(exec_error(
self.name(),
format!(
"{} requires square matrices, found ({}, {})",
self.name(),
matrix.rows,
matrix.cols
),
));
}
if vector.len != matrix.cols {
return Err(exec_error(
self.name(),
format!(
"rhs vector length mismatch: expected {}, found {}",
matrix.cols, vector.len
),
));
}
let tolerance = expect_real_scalar_arg(&args, 3, self.name())?;
let max_iterations = expect_usize_scalar_arg(&args, 4, self.name())?;
let config = iterative_config_f64(self.name(), tolerance, max_iterations)?;
invoke_complex_iterative_solver(
&args,
self.name(),
&config,
nabled::ml::iterative::conjugate_gradient_complex_view,
)
}
fn documentation(&self) -> Option<&Documentation> {
static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
iterative_doc(
"Solve each square complex linear system in the batch with conjugate gradient.",
"matrix_conjugate_gradient_complex(matrix_batch, rhs_batch, tolerance => 1e-6, \
max_iterations => 64)",
)
.with_argument(
"matrix",
"Square dense complex matrix batch in canonical fixed-shape tensor form.",
)
.with_argument(
"rhs",
"Dense complex vector batch containing one right-hand side per matrix row.",
)
.with_argument("tolerance", "Positive finite convergence tolerance.")
.with_argument("max_iterations", "Positive integer iteration cap.")
.with_alternative_syntax(
"matrix_conjugate_gradient_complex(matrix => matrix_batch, rhs => rhs_batch, \
tolerance => 1e-6, max_iterations => 64)",
)
.build()
});
Some(&DOCUMENTATION)
}
}
#[derive(Debug, PartialEq, Eq, Hash)]
struct MatrixGmresComplex {
signature: Signature,
}
impl MatrixGmresComplex {
fn new() -> Self {
Self {
signature: named_user_defined_signature(&[
"matrix",
"rhs",
"tolerance",
"max_iterations",
]),
}
}
}
impl ScalarUDFImpl for MatrixGmresComplex {
fn as_any(&self) -> &dyn Any { self }
fn name(&self) -> &'static str { "matrix_gmres_complex" }
fn signature(&self) -> &Signature { &self.signature }
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
coerce_scalar_arguments(self.name(), arg_types, &[
(3, ScalarCoercion::Real),
(4, ScalarCoercion::Integer),
])
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
datafusion::common::internal_err!("return_field_from_args should be used instead")
}
fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
let len = return_square_complex_system(&args, self.name())?;
let tolerance = expect_real_scalar_argument(&args, 3, self.name())?;
let max_iterations = expect_usize_scalar_argument(&args, 4, self.name())?;
let _ = validate_tolerance(self.name(), tolerance)?;
let _ = validate_max_iterations(self.name(), max_iterations)?;
complex_vector_field(self.name(), len, nullable_or(args.arg_fields))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let matrix = parse_complex_matrix_batch_field(&args.arg_fields[0], self.name(), 1)?;
let (_rhs_field, vector) = parse_complex_vector_field(&args.arg_fields[1], self.name(), 2)?;
if matrix.rows != matrix.cols {
return Err(exec_error(
self.name(),
format!(
"{} requires square matrices, found ({}, {})",
self.name(),
matrix.rows,
matrix.cols
),
));
}
if vector.len != matrix.cols {
return Err(exec_error(
self.name(),
format!(
"rhs vector length mismatch: expected {}, found {}",
matrix.cols, vector.len
),
));
}
let tolerance = expect_real_scalar_arg(&args, 3, self.name())?;
let max_iterations = expect_usize_scalar_arg(&args, 4, self.name())?;
let config = iterative_config_f64(self.name(), tolerance, max_iterations)?;
invoke_complex_iterative_solver(
&args,
self.name(),
&config,
nabled::ml::iterative::gmres_complex_view,
)
}
fn documentation(&self) -> Option<&Documentation> {
static DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
iterative_doc(
"Solve each square complex linear system in the batch with GMRES.",
"matrix_gmres_complex(matrix_batch, rhs_batch, tolerance => 1e-6, max_iterations \
=> 64)",
)
.with_argument(
"matrix",
"Square dense complex matrix batch in canonical fixed-shape tensor form.",
)
.with_argument(
"rhs",
"Dense complex vector batch containing one right-hand side per matrix row.",
)
.with_argument("tolerance", "Positive finite convergence tolerance.")
.with_argument("max_iterations", "Positive integer iteration cap.")
.with_alternative_syntax(
"matrix_gmres_complex(matrix => matrix_batch, rhs => rhs_batch, tolerance => \
1e-6, max_iterations => 64)",
)
.build()
});
Some(&DOCUMENTATION)
}
}
#[must_use]
pub fn matrix_conjugate_gradient_udf() -> Arc<ScalarUDF> {
Arc::new(ScalarUDF::new_from_impl(MatrixConjugateGradient::new()))
}
#[must_use]
pub fn matrix_gmres_udf() -> Arc<ScalarUDF> {
Arc::new(ScalarUDF::new_from_impl(MatrixGmres::new()))
}
#[must_use]
pub fn matrix_conjugate_gradient_complex_udf() -> Arc<ScalarUDF> {
Arc::new(ScalarUDF::new_from_impl(MatrixConjugateGradientComplex::new()))
}
#[must_use]
pub fn matrix_gmres_complex_udf() -> Arc<ScalarUDF> {
Arc::new(ScalarUDF::new_from_impl(MatrixGmresComplex::new()))
}
#[cfg(test)]
mod tests {
use datafusion::arrow::datatypes::DataType;
use datafusion::common::ScalarValue;
use datafusion::logical_expr::ReturnFieldArgs;
use super::*;
use crate::metadata::{fixed_shape_tensor_field, vector_field};
#[test]
fn iterative_helpers_validate_square_system_contracts_and_configs() {
let matrix =
fixed_shape_tensor_field("matrix", &DataType::Float64, &[2, 2], false).expect("matrix");
let rhs = vector_field("rhs", &DataType::Float64, 2, false).expect("rhs");
let scalar_arguments = [None, None];
let args = ReturnFieldArgs {
arg_fields: &[Arc::clone(&matrix), Arc::clone(&rhs)],
scalar_arguments: &scalar_arguments,
};
let (value_type, len) =
return_square_system(&args, "matrix_conjugate_gradient").expect("square contract");
assert_eq!(value_type, DataType::Float64);
assert_eq!(len, 2);
let wrong_type = vector_field("rhs", &DataType::Float32, 2, false).expect("rhs");
let wrong_type_args = ReturnFieldArgs {
arg_fields: &[Arc::clone(&matrix), wrong_type],
scalar_arguments: &scalar_arguments,
};
let error = return_square_system(&wrong_type_args, "matrix_conjugate_gradient")
.expect_err("value type mismatch should fail");
assert!(error.to_string().contains("value type mismatch"), "unexpected error: {error}");
let non_square =
fixed_shape_tensor_field("matrix", &DataType::Float64, &[2, 3], false).expect("matrix");
let non_square_args = ReturnFieldArgs {
arg_fields: &[Arc::clone(&non_square), Arc::clone(&rhs)],
scalar_arguments: &scalar_arguments,
};
let error = return_square_system(&non_square_args, "matrix_conjugate_gradient")
.expect_err("non-square matrix should fail");
assert!(
error.to_string().contains("requires square matrices"),
"unexpected error: {error}"
);
let short_rhs = vector_field("rhs", &DataType::Float64, 1, false).expect("rhs");
let short_rhs_args = ReturnFieldArgs {
arg_fields: &[matrix, short_rhs],
scalar_arguments: &scalar_arguments,
};
let error = return_square_system(&short_rhs_args, "matrix_conjugate_gradient")
.expect_err("rhs length mismatch should fail");
assert!(
error.to_string().contains("rhs vector length mismatch"),
"unexpected error: {error}"
);
let config32 =
iterative_config_f32("matrix_conjugate_gradient", 1.0e-3, 8).expect("f32 config");
assert_eq!(config32.max_iterations, 8);
assert!((config32.tolerance - 1.0e-3_f32).abs() < f32::EPSILON);
let config64 =
iterative_config_f64("matrix_conjugate_gradient", 1.0e-6, 16).expect("f64 config");
assert_eq!(config64.max_iterations, 16);
assert!((config64.tolerance - 1.0e-6).abs() < f64::EPSILON);
let error = iterative_config_f64("matrix_conjugate_gradient", 0.0, 16)
.expect_err("non-positive tolerance should fail");
assert!(
error.to_string().contains("tolerance must be positive"),
"unexpected error: {error}"
);
let error = iterative_config_f64("matrix_conjugate_gradient", f64::INFINITY, 16)
.expect_err("non-finite tolerance should fail");
assert!(
error.to_string().contains("tolerance must be finite"),
"unexpected error: {error}"
);
let error = iterative_config_f64("matrix_conjugate_gradient", 1.0e-6, 0)
.expect_err("zero iterations should fail");
assert!(
error.to_string().contains("max_iterations must be greater than 0"),
"unexpected error: {error}"
);
let scalar = ScalarValue::Float64(Some(1.0e-4));
let scalar_arguments = [None, None, Some(&scalar), None];
assert_eq!(scalar_arguments[2], Some(&ScalarValue::Float64(Some(1.0e-4))));
}
}