use arrow_array::types::ArrowPrimitiveType;
use arrow_array::{Array, FixedSizeListArray, ListArray, PrimitiveArray, StructArray};
use arrow_schema::Field;
use nabled_core::scalar::NabledReal;
use ndarrow::NdarrowElement;
use super::{
ArrowInteropError, csr_matrix_batch_view, csr_matrix_view_from_batch_row,
csr_matrix_view_from_columns, csr_matrix_view_from_extension, fixed_size_list_from_owned,
fixed_size_list_view, primitive_array_from_owned, primitive_array_view,
variable_shape_tensor_batch_view,
};
type CsrBatchRowParts<T> = ([usize; 2], Vec<i32>, Vec<u32>, Vec<T>);
macro_rules! sparse_iterative_solver_wrappers {
($columns_name:ident, $extension_name:ident, $call:path) => {
pub fn $columns_name<T>(
indices: &ListArray,
values: &ListArray,
ncols: usize,
rhs: &PrimitiveArray<T>,
tolerance: T::Native,
max_iterations: usize,
) -> Result<PrimitiveArray<T>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_columns::<T>(indices, values, ncols)?;
let rhs_view = primitive_array_view(rhs)?;
let output = $call(&matrix_view, &rhs_view, tolerance, max_iterations)?;
Ok(primitive_array_from_owned::<T>(output))
}
pub fn $extension_name<T>(
field: &Field,
matrix: &StructArray,
rhs: &PrimitiveArray<T>,
tolerance: T::Native,
max_iterations: usize,
) -> Result<PrimitiveArray<T>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_extension::<T>(field, matrix)?;
let rhs_view = primitive_array_view(rhs)?;
let output = $call(&matrix_view, &rhs_view, tolerance, max_iterations)?;
Ok(primitive_array_from_owned::<T>(output))
}
};
}
fn csr_matrix_to_batch_parts<T: NabledReal>(
matrix: crate::linalg::sparse::CsrMatrix<T>,
) -> Result<CsrBatchRowParts<T>, ArrowInteropError> {
let row_ptrs = matrix
.indptr
.into_iter()
.map(|index| {
i32::try_from(index).map_err(|_| ndarrow::NdarrowError::ShapeMismatch {
message: format!("CSR row pointer {index} exceeds i32 limits for batch output"),
})
})
.collect::<Result<Vec<_>, _>>()?;
let col_indices = matrix
.indices
.into_iter()
.map(|index| {
u32::try_from(index).map_err(|_| ndarrow::NdarrowError::ShapeMismatch {
message: format!("CSR column index {index} exceeds u32 limits for batch output"),
})
})
.collect::<Result<Vec<_>, _>>()?;
Ok(([matrix.nrows, matrix.ncols], row_ptrs, col_indices, matrix.data))
}
pub fn matvec_csr_columns<T>(
indices: &ListArray,
values: &ListArray,
ncols: usize,
vector: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_columns::<T>(indices, values, ncols)?;
let vector_view = primitive_array_view(vector)?;
let output = crate::linalg::sparse::matvec_view(&matrix_view, &vector_view)?;
Ok(primitive_array_from_owned::<T>(output))
}
pub fn matvec_csr_extension<T>(
field: &Field,
matrix: &StructArray,
vector: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_extension::<T>(field, matrix)?;
let vector_view = primitive_array_view(vector)?;
let output = crate::linalg::sparse::matvec_view(&matrix_view, &vector_view)?;
Ok(primitive_array_from_owned::<T>(output))
}
pub fn matmat_dense_csr_columns<T>(
indices: &ListArray,
values: &ListArray,
ncols: usize,
right: &FixedSizeListArray,
) -> Result<FixedSizeListArray, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_columns::<T>(indices, values, ncols)?;
let right_view = fixed_size_list_view::<T>(right)?;
let output = crate::linalg::sparse::matmat_dense_view(&matrix_view, &right_view)?;
fixed_size_list_from_owned::<T>(output)
}
pub fn matmat_dense_csr_extension<T>(
field: &Field,
matrix: &StructArray,
right: &FixedSizeListArray,
) -> Result<FixedSizeListArray, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_extension::<T>(field, matrix)?;
let right_view = fixed_size_list_view::<T>(right)?;
let output = crate::linalg::sparse::matmat_dense_view(&matrix_view, &right_view)?;
fixed_size_list_from_owned::<T>(output)
}
pub fn sparse_lu_factor_csr_columns<T>(
indices: &ListArray,
values: &ListArray,
ncols: usize,
) -> Result<crate::linalg::sparse::SparseLUFactorization<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_columns::<T>(indices, values, ncols)?;
Ok(crate::linalg::sparse::sparse_lu_factor_view(&matrix_view)?)
}
pub fn sparse_lu_factor_csr_extension<T>(
field: &Field,
matrix: &StructArray,
) -> Result<crate::linalg::sparse::SparseLUFactorization<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_extension::<T>(field, matrix)?;
Ok(crate::linalg::sparse::sparse_lu_factor_view(&matrix_view)?)
}
pub fn sparse_lu_solve_csr_columns<T>(
indices: &ListArray,
values: &ListArray,
ncols: usize,
rhs: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_columns::<T>(indices, values, ncols)?;
let rhs_view = primitive_array_view(rhs)?;
let output = crate::linalg::sparse::sparse_lu_solve_view(&matrix_view, &rhs_view)?;
Ok(primitive_array_from_owned::<T>(output))
}
pub fn sparse_lu_solve_csr_extension<T>(
field: &Field,
matrix: &StructArray,
rhs: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_extension::<T>(field, matrix)?;
let rhs_view = primitive_array_view(rhs)?;
let output = crate::linalg::sparse::sparse_lu_solve_view(&matrix_view, &rhs_view)?;
Ok(primitive_array_from_owned::<T>(output))
}
sparse_iterative_solver_wrappers!(
jacobi_solve_csr_columns,
jacobi_solve_csr_extension,
crate::linalg::sparse::jacobi_solve_view
);
sparse_iterative_solver_wrappers!(
gauss_seidel_solve_csr_columns,
gauss_seidel_solve_csr_extension,
crate::linalg::sparse::gauss_seidel_solve_view
);
sparse_iterative_solver_wrappers!(
conjugate_gradient_solve_csr_columns,
conjugate_gradient_solve_csr_extension,
crate::linalg::sparse::conjugate_gradient_solve_view
);
sparse_iterative_solver_wrappers!(
pcg_solve_csr_columns,
pcg_solve_csr_extension,
crate::linalg::sparse::pcg_solve_view
);
pub fn transpose_csr_columns<T>(
indices: &ListArray,
values: &ListArray,
ncols: usize,
) -> Result<crate::linalg::sparse::CsrMatrix<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_columns::<T>(indices, values, ncols)?;
Ok(crate::linalg::sparse::transpose_view(&matrix_view)?)
}
pub fn transpose_csr_extension<T>(
field: &Field,
matrix: &StructArray,
) -> Result<crate::linalg::sparse::CsrMatrix<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_extension::<T>(field, matrix)?;
Ok(crate::linalg::sparse::transpose_view(&matrix_view)?)
}
pub fn csr_to_csc_csr_columns<T>(
indices: &ListArray,
values: &ListArray,
ncols: usize,
) -> Result<crate::linalg::sparse::CscMatrix<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_columns::<T>(indices, values, ncols)?;
Ok(crate::linalg::sparse::csr_to_csc_view(&matrix_view)?)
}
pub fn csr_to_csc_csr_extension<T>(
field: &Field,
matrix: &StructArray,
) -> Result<crate::linalg::sparse::CscMatrix<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_extension::<T>(field, matrix)?;
Ok(crate::linalg::sparse::csr_to_csc_view(&matrix_view)?)
}
pub fn matmat_sparse_csr_columns<T>(
left_indices: &ListArray,
left_values: &ListArray,
left_ncols: usize,
right_indices: &ListArray,
right_values: &ListArray,
right_ncols: usize,
) -> Result<crate::linalg::sparse::CsrMatrix<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let left_view = csr_matrix_view_from_columns::<T>(left_indices, left_values, left_ncols)?;
let right_view = csr_matrix_view_from_columns::<T>(right_indices, right_values, right_ncols)?;
Ok(crate::linalg::sparse::matmat_sparse_view(&left_view, &right_view)?)
}
pub fn matmat_sparse_csr_extension<T>(
left_field: &Field,
left: &StructArray,
right_field: &Field,
right: &StructArray,
) -> Result<crate::linalg::sparse::CsrMatrix<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let left_view = csr_matrix_view_from_extension::<T>(left_field, left)?;
let right_view = csr_matrix_view_from_extension::<T>(right_field, right)?;
Ok(crate::linalg::sparse::matmat_sparse_view(&left_view, &right_view)?)
}
pub fn matvec_csr_batch_extension<T>(
field: &Field,
matrices: &StructArray,
vectors_field: &Field,
vectors: &StructArray,
) -> Result<(Field, StructArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
if matrices.len() != vectors.len() {
return Err(ArrowInteropError::InvalidShape(format!(
"sparse matrix batch row count mismatch: {} vs {}",
matrices.len(),
vectors.len()
)));
}
let matrix_batch = csr_matrix_batch_view::<T>(field, matrices)?;
let vector_batch = variable_shape_tensor_batch_view::<T>(vectors_field, vectors)?;
let mut outputs = Vec::with_capacity(matrix_batch.len());
for row in 0..matrix_batch.len() {
let matrix_view = csr_matrix_view_from_batch_row::<T>(matrix_batch.row(row)?)?;
let vector_view = vector_batch
.row(row)?
.as_array_viewd()?
.into_dimensionality::<ndarray::Ix1>()
.map_err(|error| ArrowInteropError::InvalidShape(error.to_string()))?;
outputs.push(crate::linalg::sparse::matvec_view(&matrix_view, &vector_view)?.into_dyn());
}
Ok(ndarrow::arrays_to_variable_shape_tensor(field.name(), outputs, Some(vec![None]))?)
}
pub fn matmat_dense_csr_batch_extension<T>(
field: &Field,
matrices: &StructArray,
right_field: &Field,
right: &StructArray,
) -> Result<(Field, StructArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
if matrices.len() != right.len() {
return Err(ArrowInteropError::InvalidShape(format!(
"sparse matrix batch row count mismatch: {} vs {}",
matrices.len(),
right.len()
)));
}
let matrix_batch = csr_matrix_batch_view::<T>(field, matrices)?;
let right_batch = variable_shape_tensor_batch_view::<T>(right_field, right)?;
let mut outputs = Vec::with_capacity(matrix_batch.len());
for row in 0..matrix_batch.len() {
let matrix_view = csr_matrix_view_from_batch_row::<T>(matrix_batch.row(row)?)?;
let right_view = right_batch
.row(row)?
.as_array_viewd()?
.into_dimensionality::<ndarray::Ix2>()
.map_err(|error| ArrowInteropError::InvalidShape(error.to_string()))?;
outputs
.push(crate::linalg::sparse::matmat_dense_view(&matrix_view, &right_view)?.into_dyn());
}
Ok(ndarrow::arrays_to_variable_shape_tensor(field.name(), outputs, Some(vec![None, None]))?)
}
pub fn transpose_csr_batch_extension<T>(
field: &Field,
matrices: &StructArray,
) -> Result<(Field, StructArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_batch = csr_matrix_batch_view::<T>(field, matrices)?;
let mut shapes = Vec::with_capacity(matrix_batch.len());
let mut row_ptrs = Vec::with_capacity(matrix_batch.len());
let mut col_indices = Vec::with_capacity(matrix_batch.len());
let mut values = Vec::with_capacity(matrix_batch.len());
for row in 0..matrix_batch.len() {
let matrix_view = csr_matrix_view_from_batch_row::<T>(matrix_batch.row(row)?)?;
let transposed = crate::linalg::sparse::transpose_view(&matrix_view)?;
let (shape, row_ptr, cols, row_values) = csr_matrix_to_batch_parts(transposed)?;
shapes.push(shape);
row_ptrs.push(row_ptr);
col_indices.push(cols);
values.push(row_values);
}
Ok(ndarrow::csr_batch_to_extension_array(
field.name(),
shapes,
row_ptrs,
col_indices,
values,
)?)
}
pub fn matmat_sparse_csr_batch_extension<T>(
left_field: &Field,
left: &StructArray,
right_field: &Field,
right: &StructArray,
) -> Result<(Field, StructArray), ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
if left.len() != right.len() {
return Err(ArrowInteropError::InvalidShape(format!(
"sparse matrix batch row count mismatch: {} vs {}",
left.len(),
right.len()
)));
}
let left_batch = csr_matrix_batch_view::<T>(left_field, left)?;
let right_batch = csr_matrix_batch_view::<T>(right_field, right)?;
let mut shapes = Vec::with_capacity(left_batch.len());
let mut row_ptrs = Vec::with_capacity(left_batch.len());
let mut col_indices = Vec::with_capacity(left_batch.len());
let mut values = Vec::with_capacity(left_batch.len());
for row in 0..left_batch.len() {
let left_view = csr_matrix_view_from_batch_row::<T>(left_batch.row(row)?)?;
let right_view = csr_matrix_view_from_batch_row::<T>(right_batch.row(row)?)?;
let product = crate::linalg::sparse::matmat_sparse_view(&left_view, &right_view)?;
let (shape, row_ptr, cols, row_values) = csr_matrix_to_batch_parts(product)?;
shapes.push(shape);
row_ptrs.push(row_ptr);
col_indices.push(cols);
values.push(row_values);
}
Ok(ndarrow::csr_batch_to_extension_array(
left_field.name(),
shapes,
row_ptrs,
col_indices,
values,
)?)
}
pub fn batched_matvec_csr_columns<T>(
indices: &ListArray,
values: &ListArray,
ncols: usize,
batch_vectors: &FixedSizeListArray,
) -> Result<FixedSizeListArray, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_columns::<T>(indices, values, ncols)?;
let batch_vectors_view = fixed_size_list_view::<T>(batch_vectors)?;
let output = crate::linalg::sparse::batched_matvec_view(&matrix_view, &batch_vectors_view)?;
fixed_size_list_from_owned::<T>(output)
}
pub fn batched_matvec_csr_extension<T>(
field: &Field,
matrix: &StructArray,
batch_vectors: &FixedSizeListArray,
) -> Result<FixedSizeListArray, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_extension::<T>(field, matrix)?;
let batch_vectors_view = fixed_size_list_view::<T>(batch_vectors)?;
let output = crate::linalg::sparse::batched_matvec_view(&matrix_view, &batch_vectors_view)?;
fixed_size_list_from_owned::<T>(output)
}
pub fn jacobi_preconditioner_csr_columns<T>(
indices: &ListArray,
values: &ListArray,
ncols: usize,
) -> Result<crate::linalg::sparse::JacobiPreconditioner<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_columns::<T>(indices, values, ncols)?;
Ok(crate::linalg::sparse::jacobi_preconditioner_view(&matrix_view)?)
}
pub fn jacobi_preconditioner_csr_extension<T>(
field: &Field,
matrix: &StructArray,
) -> Result<crate::linalg::sparse::JacobiPreconditioner<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_extension::<T>(field, matrix)?;
Ok(crate::linalg::sparse::jacobi_preconditioner_view(&matrix_view)?)
}
pub fn apply_jacobi_preconditioner<T>(
preconditioner: &crate::linalg::sparse::JacobiPreconditioner<T::Native>,
rhs: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let rhs_view = primitive_array_view(rhs)?;
let output = crate::linalg::sparse::apply_jacobi_preconditioner(preconditioner, &rhs_view)?;
Ok(primitive_array_from_owned::<T>(output))
}
macro_rules! sparse_factorization_wrappers {
($columns_name:ident, $extension_name:ident, $result_ty:ident, $call:path) => {
pub fn $columns_name<T>(
indices: &ListArray,
values: &ListArray,
ncols: usize,
) -> Result<crate::linalg::sparse::$result_ty<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_columns::<T>(indices, values, ncols)?;
Ok($call(&matrix_view)?)
}
pub fn $extension_name<T>(
field: &Field,
matrix: &StructArray,
) -> Result<crate::linalg::sparse::$result_ty<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_extension::<T>(field, matrix)?;
Ok($call(&matrix_view)?)
}
};
}
sparse_factorization_wrappers!(
ilu0_factor_csr_columns,
ilu0_factor_csr_extension,
ILU0Factorization,
crate::linalg::sparse::ilu0_factor_view
);
sparse_factorization_wrappers!(
ic0_factor_csr_columns,
ic0_factor_csr_extension,
IC0Factorization,
crate::linalg::sparse::ic0_factor_view
);
sparse_factorization_wrappers!(
ildl0_factor_csr_columns,
ildl0_factor_csr_extension,
ILDL0Factorization,
crate::linalg::sparse::ildl0_factor_view
);
pub fn ilut_factor_csr_columns<T>(
indices: &ListArray,
values: &ListArray,
ncols: usize,
drop_tolerance: T::Native,
max_fill: usize,
) -> Result<crate::linalg::sparse::ILUTFactorization<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_columns::<T>(indices, values, ncols)?;
Ok(crate::linalg::sparse::ilut_factor_view(&matrix_view, drop_tolerance, max_fill)?)
}
pub fn ilut_factor_csr_extension<T>(
field: &Field,
matrix: &StructArray,
drop_tolerance: T::Native,
max_fill: usize,
) -> Result<crate::linalg::sparse::ILUTFactorization<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_extension::<T>(field, matrix)?;
Ok(crate::linalg::sparse::ilut_factor_view(&matrix_view, drop_tolerance, max_fill)?)
}
pub fn iluk_factor_csr_columns<T>(
indices: &ListArray,
values: &ListArray,
ncols: usize,
level_of_fill: usize,
) -> Result<crate::linalg::sparse::ILUKFactorization<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_columns::<T>(indices, values, ncols)?;
Ok(crate::linalg::sparse::iluk_factor_view(&matrix_view, level_of_fill)?)
}
pub fn iluk_factor_csr_extension<T>(
field: &Field,
matrix: &StructArray,
level_of_fill: usize,
) -> Result<crate::linalg::sparse::ILUKFactorization<T::Native>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_extension::<T>(field, matrix)?;
Ok(crate::linalg::sparse::iluk_factor_view(&matrix_view, level_of_fill)?)
}
macro_rules! sparse_apply_preconditioner_wrappers {
($name:ident, $factor_ty:ident, $call:path) => {
pub fn $name<T>(
factorization: &crate::linalg::sparse::$factor_ty<T::Native>,
rhs: &PrimitiveArray<T>,
) -> Result<PrimitiveArray<T>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let rhs_view = primitive_array_view(rhs)?;
let output = $call(factorization, &rhs_view)?;
Ok(primitive_array_from_owned::<T>(output))
}
};
}
sparse_apply_preconditioner_wrappers!(
apply_ilu0_preconditioner,
ILU0Factorization,
crate::linalg::sparse::apply_ilu0_preconditioner
);
sparse_apply_preconditioner_wrappers!(
apply_ilut_preconditioner,
ILUTFactorization,
crate::linalg::sparse::apply_ilut_preconditioner
);
sparse_apply_preconditioner_wrappers!(
apply_iluk_preconditioner,
ILUKFactorization,
crate::linalg::sparse::apply_iluk_preconditioner
);
sparse_apply_preconditioner_wrappers!(
apply_ic0_preconditioner,
IC0Factorization,
crate::linalg::sparse::apply_ic0_preconditioner
);
sparse_apply_preconditioner_wrappers!(
apply_ildl0_preconditioner,
ILDL0Factorization,
crate::linalg::sparse::apply_ildl0_preconditioner
);
pub fn sparse_lu_solve_with_factorization_csr_columns<T>(
indices: &ListArray,
values: &ListArray,
ncols: usize,
rhs: &PrimitiveArray<T>,
factorization: &crate::linalg::sparse::SparseLUFactorization<T::Native>,
) -> Result<PrimitiveArray<T>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_columns::<T>(indices, values, ncols)?;
let rhs_view = primitive_array_view(rhs)?;
let output = crate::linalg::sparse::sparse_lu_solve_with_factorization_view(
&matrix_view,
&rhs_view,
factorization,
)?;
Ok(primitive_array_from_owned::<T>(output))
}
pub fn sparse_lu_solve_with_factorization_csr_extension<T>(
field: &Field,
matrix: &StructArray,
rhs: &PrimitiveArray<T>,
factorization: &crate::linalg::sparse::SparseLUFactorization<T::Native>,
) -> Result<PrimitiveArray<T>, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_extension::<T>(field, matrix)?;
let rhs_view = primitive_array_view(rhs)?;
let output = crate::linalg::sparse::sparse_lu_solve_with_factorization_view(
&matrix_view,
&rhs_view,
factorization,
)?;
Ok(primitive_array_from_owned::<T>(output))
}
pub fn sparse_lu_solve_multiple_with_factorization_csr_columns<T>(
indices: &ListArray,
values: &ListArray,
ncols: usize,
rhs: &FixedSizeListArray,
factorization: &crate::linalg::sparse::SparseLUFactorization<T::Native>,
) -> Result<FixedSizeListArray, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_columns::<T>(indices, values, ncols)?;
let rhs_view = fixed_size_list_view::<T>(rhs)?;
let output = crate::linalg::sparse::sparse_lu_solve_multiple_with_factorization_view(
&matrix_view,
&rhs_view,
factorization,
)?;
fixed_size_list_from_owned::<T>(output)
}
pub fn sparse_lu_solve_multiple_with_factorization_csr_extension<T>(
field: &Field,
matrix: &StructArray,
rhs: &FixedSizeListArray,
factorization: &crate::linalg::sparse::SparseLUFactorization<T::Native>,
) -> Result<FixedSizeListArray, ArrowInteropError>
where
T: ArrowPrimitiveType,
T::Native: NabledReal + NdarrowElement,
{
let matrix_view = csr_matrix_view_from_extension::<T>(field, matrix)?;
let rhs_view = fixed_size_list_view::<T>(rhs)?;
let output = crate::linalg::sparse::sparse_lu_solve_multiple_with_factorization_view(
&matrix_view,
&rhs_view,
factorization,
)?;
fixed_size_list_from_owned::<T>(output)
}