use scirs2_core::ndarray::{Array2, ArrayView2};
use scirs2_symbolic::cas::{canonicalize, inverse_2x2, inverse_3x3, InverseResult};
use scirs2_symbolic::eml::{simplify_op, LoweredOp};
use std::sync::Arc;
use super::SymbolicLinalgError;
#[derive(Debug, Clone, PartialEq)]
pub enum StructureKind {
Scalar,
Diagonal,
Circulant {
first_row: Vec<LoweredOp>,
},
LowRankUpdate {
u: Vec<LoweredOp>,
v: Vec<LoweredOp>,
},
General,
}
fn zero_hash() -> u128 {
LoweredOp::Const(0.0).structural_hash()
}
fn canon_hash(op: &LoweredOp) -> u128 {
canonicalize(op).hash()
}
#[inline]
fn lo_mul(a: LoweredOp, b: LoweredOp) -> LoweredOp {
LoweredOp::Mul(Box::new(a), Box::new(b))
}
#[inline]
fn lo_add(a: LoweredOp, b: LoweredOp) -> LoweredOp {
LoweredOp::Add(Box::new(a), Box::new(b))
}
#[inline]
fn lo_sub(a: LoweredOp, b: LoweredOp) -> LoweredOp {
LoweredOp::Sub(Box::new(a), Box::new(b))
}
#[inline]
fn lo_div(a: LoweredOp, b: LoweredOp) -> LoweredOp {
LoweredOp::Div(Box::new(a), Box::new(b))
}
#[inline]
fn lo_const(v: f64) -> LoweredOp {
LoweredOp::Const(v)
}
#[inline]
fn cell_op(m: &ArrayView2<Arc<LoweredOp>>, r: usize, c: usize) -> LoweredOp {
m[[r, c]].as_ref().clone()
}
pub fn recognize(m: ArrayView2<Arc<LoweredOp>>) -> StructureKind {
let nrows = m.nrows();
let ncols = m.ncols();
if nrows != ncols || nrows == 0 {
return StructureKind::General;
}
let n = nrows;
let base_hash = m[[0, 0]].structural_hash();
let all_equal = (0..n).all(|r| (0..n).all(|c| m[[r, c]].structural_hash() == base_hash));
if all_equal {
return StructureKind::Scalar;
}
let zero_h = zero_hash();
let is_diagonal =
(0..n).all(|r| (0..n).all(|c| r == c || canon_hash(m[[r, c]].as_ref()) == zero_h));
if is_diagonal {
return StructureKind::Diagonal;
}
if n <= 8 {
let is_circulant = (0..n).all(|r| {
(0..n).all(|c| {
m[[r, c]].structural_hash() == m[[(r + 1) % n, (c + 1) % n]].structural_hash()
})
});
if is_circulant {
let first_row: Vec<LoweredOp> = (0..n).map(|c| cell_op(&m, 0, c)).collect();
return StructureKind::Circulant { first_row };
}
}
if n <= 8 {
if let Some((u, v)) = try_extract_low_rank_update(&m, n) {
return StructureKind::LowRankUpdate { u, v };
}
}
StructureKind::General
}
fn try_extract_low_rank_update(
m: &ArrayView2<Arc<LoweredOp>>,
n: usize,
) -> Option<(Vec<LoweredOp>, Vec<LoweredOp>)> {
let mut row_left_hash: Vec<Option<u128>> = vec![None; n];
let mut row_left_op: Vec<Option<LoweredOp>> = (0..n).map(|_| None).collect();
let mut col_right_hash: Vec<Option<u128>> = vec![None; n];
let mut col_right_op: Vec<Option<LoweredOp>> = (0..n).map(|_| None).collect();
for r in 0..n {
for c in 0..n {
if r == c {
continue;
}
let entry = m[[r, c]].as_ref();
let (left, right) = match entry {
LoweredOp::Mul(l, r_arm) => (l.as_ref().clone(), r_arm.as_ref().clone()),
_ => return None,
};
let lh = left.structural_hash();
let rh = right.structural_hash();
match row_left_hash[r] {
None => {
row_left_hash[r] = Some(lh);
row_left_op[r] = Some(left);
}
Some(existing) => {
if existing != lh {
return None;
}
}
}
match col_right_hash[c] {
None => {
col_right_hash[c] = Some(rh);
col_right_op[c] = Some(right);
}
Some(existing) => {
if existing != rh {
return None;
}
}
}
}
}
let u: Vec<LoweredOp> = row_left_op.into_iter().collect::<Option<Vec<_>>>()?;
let v: Vec<LoweredOp> = col_right_op.into_iter().collect::<Option<Vec<_>>>()?;
let expected_diag_hashes: Vec<u128> = (0..n)
.map(|i| {
let expected = lo_add(lo_const(1.0), lo_mul(u[i].clone(), v[i].clone()));
canon_hash(&expected)
})
.collect();
for i in 0..n {
let actual_h = canon_hash(m[[i, i]].as_ref());
if actual_h != expected_diag_hashes[i] {
return None;
}
}
Some((u, v))
}
pub fn inverse_by_structure(
m: ArrayView2<Arc<LoweredOp>>,
) -> Result<Array2<Arc<LoweredOp>>, SymbolicLinalgError> {
let nrows = m.nrows();
let ncols = m.ncols();
if nrows != ncols {
return Err(SymbolicLinalgError::NotSquare {
rows: nrows,
cols: ncols,
});
}
let n = nrows;
match recognize(m) {
StructureKind::Diagonal => inverse_diagonal(m, n),
StructureKind::LowRankUpdate { u, v } => inverse_low_rank_update(u, v, n),
_ => inverse_general_cas(m, n),
}
}
fn inverse_diagonal(
m: ArrayView2<Arc<LoweredOp>>,
n: usize,
) -> Result<Array2<Arc<LoweredOp>>, SymbolicLinalgError> {
let zero = Arc::new(lo_const(0.0));
let result = Array2::from_shape_fn((n, n), |(r, c)| {
if r == c {
let diag_entry = cell_op(&m, r, c);
Arc::new(simplify_op(&lo_div(lo_const(1.0), diag_entry)))
} else {
Arc::clone(&zero)
}
});
Ok(result)
}
fn inverse_low_rank_update(
u: Vec<LoweredOp>,
v: Vec<LoweredOp>,
n: usize,
) -> Result<Array2<Arc<LoweredOp>>, SymbolicLinalgError> {
let dot: LoweredOp = (0..n).fold(lo_const(0.0), |acc, i| {
lo_add(acc, lo_mul(v[i].clone(), u[i].clone()))
});
let denom = simplify_op(&lo_add(lo_const(1.0), dot));
let result = Array2::from_shape_fn((n, n), |(r, c)| {
let correction = lo_div(lo_mul(u[r].clone(), v[c].clone()), denom.clone());
let entry = if r == c {
lo_sub(lo_const(1.0), correction)
} else {
lo_sub(lo_const(0.0), correction)
};
Arc::new(simplify_op(&entry))
});
Ok(result)
}
fn inverse_general_cas(
m: ArrayView2<Arc<LoweredOp>>,
n: usize,
) -> Result<Array2<Arc<LoweredOp>>, SymbolicLinalgError> {
match n {
2 => {
let arr: [[LoweredOp; 2]; 2] = [
[cell_op(&m, 0, 0), cell_op(&m, 0, 1)],
[cell_op(&m, 1, 0), cell_op(&m, 1, 1)],
];
match inverse_2x2(&arr) {
InverseResult::Invertible2(inv) => Ok(Array2::from_shape_fn((2, 2), |(r, c)| {
Arc::new(simplify_op(&inv[r][c]))
})),
InverseResult::Singular => Err(SymbolicLinalgError::EvalError(
"matrix is symbolically singular (zero determinant)".to_owned(),
)),
_ => Err(SymbolicLinalgError::EvalError(
"unexpected InverseResult variant for 2x2".to_owned(),
)),
}
}
3 => {
let arr: [[LoweredOp; 3]; 3] = [
[cell_op(&m, 0, 0), cell_op(&m, 0, 1), cell_op(&m, 0, 2)],
[cell_op(&m, 1, 0), cell_op(&m, 1, 1), cell_op(&m, 1, 2)],
[cell_op(&m, 2, 0), cell_op(&m, 2, 1), cell_op(&m, 2, 2)],
];
match inverse_3x3(&arr) {
InverseResult::Invertible3(inv) => Ok(Array2::from_shape_fn((3, 3), |(r, c)| {
Arc::new(simplify_op(&inv[r][c]))
})),
InverseResult::Singular => Err(SymbolicLinalgError::EvalError(
"matrix is symbolically singular (zero determinant)".to_owned(),
)),
_ => Err(SymbolicLinalgError::EvalError(
"unexpected InverseResult variant for 3x3".to_owned(),
)),
}
}
n => Err(SymbolicLinalgError::Unsupported { n, max: 3 }),
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
use scirs2_symbolic::eml::eval::{eval_real, EvalCtx};
fn c(v: f64) -> Arc<LoweredOp> {
Arc::new(LoweredOp::Const(v))
}
fn var(i: usize) -> Arc<LoweredOp> {
Arc::new(LoweredOp::Var(i))
}
fn mul_arc(a: Arc<LoweredOp>, b: Arc<LoweredOp>) -> Arc<LoweredOp> {
Arc::new(LoweredOp::Mul(
Box::new(a.as_ref().clone()),
Box::new(b.as_ref().clone()),
))
}
fn add_arc(a: Arc<LoweredOp>, b: Arc<LoweredOp>) -> Arc<LoweredOp> {
Arc::new(LoweredOp::Add(
Box::new(a.as_ref().clone()),
Box::new(b.as_ref().clone()),
))
}
#[test]
fn recognize_diagonal_2x2() {
let zero = c(0.0);
let mat = Array2::from_shape_fn((2, 2), |(r, col)| match (r, col) {
(0, 0) => var(0),
(1, 1) => var(1),
_ => Arc::clone(&zero),
});
assert_eq!(recognize(mat.view()), StructureKind::Diagonal);
}
#[test]
fn recognize_scalar_2x2() {
let c0 = var(0);
let mat = Array2::from_shape_fn((2, 2), |_| Arc::clone(&c0));
assert_eq!(recognize(mat.view()), StructureKind::Scalar);
}
#[test]
fn recognize_circulant_2x2() {
let c0 = var(0);
let c1 = var(1);
let mat = Array2::from_shape_fn((2, 2), |(r, col)| {
if r == col {
Arc::clone(&c0)
} else {
Arc::clone(&c1)
}
});
match recognize(mat.view()) {
StructureKind::Circulant { first_row } => {
assert_eq!(first_row.len(), 2);
assert_eq!(first_row[0].structural_hash(), c0.structural_hash());
assert_eq!(first_row[1].structural_hash(), c1.structural_hash());
}
other => panic!("expected Circulant, got {other:?}"),
}
}
#[test]
fn recognize_circulant_3x3() {
let first_row = [var(0), var(1), var(2)];
let n = 3usize;
let mat =
Array2::from_shape_fn((n, n), |(r, col)| Arc::clone(&first_row[(col + n - r) % n]));
match recognize(mat.view()) {
StructureKind::Circulant { first_row: fr } => {
assert_eq!(fr.len(), 3);
}
other => panic!("expected Circulant(3×3), got {other:?}"),
}
}
#[test]
fn recognize_low_rank_update_2x2() {
let u0 = var(0);
let u1 = var(1);
let v0 = var(2);
let v1 = var(3);
let mat = Array2::from_shape_fn((2, 2), |(r, col)| match (r, col) {
(0, 0) => add_arc(c(1.0), mul_arc(Arc::clone(&u0), Arc::clone(&v0))),
(0, 1) => mul_arc(Arc::clone(&u0), Arc::clone(&v1)),
(1, 0) => mul_arc(Arc::clone(&u1), Arc::clone(&v0)),
(1, 1) => add_arc(c(1.0), mul_arc(Arc::clone(&u1), Arc::clone(&v1))),
_ => unreachable!(),
});
match recognize(mat.view()) {
StructureKind::LowRankUpdate { u, v } => {
assert_eq!(u.len(), 2);
assert_eq!(v.len(), 2);
}
other => panic!("expected LowRankUpdate, got {other:?}"),
}
}
#[test]
fn inverse_diagonal_numeric_eval() {
let zero = c(0.0);
let mat = Array2::from_shape_fn((2, 2), |(r, col)| match (r, col) {
(0, 0) => var(0),
(1, 1) => var(1),
_ => Arc::clone(&zero),
});
let inv = inverse_by_structure(mat.view()).expect("inverse");
let ctx = EvalCtx::new(&[2.0, 3.0]);
let inv00 = eval_real(inv[[0, 0]].as_ref(), &ctx).expect("eval");
let inv11 = eval_real(inv[[1, 1]].as_ref(), &ctx).expect("eval");
let inv01 = eval_real(inv[[0, 1]].as_ref(), &ctx).expect("eval");
let inv10 = eval_real(inv[[1, 0]].as_ref(), &ctx).expect("eval");
assert!((inv00 - 0.5).abs() < 1e-10, "inv[0,0]={inv00}");
assert!((inv11 - 1.0 / 3.0).abs() < 1e-10, "inv[1,1]={inv11}");
assert!(inv01.abs() < 1e-10, "inv[0,1]={inv01}");
assert!(inv10.abs() < 1e-10, "inv[1,0]={inv10}");
}
#[test]
fn inverse_low_rank_update_numeric_eval() {
let u = [c(1.0), c(2.0)];
let v = [c(3.0), c(4.0)];
let mat = Array2::from_shape_fn((2, 2), |(r, col)| match (r, col) {
(0, 0) => add_arc(c(1.0), mul_arc(Arc::clone(&u[0]), Arc::clone(&v[0]))),
(0, 1) => mul_arc(Arc::clone(&u[0]), Arc::clone(&v[1])),
(1, 0) => mul_arc(Arc::clone(&u[1]), Arc::clone(&v[0])),
(1, 1) => add_arc(c(1.0), mul_arc(Arc::clone(&u[1]), Arc::clone(&v[1]))),
_ => unreachable!(),
});
assert!(
matches!(recognize(mat.view()), StructureKind::LowRankUpdate { .. }),
"should detect LowRankUpdate"
);
let inv = inverse_by_structure(mat.view()).expect("inverse");
let ctx = EvalCtx::new(&[]);
let inv00 = eval_real(inv[[0, 0]].as_ref(), &ctx).expect("eval");
let inv01 = eval_real(inv[[0, 1]].as_ref(), &ctx).expect("eval");
let inv10 = eval_real(inv[[1, 0]].as_ref(), &ctx).expect("eval");
let inv11 = eval_real(inv[[1, 1]].as_ref(), &ctx).expect("eval");
assert!((inv00 - 0.75).abs() < 1e-10, "inv[0,0]={inv00}");
assert!((inv01 - (-1.0 / 3.0)).abs() < 1e-10, "inv[0,1]={inv01}");
assert!((inv10 - (-0.5)).abs() < 1e-10, "inv[1,0]={inv10}");
assert!((inv11 - (1.0 / 3.0)).abs() < 1e-10, "inv[1,1]={inv11}");
}
#[test]
fn recognize_general_2x2() {
let mat = Array2::from_shape_fn((2, 2), |(r, col)| match (r, col) {
(0, 0) => var(0),
(0, 1) => var(1),
(1, 0) => var(2),
(1, 1) => c(1.0),
_ => unreachable!(),
});
assert_eq!(recognize(mat.view()), StructureKind::General);
}
}