1use std::sync::Arc;
7
8use crate::kernel::{CentrosymmKernel, KernelProperties, LogisticKernel};
9use crate::poly::{PiecewiseLegendrePolyVector, default_sampling_points};
10use crate::polyfourier::PiecewiseLegendreFTVector;
11use crate::sve::{SVEResult, TworkType, compute_sve};
12use crate::traits::{Bosonic, Fermionic, StatisticsType};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum Statistics {
17 Fermionic,
18 Bosonic,
19}
20
21#[derive(Clone)]
39pub struct FiniteTempBasis<K, S>
40where
41 K: KernelProperties + CentrosymmKernel + Clone + 'static,
42 S: StatisticsType,
43{
44 pub kernel: K,
46
47 pub sve_result: Arc<SVEResult>,
49
50 pub accuracy: f64,
52
53 pub beta: f64,
55
56 pub u: Arc<PiecewiseLegendrePolyVector>,
59
60 pub v: Arc<PiecewiseLegendrePolyVector>,
63
64 pub s: Vec<f64>,
66
67 pub uhat: Arc<PiecewiseLegendreFTVector<S>>,
70
71 pub uhat_full: Arc<PiecewiseLegendreFTVector<S>>,
74
75 _phantom: std::marker::PhantomData<S>,
76}
77
78impl<K, S> FiniteTempBasis<K, S>
79where
80 K: KernelProperties + CentrosymmKernel + Clone + 'static,
81 S: StatisticsType,
82{
83 pub fn wmax(&self) -> f64 {
85 self.kernel.lambda() / self.beta
86 }
87
88 pub fn default_matsubara_sampling_points_i64(&self, positive_only: bool) -> Vec<i64>
90 where
91 S: 'static,
92 {
93 let freqs = self.default_matsubara_sampling_points(positive_only);
94 freqs.into_iter().map(|f| f.n()).collect()
95 }
96
97 pub fn default_matsubara_sampling_points_i64_with_mitigate(
99 &self,
100 positive_only: bool,
101 mitigate: bool,
102 n_points: usize,
103 ) -> Vec<i64>
104 where
105 S: 'static,
106 {
107 let fence = mitigate;
108 let freqs = Self::default_matsubara_sampling_points_impl(
109 &self.uhat_full,
110 n_points,
111 fence,
112 positive_only,
113 );
114 freqs.into_iter().map(|f| f.n()).collect()
115 }
116
117 pub fn new(kernel: K, beta: f64, epsilon: Option<f64>, max_size: Option<usize>) -> Self {
130 if beta <= 0.0 {
132 panic!("Inverse temperature beta must be positive, got {}", beta);
133 }
134
135 let epsilon_value = epsilon.unwrap_or(f64::NAN);
137 let sve_result = compute_sve(
138 kernel.clone(),
139 epsilon_value,
140 None, max_size,
142 TworkType::Auto,
143 );
144
145 Self::from_sve_result(kernel, beta, sve_result, epsilon, max_size)
146 }
147
148 pub fn from_sve_result(
153 kernel: K,
154 beta: f64,
155 sve_result: SVEResult,
156 epsilon: Option<f64>,
157 max_size: Option<usize>,
158 ) -> Self {
159 let (u_sve, s_sve, v_sve) = sve_result.part(epsilon, max_size);
161
162 let accuracy = if sve_result.s.len() > s_sve.len() {
164 sve_result.s[s_sve.len()] / sve_result.s[0]
165 } else {
166 sve_result.s[sve_result.s.len() - 1] / sve_result.s[0]
167 };
168
169 let lambda = kernel.lambda();
171 let omega_max = lambda / beta;
172
173 let u_knots: Vec<f64> = u_sve.get_polys()[0]
178 .knots
179 .iter()
180 .map(|&x| beta / 2.0 * (x + 1.0))
181 .collect();
182 let u_delta_x: Vec<f64> = u_sve.get_polys()[0]
183 .delta_x
184 .iter()
185 .map(|&dx| beta / 2.0 * dx)
186 .collect();
187 let u_symm: Vec<i32> = u_sve.get_polys().iter().map(|p| p.symm).collect();
188
189 let u = u_sve.rescale_domain(u_knots, Some(u_delta_x), Some(u_symm));
190
191 let v_knots: Vec<f64> = v_sve.get_polys()[0]
193 .knots
194 .iter()
195 .map(|&y| omega_max * y)
196 .collect();
197 let v_delta_x: Vec<f64> = v_sve.get_polys()[0]
198 .delta_x
199 .iter()
200 .map(|&dy| omega_max * dy)
201 .collect();
202 let v_symm: Vec<i32> = v_sve.get_polys().iter().map(|p| p.symm).collect();
203
204 let v = v_sve.rescale_domain(v_knots, Some(v_delta_x), Some(v_symm));
205
206 let ypower = kernel.ypower();
209 let scale_factor = (beta / 2.0 * omega_max).sqrt() * omega_max.powi(-ypower);
210 let s: Vec<f64> = s_sve.iter().map(|&x| scale_factor * x).collect();
211
212 let uhat_base_full = sve_result.u.scale_data(beta.sqrt());
215 let conv_rad = kernel.conv_radius();
216
217 let stat_marker = unsafe { std::mem::zeroed::<S>() };
220
221 let uhat_full = PiecewiseLegendreFTVector::<S>::from_poly_vector(
222 &uhat_base_full,
223 stat_marker,
224 Some(conv_rad),
225 );
226
227 let uhat_polyvec: Vec<_> = uhat_full.polyvec.iter().take(s.len()).cloned().collect();
229 let uhat = PiecewiseLegendreFTVector::from_vector(uhat_polyvec);
230
231 Self {
232 kernel,
233 sve_result: Arc::new(sve_result),
234 accuracy,
235 beta,
236 u: Arc::new(u),
237 v: Arc::new(v),
238 s,
239 uhat: Arc::new(uhat),
240 uhat_full: Arc::new(uhat_full),
241 _phantom: std::marker::PhantomData,
242 }
243 }
244
245 pub fn size(&self) -> usize {
247 self.s.len()
248 }
249
250 pub fn lambda(&self) -> f64 {
252 self.kernel.lambda()
253 }
254
255 pub fn omega_max(&self) -> f64 {
257 self.lambda() / self.beta
258 }
259
260 pub fn significance(&self) -> Vec<f64> {
262 let s0 = self.s[0];
263 self.s.iter().map(|&s| s / s0).collect()
264 }
265
266 pub fn default_tau_sampling_points(&self) -> Vec<f64> {
272 let sz = self.size();
273
274 let x = default_sampling_points(&self.sve_result.u, sz);
276
277 let mut unique_x = Vec::new();
279 if x.len() % 2 == 0 {
280 for i in 0..(x.len() / 2) {
282 unique_x.push(x[i]);
283 }
284 } else {
285 for i in 0..(x.len() / 2) {
287 unique_x.push(x[i]);
288 }
289 let x_new = 0.5 * (unique_x.last().unwrap() + 0.5);
291 unique_x.push(x_new);
292 }
293
294 let mut smpl_taus = Vec::with_capacity(2 * unique_x.len());
301 for &ux in &unique_x {
302 smpl_taus.push((self.beta / 2.0) * (ux + 1.0));
303 }
304 for i in 0..unique_x.len() {
305 smpl_taus.push(-smpl_taus[i]);
306 }
307
308 smpl_taus.sort_by(|a, b| a.partial_cmp(b).unwrap());
310
311 if smpl_taus.len() % 2 != 0 {
313 panic!("The number of tau sampling points is odd!");
314 }
315
316 for &tau in &smpl_taus {
318 if tau.abs() < 1e-10 {
319 eprintln!(
320 "Warning: tau = 0 is in the sampling points (absolute error: {})",
321 tau.abs()
322 );
323 }
324 }
325
326 smpl_taus
329 }
330
331 pub fn default_matsubara_sampling_points(
342 &self,
343 positive_only: bool,
344 ) -> Vec<crate::freq::MatsubaraFreq<S>>
345 where
346 S: 'static,
347 {
348 let fence = false;
349 Self::default_matsubara_sampling_points_impl(
350 &self.uhat_full,
351 self.size(),
352 fence,
353 positive_only,
354 )
355 }
356
357 fn fence_matsubara_sampling(
365 omega_n: &mut Vec<crate::freq::MatsubaraFreq<S>>,
366 positive_only: bool,
367 ) where
368 S: StatisticsType + 'static,
369 {
370 use crate::freq::{BosonicFreq, MatsubaraFreq};
371
372 if omega_n.is_empty() {
373 return;
374 }
375
376 let mut outer_frequencies = Vec::new();
378 if positive_only {
379 outer_frequencies.push(omega_n[omega_n.len() - 1]);
380 } else {
381 outer_frequencies.push(omega_n[0]);
382 outer_frequencies.push(omega_n[omega_n.len() - 1]);
383 }
384
385 for wn_outer in outer_frequencies {
386 let outer_val = wn_outer.n();
387 let mut diff_val = 2 * (0.025 * outer_val as f64).round() as i64;
390
391 if diff_val == 0 {
393 diff_val = 2;
394 }
395
396 let wn_diff = BosonicFreq::new(diff_val).unwrap().n();
398
399 let sign_val = if outer_val > 0 {
402 1
403 } else if outer_val < 0 {
404 -1
405 } else {
406 0
407 };
408
409 let original_size = omega_n.len();
411 if original_size >= 20 {
412 let new_n = outer_val - sign_val * wn_diff;
415 if let Ok(new_freq) = MatsubaraFreq::<S>::new(new_n) {
416 omega_n.push(new_freq);
417 }
418 }
419 if original_size >= 42 {
420 let new_n = outer_val + sign_val * wn_diff;
421 if let Ok(new_freq) = MatsubaraFreq::<S>::new(new_n) {
422 omega_n.push(new_freq);
423 }
424 }
425 }
426
427 let omega_n_set: std::collections::BTreeSet<MatsubaraFreq<S>> = omega_n.drain(..).collect();
429 *omega_n = omega_n_set.into_iter().collect();
430 }
431
432 pub fn default_matsubara_sampling_points_impl(
433 uhat_full: &PiecewiseLegendreFTVector<S>,
434 l: usize,
435 fence: bool,
436 positive_only: bool,
437 ) -> Vec<crate::freq::MatsubaraFreq<S>>
438 where
439 S: StatisticsType + 'static,
440 {
441 use crate::freq::MatsubaraFreq;
442 use crate::polyfourier::{find_extrema, sign_changes};
443 use std::collections::BTreeSet;
444
445 let mut l_requested = l;
446
447 if S::STATISTICS == crate::traits::Statistics::Fermionic && l_requested % 2 != 0 {
449 l_requested += 1;
450 } else if S::STATISTICS == crate::traits::Statistics::Bosonic && l_requested % 2 == 0 {
451 l_requested += 1;
452 }
453
454 let mut omega_n = if l_requested < uhat_full.len() {
456 sign_changes(&uhat_full[l_requested], positive_only)
457 } else {
458 find_extrema(&uhat_full[uhat_full.len() - 1], positive_only)
459 };
460
461 if S::STATISTICS == crate::traits::Statistics::Bosonic {
463 omega_n.push(MatsubaraFreq::<S>::new(0).unwrap());
464 }
465
466 let omega_n_set: BTreeSet<MatsubaraFreq<S>> = omega_n.into_iter().collect();
468 let mut omega_n: Vec<MatsubaraFreq<S>> = omega_n_set.into_iter().collect();
469
470 let expected_size = if positive_only {
472 l_requested.div_ceil(2)
473 } else {
474 l_requested
475 };
476
477 if omega_n.len() != expected_size {
478 eprintln!(
479 "Warning: Requested {} sampling frequencies for basis size L = {}, but got {}.",
480 expected_size,
481 l,
482 omega_n.len()
483 );
484 }
485
486 if fence {
488 Self::fence_matsubara_sampling(&mut omega_n, positive_only);
489 }
490
491 omega_n
492 }
493 pub fn default_omega_sampling_points(&self) -> Vec<f64> {
504 let sz = self.size();
505
506 let y = default_sampling_points(&self.sve_result.v, sz);
509
510 let wmax = self.kernel.lambda() / self.beta;
512 let omega_points: Vec<f64> = y.into_iter().map(|yi| wmax * yi).collect();
513
514 omega_points
515 }
516}
517
518impl<K, S> crate::basis_trait::Basis<S> for FiniteTempBasis<K, S>
523where
524 K: KernelProperties + CentrosymmKernel + Clone + 'static,
525 S: StatisticsType + 'static,
526{
527 type Kernel = K;
528
529 fn kernel(&self) -> &Self::Kernel {
530 &self.kernel
531 }
532
533 fn beta(&self) -> f64 {
534 self.beta
535 }
536
537 fn wmax(&self) -> f64 {
538 self.kernel.lambda() / self.beta
539 }
540
541 fn lambda(&self) -> f64 {
542 self.kernel.lambda()
543 }
544
545 fn size(&self) -> usize {
546 self.size()
547 }
548
549 fn accuracy(&self) -> f64 {
550 self.accuracy
551 }
552
553 fn significance(&self) -> Vec<f64> {
554 if let Some(&first_s) = self.s.first() {
555 self.s.iter().map(|&s| s / first_s).collect()
556 } else {
557 vec![]
558 }
559 }
560
561 fn svals(&self) -> Vec<f64> {
562 self.s.clone()
563 }
564
565 fn default_tau_sampling_points(&self) -> Vec<f64> {
566 self.default_tau_sampling_points()
567 }
568
569 fn default_matsubara_sampling_points(
570 &self,
571 positive_only: bool,
572 ) -> Vec<crate::freq::MatsubaraFreq<S>> {
573 self.default_matsubara_sampling_points(positive_only)
574 }
575
576 fn evaluate_tau(&self, tau: &[f64]) -> mdarray::DTensor<f64, 2> {
577 use crate::taufuncs::normalize_tau;
578 use mdarray::DTensor;
579
580 let n_points = tau.len();
581 let basis_size = self.size();
582
583 DTensor::<f64, 2>::from_fn([n_points, basis_size], |idx| {
588 let i = idx[0]; let l = idx[1]; let (tau_norm, sign) = normalize_tau::<S>(tau[i], self.beta);
593
594 sign * self.u[l].evaluate(tau_norm)
596 })
597 }
598
599 fn evaluate_matsubara(
600 &self,
601 freqs: &[crate::freq::MatsubaraFreq<S>],
602 ) -> mdarray::DTensor<num_complex::Complex<f64>, 2> {
603 use mdarray::DTensor;
604 use num_complex::Complex;
605
606 let n_points = freqs.len();
607 let basis_size = self.size();
608
609 DTensor::<Complex<f64>, 2>::from_fn([n_points, basis_size], |idx| {
612 let i = idx[0]; let l = idx[1]; self.uhat[l].evaluate(&freqs[i])
615 })
616 }
617
618 fn evaluate_omega(&self, omega: &[f64]) -> mdarray::DTensor<f64, 2> {
619 use mdarray::DTensor;
620
621 let n_points = omega.len();
622 let basis_size = self.size();
623
624 DTensor::<f64, 2>::from_fn([n_points, basis_size], |idx| {
627 let i = idx[0]; let l = idx[1]; self.v[l].evaluate(omega[i])
630 })
631 }
632
633 fn default_omega_sampling_points(&self) -> Vec<f64> {
634 self.default_omega_sampling_points()
635 }
636}
637
638pub type FermionicBasis = FiniteTempBasis<LogisticKernel, Fermionic>;
644
645pub type BosonicBasis = FiniteTempBasis<LogisticKernel, Bosonic>;
647
648#[cfg(test)]
649#[path = "basis_tests.rs"]
650mod basis_tests;