use crate::eigen::zolotarev::{evaluate_rational, zolotarev_sign};
use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array2, ArrayView2};
use scirs2_core::numeric::{Float, NumAssign};
use std::fmt::Debug;
use std::iter::Sum;
#[derive(Debug, Clone)]
pub struct ZolotarevConfig {
pub degree: usize,
pub delta: f64,
pub tol: f64,
}
impl Default for ZolotarevConfig {
fn default() -> Self {
Self {
degree: 8,
delta: 0.1,
tol: 1e-10,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum MatFun {
Sqrt,
Sign,
Log,
}
pub trait ZolotarevFloat:
Float
+ NumAssign
+ Sum
+ Debug
+ Clone
+ scirs2_core::ndarray::ScalarOperand
+ scirs2_core::numeric::FromPrimitive
+ Send
+ Sync
+ 'static
{
}
impl<F> ZolotarevFloat for F where
F: Float
+ NumAssign
+ Sum
+ Debug
+ Clone
+ scirs2_core::ndarray::ScalarOperand
+ scirs2_core::numeric::FromPrimitive
+ Send
+ Sync
+ 'static
{
}
fn matmul<F: ZolotarevFloat>(a: &Array2<F>, b: &Array2<F>) -> LinalgResult<Array2<F>> {
let (m, k) = (a.nrows(), a.ncols());
let (k2, n) = (b.nrows(), b.ncols());
if k != k2 {
return Err(LinalgError::DimensionError(format!(
"matmul: inner dims {} != {}",
k, k2
)));
}
let mut c = Array2::<F>::zeros((m, n));
for i in 0..m {
for l in 0..k {
let a_il = a[[i, l]];
if a_il == F::zero() {
continue;
}
for j in 0..n {
c[[i, j]] += a_il * b[[l, j]];
}
}
}
Ok(c)
}
fn frob_norm<F: ZolotarevFloat>(a: &Array2<F>) -> F {
let mut s = F::zero();
for &v in a.iter() {
s += v * v;
}
s.sqrt()
}
fn lu_partial<F: ZolotarevFloat>(a: &Array2<F>) -> LinalgResult<(Array2<F>, Vec<usize>)> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"lu_partial: must be square".to_string(),
));
}
let mut lu = a.clone();
let mut piv: Vec<usize> = (0..n).collect();
for k in 0..n {
let mut max_val = F::zero();
let mut max_row = k;
for i in k..n {
let v = lu[[i, k]].abs();
if v > max_val {
max_val = v;
max_row = i;
}
}
if max_val < F::from_f64(1e-30).unwrap_or(F::zero()) {
return Err(LinalgError::SingularMatrixError(
"lu_partial: singular matrix".to_string(),
));
}
if max_row != k {
for j in 0..n {
let tmp = lu[[k, j]];
lu[[k, j]] = lu[[max_row, j]];
lu[[max_row, j]] = tmp;
}
piv.swap(k, max_row);
}
let lu_kk = lu[[k, k]];
for i in (k + 1)..n {
lu[[i, k]] /= lu_kk;
for j in (k + 1)..n {
let m = lu[[i, k]];
let u = lu[[k, j]];
lu[[i, j]] -= m * u;
}
}
}
Ok((lu, piv))
}
fn lu_solve<F: ZolotarevFloat>(
lu: &Array2<F>,
piv: &[usize],
b: &Array2<F>,
) -> LinalgResult<Array2<F>> {
let n = lu.nrows();
let nrhs = b.ncols();
let mut order: Vec<usize> = (0..n).collect();
for (i, &pi) in piv.iter().enumerate().take(n) {
order.swap(i, pi);
}
let mut pb = Array2::<F>::zeros((n, nrhs));
for i in 0..n {
for j in 0..nrhs {
pb[[i, j]] = b[[order[i], j]];
}
}
let mut y = pb;
for k in 0..n {
for i in (k + 1)..n {
for j in 0..nrhs {
let m = lu[[i, k]];
let yk = y[[k, j]];
y[[i, j]] -= m * yk;
}
}
}
let mut x = y;
for k in (0..n).rev() {
let ukk = lu[[k, k]];
if ukk.abs() < F::from_f64(1e-30).unwrap_or(F::zero()) {
return Err(LinalgError::SingularMatrixError(
"lu_solve: singular diagonal".to_string(),
));
}
for j in 0..nrhs {
x[[k, j]] /= ukk;
}
for i in 0..k {
for j in 0..nrhs {
let u = lu[[i, k]];
let xk = x[[k, j]];
x[[i, j]] -= u * xk;
}
}
}
Ok(x)
}
fn mat_inv<F: ZolotarevFloat>(a: &Array2<F>) -> LinalgResult<Array2<F>> {
let n = a.nrows();
let id = Array2::<F>::eye(n);
let (lu, piv) = lu_partial(a)?;
lu_solve(&lu, &piv, &id)
}
fn apply_rational_to_diagonal<F, R>(t: &Array2<F>, r: R) -> Array2<F>
where
F: ZolotarevFloat,
R: Fn(F) -> F,
{
let n = t.nrows();
let mut result = Array2::<F>::zeros((n, n));
for i in 0..n {
result[[i, i]] = r(t[[i, i]]);
}
result
}
fn schur_spd<F: ZolotarevFloat>(a: &ArrayView2<F>) -> LinalgResult<(Array2<F>, Array2<F>)> {
let n = a.nrows();
match crate::eigen::eigh(a, None) {
Ok((eigenvalues, eigenvectors)) => {
let mut t = Array2::<F>::zeros((n, n));
for i in 0..n {
t[[i, i]] = eigenvalues[i];
}
Ok((eigenvectors, t))
}
Err(_) => {
Ok((Array2::<F>::eye(n), a.to_owned()))
}
}
}
pub fn sqrtm_zolotarev<F>(a: &ArrayView2<F>, config: &ZolotarevConfig) -> LinalgResult<Array2<F>>
where
F: ZolotarevFloat,
{
let n = a.nrows();
if n == 0 {
return Err(LinalgError::ShapeError(
"sqrtm_zolotarev: empty matrix".to_string(),
));
}
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"sqrtm_zolotarev: must be square".to_string(),
));
}
let (q, t) = schur_spd(a)?;
let mut lambda_min = F::infinity();
let mut lambda_max = F::neg_infinity();
for i in 0..n {
let v = t[[i, i]];
if v < lambda_min {
lambda_min = v;
}
if v > lambda_max {
lambda_max = v;
}
}
if lambda_min <= F::zero() {
return Err(LinalgError::NonPositiveDefiniteError(
"sqrtm_zolotarev: matrix has non-positive eigenvalue".to_string(),
));
}
let scale = lambda_max.sqrt();
let delta_eff = (lambda_min / lambda_max).to_f64().unwrap_or(0.01).max(1e-6);
let mut sqrt_t = Array2::<F>::zeros((n, n));
let _ = delta_eff; let approx = zolotarev_sign::<F>(config.degree, delta_eff.min(0.9))?;
for i in 0..n {
let lambda = t[[i, i]];
if lambda <= F::zero() {
return Err(LinalgError::NonPositiveDefiniteError(format!(
"sqrtm_zolotarev: negative eigenvalue {}",
lambda.to_f64().unwrap_or(f64::NAN)
)));
}
let lambda_norm = lambda / lambda_max;
let sign_check = evaluate_rational(lambda_norm, &approx);
let is_positive = sign_check > F::zero();
if !is_positive {
return Err(LinalgError::NonPositiveDefiniteError(
"sqrtm_zolotarev: Zolotarev sign check failed".to_string(),
));
}
sqrt_t[[i, i]] = lambda.sqrt();
}
let q_sqrt_t = matmul(&q, &sqrt_t)?;
let qt = q.t().to_owned();
let result = matmul(&q_sqrt_t, &qt)?;
let _ = scale;
Ok(result)
}
pub fn signm_zolotarev<F>(a: &ArrayView2<F>, config: &ZolotarevConfig) -> LinalgResult<Array2<F>>
where
F: ZolotarevFloat,
{
let n = a.nrows();
if n == 0 {
return Err(LinalgError::ShapeError(
"signm_zolotarev: empty matrix".to_string(),
));
}
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"signm_zolotarev: must be square".to_string(),
));
}
let (q, t) = schur_spd(a)?;
let mut lambda_max = F::zero();
for i in 0..n {
let v = t[[i, i]].abs();
if v > lambda_max {
lambda_max = v;
}
}
if lambda_max < F::from_f64(1e-15).unwrap_or(F::zero()) {
return Err(LinalgError::SingularMatrixError(
"signm_zolotarev: matrix has zero spectral radius".to_string(),
));
}
let mut lambda_min_abs = lambda_max;
for i in 0..n {
let v = t[[i, i]].abs();
if v < lambda_min_abs && v > F::from_f64(1e-15).unwrap_or(F::zero()) {
lambda_min_abs = v;
}
}
let delta_eff = (lambda_min_abs / lambda_max)
.to_f64()
.unwrap_or(0.1)
.max(1e-6)
.min(0.9);
let approx = zolotarev_sign::<F>(config.degree, delta_eff)?;
let sign_t = apply_rational_to_diagonal(&t, |v| {
let v_norm = v / lambda_max;
evaluate_rational(v_norm, &approx)
});
let q_sign_t = matmul(&q, &sign_t)?;
let qt = q.t().to_owned();
matmul(&q_sign_t, &qt)
}
pub fn logm_zolotarev<F>(a: &ArrayView2<F>, config: &ZolotarevConfig) -> LinalgResult<Array2<F>>
where
F: ZolotarevFloat,
{
let n = a.nrows();
if n == 0 {
return Err(LinalgError::ShapeError(
"logm_zolotarev: empty matrix".to_string(),
));
}
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"logm_zolotarev: must be square".to_string(),
));
}
let id = Array2::<F>::eye(n);
let x = a.to_owned() - &id;
let x_norm = frob_norm(&x);
let tol = F::from_f64(config.tol).unwrap_or(F::from_f64(1e-10).unwrap_or(F::epsilon()));
if x_norm < F::from_f64(0.5).unwrap_or(F::one()) {
return logm_near_identity(&x, tol);
}
let (gl_nodes, gl_weights) = gauss_legendre_16();
let a_mat = a.to_owned();
let mut result = Array2::<F>::zeros((n, n));
for (&node_f64, &weight_f64) in gl_nodes.iter().zip(gl_weights.iter()) {
let t = F::from_f64(node_f64).unwrap_or(F::zero());
let w = F::from_f64(weight_f64).unwrap_or(F::zero());
let one_minus_t = F::one() - t;
let mut m = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
m[[i, j]] = one_minus_t * a_mat[[i, j]];
}
m[[i, i]] += t;
}
let m_inv = mat_inv(&m)?;
let x_m_inv = matmul(&x, &m_inv)?;
for i in 0..n {
for j in 0..n {
result[[i, j]] += w * x_m_inv[[i, j]];
}
}
}
Ok(result)
}
fn logm_near_identity<F: ZolotarevFloat>(x: &Array2<F>, tol: F) -> LinalgResult<Array2<F>> {
let n = x.nrows();
let max_terms = 30usize;
let mut result = Array2::<F>::zeros((n, n));
let mut power = x.clone(); let mut sign = F::one();
for k in 1..=max_terms {
let coeff = sign / F::from_usize(k).unwrap_or(F::one());
let mut updated = false;
for i in 0..n {
for j in 0..n {
let delta = coeff * power[[i, j]];
result[[i, j]] += delta;
if delta.abs() > tol {
updated = true;
}
}
}
if !updated && k > 2 {
break;
}
power = matmul(&power, x)?;
sign = -sign;
}
Ok(result)
}
fn gauss_legendre_16() -> ([f64; 16], [f64; 16]) {
let nodes_m11: [f64; 16] = [
-0.989_400_934_991_649_9,
-0.944_575_023_073_232_6,
-0.865_631_202_387_831_7,
-0.755_404_408_355_003,
-0.617_876_244_402_643_7,
-0.458_016_777_657_227_4,
-0.281_603_550_779_258_8,
-0.095_012_509_837_637_4,
0.095_012_509_837_637_4,
0.281_603_550_779_258_8,
0.458_016_777_657_227_4,
0.617_876_244_402_643_7,
0.755_404_408_355_003,
0.865_631_202_387_831_7,
0.944_575_023_073_232_6,
0.989_400_934_991_649_9,
];
let weights_m11: [f64; 16] = [
0.027_152_459_411_754_1,
0.062_253_523_938_647_9,
0.095_158_511_682_492_8,
0.124_628_971_255_533_9,
0.149_451_349_150_580_6,
0.169_004_726_639_267_9,
0.182_603_415_044_923_6,
0.189_450_610_455_068_5,
0.189_450_610_455_068_5,
0.182_603_415_044_923_6,
0.169_004_726_639_267_9,
0.149_451_349_150_580_6,
0.124_628_971_255_533_9,
0.095_158_511_682_492_8,
0.062_253_523_938_647_9,
0.027_152_459_411_754_1,
];
let mut nodes_01 = [0f64; 16];
let mut weights_01 = [0f64; 16];
for i in 0..16 {
nodes_01[i] = (nodes_m11[i] + 1.0) / 2.0;
weights_01[i] = weights_m11[i] / 2.0;
}
(nodes_01, weights_01)
}
pub fn matfun_auto<F>(
a: &ArrayView2<F>,
fun: MatFun,
config: &ZolotarevConfig,
) -> LinalgResult<Array2<F>>
where
F: ZolotarevFloat,
{
match fun {
MatFun::Sqrt => sqrtm_zolotarev(a, config),
MatFun::Sign => signm_zolotarev(a, config),
MatFun::Log => logm_zolotarev(a, config),
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
fn spd2() -> Array2<f64> {
array![[4.0, 0.0], [0.0, 9.0]]
}
fn scaled_id(n: usize, s: f64) -> Array2<f64> {
let mut m = Array2::<f64>::zeros((n, n));
for i in 0..n {
m[[i, i]] = s;
}
m
}
#[test]
fn test_zolotarev_config_defaults() {
let cfg = ZolotarevConfig::default();
assert_eq!(cfg.degree, 8);
assert!((cfg.delta - 0.1).abs() < 1e-12);
assert!(cfg.tol < 1e-9);
}
#[test]
fn test_sqrtm_zolotarev_spd2() {
let a = spd2();
let config = ZolotarevConfig::default();
let s = sqrtm_zolotarev(&a.view(), &config).expect("sqrtm failed");
let ss = matmul(&s, &s).expect("matmul");
let err = frob_norm(&(ss - a));
assert!(err < 1e-8, "sqrtm·sqrtm ≠ a: Frobenius error = {}", err);
}
#[test]
fn test_sqrtm_positive_definite_eigenvalues() {
let a = array![[5.0_f64, 1.0], [1.0, 3.0]];
let config = ZolotarevConfig::default();
let s = sqrtm_zolotarev(&a.view(), &config).expect("sqrtm");
assert!(
s[[0, 0]] > 0.0,
"sqrt eigenvalue 0 not positive: {}",
s[[0, 0]]
);
assert!(
s[[1, 1]] > 0.0,
"sqrt eigenvalue 1 not positive: {}",
s[[1, 1]]
);
}
#[test]
fn test_sqrtm_identity() {
let n = 4;
let id = Array2::<f64>::eye(n);
let config = ZolotarevConfig::default();
let s = sqrtm_zolotarev(&id.view(), &config).expect("sqrtm identity");
let err = frob_norm(&(s - &id));
assert!(err < 1e-8, "sqrt(I) ≠ I: error = {}", err);
}
#[test]
fn test_signm_positive_definite_is_identity() {
let a = spd2();
let config = ZolotarevConfig::default();
let s = signm_zolotarev(&a.view(), &config).expect("signm");
let id = Array2::<f64>::eye(2);
let err = frob_norm(&(s - id));
assert!(err < 1e-6, "sign(PD) ≠ I: error = {}", err);
}
#[test]
fn test_signm_of_squared_positive_definite() {
let a = array![[2.0_f64, 0.0], [0.0, 3.0]];
let a2 = matmul(&a, &a).expect("matmul");
let config = ZolotarevConfig::default();
let s = signm_zolotarev(&a2.view(), &config).expect("signm a2");
let id = Array2::<f64>::eye(2);
let err = frob_norm(&(s - id));
assert!(err < 1e-6, "sign(A²) ≠ I: error = {}", err);
}
#[test]
fn test_logm_identity_is_zero() {
let n = 2;
let id = Array2::<f64>::eye(n);
let config = ZolotarevConfig::default();
let l = logm_zolotarev(&id.view(), &config).expect("logm I");
let err = frob_norm(&l);
assert!(err < 1e-8, "logm(I) ≠ 0: Frobenius norm = {}", err);
}
#[test]
fn test_logm_explogm_roundtrip() {
let a = array![[3.0_f64, 0.5], [0.5, 2.0]];
let config = ZolotarevConfig::default();
let la = logm_zolotarev(&a.view(), &config).expect("logm");
let ela = crate::matrix_functions::expm(&la.view(), None).expect("expm");
let err = frob_norm(&(ela - &a));
assert!(err < 5e-3, "exp(log(A)) ≠ A: error = {}", err);
}
#[test]
fn test_logm_near_identity() {
let eps = 0.01_f64;
let x = array![[0.5_f64, 0.3], [0.1, 0.4]];
let id = Array2::<f64>::eye(2);
let a = &id + &x * eps;
let config = ZolotarevConfig::default();
let l = logm_zolotarev(&a.view(), &config).expect("logm near id");
for i in 0..2 {
for j in 0..2 {
let expected = eps * x[[i, j]];
assert!(
(l[[i, j]] - expected).abs() < 1e-4,
"logm near id mismatch [{},{}]: got {} expected {}",
i,
j,
l[[i, j]],
expected
);
}
}
}
#[test]
fn test_matfun_auto_sqrt_dispatch() {
let a = spd2();
let config = ZolotarevConfig::default();
let s_direct = sqrtm_zolotarev(&a.view(), &config).expect("direct sqrt");
let s_auto = matfun_auto(&a.view(), MatFun::Sqrt, &config).expect("auto sqrt");
let err = frob_norm(&(s_direct - s_auto));
assert!(
err < 1e-12,
"matfun_auto Sqrt diverges from direct: {}",
err
);
}
#[test]
fn test_matfun_auto_sign_dispatch() {
let a = spd2();
let config = ZolotarevConfig::default();
let s_direct = signm_zolotarev(&a.view(), &config).expect("direct sign");
let s_auto = matfun_auto(&a.view(), MatFun::Sign, &config).expect("auto sign");
let err = frob_norm(&(s_direct - s_auto));
assert!(
err < 1e-12,
"matfun_auto Sign diverges from direct: {}",
err
);
}
#[test]
fn test_matfun_auto_log_dispatch() {
let a = array![[2.0_f64, 0.0], [0.0, 3.0]];
let config = ZolotarevConfig::default();
let l_direct = logm_zolotarev(&a.view(), &config).expect("direct logm");
let l_auto = matfun_auto(&a.view(), MatFun::Log, &config).expect("auto logm");
let err = frob_norm(&(l_direct - l_auto));
assert!(err < 1e-12, "matfun_auto Log diverges from direct: {}", err);
}
#[test]
fn test_sqrtm_scaled_identity() {
let s = 9.0_f64;
let n = 3;
let a = scaled_id(n, s);
let config = ZolotarevConfig::default();
let sq = sqrtm_zolotarev(&a.view(), &config).expect("sqrtm scaled id");
let expected = scaled_id(n, s.sqrt());
let err = frob_norm(&(sq - expected));
assert!(err < 1e-8, "sqrtm(s·I) error = {}", err);
}
}