hammer_and_sample/
lib.rs

1#![forbid(unsafe_code)]
2#![deny(missing_docs)]
3//! Simplistic MCMC ensemble sampler based on [emcee](https://emcee.readthedocs.io/), the MCMC hammer
4//!
5//! ```
6//! use hammer_and_sample::{sample, MinChainLen, Model, Serial, Stretch};
7//! use rand::{Rng, SeedableRng};
8//! use rand_pcg::Pcg64;
9//!
10//! fn estimate_bias(coin_flips: &[bool]) -> f64 {
11//!     struct CoinFlips<'a>(&'a [bool]);
12//!
13//!     impl Model for CoinFlips<'_> {
14//!         type Params = [f64; 1];
15//!
16//!         // likelihood of Bernoulli distribution and uninformative prior
17//!         fn log_prob(&self, &[p]: &Self::Params) -> f64 {
18//!             if p < 0. || p > 1. {
19//!                 return f64::NEG_INFINITY;
20//!             }
21//!
22//!             let ln_p = p.ln();
23//!             let ln_1_p = (1. - p).ln();
24//!
25//!             self.0
26//!                 .iter()
27//!                 .map(|coin_flip| if *coin_flip { ln_p } else { ln_1_p })
28//!                 .sum()
29//!         }
30//!     }
31//!
32//!     let model = CoinFlips(coin_flips);
33//!
34//!     let walkers = (0..10).map(|seed| {
35//!         let mut rng = Pcg64::seed_from_u64(seed);
36//!
37//!         let p = rng.random_range(0.0..=1.0);
38//!
39//!         ([p], rng)
40//!     });
41//!
42//!     let (chain, _accepted) = sample(&model, &Stretch::default(), walkers, MinChainLen(10 * 1000), Serial);
43//!
44//!     // 100 iterations of 10 walkers as burn-in
45//!     let chain = &chain[10 * 100..];
46//!
47//!     chain.iter().map(|&[p]| p).sum::<f64>() / chain.len() as f64
48//! }
49//! ```
50use std::ops::ControlFlow;
51use std::ptr;
52
53use rand::{
54    distr::{Distribution, StandardUniform, Uniform},
55    Rng,
56};
57use rand_distr::{
58    weighted::{AliasableWeight, WeightedAliasIndex},
59    Normal,
60};
61#[cfg(feature = "rayon")]
62use rayon::iter::{IntoParallelRefMutIterator, ParallelExtend, ParallelIterator};
63
64/// Model parameters defining the state space of the Markov chain
65pub trait Params: Send + Sync + Clone {
66    /// Dimension of the state space
67    ///
68    /// This can depend on `self` in situations where the number of parameters depends on the data itself, e.g. the number of groups in a hierarchical model.
69    fn dimension(&self) -> usize;
70
71    /// Access the individual parameter values as an iterator
72    fn values(&self) -> impl Iterator<Item = &f64>;
73
74    /// Collect new parameters from the given iterator
75    fn collect(iter: impl Iterator<Item = f64>) -> Self;
76}
77
78/// Model parameters stored as an array of length `N` considered as an element of the vector space `R^N`
79impl<const N: usize> Params for [f64; N] {
80    fn dimension(&self) -> usize {
81        N
82    }
83
84    fn values(&self) -> impl Iterator<Item = &f64> {
85        self.iter()
86    }
87
88    fn collect(iter: impl Iterator<Item = f64>) -> Self {
89        let mut new = [0.; N];
90        iter.enumerate().for_each(|(idx, value)| new[idx] = value);
91        new
92    }
93}
94
95/// Model parameters stored as a vector of length `n` considered as an element of the vector space `R^n`
96impl Params for Vec<f64> {
97    fn dimension(&self) -> usize {
98        self.len()
99    }
100
101    fn values(&self) -> impl Iterator<Item = &f64> {
102        self.iter()
103    }
104
105    fn collect(iter: impl Iterator<Item = f64>) -> Self {
106        iter.collect()
107    }
108}
109
110/// Model parameters stored as a boxed slice of length `n` considered as an element of the vector space `R^n`
111impl Params for Box<[f64]> {
112    fn dimension(&self) -> usize {
113        self.len()
114    }
115
116    fn values(&self) -> impl Iterator<Item = &f64> {
117        self.iter()
118    }
119
120    fn collect(iter: impl Iterator<Item = f64>) -> Self {
121        iter.collect()
122    }
123}
124
125/// A move defines how new estimates of the model parameters are proposed
126pub trait Move<P>
127where
128    P: Params,
129{
130    /// Propose new estimates of the model parameters
131    ///
132    /// The proposal is based on the current estimate `this` and
133    /// optionally, randomly sampled estimates of `other` walkers.
134    ///
135    /// In addition to the new estimate, a correction factor to be added
136    /// to the difference of logarithmic probabilities can be returned.
137    fn propose<'a, O, R>(&self, this: &'a P, other: O, rng: &mut R) -> (P, f64)
138    where
139        O: FnMut(&mut R) -> &'a P,
140        R: Rng;
141}
142
143/// The "stretch" move orignally used by the emcee sampler
144///
145/// Symmetric affine invariant move as described in [Goodman & Weare (2010)](https://msp.org/camcos/2010/5-1/p04.xhtml).
146pub struct Stretch {
147    scale: f64,
148}
149
150impl Stretch {
151    /// Construct a "stretch" move using the given `scale` parameter
152    pub fn new(scale: f64) -> Self {
153        Self { scale }
154    }
155}
156
157impl Default for Stretch {
158    fn default() -> Self {
159        Self::new(2.)
160    }
161}
162
163impl<P> Move<P> for Stretch
164where
165    P: Params,
166{
167    fn propose<'a, O, R>(&self, this: &'a P, mut other: O, rng: &mut R) -> (P, f64)
168    where
169        O: FnMut(&mut R) -> &'a P,
170        R: Rng,
171    {
172        let other = other(rng);
173
174        let z = ((self.scale - 1.) * gen_unit(rng) + 1.).powi(2) / self.scale;
175
176        let new_state = P::collect(
177            this.values()
178                .zip(other.values())
179                .map(|(this, other)| (this - other).mul_add(z, *other)),
180        );
181
182        let factor = (new_state.dimension() - 1) as f64 * z.ln();
183
184        (new_state, factor)
185    }
186}
187
188/// Move using differential evolution based on two other walkers
189///
190/// Using a normal distribution to scale the proposal as described in [Nelson et al. (2013)](https://iopscience.iop.org/article/10.1088/0067-0049/210/1/11).
191pub struct DifferentialEvolution {
192    gamma: Normal<f64>,
193}
194
195impl DifferentialEvolution {
196    /// Construct a differential evolution move using a normal distribution `gamma`
197    ///
198    /// A reasonable default for `gamma_mean` is `2.38 / (2 * N).sqrt()` where `N` is the dimension of the state space.
199    ///
200    /// A reasonable default for `gamma_std_dev` is `1.0e-5`.
201    pub fn new(gamma_mean: f64, gamma_std_dev: f64) -> Self {
202        Self {
203            gamma: Normal::new(gamma_mean, gamma_std_dev).unwrap(),
204        }
205    }
206}
207
208impl<P> Move<P> for DifferentialEvolution
209where
210    P: Params,
211{
212    fn propose<'a, O, R>(&self, this: &'a P, mut other: O, rng: &mut R) -> (P, f64)
213    where
214        O: FnMut(&mut R) -> &'a P,
215        R: Rng,
216    {
217        let first_other = other(rng);
218        let mut second_other = other(rng);
219
220        while ptr::eq(first_other, second_other) {
221            second_other = other(rng);
222        }
223
224        let gamma = self.gamma.sample(rng);
225
226        let new_state = P::collect(
227            this.values()
228                .zip(first_other.values())
229                .zip(second_other.values())
230                .map(|((this, first_other), second_other)| {
231                    (first_other - second_other).mul_add(gamma, *this)
232                }),
233        );
234
235        (new_state, 0.)
236    }
237}
238
239/// A Metropolis step with a Gaussian proposal function
240///
241/// For each step, a direction is choosen randomly and
242/// the displacement is sampled from a centered normal distribution.
243pub struct RandomGaussian {
244    displ: Normal<f64>,
245}
246
247impl RandomGaussian {
248    /// Construct a move using the given standard deviation of the displacement `displ`
249    pub fn new(displ: f64) -> Self {
250        Self {
251            displ: Normal::new(0., displ).unwrap(),
252        }
253    }
254}
255
256impl<P> Move<P> for RandomGaussian
257where
258    P: Params,
259{
260    fn propose<'a, O, R>(&self, this: &'a P, _other: O, rng: &mut R) -> (P, f64)
261    where
262        O: FnMut(&mut R) -> &'a P,
263        R: Rng,
264    {
265        let dir = rng.random_range(0..this.dimension());
266
267        let new_state = P::collect(this.values().enumerate().map(|(idx, value)| {
268            if idx == dir {
269                value + self.displ.sample(rng)
270            } else {
271                *value
272            }
273        }));
274
275        (new_state, 0.)
276    }
277}
278
279/// Combines multiple moves into a single mixture
280///
281/// Mixtures are constructed from tuples of `(Move, Weight)` pairs.
282///
283/// For each step, a single move is selected to determine the next proposal.
284/// The probability of selecting a given move is determined by its relative weight.
285///
286/// ```
287/// # use hammer_and_sample::{sample, MinChainLen, Mixture, Model, RandomGaussian, Serial, Stretch};
288/// # use rand::SeedableRng;
289/// # use rand_pcg::Pcg64Mcg;
290/// #
291/// # struct Dummy;
292/// #
293/// # impl Model for Dummy {
294/// #     type Params = [f64; 1];
295/// #
296/// #     fn log_prob(&self, state: &Self::Params) -> f64 {
297/// #         f64::NEG_INFINITY
298/// #     }
299/// # }
300/// #
301/// # let model = Dummy;
302/// #
303/// # let walkers = (0..100).map(|idx| {
304/// #     let mut rng = Pcg64Mcg::seed_from_u64(idx);
305/// #
306/// #     ([0.], rng)
307/// # });
308/// #
309/// let move_ = Mixture::from((
310///     (Stretch::default(), 2),
311///     (RandomGaussian::new(1.0e-3), 1),
312/// ));
313///
314/// let (chain, accepted) = sample(&model, &move_, walkers, MinChainLen(100_000), Serial);
315/// ```
316pub struct Mixture<W, M>(WeightedAliasIndex<W>, M)
317where
318    W: AliasableWeight;
319
320macro_rules! impl_mixture {
321    ( $( $types:ident @ $weights:ident ),+ ) => {
322        impl<W, $( $types ),+> From<( $( ( $types, W ) ),+ )> for Mixture<W, ( $( $types ),+ )>
323        where
324            W: AliasableWeight
325        {
326            #[allow(non_snake_case)]
327            fn from(( $( ( $types, $weights ) ),+ ): ( $( ( $types, W ) ),+ )) -> Self {
328                let index = WeightedAliasIndex::new(vec![$( $weights ),+]).unwrap();
329
330                Self(index, ( $( $types ),+ ))
331            }
332        }
333
334        impl<W, $( $types ),+, P> Move<P> for Mixture<W, ( $( $types ),+ )>
335        where
336            W: AliasableWeight,
337            P: Params,
338            $( $types: Move<P> ),+
339        {
340            #[allow(non_snake_case)]
341            fn propose<'a, O, R>(&self, this: &'a P, other: O, rng: &mut R) -> (P, f64)
342            where
343                O: FnMut(&mut R) -> &'a P,
344                R: Rng,
345            {
346                let Self(index, ( $( $types ),+ )) = self;
347
348                let chosen_index = index.sample(rng);
349
350                let mut index = 0;
351
352                $(
353
354                #[allow(unused_assignments)]
355                if chosen_index == index {
356                    return $types.propose(this, other, rng)
357                } else {
358                    index += 1;
359                }
360
361                )+
362
363                unreachable!()
364            }
365        }
366    };
367}
368
369impl_mixture!(A @ a, B @ b);
370impl_mixture!(A @ a, B @ b, C @ c);
371impl_mixture!(A @ a, B @ b, C @ c, D @ d);
372impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e);
373impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f);
374impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g);
375impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g, H @ h);
376impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g, H @ h, I @ i);
377impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g, H @ h, I @ i, J @ j);
378
379/// Models are defined by the type of their parameters and their probability functions
380pub trait Model: Send + Sync {
381    /// Type used to store the model parameters, e.g. `[f64; N]` or `Vec<f64>`
382    type Params: Params;
383
384    /// The logarithm of the probability determined by the model given the parameters stored in `state`, up to an addititive constant
385    ///
386    /// The sampler will only ever consider differences of these values, i.e. any addititive constant that does _not_ depend on `state` can be omitted when computing them.
387    fn log_prob(&self, state: &Self::Params) -> f64;
388}
389
390/// Runs the sampler on the given [`model`][Model] using the chosen [`move`][Move], [`schedule`][Schedule] and [`execution`][Execution] strategy
391///
392/// A reasonable default for the `move` is [`Stretch`].
393///
394/// A reasonable default for the `schedule` is [`MinChainLen`].
395///
396/// A reasonable default for the `execution` is [`Serial`].
397///
398/// The `walkers` iterator is used to initialise the ensemble of walkers by defining their initial parameter values and providing appropriately seeded PRNG instances.
399///
400/// The number of walkers must be non-zero, even and at least twice the number of parameters.
401///
402/// A vector of samples and the number of accepted moves are returned.
403pub fn sample<MD, MV, W, R, S, E>(
404    model: &MD,
405    move_: &MV,
406    walkers: W,
407    mut schedule: S,
408    execution: E,
409) -> (Vec<MD::Params>, usize)
410where
411    MD: Model,
412    MV: Move<MD::Params> + Send + Sync,
413    W: Iterator<Item = (MD::Params, R)>,
414    R: Rng + Send + Sync,
415    S: Schedule<MD::Params>,
416    E: Execution,
417{
418    let mut walkers = walkers
419        .map(|(state, rng)| Walker::new(model, state, rng))
420        .collect::<Vec<_>>();
421
422    assert!(!walkers.is_empty() && walkers.len() % 2 == 0);
423    assert!(walkers.len() >= 2 * walkers[0].state.dimension());
424
425    let mut chain =
426        Vec::with_capacity(walkers.len() * schedule.iterations(walkers.len()).unwrap_or(0));
427
428    let half = walkers.len() / 2;
429    let (lower_half, upper_half) = walkers.split_at_mut(half);
430
431    let random_index = Uniform::new(0, half).unwrap();
432
433    let update_walker = move |walker: &mut Walker<MD, R>, other_walkers: &[Walker<MD, R>]| {
434        walker.move_(model, move_, |rng| &other_walkers[random_index.sample(rng)])
435    };
436
437    while schedule.next_step(&chain).is_continue() {
438        execution.extend_chain(&mut chain, lower_half, |walker| {
439            update_walker(walker, upper_half)
440        });
441
442        execution.extend_chain(&mut chain, upper_half, |walker| {
443            update_walker(walker, lower_half)
444        });
445    }
446
447    let accepted = walkers.iter().map(|walker| walker.accepted).sum();
448
449    (chain, accepted)
450}
451
452struct Walker<MD, R>
453where
454    MD: Model,
455{
456    state: MD::Params,
457    log_prob: f64,
458    rng: R,
459    accepted: usize,
460}
461
462impl<MD, R> Walker<MD, R>
463where
464    MD: Model,
465    R: Rng,
466{
467    fn new(model: &MD, state: MD::Params, rng: R) -> Self {
468        let log_prob = model.log_prob(&state);
469
470        Self {
471            state,
472            log_prob,
473            rng,
474            accepted: 0,
475        }
476    }
477
478    fn move_<'a, MV, O>(&'a mut self, model: &MD, move_: &MV, mut other: O) -> MD::Params
479    where
480        MV: Move<MD::Params>,
481        O: FnMut(&mut R) -> &'a Self,
482    {
483        let (mut new_state, factor) =
484            move_.propose(&self.state, |rng| &other(rng).state, &mut self.rng);
485
486        let new_log_prob = model.log_prob(&new_state);
487
488        let log_prob_diff = factor + new_log_prob - self.log_prob;
489
490        if log_prob_diff > gen_unit(&mut self.rng).ln() {
491            self.state.clone_from(&new_state);
492            self.log_prob = new_log_prob;
493            self.accepted += 1;
494        } else {
495            new_state.clone_from(&self.state);
496        }
497
498        new_state
499    }
500}
501
502fn gen_unit<R>(rng: &mut R) -> f64
503where
504    R: Rng,
505{
506    StandardUniform.sample(rng)
507}
508
509/// Estimate the integrated auto-correlation time
510///
511/// Returns `None` if the chain length is considered insufficient for a reliable estimate.
512///
513/// `min_win_size` defines the factor between the estimate and the window size up to which the auto-correlation is computed. (Default value: 5)
514///
515/// `min_chain_len` defines the factor between the estimate and the chain length above which the estimate is considered reliable. (Default value: 50)
516pub fn auto_corr_time<C>(
517    chain: C,
518    min_win_size: Option<usize>,
519    min_chain_len: Option<usize>,
520) -> Option<f64>
521where
522    C: ExactSizeIterator<Item = f64> + Clone,
523{
524    let min_win_size = min_win_size.unwrap_or(5) as f64;
525    let min_chain_len = min_chain_len.unwrap_or(50) as f64;
526
527    let mean = chain.clone().sum::<f64>() / chain.len() as f64;
528
529    let variance = chain
530        .clone()
531        .map(|sample| (sample - mean).powi(2))
532        .sum::<f64>()
533        / chain.len() as f64;
534
535    let mut estimate = 1.;
536
537    for lag in 1..chain.len() {
538        let auto_corr = chain
539            .clone()
540            .skip(lag)
541            .zip(chain.clone())
542            .map(|(lhs, rhs)| (lhs - mean) * (rhs - mean))
543            .sum::<f64>()
544            / chain.len() as f64
545            / variance;
546
547        estimate += 2. * auto_corr;
548
549        if lag as f64 >= min_win_size * estimate {
550            break;
551        }
552    }
553
554    if chain.len() as f64 >= min_chain_len * estimate {
555        Some(estimate)
556    } else {
557        None
558    }
559}
560
561/// Determines how many iterations of the sampler are executed
562///
563/// Enables running the sampler until some condition based on
564/// the samples collected so far is fulfilled, for example using
565/// the [auto-correlation time][auto_corr_time].
566///
567/// The [`MinChainLen`] implementor provides a reasonable default.
568///
569/// It can also be used for progress reporting:
570///
571/// ```
572/// use std::ops::ControlFlow;
573///
574/// use hammer_and_sample::{Params, Schedule};
575///
576/// struct FixedIterationsWithProgress {
577///     done: usize,
578///     todo: usize,
579/// }
580///
581/// impl<P> Schedule<P> for FixedIterationsWithProgress
582/// where
583///     P: Params
584/// {
585///      fn next_step(&mut self, _chain: &[P]) -> ControlFlow<()> {
586///         if self.done == self.todo {
587///             eprintln!("100%");
588///
589///             ControlFlow::Break(())
590///         } else {
591///             self.done += 1;
592///
593///             if self.done % (self.todo / 100) == 0 {
594///                 eprintln!("{}% ", self.done / (self.todo / 100));
595///             }
596///
597///             ControlFlow::Continue(())
598///         }
599///     }
600///
601///     fn iterations(&self, _walkers: usize) -> Option<usize> {
602///         Some(self.todo)
603///     }
604/// }
605/// ```
606pub trait Schedule<P>
607where
608    P: Params,
609{
610    /// The next step in the schedule given the current `chain`, either [continue][ControlFlow::Continue] or [break][ControlFlow::Break]
611    fn next_step(&mut self, chain: &[P]) -> ControlFlow<()>;
612
613    /// If possible, compute a lower bound for the number of iterations given the number of `walkers`
614    fn iterations(&self, _walkers: usize) -> Option<usize> {
615        None
616    }
617}
618
619/// Runs the sampler until the given chain length is reached
620pub struct MinChainLen(pub usize);
621
622impl<P> Schedule<P> for MinChainLen
623where
624    P: Params,
625{
626    fn next_step(&mut self, chain: &[P]) -> ControlFlow<()> {
627        if self.0 <= chain.len() {
628            ControlFlow::Break(())
629        } else {
630            ControlFlow::Continue(())
631        }
632    }
633
634    fn iterations(&self, walkers: usize) -> Option<usize> {
635        Some(self.0 / walkers)
636    }
637}
638
639/// Runs the inner `schedule` after calling the given `callback`
640///
641/// ```
642/// # use hammer_and_sample::{sample, MinChainLen, Model, Schedule, Serial, Stretch, WithProgress};
643/// # use rand::SeedableRng;
644/// # use rand_pcg::Pcg64Mcg;
645/// #
646/// # struct Dummy;
647/// #
648/// # impl Model for Dummy {
649/// #     type Params = [f64; 1];
650/// #
651/// #     fn log_prob(&self, state: &Self::Params) -> f64 {
652/// #         f64::NEG_INFINITY
653/// #     }
654/// # }
655/// #
656/// # let model = Dummy;
657/// #
658/// # let walkers = (0..100).map(|idx| {
659/// #     let mut rng = Pcg64Mcg::seed_from_u64(idx);
660/// #
661/// #     ([0.], rng)
662/// # });
663/// #
664/// let schedule = WithProgress {
665///     schedule: MinChainLen(100_000),
666///     callback: |chain: &[_]| eprintln!("{} %", 100 * chain.len() / 100_000),
667/// };
668///
669/// let (chain, accepted) = sample(&model, &Stretch::default(), walkers, schedule, Serial);
670/// ```
671pub struct WithProgress<S, C> {
672    /// The inner schedule which determines the number of iterations
673    pub schedule: S,
674    /// The callback which is executed after each iteration
675    pub callback: C,
676}
677
678impl<P, S, C> Schedule<P> for WithProgress<S, C>
679where
680    P: Params,
681    S: Schedule<P>,
682    C: FnMut(&[P]),
683{
684    fn next_step(&mut self, chain: &[P]) -> ControlFlow<()> {
685        (self.callback)(chain);
686
687        self.schedule.next_step(chain)
688    }
689
690    fn iterations(&self, walkers: usize) -> Option<usize> {
691        self.schedule.iterations(walkers)
692    }
693}
694
695/// Execution strategy for `update`ing an ensemble of `walkers` to extend the given `chain`
696pub trait Execution {
697    /// Must call `update` exactly once for all elements of `walkers` and store the results in `chain`
698    fn extend_chain<P, W, U>(&self, chain: &mut Vec<P>, walkers: &mut [W], update: U)
699    where
700        P: Send + Sync,
701        W: Send + Sync,
702        U: Fn(&mut W) -> P + Send + Sync;
703}
704
705/// Serial execution strategy which updates walkers using a single thread
706pub struct Serial;
707
708impl Execution for Serial {
709    fn extend_chain<P, W, U>(&self, chain: &mut Vec<P>, walkers: &mut [W], update: U)
710    where
711        P: Send + Sync,
712        W: Send + Sync,
713        U: Fn(&mut W) -> P + Send + Sync,
714    {
715        chain.extend(walkers.iter_mut().map(update));
716    }
717}
718
719#[cfg(feature = "rayon")]
720/// Parallel execution strategy which updates walkers using Rayon's thread pool
721pub struct Parallel;
722
723#[cfg(feature = "rayon")]
724impl Execution for Parallel {
725    fn extend_chain<P, W, U>(&self, chain: &mut Vec<P>, walkers: &mut [W], update: U)
726    where
727        P: Send + Sync,
728        W: Send + Sync,
729        U: Fn(&mut W) -> P + Send + Sync,
730    {
731        chain.par_extend(walkers.par_iter_mut().map(update));
732    }
733}