use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array2, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::iter::Sum;
pub trait RootsFloat:
Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static
{
}
impl<T> RootsFloat for T where
T: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static
{
}
fn matmul_nn<F: RootsFloat>(a: &Array2<F>, b: &Array2<F>) -> Array2<F> {
let n = a.nrows();
let mut c = Array2::<F>::zeros((n, n));
for i in 0..n {
for k in 0..n {
let aik = a[[i, k]];
if aik == F::zero() {
continue;
}
for j in 0..n {
c[[i, j]] = c[[i, j]] + aik * b[[k, j]];
}
}
}
c
}
fn frobenius_norm<F: RootsFloat>(m: &Array2<F>) -> F {
let mut acc = F::zero();
for &v in m.iter() {
acc = acc + v * v;
}
acc.sqrt()
}
fn check_square<F: RootsFloat>(a: &ArrayView2<F>, name: &str) -> LinalgResult<usize> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(format!(
"{name}: matrix must be square, got {}x{}",
a.nrows(),
a.ncols()
)));
}
Ok(n)
}
pub fn sqrtm<F: RootsFloat>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>> {
let n = check_square(a, "sqrtm")?;
if n == 0 {
return Ok(Array2::<F>::zeros((0, 0)));
}
if n == 1 {
let val = a[[0, 0]];
if val < F::zero() {
return Err(LinalgError::DomainError(
"sqrtm: negative scalar has no real principal square root".into(),
));
}
let mut result = Array2::<F>::zeros((1, 1));
result[[0, 0]] = val.sqrt();
return Ok(result);
}
if is_diagonal(a, n) {
let mut result = Array2::<F>::zeros((n, n));
for i in 0..n {
let val = a[[i, i]];
if val < F::zero() {
return Err(LinalgError::DomainError(
"sqrtm: matrix has negative eigenvalues; principal square root does not exist"
.into(),
));
}
result[[i, i]] = val.sqrt();
}
return Ok(result);
}
let (q, t) = crate::decomposition::schur(a)?;
let sqrt_t = triangular_sqrt(&t, n)?;
Ok(q.dot(&sqrt_t).dot(&q.t()))
}
fn triangular_sqrt<F: RootsFloat>(t: &Array2<F>, n: usize) -> LinalgResult<Array2<F>> {
let mut s = Array2::<F>::zeros((n, n));
for i in 0..n {
if t[[i, i]] < F::zero() {
return Err(LinalgError::DomainError(
"triangular_sqrt: negative diagonal entry (negative eigenvalue)".into(),
));
}
s[[i, i]] = t[[i, i]].sqrt();
}
for j in 1..n {
for i in (0..j).rev() {
let mut off_sum = F::zero();
for k in (i + 1)..j {
off_sum = off_sum + s[[i, k]] * s[[k, j]];
}
let denom = s[[i, i]] + s[[j, j]];
if denom.abs() < F::epsilon() * F::from(10.0).unwrap_or(F::one()) {
s[[i, j]] = F::zero();
} else {
s[[i, j]] = (t[[i, j]] - off_sum) / denom;
}
}
}
Ok(s)
}
pub fn sqrtm_denman_beavers<F: RootsFloat>(
a: &ArrayView2<F>,
max_iter: Option<usize>,
tol: Option<F>,
) -> LinalgResult<Array2<F>> {
let n = check_square(a, "sqrtm_denman_beavers")?;
let max_it = max_iter.unwrap_or(100);
let eps = tol.unwrap_or_else(|| F::from(1e-12).unwrap_or(F::epsilon()));
let half = F::from(0.5)
.ok_or_else(|| LinalgError::ComputationError("Cannot convert 0.5".into()))?;
let mut y = a.to_owned();
let mut z = Array2::<F>::eye(n);
for _ in 0..max_it {
let z_inv = crate::inv(&z.view(), None)?;
let y_inv = crate::inv(&y.view(), None)?;
let y_new = (&y + &z_inv) * half;
let z_new = (&z + &y_inv) * half;
let diff = frobenius_norm(&(&y_new - &y));
y = y_new;
z = z_new;
if diff < eps {
return Ok(y);
}
}
Ok(y)
}
pub fn pth_root<F: RootsFloat>(a: &ArrayView2<F>, p: u32) -> LinalgResult<Array2<F>> {
let n = check_square(a, "pth_root")?;
if p == 0 {
return Err(LinalgError::ValueError(
"pth_root: p must be >= 1".into(),
));
}
if p == 1 {
return Ok(a.to_owned());
}
if n == 0 {
return Ok(Array2::<F>::zeros((0, 0)));
}
if is_diagonal(a, n) {
let p_inv = F::one()
/ F::from(p).ok_or_else(|| LinalgError::ComputationError("Cannot convert p".into()))?;
let mut result = Array2::<F>::zeros((n, n));
for i in 0..n {
let val = a[[i, i]];
if val < F::zero() {
return Err(LinalgError::DomainError(
"pth_root: matrix has negative eigenvalues".into(),
));
}
result[[i, i]] = val.powf(p_inv);
}
return Ok(result);
}
let (q, t) = crate::decomposition::schur(a)?;
let root_t = triangular_pth_root(&t, n, p)?;
Ok(q.dot(&root_t).dot(&q.t()))
}
fn triangular_pth_root<F: RootsFloat>(
t: &Array2<F>,
n: usize,
p: u32,
) -> LinalgResult<Array2<F>> {
let p_inv = F::one()
/ F::from(p).ok_or_else(|| LinalgError::ComputationError("Cannot convert p".into()))?;
let mut r = Array2::<F>::zeros((n, n));
for i in 0..n {
let val = t[[i, i]];
if val < F::zero() {
return Err(LinalgError::DomainError(
"triangular_pth_root: negative eigenvalue encountered".into(),
));
}
r[[i, i]] = val.powf(p_inv);
}
for j in 1..n {
for i in (0..j).rev() {
let mut numer = t[[i, j]];
for k in (i + 1)..j {
numer = numer - r[[i, k]] * r[[k, j]];
}
let rii = r[[i, i]];
let rjj = r[[j, j]];
let mut denom = F::zero();
for k in 0..p {
let term = rii.powi((p - 1 - k) as i32) * rjj.powi(k as i32);
denom = denom + term;
}
if denom.abs() < F::epsilon() * F::from(10.0).unwrap_or(F::one()) {
r[[i, j]] = F::zero();
} else {
r[[i, j]] = numer / denom;
}
}
}
Ok(r)
}
fn gauss_legendre_nodes(order: usize) -> (Vec<f64>, Vec<f64>) {
match order {
1 => (vec![0.5], vec![1.0]),
2 => (
vec![0.2113248654051871, 0.7886751345948129],
vec![0.5, 0.5],
),
4 => (
vec![
0.06943184420297371,
0.33000947820757187,
0.6699905217924281,
0.9305681557970263,
],
vec![
0.17392742256872685,
0.32607257743127315,
0.32607257743127315,
0.17392742256872685,
],
),
8 => (
vec![
0.019855071751231884,
0.10166676129318664,
0.2372337950418355,
0.4082826787521751,
0.5917173212478249,
0.7627662049581645,
0.8983332387068134,
0.9801449282487682,
],
vec![
0.050614268145188,
0.11119051722492964,
0.15685332293894369,
0.18134189168918087,
0.18134189168918087,
0.15685332293894369,
0.11119051722492964,
0.050614268145188,
],
),
16 => (
vec![
0.005299532504175031,
0.027233228312309445,
0.06504581385637368,
0.11588846949991124,
0.17830370308927756,
0.24990745750488012,
0.3268553544165069,
0.4090169573769576,
0.5909830426230424,
0.6731446455834931,
0.7500925424951199,
0.8216962969107224,
0.8841115305000888,
0.9349541861436263,
0.9727667716876906,
0.9947004674958249,
],
vec![
0.013576229705877,
0.031126761969324,
0.047579255841244,
0.062314485627767,
0.074797994408289,
0.084578259697501,
0.091704110370050,
0.095879026375961,
0.095879026375961,
0.091704110370050,
0.084578259697501,
0.074797994408289,
0.062314485627767,
0.047579255841244,
0.031126761969324,
0.013576229705877,
],
),
_ => {
let h = 1.0 / (order as f64);
let nodes = (0..order).map(|i| (i as f64 + 0.5) * h).collect();
let weights = vec![h; order];
(nodes, weights)
}
}
}
fn logm_pade_approx<F: RootsFloat>(x: &Array2<F>, order: usize) -> LinalgResult<Array2<F>> {
let n = x.nrows();
let eye = Array2::<F>::eye(n);
let (nodes, weights) = gauss_legendre_nodes(order);
let mut result = Array2::<F>::zeros((n, n));
for k in 0..order {
let t_k = F::from(nodes[k]).unwrap_or(F::zero());
let w_k = F::from(weights[k]).unwrap_or(F::zero());
let mut mat = eye.clone();
for i in 0..n {
for j in 0..n {
mat[[i, j]] = mat[[i, j]] + t_k * x[[i, j]];
}
}
let mat_inv = crate::inv(&mat.view(), None)?;
let x_mat_inv = matmul_nn(x, &mat_inv);
for i in 0..n {
for j in 0..n {
result[[i, j]] = result[[i, j]] + w_k * x_mat_inv[[i, j]];
}
}
}
Ok(result)
}
pub fn logm<F: RootsFloat>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>> {
let n = check_square(a, "logm")?;
if n == 0 {
return Ok(Array2::<F>::zeros((0, 0)));
}
if n == 1 {
let val = a[[0, 0]];
if val <= F::zero() {
return Err(LinalgError::DomainError(
"logm: cannot compute real logarithm of non-positive scalar".into(),
));
}
let mut result = Array2::<F>::zeros((1, 1));
result[[0, 0]] = val.ln();
return Ok(result);
}
if is_diagonal(a, n) {
let mut result = Array2::<F>::zeros((n, n));
for i in 0..n {
let val = a[[i, i]];
if val <= F::zero() {
return Err(LinalgError::DomainError(
"logm: matrix has non-positive eigenvalue on diagonal".into(),
));
}
result[[i, i]] = val.ln();
}
return Ok(result);
}
let (q, t) = crate::decomposition::schur(a)?;
let log_t = triangular_logm(&t, n)?;
Ok(q.dot(&log_t).dot(&q.t()))
}
fn triangular_logm<F: RootsFloat>(t: &Array2<F>, n: usize) -> LinalgResult<Array2<F>> {
for i in 0..n {
if t[[i, i]] <= F::zero() {
return Err(LinalgError::DomainError(
"triangular_logm: non-positive diagonal entry (non-positive eigenvalue)".into(),
));
}
}
if n == 1 {
let mut result = Array2::<F>::zeros((1, 1));
result[[0, 0]] = t[[0, 0]].ln();
return Ok(result);
}
let eye = Array2::<F>::eye(n);
let threshold = F::from(0.5).unwrap_or(F::one());
let max_scalings = 50usize;
let mut t_k = t.to_owned();
let mut s = 0usize;
for _ in 0..max_scalings {
let diff = frobenius_norm(&(&t_k - &eye));
if diff < threshold {
break;
}
t_k = triangular_sqrt(&t_k, n)?;
s += 1;
}
let x = &t_k - &eye;
let log_x = logm_pade_approx(&x, 16)?;
let two = F::one() + F::one();
let scale = two.powi(s as i32);
Ok(log_x.map(|&v| v * scale))
}
fn is_diagonal<F: RootsFloat>(a: &ArrayView2<F>, n: usize) -> bool {
for i in 0..n {
for j in 0..n {
if i != j && a[[i, j]].abs() > F::epsilon() {
return false;
}
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_sqrtm_diagonal() {
let a = array![[4.0_f64, 0.0], [0.0, 9.0]];
let s = sqrtm(&a.view()).expect("sqrtm failed");
assert_abs_diff_eq!(s[[0, 0]], 2.0, epsilon = 1e-8);
assert_abs_diff_eq!(s[[1, 1]], 3.0, epsilon = 1e-8);
assert!(s[[0, 1]].abs() < 1e-10);
assert!(s[[1, 0]].abs() < 1e-10);
}
#[test]
fn test_sqrtm_identity() {
let eye = array![[1.0_f64, 0.0], [0.0, 1.0]];
let s = sqrtm(&eye.view()).expect("sqrtm failed");
for i in 0..2 {
for j in 0..2 {
let expected = if i == j { 1.0 } else { 0.0 };
assert_abs_diff_eq!(s[[i, j]], expected, epsilon = 1e-10);
}
}
}
#[test]
fn test_sqrtm_general_spd() {
let a = array![[5.0_f64, 2.0], [2.0, 5.0]];
let s = sqrtm(&a.view()).expect("sqrtm failed");
let ss = matmul_nn(&s, &s);
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(ss[[i, j]], a[[i, j]], epsilon = 1e-8);
}
}
}
#[test]
fn test_sqrtm_negative_eigenvalue_fails() {
let a = array![[-4.0_f64, 0.0], [0.0, 9.0]];
let result = sqrtm(&a.view());
assert!(result.is_err());
}
#[test]
fn test_sqrtm_db_diagonal() {
let a = array![[4.0_f64, 0.0], [0.0, 9.0]];
let s = sqrtm_denman_beavers(&a.view(), None, None).expect("sqrtm_db failed");
assert_abs_diff_eq!(s[[0, 0]], 2.0, epsilon = 1e-8);
assert_abs_diff_eq!(s[[1, 1]], 3.0, epsilon = 1e-8);
}
#[test]
fn test_sqrtm_db_general() {
let a = array![[2.0_f64, 1.0], [1.0, 3.0]];
let s = sqrtm_denman_beavers(&a.view(), None, None).expect("sqrtm_db failed");
let ss = matmul_nn(&s, &s);
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(ss[[i, j]], a[[i, j]], epsilon = 1e-8);
}
}
}
#[test]
fn test_pth_root_square_root() {
let a = array![[4.0_f64, 0.0], [0.0, 16.0]];
let r = pth_root(&a.view(), 2).expect("pth_root failed");
assert_abs_diff_eq!(r[[0, 0]], 2.0, epsilon = 1e-8);
assert_abs_diff_eq!(r[[1, 1]], 4.0, epsilon = 1e-8);
}
#[test]
fn test_pth_root_cube_root() {
let a = array![[8.0_f64, 0.0], [0.0, 27.0]];
let r = pth_root(&a.view(), 3).expect("pth_root cube failed");
assert_abs_diff_eq!(r[[0, 0]], 2.0, epsilon = 1e-8);
assert_abs_diff_eq!(r[[1, 1]], 3.0, epsilon = 1e-8);
}
#[test]
fn test_pth_root_identity() {
let eye = array![[1.0_f64, 0.0], [0.0, 1.0]];
let r = pth_root(&eye.view(), 5).expect("pth_root identity failed");
for i in 0..2 {
for j in 0..2 {
let expected = if i == j { 1.0 } else { 0.0 };
assert_abs_diff_eq!(r[[i, j]], expected, epsilon = 1e-8);
}
}
}
#[test]
fn test_pth_root_p_eq_1() {
let a = array![[3.0_f64, 1.0], [0.0, 2.0]];
let r = pth_root(&a.view(), 1).expect("pth_root p=1 failed");
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(r[[i, j]], a[[i, j]], epsilon = 1e-12);
}
}
}
#[test]
fn test_pth_root_general_spd() {
let a = array![[16.0_f64, 0.0], [0.0, 81.0]];
let r = pth_root(&a.view(), 4).expect("pth_root 4th failed");
assert_abs_diff_eq!(r[[0, 0]], 2.0, epsilon = 1e-8);
assert_abs_diff_eq!(r[[1, 1]], 3.0, epsilon = 1e-8);
}
#[test]
fn test_logm_identity() {
let eye = array![[1.0_f64, 0.0], [0.0, 1.0]];
let l = logm(&eye.view()).expect("logm identity failed");
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(l[[i, j]], 0.0, epsilon = 1e-8);
}
}
}
#[test]
fn test_logm_diagonal() {
let e = std::f64::consts::E;
let a = array![[e, 0.0], [0.0, e * e]];
let l = logm(&a.view()).expect("logm diagonal failed");
assert_abs_diff_eq!(l[[0, 0]], 1.0, epsilon = 1e-4);
assert_abs_diff_eq!(l[[1, 1]], 2.0, epsilon = 1e-4);
assert!(l[[0, 1]].abs() < 1e-8);
assert!(l[[1, 0]].abs() < 1e-8);
}
#[test]
fn test_logm_inverse_of_expm() {
let a = array![[0.5_f64, 0.2], [0.1, 0.3]];
let exp_a = crate::matrix_functions::pade::pade_expm(&a.view()).expect("expm failed");
let log_exp_a = logm(&exp_a.view()).expect("logm failed");
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(log_exp_a[[i, j]], a[[i, j]], epsilon = 1e-6);
}
}
}
#[test]
fn test_logm_spd_matrix() {
let a = array![[5.0_f64, 2.0], [2.0, 5.0]];
let l = logm(&a.view()).expect("logm SPD failed");
let exp_l = crate::matrix_functions::pade::pade_expm(&l.view()).expect("expm failed");
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(exp_l[[i, j]], a[[i, j]], epsilon = 1e-6);
}
}
}
#[test]
fn test_logm_non_positive_fails() {
let a = array![[-1.0_f64, 0.0], [0.0, 4.0]];
let result = logm(&a.view());
assert!(result.is_err());
}
}