qmc/sse/qmc_traits/
heatbath.rs

1use crate::sse::qmc_traits::{DiagonalUpdater, Hamiltonian};
2use crate::sse::Op;
3use rand::Rng;
4#[cfg(feature = "serialize")]
5use serde::{Deserialize, Serialize};
6
7/// Bond weight storage for fast lookup.
8#[derive(Clone, Debug)]
9#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
10pub struct BondWeights {
11    max_weight_and_cumulative: Vec<(usize, f64, f64)>,
12}
13
14impl BondWeights {
15    /// Make a new BondWeights using an iterator of each individual bond's weight.
16    pub fn new<It>(max_bond_weights: It) -> Self
17    where
18        It: IntoIterator<Item = f64>,
19    {
20        let max_weight_and_cumulative =
21            max_bond_weights
22                .into_iter()
23                .enumerate()
24                .fold(vec![], |mut acc, (b, w)| {
25                    if acc.is_empty() {
26                        acc.push((b, w, w));
27                    } else {
28                        acc.push((b, w, w + acc[acc.len() - 1].2));
29                    };
30                    acc
31                });
32        Self {
33            max_weight_and_cumulative,
34        }
35    }
36
37    fn get_random_bond_and_max_weight<R: Rng>(&self, mut rng: R) -> Result<(usize, f64), &str> {
38        if let Some(total) = self.total() {
39            let c = rng.gen_range(0. ..total);
40            let index = self.index_for_cumulative(c);
41            Ok((
42                self.max_weight_and_cumulative[index].0,
43                self.max_weight_and_cumulative[index].1,
44            ))
45        } else {
46            Err("No bonds provided")
47        }
48    }
49
50    fn total(&self) -> Option<f64> {
51        self.max_weight_and_cumulative
52            .last()
53            .map(|(_, _, tot)| *tot)
54    }
55
56    fn index_for_cumulative(&self, val: f64) -> usize {
57        self.max_weight_and_cumulative
58            .binary_search_by(|(_, _, c)| c.partial_cmp(&val).unwrap())
59            .unwrap_or_else(|x| x)
60    }
61}
62
63/// Heatbath updates for a diagonal updater.
64pub trait HeatBathDiagonalUpdater: DiagonalUpdater {
65    /// Perform a single heatbath update.
66    fn make_heatbath_diagonal_update<'b, H: Hamiltonian<'b>>(
67        &mut self,
68        cutoff: usize,
69        beta: f64,
70        state: &[bool],
71        hamiltonian: &H,
72        bond_weights: &BondWeights,
73    ) {
74        self.make_heatbath_diagonal_update_with_rng(
75            cutoff,
76            beta,
77            state,
78            hamiltonian,
79            bond_weights,
80            &mut rand::thread_rng(),
81        )
82    }
83
84    /// Perform a single heatbath update.
85    fn make_heatbath_diagonal_update_with_rng<'b, H: Hamiltonian<'b>, R: Rng>(
86        &mut self,
87        cutoff: usize,
88        beta: f64,
89        state: &[bool],
90        hamiltonian: &H,
91        bond_weights: &BondWeights,
92        rng: &mut R,
93    ) {
94        let mut state = state.to_vec();
95        self.make_heatbath_diagonal_update_with_rng_and_state_ref(
96            cutoff,
97            beta,
98            &mut state,
99            hamiltonian,
100            bond_weights,
101            rng,
102        )
103    }
104
105    /// Perform a single heatbath update.
106    fn make_heatbath_diagonal_update_with_rng_and_state_ref<'b, H: Hamiltonian<'b>, R: Rng>(
107        &mut self,
108        cutoff: usize,
109        beta: f64,
110        state: &mut [bool],
111        hamiltonian: &H,
112        bond_weights: &BondWeights,
113        rng: &mut R,
114    ) {
115        self.mutate_ps(0, cutoff, (state, rng), |s, op, (state, rng)| {
116            let op = Self::heat_bath_single_diagonal_update(
117                op,
118                cutoff,
119                s.get_n(),
120                beta,
121                state,
122                (hamiltonian, bond_weights),
123                rng,
124            );
125            (op, (state, rng))
126        });
127    }
128
129    /// Make the bond weights struct for this container.
130    fn make_bond_weights<'b, H, E>(hamiltonian: H, num_bonds: usize, bonds_fn: E) -> BondWeights
131    where
132        H: Fn(&[usize], usize, &[bool], &[bool]) -> f64,
133        E: Fn(usize) -> &'b [usize],
134    {
135        let max_weights = (0..num_bonds).map(|i| {
136            let vars = bonds_fn(i);
137            (0..1 << vars.len())
138                .map(|substate| {
139                    let substate =
140                        Self::Op::make_substate((0..vars.len()).map(|v| (substate >> v) & 1 == 1));
141                    hamiltonian(vars, i, substate.as_ref(), substate.as_ref())
142                })
143                .fold(0.0, |acc, w| if w > acc { w } else { acc })
144        });
145        BondWeights::new(max_weights)
146    }
147
148    /// Perform a single heatbath update.
149    fn heat_bath_single_diagonal_update<'b, H: Hamiltonian<'b>, R: Rng>(
150        op: Option<&Self::Op>,
151        cutoff: usize,
152        n: usize,
153        beta: f64,
154        state: &mut [bool],
155        hamiltonian_and_weights: (&H, &BondWeights),
156        rng: &mut R,
157    ) -> Option<Option<Self::Op>> {
158        let (hamiltonian, bond_weights) = hamiltonian_and_weights;
159        let new_op = match op {
160            None => {
161                let numerator = beta * bond_weights.total().unwrap();
162                let denominator = (cutoff - n) as f64 + numerator;
163                if rng.gen_bool(numerator / denominator) {
164                    // For usage later.
165                    let p = rng.gen_range(0. ..1.0);
166                    // Find the bond to use, weighted by their matrix element.
167                    let (b, maxweight) = bond_weights.get_random_bond_and_max_weight(rng).unwrap();
168                    let (vars, constant) = hamiltonian.edge_fn(b);
169                    let substate = Self::Op::make_substate(vars.iter().map(|v| state[*v]));
170                    let vars = Self::Op::make_vars(vars.iter().cloned());
171
172                    let weight = hamiltonian.hamiltonian(
173                        vars.as_ref(),
174                        b,
175                        substate.as_ref(),
176                        substate.as_ref(),
177                    );
178
179                    if p * maxweight < weight {
180                        let op = Self::Op::diagonal(vars, b, substate, constant);
181                        Some(Some(op))
182                    } else {
183                        None
184                    }
185                } else {
186                    None
187                }
188            }
189            Some(op) if op.is_diagonal() => {
190                let numerator = (cutoff - n + 1) as f64;
191                let denominator = numerator + beta * bond_weights.total().unwrap();
192
193                if rng.gen_bool(numerator / denominator) {
194                    Some(None)
195                } else {
196                    None
197                }
198            }
199            // Update state
200            Some(op) => {
201                op.get_vars()
202                    .iter()
203                    .zip(op.get_outputs().iter())
204                    .for_each(|(v, b)| state[*v] = *b);
205                None
206            }
207        };
208        new_op
209    }
210}