use crate::csr::CsrMatrix;
use crate::error::{SparseError, SparseResult};
use scirs2_core::numeric::{Float, NumAssign, One, SparseElement, Zero};
use std::iter::Sum;
#[allow(dead_code)]
pub fn expm<F>(a: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
where
F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
{
let (rows, cols) = a.shape();
if rows != cols {
return Err(SparseError::ValueError(
"Matrix must be square for expm".to_string(),
));
}
let a_norm = matrix_inf_norm(a)?;
let theta_13 = F::from(5.371920351148152).expect("Failed to convert constant to float");
if a_norm <= theta_13 {
return pade_approximation(a, 13);
}
let mut s = 0;
let mut scaled_norm = a_norm;
let two = F::from(2.0).expect("Failed to convert constant to float");
while scaled_norm > theta_13 {
s += 1;
scaled_norm /= two;
}
let scale_factor = two.powi(s);
let scaled_a = scale_matrix(a, F::sparse_one() / scale_factor)?;
let mut exp_scaled = pade_approximation(&scaled_a, 13)?;
for _ in 0..s {
exp_scaled = exp_scaled.matmul(&exp_scaled)?;
}
Ok(exp_scaled)
}
#[allow(dead_code)]
fn pade_approximation<F>(a: &CsrMatrix<F>, p: usize) -> SparseResult<CsrMatrix<F>>
where
F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
{
let n = a.shape().0;
let mut a_powers = vec![sparse_identity(n)?]; a_powers.push(a.clone());
for i in 2..=p {
let prev = &a_powers[i - 1];
let power = prev.matmul(a)?;
a_powers.push(power);
}
let pade_coeffs = match p {
6 => vec![
F::from(1.0).expect("Failed to convert constant to float"),
F::from(1.0 / 2.0).expect("Failed to convert to float"),
F::from(3.0 / 26.0).expect("Failed to convert to float"),
F::from(1.0 / 312.0).expect("Failed to convert to float"),
F::from(1.0 / 10608.0).expect("Failed to convert to float"),
F::from(1.0 / 358800.0).expect("Failed to convert to float"),
F::from(1.0 / 17297280.0).expect("Failed to convert to float"),
],
13 => {
let two_p = 26i64;
let p = 13i64;
let mut coeffs = Vec::with_capacity(14);
for k in 0..=p {
let mut num = 1.0;
let mut den = 1.0;
for i in (two_p - k + 1)..=two_p {
den *= i as f64;
}
for i in (p - k + 1)..=p {
num *= i as f64;
}
let mut k_fact = 1.0;
for i in 1..=k {
k_fact *= i as f64;
}
coeffs.push(F::from(num / (den * k_fact)).expect("Operation failed"));
}
coeffs
}
_ => {
let mut coeffs = vec![F::sparse_zero(); p + 1];
let mut factorial: F = F::sparse_one();
for (i, coeff) in coeffs.iter_mut().enumerate().take(p + 1) {
if i > 0 {
factorial *= F::from(i).expect("Failed to convert to float");
}
let numerator = factorial;
let mut denominator = F::sparse_one();
for j in 1..=i {
denominator *= F::from(p + 1 - j).expect("Failed to convert to float");
}
for j in 1..=(p - i) {
denominator *= F::from(j).expect("Failed to convert to float");
}
*coeff = numerator / denominator;
}
coeffs
}
};
let mut u = sparse_zero(n)?;
let mut v = sparse_zero(n)?;
for (i, coeff) in pade_coeffs.iter().enumerate() {
let scaled_matrix = scale_matrix(&a_powers[i], *coeff)?;
if i % 2 == 0 {
v = sparse_add(&v, &scaled_matrix)?;
} else {
u = sparse_add(&u, &scaled_matrix)?;
}
}
let neg_u = scale_matrix(
&u,
F::from(-1.0).expect("Failed to convert constant to float"),
)?;
let v_minus_u = sparse_add(&v, &neg_u)?;
let v_plus_u = sparse_add(&v, &u)?;
sparse_solve(&v_minus_u, &v_plus_u)
}
#[allow(dead_code)]
fn matrix_inf_norm<F>(a: &CsrMatrix<F>) -> SparseResult<F>
where
F: Float + NumAssign + Sum + SparseElement + std::fmt::Debug,
{
let mut max_row_sum = F::sparse_zero();
for row in 0..a.rows() {
let start = a.indptr[row];
let end = a.indptr[row + 1];
let row_sum: F = a.data[start..end].iter().map(|x| x.abs()).sum();
if row_sum > max_row_sum {
max_row_sum = row_sum;
}
}
Ok(max_row_sum)
}
#[allow(dead_code)]
fn scale_matrix<F>(a: &CsrMatrix<F>, scale: F) -> SparseResult<CsrMatrix<F>>
where
F: Float + NumAssign + SparseElement,
{
let mut data = a.data.clone();
for val in data.iter_mut() {
*val *= scale;
}
CsrMatrix::from_raw_csr(data, a.indptr.clone(), a.indices.clone(), a.shape())
}
#[allow(dead_code)]
fn sparse_identity<F>(n: usize) -> SparseResult<CsrMatrix<F>>
where
F: Float + Zero + One + SparseElement,
{
let mut rows = Vec::with_capacity(n);
let mut cols = Vec::with_capacity(n);
let mut values = Vec::with_capacity(n);
for i in 0..n {
rows.push(i);
cols.push(i);
values.push(F::sparse_one());
}
CsrMatrix::new(values, rows, cols, (n, n))
}
#[allow(dead_code)]
fn sparse_zero<F>(n: usize) -> SparseResult<CsrMatrix<F>>
where
F: Float + Zero + SparseElement,
{
Ok(CsrMatrix::empty((n, n)))
}
#[allow(dead_code)]
fn sparse_add<F>(a: &CsrMatrix<F>, b: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
where
F: Float + NumAssign + SparseElement,
{
if a.shape() != b.shape() {
return Err(SparseError::ShapeMismatch {
expected: a.shape(),
found: b.shape(),
});
}
let mut rows = Vec::new();
let mut cols = Vec::new();
let mut values = Vec::new();
for i in 0..a.rows() {
for j in 0..a.cols() {
let val = a.get(i, j) + b.get(i, j);
if val.abs() > F::epsilon() {
rows.push(i);
cols.push(j);
values.push(val);
}
}
}
CsrMatrix::new(values, rows, cols, a.shape())
}
#[allow(dead_code)]
fn sparse_solve<F>(a: &CsrMatrix<F>, b: &CsrMatrix<F>) -> SparseResult<CsrMatrix<F>>
where
F: Float + NumAssign + Sum + SparseElement + 'static + std::fmt::Debug,
{
use crate::linalg::interface::MatrixLinearOperator;
use crate::linalg::iterative::bicgstab;
use crate::linalg::iterative::BiCGSTABOptions;
let n = a.rows();
let mut result_rows = Vec::new();
let mut result_cols = Vec::new();
let mut result_values = Vec::new();
for col in 0..b.cols() {
let b_col = (0..n).map(|row| b.get(row, col)).collect::<Vec<_>>();
let op = MatrixLinearOperator::new(a.clone());
let options = BiCGSTABOptions {
rtol: F::from(1e-10).expect("Failed to convert constant to float"),
atol: F::from(1e-12).expect("Failed to convert constant to float"),
max_iter: 1000,
x0: None,
left_preconditioner: None,
right_preconditioner: None,
};
let result = bicgstab(&op, &b_col, options)?;
if !result.converged {
return Err(SparseError::IterativeSolverFailure(format!(
"BiCGSTAB failed to converge in {} iterations",
result.iterations
)));
}
for (row, &val) in result.x.iter().enumerate() {
if val.abs() > F::epsilon() {
result_rows.push(row);
result_cols.push(col);
result_values.push(val);
}
}
}
CsrMatrix::new(result_values, result_rows, result_cols, (n, b.cols()))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_expm_identity() {
let n = 3;
let zero_matrix = sparse_zero::<f64>(n).expect("Operation failed");
let exp_zero = expm(&zero_matrix).expect("Operation failed");
for i in 0..n {
for j in 0..n {
let expected = if i == j { 1.0 } else { 0.0 };
let actual = exp_zero.get(i, j);
assert_relative_eq!(actual, expected, epsilon = 1e-10);
}
}
}
#[test]
fn test_expm_diagonal() {
let n = 3;
let diag_values = [0.5, 1.0, 2.0];
let mut rows = Vec::new();
let mut cols = Vec::new();
let mut values = Vec::new();
for (i, &val) in diag_values.iter().enumerate() {
rows.push(i);
cols.push(i);
values.push(val);
}
let diag_matrix = CsrMatrix::new(values, rows, cols, (n, n)).expect("Operation failed");
let exp_diag = expm(&diag_matrix).expect("Operation failed");
for (i, &val) in diag_values.iter().enumerate() {
let expected = val.exp();
let actual = exp_diag.get(i, i);
assert_relative_eq!(actual, expected, epsilon = 1e-10);
}
for i in 0..n {
for j in 0..n {
if i != j {
let actual = exp_diag.get(i, j);
assert_relative_eq!(actual, 0.0, epsilon = 1e-10);
}
}
}
}
#[test]
fn test_expm_small_matrix() {
let rows = vec![0, 1];
let cols = vec![1, 0];
let values = vec![1.0, 0.0];
let a = CsrMatrix::new(values, rows, cols, (2, 2)).expect("Operation failed");
let exp_a = expm(&a).expect("Operation failed");
assert_relative_eq!(exp_a.get(0, 0), 1.0, epsilon = 1e-10);
assert_relative_eq!(exp_a.get(0, 1), 1.0, epsilon = 1e-10);
assert_relative_eq!(exp_a.get(1, 0), 0.0, epsilon = 1e-10);
assert_relative_eq!(exp_a.get(1, 1), 1.0, epsilon = 1e-10);
}
}