sparse_ir/
matsubara_sampling.rs

1//! Sparse sampling in Matsubara frequencies
2//!
3//! This module provides Matsubara frequency sampling for transforming between
4//! IR basis coefficients and values at sparse Matsubara frequencies.
5
6use crate::fitters::{ComplexMatrixFitter, ComplexToRealFitter, InplaceFitter};
7use crate::freq::MatsubaraFreq;
8use crate::gemm::GemmBackendHandle;
9use crate::traits::StatisticsType;
10use mdarray::{DTensor, DynRank, Shape, Slice, Tensor, ViewMut};
11use num_complex::Complex;
12use std::marker::PhantomData;
13
14/// Move axis from position src to position dst
15fn movedim<T: Clone>(arr: &Slice<T, DynRank>, src: usize, dst: usize) -> Tensor<T, DynRank> {
16    if src == dst {
17        return arr.to_tensor();
18    }
19
20    let rank = arr.rank();
21    assert!(
22        src < rank,
23        "src axis {} out of bounds for rank {}",
24        src,
25        rank
26    );
27    assert!(
28        dst < rank,
29        "dst axis {} out of bounds for rank {}",
30        dst,
31        rank
32    );
33
34    // Generate permutation: move src to dst position
35    let mut perm = Vec::with_capacity(rank);
36    let mut pos = 0;
37    for i in 0..rank {
38        if i == dst {
39            perm.push(src);
40        } else {
41            if pos == src {
42                pos += 1;
43            }
44            if pos < rank {
45                perm.push(pos);
46                pos += 1;
47            }
48        }
49    }
50
51    arr.permute(&perm[..]).to_tensor()
52}
53
54/// Matsubara sampling for full frequency range (positive and negative)
55///
56/// General complex problem without symmetry → complex coefficients
57pub struct MatsubaraSampling<S: StatisticsType> {
58    sampling_points: Vec<MatsubaraFreq<S>>,
59    fitter: ComplexMatrixFitter,
60    _phantom: PhantomData<S>,
61}
62
63impl<S: StatisticsType> MatsubaraSampling<S> {
64    /// Create Matsubara sampling with default sampling points
65    ///
66    /// Uses extrema-based sampling point selection (symmetric: positive and negative frequencies).
67    pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
68    where
69        S: 'static,
70    {
71        let sampling_points = basis.default_matsubara_sampling_points(false);
72        Self::with_sampling_points(basis, sampling_points)
73    }
74
75    /// Create Matsubara sampling with custom sampling points
76    pub fn with_sampling_points(
77        basis: &impl crate::basis_trait::Basis<S>,
78        mut sampling_points: Vec<MatsubaraFreq<S>>,
79    ) -> Self
80    where
81        S: 'static,
82    {
83        // Sort sampling points
84        sampling_points.sort();
85
86        // Evaluate matrix at sampling points
87        // Use Basis trait's evaluate_matsubara method
88        let matrix = basis.evaluate_matsubara(&sampling_points);
89
90        // Create fitter (complex → complex, no symmetry)
91        let fitter = ComplexMatrixFitter::new(matrix);
92
93        Self {
94            sampling_points,
95            fitter,
96            _phantom: PhantomData,
97        }
98    }
99
100    /// Create Matsubara sampling with custom sampling points and pre-computed matrix
101    ///
102    /// This constructor is useful when the sampling matrix is already computed
103    /// (e.g., from external sources or for testing).
104    ///
105    /// # Arguments
106    /// * `sampling_points` - Matsubara frequency sampling points
107    /// * `matrix` - Pre-computed sampling matrix (n_points × basis_size)
108    ///
109    /// # Returns
110    /// A new MatsubaraSampling object
111    ///
112    /// # Panics
113    /// Panics if `sampling_points` is empty or if matrix dimensions don't match
114    pub fn from_matrix(
115        mut sampling_points: Vec<MatsubaraFreq<S>>,
116        matrix: DTensor<Complex<f64>, 2>,
117    ) -> Self {
118        assert!(!sampling_points.is_empty(), "No sampling points given");
119        assert_eq!(
120            matrix.shape().0,
121            sampling_points.len(),
122            "Matrix rows ({}) must match number of sampling points ({})",
123            matrix.shape().0,
124            sampling_points.len()
125        );
126
127        // Sort sampling points
128        sampling_points.sort();
129
130        let fitter = ComplexMatrixFitter::new(matrix);
131
132        Self {
133            sampling_points,
134            fitter,
135            _phantom: PhantomData,
136        }
137    }
138
139    /// Get sampling points
140    pub fn sampling_points(&self) -> &[MatsubaraFreq<S>] {
141        &self.sampling_points
142    }
143
144    /// Number of sampling points
145    pub fn n_sampling_points(&self) -> usize {
146        self.sampling_points.len()
147    }
148
149    /// Basis size
150    pub fn basis_size(&self) -> usize {
151        self.fitter.basis_size()
152    }
153
154    /// Get the sampling matrix
155    pub fn matrix(&self) -> &DTensor<Complex<f64>, 2> {
156        &self.fitter.matrix
157    }
158
159    /// Evaluate complex basis coefficients at sampling points
160    ///
161    /// # Arguments
162    /// * `coeffs` - Complex basis coefficients (length = basis_size)
163    ///
164    /// # Returns
165    /// Complex values at Matsubara frequencies (length = n_sampling_points)
166    pub fn evaluate(&self, coeffs: &[Complex<f64>]) -> Vec<Complex<f64>> {
167        self.fitter.evaluate(None, coeffs)
168    }
169
170    /// Fit complex basis coefficients from values at sampling points
171    ///
172    /// # Arguments
173    /// * `values` - Complex values at Matsubara frequencies (length = n_sampling_points)
174    ///
175    /// # Returns
176    /// Fitted complex basis coefficients (length = basis_size)
177    pub fn fit(&self, values: &[Complex<f64>]) -> Vec<Complex<f64>> {
178        self.fitter.fit(None, values)
179    }
180
181    /// Evaluate N-dimensional array of basis coefficients at sampling points
182    ///
183    /// Supports both real (`f64`) and complex (`Complex<f64>`) coefficients.
184    /// Always returns complex values at Matsubara frequencies.
185    ///
186    /// # Type Parameters
187    /// * `T` - Element type (f64 or Complex<f64>)
188    ///
189    /// # Arguments
190    /// * `backend` - Optional GEMM backend handle (None uses default)
191    /// * `coeffs` - N-dimensional tensor of basis coefficients
192    /// * `dim` - Dimension along which to evaluate (must have size = basis_size)
193    ///
194    /// # Returns
195    /// N-dimensional tensor of complex values at Matsubara frequencies
196    ///
197    /// # Example
198    /// ```ignore
199    /// use num_complex::Complex;
200    ///
201    /// // Real coefficients
202    /// let values = matsubara_sampling.evaluate_nd::<f64>(None, &coeffs_real, 0);
203    ///
204    /// // Complex coefficients
205    /// let values = matsubara_sampling.evaluate_nd::<Complex<f64>>(None, &coeffs_complex, 0);
206    /// ```
207    /// Evaluate N-D coefficients for the real case `T = f64`
208    fn evaluate_nd_impl_real(
209        &self,
210        backend: Option<&GemmBackendHandle>,
211        coeffs: &Slice<f64, DynRank>,
212        dim: usize,
213    ) -> Tensor<Complex<f64>, DynRank> {
214        let rank = coeffs.rank();
215        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
216
217        let basis_size = self.basis_size();
218        let target_dim_size = coeffs.shape().dim(dim);
219
220        assert_eq!(
221            target_dim_size, basis_size,
222            "coeffs.shape().dim({}) = {} must equal basis_size = {}",
223            dim, target_dim_size, basis_size
224        );
225
226        // 1. Move target dimension to position 0
227        let coeffs_dim0 = movedim(coeffs, dim, 0);
228
229        // 2. Reshape to 2D: (basis_size, extra_size)
230        let extra_size: usize = coeffs_dim0.len() / basis_size;
231
232        let coeffs_2d_dyn = coeffs_dim0
233            .reshape(&[basis_size, extra_size][..])
234            .to_tensor();
235
236        // 3. Convert to DTensor and evaluate using evaluate_2d_real
237        let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
238            coeffs_2d_dyn[&[idx[0], idx[1]][..]]
239        });
240        let coeffs_2d_view = coeffs_2d.view(.., ..);
241        let result_2d = self.fitter.evaluate_2d_real(backend, &coeffs_2d_view);
242
243        // 4. Reshape back to N-D with n_points at position 0
244        let n_points = self.n_sampling_points();
245        let mut result_shape = vec![n_points];
246        coeffs_dim0.shape().with_dims(|dims| {
247            for i in 1..dims.len() {
248                result_shape.push(dims[i]);
249            }
250        });
251
252        let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
253
254        // 5. Move dimension 0 back to original position dim
255        movedim(&result_dim0, 0, dim)
256    }
257
258    /// Evaluate N-D coefficients for the complex case `T = Complex<f64>`
259    fn evaluate_nd_impl_complex(
260        &self,
261        backend: Option<&GemmBackendHandle>,
262        coeffs: &Slice<Complex<f64>, DynRank>,
263        dim: usize,
264    ) -> Tensor<Complex<f64>, DynRank> {
265        let rank = coeffs.rank();
266        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
267
268        let basis_size = self.basis_size();
269        let target_dim_size = coeffs.shape().dim(dim);
270
271        assert_eq!(
272            target_dim_size, basis_size,
273            "coeffs.shape().dim({}) = {} must equal basis_size = {}",
274            dim, target_dim_size, basis_size
275        );
276
277        // 1. Move target dimension to position 0
278        let coeffs_dim0 = movedim(coeffs, dim, 0);
279
280        // 2. Reshape to 2D: (basis_size, extra_size)
281        let extra_size: usize = coeffs_dim0.len() / basis_size;
282
283        let coeffs_2d_dyn = coeffs_dim0
284            .reshape(&[basis_size, extra_size][..])
285            .to_tensor();
286
287        // 3. Convert to DTensor and evaluate using evaluate_2d
288        let coeffs_2d = DTensor::<Complex<f64>, 2>::from_fn([basis_size, extra_size], |idx| {
289            coeffs_2d_dyn[&[idx[0], idx[1]][..]]
290        });
291        let coeffs_2d_view = coeffs_2d.view(.., ..);
292        let result_2d = self.fitter.evaluate_2d(backend, &coeffs_2d_view);
293
294        // 4. Reshape back to N-D with n_points at position 0
295        let n_points = self.n_sampling_points();
296        let mut result_shape = vec![n_points];
297        coeffs_dim0.shape().with_dims(|dims| {
298            for i in 1..dims.len() {
299                result_shape.push(dims[i]);
300            }
301        });
302
303        let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
304
305        // 5. Move dimension 0 back to original position dim
306        movedim(&result_dim0, 0, dim)
307    }
308
309    pub fn evaluate_nd<T>(
310        &self,
311        backend: Option<&GemmBackendHandle>,
312        coeffs: &Slice<T, DynRank>,
313        dim: usize,
314    ) -> Tensor<Complex<f64>, DynRank>
315    where
316        T: Copy + 'static,
317    {
318        use std::any::TypeId;
319
320        if TypeId::of::<T>() == TypeId::of::<f64>() {
321            // Safe: TypeId check ensures T == f64 at runtime
322            // We need unsafe because Rust can't statically prove this
323            let coeffs_f64 =
324                unsafe { &*(coeffs as *const Slice<T, DynRank> as *const Slice<f64, DynRank>) };
325            self.evaluate_nd_impl_real(backend, coeffs_f64, dim)
326        } else if TypeId::of::<T>() == TypeId::of::<Complex<f64>>() {
327            // Safe: TypeId check ensures T == Complex<f64> at runtime
328            // We need unsafe because Rust can't statically prove this
329            let coeffs_complex = unsafe {
330                &*(coeffs as *const Slice<T, DynRank> as *const Slice<Complex<f64>, DynRank>)
331            };
332            self.evaluate_nd_impl_complex(backend, coeffs_complex, dim)
333        } else {
334            panic!("Unsupported type for evaluate_nd: must be f64 or Complex<f64>");
335        }
336    }
337
338    /// Evaluate real basis coefficients at Matsubara sampling points (N-dimensional)
339    ///
340    /// This method takes real coefficients and produces complex values, useful when
341    /// working with symmetry-exploiting representations or real-valued IR coefficients.
342    ///
343    /// # Arguments
344    /// * `backend` - Optional GEMM backend handle (None uses default)
345    /// * `coeffs` - N-dimensional tensor of real basis coefficients
346    /// * `dim` - Dimension along which to evaluate (must have size = basis_size)
347    ///
348    /// # Returns
349    /// N-dimensional tensor of complex values at Matsubara frequencies
350    pub fn evaluate_nd_real(
351        &self,
352        backend: Option<&GemmBackendHandle>,
353        coeffs: &Tensor<f64, DynRank>,
354        dim: usize,
355    ) -> Tensor<Complex<f64>, DynRank> {
356        let rank = coeffs.rank();
357        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
358
359        let basis_size = self.basis_size();
360        let target_dim_size = coeffs.shape().dim(dim);
361
362        assert_eq!(
363            target_dim_size, basis_size,
364            "coeffs.shape().dim({}) = {} must equal basis_size = {}",
365            dim, target_dim_size, basis_size
366        );
367
368        // 1. Move target dimension to position 0
369        let coeffs_dim0 = movedim(coeffs, dim, 0);
370
371        // 2. Reshape to 2D: (basis_size, extra_size)
372        let extra_size: usize = coeffs_dim0.len() / basis_size;
373
374        let coeffs_2d_dyn = coeffs_dim0
375            .reshape(&[basis_size, extra_size][..])
376            .to_tensor();
377
378        // 3. Convert to DTensor and evaluate using ComplexMatrixFitter
379        let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
380            coeffs_2d_dyn[&[idx[0], idx[1]][..]]
381        });
382
383        // 4. Evaluate: values = A * coeffs (A is complex, coeffs is real)
384        let coeffs_2d_view = coeffs_2d.view(.., ..);
385        let values_2d = self.fitter.evaluate_2d_real(backend, &coeffs_2d_view);
386
387        // 5. Reshape result back to N-D with first dimension = n_sampling_points
388        let n_points = self.n_sampling_points();
389        let mut result_shape = Vec::with_capacity(rank);
390        result_shape.push(n_points);
391        coeffs_dim0.shape().with_dims(|dims| {
392            for i in 1..dims.len() {
393                result_shape.push(dims[i]);
394            }
395        });
396
397        let result_dim0 = values_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
398
399        // 6. Move dimension 0 back to original position dim
400        movedim(&result_dim0, 0, dim)
401    }
402
403    /// Fit N-dimensional array of complex values to complex basis coefficients
404    ///
405    /// # Arguments
406    /// * `backend` - Optional GEMM backend handle (None uses default)
407    /// * `values` - N-dimensional tensor of complex values at Matsubara frequencies
408    /// * `dim` - Dimension along which to fit (must have size = n_sampling_points)
409    ///
410    /// # Returns
411    /// N-dimensional tensor of complex basis coefficients
412    pub fn fit_nd(
413        &self,
414        backend: Option<&GemmBackendHandle>,
415        values: &Tensor<Complex<f64>, DynRank>,
416        dim: usize,
417    ) -> Tensor<Complex<f64>, DynRank> {
418        let rank = values.rank();
419        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
420
421        let n_points = self.n_sampling_points();
422        let target_dim_size = values.shape().dim(dim);
423
424        assert_eq!(
425            target_dim_size, n_points,
426            "values.shape().dim({}) = {} must equal n_sampling_points = {}",
427            dim, target_dim_size, n_points
428        );
429
430        // 1. Move target dimension to position 0
431        let values_dim0 = movedim(values, dim, 0);
432
433        // 2. Reshape to 2D: (n_points, extra_size)
434        let extra_size: usize = values_dim0.len() / n_points;
435        let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
436
437        // 3. Convert to DTensor and fit using GEMM
438        let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
439            values_2d_dyn[&[idx[0], idx[1]][..]]
440        });
441
442        // Use fitter's efficient 2D fit (GEMM-based)
443        let values_2d_view = values_2d.view(.., ..);
444        let coeffs_2d = self.fitter.fit_2d(backend, &values_2d_view);
445
446        // 4. Reshape back to N-D with basis_size at position 0
447        let basis_size = self.basis_size();
448        let mut coeffs_shape = vec![basis_size];
449        values_dim0.shape().with_dims(|dims| {
450            for i in 1..dims.len() {
451                coeffs_shape.push(dims[i]);
452            }
453        });
454
455        let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
456
457        // 5. Move dimension 0 back to original position dim
458        movedim(&coeffs_dim0, 0, dim)
459    }
460
461    /// Fit N-dimensional array of complex values to real basis coefficients
462    ///
463    /// This method fits complex Matsubara values to real IR coefficients.
464    /// Takes the real part of the least-squares solution.
465    ///
466    /// # Arguments
467    /// * `backend` - Optional GEMM backend handle (None uses default)
468    /// * `values` - N-dimensional tensor of complex values at Matsubara frequencies
469    /// * `dim` - Dimension along which to fit (must have size = n_sampling_points)
470    ///
471    /// # Returns
472    /// N-dimensional tensor of real basis coefficients
473    pub fn fit_nd_real(
474        &self,
475        backend: Option<&GemmBackendHandle>,
476        values: &Tensor<Complex<f64>, DynRank>,
477        dim: usize,
478    ) -> Tensor<f64, DynRank> {
479        let rank = values.rank();
480        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
481
482        let n_points = self.n_sampling_points();
483        let target_dim_size = values.shape().dim(dim);
484
485        assert_eq!(
486            target_dim_size, n_points,
487            "values.shape().dim({}) = {} must equal n_sampling_points = {}",
488            dim, target_dim_size, n_points
489        );
490
491        // 1. Move target dimension to position 0
492        let values_dim0 = movedim(values, dim, 0);
493
494        // 2. Reshape to 2D: (n_points, extra_size)
495        let extra_size: usize = values_dim0.len() / n_points;
496        let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
497
498        // 3. Convert to DTensor and fit
499        let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
500            values_2d_dyn[&[idx[0], idx[1]][..]]
501        });
502
503        // Use fitter's fit_2d_real method
504        let values_2d_view = values_2d.view(.., ..);
505        let coeffs_2d = self.fitter.fit_2d_real(backend, &values_2d_view);
506
507        // 4. Reshape back to N-D with basis_size at position 0
508        let basis_size = self.basis_size();
509        let mut coeffs_shape = vec![basis_size];
510        values_dim0.shape().with_dims(|dims| {
511            for i in 1..dims.len() {
512                coeffs_shape.push(dims[i]);
513            }
514        });
515
516        let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
517
518        // 5. Move dimension 0 back to original position dim
519        movedim(&coeffs_dim0, 0, dim)
520    }
521
522    /// Evaluate basis coefficients at Matsubara sampling points (N-dimensional) with in-place output
523    ///
524    /// # Type Parameters
525    /// * `T` - Coefficient type (f64 or Complex<f64>)
526    ///
527    /// # Arguments
528    /// * `coeffs` - N-dimensional tensor with `coeffs.shape().dim(dim) == basis_size`
529    /// * `dim` - Dimension along which to evaluate (0-indexed)
530    /// * `out` - Output tensor with `out.shape().dim(dim) == n_sampling_points` (Complex<f64>)
531    pub fn evaluate_nd_to<T>(
532        &self,
533        backend: Option<&GemmBackendHandle>,
534        coeffs: &Slice<T, DynRank>,
535        dim: usize,
536        out: &mut Tensor<Complex<f64>, DynRank>,
537    ) where
538        T: Copy + 'static,
539    {
540        // Validate output shape
541        let rank = coeffs.rank();
542        assert_eq!(
543            out.rank(),
544            rank,
545            "out.rank()={} must equal coeffs.rank()={}",
546            out.rank(),
547            rank
548        );
549
550        let n_points = self.n_sampling_points();
551        let out_dim_size = out.shape().dim(dim);
552        assert_eq!(
553            out_dim_size, n_points,
554            "out.shape().dim({}) = {} must equal n_sampling_points = {}",
555            dim, out_dim_size, n_points
556        );
557
558        // Validate other dimensions match
559        for d in 0..rank {
560            if d != dim {
561                let coeffs_d = coeffs.shape().dim(d);
562                let out_d = out.shape().dim(d);
563                assert_eq!(
564                    coeffs_d, out_d,
565                    "coeffs.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
566                    d, coeffs_d, d, out_d
567                );
568            }
569        }
570
571        // Compute result and copy to out
572        let result = self.evaluate_nd(backend, coeffs, dim);
573
574        // Copy result to out
575        let total = out.len();
576        for i in 0..total {
577            let mut idx = vec![0usize; rank];
578            let mut remaining = i;
579            for d in (0..rank).rev() {
580                let dim_size = out.shape().dim(d);
581                idx[d] = remaining % dim_size;
582                remaining /= dim_size;
583            }
584            out[&idx[..]] = result[&idx[..]];
585        }
586    }
587
588    /// Fit N-dimensional complex values to complex coefficients with in-place output
589    ///
590    /// # Arguments
591    /// * `values` - N-dimensional tensor with `values.shape().dim(dim) == n_sampling_points`
592    /// * `dim` - Dimension along which to fit (0-indexed)
593    /// * `out` - Output tensor with `out.shape().dim(dim) == basis_size` (Complex<f64>)
594    pub fn fit_nd_to(
595        &self,
596        backend: Option<&GemmBackendHandle>,
597        values: &Tensor<Complex<f64>, DynRank>,
598        dim: usize,
599        out: &mut Tensor<Complex<f64>, DynRank>,
600    ) {
601        // Validate output shape
602        let rank = values.rank();
603        assert_eq!(
604            out.rank(),
605            rank,
606            "out.rank()={} must equal values.rank()={}",
607            out.rank(),
608            rank
609        );
610
611        let basis_size = self.basis_size();
612        let out_dim_size = out.shape().dim(dim);
613        assert_eq!(
614            out_dim_size, basis_size,
615            "out.shape().dim({}) = {} must equal basis_size = {}",
616            dim, out_dim_size, basis_size
617        );
618
619        // Validate other dimensions match
620        for d in 0..rank {
621            if d != dim {
622                let values_d = values.shape().dim(d);
623                let out_d = out.shape().dim(d);
624                assert_eq!(
625                    values_d, out_d,
626                    "values.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
627                    d, values_d, d, out_d
628                );
629            }
630        }
631
632        // Compute result and copy to out
633        let result = self.fit_nd(backend, values, dim);
634
635        // Copy result to out
636        let total = out.len();
637        for i in 0..total {
638            let mut idx = vec![0usize; rank];
639            let mut remaining = i;
640            for d in (0..rank).rev() {
641                let dim_size = out.shape().dim(d);
642                idx[d] = remaining % dim_size;
643                remaining /= dim_size;
644            }
645            out[&idx[..]] = result[&idx[..]];
646        }
647    }
648}
649
650/// InplaceFitter implementation for MatsubaraSampling
651///
652/// Delegates to ComplexMatrixFitter which supports:
653/// - zz: Complex input → Complex output (full support)
654/// - dz: Real input → Complex output (evaluate only)
655/// - zd: Complex input → Real output (fit only, takes real part)
656impl<S: StatisticsType> InplaceFitter for MatsubaraSampling<S> {
657    fn n_points(&self) -> usize {
658        self.n_sampling_points()
659    }
660
661    fn basis_size(&self) -> usize {
662        self.basis_size()
663    }
664
665    fn evaluate_nd_dz_to(
666        &self,
667        backend: Option<&GemmBackendHandle>,
668        coeffs: &Slice<f64, DynRank>,
669        dim: usize,
670        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
671    ) -> bool {
672        self.fitter.evaluate_nd_dz_to(backend, coeffs, dim, out)
673    }
674
675    fn evaluate_nd_zz_to(
676        &self,
677        backend: Option<&GemmBackendHandle>,
678        coeffs: &Slice<Complex<f64>, DynRank>,
679        dim: usize,
680        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
681    ) -> bool {
682        self.fitter.evaluate_nd_zz_to(backend, coeffs, dim, out)
683    }
684
685    fn fit_nd_zd_to(
686        &self,
687        backend: Option<&GemmBackendHandle>,
688        values: &Slice<Complex<f64>, DynRank>,
689        dim: usize,
690        out: &mut ViewMut<'_, f64, DynRank>,
691    ) -> bool {
692        self.fitter.fit_nd_zd_to(backend, values, dim, out)
693    }
694
695    fn fit_nd_zz_to(
696        &self,
697        backend: Option<&GemmBackendHandle>,
698        values: &Slice<Complex<f64>, DynRank>,
699        dim: usize,
700        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
701    ) -> bool {
702        self.fitter.fit_nd_zz_to(backend, values, dim, out)
703    }
704}
705
706/// Matsubara sampling for positive frequencies only
707///
708/// Exploits symmetry to reconstruct real coefficients from positive frequencies only.
709/// Supports: {0, 1, 2, 3, ...} (no negative frequencies)
710pub struct MatsubaraSamplingPositiveOnly<S: StatisticsType> {
711    sampling_points: Vec<MatsubaraFreq<S>>,
712    fitter: ComplexToRealFitter,
713    _phantom: PhantomData<S>,
714}
715
716impl<S: StatisticsType> MatsubaraSamplingPositiveOnly<S> {
717    /// Create Matsubara sampling with default positive-only sampling points
718    ///
719    /// Uses extrema-based sampling point selection (positive frequencies only).
720    /// Exploits symmetry to reconstruct real coefficients.
721    pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
722    where
723        S: 'static,
724    {
725        let sampling_points = basis.default_matsubara_sampling_points(true);
726        Self::with_sampling_points(basis, sampling_points)
727    }
728
729    /// Create Matsubara sampling with custom positive-only sampling points
730    pub fn with_sampling_points(
731        basis: &impl crate::basis_trait::Basis<S>,
732        mut sampling_points: Vec<MatsubaraFreq<S>>,
733    ) -> Self
734    where
735        S: 'static,
736    {
737        // Sort and validate (all n >= 0)
738        sampling_points.sort();
739
740        // TODO: Validate that all points are non-negative
741
742        // Evaluate matrix at sampling points
743        // Use Basis trait's evaluate_matsubara method
744        let matrix = basis.evaluate_matsubara(&sampling_points);
745
746        // Create fitter (complex → real, exploits symmetry)
747        let fitter = ComplexToRealFitter::new(&matrix);
748
749        Self {
750            sampling_points,
751            fitter,
752            _phantom: PhantomData,
753        }
754    }
755
756    /// Create Matsubara sampling (positive-only) with custom sampling points and pre-computed matrix
757    ///
758    /// This constructor is useful when the sampling matrix is already computed.
759    /// Uses symmetry to fit real coefficients from complex values at positive frequencies.
760    ///
761    /// # Arguments
762    /// * `sampling_points` - Matsubara frequency sampling points (should be positive)
763    /// * `matrix` - Pre-computed sampling matrix (n_points × basis_size)
764    ///
765    /// # Returns
766    /// A new MatsubaraSamplingPositiveOnly object
767    ///
768    /// # Panics
769    /// Panics if `sampling_points` is empty or if matrix dimensions don't match
770    pub fn from_matrix(
771        mut sampling_points: Vec<MatsubaraFreq<S>>,
772        matrix: DTensor<Complex<f64>, 2>,
773    ) -> Self {
774        assert!(!sampling_points.is_empty(), "No sampling points given");
775        assert_eq!(
776            matrix.shape().0,
777            sampling_points.len(),
778            "Matrix rows ({}) must match number of sampling points ({})",
779            matrix.shape().0,
780            sampling_points.len()
781        );
782
783        // Sort sampling points
784        sampling_points.sort();
785
786        let fitter = ComplexToRealFitter::new(&matrix);
787
788        Self {
789            sampling_points,
790            fitter,
791            _phantom: PhantomData,
792        }
793    }
794
795    /// Get sampling points
796    pub fn sampling_points(&self) -> &[MatsubaraFreq<S>] {
797        &self.sampling_points
798    }
799
800    /// Number of sampling points
801    pub fn n_sampling_points(&self) -> usize {
802        self.sampling_points.len()
803    }
804
805    /// Basis size
806    pub fn basis_size(&self) -> usize {
807        self.fitter.basis_size()
808    }
809
810    /// Get the original complex sampling matrix
811    pub fn matrix(&self) -> &DTensor<Complex<f64>, 2> {
812        &self.fitter.matrix
813    }
814
815    /// Evaluate basis coefficients at sampling points
816    pub fn evaluate(&self, coeffs: &[f64]) -> Vec<Complex<f64>> {
817        self.fitter.evaluate(None, coeffs)
818    }
819
820    /// Fit basis coefficients from values at sampling points
821    pub fn fit(&self, values: &[Complex<f64>]) -> Vec<f64> {
822        self.fitter.fit(None, values)
823    }
824
825    /// Evaluate N-dimensional array of real basis coefficients at sampling points
826    ///
827    /// # Arguments
828    /// * `coeffs` - N-dimensional tensor of real basis coefficients
829    /// * `dim` - Dimension along which to evaluate (must have size = basis_size)
830    ///
831    /// # Returns
832    /// N-dimensional tensor of complex values at Matsubara frequencies
833    pub fn evaluate_nd(
834        &self,
835        backend: Option<&GemmBackendHandle>,
836        coeffs: &Tensor<f64, DynRank>,
837        dim: usize,
838    ) -> Tensor<Complex<f64>, DynRank> {
839        let rank = coeffs.rank();
840        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
841
842        let basis_size = self.basis_size();
843        let target_dim_size = coeffs.shape().dim(dim);
844
845        assert_eq!(
846            target_dim_size, basis_size,
847            "coeffs.shape().dim({}) = {} must equal basis_size = {}",
848            dim, target_dim_size, basis_size
849        );
850
851        // 1. Move target dimension to position 0
852        let coeffs_dim0 = movedim(coeffs, dim, 0);
853
854        // 2. Reshape to 2D: (basis_size, extra_size)
855        let extra_size: usize = coeffs_dim0.len() / basis_size;
856
857        let coeffs_2d_dyn = coeffs_dim0
858            .reshape(&[basis_size, extra_size][..])
859            .to_tensor();
860
861        // 3. Convert to DTensor and evaluate using GEMM
862        let coeffs_2d = DTensor::<f64, 2>::from_fn([basis_size, extra_size], |idx| {
863            coeffs_2d_dyn[&[idx[0], idx[1]][..]]
864        });
865
866        // Use fitter's efficient 2D evaluate (GEMM-based)
867        let coeffs_2d_view = coeffs_2d.view(.., ..);
868        let result_2d = self.fitter.evaluate_2d(backend, &coeffs_2d_view);
869
870        // 4. Reshape back to N-D with n_points at position 0
871        let n_points = self.n_sampling_points();
872        let mut result_shape = vec![n_points];
873        coeffs_dim0.shape().with_dims(|dims| {
874            for i in 1..dims.len() {
875                result_shape.push(dims[i]);
876            }
877        });
878
879        let result_dim0 = result_2d.into_dyn().reshape(&result_shape[..]).to_tensor();
880
881        // 5. Move dimension 0 back to original position dim
882        movedim(&result_dim0, 0, dim)
883    }
884
885    /// Fit N-dimensional array of complex values to real basis coefficients
886    ///
887    /// # Arguments
888    /// * `backend` - Optional GEMM backend handle (None uses default)
889    /// * `values` - N-dimensional tensor of complex values at Matsubara frequencies
890    /// * `dim` - Dimension along which to fit (must have size = n_sampling_points)
891    ///
892    /// # Returns
893    /// N-dimensional tensor of real basis coefficients
894    pub fn fit_nd(
895        &self,
896        backend: Option<&GemmBackendHandle>,
897        values: &Tensor<Complex<f64>, DynRank>,
898        dim: usize,
899    ) -> Tensor<f64, DynRank> {
900        let rank = values.rank();
901        assert!(dim < rank, "dim={} must be < rank={}", dim, rank);
902
903        let n_points = self.n_sampling_points();
904        let target_dim_size = values.shape().dim(dim);
905
906        assert_eq!(
907            target_dim_size, n_points,
908            "values.shape().dim({}) = {} must equal n_sampling_points = {}",
909            dim, target_dim_size, n_points
910        );
911
912        // 1. Move target dimension to position 0
913        let values_dim0 = movedim(values, dim, 0);
914
915        // 2. Reshape to 2D: (n_points, extra_size)
916        let extra_size: usize = values_dim0.len() / n_points;
917        let values_2d_dyn = values_dim0.reshape(&[n_points, extra_size][..]).to_tensor();
918
919        // 3. Convert to DTensor and fit using GEMM
920        let values_2d = DTensor::<Complex<f64>, 2>::from_fn([n_points, extra_size], |idx| {
921            values_2d_dyn[&[idx[0], idx[1]][..]]
922        });
923
924        // Use fitter's efficient 2D fit (GEMM-based)
925        let values_2d_view = values_2d.view(.., ..);
926        let coeffs_2d = self.fitter.fit_2d(backend, &values_2d_view);
927
928        // 4. Reshape back to N-D with basis_size at position 0
929        let basis_size = self.basis_size();
930        let mut coeffs_shape = vec![basis_size];
931        values_dim0.shape().with_dims(|dims| {
932            for i in 1..dims.len() {
933                coeffs_shape.push(dims[i]);
934            }
935        });
936
937        let coeffs_dim0 = coeffs_2d.into_dyn().reshape(&coeffs_shape[..]).to_tensor();
938
939        // 5. Move dimension 0 back to original position dim
940        movedim(&coeffs_dim0, 0, dim)
941    }
942
943    /// Evaluate real basis coefficients at Matsubara sampling points (N-dimensional) with in-place output
944    ///
945    /// # Arguments
946    /// * `coeffs` - N-dimensional tensor of real coefficients with `coeffs.shape().dim(dim) == basis_size`
947    /// * `dim` - Dimension along which to evaluate (0-indexed)
948    /// * `out` - Output tensor with `out.shape().dim(dim) == n_sampling_points` (Complex<f64>)
949    pub fn evaluate_nd_to(
950        &self,
951        backend: Option<&GemmBackendHandle>,
952        coeffs: &Tensor<f64, DynRank>,
953        dim: usize,
954        out: &mut Tensor<Complex<f64>, DynRank>,
955    ) {
956        // Validate output shape
957        let rank = coeffs.rank();
958        assert_eq!(
959            out.rank(),
960            rank,
961            "out.rank()={} must equal coeffs.rank()={}",
962            out.rank(),
963            rank
964        );
965
966        let n_points = self.n_sampling_points();
967        let out_dim_size = out.shape().dim(dim);
968        assert_eq!(
969            out_dim_size, n_points,
970            "out.shape().dim({}) = {} must equal n_sampling_points = {}",
971            dim, out_dim_size, n_points
972        );
973
974        // Validate other dimensions match
975        for d in 0..rank {
976            if d != dim {
977                let coeffs_d = coeffs.shape().dim(d);
978                let out_d = out.shape().dim(d);
979                assert_eq!(
980                    coeffs_d, out_d,
981                    "coeffs.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
982                    d, coeffs_d, d, out_d
983                );
984            }
985        }
986
987        // Compute result and copy to out
988        let result = self.evaluate_nd(backend, coeffs, dim);
989
990        // Copy result to out
991        let total = out.len();
992        for i in 0..total {
993            let mut idx = vec![0usize; rank];
994            let mut remaining = i;
995            for d in (0..rank).rev() {
996                let dim_size = out.shape().dim(d);
997                idx[d] = remaining % dim_size;
998                remaining /= dim_size;
999            }
1000            out[&idx[..]] = result[&idx[..]];
1001        }
1002    }
1003
1004    /// Fit N-dimensional complex values to real coefficients with in-place output
1005    ///
1006    /// # Arguments
1007    /// * `values` - N-dimensional tensor with `values.shape().dim(dim) == n_sampling_points`
1008    /// * `dim` - Dimension along which to fit (0-indexed)
1009    /// * `out` - Output tensor with `out.shape().dim(dim) == basis_size` (f64)
1010    pub fn fit_nd_to(
1011        &self,
1012        backend: Option<&GemmBackendHandle>,
1013        values: &Tensor<Complex<f64>, DynRank>,
1014        dim: usize,
1015        out: &mut Tensor<f64, DynRank>,
1016    ) {
1017        // Validate output shape
1018        let rank = values.rank();
1019        assert_eq!(
1020            out.rank(),
1021            rank,
1022            "out.rank()={} must equal values.rank()={}",
1023            out.rank(),
1024            rank
1025        );
1026
1027        let basis_size = self.basis_size();
1028        let out_dim_size = out.shape().dim(dim);
1029        assert_eq!(
1030            out_dim_size, basis_size,
1031            "out.shape().dim({}) = {} must equal basis_size = {}",
1032            dim, out_dim_size, basis_size
1033        );
1034
1035        // Validate other dimensions match
1036        for d in 0..rank {
1037            if d != dim {
1038                let values_d = values.shape().dim(d);
1039                let out_d = out.shape().dim(d);
1040                assert_eq!(
1041                    values_d, out_d,
1042                    "values.shape().dim({}) = {} must equal out.shape().dim({}) = {}",
1043                    d, values_d, d, out_d
1044                );
1045            }
1046        }
1047
1048        // Compute result and copy to out
1049        let result = self.fit_nd(backend, values, dim);
1050
1051        // Copy result to out
1052        let total = out.len();
1053        for i in 0..total {
1054            let mut idx = vec![0usize; rank];
1055            let mut remaining = i;
1056            for d in (0..rank).rev() {
1057                let dim_size = out.shape().dim(d);
1058                idx[d] = remaining % dim_size;
1059                remaining /= dim_size;
1060            }
1061            out[&idx[..]] = result[&idx[..]];
1062        }
1063    }
1064}
1065
1066/// InplaceFitter implementation for MatsubaraSamplingPositiveOnly
1067///
1068/// Delegates to ComplexToRealFitter which supports:
1069/// - dz: Real coefficients → Complex values (evaluate)
1070/// - zz: Complex coefficients → Complex values (evaluate, extracts real parts)
1071/// - zd: Complex values → Real coefficients (fit)
1072/// - zz: Complex values → Complex coefficients (fit, with zero imaginary parts)
1073impl<S: StatisticsType> InplaceFitter for MatsubaraSamplingPositiveOnly<S> {
1074    fn n_points(&self) -> usize {
1075        self.n_sampling_points()
1076    }
1077
1078    fn basis_size(&self) -> usize {
1079        self.basis_size()
1080    }
1081
1082    fn evaluate_nd_dz_to(
1083        &self,
1084        backend: Option<&GemmBackendHandle>,
1085        coeffs: &Slice<f64, DynRank>,
1086        dim: usize,
1087        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
1088    ) -> bool {
1089        self.fitter.evaluate_nd_dz_to(backend, coeffs, dim, out)
1090    }
1091
1092    fn evaluate_nd_zz_to(
1093        &self,
1094        backend: Option<&GemmBackendHandle>,
1095        coeffs: &Slice<Complex<f64>, DynRank>,
1096        dim: usize,
1097        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
1098    ) -> bool {
1099        self.fitter.evaluate_nd_zz_to(backend, coeffs, dim, out)
1100    }
1101
1102    fn fit_nd_zd_to(
1103        &self,
1104        backend: Option<&GemmBackendHandle>,
1105        values: &Slice<Complex<f64>, DynRank>,
1106        dim: usize,
1107        out: &mut ViewMut<'_, f64, DynRank>,
1108    ) -> bool {
1109        self.fitter.fit_nd_zd_to(backend, values, dim, out)
1110    }
1111
1112    fn fit_nd_zz_to(
1113        &self,
1114        backend: Option<&GemmBackendHandle>,
1115        values: &Slice<Complex<f64>, DynRank>,
1116        dim: usize,
1117        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
1118    ) -> bool {
1119        self.fitter.fit_nd_zz_to(backend, values, dim, out)
1120    }
1121}
1122
1123#[cfg(test)]
1124#[path = "matsubara_sampling_tests.rs"]
1125mod tests;