use faer_traits::ComplexField;
use mdarray::{DSlice, Dense, Layout, tensor};
use mdarray_linalg::eig::{Eig, EigDecomp, EigError, EigResult, SchurError, SchurResult};
use num_complex::{Complex, ComplexFloat};
use crate::{Faer, into_faer, into_faer_mut};
macro_rules! complex_from_faer {
($val:expr, $t:ty) => {{
let re: <$t as ComplexFloat>::Real = unsafe { std::mem::transmute_copy(&($val.re)) };
let im: <$t as ComplexFloat>::Real = unsafe { std::mem::transmute_copy(&($val.im)) };
Complex::new(re, im)
}};
}
impl<T> Eig<T> for Faer
where
T: ComplexFloat
+ ComplexField
+ Default
+ std::convert::From<<T as num_complex::ComplexFloat>::Real>
+ 'static,
{
fn eig<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> EigResult<T> {
let (m, n) = *a.shape();
if m != n {
return Err(EigError::NotSquareMatrix);
}
let a_faer = into_faer(a);
let eig_result = a_faer.eigen();
match eig_result {
Ok(eig) => {
let eigenvalues = eig.S();
let right_vecs = eig.U();
let x = T::default();
let mut eigenvalues_mda = tensor![[Complex::new(x.re(), x.re()); n]; 1];
let mut right_vecs_mda = tensor![[Complex::new(x.re(), x.re());n]; n];
for i in 0..n {
eigenvalues_mda[[0, i]] = complex_from_faer!(&eigenvalues[i], T);
}
for i in 0..n {
for j in 0..n {
right_vecs_mda[[i, j]] = complex_from_faer!(&right_vecs[(i, j)], T);
}
}
Ok(EigDecomp {
eigenvalues: eigenvalues_mda,
left_eigenvectors: None,
right_eigenvectors: Some(right_vecs_mda),
})
}
Err(_) => Err(EigError::BackendDidNotConverge { iterations: 0 }),
}
}
fn eig_full<L: Layout>(&self, _a: &mut DSlice<T, 2, L>) -> Result<EigDecomp<T>, EigError> {
todo!();
}
fn eig_values<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> EigResult<T> {
let (m, n) = *a.shape();
if m != n {
return Err(EigError::NotSquareMatrix);
}
let a_faer = into_faer(a);
let eigenvalues_result = a_faer.eigenvalues();
match eigenvalues_result {
Ok(eigenvalues) => {
let x = T::default();
let mut eigenvalues_mda = tensor![[Complex::new(x.re(), x.re()); n]; 1];
for i in 0..n {
eigenvalues_mda[[0, i]] = complex_from_faer!(&eigenvalues[i], T);
}
Ok(EigDecomp {
eigenvalues: eigenvalues_mda,
left_eigenvectors: None,
right_eigenvectors: None,
})
}
Err(_) => Err(EigError::BackendDidNotConverge { iterations: 0 }),
}
}
fn eigh<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> EigResult<T> {
let (m, n) = *a.shape();
if m != n {
return Err(EigError::NotSquareMatrix);
}
let a_faer = into_faer(a);
let eig_result = a_faer.self_adjoint_eigen(faer::Side::Lower);
match eig_result {
Ok(eig) => {
let eigenvalues = eig.S();
let eigenvectors = eig.U();
let x = T::default();
let mut eigenvalues_mda = tensor![[Complex::new(x.re(), x.re()); n]; 1];
let mut eigenvalues_faer = into_faer_mut(&mut eigenvalues_mda);
for i in 0..n {
eigenvalues_faer[(0, i)] = Complex::new(eigenvalues[i].re(), x.re());
}
let mut right_vecs_mda = tensor![[Complex::new(x.re(), x.re());n]; n];
let mut eigenvectors_faer = into_faer_mut(&mut right_vecs_mda);
for i in 0..n {
for j in 0..n {
let val = eigenvectors[(i, j)];
eigenvectors_faer[(i, j)] = Complex::new(val.re(), val.im());
}
}
Ok(EigDecomp {
eigenvalues: eigenvalues_mda,
left_eigenvectors: None,
right_eigenvectors: Some(right_vecs_mda),
})
}
Err(_) => Err(EigError::BackendDidNotConverge { iterations: 0 }),
}
}
fn eigs<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> EigResult<T> {
self.eigh(a)
}
fn schur<L: Layout>(&self, _a: &mut DSlice<T, 2, L>) -> SchurResult<T> {
todo!();
}
fn schur_overwrite<L: Layout>(
&self,
_a: &mut DSlice<T, 2, L>,
_t: &mut DSlice<T, 2, Dense>,
_z: &mut DSlice<T, 2, Dense>,
) -> Result<(), SchurError> {
todo!();
}
fn schur_complex<L: Layout>(&self, _a: &mut DSlice<T, 2, L>) -> SchurResult<T> {
todo!();
}
fn schur_complex_overwrite<L: Layout>(
&self,
_a: &mut DSlice<T, 2, L>,
_t: &mut DSlice<T, 2, Dense>,
_z: &mut DSlice<T, 2, Dense>,
) -> Result<(), SchurError> {
todo!();
}
}