use super::simple::svd_faer;
use faer_traits::ComplexField;
use mdarray::{DSlice, DTensor, Dense, Layout, tensor};
use mdarray_linalg::svd::{SVD, SVDDecomp, SVDError};
use num_complex::ComplexFloat;
use crate::Faer;
impl<T> SVD<T> for Faer
where
T: ComplexFloat
+ ComplexField
+ Default
+ std::convert::From<<T as num_complex::ComplexFloat>::Real>
+ 'static,
{
fn svd<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> Result<SVDDecomp<T>, SVDError> {
let (m, n) = *a.shape();
let min_mn = m.min(n);
let mut s_mda = tensor![[T::default(); min_mn]; min_mn];
let mut u_mda = tensor![[T::default(); m]; m];
let mut vt_mda = tensor![[T::default(); n]; n];
match svd_faer(a, &mut s_mda, Some(&mut u_mda), Some(&mut vt_mda)) {
Err(_) => Err(SVDError::BackendDidNotConverge {
superdiagonals: (0),
}),
Ok(_) => Ok(SVDDecomp {
s: s_mda,
u: u_mda,
vt: vt_mda,
}),
}
}
fn svd_s<L: Layout>(&self, a: &mut DSlice<T, 2, L>) -> Result<DTensor<T, 2>, SVDError> {
let (m, n) = *a.shape();
let min_mn = m.min(n);
let mut s_mda = tensor![[T::default(); min_mn]; min_mn];
match svd_faer::<T, L, Dense, Dense, Dense>(a, &mut s_mda, None, None) {
Err(_) => Err(SVDError::BackendDidNotConverge {
superdiagonals: (0),
}),
Ok(_) => Ok(s_mda),
}
}
fn svd_overwrite<L: Layout, Ls: Layout, Lu: Layout, Lvt: Layout>(
&self,
a: &mut DSlice<T, 2, L>,
s: &mut DSlice<T, 2, Ls>,
u: &mut DSlice<T, 2, Lu>,
vt: &mut DSlice<T, 2, Lvt>,
) -> Result<(), SVDError> {
svd_faer::<T, L, Ls, Lu, Lvt>(a, s, Some(u), Some(vt))
}
fn svd_overwrite_s<L: Layout, Ls: Layout>(
&self,
a: &mut DSlice<T, 2, L>,
s: &mut DSlice<T, 2, Ls>,
) -> Result<(), SVDError> {
svd_faer::<T, L, Ls, Dense, Dense>(a, s, None, None)
}
}