ganesh/algorithms/mcmc/
ess.rs

1use crate::{
2    algorithms::mcmc::{EnsembleStatus, Walker},
3    core::{
4        utils::{generate_random_vector_in_limits, RandChoice, SampleFloat},
5        MCMCSummary, Point,
6    },
7    traits::{Algorithm, LogDensity, Status, SupportsTransform, Transform},
8    DMatrix, DVector, Float, PI,
9};
10use fastrand::Rng;
11use nalgebra::Cholesky;
12use parking_lot::RwLock;
13use std::sync::Arc;
14
15/// A move used by the [`ESS`] algorithm
16///
17/// See Karamanis & Beutler[^1] for step implementation algorithms
18///
19/// [^1]: Karamanis, M., & Beutler, F. (2020). Ensemble slice sampling: Parallel, black-box and gradient-free inference for correlated & multimodal distributions. arXiv Preprint arXiv: 2002.06212.
20#[derive(Copy, Clone)]
21pub enum ESSMove {
22    /// The Differential move described in Algorithm 2 of Karamanis & Beutler
23    Differential,
24    /// The Gaussian move described in Algorithm 3 of Karamanis & Beutler
25    Gaussian,
26    /// The Global move described in Algorithm 4 of Karamanis & Beutler
27    Global {
28        /// A scale factor that is applied if the walker jumps within its own cluster
29        scale: Float,
30        /// A rescaling factor applied to the covariance which promotes mode jumping
31        rescale_cov: Float,
32        /// The number of mixture coefficients
33        n_components: usize,
34    },
35}
36impl ESSMove {
37    /// Create a new [`ESSMove::Differential`] with a usage weight
38    pub const fn differential(weight: Float) -> WeightedESSMove {
39        (Self::Differential, weight)
40    }
41    /// Create a new [`ESSMove::Gaussian`] with a usage weight
42    pub const fn gaussian(weight: Float) -> WeightedESSMove {
43        (Self::Gaussian, weight)
44    }
45    /// Create a new [`ESSMove::Global`] with a usage weight
46    pub fn global(
47        weight: Float,
48        scale: Option<Float>,
49        rescale_cov: Option<Float>,
50        n_components: Option<usize>,
51    ) -> WeightedESSMove {
52        (
53            Self::Global {
54                scale: scale.unwrap_or(1.0),
55                rescale_cov: rescale_cov.unwrap_or(0.001),
56                n_components: n_components.unwrap_or(5),
57            },
58            weight,
59        )
60    }
61    #[allow(clippy::too_many_arguments)]
62    fn step<P, U, E>(
63        &self,
64        step: usize,
65        n_adaptive: usize,
66        max_steps: usize,
67        mu: &mut Float,
68        problem: &P,
69        transform: &Option<Box<dyn Transform>>,
70        args: &U,
71        ensemble: &mut EnsembleStatus,
72        rng: &mut Rng,
73    ) -> Result<(), E>
74    where
75        P: LogDensity<U, E>,
76    {
77        let mut positions = Vec::with_capacity(ensemble.len());
78        match self {
79            Self::Differential => {
80                ensemble.update_message("Differential Move");
81            }
82            Self::Gaussian => {
83                ensemble.update_message("Gaussian Move");
84            }
85            Self::Global {
86                scale,
87                rescale_cov,
88                n_components,
89            } => {
90                ensemble.update_message(&format!(
91                    "Global Move (scale = {}, rescale_cov = {}, n_components = {})",
92                    scale, rescale_cov, n_components
93                ));
94            }
95        }
96        let mut n_expand = 0;
97        let mut n_contract = 0;
98        let mut dpgm_result = None;
99        for (i, walker) in ensemble.iter().enumerate() {
100            let x_k = walker.get_latest();
101            let eta = match self {
102                Self::Differential => {
103                    // Given a walker Xₖ and complementary set of walkers S, pick two walkers Xₗ and Xₘ from S (without
104                    // replacement) and compute direction vector ηₖ = μ(Xₗ - Xₘ)
105                    let s = &ensemble.get_compliment_walkers(i, 2, rng);
106                    let x_l = s[0].get_latest();
107                    let x_m = s[1].get_latest();
108                    let eta = (transform.to_internal(&x_l.read().x).as_ref()
109                        - transform.to_internal(&x_m.read().x).as_ref())
110                    .scale(*mu);
111                    eta
112                }
113                Self::Gaussian => {
114                    // Cₛ = 1/|S|   ⅀ (Xₗ - X̅ₛ)(Xₗ - X̅ₛ)†
115                    //            Xₗ∈S
116                    // sample ηₖ/(2μ) ∝ Norm(0, Cₛ)
117                    //
118                    // We can do this faster by selecting Zₗ ~ Norm(μ=0, σ=1) and
119                    //
120                    // W = ⅀ Zₗ(Xₗ - X̅ₛ)
121                    //   Xₗ∈S
122                    let x_s = ensemble.internal_mean_compliment(i, transform);
123                    ensemble
124                        .iter_compliment(i)
125                        .map(|x_l| {
126                            (transform.to_internal(&x_l.read().x).as_ref() - &x_s)
127                                .scale(rng.normal(0.0, 1.0))
128                        })
129                        .sum::<DVector<Float>>()
130                        .scale(2.0 * *mu)
131                }
132                Self::Global {
133                    scale,
134                    rescale_cov,
135                    n_components,
136                } => {
137                    let dpgm = dpgm_result
138                        .get_or_insert_with(|| dpgm(*n_components, ensemble, transform, rng));
139                    let labels = &dpgm.labels;
140                    let means = &dpgm.means;
141                    let covariances = &dpgm.covariances;
142                    let indices = rng.choose_multiple(labels.iter(), 2);
143                    let a = indices[0];
144                    let b = indices[1];
145                    // TODO: the multivariate sampling could be faster if the input was the
146                    // Cholesky decomposition of the covariance matrix
147                    if a == b {
148                        rng.mv_normal(&means[*a], &covariances[*a])
149                            .scale(2.0 * scale)
150                    } else {
151                        (rng.mv_normal(&means[*a], &covariances[*a].scale(*rescale_cov))
152                            - rng.mv_normal(&means[*b], &covariances[*b].scale(*rescale_cov)))
153                        .scale(2.0)
154                    }
155                }
156            };
157            // Y ~ U(0, f(Xₖ(t)))
158            let y = x_k.read().fx_checked() + rng.float().ln();
159            let x_k_internal = transform.to_internal(&x_k.read().x).into_owned();
160            // U ~ U(0, 1)
161            // L <- -U
162            let mut l = -rng.float();
163            let mut p_l = Point::from(&x_k_internal + eta.scale(l));
164            p_l.log_density_transformed(problem, transform, args)?;
165            // R <- L + 1
166            let mut r = l + 1.0;
167            let mut p_r = Point::from(&x_k_internal + eta.scale(r));
168            p_r.log_density_transformed(problem, transform, args)?;
169            // while Y < f(L) do
170            while y < p_l.fx_checked() && n_expand < max_steps {
171                // L <- L - 1
172                l -= 1.0;
173                p_l.set_position(&x_k_internal + eta.scale(l));
174                p_l.log_density_transformed(problem, transform, args)?;
175                // N₊(t) <- N₊(t) + 1
176                n_expand += 1;
177            }
178            // while Y < f(R) do
179            while y < p_r.fx_checked() && n_expand < max_steps {
180                // R <- R + 1
181                r += 1.0;
182                p_r.set_position(&x_k_internal + eta.scale(r));
183                p_r.log_density_transformed(problem, transform, args)?;
184                // N₊(t) <- N₊(t) + 1
185                n_expand += 1;
186            }
187            // while True do
188            let xprime = loop {
189                // X' ~ U(L, R)
190                let xprime = rng.range(l, r);
191                // Y' <- f(X'ηₖ + Xₖ(t))
192                let mut p_yprime = Point::from(&x_k_internal + eta.scale(xprime));
193                p_yprime.log_density_transformed(problem, transform, args)?;
194                if y < p_yprime.fx_checked() || n_contract >= max_steps {
195                    // if Y < Y' then break
196                    break xprime;
197                }
198                if xprime < 0.0 {
199                    // if X' < 0 then L <- X'
200                    l = xprime;
201                } else {
202                    // else R <- X'
203                    r = xprime;
204                }
205                // N₋(t) <- N₋(t) + 1
206                n_contract += 1;
207            };
208            // Xₖ(t+1) <- X'ηₖ + Xₖ(t)
209            let mut proposal = Point::from(x_k_internal + eta.scale(xprime));
210            proposal.log_density_transformed(problem, transform, args)?;
211            positions.push(Arc::new(RwLock::new(proposal.to_external(transform))))
212        }
213        // μ(t+1) <- TuneLengthScale(t, μ(t), N₊(t), N₋(t), M[adapt])
214        if step <= n_adaptive {
215            *mu *= 2.0 * (n_expand as Float) / (n_expand + n_contract) as Float
216        }
217        ensemble.push(positions);
218        Ok(())
219    }
220}
221
222/// The internal configuration struct for the [`ESS`] algorithm.
223#[derive(Clone)]
224pub struct ESSConfig {
225    transform: Option<Box<dyn Transform>>,
226    walkers: Vec<Walker>,
227    moves: Vec<WeightedESSMove>,
228    n_adaptive: usize,
229    max_steps: usize,
230    mu: Float,
231}
232impl ESSConfig {
233    /// Create a new configuratione with the initial positions of the walkers.
234    ///
235    /// This sets the default move list to use a [`ESSMove::Differential`] move 100% of the time.
236    ///
237    /// # See Also
238    /// [`Walker::new`]
239    pub fn new(x0: Vec<DVector<Float>>) -> Self {
240        Self {
241            transform: None,
242            walkers: x0.into_iter().map(Walker::new).collect(),
243            moves: vec![ESSMove::differential(1.0)],
244            n_adaptive: 0,
245            max_steps: 10000,
246            mu: 1.0,
247        }
248    }
249    /// Set the moves for the [`ESS`] algorithm to use.
250    pub fn with_moves<T: AsRef<[WeightedESSMove]>>(mut self, moves: T) -> Self {
251        self.moves = moves.as_ref().to_vec();
252        self
253    }
254    /// Set the number of adaptive moves to perform at the start of sampling (default: `0`)
255    pub const fn with_n_adaptive(mut self, n_adaptive: usize) -> Self {
256        self.n_adaptive = n_adaptive;
257        self
258    }
259    /// Set the maximum number of expansion/contractions to perform at each step (default: `10000`)
260    pub const fn with_max_steps(mut self, max_steps: usize) -> Self {
261        self.max_steps = max_steps;
262        self
263    }
264    /// Set the adaptive scaling parameter, $`\mu`$ (default: `1.0`)
265    pub const fn with_mu(mut self, mu: Float) -> Self {
266        self.mu = mu;
267        self
268    }
269}
270
271impl SupportsTransform for ESSConfig {
272    fn get_transform_mut(&mut self) -> &mut Option<Box<dyn Transform>> {
273        &mut self.transform
274    }
275}
276
277/// The Ensemble Slice Sampler
278///
279/// This sampler follows Algorithm 5 in Karamanis & Beutler.[^1].
280///
281/// [^1]: Karamanis, M., & Beutler, F. (2020). Ensemble slice sampling: Parallel, black-box and gradient-free inference for correlated & multimodal distributions. arXiv Preprint arXiv: 2002.06212.
282#[derive(Clone)]
283pub struct ESS {
284    rng: Rng,
285    mu: Float,
286}
287impl Default for ESS {
288    fn default() -> Self {
289        Self::new(Some(0))
290    }
291}
292
293/// A [`ESSMove`] coupled with a weight
294pub type WeightedESSMove = (ESSMove, Float);
295
296impl ESS {
297    /// Create a new Ensemble Slice Sampler with the given seed.
298    pub fn new(seed: Option<u64>) -> Self {
299        Self {
300            rng: seed.map_or_else(fastrand::Rng::new, fastrand::Rng::with_seed),
301            mu: 1.0,
302        }
303    }
304}
305impl<P, U, E> Algorithm<P, EnsembleStatus, U, E> for ESS
306where
307    P: LogDensity<U, E>,
308{
309    type Summary = MCMCSummary;
310    type Config = ESSConfig;
311    fn initialize(
312        &mut self,
313        problem: &P,
314        status: &mut EnsembleStatus,
315        args: &U,
316        config: &Self::Config,
317    ) -> Result<(), E> {
318        status.walkers = config.walkers.clone();
319        self.mu = config.mu;
320        status.log_density_latest(problem, args)
321    }
322
323    fn step(
324        &mut self,
325        current_step: usize,
326        problem: &P,
327        status: &mut EnsembleStatus,
328        args: &U,
329        config: &Self::Config,
330    ) -> Result<(), E> {
331        let step_type_index = self
332            .rng
333            .choice_weighted(&config.moves.iter().map(|s| s.1).collect::<Vec<Float>>())
334            .unwrap_or(0);
335        let step_type = config.moves[step_type_index].0;
336        step_type.step(
337            current_step,
338            config.n_adaptive,
339            config.max_steps,
340            &mut self.mu,
341            problem,
342            &config.transform,
343            args,
344            status,
345            &mut self.rng,
346        )
347    }
348
349    fn summarize(
350        &self,
351        _current_step: usize,
352        _problem: &P,
353        status: &EnsembleStatus,
354        _args: &U,
355        _config: &Self::Config,
356    ) -> Result<Self::Summary, E> {
357        Ok(MCMCSummary {
358            bounds: None,
359            parameter_names: None,
360            message: status.message().to_string(),
361            chain: status.get_chain(None, None),
362            cost_evals: status.n_f_evals,
363            gradient_evals: status.n_g_evals,
364            converged: status.converged(),
365            dimension: status.dimension(),
366        })
367    }
368}
369
370// Calculate the k-means cluster of a set of points
371//
372// n_clusters: number of clusters
373// data: (n_walkers, n_parameters)
374//
375// # Returns
376//
377// labels: Vec<usize> (n_walkers,)
378#[allow(clippy::unwrap_used)]
379fn kmeans(n_clusters: usize, data: &DMatrix<Float>, rng: &mut Rng) -> Vec<usize> {
380    let n_walkers = data.nrows();
381    let n_parameters = data.ncols();
382    let limits = data
383        .column_iter()
384        .map(|col| (col.min(), col.max()))
385        .collect::<Vec<_>>();
386    let mut centroids: Vec<DVector<Float>> = (0..n_clusters)
387        .map(|_| generate_random_vector_in_limits(&limits, rng))
388        .collect();
389    let mut labels = vec![0; n_walkers];
390    for _ in 0..50 {
391        for (i, walker) in data.row_iter().enumerate() {
392            labels[i] = centroids
393                .iter()
394                .enumerate()
395                .min_by(|(_, a), (_, b)| {
396                    (walker.transpose() - *a)
397                        .norm_squared()
398                        .partial_cmp(&(walker.transpose() - *b).norm_squared())
399                        .unwrap()
400                })
401                .map(|(j, _)| j)
402                .unwrap();
403        }
404        for (j, centroid) in centroids.iter_mut().enumerate() {
405            let mut sum = DVector::zeros(n_parameters);
406            let mut count = 0;
407            for (l, w) in labels.iter().zip(data.row_iter()) {
408                if *l == j {
409                    sum += w.transpose();
410                    count += 1;
411                }
412            }
413            if count > 0 {
414                sum /= count as Float;
415            }
416            *centroid = sum;
417        }
418    }
419    labels
420}
421
422// Computes the covariance matrix of a given matrix
423//
424// m: (N, M)
425//
426// # Returns
427//
428// cov: (N, N)
429fn cov(m: &DMatrix<Float>) -> DMatrix<Float> {
430    let mean: DVector<Float> = m
431        .row_iter()
432        .map(|row| row.mean())
433        .collect::<Vec<Float>>()
434        .into();
435    let centered = m.clone() - mean * DMatrix::from_element(1, m.ncols(), 1.0);
436    &centered * centered.transpose() / (m.ncols() as Float - 1.0)
437}
438
439// data: (n_walkers, n_parameters)
440// resp: (n_walkers, n_components)
441// reg_covar: Float
442//
443// # Returns
444//
445// nk: (n_components,)
446// means: (n_components, n_parameters)
447// covariances: (n_components, (n_parameters, n_parameters))
448fn estimate_gaussian_parameters(
449    data: &DMatrix<Float>,
450    resp: &DMatrix<Float>,
451    reg_covar: Float,
452) -> (DVector<Float>, DMatrix<Float>, Vec<DMatrix<Float>>) {
453    assert_eq!(data.nrows(), resp.nrows());
454
455    let nk = resp.row_sum_tr().add_scalar(10.0 * Float::EPSILON);
456    let mut means: DMatrix<Float> = resp.transpose() * data;
457    means.column_iter_mut().for_each(|mut c| {
458        c.component_div_assign(&nk);
459    });
460    let cov = (0..means.nrows())
461        .map(|k| {
462            let mean_k = means.row(k);
463            let diff =
464                DMatrix::from_rows(&data.row_iter().map(|row| row - mean_k).collect::<Vec<_>>());
465            let weighted_diff_t = DMatrix::from_columns(
466                &diff
467                    .row_iter()
468                    .zip(resp.column(k).iter())
469                    .map(|(d, &r)| d.scale(r).transpose())
470                    .collect::<Vec<_>>(),
471            );
472            let mut cov = (&weighted_diff_t * &diff).unscale(nk[k]);
473            for i in 0..data.ncols() {
474                cov[(i, i)] += reg_covar;
475            }
476            cov
477        })
478        .collect();
479    (nk, means, cov)
480}
481
482// nk: (n_components,)
483//
484// # Returns
485//
486// dirichlet_0: (n_components,)
487// dirichlet_1: (n_components,)
488fn estimate_weights(
489    nk: &DVector<Float>,
490    weight_concentration_prior: Float,
491) -> (DVector<Float>, DVector<Float>) {
492    let n_components = nk.len();
493    (nk.map(|x| x + 1.0), {
494        let reversed: Vec<Float> = nk.iter().rev().copied().collect();
495        let mut cumulative_sum = vec![0.0; n_components];
496        let mut sum: Float = 0.0;
497        for (i, &val) in reversed.iter().enumerate() {
498            sum += val;
499            cumulative_sum[i] = sum;
500        }
501        let mut tail = cumulative_sum[..n_components - 1]
502            .iter()
503            .rev()
504            .copied()
505            .collect::<Vec<Float>>();
506        tail.push(0.0);
507        DVector::from_iterator(
508            n_components,
509            tail.into_iter().map(|x| x + weight_concentration_prior),
510        )
511    })
512}
513
514// nk: (n_components,)
515// xk: (n_components, n_parameters)
516// mean_prior: (n_parameters,)
517//
518// # Returns:
519//
520// mean_precision: (n_components,)
521// means: (n_components, n_parameters)
522fn estimate_means(
523    nk: &DVector<Float>,
524    xk: &DMatrix<Float>,
525    mean_prior: &DVector<Float>,
526    mean_precision_prior: Float,
527) -> (DVector<Float>, DMatrix<Float>) {
528    assert_eq!(nk.len(), xk.nrows());
529    assert_eq!(mean_prior.len(), xk.ncols());
530    let mean_precision = nk.map(|x| x + mean_precision_prior);
531    let mut means = DMatrix::zeros(xk.nrows(), xk.ncols());
532    let nkxk: DMatrix<Float> = DMatrix::from_columns(
533        &xk.column_iter()
534            .map(|x| x.component_mul(nk))
535            .collect::<Vec<_>>(),
536    );
537    means.row_iter_mut().for_each(|mut row| {
538        row += mean_prior.transpose().scale(mean_precision_prior);
539    });
540    means += nkxk;
541    means.column_iter_mut().for_each(|mut col| {
542        col.component_div_assign(&mean_precision);
543    });
544    (mean_precision, means)
545}
546
547// nk: (n_components,)
548// xk: (n_components, n_parameters)
549// sk: (n_components, (n_parameters, n_parameters))
550//
551// covariance_prior: (n_parameters, n_parameters)
552// mean_prior: (n_parameters,)
553// mean_precision: (n_components,)
554//
555// # Returns
556//
557// degrees_of_freedom: (n_components,)
558// covariances: (n_components, (n_parameters, n_parameters))
559// precisions_cholesky: (n_components, (n_parameters, n_parameters))
560#[allow(clippy::too_many_arguments)]
561fn estimate_precisions(
562    nk: &DVector<Float>,
563    xk: &DMatrix<Float>,
564    sk: &[DMatrix<Float>],
565    degrees_of_freedom_prior: Float,
566    covariance_prior: &DMatrix<Float>,
567    mean_prior: &DVector<Float>,
568    mean_precision_prior: Float,
569    mean_precision: &DVector<Float>,
570) -> (DVector<Float>, Vec<DMatrix<Float>>, Vec<DMatrix<Float>>) {
571    let n_components = nk.len();
572    let n_parameters = mean_prior.len();
573
574    assert_eq!(xk.nrows(), n_components);
575    assert_eq!(xk.ncols(), n_parameters);
576    assert_eq!(covariance_prior.nrows(), n_parameters);
577    assert_eq!(covariance_prior.ncols(), n_parameters);
578    assert_eq!(mean_precision.len(), n_components);
579
580    let degrees_of_freedom = nk.map(|x| x + degrees_of_freedom_prior);
581
582    let mut covariances = Vec::with_capacity(n_components);
583    let mut precisions_cholesky = Vec::with_capacity(n_components);
584
585    for k in 0..n_components {
586        let nk_k = nk[k];
587        let xk_k = xk.row(k).transpose();
588        let sk_k = &sk[k];
589        let mean_precision_k = mean_precision[k];
590        let degrees_of_freedom_k = degrees_of_freedom[k];
591        let diff = &xk_k - mean_prior;
592        let outer = &diff * diff.transpose();
593        let covariance = (covariance_prior
594            + (sk_k * nk_k)
595            + outer * (nk_k * mean_precision_prior / mean_precision_k))
596            .unscale(degrees_of_freedom_k);
597        covariances.push(covariance.clone());
598        #[allow(clippy::expect_used)]
599        let cholesky = Cholesky::new(covariance).expect("Cholesky decomposition failed");
600        let l = cholesky.l();
601        let id = DMatrix::identity(n_parameters, n_parameters);
602        #[allow(clippy::expect_used)]
603        let solved = l
604            .solve_lower_triangular(&id)
605            .expect("Colesky solve_lower_triangular failed");
606        precisions_cholesky.push(solved.transpose());
607    }
608    (degrees_of_freedom, covariances, precisions_cholesky)
609}
610
611// precisions_cholesky: (n_components, (n_parameters, n_parameters))
612//
613// # Returns
614//
615// log_det_cholesky: (n_components,)
616fn log_det_cholesky(precisions_cholesky: &[DMatrix<Float>], n_parameters: usize) -> DVector<Float> {
617    DVector::from_iterator(
618        precisions_cholesky.len(),
619        precisions_cholesky
620            .iter()
621            .map(|chol| (0..n_parameters).map(|i| chol[(i, i)].ln()).sum()),
622    )
623}
624
625// data: (n_walkers, n_parameters)
626// means: (n_components, n_parameters)
627// precisions_cholesky: (n_components, (n_parameters, n_parameters))
628//
629// # Returns
630//
631// log_prob: (n_walkers, n_components)
632fn log_gaussian_prob(
633    data: &DMatrix<Float>,
634    means: &DMatrix<Float>,
635    precisions_cholesky: &[DMatrix<Float>],
636) -> DMatrix<Float> {
637    let n_walkers = data.nrows();
638    let n_parameters = data.ncols();
639    let n_components = means.nrows();
640
641    let log_det = log_det_cholesky(precisions_cholesky, n_parameters);
642    let mut log_prob = DMatrix::zeros(n_walkers, n_components);
643    for k in 0..n_components {
644        let mu_k = means.row(k);
645        let prec_chol_k = &precisions_cholesky[k];
646
647        for i in 0..n_walkers {
648            let x_i = data.row(i);
649            let centered = x_i - mu_k;
650            let y = &centered * prec_chol_k;
651            let sq_sum = y.map(|val| val * val).sum();
652            log_prob[(i, k)] = (-0.5 as Float).mul_add(
653                (n_parameters as Float).mul_add(Float::ln(2.0 * PI), sq_sum),
654                log_det[k],
655            );
656        }
657    }
658    log_prob
659}
660
661// data: (n_walkers, n_parameters)
662// means: (n_components, n_parameters)
663// precisions_cholesky: (n_components, (n_parameters, n_parameters))
664//
665// # Returns
666//
667// log_prob_norm: Float
668// log_resp: (n_walkers, n_components)
669#[allow(clippy::unnecessary_cast)]
670fn e_step(
671    data: &DMatrix<Float>,
672    means: &DMatrix<Float>,
673    precisions_cholesky: &[DMatrix<Float>],
674    mean_precision: &DVector<Float>,
675    degrees_of_freedom: &DVector<Float>,
676    weight_concentration: &(DVector<Float>, DVector<Float>),
677) -> (Float, DMatrix<Float>) {
678    let n_walkers = data.nrows();
679    let n_parameters = data.ncols();
680    let n_components = means.nrows();
681    let estimated_log_prob = {
682        let mut log_gauss = log_gaussian_prob(data, means, precisions_cholesky);
683        log_gauss.row_iter_mut().for_each(|mut row| {
684            row -= degrees_of_freedom
685                .map(|x| 0.5 * (n_parameters as Float) * x.ln())
686                .transpose()
687        });
688        let log_lambda = {
689            let mut res: DVector<Float> = DVector::zeros(n_components);
690            for j in 0..n_parameters {
691                for k in 0..n_components {
692                    res[k] += spec_math::Gamma::digamma(
693                        &((0.5 * (degrees_of_freedom[k] - j as Float)) as f64),
694                    ) as Float
695                }
696            }
697            res.map(|r| (n_parameters as Float).mul_add(Float::ln(2.0), r))
698        };
699        log_gauss.row_iter_mut().for_each(|mut row| {
700            row += (0.5 * (&log_lambda - mean_precision.map(|mu| n_parameters as Float / mu)))
701                .transpose()
702        });
703        log_gauss
704    };
705    let estimated_log_weights = {
706        let a = &weight_concentration.0;
707        let b = &weight_concentration.1;
708        let n = a.len();
709        let digamma_sum = (a + b).map(|v| spec_math::Gamma::digamma(&(v as f64)) as Float);
710        let digamma_a = a.map(|v| spec_math::Gamma::digamma(&(v as f64)) as Float);
711        let digamma_b = b.map(|v| spec_math::Gamma::digamma(&(v as f64)) as Float);
712        let mut cumulative = Vec::with_capacity(n);
713        let mut acc = 0.0;
714        cumulative.push(0.0);
715        for i in 0..n - 1 {
716            acc += digamma_b[i] - digamma_sum[i];
717            cumulative.push(acc);
718        }
719        DVector::from_iterator(
720            n,
721            (0..n).map(|i| digamma_a[i] - digamma_sum[i] + cumulative[i]),
722        )
723    };
724    let mut weighted_log_prob = estimated_log_prob;
725    weighted_log_prob
726        .row_iter_mut()
727        .for_each(|mut row| row += &estimated_log_weights.transpose());
728    let log_prob_norm = DVector::from_iterator(
729        n_walkers,
730        weighted_log_prob
731            .row_iter()
732            .map(|row| logsumexp::LogSumExp::ln_sum_exp(row.iter())),
733    );
734    let mut log_resp = weighted_log_prob;
735    log_resp
736        .column_iter_mut()
737        .for_each(|mut col| col -= &log_prob_norm);
738    (log_prob_norm.mean(), log_resp)
739}
740
741#[derive(Clone)]
742struct DPGMResult {
743    // labels: (n_walkers,)
744    labels: Vec<usize>,
745    // means: (n_components, (n_parameters,))
746    means: Vec<DVector<Float>>,
747    // covariances: (n_components, (n_parameters, n_parameters))
748    covariances: Vec<DMatrix<Float>>,
749}
750
751// Dirichlet Process Gaussian Mixture
752//
753// Code is taken almost verbatim (converting numpy to nalgebra) from
754// <https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/mixture/_bayesian_mixture.py#L74>
755// with some modifications to only use the "full"" covariance mode, the "kmeans" initialization
756// method, and the "dirichlet_process" weight concentration prior. See the readme/crate
757// documentation for the proper citation.
758//
759// n_components: usize, the number of Gaussian mixture components
760// ensemble: &Ensemble
761//
762// # Returns
763//
764// DPGMResult
765#[allow(clippy::unnecessary_cast)]
766fn dpgm(
767    n_components: usize,
768    ensemble: &EnsembleStatus,
769    transform: &Option<Box<dyn Transform>>,
770    rng: &mut Rng,
771) -> DPGMResult
772where
773{
774    let (n_walkers, _, n_parameters) = ensemble.dimension();
775    let data = ensemble.get_latest_internal_position_matrix(transform);
776    let weight_concentration_prior = 1.0 / n_components as Float;
777    let mean_precision_prior = 1.0;
778    let mean_prior = ensemble.internal_mean(transform);
779    let degrees_of_freedom_prior = n_parameters as Float;
780    let covariance_prior = cov(&data.transpose());
781
782    let mut resp: DMatrix<Float> = DMatrix::zeros(n_walkers, n_components);
783    let labels = kmeans(n_components, &data, rng);
784    for (i, &cluster_id) in labels.iter().enumerate() {
785        resp[(i, cluster_id)] = 1.0;
786    }
787    let (mut nk, mut xk, mut sk) = estimate_gaussian_parameters(&data, &resp, 1e-6);
788    let mut weight_concentration = estimate_weights(&nk, weight_concentration_prior);
789    let (mut mean_precision, mut means) =
790        estimate_means(&nk, &xk, &mean_prior, mean_precision_prior);
791    let (mut degrees_of_freedom, mut covariances, mut precisions_cholesky) = estimate_precisions(
792        &nk,
793        &xk,
794        &sk,
795        degrees_of_freedom_prior,
796        &covariance_prior,
797        &mean_prior,
798        mean_precision_prior,
799        &mean_precision,
800    );
801    let mut lower_bound = Float::NEG_INFINITY;
802    for _ in 1..=100 {
803        let prev_lower_bound = lower_bound;
804        let (_, log_resp) = e_step(
805            &data,
806            &means,
807            &precisions_cholesky,
808            &mean_precision,
809            &degrees_of_freedom,
810            &weight_concentration,
811        );
812        (nk, xk, sk) = estimate_gaussian_parameters(&data, &log_resp.map(Float::exp), 1e-6);
813        weight_concentration = estimate_weights(&nk, weight_concentration_prior);
814        (mean_precision, means) = estimate_means(&nk, &xk, &mean_prior, mean_precision_prior);
815        (degrees_of_freedom, covariances, precisions_cholesky) = estimate_precisions(
816            &nk,
817            &xk,
818            &sk,
819            degrees_of_freedom_prior,
820            &covariance_prior,
821            &mean_prior,
822            mean_precision_prior,
823            &mean_precision,
824        );
825        lower_bound = {
826            let log_det_precisions_cholesky = log_det_cholesky(&precisions_cholesky, n_parameters)
827                - degrees_of_freedom
828                    .map(Float::ln)
829                    .scale(0.5 * n_parameters as Float);
830            let log_wishart_norm = {
831                let mut log_wishart_norm =
832                    degrees_of_freedom.component_mul(&log_det_precisions_cholesky);
833                log_wishart_norm +=
834                    degrees_of_freedom.scale(0.5 * Float::ln(2.0) * n_parameters as Float);
835
836                let gammaln_term: DVector<Float> = degrees_of_freedom.map(|dof| {
837                    (0..n_parameters)
838                        .map(|i| {
839                            spec_math::Gamma::lgamma(&((0.5 * (dof - i as Float)) as f64)) as Float
840                        })
841                        .sum()
842                });
843                log_wishart_norm += gammaln_term;
844                -log_wishart_norm
845            };
846            let log_norm_weight = -((0..weight_concentration.0.len())
847                .map(|i| {
848                    spec_math::Beta::lbeta(
849                        &(weight_concentration.0[i] as f64),
850                        weight_concentration.1[i] as f64,
851                    )
852                })
853                .sum::<f64>()) as Float;
854            (0.5 * (n_parameters as Float)).mul_add(
855                -mean_precision.map(|mp| mp.ln()).sum(),
856                -log_resp.map(|lr| lr.exp() * lr).sum() - log_wishart_norm.sum(),
857            ) - log_norm_weight
858        };
859        let change = lower_bound - prev_lower_bound;
860        if change.abs() < 1e-3 {
861            break;
862        }
863    }
864    let weight_dirichlet_sum = &weight_concentration.0 + &weight_concentration.1;
865    let tmp0 = &weight_concentration.0.component_div(&weight_dirichlet_sum);
866    let tmp1 = &weight_concentration.1.component_div(&weight_dirichlet_sum);
867    let mut prod_vec = Vec::with_capacity(n_components);
868    prod_vec.push(1.0);
869    for i in 0..(n_components - 1) {
870        prod_vec.push(prod_vec[i] * tmp1[i])
871    }
872    let mut weights = tmp0.component_mul(&DVector::from_vec(prod_vec));
873    weights /= weights.sum();
874    // let precisions: Vec<DMatrix<Float>> = (0..n_components)
875    //     .map(|k| &precisions_cholesky[k] * precisions_cholesky[k].transpose())
876    //     .collect();
877    let (_, log_resp) = e_step(
878        &data,
879        &means,
880        &precisions_cholesky,
881        &mean_precision,
882        &degrees_of_freedom,
883        &weight_concentration,
884    );
885    DPGMResult {
886        labels: log_resp
887            .row_iter()
888            .map(|row| row.transpose().argmax().0)
889            .collect(),
890        means: means
891            .row_iter()
892            .map(|row| row.transpose())
893            .collect::<Vec<DVector<Float>>>(),
894        covariances,
895    }
896}
897
898#[cfg(test)]
899mod tests {
900    use super::*;
901    use crate::test_functions::Rosenbrock;
902
903    fn make_walkers(n_walkers: usize, dim: usize) -> Vec<DVector<Float>> {
904        (0..n_walkers)
905            .map(|i| DVector::from_element(dim, i as Float + 1.0))
906            .collect()
907    }
908
909    #[test]
910    fn test_essmove_constructors() {
911        let d = ESSMove::differential(0.5);
912        assert!(matches!(d.0, ESSMove::Differential));
913        assert_eq!(d.1, 0.5);
914
915        let g = ESSMove::gaussian(1.0);
916        assert!(matches!(g.0, ESSMove::Gaussian));
917
918        let gl = ESSMove::global(2.0, None, None, None);
919        if let ESSMove::Global {
920            scale,
921            rescale_cov,
922            n_components,
923        } = gl.0
924        {
925            assert_eq!(scale, 1.0);
926            assert_eq!(rescale_cov, 0.001);
927            assert_eq!(n_components, 5);
928        } else {
929            panic!("expected Global");
930        }
931        assert_eq!(gl.1, 2.0);
932    }
933
934    #[test]
935    fn test_essconfig_defaults_and_builders() {
936        let walkers = make_walkers(3, 2);
937        let cfg = ESSConfig::new(walkers);
938        assert_eq!(cfg.walkers.len(), 3);
939        assert_eq!(cfg.moves.len(), 1);
940        assert_eq!(cfg.n_adaptive, 0);
941        assert_eq!(cfg.max_steps, 10000);
942        assert_eq!(cfg.mu, 1.0);
943
944        let moves = vec![ESSMove::gaussian(1.0), ESSMove::differential(1.0)];
945        let cfg = cfg
946            .with_moves(&moves)
947            .with_n_adaptive(5)
948            .with_max_steps(42)
949            .with_mu(4.1);
950
951        assert_eq!(cfg.moves.len(), 2);
952        assert_eq!(cfg.n_adaptive, 5);
953        assert_eq!(cfg.max_steps, 42);
954        assert!((cfg.mu - 4.1).abs() < 1e-12);
955    }
956
957    #[test]
958    fn test_ess_initialize_and_summarize() {
959        let mut ess = ESS::default();
960        let walkers = make_walkers(3, 2);
961        let cfg = ESSConfig::new(walkers);
962        let mut status = EnsembleStatus::default();
963        let f = Rosenbrock { n: 2 };
964
965        ess.initialize(&f, &mut status, &(), &cfg).unwrap();
966        assert_eq!(status.walkers.len(), 3);
967
968        let summary = ess.summarize(0, &f, &status, &(), &cfg).unwrap();
969        assert_eq!(summary.dimension, status.dimension());
970    }
971
972    #[test]
973    fn test_differential_step_runs() {
974        let mut ess = ESS::default();
975        let walkers = make_walkers(3, 2);
976        let cfg = ESSConfig::new(walkers);
977        let mut status = EnsembleStatus::default();
978        let f = Rosenbrock { n: 2 };
979        ess.initialize(&f, &mut status, &(), &cfg).unwrap();
980
981        let result = ess.step(0, &f, &mut status, &(), &cfg);
982        assert!(result.is_ok());
983        assert!(status.message().contains("Differential"));
984    }
985
986    #[test]
987    fn test_gaussian_step_runs() {
988        let mut ess = ESS::default();
989        let walkers = make_walkers(6, 2);
990        let cfg = ESSConfig::new(walkers).with_moves(vec![ESSMove::gaussian(1.0)]);
991        let mut status = EnsembleStatus::default();
992        let f = Rosenbrock { n: 2 };
993
994        ess.initialize(&f, &mut status, &(), &cfg).unwrap();
995        let result = ess.step(0, &f, &mut status, &(), &cfg);
996        assert!(result.is_ok());
997        assert!(status.message().contains("Gaussian"));
998    }
999
1000    #[test]
1001    fn test_global_step_runs() {
1002        let mut ess = ESS::default();
1003        let walkers = make_walkers(100, 2);
1004        let cfg = ESSConfig::new(walkers).with_moves(vec![ESSMove::global(
1005            1.0,
1006            Some(1.0),
1007            Some(0.001),
1008            Some(3),
1009        )]);
1010        let mut status = EnsembleStatus::default();
1011        let f = Rosenbrock { n: 2 };
1012
1013        ess.initialize(&f, &mut status, &(), &cfg).unwrap();
1014        let result = ess.step(0, &f, &mut status, &(), &cfg);
1015        assert!(result.is_ok());
1016        assert!(status.message().contains("Global"));
1017    }
1018
1019    #[test]
1020    fn test_kmeans_two_clusters() {
1021        let mut rng = Rng::with_seed(0);
1022
1023        let points_a = [
1024            DVector::from_vec(vec![0.0, 0.1]).transpose(),
1025            DVector::from_vec(vec![0.2, -0.1]).transpose(),
1026            DVector::from_vec(vec![-0.1, 0.0]).transpose(),
1027        ];
1028        let points_b = [
1029            DVector::from_vec(vec![10.0, 10.1]).transpose(),
1030            DVector::from_vec(vec![9.8, 9.9]).transpose(),
1031            DVector::from_vec(vec![10.2, 9.9]).transpose(),
1032        ];
1033
1034        let mut rows = Vec::new();
1035        rows.extend(points_a.iter().cloned());
1036        rows.extend(points_b.iter().cloned());
1037        let data = DMatrix::from_rows(&rows);
1038
1039        let labels = super::kmeans(2, &data, &mut rng);
1040        assert_eq!(labels.len(), 6);
1041
1042        assert_eq!(labels[0], labels[1]);
1043        assert_eq!(labels[1], labels[2]);
1044        assert_eq!(labels[3], labels[4]);
1045        assert_eq!(labels[4], labels[5]);
1046        assert_ne!(labels[0], labels[3]);
1047    }
1048
1049    #[test]
1050    #[allow(clippy::field_reassign_with_default)]
1051    fn test_dpgm_recovers_means_covariances_two_blobs() {
1052        use crate::core::utils::SampleFloat;
1053
1054        let mu_a = DVector::from_vec(vec![0.0, 0.0]);
1055        let mu_b = DVector::from_vec(vec![3.0, -2.0]);
1056        let cov_a = DMatrix::from_row_slice(2, 2, &[0.20, 0.05, 0.05, 0.10]);
1057        let cov_b = DMatrix::from_row_slice(2, 2, &[0.30, -0.04, -0.04, 0.50]);
1058
1059        let n_a = 80usize;
1060        let n_b = 70usize;
1061        let mut rng = Rng::with_seed(0);
1062
1063        let mut positions: Vec<Walker> = Vec::with_capacity(n_a + n_b);
1064        for _ in 0..n_a {
1065            let x = rng.mv_normal(&mu_a, &cov_a);
1066            positions.push(Walker::new(x));
1067        }
1068        for _ in 0..n_b {
1069            let x = rng.mv_normal(&mu_b, &cov_b);
1070            positions.push(Walker::new(x));
1071        }
1072
1073        let mut status = EnsembleStatus::default();
1074        status.walkers = positions;
1075
1076        let mut rng2 = Rng::with_seed(0);
1077        let res = super::dpgm(2, &status, &None, &mut rng2);
1078
1079        assert_eq!(res.labels.len(), n_a + n_b);
1080        assert_eq!(res.means.len(), 2);
1081        assert_eq!(res.covariances.len(), 2);
1082        assert_eq!(res.covariances[0].nrows(), 2);
1083        assert_eq!(res.covariances[0].ncols(), 2);
1084
1085        let d0_a = (&res.means[0] - &mu_a).norm();
1086        let d1_a = (&res.means[1] - &mu_a).norm();
1087        let (idx_a, idx_b) = if d0_a <= d1_a { (0, 1) } else { (1, 0) };
1088
1089        assert!((&res.means[idx_a] - &mu_a).norm() < 0.25);
1090        assert!((&res.means[idx_b] - &mu_b).norm() < 0.25);
1091
1092        let cov_a_hat = &res.covariances[idx_a];
1093        let cov_b_hat = &res.covariances[idx_b];
1094        for i in 0..2 {
1095            let a_true = cov_a[(i, i)];
1096            let a_est = cov_a_hat[(i, i)];
1097            assert!((a_est - a_true).abs() / a_true < 0.35);
1098
1099            let b_true = cov_b[(i, i)];
1100            let b_est = cov_b_hat[(i, i)];
1101            assert!((b_est - b_true).abs() / b_true < 0.35);
1102        }
1103        assert!((cov_a_hat[(0, 1)] - cov_a[(0, 1)]).abs() < 0.1);
1104        assert!((cov_b_hat[(0, 1)] - cov_b[(0, 1)]).abs() < 0.1);
1105
1106        let count_a = res.labels[..n_a].iter().filter(|&&l| l == idx_a).count();
1107        let count_b = res.labels[n_a..].iter().filter(|&&l| l == idx_b).count();
1108        assert!(count_a as Float > 0.9 * n_a as Float);
1109        assert!(count_b as Float > 0.9 * n_b as Float);
1110    }
1111}