use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, NumAssign, One, Zero};
use std::fmt::Debug;
use std::iter::Sum;
use crate::error::{LinalgError, LinalgResult};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum MatrixFunctionType {
Exponential,
SquareRoot,
Logarithm,
Sign,
ExponentialScaled(f64),
Cosine,
Sine,
}
#[derive(Debug, Clone)]
pub struct MatrixFunctionParams {
pub function: MatrixFunctionType,
pub tol: f64,
pub max_krylov_dim: usize,
pub use_lanczos: bool,
pub restart_dim: Option<usize>,
}
impl Default for MatrixFunctionParams {
fn default() -> Self {
Self {
function: MatrixFunctionType::Exponential,
tol: 1e-10,
max_krylov_dim: 50,
use_lanczos: false,
restart_dim: None,
}
}
}
impl MatrixFunctionParams {
pub fn new(function: MatrixFunctionType) -> Self {
Self {
function,
..Self::default()
}
}
pub fn with_tol(mut self, tol: f64) -> Self {
self.tol = tol;
self
}
pub fn with_max_krylov_dim(mut self, dim: usize) -> Self {
self.max_krylov_dim = dim;
self
}
pub fn with_lanczos(mut self) -> Self {
self.use_lanczos = true;
self
}
pub fn with_restart(mut self, restart_dim: usize) -> Self {
self.restart_dim = Some(restart_dim);
self
}
}
#[derive(Debug, Clone)]
pub struct ArnoldiResult {
pub v: Array2<f64>,
pub h: Array2<f64>,
pub m: usize,
pub happy_breakdown: bool,
pub residual_norm: f64,
}
pub struct ArnoldiIteration;
impl ArnoldiIteration {
pub fn run(
a: &ArrayView2<f64>,
v0: &ArrayView1<f64>,
max_dim: usize,
tol: Option<f64>,
) -> LinalgResult<ArnoldiResult> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"ArnoldiIteration: A must be square".to_string(),
));
}
if v0.len() != n {
return Err(LinalgError::DimensionError(format!(
"ArnoldiIteration: v0 has length {} but A is {}×{}",
v0.len(),
n,
n
)));
}
if max_dim == 0 {
return Err(LinalgError::ShapeError(
"ArnoldiIteration: max_dim must be positive".to_string(),
));
}
let tolerance = tol.unwrap_or(1e-14);
let m = max_dim.min(n);
let v0_norm = v0.iter().map(|&v| v * v).sum::<f64>().sqrt();
if v0_norm < tolerance {
return Err(LinalgError::InvalidInputError(
"ArnoldiIteration: starting vector is zero".to_string(),
));
}
let mut v_mat = Array2::zeros((n, m + 1));
for i in 0..n {
v_mat[[i, 0]] = v0[i] / v0_norm;
}
let mut h_mat = Array2::zeros((m + 1, m));
let mut actual_m = 0;
let mut happy = false;
let mut res_norm = 0.0;
for j in 0..m {
let mut w = Array1::zeros(n);
for i in 0..n {
let mut val = 0.0f64;
for k in 0..n {
val += a[[i, k]] * v_mat[[k, j]];
}
w[i] = val;
}
for i in 0..=j {
let h_ij = (0..n).map(|k| w[k] * v_mat[[k, i]]).sum::<f64>();
h_mat[[i, j]] = h_ij;
for k in 0..n {
w[k] -= h_ij * v_mat[[k, i]];
}
}
for i in 0..=j {
let correction = (0..n).map(|k| w[k] * v_mat[[k, i]]).sum::<f64>();
h_mat[[i, j]] += correction;
for k in 0..n {
w[k] -= correction * v_mat[[k, i]];
}
}
let h_next = w.iter().map(|&v| v * v).sum::<f64>().sqrt();
h_mat[[j + 1, j]] = h_next;
actual_m = j + 1;
if h_next < tolerance {
happy = true;
res_norm = 0.0;
break;
}
res_norm = h_next;
if j + 1 < m {
for k in 0..n {
v_mat[[k, j + 1]] = w[k] / h_next;
}
}
}
let v_out = v_mat.slice(scirs2_core::ndarray::s![.., ..actual_m]).to_owned();
let h_out = h_mat
.slice(scirs2_core::ndarray::s![..actual_m + 1, ..actual_m])
.to_owned();
Ok(ArnoldiResult {
v: v_out,
h: h_out,
m: actual_m,
happy_breakdown: happy,
residual_norm: res_norm,
})
}
}
#[derive(Debug, Clone)]
pub struct LanczosResult {
pub v: Array2<f64>,
pub alpha: Array1<f64>,
pub beta: Array1<f64>,
pub m: usize,
pub breakdown: bool,
}
pub struct LanczosIteration;
impl LanczosIteration {
pub fn run(
a: &ArrayView2<f64>,
v0: &ArrayView1<f64>,
max_dim: usize,
tol: Option<f64>,
) -> LinalgResult<LanczosResult> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"LanczosIteration: A must be square".to_string(),
));
}
if v0.len() != n {
return Err(LinalgError::DimensionError(format!(
"LanczosIteration: v0 has length {} but n={}",
v0.len(),
n
)));
}
let tolerance = tol.unwrap_or(1e-14);
let m = max_dim.min(n);
let v0_norm = v0.iter().map(|&v| v * v).sum::<f64>().sqrt();
if v0_norm < tolerance {
return Err(LinalgError::InvalidInputError(
"LanczosIteration: starting vector is zero".to_string(),
));
}
let mut v_mat = Array2::zeros((n, m + 1));
for i in 0..n {
v_mat[[i, 0]] = v0[i] / v0_norm;
}
let mut alpha_vec = vec![0.0f64; m];
let mut beta_vec = vec![0.0f64; m];
let mut beta_prev = 0.0f64;
let mut actual_m = 0;
let mut breakdown = false;
for j in 0..m {
let mut w = Array1::zeros(n);
for i in 0..n {
let mut val = 0.0f64;
for k in 0..n {
val += a[[i, k]] * v_mat[[k, j]];
}
w[i] = val;
}
if j > 0 {
for k in 0..n {
w[k] -= beta_prev * v_mat[[k, j - 1]];
}
}
let alpha_j = (0..n).map(|k| v_mat[[k, j]] * w[k]).sum::<f64>();
alpha_vec[j] = alpha_j;
for k in 0..n {
w[k] -= alpha_j * v_mat[[k, j]];
}
for i in 0..=j {
let corr = (0..n).map(|k| w[k] * v_mat[[k, i]]).sum::<f64>();
for k in 0..n {
w[k] -= corr * v_mat[[k, i]];
}
}
let beta_j = w.iter().map(|&v| v * v).sum::<f64>().sqrt();
actual_m = j + 1;
if beta_j < tolerance {
breakdown = true;
break;
}
beta_vec[j] = beta_j;
beta_prev = beta_j;
if j + 1 < m {
for k in 0..n {
v_mat[[k, j + 1]] = w[k] / beta_j;
}
}
}
let v_out = v_mat.slice(scirs2_core::ndarray::s![.., ..actual_m]).to_owned();
let alpha_out = Array1::from_vec(alpha_vec[..actual_m].to_vec());
let beta_out = if actual_m > 1 {
Array1::from_vec(beta_vec[..actual_m - 1].to_vec())
} else {
Array1::zeros(0)
};
Ok(LanczosResult {
v: v_out,
alpha: alpha_out,
beta: beta_out,
m: actual_m,
breakdown,
})
}
pub fn tridiagonal_matrix(result: &LanczosResult) -> Array2<f64> {
let m = result.m;
let mut t = Array2::zeros((m, m));
for i in 0..m {
t[[i, i]] = result.alpha[i];
if i + 1 < m {
let beta = result.beta[i];
t[[i, i + 1]] = beta;
t[[i + 1, i]] = beta;
}
}
t
}
}
pub struct MatrixFunctionInterpolation;
impl MatrixFunctionInterpolation {
pub fn apply(
a: &ArrayView2<f64>,
v: &ArrayView1<f64>,
params: &MatrixFunctionParams,
) -> LinalgResult<Array1<f64>> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"MatrixFunctionInterpolation: A must be square".to_string(),
));
}
if v.len() != n {
return Err(LinalgError::DimensionError(format!(
"MatrixFunctionInterpolation: v has length {} but n={}",
v.len(),
n
)));
}
let v_norm = v.iter().map(|&vi| vi * vi).sum::<f64>().sqrt();
if v_norm < 1e-300 {
return Ok(Array1::zeros(n));
}
if params.use_lanczos {
Self::apply_lanczos(a, v, v_norm, params)
} else {
Self::apply_arnoldi(a, v, v_norm, params)
}
}
fn apply_arnoldi(
a: &ArrayView2<f64>,
v: &ArrayView1<f64>,
v_norm: f64,
params: &MatrixFunctionParams,
) -> LinalgResult<Array1<f64>> {
let arnoldi_result = ArnoldiIteration::run(a, v, params.max_krylov_dim, Some(1e-14))?;
let m = arnoldi_result.m;
let hm = arnoldi_result
.h
.slice(scirs2_core::ndarray::s![..m, ..m])
.to_owned();
let fhm = apply_dense_function(&hm.view(), params)?;
let mut result = Array1::zeros(v.len());
for i in 0..v.len() {
let mut val = 0.0f64;
for j in 0..m {
val += arnoldi_result.v[[i, j]] * fhm[[j, 0]];
}
result[i] = v_norm * val;
}
Ok(result)
}
fn apply_lanczos(
a: &ArrayView2<f64>,
v: &ArrayView1<f64>,
v_norm: f64,
params: &MatrixFunctionParams,
) -> LinalgResult<Array1<f64>> {
let lanczos_result = LanczosIteration::run(a, v, params.max_krylov_dim, Some(1e-14))?;
let m = lanczos_result.m;
let tm = LanczosIteration::tridiagonal_matrix(&lanczos_result);
let ftm = apply_dense_function(&tm.view(), params)?;
let mut result = Array1::zeros(v.len());
for i in 0..v.len() {
let mut val = 0.0f64;
for j in 0..m {
val += lanczos_result.v[[i, j]] * ftm[[j, 0]];
}
result[i] = v_norm * val;
}
Ok(result)
}
}
pub struct MatrixExpKrylov;
impl MatrixExpKrylov {
pub fn apply(
a: &ArrayView2<f64>,
v: &ArrayView1<f64>,
t: f64,
max_krylov_dim: Option<usize>,
tol: Option<f64>,
) -> LinalgResult<Array1<f64>> {
let n = a.nrows();
if a.ncols() != n {
return Err(LinalgError::ShapeError(
"MatrixExpKrylov: A must be square".to_string(),
));
}
if v.len() != n {
return Err(LinalgError::DimensionError(format!(
"MatrixExpKrylov: v has length {} but n={}",
v.len(),
n
)));
}
let m = max_krylov_dim.unwrap_or_else(|| 50.min(n));
let tolerance = tol.unwrap_or(1e-10);
let v_norm = v.iter().map(|&vi| vi * vi).sum::<f64>().sqrt();
if v_norm < 1e-300 {
return Ok(Array1::zeros(n));
}
if t == 0.0 {
return Ok(v.to_owned());
}
if n <= m + 5 {
let params = MatrixFunctionParams::new(MatrixFunctionType::ExponentialScaled(t))
.with_max_krylov_dim(n.min(m))
.with_tol(tolerance);
return MatrixFunctionInterpolation::apply(a, v, ¶ms);
}
let arnoldi_result = ArnoldiIteration::run(a, v, m, Some(1e-14))?;
let m_actual = arnoldi_result.m;
let h_core = arnoldi_result
.h
.slice(scirs2_core::ndarray::s![..m_actual, ..m_actual])
.to_owned();
let th = h_core.mapv(|v| v * t);
let exp_th = pade_expm(&th.view())?;
let mut result = Array1::zeros(n);
for i in 0..n {
let mut val = 0.0f64;
for j in 0..m_actual {
val += arnoldi_result.v[[i, j]] * exp_th[[j, 0]];
}
result[i] = v_norm * val;
}
let err_est = if m_actual >= 1 {
let h_residual = arnoldi_result.residual_norm;
let exp_last = exp_th[[m_actual - 1, 0]].abs();
v_norm * h_residual * exp_last * t.abs()
} else {
0.0
};
if err_est > tolerance && m_actual < n {
let m2 = (m_actual * 2).min(n);
let params = MatrixFunctionParams::new(MatrixFunctionType::ExponentialScaled(t))
.with_max_krylov_dim(m2)
.with_tol(tolerance);
return MatrixFunctionInterpolation::apply(a, v, ¶ms);
}
Ok(result)
}
}
fn apply_dense_function(
h: &ArrayView2<f64>,
params: &MatrixFunctionParams,
) -> LinalgResult<Array2<f64>> {
match params.function {
MatrixFunctionType::Exponential => pade_expm(h),
MatrixFunctionType::ExponentialScaled(t) => {
let ht = h.mapv(|v| v * t);
pade_expm(&ht.view())
}
MatrixFunctionType::SquareRoot => dense_sqrtm(h),
MatrixFunctionType::Logarithm => dense_logm(h),
MatrixFunctionType::Sign => dense_signm(h),
MatrixFunctionType::Cosine => dense_cosm(h),
MatrixFunctionType::Sine => dense_sinm(h),
}
}
pub fn pade_expm(a: &ArrayView2<f64>) -> LinalgResult<Array2<f64>> {
let n = a.nrows();
if n == 0 {
return Ok(Array2::zeros((0, 0)));
}
if n == 1 {
let mut r = Array2::zeros((1, 1));
r[[0, 0]] = a[[0, 0]].exp();
return Ok(r);
}
let a_norm = matrix_1norm(a);
let theta = 5.371920351148152; let mut s = 0i32;
let mut scale = 1.0f64;
if a_norm > theta {
s = (a_norm / theta).log2().ceil() as i32;
scale = 2.0f64.powi(-s);
}
let a_scaled = if scale != 1.0 {
a.mapv(|v| v * scale)
} else {
a.to_owned()
};
let c = [
1.0_f64,
0.5_f64,
12.0_f64.recip(),
120.0_f64.recip(),
1.0 / 720.0,
1.0 / 30240.0,
1.0 / 1209600.0,
];
let ident = Array2::eye(n);
let a2 = matmul_dense(&a_scaled.view(), &a_scaled.view());
let a4 = matmul_dense(&a2.view(), &a2.view());
let a6 = matmul_dense(&a4.view(), &a2.view());
let u = &(c[6] * &a6 + c[4] * &a4 + c[2] * &a2 + c[0] * &ident);
let u = matmul_dense(&a_scaled.view(), &u.view());
let v = c[5] * &a6 + c[3] * &a4 + c[1] * &a2 + &ident;
let vpu = &v + &u;
let vmu = &v - &u;
let r = solve_dense(&vmu.view(), &vpu.view())?;
let mut result = r;
for _ in 0..s {
result = matmul_dense(&result.view(), &result.view());
}
Ok(result)
}
fn dense_sqrtm(a: &ArrayView2<f64>) -> LinalgResult<Array2<f64>> {
let n = a.nrows();
let mut x = a.to_owned();
let mut y = Array2::eye(n);
for _ in 0..50 {
let x_inv = solve_dense(&x.view(), &Array2::eye(n).view())?;
let y_inv = solve_dense(&y.view(), &Array2::eye(n).view())?;
let x_new = 0.5 * (&x + &y_inv);
let y_new = 0.5 * (&y + &x_inv);
let diff = (&x_new - &x).mapv(|v: f64| v.abs());
let err = diff.sum();
let x_norm = x_new.mapv(|v: f64| v.abs()).sum();
x = x_new;
y = y_new;
if err < 1e-14 * x_norm {
break;
}
}
Ok(x)
}
fn dense_logm(a: &ArrayView2<f64>) -> LinalgResult<Array2<f64>> {
let n = a.nrows();
let mut a_scaled = a.to_owned();
let mut s = 0i32;
for _ in 0..50 {
let diff: f64 = (0..n)
.map(|i| {
(0..n)
.map(|j| {
let v = a_scaled[[i, j]] - if i == j { 1.0 } else { 0.0 };
v.abs()
})
.sum::<f64>()
})
.fold(0.0, f64::max);
if diff <= 0.5 {
break;
}
a_scaled = dense_sqrtm(&a_scaled.view())?;
s += 1;
}
let ident = Array2::<f64>::eye(n);
let b = &a_scaled - &ident;
let b_norm = matrix_1norm(&b.view());
let log_scaled = if b_norm < 0.1 {
let mut log_a = b.clone();
let mut bpow = b.clone();
for k in 2..=20usize {
bpow = matmul_dense(&bpow.view(), &b.view());
let sign = if k % 2 == 0 { -1.0 } else { 1.0 };
log_a = log_a + (sign / k as f64) * &bpow;
let bpow_norm = matrix_1norm(&bpow.view());
if bpow_norm / (k as f64) < 1e-15 {
break;
}
}
log_a
} else {
let two_plus_b = 2.0 * &ident + &b;
let z = solve_dense(&two_plus_b.view(), &b.view())?;
let mut atanh_z = z.clone();
let mut zpow = z.clone();
for k in 1..=20usize {
zpow = matmul_dense(&zpow.view(), &matmul_dense(&z.view(), &z.view()).view());
let coeff = 1.0 / (2 * k + 1) as f64;
atanh_z = atanh_z + coeff * &zpow;
let err = matrix_1norm(&zpow.view()) * coeff;
if err < 1e-15 {
break;
}
}
2.0 * atanh_z
};
Ok(2.0f64.powi(s) * log_scaled)
}
fn dense_signm(a: &ArrayView2<f64>) -> LinalgResult<Array2<f64>> {
let n = a.nrows();
let mut x = a.to_owned();
for _ in 0..50 {
let x_inv = solve_dense(&x.view(), &Array2::eye(n).view())?;
let x_new = 0.5 * (&x + &x_inv);
let diff_norm = (&x_new - &x).mapv(|v: f64| v.abs()).sum();
let x_norm = x_new.mapv(|v: f64| v.abs()).sum().max(1e-300);
x = x_new;
if diff_norm / x_norm < 1e-14 {
break;
}
}
Ok(x)
}
fn dense_cosm(a: &ArrayView2<f64>) -> LinalgResult<Array2<f64>> {
let n = a.nrows();
let ident = Array2::<f64>::eye(n);
let a2 = matmul_dense(&a.view(), &a.view());
let mut result = ident.clone();
let mut apow = ident.clone();
let mut sign = 1.0_f64;
for k in 1..=16usize {
let two_k = 2 * k;
apow = matmul_dense(&apow.view(), &a2.view());
let mut fact = 1.0f64;
for i in 1..=(two_k as u64) {
fact *= i as f64;
}
sign = -sign;
result = result + (sign / fact) * &apow;
let apow_norm = matrix_1norm(&apow.view());
if apow_norm / fact < 1e-15 {
break;
}
}
Ok(result)
}
fn dense_sinm(a: &ArrayView2<f64>) -> LinalgResult<Array2<f64>> {
let n = a.nrows();
let a2 = matmul_dense(&a.view(), &a.view());
let mut result = a.to_owned();
let mut apow = a.to_owned();
let mut sign = 1.0_f64;
for k in 1..=16usize {
let two_k_plus_1 = 2 * k + 1;
apow = matmul_dense(&matmul_dense(&apow.view(), &a2.view()).view(), &Array2::eye(n).view());
apow = matmul_dense(&apow.view(), &a2.view());
let mut fact = 1.0f64;
for i in 1..=(two_k_plus_1 as u64) {
fact *= i as f64;
}
sign = -sign;
result = result + (sign / fact) * &apow;
let apow_norm = matrix_1norm(&apow.view());
if apow_norm / fact < 1e-15 {
break;
}
}
Ok(result)
}
fn matrix_1norm(a: &ArrayView2<f64>) -> f64 {
let (m, n) = a.dim();
let mut max_col = 0.0f64;
for j in 0..n {
let col_sum: f64 = (0..m).map(|i| a[[i, j]].abs()).sum();
if col_sum > max_col {
max_col = col_sum;
}
}
max_col
}
fn matmul_dense(a: &ArrayView2<f64>, b: &ArrayView2<f64>) -> Array2<f64> {
let (m, k) = a.dim();
let (_, n) = b.dim();
let mut c = Array2::zeros((m, n));
for i in 0..m {
for l in 0..k {
let aval = a[[i, l]];
for j in 0..n {
c[[i, j]] += aval * b[[l, j]];
}
}
}
c
}
fn solve_dense(a: &ArrayView2<f64>, b: &ArrayView2<f64>) -> LinalgResult<Array2<f64>> {
let n = a.nrows();
let nb = b.ncols();
let mut lu = a.to_owned();
let mut pivot = vec![0usize; n];
for k in 0..n {
let mut max_val = lu[[k, k]].abs();
let mut max_row = k;
for i in (k + 1)..n {
let v = lu[[i, k]].abs();
if v > max_val {
max_val = v;
max_row = i;
}
}
pivot[k] = max_row;
if max_row != k {
for j in 0..n {
let tmp = lu[[k, j]];
lu[[k, j]] = lu[[max_row, j]];
lu[[max_row, j]] = tmp;
}
}
let diag = lu[[k, k]];
if diag.abs() < 1e-300 {
return Err(LinalgError::SingularMatrixError(format!(
"solve_dense: matrix is singular at column {k}"
)));
}
for i in (k + 1)..n {
lu[[i, k]] /= diag;
for j in (k + 1)..n {
let luk = lu[[i, k]];
let lukj = lu[[k, j]];
lu[[i, j]] -= luk * lukj;
}
}
}
let mut x = b.to_owned();
for k in 0..n {
let p = pivot[k];
if p != k {
for j in 0..nb {
let tmp = x[[k, j]];
x[[k, j]] = x[[p, j]];
x[[p, j]] = tmp;
}
}
}
for k in 0..n {
for i in (k + 1)..n {
let lk = lu[[i, k]];
for j in 0..nb {
let xkj = x[[k, j]];
x[[i, j]] -= lk * xkj;
}
}
}
for k in (0..n).rev() {
let diag = lu[[k, k]];
for j in 0..nb {
x[[k, j]] /= diag;
}
for i in 0..k {
let uk = lu[[i, k]];
for j in 0..nb {
let xkj = x[[k, j]];
x[[i, j]] -= uk * xkj;
}
}
}
Ok(x)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_arnoldi_iteration() {
let a = array![
[2.0, 1.0, 0.0, 0.0],
[0.0, 2.0, 1.0, 0.0],
[0.0, 0.0, 2.0, 1.0],
[0.0, 0.0, 0.0, 2.0],
];
let v = array![1.0, 0.0, 0.0, 0.0];
let result = ArnoldiIteration::run(&a.view(), &v.view(), 3, None)
.expect("Arnoldi failed");
let vt_v = matmul_dense(&result.v.t(), &result.v.view());
for i in 0..result.m {
for j in 0..result.m {
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(vt_v[[i, j]] - expected).abs() < 1e-12,
"V^T V not identity at ({},{}) = {}",
i,
j,
vt_v[[i, j]]
);
}
}
}
#[test]
fn test_lanczos_iteration() {
let a = array![
[4.0, 1.0, 0.0, 0.0],
[1.0, 3.0, 1.0, 0.0],
[0.0, 1.0, 2.0, 1.0],
[0.0, 0.0, 1.0, 1.0],
];
let v = array![1.0, 0.0, 0.0, 0.0];
let result = LanczosIteration::run(&a.view(), &v.view(), 3, None)
.expect("Lanczos failed");
let vt_v = matmul_dense(&result.v.t(), &result.v.view());
for i in 0..result.m {
for j in 0..result.m {
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(vt_v[[i, j]] - expected).abs() < 1e-10,
"Lanczos V^T V not identity at ({},{}) = {}",
i,
j,
vt_v[[i, j]]
);
}
}
}
#[test]
fn test_matrix_exp_krylov_scalar() {
let a = array![[2.0_f64]];
let v = array![3.0_f64];
let t = 0.5;
let result = MatrixExpKrylov::apply(&a.view(), &v.view(), t, None, None)
.expect("MatrixExpKrylov failed");
let expected = 3.0 * (2.0_f64 * 0.5).exp();
assert!(
(result[0] - expected).abs() < 1e-8,
"MatrixExpKrylov scalar: {} vs {}",
result[0],
expected
);
}
#[test]
fn test_matrix_function_interpolation_identity() {
let n = 4;
let a = Array2::eye(n);
let v: Array1<f64> = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let params = MatrixFunctionParams::new(MatrixFunctionType::Exponential)
.with_max_krylov_dim(n)
.with_tol(1e-10);
let result = MatrixFunctionInterpolation::apply(&a.view(), &v.view(), ¶ms)
.expect("MatrixFunctionInterpolation failed");
let expected_scale = 1.0_f64.exp();
for i in 0..n {
let expected = expected_scale * v[i];
assert!(
(result[i] - expected).abs() < 1e-6,
"exp(I) * v failed at {i}: {} vs {}",
result[i],
expected
);
}
}
#[test]
fn test_pade_expm_zero() {
let a = Array2::<f64>::zeros((3, 3));
let result = pade_expm(&a.view()).expect("pade_expm of zero failed");
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(result[[i, j]] - expected).abs() < 1e-12,
"pade_expm(0) failed at ({},{})={}",
i,
j,
result[[i, j]]
);
}
}
}
#[test]
fn test_lanczos_tridiagonal() {
let a = array![
[3.0, 1.0, 0.0],
[1.0, 3.0, 1.0],
[0.0, 1.0, 3.0],
];
let v = array![1.0, 1.0, 1.0];
let result = LanczosIteration::run(&a.view(), &v.view(), 3, None)
.expect("Lanczos failed");
let t = LanczosIteration::tridiagonal_matrix(&result);
for i in 0..result.m {
for j in 0..result.m {
assert!((t[[i, j]] - t[[j, i]]).abs() < 1e-12);
}
}
}
}