use faer_traits::ComplexField;
use mdarray::{DSlice, Layout, tensor};
use mdarray_linalg::identity;
use mdarray_linalg::solve::{Solve, SolveError, SolveResult, SolveResultType};
use num_complex::ComplexFloat;
use faer::linalg::solvers::Solve as FaerSolve;
use crate::{Faer, into_faer, into_faer_mut};
impl<T> Solve<T> for Faer
where
T: ComplexFloat
+ ComplexField
+ Default
+ std::convert::From<<T as num_complex::ComplexFloat>::Real>
+ 'static,
{
fn solve<La: Layout, Lb: Layout>(
&self,
a: &mut DSlice<T, 2, La>,
b: &DSlice<T, 2, Lb>,
) -> SolveResultType<T> {
let (m, n) = *a.shape();
let (b_m, b_n) = *b.shape();
if m != n {
return Err(SolveError::InvalidDimensions);
}
if b_m != m {
return Err(SolveError::InvalidDimensions);
}
let a_faer = into_faer_mut(a);
let solver = a_faer.partial_piv_lu();
let b_faer = into_faer(b);
let x_faer = solver.solve(b_faer);
let mut x_mda = tensor![[T::default(); b_n]; m];
let mut x_faer_mut = into_faer_mut(&mut x_mda);
for i in 0..m {
for j in 0..b_n {
x_faer_mut[(i, j)] = x_faer[(i, j)];
}
}
let p_mda = identity(m);
Ok(SolveResult { x: x_mda, p: p_mda })
}
fn solve_overwrite<La: Layout, Lb: Layout, Lp: Layout>(
&self,
a: &mut DSlice<T, 2, La>,
b: &mut DSlice<T, 2, Lb>,
p: &mut DSlice<T, 2, Lp>,
) -> Result<(), SolveError> {
let (m, n) = *a.shape();
let (b_m, b_n) = *b.shape();
if m != n {
return Err(SolveError::InvalidDimensions);
}
if b_m != m {
return Err(SolveError::InvalidDimensions);
}
let _par = faer::get_global_parallelism();
let a_faer = into_faer(a);
let solver = a_faer.partial_piv_lu();
let b_faer = into_faer(b).to_owned();
let x_faer = solver.solve(b_faer);
let mut b_faer_mut = into_faer_mut(b);
for i in 0..m {
for j in 0..b_n {
b_faer_mut[(i, j)] = x_faer[(i, j)];
}
}
let mut p_faer = into_faer_mut(p);
for i in 0..m {
for j in 0..m {
if i != j {
p_faer[(i, j)] = T::zero();
} else {
p_faer[(i, j)] = T::one();
}
}
}
Ok(())
}
}