Skip to main content

spin_sim/simulation/
mod.rs

1pub mod realization;
2
3pub use realization::Realization;
4
5use std::sync::atomic::{AtomicBool, Ordering};
6
7use crate::config::{SimConfig, SweepMode};
8use crate::geometry::Lattice;
9use crate::statistics::{
10    sokal_tau, AutocorrAccum, ClusterSnapshot, ClusterStats, Diagnostics, EquilDiagnosticAccum,
11    OverlapAccum, Statistics, SweepResult,
12};
13use crate::{clusters, mcmc, spins};
14use rayon::prelude::*;
15use validator::Validate;
16
17/// Run the full Monte Carlo loop (warmup + measurement) for one [`Realization`].
18///
19/// Each sweep consists of:
20/// 1. A full single-spin pass (`sweep_mode`: Metropolis or Gibbs)
21/// 2. An optional cluster update (every `cluster_update.interval` sweeps)
22/// 3. Measurement (after `warmup_sweeps`)
23/// 4. Optional overlap cluster move (every `overlap_cluster.interval` sweeps, requires `n_replicas ≥ 2`)
24/// 5. Optional parallel tempering (every `pt_interval` sweeps)
25///
26/// `on_sweep` is called once per sweep (useful for progress bars).
27#[allow(clippy::too_many_arguments)]
28pub fn run_sweep_loop(
29    lattice: &Lattice,
30    real: &mut Realization,
31    n_replicas: usize,
32    n_temps: usize,
33    config: &SimConfig,
34    interrupted: &AtomicBool,
35    on_sweep: &(dyn Fn() + Sync),
36    realization_idx: usize,
37) -> Result<SweepResult, String> {
38    config.validate().map_err(|e| format!("{e}"))?;
39
40    let n_spins = lattice.n_spins;
41    let n_systems = n_replicas * n_temps;
42    let n_sweeps = config.n_sweeps;
43    let warmup_sweeps = config.warmup_sweeps;
44
45    let n_modes = config.overlap_cluster.as_ref().map_or(0, |h| h.modes.len());
46
47    if let Some(ref oc_cfg) = config.overlap_cluster {
48        let max_gs = oc_cfg.max_group_size();
49        if n_replicas < max_gs {
50            return Err(format!(
51                "overlap cluster requires n_replicas >= max group_size ({n_replicas} < {max_gs})"
52            ));
53        }
54    }
55
56    let n_pairs = n_replicas / 2;
57
58    let mut fk_csd_accum: Vec<Vec<u64>> = (0..n_temps).map(|_| vec![0u64; n_spins + 1]).collect();
59    let mut sw_csd_buf: Vec<Vec<u64>> = (0..n_systems).map(|_| vec![0u64; n_spins + 1]).collect();
60
61    let mut overlap_csd_accum: Vec<Vec<Vec<u64>>> = (0..n_modes)
62        .map(|_| (0..n_temps).map(|_| vec![0u64; n_spins + 1]).collect())
63        .collect();
64    let mut overlap_csd_buf: Vec<Vec<u64>> = (0..n_temps * n_pairs)
65        .map(|_| vec![0u64; n_spins + 1])
66        .collect();
67
68    let collect_top = config
69        .overlap_cluster
70        .as_ref()
71        .is_some_and(|h| h.collect_stats)
72        && n_pairs > 0;
73
74    let mut top4_accum: Vec<Vec<[f64; 4]>> =
75        (0..n_modes).map(|_| vec![[0.0; 4]; n_temps]).collect();
76    let mut top4_n: Vec<usize> = vec![0; n_modes];
77    let mut top4_buf: Vec<[u32; 4]> = if collect_top {
78        vec![[0u32; 4]; n_temps * n_pairs]
79    } else {
80        vec![]
81    };
82
83    let mut overlap_call_count: usize = 0;
84
85    let snapshot_interval = if realization_idx == 0 {
86        config
87            .overlap_cluster
88            .as_ref()
89            .and_then(|oc| oc.snapshot_interval)
90    } else {
91        None
92    };
93    let n_pair_slots = n_temps * n_pairs;
94    let mut snap_buf: Vec<Vec<u32>> = if snapshot_interval.is_some() {
95        (0..n_pair_slots)
96            .map(|_| Vec::with_capacity(n_spins))
97            .collect()
98    } else {
99        vec![]
100    };
101    let mut blue_snap_buf: Vec<Vec<u32>> = if snapshot_interval.is_some() {
102        (0..n_pair_slots)
103            .map(|_| Vec::with_capacity(n_spins))
104            .collect()
105    } else {
106        vec![]
107    };
108    let mut spin_snap_buf: Vec<Vec<[Vec<i8>; 2]>> = if snapshot_interval.is_some() {
109        (0..n_pair_slots).map(|_| Vec::new()).collect()
110    } else {
111        vec![]
112    };
113    let mut sid_snap_buf: Vec<Vec<[usize; 2]>> = if snapshot_interval.is_some() {
114        (0..n_pair_slots).map(|_| Vec::new()).collect()
115    } else {
116        vec![]
117    };
118    let mut cluster_snapshots: Vec<ClusterSnapshot> = Vec::new();
119
120    let mut mags_stat = Statistics::new(n_temps, 1);
121    let mut mags2_stat = Statistics::new(n_temps, 1);
122    let mut mags4_stat = Statistics::new(n_temps, 1);
123    let mut energies_stat = Statistics::new(n_temps, 1);
124    let mut energies2_stat = Statistics::new(n_temps, 2);
125    let n_measurement_sweeps = n_sweeps.saturating_sub(warmup_sweeps);
126    let ac_max_lag = config
127        .autocorrelation_max_lag
128        .map(|k| k.min(n_measurement_sweeps / 4).max(1));
129    let mut m2_accum = ac_max_lag.map(|k| AutocorrAccum::new(k, n_temps));
130    let mut q2_accum = if ac_max_lag.is_some() && n_pairs > 0 {
131        ac_max_lag.map(|k| AutocorrAccum::new(k, n_temps))
132    } else {
133        None
134    };
135    let collect_ac = ac_max_lag.is_some();
136    let collect_q2_ac = q2_accum.is_some();
137    let mut m2_ac_buf = if collect_ac {
138        vec![0.0f64; n_temps]
139    } else {
140        vec![]
141    };
142
143    let equil_diag = config.equilibration_diagnostic;
144    let mut equil_accum = if equil_diag {
145        Some(EquilDiagnosticAccum::new(n_temps, n_sweeps))
146    } else {
147        None
148    };
149    let mut diag_e_buf = if equil_diag {
150        vec![0.0f32; n_temps]
151    } else {
152        vec![]
153    };
154
155    let mut ov_accum = OverlapAccum::new(
156        n_temps,
157        n_spins,
158        n_pairs,
159        lattice.n_neighbors,
160        equil_diag,
161        collect_q2_ac,
162    );
163
164    let mut mags_buf = vec![0.0f32; n_temps];
165    let mut mags2_buf = vec![0.0f32; n_temps];
166    let mut mags4_buf = vec![0.0f32; n_temps];
167    let mut energies_buf = vec![0.0f32; n_temps];
168
169    for sweep_id in 0..n_sweeps {
170        if interrupted.load(Ordering::Relaxed) {
171            return Err("interrupted".to_string());
172        }
173        on_sweep();
174        let record = sweep_id >= warmup_sweeps;
175
176        match config.sweep_mode {
177            SweepMode::Metropolis => mcmc::sweep::metropolis_sweep(
178                lattice,
179                &mut real.spins,
180                &real.couplings,
181                &real.temperatures,
182                &real.system_ids,
183                &mut real.rngs,
184                config.sequential,
185            ),
186            SweepMode::Gibbs => mcmc::sweep::gibbs_sweep(
187                lattice,
188                &mut real.spins,
189                &real.couplings,
190                &real.temperatures,
191                &real.system_ids,
192                &mut real.rngs,
193                config.sequential,
194            ),
195        }
196
197        let do_cluster = config
198            .cluster_update
199            .as_ref()
200            .is_some_and(|c| sweep_id % c.interval == 0);
201
202        if do_cluster {
203            let cluster_cfg = config.cluster_update.as_ref().unwrap();
204            let wolff = cluster_cfg.mode == crate::config::ClusterMode::Wolff;
205            let csd_out = if cluster_cfg.collect_stats && record {
206                for buf in sw_csd_buf.iter_mut() {
207                    buf.fill(0);
208                }
209                Some(sw_csd_buf.as_mut_slice())
210            } else {
211                None
212            };
213
214            clusters::fk_update(
215                lattice,
216                &mut real.spins,
217                &real.couplings,
218                &real.temperatures,
219                &real.system_ids,
220                &mut real.rngs,
221                wolff,
222                csd_out,
223                config.sequential,
224            );
225
226            if cluster_cfg.collect_stats && record {
227                for (slot, buf) in sw_csd_buf.iter().enumerate() {
228                    let accum = &mut fk_csd_accum[slot % n_temps];
229                    for (a, &b) in accum.iter_mut().zip(buf.iter()) {
230                        *a += b;
231                    }
232                }
233            }
234        }
235
236        let pt_this_sweep = config
237            .pt_interval
238            .is_some_and(|interval| sweep_id % interval == 0);
239
240        if record || pt_this_sweep || equil_diag {
241            (real.energies, _) = spins::energy::compute_energies(
242                lattice,
243                &real.spins,
244                &real.couplings,
245                n_systems,
246                false,
247            );
248        }
249
250        if equil_diag {
251            diag_e_buf.fill(0.0);
252            #[allow(clippy::needless_range_loop)]
253            for r in 0..n_replicas {
254                let offset = r * n_temps;
255                for t in 0..n_temps {
256                    let system_id = real.system_ids[offset + t];
257                    diag_e_buf[t] += real.energies[system_id];
258                }
259            }
260            let inv = 1.0 / n_replicas as f32;
261            for v in diag_e_buf.iter_mut() {
262                *v *= inv;
263            }
264        }
265
266        if (equil_diag || record) && n_pairs > 0 {
267            ov_accum.collect(lattice, &real.spins, &real.system_ids, record);
268        }
269
270        if equil_diag {
271            if n_pairs > 0 {
272                equil_accum
273                    .as_mut()
274                    .unwrap()
275                    .push(&diag_e_buf, &ov_accum.diag_ql_buf);
276            } else {
277                let zeros = vec![0.0f32; n_temps];
278                equil_accum.as_mut().unwrap().push(&diag_e_buf, &zeros);
279            }
280        }
281
282        if record {
283            for t in 0..n_temps {
284                mags_buf[t] = 0.0;
285                mags2_buf[t] = 0.0;
286                mags4_buf[t] = 0.0;
287                energies_buf[t] = 0.0;
288            }
289
290            if collect_ac {
291                m2_ac_buf.fill(0.0);
292            }
293
294            for r in 0..n_replicas {
295                let offset = r * n_temps;
296                for t in 0..n_temps {
297                    let system_id = real.system_ids[offset + t];
298                    let spin_base = system_id * n_spins;
299                    let mut sum = 0i64;
300                    for j in 0..n_spins {
301                        sum += real.spins[spin_base + j] as i64;
302                    }
303                    let mag = sum as f32 / n_spins as f32;
304                    let m2 = mag * mag;
305                    mags_buf[t] = mag;
306                    mags2_buf[t] = m2;
307                    mags4_buf[t] = m2 * m2;
308                    energies_buf[t] = real.energies[system_id];
309                }
310
311                if collect_ac {
312                    for t in 0..n_temps {
313                        m2_ac_buf[t] += mags2_buf[t] as f64;
314                    }
315                }
316
317                mags_stat.update(&mags_buf);
318                mags2_stat.update(&mags2_buf);
319                mags4_stat.update(&mags4_buf);
320                energies_stat.update(&energies_buf);
321                energies2_stat.update(&energies_buf);
322            }
323
324            if let Some(ref mut acc) = m2_accum {
325                let inv = 1.0 / n_replicas as f64;
326                for v in m2_ac_buf.iter_mut() {
327                    *v *= inv;
328                }
329                acc.push(&m2_ac_buf);
330            }
331
332            if let Some(ref mut acc) = q2_accum {
333                let inv = 1.0 / n_pairs as f64;
334                for v in ov_accum.q2_ac_buf.iter_mut() {
335                    *v *= inv;
336                }
337                acc.push(&ov_accum.q2_ac_buf);
338            }
339        }
340
341        if let Some(ref oc_cfg) = config.overlap_cluster {
342            if sweep_id % oc_cfg.interval == 0 {
343                let mode_idx = overlap_call_count % n_modes;
344                let mode = &oc_cfg.modes[mode_idx];
345
346                let ov_csd_out = if oc_cfg.collect_stats && record {
347                    for buf in overlap_csd_buf.iter_mut() {
348                        buf.fill(0);
349                    }
350                    Some(overlap_csd_buf.as_mut_slice())
351                } else {
352                    None
353                };
354
355                let top4_out = if collect_top && record {
356                    for slot in top4_buf.iter_mut() {
357                        *slot = [0u32; 4];
358                    }
359                    Some(top4_buf.as_mut_slice())
360                } else {
361                    None
362                };
363
364                let take_snapshot =
365                    snapshot_interval.is_some_and(|si| sweep_id % si == 0) && record;
366
367                let is_cmr = matches!(mode, crate::config::OverlapClusterBuildMode::Cmr);
368
369                let snap = if take_snapshot {
370                    for buf in spin_snap_buf.iter_mut() {
371                        buf.clear();
372                    }
373                    for buf in sid_snap_buf.iter_mut() {
374                        buf.clear();
375                    }
376                    Some(snap_buf.as_mut_slice())
377                } else {
378                    None
379                };
380                let blue_snap = if take_snapshot && is_cmr {
381                    Some(blue_snap_buf.as_mut_slice())
382                } else {
383                    None
384                };
385                let spin_snap = if take_snapshot {
386                    Some(spin_snap_buf.as_mut_slice())
387                } else {
388                    None
389                };
390                let sid_snap = if take_snapshot {
391                    Some(sid_snap_buf.as_mut_slice())
392                } else {
393                    None
394                };
395
396                clusters::overlap_update(
397                    lattice,
398                    &mut real.spins,
399                    &real.couplings,
400                    &real.temperatures,
401                    &real.system_ids,
402                    n_replicas,
403                    n_temps,
404                    &mut real.pair_rngs,
405                    mode,
406                    oc_cfg.cluster_mode,
407                    ov_csd_out,
408                    top4_out,
409                    config.sequential,
410                    snap,
411                    blue_snap,
412                    spin_snap,
413                    sid_snap,
414                );
415
416                if take_snapshot {
417                    let ids: Vec<Vec<u32>> = (0..n_temps)
418                        .map(|t| snap_buf[t * n_pairs].clone())
419                        .collect();
420                    let blue = if is_cmr {
421                        Some(
422                            (0..n_temps)
423                                .map(|t| blue_snap_buf[t * n_pairs].clone())
424                                .collect(),
425                        )
426                    } else {
427                        None
428                    };
429                    let spins: Vec<[Vec<i8>; 2]> = (0..n_temps)
430                        .map(|t| {
431                            spin_snap_buf[t * n_pairs]
432                                .first()
433                                .cloned()
434                                .unwrap_or_else(|| [vec![], vec![]])
435                        })
436                        .collect();
437                    let sids: Vec<[usize; 2]> = (0..n_temps)
438                        .map(|t| sid_snap_buf[t * n_pairs].first().copied().unwrap_or([0, 0]))
439                        .collect();
440                    cluster_snapshots.push(ClusterSnapshot {
441                        sweep_id,
442                        mode_idx,
443                        cluster_ids: ids,
444                        blue_ids: blue,
445                        spins,
446                        system_ids: sids,
447                    });
448                }
449
450                if oc_cfg.collect_stats && record {
451                    for (slot, buf) in overlap_csd_buf.iter().enumerate() {
452                        let accum = &mut overlap_csd_accum[mode_idx][slot / n_pairs];
453                        for (a, &b) in accum.iter_mut().zip(buf.iter()) {
454                            *a += b;
455                        }
456                    }
457                }
458
459                if collect_top && record {
460                    for t in 0..n_temps {
461                        for p in 0..n_pairs {
462                            let raw = top4_buf[t * n_pairs + p];
463                            for (k, &v) in raw.iter().enumerate() {
464                                top4_accum[mode_idx][t][k] += v as f64 / n_spins as f64;
465                            }
466                        }
467                    }
468                    top4_n[mode_idx] += 1;
469                }
470
471                overlap_call_count += 1;
472            }
473        }
474
475        if pt_this_sweep {
476            if config.overlap_cluster.is_some() {
477                (real.energies, _) = spins::energy::compute_energies(
478                    lattice,
479                    &real.spins,
480                    &real.couplings,
481                    n_systems,
482                    false,
483                );
484            }
485            for r in 0..n_replicas {
486                let offset = r * n_temps;
487                let sid_slice = &mut real.system_ids[offset..offset + n_temps];
488                let temp_slice = &real.temperatures[offset..offset + n_temps];
489                mcmc::tempering::parallel_tempering(
490                    &real.energies,
491                    temp_slice,
492                    sid_slice,
493                    n_spins,
494                    &mut real.rngs[offset],
495                );
496            }
497        }
498    }
499
500    let top_cluster_sizes: Vec<Vec<[f64; 4]>> = if collect_top {
501        top4_accum
502            .iter()
503            .zip(top4_n.iter())
504            .map(|(mode_accum, &count)| {
505                if count == 0 {
506                    return vec![];
507                }
508                let denom = (count * n_pairs) as f64;
509                mode_accum
510                    .iter()
511                    .map(|arr| {
512                        [
513                            arr[0] / denom,
514                            arr[1] / denom,
515                            arr[2] / denom,
516                            arr[3] / denom,
517                        ]
518                    })
519                    .collect()
520            })
521            .collect()
522    } else {
523        vec![]
524    };
525
526    let mags2_tau = m2_accum
527        .as_ref()
528        .map(|acc| acc.finish().iter().map(|g| sokal_tau(g)).collect())
529        .unwrap_or_default();
530    let overlap2_tau = q2_accum
531        .as_ref()
532        .map(|acc| acc.finish().iter().map(|g| sokal_tau(g)).collect())
533        .unwrap_or_default();
534
535    let equil_checkpoints = equil_accum.map(|acc| acc.finish()).unwrap_or_default();
536
537    Ok(SweepResult {
538        mags: mags_stat.average(),
539        mags2: mags2_stat.average(),
540        mags4: mags4_stat.average(),
541        energies: energies_stat.average(),
542        energies2: energies2_stat.average(),
543        overlap_stats: ov_accum.finish(),
544        cluster_stats: ClusterStats {
545            fk_csd: fk_csd_accum,
546            overlap_csd: overlap_csd_accum,
547            top_cluster_sizes,
548        },
549        diagnostics: Diagnostics {
550            mags2_tau,
551            overlap2_tau,
552            equil_checkpoints,
553        },
554        cluster_snapshots,
555    })
556}
557
558/// Run the sweep loop in parallel over multiple disorder realizations.
559///
560/// Each realization is processed by [`run_sweep_loop`], then results are
561/// averaged via [`SweepResult::aggregate`]. For a single realization the
562/// call is made directly, skipping rayon thread-pool overhead.
563pub fn run_sweep_parallel(
564    lattice: &Lattice,
565    realizations: &mut [Realization],
566    n_replicas: usize,
567    n_temps: usize,
568    config: &SimConfig,
569    interrupted: &AtomicBool,
570    on_sweep: &(dyn Fn() + Sync),
571) -> Result<SweepResult, String> {
572    if realizations.len() == 1 {
573        return run_sweep_loop(
574            lattice,
575            &mut realizations[0],
576            n_replicas,
577            n_temps,
578            config,
579            interrupted,
580            on_sweep,
581            0,
582        );
583    }
584
585    let results: Vec<Result<SweepResult, String>> = realizations
586        .par_iter_mut()
587        .enumerate()
588        .map(|(idx, real)| {
589            run_sweep_loop(
590                lattice,
591                real,
592                n_replicas,
593                n_temps,
594                config,
595                interrupted,
596                on_sweep,
597                idx,
598            )
599        })
600        .collect();
601
602    let mut results: Vec<SweepResult> = results.into_iter().collect::<Result<Vec<_>, _>>()?;
603    let snapshots = std::mem::take(&mut results[0].cluster_snapshots);
604    let mut agg = SweepResult::aggregate(&results);
605    agg.cluster_snapshots = snapshots;
606    Ok(agg)
607}