Skip to main content

spin_sim/simulation/
mod.rs

1pub mod realization;
2
3pub use realization::Realization;
4
5use std::sync::atomic::{AtomicBool, Ordering};
6
7use crate::config::{OverlapClusterBuildMode, OverlapUpdateMode, SimConfig, SweepMode};
8use crate::geometry::Lattice;
9use crate::statistics::{
10    sokal_tau, AutocorrAccum, ClusterStats, Diagnostics, EquilDiagnosticAccum, Statistics,
11    SweepResult,
12};
13use crate::{clusters, mcmc, spins};
14use rayon::prelude::*;
15use validator::Validate;
16
17/// Run the full Monte Carlo loop (warmup + measurement) for one [`Realization`].
18///
19/// Each sweep consists of:
20/// 1. A full single-spin pass (`sweep_mode`: Metropolis or Gibbs)
21/// 2. An optional cluster update (every `cluster_update.interval` sweeps)
22/// 3. Measurement (after `warmup_sweeps`)
23/// 4. Optional overlap cluster move (every `overlap_cluster.interval` sweeps, requires `n_replicas ≥ 2`)
24/// 5. Optional parallel tempering (every `pt_interval` sweeps)
25///
26/// `on_sweep` is called once per sweep (useful for progress bars).
27pub fn run_sweep_loop(
28    lattice: &Lattice,
29    real: &mut Realization,
30    n_replicas: usize,
31    n_temps: usize,
32    config: &SimConfig,
33    interrupted: &AtomicBool,
34    on_sweep: &(dyn Fn() + Sync),
35) -> Result<SweepResult, String> {
36    config.validate().map_err(|e| format!("{e}"))?;
37
38    let n_spins = lattice.n_spins;
39    let n_systems = n_replicas * n_temps;
40    let n_sweeps = config.n_sweeps;
41    let warmup_sweeps = config.warmup_sweeps;
42
43    let overlap_wolff = config
44        .overlap_cluster
45        .as_ref()
46        .is_some_and(|h| h.cluster_mode == crate::config::ClusterMode::Wolff);
47
48    let (stochastic, restrict_to_negative) =
49        config
50            .overlap_cluster
51            .as_ref()
52            .map_or((false, true), |h| match h.mode {
53                OverlapClusterBuildMode::Houdayer => (false, true),
54                OverlapClusterBuildMode::Jorg => (true, true),
55                OverlapClusterBuildMode::Cmr | OverlapClusterBuildMode::Cmr3 => (true, false),
56            });
57
58    let group_size = config
59        .overlap_cluster
60        .as_ref()
61        .map_or(2, |h| h.mode.group_size());
62
63    let free_assign = config
64        .overlap_cluster
65        .as_ref()
66        .is_some_and(|h| h.update_mode == OverlapUpdateMode::Free);
67
68    let n_pairs = n_replicas / 2;
69
70    let mut fk_csd_accum: Vec<Vec<u64>> = (0..n_temps).map(|_| vec![0u64; n_spins + 1]).collect();
71    let mut sw_csd_buf: Vec<Vec<u64>> = (0..n_systems).map(|_| vec![0u64; n_spins + 1]).collect();
72
73    let mut overlap_csd_accum: Vec<Vec<u64>> =
74        (0..n_temps).map(|_| vec![0u64; n_spins + 1]).collect();
75    let mut overlap_csd_buf: Vec<Vec<u64>> = (0..n_temps * n_pairs)
76        .map(|_| vec![0u64; n_spins + 1])
77        .collect();
78
79    let collect_top = config
80        .overlap_cluster
81        .as_ref()
82        .is_some_and(|h| h.collect_top_clusters)
83        && n_pairs > 0;
84
85    let mut top4_accum: Vec<[f64; 4]> = vec![[0.0; 4]; n_temps];
86    let mut top4_n: usize = 0;
87    let mut top4_buf: Vec<[u32; 4]> = if collect_top {
88        vec![[0u32; 4]; n_temps * n_pairs]
89    } else {
90        vec![]
91    };
92
93    let mut mags_stat = Statistics::new(n_temps, 1);
94    let mut mags2_stat = Statistics::new(n_temps, 1);
95    let mut mags4_stat = Statistics::new(n_temps, 1);
96    let mut energies_stat = Statistics::new(n_temps, 1);
97    let mut energies2_stat = Statistics::new(n_temps, 2);
98    let mut overlap_stat = Statistics::new(n_temps, 1);
99    let mut overlap2_stat = Statistics::new(n_temps, 1);
100    let mut overlap4_stat = Statistics::new(n_temps, 1);
101
102    let n_measurement_sweeps = n_sweeps.saturating_sub(warmup_sweeps);
103    let ac_max_lag = config
104        .autocorrelation_max_lag
105        .map(|k| k.min(n_measurement_sweeps / 4).max(1));
106    let mut m2_accum = ac_max_lag.map(|k| AutocorrAccum::new(k, n_temps));
107    let mut q2_accum = if ac_max_lag.is_some() && n_pairs > 0 {
108        ac_max_lag.map(|k| AutocorrAccum::new(k, n_temps))
109    } else {
110        None
111    };
112    let collect_ac = ac_max_lag.is_some();
113    let collect_q2_ac = q2_accum.is_some();
114    let mut m2_ac_buf = if collect_ac {
115        vec![0.0f64; n_temps]
116    } else {
117        vec![]
118    };
119    let mut q2_ac_buf = if collect_q2_ac {
120        vec![0.0f64; n_temps]
121    } else {
122        vec![]
123    };
124
125    let equil_diag = config.equilibration_diagnostic;
126    let mut equil_accum = if equil_diag {
127        Some(EquilDiagnosticAccum::new(n_temps, n_sweeps))
128    } else {
129        None
130    };
131    let mut diag_e_buf = if equil_diag {
132        vec![0.0f32; n_temps]
133    } else {
134        vec![]
135    };
136
137    let mut mags_buf = vec![0.0f32; n_temps];
138    let mut mags2_buf = vec![0.0f32; n_temps];
139    let mut mags4_buf = vec![0.0f32; n_temps];
140    let mut energies_buf = vec![0.0f32; n_temps];
141    let mut overlaps_buf = vec![0.0f32; n_temps];
142    let mut overlaps2_buf = vec![0.0f32; n_temps];
143    let mut overlaps4_buf = vec![0.0f32; n_temps];
144
145    for sweep_id in 0..n_sweeps {
146        if interrupted.load(Ordering::Relaxed) {
147            return Err("interrupted".to_string());
148        }
149        on_sweep();
150        let record = sweep_id >= warmup_sweeps;
151
152        match config.sweep_mode {
153            SweepMode::Metropolis => mcmc::sweep::metropolis_sweep(
154                lattice,
155                &mut real.spins,
156                &real.couplings,
157                &real.temperatures,
158                &real.system_ids,
159                &mut real.rngs,
160                config.sequential,
161            ),
162            SweepMode::Gibbs => mcmc::sweep::gibbs_sweep(
163                lattice,
164                &mut real.spins,
165                &real.couplings,
166                &real.temperatures,
167                &real.system_ids,
168                &mut real.rngs,
169                config.sequential,
170            ),
171        }
172
173        let do_cluster = config
174            .cluster_update
175            .as_ref()
176            .is_some_and(|c| sweep_id % c.interval == 0);
177
178        if do_cluster {
179            let cluster_cfg = config.cluster_update.as_ref().unwrap();
180            let wolff = cluster_cfg.mode == crate::config::ClusterMode::Wolff;
181            let csd_out = if cluster_cfg.collect_csd && record {
182                for buf in sw_csd_buf.iter_mut() {
183                    buf.fill(0);
184                }
185                Some(sw_csd_buf.as_mut_slice())
186            } else {
187                None
188            };
189
190            clusters::fk_update(
191                lattice,
192                &mut real.spins,
193                &real.couplings,
194                &real.temperatures,
195                &real.system_ids,
196                &mut real.rngs,
197                wolff,
198                csd_out,
199                config.sequential,
200            );
201
202            if cluster_cfg.collect_csd && record {
203                for (slot, buf) in sw_csd_buf.iter().enumerate() {
204                    let accum = &mut fk_csd_accum[slot % n_temps];
205                    for (a, &b) in accum.iter_mut().zip(buf.iter()) {
206                        *a += b;
207                    }
208                }
209            }
210        }
211
212        let pt_this_sweep = config
213            .pt_interval
214            .is_some_and(|interval| sweep_id % interval == 0);
215
216        if record || pt_this_sweep || equil_diag {
217            (real.energies, _) = spins::energy::compute_energies(
218                lattice,
219                &real.spins,
220                &real.couplings,
221                n_systems,
222                false,
223            );
224        }
225
226        if equil_diag {
227            diag_e_buf.fill(0.0);
228            #[allow(clippy::needless_range_loop)]
229            for r in 0..n_replicas {
230                let offset = r * n_temps;
231                for t in 0..n_temps {
232                    let system_id = real.system_ids[offset + t];
233                    diag_e_buf[t] += real.energies[system_id];
234                }
235            }
236            let inv = 1.0 / n_replicas as f32;
237            for v in diag_e_buf.iter_mut() {
238                *v *= inv;
239            }
240
241            let link_overlaps = if n_pairs > 0 {
242                spins::energy::compute_link_overlaps(
243                    lattice,
244                    &real.spins,
245                    &real.system_ids,
246                    n_replicas,
247                    n_temps,
248                )
249            } else {
250                vec![0.0f32; n_temps]
251            };
252
253            equil_accum
254                .as_mut()
255                .unwrap()
256                .push(&diag_e_buf, &link_overlaps);
257        }
258
259        if record {
260            for t in 0..n_temps {
261                mags_buf[t] = 0.0;
262                mags2_buf[t] = 0.0;
263                mags4_buf[t] = 0.0;
264                energies_buf[t] = 0.0;
265            }
266
267            if collect_ac {
268                m2_ac_buf.fill(0.0);
269            }
270
271            for r in 0..n_replicas {
272                let offset = r * n_temps;
273                for t in 0..n_temps {
274                    let system_id = real.system_ids[offset + t];
275                    let spin_base = system_id * n_spins;
276                    let mut sum = 0i64;
277                    for j in 0..n_spins {
278                        sum += real.spins[spin_base + j] as i64;
279                    }
280                    let mag = sum as f32 / n_spins as f32;
281                    let m2 = mag * mag;
282                    mags_buf[t] = mag;
283                    mags2_buf[t] = m2;
284                    mags4_buf[t] = m2 * m2;
285                    energies_buf[t] = real.energies[system_id];
286                }
287
288                if collect_ac {
289                    for t in 0..n_temps {
290                        m2_ac_buf[t] += mags2_buf[t] as f64;
291                    }
292                }
293
294                mags_stat.update(&mags_buf);
295                mags2_stat.update(&mags2_buf);
296                mags4_stat.update(&mags4_buf);
297                energies_stat.update(&energies_buf);
298                energies2_stat.update(&energies_buf);
299            }
300
301            if let Some(ref mut acc) = m2_accum {
302                let inv = 1.0 / n_replicas as f64;
303                for v in m2_ac_buf.iter_mut() {
304                    *v *= inv;
305                }
306                acc.push(&m2_ac_buf);
307            }
308
309            if collect_q2_ac {
310                q2_ac_buf.fill(0.0);
311            }
312
313            for pair_idx in 0..n_pairs {
314                let r_a = 2 * pair_idx;
315                let r_b = 2 * pair_idx + 1;
316                for t in 0..n_temps {
317                    overlaps_buf[t] = 0.0;
318                    overlaps2_buf[t] = 0.0;
319                    overlaps4_buf[t] = 0.0;
320                }
321
322                for t in 0..n_temps {
323                    let sys_a = real.system_ids[r_a * n_temps + t];
324                    let sys_b = real.system_ids[r_b * n_temps + t];
325                    let base_a = sys_a * n_spins;
326                    let base_b = sys_b * n_spins;
327                    let mut dot = 0i64;
328                    for j in 0..n_spins {
329                        dot += (real.spins[base_a + j] as i64) * (real.spins[base_b + j] as i64);
330                    }
331                    let q = dot as f32 / n_spins as f32;
332                    let q2 = q * q;
333                    overlaps_buf[t] = q;
334                    overlaps2_buf[t] = q2;
335                    overlaps4_buf[t] = q2 * q2;
336                }
337
338                if collect_q2_ac {
339                    for t in 0..n_temps {
340                        q2_ac_buf[t] += overlaps2_buf[t] as f64;
341                    }
342                }
343
344                overlap_stat.update(&overlaps_buf);
345                overlap2_stat.update(&overlaps2_buf);
346                overlap4_stat.update(&overlaps4_buf);
347            }
348
349            if let Some(ref mut acc) = q2_accum {
350                let inv = 1.0 / n_pairs as f64;
351                for v in q2_ac_buf.iter_mut() {
352                    *v *= inv;
353                }
354                acc.push(&q2_ac_buf);
355            }
356        }
357
358        if let Some(ref oc_cfg) = config.overlap_cluster {
359            if sweep_id % oc_cfg.interval == 0 && n_replicas >= group_size {
360                let ov_csd_out = if oc_cfg.collect_csd && record {
361                    for buf in overlap_csd_buf.iter_mut() {
362                        buf.fill(0);
363                    }
364                    Some(overlap_csd_buf.as_mut_slice())
365                } else {
366                    None
367                };
368
369                let top4_out = if collect_top && record {
370                    for slot in top4_buf.iter_mut() {
371                        *slot = [0u32; 4];
372                    }
373                    Some(top4_buf.as_mut_slice())
374                } else {
375                    None
376                };
377
378                clusters::overlap_update(
379                    lattice,
380                    &mut real.spins,
381                    &real.couplings,
382                    &real.temperatures,
383                    &real.system_ids,
384                    n_replicas,
385                    n_temps,
386                    &mut real.pair_rngs,
387                    stochastic,
388                    restrict_to_negative,
389                    overlap_wolff,
390                    free_assign,
391                    group_size,
392                    ov_csd_out,
393                    top4_out,
394                    config.sequential,
395                );
396
397                if oc_cfg.collect_csd && record {
398                    for (slot, buf) in overlap_csd_buf.iter().enumerate() {
399                        let accum = &mut overlap_csd_accum[slot / n_pairs];
400                        for (a, &b) in accum.iter_mut().zip(buf.iter()) {
401                            *a += b;
402                        }
403                    }
404                }
405
406                if collect_top && record {
407                    for t in 0..n_temps {
408                        for p in 0..n_pairs {
409                            let raw = top4_buf[t * n_pairs + p];
410                            for (k, &v) in raw.iter().enumerate() {
411                                top4_accum[t][k] += v as f64 / n_spins as f64;
412                            }
413                        }
414                    }
415                    top4_n += 1;
416                }
417            }
418        }
419
420        if pt_this_sweep {
421            if config.overlap_cluster.is_some() {
422                (real.energies, _) = spins::energy::compute_energies(
423                    lattice,
424                    &real.spins,
425                    &real.couplings,
426                    n_systems,
427                    false,
428                );
429            }
430            for r in 0..n_replicas {
431                let offset = r * n_temps;
432                let sid_slice = &mut real.system_ids[offset..offset + n_temps];
433                let temp_slice = &real.temperatures[offset..offset + n_temps];
434                mcmc::tempering::parallel_tempering(
435                    &real.energies,
436                    temp_slice,
437                    sid_slice,
438                    n_spins,
439                    &mut real.rngs[offset],
440                );
441            }
442        }
443    }
444
445    let top_cluster_sizes = if collect_top && top4_n > 0 {
446        let denom = (top4_n * n_pairs) as f64;
447        top4_accum
448            .iter()
449            .map(|arr| {
450                [
451                    arr[0] / denom,
452                    arr[1] / denom,
453                    arr[2] / denom,
454                    arr[3] / denom,
455                ]
456            })
457            .collect()
458    } else {
459        vec![]
460    };
461
462    let mags2_tau = m2_accum
463        .as_ref()
464        .map(|acc| acc.finish().iter().map(|g| sokal_tau(g)).collect())
465        .unwrap_or_default();
466    let overlap2_tau = q2_accum
467        .as_ref()
468        .map(|acc| acc.finish().iter().map(|g| sokal_tau(g)).collect())
469        .unwrap_or_default();
470
471    let equil_checkpoints = equil_accum.map(|acc| acc.finish()).unwrap_or_default();
472
473    Ok(SweepResult {
474        mags: mags_stat.average(),
475        mags2: mags2_stat.average(),
476        mags4: mags4_stat.average(),
477        energies: energies_stat.average(),
478        energies2: energies2_stat.average(),
479        overlap: if n_pairs > 0 {
480            overlap_stat.average()
481        } else {
482            vec![]
483        },
484        overlap2: if n_pairs > 0 {
485            overlap2_stat.average()
486        } else {
487            vec![]
488        },
489        overlap4: if n_pairs > 0 {
490            overlap4_stat.average()
491        } else {
492            vec![]
493        },
494        cluster_stats: ClusterStats {
495            fk_csd: fk_csd_accum,
496            overlap_csd: overlap_csd_accum,
497            top_cluster_sizes,
498        },
499        diagnostics: Diagnostics {
500            mags2_tau,
501            overlap2_tau,
502            equil_checkpoints,
503        },
504    })
505}
506
507/// Run the sweep loop in parallel over multiple disorder realizations.
508///
509/// Each realization is processed by [`run_sweep_loop`], then results are
510/// averaged via [`SweepResult::aggregate`]. For a single realization the
511/// call is made directly, skipping rayon thread-pool overhead.
512pub fn run_sweep_parallel(
513    lattice: &Lattice,
514    realizations: &mut [Realization],
515    n_replicas: usize,
516    n_temps: usize,
517    config: &SimConfig,
518    interrupted: &AtomicBool,
519    on_sweep: &(dyn Fn() + Sync),
520) -> Result<SweepResult, String> {
521    if realizations.len() == 1 {
522        return run_sweep_loop(
523            lattice,
524            &mut realizations[0],
525            n_replicas,
526            n_temps,
527            config,
528            interrupted,
529            on_sweep,
530        );
531    }
532
533    let results: Vec<Result<SweepResult, String>> = realizations
534        .par_iter_mut()
535        .map(|real| {
536            run_sweep_loop(
537                lattice,
538                real,
539                n_replicas,
540                n_temps,
541                config,
542                interrupted,
543                on_sweep,
544            )
545        })
546        .collect();
547
548    let results: Vec<SweepResult> = results.into_iter().collect::<Result<Vec<_>, _>>()?;
549    Ok(SweepResult::aggregate(&results))
550}