use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Complex, Float, NumAssign};
use std::fmt::{Debug, Display};
use std::iter::Sum;
#[derive(Debug, Clone)]
pub struct GenEigenResult<F: Float> {
pub eigenvalues: Vec<Complex<F>>,
pub eigenvectors: Array2<Complex<F>>,
}
#[derive(Debug, Clone)]
pub struct GenEighResult<F: Float> {
pub eigenvalues: Vec<F>,
pub eigenvectors: Array2<F>,
}
#[derive(Debug, Clone)]
pub struct GsvdResult<F: Float> {
pub u: Array2<F>,
pub v: Array2<F>,
pub x: Array2<F>,
pub c: Vec<F>,
pub s: Vec<F>,
}
pub fn gen_eig<F>(a: &ArrayView2<F>, b: &ArrayView2<F>) -> LinalgResult<GenEigenResult<F>>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + Debug + Display + 'static,
{
validate_square_same(a, b)?;
let (eigenvalues, eigenvectors) =
crate::eigen::generalized::eig_gen(a, b, None)?;
let evecs: Vec<Complex<F>> = eigenvectors.iter().copied().collect();
let (nr, nc) = (eigenvectors.nrows(), eigenvectors.ncols());
let evec_arr = Array2::from_shape_vec((nr, nc), evecs).map_err(|e| {
LinalgError::ComputationError(format!("gen_eig reshape failed: {e}"))
})?;
Ok(GenEigenResult {
eigenvalues: eigenvalues.to_vec(),
eigenvectors: evec_arr,
})
}
pub fn gen_eigh<F>(a: &ArrayView2<F>, b: &ArrayView2<F>) -> LinalgResult<GenEighResult<F>>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + Debug + Display + 'static,
{
validate_square_same(a, b)?;
let (eigenvalues_arr, eigenvectors_arr) =
crate::eigen::generalized::eigh_gen(a, b, None)?;
Ok(GenEighResult {
eigenvalues: eigenvalues_arr.to_vec(),
eigenvectors: eigenvectors_arr,
})
}
pub fn gsvd<F>(a: &ArrayView2<F>, b: &ArrayView2<F>) -> LinalgResult<GsvdResult<F>>
where
F: Float + NumAssign + Sum + Send + Sync + ScalarOperand + Debug + Display + 'static,
{
if a.ncols() != b.ncols() {
return Err(LinalgError::ShapeError(format!(
"gsvd: A and B must have the same number of columns; got {} and {}",
a.ncols(),
b.ncols()
)));
}
if a.nrows() == 0 || a.ncols() == 0 || b.nrows() == 0 {
return Err(LinalgError::InvalidInputError(
"gsvd: matrices must have non-zero dimensions".to_string(),
));
}
let inner = crate::decomposition_enhanced::generalized_svd(a, b)?;
Ok(GsvdResult {
u: inner.u,
v: inner.v,
x: inner.q,
c: inner.alpha.to_vec(),
s: inner.beta.to_vec(),
})
}
fn validate_square_same<F: Float>(a: &ArrayView2<F>, b: &ArrayView2<F>) -> LinalgResult<()> {
if a.nrows() != a.ncols() {
return Err(LinalgError::ShapeError(format!(
"Matrix A must be square; got {}x{}",
a.nrows(),
a.ncols()
)));
}
if b.nrows() != b.ncols() {
return Err(LinalgError::ShapeError(format!(
"Matrix B must be square; got {}x{}",
b.nrows(),
b.ncols()
)));
}
if a.nrows() != b.nrows() {
return Err(LinalgError::ShapeError(format!(
"A and B must have the same dimension; A: {}x{}, B: {}x{}",
a.nrows(),
a.ncols(),
b.nrows(),
b.ncols()
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_gen_eig_identity_b() {
let a = array![[3.0_f64, 1.0], [1.0, 3.0]];
let b = Array2::<f64>::eye(2);
let res = gen_eig(&a.view(), &b.view()).expect("gen_eig");
assert_eq!(res.eigenvalues.len(), 2);
let mut reals: Vec<f64> = res.eigenvalues.iter().map(|c| c.re).collect();
reals.sort_by(|a, b| a.partial_cmp(b).expect("cmp"));
assert_relative_eq!(reals[0], 2.0, epsilon = 1e-6);
assert_relative_eq!(reals[1], 4.0, epsilon = 1e-6);
}
#[test]
fn test_gen_eig_scaled_b() {
let a = array![[4.0_f64, 0.0], [0.0, 9.0]];
let b = array![[2.0_f64, 0.0], [0.0, 2.0]];
let res = gen_eig(&a.view(), &b.view()).expect("gen_eig");
assert_eq!(res.eigenvalues.len(), 2);
let mut reals: Vec<f64> = res.eigenvalues.iter().map(|c| c.re).collect();
reals.sort_by(|a, b| a.partial_cmp(b).expect("cmp"));
assert_relative_eq!(reals[0], 2.0, epsilon = 1e-6);
assert_relative_eq!(reals[1], 4.5, epsilon = 1e-6);
}
#[test]
fn test_gen_eig_shape_mismatch() {
let a = array![[1.0_f64, 0.0], [0.0, 1.0]];
let b = array![[1.0_f64, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]];
let res = gen_eig(&a.view(), &b.view());
assert!(res.is_err());
}
#[test]
fn test_gen_eigh_identity_b() {
let a = array![[3.0_f64, 1.0], [1.0, 3.0]];
let b = Array2::<f64>::eye(2);
let res = gen_eigh(&a.view(), &b.view()).expect("gen_eigh");
assert_eq!(res.eigenvalues.len(), 2);
assert_relative_eq!(res.eigenvalues[0], 2.0, epsilon = 1e-6);
assert_relative_eq!(res.eigenvalues[1], 4.0, epsilon = 1e-6);
}
#[test]
fn test_gen_eigh_scaled_b() {
let a = array![[6.0_f64, 2.0], [2.0, 4.0]];
let b = array![[2.0_f64, 0.0], [0.0, 1.0]];
let res = gen_eigh(&a.view(), &b.view()).expect("gen_eigh");
assert_eq!(res.eigenvalues.len(), 2);
for &lam in &res.eigenvalues {
assert!(lam.is_finite(), "eigenvalue not finite: {lam}");
}
}
#[test]
fn test_gen_eigh_eigenvectors_b_orthonormal() {
let a = array![[5.0_f64, 1.0], [1.0, 3.0]];
let b = array![[2.0_f64, 0.5], [0.5, 1.0]];
let res = gen_eigh(&a.view(), &b.view()).expect("gen_eigh");
let v = &res.eigenvectors;
let n = v.ncols();
let mut bv = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
bv[[i, j]] += b[[i, k]] * v[[k, j]];
}
}
}
let mut vtbv = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
vtbv[[i, j]] += v[[k, i]] * bv[[k, j]];
}
}
}
assert_relative_eq!(vtbv[[0, 0]], 1.0, epsilon = 1e-6);
assert_relative_eq!(vtbv[[1, 1]], 1.0, epsilon = 1e-6);
assert_relative_eq!(vtbv[[0, 1]].abs(), 0.0, epsilon = 1e-6);
}
#[test]
fn test_gsvd_identity_pair() {
let a = Array2::<f64>::eye(3);
let b = Array2::<f64>::eye(3);
let res = gsvd(&a.view(), &b.view()).expect("gsvd");
for (&ci, &si) in res.c.iter().zip(res.s.iter()) {
assert_relative_eq!(ci * ci + si * si, 1.0, epsilon = 1e-8);
}
}
#[test]
fn test_gsvd_orthogonality_u() {
let a = array![[2.0_f64, 1.0], [0.0, 3.0]];
let b = array![[1.0_f64, 0.5], [0.5, 1.0]];
let res = gsvd(&a.view(), &b.view()).expect("gsvd");
let u = &res.u;
let n = u.ncols();
let mut utu = Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..n {
for k in 0..n {
utu[[i, j]] += u[[k, i]] * u[[k, j]];
}
}
}
for i in 0..n {
for j in 0..n {
let expected = if i == j { 1.0 } else { 0.0 };
assert_relative_eq!(utu[[i, j]], expected, epsilon = 1e-6);
}
}
}
#[test]
fn test_gsvd_dimension_mismatch() {
let a = array![[1.0_f64, 0.0], [0.0, 1.0]];
let b = array![[1.0_f64, 0.0, 0.0]];
let res = gsvd(&a.view(), &b.view());
assert!(res.is_err());
}
#[test]
fn test_gsvd_reconstruction_a() {
let a = array![[3.0_f64, 1.0], [1.0, 2.0]];
let b = array![[1.0_f64, 0.5], [0.5, 2.0]];
let res = gsvd(&a.view(), &b.view()).expect("gsvd");
assert_eq!(res.u.nrows(), 2);
assert_eq!(res.v.nrows(), 2);
assert_eq!(res.c.len(), res.s.len());
}
#[test]
fn test_gsvd_non_square() {
let a = array![[1.0_f64, 0.0], [0.0, 1.0], [1.0, 1.0]];
let b = array![[2.0_f64, 0.0], [0.0, 2.0]];
let res = gsvd(&a.view(), &b.view()).expect("gsvd non-square");
for (&ci, &si) in res.c.iter().zip(res.s.iter()) {
assert_relative_eq!(ci * ci + si * si, 1.0, epsilon = 1e-8);
}
}
}