1use crate::sse::qmc_traits::{DiagonalUpdater, Hamiltonian};
2use crate::sse::Op;
3use rand::Rng;
4#[cfg(feature = "serialize")]
5use serde::{Deserialize, Serialize};
6
7#[derive(Clone, Debug)]
9#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
10pub struct BondWeights {
11 max_weight_and_cumulative: Vec<(usize, f64, f64)>,
12}
13
14impl BondWeights {
15 pub fn new<It>(max_bond_weights: It) -> Self
17 where
18 It: IntoIterator<Item = f64>,
19 {
20 let max_weight_and_cumulative =
21 max_bond_weights
22 .into_iter()
23 .enumerate()
24 .fold(vec![], |mut acc, (b, w)| {
25 if acc.is_empty() {
26 acc.push((b, w, w));
27 } else {
28 acc.push((b, w, w + acc[acc.len() - 1].2));
29 };
30 acc
31 });
32 Self {
33 max_weight_and_cumulative,
34 }
35 }
36
37 fn get_random_bond_and_max_weight<R: Rng>(&self, mut rng: R) -> Result<(usize, f64), &str> {
38 if let Some(total) = self.total() {
39 let c = rng.gen_range(0. ..total);
40 let index = self.index_for_cumulative(c);
41 Ok((
42 self.max_weight_and_cumulative[index].0,
43 self.max_weight_and_cumulative[index].1,
44 ))
45 } else {
46 Err("No bonds provided")
47 }
48 }
49
50 fn total(&self) -> Option<f64> {
51 self.max_weight_and_cumulative
52 .last()
53 .map(|(_, _, tot)| *tot)
54 }
55
56 fn index_for_cumulative(&self, val: f64) -> usize {
57 self.max_weight_and_cumulative
58 .binary_search_by(|(_, _, c)| c.partial_cmp(&val).unwrap())
59 .unwrap_or_else(|x| x)
60 }
61}
62
63pub trait HeatBathDiagonalUpdater: DiagonalUpdater {
65 fn make_heatbath_diagonal_update<'b, H: Hamiltonian<'b>>(
67 &mut self,
68 cutoff: usize,
69 beta: f64,
70 state: &[bool],
71 hamiltonian: &H,
72 bond_weights: &BondWeights,
73 ) {
74 self.make_heatbath_diagonal_update_with_rng(
75 cutoff,
76 beta,
77 state,
78 hamiltonian,
79 bond_weights,
80 &mut rand::thread_rng(),
81 )
82 }
83
84 fn make_heatbath_diagonal_update_with_rng<'b, H: Hamiltonian<'b>, R: Rng>(
86 &mut self,
87 cutoff: usize,
88 beta: f64,
89 state: &[bool],
90 hamiltonian: &H,
91 bond_weights: &BondWeights,
92 rng: &mut R,
93 ) {
94 let mut state = state.to_vec();
95 self.make_heatbath_diagonal_update_with_rng_and_state_ref(
96 cutoff,
97 beta,
98 &mut state,
99 hamiltonian,
100 bond_weights,
101 rng,
102 )
103 }
104
105 fn make_heatbath_diagonal_update_with_rng_and_state_ref<'b, H: Hamiltonian<'b>, R: Rng>(
107 &mut self,
108 cutoff: usize,
109 beta: f64,
110 state: &mut [bool],
111 hamiltonian: &H,
112 bond_weights: &BondWeights,
113 rng: &mut R,
114 ) {
115 self.mutate_ps(0, cutoff, (state, rng), |s, op, (state, rng)| {
116 let op = Self::heat_bath_single_diagonal_update(
117 op,
118 cutoff,
119 s.get_n(),
120 beta,
121 state,
122 (hamiltonian, bond_weights),
123 rng,
124 );
125 (op, (state, rng))
126 });
127 }
128
129 fn make_bond_weights<'b, H, E>(hamiltonian: H, num_bonds: usize, bonds_fn: E) -> BondWeights
131 where
132 H: Fn(&[usize], usize, &[bool], &[bool]) -> f64,
133 E: Fn(usize) -> &'b [usize],
134 {
135 let max_weights = (0..num_bonds).map(|i| {
136 let vars = bonds_fn(i);
137 (0..1 << vars.len())
138 .map(|substate| {
139 let substate =
140 Self::Op::make_substate((0..vars.len()).map(|v| (substate >> v) & 1 == 1));
141 hamiltonian(vars, i, substate.as_ref(), substate.as_ref())
142 })
143 .fold(0.0, |acc, w| if w > acc { w } else { acc })
144 });
145 BondWeights::new(max_weights)
146 }
147
148 fn heat_bath_single_diagonal_update<'b, H: Hamiltonian<'b>, R: Rng>(
150 op: Option<&Self::Op>,
151 cutoff: usize,
152 n: usize,
153 beta: f64,
154 state: &mut [bool],
155 hamiltonian_and_weights: (&H, &BondWeights),
156 rng: &mut R,
157 ) -> Option<Option<Self::Op>> {
158 let (hamiltonian, bond_weights) = hamiltonian_and_weights;
159 let new_op = match op {
160 None => {
161 let numerator = beta * bond_weights.total().unwrap();
162 let denominator = (cutoff - n) as f64 + numerator;
163 if rng.gen_bool(numerator / denominator) {
164 let p = rng.gen_range(0. ..1.0);
166 let (b, maxweight) = bond_weights.get_random_bond_and_max_weight(rng).unwrap();
168 let (vars, constant) = hamiltonian.edge_fn(b);
169 let substate = Self::Op::make_substate(vars.iter().map(|v| state[*v]));
170 let vars = Self::Op::make_vars(vars.iter().cloned());
171
172 let weight = hamiltonian.hamiltonian(
173 vars.as_ref(),
174 b,
175 substate.as_ref(),
176 substate.as_ref(),
177 );
178
179 if p * maxweight < weight {
180 let op = Self::Op::diagonal(vars, b, substate, constant);
181 Some(Some(op))
182 } else {
183 None
184 }
185 } else {
186 None
187 }
188 }
189 Some(op) if op.is_diagonal() => {
190 let numerator = (cutoff - n + 1) as f64;
191 let denominator = numerator + beta * bond_weights.total().unwrap();
192
193 if rng.gen_bool(numerator / denominator) {
194 Some(None)
195 } else {
196 None
197 }
198 }
199 Some(op) => {
201 op.get_vars()
202 .iter()
203 .zip(op.get_outputs().iter())
204 .for_each(|(v, b)| state[*v] = *b);
205 None
206 }
207 };
208 new_op
209 }
210}