use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array2, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::iter::Sum;
pub trait PadeFloat:
Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static
{
}
impl<T> PadeFloat for T where
T: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static
{
}
fn matmul_nn<F: PadeFloat>(a: &Array2<F>, b: &Array2<F>) -> Array2<F> {
let n = a.nrows();
let mut c = Array2::<F>::zeros((n, n));
for i in 0..n {
for k in 0..n {
let aik = a[[i, k]];
if aik == F::zero() {
continue;
}
for j in 0..n {
c[[i, j]] = c[[i, j]] + aik * b[[k, j]];
}
}
}
c
}
fn one_norm<F: PadeFloat>(a: &Array2<F>) -> F {
let n = a.ncols();
let m = a.nrows();
let mut max_col = F::zero();
for j in 0..n {
let mut col_sum = F::zero();
for i in 0..m {
col_sum = col_sum + a[[i, j]].abs();
}
if col_sum > max_col {
max_col = col_sum;
}
}
max_col
}
fn add_identity_scaled<F: PadeFloat>(a: &mut Array2<F>, scale: F) {
let n = a.nrows();
for i in 0..n {
a[[i, i]] = a[[i, i]] + scale;
}
}
pub fn pade_coefficients(m: usize) -> Vec<f64> {
match m {
3 => vec![
120.0,
60.0,
12.0,
1.0,
],
5 => vec![
30240.0,
15120.0,
3360.0,
420.0,
30.0,
1.0,
],
7 => vec![
17297280.0,
8648640.0,
1995840.0,
277200.0,
25200.0,
1512.0,
56.0,
1.0,
],
9 => vec![
17643225600.0,
8821612800.0,
2075673600.0,
302702400.0,
30270240.0,
2162160.0,
110880.0,
3960.0,
90.0,
1.0,
],
13 => vec![
64764752532480000.0,
32382376266240000.0,
7771770303897600.0,
1187353796428800.0,
129060195264000.0,
10559470521600.0,
670442572800.0,
33522128640.0,
1323241920.0,
40840800.0,
960960.0,
16380.0,
182.0,
1.0,
],
_ => {
let mut coeffs = vec![0.0f64; m + 1];
let two_m_fact = factorial_f64(2 * m);
let m_fact = factorial_f64(m);
for k in 0..=m {
let two_m_minus_k_fact = factorial_f64(2 * m - k);
let k_fact = factorial_f64(k);
let m_minus_k_fact = factorial_f64(m - k);
coeffs[k] = two_m_minus_k_fact * m_fact / (two_m_fact * k_fact * m_minus_k_fact);
}
coeffs
}
}
}
fn factorial_f64(n: usize) -> f64 {
if n == 0 {
1.0
} else {
let mut result = 1.0f64;
for i in 2..=n {
result *= i as f64;
}
result
}
}
pub fn expm_pade<F: PadeFloat>(a: &ArrayView2<F>, m: usize) -> LinalgResult<Array2<F>> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"expm_pade: matrix must be square".into(),
));
}
if n == 0 {
return Ok(Array2::<F>::zeros((0, 0)));
}
let coeffs_f64 = pade_coefficients(m);
let coeffs: Vec<F> = coeffs_f64
.iter()
.map(|&c| {
F::from(c).unwrap_or_else(|| {
F::from(c as i64).unwrap_or(F::one())
})
})
.collect();
let a_owned = a.to_owned();
let (n_pade, d_pade) = if m == 13 {
pade_13_polynomials(&a_owned, &coeffs)?
} else {
pade_general_polynomials(&a_owned, &coeffs, m)?
};
crate::solve::solve_multiple(&d_pade.view(), &n_pade.view(), None)
}
fn pade_13_polynomials<F: PadeFloat>(
a: &Array2<F>,
c: &[F],
) -> LinalgResult<(Array2<F>, Array2<F>)> {
let n = a.nrows();
let eye = Array2::<F>::eye(n);
let a2 = matmul_nn(a, a);
let a4 = matmul_nn(&a2, &a2);
let a6 = matmul_nn(&a2, &a4);
let mut u_inner = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
u_inner[[i, j]] = c[13] * a6[[i, j]]
+ c[11] * a4[[i, j]]
+ c[9] * a2[[i, j]];
}
}
add_identity_scaled(&mut u_inner, c[7]);
let mut v_inner = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
v_inner[[i, j]] = c[12] * a6[[i, j]]
+ c[10] * a4[[i, j]]
+ c[8] * a2[[i, j]];
}
}
add_identity_scaled(&mut v_inner, c[6]);
let a6_u_inner = matmul_nn(&a6, &u_inner);
let mut u_rest = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
u_rest[[i, j]] = a6_u_inner[[i, j]]
+ c[5] * a4[[i, j]]
+ c[3] * a2[[i, j]];
}
}
add_identity_scaled(&mut u_rest, c[1]);
let a6_v_inner = matmul_nn(&a6, &v_inner);
let mut v_rest = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
v_rest[[i, j]] = a6_v_inner[[i, j]]
+ c[4] * a4[[i, j]]
+ c[2] * a2[[i, j]];
}
}
add_identity_scaled(&mut v_rest, c[0]);
let u = matmul_nn(a, &u_rest);
let v = v_rest;
let mut n_pade = Array2::<F>::zeros((n, n));
let mut d_pade = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
n_pade[[i, j]] = v[[i, j]] + u[[i, j]];
d_pade[[i, j]] = v[[i, j]] - u[[i, j]];
}
}
let _ = eye;
Ok((n_pade, d_pade))
}
fn pade_general_polynomials<F: PadeFloat>(
a: &Array2<F>,
c: &[F],
m: usize,
) -> LinalgResult<(Array2<F>, Array2<F>)> {
let n = a.nrows();
let mut a_powers: Vec<Array2<F>> = Vec::with_capacity(m + 1);
a_powers.push(Array2::<F>::eye(n)); if m >= 1 {
a_powers.push(a.to_owned()); }
for k in 2..=m {
let prev = a_powers[k - 1].clone();
a_powers.push(matmul_nn(a, &prev));
}
let mut n_pade = Array2::<F>::zeros((n, n));
let mut d_pade = Array2::<F>::zeros((n, n));
for k in 0..=m {
let sign = if k % 2 == 0 { F::one() } else { -F::one() };
for i in 0..n {
for j in 0..n {
let term = c[k] * a_powers[k][[i, j]];
n_pade[[i, j]] = n_pade[[i, j]] + term;
d_pade[[i, j]] = d_pade[[i, j]] + sign * term;
}
}
}
Ok((n_pade, d_pade))
}
const THETA: [f64; 5] = [
1.495_585_217_958_292e-2, 2.539_398_330_063_23e-1, 9.504_178_996_162_932e-1, 2.097_847_961_257_068, 5.371_920_351_148_152, ];
const PADE_ORDERS: [usize; 5] = [3, 5, 7, 9, 13];
pub fn pade_expm<F: PadeFloat>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"pade_expm: matrix must be square".into(),
));
}
if n == 0 {
return Ok(Array2::<F>::zeros((0, 0)));
}
if n == 1 {
let mut result = Array2::<F>::zeros((1, 1));
result[[0, 0]] = a[[0, 0]].exp();
return Ok(result);
}
let a_owned = a.to_owned();
let norm = one_norm(&a_owned);
let norm_f64 = norm
.to_f64()
.ok_or_else(|| LinalgError::ComputationError("Cannot convert norm to f64".into()))?;
for (idx, &theta) in THETA.iter().enumerate() {
if norm_f64 <= theta {
let m = PADE_ORDERS[idx];
return expm_pade(a, m);
}
}
let theta_13 = THETA[4];
let s_f64 = (norm_f64 / theta_13).log2().ceil().max(0.0);
let s = s_f64 as u32;
let two_s = F::from(2.0_f64.powi(s as i32))
.ok_or_else(|| LinalgError::ComputationError("Cannot convert 2^s".into()))?;
let a_scaled = a_owned.map(|&x| x / two_s);
let mut result = expm_pade(&a_scaled.view(), 13)?;
for _ in 0..s {
result = matmul_nn(&result.clone(), &result.clone());
}
Ok(result)
}
pub fn expm_frechet<F: PadeFloat>(
a: &ArrayView2<F>,
e: &ArrayView2<F>,
) -> LinalgResult<(Array2<F>, Array2<F>)> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"expm_frechet: A must be square".into(),
));
}
if e.nrows() != n || e.ncols() != n {
return Err(LinalgError::ShapeError(
"expm_frechet: E must have same shape as A".into(),
));
}
let n2 = 2 * n;
let mut aug = Array2::<F>::zeros((n2, n2));
for i in 0..n {
for j in 0..n {
aug[[i, j]] = a[[i, j]]; aug[[i, j + n]] = e[[i, j]]; aug[[i + n, j + n]] = a[[i, j]]; }
}
let exp_aug = pade_expm(&aug.view())?;
let mut exp_a = Array2::<F>::zeros((n, n));
let mut frechet = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
exp_a[[i, j]] = exp_aug[[i, j]];
frechet[[i, j]] = exp_aug[[i, j + n]];
}
}
Ok((exp_a, frechet))
}
pub fn expm_cond(a: &ArrayView2<f64>) -> LinalgResult<f64> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"expm_cond: matrix must be square".into(),
));
}
let exp_a = pade_expm(a)?;
let norm_exp_a = one_norm(&exp_a);
if norm_exp_a < f64::EPSILON {
return Ok(f64::INFINITY);
}
let mut max_ratio: f64 = 0.0;
let n_dirs = (n * n).min(25);
for dir in 0..n_dirs {
let row = dir / n;
let col = dir % n;
let mut e_mat = Array2::<f64>::zeros((n, n));
e_mat[[row.min(n - 1), col.min(n - 1)]] = 1.0;
let (_, l_ae) = expm_frechet(a, &e_mat.view())?;
let norm_l = one_norm(&l_ae);
if norm_l > max_ratio {
max_ratio = norm_l;
}
}
let norm_a = one_norm(&a.to_owned());
let cond = max_ratio * norm_a / norm_exp_a;
Ok(cond.max(1.0))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_pade_coefficients_m3() {
let c = pade_coefficients(3);
assert_eq!(c.len(), 4);
assert!((c[0] - 120.0).abs() < 1e-6);
assert!((c[3] - 1.0).abs() < 1e-6);
}
#[test]
fn test_pade_coefficients_m13() {
let c = pade_coefficients(13);
assert_eq!(c.len(), 14);
assert!((c[13] - 1.0).abs() < 1e-6);
}
#[test]
fn test_expm_pade_identity() {
let eye = array![[1.0, 0.0], [0.0, 1.0]];
let result = expm_pade(&eye.view(), 13).expect("pade failed");
let e_val = std::f64::consts::E;
assert_abs_diff_eq!(result[[0, 0]], e_val, epsilon = 1e-10);
assert_abs_diff_eq!(result[[1, 1]], e_val, epsilon = 1e-10);
assert_abs_diff_eq!(result[[0, 1]], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(result[[1, 0]], 0.0, epsilon = 1e-10);
}
#[test]
fn test_expm_pade_zero() {
let z = array![[0.0, 0.0], [0.0, 0.0]];
let result = expm_pade(&z.view(), 13).expect("pade failed");
assert_abs_diff_eq!(result[[0, 0]], 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(result[[1, 1]], 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(result[[0, 1]], 0.0, epsilon = 1e-12);
assert_abs_diff_eq!(result[[1, 0]], 0.0, epsilon = 1e-12);
}
#[test]
fn test_pade_expm_rotation() {
let a = array![[0.0_f64, 1.0], [-1.0, 0.0]];
let result = pade_expm(&a.view()).expect("pade_expm failed");
let cos1 = 1.0_f64.cos();
let sin1 = 1.0_f64.sin();
assert_abs_diff_eq!(result[[0, 0]], cos1, epsilon = 1e-12);
assert_abs_diff_eq!(result[[0, 1]], sin1, epsilon = 1e-12);
assert_abs_diff_eq!(result[[1, 0]], -sin1, epsilon = 1e-12);
assert_abs_diff_eq!(result[[1, 1]], cos1, epsilon = 1e-12);
}
#[test]
fn test_pade_expm_large_norm() {
let a = array![[10.0_f64, 0.0], [0.0, -5.0]];
let result = pade_expm(&a.view()).expect("pade_expm failed");
assert_abs_diff_eq!(result[[0, 0]], 10.0_f64.exp(), epsilon = 1e-4);
assert_abs_diff_eq!(result[[1, 1]], (-5.0_f64).exp(), epsilon = 1e-10);
assert_abs_diff_eq!(result[[0, 1]], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(result[[1, 0]], 0.0, epsilon = 1e-10);
}
#[test]
fn test_pade_expm_nilpotent() {
let a = array![[0.0_f64, 1.0], [0.0, 0.0]];
let result = pade_expm(&a.view()).expect("pade_expm failed");
assert_abs_diff_eq!(result[[0, 0]], 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(result[[0, 1]], 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(result[[1, 0]], 0.0, epsilon = 1e-12);
assert_abs_diff_eq!(result[[1, 1]], 1.0, epsilon = 1e-12);
}
#[test]
fn test_expm_frechet_diagonal() {
let a = array![[1.0_f64, 0.0], [0.0, 2.0]];
let e = array![[1.0_f64, 0.0], [0.0, 1.0]];
let (exp_a, l_ae) = expm_frechet(&a.view(), &e.view()).expect("frechet failed");
assert_abs_diff_eq!(exp_a[[0, 0]], 1.0_f64.exp(), epsilon = 1e-10);
assert_abs_diff_eq!(exp_a[[1, 1]], 2.0_f64.exp(), epsilon = 1e-10);
assert_abs_diff_eq!(l_ae[[0, 0]], 1.0_f64.exp(), epsilon = 1e-8);
assert_abs_diff_eq!(l_ae[[1, 1]], 2.0_f64.exp(), epsilon = 1e-8);
}
#[test]
fn test_expm_frechet_linearity() {
let a = array![[0.5_f64, 0.2], [0.1, 0.3]];
let e = array![[0.1_f64, 0.0], [0.0, 0.2]];
let alpha = 3.0_f64;
let mut e_scaled = Array2::<f64>::zeros((2, 2));
for i in 0..2 {
for j in 0..2 {
e_scaled[[i, j]] = alpha * e[[i, j]];
}
}
let (_, l1) = expm_frechet(&a.view(), &e.view()).expect("frechet failed");
let (_, l2) = expm_frechet(&a.view(), &e_scaled.view()).expect("frechet failed");
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(l2[[i, j]], alpha * l1[[i, j]], epsilon = 1e-10);
}
}
}
#[test]
fn test_expm_cond_positive() {
let a = array![[0.5_f64, 0.0], [0.0, 0.5]];
let kappa = expm_cond(&a.view()).expect("cond failed");
assert!(kappa >= 1.0);
assert!(kappa.is_finite());
}
#[test]
fn test_expm_cond_identity() {
let eye = array![[1.0_f64, 0.0], [0.0, 1.0]];
let kappa = expm_cond(&eye.view()).expect("cond failed");
assert!(kappa >= 1.0);
assert!(kappa < 1000.0);
}
#[test]
fn test_expm_pade_all_orders() {
let a = array![[0.01_f64, 0.005], [-0.005, 0.02]];
let expected = pade_expm(&a.view()).expect("expected failed");
for &m in &[3usize, 5, 7, 9, 13] {
let result = expm_pade(&a.view(), m).expect("pade order failed");
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(result[[i, j]], expected[[i, j]], epsilon = 1e-8);
}
}
}
}
}