use crate::gemm::GemmBackendHandle;
use mdarray::{DTensor, DView, DynRank, Shape, Slice, ViewMut};
use num_complex::Complex;
use std::sync::OnceLock;
use super::common::{InplaceFitter, RealSVD, compute_real_svd};
#[inline]
fn interleave_to_complex(
values_re: &DTensor<f64, 2>,
values_im: &DTensor<f64, 2>,
out: &mut mdarray::DViewMut<'_, Complex<f64>, 2>,
) {
let total = values_re.len();
debug_assert_eq!(values_im.len(), total);
debug_assert_eq!(out.len(), total);
let out_ptr = out.as_mut_ptr();
let re_ptr = values_re.as_ptr();
let im_ptr = values_im.as_ptr();
unsafe {
for k in 0..total {
*out_ptr.add(k) = Complex::new(*re_ptr.add(k), *im_ptr.add(k));
}
}
}
#[inline]
fn flatten_complex_to_real_rows(values: &DView<'_, Complex<f64>, 2>, out: &mut DTensor<f64, 2>) {
let (n_points, extra_size) = *values.shape();
debug_assert_eq!(*out.shape(), (2 * n_points, extra_size));
let src_ptr = values.as_ptr();
let dst_ptr = out.as_mut_ptr();
unsafe {
for i in 0..n_points {
for j in 0..extra_size {
let val = *src_ptr.add(i * extra_size + j);
*dst_ptr.add((2 * i) * extra_size + j) = val.re;
*dst_ptr.add((2 * i + 1) * extra_size + j) = val.im;
}
}
}
}
#[inline]
fn flatten_complex_to_real_cols(values: &DView<'_, Complex<f64>, 2>, out: &mut DTensor<f64, 2>) {
let (extra_size, n_points) = *values.shape();
debug_assert_eq!(*out.shape(), (extra_size, 2 * n_points));
let src_ptr = values.as_ptr();
let dst_ptr = out.as_mut_ptr();
unsafe {
for i in 0..extra_size {
for j in 0..n_points {
let val = *src_ptr.add(i * n_points + j);
*dst_ptr.add(i * (2 * n_points) + 2 * j) = val.re;
*dst_ptr.add(i * (2 * n_points) + 2 * j + 1) = val.im;
}
}
}
}
pub(crate) struct ComplexToRealFitter {
matrix_real: DTensor<f64, 2>, matrix_re: DTensor<f64, 2>, matrix_im: DTensor<f64, 2>, matrix_re_t: DTensor<f64, 2>, matrix_im_t: DTensor<f64, 2>, pub matrix: DTensor<Complex<f64>, 2>, svd: OnceLock<RealSVDExtended>,
n_points: usize, }
struct RealSVDExtended {
svd: RealSVD,
u: DTensor<f64, 2>,
vt: DTensor<f64, 2>,
}
impl RealSVDExtended {
fn from_matrix(matrix: &DTensor<f64, 2>) -> Self {
let svd = compute_real_svd(matrix);
let min_dim = svd.s.len();
let (rows, cols) = *matrix.shape();
let u = DTensor::<f64, 2>::from_fn([rows, min_dim], |idx| svd.ut[[idx[1], idx[0]]]);
let vt = DTensor::<f64, 2>::from_fn([min_dim, cols], |idx| svd.v[[idx[1], idx[0]]]);
Self { svd, u, vt }
}
}
impl ComplexToRealFitter {
pub fn new(matrix_complex: &DTensor<Complex<f64>, 2>) -> Self {
let (n_points, basis_size) = *matrix_complex.shape();
let matrix_real = DTensor::<f64, 2>::from_fn([2 * n_points, basis_size], |idx| {
let i = idx[0] / 2;
let j = idx[1];
let val = matrix_complex[[i, j]];
if idx[0] % 2 == 0 {
val.re } else {
val.im }
});
let matrix_re =
DTensor::<f64, 2>::from_fn([n_points, basis_size], |idx| matrix_complex[idx].re);
let matrix_im =
DTensor::<f64, 2>::from_fn([n_points, basis_size], |idx| matrix_complex[idx].im);
let matrix_re_t =
DTensor::<f64, 2>::from_fn([basis_size, n_points], |idx| matrix_re[[idx[1], idx[0]]]);
let matrix_im_t =
DTensor::<f64, 2>::from_fn([basis_size, n_points], |idx| matrix_im[[idx[1], idx[0]]]);
Self {
matrix_real,
matrix_re,
matrix_im,
matrix_re_t,
matrix_im_t,
matrix: matrix_complex.clone(),
svd: OnceLock::new(),
n_points,
}
}
pub fn n_points(&self) -> usize {
self.n_points
}
pub fn basis_size(&self) -> usize {
self.matrix_real.shape().1
}
pub fn evaluate(
&self,
backend: Option<&GemmBackendHandle>,
coeffs: &[f64],
) -> Vec<Complex<f64>> {
assert_eq!(
coeffs.len(),
self.basis_size(),
"coeffs.len()={} must equal basis_size={}",
coeffs.len(),
self.basis_size()
);
let basis_size = coeffs.len();
let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, 1], |idx| coeffs[idx[0]]);
let coeffs_view = coeffs_2d.view(.., ..);
let values_2d = self.evaluate_2d(backend, &coeffs_view);
let n_points = self.n_points();
(0..n_points).map(|i| values_2d[[i, 0]]).collect()
}
pub fn fit(&self, backend: Option<&GemmBackendHandle>, values: &[Complex<f64>]) -> Vec<f64> {
assert_eq!(
values.len(),
self.n_points(),
"values.len()={} must equal n_points={}",
values.len(),
self.n_points()
);
let n = values.len();
let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n, 1], |idx| values[idx[0]]);
let values_view = values_2d.view(.., ..);
let coeffs_2d = self.fit_2d(backend, &values_view);
let basis_size = self.basis_size();
(0..basis_size).map(|i| coeffs_2d[[i, 0]]).collect()
}
pub fn evaluate_2d(
&self,
backend: Option<&GemmBackendHandle>,
coeffs_2d: &DView<'_, f64, 2>,
) -> DTensor<Complex<f64>, 2> {
let (basis_size, extra_size) = *coeffs_2d.shape();
assert_eq!(
basis_size,
self.basis_size(),
"coeffs_2d.shape().0={} must equal basis_size={}",
basis_size,
self.basis_size()
);
let n_points = self.n_points();
let mut out = DTensor::<Complex<f64>, 2>::zeros([n_points, extra_size]);
let mut out_view = out.view_mut(.., ..);
self.evaluate_2d_to(backend, coeffs_2d, &mut out_view);
out
}
pub fn evaluate_2d_to(
&self,
backend: Option<&GemmBackendHandle>,
coeffs_2d: &DView<'_, f64, 2>,
out: &mut mdarray::DViewMut<'_, Complex<f64>, 2>,
) {
use crate::gemm::matmul_par_view;
let (basis_size, extra_size) = *coeffs_2d.shape();
let (out_rows, out_cols) = *out.shape();
assert_eq!(
basis_size,
self.basis_size(),
"coeffs_2d.shape().0={} must equal basis_size={}",
basis_size,
self.basis_size()
);
assert_eq!(
out_rows,
self.n_points(),
"out.shape().0={} must equal n_points={}",
out_rows,
self.n_points()
);
assert_eq!(
out_cols, extra_size,
"out.shape().1={} must equal extra_size={}",
out_cols, extra_size
);
let matrix_re_view = self.matrix_re.view(.., ..);
let matrix_im_view = self.matrix_im.view(.., ..);
let values_re = matmul_par_view(&matrix_re_view, coeffs_2d, backend);
let values_im = matmul_par_view(&matrix_im_view, coeffs_2d, backend);
interleave_to_complex(&values_re, &values_im, out);
}
fn evaluate_2d_to_dim(
&self,
backend: Option<&GemmBackendHandle>,
coeffs_2d: &DView<'_, f64, 2>,
out: &mut mdarray::DViewMut<'_, Complex<f64>, 2>,
dim: usize,
) {
use crate::gemm::matmul_par_view;
let (coeffs_rows, coeffs_cols) = *coeffs_2d.shape();
let (out_rows, out_cols) = *out.shape();
let n_points = self.n_points();
if dim == 0 {
let basis_size = coeffs_rows;
let extra_size = coeffs_cols;
assert_eq!(basis_size, self.basis_size());
assert_eq!(out_rows, n_points);
assert_eq!(out_cols, extra_size);
let matrix_re_view = self.matrix_re.view(.., ..);
let matrix_im_view = self.matrix_im.view(.., ..);
let values_re = matmul_par_view(&matrix_re_view, coeffs_2d, backend);
let values_im = matmul_par_view(&matrix_im_view, coeffs_2d, backend);
interleave_to_complex(&values_re, &values_im, out);
} else {
let extra_size = coeffs_rows;
let basis_size = coeffs_cols;
assert_eq!(basis_size, self.basis_size());
assert_eq!(out_rows, extra_size);
assert_eq!(out_cols, n_points);
let matrix_re_t_view = self.matrix_re_t.view(.., ..);
let matrix_im_t_view = self.matrix_im_t.view(.., ..);
let values_re = matmul_par_view(coeffs_2d, &matrix_re_t_view, backend);
let values_im = matmul_par_view(coeffs_2d, &matrix_im_t_view, backend);
interleave_to_complex(&values_re, &values_im, out);
}
}
pub fn evaluate_nd_dz_to(
&self,
backend: Option<&GemmBackendHandle>,
coeffs: &Slice<f64, DynRank>,
dim: usize,
out: &mut ViewMut<'_, Complex<f64>, DynRank>,
) -> bool {
let rank = coeffs.rank();
let basis_size = self.basis_size();
let n_points = self.n_points();
assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
assert_eq!(out.rank(), rank);
assert_eq!(coeffs.shape().dim(dim), basis_size);
assert_eq!(out.shape().dim(dim), n_points);
let total = coeffs.len();
let extra_size = total / basis_size;
if dim == 0 {
let coeffs_2d = unsafe {
let mapping = mdarray::DenseMapping::new((basis_size, extra_size));
mdarray::DView::<'_, f64, 2>::new_unchecked(coeffs.as_ptr(), mapping)
};
let mut out_2d = unsafe {
let mapping = mdarray::DenseMapping::new((n_points, extra_size));
mdarray::DViewMut::<'_, Complex<f64>, 2>::new_unchecked(out.as_mut_ptr(), mapping)
};
self.evaluate_2d_to_dim(backend, &coeffs_2d, &mut out_2d, 0);
} else if dim == rank - 1 {
let coeffs_2d = unsafe {
let mapping = mdarray::DenseMapping::new((extra_size, basis_size));
mdarray::DView::<'_, f64, 2>::new_unchecked(coeffs.as_ptr(), mapping)
};
let mut out_2d = unsafe {
let mapping = mdarray::DenseMapping::new((extra_size, n_points));
mdarray::DViewMut::<'_, Complex<f64>, 2>::new_unchecked(out.as_mut_ptr(), mapping)
};
self.evaluate_2d_to_dim(backend, &coeffs_2d, &mut out_2d, 1);
} else {
self.evaluate_nd_dz_to_batched(backend, coeffs, dim, out);
}
true
}
fn evaluate_nd_dz_to_batched(
&self,
backend: Option<&GemmBackendHandle>,
coeffs: &Slice<f64, DynRank>,
dim: usize,
out: &mut ViewMut<'_, Complex<f64>, DynRank>,
) {
let rank = coeffs.rank();
let basis_size = self.basis_size();
let n_points = self.n_points();
let mut batch_size = 1usize;
let mut extra_size = 1usize;
coeffs.shape().with_dims(|dims| {
for i in 0..dim {
batch_size *= dims[i];
}
for i in (dim + 1)..rank {
extra_size *= dims[i];
}
});
let coeffs_batch_stride = basis_size * extra_size;
let out_batch_stride = n_points * extra_size;
let coeffs_ptr = coeffs.as_ptr();
let out_ptr = out.as_mut_ptr();
for b in 0..batch_size {
let coeffs_2d = unsafe {
let mapping = mdarray::DenseMapping::new((basis_size, extra_size));
mdarray::DView::<'_, f64, 2>::new_unchecked(
coeffs_ptr.add(b * coeffs_batch_stride),
mapping,
)
};
let mut out_2d = unsafe {
let mapping = mdarray::DenseMapping::new((n_points, extra_size));
mdarray::DViewMut::<'_, Complex<f64>, 2>::new_unchecked(
out_ptr.add(b * out_batch_stride),
mapping,
)
};
self.evaluate_2d_to_dim(backend, &coeffs_2d, &mut out_2d, 0);
}
}
pub fn fit_2d(
&self,
backend: Option<&GemmBackendHandle>,
values_2d: &DView<'_, Complex<f64>, 2>,
) -> DTensor<f64, 2> {
let (n_points, extra_size) = *values_2d.shape();
assert_eq!(
n_points,
self.n_points(),
"values_2d.shape().0={} must equal n_points={}",
n_points,
self.n_points()
);
let basis_size = self.basis_size();
let mut out = DTensor::<f64, 2>::zeros([basis_size, extra_size]);
let mut out_view = out.view_mut(.., ..);
self.fit_2d_to(backend, values_2d, &mut out_view);
out
}
pub fn fit_2d_to(
&self,
backend: Option<&GemmBackendHandle>,
values_2d: &DView<'_, Complex<f64>, 2>,
out: &mut mdarray::DViewMut<'_, f64, 2>,
) {
use crate::gemm::{matmul_par_to_viewmut, matmul_par_view};
let (n_points, extra_size) = *values_2d.shape();
let (out_rows, out_cols) = *out.shape();
assert_eq!(
n_points,
self.n_points(),
"values_2d.shape().0={} must equal n_points={}",
n_points,
self.n_points()
);
assert_eq!(
out_rows,
self.basis_size(),
"out.shape().0={} must equal basis_size={}",
out_rows,
self.basis_size()
);
assert_eq!(
out_cols, extra_size,
"out.shape().1={} must equal extra_size={}",
out_cols, extra_size
);
let mut values_flat = DTensor::<f64, 2>::zeros([2 * n_points, extra_size]);
flatten_complex_to_real_rows(values_2d, &mut values_flat);
let svd_ext = self.get_svd();
let svd = &svd_ext.svd;
let ut_view = svd.ut.view(.., ..);
let values_flat_view = values_flat.view(.., ..);
let mut ut_values = matmul_par_view(&ut_view, &values_flat_view, backend);
let min_dim = svd.s.len();
for i in 0..min_dim {
for j in 0..extra_size {
ut_values[[i, j]] /= svd.s[i];
}
}
let v_view = svd.v.view(.., ..);
let ut_values_view = ut_values.view(.., ..);
matmul_par_to_viewmut(&v_view, &ut_values_view, out, backend);
}
fn get_svd(&self) -> &RealSVDExtended {
self.svd.get_or_init(|| {
let n_points = self.n_points();
let basis_size = self.basis_size();
let effective_points = 2 * n_points;
if effective_points < basis_size {
eprintln!(
"Warning: Effective number of sampling points ({} × 2 = {}) is less than basis size ({}). \
Fitting may be ill-conditioned.",
n_points, effective_points, basis_size
);
}
RealSVDExtended::from_matrix(&self.matrix_real)
})
}
fn fit_2d_to_dim(
&self,
backend: Option<&GemmBackendHandle>,
values_2d: &DView<'_, Complex<f64>, 2>,
out: &mut mdarray::DViewMut<'_, f64, 2>,
dim: usize,
) {
use crate::gemm::{matmul_par_to_viewmut, matmul_par_view};
let (values_rows, values_cols) = *values_2d.shape();
let (out_rows, out_cols) = *out.shape();
let svd_ext = self.get_svd();
let svd = &svd_ext.svd;
let min_dim = svd.s.len();
let n_points = self.n_points();
let basis_size = self.basis_size();
if dim == 0 {
let extra_size = values_cols;
assert_eq!(values_rows, n_points);
assert_eq!(out_rows, basis_size);
assert_eq!(out_cols, extra_size);
let mut values_flat = DTensor::<f64, 2>::zeros([2 * n_points, extra_size]);
flatten_complex_to_real_rows(values_2d, &mut values_flat);
let ut_view = svd.ut.view(.., ..);
let values_flat_view = values_flat.view(.., ..);
let mut ut_values = matmul_par_view(&ut_view, &values_flat_view, backend);
for i in 0..min_dim {
for j in 0..extra_size {
ut_values[[i, j]] /= svd.s[i];
}
}
let v_view = svd.v.view(.., ..);
let ut_values_view = ut_values.view(.., ..);
matmul_par_to_viewmut(&v_view, &ut_values_view, out, backend);
} else {
let extra_size = values_rows;
assert_eq!(values_cols, n_points);
assert_eq!(out_rows, extra_size);
assert_eq!(out_cols, basis_size);
let mut values_flat = DTensor::<f64, 2>::zeros([extra_size, 2 * n_points]);
flatten_complex_to_real_cols(values_2d, &mut values_flat);
let u_view = svd_ext.u.view(.., ..);
let values_flat_view = values_flat.view(.., ..);
let mut values_u = matmul_par_view(&values_flat_view, &u_view, backend);
for i in 0..extra_size {
for j in 0..min_dim {
values_u[[i, j]] /= svd.s[j];
}
}
let vt_view = svd_ext.vt.view(.., ..);
let values_u_view = values_u.view(.., ..);
matmul_par_to_viewmut(&values_u_view, &vt_view, out, backend);
}
}
pub fn fit_nd_zd_to(
&self,
backend: Option<&GemmBackendHandle>,
values: &Slice<Complex<f64>, DynRank>,
dim: usize,
out: &mut ViewMut<'_, f64, DynRank>,
) -> bool {
let rank = values.rank();
let basis_size = self.basis_size();
let n_points = self.n_points();
assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
assert_eq!(out.rank(), rank);
assert_eq!(values.shape().dim(dim), n_points);
assert_eq!(out.shape().dim(dim), basis_size);
let total = values.len();
let extra_size = total / n_points;
if dim == 0 {
let values_2d = unsafe {
let mapping = mdarray::DenseMapping::new((n_points, extra_size));
mdarray::DView::<'_, Complex<f64>, 2>::new_unchecked(values.as_ptr(), mapping)
};
let mut out_2d = unsafe {
let mapping = mdarray::DenseMapping::new((basis_size, extra_size));
mdarray::DViewMut::<'_, f64, 2>::new_unchecked(out.as_mut_ptr(), mapping)
};
self.fit_2d_to_dim(backend, &values_2d, &mut out_2d, 0);
} else if dim == rank - 1 {
let values_2d = unsafe {
let mapping = mdarray::DenseMapping::new((extra_size, n_points));
mdarray::DView::<'_, Complex<f64>, 2>::new_unchecked(values.as_ptr(), mapping)
};
let mut out_2d = unsafe {
let mapping = mdarray::DenseMapping::new((extra_size, basis_size));
mdarray::DViewMut::<'_, f64, 2>::new_unchecked(out.as_mut_ptr(), mapping)
};
self.fit_2d_to_dim(backend, &values_2d, &mut out_2d, 1);
} else {
self.fit_nd_zd_to_batched(backend, values, dim, out);
}
true
}
fn fit_nd_zd_to_batched(
&self,
backend: Option<&GemmBackendHandle>,
values: &Slice<Complex<f64>, DynRank>,
dim: usize,
out: &mut ViewMut<'_, f64, DynRank>,
) {
let rank = values.rank();
let n_points = self.n_points();
let basis_size = self.basis_size();
let mut batch_size = 1usize;
let mut extra_size = 1usize;
values.shape().with_dims(|dims| {
for i in 0..dim {
batch_size *= dims[i];
}
for i in (dim + 1)..rank {
extra_size *= dims[i];
}
});
let values_batch_stride = n_points * extra_size;
let out_batch_stride = basis_size * extra_size;
let values_ptr = values.as_ptr();
let out_ptr = out.as_mut_ptr();
for b in 0..batch_size {
let values_2d = unsafe {
let mapping = mdarray::DenseMapping::new((n_points, extra_size));
mdarray::DView::<'_, Complex<f64>, 2>::new_unchecked(
values_ptr.add(b * values_batch_stride),
mapping,
)
};
let mut out_2d = unsafe {
let mapping = mdarray::DenseMapping::new((basis_size, extra_size));
mdarray::DViewMut::<'_, f64, 2>::new_unchecked(
out_ptr.add(b * out_batch_stride),
mapping,
)
};
self.fit_2d_to_dim(backend, &values_2d, &mut out_2d, 0);
}
}
}
impl InplaceFitter for ComplexToRealFitter {
fn n_points(&self) -> usize {
self.n_points()
}
fn basis_size(&self) -> usize {
self.basis_size()
}
fn evaluate_nd_dz_to(
&self,
backend: Option<&GemmBackendHandle>,
coeffs: &Slice<f64, DynRank>,
dim: usize,
out: &mut ViewMut<'_, Complex<f64>, DynRank>,
) -> bool {
ComplexToRealFitter::evaluate_nd_dz_to(self, backend, coeffs, dim, out)
}
fn evaluate_nd_zz_to(
&self,
backend: Option<&GemmBackendHandle>,
coeffs: &Slice<Complex<f64>, DynRank>,
dim: usize,
out: &mut ViewMut<'_, Complex<f64>, DynRank>,
) -> bool {
use mdarray::Shape;
let mut coeffs_shape: Vec<usize> = Vec::with_capacity(coeffs.rank());
coeffs.shape().with_dims(|dims| {
for d in dims {
coeffs_shape.push(*d);
}
});
let total = coeffs.len();
let mut real_buffer: Vec<f64> = Vec::with_capacity(total);
for c in coeffs.iter() {
real_buffer.push(c.re);
}
let shape: DynRank = Shape::from_dims(&coeffs_shape[..]);
let real_view = unsafe {
let mapping = mdarray::DenseMapping::new(shape);
mdarray::View::new_unchecked(real_buffer.as_ptr(), mapping)
};
ComplexToRealFitter::evaluate_nd_dz_to(self, backend, &real_view, dim, out)
}
fn fit_nd_zd_to(
&self,
backend: Option<&GemmBackendHandle>,
values: &Slice<Complex<f64>, DynRank>,
dim: usize,
out: &mut ViewMut<'_, f64, DynRank>,
) -> bool {
ComplexToRealFitter::fit_nd_zd_to(self, backend, values, dim, out)
}
fn fit_nd_zz_to(
&self,
backend: Option<&GemmBackendHandle>,
values: &Slice<Complex<f64>, DynRank>,
dim: usize,
out: &mut ViewMut<'_, Complex<f64>, DynRank>,
) -> bool {
use mdarray::Shape;
let mut out_shape: Vec<usize> = Vec::with_capacity(out.rank());
out.shape().with_dims(|dims| {
for d in dims {
out_shape.push(*d);
}
});
let total = out.len();
let mut real_buffer: Vec<f64> = vec![0.0; total];
let shape: DynRank = Shape::from_dims(&out_shape[..]);
let mut real_view = unsafe {
let mapping = mdarray::DenseMapping::new(shape);
mdarray::ViewMut::new_unchecked(real_buffer.as_mut_ptr(), mapping)
};
if !ComplexToRealFitter::fit_nd_zd_to(self, backend, values, dim, &mut real_view) {
return false;
}
for (c, r) in out.iter_mut().zip(real_buffer.iter()) {
*c = Complex::new(*r, 0.0);
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use mdarray::DTensor;
use num_complex::Complex;
#[test]
fn test_roundtrip() {
let n_points = 10;
let basis_size = 5;
let matrix = DTensor::<Complex<f64>, 2>::from_fn([n_points, basis_size], |idx| {
let i = idx[0] as f64 / (n_points as f64);
let j = idx[1] as i32;
let re = i.powi(j);
let im = (i * (j as f64) * 0.1).sin();
Complex::new(re, im)
});
let fitter = ComplexToRealFitter::new(&matrix);
let coeffs: Vec<f64> = (0..basis_size).map(|i| (i as f64 + 1.0) * 0.5).collect();
let values = fitter.evaluate(None, &coeffs);
assert_eq!(values.len(), n_points);
assert!(values.iter().any(|z| z.im.abs() > 1e-10));
let fitted_coeffs = fitter.fit(None, &values);
assert_eq!(fitted_coeffs.len(), basis_size);
for (orig, fitted) in coeffs.iter().zip(fitted_coeffs.iter()) {
let error = (orig - fitted).abs();
assert!(error < 1e-10, "Roundtrip error: {}", error);
}
}
#[test]
fn test_overdetermined() {
let n_points = 20;
let basis_size = 5;
let matrix = DTensor::<Complex<f64>, 2>::from_fn([n_points, basis_size], |idx| {
let i = idx[0] as f64;
let j = idx[1] as f64;
let phase = 2.0 * std::f64::consts::PI * i * j / (n_points as f64);
Complex::new(phase.cos(), phase.sin()) / (j + 1.0)
});
let fitter = ComplexToRealFitter::new(&matrix);
let coeffs: Vec<f64> = (0..basis_size).map(|i| (i as f64) * 0.3).collect();
let values = fitter.evaluate(None, &coeffs);
let fitted_coeffs = fitter.fit(None, &values);
for (orig, fitted) in coeffs.iter().zip(fitted_coeffs.iter()) {
let error = (orig - fitted).abs();
assert!(error < 1e-10, "Overdetermined roundtrip error: {}", error);
}
}
#[test]
fn test_nd_roundtrip() {
use mdarray::Tensor;
let n_points = 8;
let basis_size = 4;
let extra1 = 3;
let extra2 = 2;
let matrix = DTensor::<Complex<f64>, 2>::from_fn([n_points, basis_size], |idx| {
let i = idx[0] as f64 / (n_points as f64);
let j = idx[1] as i32;
let re = i.powi(j);
let im = (i * (j as f64) * 0.1).sin();
Complex::new(re, im)
});
let fitter = ComplexToRealFitter::new(&matrix);
{
let coeffs =
Tensor::<f64, mdarray::DynRank>::from_fn(&[basis_size, extra1][..], |idx| {
(idx[0] + idx[1]) as f64 * 0.3
});
let mut values =
Tensor::<Complex<f64>, mdarray::DynRank>::zeros(&[n_points, extra1][..]);
let mut fitted = Tensor::<f64, mdarray::DynRank>::zeros(&[basis_size, extra1][..]);
fitter.evaluate_nd_dz_to(None, &coeffs.expr(), 0, &mut values.expr_mut());
fitter.fit_nd_zd_to(None, &values.expr(), 0, &mut fitted.expr_mut());
for i in 0..basis_size {
for j in 0..extra1 {
let error = (coeffs[&[i, j][..]] - fitted[&[i, j][..]]).abs();
assert!(
error < 1e-8,
"dim=0 roundtrip error at [{}, {}]: {}",
i,
j,
error
);
}
}
}
{
let coeffs =
Tensor::<f64, mdarray::DynRank>::from_fn(&[extra1, basis_size][..], |idx| {
(idx[0] + idx[1]) as f64 * 0.3
});
let mut values =
Tensor::<Complex<f64>, mdarray::DynRank>::zeros(&[extra1, n_points][..]);
let mut fitted = Tensor::<f64, mdarray::DynRank>::zeros(&[extra1, basis_size][..]);
fitter.evaluate_nd_dz_to(None, &coeffs.expr(), 1, &mut values.expr_mut());
fitter.fit_nd_zd_to(None, &values.expr(), 1, &mut fitted.expr_mut());
for i in 0..extra1 {
for j in 0..basis_size {
let error = (coeffs[&[i, j][..]] - fitted[&[i, j][..]]).abs();
assert!(
error < 1e-8,
"dim=1 roundtrip error at [{}, {}]: {}",
i,
j,
error
);
}
}
}
{
let coeffs = Tensor::<f64, mdarray::DynRank>::from_fn(
&[extra1, basis_size, extra2][..],
|idx| (idx[0] + idx[1] + idx[2]) as f64 * 0.2,
);
let mut values =
Tensor::<Complex<f64>, mdarray::DynRank>::zeros(&[extra1, n_points, extra2][..]);
let mut fitted =
Tensor::<f64, mdarray::DynRank>::zeros(&[extra1, basis_size, extra2][..]);
fitter.evaluate_nd_dz_to(None, &coeffs.expr(), 1, &mut values.expr_mut());
fitter.fit_nd_zd_to(None, &values.expr(), 1, &mut fitted.expr_mut());
for i in 0..extra1 {
for j in 0..basis_size {
for k in 0..extra2 {
let error = (coeffs[&[i, j, k][..]] - fitted[&[i, j, k][..]]).abs();
assert!(
error < 1e-8,
"dim=1 (3D) roundtrip error at [{}, {}, {}]: {}",
i,
j,
k,
error
);
}
}
}
}
}
}