lattice_qcd_rs/simulation/monte_carlo/
metropolis_hastings.rs

1//! Metropolis Hastings method
2//!
3//! I recommend not using method in this module, but they may have niche usage.
4//! look at [`super::metropolis_hastings_sweep`] for a more common algorithm.
5//!
6//! # Example
7//! ```rust
8//! use lattice_qcd_rs::{
9//!     error::ImplementationError,
10//!     simulation::monte_carlo::MetropolisHastingsDeltaDiagnostic,
11//!     simulation::state::{LatticeState, LatticeStateDefault},
12//!     ComplexField,
13//! };
14//!
15//! # use std::error::Error;
16//! # fn main() -> Result<(), Box<dyn Error>> {
17//! let mut rng = rand::thread_rng();
18//!
19//! let size = 1_000_f64;
20//! let number_of_pts = 4;
21//! let beta = 2_f64;
22//! let mut simulation =
23//!     LatticeStateDefault::<4>::new_determinist(size, beta, number_of_pts, &mut rng)?;
24//!
25//! let spread_parameter = 1E-5_f64;
26//! let mut mc = MetropolisHastingsDeltaDiagnostic::new(spread_parameter, rng)
27//!     .ok_or(ImplementationError::OptionWithUnexpectedNone)?;
28//!
29//! let number_of_sims = 100;
30//! for _ in 0..number_of_sims / 10 {
31//!     for _ in 0..10 {
32//!         simulation = simulation.monte_carlo_step(&mut mc)?;
33//!     }
34//!     simulation.normalize_link_matrices(); // we renormalize all matrices back to SU(3);
35//! }
36//! let average = simulation
37//!     .average_trace_plaquette()
38//!     .ok_or(ImplementationError::OptionWithUnexpectedNone)?
39//!     .real();
40//! # Ok(())
41//! # }
42//! ```
43
44use rand_distr::Distribution;
45#[cfg(feature = "serde-serialize")]
46use serde::{Deserialize, Serialize};
47
48use super::{
49    super::{
50        super::{
51            error::Never,
52            field::LinkMatrix,
53            lattice::{
54                Direction, LatticeCyclic, LatticeElementToIndex, LatticeLink, LatticeLinkCanonical,
55                LatticePoint,
56            },
57            su3, Complex, Real,
58        },
59        state::{LatticeState, LatticeStateDefault, LatticeStateNew},
60    },
61    delta_s_old_new_cmp, MonteCarlo, MonteCarloDefault,
62};
63
64/// Metropolis Hastings algorithm. Very slow, use [`MetropolisHastingsDeltaDiagnostic`]
65/// instead when applicable.
66///
67/// This a very general method that can manage every [`LatticeState`] but the tread off
68/// is that it is much slower than
69/// a dedicated algorithm knowing the from of the hamiltonian. If you want to use your own
70/// hamiltonian I advice to implement
71/// you own method too.
72///
73/// Note that this method does not do a sweep but change random link matrix,
74/// for a sweep there is [`super::MetropolisHastingsSweep`].
75///
76/// # Example
77/// See the example of [`super::McWrapper`]
78#[derive(Debug, Clone, Copy, PartialEq)]
79#[cfg_attr(feature = "serde-serialize", derive(Serialize, Deserialize))]
80pub struct MetropolisHastings {
81    number_of_update: usize,
82    spread: Real,
83}
84
85impl MetropolisHastings {
86    /// `spread` should be between 0 and 1 both not included and number_of_update should be greater
87    /// than 0. `0.1_f64` is a good choice for this parameter.
88    ///
89    /// `number_of_update` is the number of times a link matrix is randomly changed.
90    /// `spread` is the spread factor for the random matrix change
91    /// ( used in [`su3::random_su3_close_to_unity`]).
92    pub fn new(number_of_update: usize, spread: Real) -> Option<Self> {
93        if number_of_update == 0 || !(spread > 0_f64 && spread < 1_f64) {
94            return None;
95        }
96        Some(Self {
97            number_of_update,
98            spread,
99        })
100    }
101
102    getter_copy!(
103        /// Get the number of attempted updates per steps.
104        pub const number_of_update() -> usize
105    );
106
107    getter_copy!(
108        /// Get the spread parameter.
109        pub const spread() -> Real
110    );
111}
112
113impl Default for MetropolisHastings {
114    fn default() -> Self {
115        Self::new(1, 0.1_f64).unwrap()
116    }
117}
118
119impl std::fmt::Display for MetropolisHastings {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        write!(
122            f,
123            "Metropolis-Hastings method with {} update and spread {}",
124            self.number_of_update(),
125            self.spread()
126        )
127    }
128}
129
130impl<State, const D: usize> MonteCarloDefault<State, D> for MetropolisHastings
131where
132    State: LatticeState<D> + LatticeStateNew<D>,
133{
134    type Error = State::Error;
135
136    fn potential_next_element<Rng>(
137        &mut self,
138        state: &State,
139        rng: &mut Rng,
140    ) -> Result<State, Self::Error>
141    where
142        Rng: rand::Rng + ?Sized,
143    {
144        let d = rand::distributions::Uniform::new(0, state.link_matrix().len());
145        let mut link_matrix = state.link_matrix().data().clone();
146        (0..self.number_of_update).for_each(|_| {
147            let pos = d.sample(rng);
148            link_matrix[pos] *= su3::random_su3_close_to_unity(self.spread, rng);
149        });
150        State::new(
151            state.lattice().clone(),
152            state.beta(),
153            LinkMatrix::new(link_matrix),
154        )
155    }
156}
157
158/// Metropolis Hastings algorithm with diagnostics. Very slow, use [`MetropolisHastingsDeltaDiagnostic`] instead.
159///
160/// Similar to [`MetropolisHastingsDiagnostic`] but with diagnostic information.
161///
162/// Note that this method does not do a sweep but change random link matrix,
163/// for a sweep there is [`super::MetropolisHastingsSweep`].
164///
165/// # Example
166/// see example of [`super::McWrapper`]
167#[derive(Debug, Clone, Copy, PartialEq)]
168#[cfg_attr(feature = "serde-serialize", derive(Serialize, Deserialize))]
169pub struct MetropolisHastingsDiagnostic {
170    number_of_update: usize,
171    spread: Real,
172    has_replace_last: bool,
173    prob_replace_last: Real,
174}
175
176impl MetropolisHastingsDiagnostic {
177    /// `spread` should be between 0 and 1 both not included and number_of_update should be greater
178    /// than 0. `0.1_f64` is a good choice for this parameter.
179    ///
180    /// `number_of_update` is the number of times a link matrix is randomly changed.
181    /// `spread` is the spread factor for the random matrix change
182    /// ( used in [`su3::random_su3_close_to_unity`]).
183    pub fn new(number_of_update: usize, spread: Real) -> Option<Self> {
184        if number_of_update == 0 || spread <= 0_f64 || spread >= 1_f64 {
185            return None;
186        }
187        Some(Self {
188            number_of_update,
189            spread,
190            has_replace_last: false,
191            prob_replace_last: 0_f64,
192        })
193    }
194
195    /// Get the last probably of acceptance of the random change.
196    pub const fn prob_replace_last(&self) -> Real {
197        self.prob_replace_last
198    }
199
200    /// Get if last step has accepted the replacement.
201    pub const fn has_replace_last(&self) -> bool {
202        self.has_replace_last
203    }
204
205    getter_copy!(
206        /// Get the number of updates per steps.
207        pub const number_of_update() -> usize
208    );
209
210    getter_copy!(
211        /// Get the spread parameter.
212        pub const spread() -> Real
213    );
214}
215
216impl Default for MetropolisHastingsDiagnostic {
217    fn default() -> Self {
218        Self::new(1, 0.1_f64).unwrap()
219    }
220}
221
222impl std::fmt::Display for MetropolisHastingsDiagnostic {
223    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
224        write!(
225            f,
226            "Metropolis-Hastings method with {} update and spread {}, with diagnostics: has accepted last step {}, probability of acceptance of last step {}",
227            self.number_of_update(),
228            self.spread(),
229            self.has_replace_last(),
230            self.prob_replace_last()
231        )
232    }
233}
234
235impl<State, const D: usize> MonteCarloDefault<State, D> for MetropolisHastingsDiagnostic
236where
237    State: LatticeState<D> + LatticeStateNew<D>,
238{
239    type Error = State::Error;
240
241    fn potential_next_element<Rng>(
242        &mut self,
243        state: &State,
244        rng: &mut Rng,
245    ) -> Result<State, Self::Error>
246    where
247        Rng: rand::Rng + ?Sized,
248    {
249        let d = rand::distributions::Uniform::new(0, state.link_matrix().len());
250        let mut link_matrix = state.link_matrix().data().clone();
251        (0..self.number_of_update).for_each(|_| {
252            let pos = d.sample(rng);
253            link_matrix[pos] *= su3::random_su3_close_to_unity(self.spread, rng);
254        });
255        State::new(
256            state.lattice().clone(),
257            state.beta(),
258            LinkMatrix::new(link_matrix),
259        )
260    }
261
262    fn next_element_default<Rng>(
263        &mut self,
264        state: State,
265        rng: &mut Rng,
266    ) -> Result<State, Self::Error>
267    where
268        Rng: rand::Rng + ?Sized,
269    {
270        let potential_next = self.potential_next_element(&state, rng)?;
271        let proba = Self::probability_of_replacement(&state, &potential_next)
272            .min(1_f64)
273            .max(0_f64);
274        self.prob_replace_last = proba;
275        let d = rand::distributions::Bernoulli::new(proba).unwrap();
276        if d.sample(rng) {
277            self.has_replace_last = true;
278            Ok(potential_next)
279        }
280        else {
281            self.has_replace_last = false;
282            Ok(state)
283        }
284    }
285}
286
287/// Metropolis Hastings algorithm with diagnostics.
288///
289/// Note that this method does not do a sweep but change random link matrix,
290/// for a sweep there is [`super::MetropolisHastingsSweep`].
291///
292/// # Example
293/// see example of [`super`]
294#[derive(Clone, Debug, PartialEq)]
295#[cfg_attr(feature = "serde-serialize", derive(Serialize, Deserialize))]
296pub struct MetropolisHastingsDeltaDiagnostic<Rng: rand::Rng> {
297    spread: Real,
298    has_replace_last: bool,
299    prob_replace_last: Real,
300    rng: Rng,
301}
302
303impl<Rng: rand::Rng> MetropolisHastingsDeltaDiagnostic<Rng> {
304    getter_copy!(
305        /// Get the last probably of acceptance of the random change.
306        pub const,
307        prob_replace_last,
308        Real
309    );
310
311    getter_copy!(
312        /// Get if last step has accepted the replacement.
313        pub const,
314        has_replace_last,
315        bool
316    );
317
318    getter!(
319        /// Get a ref to the rng.
320        pub const,
321        rng,
322        Rng
323    );
324
325    getter_copy!(
326        /// Get the spread parameter.
327        pub const spread() -> Real
328    );
329
330    /// Get a mutable reference to the rng.
331    pub fn rng_mut(&mut self) -> &mut Rng {
332        &mut self.rng
333    }
334
335    /// `spread` should be between 0 and 1 both not included and number_of_update should be greater
336    /// than 0.
337    ///
338    /// `number_of_update` is the number of times a link matrix is randomly changed.
339    /// `spread` is the spread factor for the random matrix change
340    /// ( used in [`su3::random_su3_close_to_unity`]).
341    pub fn new(spread: Real, rng: Rng) -> Option<Self> {
342        if spread <= 0_f64 || spread >= 1_f64 {
343            return None;
344        }
345        Some(Self {
346            spread,
347            has_replace_last: false,
348            prob_replace_last: 0_f64,
349            rng,
350        })
351    }
352
353    /// Absorbs self and return the RNG as owned. It essentially deconstruct the structure.
354    #[allow(clippy::missing_const_for_fn)] // false positive
355    pub fn rng_owned(self) -> Rng {
356        self.rng
357    }
358
359    #[inline]
360    fn delta_s<const D: usize>(
361        link_matrix: &LinkMatrix,
362        lattice: &LatticeCyclic<D>,
363        link: &LatticeLinkCanonical<D>,
364        new_link: &na::Matrix3<Complex>,
365        beta: Real,
366    ) -> Real {
367        let old_matrix = link_matrix
368            .matrix(&LatticeLink::from(*link), lattice)
369            .unwrap();
370        delta_s_old_new_cmp(link_matrix, lattice, link, new_link, beta, &old_matrix)
371    }
372
373    #[inline]
374    fn potential_modif<const D: usize>(
375        &mut self,
376        state: &LatticeStateDefault<D>,
377    ) -> (LatticeLinkCanonical<D>, na::Matrix3<Complex>) {
378        let d_p = rand::distributions::Uniform::new(0, state.lattice().dim());
379        let d_d = rand::distributions::Uniform::new(0, LatticeCyclic::<D>::dim_st());
380
381        let point = LatticePoint::from_fn(|_| d_p.sample(&mut self.rng));
382        let direction = Direction::positive_directions()[d_d.sample(&mut self.rng)];
383        let link = LatticeLinkCanonical::new(point, direction).unwrap();
384        let index = link.to_index(state.lattice());
385
386        let old_link_m = state.link_matrix()[index];
387        let rand_m =
388            su3::orthonormalize_matrix(&su3::random_su3_close_to_unity(self.spread, &mut self.rng));
389        let new_link = rand_m * old_link_m;
390        (link, new_link)
391    }
392
393    #[inline]
394    fn next_element_default<const D: usize>(
395        &mut self,
396        mut state: LatticeStateDefault<D>,
397    ) -> LatticeStateDefault<D> {
398        let (link, matrix) = self.potential_modif(&state);
399        let delta_s = Self::delta_s(
400            state.link_matrix(),
401            state.lattice(),
402            &link,
403            &matrix,
404            state.beta(),
405        );
406        let proba = (-delta_s).exp().min(1_f64).max(0_f64);
407        self.prob_replace_last = proba;
408        let d = rand::distributions::Bernoulli::new(proba).unwrap();
409        if d.sample(&mut self.rng) {
410            self.has_replace_last = true;
411            *state.link_mut(&link).unwrap() = matrix;
412        }
413        else {
414            self.has_replace_last = false;
415        }
416        state
417    }
418}
419
420impl<Rng: rand::Rng + Default> Default for MetropolisHastingsDeltaDiagnostic<Rng> {
421    fn default() -> Self {
422        Self::new(0.1_f64, Rng::default()).unwrap()
423    }
424}
425
426impl<Rng: rand::Rng + std::fmt::Display> std::fmt::Display
427    for MetropolisHastingsDeltaDiagnostic<Rng>
428{
429    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
430        write!(
431            f,
432            "Metropolis-Hastings delta method with rng {} and spread {}, with diagnostics: has accepted last step {}, probability of acceptance of last step {}",
433            self.rng(),
434            self.spread(),
435            self.has_replace_last(),
436            self.prob_replace_last()
437        )
438    }
439}
440
441impl<Rng: rand::Rng> AsRef<Rng> for MetropolisHastingsDeltaDiagnostic<Rng> {
442    fn as_ref(&self) -> &Rng {
443        self.rng()
444    }
445}
446
447impl<Rng: rand::Rng> AsMut<Rng> for MetropolisHastingsDeltaDiagnostic<Rng> {
448    fn as_mut(&mut self) -> &mut Rng {
449        self.rng_mut()
450    }
451}
452
453impl<Rng, const D: usize> MonteCarlo<LatticeStateDefault<D>, D>
454    for MetropolisHastingsDeltaDiagnostic<Rng>
455where
456    Rng: rand::Rng,
457{
458    type Error = Never;
459
460    #[inline]
461    fn next_element(
462        &mut self,
463        state: LatticeStateDefault<D>,
464    ) -> Result<LatticeStateDefault<D>, Self::Error> {
465        Ok(self.next_element_default(state))
466    }
467}
468
469#[cfg(test)]
470mod test {
471
472    use rand::SeedableRng;
473
474    use super::*;
475    use crate::simulation::state::*;
476
477    const SEED: u64 = 0x45_78_93_f4_4a_b0_67_f0;
478
479    #[test]
480    fn test_mh_delta() {
481        let mut rng = rand::rngs::StdRng::seed_from_u64(SEED);
482
483        let size = 1_000_f64;
484        let number_of_pts = 4;
485        let beta = 2_f64;
486        let mut simulation =
487            LatticeStateDefault::<4>::new_determinist(size, beta, number_of_pts, &mut rng).unwrap();
488
489        let mut mcd = MetropolisHastingsDeltaDiagnostic::new(0.01_f64, rng).unwrap();
490        for _ in 0_u32..10_u32 {
491            let mut simulation2 = simulation.clone();
492            let (link, matrix) = mcd.potential_modif(&simulation);
493            *simulation2.link_mut(&link).unwrap() = matrix;
494            let ds = MetropolisHastingsDeltaDiagnostic::<rand::rngs::StdRng>::delta_s(
495                simulation.link_matrix(),
496                simulation.lattice(),
497                &link,
498                &matrix,
499                simulation.beta(),
500            );
501            println!(
502                "ds {}, dh {}",
503                ds,
504                -simulation.hamiltonian_links() + simulation2.hamiltonian_links()
505            );
506            let prob_of_replacement = (simulation.hamiltonian_links()
507                - simulation2.hamiltonian_links())
508            .exp()
509            .min(1_f64)
510            .max(0_f64);
511            assert!(((-ds).exp().min(1_f64).max(0_f64) - prob_of_replacement).abs() < 1E-8_f64);
512            simulation = simulation2;
513        }
514    }
515    #[test]
516    fn methods_common_traits() {
517        assert_eq!(
518            MetropolisHastings::default(),
519            MetropolisHastings::new(1, 0.1_f64).unwrap()
520        );
521        assert_eq!(
522            MetropolisHastingsDiagnostic::default(),
523            MetropolisHastingsDiagnostic::new(1, 0.1_f64).unwrap()
524        );
525
526        let rng = rand::rngs::StdRng::seed_from_u64(SEED);
527        assert!(MetropolisHastingsDeltaDiagnostic::new(0_f64, rng.clone()).is_none());
528        assert!(MetropolisHastings::new(0, 0.1_f64).is_none());
529        assert!(MetropolisHastingsDiagnostic::new(1, 0_f64).is_none());
530
531        assert_eq!(
532            MetropolisHastings::new(2, 0.2_f64).unwrap().to_string(),
533            "Metropolis-Hastings method with 2 update and spread 0.2"
534        );
535        assert_eq!(
536            MetropolisHastingsDiagnostic::new(2, 0.2_f64).unwrap().to_string(),
537            "Metropolis-Hastings method with 2 update and spread 0.2, with diagnostics: has accepted last step false, probability of acceptance of last step 0"
538        );
539        let mut mhdd = MetropolisHastingsDeltaDiagnostic::new(0.1_f64, rng).unwrap();
540        let _: &rand::rngs::StdRng = mhdd.as_ref();
541        let _: &mut rand::rngs::StdRng = mhdd.as_mut();
542    }
543}