1use crate::{
2 algorithms::mcmc::{EnsembleStatus, Walker},
3 core::{
4 utils::{generate_random_vector_in_limits, RandChoice, SampleFloat},
5 MCMCSummary, Point,
6 },
7 traits::{Algorithm, LogDensity, Status, SupportsTransform, Transform},
8 DMatrix, DVector, Float, PI,
9};
10use fastrand::Rng;
11use nalgebra::Cholesky;
12use parking_lot::RwLock;
13use std::sync::Arc;
14
15#[derive(Copy, Clone)]
21pub enum ESSMove {
22 Differential,
24 Gaussian,
26 Global {
28 scale: Float,
30 rescale_cov: Float,
32 n_components: usize,
34 },
35}
36impl ESSMove {
37 pub const fn differential(weight: Float) -> WeightedESSMove {
39 (Self::Differential, weight)
40 }
41 pub const fn gaussian(weight: Float) -> WeightedESSMove {
43 (Self::Gaussian, weight)
44 }
45 pub fn global(
47 weight: Float,
48 scale: Option<Float>,
49 rescale_cov: Option<Float>,
50 n_components: Option<usize>,
51 ) -> WeightedESSMove {
52 (
53 Self::Global {
54 scale: scale.unwrap_or(1.0),
55 rescale_cov: rescale_cov.unwrap_or(0.001),
56 n_components: n_components.unwrap_or(5),
57 },
58 weight,
59 )
60 }
61 #[allow(clippy::too_many_arguments)]
62 fn step<P, U, E>(
63 &self,
64 step: usize,
65 n_adaptive: usize,
66 max_steps: usize,
67 mu: &mut Float,
68 problem: &P,
69 transform: &Option<Box<dyn Transform>>,
70 args: &U,
71 ensemble: &mut EnsembleStatus,
72 rng: &mut Rng,
73 ) -> Result<(), E>
74 where
75 P: LogDensity<U, E>,
76 {
77 let mut positions = Vec::with_capacity(ensemble.len());
78 match self {
79 Self::Differential => {
80 ensemble.update_message("Differential Move");
81 }
82 Self::Gaussian => {
83 ensemble.update_message("Gaussian Move");
84 }
85 Self::Global {
86 scale,
87 rescale_cov,
88 n_components,
89 } => {
90 ensemble.update_message(&format!(
91 "Global Move (scale = {}, rescale_cov = {}, n_components = {})",
92 scale, rescale_cov, n_components
93 ));
94 }
95 }
96 let mut n_expand = 0;
97 let mut n_contract = 0;
98 let mut dpgm_result = None;
99 for (i, walker) in ensemble.iter().enumerate() {
100 let x_k = walker.get_latest();
101 let eta = match self {
102 Self::Differential => {
103 let s = &ensemble.get_compliment_walkers(i, 2, rng);
106 let x_l = s[0].get_latest();
107 let x_m = s[1].get_latest();
108 let eta = (transform.to_internal(&x_l.read().x).as_ref()
109 - transform.to_internal(&x_m.read().x).as_ref())
110 .scale(*mu);
111 eta
112 }
113 Self::Gaussian => {
114 let x_s = ensemble.internal_mean_compliment(i, transform);
123 ensemble
124 .iter_compliment(i)
125 .map(|x_l| {
126 (transform.to_internal(&x_l.read().x).as_ref() - &x_s)
127 .scale(rng.normal(0.0, 1.0))
128 })
129 .sum::<DVector<Float>>()
130 .scale(2.0 * *mu)
131 }
132 Self::Global {
133 scale,
134 rescale_cov,
135 n_components,
136 } => {
137 let dpgm = dpgm_result
138 .get_or_insert_with(|| dpgm(*n_components, ensemble, transform, rng));
139 let labels = &dpgm.labels;
140 let means = &dpgm.means;
141 let covariances = &dpgm.covariances;
142 let indices = rng.choose_multiple(labels.iter(), 2);
143 let a = indices[0];
144 let b = indices[1];
145 if a == b {
148 rng.mv_normal(&means[*a], &covariances[*a])
149 .scale(2.0 * scale)
150 } else {
151 (rng.mv_normal(&means[*a], &covariances[*a].scale(*rescale_cov))
152 - rng.mv_normal(&means[*b], &covariances[*b].scale(*rescale_cov)))
153 .scale(2.0)
154 }
155 }
156 };
157 let y = x_k.read().fx_checked() + rng.float().ln();
159 let x_k_internal = transform.to_internal(&x_k.read().x).into_owned();
160 let mut l = -rng.float();
163 let mut p_l = Point::from(&x_k_internal + eta.scale(l));
164 p_l.log_density_transformed(problem, transform, args)?;
165 let mut r = l + 1.0;
167 let mut p_r = Point::from(&x_k_internal + eta.scale(r));
168 p_r.log_density_transformed(problem, transform, args)?;
169 while y < p_l.fx_checked() && n_expand < max_steps {
171 l -= 1.0;
173 p_l.set_position(&x_k_internal + eta.scale(l));
174 p_l.log_density_transformed(problem, transform, args)?;
175 n_expand += 1;
177 }
178 while y < p_r.fx_checked() && n_expand < max_steps {
180 r += 1.0;
182 p_r.set_position(&x_k_internal + eta.scale(r));
183 p_r.log_density_transformed(problem, transform, args)?;
184 n_expand += 1;
186 }
187 let xprime = loop {
189 let xprime = rng.range(l, r);
191 let mut p_yprime = Point::from(&x_k_internal + eta.scale(xprime));
193 p_yprime.log_density_transformed(problem, transform, args)?;
194 if y < p_yprime.fx_checked() || n_contract >= max_steps {
195 break xprime;
197 }
198 if xprime < 0.0 {
199 l = xprime;
201 } else {
202 r = xprime;
204 }
205 n_contract += 1;
207 };
208 let mut proposal = Point::from(x_k_internal + eta.scale(xprime));
210 proposal.log_density_transformed(problem, transform, args)?;
211 positions.push(Arc::new(RwLock::new(proposal.to_external(transform))))
212 }
213 if step <= n_adaptive {
215 *mu *= 2.0 * (n_expand as Float) / (n_expand + n_contract) as Float
216 }
217 ensemble.push(positions);
218 Ok(())
219 }
220}
221
222#[derive(Clone)]
224pub struct ESSConfig {
225 transform: Option<Box<dyn Transform>>,
226 walkers: Vec<Walker>,
227 moves: Vec<WeightedESSMove>,
228 n_adaptive: usize,
229 max_steps: usize,
230 mu: Float,
231}
232impl ESSConfig {
233 pub fn new(x0: Vec<DVector<Float>>) -> Self {
240 Self {
241 transform: None,
242 walkers: x0.into_iter().map(Walker::new).collect(),
243 moves: vec![ESSMove::differential(1.0)],
244 n_adaptive: 0,
245 max_steps: 10000,
246 mu: 1.0,
247 }
248 }
249 pub fn with_moves<T: AsRef<[WeightedESSMove]>>(mut self, moves: T) -> Self {
251 self.moves = moves.as_ref().to_vec();
252 self
253 }
254 pub const fn with_n_adaptive(mut self, n_adaptive: usize) -> Self {
256 self.n_adaptive = n_adaptive;
257 self
258 }
259 pub const fn with_max_steps(mut self, max_steps: usize) -> Self {
261 self.max_steps = max_steps;
262 self
263 }
264 pub const fn with_mu(mut self, mu: Float) -> Self {
266 self.mu = mu;
267 self
268 }
269}
270
271impl SupportsTransform for ESSConfig {
272 fn get_transform_mut(&mut self) -> &mut Option<Box<dyn Transform>> {
273 &mut self.transform
274 }
275}
276
277#[derive(Clone)]
283pub struct ESS {
284 rng: Rng,
285 mu: Float,
286}
287impl Default for ESS {
288 fn default() -> Self {
289 Self::new(Some(0))
290 }
291}
292
293pub type WeightedESSMove = (ESSMove, Float);
295
296impl ESS {
297 pub fn new(seed: Option<u64>) -> Self {
299 Self {
300 rng: seed.map_or_else(fastrand::Rng::new, fastrand::Rng::with_seed),
301 mu: 1.0,
302 }
303 }
304}
305impl<P, U, E> Algorithm<P, EnsembleStatus, U, E> for ESS
306where
307 P: LogDensity<U, E>,
308{
309 type Summary = MCMCSummary;
310 type Config = ESSConfig;
311 fn initialize(
312 &mut self,
313 problem: &P,
314 status: &mut EnsembleStatus,
315 args: &U,
316 config: &Self::Config,
317 ) -> Result<(), E> {
318 status.walkers = config.walkers.clone();
319 self.mu = config.mu;
320 status.log_density_latest(problem, args)
321 }
322
323 fn step(
324 &mut self,
325 current_step: usize,
326 problem: &P,
327 status: &mut EnsembleStatus,
328 args: &U,
329 config: &Self::Config,
330 ) -> Result<(), E> {
331 let step_type_index = self
332 .rng
333 .choice_weighted(&config.moves.iter().map(|s| s.1).collect::<Vec<Float>>())
334 .unwrap_or(0);
335 let step_type = config.moves[step_type_index].0;
336 step_type.step(
337 current_step,
338 config.n_adaptive,
339 config.max_steps,
340 &mut self.mu,
341 problem,
342 &config.transform,
343 args,
344 status,
345 &mut self.rng,
346 )
347 }
348
349 fn summarize(
350 &self,
351 _current_step: usize,
352 _problem: &P,
353 status: &EnsembleStatus,
354 _args: &U,
355 _config: &Self::Config,
356 ) -> Result<Self::Summary, E> {
357 Ok(MCMCSummary {
358 bounds: None,
359 parameter_names: None,
360 message: status.message().to_string(),
361 chain: status.get_chain(None, None),
362 cost_evals: status.n_f_evals,
363 gradient_evals: status.n_g_evals,
364 converged: status.converged(),
365 dimension: status.dimension(),
366 })
367 }
368}
369
370#[allow(clippy::unwrap_used)]
379fn kmeans(n_clusters: usize, data: &DMatrix<Float>, rng: &mut Rng) -> Vec<usize> {
380 let n_walkers = data.nrows();
381 let n_parameters = data.ncols();
382 let limits = data
383 .column_iter()
384 .map(|col| (col.min(), col.max()))
385 .collect::<Vec<_>>();
386 let mut centroids: Vec<DVector<Float>> = (0..n_clusters)
387 .map(|_| generate_random_vector_in_limits(&limits, rng))
388 .collect();
389 let mut labels = vec![0; n_walkers];
390 for _ in 0..50 {
391 for (i, walker) in data.row_iter().enumerate() {
392 labels[i] = centroids
393 .iter()
394 .enumerate()
395 .min_by(|(_, a), (_, b)| {
396 (walker.transpose() - *a)
397 .norm_squared()
398 .partial_cmp(&(walker.transpose() - *b).norm_squared())
399 .unwrap()
400 })
401 .map(|(j, _)| j)
402 .unwrap();
403 }
404 for (j, centroid) in centroids.iter_mut().enumerate() {
405 let mut sum = DVector::zeros(n_parameters);
406 let mut count = 0;
407 for (l, w) in labels.iter().zip(data.row_iter()) {
408 if *l == j {
409 sum += w.transpose();
410 count += 1;
411 }
412 }
413 if count > 0 {
414 sum /= count as Float;
415 }
416 *centroid = sum;
417 }
418 }
419 labels
420}
421
422fn cov(m: &DMatrix<Float>) -> DMatrix<Float> {
430 let mean: DVector<Float> = m
431 .row_iter()
432 .map(|row| row.mean())
433 .collect::<Vec<Float>>()
434 .into();
435 let centered = m.clone() - mean * DMatrix::from_element(1, m.ncols(), 1.0);
436 ¢ered * centered.transpose() / (m.ncols() as Float - 1.0)
437}
438
439fn estimate_gaussian_parameters(
449 data: &DMatrix<Float>,
450 resp: &DMatrix<Float>,
451 reg_covar: Float,
452) -> (DVector<Float>, DMatrix<Float>, Vec<DMatrix<Float>>) {
453 assert_eq!(data.nrows(), resp.nrows());
454
455 let nk = resp.row_sum_tr().add_scalar(10.0 * Float::EPSILON);
456 let mut means: DMatrix<Float> = resp.transpose() * data;
457 means.column_iter_mut().for_each(|mut c| {
458 c.component_div_assign(&nk);
459 });
460 let cov = (0..means.nrows())
461 .map(|k| {
462 let mean_k = means.row(k);
463 let diff =
464 DMatrix::from_rows(&data.row_iter().map(|row| row - mean_k).collect::<Vec<_>>());
465 let weighted_diff_t = DMatrix::from_columns(
466 &diff
467 .row_iter()
468 .zip(resp.column(k).iter())
469 .map(|(d, &r)| d.scale(r).transpose())
470 .collect::<Vec<_>>(),
471 );
472 let mut cov = (&weighted_diff_t * &diff).unscale(nk[k]);
473 for i in 0..data.ncols() {
474 cov[(i, i)] += reg_covar;
475 }
476 cov
477 })
478 .collect();
479 (nk, means, cov)
480}
481
482fn estimate_weights(
489 nk: &DVector<Float>,
490 weight_concentration_prior: Float,
491) -> (DVector<Float>, DVector<Float>) {
492 let n_components = nk.len();
493 (nk.map(|x| x + 1.0), {
494 let reversed: Vec<Float> = nk.iter().rev().copied().collect();
495 let mut cumulative_sum = vec![0.0; n_components];
496 let mut sum: Float = 0.0;
497 for (i, &val) in reversed.iter().enumerate() {
498 sum += val;
499 cumulative_sum[i] = sum;
500 }
501 let mut tail = cumulative_sum[..n_components - 1]
502 .iter()
503 .rev()
504 .copied()
505 .collect::<Vec<Float>>();
506 tail.push(0.0);
507 DVector::from_iterator(
508 n_components,
509 tail.into_iter().map(|x| x + weight_concentration_prior),
510 )
511 })
512}
513
514fn estimate_means(
523 nk: &DVector<Float>,
524 xk: &DMatrix<Float>,
525 mean_prior: &DVector<Float>,
526 mean_precision_prior: Float,
527) -> (DVector<Float>, DMatrix<Float>) {
528 assert_eq!(nk.len(), xk.nrows());
529 assert_eq!(mean_prior.len(), xk.ncols());
530 let mean_precision = nk.map(|x| x + mean_precision_prior);
531 let mut means = DMatrix::zeros(xk.nrows(), xk.ncols());
532 let nkxk: DMatrix<Float> = DMatrix::from_columns(
533 &xk.column_iter()
534 .map(|x| x.component_mul(nk))
535 .collect::<Vec<_>>(),
536 );
537 means.row_iter_mut().for_each(|mut row| {
538 row += mean_prior.transpose().scale(mean_precision_prior);
539 });
540 means += nkxk;
541 means.column_iter_mut().for_each(|mut col| {
542 col.component_div_assign(&mean_precision);
543 });
544 (mean_precision, means)
545}
546
547#[allow(clippy::too_many_arguments)]
561fn estimate_precisions(
562 nk: &DVector<Float>,
563 xk: &DMatrix<Float>,
564 sk: &[DMatrix<Float>],
565 degrees_of_freedom_prior: Float,
566 covariance_prior: &DMatrix<Float>,
567 mean_prior: &DVector<Float>,
568 mean_precision_prior: Float,
569 mean_precision: &DVector<Float>,
570) -> (DVector<Float>, Vec<DMatrix<Float>>, Vec<DMatrix<Float>>) {
571 let n_components = nk.len();
572 let n_parameters = mean_prior.len();
573
574 assert_eq!(xk.nrows(), n_components);
575 assert_eq!(xk.ncols(), n_parameters);
576 assert_eq!(covariance_prior.nrows(), n_parameters);
577 assert_eq!(covariance_prior.ncols(), n_parameters);
578 assert_eq!(mean_precision.len(), n_components);
579
580 let degrees_of_freedom = nk.map(|x| x + degrees_of_freedom_prior);
581
582 let mut covariances = Vec::with_capacity(n_components);
583 let mut precisions_cholesky = Vec::with_capacity(n_components);
584
585 for k in 0..n_components {
586 let nk_k = nk[k];
587 let xk_k = xk.row(k).transpose();
588 let sk_k = &sk[k];
589 let mean_precision_k = mean_precision[k];
590 let degrees_of_freedom_k = degrees_of_freedom[k];
591 let diff = &xk_k - mean_prior;
592 let outer = &diff * diff.transpose();
593 let covariance = (covariance_prior
594 + (sk_k * nk_k)
595 + outer * (nk_k * mean_precision_prior / mean_precision_k))
596 .unscale(degrees_of_freedom_k);
597 covariances.push(covariance.clone());
598 #[allow(clippy::expect_used)]
599 let cholesky = Cholesky::new(covariance).expect("Cholesky decomposition failed");
600 let l = cholesky.l();
601 let id = DMatrix::identity(n_parameters, n_parameters);
602 #[allow(clippy::expect_used)]
603 let solved = l
604 .solve_lower_triangular(&id)
605 .expect("Colesky solve_lower_triangular failed");
606 precisions_cholesky.push(solved.transpose());
607 }
608 (degrees_of_freedom, covariances, precisions_cholesky)
609}
610
611fn log_det_cholesky(precisions_cholesky: &[DMatrix<Float>], n_parameters: usize) -> DVector<Float> {
617 DVector::from_iterator(
618 precisions_cholesky.len(),
619 precisions_cholesky
620 .iter()
621 .map(|chol| (0..n_parameters).map(|i| chol[(i, i)].ln()).sum()),
622 )
623}
624
625fn log_gaussian_prob(
633 data: &DMatrix<Float>,
634 means: &DMatrix<Float>,
635 precisions_cholesky: &[DMatrix<Float>],
636) -> DMatrix<Float> {
637 let n_walkers = data.nrows();
638 let n_parameters = data.ncols();
639 let n_components = means.nrows();
640
641 let log_det = log_det_cholesky(precisions_cholesky, n_parameters);
642 let mut log_prob = DMatrix::zeros(n_walkers, n_components);
643 for k in 0..n_components {
644 let mu_k = means.row(k);
645 let prec_chol_k = &precisions_cholesky[k];
646
647 for i in 0..n_walkers {
648 let x_i = data.row(i);
649 let centered = x_i - mu_k;
650 let y = ¢ered * prec_chol_k;
651 let sq_sum = y.map(|val| val * val).sum();
652 log_prob[(i, k)] = (-0.5 as Float).mul_add(
653 (n_parameters as Float).mul_add(Float::ln(2.0 * PI), sq_sum),
654 log_det[k],
655 );
656 }
657 }
658 log_prob
659}
660
661#[allow(clippy::unnecessary_cast)]
670fn e_step(
671 data: &DMatrix<Float>,
672 means: &DMatrix<Float>,
673 precisions_cholesky: &[DMatrix<Float>],
674 mean_precision: &DVector<Float>,
675 degrees_of_freedom: &DVector<Float>,
676 weight_concentration: &(DVector<Float>, DVector<Float>),
677) -> (Float, DMatrix<Float>) {
678 let n_walkers = data.nrows();
679 let n_parameters = data.ncols();
680 let n_components = means.nrows();
681 let estimated_log_prob = {
682 let mut log_gauss = log_gaussian_prob(data, means, precisions_cholesky);
683 log_gauss.row_iter_mut().for_each(|mut row| {
684 row -= degrees_of_freedom
685 .map(|x| 0.5 * (n_parameters as Float) * x.ln())
686 .transpose()
687 });
688 let log_lambda = {
689 let mut res: DVector<Float> = DVector::zeros(n_components);
690 for j in 0..n_parameters {
691 for k in 0..n_components {
692 res[k] += spec_math::Gamma::digamma(
693 &((0.5 * (degrees_of_freedom[k] - j as Float)) as f64),
694 ) as Float
695 }
696 }
697 res.map(|r| (n_parameters as Float).mul_add(Float::ln(2.0), r))
698 };
699 log_gauss.row_iter_mut().for_each(|mut row| {
700 row += (0.5 * (&log_lambda - mean_precision.map(|mu| n_parameters as Float / mu)))
701 .transpose()
702 });
703 log_gauss
704 };
705 let estimated_log_weights = {
706 let a = &weight_concentration.0;
707 let b = &weight_concentration.1;
708 let n = a.len();
709 let digamma_sum = (a + b).map(|v| spec_math::Gamma::digamma(&(v as f64)) as Float);
710 let digamma_a = a.map(|v| spec_math::Gamma::digamma(&(v as f64)) as Float);
711 let digamma_b = b.map(|v| spec_math::Gamma::digamma(&(v as f64)) as Float);
712 let mut cumulative = Vec::with_capacity(n);
713 let mut acc = 0.0;
714 cumulative.push(0.0);
715 for i in 0..n - 1 {
716 acc += digamma_b[i] - digamma_sum[i];
717 cumulative.push(acc);
718 }
719 DVector::from_iterator(
720 n,
721 (0..n).map(|i| digamma_a[i] - digamma_sum[i] + cumulative[i]),
722 )
723 };
724 let mut weighted_log_prob = estimated_log_prob;
725 weighted_log_prob
726 .row_iter_mut()
727 .for_each(|mut row| row += &estimated_log_weights.transpose());
728 let log_prob_norm = DVector::from_iterator(
729 n_walkers,
730 weighted_log_prob
731 .row_iter()
732 .map(|row| logsumexp::LogSumExp::ln_sum_exp(row.iter())),
733 );
734 let mut log_resp = weighted_log_prob;
735 log_resp
736 .column_iter_mut()
737 .for_each(|mut col| col -= &log_prob_norm);
738 (log_prob_norm.mean(), log_resp)
739}
740
741#[derive(Clone)]
742struct DPGMResult {
743 labels: Vec<usize>,
745 means: Vec<DVector<Float>>,
747 covariances: Vec<DMatrix<Float>>,
749}
750
751#[allow(clippy::unnecessary_cast)]
766fn dpgm(
767 n_components: usize,
768 ensemble: &EnsembleStatus,
769 transform: &Option<Box<dyn Transform>>,
770 rng: &mut Rng,
771) -> DPGMResult
772where
773{
774 let (n_walkers, _, n_parameters) = ensemble.dimension();
775 let data = ensemble.get_latest_internal_position_matrix(transform);
776 let weight_concentration_prior = 1.0 / n_components as Float;
777 let mean_precision_prior = 1.0;
778 let mean_prior = ensemble.internal_mean(transform);
779 let degrees_of_freedom_prior = n_parameters as Float;
780 let covariance_prior = cov(&data.transpose());
781
782 let mut resp: DMatrix<Float> = DMatrix::zeros(n_walkers, n_components);
783 let labels = kmeans(n_components, &data, rng);
784 for (i, &cluster_id) in labels.iter().enumerate() {
785 resp[(i, cluster_id)] = 1.0;
786 }
787 let (mut nk, mut xk, mut sk) = estimate_gaussian_parameters(&data, &resp, 1e-6);
788 let mut weight_concentration = estimate_weights(&nk, weight_concentration_prior);
789 let (mut mean_precision, mut means) =
790 estimate_means(&nk, &xk, &mean_prior, mean_precision_prior);
791 let (mut degrees_of_freedom, mut covariances, mut precisions_cholesky) = estimate_precisions(
792 &nk,
793 &xk,
794 &sk,
795 degrees_of_freedom_prior,
796 &covariance_prior,
797 &mean_prior,
798 mean_precision_prior,
799 &mean_precision,
800 );
801 let mut lower_bound = Float::NEG_INFINITY;
802 for _ in 1..=100 {
803 let prev_lower_bound = lower_bound;
804 let (_, log_resp) = e_step(
805 &data,
806 &means,
807 &precisions_cholesky,
808 &mean_precision,
809 °rees_of_freedom,
810 &weight_concentration,
811 );
812 (nk, xk, sk) = estimate_gaussian_parameters(&data, &log_resp.map(Float::exp), 1e-6);
813 weight_concentration = estimate_weights(&nk, weight_concentration_prior);
814 (mean_precision, means) = estimate_means(&nk, &xk, &mean_prior, mean_precision_prior);
815 (degrees_of_freedom, covariances, precisions_cholesky) = estimate_precisions(
816 &nk,
817 &xk,
818 &sk,
819 degrees_of_freedom_prior,
820 &covariance_prior,
821 &mean_prior,
822 mean_precision_prior,
823 &mean_precision,
824 );
825 lower_bound = {
826 let log_det_precisions_cholesky = log_det_cholesky(&precisions_cholesky, n_parameters)
827 - degrees_of_freedom
828 .map(Float::ln)
829 .scale(0.5 * n_parameters as Float);
830 let log_wishart_norm = {
831 let mut log_wishart_norm =
832 degrees_of_freedom.component_mul(&log_det_precisions_cholesky);
833 log_wishart_norm +=
834 degrees_of_freedom.scale(0.5 * Float::ln(2.0) * n_parameters as Float);
835
836 let gammaln_term: DVector<Float> = degrees_of_freedom.map(|dof| {
837 (0..n_parameters)
838 .map(|i| {
839 spec_math::Gamma::lgamma(&((0.5 * (dof - i as Float)) as f64)) as Float
840 })
841 .sum()
842 });
843 log_wishart_norm += gammaln_term;
844 -log_wishart_norm
845 };
846 let log_norm_weight = -((0..weight_concentration.0.len())
847 .map(|i| {
848 spec_math::Beta::lbeta(
849 &(weight_concentration.0[i] as f64),
850 weight_concentration.1[i] as f64,
851 )
852 })
853 .sum::<f64>()) as Float;
854 (0.5 * (n_parameters as Float)).mul_add(
855 -mean_precision.map(|mp| mp.ln()).sum(),
856 -log_resp.map(|lr| lr.exp() * lr).sum() - log_wishart_norm.sum(),
857 ) - log_norm_weight
858 };
859 let change = lower_bound - prev_lower_bound;
860 if change.abs() < 1e-3 {
861 break;
862 }
863 }
864 let weight_dirichlet_sum = &weight_concentration.0 + &weight_concentration.1;
865 let tmp0 = &weight_concentration.0.component_div(&weight_dirichlet_sum);
866 let tmp1 = &weight_concentration.1.component_div(&weight_dirichlet_sum);
867 let mut prod_vec = Vec::with_capacity(n_components);
868 prod_vec.push(1.0);
869 for i in 0..(n_components - 1) {
870 prod_vec.push(prod_vec[i] * tmp1[i])
871 }
872 let mut weights = tmp0.component_mul(&DVector::from_vec(prod_vec));
873 weights /= weights.sum();
874 let (_, log_resp) = e_step(
878 &data,
879 &means,
880 &precisions_cholesky,
881 &mean_precision,
882 °rees_of_freedom,
883 &weight_concentration,
884 );
885 DPGMResult {
886 labels: log_resp
887 .row_iter()
888 .map(|row| row.transpose().argmax().0)
889 .collect(),
890 means: means
891 .row_iter()
892 .map(|row| row.transpose())
893 .collect::<Vec<DVector<Float>>>(),
894 covariances,
895 }
896}
897
898#[cfg(test)]
899mod tests {
900 use super::*;
901 use crate::test_functions::Rosenbrock;
902
903 fn make_walkers(n_walkers: usize, dim: usize) -> Vec<DVector<Float>> {
904 (0..n_walkers)
905 .map(|i| DVector::from_element(dim, i as Float + 1.0))
906 .collect()
907 }
908
909 #[test]
910 fn test_essmove_constructors() {
911 let d = ESSMove::differential(0.5);
912 assert!(matches!(d.0, ESSMove::Differential));
913 assert_eq!(d.1, 0.5);
914
915 let g = ESSMove::gaussian(1.0);
916 assert!(matches!(g.0, ESSMove::Gaussian));
917
918 let gl = ESSMove::global(2.0, None, None, None);
919 if let ESSMove::Global {
920 scale,
921 rescale_cov,
922 n_components,
923 } = gl.0
924 {
925 assert_eq!(scale, 1.0);
926 assert_eq!(rescale_cov, 0.001);
927 assert_eq!(n_components, 5);
928 } else {
929 panic!("expected Global");
930 }
931 assert_eq!(gl.1, 2.0);
932 }
933
934 #[test]
935 fn test_essconfig_defaults_and_builders() {
936 let walkers = make_walkers(3, 2);
937 let cfg = ESSConfig::new(walkers);
938 assert_eq!(cfg.walkers.len(), 3);
939 assert_eq!(cfg.moves.len(), 1);
940 assert_eq!(cfg.n_adaptive, 0);
941 assert_eq!(cfg.max_steps, 10000);
942 assert_eq!(cfg.mu, 1.0);
943
944 let moves = vec![ESSMove::gaussian(1.0), ESSMove::differential(1.0)];
945 let cfg = cfg
946 .with_moves(&moves)
947 .with_n_adaptive(5)
948 .with_max_steps(42)
949 .with_mu(4.1);
950
951 assert_eq!(cfg.moves.len(), 2);
952 assert_eq!(cfg.n_adaptive, 5);
953 assert_eq!(cfg.max_steps, 42);
954 assert!((cfg.mu - 4.1).abs() < 1e-12);
955 }
956
957 #[test]
958 fn test_ess_initialize_and_summarize() {
959 let mut ess = ESS::default();
960 let walkers = make_walkers(3, 2);
961 let cfg = ESSConfig::new(walkers);
962 let mut status = EnsembleStatus::default();
963 let f = Rosenbrock { n: 2 };
964
965 ess.initialize(&f, &mut status, &(), &cfg).unwrap();
966 assert_eq!(status.walkers.len(), 3);
967
968 let summary = ess.summarize(0, &f, &status, &(), &cfg).unwrap();
969 assert_eq!(summary.dimension, status.dimension());
970 }
971
972 #[test]
973 fn test_differential_step_runs() {
974 let mut ess = ESS::default();
975 let walkers = make_walkers(3, 2);
976 let cfg = ESSConfig::new(walkers);
977 let mut status = EnsembleStatus::default();
978 let f = Rosenbrock { n: 2 };
979 ess.initialize(&f, &mut status, &(), &cfg).unwrap();
980
981 let result = ess.step(0, &f, &mut status, &(), &cfg);
982 assert!(result.is_ok());
983 assert!(status.message().contains("Differential"));
984 }
985
986 #[test]
987 fn test_gaussian_step_runs() {
988 let mut ess = ESS::default();
989 let walkers = make_walkers(6, 2);
990 let cfg = ESSConfig::new(walkers).with_moves(vec![ESSMove::gaussian(1.0)]);
991 let mut status = EnsembleStatus::default();
992 let f = Rosenbrock { n: 2 };
993
994 ess.initialize(&f, &mut status, &(), &cfg).unwrap();
995 let result = ess.step(0, &f, &mut status, &(), &cfg);
996 assert!(result.is_ok());
997 assert!(status.message().contains("Gaussian"));
998 }
999
1000 #[test]
1001 fn test_global_step_runs() {
1002 let mut ess = ESS::default();
1003 let walkers = make_walkers(100, 2);
1004 let cfg = ESSConfig::new(walkers).with_moves(vec![ESSMove::global(
1005 1.0,
1006 Some(1.0),
1007 Some(0.001),
1008 Some(3),
1009 )]);
1010 let mut status = EnsembleStatus::default();
1011 let f = Rosenbrock { n: 2 };
1012
1013 ess.initialize(&f, &mut status, &(), &cfg).unwrap();
1014 let result = ess.step(0, &f, &mut status, &(), &cfg);
1015 assert!(result.is_ok());
1016 assert!(status.message().contains("Global"));
1017 }
1018
1019 #[test]
1020 fn test_kmeans_two_clusters() {
1021 let mut rng = Rng::with_seed(0);
1022
1023 let points_a = [
1024 DVector::from_vec(vec![0.0, 0.1]).transpose(),
1025 DVector::from_vec(vec![0.2, -0.1]).transpose(),
1026 DVector::from_vec(vec![-0.1, 0.0]).transpose(),
1027 ];
1028 let points_b = [
1029 DVector::from_vec(vec![10.0, 10.1]).transpose(),
1030 DVector::from_vec(vec![9.8, 9.9]).transpose(),
1031 DVector::from_vec(vec![10.2, 9.9]).transpose(),
1032 ];
1033
1034 let mut rows = Vec::new();
1035 rows.extend(points_a.iter().cloned());
1036 rows.extend(points_b.iter().cloned());
1037 let data = DMatrix::from_rows(&rows);
1038
1039 let labels = super::kmeans(2, &data, &mut rng);
1040 assert_eq!(labels.len(), 6);
1041
1042 assert_eq!(labels[0], labels[1]);
1043 assert_eq!(labels[1], labels[2]);
1044 assert_eq!(labels[3], labels[4]);
1045 assert_eq!(labels[4], labels[5]);
1046 assert_ne!(labels[0], labels[3]);
1047 }
1048
1049 #[test]
1050 #[allow(clippy::field_reassign_with_default)]
1051 fn test_dpgm_recovers_means_covariances_two_blobs() {
1052 use crate::core::utils::SampleFloat;
1053
1054 let mu_a = DVector::from_vec(vec![0.0, 0.0]);
1055 let mu_b = DVector::from_vec(vec![3.0, -2.0]);
1056 let cov_a = DMatrix::from_row_slice(2, 2, &[0.20, 0.05, 0.05, 0.10]);
1057 let cov_b = DMatrix::from_row_slice(2, 2, &[0.30, -0.04, -0.04, 0.50]);
1058
1059 let n_a = 80usize;
1060 let n_b = 70usize;
1061 let mut rng = Rng::with_seed(0);
1062
1063 let mut positions: Vec<Walker> = Vec::with_capacity(n_a + n_b);
1064 for _ in 0..n_a {
1065 let x = rng.mv_normal(&mu_a, &cov_a);
1066 positions.push(Walker::new(x));
1067 }
1068 for _ in 0..n_b {
1069 let x = rng.mv_normal(&mu_b, &cov_b);
1070 positions.push(Walker::new(x));
1071 }
1072
1073 let mut status = EnsembleStatus::default();
1074 status.walkers = positions;
1075
1076 let mut rng2 = Rng::with_seed(0);
1077 let res = super::dpgm(2, &status, &None, &mut rng2);
1078
1079 assert_eq!(res.labels.len(), n_a + n_b);
1080 assert_eq!(res.means.len(), 2);
1081 assert_eq!(res.covariances.len(), 2);
1082 assert_eq!(res.covariances[0].nrows(), 2);
1083 assert_eq!(res.covariances[0].ncols(), 2);
1084
1085 let d0_a = (&res.means[0] - &mu_a).norm();
1086 let d1_a = (&res.means[1] - &mu_a).norm();
1087 let (idx_a, idx_b) = if d0_a <= d1_a { (0, 1) } else { (1, 0) };
1088
1089 assert!((&res.means[idx_a] - &mu_a).norm() < 0.25);
1090 assert!((&res.means[idx_b] - &mu_b).norm() < 0.25);
1091
1092 let cov_a_hat = &res.covariances[idx_a];
1093 let cov_b_hat = &res.covariances[idx_b];
1094 for i in 0..2 {
1095 let a_true = cov_a[(i, i)];
1096 let a_est = cov_a_hat[(i, i)];
1097 assert!((a_est - a_true).abs() / a_true < 0.35);
1098
1099 let b_true = cov_b[(i, i)];
1100 let b_est = cov_b_hat[(i, i)];
1101 assert!((b_est - b_true).abs() / b_true < 0.35);
1102 }
1103 assert!((cov_a_hat[(0, 1)] - cov_a[(0, 1)]).abs() < 0.1);
1104 assert!((cov_b_hat[(0, 1)] - cov_b[(0, 1)]).abs() < 0.1);
1105
1106 let count_a = res.labels[..n_a].iter().filter(|&&l| l == idx_a).count();
1107 let count_b = res.labels[n_a..].iter().filter(|&&l| l == idx_b).count();
1108 assert!(count_a as Float > 0.9 * n_a as Float);
1109 assert!(count_b as Float > 0.9 * n_b as Float);
1110 }
1111}