use crate::fpu_check::FpuGuard;
use crate::gemm::GemmBackendHandle;
use mdarray::{DTensor, DynRank, Shape, Slice, ViewMut};
use num_complex::Complex;
pub trait InplaceFitter {
fn n_points(&self) -> usize;
fn basis_size(&self) -> usize;
fn evaluate_nd_dd_to(
&self,
backend: Option<&GemmBackendHandle>,
coeffs: &Slice<f64, DynRank>,
dim: usize,
out: &mut ViewMut<'_, f64, DynRank>,
) -> bool {
let _ = (backend, coeffs, dim, out);
false
}
fn evaluate_nd_dz_to(
&self,
backend: Option<&GemmBackendHandle>,
coeffs: &Slice<f64, DynRank>,
dim: usize,
out: &mut ViewMut<'_, Complex<f64>, DynRank>,
) -> bool {
let _ = (backend, coeffs, dim, out);
false
}
fn evaluate_nd_zd_to(
&self,
backend: Option<&GemmBackendHandle>,
coeffs: &Slice<Complex<f64>, DynRank>,
dim: usize,
out: &mut ViewMut<'_, f64, DynRank>,
) -> bool {
let _ = (backend, coeffs, dim, out);
false
}
fn evaluate_nd_zz_to(
&self,
backend: Option<&GemmBackendHandle>,
coeffs: &Slice<Complex<f64>, DynRank>,
dim: usize,
out: &mut ViewMut<'_, Complex<f64>, DynRank>,
) -> bool {
let _ = (backend, coeffs, dim, out);
false
}
fn fit_nd_dd_to(
&self,
backend: Option<&GemmBackendHandle>,
values: &Slice<f64, DynRank>,
dim: usize,
out: &mut ViewMut<'_, f64, DynRank>,
) -> bool {
let _ = (backend, values, dim, out);
false
}
fn fit_nd_dz_to(
&self,
backend: Option<&GemmBackendHandle>,
values: &Slice<f64, DynRank>,
dim: usize,
out: &mut ViewMut<'_, Complex<f64>, DynRank>,
) -> bool {
let _ = (backend, values, dim, out);
false
}
fn fit_nd_zd_to(
&self,
backend: Option<&GemmBackendHandle>,
values: &Slice<Complex<f64>, DynRank>,
dim: usize,
out: &mut ViewMut<'_, f64, DynRank>,
) -> bool {
let _ = (backend, values, dim, out);
false
}
fn fit_nd_zz_to(
&self,
backend: Option<&GemmBackendHandle>,
values: &Slice<Complex<f64>, DynRank>,
dim: usize,
out: &mut ViewMut<'_, Complex<f64>, DynRank>,
) -> bool {
let _ = (backend, values, dim, out);
false
}
}
pub(crate) fn make_perm_to_front(rank: usize, dim: usize) -> Vec<usize> {
let mut perm = Vec::with_capacity(rank);
perm.push(dim);
for i in 0..rank {
if i != dim {
perm.push(i);
}
}
perm
}
pub(crate) fn copy_from_contiguous<T: Copy>(
src: &[T],
dst: &mut mdarray::Slice<T, mdarray::DynRank, mdarray::Strided>,
) {
assert_eq!(src.len(), dst.len(), "Source size mismatch");
for (d, s) in dst.iter_mut().zip(src.iter()) {
*d = *s;
}
}
#[allow(dead_code)]
pub(crate) fn complex_slice_mut_as_real<'a>(
out: &'a mut Slice<Complex<f64>, DynRank>,
) -> mdarray::ViewMut<'a, f64, DynRank, mdarray::Dense> {
let mut new_shape: Vec<usize> = Vec::with_capacity(out.rank() + 1);
out.shape().with_dims(|dims| {
for d in dims {
new_shape.push(*d);
}
});
new_shape.push(2);
unsafe {
let shape: DynRank = Shape::from_dims(&new_shape[..]);
let mapping = mdarray::DenseMapping::new(shape);
mdarray::ViewMut::new_unchecked(out.as_mut_ptr() as *mut f64, mapping)
}
}
pub(crate) struct RealSVD {
pub ut: DTensor<f64, 2>, pub s: Vec<f64>, pub v: DTensor<f64, 2>, }
impl RealSVD {
pub fn new(u: DTensor<f64, 2>, s: Vec<f64>, vt: DTensor<f64, 2>) -> Self {
let (_, u_cols) = *u.shape();
let (vt_rows, _) = *vt.shape();
let min_dim = s.len();
assert_eq!(
u_cols, min_dim,
"u.cols()={} must equal s.len()={}",
u_cols, min_dim
);
assert_eq!(
vt_rows, min_dim,
"vt.rows()={} must equal s.len()={}",
vt_rows, min_dim
);
let ut = u.transpose().to_tensor(); let v = vt.transpose().to_tensor();
assert_eq!(
v.shape().1,
min_dim,
"v.cols()={} must equal s.len()={}",
v.shape().1,
min_dim
);
Self { ut, s, v }
}
}
pub(crate) struct ComplexSVD {
pub ut: DTensor<Complex<f64>, 2>, pub s: Vec<f64>, pub v: DTensor<Complex<f64>, 2>, }
impl ComplexSVD {
pub fn new(u: DTensor<Complex<f64>, 2>, s: Vec<f64>, vt: DTensor<Complex<f64>, 2>) -> Self {
let (u_rows, u_cols) = *u.shape();
let (vt_rows, _) = *vt.shape();
let min_dim = s.len();
assert_eq!(
u_cols, min_dim,
"u.cols()={} must equal s.len()={}",
u_cols, min_dim
);
assert_eq!(
vt_rows, min_dim,
"vt.rows()={} must equal s.len()={}",
vt_rows, min_dim
);
let ut = DTensor::<Complex<f64>, 2>::from_fn([u_cols, u_rows], |idx| {
u[[idx[1], idx[0]]].conj() });
let v = vt.transpose().to_tensor();
assert_eq!(
v.shape().1,
min_dim,
"v.cols()={} must equal s.len()={}",
v.shape().1,
min_dim
);
Self { ut, s, v }
}
}
pub(crate) fn compute_real_svd(matrix: &DTensor<f64, 2>) -> RealSVD {
use mdarray_linalg::prelude::SVD;
use mdarray_linalg::svd::SVDDecomp;
use mdarray_linalg_faer::Faer;
let _guard = FpuGuard::new_protect_computation();
let mut a = matrix.clone();
let SVDDecomp { u, s, vt } = Faer.svd(&mut *a).expect("SVD computation failed");
let min_dim = s.shape().0.min(s.shape().1);
let s_vec: Vec<f64> = (0..min_dim).map(|i| s[[0, i]]).collect();
let u_trimmed = u.view(.., ..min_dim).to_tensor();
let vt_trimmed = vt.view(..min_dim, ..).to_tensor();
RealSVD::new(u_trimmed, s_vec, vt_trimmed)
}
pub(crate) fn compute_complex_svd(matrix: &DTensor<Complex<f64>, 2>) -> ComplexSVD {
use mdarray_linalg::prelude::SVD;
use mdarray_linalg::svd::SVDDecomp;
use mdarray_linalg_faer::Faer;
let _guard = FpuGuard::new_protect_computation();
let mut matrix_c64 = matrix.clone();
let SVDDecomp { u, s, vt } = Faer
.svd(&mut *matrix_c64)
.expect("Complex SVD computation failed");
let min_dim = s.shape().0.min(s.shape().1);
let s_vec: Vec<f64> = (0..min_dim).map(|i| s[[0, i]].re).collect();
let u_trimmed = u.view(.., ..min_dim).to_tensor();
let vt_trimmed = vt.view(..min_dim, ..).to_tensor();
ComplexSVD::new(u_trimmed, s_vec, vt_trimmed)
}
pub(crate) fn combine_complex(
re: &DTensor<f64, 2>,
im: &DTensor<f64, 2>,
) -> DTensor<Complex<f64>, 2> {
let (n_points, extra_size) = *re.shape();
DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
Complex::new(re[idx], im[idx])
})
}
pub(crate) fn extract_real_parts_coeffs(coeffs_2d: &DTensor<Complex<f64>, 2>) -> DTensor<f64, 2> {
let (basis_size, extra_size) = *coeffs_2d.shape();
DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| coeffs_2d[idx].re)
}