1use rand_distr::Distribution;
45#[cfg(feature = "serde-serialize")]
46use serde::{Deserialize, Serialize};
47
48use super::{
49 super::{
50 super::{
51 error::Never,
52 field::LinkMatrix,
53 lattice::{
54 Direction, LatticeCyclic, LatticeElementToIndex, LatticeLink, LatticeLinkCanonical,
55 LatticePoint,
56 },
57 su3, Complex, Real,
58 },
59 state::{LatticeState, LatticeStateDefault, LatticeStateNew},
60 },
61 delta_s_old_new_cmp, MonteCarlo, MonteCarloDefault,
62};
63
64#[derive(Debug, Clone, Copy, PartialEq)]
79#[cfg_attr(feature = "serde-serialize", derive(Serialize, Deserialize))]
80pub struct MetropolisHastings {
81 number_of_update: usize,
82 spread: Real,
83}
84
85impl MetropolisHastings {
86 pub fn new(number_of_update: usize, spread: Real) -> Option<Self> {
93 if number_of_update == 0 || !(spread > 0_f64 && spread < 1_f64) {
94 return None;
95 }
96 Some(Self {
97 number_of_update,
98 spread,
99 })
100 }
101
102 getter_copy!(
103 pub const number_of_update() -> usize
105 );
106
107 getter_copy!(
108 pub const spread() -> Real
110 );
111}
112
113impl Default for MetropolisHastings {
114 fn default() -> Self {
115 Self::new(1, 0.1_f64).unwrap()
116 }
117}
118
119impl std::fmt::Display for MetropolisHastings {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 write!(
122 f,
123 "Metropolis-Hastings method with {} update and spread {}",
124 self.number_of_update(),
125 self.spread()
126 )
127 }
128}
129
130impl<State, const D: usize> MonteCarloDefault<State, D> for MetropolisHastings
131where
132 State: LatticeState<D> + LatticeStateNew<D>,
133{
134 type Error = State::Error;
135
136 fn potential_next_element<Rng>(
137 &mut self,
138 state: &State,
139 rng: &mut Rng,
140 ) -> Result<State, Self::Error>
141 where
142 Rng: rand::Rng + ?Sized,
143 {
144 let d = rand::distributions::Uniform::new(0, state.link_matrix().len());
145 let mut link_matrix = state.link_matrix().data().clone();
146 (0..self.number_of_update).for_each(|_| {
147 let pos = d.sample(rng);
148 link_matrix[pos] *= su3::random_su3_close_to_unity(self.spread, rng);
149 });
150 State::new(
151 state.lattice().clone(),
152 state.beta(),
153 LinkMatrix::new(link_matrix),
154 )
155 }
156}
157
158#[derive(Debug, Clone, Copy, PartialEq)]
168#[cfg_attr(feature = "serde-serialize", derive(Serialize, Deserialize))]
169pub struct MetropolisHastingsDiagnostic {
170 number_of_update: usize,
171 spread: Real,
172 has_replace_last: bool,
173 prob_replace_last: Real,
174}
175
176impl MetropolisHastingsDiagnostic {
177 pub fn new(number_of_update: usize, spread: Real) -> Option<Self> {
184 if number_of_update == 0 || spread <= 0_f64 || spread >= 1_f64 {
185 return None;
186 }
187 Some(Self {
188 number_of_update,
189 spread,
190 has_replace_last: false,
191 prob_replace_last: 0_f64,
192 })
193 }
194
195 pub const fn prob_replace_last(&self) -> Real {
197 self.prob_replace_last
198 }
199
200 pub const fn has_replace_last(&self) -> bool {
202 self.has_replace_last
203 }
204
205 getter_copy!(
206 pub const number_of_update() -> usize
208 );
209
210 getter_copy!(
211 pub const spread() -> Real
213 );
214}
215
216impl Default for MetropolisHastingsDiagnostic {
217 fn default() -> Self {
218 Self::new(1, 0.1_f64).unwrap()
219 }
220}
221
222impl std::fmt::Display for MetropolisHastingsDiagnostic {
223 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
224 write!(
225 f,
226 "Metropolis-Hastings method with {} update and spread {}, with diagnostics: has accepted last step {}, probability of acceptance of last step {}",
227 self.number_of_update(),
228 self.spread(),
229 self.has_replace_last(),
230 self.prob_replace_last()
231 )
232 }
233}
234
235impl<State, const D: usize> MonteCarloDefault<State, D> for MetropolisHastingsDiagnostic
236where
237 State: LatticeState<D> + LatticeStateNew<D>,
238{
239 type Error = State::Error;
240
241 fn potential_next_element<Rng>(
242 &mut self,
243 state: &State,
244 rng: &mut Rng,
245 ) -> Result<State, Self::Error>
246 where
247 Rng: rand::Rng + ?Sized,
248 {
249 let d = rand::distributions::Uniform::new(0, state.link_matrix().len());
250 let mut link_matrix = state.link_matrix().data().clone();
251 (0..self.number_of_update).for_each(|_| {
252 let pos = d.sample(rng);
253 link_matrix[pos] *= su3::random_su3_close_to_unity(self.spread, rng);
254 });
255 State::new(
256 state.lattice().clone(),
257 state.beta(),
258 LinkMatrix::new(link_matrix),
259 )
260 }
261
262 fn next_element_default<Rng>(
263 &mut self,
264 state: State,
265 rng: &mut Rng,
266 ) -> Result<State, Self::Error>
267 where
268 Rng: rand::Rng + ?Sized,
269 {
270 let potential_next = self.potential_next_element(&state, rng)?;
271 let proba = Self::probability_of_replacement(&state, &potential_next)
272 .min(1_f64)
273 .max(0_f64);
274 self.prob_replace_last = proba;
275 let d = rand::distributions::Bernoulli::new(proba).unwrap();
276 if d.sample(rng) {
277 self.has_replace_last = true;
278 Ok(potential_next)
279 }
280 else {
281 self.has_replace_last = false;
282 Ok(state)
283 }
284 }
285}
286
287#[derive(Clone, Debug, PartialEq)]
295#[cfg_attr(feature = "serde-serialize", derive(Serialize, Deserialize))]
296pub struct MetropolisHastingsDeltaDiagnostic<Rng: rand::Rng> {
297 spread: Real,
298 has_replace_last: bool,
299 prob_replace_last: Real,
300 rng: Rng,
301}
302
303impl<Rng: rand::Rng> MetropolisHastingsDeltaDiagnostic<Rng> {
304 getter_copy!(
305 pub const,
307 prob_replace_last,
308 Real
309 );
310
311 getter_copy!(
312 pub const,
314 has_replace_last,
315 bool
316 );
317
318 getter!(
319 pub const,
321 rng,
322 Rng
323 );
324
325 getter_copy!(
326 pub const spread() -> Real
328 );
329
330 pub fn rng_mut(&mut self) -> &mut Rng {
332 &mut self.rng
333 }
334
335 pub fn new(spread: Real, rng: Rng) -> Option<Self> {
342 if spread <= 0_f64 || spread >= 1_f64 {
343 return None;
344 }
345 Some(Self {
346 spread,
347 has_replace_last: false,
348 prob_replace_last: 0_f64,
349 rng,
350 })
351 }
352
353 #[allow(clippy::missing_const_for_fn)] pub fn rng_owned(self) -> Rng {
356 self.rng
357 }
358
359 #[inline]
360 fn delta_s<const D: usize>(
361 link_matrix: &LinkMatrix,
362 lattice: &LatticeCyclic<D>,
363 link: &LatticeLinkCanonical<D>,
364 new_link: &na::Matrix3<Complex>,
365 beta: Real,
366 ) -> Real {
367 let old_matrix = link_matrix
368 .matrix(&LatticeLink::from(*link), lattice)
369 .unwrap();
370 delta_s_old_new_cmp(link_matrix, lattice, link, new_link, beta, &old_matrix)
371 }
372
373 #[inline]
374 fn potential_modif<const D: usize>(
375 &mut self,
376 state: &LatticeStateDefault<D>,
377 ) -> (LatticeLinkCanonical<D>, na::Matrix3<Complex>) {
378 let d_p = rand::distributions::Uniform::new(0, state.lattice().dim());
379 let d_d = rand::distributions::Uniform::new(0, LatticeCyclic::<D>::dim_st());
380
381 let point = LatticePoint::from_fn(|_| d_p.sample(&mut self.rng));
382 let direction = Direction::positive_directions()[d_d.sample(&mut self.rng)];
383 let link = LatticeLinkCanonical::new(point, direction).unwrap();
384 let index = link.to_index(state.lattice());
385
386 let old_link_m = state.link_matrix()[index];
387 let rand_m =
388 su3::orthonormalize_matrix(&su3::random_su3_close_to_unity(self.spread, &mut self.rng));
389 let new_link = rand_m * old_link_m;
390 (link, new_link)
391 }
392
393 #[inline]
394 fn next_element_default<const D: usize>(
395 &mut self,
396 mut state: LatticeStateDefault<D>,
397 ) -> LatticeStateDefault<D> {
398 let (link, matrix) = self.potential_modif(&state);
399 let delta_s = Self::delta_s(
400 state.link_matrix(),
401 state.lattice(),
402 &link,
403 &matrix,
404 state.beta(),
405 );
406 let proba = (-delta_s).exp().min(1_f64).max(0_f64);
407 self.prob_replace_last = proba;
408 let d = rand::distributions::Bernoulli::new(proba).unwrap();
409 if d.sample(&mut self.rng) {
410 self.has_replace_last = true;
411 *state.link_mut(&link).unwrap() = matrix;
412 }
413 else {
414 self.has_replace_last = false;
415 }
416 state
417 }
418}
419
420impl<Rng: rand::Rng + Default> Default for MetropolisHastingsDeltaDiagnostic<Rng> {
421 fn default() -> Self {
422 Self::new(0.1_f64, Rng::default()).unwrap()
423 }
424}
425
426impl<Rng: rand::Rng + std::fmt::Display> std::fmt::Display
427 for MetropolisHastingsDeltaDiagnostic<Rng>
428{
429 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
430 write!(
431 f,
432 "Metropolis-Hastings delta method with rng {} and spread {}, with diagnostics: has accepted last step {}, probability of acceptance of last step {}",
433 self.rng(),
434 self.spread(),
435 self.has_replace_last(),
436 self.prob_replace_last()
437 )
438 }
439}
440
441impl<Rng: rand::Rng> AsRef<Rng> for MetropolisHastingsDeltaDiagnostic<Rng> {
442 fn as_ref(&self) -> &Rng {
443 self.rng()
444 }
445}
446
447impl<Rng: rand::Rng> AsMut<Rng> for MetropolisHastingsDeltaDiagnostic<Rng> {
448 fn as_mut(&mut self) -> &mut Rng {
449 self.rng_mut()
450 }
451}
452
453impl<Rng, const D: usize> MonteCarlo<LatticeStateDefault<D>, D>
454 for MetropolisHastingsDeltaDiagnostic<Rng>
455where
456 Rng: rand::Rng,
457{
458 type Error = Never;
459
460 #[inline]
461 fn next_element(
462 &mut self,
463 state: LatticeStateDefault<D>,
464 ) -> Result<LatticeStateDefault<D>, Self::Error> {
465 Ok(self.next_element_default(state))
466 }
467}
468
469#[cfg(test)]
470mod test {
471
472 use rand::SeedableRng;
473
474 use super::*;
475 use crate::simulation::state::*;
476
477 const SEED: u64 = 0x45_78_93_f4_4a_b0_67_f0;
478
479 #[test]
480 fn test_mh_delta() {
481 let mut rng = rand::rngs::StdRng::seed_from_u64(SEED);
482
483 let size = 1_000_f64;
484 let number_of_pts = 4;
485 let beta = 2_f64;
486 let mut simulation =
487 LatticeStateDefault::<4>::new_determinist(size, beta, number_of_pts, &mut rng).unwrap();
488
489 let mut mcd = MetropolisHastingsDeltaDiagnostic::new(0.01_f64, rng).unwrap();
490 for _ in 0_u32..10_u32 {
491 let mut simulation2 = simulation.clone();
492 let (link, matrix) = mcd.potential_modif(&simulation);
493 *simulation2.link_mut(&link).unwrap() = matrix;
494 let ds = MetropolisHastingsDeltaDiagnostic::<rand::rngs::StdRng>::delta_s(
495 simulation.link_matrix(),
496 simulation.lattice(),
497 &link,
498 &matrix,
499 simulation.beta(),
500 );
501 println!(
502 "ds {}, dh {}",
503 ds,
504 -simulation.hamiltonian_links() + simulation2.hamiltonian_links()
505 );
506 let prob_of_replacement = (simulation.hamiltonian_links()
507 - simulation2.hamiltonian_links())
508 .exp()
509 .min(1_f64)
510 .max(0_f64);
511 assert!(((-ds).exp().min(1_f64).max(0_f64) - prob_of_replacement).abs() < 1E-8_f64);
512 simulation = simulation2;
513 }
514 }
515 #[test]
516 fn methods_common_traits() {
517 assert_eq!(
518 MetropolisHastings::default(),
519 MetropolisHastings::new(1, 0.1_f64).unwrap()
520 );
521 assert_eq!(
522 MetropolisHastingsDiagnostic::default(),
523 MetropolisHastingsDiagnostic::new(1, 0.1_f64).unwrap()
524 );
525
526 let rng = rand::rngs::StdRng::seed_from_u64(SEED);
527 assert!(MetropolisHastingsDeltaDiagnostic::new(0_f64, rng.clone()).is_none());
528 assert!(MetropolisHastings::new(0, 0.1_f64).is_none());
529 assert!(MetropolisHastingsDiagnostic::new(1, 0_f64).is_none());
530
531 assert_eq!(
532 MetropolisHastings::new(2, 0.2_f64).unwrap().to_string(),
533 "Metropolis-Hastings method with 2 update and spread 0.2"
534 );
535 assert_eq!(
536 MetropolisHastingsDiagnostic::new(2, 0.2_f64).unwrap().to_string(),
537 "Metropolis-Hastings method with 2 update and spread 0.2, with diagnostics: has accepted last step false, probability of acceptance of last step 0"
538 );
539 let mut mhdd = MetropolisHastingsDeltaDiagnostic::new(0.1_f64, rng).unwrap();
540 let _: &rand::rngs::StdRng = mhdd.as_ref();
541 let _: &mut rand::rngs::StdRng = mhdd.as_mut();
542 }
543}