use crate::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix1, Ix2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign, Zero};
use scirs2_core::Complex;
use std::iter::Sum;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UPLO {
Upper,
Lower,
}
pub trait ArrayLinalgExt<A, S: scirs2_core::ndarray::RawData> {
fn svd(&self, compute_uv: bool) -> LinalgResult<(Array2<A>, Array1<A>, Array2<A>)>;
#[allow(clippy::type_complexity)]
fn eig(
&self,
) -> LinalgResult<(
Array1<scirs2_core::Complex<A>>,
Array2<scirs2_core::Complex<A>>,
)>;
fn eigh(&self, uplo: UPLO) -> LinalgResult<(Array1<A>, Array2<A>)>;
fn eigvalsh(&self, uplo: UPLO) -> LinalgResult<Array1<A>>;
fn inv(&self) -> LinalgResult<Array2<A>>;
fn solve(&self, b: &ArrayBase<S, Ix1>) -> LinalgResult<Array1<A>>;
fn solve_into(&self, b: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>;
fn norm_l2(&self) -> A;
fn norm_fro(&self) -> A;
fn det(&self) -> LinalgResult<A>;
fn qr(&self) -> LinalgResult<(Array2<A>, Array2<A>)>;
fn lu(&self) -> LinalgResult<(Array2<A>, Array2<A>, Array2<A>)>;
fn cholesky(&self) -> LinalgResult<Array2<A>>;
}
impl<A, S> ArrayLinalgExt<A, S> for ArrayBase<S, Ix2>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
fn svd(&self, compute_uv: bool) -> LinalgResult<(Array2<A>, Array1<A>, Array2<A>)> {
crate::svd(&self.view(), compute_uv, None)
}
fn eig(
&self,
) -> LinalgResult<(
Array1<scirs2_core::Complex<A>>,
Array2<scirs2_core::Complex<A>>,
)> {
crate::eig(&self.view(), None)
}
fn eigh(&self, _uplo: UPLO) -> LinalgResult<(Array1<A>, Array2<A>)> {
crate::eigh(&self.view(), None)
}
fn eigvalsh(&self, _uplo: UPLO) -> LinalgResult<Array1<A>> {
crate::eigvalsh(&self.view(), None)
}
fn inv(&self) -> LinalgResult<Array2<A>> {
crate::inv(&self.view(), None)
}
fn solve(&self, b: &ArrayBase<S, Ix1>) -> LinalgResult<Array1<A>> {
crate::solve(&self.view(), &b.view(), None)
}
fn solve_into(&self, b: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>> {
crate::solve_multiple(&self.view(), &b.view(), None)
}
fn norm_l2(&self) -> A {
self.iter().map(|&x| x * x).sum::<A>().sqrt()
}
fn norm_fro(&self) -> A {
self.iter().map(|&x| x * x).sum::<A>().sqrt()
}
fn det(&self) -> LinalgResult<A> {
crate::det(&self.view(), None)
}
fn qr(&self) -> LinalgResult<(Array2<A>, Array2<A>)> {
crate::qr(&self.view(), None)
}
fn lu(&self) -> LinalgResult<(Array2<A>, Array2<A>, Array2<A>)> {
crate::lu(&self.view(), None)
}
fn cholesky(&self) -> LinalgResult<Array2<A>> {
crate::cholesky(&self.view(), None)
}
}
pub trait Solve<A> {
type Output;
fn solve(&self, rhs: &Self) -> LinalgResult<Self::Output>;
}
pub trait SVD {
type S;
type U;
type Vt;
fn svd(&self, compute_uv: bool) -> LinalgResult<(Self::U, Self::S, Self::Vt)>;
}
impl<A, S> SVD for ArrayBase<S, Ix2>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
type S = Array1<A>;
type U = Array2<A>;
type Vt = Array2<A>;
fn svd(&self, compute_uv: bool) -> LinalgResult<(Self::U, Self::S, Self::Vt)> {
ArrayLinalgExt::svd(self, compute_uv)
}
}
pub trait Eig {
type EigVal;
type EigVec;
fn eig(&self) -> LinalgResult<(Self::EigVal, Self::EigVec)>;
}
impl<A, S> Eig for ArrayBase<S, Ix2>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
type EigVal = Array1<scirs2_core::Complex<A>>;
type EigVec = Array2<scirs2_core::Complex<A>>;
fn eig(&self) -> LinalgResult<(Self::EigVal, Self::EigVec)> {
ArrayLinalgExt::eig(self)
}
}
pub trait Eigh {
type EigVal;
type EigVec;
fn eigh(&self, uplo: UPLO) -> LinalgResult<(Self::EigVal, Self::EigVec)>;
}
impl<A, S> Eigh for ArrayBase<S, Ix2>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
type EigVal = Array1<A>;
type EigVec = Array2<A>;
fn eigh(&self, uplo: UPLO) -> LinalgResult<(Self::EigVal, Self::EigVec)> {
ArrayLinalgExt::eigh(self, uplo)
}
}
pub trait Inverse {
type Output;
fn inv(&self) -> LinalgResult<Self::Output>;
}
impl<A, S> Inverse for ArrayBase<S, Ix2>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
type Output = Array2<A>;
fn inv(&self) -> LinalgResult<Self::Output> {
ArrayLinalgExt::inv(self)
}
}
pub trait Norm<A> {
fn norm(&self) -> A;
fn norm_l2(&self) -> A;
fn norm_fro(&self) -> A;
}
impl<A, S> Norm<A> for ArrayBase<S, Ix2>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
fn norm(&self) -> A {
ArrayLinalgExt::norm_fro(self)
}
fn norm_l2(&self) -> A {
ArrayLinalgExt::norm_l2(self)
}
fn norm_fro(&self) -> A {
ArrayLinalgExt::norm_fro(self)
}
}
impl<A, S> Norm<A> for ArrayBase<S, Ix1>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
fn norm(&self) -> A {
self.norm_l2()
}
fn norm_l2(&self) -> A {
self.iter().map(|&x| x * x).sum::<A>().sqrt()
}
fn norm_fro(&self) -> A {
self.norm_l2()
}
}
pub type SvdResult<A> = (Array2<A>, Array1<A>, Array2<A>);
pub fn svd<A, S>(a: &ArrayBase<S, Ix2>, compute_uv: bool) -> LinalgResult<SvdResult<A>>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
crate::svd(&a.view(), compute_uv, None)
}
#[allow(clippy::type_complexity)]
pub fn eig<A, S>(
a: &ArrayBase<S, Ix2>,
) -> LinalgResult<(
Array1<scirs2_core::Complex<A>>,
Array2<scirs2_core::Complex<A>>,
)>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
crate::eig(&a.view(), None)
}
pub fn eigh<A, S>(a: &ArrayBase<S, Ix2>, uplo: UPLO) -> LinalgResult<(Array1<A>, Array2<A>)>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
let _ = uplo; crate::eigh(&a.view(), None)
}
pub fn eigvalsh<A, S>(a: &ArrayBase<S, Ix2>, uplo: UPLO) -> LinalgResult<Array1<A>>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
let _ = uplo; crate::eigvalsh(&a.view(), None)
}
pub fn eigvals<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array1<scirs2_core::Complex<A>>>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
let (vals, _) = crate::eig(&a.view(), None)?;
Ok(vals)
}
pub fn eigvals_banded<A, S>(a: &ArrayBase<S, Ix2>, uplo: UPLO) -> LinalgResult<Array1<A>>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
eigvalsh(a, uplo)
}
pub fn eig_banded<A, S>(a: &ArrayBase<S, Ix2>, uplo: UPLO) -> LinalgResult<(Array1<A>, Array2<A>)>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
eigh(a, uplo)
}
pub fn eigh_tridiagonal<A>(d: &Array1<A>, e: &Array1<A>) -> LinalgResult<(Array1<A>, Array2<A>)>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
{
let n = d.len();
let mut mat = Array2::zeros((n, n));
for i in 0..n {
mat[[i, i]] = d[i];
if i < n - 1 {
mat[[i, i + 1]] = e[i];
mat[[i + 1, i]] = e[i];
}
}
eigh(&mat, UPLO::Lower)
}
pub fn eigvalsh_tridiagonal<A>(d: &Array1<A>, e: &Array1<A>) -> LinalgResult<Array1<A>>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
{
let (vals, _) = eigh_tridiagonal(d, e)?;
Ok(vals)
}
pub fn inv<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
crate::inv(&a.view(), None)
}
pub fn det<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<A>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
crate::det(&a.view(), None)
}
pub fn qr<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<(Array2<A>, Array2<A>)>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
crate::qr(&a.view(), None)
}
pub fn rq<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<(Array2<A>, Array2<A>)>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
let t = a.t();
let (q, r) = crate::qr(&t.view(), None)?;
Ok((r.reversed_axes(), q.reversed_axes()))
}
pub fn lu<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<(Array2<A>, Array2<A>, Array2<A>)>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
crate::lu(&a.view(), None)
}
pub fn cholesky<A, S>(a: &ArrayBase<S, Ix2>, uplo: UPLO) -> LinalgResult<Array2<A>>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
let _ = uplo; crate::cholesky(&a.view(), None)
}
pub fn compat_solve<A, S1, S2>(
a: &ArrayBase<S1, Ix2>,
b: &ArrayBase<S2, Ix1>,
) -> LinalgResult<Array1<A>>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S1: Data<Elem = A>,
S2: Data<Elem = A>,
{
crate::solve(&a.view(), &b.view(), None)
}
pub fn solve_banded<A, S1, S2>(
l_and_u: (usize, usize),
ab: &ArrayBase<S1, Ix2>,
b: &ArrayBase<S2, Ix1>,
) -> LinalgResult<Array1<A>>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S1: Data<Elem = A>,
S2: Data<Elem = A>,
{
let (l, u) = l_and_u;
crate::structured_solvers::solve_banded(l, u, &ab.view(), &b.view())
}
pub fn solve_triangular<A, S1, S2>(
a: &ArrayBase<S1, Ix2>,
b: &ArrayBase<S2, Ix1>,
lower: bool,
) -> LinalgResult<Array1<A>>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S1: Data<Elem = A>,
S2: Data<Elem = A>,
{
let _ = (a, b, lower);
Err(LinalgError::ComputationError(
"solve_triangular not yet implemented".to_string(),
))
}
pub fn lstsq<A, S1, S2>(a: &ArrayBase<S1, Ix2>, b: &ArrayBase<S2, Ix1>) -> LinalgResult<Array1<A>>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S1: Data<Elem = A>,
S2: Data<Elem = A>,
{
let result = crate::lstsq(&a.view(), &b.view(), None)?;
Ok(result.x)
}
pub fn pinv<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
let (u, s, vt) = crate::svd(&a.view(), true, None)?;
let threshold = A::from(1e-15)
.ok_or_else(|| LinalgError::ComputationError("Failed to convert threshold".to_string()))?
* s[[0]];
let s_inv: Array1<A> = s.map(|&val| {
if val > threshold {
A::one() / val
} else {
A::zero()
}
});
Ok(vt.t().dot(&Array2::from_diag(&s_inv)).dot(&u.t()))
}
pub fn matrix_rank<A, S>(a: &ArrayBase<S, Ix2>, tol: Option<A>) -> LinalgResult<usize>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
let (_, s, _) = crate::svd(&a.view(), false, None)?;
let threshold = tol.unwrap_or_else(|| {
let max_singular = s.iter().fold(A::zero(), |a, &b| if b > a { b } else { a });
let dim_factor = A::from(a.nrows().max(a.ncols())).unwrap_or_else(|| A::one());
max_singular * dim_factor * A::epsilon()
});
Ok(s.iter().filter(|&&val| val > threshold).count())
}
pub fn cond<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<A>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
let (_, s, _) = crate::svd(&a.view(), false, None)?;
let s_max = s.iter().fold(A::zero(), |a, &b| if b > a { b } else { a });
let s_min = s
.iter()
.fold(s_max, |a, &b| if b < a && b > A::zero() { b } else { a });
if s_min == A::zero() {
return Ok(A::infinity());
}
Ok(s_max / s_min)
}
pub fn norm<A, S>(a: &ArrayBase<S, Ix2>, ord: Option<&str>) -> LinalgResult<A>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
match ord {
None | Some("fro") => Ok(ArrayLinalgExt::norm_fro(a)),
Some("2") => {
let (_, s, _) = crate::svd(&a.view(), false, None)?;
Ok(s[[0]])
}
_ => Err(LinalgError::ComputationError(format!(
"norm ord={:?} not implemented",
ord
))),
}
}
pub fn vector_norm<A, S>(a: &ArrayBase<S, Ix1>, ord: Option<i32>) -> A
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
match ord {
None | Some(2) => a.iter().map(|&x| x * x).sum::<A>().sqrt(),
Some(1) => a.iter().map(|&x| x.abs()).sum::<A>(),
_ => a.iter().map(|&x| x * x).sum::<A>().sqrt(), }
}
pub fn schur<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<(Array2<A>, Array2<A>)>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
crate::schur(&a.view())
}
pub fn polar<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<(Array2<A>, Array2<A>)>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
let (u, s, vt) = crate::svd(&a.view(), true, None)?;
let unitary = u.dot(&vt);
let hermitian = vt.t().dot(&Array2::from_diag(&s)).dot(&vt);
Ok((unitary, hermitian))
}
pub fn expm<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
crate::expm(&a.view(), None)
}
pub fn logm<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
crate::logm(&a.view())
}
pub fn sqrtm<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
let tol = A::from(1e-8)
.ok_or_else(|| LinalgError::ComputationError("Failed to convert tolerance".to_string()))?;
crate::sqrtm(&a.view(), 100, tol)
}
pub fn sinm<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
crate::sinm(&a.view())
}
pub fn cosm<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
crate::cosm(&a.view())
}
pub fn tanm<A, S>(a: &ArrayBase<S, Ix2>) -> LinalgResult<Array2<A>>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
crate::tanm(&a.view())
}
pub fn funm<A, S, F>(a: &ArrayBase<S, Ix2>, func: F) -> LinalgResult<Array2<A>>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
F: Fn(A) -> A,
{
let (vals, vecs) = crate::eigh(&a.view(), None)?;
let f_vals: Array1<A> = vals.map(|&v| func(v));
Ok(vecs.dot(&Array2::from_diag(&f_vals)).dot(&vecs.t()))
}
pub fn fractionalmatrix_power<A, S>(a: &ArrayBase<S, Ix2>, p: A) -> LinalgResult<Array2<A>>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static,
S: Data<Elem = A>,
{
funm(a, |x| x.powf(p))
}
pub fn block_diag<A>(blocks: &[Array2<A>]) -> Array2<A>
where
A: Float + NumAssign + Sum + Send + Sync + ScalarOperand + 'static + Zero,
{
if blocks.is_empty() {
return Array2::zeros((0, 0));
}
let total_rows: usize = blocks.iter().map(|b| b.nrows()).sum();
let total_cols: usize = blocks.iter().map(|b| b.ncols()).sum();
let mut result = Array2::zeros((total_rows, total_cols));
let mut row_offset = 0;
let mut col_offset = 0;
for block in blocks {
let nrows = block.nrows();
let ncols = block.ncols();
result
.slice_mut(scirs2_core::ndarray::s![
row_offset..row_offset + nrows,
col_offset..col_offset + ncols
])
.assign(block);
row_offset += nrows;
col_offset += ncols;
}
result
}