use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array2, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign};
use std::iter::Sum;
pub trait TrigFloat:
Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static
{
}
impl<T> TrigFloat for T where
T: Float + NumAssign + Sum + ScalarOperand + Send + Sync + 'static
{
}
fn matmul_nn<F: TrigFloat>(a: &Array2<F>, b: &Array2<F>) -> Array2<F> {
let n = a.nrows();
let mut c = Array2::<F>::zeros((n, n));
for i in 0..n {
for k in 0..n {
let aik = a[[i, k]];
if aik == F::zero() {
continue;
}
for j in 0..n {
c[[i, j]] = c[[i, j]] + aik * b[[k, j]];
}
}
}
c
}
fn parlett_recurrence<F: TrigFloat>(
t: &Array2<F>,
f_diag: &[F],
n: usize,
) -> Array2<F> {
let mut ft = Array2::<F>::zeros((n, n));
for i in 0..n {
ft[[i, i]] = f_diag[i];
}
for j in 1..n {
for i in (0..j).rev() {
let fii = ft[[i, i]];
let fjj = ft[[j, j]];
let tij = t[[i, j]];
let denom = t[[j, j]] - t[[i, i]];
if denom.abs() < F::epsilon() * F::from(100.0).unwrap_or(F::one()) {
let mut inner_sum = F::zero();
for k in (i + 1)..j {
inner_sum = inner_sum + ft[[i, k]] * t[[k, j]] - t[[i, k]] * ft[[k, j]];
}
ft[[i, j]] = (fii - fjj) / (if denom.abs() > F::zero() { denom } else { F::epsilon() }) * tij
- inner_sum / (if denom.abs() > F::zero() { denom } else { F::epsilon() });
let _ = tij;
ft[[i, j]] = F::zero();
} else {
let mut numer = (fii - fjj) * tij;
for k in (i + 1)..j {
numer = numer + ft[[i, k]] * t[[k, j]] - t[[i, k]] * ft[[k, j]];
}
ft[[i, j]] = numer / denom;
}
}
}
ft
}
fn schur_function<F: TrigFloat>(
a: &ArrayView2<F>,
scalar_fn: fn(F) -> F,
name: &str,
) -> LinalgResult<Array2<F>> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(format!(
"{name}: matrix must be square"
)));
}
if n == 0 {
return Ok(Array2::<F>::zeros((0, 0)));
}
if n == 1 {
let mut result = Array2::<F>::zeros((1, 1));
result[[0, 0]] = scalar_fn(a[[0, 0]]);
return Ok(result);
}
let (q, t) = crate::decomposition::schur(a)?;
let f_diag: Vec<F> = (0..n).map(|i| scalar_fn(t[[i, i]])).collect();
let ft = parlett_recurrence(&t, &f_diag, n);
Ok(q.dot(&ft).dot(&q.t()))
}
pub fn sinm_schur<F: TrigFloat>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>> {
schur_function(a, |x: F| x.sin(), "sinm_schur")
}
pub fn cosm_schur<F: TrigFloat>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>> {
schur_function(a, |x: F| x.cos(), "cosm_schur")
}
pub fn tanm_schur<F: TrigFloat>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError("tanm_schur: matrix must be square".into()));
}
let sin_a = sinm_schur(a)?;
let cos_a = cosm_schur(a)?;
crate::solve::solve_multiple(&cos_a.view(), &sin_a.view(), None)
}
pub fn sinhm_schur<F: TrigFloat>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>> {
schur_function(a, |x: F| x.sinh(), "sinhm_schur")
}
pub fn coshm_schur<F: TrigFloat>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>> {
schur_function(a, |x: F| x.cosh(), "coshm_schur")
}
pub fn tanhm_schur<F: TrigFloat>(a: &ArrayView2<F>) -> LinalgResult<Array2<F>> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError("tanhm_schur: matrix must be square".into()));
}
let sinh_a = sinhm_schur(a)?;
let cosh_a = coshm_schur(a)?;
crate::solve::solve_multiple(&cosh_a.view(), &sinh_a.view(), None)
}
pub fn apply_schur<F: TrigFloat>(
a: &ArrayView2<F>,
f: fn(F) -> F,
name: &str,
) -> LinalgResult<Array2<F>> {
schur_function(a, f, name)
}
pub fn sincos_expm<F: TrigFloat>(a: &ArrayView2<F>) -> LinalgResult<(Array2<F>, Array2<F>)> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError("sincos_expm: matrix must be square".into()));
}
let n2 = 2 * n;
let mut aug = Array2::<F>::zeros((n2, n2));
for i in 0..n {
for j in 0..n {
aug[[i, j + n]] = -a[[i, j]]; aug[[i + n, j]] = a[[i, j]]; }
}
let exp_aug = crate::matrix_functions::pade::pade_expm(&aug.view())?;
let mut cos_a = Array2::<F>::zeros((n, n));
let mut sin_a = Array2::<F>::zeros((n, n));
for i in 0..n {
for j in 0..n {
cos_a[[i, j]] = exp_aug[[i, j]];
sin_a[[i, j]] = exp_aug[[i, j + n]];
}
}
Ok((cos_a, sin_a))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_sinm_schur_zero() {
let a = array![[0.0_f64, 0.0], [0.0, 0.0]];
let s = sinm_schur(&a.view()).expect("sinm_schur zero");
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(s[[i, j]], 0.0, epsilon = 1e-12);
}
}
}
#[test]
fn test_sinm_schur_diagonal() {
let a = array![[0.5_f64, 0.0], [0.0, 1.0]];
let s = sinm_schur(&a.view()).expect("sinm_schur diagonal");
assert_abs_diff_eq!(s[[0, 0]], 0.5_f64.sin(), epsilon = 1e-10);
assert_abs_diff_eq!(s[[1, 1]], 1.0_f64.sin(), epsilon = 1e-10);
assert!(s[[0, 1]].abs() < 1e-10);
assert!(s[[1, 0]].abs() < 1e-10);
}
#[test]
fn test_sinm_schur_nilpotent() {
let t_val = 0.1_f64;
let a = array![[0.0, t_val], [0.0, 0.0]];
let s = sinm_schur(&a.view()).expect("sinm_schur nilpotent");
assert_abs_diff_eq!(s[[0, 0]], 0.0, epsilon = 1e-12);
assert_abs_diff_eq!(s[[0, 1]], t_val, epsilon = 1e-10);
assert_abs_diff_eq!(s[[1, 0]], 0.0, epsilon = 1e-12);
assert_abs_diff_eq!(s[[1, 1]], 0.0, epsilon = 1e-12);
}
#[test]
fn test_cosm_schur_zero() {
let a = array![[0.0_f64, 0.0], [0.0, 0.0]];
let c = cosm_schur(&a.view()).expect("cosm_schur zero");
assert_abs_diff_eq!(c[[0, 0]], 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(c[[1, 1]], 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(c[[0, 1]], 0.0, epsilon = 1e-12);
assert_abs_diff_eq!(c[[1, 0]], 0.0, epsilon = 1e-12);
}
#[test]
fn test_cosm_schur_diagonal() {
let a = array![[0.5_f64, 0.0], [0.0, 1.0]];
let c = cosm_schur(&a.view()).expect("cosm_schur diagonal");
assert_abs_diff_eq!(c[[0, 0]], 0.5_f64.cos(), epsilon = 1e-10);
assert_abs_diff_eq!(c[[1, 1]], 1.0_f64.cos(), epsilon = 1e-10);
}
#[test]
fn test_sin2_cos2_identity() {
let a = array![[0.3_f64, 0.0], [0.0, 0.7]];
let sin_a = sinm_schur(&a.view()).expect("sinm");
let cos_a = cosm_schur(&a.view()).expect("cosm");
let s2 = matmul_nn(&sin_a, &sin_a);
let c2 = matmul_nn(&cos_a, &cos_a);
for i in 0..2 {
for j in 0..2 {
let sum = s2[[i, j]] + c2[[i, j]];
let expected = if i == j { 1.0 } else { 0.0 };
assert_abs_diff_eq!(sum, expected, epsilon = 1e-10);
}
}
}
#[test]
fn test_tanm_schur_zero() {
let a = array![[0.0_f64, 0.0], [0.0, 0.0]];
let t = tanm_schur(&a.view()).expect("tanm_schur zero");
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(t[[i, j]], 0.0, epsilon = 1e-12);
}
}
}
#[test]
fn test_tanm_schur_diagonal() {
let a = array![[0.3_f64, 0.0], [0.0, 0.5]];
let t = tanm_schur(&a.view()).expect("tanm_schur diagonal");
assert_abs_diff_eq!(t[[0, 0]], 0.3_f64.tan(), epsilon = 1e-10);
assert_abs_diff_eq!(t[[1, 1]], 0.5_f64.tan(), epsilon = 1e-10);
}
#[test]
fn test_sinhm_schur_zero() {
let a = array![[0.0_f64, 0.0], [0.0, 0.0]];
let s = sinhm_schur(&a.view()).expect("sinhm_schur zero");
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(s[[i, j]], 0.0, epsilon = 1e-12);
}
}
}
#[test]
fn test_sinhm_schur_diagonal() {
let a = array![[0.5_f64, 0.0], [0.0, 1.0]];
let s = sinhm_schur(&a.view()).expect("sinhm_schur diagonal");
assert_abs_diff_eq!(s[[0, 0]], 0.5_f64.sinh(), epsilon = 1e-10);
assert_abs_diff_eq!(s[[1, 1]], 1.0_f64.sinh(), epsilon = 1e-10);
}
#[test]
fn test_coshm_schur_zero() {
let a = array![[0.0_f64, 0.0], [0.0, 0.0]];
let c = coshm_schur(&a.view()).expect("coshm_schur zero");
assert_abs_diff_eq!(c[[0, 0]], 1.0, epsilon = 1e-12);
assert_abs_diff_eq!(c[[1, 1]], 1.0, epsilon = 1e-12);
}
#[test]
fn test_coshm_schur_diagonal() {
let a = array![[0.5_f64, 0.0], [0.0, 1.0]];
let c = coshm_schur(&a.view()).expect("coshm_schur diagonal");
assert_abs_diff_eq!(c[[0, 0]], 0.5_f64.cosh(), epsilon = 1e-10);
assert_abs_diff_eq!(c[[1, 1]], 1.0_f64.cosh(), epsilon = 1e-10);
}
#[test]
fn test_cosh2_sinh2_identity() {
let a = array![[0.3_f64, 0.0], [0.0, 0.7]];
let sinh_a = sinhm_schur(&a.view()).expect("sinhm");
let cosh_a = coshm_schur(&a.view()).expect("coshm");
let c2 = matmul_nn(&cosh_a, &cosh_a);
let s2 = matmul_nn(&sinh_a, &sinh_a);
for i in 0..2 {
for j in 0..2 {
let diff = c2[[i, j]] - s2[[i, j]];
let expected = if i == j { 1.0 } else { 0.0 };
assert_abs_diff_eq!(diff, expected, epsilon = 1e-10);
}
}
}
#[test]
fn test_tanhm_schur_zero() {
let a = array![[0.0_f64, 0.0], [0.0, 0.0]];
let t = tanhm_schur(&a.view()).expect("tanhm_schur zero");
for i in 0..2 {
for j in 0..2 {
assert_abs_diff_eq!(t[[i, j]], 0.0, epsilon = 1e-12);
}
}
}
#[test]
fn test_tanhm_schur_diagonal() {
let a = array![[0.3_f64, 0.0], [0.0, 0.5]];
let t = tanhm_schur(&a.view()).expect("tanhm_schur diagonal");
assert_abs_diff_eq!(t[[0, 0]], 0.3_f64.tanh(), epsilon = 1e-10);
assert_abs_diff_eq!(t[[1, 1]], 0.5_f64.tanh(), epsilon = 1e-10);
}
#[test]
fn test_sincos_expm_diagonal() {
let a = array![[0.5_f64, 0.0], [0.0, 1.0]];
let (cos_a, sin_a) = sincos_expm(&a.view()).expect("sincos_expm failed");
assert_abs_diff_eq!(cos_a[[0, 0]], 0.5_f64.cos(), epsilon = 1e-10);
assert_abs_diff_eq!(cos_a[[1, 1]], 1.0_f64.cos(), epsilon = 1e-10);
assert_abs_diff_eq!(sin_a[[0, 0]], 0.5_f64.sin(), epsilon = 1e-10);
assert_abs_diff_eq!(sin_a[[1, 1]], 1.0_f64.sin(), epsilon = 1e-10);
}
#[test]
fn test_sincos_expm_rotation() {
let theta = 0.7_f64;
let a = array![[0.0_f64, theta], [-theta, 0.0]];
let (cos_a, sin_a) = sincos_expm(&a.view()).expect("sincos_expm rotation");
assert_abs_diff_eq!(cos_a[[0, 0]], theta.cos(), epsilon = 1e-10);
assert_abs_diff_eq!(cos_a[[1, 1]], theta.cos(), epsilon = 1e-10);
assert_abs_diff_eq!(cos_a[[0, 1]], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(cos_a[[1, 0]], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(sin_a[[0, 0]], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(sin_a[[1, 1]], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(sin_a[[0, 1]], theta.sin(), epsilon = 1e-10);
assert_abs_diff_eq!(sin_a[[1, 0]], -theta.sin(), epsilon = 1e-10);
}
#[test]
fn test_apply_schur_exp() {
let a = array![[0.5_f64, 0.0], [0.0, 1.0]];
let exp_a = apply_schur(&a.view(), |x: f64| x.exp(), "exp").expect("apply_schur exp");
assert_abs_diff_eq!(exp_a[[0, 0]], 0.5_f64.exp(), epsilon = 1e-10);
assert_abs_diff_eq!(exp_a[[1, 1]], 1.0_f64.exp(), epsilon = 1e-10);
}
#[test]
fn test_apply_schur_sqrt() {
let a = array![[4.0_f64, 0.0], [0.0, 9.0]];
let sqrt_a = apply_schur(&a.view(), |x: f64| x.sqrt(), "sqrt").expect("apply_schur sqrt");
assert_abs_diff_eq!(sqrt_a[[0, 0]], 2.0, epsilon = 1e-8);
assert_abs_diff_eq!(sqrt_a[[1, 1]], 3.0, epsilon = 1e-8);
}
}