1use std::sync::Arc;
7
8use crate::kernel::{CentrosymmKernel, KernelProperties, LogisticKernel};
9use crate::poly::{PiecewiseLegendrePolyVector, default_sampling_points};
10use crate::polyfourier::PiecewiseLegendreFTVector;
11use crate::sve::{SVEResult, TworkType, compute_sve};
12use crate::traits::{Bosonic, Fermionic, StatisticsType};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum Statistics {
17 Fermionic,
18 Bosonic,
19}
20
21#[derive(Clone)]
39pub struct FiniteTempBasis<K, S>
40where
41 K: KernelProperties + CentrosymmKernel + Clone + 'static,
42 S: StatisticsType,
43{
44 pub kernel: K,
46
47 pub sve_result: Arc<SVEResult>,
49
50 pub accuracy: f64,
52
53 pub beta: f64,
55
56 pub u: Arc<PiecewiseLegendrePolyVector>,
59
60 pub v: Arc<PiecewiseLegendrePolyVector>,
63
64 pub s: Vec<f64>,
66
67 pub uhat: Arc<PiecewiseLegendreFTVector<S>>,
70
71 pub uhat_full: Arc<PiecewiseLegendreFTVector<S>>,
74
75 _phantom: std::marker::PhantomData<S>,
76}
77
78impl<K, S> FiniteTempBasis<K, S>
79where
80 K: KernelProperties + CentrosymmKernel + Clone + 'static,
81 S: StatisticsType,
82{
83 pub fn wmax(&self) -> f64 {
85 self.kernel.lambda() / self.beta
86 }
87
88 pub fn default_matsubara_sampling_points_i64(&self, positive_only: bool) -> Vec<i64>
90 where
91 S: 'static,
92 {
93 let freqs = self.default_matsubara_sampling_points(positive_only);
94 freqs.into_iter().map(|f| f.n()).collect()
95 }
96
97 pub fn default_matsubara_sampling_points_i64_with_mitigate(
99 &self,
100 positive_only: bool,
101 mitigate: bool,
102 n_points: usize,
103 ) -> Vec<i64>
104 where
105 S: 'static,
106 {
107 let fence = mitigate;
108 let freqs = Self::default_matsubara_sampling_points_impl(
109 &self.uhat_full,
110 n_points,
111 fence,
112 positive_only,
113 );
114 freqs.into_iter().map(|f| f.n()).collect()
115 }
116
117 pub fn new(kernel: K, beta: f64, epsilon: Option<f64>, max_size: Option<usize>) -> Self {
130 if beta <= 0.0 {
132 panic!("Inverse temperature beta must be positive, got {}", beta);
133 }
134
135 let epsilon_value = epsilon.unwrap_or(f64::NAN);
137 let sve_result = compute_sve(
138 kernel.clone(),
139 epsilon_value,
140 None, max_size,
142 TworkType::Auto,
143 );
144
145 Self::from_sve_result(kernel, beta, sve_result, epsilon, max_size)
146 }
147
148 pub fn from_sve_result(
153 kernel: K,
154 beta: f64,
155 sve_result: SVEResult,
156 epsilon: Option<f64>,
157 max_size: Option<usize>,
158 ) -> Self {
159 let (u_sve, s_sve, v_sve) = sve_result.part(epsilon, max_size);
161
162 let accuracy = if sve_result.s.len() > s_sve.len() {
164 sve_result.s[s_sve.len()] / sve_result.s[0]
165 } else {
166 sve_result.s[sve_result.s.len() - 1] / sve_result.s[0]
167 };
168
169 let lambda = kernel.lambda();
171 let omega_max = lambda / beta;
172
173 let u_knots: Vec<f64> = u_sve.get_polys()[0]
178 .knots
179 .iter()
180 .map(|&x| beta / 2.0 * (x + 1.0))
181 .collect();
182 let u_delta_x: Vec<f64> = u_sve.get_polys()[0]
183 .delta_x
184 .iter()
185 .map(|&dx| beta / 2.0 * dx)
186 .collect();
187 let u_symm: Vec<i32> = u_sve.get_polys().iter().map(|p| p.symm).collect();
188
189 let u = u_sve.rescale_domain(u_knots, Some(u_delta_x), Some(u_symm));
190
191 let v_knots: Vec<f64> = v_sve.get_polys()[0]
193 .knots
194 .iter()
195 .map(|&y| omega_max * y)
196 .collect();
197 let v_delta_x: Vec<f64> = v_sve.get_polys()[0]
198 .delta_x
199 .iter()
200 .map(|&dy| omega_max * dy)
201 .collect();
202 let v_symm: Vec<i32> = v_sve.get_polys().iter().map(|p| p.symm).collect();
203
204 let v = v_sve.rescale_domain(v_knots, Some(v_delta_x), Some(v_symm));
205
206 let ypower = kernel.ypower();
209 let scale_factor = (beta / 2.0 * omega_max).sqrt() * omega_max.powi(-ypower);
210 let s: Vec<f64> = s_sve.iter().map(|&x| scale_factor * x).collect();
211
212 let uhat_base_full = sve_result.u.scale_data(beta.sqrt());
215 let conv_rad = kernel.conv_radius();
216
217 let stat_marker = unsafe { std::mem::zeroed::<S>() };
220
221 let uhat_full = PiecewiseLegendreFTVector::<S>::from_poly_vector(
222 &uhat_base_full,
223 stat_marker,
224 Some(conv_rad),
225 );
226
227 let uhat_polyvec: Vec<_> = uhat_full.polyvec.iter().take(s.len()).cloned().collect();
229 let uhat = PiecewiseLegendreFTVector::from_vector(uhat_polyvec);
230
231 Self {
232 kernel,
233 sve_result: Arc::new(sve_result),
234 accuracy,
235 beta,
236 u: Arc::new(u),
237 v: Arc::new(v),
238 s,
239 uhat: Arc::new(uhat),
240 uhat_full: Arc::new(uhat_full),
241 _phantom: std::marker::PhantomData,
242 }
243 }
244
245 pub fn size(&self) -> usize {
247 self.s.len()
248 }
249
250 pub fn lambda(&self) -> f64 {
252 self.kernel.lambda()
253 }
254
255 pub fn omega_max(&self) -> f64 {
257 self.lambda() / self.beta
258 }
259
260 pub fn significance(&self) -> Vec<f64> {
262 let s0 = self.s[0];
263 self.s.iter().map(|&s| s / s0).collect()
264 }
265
266 pub fn default_tau_sampling_points(&self) -> Vec<f64> {
272 self.default_tau_sampling_points_size_requested(self.size())
273 }
274
275 pub fn default_tau_sampling_points_size_requested(&self, size_requested: usize) -> Vec<f64> {
276 let x = default_sampling_points(&self.sve_result.u, size_requested);
278 let mut unique_x = Vec::new();
280 if x.len() % 2 == 0 {
281 for i in 0..(x.len() / 2) {
283 unique_x.push(x[i]);
284 }
285 } else {
286 for i in 0..(x.len() / 2) {
288 unique_x.push(x[i]);
289 }
290 let x_new = 0.5 * (unique_x.last().unwrap() + 0.5);
292 unique_x.push(x_new);
293 }
294
295 let mut smpl_taus = Vec::with_capacity(2 * unique_x.len());
302 for &ux in &unique_x {
303 smpl_taus.push((self.beta / 2.0) * (ux + 1.0));
304 }
305 for i in 0..unique_x.len() {
306 smpl_taus.push(-smpl_taus[i]);
307 }
308
309 smpl_taus.sort_by(|a, b| a.partial_cmp(b).unwrap());
311
312 if smpl_taus.len() % 2 != 0 {
314 panic!("The number of tau sampling points is odd!");
315 }
316
317 for &tau in &smpl_taus {
319 if tau.abs() < 1e-10 {
320 eprintln!(
321 "Warning: tau = 0 is in the sampling points (absolute error: {})",
322 tau.abs()
323 );
324 }
325 }
326
327 smpl_taus
330 }
331
332 pub fn default_matsubara_sampling_points(
343 &self,
344 positive_only: bool,
345 ) -> Vec<crate::freq::MatsubaraFreq<S>>
346 where
347 S: 'static,
348 {
349 let fence = false;
350 Self::default_matsubara_sampling_points_impl(
351 &self.uhat_full,
352 self.size(),
353 fence,
354 positive_only,
355 )
356 }
357
358 fn fence_matsubara_sampling(
366 omega_n: &mut Vec<crate::freq::MatsubaraFreq<S>>,
367 positive_only: bool,
368 ) where
369 S: StatisticsType + 'static,
370 {
371 use crate::freq::{BosonicFreq, MatsubaraFreq};
372
373 if omega_n.is_empty() {
374 return;
375 }
376
377 let mut outer_frequencies = Vec::new();
379 if positive_only {
380 outer_frequencies.push(omega_n[omega_n.len() - 1]);
381 } else {
382 outer_frequencies.push(omega_n[0]);
383 outer_frequencies.push(omega_n[omega_n.len() - 1]);
384 }
385
386 for wn_outer in outer_frequencies {
387 let outer_val = wn_outer.n();
388 let mut diff_val = 2 * (0.025 * outer_val as f64).round() as i64;
391
392 if diff_val == 0 {
394 diff_val = 2;
395 }
396
397 let wn_diff = BosonicFreq::new(diff_val).unwrap().n();
399
400 let sign_val = if outer_val > 0 {
403 1
404 } else if outer_val < 0 {
405 -1
406 } else {
407 0
408 };
409
410 let original_size = omega_n.len();
412 if original_size >= 20 {
413 let new_n = outer_val - sign_val * wn_diff;
416 if let Ok(new_freq) = MatsubaraFreq::<S>::new(new_n) {
417 omega_n.push(new_freq);
418 }
419 }
420 if original_size >= 42 {
421 let new_n = outer_val + sign_val * wn_diff;
422 if let Ok(new_freq) = MatsubaraFreq::<S>::new(new_n) {
423 omega_n.push(new_freq);
424 }
425 }
426 }
427
428 let omega_n_set: std::collections::BTreeSet<MatsubaraFreq<S>> = omega_n.drain(..).collect();
430 *omega_n = omega_n_set.into_iter().collect();
431 }
432
433 pub fn default_matsubara_sampling_points_impl(
434 uhat_full: &PiecewiseLegendreFTVector<S>,
435 l: usize,
436 fence: bool,
437 positive_only: bool,
438 ) -> Vec<crate::freq::MatsubaraFreq<S>>
439 where
440 S: StatisticsType + 'static,
441 {
442 use crate::freq::MatsubaraFreq;
443 use crate::polyfourier::{find_extrema, sign_changes};
444 use std::collections::BTreeSet;
445
446 let mut l_requested = l;
447
448 if S::STATISTICS == crate::traits::Statistics::Fermionic && l_requested % 2 != 0 {
450 l_requested += 1;
451 } else if S::STATISTICS == crate::traits::Statistics::Bosonic && l_requested % 2 == 0 {
452 l_requested += 1;
453 }
454
455 let mut omega_n = if l_requested < uhat_full.len() {
457 sign_changes(&uhat_full[l_requested], positive_only)
458 } else {
459 find_extrema(&uhat_full[uhat_full.len() - 1], positive_only)
460 };
461
462 if S::STATISTICS == crate::traits::Statistics::Bosonic {
464 omega_n.push(MatsubaraFreq::<S>::new(0).unwrap());
465 }
466
467 let omega_n_set: BTreeSet<MatsubaraFreq<S>> = omega_n.into_iter().collect();
469 let mut omega_n: Vec<MatsubaraFreq<S>> = omega_n_set.into_iter().collect();
470
471 let expected_size = if positive_only {
473 l_requested.div_ceil(2)
474 } else {
475 l_requested
476 };
477
478 if omega_n.len() != expected_size {
479 eprintln!(
480 "Warning: Requested {} sampling frequencies for basis size L = {}, but got {}.",
481 expected_size,
482 l,
483 omega_n.len()
484 );
485 }
486
487 if fence {
489 Self::fence_matsubara_sampling(&mut omega_n, positive_only);
490 }
491
492 omega_n
493 }
494 pub fn default_omega_sampling_points(&self) -> Vec<f64> {
505 let sz = self.size();
506
507 let y = default_sampling_points(&self.sve_result.v, sz);
510
511 let wmax = self.kernel.lambda() / self.beta;
513 let omega_points: Vec<f64> = y.into_iter().map(|yi| wmax * yi).collect();
514
515 omega_points
516 }
517}
518
519impl<K, S> crate::basis_trait::Basis<S> for FiniteTempBasis<K, S>
524where
525 K: KernelProperties + CentrosymmKernel + Clone + 'static,
526 S: StatisticsType + 'static,
527{
528 type Kernel = K;
529
530 fn kernel(&self) -> &Self::Kernel {
531 &self.kernel
532 }
533
534 fn beta(&self) -> f64 {
535 self.beta
536 }
537
538 fn wmax(&self) -> f64 {
539 self.kernel.lambda() / self.beta
540 }
541
542 fn lambda(&self) -> f64 {
543 self.kernel.lambda()
544 }
545
546 fn size(&self) -> usize {
547 self.size()
548 }
549
550 fn accuracy(&self) -> f64 {
551 self.accuracy
552 }
553
554 fn significance(&self) -> Vec<f64> {
555 if let Some(&first_s) = self.s.first() {
556 self.s.iter().map(|&s| s / first_s).collect()
557 } else {
558 vec![]
559 }
560 }
561
562 fn svals(&self) -> Vec<f64> {
563 self.s.clone()
564 }
565
566 fn default_tau_sampling_points(&self) -> Vec<f64> {
567 self.default_tau_sampling_points()
568 }
569
570 fn default_matsubara_sampling_points(
571 &self,
572 positive_only: bool,
573 ) -> Vec<crate::freq::MatsubaraFreq<S>> {
574 self.default_matsubara_sampling_points(positive_only)
575 }
576
577 fn evaluate_tau(&self, tau: &[f64]) -> mdarray::DTensor<f64, 2> {
578 use crate::taufuncs::normalize_tau;
579 use mdarray::DTensor;
580
581 let n_points = tau.len();
582 let basis_size = self.size();
583
584 DTensor::<f64, 2>::from_fn([n_points, basis_size], |idx| {
589 let i = idx[0]; let l = idx[1]; let (tau_norm, sign) = normalize_tau::<S>(tau[i], self.beta);
594
595 sign * self.u[l].evaluate(tau_norm)
597 })
598 }
599
600 fn evaluate_matsubara(
601 &self,
602 freqs: &[crate::freq::MatsubaraFreq<S>],
603 ) -> mdarray::DTensor<num_complex::Complex<f64>, 2> {
604 use mdarray::DTensor;
605 use num_complex::Complex;
606
607 let n_points = freqs.len();
608 let basis_size = self.size();
609
610 DTensor::<Complex<f64>, 2>::from_fn([n_points, basis_size], |idx| {
613 let i = idx[0]; let l = idx[1]; self.uhat[l].evaluate(&freqs[i])
616 })
617 }
618
619 fn evaluate_omega(&self, omega: &[f64]) -> mdarray::DTensor<f64, 2> {
620 use mdarray::DTensor;
621
622 let n_points = omega.len();
623 let basis_size = self.size();
624
625 DTensor::<f64, 2>::from_fn([n_points, basis_size], |idx| {
628 let i = idx[0]; let l = idx[1]; self.v[l].evaluate(omega[i])
631 })
632 }
633
634 fn default_omega_sampling_points(&self) -> Vec<f64> {
635 self.default_omega_sampling_points()
636 }
637}
638
639pub type FermionicBasis = FiniteTempBasis<LogisticKernel, Fermionic>;
645
646pub type BosonicBasis = FiniteTempBasis<LogisticKernel, Bosonic>;
648
649#[cfg(test)]
650#[path = "basis_tests.rs"]
651mod basis_tests;