use crate::{
algorithms::mcmc::{
validate_walker_inputs, validate_weighted_moves, ChainStorageMode, EnsembleStatus, Walker,
},
core::{
utils::{RandChoice, SampleFloat},
MCMCSummary, Point,
},
error::{GaneshError, GaneshResult},
traits::{
status::StatusType, Algorithm, LogDensity, Status, SupportsParameterNames,
SupportsTransform, Transform,
},
DVector, Float,
};
use fastrand::Rng;
#[derive(Copy, Clone)]
pub enum AIESMove {
Stretch {
a: Float,
},
Walk,
}
impl AIESMove {
pub const fn stretch(weight: Float) -> WeightedAIESMove {
(Self::Stretch { a: 2.0 }, weight)
}
pub fn custom_stretch(a: Float, weight: Float) -> GaneshResult<WeightedAIESMove> {
if a <= 0.0 {
return Err(GaneshError::ConfigError(
"Scaling parameter must be greater than 0".to_string(),
));
}
Ok((Self::Stretch { a }, weight))
}
pub const fn walk(weight: Float) -> WeightedAIESMove {
(Self::Walk, weight)
}
fn step<P, U, E>(
&self,
problem: &P,
transform: &Option<Box<dyn Transform>>,
args: &U,
ensemble: &mut EnsembleStatus,
rng: &mut Rng,
) -> Result<(), E>
where
P: LogDensity<U, E>,
{
let mut positions = Vec::with_capacity(ensemble.len());
match self {
Self::Stretch { a } => {
ensemble
.set_message()
.step_with_message(&format!("Stretch Move (a = {})", &a));
}
Self::Walk => {
ensemble.set_message().step_with_message("Walk Move");
}
}
for (i, walker) in ensemble.iter().enumerate() {
let x_k = walker.get_latest();
let (proposal, r) = match self {
Self::Stretch { a } => {
let z = (a - 1.0).mul_add(rng.float(), 1.0).powi(2) / a;
let x_l =
ensemble.walkers[ensemble.get_compliment_walker_index(i, rng)].get_latest();
let mut proposal = Point::from(
transform.to_internal(&x_l.x).as_ref()
+ (transform.to_internal(&x_k.x).as_ref()
- transform.to_internal(&x_l.x).as_ref())
.scale(z),
);
proposal.log_density_transformed(problem, transform, args)?;
let n = x_l.x.len();
let r =
z.ln().mul_add((n - 1) as Float, proposal.fx_checked()) - x_k.fx_checked();
(proposal, r)
}
Self::Walk => {
let x_s = ensemble.internal_mean_compliment(i, transform);
let w = ensemble
.iter_compliment(i)
.map(|x_l| {
(transform.to_internal(&x_l.x).as_ref() - &x_s)
.scale(rng.normal(0.0, 1.0))
})
.sum::<DVector<Float>>();
let mut proposal = Point::from(transform.to_internal(&x_k.x).as_ref() + w);
proposal.log_density_transformed(problem, transform, args)?;
let r = proposal.fx_checked() - x_k.fx_checked();
(proposal, r)
}
};
if r > rng.float().ln() {
positions.push(proposal.to_external(transform))
} else {
positions.push(x_k.clone())
}
}
ensemble.n_f_evals += ensemble.walkers.len();
ensemble.push(positions);
Ok(())
}
}
#[derive(Clone)]
pub struct AIESConfig {
parameter_names: Option<Vec<String>>,
transform: Option<Box<dyn Transform>>,
moves: Vec<WeightedAIESMove>,
chain_storage: ChainStorageMode,
}
impl SupportsTransform for AIESConfig {
fn get_transform_mut(&mut self) -> &mut Option<Box<dyn Transform>> {
&mut self.transform
}
}
impl SupportsParameterNames for AIESConfig {
fn get_parameter_names_mut(&mut self) -> &mut Option<Vec<String>> {
&mut self.parameter_names
}
}
impl AIESConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_moves<T: AsRef<[WeightedAIESMove]>>(mut self, moves: T) -> GaneshResult<Self> {
validate_weighted_moves(
&moves
.as_ref()
.iter()
.map(|move_weight| move_weight.1)
.collect::<Vec<_>>(),
"AIES",
)?;
self.moves = moves.as_ref().to_vec();
Ok(self)
}
pub const fn with_chain_storage(mut self, chain_storage: ChainStorageMode) -> Self {
self.chain_storage = chain_storage;
self
}
}
impl Default for AIESConfig {
fn default() -> Self {
Self {
parameter_names: None,
transform: None,
moves: vec![AIESMove::stretch(1.0)],
chain_storage: ChainStorageMode::default(),
}
}
}
#[derive(Clone)]
pub struct AIESInit {
walkers: Vec<DVector<Float>>,
}
impl AIESInit {
pub fn new(walkers: Vec<DVector<Float>>) -> GaneshResult<Self> {
validate_walker_inputs(&walkers, "AIES", 2)?;
Ok(Self { walkers })
}
}
#[derive(Clone)]
pub struct AIES {
rng: Rng,
}
impl Default for AIES {
fn default() -> Self {
Self::new(Some(0))
}
}
pub type WeightedAIESMove = (AIESMove, Float);
impl AIES {
pub fn new(seed: Option<u64>) -> Self {
Self {
rng: seed.map_or_else(fastrand::Rng::new, fastrand::Rng::with_seed),
}
}
}
impl<P, U, E> Algorithm<P, EnsembleStatus, U, E> for AIES
where
P: LogDensity<U, E>,
{
type Summary = MCMCSummary;
type Config = AIESConfig;
type Init = AIESInit;
fn initialize(
&mut self,
problem: &P,
status: &mut EnsembleStatus,
args: &U,
init: &Self::Init,
config: &Self::Config,
) -> Result<(), E> {
status.walkers = init.walkers.iter().cloned().map(Walker::new).collect();
for walker in status.walkers.iter_mut() {
walker.set_chain_storage(config.chain_storage);
}
status.log_density_latest(problem, args)?;
status.set_message().initialize();
Ok(())
}
fn step(
&mut self,
_current_step: usize,
problem: &P,
status: &mut EnsembleStatus,
args: &U,
config: &Self::Config,
) -> Result<(), E> {
let step_type_index = self
.rng
.choice_weighted(&config.moves.iter().map(|s| s.1).collect::<Vec<Float>>())
.unwrap_or_else(|| {
unreachable!("AIESConfig validates that move weights contain a positive entry")
});
let step_type = config.moves[step_type_index].0;
step_type.step(problem, &config.transform, args, status, &mut self.rng)
}
fn summarize(
&self,
_current_step: usize,
_func: &P,
status: &EnsembleStatus,
_args: &U,
_init: &Self::Init,
config: &Self::Config,
) -> Result<Self::Summary, E> {
let mut message = status.message().clone();
if matches!(message.status_type, StatusType::Custom)
&& message.text.contains("Maximum number of steps reached")
{
message.succeed_with_message(&message.text.clone());
}
Ok(MCMCSummary {
bounds: None,
parameter_names: config.parameter_names.clone(),
message,
chain: status.get_chain(None, None),
chain_storage: config.chain_storage,
n_f_evals: status.n_f_evals,
n_g_evals: status.n_g_evals,
n_h_evals: 0,
dimension: status.dimension(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
core::{Callbacks, MaxSteps},
test_functions::Rosenbrock,
traits::Algorithm,
};
use approx::assert_relative_eq;
use std::convert::Infallible;
fn make_walkers(n_walkers: usize, dim: usize) -> Vec<DVector<Float>> {
(0..n_walkers)
.map(|i| DVector::from_element(dim, i as Float + 1.0))
.collect()
}
struct CenteredLogDensity {
target: Float,
}
impl crate::traits::LogDensity<(), Infallible> for CenteredLogDensity {
fn log_density(&self, x: &DVector<Float>, _: &()) -> Result<Float, Infallible> {
Ok(-Float::powi(x[0] - self.target, 2))
}
}
#[test]
fn test_aies_config_builders() {
let walkers = make_walkers(3, 2);
let moves = vec![AIESMove::stretch(0.5), AIESMove::walk(0.5)];
let init = AIESInit::new(walkers.clone()).unwrap();
let config = AIESConfig::default().with_moves(moves.clone()).unwrap();
assert_eq!(init.walkers.len(), walkers.len());
assert_eq!(config.moves.len(), moves.len());
}
#[test]
fn test_aies_rejects_invalid_move_weights() {
let err = match AIESConfig::default()
.with_moves([AIESMove::stretch(-1.0), AIESMove::walk(1.0)])
{
Err(err) => err,
Ok(_) => panic!("negative AIES move weights should be rejected"),
};
assert!(err.to_string().contains("finite and non-negative"));
let err =
match AIESConfig::default().with_moves([AIESMove::stretch(0.0), AIESMove::walk(0.0)]) {
Err(err) => err,
Ok(_) => panic!("zero-sum AIES move weights should be rejected"),
};
assert!(err.to_string().contains("sum to a positive finite value"));
}
#[test]
fn test_aies_rejects_invalid_walker_inputs() {
let err = match AIESInit::new(Vec::new()) {
Err(err) => err,
Ok(_) => panic!("empty AIES walker lists should be rejected"),
};
assert!(err.to_string().contains("at least 2 walkers"));
let err = match AIESInit::new(vec![DVector::from_row_slice(&[1.0])]) {
Err(err) => err,
Ok(_) => panic!("single-walker AIES inputs should be rejected"),
};
assert!(err.to_string().contains("at least 2 walkers"));
let err = match AIESInit::new(vec![
DVector::from_row_slice(&[1.0, 2.0]),
DVector::from_row_slice(&[3.0]),
]) {
Err(err) => err,
Ok(_) => panic!("mixed-dimension AIES walkers should be rejected"),
};
assert!(err.to_string().contains("same dimension"));
}
#[test]
fn test_aiesmove_updates_message() {
let mut rng = Rng::with_seed(0);
let problem = Rosenbrock { n: 2 };
let mut status = EnsembleStatus::default();
AIESMove::Stretch { a: 2.0 }
.step(&problem, &None, &(), &mut status, &mut rng)
.unwrap();
assert!(status.message().to_string().contains("Stretch Move"));
AIESMove::Walk
.step(&problem, &None, &(), &mut status, &mut rng)
.unwrap();
assert!(status.message().to_string().contains("Walk Move"));
}
#[test]
fn test_aies_initialize_and_summarize() {
let mut aies = AIES::default();
let walkers = make_walkers(3, 2);
let init = AIESInit::new(walkers.clone()).unwrap();
let config = AIESConfig::default();
let problem = Rosenbrock { n: 2 };
let mut status = EnsembleStatus::default();
aies.initialize(&problem, &mut status, &(), &init, &config)
.unwrap();
assert_eq!(status.walkers.len(), walkers.len());
let summary = aies
.summarize(0, &problem, &status, &(), &init, &config)
.unwrap();
assert_eq!(summary.dimension, status.dimension());
}
#[test]
fn test_aies_step_runs() {
let mut aies = AIES::default();
let problem = Rosenbrock { n: 2 };
let walkers = make_walkers(3, 2);
let moves = vec![AIESMove::stretch(1.0), AIESMove::walk(1.0)];
let init = AIESInit::new(walkers).unwrap();
let config = AIESConfig::default().with_moves(moves).unwrap();
let mut status = EnsembleStatus::default();
aies.initialize(&problem, &mut status, &(), &init, &config)
.unwrap();
assert!(aies.step(0, &problem, &mut status, &(), &config).is_ok());
}
#[test]
fn stretch_move_proposes_toward_current_from_compliment() {
let mut rng = Rng::with_seed(0);
let a: Float = 2.0;
let z = (a - 1.0).mul_add(rng.float(), 1.0).powi(2) / a;
let expected = 1.0 + z * (2.0 - 1.0);
let problem = CenteredLogDensity { target: expected };
let mut ensemble = EnsembleStatus {
walkers: vec![
Walker::new(DVector::from_row_slice(&[2.0])),
Walker::new(DVector::from_row_slice(&[1.0])),
],
..Default::default()
};
ensemble.log_density_latest(&problem, &()).unwrap();
AIESMove::Stretch { a }
.step(&problem, &None, &(), &mut ensemble, &mut Rng::with_seed(0))
.unwrap();
let x0 = ensemble.walkers[0].get_latest();
assert_relative_eq!(x0.x[0], expected);
}
#[test]
fn summary_marks_max_steps_as_success_and_counts_initial_evals() {
let mut aies = AIES::default();
let walkers = make_walkers(4, 2);
let init = AIESInit::new(walkers).unwrap();
let config = AIESConfig::default();
let result = aies
.process(
&Rosenbrock { n: 2 },
&(),
init,
config,
Callbacks::empty().with_terminator(MaxSteps(2)),
)
.unwrap();
assert!(result.n_f_evals >= 4);
assert_eq!(result.n_g_evals, 0);
assert!(result.message.success());
assert!(result
.message
.text
.contains("Maximum number of steps reached"));
}
#[test]
fn summary_uses_parameter_names_from_config() {
let mut aies = AIES::default();
let init = AIESInit::new(make_walkers(4, 2)).unwrap();
let config = AIESConfig::default().with_parameter_names(["alpha", "beta"]);
let result = aies
.process(
&Rosenbrock { n: 2 },
&(),
init,
config,
Callbacks::empty().with_terminator(MaxSteps(2)),
)
.unwrap();
assert_eq!(
result.parameter_names,
Some(vec!["alpha".to_string(), "beta".to_string()])
);
}
#[test]
fn rolling_chain_storage_limits_retained_history() {
let mut aies = AIES::default();
let init = AIESInit::new(make_walkers(4, 2)).unwrap();
let config =
AIESConfig::default().with_chain_storage(ChainStorageMode::Rolling { window: 2 });
let result = aies
.process(
&Rosenbrock { n: 2 },
&(),
init,
config,
Callbacks::empty().with_terminator(MaxSteps(4)),
)
.unwrap();
assert_eq!(
result.chain_storage,
ChainStorageMode::Rolling { window: 2 }
);
assert!(result.chain.iter().all(|walker| walker.len() <= 2));
assert_eq!(result.dimension.1, 2);
}
#[test]
fn sampled_chain_storage_downsamples_retained_history() {
let mut aies = AIES::default();
let init = AIESInit::new(make_walkers(4, 2)).unwrap();
let config = AIESConfig::default().with_chain_storage(ChainStorageMode::Sampled {
keep_every: 2,
max_samples: Some(3),
});
let result = aies
.process(
&Rosenbrock { n: 2 },
&(),
init,
config,
Callbacks::empty().with_terminator(MaxSteps(4)),
)
.unwrap();
assert_eq!(
result.chain_storage,
ChainStorageMode::Sampled {
keep_every: 2,
max_samples: Some(3),
}
);
assert!(result.chain.iter().all(|walker| walker.len() <= 3));
assert_eq!(result.dimension.1, 3);
}
}