sparse_ir/
sampling.rs

1//! Sparse sampling in imaginary time
2//!
3//! This module provides `TauSampling` for transforming between IR basis coefficients
4//! and values at sparse sampling points in imaginary time.
5
6use crate::fitters::InplaceFitter;
7use crate::gemm::GemmBackendHandle;
8use crate::traits::StatisticsType;
9use mdarray::{DTensor, DynRank, Shape, Slice, Tensor, ViewMut};
10use num_complex::Complex;
11
12/// Build output shape by replacing dimension `dim` with `new_size`
13fn build_output_shape<S: Shape>(input_shape: &S, dim: usize, new_size: usize) -> Vec<usize> {
14    let mut out_shape: Vec<usize> = Vec::with_capacity(input_shape.rank());
15    input_shape.with_dims(|dims| {
16        for (i, d) in dims.iter().enumerate() {
17            if i == dim {
18                out_shape.push(new_size);
19            } else {
20                out_shape.push(*d);
21            }
22        }
23    });
24    out_shape
25}
26
27/// Move axis from position `src` to position `dst`
28///
29/// This is equivalent to numpy.moveaxis or libsparseir's movedim.
30/// It creates a permutation array that moves the specified axis.
31///
32/// # Arguments
33/// * `arr` - Input array slice (Tensor or View)
34/// * `src` - Source axis position
35/// * `dst` - Destination axis position
36///
37/// # Returns
38/// Tensor with axes permuted
39///
40/// # Example
41/// ```ignore
42/// // For a 4D tensor with shape (2, 3, 4, 5)
43/// // movedim(arr, 0, 2) moves axis 0 to position 2
44/// // Result shape: (3, 4, 2, 5) with axes permuted as [1, 2, 0, 3]
45/// ```
46pub fn movedim<T: Clone>(arr: &Slice<T, DynRank>, src: usize, dst: usize) -> Tensor<T, DynRank> {
47    if src == dst {
48        return arr.to_tensor();
49    }
50
51    let rank = arr.rank();
52    assert!(
53        src < rank,
54        "src axis {} out of bounds for rank {}",
55        src,
56        rank
57    );
58    assert!(
59        dst < rank,
60        "dst axis {} out of bounds for rank {}",
61        dst,
62        rank
63    );
64
65    // Generate permutation: move src to dst position
66    let mut perm = Vec::with_capacity(rank);
67    let mut pos = 0;
68    for i in 0..rank {
69        if i == dst {
70            perm.push(src);
71        } else {
72            // Skip src position
73            if pos == src {
74                pos += 1;
75            }
76            perm.push(pos);
77            pos += 1;
78        }
79    }
80
81    arr.permute(&perm[..]).to_tensor()
82}
83
84/// Sparse sampling in imaginary time
85///
86/// Allows transformation between the IR basis and a set of sampling points
87/// in imaginary time (τ).
88pub struct TauSampling<S>
89where
90    S: StatisticsType,
91{
92    /// Sampling points in imaginary time τ ∈ [-β/2, β/2]
93    sampling_points: Vec<f64>,
94
95    /// Real matrix fitter for least-squares fitting
96    fitter: crate::fitters::RealMatrixFitter,
97
98    /// Marker for statistics type
99    _phantom: std::marker::PhantomData<S>,
100}
101
102impl<S> TauSampling<S>
103where
104    S: StatisticsType,
105{
106    /// Create a new TauSampling with default sampling points
107    ///
108    /// The default sampling points are chosen as the extrema of the highest-order
109    /// basis function, which gives near-optimal conditioning.
110    /// SVD is computed lazily on first call to `fit` or `fit_nd`.
111    ///
112    /// # Arguments
113    /// * `basis` - Any basis implementing the `Basis` trait
114    ///
115    /// # Returns
116    /// A new TauSampling object
117    pub fn new(basis: &impl crate::basis_trait::Basis<S>) -> Self
118    where
119        S: 'static,
120    {
121        let sampling_points = basis.default_tau_sampling_points();
122        Self::with_sampling_points(basis, sampling_points)
123    }
124
125    /// Create a new TauSampling with custom sampling points
126    ///
127    /// SVD is computed lazily on first call to `fit` or `fit_nd`.
128    ///
129    /// # Arguments
130    /// * `basis` - Any basis implementing the `Basis` trait
131    /// * `sampling_points` - Custom sampling points in τ ∈ [-β, β]
132    ///
133    /// # Returns
134    /// A new TauSampling object
135    ///
136    /// # Panics
137    /// Panics if `sampling_points` is empty or if any point is outside [-β, β]
138    pub fn with_sampling_points(
139        basis: &impl crate::basis_trait::Basis<S>,
140        sampling_points: Vec<f64>,
141    ) -> Self
142    where
143        S: 'static,
144    {
145        assert!(!sampling_points.is_empty(), "No sampling points given");
146
147        let beta = basis.beta();
148        for &tau in &sampling_points {
149            assert!(
150                tau >= -beta && tau <= beta,
151                "Sampling point τ={} is outside [-β, β]",
152                tau
153            );
154        }
155
156        // Compute sampling matrix: A[i, l] = u_l(τ_i)
157        // Use Basis trait's evaluate_tau method
158        let matrix = basis.evaluate_tau(&sampling_points);
159
160        // Create fitter
161        let fitter = crate::fitters::RealMatrixFitter::new(matrix);
162
163        Self {
164            sampling_points,
165            fitter,
166            _phantom: std::marker::PhantomData,
167        }
168    }
169
170    /// Create a new TauSampling with custom sampling points and pre-computed matrix
171    ///
172    /// This constructor is useful when the sampling matrix is already computed
173    /// (e.g., from external sources or for testing).
174    ///
175    /// # Arguments
176    /// * `sampling_points` - Sampling points in τ ∈ [-β, β]
177    /// * `matrix` - Pre-computed sampling matrix (n_points × basis_size)
178    ///
179    /// # Returns
180    /// A new TauSampling object
181    ///
182    /// # Panics
183    /// Panics if `sampling_points` is empty or if matrix dimensions don't match
184    pub fn from_matrix(sampling_points: Vec<f64>, matrix: DTensor<f64, 2>) -> Self {
185        assert!(!sampling_points.is_empty(), "No sampling points given");
186        assert_eq!(
187            matrix.shape().0,
188            sampling_points.len(),
189            "Matrix rows ({}) must match number of sampling points ({})",
190            matrix.shape().0,
191            sampling_points.len()
192        );
193
194        let fitter = crate::fitters::RealMatrixFitter::new(matrix);
195
196        Self {
197            sampling_points,
198            fitter,
199            _phantom: std::marker::PhantomData,
200        }
201    }
202
203    /// Get the sampling points
204    pub fn sampling_points(&self) -> &[f64] {
205        &self.sampling_points
206    }
207
208    /// Get the number of sampling points
209    pub fn n_sampling_points(&self) -> usize {
210        self.fitter.n_points()
211    }
212
213    /// Get the basis size
214    pub fn basis_size(&self) -> usize {
215        self.fitter.basis_size()
216    }
217
218    /// Get the sampling matrix
219    pub fn matrix(&self) -> &DTensor<f64, 2> {
220        &self.fitter.matrix
221    }
222
223    // ========================================================================
224    // 1D functions (real and complex)
225    // ========================================================================
226
227    /// Evaluate basis coefficients at sampling points
228    ///
229    /// Computes g(τ_i) = Σ_l a_l * u_l(τ_i) for all sampling points
230    ///
231    /// # Arguments
232    /// * `coeffs` - Basis coefficients (length = basis_size)
233    ///
234    /// # Returns
235    /// Values at sampling points (length = n_sampling_points)
236    pub fn evaluate(&self, coeffs: &[f64]) -> Vec<f64> {
237        self.fitter.evaluate(None, coeffs)
238    }
239
240    /// Evaluate basis coefficients at sampling points, writing to output slice
241    pub fn evaluate_to(&self, coeffs: &[f64], out: &mut [f64]) {
242        self.fitter.evaluate_to(None, coeffs, out)
243    }
244
245    /// Fit values at sampling points to basis coefficients
246    pub fn fit(&self, values: &[f64]) -> Vec<f64> {
247        self.fitter.fit(None, values)
248    }
249
250    /// Fit values at sampling points to basis coefficients, writing to output slice
251    pub fn fit_to(&self, values: &[f64], out: &mut [f64]) {
252        self.fitter.fit_to(None, values, out)
253    }
254
255    /// Evaluate complex basis coefficients at sampling points
256    pub fn evaluate_zz(&self, coeffs: &[Complex<f64>]) -> Vec<Complex<f64>> {
257        self.fitter.evaluate_zz(None, coeffs)
258    }
259
260    /// Evaluate complex basis coefficients, writing to output slice
261    pub fn evaluate_zz_to(&self, coeffs: &[Complex<f64>], out: &mut [Complex<f64>]) {
262        self.fitter.evaluate_zz_to(None, coeffs, out)
263    }
264
265    /// Fit complex values at sampling points to basis coefficients
266    pub fn fit_zz(&self, values: &[Complex<f64>]) -> Vec<Complex<f64>> {
267        self.fitter.fit_zz(None, values)
268    }
269
270    /// Fit complex values, writing to output slice
271    pub fn fit_zz_to(&self, values: &[Complex<f64>], out: &mut [Complex<f64>]) {
272        self.fitter.fit_zz_to(None, values, out)
273    }
274
275    // ========================================================================
276    // N-D functions (real)
277    // ========================================================================
278
279    /// Evaluate N-D real coefficients at sampling points
280    ///
281    /// # Arguments
282    /// * `coeffs` - N-dimensional array with `coeffs.shape().dim(dim) == basis_size`
283    /// * `dim` - Dimension along which to evaluate (0-indexed)
284    ///
285    /// # Returns
286    /// N-dimensional array with `result.shape().dim(dim) == n_sampling_points`
287    pub fn evaluate_nd(
288        &self,
289        backend: Option<&GemmBackendHandle>,
290        coeffs: &Slice<f64, DynRank>,
291        dim: usize,
292    ) -> Tensor<f64, DynRank> {
293        let out_shape = build_output_shape(coeffs.shape(), dim, self.n_sampling_points());
294        let mut out = Tensor::<f64, DynRank>::zeros(&out_shape[..]);
295        self.evaluate_nd_to(backend, coeffs, dim, &mut out.expr_mut());
296        out
297    }
298
299    /// Evaluate N-D real coefficients, writing to a mutable view
300    pub fn evaluate_nd_to(
301        &self,
302        backend: Option<&GemmBackendHandle>,
303        coeffs: &Slice<f64, DynRank>,
304        dim: usize,
305        out: &mut ViewMut<'_, f64, DynRank>,
306    ) {
307        InplaceFitter::evaluate_nd_dd_to(self, backend, coeffs, dim, out);
308    }
309
310    /// Fit N-D real values at sampling points to basis coefficients
311    ///
312    /// # Arguments
313    /// * `values` - N-dimensional array with `values.shape().dim(dim) == n_sampling_points`
314    /// * `dim` - Dimension along which to fit (0-indexed)
315    ///
316    /// # Returns
317    /// N-dimensional array with `result.shape().dim(dim) == basis_size`
318    pub fn fit_nd(
319        &self,
320        backend: Option<&GemmBackendHandle>,
321        values: &Slice<f64, DynRank>,
322        dim: usize,
323    ) -> Tensor<f64, DynRank> {
324        let out_shape = build_output_shape(values.shape(), dim, self.basis_size());
325        let mut out = Tensor::<f64, DynRank>::zeros(&out_shape[..]);
326        self.fit_nd_to(backend, values, dim, &mut out.expr_mut());
327        out
328    }
329
330    /// Fit N-D real values, writing to a mutable view
331    pub fn fit_nd_to(
332        &self,
333        backend: Option<&GemmBackendHandle>,
334        values: &Slice<f64, DynRank>,
335        dim: usize,
336        out: &mut ViewMut<'_, f64, DynRank>,
337    ) {
338        InplaceFitter::fit_nd_dd_to(self, backend, values, dim, out);
339    }
340
341    // ========================================================================
342    // N-D functions (complex)
343    // ========================================================================
344
345    /// Evaluate N-D complex coefficients at sampling points
346    ///
347    /// # Arguments
348    /// * `coeffs` - N-dimensional complex array with `coeffs.shape().dim(dim) == basis_size`
349    /// * `dim` - Dimension along which to evaluate (0-indexed)
350    ///
351    /// # Returns
352    /// N-dimensional complex array with `result.shape().dim(dim) == n_sampling_points`
353    pub fn evaluate_nd_zz(
354        &self,
355        backend: Option<&GemmBackendHandle>,
356        coeffs: &Slice<Complex<f64>, DynRank>,
357        dim: usize,
358    ) -> Tensor<Complex<f64>, DynRank> {
359        let out_shape = build_output_shape(coeffs.shape(), dim, self.n_sampling_points());
360        let mut out = Tensor::<Complex<f64>, DynRank>::zeros(&out_shape[..]);
361        self.evaluate_nd_zz_to(backend, coeffs, dim, &mut out.expr_mut());
362        out
363    }
364
365    /// Evaluate N-D complex coefficients, writing to a mutable view
366    pub fn evaluate_nd_zz_to(
367        &self,
368        backend: Option<&GemmBackendHandle>,
369        coeffs: &Slice<Complex<f64>, DynRank>,
370        dim: usize,
371        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
372    ) {
373        InplaceFitter::evaluate_nd_zz_to(self, backend, coeffs, dim, out);
374    }
375
376    /// Fit N-D complex values at sampling points to basis coefficients
377    ///
378    /// # Arguments
379    /// * `values` - N-dimensional complex array with `values.shape().dim(dim) == n_sampling_points`
380    /// * `dim` - Dimension along which to fit (0-indexed)
381    ///
382    /// # Returns
383    /// N-dimensional complex array with `result.shape().dim(dim) == basis_size`
384    pub fn fit_nd_zz(
385        &self,
386        backend: Option<&GemmBackendHandle>,
387        values: &Slice<Complex<f64>, DynRank>,
388        dim: usize,
389    ) -> Tensor<Complex<f64>, DynRank> {
390        let out_shape = build_output_shape(values.shape(), dim, self.basis_size());
391        let mut out = Tensor::<Complex<f64>, DynRank>::zeros(&out_shape[..]);
392        self.fit_nd_zz_to(backend, values, dim, &mut out.expr_mut());
393        out
394    }
395
396    /// Fit N-D complex values, writing to a mutable view
397    pub fn fit_nd_zz_to(
398        &self,
399        backend: Option<&GemmBackendHandle>,
400        values: &Slice<Complex<f64>, DynRank>,
401        dim: usize,
402        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
403    ) {
404        InplaceFitter::fit_nd_zz_to(self, backend, values, dim, out);
405    }
406}
407
408/// InplaceFitter implementation for TauSampling
409///
410/// Delegates to RealMatrixFitter which supports dd and zz operations.
411impl<S: StatisticsType> InplaceFitter for TauSampling<S> {
412    fn n_points(&self) -> usize {
413        self.n_sampling_points()
414    }
415
416    fn basis_size(&self) -> usize {
417        self.basis_size()
418    }
419
420    fn evaluate_nd_dd_to(
421        &self,
422        backend: Option<&GemmBackendHandle>,
423        coeffs: &Slice<f64, DynRank>,
424        dim: usize,
425        out: &mut ViewMut<'_, f64, DynRank>,
426    ) -> bool {
427        self.fitter.evaluate_nd_dd_to(backend, coeffs, dim, out)
428    }
429
430    fn evaluate_nd_zz_to(
431        &self,
432        backend: Option<&GemmBackendHandle>,
433        coeffs: &Slice<Complex<f64>, DynRank>,
434        dim: usize,
435        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
436    ) -> bool {
437        self.fitter.evaluate_nd_zz_to(backend, coeffs, dim, out)
438    }
439
440    fn fit_nd_dd_to(
441        &self,
442        backend: Option<&GemmBackendHandle>,
443        values: &Slice<f64, DynRank>,
444        dim: usize,
445        out: &mut ViewMut<'_, f64, DynRank>,
446    ) -> bool {
447        self.fitter.fit_nd_dd_to(backend, values, dim, out)
448    }
449
450    fn fit_nd_zz_to(
451        &self,
452        backend: Option<&GemmBackendHandle>,
453        values: &Slice<Complex<f64>, DynRank>,
454        dim: usize,
455        out: &mut ViewMut<'_, Complex<f64>, DynRank>,
456    ) -> bool {
457        self.fitter.fit_nd_zz_to(backend, values, dim, out)
458    }
459}
460
461#[cfg(test)]
462#[path = "tau_sampling_tests.rs"]
463mod tests;