pub mod expm;
pub use expm::{expm_symbolic_2x2, expm_symbolic_3x3, ExpmSymbolicError};
pub mod recognize;
pub use recognize::{inverse_by_structure, recognize, StructureKind};
pub mod spectral;
pub use spectral::{
eigenpairs_symmetric_2x2, eigenvalues_circulant, structured_eigenvalues, StructuredEig,
};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_symbolic::eml::eval::{eval_real, EvalCtx};
use scirs2_symbolic::eml::{simplify_op, LoweredOp};
use std::sync::Arc;
#[derive(Debug)]
pub enum SymbolicLinalgError {
NotSquare {
rows: usize,
cols: usize,
},
Unsupported {
n: usize,
max: usize,
},
EvalError(String),
LinalgError(String),
}
impl std::fmt::Display for SymbolicLinalgError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SymbolicLinalgError::NotSquare { rows, cols } => {
write!(f, "matrix is not square ({rows}×{cols})")
}
SymbolicLinalgError::Unsupported { n, max } => {
write!(
f,
"matrix size {n}×{n} exceeds maximum supported size {max}×{max}"
)
}
SymbolicLinalgError::EvalError(msg) => write!(f, "evaluation error: {msg}"),
SymbolicLinalgError::LinalgError(msg) => write!(f, "linalg error: {msg}"),
}
}
}
impl std::error::Error for SymbolicLinalgError {}
#[inline]
fn cell(m: &ArrayView2<Arc<LoweredOp>>, r: usize, c: usize) -> LoweredOp {
m[[r, c]].as_ref().clone()
}
#[inline]
fn add(a: LoweredOp, b: LoweredOp) -> LoweredOp {
LoweredOp::Add(Box::new(a), Box::new(b))
}
#[inline]
fn sub(a: LoweredOp, b: LoweredOp) -> LoweredOp {
LoweredOp::Sub(Box::new(a), Box::new(b))
}
#[inline]
fn mul(a: LoweredOp, b: LoweredOp) -> LoweredOp {
LoweredOp::Mul(Box::new(a), Box::new(b))
}
#[inline]
fn pow(base: LoweredOp, exp: LoweredOp) -> LoweredOp {
LoweredOp::Pow(Box::new(base), Box::new(exp))
}
#[inline]
fn div(a: LoweredOp, b: LoweredOp) -> LoweredOp {
LoweredOp::Div(Box::new(a), Box::new(b))
}
#[inline]
fn sqrt(a: LoweredOp) -> LoweredOp {
LoweredOp::Sqrt(Box::new(a))
}
#[inline]
fn neg(a: LoweredOp) -> LoweredOp {
LoweredOp::Neg(Box::new(a))
}
#[inline]
fn cnst(v: f64) -> LoweredOp {
LoweredOp::Const(v)
}
fn minor(matrix: &ArrayView2<Arc<LoweredOp>>, row: usize, col: usize) -> Array2<Arc<LoweredOp>> {
let n = matrix.nrows();
debug_assert!(n >= 1);
let m = n - 1;
Array2::from_shape_fn((m, m), |(ri, ci)| {
let src_r = if ri < row { ri } else { ri + 1 };
let src_c = if ci < col { ci } else { ci + 1 };
Arc::clone(&matrix[[src_r, src_c]])
})
}
pub fn det_symbolic(matrix: ArrayView2<Arc<LoweredOp>>) -> Result<LoweredOp, SymbolicLinalgError> {
let rows = matrix.nrows();
let cols = matrix.ncols();
if rows != cols {
return Err(SymbolicLinalgError::NotSquare { rows, cols });
}
let n = rows;
let raw = det_recursive(&matrix, n)?;
Ok(simplify_op(&raw))
}
fn det_recursive(
matrix: &ArrayView2<Arc<LoweredOp>>,
n: usize,
) -> Result<LoweredOp, SymbolicLinalgError> {
match n {
0 => Ok(cnst(1.0)),
1 => Ok(cell(matrix, 0, 0)),
2 => {
let a = cell(matrix, 0, 0);
let b = cell(matrix, 0, 1);
let c = cell(matrix, 1, 0);
let d = cell(matrix, 1, 1);
Ok(sub(mul(a, d), mul(b, c)))
}
3 | 4 => {
let mut terms: Vec<LoweredOp> = Vec::with_capacity(n);
for j in 0..n {
let m_sub = minor(matrix, 0, j);
let sub_det = det_recursive(&m_sub.view(), n - 1)?;
let entry = cell(matrix, 0, j);
let product = mul(entry, sub_det);
if j % 2 == 0 {
terms.push(product);
} else {
terms.push(neg(product));
}
}
let mut acc = terms.remove(0);
for t in terms {
acc = add(acc, t);
}
Ok(acc)
}
n => Err(SymbolicLinalgError::Unsupported { n, max: 4 }),
}
}
pub fn eigenvalues_symbolic_2x2(
matrix: ArrayView2<Arc<LoweredOp>>,
) -> Result<[LoweredOp; 2], SymbolicLinalgError> {
let rows = matrix.nrows();
let cols = matrix.ncols();
if rows != 2 || cols != 2 {
let n = rows;
return Err(SymbolicLinalgError::Unsupported { n, max: 2 });
}
let a = cell(&matrix, 0, 0);
let b = cell(&matrix, 0, 1);
let c = cell(&matrix, 1, 0);
let d = cell(&matrix, 1, 1);
let tr = add(a.clone(), d.clone());
let det_val = sub(mul(a, d), mul(b, c));
let tr_sq = pow(tr.clone(), cnst(2.0));
let four_det = mul(cnst(4.0), det_val);
let discriminant = sub(tr_sq, four_det);
let sqrt_disc = sqrt(discriminant);
let lambda_plus = div(add(tr.clone(), sqrt_disc.clone()), cnst(2.0));
let lambda_minus = div(sub(tr, sqrt_disc), cnst(2.0));
Ok([simplify_op(&lambda_plus), simplify_op(&lambda_minus)])
}
pub fn condition_number_symbolic(
matrix: ArrayView2<Arc<LoweredOp>>,
point: ArrayView1<f64>,
) -> Result<f64, SymbolicLinalgError> {
let rows = matrix.nrows();
let cols = matrix.ncols();
if rows != cols {
return Err(SymbolicLinalgError::NotSquare { rows, cols });
}
let n = rows;
let bindings: Vec<f64> = point.to_vec();
let ctx = EvalCtx::new(&bindings);
let mut numeric = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..n {
let entry = matrix[[i, j]].as_ref();
let v = eval_real(entry, &ctx)
.map_err(|e| SymbolicLinalgError::EvalError(e.to_string()))?;
numeric[[i, j]] = v;
}
}
crate::cond(&numeric.view(), None, None)
.map_err(|e| SymbolicLinalgError::LinalgError(e.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{arr1, arr2, Array2};
use scirs2_symbolic::eml::{eval_real, EvalCtx};
fn c(v: f64) -> Arc<LoweredOp> {
Arc::new(LoweredOp::Const(v))
}
fn v(i: usize) -> Arc<LoweredOp> {
Arc::new(LoweredOp::Var(i))
}
#[test]
fn det_2x2_diagonal_matches_product() {
let zero = c(0.0);
let mat = Array2::from_shape_fn((2, 2), |(r, c_)| match (r, c_) {
(0, 0) => v(0),
(1, 1) => v(1),
_ => Arc::clone(&zero),
});
let expr = det_symbolic(mat.view()).expect("det");
let val = eval_real(&expr, &EvalCtx::new(&[2.0, 3.0])).expect("eval");
assert!((val - 6.0).abs() < 1e-12, "got {val}");
}
#[test]
fn det_2x2_general() {
let mat = Array2::from_shape_fn((2, 2), |(r, c_)| v(r * 2 + c_));
let expr = det_symbolic(mat.view()).expect("det");
let val = eval_real(&expr, &EvalCtx::new(&[1.0, 2.0, 3.0, 4.0])).expect("eval");
assert!((val - (-2.0)).abs() < 1e-12, "got {val}");
}
#[test]
fn det_3x3_diagonal() {
let zero = c(0.0);
let mat = Array2::from_shape_fn((3, 3), |(r, c_)| {
if r == c_ {
c([2.0, 3.0, 5.0][r])
} else {
Arc::clone(&zero)
}
});
let expr = det_symbolic(mat.view()).expect("det");
let val = eval_real(&expr, &EvalCtx::new(&[])).expect("eval");
assert!((val - 30.0).abs() < 1e-10, "got {val}");
}
#[test]
fn det_3x3_known() {
let entries = [[1.0, 2.0, 0.0], [3.0, 4.0, 0.0], [0.0, 0.0, 5.0]];
let mat = Array2::from_shape_fn((3, 3), |(r, c_)| c(entries[r][c_]));
let expr = det_symbolic(mat.view()).expect("det");
let val = eval_real(&expr, &EvalCtx::new(&[])).expect("eval");
assert!((val - (-10.0)).abs() < 1e-10, "got {val}");
}
#[test]
fn det_4x4_block_diagonal() {
let zero = c(0.0);
let mat = Array2::from_shape_fn(
(4, 4),
|(r, c_)| {
if r == c_ {
v(r)
} else {
Arc::clone(&zero)
}
},
);
let expr = det_symbolic(mat.view()).expect("det");
let val = eval_real(&expr, &EvalCtx::new(&[2.0, 3.0, 4.0, 5.0])).expect("eval");
assert!((val - 120.0).abs() < 1e-8, "got {val}");
}
#[test]
fn det_5x5_returns_unsupported() {
let mat = Array2::from_elem((5, 5), c(1.0));
match det_symbolic(mat.view()) {
Err(SymbolicLinalgError::Unsupported { n: 5, max: 4 }) => {}
other => panic!("expected Unsupported(5,4), got {other:?}"),
}
}
#[test]
fn det_non_square_returns_err() {
let mat = Array2::from_elem((2, 3), c(1.0));
match det_symbolic(mat.view()) {
Err(SymbolicLinalgError::NotSquare { rows: 2, cols: 3 }) => {}
other => panic!("expected NotSquare(2,3), got {other:?}"),
}
}
#[test]
fn eigenvalues_2x2_symmetric() {
let one = c(1.0);
let mat = Array2::from_shape_fn(
(2, 2),
|(r, c_)| {
if r == c_ {
v(0)
} else {
Arc::clone(&one)
}
},
);
let [lp, lm] = eigenvalues_symbolic_2x2(mat.view()).expect("eig");
let ctx = EvalCtx::new(&[3.0]);
let vp = eval_real(&lp, &ctx).expect("eval λ+");
let vm = eval_real(&lm, &ctx).expect("eval λ-");
assert!((vp - 4.0).abs() < 1e-10, "λ+ = {vp}");
assert!((vm - 2.0).abs() < 1e-10, "λ- = {vm}");
}
#[test]
fn eigenvalues_2x2_complex_at_point_returns_err() {
let mat = Array2::from_shape_fn((2, 2), |(r, c_)| {
let v = match (r, c_) {
(0, 0) | (1, 1) => 0.0,
(0, 1) => -1.0,
_ => 1.0,
};
c(v)
});
let [lp, _lm] = eigenvalues_symbolic_2x2(mat.view()).expect("eig");
let ctx = EvalCtx::new(&[]);
assert!(
eval_real(&lp, &ctx).is_err(),
"expected Err for complex eigenvalue"
);
}
#[test]
fn eigenvalues_non_2x2_returns_err() {
let mat = Array2::from_elem((3, 3), c(1.0));
match eigenvalues_symbolic_2x2(mat.view()) {
Err(SymbolicLinalgError::Unsupported { n: 3, max: 2 }) => {}
other => panic!("expected Unsupported(3,2), got {other:?}"),
}
}
#[test]
fn condition_number_2x2_diagonal_known() {
let one = c(1.0);
let mat = Array2::from_shape_fn(
(2, 2),
|(r, c_)| {
if r == c_ {
v(0)
} else {
Arc::clone(&one)
}
},
);
let kappa = condition_number_symbolic(mat.view(), arr1(&[3.0]).view()).expect("cond");
assert!((kappa - 2.0).abs() < 1e-6, "cond = {kappa}");
}
#[test]
fn condition_number_matches_numerical_baseline() {
let half = c(0.5);
let mat = Array2::from_shape_fn((2, 2), |(r, c_)| match (r, c_) {
(0, 0) => v(0),
(1, 1) => v(1),
_ => Arc::clone(&half),
});
let symbolic_kappa =
condition_number_symbolic(mat.view(), arr1(&[2.0, 3.0]).view()).expect("cond");
let numeric = arr2(&[[2.0_f64, 0.5], [0.5, 3.0]]);
let numeric_kappa = crate::cond(&numeric.view(), None, None).expect("cond baseline");
assert!(
(symbolic_kappa - numeric_kappa).abs() < 1e-8,
"symbolic={symbolic_kappa}, numeric={numeric_kappa}"
);
}
#[test]
fn condition_number_non_square_returns_err() {
let mat = Array2::from_elem((2, 3), c(1.0));
match condition_number_symbolic(mat.view(), arr1(&[1.0, 2.0, 3.0]).view()) {
Err(SymbolicLinalgError::NotSquare { rows: 2, cols: 3 }) => {}
other => panic!("expected NotSquare, got {other:?}"),
}
}
}