1use crate::fitters::InplaceFitter;
7use crate::gemm::GemmBackendHandle;
8use crate::traits::StatisticsType;
9use mdarray::{DTensor, DynRank, Shape, Slice, Tensor, ViewMut};
10use num_complex::Complex;
11
12fn build_output_shape<S: Shape>(input_shape: &S, dim: usize, new_size: usize) -> Vec<usize> {
14 let mut out_shape: Vec<usize> = Vec::with_capacity(input_shape.rank());
15 input_shape.with_dims(|dims| {
16 for (i, d) in dims.iter().enumerate() {
17 if i == dim {
18 out_shape.push(new_size);
19 } else {
20 out_shape.push(*d);
21 }
22 }
23 });
24 out_shape
25}
26
27pub fn movedim<T: Clone>(arr: &Slice<T, DynRank>, src: usize, dst: usize) -> Tensor<T, DynRank> {
47 if src == dst {
48 return arr.to_tensor();
49 }
50
51 let rank = arr.rank();
52 assert!(
53 src < rank,
54 "src axis {} out of bounds for rank {}",
55 src,
56 rank
57 );
58 assert!(
59 dst < rank,
60 "dst axis {} out of bounds for rank {}",
61 dst,
62 rank
63 );
64
65 let mut perm = Vec::with_capacity(rank);
67 let mut pos = 0;
68 for i in 0..rank {
69 if i == dst {
70 perm.push(src);
71 } else {
72 if pos == src {
74 pos += 1;
75 }
76 perm.push(pos);
77 pos += 1;
78 }
79 }
80
81 arr.permute(&perm[..]).to_tensor()
82}
83
84pub struct TauSampling<S>
89where
90 S: StatisticsType,
91{
92 sampling_points: Vec<f64>,
94
95 fitter: crate::fitters::RealMatrixFitter,
97
98 _phantom: std::marker::PhantomData<S>,
100}
101
102impl<S> TauSampling<S>
103where
104 S: StatisticsType,
105{
106 pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
118 where
119 S: 'static,
120 {
121 let sampling_points = basis.default_tau_sampling_points();
122 Self::with_sampling_points(basis, sampling_points)
123 }
124
125 pub fn with_sampling_points(
139 basis: &impl crate::basis_trait::Basis<S>,
140 sampling_points: Vec<f64>,
141 ) -> Self
142 where
143 S: 'static,
144 {
145 assert!(!sampling_points.is_empty(), "No sampling points given");
146
147 let beta = basis.beta();
148 for &tau in &sampling_points {
149 assert!(
150 tau >= -beta && tau <= beta,
151 "Sampling point τ={} is outside [-β, β]",
152 tau
153 );
154 }
155
156 let matrix = basis.evaluate_tau(&sampling_points);
159
160 let fitter = crate::fitters::RealMatrixFitter::new(matrix);
162
163 Self {
164 sampling_points,
165 fitter,
166 _phantom: std::marker::PhantomData,
167 }
168 }
169
170 pub fn from_matrix(sampling_points: Vec<f64>, matrix: DTensor<f64, 2>) -> Self {
185 assert!(!sampling_points.is_empty(), "No sampling points given");
186 assert_eq!(
187 matrix.shape().0,
188 sampling_points.len(),
189 "Matrix rows ({}) must match number of sampling points ({})",
190 matrix.shape().0,
191 sampling_points.len()
192 );
193
194 let fitter = crate::fitters::RealMatrixFitter::new(matrix);
195
196 Self {
197 sampling_points,
198 fitter,
199 _phantom: std::marker::PhantomData,
200 }
201 }
202
203 pub fn sampling_points(&self) -> &[f64] {
205 &self.sampling_points
206 }
207
208 pub fn n_sampling_points(&self) -> usize {
210 self.fitter.n_points()
211 }
212
213 pub fn basis_size(&self) -> usize {
215 self.fitter.basis_size()
216 }
217
218 pub fn matrix(&self) -> &DTensor<f64, 2> {
220 &self.fitter.matrix
221 }
222
223 pub fn evaluate(&self, coeffs: &[f64]) -> Vec<f64> {
237 self.fitter.evaluate(None, coeffs)
238 }
239
240 pub fn evaluate_to(&self, coeffs: &[f64], out: &mut [f64]) {
242 self.fitter.evaluate_to(None, coeffs, out)
243 }
244
245 pub fn fit(&self, values: &[f64]) -> Vec<f64> {
247 self.fitter.fit(None, values)
248 }
249
250 pub fn fit_to(&self, values: &[f64], out: &mut [f64]) {
252 self.fitter.fit_to(None, values, out)
253 }
254
255 pub fn evaluate_zz(&self, coeffs: &[Complex<f64>]) -> Vec<Complex<f64>> {
257 self.fitter.evaluate_zz(None, coeffs)
258 }
259
260 pub fn evaluate_zz_to(&self, coeffs: &[Complex<f64>], out: &mut [Complex<f64>]) {
262 self.fitter.evaluate_zz_to(None, coeffs, out)
263 }
264
265 pub fn fit_zz(&self, values: &[Complex<f64>]) -> Vec<Complex<f64>> {
267 self.fitter.fit_zz(None, values)
268 }
269
270 pub fn fit_zz_to(&self, values: &[Complex<f64>], out: &mut [Complex<f64>]) {
272 self.fitter.fit_zz_to(None, values, out)
273 }
274
275 pub fn evaluate_nd(
288 &self,
289 backend: Option<&GemmBackendHandle>,
290 coeffs: &Slice<f64, DynRank>,
291 dim: usize,
292 ) -> Tensor<f64, DynRank> {
293 let out_shape = build_output_shape(coeffs.shape(), dim, self.n_sampling_points());
294 let mut out = Tensor::<f64, DynRank>::zeros(&out_shape[..]);
295 self.evaluate_nd_to(backend, coeffs, dim, &mut out.expr_mut());
296 out
297 }
298
299 pub fn evaluate_nd_to(
301 &self,
302 backend: Option<&GemmBackendHandle>,
303 coeffs: &Slice<f64, DynRank>,
304 dim: usize,
305 out: &mut ViewMut<'_, f64, DynRank>,
306 ) {
307 InplaceFitter::evaluate_nd_dd_to(self, backend, coeffs, dim, out);
308 }
309
310 pub fn fit_nd(
319 &self,
320 backend: Option<&GemmBackendHandle>,
321 values: &Slice<f64, DynRank>,
322 dim: usize,
323 ) -> Tensor<f64, DynRank> {
324 let out_shape = build_output_shape(values.shape(), dim, self.basis_size());
325 let mut out = Tensor::<f64, DynRank>::zeros(&out_shape[..]);
326 self.fit_nd_to(backend, values, dim, &mut out.expr_mut());
327 out
328 }
329
330 pub fn fit_nd_to(
332 &self,
333 backend: Option<&GemmBackendHandle>,
334 values: &Slice<f64, DynRank>,
335 dim: usize,
336 out: &mut ViewMut<'_, f64, DynRank>,
337 ) {
338 InplaceFitter::fit_nd_dd_to(self, backend, values, dim, out);
339 }
340
341 pub fn evaluate_nd_zz(
354 &self,
355 backend: Option<&GemmBackendHandle>,
356 coeffs: &Slice<Complex<f64>, DynRank>,
357 dim: usize,
358 ) -> Tensor<Complex<f64>, DynRank> {
359 let out_shape = build_output_shape(coeffs.shape(), dim, self.n_sampling_points());
360 let mut out = Tensor::<Complex<f64>, DynRank>::zeros(&out_shape[..]);
361 self.evaluate_nd_zz_to(backend, coeffs, dim, &mut out.expr_mut());
362 out
363 }
364
365 pub fn evaluate_nd_zz_to(
367 &self,
368 backend: Option<&GemmBackendHandle>,
369 coeffs: &Slice<Complex<f64>, DynRank>,
370 dim: usize,
371 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
372 ) {
373 InplaceFitter::evaluate_nd_zz_to(self, backend, coeffs, dim, out);
374 }
375
376 pub fn fit_nd_zz(
385 &self,
386 backend: Option<&GemmBackendHandle>,
387 values: &Slice<Complex<f64>, DynRank>,
388 dim: usize,
389 ) -> Tensor<Complex<f64>, DynRank> {
390 let out_shape = build_output_shape(values.shape(), dim, self.basis_size());
391 let mut out = Tensor::<Complex<f64>, DynRank>::zeros(&out_shape[..]);
392 self.fit_nd_zz_to(backend, values, dim, &mut out.expr_mut());
393 out
394 }
395
396 pub fn fit_nd_zz_to(
398 &self,
399 backend: Option<&GemmBackendHandle>,
400 values: &Slice<Complex<f64>, DynRank>,
401 dim: usize,
402 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
403 ) {
404 InplaceFitter::fit_nd_zz_to(self, backend, values, dim, out);
405 }
406}
407
408impl<S: StatisticsType> InplaceFitter for TauSampling<S> {
412 fn n_points(&self) -> usize {
413 self.n_sampling_points()
414 }
415
416 fn basis_size(&self) -> usize {
417 self.basis_size()
418 }
419
420 fn evaluate_nd_dd_to(
421 &self,
422 backend: Option<&GemmBackendHandle>,
423 coeffs: &Slice<f64, DynRank>,
424 dim: usize,
425 out: &mut ViewMut<'_, f64, DynRank>,
426 ) -> bool {
427 self.fitter.evaluate_nd_dd_to(backend, coeffs, dim, out)
428 }
429
430 fn evaluate_nd_zz_to(
431 &self,
432 backend: Option<&GemmBackendHandle>,
433 coeffs: &Slice<Complex<f64>, DynRank>,
434 dim: usize,
435 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
436 ) -> bool {
437 self.fitter.evaluate_nd_zz_to(backend, coeffs, dim, out)
438 }
439
440 fn fit_nd_dd_to(
441 &self,
442 backend: Option<&GemmBackendHandle>,
443 values: &Slice<f64, DynRank>,
444 dim: usize,
445 out: &mut ViewMut<'_, f64, DynRank>,
446 ) -> bool {
447 self.fitter.fit_nd_dd_to(backend, values, dim, out)
448 }
449
450 fn fit_nd_zz_to(
451 &self,
452 backend: Option<&GemmBackendHandle>,
453 values: &Slice<Complex<f64>, DynRank>,
454 dim: usize,
455 out: &mut ViewMut<'_, Complex<f64>, DynRank>,
456 ) -> bool {
457 self.fitter.fit_nd_zz_to(backend, values, dim, out)
458 }
459}
460
461#[cfg(test)]
462#[path = "tau_sampling_tests.rs"]
463mod tests;