use {
crate::{glue_helper::*, rewl::replica_exchange, rewl::*, *},
rand::{prelude::SliceRandom, Rng, SeedableRng},
rayon::{iter::ParallelIterator, prelude::*},
std::{cmp::*, num::NonZeroUsize, sync::*},
};
#[cfg(feature = "sweep_time_optimization")]
use std::cmp::Reverse;
#[cfg(feature = "serde_support")]
use serde::{Deserialize, Serialize};
pub type GluedResult<Hist, Energy> = Result<Glued<Hist, Energy>, HistErrors>;
#[derive(Debug)]
#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
pub struct ReplicaExchangeWangLandau<Ensemble, R, Hist, Energy, S, Res> {
pub(crate) chunk_size: NonZeroUsize,
pub(crate) ensembles: Vec<RwLock<Ensemble>>,
pub(crate) walker: Vec<RewlWalker<R, Hist, Energy, S, Res>>,
pub(crate) log_f_threshold: f64,
pub(crate) replica_exchange_mode: bool,
pub(crate) roundtrip_halfes: Vec<usize>,
pub(crate) last_extreme_interval_visited: Vec<ExtremeInterval>,
}
impl<Ensemble, R, Hist, Energy, S, Res> GlueAble<Hist>
for ReplicaExchangeWangLandau<Ensemble, R, Hist, Energy, S, Res>
where
Hist: Clone,
{
fn push_glue_entry_ignoring(&self, job: &mut GlueJob<Hist>, ignore_idx: &[usize]) {
job.round_trips.extend(self.roundtrip_iter());
let (hists, probs) = self.get_log_prob_and_hists();
self.walker
.chunks(self.chunk_size.get())
.zip(hists)
.zip(probs)
.enumerate()
.filter_map(|(index, ((walker, hist), prob))| {
if ignore_idx.contains(&index) {
None
} else {
Some(((walker, hist), prob))
}
})
.for_each(|((walker, hist), prob)| {
let mut progress = f64::NEG_INFINITY;
let mut accepted = 0;
let mut rejected = 0;
let mut replica_exchanges = 0_u64;
let mut proposed_replica_exchanges = 0;
for w in walker {
let log_f = w.log_f();
if log_f > progress {
progress = log_f;
}
let r = w.rejected_markov_steps();
let a = w.step_count() - r;
rejected += r;
accepted += a;
replica_exchanges += w.replica_exchanges() as u64;
proposed_replica_exchanges += w.proposed_replica_exchanges();
}
let stats = IntervalSimStats {
sim_progress: SimProgress::LogF(progress),
interval_sim_type: SimulationType::REWL,
rejected_steps: rejected,
accepted_steps: accepted,
replica_exchanges: Some(replica_exchanges),
proposed_replica_exchanges: Some(proposed_replica_exchanges),
merged_over_walkers: self.chunk_size,
};
job.collection.push(GlueEntry {
hist: hist.clone(),
prob,
log_base: LogBase::BaseE,
interval_stats: stats,
});
})
}
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde_support", derive(Serialize, Deserialize))]
pub enum ExtremeInterval {
Left,
Right,
None,
}
#[derive(Serialize, Deserialize, Debug, Clone, Copy)]
pub enum ThresholdErrors {
Negative,
NonNormal,
Zero,
}
pub type Rewl<Ensemble, R, Hist, Energy, S, Res> =
ReplicaExchangeWangLandau<Ensemble, R, Hist, Energy, S, Res>;
impl<Ensemble, R, Hist, Energy, S, Res> Rewl<Ensemble, R, Hist, Energy, S, Res> {
pub fn walkers(&self) -> &Vec<RewlWalker<R, Hist, Energy, S, Res>> {
&self.walker
}
pub fn ensemble_iter(&'_ self) -> impl Iterator<Item = RwLockReadGuard<'_, Ensemble>> {
self.ensembles.iter().map(|e| e.read().unwrap())
}
pub fn get_ensemble<'a>(&'a self, index: usize) -> Option<RwLockReadGuard<'a, Ensemble>> {
self.ensembles.get(index).map(|e| e.read().unwrap())
}
pub unsafe fn ensemble_iter_mut(&mut self) -> impl Iterator<Item = &mut Ensemble> {
self.ensembles
.iter_mut()
.map(|item| item.get_mut().unwrap())
}
pub unsafe fn get_ensemble_mut(&mut self, index: usize) -> Option<&mut Ensemble> {
self.ensembles.get_mut(index).map(|e| e.get_mut().unwrap())
}
pub fn num_intervals(&self) -> NonZeroUsize {
match NonZeroUsize::new(self.walker.len() / self.chunk_size.get()) {
Some(v) => v,
None => unreachable!(),
}
}
pub fn walkers_per_interval(&self) -> NonZeroUsize {
self.chunk_size
}
pub fn change_step_size_of_interval(&mut self, n: usize, step_size: usize) -> Outcome {
let start = n * self.chunk_size.get();
let end = start + self.chunk_size.get();
if self.walker.len() < end {
Outcome::Failure
} else {
let slice = &mut self.walker[start..start + self.chunk_size.get()];
slice
.iter_mut()
.for_each(|entry| entry.step_size_change(step_size));
Outcome::Success
}
}
pub fn get_step_size_of_interval(&self, n: usize) -> Option<usize> {
let start = n * self.chunk_size.get();
let end = start + self.chunk_size.get();
if self.walker.len() < end {
None
} else {
let slice = &self.walker[start..start + self.chunk_size.get()];
let step_size = slice[0].step_size();
slice[1..]
.iter()
.for_each(|w|
assert_eq!(
step_size, w.step_size(),
"Fatal Error encountered; ERRORCODE 0x9 - \
Sweep sizes of intervals do not match! \
This should be impossible! if you are using the latest version of the \
'sampling' library, please contact the library author via github by opening an \
issue! https://github.com/Pardoxa/sampling/issues"
)
);
Some(step_size)
}
}
pub fn change_sweep_size_of_interval(&mut self, n: usize, sweep_size: NonZeroUsize) -> Outcome {
let start = n * self.chunk_size.get();
let end = start + self.chunk_size.get();
if self.walker.len() < end {
Outcome::Failure
} else {
let slice = &mut self.walker[start..start + self.chunk_size.get()];
slice
.iter_mut()
.for_each(|entry| entry.sweep_size_change(sweep_size));
Outcome::Success
}
}
pub fn get_sweep_size_of_interval(&self, n: usize) -> Option<NonZeroUsize> {
let start = n * self.chunk_size.get();
let end = start + self.chunk_size.get();
if self.walker.len() < end {
None
} else {
let slice = &self.walker[start..start + self.chunk_size.get()];
let sweep_size = slice[0].sweep_size();
slice[1..]
.iter()
.for_each(|w|
assert_eq!(
sweep_size, w.sweep_size(),
"Fatal Error encountered; ERRORCODE 0xA - \
Sweep sizes of intervals do not match! \
This should be impossible! if you are using the latest version of the \
'sampling' library, please contact the library author via github by opening an \
issue! https://github.com/Pardoxa/sampling/issues"
)
);
Some(sweep_size)
}
}
fn get_log_prob_and_hists(&self) -> (Vec<&Hist>, Vec<Vec<f64>>) {
let log_prob: Vec<_> = self
.walker
.chunks(self.chunk_size.get())
.map(get_merged_walker_prob)
.collect();
let hists: Vec<_> = self
.walker
.iter()
.step_by(self.chunk_size.get())
.map(|w| w.hist())
.collect();
(hists, log_prob)
}
pub fn min_roundtrips(&self) -> usize {
match self.roundtrip_iter().min() {
Some(v) => v,
None => unreachable!(),
}
}
pub fn max_roundtrips(&self) -> usize {
match self.roundtrip_iter().max() {
Some(v) => v,
None => unreachable!(),
}
}
#[inline]
pub fn roundtrip_iter(&'_ self) -> impl Iterator<Item = usize> + '_ {
self.roundtrip_halfes.iter().map(|&r_h| r_h / 2)
}
pub fn largest_log_f(&self) -> f64 {
self.walker
.iter()
.map(|w| w.log_f())
.fold(f64::NEG_INFINITY, |acc, x| x.max(acc))
}
pub fn log_f_vec(&self) -> Vec<f64> {
self.walker.iter().map(|w| w.log_f()).collect()
}
pub fn set_log_f_threshold(&mut self, new_threshold: f64) -> Result<f64, ThresholdErrors> {
if !new_threshold.is_normal() {
Err(ThresholdErrors::NonNormal)
} else if new_threshold < 0.0 {
Err(ThresholdErrors::Negative)
} else if new_threshold == 0.0 {
Err(ThresholdErrors::Zero)
} else {
let old_threshold = self.log_f_threshold;
self.log_f_threshold = new_threshold;
Ok(old_threshold)
}
}
pub fn is_finished(&self) -> bool {
self.walker.iter().all(|w| w.log_f() < self.log_f_threshold)
}
pub fn derivative_merged_log_prob_and_aligned(&self) -> Result<Glued<Hist, Energy>, HistErrors>
where
Hist: HistogramCombine + Histogram,
{
let (hists, log_probs) = self.get_log_prob_and_hists();
let mut res = derivative_merged_and_aligned(log_probs, &hists, LogBase::BaseE)?;
let stats = self.get_glue_stats();
res.set_stats(stats);
Ok(res)
}
fn get_glue_stats(&self) -> GlueStats {
let stats = self
.walker
.chunks(self.chunk_size.get())
.map(|walker| {
let mut progress = f64::NEG_INFINITY;
let mut accepted = 0;
let mut rejected = 0;
let mut replica_exchanges = 0_u64;
let mut proposed_replica_exchanges = 0;
for w in walker {
let log_f = w.log_f();
if log_f > progress {
progress = log_f;
}
let r = w.rejected_markov_steps();
let a = w.step_count() - r;
rejected += r;
accepted += a;
replica_exchanges += w.replica_exchanges() as u64;
proposed_replica_exchanges += w.proposed_replica_exchanges();
}
IntervalSimStats {
sim_progress: SimProgress::LogF(progress),
interval_sim_type: SimulationType::REWL,
rejected_steps: rejected,
accepted_steps: accepted,
replica_exchanges: Some(replica_exchanges),
proposed_replica_exchanges: Some(proposed_replica_exchanges),
merged_over_walkers: self.chunk_size,
}
})
.collect();
let roundtrips = self.roundtrip_iter().collect();
GlueStats {
roundtrips,
interval_stats: stats,
}
}
pub fn average_merged_log_probability_and_align(
&self,
) -> Result<Glued<Hist, Energy>, HistErrors>
where
Hist: HistogramCombine + Histogram,
{
let (hists, log_probs) = self.get_log_prob_and_hists();
let mut res = average_merged_and_aligned(log_probs, &hists, LogBase::BaseE)?;
let stats = self.get_glue_stats();
res.set_stats(stats);
Ok(res)
}
pub fn get_id_vec(&self) -> Vec<usize> {
self.walker.iter().map(|w| w.id()).collect()
}
pub fn hists(&self) -> Vec<&Hist> {
self.walker.iter().map(|w| w.hist()).collect()
}
pub fn get_hist(&self, index: usize) -> Option<&Hist> {
self.walker.get(index).map(|w| w.hist())
}
pub fn into_rees(self) -> Rees<(), Ensemble, R, Hist, Energy, S, Res>
where
Hist: Histogram,
{
self.into()
}
#[allow(clippy::type_complexity, clippy::result_large_err)]
pub fn into_rees_with_extra<Extra>(
self,
extra: Vec<Extra>,
) -> Result<Rees<Extra, Ensemble, R, Hist, Energy, S, Res>, (Self, Vec<Extra>)>
where
Hist: Histogram,
{
if extra.len() != self.walker.len() {
Err((self, extra))
} else {
let rewl_roundtrips: Vec<_> = self.roundtrip_iter().collect();
let rees_roundtrip_halfes: Vec<_> = vec![0; rewl_roundtrips.len()];
let rees_last_extreme_interval_visited: Vec<_> =
vec![ExtremeInterval::None; rewl_roundtrips.len()];
let mut walker = Vec::with_capacity(self.walker.len());
walker.extend(self.walker.into_iter().map(|w| w.into()));
let mut rees = Rees {
walker,
ensembles: self.ensembles,
replica_exchange_mode: self.replica_exchange_mode,
extra,
chunk_size: self.chunk_size,
rewl_roundtrips,
rees_last_extreme_interval_visited,
rees_roundtrip_halfes,
};
rees.update_roundtrips();
Ok(rees)
}
}
}
impl<Ensemble, R, Hist, Energy, S, Res> Rewl<Ensemble, R, Hist, Energy, S, Res>
where
R: Send + Sync + Rng + SeedableRng,
Hist: Send + Sync + Histogram + HistogramVal<Energy>,
Energy: Send + Sync + Clone,
Ensemble: MarkovChain<S, Res>,
Res: Send + Sync,
S: Send + Sync,
{
pub fn simulate_until_convergence<F>(&mut self, energy_fn: F)
where
Ensemble: Send + Sync,
R: Send + Sync,
F: Fn(&mut Ensemble) -> Option<Energy> + Copy + Send + Sync,
{
while !self.is_finished() {
self.sweep(energy_fn);
}
}
pub fn simulate_while<F, C>(&mut self, energy_fn: F, mut condition: C)
where
Ensemble: Send + Sync,
R: Send + Sync,
F: Fn(&mut Ensemble) -> Option<Energy> + Copy + Send + Sync,
C: FnMut(&Self) -> bool,
{
while !self.is_finished() && condition(self) {
self.sweep(energy_fn);
}
}
pub fn check_energy_fn<F>(&mut self, energy_fn: F) -> bool
where
Energy: PartialEq,
F: Fn(&mut Ensemble) -> Option<Energy> + Copy + Send + Sync,
Ensemble: Sync + Send,
{
let ensembles = self.ensembles.as_slice();
self.walker
.par_iter()
.all(|w| w.check_energy_fn(ensembles, energy_fn))
}
pub fn sweep<F>(&mut self, energy_fn: F)
where
Ensemble: Send + Sync,
R: Send + Sync,
F: Fn(&mut Ensemble) -> Option<Energy> + Copy + Send + Sync,
{
let slice = self.ensembles.as_slice();
#[cfg(not(feature = "sweep_time_optimization"))]
let walker = &mut self.walker;
#[cfg(feature = "sweep_time_optimization")]
let mut walker = {
let mut walker = Vec::with_capacity(self.walker.len());
walker.extend(self.walker.iter_mut());
walker.par_sort_unstable_by_key(|w| Reverse(w.duration()));
walker
};
walker
.par_iter_mut()
.for_each(|w| w.wang_landau_sweep(slice, energy_fn));
self.walker
.par_chunks_mut(self.chunk_size.get())
.filter(|chunk| chunk.iter().all(RewlWalker::all_bins_reached))
.for_each(|chunk| {
chunk.iter_mut().for_each(RewlWalker::refine_f_reset_hist);
merge_walker_prob(chunk);
});
if self.walkers_per_interval().get() > 1 {
let exchange_m = self.replica_exchange_mode;
self.walker
.par_chunks_mut(self.chunk_size.get())
.for_each(|chunk| {
let mut shuf = Vec::with_capacity(chunk.len());
if let Some((first, rest)) = chunk.split_first_mut() {
shuf.extend(rest.iter_mut());
shuf.shuffle(&mut first.rng);
shuf.push(first);
let s = if exchange_m {
&mut shuf
} else {
&mut shuf[1..]
};
s.chunks_exact_mut(2).for_each(|c| {
let ptr = c.as_mut_ptr();
unsafe {
let a = &mut *ptr;
let b = &mut *ptr.offset(1);
replica_exchange(a, b);
}
});
}
});
}
let walker_slice = if self.replica_exchange_mode {
&mut self.walker
} else {
&mut self.walker[self.chunk_size.get()..]
};
self.replica_exchange_mode = !self.replica_exchange_mode;
let chunk_size = self.chunk_size;
walker_slice
.par_chunks_exact_mut(2 * self.chunk_size.get())
.for_each(|walker_chunk| {
let (slice_a, slice_b) = walker_chunk.split_at_mut(chunk_size.get());
let mut slice_b_shuffled: Vec<_> = slice_b.iter_mut().collect();
slice_b_shuffled.shuffle(&mut slice_a[0].rng);
for (walker_a, walker_b) in slice_a.iter_mut().zip(slice_b_shuffled.into_iter()) {
replica_exchange(walker_a, walker_b);
}
});
self.update_roundtrips();
}
pub(crate) fn update_roundtrips(&mut self) {
if self.num_intervals().get() == 1 {
return;
}
let mut chunk_iter = self.walker.chunks(self.chunk_size.get());
let first_chunk = chunk_iter.next().unwrap();
first_chunk.iter().for_each(|walker| {
let id = walker.id();
let last_visited = match self.last_extreme_interval_visited.get_mut(id) {
Some(last) => last,
None => unreachable!(),
};
match last_visited {
ExtremeInterval::Right => {
*last_visited = ExtremeInterval::Left;
self.roundtrip_halfes[id] += 1;
}
ExtremeInterval::None => {
*last_visited = ExtremeInterval::Left;
}
_ => (),
}
});
let last_chunk = match chunk_iter.last() {
Some(chunk) => chunk,
None => unreachable!(),
};
last_chunk.iter().for_each(|walker| {
let id = walker.id();
let last_visited = match self.last_extreme_interval_visited.get_mut(id) {
Some(last) => last,
None => unreachable!(),
};
match last_visited {
ExtremeInterval::Left => {
*last_visited = ExtremeInterval::Right;
self.roundtrip_halfes[id] += 1;
}
ExtremeInterval::None => {
*last_visited = ExtremeInterval::Right;
}
_ => (),
}
});
}
}
pub fn merged_log10_prob<Ensemble, R, Hist, Energy, S, Res>(
rewls: &[Rewl<Ensemble, R, Hist, Energy, S, Res>],
) -> Result<(Vec<f64>, Hist), HistErrors>
where
Hist: Histogram + HistogramVal<Energy> + HistogramCombine + Send + Sync,
Energy: PartialOrd,
{
let mut res = merged_log_prob(rewls)?;
ln_to_log10(&mut res.0);
Ok(res)
}
pub fn merged_log_prob<Ensemble, R, Hist, Energy, S, Res>(
rewls: &[Rewl<Ensemble, R, Hist, Energy, S, Res>],
) -> Result<(Vec<f64>, Hist), HistErrors>
where
Hist: Histogram + HistogramVal<Energy> + HistogramCombine + Send + Sync,
Energy: PartialOrd,
{
if rewls.is_empty() {
return Err(HistErrors::EmptySlice);
}
let merged_prob = merged_probs(rewls);
let container = combine_container(rewls, &merged_prob, true);
let (merge_points, alignment, log_prob, e_hist) = align(&container)?;
Ok(only_merged(merge_points, alignment, log_prob, e_hist))
}
pub(crate) fn ignore_fn<T>(container: &mut Vec<T>, ignore: &[usize]) {
let mut ignore = ignore.to_vec();
ignore.sort_unstable_by_key(|&e| Reverse(e));
ignore.dedup();
ignore.into_iter().for_each(|i| {
if i < container.len() {
let _ = container.remove(i);
}
});
}
fn merged_probs<Ensemble, R, Hist, Energy, S, Res>(
rewls: &[Rewl<Ensemble, R, Hist, Energy, S, Res>],
) -> Vec<Vec<f64>> {
let merged_probs: Vec<_> = rewls
.iter()
.flat_map(|rewl| {
rewl.walkers()
.chunks(rewl.walkers_per_interval().get())
.map(get_merged_walker_prob)
})
.collect();
merged_probs
}
fn combine_container<'a, Ensemble, R, Hist, Energy, S, Res>(
rewls: &'a [Rewl<Ensemble, R, Hist, Energy, S, Res>],
log_probabilities: &'a [Vec<f64>],
merged: bool,
) -> Vec<(&'a [f64], &'a Hist)>
where
Hist: HistogramVal<Energy> + HistogramCombine,
Energy: PartialOrd,
{
let mut step_by = NonZeroUsize::new(1).unwrap();
let hists: Vec<_> = rewls
.iter()
.flat_map(|rewl| {
if merged {
step_by = rewl.walkers_per_interval();
}
rewl.walkers()
.iter()
.step_by(step_by.get())
.map(|w| w.hist())
})
.collect();
assert_eq!(hists.len(), log_probabilities.len());
let mut container: Vec<_> = log_probabilities
.iter()
.zip(hists)
.map(|(prob, hist)| (prob.as_slice(), hist))
.collect();
container.sort_unstable_by(|a, b| {
a.1.first_border()
.partial_cmp(&b.1.first_border())
.unwrap_or(Ordering::Equal)
});
container
}