1use crate::fpu_check::FpuGuard;
7use crate::gemm::GemmBackendHandle;
8use mdarray::{DTensor, DynRank, Shape, Slice, ViewMut};
9use num_complex::Complex;
10
11pub trait InplaceFitter {
34 fn n_points(&self) -> usize;
36
37 fn basis_size(&self) -> usize;
39
40 fn evaluate_nd_dd_to(
42 &self,
43 backend: Option<&GemmBackendHandle>,
44 coeffs: &Slice<f64, DynRank>,
45 dim: usize,
46 out: &mut ViewMut<'_, f64, DynRank>,
47 ) -> bool {
48 let _ = (backend, coeffs, dim, out);
49 false
50 }
51
52 fn evaluate_nd_dz_to(
54 &self,
55 backend: Option<&GemmBackendHandle>,
56 coeffs: &Slice<f64, DynRank>,
57 dim: usize,
58 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
59 ) -> bool {
60 let _ = (backend, coeffs, dim, out);
61 false
62 }
63
64 fn evaluate_nd_zd_to(
66 &self,
67 backend: Option<&GemmBackendHandle>,
68 coeffs: &Slice<Complex<f64>, DynRank>,
69 dim: usize,
70 out: &mut ViewMut<'_, f64, DynRank>,
71 ) -> bool {
72 let _ = (backend, coeffs, dim, out);
73 false
74 }
75
76 fn evaluate_nd_zz_to(
78 &self,
79 backend: Option<&GemmBackendHandle>,
80 coeffs: &Slice<Complex<f64>, DynRank>,
81 dim: usize,
82 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
83 ) -> bool {
84 let _ = (backend, coeffs, dim, out);
85 false
86 }
87
88 fn fit_nd_dd_to(
90 &self,
91 backend: Option<&GemmBackendHandle>,
92 values: &Slice<f64, DynRank>,
93 dim: usize,
94 out: &mut ViewMut<'_, f64, DynRank>,
95 ) -> bool {
96 let _ = (backend, values, dim, out);
97 false
98 }
99
100 fn fit_nd_dz_to(
102 &self,
103 backend: Option<&GemmBackendHandle>,
104 values: &Slice<f64, DynRank>,
105 dim: usize,
106 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
107 ) -> bool {
108 let _ = (backend, values, dim, out);
109 false
110 }
111
112 fn fit_nd_zd_to(
114 &self,
115 backend: Option<&GemmBackendHandle>,
116 values: &Slice<Complex<f64>, DynRank>,
117 dim: usize,
118 out: &mut ViewMut<'_, f64, DynRank>,
119 ) -> bool {
120 let _ = (backend, values, dim, out);
121 false
122 }
123
124 fn fit_nd_zz_to(
126 &self,
127 backend: Option<&GemmBackendHandle>,
128 values: &Slice<Complex<f64>, DynRank>,
129 dim: usize,
130 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
131 ) -> bool {
132 let _ = (backend, values, dim, out);
133 false
134 }
135}
136
137pub(crate) fn make_perm_to_front(rank: usize, dim: usize) -> Vec<usize> {
146 let mut perm = Vec::with_capacity(rank);
147 perm.push(dim);
148 for i in 0..rank {
149 if i != dim {
150 perm.push(i);
151 }
152 }
153 perm
154}
155
156pub(crate) fn copy_to_contiguous<T: Copy>(
169 src: &mdarray::Slice<T, mdarray::DynRank, mdarray::Strided>,
170 dst: &mut [T],
171) {
172 assert_eq!(dst.len(), src.len(), "Destination size mismatch");
173
174 for (d, s) in dst.iter_mut().zip(src.iter()) {
176 *d = *s;
177 }
178}
179
180pub(crate) fn copy_from_contiguous<T: Copy>(
188 src: &[T],
189 dst: &mut mdarray::Slice<T, mdarray::DynRank, mdarray::Strided>,
190) {
191 assert_eq!(src.len(), dst.len(), "Source size mismatch");
192
193 for (d, s) in dst.iter_mut().zip(src.iter()) {
195 *d = *s;
196 }
197}
198
199pub(crate) fn complex_slice_as_real<'a>(
208 coeffs: &'a Slice<Complex<f64>, DynRank>,
209) -> mdarray::View<'a, f64, DynRank, mdarray::Dense> {
210 let mut new_shape: Vec<usize> = Vec::with_capacity(coeffs.rank() + 1);
212 coeffs.shape().with_dims(|dims| {
213 for d in dims {
214 new_shape.push(*d);
215 }
216 });
217 new_shape.push(2);
218
219 unsafe {
220 let shape: DynRank = Shape::from_dims(&new_shape[..]);
221 let mapping = mdarray::DenseMapping::new(shape);
222 mdarray::View::new_unchecked(coeffs.as_ptr() as *const f64, mapping)
223 }
224}
225
226#[allow(dead_code)]
231pub(crate) fn complex_slice_mut_as_real<'a>(
232 out: &'a mut Slice<Complex<f64>, DynRank>,
233) -> mdarray::ViewMut<'a, f64, DynRank, mdarray::Dense> {
234 let mut new_shape: Vec<usize> = Vec::with_capacity(out.rank() + 1);
236 out.shape().with_dims(|dims| {
237 for d in dims {
238 new_shape.push(*d);
239 }
240 });
241 new_shape.push(2);
242
243 unsafe {
244 let shape: DynRank = Shape::from_dims(&new_shape[..]);
245 let mapping = mdarray::DenseMapping::new(shape);
246 mdarray::ViewMut::new_unchecked(out.as_mut_ptr() as *mut f64, mapping)
247 }
248}
249
250pub(crate) struct RealSVD {
256 pub ut: DTensor<f64, 2>, pub s: Vec<f64>, pub v: DTensor<f64, 2>, }
260
261impl RealSVD {
262 pub fn new(u: DTensor<f64, 2>, s: Vec<f64>, vt: DTensor<f64, 2>) -> Self {
263 let (_, u_cols) = *u.shape();
265 let (vt_rows, _) = *vt.shape();
266 let min_dim = s.len();
267
268 assert_eq!(
269 u_cols, min_dim,
270 "u.cols()={} must equal s.len()={}",
271 u_cols, min_dim
272 );
273 assert_eq!(
274 vt_rows, min_dim,
275 "vt.rows()={} must equal s.len()={}",
276 vt_rows, min_dim
277 );
278
279 let ut = u.transpose().to_tensor(); let v = vt.transpose().to_tensor(); assert_eq!(
285 v.shape().1,
286 min_dim,
287 "v.cols()={} must equal s.len()={}",
288 v.shape().1,
289 min_dim
290 );
291
292 Self { ut, s, v }
293 }
294}
295
296pub(crate) struct ComplexSVD {
298 pub ut: DTensor<Complex<f64>, 2>, pub s: Vec<f64>, pub v: DTensor<Complex<f64>, 2>, }
302
303impl ComplexSVD {
304 pub fn new(u: DTensor<Complex<f64>, 2>, s: Vec<f64>, vt: DTensor<Complex<f64>, 2>) -> Self {
305 let (u_rows, u_cols) = *u.shape();
307 let (vt_rows, _) = *vt.shape();
308 let min_dim = s.len();
309
310 assert_eq!(
311 u_cols, min_dim,
312 "u.cols()={} must equal s.len()={}",
313 u_cols, min_dim
314 );
315 assert_eq!(
316 vt_rows, min_dim,
317 "vt.rows()={} must equal s.len()={}",
318 vt_rows, min_dim
319 );
320
321 let ut = DTensor::<Complex<f64>, 2>::from_fn([u_cols, u_rows], |idx| {
323 u[[idx[1], idx[0]]].conj() });
325 let v = vt.transpose().to_tensor(); assert_eq!(
329 v.shape().1,
330 min_dim,
331 "v.cols()={} must equal s.len()={}",
332 v.shape().1,
333 min_dim
334 );
335
336 Self { ut, s, v }
337 }
338}
339
340pub(crate) fn compute_real_svd(matrix: &DTensor<f64, 2>) -> RealSVD {
346 use mdarray_linalg::prelude::SVD;
347 use mdarray_linalg::svd::SVDDecomp;
348 use mdarray_linalg_faer::Faer;
349
350 let _guard = FpuGuard::new_protect_computation();
352
353 let mut a = matrix.clone();
354 let SVDDecomp { u, s, vt } = Faer.svd(&mut *a).expect("SVD computation failed");
355
356 let min_dim = s.shape().0.min(s.shape().1);
358 let s_vec: Vec<f64> = (0..min_dim).map(|i| s[[0, i]]).collect();
359
360 let u_trimmed = u.view(.., ..min_dim).to_tensor();
364 let vt_trimmed = vt.view(..min_dim, ..).to_tensor();
365
366 RealSVD::new(u_trimmed, s_vec, vt_trimmed)
367}
368
369pub(crate) fn compute_complex_svd(matrix: &DTensor<Complex<f64>, 2>) -> ComplexSVD {
371 use mdarray_linalg::prelude::SVD;
372 use mdarray_linalg::svd::SVDDecomp;
373 use mdarray_linalg_faer::Faer;
374
375 let _guard = FpuGuard::new_protect_computation();
377
378 let mut matrix_c64 = matrix.clone();
380
381 let SVDDecomp { u, s, vt } = Faer
383 .svd(&mut *matrix_c64)
384 .expect("Complex SVD computation failed");
385
386 let min_dim = s.shape().0.min(s.shape().1);
388 let s_vec: Vec<f64> = (0..min_dim).map(|i| s[[0, i]].re).collect();
389
390 let u_trimmed = u.view(.., ..min_dim).to_tensor();
394 let vt_trimmed = vt.view(..min_dim, ..).to_tensor();
395
396 ComplexSVD::new(u_trimmed, s_vec, vt_trimmed)
397}
398
399pub(crate) fn combine_complex(
405 re: &DTensor<f64, 2>,
406 im: &DTensor<f64, 2>,
407) -> DTensor<Complex<f64>, 2> {
408 let (n_points, extra_size) = *re.shape();
409 DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
410 Complex::new(re[idx], im[idx])
411 })
412}
413
414pub(crate) fn extract_real_parts_coeffs(coeffs_2d: &DTensor<Complex<f64>, 2>) -> DTensor<f64, 2> {
416 let (basis_size, extra_size) = *coeffs_2d.shape();
417 DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| coeffs_2d[idx].re)
418}