use crate::Outcome;
use {
crate::{HasRng, MarkovChain},
num_traits::AsPrimitive,
rand::Rng,
std::marker::PhantomData,
};
#[cfg(feature = "serde_support")]
use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
pub enum MetropolisError {
InvalidState,
NAN,
InfinitBeta,
}
#[derive(Clone)]
#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
pub struct Metropolis<E, R, S, Res, T> {
ensemble: E,
rng: R,
energy: T,
m_beta: f64,
step_size: usize,
counter: usize,
steps: Vec<S>,
marker_res: PhantomData<Res>,
}
impl<R, E, S, Res, T> Metropolis<E, R, S, Res, T>
where
T: Copy + AsPrimitive<f64>,
{
pub fn m_beta(&self) -> f64 {
self.m_beta
}
pub fn set_m_beta(&mut self, m_beta: f64) {
self.m_beta = m_beta;
}
pub fn set_temperature(&mut self, temperature: f64) {
self.m_beta = -1.0 / temperature;
}
pub fn energy(&self) -> T {
self.energy
}
pub unsafe fn set_energy(&mut self, energy: T) {
self.energy = energy;
}
pub fn ensemble(&self) -> &E {
&self.ensemble
}
pub unsafe fn ensemble_mut(&mut self) -> &mut E {
&mut self.ensemble
}
pub fn counter(&self) -> usize {
self.counter
}
pub fn reset_counter(&mut self) {
self.counter = 0;
}
pub fn step_size(&self) -> usize {
self.step_size
}
pub fn set_step_size(&mut self, step_size: usize) -> Outcome {
if step_size == 0 {
Outcome::Failure
} else {
self.step_size = step_size;
Outcome::Success
}
}
}
impl<E, R, S, Res, T> Metropolis<E, R, S, Res, T>
where
R: Rng,
E: MarkovChain<S, Res>,
T: Copy + AsPrimitive<f64>,
{
pub fn new_from_m_beta(
rng: R,
ensemble: E,
energy: T,
m_beta: f64,
step_size: usize,
) -> Result<Self, MetropolisError> {
if (energy.as_()).is_nan() || m_beta.is_nan() {
return Err(MetropolisError::NAN);
}
if !m_beta.is_finite() {
return Err(MetropolisError::InfinitBeta);
}
let steps = Vec::with_capacity(step_size);
Ok(Self {
ensemble,
rng,
energy,
m_beta,
steps,
marker_res: PhantomData::<Res>,
counter: 0,
step_size,
})
}
pub fn new_from_temperature(
rng: R,
ensemble: E,
energy: T,
temperature: f64,
step_size: usize,
) -> Result<Self, MetropolisError> {
if temperature.is_nan() {
return Err(MetropolisError::NAN);
}
Self::new_from_m_beta(rng, ensemble, energy, -1.0 / temperature, step_size)
}
pub fn change_markov_chain<S2, Res2>(self) -> Metropolis<E, R, S2, Res2, T>
where
E: MarkovChain<S2, Res2>,
{
Metropolis::<E, R, S2, Res2, T> {
ensemble: self.ensemble,
rng: self.rng,
energy: self.energy,
step_size: self.step_size,
m_beta: self.m_beta,
counter: self.counter,
steps: Vec::with_capacity(self.step_size),
marker_res: PhantomData::<Res2>,
}
}
#[inline(always)]
unsafe fn metropolis_step_unsafe<Energy>(&mut self, mut energy_fn: Energy)
where
Energy: FnMut(&mut E) -> Option<T>,
{
self.metropolis_step_efficient_unsafe(|ensemble, _, _| energy_fn(ensemble))
}
#[inline(always)]
fn metropolis_step<Energy>(&mut self, mut energy_fn: Energy)
where
Energy: FnMut(&E) -> Option<T>,
{
unsafe { self.metropolis_step_unsafe(|ensemble| energy_fn(ensemble)) }
}
#[inline(always)]
unsafe fn metropolis_step_efficient_unsafe<Energy>(&mut self, mut energy_fn: Energy)
where
Energy: FnMut(&mut E, T, &[S]) -> Option<T>,
{
self.counter = self.counter.wrapping_add(1);
self.ensemble.m_steps(self.step_size, &mut self.steps);
let new_energy = match energy_fn(&mut self.ensemble, self.energy, &self.steps) {
None => {
self.ensemble.undo_steps_quiet(&self.steps);
return;
}
Some(e) => e,
};
let a_prob = (self.m_beta * (new_energy.as_() - self.energy.as_())).exp();
let rejected = self.rng.random::<f64>() > a_prob;
if rejected {
self.ensemble.undo_steps_quiet(&self.steps);
} else {
self.energy = new_energy;
}
}
#[inline(always)]
fn metropolis_step_efficient<Energy>(&mut self, mut energy_fn: Energy)
where
Energy: FnMut(&E, T, &[S]) -> Option<T>,
{
unsafe {
self.metropolis_step_efficient_unsafe(|ensemble, energy, steps| {
energy_fn(ensemble, energy, steps)
})
}
}
pub fn metropolis<Energy, Mes>(
&mut self,
step_target: usize,
mut energy_fn: Energy,
mut measure: Mes,
) where
Energy: FnMut(&E) -> Option<T>,
Mes: FnMut(&Self),
{
for _ in self.counter..=step_target {
self.metropolis_step(&mut energy_fn);
measure(self);
}
}
pub unsafe fn metropolis_unsafe<Energy, Mes>(
&mut self,
step_target: usize,
mut energy_fn: Energy,
mut measure: Mes,
) where
Energy: FnMut(&mut E) -> Option<T>,
Mes: FnMut(&mut Self),
{
for _ in self.counter..=step_target {
self.metropolis_step_unsafe(&mut energy_fn);
measure(self);
}
}
pub fn metropolis_efficient<Energy, Mes>(
&mut self,
step_target: usize,
mut energy_fn: Energy,
mut measure: Mes,
) where
Energy: FnMut(&E, T, &[S]) -> Option<T>,
Mes: FnMut(&Self),
{
for _ in self.counter..=step_target {
self.metropolis_step_efficient(&mut energy_fn);
measure(self);
}
}
pub unsafe fn metropolis_efficient_unsafe<Energy, Mes>(
&mut self,
step_target: usize,
mut energy_fn: Energy,
mut measure: Mes,
) where
Energy: Fn(&mut E, T, &[S]) -> Option<T>,
Mes: FnMut(&mut Self),
{
for _ in self.counter..=step_target {
self.metropolis_step_efficient_unsafe(&mut energy_fn);
measure(self);
}
}
pub fn metropolis_while<Energy, Mes, Cond>(
&mut self,
mut energy_fn: Energy,
mut measure: Mes,
mut condition: Cond,
) where
Energy: FnMut(&E) -> Option<T>,
Mes: FnMut(&Self),
Cond: FnMut(&Self) -> bool,
{
while condition(self) {
self.metropolis_step(&mut energy_fn);
measure(self);
}
}
pub unsafe fn metropolis_while_unsafe<Energy, Mes, Cond>(
&mut self,
mut energy_fn: Energy,
mut measure: Mes,
mut condition: Cond,
) where
Energy: FnMut(&mut E) -> Option<T>,
Mes: FnMut(&mut Self),
Cond: FnMut(&mut Self) -> bool,
{
while condition(self) {
self.metropolis_step_unsafe(&mut energy_fn);
measure(self);
}
}
pub fn metropolis_efficient_while<Energy, Mes, Cond>(
&mut self,
mut energy_fn: Energy,
mut measure: Mes,
mut condition: Cond,
) where
Energy: FnMut(&E, T, &[S]) -> Option<T>,
Mes: FnMut(&Self),
Cond: FnMut(&Self) -> bool,
{
while condition(self) {
self.metropolis_step_efficient(&mut energy_fn);
measure(self);
}
}
pub unsafe fn metropolis_efficient_while_unsafe<Energy, Mes, Cond>(
&mut self,
mut energy_fn: Energy,
mut measure: Mes,
mut condition: Cond,
) where
Energy: FnMut(&mut E, T, &[S]) -> Option<T>,
Mes: FnMut(&mut Self),
Cond: FnMut(&mut Self) -> bool,
{
while condition(self) {
self.metropolis_step_efficient_unsafe(&mut energy_fn);
measure(self);
}
}
}
impl<E, R, S, Res, T> HasRng<R> for Metropolis<E, R, S, Res, T>
where
R: Rng,
{
fn rng(&mut self) -> &mut R {
&mut self.rng
}
fn swap_rng(&mut self, rng: &mut R) {
std::mem::swap(&mut self.rng, rng);
}
}