#![forbid(unsafe_code)]
#![deny(missing_docs)]
use std::ops::ControlFlow;
use std::ptr;
use rand::{
distr::{Distribution, StandardUniform, Uniform},
Rng, RngExt,
};
use rand_distr::{
weighted::{AliasableWeight, WeightedAliasIndex},
Normal,
};
#[cfg(feature = "rayon")]
use rayon::iter::{IntoParallelRefMutIterator, ParallelExtend, ParallelIterator};
pub trait Params: Send + Sync + Clone {
fn dimension(&self) -> usize;
fn values(&self) -> impl Iterator<Item = &f64>;
fn collect(iter: impl Iterator<Item = f64>) -> Self;
}
impl<const N: usize> Params for [f64; N] {
fn dimension(&self) -> usize {
N
}
fn values(&self) -> impl Iterator<Item = &f64> {
self.iter()
}
fn collect(iter: impl Iterator<Item = f64>) -> Self {
let mut new = [0.; N];
iter.enumerate().for_each(|(idx, value)| new[idx] = value);
new
}
}
impl Params for Vec<f64> {
fn dimension(&self) -> usize {
self.len()
}
fn values(&self) -> impl Iterator<Item = &f64> {
self.iter()
}
fn collect(iter: impl Iterator<Item = f64>) -> Self {
iter.collect()
}
}
impl Params for Box<[f64]> {
fn dimension(&self) -> usize {
self.len()
}
fn values(&self) -> impl Iterator<Item = &f64> {
self.iter()
}
fn collect(iter: impl Iterator<Item = f64>) -> Self {
iter.collect()
}
}
pub trait Move<P>
where
P: Params,
{
fn propose<'a, O, R>(&self, this: &'a P, other: O, rng: &mut R) -> (P, f64)
where
O: FnMut(&mut R) -> &'a P,
R: Rng;
}
pub struct Stretch {
scale: f64,
}
impl Stretch {
pub fn new(scale: f64) -> Self {
Self { scale }
}
}
impl Default for Stretch {
fn default() -> Self {
Self::new(2.)
}
}
impl<P> Move<P> for Stretch
where
P: Params,
{
fn propose<'a, O, R>(&self, this: &'a P, mut other: O, rng: &mut R) -> (P, f64)
where
O: FnMut(&mut R) -> &'a P,
R: Rng,
{
let other = other(rng);
let z = ((self.scale - 1.) * gen_unit(rng) + 1.).powi(2) / self.scale;
let new_state = P::collect(
this.values()
.zip(other.values())
.map(|(this, other)| (this - other).mul_add(z, *other)),
);
let factor = (new_state.dimension() - 1) as f64 * z.ln();
(new_state, factor)
}
}
pub struct DifferentialEvolution {
gamma: Normal<f64>,
}
impl DifferentialEvolution {
pub fn new(gamma_mean: f64, gamma_std_dev: f64) -> Self {
Self {
gamma: Normal::new(gamma_mean, gamma_std_dev).unwrap(),
}
}
}
impl<P> Move<P> for DifferentialEvolution
where
P: Params,
{
fn propose<'a, O, R>(&self, this: &'a P, mut other: O, rng: &mut R) -> (P, f64)
where
O: FnMut(&mut R) -> &'a P,
R: Rng,
{
let first_other = other(rng);
let mut second_other = other(rng);
while ptr::eq(first_other, second_other) {
second_other = other(rng);
}
let gamma = self.gamma.sample(rng);
let new_state = P::collect(
this.values()
.zip(first_other.values())
.zip(second_other.values())
.map(|((this, first_other), second_other)| {
(first_other - second_other).mul_add(gamma, *this)
}),
);
(new_state, 0.)
}
}
pub struct RandomGaussian {
displ: Normal<f64>,
}
impl RandomGaussian {
pub fn new(displ: f64) -> Self {
Self {
displ: Normal::new(0., displ).unwrap(),
}
}
}
impl<P> Move<P> for RandomGaussian
where
P: Params,
{
fn propose<'a, O, R>(&self, this: &'a P, _other: O, rng: &mut R) -> (P, f64)
where
O: FnMut(&mut R) -> &'a P,
R: Rng,
{
let dir = rng.random_range(0..this.dimension());
let new_state = P::collect(this.values().enumerate().map(|(idx, value)| {
if idx == dir {
value + self.displ.sample(rng)
} else {
*value
}
}));
(new_state, 0.)
}
}
pub struct Mixture<W, M>(WeightedAliasIndex<W>, M)
where
W: AliasableWeight;
macro_rules! impl_mixture {
( $( $types:ident @ $weights:ident ),+ ) => {
impl<W, $( $types ),+> From<( $( ( $types, W ) ),+ )> for Mixture<W, ( $( $types ),+ )>
where
W: AliasableWeight
{
#[allow(non_snake_case)]
fn from(( $( ( $types, $weights ) ),+ ): ( $( ( $types, W ) ),+ )) -> Self {
let index = WeightedAliasIndex::new(vec![$( $weights ),+]).unwrap();
Self(index, ( $( $types ),+ ))
}
}
impl<W, $( $types ),+, P> Move<P> for Mixture<W, ( $( $types ),+ )>
where
W: AliasableWeight,
P: Params,
$( $types: Move<P> ),+
{
#[allow(non_snake_case)]
fn propose<'a, O, R>(&self, this: &'a P, other: O, rng: &mut R) -> (P, f64)
where
O: FnMut(&mut R) -> &'a P,
R: Rng,
{
let Self(index, ( $( $types ),+ )) = self;
let chosen_index = index.sample(rng);
let mut index = 0;
$(
#[allow(unused_assignments)]
if chosen_index == index {
return $types.propose(this, other, rng)
} else {
index += 1;
}
)+
unreachable!()
}
}
};
}
impl_mixture!(A @ a, B @ b);
impl_mixture!(A @ a, B @ b, C @ c);
impl_mixture!(A @ a, B @ b, C @ c, D @ d);
impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e);
impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f);
impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g);
impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g, H @ h);
impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g, H @ h, I @ i);
impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g, H @ h, I @ i, J @ j);
pub trait Model: Send + Sync {
type Params: Params;
fn log_prob(&self, state: &Self::Params) -> f64;
}
pub fn sample<MD, MV, W, R, S, E>(
model: &MD,
move_: &MV,
walkers: W,
mut schedule: S,
execution: E,
) -> (Vec<MD::Params>, usize)
where
MD: Model,
MV: Move<MD::Params> + Send + Sync,
W: Iterator<Item = (MD::Params, R)>,
R: Rng + Send + Sync,
S: Schedule<MD::Params>,
E: Execution,
{
let mut walkers = walkers
.map(|(state, rng)| Walker::new(model, state, rng))
.collect::<Vec<_>>();
assert!(!walkers.is_empty() && walkers.len() % 2 == 0);
assert!(walkers.len() >= 2 * walkers[0].state.dimension());
let mut chain =
Vec::with_capacity(walkers.len() * schedule.iterations(walkers.len()).unwrap_or(0));
let half = walkers.len() / 2;
let (lower_half, upper_half) = walkers.split_at_mut(half);
let random_index = Uniform::new(0, half).unwrap();
let update_walker = move |walker: &mut Walker<MD, R>, other_walkers: &[Walker<MD, R>]| {
walker.move_(model, move_, |rng| &other_walkers[random_index.sample(rng)])
};
while schedule.next_step(&chain).is_continue() {
execution.extend_chain(&mut chain, lower_half, |walker| {
update_walker(walker, upper_half)
});
execution.extend_chain(&mut chain, upper_half, |walker| {
update_walker(walker, lower_half)
});
}
let accepted = walkers.iter().map(|walker| walker.accepted).sum();
(chain, accepted)
}
struct Walker<MD, R>
where
MD: Model,
{
state: MD::Params,
log_prob: f64,
rng: R,
accepted: usize,
}
impl<MD, R> Walker<MD, R>
where
MD: Model,
R: Rng,
{
fn new(model: &MD, state: MD::Params, rng: R) -> Self {
let log_prob = model.log_prob(&state);
Self {
state,
log_prob,
rng,
accepted: 0,
}
}
fn move_<'a, MV, O>(&'a mut self, model: &MD, move_: &MV, mut other: O) -> MD::Params
where
MV: Move<MD::Params>,
O: FnMut(&mut R) -> &'a Self,
{
let (mut new_state, factor) =
move_.propose(&self.state, |rng| &other(rng).state, &mut self.rng);
let new_log_prob = model.log_prob(&new_state);
let log_prob_diff = factor + new_log_prob - self.log_prob;
if log_prob_diff > gen_unit(&mut self.rng).ln() {
self.state.clone_from(&new_state);
self.log_prob = new_log_prob;
self.accepted += 1;
} else {
new_state.clone_from(&self.state);
}
new_state
}
}
fn gen_unit<R>(rng: &mut R) -> f64
where
R: Rng,
{
StandardUniform.sample(rng)
}
pub fn auto_corr_time<C>(
chain: C,
min_win_size: Option<usize>,
min_chain_len: Option<usize>,
) -> Option<f64>
where
C: ExactSizeIterator<Item = f64> + Clone,
{
let min_win_size = min_win_size.unwrap_or(5) as f64;
let min_chain_len = min_chain_len.unwrap_or(50) as f64;
let mean = chain.clone().sum::<f64>() / chain.len() as f64;
let variance = chain
.clone()
.map(|sample| (sample - mean).powi(2))
.sum::<f64>()
/ chain.len() as f64;
let mut estimate = 1.;
for lag in 1..chain.len() {
let auto_corr = chain
.clone()
.skip(lag)
.zip(chain.clone())
.map(|(lhs, rhs)| (lhs - mean) * (rhs - mean))
.sum::<f64>()
/ chain.len() as f64
/ variance;
estimate += 2. * auto_corr;
if lag as f64 >= min_win_size * estimate {
break;
}
}
if chain.len() as f64 >= min_chain_len * estimate {
Some(estimate)
} else {
None
}
}
pub trait Schedule<P>
where
P: Params,
{
fn next_step(&mut self, chain: &[P]) -> ControlFlow<()>;
fn iterations(&self, _walkers: usize) -> Option<usize> {
None
}
}
pub struct MinChainLen(pub usize);
impl<P> Schedule<P> for MinChainLen
where
P: Params,
{
fn next_step(&mut self, chain: &[P]) -> ControlFlow<()> {
if self.0 <= chain.len() {
ControlFlow::Break(())
} else {
ControlFlow::Continue(())
}
}
fn iterations(&self, walkers: usize) -> Option<usize> {
Some(self.0 / walkers)
}
}
pub struct WithProgress<S, C> {
pub schedule: S,
pub callback: C,
}
impl<P, S, C> Schedule<P> for WithProgress<S, C>
where
P: Params,
S: Schedule<P>,
C: FnMut(&[P]),
{
fn next_step(&mut self, chain: &[P]) -> ControlFlow<()> {
(self.callback)(chain);
self.schedule.next_step(chain)
}
fn iterations(&self, walkers: usize) -> Option<usize> {
self.schedule.iterations(walkers)
}
}
pub trait Execution {
fn extend_chain<P, W, U>(&self, chain: &mut Vec<P>, walkers: &mut [W], update: U)
where
P: Send + Sync,
W: Send + Sync,
U: Fn(&mut W) -> P + Send + Sync;
}
pub struct Serial;
impl Execution for Serial {
fn extend_chain<P, W, U>(&self, chain: &mut Vec<P>, walkers: &mut [W], update: U)
where
P: Send + Sync,
W: Send + Sync,
U: Fn(&mut W) -> P + Send + Sync,
{
chain.extend(walkers.iter_mut().map(update));
}
}
#[cfg(feature = "rayon")]
pub struct Parallel;
#[cfg(feature = "rayon")]
impl Execution for Parallel {
fn extend_chain<P, W, U>(&self, chain: &mut Vec<P>, walkers: &mut [W], update: U)
where
P: Send + Sync,
W: Send + Sync,
U: Fn(&mut W) -> P + Send + Sync,
{
chain.par_extend(walkers.par_iter_mut().map(update));
}
}