use crate::{
core::Point,
error::{GaneshError, GaneshResult},
traits::{Algorithm, LogDensity, Terminator},
DVector, Float,
};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::{ops::ControlFlow, sync::Arc};
pub mod aies;
pub use aies::{AIESConfig, AIESMove, AIES};
pub mod ess;
pub use ess::{ESSConfig, ESSMove, ESS};
pub mod ensemble_status;
pub use crate::core::mcmc_diagnostics::integrated_autocorrelation_times;
pub use ensemble_status::EnsembleStatus;
#[derive(Clone, Copy, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
pub enum ChainStorageMode {
#[default]
Full,
Rolling {
window: usize,
},
Sampled {
keep_every: usize,
max_samples: Option<usize>,
},
}
impl ChainStorageMode {
pub(crate) const fn history_limit(self) -> Option<usize> {
match self {
Self::Full => None,
Self::Rolling { window } => Some(window),
Self::Sampled { max_samples, .. } => max_samples,
}
}
}
pub(crate) fn validate_weighted_moves(weights: &[Float], family: &str) -> GaneshResult<()> {
if weights.is_empty() {
return Err(GaneshError::ConfigError(format!(
"{family} move weights must not be empty"
)));
}
if weights
.iter()
.any(|&weight| !weight.is_finite() || weight < 0.0)
{
return Err(GaneshError::ConfigError(format!(
"{family} move weights must be finite and non-negative"
)));
}
let total = weights.iter().sum::<Float>();
if !total.is_finite() || total <= 0.0 {
return Err(GaneshError::ConfigError(format!(
"{family} move weights must sum to a positive finite value"
)));
}
Ok(())
}
pub(crate) fn validate_walker_inputs(
walkers: &[DVector<Float>],
family: &str,
min_walkers: usize,
) -> GaneshResult<()> {
if walkers.len() < min_walkers {
return Err(GaneshError::ConfigError(format!(
"{family} requires at least {min_walkers} walkers"
)));
}
let Some(first) = walkers.first() else {
return Err(GaneshError::ConfigError(format!(
"{family} walker list must not be empty"
)));
};
if first.is_empty() {
return Err(GaneshError::ConfigError(format!(
"{family} walker dimension must be at least 1"
)));
}
if walkers.iter().any(|walker| walker.len() != first.len()) {
return Err(GaneshError::ConfigError(format!(
"{family} walkers must all have the same dimension"
)));
}
Ok(())
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Walker {
initial: Point<DVector<Float>>,
current: Point<DVector<Float>>,
history: Vec<Point<DVector<Float>>>,
chain_storage: ChainStorageMode,
current_retained: bool,
total_samples_seen: usize,
}
impl Walker {
pub fn new(x0: DVector<Float>) -> Self {
let initial = Point::from(x0);
let current = initial.clone();
let history = vec![initial.clone()];
Self {
initial,
current,
history,
chain_storage: ChainStorageMode::Full,
current_retained: true,
total_samples_seen: 1,
}
}
pub fn dimension(&self) -> (usize, usize) {
let n_steps = self.retained_len();
let n_variables = self.current.x.len();
(n_steps, n_variables)
}
pub fn reset(&mut self) {
self.current = self.initial.clone();
self.history = vec![self.initial.clone()];
self.current_retained = true;
self.total_samples_seen = 1;
self.enforce_history_limit();
}
pub const fn get_latest(&self) -> &Point<DVector<Float>> {
&self.current
}
pub fn get_latest_mut(&mut self) -> &mut Point<DVector<Float>> {
&mut self.current
}
pub fn log_density_latest<U, E>(
&mut self,
func: &dyn LogDensity<U, E>,
args: &U,
) -> Result<(), E> {
self.get_latest_mut().log_density(func, args)
}
pub fn push(&mut self, position: Point<DVector<Float>>) {
self.total_samples_seen += 1;
self.current = position;
self.current_retained = self.should_retain_current();
if self.current_retained {
self.history.push(self.current.clone());
}
self.enforce_history_limit();
}
pub(crate) fn set_chain_storage(&mut self, chain_storage: ChainStorageMode) {
self.chain_storage = chain_storage;
self.rebuild_retained_history();
self.enforce_history_limit();
}
pub(crate) fn retained_positions(&self) -> Vec<&Point<DVector<Float>>> {
if self.current_retained {
self.history.iter().collect()
} else {
let mut positions = self.history.iter().collect::<Vec<_>>();
positions.push(&self.current);
positions
}
}
fn retained_len(&self) -> usize {
self.history.len() + usize::from(!self.current_retained)
}
const fn should_retain_current(&self) -> bool {
match self.chain_storage {
ChainStorageMode::Full | ChainStorageMode::Rolling { .. } => true,
ChainStorageMode::Sampled { keep_every, .. } => {
keep_every == 0 || (self.total_samples_seen - 1) % keep_every == 0
}
}
}
fn rebuild_retained_history(&mut self) {
self.history = vec![self.initial.clone()];
self.current_retained = true;
if self.total_samples_seen == 1 {
self.current = self.initial.clone();
return;
}
if self.should_retain_current() {
self.history.push(self.current.clone());
self.current_retained = true;
} else {
self.current_retained = false;
}
}
fn enforce_history_limit(&mut self) {
if let Some(limit) = self.chain_storage.history_limit() {
if self.history.len() > limit {
let excess = self.history.len() - limit;
self.history.drain(0..excess);
}
}
}
}
#[derive(Clone)]
pub struct AutocorrelationTerminator {
n_check: usize,
n_taus_threshold: usize,
dtau_threshold: Float,
discard: Float,
terminate: bool,
c: Option<Float>,
verbose: bool,
pub taus: Vec<Float>,
}
impl AutocorrelationTerminator {
pub const fn with_n_check(mut self, n_check: usize) -> Self {
self.n_check = n_check;
self
}
pub const fn with_n_taus_threshold(mut self, n_taus_threshold: usize) -> Self {
self.n_taus_threshold = n_taus_threshold;
self
}
pub const fn with_dtau_threshold(mut self, dtau_threshold: Float) -> Self {
self.dtau_threshold = dtau_threshold;
self
}
pub const fn with_discard(mut self, discard: Float) -> Self {
self.discard = discard;
self
}
pub const fn with_terminate(mut self, terminate: bool) -> Self {
self.terminate = terminate;
self
}
pub const fn with_sokal_window(mut self, c: Float) -> Self {
self.c = Some(c);
self
}
pub const fn with_verbose(mut self, verbose: bool) -> Self {
self.verbose = verbose;
self
}
pub fn build(self) -> Arc<Mutex<Self>> {
Arc::new(Mutex::new(self))
}
}
impl Default for AutocorrelationTerminator {
fn default() -> Self {
Self {
n_check: 50,
n_taus_threshold: 50,
dtau_threshold: 0.01,
discard: 0.5,
terminate: true,
c: None,
verbose: false,
taus: Vec::default(),
}
}
}
impl<A, P, U, E, C> Terminator<A, P, EnsembleStatus, U, E, C> for AutocorrelationTerminator
where
A: Algorithm<P, EnsembleStatus, U, E, Config = C>,
{
fn check_for_termination(
&mut self,
current_step: usize,
_algorithm: &mut A,
_problem: &P,
status: &mut EnsembleStatus,
_args: &U,
_config: &C,
) -> ControlFlow<()> {
if current_step % self.n_check == 0 {
let taus = status.get_integrated_autocorrelation_times(
self.c,
Some((current_step as Float * self.discard) as usize),
None,
);
let tau = taus.mean();
let enough_steps = tau * (self.n_taus_threshold as Float) < current_step as Float;
let (dtau, dtau_met) = if !self.taus.is_empty() {
let dtau = Float::abs(self.taus.last().unwrap_or(&0.0) - tau) / tau;
(dtau, dtau < self.dtau_threshold)
} else {
(Float::NAN, false)
};
let converged = enough_steps && dtau_met;
if self.verbose {
println!("Integrated Autocorrelation Analysis:");
println!("τ = \n{}", taus);
println!(
"Minimum steps to converge = {}",
(tau * (self.n_taus_threshold as Float)) as usize
);
println!("Steps completed = {}", current_step);
println!("Δτ/τ = {} (converges if < {})", dtau, self.dtau_threshold);
println!("Converged: {}\n", converged);
}
self.taus.push(tau);
if converged && self.terminate {
return ControlFlow::Break(());
}
}
ControlFlow::Continue(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
core::{utils::SampleFloat, Callbacks},
test_functions::Rosenbrock,
DVector,
};
use fastrand::Rng;
#[test]
fn test_autocorrelation_terminator() {
let problem = Rosenbrock { n: 2 };
let mut rng = Rng::new();
rng.seed(0);
let x0: Vec<DVector<Float>> = (0..5)
.map(|_| DVector::from_fn(2, |_, _| rng.normal(1.0, 4.0)))
.collect();
let aco = AutocorrelationTerminator::default()
.with_n_check(20)
.with_discard(0.55)
.with_sokal_window(7.1)
.with_terminate(true)
.with_dtau_threshold(0.05)
.with_n_taus_threshold(51)
.with_verbose(false)
.build();
let mut sampler = ESS::new(Some(1));
let init = crate::algorithms::mcmc::ess::ESSInit::new(x0).unwrap();
let config = ESSConfig::default()
.with_moves([ESSMove::gaussian(0.1), ESSMove::differential(0.9)])
.unwrap();
let result = sampler
.process(
&problem,
&(),
init,
config,
Callbacks::empty().with_terminator(aco.clone()),
)
.unwrap();
println!(
"Walker 0 Final Position: {}",
result.chain[0].last().unwrap()
);
println!(
"Autocorrelation Time at Termination: {}",
aco.lock().taus.last().unwrap()
)
}
}