use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use crate::interop::ndarray_compat::{from_ndarray, to_ndarray};
use num_traits::{Float, NumAssign, NumCast};
use scirs2_core::ndarray::{
Array1, Array2, ArrayView1, ArrayView2, Ix1, Ix2, IxDyn, ScalarOperand,
};
use std::fmt::Debug;
use std::iter::Sum;
pub use scirs2_linalg::LinalgError;
pub type LinalgResult<T> = std::result::Result<T, LinalgError>;
pub type ComplexEigResult<T> = (
Array<scirs2_core::numeric::Complex<T>>,
Array<scirs2_core::numeric::Complex<T>>,
);
fn to_array2<T>(arr: &Array<T>) -> Result<Array2<T>>
where
T: Clone + Debug,
{
let shape = arr.shape();
if shape.len() != 2 {
return Err(NumRs2Error::DimensionMismatch(format!(
"Expected 2D array, got {}D",
shape.len()
)));
}
let data = arr.to_vec();
Array2::from_shape_vec((shape[0], shape[1]), data)
.map_err(|e| NumRs2Error::ConversionError(format!("Failed to convert to Array2: {}", e)))
}
fn to_array1<T>(arr: &Array<T>) -> Result<Array1<T>>
where
T: Clone + Debug,
{
let shape = arr.shape();
if shape.len() != 1 {
return Err(NumRs2Error::DimensionMismatch(format!(
"Expected 1D array, got {}D",
shape.len()
)));
}
let data = arr.to_vec();
Ok(Array1::from_vec(data))
}
fn from_array2<T>(arr: Array2<T>) -> Result<Array<T>>
where
T: Clone + Debug + NumCast,
{
let nrows = arr.nrows();
let ncols = arr.ncols();
let data: Vec<T> = arr.into_iter().collect();
Ok(Array::from_vec(data).reshape(&[nrows, ncols]))
}
fn from_array1<T>(arr: Array1<T>) -> Result<Array<T>>
where
T: Clone + Debug + NumCast,
{
let data: Vec<T> = arr.into_iter().collect();
Ok(Array::from_vec(data))
}
fn linalg_to_numrs2_error(e: LinalgError) -> NumRs2Error {
match e {
LinalgError::SingularMatrixError(s) => {
NumRs2Error::InvalidOperation(format!("Singular matrix: {}", s))
}
LinalgError::DimensionError(s) => NumRs2Error::DimensionMismatch(s),
LinalgError::ShapeError(s) => NumRs2Error::DimensionMismatch(s),
LinalgError::NonPositiveDefiniteError(s) => {
NumRs2Error::InvalidOperation(format!("Matrix is not positive definite: {}", s))
}
LinalgError::ConvergenceError(s) => {
NumRs2Error::ComputationError(format!("Convergence failed: {}", s))
}
_ => NumRs2Error::ComputationError(format!("Linear algebra error: {}", e)),
}
}
pub fn dot<T>(x: &Array<T>, y: &Array<T>) -> Result<T>
where
T: Float + NumAssign + Clone + Debug + NumCast + 'static,
{
let x_nd = to_array1(x)?;
let y_nd = to_array1(y)?;
scirs2_linalg::blas_accelerated::dot(&x_nd.view(), &y_nd.view()).map_err(linalg_to_numrs2_error)
}
pub fn norm<T>(x: &Array<T>) -> Result<T>
where
T: Float + NumAssign + Clone + Debug + NumCast + 'static,
{
let x_nd = to_array1(x)?;
scirs2_linalg::blas_accelerated::norm(&x_nd.view()).map_err(linalg_to_numrs2_error)
}
pub fn gemv<T>(alpha: T, a: &Array<T>, x: &Array<T>, beta: T, y: &Array<T>) -> Result<Array<T>>
where
T: Float + NumAssign + Clone + Debug + NumCast + 'static,
{
let a_nd = to_array2(a)?;
let x_nd = to_array1(x)?;
let y_nd = to_array1(y)?;
let result = scirs2_linalg::blas_accelerated::gemv(
alpha,
&a_nd.view(),
&x_nd.view(),
beta,
&y_nd.view(),
)
.map_err(linalg_to_numrs2_error)?;
from_array1(result)
}
pub fn matvec<T>(a: &Array<T>, x: &Array<T>) -> Result<Array<T>>
where
T: Float + NumAssign + Clone + Debug + NumCast + 'static,
{
let a_shape = a.shape();
if a_shape.len() != 2 {
return Err(NumRs2Error::DimensionMismatch(
"Matrix must be 2D".to_string(),
));
}
let m = a_shape[0];
let y_init = Array::zeros(&[m]);
gemv(T::one(), a, x, T::zero(), &y_init)
}
pub fn gemm<T>(alpha: T, a: &Array<T>, b: &Array<T>, beta: T, c: &Array<T>) -> Result<Array<T>>
where
T: Float + NumAssign + Clone + Debug + NumCast + 'static,
{
let a_nd = to_array2(a)?;
let b_nd = to_array2(b)?;
let c_nd = to_array2(c)?;
let result = scirs2_linalg::blas_accelerated::gemm(
alpha,
&a_nd.view(),
&b_nd.view(),
beta,
&c_nd.view(),
)
.map_err(linalg_to_numrs2_error)?;
from_array2(result)
}
pub fn matmul<T>(a: &Array<T>, b: &Array<T>) -> Result<Array<T>>
where
T: Float + NumAssign + Clone + Debug + NumCast + 'static,
{
let a_nd = to_array2(a)?;
let b_nd = to_array2(b)?;
let result = scirs2_linalg::blas_accelerated::matmul(&a_nd.view(), &b_nd.view())
.map_err(linalg_to_numrs2_error)?;
from_array2(result)
}
pub fn lu<T>(a: &Array<T>) -> Result<(Array<T>, Array<T>, Array<T>)>
where
T: Float + NumAssign + Clone + Debug + NumCast + Sum + Send + Sync + ScalarOperand + 'static,
{
let a_nd = to_array2(a)?;
let (p, l, u) = scirs2_linalg::lu(&a_nd.view(), None).map_err(linalg_to_numrs2_error)?;
Ok((from_array2(p)?, from_array2(l)?, from_array2(u)?))
}
pub fn qr<T>(a: &Array<T>) -> Result<(Array<T>, Array<T>)>
where
T: Float + NumAssign + Clone + Debug + NumCast + Sum + Send + Sync + ScalarOperand + 'static,
{
let a_nd = to_array2(a)?;
let (q, r) = scirs2_linalg::qr(&a_nd.view(), None).map_err(linalg_to_numrs2_error)?;
Ok((from_array2(q)?, from_array2(r)?))
}
pub fn svd<T>(a: &Array<T>, full_matrices: bool) -> Result<(Array<T>, Array<T>, Array<T>)>
where
T: Float + NumAssign + Clone + Debug + NumCast + Sum + Send + Sync + ScalarOperand + 'static,
{
let a_nd = to_array2(a)?;
let (u, s, vt) =
scirs2_linalg::svd(&a_nd.view(), full_matrices, None).map_err(linalg_to_numrs2_error)?;
Ok((from_array2(u)?, from_array1(s)?, from_array2(vt)?))
}
pub fn cholesky<T>(a: &Array<T>) -> Result<Array<T>>
where
T: Float + NumAssign + Clone + Debug + NumCast + Sum + Send + Sync + ScalarOperand + 'static,
{
let a_nd = to_array2(a)?;
let l = scirs2_linalg::cholesky(&a_nd.view(), None).map_err(linalg_to_numrs2_error)?;
from_array2(l)
}
pub fn eig<T>(a: &Array<T>) -> Result<ComplexEigResult<T>>
where
T: Float + NumAssign + Clone + Debug + NumCast + Sum + Send + Sync + ScalarOperand + 'static,
{
let a_nd = to_array2(a)?;
let (eigenvalues, eigenvectors) =
scirs2_linalg::eig(&a_nd.view(), None).map_err(linalg_to_numrs2_error)?;
let eigenvalues_vec: Vec<scirs2_core::numeric::Complex<T>> = eigenvalues.into_iter().collect();
let eigenvectors_data: Vec<scirs2_core::numeric::Complex<T>> =
eigenvectors.iter().cloned().collect();
let ev_shape = eigenvectors.shape();
Ok((
Array::from_vec(eigenvalues_vec),
Array::from_vec(eigenvectors_data).reshape(&[ev_shape[0], ev_shape[1]]),
))
}
pub fn eigh<T>(a: &Array<T>) -> Result<(Array<T>, Array<T>)>
where
T: Float + NumAssign + Clone + Debug + NumCast + Sum + Send + Sync + ScalarOperand + 'static,
{
let a_nd = to_array2(a)?;
let (eigenvalues, eigenvectors) =
scirs2_linalg::eigh(&a_nd.view(), None).map_err(linalg_to_numrs2_error)?;
Ok((from_array1(eigenvalues)?, from_array2(eigenvectors)?))
}
pub fn eigvals<T>(a: &Array<T>) -> Result<Array<scirs2_core::numeric::Complex<T>>>
where
T: Float + NumAssign + Clone + Debug + NumCast + Sum + Send + Sync + ScalarOperand + 'static,
{
let a_nd = to_array2(a)?;
let eigenvalues = scirs2_linalg::eigvals(&a_nd.view(), None).map_err(linalg_to_numrs2_error)?;
let eigenvalues_vec: Vec<scirs2_core::numeric::Complex<T>> = eigenvalues.into_iter().collect();
Ok(Array::from_vec(eigenvalues_vec))
}
pub fn eigvalsh<T>(a: &Array<T>) -> Result<Array<T>>
where
T: Float + NumAssign + Clone + Debug + NumCast + Sum + Send + Sync + ScalarOperand + 'static,
{
let a_nd = to_array2(a)?;
let eigenvalues =
scirs2_linalg::eigvalsh(&a_nd.view(), None).map_err(linalg_to_numrs2_error)?;
from_array1(eigenvalues)
}
pub fn solve<T>(a: &Array<T>, b: &Array<T>) -> Result<Array<T>>
where
T: Float
+ NumAssign
+ Clone
+ Debug
+ NumCast
+ Sum
+ Send
+ Sync
+ ScalarOperand
+ num_traits::One
+ 'static,
{
let a_nd = to_array2(a)?;
let b_nd = to_array1(b)?;
let x =
scirs2_linalg::solve(&a_nd.view(), &b_nd.view(), None).map_err(linalg_to_numrs2_error)?;
from_array1(x)
}
pub fn inv<T>(a: &Array<T>) -> Result<Array<T>>
where
T: Float + NumAssign + Clone + Debug + NumCast + Sum + Send + Sync + ScalarOperand + 'static,
{
let a_nd = to_array2(a)?;
let a_inv = scirs2_linalg::inv(&a_nd.view(), None).map_err(linalg_to_numrs2_error)?;
from_array2(a_inv)
}
pub fn det<T>(a: &Array<T>) -> Result<T>
where
T: Float + NumAssign + Clone + Debug + NumCast + Sum + Send + Sync + ScalarOperand + 'static,
{
let a_nd = to_array2(a)?;
scirs2_linalg::det(&a_nd.view(), None).map_err(linalg_to_numrs2_error)
}
pub fn lstsq<T>(a: &Array<T>, b: &Array<T>) -> Result<Array<T>>
where
T: Float
+ NumAssign
+ Clone
+ Debug
+ NumCast
+ Sum
+ Send
+ Sync
+ ScalarOperand
+ num_traits::One
+ 'static,
{
let a_nd = to_array2(a)?;
let b_nd = to_array1(b)?;
let result =
scirs2_linalg::lstsq(&a_nd.view(), &b_nd.view(), None).map_err(linalg_to_numrs2_error)?;
from_array1(result.x)
}
pub struct AcceleratedBlas;
impl AcceleratedBlas {
pub fn gemm<T>(
a: &Array<T>,
b: &Array<T>,
c: &mut Array<T>,
alpha: T,
beta: T,
_trans_a: bool,
_trans_b: bool,
) -> Result<()>
where
T: Float + NumAssign + Clone + Debug + NumCast + 'static,
{
let result = gemm(alpha, a, b, beta, c)?;
*c = result;
Ok(())
}
pub fn gemv<T>(
a: &Array<T>,
x: &Array<T>,
y: &mut Array<T>,
alpha: T,
beta: T,
_trans: bool,
) -> Result<()>
where
T: Float + NumAssign + Clone + Debug + NumCast + 'static,
{
let result = gemv(alpha, a, x, beta, y)?;
*y = result;
Ok(())
}
pub fn dot<T>(x: &Array<T>, y: &Array<T>) -> Result<T>
where
T: Float + NumAssign + Clone + Debug + NumCast + 'static,
{
dot(x, y)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_dot_product() {
let x = Array::from_vec(vec![1.0f64, 2.0, 3.0]);
let y = Array::from_vec(vec![4.0f64, 5.0, 6.0]);
let result = dot(&x, &y).expect("dot product should succeed");
assert_relative_eq!(result, 32.0, epsilon = 1e-10);
}
#[test]
fn test_norm() {
let x = Array::from_vec(vec![3.0f64, 4.0]);
let result = norm(&x).expect("norm should succeed");
assert_relative_eq!(result, 5.0, epsilon = 1e-10);
}
#[test]
fn test_matmul() {
let a = Array::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
let b = Array::from_vec(vec![5.0f64, 6.0, 7.0, 8.0]).reshape(&[2, 2]);
let c = matmul(&a, &b).expect("matmul should succeed");
assert_relative_eq!(c.get(&[0, 0]).expect("valid index"), 19.0, epsilon = 1e-10);
assert_relative_eq!(c.get(&[0, 1]).expect("valid index"), 22.0, epsilon = 1e-10);
assert_relative_eq!(c.get(&[1, 0]).expect("valid index"), 43.0, epsilon = 1e-10);
assert_relative_eq!(c.get(&[1, 1]).expect("valid index"), 50.0, epsilon = 1e-10);
}
#[test]
fn test_matvec() {
let a = Array::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
let x = Array::from_vec(vec![1.0f64, 2.0]);
let y = matvec(&a, &x).expect("matvec should succeed");
assert_relative_eq!(y.get(&[0]).expect("valid index"), 5.0, epsilon = 1e-10);
assert_relative_eq!(y.get(&[1]).expect("valid index"), 11.0, epsilon = 1e-10);
}
#[test]
fn test_solve() {
let a = Array::from_vec(vec![2.0f64, 1.0, 1.0, 3.0]).reshape(&[2, 2]);
let b = Array::from_vec(vec![5.0f64, 6.0]);
let x = solve(&a, &b).expect("solve should succeed");
assert_relative_eq!(x.get(&[0]).expect("valid index"), 1.8, epsilon = 1e-10);
assert_relative_eq!(x.get(&[1]).expect("valid index"), 1.4, epsilon = 1e-10);
}
#[test]
fn test_inv() {
let a = Array::from_vec(vec![4.0f64, 7.0, 2.0, 6.0]).reshape(&[2, 2]);
let a_inv = inv(&a).expect("inverse should succeed");
let identity = matmul(&a, &a_inv).expect("matmul should succeed");
assert_relative_eq!(
identity.get(&[0, 0]).expect("valid index"),
1.0,
epsilon = 1e-10
);
assert_relative_eq!(
identity.get(&[0, 1]).expect("valid index"),
0.0,
epsilon = 1e-10
);
assert_relative_eq!(
identity.get(&[1, 0]).expect("valid index"),
0.0,
epsilon = 1e-10
);
assert_relative_eq!(
identity.get(&[1, 1]).expect("valid index"),
1.0,
epsilon = 1e-10
);
}
#[test]
fn test_det() {
let a = Array::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
let d = det(&a).expect("determinant should succeed");
assert_relative_eq!(d, -2.0, epsilon = 1e-10);
}
#[test]
fn test_qr() {
let a = Array::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
let (q, r) = qr(&a).expect("QR decomposition should succeed");
let q_t = q.transpose();
let identity = matmul(&q, &q_t).expect("matmul should succeed");
assert_relative_eq!(
identity.get(&[0, 0]).expect("valid index"),
1.0,
epsilon = 1e-10
);
assert_relative_eq!(
identity.get(&[1, 1]).expect("valid index"),
1.0,
epsilon = 1e-10
);
let reconstructed = matmul(&q, &r).expect("matmul should succeed");
assert_relative_eq!(
reconstructed.get(&[0, 0]).expect("valid index"),
1.0,
epsilon = 1e-10
);
assert_relative_eq!(
reconstructed.get(&[0, 1]).expect("valid index"),
2.0,
epsilon = 1e-10
);
assert_relative_eq!(
reconstructed.get(&[1, 0]).expect("valid index"),
3.0,
epsilon = 1e-10
);
assert_relative_eq!(
reconstructed.get(&[1, 1]).expect("valid index"),
4.0,
epsilon = 1e-10
);
}
#[test]
fn test_cholesky() {
let a = Array::from_vec(vec![4.0f64, 2.0, 2.0, 3.0]).reshape(&[2, 2]);
let l = cholesky(&a).expect("Cholesky decomposition should succeed");
let l_t = l.transpose();
let reconstructed = matmul(&l, &l_t).expect("matmul should succeed");
assert_relative_eq!(
reconstructed.get(&[0, 0]).expect("valid index"),
4.0,
epsilon = 1e-10
);
assert_relative_eq!(
reconstructed.get(&[0, 1]).expect("valid index"),
2.0,
epsilon = 1e-10
);
assert_relative_eq!(
reconstructed.get(&[1, 0]).expect("valid index"),
2.0,
epsilon = 1e-10
);
assert_relative_eq!(
reconstructed.get(&[1, 1]).expect("valid index"),
3.0,
epsilon = 1e-10
);
}
#[test]
fn test_eigh() {
let a = Array::from_vec(vec![2.0f64, 1.0, 1.0, 2.0]).reshape(&[2, 2]);
let (eigenvalues, _eigenvectors) = eigh(&a).expect("eigendecomposition should succeed");
let mut eigs = eigenvalues.to_vec();
eigs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
assert_relative_eq!(eigs[0], 1.0, epsilon = 1e-10);
assert_relative_eq!(eigs[1], 3.0, epsilon = 1e-10);
}
#[test]
fn test_accelerated_blas_gemm() {
let a = Array::from_vec(vec![1.0f64, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
let b = Array::from_vec(vec![5.0f64, 6.0, 7.0, 8.0]).reshape(&[2, 2]);
let mut c = Array::zeros(&[2, 2]);
AcceleratedBlas::gemm(&a, &b, &mut c, 1.0, 0.0, false, false).expect("gemm should succeed");
assert_relative_eq!(c.get(&[0, 0]).expect("valid index"), 19.0, epsilon = 1e-10);
assert_relative_eq!(c.get(&[0, 1]).expect("valid index"), 22.0, epsilon = 1e-10);
assert_relative_eq!(c.get(&[1, 0]).expect("valid index"), 43.0, epsilon = 1e-10);
assert_relative_eq!(c.get(&[1, 1]).expect("valid index"), 50.0, epsilon = 1e-10);
}
}