use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array2, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign, One};
use std::iter::Sum;
pub trait SqrtFloat: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static {}
impl<F> SqrtFloat for F where F: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static {}
fn matmul_nn<F: SqrtFloat>(a: &Array2<F>, b: &Array2<F>) -> LinalgResult<Array2<F>> {
let (m, k) = (a.nrows(), a.ncols());
let (k2, n) = (b.nrows(), b.ncols());
if k != k2 {
return Err(LinalgError::ShapeError(format!(
"sqrtm matmul: inner dims mismatch {} vs {}",
k, k2
)));
}
let mut c = Array2::<F>::zeros((m, n));
for i in 0..m {
for l in 0..k {
let a_il = a[[i, l]];
if a_il == F::zero() {
continue;
}
for j in 0..n {
c[[i, j]] += a_il * b[[l, j]];
}
}
}
Ok(c)
}
fn frobenius_norm<F: SqrtFloat>(a: &Array2<F>) -> F {
let mut acc = F::zero();
for &v in a.iter() {
acc += v * v;
}
acc.sqrt()
}
fn lu_factorize<F: SqrtFloat>(a: &Array2<F>) -> LinalgResult<(Array2<F>, Vec<usize>)> {
let n = a.nrows();
let mut lu = a.clone();
let mut perm: Vec<usize> = (0..n).collect();
for k in 0..n {
let mut max_val = F::zero();
let mut max_row = k;
for i in k..n {
let v = lu[[i, k]].abs();
if v > max_val {
max_val = v;
max_row = i;
}
}
if max_val < F::epsilon() * F::from(1000.0).unwrap_or(F::one()) {
return Err(LinalgError::SingularMatrixError(
"Matrix is (nearly) singular in LU factorization".into(),
));
}
if max_row != k {
perm.swap(k, max_row);
for j in 0..n {
let tmp = lu[[k, j]];
lu[[k, j]] = lu[[max_row, j]];
lu[[max_row, j]] = tmp;
}
}
for i in (k + 1)..n {
if lu[[k, k]].abs() < F::epsilon() {
continue;
}
let lkk = lu[[k, k]];
lu[[i, k]] /= lkk;
for j in (k + 1)..n {
let l_ik = lu[[i, k]];
let u_kj = lu[[k, j]];
lu[[i, j]] -= l_ik * u_kj;
}
}
}
Ok((lu, perm))
}
fn lu_solve<F: SqrtFloat>(lu: &Array2<F>, perm: &[usize], b: &Array2<F>) -> Array2<F> {
let n = lu.nrows();
let nrhs = b.ncols();
let mut x = Array2::<F>::zeros((n, nrhs));
for col in 0..nrhs {
let mut y = vec![F::zero(); n];
for i in 0..n {
y[i] = b[[perm[i], col]];
}
for i in 0..n {
for j in 0..i {
let yj = y[j];
y[i] -= lu[[i, j]] * yj;
}
}
let mut z = vec![F::zero(); n];
for i in (0..n).rev() {
let mut sum = y[i];
for j in (i + 1)..n {
sum -= lu[[i, j]] * z[j];
}
z[i] = sum / lu[[i, i]];
}
for i in 0..n {
x[[i, col]] = z[i];
}
}
x
}
fn mat_inv<F: SqrtFloat>(a: &Array2<F>) -> LinalgResult<Array2<F>> {
let n = a.nrows();
let identity = Array2::<F>::eye(n);
let (lu, perm) = lu_factorize(a)?;
Ok(lu_solve(&lu, &perm, &identity))
}
pub fn sqrtm_denman_beavers<F: SqrtFloat>(
a: &ArrayView2<F>,
max_iter: Option<usize>,
tol: Option<F>,
) -> LinalgResult<Array2<F>> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"sqrtm_denman_beavers: matrix must be square".into(),
));
}
if n == 0 {
return Ok(Array2::<F>::zeros((0, 0)));
}
let max_iter = max_iter.unwrap_or(100);
let tol = tol.unwrap_or_else(|| F::from(1e-10).unwrap_or(F::epsilon()));
let mut x = a.to_owned();
let mut y = Array2::<F>::eye(n);
for _ in 0..max_iter {
let x_inv = mat_inv(&x)?;
let y_inv = mat_inv(&y)?;
let two = F::from(2.0).unwrap_or(F::one() + F::one());
let mut x_new = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
x_new[[i, j]] = (x[[i, j]] + y_inv[[i, j]]) / two;
}
}
let mut y_new = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
y_new[[i, j]] = (y[[i, j]] + x_inv[[i, j]]) / two;
}
}
let mut diff = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
diff[[i, j]] = x_new[[i, j]] - x[[i, j]];
}
}
let rel_change = frobenius_norm(&diff)
/ (frobenius_norm(&x_new) + F::from(1e-30).unwrap_or(F::epsilon()));
x = x_new;
y = y_new;
if rel_change < tol {
return Ok(x);
}
}
Ok(x)
}
pub fn sqrtm_product_db<F: SqrtFloat>(
a: &ArrayView2<F>,
max_iter: Option<usize>,
tol: Option<F>,
) -> LinalgResult<Array2<F>> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"sqrtm_product_db: matrix must be square".into(),
));
}
if n == 0 {
return Ok(Array2::<F>::zeros((0, 0)));
}
let max_iter = max_iter.unwrap_or(50);
let tol = tol.unwrap_or_else(|| F::from(1e-12).unwrap_or(F::epsilon()));
let two = F::from(2.0).unwrap_or(F::one() + F::one());
let a_owned = a.to_owned();
let mu0 = compute_det_scale(&a_owned, n);
let mu0_sqrt = mu0.sqrt();
let mu0_sqrt_inv = if mu0_sqrt.abs() < F::from(1e-30).unwrap_or(F::epsilon()) {
F::one()
} else {
F::one() / mu0_sqrt
};
let mut x = a.mapv(|v| v * mu0);
let mut y = Array2::<F>::eye(n);
for _ in 0..max_iter {
let x_inv = mat_inv(&x)?;
let y_inv = mat_inv(&y)?;
let mut x_new = Array2::<F>::zeros((n, n));
let mut y_new = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
x_new[[i, j]] = (x[[i, j]] + y_inv[[i, j]]) / two;
y_new[[i, j]] = (y[[i, j]] + x_inv[[i, j]]) / two;
}
}
let mut diff = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
diff[[i, j]] = x_new[[i, j]] - x[[i, j]];
}
}
let rel_change = frobenius_norm(&diff)
/ (frobenius_norm(&x_new) + F::from(1e-30).unwrap_or(F::epsilon()));
x = x_new;
y = y_new;
if rel_change < tol {
return Ok(x.mapv(|v| v * mu0_sqrt_inv));
}
}
Ok(x.mapv(|v| v * mu0_sqrt_inv))
}
fn compute_det_scale<F: SqrtFloat>(a: &Array2<F>, n: usize) -> F {
if let Ok((lu, _)) = lu_factorize(a) {
let mut log_det = F::zero();
let mut sign_count = 0i32;
for i in 0..n {
let d = lu[[i, i]];
if d.abs() < F::from(1e-30).unwrap_or(F::epsilon()) {
return F::one(); }
if d < F::zero() {
sign_count += 1;
log_det += (-d).ln();
} else {
log_det += d.ln();
}
}
if sign_count % 2 != 0 {
return F::one();
}
let exponent = -log_det / F::from(2 * n).unwrap_or(F::one());
exponent.exp()
} else {
F::one()
}
}
pub fn sqrtm_positive_definite<F>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>>
where
F: Float + NumAssign + Sum + One + ScalarOperand + Send + Sync + 'static + std::fmt::Display,
{
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"sqrtm_positive_definite: matrix must be square".into(),
));
}
if n == 0 {
return Ok(Array2::<F>::zeros((0, 0)));
}
use crate::matrix_functions::fractional::spdmatrix_function;
spdmatrix_function(
a,
|x: F| {
if x < F::zero() {
F::zero() } else {
x.sqrt()
}
},
true,
)
}
pub fn sqrtm<F: SqrtFloat>(
a: &ArrayView2<F>,
max_iter: Option<usize>,
tol: Option<F>,
) -> LinalgResult<Array2<F>> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"sqrtm: matrix must be square".into(),
));
}
let _ = n; sqrtm_denman_beavers(a, max_iter, tol)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_sqrtm_db_diagonal() {
let a = array![[4.0_f64, 0.0], [0.0, 9.0]];
let s = sqrtm_denman_beavers(&a.view(), None, None).expect("sqrtm_db diagonal");
assert_abs_diff_eq!(s[[0, 0]], 2.0, epsilon = 1e-7);
assert_abs_diff_eq!(s[[1, 1]], 3.0, epsilon = 1e-7);
assert_abs_diff_eq!(s[[0, 1]], 0.0, epsilon = 1e-6);
assert_abs_diff_eq!(s[[1, 0]], 0.0, epsilon = 1e-6);
}
#[test]
fn test_sqrtm_db_identity() {
let a = Array2::<f64>::eye(3);
let s = sqrtm_denman_beavers(&a.view(), None, None).expect("sqrtm_db identity");
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert_abs_diff_eq!(s[[i, j]], expected, epsilon = 1e-7);
}
}
}
#[test]
fn test_sqrtm_product_db_diagonal() {
let a = array![[9.0_f64, 0.0], [0.0, 4.0]];
let s = sqrtm_product_db(&a.view(), None, None).expect("sqrtm_product_db diagonal");
assert_abs_diff_eq!(s[[0, 0]], 3.0, epsilon = 1e-8);
assert_abs_diff_eq!(s[[1, 1]], 2.0, epsilon = 1e-8);
}
#[test]
fn test_sqrtm_verifies_s_squared() {
let a = array![[5.0_f64, 2.0], [2.0, 5.0]];
let s = sqrtm_denman_beavers(&a.view(), Some(200), Some(1e-12))
.expect("sqrtm_db square verify");
let s2 = matmul_nn(&s, &s).expect("s2 matmul");
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(s2[[i, j]], a[[i, j]], epsilon = 1e-6);
}
}
}
#[test]
fn test_sqrtm_pd_diagonal() {
let a = array![[4.0_f64, 0.0], [0.0, 9.0]];
let s = sqrtm_positive_definite(&a.view()).expect("sqrtm_pd diagonal");
assert_abs_diff_eq!(s[[0, 0]], 2.0, epsilon = 1e-8);
assert_abs_diff_eq!(s[[1, 1]], 3.0, epsilon = 1e-8);
}
#[test]
fn test_sqrtm_dispatch_large() {
let n = 5;
let mut a = Array2::<f64>::eye(n);
for i in 0..n {
a[[i, i]] = 4.0;
}
let s = sqrtm(&a.view(), None, None).expect("sqrtm dispatch large");
for i in 0..n {
assert_abs_diff_eq!(s[[i, i]], 2.0, epsilon = 1e-7);
}
}
#[test]
fn test_sqrtm_2x2_upper_triangular() {
let a = array![[4.0_f64, 2.0], [0.0, 9.0]];
let s =
sqrtm_denman_beavers(&a.view(), Some(200), Some(1e-12)).expect("sqrtm_db triangular");
let s2 = matmul_nn(&s, &s).expect("s2 matmul triangular");
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(s2[[i, j]], a[[i, j]], epsilon = 1e-6);
}
}
}
}