use crate::{
algorithms::mcmc::{
validate_walker_inputs, validate_weighted_moves, ChainStorageMode, EnsembleStatus, Walker,
},
core::{
utils::{generate_random_vector_in_limits, RandChoice, SampleFloat},
MCMCSummary, Point,
},
error::{GaneshError, GaneshResult},
traits::{
status::StatusType, Algorithm, LogDensity, Status, SupportsParameterNames,
SupportsTransform, Transform,
},
DMatrix, DVector, Float, PI,
};
use fastrand::Rng;
use nalgebra::Cholesky;
#[derive(Copy, Clone)]
pub enum ESSMove {
Differential,
Gaussian,
Global {
scale: Float,
rescale_cov: Float,
n_components: usize,
},
}
impl ESSMove {
pub const fn differential(weight: Float) -> WeightedESSMove {
(Self::Differential, weight)
}
pub const fn gaussian(weight: Float) -> WeightedESSMove {
(Self::Gaussian, weight)
}
pub const fn global(weight: Float) -> WeightedESSMove {
(
Self::Global {
scale: 1.0,
rescale_cov: 0.001,
n_components: 5,
},
weight,
)
}
pub fn custom_global(
weight: Float,
scale: Option<Float>,
rescale_cov: Option<Float>,
n_components: Option<usize>,
) -> GaneshResult<WeightedESSMove> {
if let Some(scale) = scale {
if scale <= 0.0 {
return Err(GaneshError::ConfigError(
"scale must be greater than 0".to_string(),
));
}
}
if let Some(rescale_cov) = rescale_cov {
if rescale_cov <= 0.0 {
return Err(GaneshError::ConfigError(
"rescale_cov must be greater than 0".to_string(),
));
}
}
if let Some(n_components) = n_components {
if n_components < 2 {
return Err(GaneshError::ConfigError(
"n_components must be greater than 1".to_string(),
));
}
}
Ok((
Self::Global {
scale: scale.unwrap_or(1.0),
rescale_cov: rescale_cov.unwrap_or(0.001),
n_components: n_components.unwrap_or(5),
},
weight,
))
}
#[allow(clippy::too_many_arguments)]
fn step<P, U, E>(
&self,
step: usize,
n_adaptive: usize,
max_steps: usize,
mu: &mut Float,
problem: &P,
transform: &Option<Box<dyn Transform>>,
args: &U,
ensemble: &mut EnsembleStatus,
rng: &mut Rng,
) -> Result<(), E>
where
P: LogDensity<U, E>,
{
let mut positions = Vec::with_capacity(ensemble.len());
match self {
Self::Differential => {
ensemble
.set_message()
.step_with_message("Differential Move");
}
Self::Gaussian => {
ensemble.set_message().step_with_message("Gaussian Move");
}
Self::Global {
scale,
rescale_cov,
n_components,
} => {
ensemble.set_message().step_with_message(&format!(
"Global Move (scale = {}, rescale_cov = {}, n_components = {})",
scale, rescale_cov, n_components
));
}
}
let mut n_expand = 0;
let mut n_contract = 0;
let mut dpgm_result = None;
let mut n_f_evals: usize = 0;
for (i, walker) in ensemble.iter().enumerate() {
let x_k = walker.get_latest();
let eta = match self {
Self::Differential => {
let s = ensemble.get_compliment_walker_indices(i, 2, rng);
let x_l = ensemble.walkers[s[0]].get_latest();
let x_m = ensemble.walkers[s[1]].get_latest();
let eta = (transform.to_internal(&x_l.x).as_ref()
- transform.to_internal(&x_m.x).as_ref())
.scale(*mu);
eta
}
Self::Gaussian => {
let x_s = ensemble.internal_mean_compliment(i, transform);
ensemble
.iter_compliment(i)
.map(|x_l| {
(transform.to_internal(&x_l.x).as_ref() - &x_s)
.scale(rng.normal(0.0, 1.0))
})
.sum::<DVector<Float>>()
.scale(2.0 * *mu)
}
Self::Global {
scale,
rescale_cov,
n_components,
} => {
let dpgm = dpgm_result
.get_or_insert_with(|| dpgm(*n_components, ensemble, transform, rng));
let labels = &dpgm.labels;
let means = &dpgm.means;
let covariances = &dpgm.covariances;
let indices = rng.choose_multiple(labels.iter(), 2);
let a = indices[0];
let b = indices[1];
if a == b {
rng.mv_normal(&means[*a], &covariances[*a])
.scale(2.0 * scale)
} else {
(rng.mv_normal(&means[*a], &covariances[*a].scale(*rescale_cov))
- rng.mv_normal(&means[*b], &covariances[*b].scale(*rescale_cov)))
.scale(2.0)
}
}
};
let y = x_k.fx_checked() + rng.float().ln();
let x_k_internal = transform.to_internal(&x_k.x).into_owned();
let mut l = -rng.float();
let mut p_l = Point::from(&x_k_internal + eta.scale(l));
p_l.log_density_transformed(problem, transform, args)?;
n_f_evals += 1;
let mut r = l + 1.0;
let mut p_r = Point::from(&x_k_internal + eta.scale(r));
p_r.log_density_transformed(problem, transform, args)?;
n_f_evals += 1;
while y < p_l.fx_checked() && n_expand < max_steps {
l -= 1.0;
p_l.set_position(&x_k_internal + eta.scale(l));
p_l.log_density_transformed(problem, transform, args)?;
n_f_evals += 1;
n_expand += 1;
}
while y < p_r.fx_checked() && n_expand < max_steps {
r += 1.0;
p_r.set_position(&x_k_internal + eta.scale(r));
p_r.log_density_transformed(problem, transform, args)?;
n_f_evals += 1;
n_expand += 1;
}
let xprime = loop {
let xprime = rng.range(l, r);
let mut p_yprime = Point::from(&x_k_internal + eta.scale(xprime));
p_yprime.log_density_transformed(problem, transform, args)?;
n_f_evals += 1;
if y < p_yprime.fx_checked() || n_contract >= max_steps {
break xprime;
}
if xprime < 0.0 {
l = xprime;
} else {
r = xprime;
}
n_contract += 1;
};
let mut proposal = Point::from(x_k_internal + eta.scale(xprime));
proposal.log_density_transformed(problem, transform, args)?;
n_f_evals += 1;
positions.push(proposal.to_external(transform))
}
ensemble.n_f_evals += n_f_evals;
if step <= n_adaptive {
let total_updates = n_expand + n_contract;
if total_updates > 0 {
*mu *= 2.0 * (n_expand as Float) / (total_updates as Float);
}
}
ensemble.push(positions);
Ok(())
}
}
#[derive(Clone)]
pub struct ESSConfig {
parameter_names: Option<Vec<String>>,
transform: Option<Box<dyn Transform>>,
moves: Vec<WeightedESSMove>,
n_adaptive: usize,
max_steps: usize,
mu: Float,
chain_storage: ChainStorageMode,
}
impl ESSConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_moves<T: AsRef<[WeightedESSMove]>>(mut self, moves: T) -> GaneshResult<Self> {
validate_weighted_moves(
&moves
.as_ref()
.iter()
.map(|move_weight| move_weight.1)
.collect::<Vec<_>>(),
"ESS",
)?;
self.moves = moves.as_ref().to_vec();
Ok(self)
}
pub const fn with_n_adaptive(mut self, n_adaptive: usize) -> Self {
self.n_adaptive = n_adaptive;
self
}
pub const fn with_max_steps(mut self, max_steps: usize) -> Self {
self.max_steps = max_steps;
self
}
pub fn with_mu(mut self, mu: Float) -> GaneshResult<Self> {
if mu <= 0.0 {
return Err(GaneshError::ConfigError(
"Adaptive scaling parameter must be greater than 0".to_string(),
));
}
self.mu = mu;
Ok(self)
}
pub const fn with_chain_storage(mut self, chain_storage: ChainStorageMode) -> Self {
self.chain_storage = chain_storage;
self
}
}
impl Default for ESSConfig {
fn default() -> Self {
Self {
parameter_names: None,
transform: None,
moves: vec![ESSMove::differential(1.0)],
n_adaptive: 0,
max_steps: 10000,
mu: 1.0,
chain_storage: ChainStorageMode::default(),
}
}
}
#[derive(Clone)]
pub struct ESSInit {
walkers: Vec<DVector<Float>>,
}
impl ESSInit {
pub fn new(walkers: Vec<DVector<Float>>) -> GaneshResult<Self> {
validate_walker_inputs(&walkers, "ESS", 3)?;
Ok(Self { walkers })
}
}
impl SupportsTransform for ESSConfig {
fn get_transform_mut(&mut self) -> &mut Option<Box<dyn Transform>> {
&mut self.transform
}
}
impl SupportsParameterNames for ESSConfig {
fn get_parameter_names_mut(&mut self) -> &mut Option<Vec<String>> {
&mut self.parameter_names
}
}
#[derive(Clone)]
pub struct ESS {
rng: Rng,
mu: Float,
}
impl Default for ESS {
fn default() -> Self {
Self::new(Some(0))
}
}
pub type WeightedESSMove = (ESSMove, Float);
impl ESS {
pub fn new(seed: Option<u64>) -> Self {
Self {
rng: seed.map_or_else(fastrand::Rng::new, fastrand::Rng::with_seed),
mu: 1.0,
}
}
}
impl<P, U, E> Algorithm<P, EnsembleStatus, U, E> for ESS
where
P: LogDensity<U, E>,
{
type Summary = MCMCSummary;
type Config = ESSConfig;
type Init = ESSInit;
fn initialize(
&mut self,
problem: &P,
status: &mut EnsembleStatus,
args: &U,
init: &Self::Init,
config: &Self::Config,
) -> Result<(), E> {
status.walkers = init.walkers.iter().cloned().map(Walker::new).collect();
for walker in status.walkers.iter_mut() {
walker.set_chain_storage(config.chain_storage);
}
self.mu = config.mu;
status.log_density_latest(problem, args)?;
status.set_message().initialize();
Ok(())
}
fn step(
&mut self,
current_step: usize,
problem: &P,
status: &mut EnsembleStatus,
args: &U,
config: &Self::Config,
) -> Result<(), E> {
let step_type_index = self
.rng
.choice_weighted(&config.moves.iter().map(|s| s.1).collect::<Vec<Float>>())
.unwrap_or_else(|| {
unreachable!("ESSConfig validates that move weights contain a positive entry")
});
let step_type = config.moves[step_type_index].0;
step_type.step(
current_step,
config.n_adaptive,
config.max_steps,
&mut self.mu,
problem,
&config.transform,
args,
status,
&mut self.rng,
)
}
fn summarize(
&self,
_current_step: usize,
_problem: &P,
status: &EnsembleStatus,
_args: &U,
_init: &Self::Init,
config: &Self::Config,
) -> Result<Self::Summary, E> {
let mut message = status.message().clone();
if matches!(message.status_type, StatusType::Custom)
&& message.text.contains("Maximum number of steps reached")
{
message.succeed_with_message(&message.text.clone());
}
Ok(MCMCSummary {
bounds: None,
parameter_names: config.parameter_names.clone(),
message,
chain: status.get_chain(None, None),
chain_storage: config.chain_storage,
n_f_evals: status.n_f_evals,
n_g_evals: status.n_g_evals,
n_h_evals: 0,
dimension: status.dimension(),
})
}
}
#[allow(clippy::unwrap_used)]
fn kmeans(n_clusters: usize, data: &DMatrix<Float>, rng: &mut Rng) -> Vec<usize> {
let n_walkers = data.nrows();
let n_parameters = data.ncols();
let limits = data
.column_iter()
.map(|col| (col.min(), col.max()))
.collect::<Vec<_>>();
let mut centroids: Vec<DVector<Float>> = (0..n_clusters)
.map(|_| generate_random_vector_in_limits(&limits, rng))
.collect();
let mut labels = vec![0; n_walkers];
for _ in 0..50 {
for (i, walker) in data.row_iter().enumerate() {
labels[i] = centroids
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
(walker.transpose() - *a)
.norm_squared()
.partial_cmp(&(walker.transpose() - *b).norm_squared())
.unwrap()
})
.map(|(j, _)| j)
.unwrap();
}
for (j, centroid) in centroids.iter_mut().enumerate() {
let mut sum = DVector::zeros(n_parameters);
let mut count = 0;
for (l, w) in labels.iter().zip(data.row_iter()) {
if *l == j {
sum += w.transpose();
count += 1;
}
}
if count > 0 {
sum /= count as Float;
}
*centroid = sum;
}
}
labels
}
fn cov(m: &DMatrix<Float>) -> DMatrix<Float> {
let mean: DVector<Float> = m
.row_iter()
.map(|row| row.mean())
.collect::<Vec<Float>>()
.into();
let centered = m.clone() - mean * DMatrix::from_element(1, m.ncols(), 1.0);
¢ered * centered.transpose() / (m.ncols() as Float - 1.0)
}
fn estimate_gaussian_parameters(
data: &DMatrix<Float>,
resp: &DMatrix<Float>,
reg_covar: Float,
) -> (DVector<Float>, DMatrix<Float>, Vec<DMatrix<Float>>) {
assert_eq!(data.nrows(), resp.nrows());
let nk = resp.row_sum_tr().add_scalar(10.0 * Float::EPSILON);
let mut means: DMatrix<Float> = resp.transpose() * data;
means.column_iter_mut().for_each(|mut c| {
c.component_div_assign(&nk);
});
let cov = (0..means.nrows())
.map(|k| {
let mean_k = means.row(k);
let diff =
DMatrix::from_rows(&data.row_iter().map(|row| row - mean_k).collect::<Vec<_>>());
let weighted_diff_t = DMatrix::from_columns(
&diff
.row_iter()
.zip(resp.column(k).iter())
.map(|(d, &r)| d.scale(r).transpose())
.collect::<Vec<_>>(),
);
let mut cov = (&weighted_diff_t * &diff).unscale(nk[k]);
for i in 0..data.ncols() {
cov[(i, i)] += reg_covar;
}
cov
})
.collect();
(nk, means, cov)
}
fn estimate_weights(
nk: &DVector<Float>,
weight_concentration_prior: Float,
) -> (DVector<Float>, DVector<Float>) {
let n_components = nk.len();
(nk.map(|x| x + 1.0), {
let reversed: Vec<Float> = nk.iter().rev().copied().collect();
let mut cumulative_sum = vec![0.0; n_components];
let mut sum: Float = 0.0;
for (i, &val) in reversed.iter().enumerate() {
sum += val;
cumulative_sum[i] = sum;
}
let mut tail = cumulative_sum[..n_components - 1]
.iter()
.rev()
.copied()
.collect::<Vec<Float>>();
tail.push(0.0);
DVector::from_iterator(
n_components,
tail.into_iter().map(|x| x + weight_concentration_prior),
)
})
}
fn estimate_means(
nk: &DVector<Float>,
xk: &DMatrix<Float>,
mean_prior: &DVector<Float>,
mean_precision_prior: Float,
) -> (DVector<Float>, DMatrix<Float>) {
assert_eq!(nk.len(), xk.nrows());
assert_eq!(mean_prior.len(), xk.ncols());
let mean_precision = nk.map(|x| x + mean_precision_prior);
let mut means = DMatrix::zeros(xk.nrows(), xk.ncols());
let nkxk: DMatrix<Float> = DMatrix::from_columns(
&xk.column_iter()
.map(|x| x.component_mul(nk))
.collect::<Vec<_>>(),
);
means.row_iter_mut().for_each(|mut row| {
row += mean_prior.transpose().scale(mean_precision_prior);
});
means += nkxk;
means.column_iter_mut().for_each(|mut col| {
col.component_div_assign(&mean_precision);
});
(mean_precision, means)
}
#[allow(clippy::too_many_arguments)]
fn estimate_precisions(
nk: &DVector<Float>,
xk: &DMatrix<Float>,
sk: &[DMatrix<Float>],
degrees_of_freedom_prior: Float,
covariance_prior: &DMatrix<Float>,
mean_prior: &DVector<Float>,
mean_precision_prior: Float,
mean_precision: &DVector<Float>,
) -> (DVector<Float>, Vec<DMatrix<Float>>, Vec<DMatrix<Float>>) {
let n_components = nk.len();
let n_parameters = mean_prior.len();
assert_eq!(xk.nrows(), n_components);
assert_eq!(xk.ncols(), n_parameters);
assert_eq!(covariance_prior.nrows(), n_parameters);
assert_eq!(covariance_prior.ncols(), n_parameters);
assert_eq!(mean_precision.len(), n_components);
let degrees_of_freedom = nk.map(|x| x + degrees_of_freedom_prior);
let mut covariances = Vec::with_capacity(n_components);
let mut precisions_cholesky = Vec::with_capacity(n_components);
for k in 0..n_components {
let nk_k = nk[k];
let xk_k = xk.row(k).transpose();
let sk_k = &sk[k];
let mean_precision_k = mean_precision[k];
let degrees_of_freedom_k = degrees_of_freedom[k];
let diff = &xk_k - mean_prior;
let outer = &diff * diff.transpose();
let covariance = (covariance_prior
+ (sk_k * nk_k)
+ outer * (nk_k * mean_precision_prior / mean_precision_k))
.unscale(degrees_of_freedom_k);
covariances.push(covariance.clone());
#[allow(clippy::expect_used)]
let cholesky = Cholesky::new(covariance).expect("Cholesky decomposition failed");
let l = cholesky.l();
let id = DMatrix::identity(n_parameters, n_parameters);
#[allow(clippy::expect_used)]
let solved = l
.solve_lower_triangular(&id)
.expect("Colesky solve_lower_triangular failed");
precisions_cholesky.push(solved.transpose());
}
(degrees_of_freedom, covariances, precisions_cholesky)
}
fn log_det_cholesky(precisions_cholesky: &[DMatrix<Float>], n_parameters: usize) -> DVector<Float> {
DVector::from_iterator(
precisions_cholesky.len(),
precisions_cholesky
.iter()
.map(|chol| (0..n_parameters).map(|i| chol[(i, i)].ln()).sum()),
)
}
fn log_gaussian_prob(
data: &DMatrix<Float>,
means: &DMatrix<Float>,
precisions_cholesky: &[DMatrix<Float>],
) -> DMatrix<Float> {
let n_walkers = data.nrows();
let n_parameters = data.ncols();
let n_components = means.nrows();
let log_det = log_det_cholesky(precisions_cholesky, n_parameters);
let mut log_prob = DMatrix::zeros(n_walkers, n_components);
for k in 0..n_components {
let mu_k = means.row(k);
let prec_chol_k = &precisions_cholesky[k];
for i in 0..n_walkers {
let x_i = data.row(i);
let centered = x_i - mu_k;
let y = ¢ered * prec_chol_k;
let sq_sum = y.map(|val| val * val).sum();
log_prob[(i, k)] = (-0.5 as Float).mul_add(
(n_parameters as Float).mul_add(Float::ln(2.0 * PI), sq_sum),
log_det[k],
);
}
}
log_prob
}
#[allow(clippy::unnecessary_cast)]
fn e_step(
data: &DMatrix<Float>,
means: &DMatrix<Float>,
precisions_cholesky: &[DMatrix<Float>],
mean_precision: &DVector<Float>,
degrees_of_freedom: &DVector<Float>,
weight_concentration: &(DVector<Float>, DVector<Float>),
) -> (Float, DMatrix<Float>) {
let n_walkers = data.nrows();
let n_parameters = data.ncols();
let n_components = means.nrows();
let estimated_log_prob = {
let mut log_gauss = log_gaussian_prob(data, means, precisions_cholesky);
log_gauss.row_iter_mut().for_each(|mut row| {
row -= degrees_of_freedom
.map(|x| 0.5 * (n_parameters as Float) * x.ln())
.transpose()
});
let log_lambda = {
let mut res: DVector<Float> = DVector::zeros(n_components);
for j in 0..n_parameters {
for k in 0..n_components {
res[k] += spec_math::Gamma::digamma(
&((0.5 * (degrees_of_freedom[k] - j as Float)) as f64),
) as Float
}
}
res.map(|r| (n_parameters as Float).mul_add(Float::ln(2.0), r))
};
log_gauss.row_iter_mut().for_each(|mut row| {
row += (0.5 * (&log_lambda - mean_precision.map(|mu| n_parameters as Float / mu)))
.transpose()
});
log_gauss
};
let estimated_log_weights = {
let a = &weight_concentration.0;
let b = &weight_concentration.1;
let n = a.len();
let digamma_sum = (a + b).map(|v| spec_math::Gamma::digamma(&(v as f64)) as Float);
let digamma_a = a.map(|v| spec_math::Gamma::digamma(&(v as f64)) as Float);
let digamma_b = b.map(|v| spec_math::Gamma::digamma(&(v as f64)) as Float);
let mut cumulative = Vec::with_capacity(n);
let mut acc = 0.0;
cumulative.push(0.0);
for i in 0..n - 1 {
acc += digamma_b[i] - digamma_sum[i];
cumulative.push(acc);
}
DVector::from_iterator(
n,
(0..n).map(|i| digamma_a[i] - digamma_sum[i] + cumulative[i]),
)
};
let mut weighted_log_prob = estimated_log_prob;
weighted_log_prob
.row_iter_mut()
.for_each(|mut row| row += &estimated_log_weights.transpose());
let log_prob_norm = DVector::from_iterator(
n_walkers,
weighted_log_prob
.row_iter()
.map(|row| logsumexp::LogSumExp::ln_sum_exp(row.iter())),
);
let mut log_resp = weighted_log_prob;
log_resp
.column_iter_mut()
.for_each(|mut col| col -= &log_prob_norm);
(log_prob_norm.mean(), log_resp)
}
#[derive(Clone)]
struct DPGMResult {
labels: Vec<usize>,
means: Vec<DVector<Float>>,
covariances: Vec<DMatrix<Float>>,
}
#[allow(clippy::unnecessary_cast)]
fn dpgm(
n_components: usize,
ensemble: &EnsembleStatus,
transform: &Option<Box<dyn Transform>>,
rng: &mut Rng,
) -> DPGMResult
where
{
let (n_walkers, _, n_parameters) = ensemble.dimension();
let data = ensemble.get_latest_internal_position_matrix(transform);
let weight_concentration_prior = 1.0 / n_components as Float;
let mean_precision_prior = 1.0;
let mean_prior = ensemble.internal_mean(transform);
let degrees_of_freedom_prior = n_parameters as Float;
let covariance_prior = cov(&data.transpose());
let mut resp: DMatrix<Float> = DMatrix::zeros(n_walkers, n_components);
let labels = kmeans(n_components, &data, rng);
for (i, &cluster_id) in labels.iter().enumerate() {
resp[(i, cluster_id)] = 1.0;
}
let (mut nk, mut xk, mut sk) = estimate_gaussian_parameters(&data, &resp, 1e-6);
let mut weight_concentration = estimate_weights(&nk, weight_concentration_prior);
let (mut mean_precision, mut means) =
estimate_means(&nk, &xk, &mean_prior, mean_precision_prior);
let (mut degrees_of_freedom, mut covariances, mut precisions_cholesky) = estimate_precisions(
&nk,
&xk,
&sk,
degrees_of_freedom_prior,
&covariance_prior,
&mean_prior,
mean_precision_prior,
&mean_precision,
);
let mut lower_bound = Float::NEG_INFINITY;
for _ in 1..=100 {
let prev_lower_bound = lower_bound;
let (_, log_resp) = e_step(
&data,
&means,
&precisions_cholesky,
&mean_precision,
°rees_of_freedom,
&weight_concentration,
);
(nk, xk, sk) = estimate_gaussian_parameters(&data, &log_resp.map(Float::exp), 1e-6);
weight_concentration = estimate_weights(&nk, weight_concentration_prior);
(mean_precision, means) = estimate_means(&nk, &xk, &mean_prior, mean_precision_prior);
(degrees_of_freedom, covariances, precisions_cholesky) = estimate_precisions(
&nk,
&xk,
&sk,
degrees_of_freedom_prior,
&covariance_prior,
&mean_prior,
mean_precision_prior,
&mean_precision,
);
lower_bound = {
let log_det_precisions_cholesky = log_det_cholesky(&precisions_cholesky, n_parameters)
- degrees_of_freedom
.map(Float::ln)
.scale(0.5 * n_parameters as Float);
let log_wishart_norm = {
let mut log_wishart_norm =
degrees_of_freedom.component_mul(&log_det_precisions_cholesky);
log_wishart_norm +=
degrees_of_freedom.scale(0.5 * Float::ln(2.0) * n_parameters as Float);
let gammaln_term: DVector<Float> = degrees_of_freedom.map(|dof| {
(0..n_parameters)
.map(|i| {
spec_math::Gamma::lgamma(&((0.5 * (dof - i as Float)) as f64)) as Float
})
.sum()
});
log_wishart_norm += gammaln_term;
-log_wishart_norm
};
let log_norm_weight = -((0..weight_concentration.0.len())
.map(|i| {
spec_math::Beta::lbeta(
&(weight_concentration.0[i] as f64),
weight_concentration.1[i] as f64,
)
})
.sum::<f64>()) as Float;
(0.5 * (n_parameters as Float)).mul_add(
-mean_precision.map(|mp| mp.ln()).sum(),
-log_resp.map(|lr| lr.exp() * lr).sum() - log_wishart_norm.sum(),
) - log_norm_weight
};
let change = lower_bound - prev_lower_bound;
if change.abs() < 1e-3 {
break;
}
}
let weight_dirichlet_sum = &weight_concentration.0 + &weight_concentration.1;
let tmp0 = &weight_concentration.0.component_div(&weight_dirichlet_sum);
let tmp1 = &weight_concentration.1.component_div(&weight_dirichlet_sum);
let mut prod_vec = Vec::with_capacity(n_components);
prod_vec.push(1.0);
for i in 0..(n_components - 1) {
prod_vec.push(prod_vec[i] * tmp1[i])
}
let mut weights = tmp0.component_mul(&DVector::from_vec(prod_vec));
weights /= weights.sum();
let (_, log_resp) = e_step(
&data,
&means,
&precisions_cholesky,
&mean_precision,
°rees_of_freedom,
&weight_concentration,
);
DPGMResult {
labels: log_resp
.row_iter()
.map(|row| row.transpose().argmax().0)
.collect(),
means: means
.row_iter()
.map(|row| row.transpose())
.collect::<Vec<DVector<Float>>>(),
covariances,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
core::{Callbacks, MaxSteps},
test_functions::Rosenbrock,
traits::Algorithm,
};
fn make_walkers(n_walkers: usize, dim: usize) -> Vec<DVector<Float>> {
(0..n_walkers)
.map(|i| DVector::from_element(dim, i as Float + 1.0))
.collect()
}
#[test]
fn test_essmove_constructors() {
let d = ESSMove::differential(0.5);
assert!(matches!(d.0, ESSMove::Differential));
assert_eq!(d.1, 0.5);
let g = ESSMove::gaussian(1.0);
assert!(matches!(g.0, ESSMove::Gaussian));
let gl = ESSMove::global(2.0);
if let ESSMove::Global {
scale,
rescale_cov,
n_components,
} = gl.0
{
assert_eq!(scale, 1.0);
assert_eq!(rescale_cov, 0.001);
assert_eq!(n_components, 5);
} else {
panic!("expected Global");
}
assert_eq!(gl.1, 2.0);
}
#[test]
fn test_essconfig_defaults_and_builders() {
let walkers = make_walkers(3, 2);
let init = ESSInit::new(walkers).unwrap();
let cfg = ESSConfig::default();
assert_eq!(init.walkers.len(), 3);
assert_eq!(cfg.moves.len(), 1);
assert_eq!(cfg.n_adaptive, 0);
assert_eq!(cfg.max_steps, 10000);
assert_eq!(cfg.mu, 1.0);
let moves = vec![ESSMove::gaussian(1.0), ESSMove::differential(1.0)];
let cfg = cfg
.with_moves(&moves)
.unwrap()
.with_n_adaptive(5)
.with_max_steps(42)
.with_mu(4.1)
.unwrap();
assert_eq!(cfg.moves.len(), 2);
assert_eq!(cfg.n_adaptive, 5);
assert_eq!(cfg.max_steps, 42);
assert!((cfg.mu - 4.1).abs() < 1e-12);
}
#[test]
fn test_ess_rejects_invalid_move_weights() {
let err = match ESSConfig::default()
.with_moves([ESSMove::gaussian(-1.0), ESSMove::differential(1.0)])
{
Err(err) => err,
Ok(_) => panic!("negative ESS move weights should be rejected"),
};
assert!(err.to_string().contains("finite and non-negative"));
let err = match ESSConfig::default().with_moves(Vec::<WeightedESSMove>::new()) {
Err(err) => err,
Ok(_) => panic!("empty ESS move lists should be rejected"),
};
assert!(err.to_string().contains("must not be empty"));
let err = match ESSConfig::default()
.with_moves([ESSMove::gaussian(0.0), ESSMove::differential(0.0)])
{
Err(err) => err,
Ok(_) => panic!("zero-sum ESS move weights should be rejected"),
};
assert!(err.to_string().contains("sum to a positive finite value"));
}
#[test]
fn test_ess_rejects_invalid_walker_inputs() {
let err = match ESSInit::new(Vec::new()) {
Err(err) => err,
Ok(_) => panic!("empty ESS walker lists should be rejected"),
};
assert!(err.to_string().contains("at least 3 walkers"));
let err = match ESSInit::new(vec![
DVector::from_row_slice(&[1.0, 2.0]),
DVector::from_row_slice(&[3.0, 4.0]),
]) {
Err(err) => err,
Ok(_) => panic!("too-few ESS walkers should be rejected"),
};
assert!(err.to_string().contains("at least 3 walkers"));
let err = match ESSInit::new(vec![
DVector::from_row_slice(&[1.0, 2.0]),
DVector::from_row_slice(&[3.0]),
DVector::from_row_slice(&[4.0, 5.0]),
]) {
Err(err) => err,
Ok(_) => panic!("mixed-dimension ESS walkers should be rejected"),
};
assert!(err.to_string().contains("same dimension"));
}
#[test]
fn test_ess_initialize_and_summarize() {
let mut ess = ESS::default();
let walkers = make_walkers(3, 2);
let init = ESSInit::new(walkers).unwrap();
let cfg = ESSConfig::default();
let mut status = EnsembleStatus::default();
let f = Rosenbrock { n: 2 };
ess.initialize(&f, &mut status, &(), &init, &cfg).unwrap();
assert_eq!(status.walkers.len(), 3);
assert_eq!(status.n_f_evals, 3);
let summary = ess.summarize(0, &f, &status, &(), &init, &cfg).unwrap();
assert_eq!(summary.dimension, status.dimension());
assert_eq!(summary.n_f_evals, 3);
}
#[test]
fn test_differential_step_runs() {
let mut ess = ESS::default();
let walkers = make_walkers(3, 2);
let init = ESSInit::new(walkers).unwrap();
let cfg = ESSConfig::default();
let mut status = EnsembleStatus::default();
let f = Rosenbrock { n: 2 };
ess.initialize(&f, &mut status, &(), &init, &cfg).unwrap();
let result = ess.step(0, &f, &mut status, &(), &cfg);
assert!(result.is_ok());
assert!(status.message().to_string().contains("Differential"));
}
#[test]
fn test_gaussian_step_runs() {
let mut ess = ESS::default();
let walkers = make_walkers(6, 2);
let init = ESSInit::new(walkers).unwrap();
let cfg = ESSConfig::default()
.with_moves(vec![ESSMove::gaussian(1.0)])
.unwrap();
let mut status = EnsembleStatus::default();
let f = Rosenbrock { n: 2 };
ess.initialize(&f, &mut status, &(), &init, &cfg).unwrap();
let result = ess.step(0, &f, &mut status, &(), &cfg);
assert!(result.is_ok());
assert!(status.message().to_string().contains("Gaussian"));
}
#[test]
fn test_global_step_runs() {
let mut ess = ESS::default();
let walkers = make_walkers(100, 2);
let init = ESSInit::new(walkers).unwrap();
let cfg = ESSConfig::default()
.with_moves(vec![ESSMove::custom_global(
1.0,
Some(1.0),
Some(0.001),
Some(3),
)
.unwrap()])
.unwrap();
let mut status = EnsembleStatus::default();
let f = Rosenbrock { n: 2 };
ess.initialize(&f, &mut status, &(), &init, &cfg).unwrap();
let result = ess.step(0, &f, &mut status, &(), &cfg);
assert!(result.is_ok());
assert!(status.message().to_string().contains("Global"));
}
#[test]
fn adaptive_mu_stays_finite_when_no_expand_or_contract_updates_occur() {
let mut rng = Rng::with_seed(0);
let mut status = EnsembleStatus::default();
let problem = Rosenbrock { n: 2 };
status.walkers = ESSInit::new(make_walkers(3, 2))
.unwrap()
.walkers
.into_iter()
.map(Walker::new)
.collect();
status.log_density_latest(&problem, &()).unwrap();
let mut mu = 1.5;
ESSMove::Differential
.step(
0,
1,
0,
&mut mu,
&problem,
&None,
&(),
&mut status,
&mut rng,
)
.unwrap();
assert!(mu.is_finite());
assert_eq!(mu, 1.5);
}
#[test]
fn summary_marks_max_steps_as_success_and_counts_evals() {
let mut ess = ESS::default();
let walkers = make_walkers(4, 2);
let init = ESSInit::new(walkers).unwrap();
let cfg = ESSConfig::default();
let result = ess
.process(
&Rosenbrock { n: 2 },
&(),
init,
cfg,
Callbacks::empty().with_terminator(MaxSteps(2)),
)
.unwrap();
assert!(result.n_f_evals >= 4);
assert_eq!(result.n_g_evals, 0);
assert!(result.message.success());
assert!(result
.message
.text
.contains("Maximum number of steps reached"));
}
#[test]
fn rolling_chain_storage_limits_retained_history() {
let walkers = make_walkers(4, 2);
let init = ESSInit::new(walkers).unwrap();
let cfg = ESSConfig::default().with_chain_storage(ChainStorageMode::Rolling { window: 3 });
let mut ess = ESS::default();
let result = ess
.process(
&Rosenbrock { n: 2 },
&(),
init,
cfg,
Callbacks::empty().with_terminator(MaxSteps(4)),
)
.unwrap();
assert_eq!(
result.chain_storage,
ChainStorageMode::Rolling { window: 3 }
);
assert!(result.chain.iter().all(|walker| walker.len() <= 3));
assert_eq!(result.dimension.1, 3);
}
#[test]
fn sampled_chain_storage_downsamples_retained_history() {
let walkers = make_walkers(4, 2);
let init = ESSInit::new(walkers).unwrap();
let cfg = ESSConfig::default().with_chain_storage(ChainStorageMode::Sampled {
keep_every: 2,
max_samples: Some(3),
});
let mut ess = ESS::default();
let result = ess
.process(
&Rosenbrock { n: 2 },
&(),
init,
cfg,
Callbacks::empty().with_terminator(MaxSteps(4)),
)
.unwrap();
assert_eq!(
result.chain_storage,
ChainStorageMode::Sampled {
keep_every: 2,
max_samples: Some(3),
}
);
assert!(result.chain.iter().all(|walker| walker.len() <= 3));
assert_eq!(result.dimension.1, 3);
}
#[test]
fn test_kmeans_two_clusters() {
let mut rng = Rng::with_seed(0);
let points_a = [
DVector::from_vec(vec![0.0, 0.1]).transpose(),
DVector::from_vec(vec![0.2, -0.1]).transpose(),
DVector::from_vec(vec![-0.1, 0.0]).transpose(),
];
let points_b = [
DVector::from_vec(vec![10.0, 10.1]).transpose(),
DVector::from_vec(vec![9.8, 9.9]).transpose(),
DVector::from_vec(vec![10.2, 9.9]).transpose(),
];
let mut rows = Vec::new();
rows.extend(points_a.iter().cloned());
rows.extend(points_b.iter().cloned());
let data = DMatrix::from_rows(&rows);
let labels = super::kmeans(2, &data, &mut rng);
assert_eq!(labels.len(), 6);
assert_eq!(labels[0], labels[1]);
assert_eq!(labels[1], labels[2]);
assert_eq!(labels[3], labels[4]);
assert_eq!(labels[4], labels[5]);
assert_ne!(labels[0], labels[3]);
}
#[test]
#[allow(clippy::field_reassign_with_default)]
fn test_dpgm_recovers_means_covariances_two_blobs() {
use crate::core::utils::SampleFloat;
let mu_a = DVector::from_vec(vec![0.0, 0.0]);
let mu_b = DVector::from_vec(vec![3.0, -2.0]);
let cov_a = DMatrix::from_row_slice(2, 2, &[0.20, 0.05, 0.05, 0.10]);
let cov_b = DMatrix::from_row_slice(2, 2, &[0.30, -0.04, -0.04, 0.50]);
let n_a = 80usize;
let n_b = 70usize;
let mut rng = Rng::with_seed(0);
let mut positions: Vec<Walker> = Vec::with_capacity(n_a + n_b);
for _ in 0..n_a {
let x = rng.mv_normal(&mu_a, &cov_a);
positions.push(Walker::new(x));
}
for _ in 0..n_b {
let x = rng.mv_normal(&mu_b, &cov_b);
positions.push(Walker::new(x));
}
let mut status = EnsembleStatus::default();
status.walkers = positions;
let mut rng2 = Rng::with_seed(0);
let res = super::dpgm(2, &status, &None, &mut rng2);
assert_eq!(res.labels.len(), n_a + n_b);
assert_eq!(res.means.len(), 2);
assert_eq!(res.covariances.len(), 2);
assert_eq!(res.covariances[0].nrows(), 2);
assert_eq!(res.covariances[0].ncols(), 2);
let d0_a = (&res.means[0] - &mu_a).norm();
let d1_a = (&res.means[1] - &mu_a).norm();
let (idx_a, idx_b) = if d0_a <= d1_a { (0, 1) } else { (1, 0) };
assert!((&res.means[idx_a] - &mu_a).norm() < 0.25);
assert!((&res.means[idx_b] - &mu_b).norm() < 0.25);
let cov_a_hat = &res.covariances[idx_a];
let cov_b_hat = &res.covariances[idx_b];
for i in 0..2 {
let a_true = cov_a[(i, i)];
let a_est = cov_a_hat[(i, i)];
assert!((a_est - a_true).abs() / a_true < 0.35);
let b_true = cov_b[(i, i)];
let b_est = cov_b_hat[(i, i)];
assert!((b_est - b_true).abs() / b_true < 0.35);
}
assert!((cov_a_hat[(0, 1)] - cov_a[(0, 1)]).abs() < 0.1);
assert!((cov_b_hat[(0, 1)] - cov_b[(0, 1)]).abs() < 0.1);
let count_a = res.labels[..n_a].iter().filter(|&&l| l == idx_a).count();
let count_b = res.labels[n_a..].iter().filter(|&&l| l == idx_b).count();
assert!(count_a as Float > 0.9 * n_a as Float);
assert!(count_b as Float > 0.9 * n_b as Float);
}
}