use scirs2_core::ndarray::{Array2, ArrayView2};
use scirs2_symbolic::cas::matrix_exp::{expm_2x2, expm_3x3, expm_diag_2x2, expm_diag_3x3};
use scirs2_symbolic::eml::op::LoweredOp;
use std::sync::Arc;
#[derive(Debug)]
pub enum ExpmSymbolicError {
WrongSize {
got_rows: usize,
got_cols: usize,
expected: usize,
},
NotSquare {
rows: usize,
cols: usize,
},
CubicRootSymbolic,
}
impl std::fmt::Display for ExpmSymbolicError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExpmSymbolicError::WrongSize {
got_rows,
got_cols,
expected,
} => write!(
f,
"matrix is {got_rows}×{got_cols} but expm_symbolic_{expected}x{expected} \
requires a {expected}×{expected} matrix"
),
ExpmSymbolicError::NotSquare { rows, cols } => {
write!(f, "matrix is not square ({rows}×{cols})")
}
ExpmSymbolicError::CubicRootSymbolic => write!(
f,
"expm_symbolic_3x3 requires all entries to be constant (Const) \
or the matrix to be diagonal; symbolic off-diagonal entries \
are not supported in the general 3×3 case"
),
}
}
}
impl std::error::Error for ExpmSymbolicError {}
fn array2_to_fixed_2x2(
m: ArrayView2<Arc<LoweredOp>>,
) -> Result<[[LoweredOp; 2]; 2], ExpmSymbolicError> {
let (r, c) = m.dim();
if r != c {
return Err(ExpmSymbolicError::NotSquare { rows: r, cols: c });
}
if r != 2 {
return Err(ExpmSymbolicError::WrongSize {
got_rows: r,
got_cols: c,
expected: 2,
});
}
Ok([
[m[[0, 0]].as_ref().clone(), m[[0, 1]].as_ref().clone()],
[m[[1, 0]].as_ref().clone(), m[[1, 1]].as_ref().clone()],
])
}
fn array2_to_fixed_3x3(
m: ArrayView2<Arc<LoweredOp>>,
) -> Result<[[LoweredOp; 3]; 3], ExpmSymbolicError> {
let (r, c) = m.dim();
if r != c {
return Err(ExpmSymbolicError::NotSquare { rows: r, cols: c });
}
if r != 3 {
return Err(ExpmSymbolicError::WrongSize {
got_rows: r,
got_cols: c,
expected: 3,
});
}
Ok([
[
m[[0, 0]].as_ref().clone(),
m[[0, 1]].as_ref().clone(),
m[[0, 2]].as_ref().clone(),
],
[
m[[1, 0]].as_ref().clone(),
m[[1, 1]].as_ref().clone(),
m[[1, 2]].as_ref().clone(),
],
[
m[[2, 0]].as_ref().clone(),
m[[2, 1]].as_ref().clone(),
m[[2, 2]].as_ref().clone(),
],
])
}
fn fixed_2x2_to_array2(m: [[LoweredOp; 2]; 2]) -> Array2<Arc<LoweredOp>> {
Array2::from_shape_fn((2, 2), |(i, j)| Arc::new(m[i][j].clone()))
}
fn fixed_3x3_to_array2(m: [[LoweredOp; 3]; 3]) -> Array2<Arc<LoweredOp>> {
Array2::from_shape_fn((3, 3), |(i, j)| Arc::new(m[i][j].clone()))
}
pub fn expm_symbolic_2x2(
m: ArrayView2<Arc<LoweredOp>>,
) -> Result<Array2<Arc<LoweredOp>>, ExpmSymbolicError> {
let fixed = array2_to_fixed_2x2(m)?;
if let Some(result) = expm_diag_2x2(&fixed) {
return Ok(fixed_2x2_to_array2(result));
}
let result = expm_2x2(&fixed);
Ok(fixed_2x2_to_array2(result))
}
pub fn expm_symbolic_3x3(
m: ArrayView2<Arc<LoweredOp>>,
) -> Result<Array2<Arc<LoweredOp>>, ExpmSymbolicError> {
let fixed = array2_to_fixed_3x3(m)?;
if let Some(result) = expm_diag_3x3(&fixed) {
return Ok(fixed_3x3_to_array2(result));
}
let result = expm_3x3(&fixed).map_err(|_| ExpmSymbolicError::CubicRootSymbolic)?;
Ok(fixed_3x3_to_array2(result))
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_symbolic::eml::eval::{eval_real, EvalCtx};
fn c(v: f64) -> Arc<LoweredOp> {
Arc::new(LoweredOp::Const(v))
}
fn var(i: usize) -> Arc<LoweredOp> {
Arc::new(LoweredOp::Var(i))
}
fn eval(op: &LoweredOp) -> f64 {
let ctx = EvalCtx::new(&[]);
eval_real(op, &ctx).expect("expression must be constant for eval()")
}
#[test]
fn test_expm_2x2_zero_matrix_gives_identity() {
let mat = Array2::from_shape_fn((2, 2), |_| c(0.0));
let result = expm_symbolic_2x2(mat.view()).expect("expm_symbolic_2x2 zero");
let tol = 1e-12;
assert!(
(eval(result[[0, 0]].as_ref()) - 1.0).abs() < tol,
"result[0][0] should be 1, got {}",
eval(result[[0, 0]].as_ref())
);
assert!(
eval(result[[0, 1]].as_ref()).abs() < tol,
"result[0][1] should be 0"
);
assert!(
eval(result[[1, 0]].as_ref()).abs() < tol,
"result[1][0] should be 0"
);
assert!(
(eval(result[[1, 1]].as_ref()) - 1.0).abs() < tol,
"result[1][1] should be 1"
);
}
#[test]
fn test_expm_2x2_diagonal_entries() {
let mat = Array2::from_shape_fn(
(2, 2),
|(r, col)| {
if r == col {
c([1.0, 2.0][r])
} else {
c(0.0)
}
},
);
let result = expm_symbolic_2x2(mat.view()).expect("expm_symbolic_2x2 diag");
let e = std::f64::consts::E;
let tol = 1e-10;
assert!(
(eval(result[[0, 0]].as_ref()) - e).abs() < tol,
"result[0][0] = e¹"
);
assert!(
(eval(result[[1, 1]].as_ref()) - e * e).abs() < tol,
"result[1][1] = e²"
);
assert!(
eval(result[[0, 1]].as_ref()).abs() < tol,
"off-diagonal [0][1] = 0"
);
assert!(
eval(result[[1, 0]].as_ref()).abs() < tol,
"off-diagonal [1][0] = 0"
);
}
#[test]
fn test_expm_2x2_general_off_diagonal() {
let mat = Array2::from_shape_fn((2, 2), |(r, col)| match (r, col) {
(0, 1) => c(2.0),
(1, 0) => c(1.0),
_ => c(0.0),
});
let result = expm_symbolic_2x2(mat.view()).expect("expm_symbolic_2x2 off-diagonal general");
let sqrt2 = 2.0_f64.sqrt();
let cosh_sqrt2 = sqrt2.cosh();
let sinh_sqrt2 = sqrt2.sinh();
let tol = 1e-8;
assert!(
(eval(result[[0, 0]].as_ref()) - cosh_sqrt2).abs() < tol,
"[0][0] = cosh(√2) = {cosh_sqrt2:.8}, got {}",
eval(result[[0, 0]].as_ref())
);
let expected_01 = 2.0 * sinh_sqrt2 / sqrt2;
assert!(
(eval(result[[0, 1]].as_ref()) - expected_01).abs() < tol,
"[0][1] = 2·sinh(√2)/√2 = {expected_01:.8}, got {}",
eval(result[[0, 1]].as_ref())
);
let expected_10 = sinh_sqrt2 / sqrt2;
assert!(
(eval(result[[1, 0]].as_ref()) - expected_10).abs() < tol,
"[1][0] = sinh(√2)/√2 = {expected_10:.8}, got {}",
eval(result[[1, 0]].as_ref())
);
assert!(
(eval(result[[1, 1]].as_ref()) - cosh_sqrt2).abs() < tol,
"[1][1] = cosh(√2)"
);
}
#[test]
fn test_expm_2x2_round_trip_inverse() {
let entries = [[1.0_f64, 2.0], [3.0, 4.0]];
let mat = Array2::from_shape_fn((2, 2), |(r, col)| c(entries[r][col]));
let neg_mat = Array2::from_shape_fn((2, 2), |(r, col)| c(-entries[r][col]));
let em = expm_symbolic_2x2(mat.view()).expect("expm(M)");
let enm = expm_symbolic_2x2(neg_mat.view()).expect("expm(-M)");
let em_f = [
[eval(em[[0, 0]].as_ref()), eval(em[[0, 1]].as_ref())],
[eval(em[[1, 0]].as_ref()), eval(em[[1, 1]].as_ref())],
];
let enm_f = [
[eval(enm[[0, 0]].as_ref()), eval(enm[[0, 1]].as_ref())],
[eval(enm[[1, 0]].as_ref()), eval(enm[[1, 1]].as_ref())],
];
let prod = [
[
em_f[0][0] * enm_f[0][0] + em_f[0][1] * enm_f[1][0],
em_f[0][0] * enm_f[0][1] + em_f[0][1] * enm_f[1][1],
],
[
em_f[1][0] * enm_f[0][0] + em_f[1][1] * enm_f[1][0],
em_f[1][0] * enm_f[0][1] + em_f[1][1] * enm_f[1][1],
],
];
let tol = 1e-6;
assert!(
(prod[0][0] - 1.0).abs() < tol,
"I[0][0] = {:.2e}",
prod[0][0]
);
assert!(prod[0][1].abs() < tol, "I[0][1] = {:.2e}", prod[0][1]);
assert!(prod[1][0].abs() < tol, "I[1][0] = {:.2e}", prod[1][0]);
assert!(
(prod[1][1] - 1.0).abs() < tol,
"I[1][1] = {:.2e}",
prod[1][1]
);
}
#[test]
fn test_expm_3x3_diagonal_entries() {
let mat = Array2::from_shape_fn((3, 3), |(r, col)| {
if r == col {
c([1.0, 2.0, 3.0][r])
} else {
c(0.0)
}
});
let result = expm_symbolic_3x3(mat.view()).expect("expm_symbolic_3x3 diag");
let e = std::f64::consts::E;
let tol = 1e-10;
assert!(
(eval(result[[0, 0]].as_ref()) - e).abs() < tol,
"diag[0] = e¹"
);
assert!(
(eval(result[[1, 1]].as_ref()) - e * e).abs() < tol,
"diag[1] = e²"
);
assert!(
(eval(result[[2, 2]].as_ref()) - e * e * e).abs() < tol,
"diag[2] = e³"
);
assert!(eval(result[[0, 1]].as_ref()).abs() < tol, "off[0][1] = 0");
assert!(eval(result[[1, 2]].as_ref()).abs() < tol, "off[1][2] = 0");
}
#[test]
fn test_expm_3x3_symbolic_entries_returns_err() {
let mat = Array2::from_shape_fn((3, 3), |(r, col)| var(r * 3 + col));
let result = expm_symbolic_3x3(mat.view());
assert!(
matches!(result, Err(ExpmSymbolicError::CubicRootSymbolic)),
"expected CubicRootSymbolic, got {result:?}"
);
}
#[test]
fn test_expm_2x2_wrong_size_returns_err() {
let mat = Array2::from_shape_fn((3, 3), |_| c(0.0));
let result = expm_symbolic_2x2(mat.view());
assert!(
matches!(
result,
Err(ExpmSymbolicError::WrongSize { expected: 2, .. })
),
"expected WrongSize(2), got {result:?}"
);
}
#[test]
fn test_expm_2x2_non_square_returns_err() {
let mat = Array2::from_shape_fn((2, 3), |_| c(0.0));
let result = expm_symbolic_2x2(mat.view());
assert!(
matches!(
result,
Err(ExpmSymbolicError::NotSquare { rows: 2, cols: 3 })
),
"expected NotSquare(2,3), got {result:?}"
);
}
#[test]
fn test_expm_3x3_non_square_returns_err() {
let mat = Array2::from_shape_fn((3, 4), |_| c(0.0));
let result = expm_symbolic_3x3(mat.view());
assert!(
matches!(
result,
Err(ExpmSymbolicError::NotSquare { rows: 3, cols: 4 })
),
"expected NotSquare(3,4), got {result:?}"
);
}
#[test]
fn test_expm_3x3_constant_general_matrix() {
let mat = Array2::from_shape_fn((3, 3), |(r, col)| {
let v = match (r, col) {
(0, 1) => 1.0,
(1, 0) => -1.0,
_ => 0.0,
};
c(v)
});
let result = expm_symbolic_3x3(mat.view())
.expect("expm_symbolic_3x3 should succeed for all-constant matrix");
let cos1 = 1.0_f64.cos();
let sin1 = 1.0_f64.sin();
let tol = 1e-8;
assert!(
(eval(result[[0, 0]].as_ref()) - cos1).abs() < tol,
"[0][0] = cos(1) = {cos1:.8}, got {}",
eval(result[[0, 0]].as_ref())
);
assert!(
(eval(result[[0, 1]].as_ref()) - sin1).abs() < tol,
"[0][1] = sin(1) = {sin1:.8}, got {}",
eval(result[[0, 1]].as_ref())
);
assert!(
(eval(result[[2, 2]].as_ref()) - 1.0).abs() < tol,
"[2][2] = 1"
);
}
}