sparse_ir/
matsubara_sampling.rs

1//! Sparse sampling in Matsubara frequencies
2//!
3//! This module provides Matsubara frequency sampling for transforming between
4//! IR basis coefficients and values at sparse Matsubara frequencies.
5
6use crate::fitter::{ComplexMatrixFitter, ComplexToRealFitter};
7use crate::freq::MatsubaraFreq;
8use crate::gemm::GemmBackendHandle;
9use crate::traits::StatisticsType;
10use mdarray::{DTensor, DynRank, Shape, Tensor};
11use num_complex::Complex;
12use std::marker::PhantomData;
13
14/// Move axis from position src to position dst
15fn movedim<T: Clone>(arr: &Tensor<T, DynRank>, src: usize, dst: usize) -> Tensor<T, DynRank> {
16    if src == dst {
17        return arr.clone();
18    }
19
20    let rank = arr.rank();
21    assert!(
22        src < rank,
23        "src axis {} out of bounds for rank {}",
24        src,
25        rank
26    );
27    assert!(
28        dst < rank,
29        "dst axis {} out of bounds for rank {}",
30        dst,
31        rank
32    );
33
34    // Generate permutation: move src to dst position
35    let mut perm = Vec::with_capacity(rank);
36    let mut pos = 0;
37    for i in 0..rank {
38        if i == dst {
39            perm.push(src);
40        } else {
41            if pos == src {
42                pos += 1;
43            }
44            if pos < rank {
45                perm.push(pos);
46                pos += 1;
47            }
48        }
49    }
50
51    arr.permute(&perm[..]).to_tensor()
52}
53
54/// Matsubara sampling for full frequency range (positive and negative)
55///
56/// General complex problem without symmetry → complex coefficients
57pub struct MatsubaraSampling<S: StatisticsType> {
58    sampling_points: Vec<MatsubaraFreq<S>>,
59    fitter: ComplexMatrixFitter,
60    _phantom: PhantomData<S>,
61}
62
63impl<S: StatisticsType> MatsubaraSampling<S> {
64    /// Create Matsubara sampling with default sampling points
65    ///
66    /// Uses extrema-based sampling point selection (symmetric: positive and negative frequencies).
67    pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
68    where
69        S: 'static,
70    {
71        let sampling_points = basis.default_matsubara_sampling_points(false);
72        Self::with_sampling_points(basis, sampling_points)
73    }
74
75    /// Create Matsubara sampling with custom sampling points
76    pub fn with_sampling_points(
77        basis: &impl crate::basis_trait::Basis<S>,
78        mut sampling_points: Vec<MatsubaraFreq<S>>,
79    ) -> Self
80    where
81        S: 'static,
82    {
83        // Sort sampling points
84        sampling_points.sort();
85
86        // Evaluate matrix at sampling points
87        // Use Basis trait's evaluate_matsubara method
88        let matrix = basis.evaluate_matsubara(&sampling_points);
89
90        // Create fitter (complex → complex, no symmetry)
91        let fitter = ComplexMatrixFitter::new(matrix);
92
93        Self {
94            sampling_points,
95            fitter,
96            _phantom: PhantomData,
97        }
98    }
99
100    /// Create Matsubara sampling with custom sampling points and pre-computed matrix
101    ///
102    /// This constructor is useful when the sampling matrix is already computed
103    /// (e.g., from external sources or for testing).
104    ///
105    /// # Arguments
106    /// * `sampling_points` - Matsubara frequency sampling points
107    /// * `matrix` - Pre-computed sampling matrix (n_points × basis_size)
108    ///
109    /// # Returns
110    /// A new MatsubaraSampling object
111    ///
112    /// # Panics
113    /// Panics if `sampling_points` is empty or if matrix dimensions don't match
114    pub fn from_matrix(
115        mut sampling_points: Vec<MatsubaraFreq<S>>,
116        matrix: DTensor<Complex<f64>, 2>,
117    ) -> Self {
118        assert!(!sampling_points.is_empty(), "No sampling points given");
119        assert_eq!(
120            matrix.shape().0,
121            sampling_points.len(),
122            "Matrix rows ({}) must match number of sampling points ({})",
123            matrix.shape().0,
124            sampling_points.len()
125        );
126
127        // Sort sampling points
128        sampling_points.sort();
129
130        let fitter = ComplexMatrixFitter::new(matrix);
131
132        Self {
133            sampling_points,
134            fitter,
135            _phantom: PhantomData,
136        }
137    }
138
139    /// Get sampling points
140    pub fn sampling_points(&self) -> &[MatsubaraFreq<S>] {
141        &self.sampling_points
142    }
143
144    /// Number of sampling points
145    pub fn n_sampling_points(&self) -> usize {
146        self.sampling_points.len()
147    }
148
149    /// Basis size
150    pub fn basis_size(&self) -> usize {
151        self.fitter.basis_size()
152    }
153
154    /// Evaluate complex basis coefficients at sampling points
155    ///
156    /// # Arguments
157    /// * `coeffs` - Complex basis coefficients (length = basis_size)
158    ///
159    /// # Returns
160    /// Complex values at Matsubara frequencies (length = n_sampling_points)
161    pub fn evaluate(&self, coeffs: &[Complex<f64>]) -> Vec<Complex<f64>> {
162        self.fitter.evaluate(None, coeffs)
163    }
164
165    /// Fit complex basis coefficients from values at sampling points
166    ///
167    /// # Arguments
168    /// * `values` - Complex values at Matsubara frequencies (length = n_sampling_points)
169    ///
170    /// # Returns
171    /// Fitted complex basis coefficients (length = basis_size)
172    pub fn fit(&self, values: &[Complex<f64>]) -> Vec<Complex<f64>> {
173        self.fitter.fit(None, values)
174    }
175
176    /// Evaluate N-dimensional array of complex basis coefficients at sampling points
177    ///
178    /// # Arguments
179    /// * `coeffs` - N-dimensional tensor of complex basis coefficients
180    /// * `dim` - Dimension along which to evaluate (must have size = basis_size)
181    ///
182    /// # Returns
183    /// N-dimensional tensor of complex values at Matsubara frequencies
184    pub fn evaluate_nd(
185        &self,
186        backend: Option<&GemmBackendHandle>,
187        coeffs: &Tensor<Complex<f64>, DynRank>,
188        dim: usize,
189    ) -> Tensor<Complex<f64>, DynRank> {
190        let rank = coeffs.rank();
191        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
192
193        let basis_size = self.basis_size();
194        let target_dim_size = coeffs.shape().dim(dim);
195
196        assert_eq!(
197            target_dim_size, basis_size,
198            "coeffs.shape().dim({}) = {} must equal basis_size = {}",
199            dim, target_dim_size, basis_size
200        );
201
202        // 1. Move target dimension to position 0
203        let coeffs_dim0 = movedim(coeffs, dim, 0);
204
205        // 2. Reshape to 2D: (basis_size, extra_size)
206        let extra_size: usize = coeffs_dim0.len() / basis_size;
207
208        let coeffs_2d_dyn = coeffs_dim0
209            .reshape(&[basis_size, extra_size][..])
210            .to_tensor();
211
212        // 3. Convert to DTensor and evaluate using GEMM
213        let coeffs_2d = DTensor::<Complex<f64>, 2>::from_fn([basis_size, extra_size], |idx| {
214            coeffs_2d_dyn[&[idx[0], idx[1]][..]]
215        });
216
217        // Use fitter's efficient 2D evaluate (GEMM-based)
218        let n_points = self.n_sampling_points();
219        let result_2d = self.fitter.evaluate_2d(backend, &coeffs_2d);
220
221        // 4. Reshape back to N-D with n_points at position 0
222        let mut result_shape = vec![n_points];
223        coeffs_dim0.shape().with_dims(|dims| {
224            for i in 1..dims.len() {
225                result_shape.push(dims[i]);
226            }
227        });
228
229        let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
230
231        // 5. Move dimension 0 back to original position dim
232        movedim(&result_dim0, 0, dim)
233    }
234
235    /// Evaluate real basis coefficients at Matsubara sampling points (N-dimensional)
236    ///
237    /// This method takes real coefficients and produces complex values, useful when
238    /// working with symmetry-exploiting representations or real-valued IR coefficients.
239    ///
240    /// # Arguments
241    /// * `backend` - Optional GEMM backend handle (None uses default)
242    /// * `coeffs` - N-dimensional tensor of real basis coefficients
243    /// * `dim` - Dimension along which to evaluate (must have size = basis_size)
244    ///
245    /// # Returns
246    /// N-dimensional tensor of complex values at Matsubara frequencies
247    pub fn evaluate_nd_real(
248        &self,
249        backend: Option<&GemmBackendHandle>,
250        coeffs: &Tensor<f64, DynRank>,
251        dim: usize,
252    ) -> Tensor<Complex<f64>, DynRank> {
253        let rank = coeffs.rank();
254        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
255
256        let basis_size = self.basis_size();
257        let target_dim_size = coeffs.shape().dim(dim);
258
259        assert_eq!(
260            target_dim_size, basis_size,
261            "coeffs.shape().dim({}) = {} must equal basis_size = {}",
262            dim, target_dim_size, basis_size
263        );
264
265        // 1. Move target dimension to position 0
266        let coeffs_dim0 = movedim(coeffs, dim, 0);
267
268        // 2. Reshape to 2D: (basis_size, extra_size)
269        let extra_size: usize = coeffs_dim0.len() / basis_size;
270
271        let coeffs_2d_dyn = coeffs_dim0
272            .reshape(&[basis_size, extra_size][..])
273            .to_tensor();
274
275        // 3. Convert to DTensor and evaluate using ComplexMatrixFitter
276        let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
277            coeffs_2d_dyn[&[idx[0], idx[1]][..]]
278        });
279
280        // 4. Evaluate: values = A * coeffs (A is complex, coeffs is real)
281        let values_2d = self.fitter.evaluate_2d_real(backend, &coeffs_2d);
282
283        // 5. Reshape result back to N-D with first dimension = n_sampling_points
284        let n_points = self.n_sampling_points();
285        let mut result_shape = Vec::with_capacity(rank);
286        result_shape.push(n_points);
287        coeffs_dim0.shape().with_dims(|dims| {
288            for i in 1..dims.len() {
289                result_shape.push(dims[i]);
290            }
291        });
292
293        let result_dim0 = values_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
294
295        // 6. Move dimension 0 back to original position dim
296        movedim(&result_dim0, 0, dim)
297    }
298
299    /// Fit N-dimensional array of complex values to complex basis coefficients
300    ///
301    /// # Arguments
302    /// * `backend` - Optional GEMM backend handle (None uses default)
303    /// * `values` - N-dimensional tensor of complex values at Matsubara frequencies
304    /// * `dim` - Dimension along which to fit (must have size = n_sampling_points)
305    ///
306    /// # Returns
307    /// N-dimensional tensor of complex basis coefficients
308    pub fn fit_nd(
309        &self,
310        backend: Option<&GemmBackendHandle>,
311        values: &Tensor<Complex<f64>, DynRank>,
312        dim: usize,
313    ) -> Tensor<Complex<f64>, DynRank> {
314        let rank = values.rank();
315        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
316
317        let n_points = self.n_sampling_points();
318        let target_dim_size = values.shape().dim(dim);
319
320        assert_eq!(
321            target_dim_size, n_points,
322            "values.shape().dim({}) = {} must equal n_sampling_points = {}",
323            dim, target_dim_size, n_points
324        );
325
326        // 1. Move target dimension to position 0
327        let values_dim0 = movedim(values, dim, 0);
328
329        // 2. Reshape to 2D: (n_points, extra_size)
330        let extra_size: usize = values_dim0.len() / n_points;
331        let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
332
333        // 3. Convert to DTensor and fit using GEMM
334        let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
335            values_2d_dyn[&[idx[0], idx[1]][..]]
336        });
337
338        // Use fitter's efficient 2D fit (GEMM-based)
339        let coeffs_2d = self.fitter.fit_2d(backend, &values_2d);
340
341        // 4. Reshape back to N-D with basis_size at position 0
342        let basis_size = self.basis_size();
343        let mut coeffs_shape = vec![basis_size];
344        values_dim0.shape().with_dims(|dims| {
345            for i in 1..dims.len() {
346                coeffs_shape.push(dims[i]);
347            }
348        });
349
350        let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
351
352        // 5. Move dimension 0 back to original position dim
353        movedim(&coeffs_dim0, 0, dim)
354    }
355
356    /// Fit N-dimensional array of complex values to real basis coefficients
357    ///
358    /// This method fits complex Matsubara values to real IR coefficients.
359    /// Takes the real part of the least-squares solution.
360    ///
361    /// # Arguments
362    /// * `backend` - Optional GEMM backend handle (None uses default)
363    /// * `values` - N-dimensional tensor of complex values at Matsubara frequencies
364    /// * `dim` - Dimension along which to fit (must have size = n_sampling_points)
365    ///
366    /// # Returns
367    /// N-dimensional tensor of real basis coefficients
368    pub fn fit_nd_real(
369        &self,
370        backend: Option<&GemmBackendHandle>,
371        values: &Tensor<Complex<f64>, DynRank>,
372        dim: usize,
373    ) -> Tensor<f64, DynRank> {
374        let rank = values.rank();
375        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
376
377        let n_points = self.n_sampling_points();
378        let target_dim_size = values.shape().dim(dim);
379
380        assert_eq!(
381            target_dim_size, n_points,
382            "values.shape().dim({}) = {} must equal n_sampling_points = {}",
383            dim, target_dim_size, n_points
384        );
385
386        // 1. Move target dimension to position 0
387        let values_dim0 = movedim(values, dim, 0);
388
389        // 2. Reshape to 2D: (n_points, extra_size)
390        let extra_size: usize = values_dim0.len() / n_points;
391        let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
392
393        // 3. Convert to DTensor and fit
394        let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
395            values_2d_dyn[&[idx[0], idx[1]][..]]
396        });
397
398        // Use fitter's fit_2d_real method
399        let coeffs_2d = self.fitter.fit_2d_real(backend, &values_2d);
400
401        // 4. Reshape back to N-D with basis_size at position 0
402        let basis_size = self.basis_size();
403        let mut coeffs_shape = vec![basis_size];
404        values_dim0.shape().with_dims(|dims| {
405            for i in 1..dims.len() {
406                coeffs_shape.push(dims[i]);
407            }
408        });
409
410        let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
411
412        // 5. Move dimension 0 back to original position dim
413        movedim(&coeffs_dim0, 0, dim)
414    }
415}
416
417/// Matsubara sampling for positive frequencies only
418///
419/// Exploits symmetry to reconstruct real coefficients from positive frequencies only.
420/// Supports: {0, 1, 2, 3, ...} (no negative frequencies)
421pub struct MatsubaraSamplingPositiveOnly<S: StatisticsType> {
422    sampling_points: Vec<MatsubaraFreq<S>>,
423    fitter: ComplexToRealFitter,
424    _phantom: PhantomData<S>,
425}
426
427impl<S: StatisticsType> MatsubaraSamplingPositiveOnly<S> {
428    /// Create Matsubara sampling with default positive-only sampling points
429    ///
430    /// Uses extrema-based sampling point selection (positive frequencies only).
431    /// Exploits symmetry to reconstruct real coefficients.
432    pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
433    where
434        S: 'static,
435    {
436        let sampling_points = basis.default_matsubara_sampling_points(true);
437        Self::with_sampling_points(basis, sampling_points)
438    }
439
440    /// Create Matsubara sampling with custom positive-only sampling points
441    pub fn with_sampling_points(
442        basis: &impl crate::basis_trait::Basis<S>,
443        mut sampling_points: Vec<MatsubaraFreq<S>>,
444    ) -> Self
445    where
446        S: 'static,
447    {
448        // Sort and validate (all n >= 0)
449        sampling_points.sort();
450
451        // TODO: Validate that all points are non-negative
452
453        // Evaluate matrix at sampling points
454        // Use Basis trait's evaluate_matsubara method
455        let matrix = basis.evaluate_matsubara(&sampling_points);
456
457        // Create fitter (complex → real, exploits symmetry)
458        let fitter = ComplexToRealFitter::new(&matrix);
459
460        Self {
461            sampling_points,
462            fitter,
463            _phantom: PhantomData,
464        }
465    }
466
467    /// Create Matsubara sampling (positive-only) with custom sampling points and pre-computed matrix
468    ///
469    /// This constructor is useful when the sampling matrix is already computed.
470    /// Uses symmetry to fit real coefficients from complex values at positive frequencies.
471    ///
472    /// # Arguments
473    /// * `sampling_points` - Matsubara frequency sampling points (should be positive)
474    /// * `matrix` - Pre-computed sampling matrix (n_points × basis_size)
475    ///
476    /// # Returns
477    /// A new MatsubaraSamplingPositiveOnly object
478    ///
479    /// # Panics
480    /// Panics if `sampling_points` is empty or if matrix dimensions don't match
481    pub fn from_matrix(
482        mut sampling_points: Vec<MatsubaraFreq<S>>,
483        matrix: DTensor<Complex<f64>, 2>,
484    ) -> Self {
485        assert!(!sampling_points.is_empty(), "No sampling points given");
486        assert_eq!(
487            matrix.shape().0,
488            sampling_points.len(),
489            "Matrix rows ({}) must match number of sampling points ({})",
490            matrix.shape().0,
491            sampling_points.len()
492        );
493
494        // Sort sampling points
495        sampling_points.sort();
496
497        let fitter = ComplexToRealFitter::new(&matrix);
498
499        Self {
500            sampling_points,
501            fitter,
502            _phantom: PhantomData,
503        }
504    }
505
506    /// Get sampling points
507    pub fn sampling_points(&self) -> &[MatsubaraFreq<S>] {
508        &self.sampling_points
509    }
510
511    /// Number of sampling points
512    pub fn n_sampling_points(&self) -> usize {
513        self.sampling_points.len()
514    }
515
516    /// Basis size
517    pub fn basis_size(&self) -> usize {
518        self.fitter.basis_size()
519    }
520
521    /// Evaluate basis coefficients at sampling points
522    pub fn evaluate(&self, coeffs: &[f64]) -> Vec<Complex<f64>> {
523        self.fitter.evaluate(None, coeffs)
524    }
525
526    /// Fit basis coefficients from values at sampling points
527    pub fn fit(&self, values: &[Complex<f64>]) -> Vec<f64> {
528        self.fitter.fit(None, values)
529    }
530
531    /// Evaluate N-dimensional array of real basis coefficients at sampling points
532    ///
533    /// # Arguments
534    /// * `coeffs` - N-dimensional tensor of real basis coefficients
535    /// * `dim` - Dimension along which to evaluate (must have size = basis_size)
536    ///
537    /// # Returns
538    /// N-dimensional tensor of complex values at Matsubara frequencies
539    pub fn evaluate_nd(
540        &self,
541        backend: Option<&GemmBackendHandle>,
542        coeffs: &Tensor<f64, DynRank>,
543        dim: usize,
544    ) -> Tensor<Complex<f64>, DynRank> {
545        let rank = coeffs.rank();
546        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
547
548        let basis_size = self.basis_size();
549        let target_dim_size = coeffs.shape().dim(dim);
550
551        assert_eq!(
552            target_dim_size, basis_size,
553            "coeffs.shape().dim({}) = {} must equal basis_size = {}",
554            dim, target_dim_size, basis_size
555        );
556
557        // 1. Move target dimension to position 0
558        let coeffs_dim0 = movedim(coeffs, dim, 0);
559
560        // 2. Reshape to 2D: (basis_size, extra_size)
561        let extra_size: usize = coeffs_dim0.len() / basis_size;
562
563        let coeffs_2d_dyn = coeffs_dim0
564            .reshape(&[basis_size, extra_size][..])
565            .to_tensor();
566
567        // 3. Convert to DTensor and evaluate using GEMM
568        let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
569            coeffs_2d_dyn[&[idx[0], idx[1]][..]]
570        });
571
572        // Use fitter's efficient 2D evaluate (GEMM-based)
573        let result_2d = self.fitter.evaluate_2d(backend, &coeffs_2d);
574
575        // 4. Reshape back to N-D with n_points at position 0
576        let n_points = self.n_sampling_points();
577        let mut result_shape = vec![n_points];
578        coeffs_dim0.shape().with_dims(|dims| {
579            for i in 1..dims.len() {
580                result_shape.push(dims[i]);
581            }
582        });
583
584        let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
585
586        // 5. Move dimension 0 back to original position dim
587        movedim(&result_dim0, 0, dim)
588    }
589
590    /// Fit N-dimensional array of complex values to real basis coefficients
591    ///
592    /// # Arguments
593    /// * `backend` - Optional GEMM backend handle (None uses default)
594    /// * `values` - N-dimensional tensor of complex values at Matsubara frequencies
595    /// * `dim` - Dimension along which to fit (must have size = n_sampling_points)
596    ///
597    /// # Returns
598    /// N-dimensional tensor of real basis coefficients
599    pub fn fit_nd(
600        &self,
601        backend: Option<&GemmBackendHandle>,
602        values: &Tensor<Complex<f64>, DynRank>,
603        dim: usize,
604    ) -> Tensor<f64, DynRank> {
605        let rank = values.rank();
606        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
607
608        let n_points = self.n_sampling_points();
609        let target_dim_size = values.shape().dim(dim);
610
611        assert_eq!(
612            target_dim_size, n_points,
613            "values.shape().dim({}) = {} must equal n_sampling_points = {}",
614            dim, target_dim_size, n_points
615        );
616
617        // 1. Move target dimension to position 0
618        let values_dim0 = movedim(values, dim, 0);
619
620        // 2. Reshape to 2D: (n_points, extra_size)
621        let extra_size: usize = values_dim0.len() / n_points;
622        let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
623
624        // 3. Convert to DTensor and fit using GEMM
625        let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
626            values_2d_dyn[&[idx[0], idx[1]][..]]
627        });
628
629        // Use fitter's efficient 2D fit (GEMM-based)
630        let coeffs_2d = self.fitter.fit_2d(backend, &values_2d);
631
632        // 4. Reshape back to N-D with basis_size at position 0
633        let basis_size = self.basis_size();
634        let mut coeffs_shape = vec![basis_size];
635        values_dim0.shape().with_dims(|dims| {
636            for i in 1..dims.len() {
637                coeffs_shape.push(dims[i]);
638            }
639        });
640
641        let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
642
643        // 5. Move dimension 0 back to original position dim
644        movedim(&coeffs_dim0, 0, dim)
645    }
646}
647
648#[cfg(test)]
649#[path = "matsubara_sampling_tests.rs"]
650mod tests;