1use std::sync::Arc;
7
8use crate::kernel::{CentrosymmKernel, KernelProperties, LogisticKernel};
9use crate::poly::{PiecewiseLegendrePolyVector, default_sampling_points};
10use crate::polyfourier::PiecewiseLegendreFTVector;
11use crate::sve::{SVEResult, TworkType, compute_sve};
12use crate::traits::{Bosonic, Fermionic, StatisticsType};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum Statistics {
17 Fermionic,
18 Bosonic,
19}
20
21#[derive(Clone)]
39pub struct FiniteTempBasis<K, S>
40where
41 K: KernelProperties + CentrosymmKernel + Clone + 'static,
42 S: StatisticsType,
43{
44 kernel: K,
46
47 sve_result: Arc<SVEResult>,
49
50 accuracy: f64,
52
53 beta: f64,
55
56 u: Arc<PiecewiseLegendrePolyVector>,
59
60 v: Arc<PiecewiseLegendrePolyVector>,
63
64 s: Vec<f64>,
66
67 uhat: Arc<PiecewiseLegendreFTVector<S>>,
70
71 uhat_full: Arc<PiecewiseLegendreFTVector<S>>,
74
75 _phantom: std::marker::PhantomData<S>,
76}
77
78impl<K, S> FiniteTempBasis<K, S>
79where
80 K: KernelProperties + CentrosymmKernel + Clone + 'static,
81 S: StatisticsType,
82{
83 pub fn kernel(&self) -> &K {
87 &self.kernel
88 }
89
90 pub fn sve_result(&self) -> &Arc<SVEResult> {
92 &self.sve_result
93 }
94
95 pub fn accuracy(&self) -> f64 {
97 self.accuracy
98 }
99
100 pub fn beta(&self) -> f64 {
102 self.beta
103 }
104
105 pub fn u(&self) -> &Arc<PiecewiseLegendrePolyVector> {
107 &self.u
108 }
109
110 pub fn v(&self) -> &Arc<PiecewiseLegendrePolyVector> {
112 &self.v
113 }
114
115 pub fn s(&self) -> &[f64] {
117 &self.s
118 }
119
120 pub fn uhat(&self) -> &Arc<PiecewiseLegendreFTVector<S>> {
122 &self.uhat
123 }
124
125 pub fn uhat_full(&self) -> &Arc<PiecewiseLegendreFTVector<S>> {
127 &self.uhat_full
128 }
129
130 pub fn wmax(&self) -> f64 {
134 self.kernel.lambda() / self.beta
135 }
136
137 pub fn default_matsubara_sampling_points_i64(&self, positive_only: bool) -> Vec<i64>
139 where
140 S: 'static,
141 {
142 let freqs = self.default_matsubara_sampling_points(positive_only);
143 freqs.into_iter().map(|f| f.n()).collect()
144 }
145
146 pub fn default_matsubara_sampling_points_i64_with_mitigate(
152 &self,
153 positive_only: bool,
154 mitigate: bool,
155 n_points: usize,
156 ) -> Vec<i64>
157 where
158 S: 'static,
159 {
160 if !self.kernel().is_centrosymmetric() {
161 panic!(
162 "default_matsubara_sampling_points_i64_with_mitigate is not supported for non-centrosymmetric kernels. \
163 The current implementation relies on centrosymmetry to generate sampling points."
164 );
165 }
166 let fence = mitigate;
167 let freqs = Self::default_matsubara_sampling_points_impl(
168 &self.uhat_full,
169 n_points,
170 fence,
171 positive_only,
172 );
173 freqs.into_iter().map(|f| f.n()).collect()
174 }
175
176 pub fn new(kernel: K, beta: f64, epsilon: Option<f64>, max_size: Option<usize>) -> Self {
189 if beta <= 0.0 {
191 panic!("Inverse temperature beta must be positive, got {}", beta);
192 }
193
194 let epsilon_value = epsilon.unwrap_or(f64::NAN);
196 let sve_result = compute_sve(
197 kernel.clone(),
198 epsilon_value,
199 None, max_size,
201 TworkType::Auto,
202 );
203
204 Self::from_sve_result(kernel, beta, sve_result, epsilon, max_size)
205 }
206
207 pub fn from_sve_result(
212 kernel: K,
213 beta: f64,
214 sve_result: SVEResult,
215 epsilon: Option<f64>,
216 max_size: Option<usize>,
217 ) -> Self {
218 let (u_sve, s_sve, v_sve) = sve_result.part(epsilon, max_size);
220
221 let accuracy = if sve_result.s.len() > s_sve.len() {
223 sve_result.s[s_sve.len()] / sve_result.s[0]
224 } else {
225 sve_result.s[sve_result.s.len() - 1] / sve_result.s[0]
226 };
227
228 let lambda = kernel.lambda();
230 let omega_max = lambda / beta;
231
232 let u_knots: Vec<f64> = u_sve.get_polys()[0]
237 .knots
238 .iter()
239 .map(|&x| beta / 2.0 * (x + 1.0))
240 .collect();
241 let u_delta_x: Vec<f64> = u_sve.get_polys()[0]
242 .delta_x
243 .iter()
244 .map(|&dx| beta / 2.0 * dx)
245 .collect();
246 let u_symm: Vec<i32> = u_sve.get_polys().iter().map(|p| p.symm).collect();
247
248 let u = u_sve.rescale_domain(u_knots, Some(u_delta_x), Some(u_symm));
249
250 let v_knots: Vec<f64> = v_sve.get_polys()[0]
252 .knots
253 .iter()
254 .map(|&y| omega_max * y)
255 .collect();
256 let v_delta_x: Vec<f64> = v_sve.get_polys()[0]
257 .delta_x
258 .iter()
259 .map(|&dy| omega_max * dy)
260 .collect();
261 let v_symm: Vec<i32> = v_sve.get_polys().iter().map(|p| p.symm).collect();
262
263 let v = v_sve.rescale_domain(v_knots, Some(v_delta_x), Some(v_symm));
264
265 let ypower = kernel.ypower();
268 let scale_factor = (beta / 2.0 * omega_max).sqrt() * omega_max.powi(-ypower);
269 let s: Vec<f64> = s_sve.iter().map(|&x| scale_factor * x).collect();
270
271 let uhat_base_full = sve_result.u.scale_data(beta.sqrt());
274 let conv_rad = kernel.conv_radius();
275
276 let stat_marker = S::default();
279
280 let uhat_full = PiecewiseLegendreFTVector::<S>::from_poly_vector(
281 &uhat_base_full,
282 stat_marker,
283 Some(conv_rad),
284 );
285
286 let uhat_polyvec: Vec<_> = uhat_full.polyvec.iter().take(s.len()).cloned().collect();
288 let uhat = PiecewiseLegendreFTVector::from_vector(uhat_polyvec);
289
290 Self {
291 kernel,
292 sve_result: Arc::new(sve_result),
293 accuracy,
294 beta,
295 u: Arc::new(u),
296 v: Arc::new(v),
297 s,
298 uhat: Arc::new(uhat),
299 uhat_full: Arc::new(uhat_full),
300 _phantom: std::marker::PhantomData,
301 }
302 }
303
304 pub fn size(&self) -> usize {
306 self.s.len()
307 }
308
309 pub fn lambda(&self) -> f64 {
311 self.kernel.lambda()
312 }
313
314 pub fn omega_max(&self) -> f64 {
316 self.lambda() / self.beta
317 }
318
319 pub fn significance(&self) -> Vec<f64> {
321 let s0 = self.s[0];
322 self.s.iter().map(|&s| s / s0).collect()
323 }
324
325 pub fn default_tau_sampling_points(&self) -> Vec<f64> {
332 let points = self.default_tau_sampling_points_size_requested(self.size());
333 let basis_size = self.size();
334 if points.len() < basis_size {
335 eprintln!(
336 "Warning: Number of tau sampling points ({}) is less than basis size ({}). \
337 Basis parameters: beta={}, wmax={}, epsilon={:.2e}",
338 points.len(),
339 basis_size,
340 self.beta,
341 self.wmax(),
342 self.accuracy()
343 );
344 }
345 points
346 }
347
348 pub fn default_tau_sampling_points_size_requested(&self, size_requested: usize) -> Vec<f64> {
352 let x = default_sampling_points(&self.sve_result.u, size_requested);
353 let half_beta = self.beta / 2.0;
354 let mut smpl_taus: Vec<f64> = x
356 .iter()
357 .map(|&xi| {
358 let tau = half_beta * (xi + 1.0);
359 if tau <= half_beta {
360 tau
361 } else {
362 tau - self.beta
363 }
364 })
365 .collect();
366 smpl_taus.sort_by(|a, b| a.partial_cmp(b).unwrap());
367 smpl_taus
368 }
369
370 pub fn default_matsubara_sampling_points(
385 &self,
386 positive_only: bool,
387 ) -> Vec<crate::freq::MatsubaraFreq<S>>
388 where
389 S: 'static,
390 {
391 if !self.kernel().is_centrosymmetric() {
392 panic!(
393 "default_matsubara_sampling_points is not supported for non-centrosymmetric kernels. \
394 The current implementation relies on centrosymmetry to generate sampling points."
395 );
396 }
397 let fence = false;
398 let points = Self::default_matsubara_sampling_points_impl(
399 &self.uhat_full,
400 self.size(),
401 fence,
402 positive_only,
403 );
404 let basis_size = self.size();
405 let effective_points = if positive_only {
408 2 * points.len()
409 } else {
410 points.len()
411 };
412 if effective_points < basis_size {
413 eprintln!(
414 "Warning: Number of Matsubara sampling points ({}{}) is less than basis size ({}). \
415 Basis parameters: beta={}, wmax={}, epsilon={:.2e}",
416 points.len(),
417 if positive_only { " × 2" } else { "" },
418 basis_size,
419 self.beta,
420 self.wmax(),
421 self.accuracy()
422 );
423 }
424 points
425 }
426
427 fn fence_matsubara_sampling(
435 omega_n: &mut Vec<crate::freq::MatsubaraFreq<S>>,
436 positive_only: bool,
437 ) where
438 S: StatisticsType + 'static,
439 {
440 use crate::freq::{BosonicFreq, MatsubaraFreq};
441
442 if omega_n.is_empty() {
443 return;
444 }
445
446 let mut outer_frequencies = Vec::new();
448 if positive_only {
449 outer_frequencies.push(omega_n[omega_n.len() - 1]);
450 } else {
451 outer_frequencies.push(omega_n[0]);
452 outer_frequencies.push(omega_n[omega_n.len() - 1]);
453 }
454
455 for wn_outer in outer_frequencies {
456 let outer_val = wn_outer.n();
457 let mut diff_val = 2 * (0.025 * outer_val as f64).round() as i64;
460
461 if diff_val == 0 {
463 diff_val = 2;
464 }
465
466 let wn_diff = BosonicFreq::new(diff_val).unwrap().n();
468
469 let sign_val = if outer_val > 0 {
472 1
473 } else if outer_val < 0 {
474 -1
475 } else {
476 0
477 };
478
479 let original_size = omega_n.len();
481 if original_size >= 20 {
482 let new_n = outer_val - sign_val * wn_diff;
485 if let Ok(new_freq) = MatsubaraFreq::<S>::new(new_n) {
486 omega_n.push(new_freq);
487 }
488 }
489 if original_size >= 42 {
490 let new_n = outer_val + sign_val * wn_diff;
491 if let Ok(new_freq) = MatsubaraFreq::<S>::new(new_n) {
492 omega_n.push(new_freq);
493 }
494 }
495 }
496
497 let omega_n_set: std::collections::BTreeSet<MatsubaraFreq<S>> = omega_n.drain(..).collect();
499 *omega_n = omega_n_set.into_iter().collect();
500 }
501
502 pub fn default_matsubara_sampling_points_impl(
503 uhat_full: &PiecewiseLegendreFTVector<S>,
504 l: usize,
505 fence: bool,
506 positive_only: bool,
507 ) -> Vec<crate::freq::MatsubaraFreq<S>>
508 where
509 S: StatisticsType + 'static,
510 {
511 use crate::freq::MatsubaraFreq;
512 use crate::polyfourier::{find_extrema, sign_changes};
513 use std::collections::BTreeSet;
514
515 let mut l_requested = l;
516
517 if S::STATISTICS == crate::traits::Statistics::Fermionic && l_requested % 2 != 0 {
519 l_requested += 1;
520 } else if S::STATISTICS == crate::traits::Statistics::Bosonic && l_requested % 2 == 0 {
521 l_requested += 1;
522 }
523
524 let mut omega_n = if l_requested < uhat_full.len() {
526 sign_changes(&uhat_full[l_requested], positive_only)
527 } else {
528 find_extrema(&uhat_full[uhat_full.len() - 1], positive_only)
529 };
530
531 if S::STATISTICS == crate::traits::Statistics::Bosonic {
533 omega_n.push(MatsubaraFreq::<S>::new(0).unwrap());
534 }
535
536 let omega_n_set: BTreeSet<MatsubaraFreq<S>> = omega_n.into_iter().collect();
538 let mut omega_n: Vec<MatsubaraFreq<S>> = omega_n_set.into_iter().collect();
539
540 let expected_size = if positive_only {
542 l_requested.div_ceil(2)
543 } else {
544 l_requested
545 };
546
547 if omega_n.len() != expected_size {
548 eprintln!(
549 "Warning: Requested {} sampling frequencies for basis size L = {}, but got {}.",
550 expected_size,
551 l,
552 omega_n.len()
553 );
554 }
555
556 if fence {
558 Self::fence_matsubara_sampling(&mut omega_n, positive_only);
559 }
560
561 omega_n
562 }
563 pub fn default_omega_sampling_points(&self) -> Vec<f64> {
574 let sz = self.size();
575
576 let y = default_sampling_points(&self.sve_result.v, sz);
579
580 let wmax = self.kernel.lambda() / self.beta;
582 let omega_points: Vec<f64> = y.into_iter().map(|yi| wmax * yi).collect();
583
584 omega_points
585 }
586}
587
588impl<K, S> crate::basis_trait::Basis<S> for FiniteTempBasis<K, S>
593where
594 K: KernelProperties + CentrosymmKernel + Clone + 'static,
595 S: StatisticsType + 'static,
596{
597 type Kernel = K;
598
599 fn kernel(&self) -> &Self::Kernel {
600 &self.kernel
601 }
602
603 fn beta(&self) -> f64 {
604 self.beta
605 }
606
607 fn wmax(&self) -> f64 {
608 self.kernel.lambda() / self.beta
609 }
610
611 fn lambda(&self) -> f64 {
612 self.kernel.lambda()
613 }
614
615 fn size(&self) -> usize {
616 self.size()
617 }
618
619 fn accuracy(&self) -> f64 {
620 self.accuracy
621 }
622
623 fn significance(&self) -> Vec<f64> {
624 if let Some(&first_s) = self.s.first() {
625 self.s.iter().map(|&s| s / first_s).collect()
626 } else {
627 vec![]
628 }
629 }
630
631 fn svals(&self) -> Vec<f64> {
632 self.s.clone()
633 }
634
635 fn default_tau_sampling_points(&self) -> Vec<f64> {
636 self.default_tau_sampling_points()
637 }
638
639 fn default_matsubara_sampling_points(
640 &self,
641 positive_only: bool,
642 ) -> Vec<crate::freq::MatsubaraFreq<S>> {
643 self.default_matsubara_sampling_points(positive_only)
644 }
645
646 fn evaluate_tau(&self, tau: &[f64]) -> mdarray::DTensor<f64, 2> {
647 use crate::taufuncs::normalize_tau;
648 use mdarray::DTensor;
649
650 let n_points = tau.len();
651 let basis_size = self.size();
652
653 DTensor::<f64, 2>::from_fn([n_points, basis_size], |idx| {
658 let i = idx[0]; let l = idx[1]; let (tau_norm, sign) = normalize_tau::<S>(tau[i], self.beta);
663
664 sign * self.u[l].evaluate(tau_norm)
666 })
667 }
668
669 fn evaluate_matsubara(
670 &self,
671 freqs: &[crate::freq::MatsubaraFreq<S>],
672 ) -> mdarray::DTensor<num_complex::Complex<f64>, 2> {
673 use mdarray::DTensor;
674 use num_complex::Complex;
675
676 let n_points = freqs.len();
677 let basis_size = self.size();
678
679 DTensor::<Complex<f64>, 2>::from_fn([n_points, basis_size], |idx| {
682 let i = idx[0]; let l = idx[1]; self.uhat[l].evaluate(&freqs[i])
685 })
686 }
687
688 fn evaluate_omega(&self, omega: &[f64]) -> mdarray::DTensor<f64, 2> {
689 use mdarray::DTensor;
690
691 let n_points = omega.len();
692 let basis_size = self.size();
693
694 DTensor::<f64, 2>::from_fn([n_points, basis_size], |idx| {
697 let i = idx[0]; let l = idx[1]; self.v[l].evaluate(omega[i])
700 })
701 }
702
703 fn default_omega_sampling_points(&self) -> Vec<f64> {
704 self.default_omega_sampling_points()
705 }
706}
707
708pub type FermionicBasis = FiniteTempBasis<LogisticKernel, Fermionic>;
714
715pub type BosonicBasis = FiniteTempBasis<LogisticKernel, Bosonic>;
717
718#[cfg(test)]
719#[path = "basis_tests.rs"]
720mod basis_tests;