1use crate::fitter::{ComplexMatrixFitter, ComplexToRealFitter};
7use crate::freq::MatsubaraFreq;
8use crate::gemm::GemmBackendHandle;
9use crate::traits::StatisticsType;
10use mdarray::{DTensor, DynRank, Shape, Tensor};
11use num_complex::Complex;
12use std::marker::PhantomData;
13
14fn movedim<T: Clone>(arr: &Tensor<T, DynRank>, src: usize, dst: usize) -> Tensor<T, DynRank> {
16 if src == dst {
17 return arr.clone();
18 }
19
20 let rank = arr.rank();
21 assert!(
22 src < rank,
23 "src axis {} out of bounds for rank {}",
24 src,
25 rank
26 );
27 assert!(
28 dst < rank,
29 "dst axis {} out of bounds for rank {}",
30 dst,
31 rank
32 );
33
34 let mut perm = Vec::with_capacity(rank);
36 let mut pos = 0;
37 for i in 0..rank {
38 if i == dst {
39 perm.push(src);
40 } else {
41 if pos == src {
42 pos += 1;
43 }
44 if pos < rank {
45 perm.push(pos);
46 pos += 1;
47 }
48 }
49 }
50
51 arr.permute(&perm[..]).to_tensor()
52}
53
54pub struct MatsubaraSampling<S: StatisticsType> {
58 sampling_points: Vec<MatsubaraFreq<S>>,
59 fitter: ComplexMatrixFitter,
60 _phantom: PhantomData<S>,
61}
62
63impl<S: StatisticsType> MatsubaraSampling<S> {
64 pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
68 where
69 S: 'static,
70 {
71 let sampling_points = basis.default_matsubara_sampling_points(false);
72 Self::with_sampling_points(basis, sampling_points)
73 }
74
75 pub fn with_sampling_points(
77 basis: &impl crate::basis_trait::Basis<S>,
78 mut sampling_points: Vec<MatsubaraFreq<S>>,
79 ) -> Self
80 where
81 S: 'static,
82 {
83 sampling_points.sort();
85
86 let matrix = basis.evaluate_matsubara(&sampling_points);
89
90 let fitter = ComplexMatrixFitter::new(matrix);
92
93 Self {
94 sampling_points,
95 fitter,
96 _phantom: PhantomData,
97 }
98 }
99
100 pub fn from_matrix(
115 mut sampling_points: Vec<MatsubaraFreq<S>>,
116 matrix: DTensor<Complex<f64>, 2>,
117 ) -> Self {
118 assert!(!sampling_points.is_empty(), "No sampling points given");
119 assert_eq!(
120 matrix.shape().0,
121 sampling_points.len(),
122 "Matrix rows ({}) must match number of sampling points ({})",
123 matrix.shape().0,
124 sampling_points.len()
125 );
126
127 sampling_points.sort();
129
130 let fitter = ComplexMatrixFitter::new(matrix);
131
132 Self {
133 sampling_points,
134 fitter,
135 _phantom: PhantomData,
136 }
137 }
138
139 pub fn sampling_points(&self) -> &[MatsubaraFreq<S>] {
141 &self.sampling_points
142 }
143
144 pub fn n_sampling_points(&self) -> usize {
146 self.sampling_points.len()
147 }
148
149 pub fn basis_size(&self) -> usize {
151 self.fitter.basis_size()
152 }
153
154 pub fn evaluate(&self, coeffs: &[Complex<f64>]) -> Vec<Complex<f64>> {
162 self.fitter.evaluate(None, coeffs)
163 }
164
165 pub fn fit(&self, values: &[Complex<f64>]) -> Vec<Complex<f64>> {
173 self.fitter.fit(None, values)
174 }
175
176 pub fn evaluate_nd(
185 &self,
186 backend: Option<&GemmBackendHandle>,
187 coeffs: &Tensor<Complex<f64>, DynRank>,
188 dim: usize,
189 ) -> Tensor<Complex<f64>, DynRank> {
190 let rank = coeffs.rank();
191 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
192
193 let basis_size = self.basis_size();
194 let target_dim_size = coeffs.shape().dim(dim);
195
196 assert_eq!(
197 target_dim_size, basis_size,
198 "coeffs.shape().dim({}) = {} must equal basis_size = {}",
199 dim, target_dim_size, basis_size
200 );
201
202 let coeffs_dim0 = movedim(coeffs, dim, 0);
204
205 let extra_size: usize = coeffs_dim0.len() / basis_size;
207
208 let coeffs_2d_dyn = coeffs_dim0
209 .reshape(&[basis_size, extra_size][..])
210 .to_tensor();
211
212 let coeffs_2d = DTensor::<Complex<f64>, 2>::from_fn([basis_size, extra_size], |idx| {
214 coeffs_2d_dyn[&[idx[0], idx[1]][..]]
215 });
216
217 let n_points = self.n_sampling_points();
219 let result_2d = self.fitter.evaluate_2d(backend, &coeffs_2d);
220
221 let mut result_shape = vec![n_points];
223 coeffs_dim0.shape().with_dims(|dims| {
224 for i in 1..dims.len() {
225 result_shape.push(dims[i]);
226 }
227 });
228
229 let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
230
231 movedim(&result_dim0, 0, dim)
233 }
234
235 pub fn evaluate_nd_real(
248 &self,
249 backend: Option<&GemmBackendHandle>,
250 coeffs: &Tensor<f64, DynRank>,
251 dim: usize,
252 ) -> Tensor<Complex<f64>, DynRank> {
253 let rank = coeffs.rank();
254 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
255
256 let basis_size = self.basis_size();
257 let target_dim_size = coeffs.shape().dim(dim);
258
259 assert_eq!(
260 target_dim_size, basis_size,
261 "coeffs.shape().dim({}) = {} must equal basis_size = {}",
262 dim, target_dim_size, basis_size
263 );
264
265 let coeffs_dim0 = movedim(coeffs, dim, 0);
267
268 let extra_size: usize = coeffs_dim0.len() / basis_size;
270
271 let coeffs_2d_dyn = coeffs_dim0
272 .reshape(&[basis_size, extra_size][..])
273 .to_tensor();
274
275 let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
277 coeffs_2d_dyn[&[idx[0], idx[1]][..]]
278 });
279
280 let values_2d = self.fitter.evaluate_2d_real(backend, &coeffs_2d);
282
283 let n_points = self.n_sampling_points();
285 let mut result_shape = Vec::with_capacity(rank);
286 result_shape.push(n_points);
287 coeffs_dim0.shape().with_dims(|dims| {
288 for i in 1..dims.len() {
289 result_shape.push(dims[i]);
290 }
291 });
292
293 let result_dim0 = values_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
294
295 movedim(&result_dim0, 0, dim)
297 }
298
299 pub fn fit_nd(
309 &self,
310 backend: Option<&GemmBackendHandle>,
311 values: &Tensor<Complex<f64>, DynRank>,
312 dim: usize,
313 ) -> Tensor<Complex<f64>, DynRank> {
314 let rank = values.rank();
315 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
316
317 let n_points = self.n_sampling_points();
318 let target_dim_size = values.shape().dim(dim);
319
320 assert_eq!(
321 target_dim_size, n_points,
322 "values.shape().dim({}) = {} must equal n_sampling_points = {}",
323 dim, target_dim_size, n_points
324 );
325
326 let values_dim0 = movedim(values, dim, 0);
328
329 let extra_size: usize = values_dim0.len() / n_points;
331 let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
332
333 let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
335 values_2d_dyn[&[idx[0], idx[1]][..]]
336 });
337
338 let coeffs_2d = self.fitter.fit_2d(backend, &values_2d);
340
341 let basis_size = self.basis_size();
343 let mut coeffs_shape = vec![basis_size];
344 values_dim0.shape().with_dims(|dims| {
345 for i in 1..dims.len() {
346 coeffs_shape.push(dims[i]);
347 }
348 });
349
350 let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
351
352 movedim(&coeffs_dim0, 0, dim)
354 }
355
356 pub fn fit_nd_real(
369 &self,
370 backend: Option<&GemmBackendHandle>,
371 values: &Tensor<Complex<f64>, DynRank>,
372 dim: usize,
373 ) -> Tensor<f64, DynRank> {
374 let rank = values.rank();
375 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
376
377 let n_points = self.n_sampling_points();
378 let target_dim_size = values.shape().dim(dim);
379
380 assert_eq!(
381 target_dim_size, n_points,
382 "values.shape().dim({}) = {} must equal n_sampling_points = {}",
383 dim, target_dim_size, n_points
384 );
385
386 let values_dim0 = movedim(values, dim, 0);
388
389 let extra_size: usize = values_dim0.len() / n_points;
391 let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
392
393 let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
395 values_2d_dyn[&[idx[0], idx[1]][..]]
396 });
397
398 let coeffs_2d = self.fitter.fit_2d_real(backend, &values_2d);
400
401 let basis_size = self.basis_size();
403 let mut coeffs_shape = vec![basis_size];
404 values_dim0.shape().with_dims(|dims| {
405 for i in 1..dims.len() {
406 coeffs_shape.push(dims[i]);
407 }
408 });
409
410 let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
411
412 movedim(&coeffs_dim0, 0, dim)
414 }
415}
416
417pub struct MatsubaraSamplingPositiveOnly<S: StatisticsType> {
422 sampling_points: Vec<MatsubaraFreq<S>>,
423 fitter: ComplexToRealFitter,
424 _phantom: PhantomData<S>,
425}
426
427impl<S: StatisticsType> MatsubaraSamplingPositiveOnly<S> {
428 pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
433 where
434 S: 'static,
435 {
436 let sampling_points = basis.default_matsubara_sampling_points(true);
437 Self::with_sampling_points(basis, sampling_points)
438 }
439
440 pub fn with_sampling_points(
442 basis: &impl crate::basis_trait::Basis<S>,
443 mut sampling_points: Vec<MatsubaraFreq<S>>,
444 ) -> Self
445 where
446 S: 'static,
447 {
448 sampling_points.sort();
450
451 let matrix = basis.evaluate_matsubara(&sampling_points);
456
457 let fitter = ComplexToRealFitter::new(&matrix);
459
460 Self {
461 sampling_points,
462 fitter,
463 _phantom: PhantomData,
464 }
465 }
466
467 pub fn from_matrix(
482 mut sampling_points: Vec<MatsubaraFreq<S>>,
483 matrix: DTensor<Complex<f64>, 2>,
484 ) -> Self {
485 assert!(!sampling_points.is_empty(), "No sampling points given");
486 assert_eq!(
487 matrix.shape().0,
488 sampling_points.len(),
489 "Matrix rows ({}) must match number of sampling points ({})",
490 matrix.shape().0,
491 sampling_points.len()
492 );
493
494 sampling_points.sort();
496
497 let fitter = ComplexToRealFitter::new(&matrix);
498
499 Self {
500 sampling_points,
501 fitter,
502 _phantom: PhantomData,
503 }
504 }
505
506 pub fn sampling_points(&self) -> &[MatsubaraFreq<S>] {
508 &self.sampling_points
509 }
510
511 pub fn n_sampling_points(&self) -> usize {
513 self.sampling_points.len()
514 }
515
516 pub fn basis_size(&self) -> usize {
518 self.fitter.basis_size()
519 }
520
521 pub fn evaluate(&self, coeffs: &[f64]) -> Vec<Complex<f64>> {
523 self.fitter.evaluate(None, coeffs)
524 }
525
526 pub fn fit(&self, values: &[Complex<f64>]) -> Vec<f64> {
528 self.fitter.fit(None, values)
529 }
530
531 pub fn evaluate_nd(
540 &self,
541 backend: Option<&GemmBackendHandle>,
542 coeffs: &Tensor<f64, DynRank>,
543 dim: usize,
544 ) -> Tensor<Complex<f64>, DynRank> {
545 let rank = coeffs.rank();
546 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
547
548 let basis_size = self.basis_size();
549 let target_dim_size = coeffs.shape().dim(dim);
550
551 assert_eq!(
552 target_dim_size, basis_size,
553 "coeffs.shape().dim({}) = {} must equal basis_size = {}",
554 dim, target_dim_size, basis_size
555 );
556
557 let coeffs_dim0 = movedim(coeffs, dim, 0);
559
560 let extra_size: usize = coeffs_dim0.len() / basis_size;
562
563 let coeffs_2d_dyn = coeffs_dim0
564 .reshape(&[basis_size, extra_size][..])
565 .to_tensor();
566
567 let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
569 coeffs_2d_dyn[&[idx[0], idx[1]][..]]
570 });
571
572 let result_2d = self.fitter.evaluate_2d(backend, &coeffs_2d);
574
575 let n_points = self.n_sampling_points();
577 let mut result_shape = vec![n_points];
578 coeffs_dim0.shape().with_dims(|dims| {
579 for i in 1..dims.len() {
580 result_shape.push(dims[i]);
581 }
582 });
583
584 let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
585
586 movedim(&result_dim0, 0, dim)
588 }
589
590 pub fn fit_nd(
600 &self,
601 backend: Option<&GemmBackendHandle>,
602 values: &Tensor<Complex<f64>, DynRank>,
603 dim: usize,
604 ) -> Tensor<f64, DynRank> {
605 let rank = values.rank();
606 assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
607
608 let n_points = self.n_sampling_points();
609 let target_dim_size = values.shape().dim(dim);
610
611 assert_eq!(
612 target_dim_size, n_points,
613 "values.shape().dim({}) = {} must equal n_sampling_points = {}",
614 dim, target_dim_size, n_points
615 );
616
617 let values_dim0 = movedim(values, dim, 0);
619
620 let extra_size: usize = values_dim0.len() / n_points;
622 let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
623
624 let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
626 values_2d_dyn[&[idx[0], idx[1]][..]]
627 });
628
629 let coeffs_2d = self.fitter.fit_2d(backend, &values_2d);
631
632 let basis_size = self.basis_size();
634 let mut coeffs_shape = vec![basis_size];
635 values_dim0.shape().with_dims(|dims| {
636 for i in 1..dims.len() {
637 coeffs_shape.push(dims[i]);
638 }
639 });
640
641 let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
642
643 movedim(&coeffs_dim0, 0, dim)
645 }
646}
647
648#[cfg(test)]
649#[path = "matsubara_sampling_tests.rs"]
650mod tests;