Skip to main content

spin_sim/
simulation.rs

1use crate::geometry::Lattice;
2use crate::statistics::{Statistics, SweepResult};
3use crate::{clusters, mcmc, spins};
4use rand::{Rng, SeedableRng};
5use rand_xoshiro::Xoshiro256StarStar;
6
7/// Mutable state for one disorder realization.
8///
9/// Holds the coupling array (fixed after construction), spin configurations for
10/// every replica at every temperature, and bookkeeping for parallel tempering.
11///
12/// With `n_replicas` replicas and `n_temps` temperatures there are
13/// `n_systems = n_replicas * n_temps` independent spin configurations.
14/// Spins are stored in a single flat `Vec` of length `n_systems * n_spins`,
15/// where system `i` occupies `spins[i*n_spins .. (i+1)*n_spins]`.
16pub struct Realization {
17    /// Forward couplings, length `n_spins * n_neighbors`.
18    pub couplings: Vec<f32>,
19    /// All spin configurations, length `n_systems * n_spins` (+1/−1).
20    pub spins: Vec<i8>,
21    /// Temperature assigned to each system slot, length `n_systems`.
22    pub temperatures: Vec<f32>,
23    /// Parallel-tempering permutation: `system_ids[slot]` is the system index
24    /// currently occupying temperature slot `slot`.
25    pub system_ids: Vec<usize>,
26    /// One PRNG per system.
27    pub rngs: Vec<Xoshiro256StarStar>,
28    /// One PRNG per overlap-update pair slot, length `n_temps * (n_replicas / 2)`.
29    pub pair_rngs: Vec<Xoshiro256StarStar>,
30    /// Cached total energy per system (E / N), length `n_systems`.
31    pub energies: Vec<f32>,
32}
33
34impl Realization {
35    /// Initialize a realization with random ±1 spins.
36    ///
37    /// Seeds replica RNGs deterministically as `base_seed, base_seed+1, …`.
38    pub fn new(
39        lattice: &Lattice,
40        couplings: Vec<f32>,
41        temps: &[f32],
42        n_replicas: usize,
43        base_seed: u64,
44    ) -> Self {
45        let n_spins = lattice.n_spins;
46        let n_temps = temps.len();
47        let n_systems = n_replicas * n_temps;
48
49        let temperatures = temps.repeat(n_replicas);
50
51        let mut rngs = Vec::with_capacity(n_systems);
52        for i in 0..n_systems {
53            rngs.push(Xoshiro256StarStar::seed_from_u64(base_seed + i as u64));
54        }
55
56        let mut spins = vec![0i8; n_systems * n_spins];
57        for (i, rng) in rngs.iter_mut().enumerate() {
58            for j in 0..n_spins {
59                spins[i * n_spins + j] = if rng.gen::<f32>() < 0.5 { -1 } else { 1 };
60            }
61        }
62
63        let system_ids: Vec<usize> = (0..n_systems).collect();
64
65        let n_pairs = n_replicas / 2;
66        let mut pair_rngs = Vec::with_capacity(n_temps * n_pairs);
67        for i in 0..n_temps * n_pairs {
68            pair_rngs.push(Xoshiro256StarStar::seed_from_u64(
69                base_seed + n_systems as u64 + i as u64,
70            ));
71        }
72
73        let (energies, _) =
74            spins::energy::compute_energies(lattice, &spins, &couplings, n_systems, false);
75
76        Self {
77            couplings,
78            spins,
79            temperatures,
80            system_ids,
81            rngs,
82            pair_rngs,
83            energies,
84        }
85    }
86
87    /// Re-randomize all spins and reset the tempering permutation.
88    pub fn reset(&mut self, lattice: &Lattice, n_replicas: usize, n_temps: usize, base_seed: u64) {
89        let n_spins = lattice.n_spins;
90        let n_systems = n_replicas * n_temps;
91
92        for i in 0..n_systems {
93            self.rngs[i] = Xoshiro256StarStar::seed_from_u64(base_seed + i as u64);
94            for j in 0..n_spins {
95                self.spins[i * n_spins + j] = if self.rngs[i].gen::<f32>() < 0.5 {
96                    -1
97                } else {
98                    1
99                };
100            }
101        }
102
103        self.system_ids = (0..n_systems).collect();
104
105        let n_pairs = n_replicas / 2;
106        for i in 0..n_temps * n_pairs {
107            self.pair_rngs[i] =
108                Xoshiro256StarStar::seed_from_u64(base_seed + n_systems as u64 + i as u64);
109        }
110
111        let (energies, _) = spins::energy::compute_energies(
112            lattice,
113            &self.spins,
114            &self.couplings,
115            n_systems,
116            false,
117        );
118        self.energies = energies;
119    }
120}
121
122/// Run the full Monte Carlo loop (warmup + measurement) for one [`Realization`].
123///
124/// Each sweep consists of:
125/// 1. A full single-spin pass (`sweep_mode`: `"metropolis"` or `"gibbs"`)
126/// 2. An optional cluster update (`cluster_mode`: `"wolff"` or `"sw"`,
127///    every `cluster_update_interval` sweeps)
128/// 3. Measurement (after `warmup_sweeps`)
129/// 4. Optional Houdayer ICM (every `houdayer_interval` sweeps, requires `n_replicas ≥ 2`)
130/// 5. Optional parallel tempering (every `pt_interval` sweeps)
131///
132/// `on_sweep` is called once per sweep (useful for progress bars).
133#[allow(clippy::too_many_arguments)]
134pub fn run_sweep_loop(
135    lattice: &Lattice,
136    real: &mut Realization,
137    n_replicas: usize,
138    n_temps: usize,
139    n_sweeps: usize,
140    warmup_sweeps: usize,
141    sweep_mode: &str,
142    cluster_update_interval: Option<usize>,
143    cluster_mode: &str,
144    pt_interval: Option<usize>,
145    houdayer_interval: Option<usize>,
146    houdayer_mode: &str,
147    overlap_cluster_mode: &str,
148    collect_csd: bool,
149    on_sweep: &(dyn Fn() + Sync),
150) -> SweepResult {
151    let n_spins = lattice.n_spins;
152    let n_systems = n_replicas * n_temps;
153    let overlap_wolff = overlap_cluster_mode == "wolff";
154
155    let (stochastic, restrict_to_negative) = match houdayer_mode {
156        "houdayer" => (false, true),
157        "jorg" => (true, true),
158        "cmr" => (true, false),
159        _ => unreachable!(),
160    };
161
162    let n_pairs = n_replicas / 2;
163
164    let mut fk_csd_accum: Vec<Vec<u64>> = (0..n_temps).map(|_| vec![0u64; n_spins + 1]).collect();
165    let mut sw_csd_buf: Vec<Vec<u64>> = (0..n_systems).map(|_| vec![0u64; n_spins + 1]).collect();
166
167    let mut overlap_csd_accum: Vec<Vec<u64>> =
168        (0..n_temps).map(|_| vec![0u64; n_spins + 1]).collect();
169    let mut overlap_csd_buf: Vec<Vec<u64>> = (0..n_temps * n_pairs)
170        .map(|_| vec![0u64; n_spins + 1])
171        .collect();
172
173    let mut mags_stat = Statistics::new(n_temps, 1);
174    let mut mags2_stat = Statistics::new(n_temps, 1);
175    let mut mags4_stat = Statistics::new(n_temps, 1);
176    let mut energies_stat = Statistics::new(n_temps, 1);
177    let mut energies2_stat = Statistics::new(n_temps, 2);
178    let mut overlap_stat = Statistics::new(n_temps, 1);
179    let mut overlap2_stat = Statistics::new(n_temps, 1);
180    let mut overlap4_stat = Statistics::new(n_temps, 1);
181
182    for sweep_id in 0..n_sweeps {
183        on_sweep();
184        let record = sweep_id >= warmup_sweeps;
185
186        match sweep_mode {
187            "metropolis" => mcmc::sweep::metropolis_sweep(
188                lattice,
189                &mut real.spins,
190                &real.couplings,
191                &real.temperatures,
192                &real.system_ids,
193                &mut real.rngs,
194            ),
195            "gibbs" => mcmc::sweep::gibbs_sweep(
196                lattice,
197                &mut real.spins,
198                &real.couplings,
199                &real.temperatures,
200                &real.system_ids,
201                &mut real.rngs,
202            ),
203            _ => unreachable!(),
204        }
205
206        let do_cluster = cluster_update_interval.is_some_and(|interval| sweep_id % interval == 0);
207
208        if do_cluster {
209            let wolff = cluster_mode == "wolff";
210            let csd_out = if collect_csd && record {
211                for buf in sw_csd_buf.iter_mut() {
212                    buf.fill(0);
213                }
214                Some(sw_csd_buf.as_mut_slice())
215            } else {
216                None
217            };
218
219            clusters::fk_update(
220                lattice,
221                &mut real.spins,
222                &real.couplings,
223                &real.temperatures,
224                &real.system_ids,
225                &mut real.rngs,
226                wolff,
227                csd_out,
228            );
229
230            if collect_csd && record {
231                for (slot, buf) in sw_csd_buf.iter().enumerate() {
232                    let accum = &mut fk_csd_accum[slot % n_temps];
233                    for (a, &b) in accum.iter_mut().zip(buf.iter()) {
234                        *a += b;
235                    }
236                }
237            }
238
239            (real.energies, _) = spins::energy::compute_energies(
240                lattice,
241                &real.spins,
242                &real.couplings,
243                n_systems,
244                false,
245            );
246        } else {
247            (real.energies, _) = spins::energy::compute_energies(
248                lattice,
249                &real.spins,
250                &real.couplings,
251                n_systems,
252                false,
253            );
254        }
255
256        if record {
257            let mut mags = vec![0.0f32; n_temps];
258            let mut mags2 = vec![0.0f32; n_temps];
259            let mut mags4 = vec![0.0f32; n_temps];
260            let mut energies_ordered = vec![0.0f32; n_temps];
261
262            for r in 0..n_replicas {
263                let offset = r * n_temps;
264                for t in 0..n_temps {
265                    let system_id = real.system_ids[offset + t];
266                    let spin_base = system_id * n_spins;
267                    let mut sum = 0i64;
268                    for j in 0..n_spins {
269                        sum += real.spins[spin_base + j] as i64;
270                    }
271                    let mag = sum as f32 / n_spins as f32;
272                    let m2 = mag * mag;
273                    mags[t] = mag;
274                    mags2[t] = m2;
275                    mags4[t] = m2 * m2;
276                    energies_ordered[t] = real.energies[system_id];
277                }
278
279                mags_stat.update(&mags);
280                mags2_stat.update(&mags2);
281                mags4_stat.update(&mags4);
282                energies_stat.update(&energies_ordered);
283                energies2_stat.update(&energies_ordered);
284            }
285
286            for pair_idx in 0..n_pairs {
287                let r_a = 2 * pair_idx;
288                let r_b = 2 * pair_idx + 1;
289                let mut overlaps = vec![0.0f32; n_temps];
290                let mut overlaps2 = vec![0.0f32; n_temps];
291                let mut overlaps4 = vec![0.0f32; n_temps];
292
293                for t in 0..n_temps {
294                    let sys_a = real.system_ids[r_a * n_temps + t];
295                    let sys_b = real.system_ids[r_b * n_temps + t];
296                    let base_a = sys_a * n_spins;
297                    let base_b = sys_b * n_spins;
298                    let mut dot = 0i64;
299                    for j in 0..n_spins {
300                        dot += (real.spins[base_a + j] as i64) * (real.spins[base_b + j] as i64);
301                    }
302                    let q = dot as f32 / n_spins as f32;
303                    let q2 = q * q;
304                    overlaps[t] = q;
305                    overlaps2[t] = q2;
306                    overlaps4[t] = q2 * q2;
307                }
308
309                overlap_stat.update(&overlaps);
310                overlap2_stat.update(&overlaps2);
311                overlap4_stat.update(&overlaps4);
312            }
313        }
314
315        if let Some(interval) = houdayer_interval {
316            if sweep_id % interval == 0 && n_replicas >= 2 {
317                let ov_csd_out = if collect_csd && record {
318                    for buf in overlap_csd_buf.iter_mut() {
319                        buf.fill(0);
320                    }
321                    Some(overlap_csd_buf.as_mut_slice())
322                } else {
323                    None
324                };
325
326                clusters::overlap_update(
327                    lattice,
328                    &mut real.spins,
329                    &real.couplings,
330                    &real.temperatures,
331                    &real.system_ids,
332                    n_replicas,
333                    n_temps,
334                    &mut real.pair_rngs,
335                    stochastic,
336                    restrict_to_negative,
337                    overlap_wolff,
338                    ov_csd_out,
339                );
340
341                if collect_csd && record {
342                    for (slot, buf) in overlap_csd_buf.iter().enumerate() {
343                        let accum = &mut overlap_csd_accum[slot / n_pairs];
344                        for (a, &b) in accum.iter_mut().zip(buf.iter()) {
345                            *a += b;
346                        }
347                    }
348                }
349
350                (real.energies, _) = spins::energy::compute_energies(
351                    lattice,
352                    &real.spins,
353                    &real.couplings,
354                    n_systems,
355                    false,
356                );
357            }
358        }
359
360        if let Some(interval) = pt_interval {
361            if sweep_id % interval == 0 {
362                for r in 0..n_replicas {
363                    let offset = r * n_temps;
364                    let sid_slice = &mut real.system_ids[offset..offset + n_temps];
365                    let temp_slice = &real.temperatures[offset..offset + n_temps];
366                    mcmc::tempering::parallel_tempering(
367                        &real.energies,
368                        temp_slice,
369                        sid_slice,
370                        n_spins,
371                        &mut real.rngs[offset],
372                    );
373                }
374            }
375        }
376    }
377
378    SweepResult {
379        mags: mags_stat.average(),
380        mags2: mags2_stat.average(),
381        mags4: mags4_stat.average(),
382        energies: energies_stat.average(),
383        energies2: energies2_stat.average(),
384        overlap: if n_pairs > 0 {
385            overlap_stat.average()
386        } else {
387            vec![]
388        },
389        overlap2: if n_pairs > 0 {
390            overlap2_stat.average()
391        } else {
392            vec![]
393        },
394        overlap4: if n_pairs > 0 {
395            overlap4_stat.average()
396        } else {
397            vec![]
398        },
399        fk_csd: fk_csd_accum,
400        overlap_csd: overlap_csd_accum,
401    }
402}