use crate::array::Array;
#[allow(unused_imports)] use crate::error::{NumRs2Error, Result};
#[cfg(feature = "lapack")]
use num_traits::{Float, NumCast, Zero};
#[cfg(feature = "lapack")]
use scirs2_core::linalg::{
eig_ndarray, eig_symmetric, eigvals_ndarray, eigvals_symmetric, Eigenvalue,
};
#[allow(unused_imports)] use scirs2_core::ndarray::ArrayView2;
use scirs2_core::Complex;
#[cfg(feature = "lapack")]
use std::fmt::Debug;
pub type EigResult<T> = (Array<Complex<T>>, Array<Complex<T>>);
#[cfg(feature = "lapack")]
pub fn eigh<T>(a: &Array<T>, _uplo: &str) -> Result<(Array<T>, Array<T>)>
where
T: Float + Clone + Debug + 'static,
{
let shape = a.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(NumRs2Error::DimensionMismatch(
"eigendecomposition requires a square matrix".to_string(),
));
}
let a_view: ArrayView2<T> = a.view_2d()?;
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let a_f64 = unsafe { std::mem::transmute::<ArrayView2<T>, ArrayView2<f64>>(a_view) };
let result = eig_symmetric(&a_f64.to_owned()).map_err(|e| {
NumRs2Error::ComputationError(format!("Eigendecomposition failed: {:?}", e))
})?;
let eigenvalues = unsafe {
std::mem::transmute::<scirs2_core::ndarray::Array1<f64>, scirs2_core::ndarray::Array1<T>>(
result.eigenvalues,
)
};
let eigenvectors = unsafe {
std::mem::transmute::<scirs2_core::ndarray::Array2<f64>, scirs2_core::ndarray::Array2<T>>(
result.eigenvectors,
)
};
let eigenvalues_converted = Array::from_ndarray(eigenvalues.into_dyn());
let eigenvectors_converted = Array::from_ndarray(eigenvectors.into_dyn());
return Ok((eigenvalues_converted, eigenvectors_converted));
}
let mut a_f64 = scirs2_core::ndarray::Array2::<f64>::zeros((a_view.nrows(), a_view.ncols()));
for i in 0..a_view.nrows() {
for j in 0..a_view.ncols() {
a_f64[[i, j]] = a_view[[i, j]].to_f64().ok_or_else(|| {
NumRs2Error::ComputationError("Cannot convert to f64".to_string())
})?;
}
}
let result = eig_symmetric(&a_f64).map_err(|e| {
NumRs2Error::ComputationError(format!("Eigendecomposition failed: {:?}", e))
})?;
let eigenvalues: Vec<T> = result
.eigenvalues
.iter()
.map(|&v| {
T::from(v).ok_or_else(|| NumRs2Error::ComputationError("Conversion failed".to_string()))
})
.collect::<Result<Vec<T>>>()?;
let mut eigenvectors: Vec<T> = Vec::with_capacity(result.eigenvectors.len());
for &v in result.eigenvectors.iter() {
eigenvectors.push(
T::from(v)
.ok_or_else(|| NumRs2Error::ComputationError("Conversion failed".to_string()))?,
);
}
let n = a_view.nrows();
let eigenvalues_converted = Array::from_vec(eigenvalues);
let eigenvectors_converted = Array::from_vec(eigenvectors).reshape(&[n, n]);
Ok((eigenvalues_converted, eigenvectors_converted))
}
#[cfg(feature = "lapack")]
pub fn eigvalsh<T>(a: &Array<T>, _uplo: &str) -> Result<Array<T>>
where
T: Float + Clone + Debug + 'static,
{
let shape = a.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(NumRs2Error::DimensionMismatch(
"eigendecomposition requires a square matrix".to_string(),
));
}
let a_view: ArrayView2<T> = a.view_2d()?;
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let a_f64 = unsafe { std::mem::transmute::<ArrayView2<T>, ArrayView2<f64>>(a_view) };
let result = eigvals_symmetric(&a_f64.to_owned()).map_err(|e| {
NumRs2Error::ComputationError(format!("Eigenvalue computation failed: {:?}", e))
})?;
let eigenvalues = unsafe {
std::mem::transmute::<scirs2_core::ndarray::Array1<f64>, scirs2_core::ndarray::Array1<T>>(
result,
)
};
return Ok(Array::from_ndarray(eigenvalues.into_dyn()));
}
let mut a_f64 = scirs2_core::ndarray::Array2::<f64>::zeros((a_view.nrows(), a_view.ncols()));
for i in 0..a_view.nrows() {
for j in 0..a_view.ncols() {
a_f64[[i, j]] = a_view[[i, j]].to_f64().ok_or_else(|| {
NumRs2Error::ComputationError("Cannot convert to f64".to_string())
})?;
}
}
let result = eigvals_symmetric(&a_f64).map_err(|e| {
NumRs2Error::ComputationError(format!("Eigenvalue computation failed: {:?}", e))
})?;
let eigenvalues: Vec<T> = result
.iter()
.map(|&v| {
T::from(v).ok_or_else(|| NumRs2Error::ComputationError("Conversion failed".to_string()))
})
.collect::<Result<Vec<T>>>()?;
Ok(Array::from_vec(eigenvalues))
}
#[cfg(feature = "lapack")]
pub fn eig<T>(a: &Array<T>) -> Result<EigResult<T>>
where
T: Float + Clone + Debug,
{
let shape = a.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(NumRs2Error::DimensionMismatch(
"eigendecomposition requires a square matrix".to_string(),
));
}
let a_view: ArrayView2<T> = a.view_2d()?;
let n = a_view.nrows();
let mut a_f64 = scirs2_core::ndarray::Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..n {
a_f64[[i, j]] = a_view[[i, j]].to_f64().ok_or_else(|| {
NumRs2Error::ComputationError("Cannot convert to f64".to_string())
})?;
}
}
let result = eig_ndarray(&a_f64).map_err(|e| {
NumRs2Error::ComputationError(format!("Eigendecomposition failed: {:?}", e))
})?;
let vals_vec: Vec<Complex<T>> = result
.eigenvalues
.iter()
.map(|e| {
let re = T::from(e.real)
.ok_or_else(|| NumRs2Error::ComputationError("Conversion failed".to_string()))?;
let im = T::from(e.imag)
.ok_or_else(|| NumRs2Error::ComputationError("Conversion failed".to_string()))?;
Ok(Complex::new(re, im))
})
.collect::<Result<Vec<Complex<T>>>>()?;
let eigvecs_real = result.eigenvectors_real.ok_or_else(|| {
NumRs2Error::ComputationError("Eigenvectors real part missing".to_string())
})?;
let eigvecs_imag = result.eigenvectors_imag.ok_or_else(|| {
NumRs2Error::ComputationError("Eigenvectors imag part missing".to_string())
})?;
let mut vecs_vec: Vec<Complex<T>> = Vec::with_capacity(n * n);
for i in 0..n {
for j in 0..n {
let re = T::from(eigvecs_real[[i, j]])
.ok_or_else(|| NumRs2Error::ComputationError("Conversion failed".to_string()))?;
let im = T::from(eigvecs_imag[[i, j]])
.ok_or_else(|| NumRs2Error::ComputationError("Conversion failed".to_string()))?;
vecs_vec.push(Complex::new(re, im));
}
}
let eigenvalues_converted = Array::from_vec(vals_vec);
let eigenvectors_converted = Array::from_vec(vecs_vec).reshape(&[n, n]);
Ok((eigenvalues_converted, eigenvectors_converted))
}
#[cfg(feature = "lapack")]
pub fn eigvals<T>(a: &Array<T>) -> Result<Array<Complex<T>>>
where
T: Float + Clone + Debug,
{
let shape = a.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(NumRs2Error::DimensionMismatch(
"eigendecomposition requires a square matrix".to_string(),
));
}
let a_view: ArrayView2<T> = a.view_2d()?;
let n = a_view.nrows();
let mut a_f64 = scirs2_core::ndarray::Array2::<f64>::zeros((n, n));
for i in 0..n {
for j in 0..n {
a_f64[[i, j]] = a_view[[i, j]].to_f64().ok_or_else(|| {
NumRs2Error::ComputationError("Cannot convert to f64".to_string())
})?;
}
}
let result = eigvals_ndarray(&a_f64).map_err(|e| {
NumRs2Error::ComputationError(format!("Eigenvalue computation failed: {:?}", e))
})?;
let vals_vec: Vec<Complex<T>> = result
.iter()
.map(|e| {
let re = T::from(e.real)
.ok_or_else(|| NumRs2Error::ComputationError("Conversion failed".to_string()))?;
let im = T::from(e.imag)
.ok_or_else(|| NumRs2Error::ComputationError("Conversion failed".to_string()))?;
Ok(Complex::new(re, im))
})
.collect::<Result<Vec<Complex<T>>>>()?;
Ok(Array::from_vec(vals_vec))
}
#[cfg(feature = "lapack")]
pub fn is_positive_definite<T>(a: &Array<T>) -> Result<bool>
where
T: Float + Clone + Debug + PartialOrd + Zero + 'static,
{
let eigenvalues = eigvalsh(a, "lower")?;
let eigenvalues_vec = eigenvalues.to_vec();
let zero = T::zero();
Ok(eigenvalues_vec.iter().all(|&x| x > zero))
}
#[cfg(feature = "lapack")]
impl<T> Array<T>
where
T: Float + Clone + Debug + 'static,
{
pub fn eigh(&self, uplo: &str) -> Result<(Array<T>, Array<T>)> {
eigh(self, uplo)
}
pub fn eigvalsh(&self, uplo: &str) -> Result<Array<T>> {
eigvalsh(self, uplo)
}
pub fn eig_general(&self) -> Result<EigResult<T>> {
eig(self)
}
pub fn eigvals(&self) -> Result<Array<Complex<T>>> {
eigvals(self)
}
pub fn is_positive_definite(&self) -> Result<bool>
where
T: PartialOrd + Zero,
{
is_positive_definite(self)
}
}
#[cfg(all(test, feature = "lapack"))]
mod tests {
use super::*;
#[test]
fn test_symmetric_eigenvalues() {
let a =
Array::from_vec(vec![2.0, -1.0, 0.0, -1.0, 2.0, -1.0, 0.0, -1.0, 2.0]).reshape(&[3, 3]);
let eigenvalues = eigvalsh(&a, "lower").expect("eigvalsh should succeed");
assert_eq!(eigenvalues.shape(), vec![3]);
let eig_data = eigenvalues.to_vec();
let expected = [2.0 - 2.0_f64.sqrt(), 2.0, 2.0 + 2.0_f64.sqrt()];
for i in 0..3 {
assert!(num_traits::Float::abs(eig_data[i] - expected[i]) < 1e-10);
}
}
#[test]
fn test_symmetric_eigenvectors() {
let a = Array::from_vec(vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]).reshape(&[3, 3]);
let (eigenvalues, eigenvectors) = eigh(&a, "lower").expect("eigh should succeed");
assert_eq!(eigenvalues.shape(), vec![3]);
assert_eq!(eigenvectors.shape(), vec![3, 3]);
let eig_data = eigenvalues.to_vec();
assert!(num_traits::Float::abs(eig_data[0] - 1.0) < 1e-10);
assert!(num_traits::Float::abs(eig_data[1] - 2.0) < 1e-10);
assert!(num_traits::Float::abs(eig_data[2] - 3.0) < 1e-10);
let vecs = eigenvectors.to_vec();
for i in 0..3 {
let mut norm_squared = 0.0;
for j in 0..3 {
norm_squared += vecs[j * 3 + i] * vecs[j * 3 + i];
}
assert!(num_traits::Float::abs(norm_squared - 1.0) < 1e-10);
}
}
#[test]
fn test_general_eigenvalues() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]).reshape(&[3, 3]);
let eigenvalues = eigvals(&a).expect("eigvals should succeed for general matrix");
assert_eq!(eigenvalues.shape(), vec![3]);
assert_eq!(eigenvalues.size(), 3);
let mut has_large_eigenvalue = false;
for eigenvalue in eigenvalues.to_vec() {
if eigenvalue.re > 15.0 {
has_large_eigenvalue = true;
break;
}
}
assert!(has_large_eigenvalue);
}
#[test]
fn test_complex_eigenvalues() {
let theta = std::f64::consts::PI / 4.0; let a = Array::from_vec(vec![theta.cos(), -theta.sin(), theta.sin(), theta.cos()])
.reshape(&[2, 2]);
let eigenvalues = eigvals(&a).expect("eigvals should succeed for rotation matrix");
assert_eq!(eigenvalues.shape(), vec![2]);
let eig_data = eigenvalues.to_vec();
for eigenvalue in eig_data {
let magnitude = (eigenvalue.re * eigenvalue.re + eigenvalue.im * eigenvalue.im).sqrt();
assert!((magnitude - 1.0).abs() < 1e-10);
}
}
#[test]
fn test_positive_definite() {
let a =
Array::from_vec(vec![2.0, -1.0, 0.0, -1.0, 2.0, -1.0, 0.0, -1.0, 2.0]).reshape(&[3, 3]);
let is_pd = a
.is_positive_definite()
.expect("is_positive_definite should succeed");
assert!(is_pd);
let b = Array::from_vec(vec![1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0]).reshape(&[3, 3]);
let is_pd = b
.is_positive_definite()
.expect("is_positive_definite should succeed");
assert!(!is_pd);
}
}