use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array2, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::iter::Sum;
pub trait TrigFloat: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static {}
impl<T> TrigFloat for T where T: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static {}
fn matmul_nn<F: TrigFloat>(a: &Array2<F>, b: &Array2<F>) -> Array2<F> {
let n = a.nrows();
let mut c = Array2::<F>::zeros((n, n));
for i in 0..n {
for k in 0..n {
let aik = a[[i, k]];
if aik == F::zero() {
continue;
}
for j in 0..n {
c[[i, j]] += aik * b[[k, j]];
}
}
}
c
}
fn parlett_recurrence<F: TrigFloat>(
t: &Array2<F>,
f_diag: &[F],
n: usize,
scalar_fn: fn(F) -> F,
) -> Array2<F> {
let mut ft = Array2::<F>::zeros((n, n));
for i in 0..n {
ft[[i, i]] = f_diag[i];
}
let thresh = F::epsilon() * F::from(100.0).unwrap_or(F::one());
for j in 1..n {
for i in (0..j).rev() {
let fii = ft[[i, i]];
let fjj = ft[[j, j]];
let tij = t[[i, j]];
let denom = t[[j, j]] - t[[i, i]];
let mut inner_sum = F::zero();
for k in (i + 1)..j {
inner_sum = inner_sum + ft[[i, k]] * t[[k, j]] - t[[i, k]] * ft[[k, j]];
}
if denom.abs() < thresh {
let f_prime = numerical_derivative(scalar_fn, t[[i, i]]);
ft[[i, j]] = f_prime * tij + inner_sum;
} else {
let numer = (fii - fjj) * tij + inner_sum;
ft[[i, j]] = numer / denom;
}
}
}
ft
}
fn numerical_derivative<F: TrigFloat>(f: fn(F) -> F, x: F) -> F {
let h = F::from(1e-5).unwrap_or(F::epsilon()) * (F::one() + x.abs());
(f(x + h) - f(x - h)) / (F::from(2.0).unwrap_or(F::one()) * h)
}
fn schur_function<F: TrigFloat>(
a: &ArrayView2<F>,
scalar_fn: fn(F) -> F,
name: &str,
) -> LinalgResult<Array2<F>> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(format!(
"{name}: matrix must be square"
)));
}
if n == 0 {
return Ok(Array2::<F>::zeros((0, 0)));
}
if n == 1 {
let mut result = Array2::<F>::zeros((1, 1));
result[[0, 0]] = scalar_fn(a[[0, 0]]);
return Ok(result);
}
let (q, t) = crate::decomposition::schur(a)?;
let f_diag: Vec<F> = (0..n).map(|i| scalar_fn(t[[i, i]])).collect();
let ft = parlett_recurrence(&t, &f_diag, n, scalar_fn);
Ok(q.dot(&ft).dot(&q.t()))
}
pub fn schur_apply<F: TrigFloat>(
a: &ArrayView2<F>,
scalar_fn: fn(F) -> F,
name: &str,
) -> LinalgResult<Array2<F>> {
schur_function(a, scalar_fn, name)
}
pub fn sinm_schur<F: TrigFloat>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>> {
schur_function(a, |x: F| x.sin(), "sinm_schur")
}
pub fn cosm_schur<F: TrigFloat>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>> {
schur_function(a, |x: F| x.cos(), "cosm_schur")
}
pub fn tanm_schur<F: TrigFloat>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"tanm_schur: matrix must be square".into(),
));
}
let sin_a = sinm_schur(a)?;
let cos_a = cosm_schur(a)?;
crate::solve::solve_multiple(&cos_a.view(), &sin_a.view(), None)
}
pub fn sinhm_schur<F: TrigFloat>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>> {
schur_function(a, |x: F| x.sinh(), "sinhm_schur")
}
pub fn coshm_schur<F: TrigFloat>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>> {
schur_function(a, |x: F| x.cosh(), "coshm_schur")
}
pub fn tanhm_schur<F: TrigFloat>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"tanhm_schur: matrix must be square".into(),
));
}
let sinh_a = sinhm_schur(a)?;
let cosh_a = coshm_schur(a)?;
crate::solve::solve_multiple(&cosh_a.view(), &sinh_a.view(), None)
}
pub fn apply_schur<F: TrigFloat>(
a: &ArrayView2<F>,
f: fn(F) -> F,
name: &str,
) -> LinalgResult<Array2<F>> {
schur_function(a, f, name)
}
pub fn sincos_expm<F: TrigFloat>(a: &ArrayView2<F>) -> LinalgResult<(Array2<F>, Array2<F>)> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"sincos_expm: matrix must be square".into(),
));
}
let is_symmetric = {
let mut sym = true;
'outer: for i in 0..n {
for j in (i + 1)..n {
if (a[[i, j]] - a[[j, i]]).abs() > F::epsilon() * F::from(10.0).unwrap_or(F::one())
{
sym = false;
break 'outer;
}
}
}
sym
};
if is_symmetric {
let n2 = 2 * n;
let mut aug = Array2::<F>::zeros((n2, n2));
for i in 0..n {
for j in 0..n {
aug[[i, j + n]] = -a[[i, j]]; aug[[i + n, j]] = a[[i, j]]; }
}
let exp_aug = crate::matrix_functions::pade::pade_expm(&aug.view())?;
let mut cos_a = Array2::<F>::zeros((n, n));
let mut sin_a = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
cos_a[[i, j]] = exp_aug[[i, j]]; sin_a[[i, j]] = exp_aug[[i + n, j]]; }
}
Ok((cos_a, sin_a))
} else {
sincos_via_eig(a, n)
}
}
fn sincos_via_eig<F: TrigFloat>(
a: &ArrayView2<F>,
n: usize,
) -> LinalgResult<(Array2<F>, Array2<F>)> {
use scirs2_core::numeric::Complex;
let (eigenvals, eigenvecs) = crate::eigen::eig(a, None)?;
let cos_eigs: Vec<Complex<F>> = eigenvals
.iter()
.map(|&lam| {
let (a, b) = (lam.re, lam.im);
let ca = a.cos();
let cb = b.cosh();
let sa = a.sin();
let sb = b.sinh();
Complex::new(ca * cb, -(sa * sb))
})
.collect();
let sin_eigs: Vec<Complex<F>> = eigenvals
.iter()
.map(|&lam| {
let (a, b) = (lam.re, lam.im);
let ca = a.cos();
let cb = b.cosh();
let sa = a.sin();
let sb = b.sinh();
Complex::new(sa * cb, ca * sb)
})
.collect();
let cos_d: Array2<Complex<F>> =
Array2::from_diag(&cos_eigs.iter().copied().collect::<Array1<_>>());
let sin_d: Array2<Complex<F>> =
Array2::from_diag(&sin_eigs.iter().copied().collect::<Array1<_>>());
let v_cos_d = eigenvecs.dot(&cos_d);
let v_sin_d = eigenvecs.dot(&sin_d);
let v_inv = complex_inv(&eigenvecs, n)?;
let cos_a_complex = v_cos_d.dot(&v_inv);
let sin_a_complex = v_sin_d.dot(&v_inv);
let mut cos_a = Array2::<F>::zeros((n, n));
let mut sin_a = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
cos_a[[i, j]] = cos_a_complex[[i, j]].re;
sin_a[[i, j]] = sin_a_complex[[i, j]].re;
}
}
Ok((cos_a, sin_a))
}
use scirs2_core::ndarray::Array1;
fn complex_inv<F: TrigFloat>(
m: &Array2<scirs2_core::numeric::Complex<F>>,
n: usize,
) -> LinalgResult<Array2<scirs2_core::numeric::Complex<F>>> {
use scirs2_core::numeric::Complex;
let mut a = m.to_owned();
let mut inv = Array2::<Complex<F>>::zeros((n, n));
for i in 0..n {
inv[[i, i]] = Complex::new(F::one(), F::zero());
}
for col in 0..n {
let mut max_row = col;
let mut max_val = a[[col, col]].norm_sqr();
for row in (col + 1)..n {
let v = a[[row, col]].norm_sqr();
if v > max_val {
max_val = v;
max_row = row;
}
}
if max_val < F::from(1e-30).unwrap_or(F::epsilon()) {
return Err(LinalgError::SingularMatrixError(
"sincos_via_eig: singular eigenvector matrix".to_string(),
));
}
if max_row != col {
for j in 0..n {
let tmp_a = a[[col, j]];
a[[col, j]] = a[[max_row, j]];
a[[max_row, j]] = tmp_a;
let tmp_i = inv[[col, j]];
inv[[col, j]] = inv[[max_row, j]];
inv[[max_row, j]] = tmp_i;
}
}
let pivot = a[[col, col]];
let inv_pivot = pivot.inv();
for j in 0..n {
a[[col, j]] *= inv_pivot;
inv[[col, j]] *= inv_pivot;
}
for row in 0..n {
if row == col {
continue;
}
let factor = a[[row, col]];
if factor.norm_sqr() < F::from(1e-30).unwrap_or(F::epsilon()) {
continue;
}
for j in 0..n {
let av = a[[col, j]] * factor;
let iv = inv[[col, j]] * factor;
a[[row, j]] -= av;
inv[[row, j]] -= iv;
}
}
}
Ok(inv)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_sinm_schur_zero() {
let a = array![[0.0_f64, 0.0], [0.0, 0.0]];
let s = sinm_schur(&a.view()).expect("sinm_schur zero");
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(s[[i, j]], 0.0, epsilon = 1e-12);
}
}
}
#[test]
fn test_sinm_schur_diagonal() {
let a = array![[0.5_f64, 0.0], [0.0, 1.0]];
let s = sinm_schur(&a.view()).expect("sinm_schur diagonal");
assert_abs_diff_eq!(s[[0, 0]], 0.5_f64.sin(), epsilon = 1e-10);
assert_abs_diff_eq!(s[[1, 1]], 1.0_f64.sin(), epsilon = 1e-10);
assert!(s[[0, 1]].abs() < 1e-10);
assert!(s[[1, 0]].abs() < 1e-10);
}
#[test]
fn test_sinm_schur_nilpotent() {
let t_val = 0.1_f64;
let a = array![[0.0, t_val], [0.0, 0.0]];
let s = sinm_schur(&a.view()).expect("sinm_schur nilpotent");
assert_abs_diff_eq!(s[[0, 0]], 0.0, epsilon = 1e-12);
assert_abs_diff_eq!(s[[0, 1]], t_val, epsilon = 1e-10);
assert_abs_diff_eq!(s[[1, 0]], 0.0, epsilon = 1e-12);
assert_abs_diff_eq!(s[[1, 1]], 0.0, epsilon = 1e-12);
}
#[test]
fn test_cosm_schur_zero() {
let a = array![[0.0_f64, 0.0], [0.0, 0.0]];
let c = cosm_schur(&a.view()).expect("cosm_schur zero");
assert_abs_diff_eq!(c[[0, 0]], 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(c[[1, 1]], 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(c[[0, 1]], 0.0, epsilon = 1e-12);
assert_abs_diff_eq!(c[[1, 0]], 0.0, epsilon = 1e-12);
}
#[test]
fn test_cosm_schur_diagonal() {
let a = array![[0.5_f64, 0.0], [0.0, 1.0]];
let c = cosm_schur(&a.view()).expect("cosm_schur diagonal");
assert_abs_diff_eq!(c[[0, 0]], 0.5_f64.cos(), epsilon = 1e-10);
assert_abs_diff_eq!(c[[1, 1]], 1.0_f64.cos(), epsilon = 1e-10);
}
#[test]
fn test_sin2_cos2_identity() {
let a = array![[0.3_f64, 0.0], [0.0, 0.7]];
let sin_a = sinm_schur(&a.view()).expect("sinm");
let cos_a = cosm_schur(&a.view()).expect("cosm");
let s2 = matmul_nn(&sin_a, &sin_a);
let c2 = matmul_nn(&cos_a, &cos_a);
for i in 0..2 {
for j in 0..2 {
let sum = s2[[i, j]] + c2[[i, j]];
let expected = if i == j { 1.0 } else { 0.0 };
assert_abs_diff_eq!(sum, expected, epsilon = 1e-10);
}
}
}
#[test]
fn test_tanm_schur_zero() {
let a = array![[0.0_f64, 0.0], [0.0, 0.0]];
let t = tanm_schur(&a.view()).expect("tanm_schur zero");
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(t[[i, j]], 0.0, epsilon = 1e-12);
}
}
}
#[test]
fn test_tanm_schur_diagonal() {
let a = array![[0.3_f64, 0.0], [0.0, 0.5]];
let t = tanm_schur(&a.view()).expect("tanm_schur diagonal");
assert_abs_diff_eq!(t[[0, 0]], 0.3_f64.tan(), epsilon = 1e-10);
assert_abs_diff_eq!(t[[1, 1]], 0.5_f64.tan(), epsilon = 1e-10);
}
#[test]
fn test_sinhm_schur_zero() {
let a = array![[0.0_f64, 0.0], [0.0, 0.0]];
let s = sinhm_schur(&a.view()).expect("sinhm_schur zero");
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(s[[i, j]], 0.0, epsilon = 1e-12);
}
}
}
#[test]
fn test_sinhm_schur_diagonal() {
let a = array![[0.5_f64, 0.0], [0.0, 1.0]];
let s = sinhm_schur(&a.view()).expect("sinhm_schur diagonal");
assert_abs_diff_eq!(s[[0, 0]], 0.5_f64.sinh(), epsilon = 1e-10);
assert_abs_diff_eq!(s[[1, 1]], 1.0_f64.sinh(), epsilon = 1e-10);
}
#[test]
fn test_coshm_schur_zero() {
let a = array![[0.0_f64, 0.0], [0.0, 0.0]];
let c = coshm_schur(&a.view()).expect("coshm_schur zero");
assert_abs_diff_eq!(c[[0, 0]], 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(c[[1, 1]], 1.0, epsilon = 1e-12);
}
#[test]
fn test_coshm_schur_diagonal() {
let a = array![[0.5_f64, 0.0], [0.0, 1.0]];
let c = coshm_schur(&a.view()).expect("coshm_schur diagonal");
assert_abs_diff_eq!(c[[0, 0]], 0.5_f64.cosh(), epsilon = 1e-10);
assert_abs_diff_eq!(c[[1, 1]], 1.0_f64.cosh(), epsilon = 1e-10);
}
#[test]
fn test_cosh2_sinh2_identity() {
let a = array![[0.3_f64, 0.0], [0.0, 0.7]];
let sinh_a = sinhm_schur(&a.view()).expect("sinhm");
let cosh_a = coshm_schur(&a.view()).expect("coshm");
let c2 = matmul_nn(&cosh_a, &cosh_a);
let s2 = matmul_nn(&sinh_a, &sinh_a);
for i in 0..2 {
for j in 0..2 {
let diff = c2[[i, j]] - s2[[i, j]];
let expected = if i == j { 1.0 } else { 0.0 };
assert_abs_diff_eq!(diff, expected, epsilon = 1e-10);
}
}
}
#[test]
fn test_tanhm_schur_zero() {
let a = array![[0.0_f64, 0.0], [0.0, 0.0]];
let t = tanhm_schur(&a.view()).expect("tanhm_schur zero");
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(t[[i, j]], 0.0, epsilon = 1e-12);
}
}
}
#[test]
fn test_tanhm_schur_diagonal() {
let a = array![[0.3_f64, 0.0], [0.0, 0.5]];
let t = tanhm_schur(&a.view()).expect("tanhm_schur diagonal");
assert_abs_diff_eq!(t[[0, 0]], 0.3_f64.tanh(), epsilon = 1e-10);
assert_abs_diff_eq!(t[[1, 1]], 0.5_f64.tanh(), epsilon = 1e-10);
}
#[test]
fn test_sincos_expm_diagonal() {
let a = array![[0.5_f64, 0.0], [0.0, 1.0]];
let (cos_a, sin_a) = sincos_expm(&a.view()).expect("sincos_expm failed");
assert_abs_diff_eq!(cos_a[[0, 0]], 0.5_f64.cos(), epsilon = 1e-10);
assert_abs_diff_eq!(cos_a[[1, 1]], 1.0_f64.cos(), epsilon = 1e-10);
assert_abs_diff_eq!(sin_a[[0, 0]], 0.5_f64.sin(), epsilon = 1e-10);
assert_abs_diff_eq!(sin_a[[1, 1]], 1.0_f64.sin(), epsilon = 1e-10);
}
#[test]
fn test_sincos_expm_rotation() {
let theta = 0.7_f64;
let a = array![[0.0_f64, theta], [-theta, 0.0]];
let (cos_a, sin_a) = sincos_expm(&a.view()).expect("sincos_expm rotation");
assert_abs_diff_eq!(cos_a[[0, 0]], theta.cosh(), epsilon = 1e-10);
assert_abs_diff_eq!(cos_a[[1, 1]], theta.cosh(), epsilon = 1e-10);
assert_abs_diff_eq!(cos_a[[0, 1]], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(cos_a[[1, 0]], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(sin_a[[0, 0]], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(sin_a[[1, 1]], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(sin_a[[0, 1]], theta.sinh(), epsilon = 1e-10);
assert_abs_diff_eq!(sin_a[[1, 0]], -theta.sinh(), epsilon = 1e-10);
}
#[test]
fn test_apply_schur_exp() {
let a = array![[0.5_f64, 0.0], [0.0, 1.0]];
let exp_a = apply_schur(&a.view(), |x: f64| x.exp(), "exp").expect("apply_schur exp");
assert_abs_diff_eq!(exp_a[[0, 0]], 0.5_f64.exp(), epsilon = 1e-10);
assert_abs_diff_eq!(exp_a[[1, 1]], 1.0_f64.exp(), epsilon = 1e-10);
}
#[test]
fn test_apply_schur_sqrt() {
let a = array![[4.0_f64, 0.0], [0.0, 9.0]];
let sqrt_a = apply_schur(&a.view(), |x: f64| x.sqrt(), "sqrt").expect("apply_schur sqrt");
assert_abs_diff_eq!(sqrt_a[[0, 0]], 2.0, epsilon = 1e-8);
assert_abs_diff_eq!(sqrt_a[[1, 1]], 3.0, epsilon = 1e-8);
}
}