hammer_and_sample/lib.rs
1#![forbid(unsafe_code)]
2#![deny(missing_docs)]
3//! Simplistic MCMC ensemble sampler based on [emcee](https://emcee.readthedocs.io/), the MCMC hammer
4//!
5//! ```
6//! use hammer_and_sample::{sample, MinChainLen, Model, Serial, Stretch};
7//! use rand::{Rng, SeedableRng};
8//! use rand_pcg::Pcg64;
9//!
10//! fn estimate_bias(coin_flips: &[bool]) -> f64 {
11//! struct CoinFlips<'a>(&'a [bool]);
12//!
13//! impl Model for CoinFlips<'_> {
14//! type Params = [f64; 1];
15//!
16//! // likelihood of Bernoulli distribution and uninformative prior
17//! fn log_prob(&self, &[p]: &Self::Params) -> f64 {
18//! if p < 0. || p > 1. {
19//! return f64::NEG_INFINITY;
20//! }
21//!
22//! let ln_p = p.ln();
23//! let ln_1_p = (1. - p).ln();
24//!
25//! self.0
26//! .iter()
27//! .map(|coin_flip| if *coin_flip { ln_p } else { ln_1_p })
28//! .sum()
29//! }
30//! }
31//!
32//! let model = CoinFlips(coin_flips);
33//!
34//! let walkers = (0..10).map(|seed| {
35//! let mut rng = Pcg64::seed_from_u64(seed);
36//!
37//! let p = rng.random_range(0.0..=1.0);
38//!
39//! ([p], rng)
40//! });
41//!
42//! let (chain, _accepted) = sample(&model, &Stretch::default(), walkers, MinChainLen(10 * 1000), Serial);
43//!
44//! // 100 iterations of 10 walkers as burn-in
45//! let chain = &chain[10 * 100..];
46//!
47//! chain.iter().map(|&[p]| p).sum::<f64>() / chain.len() as f64
48//! }
49//! ```
50use std::ops::ControlFlow;
51use std::ptr;
52
53use rand::{
54 distr::{Distribution, StandardUniform, Uniform},
55 Rng,
56};
57use rand_distr::{
58 weighted::{AliasableWeight, WeightedAliasIndex},
59 Normal,
60};
61#[cfg(feature = "rayon")]
62use rayon::iter::{IntoParallelRefMutIterator, ParallelExtend, ParallelIterator};
63
64/// Model parameters defining the state space of the Markov chain
65pub trait Params: Send + Sync + Clone {
66 /// Dimension of the state space
67 ///
68 /// This can depend on `self` in situations where the number of parameters depends on the data itself, e.g. the number of groups in a hierarchical model.
69 fn dimension(&self) -> usize;
70
71 /// Access the individual parameter values as an iterator
72 fn values(&self) -> impl Iterator<Item = &f64>;
73
74 /// Collect new parameters from the given iterator
75 fn collect(iter: impl Iterator<Item = f64>) -> Self;
76}
77
78/// Model parameters stored as an array of length `N` considered as an element of the vector space `R^N`
79impl<const N: usize> Params for [f64; N] {
80 fn dimension(&self) -> usize {
81 N
82 }
83
84 fn values(&self) -> impl Iterator<Item = &f64> {
85 self.iter()
86 }
87
88 fn collect(iter: impl Iterator<Item = f64>) -> Self {
89 let mut new = [0.; N];
90 iter.enumerate().for_each(|(idx, value)| new[idx] = value);
91 new
92 }
93}
94
95/// Model parameters stored as a vector of length `n` considered as an element of the vector space `R^n`
96impl Params for Vec<f64> {
97 fn dimension(&self) -> usize {
98 self.len()
99 }
100
101 fn values(&self) -> impl Iterator<Item = &f64> {
102 self.iter()
103 }
104
105 fn collect(iter: impl Iterator<Item = f64>) -> Self {
106 iter.collect()
107 }
108}
109
110/// Model parameters stored as a boxed slice of length `n` considered as an element of the vector space `R^n`
111impl Params for Box<[f64]> {
112 fn dimension(&self) -> usize {
113 self.len()
114 }
115
116 fn values(&self) -> impl Iterator<Item = &f64> {
117 self.iter()
118 }
119
120 fn collect(iter: impl Iterator<Item = f64>) -> Self {
121 iter.collect()
122 }
123}
124
125/// A move defines how new estimates of the model parameters are proposed
126pub trait Move<P>
127where
128 P: Params,
129{
130 /// Propose new estimates of the model parameters
131 ///
132 /// The proposal is based on the current estimate `this` and
133 /// optionally, randomly sampled estimates of `other` walkers.
134 ///
135 /// In addition to the new estimate, a correction factor to be added
136 /// to the difference of logarithmic probabilities can be returned.
137 fn propose<'a, O, R>(&self, this: &'a P, other: O, rng: &mut R) -> (P, f64)
138 where
139 O: FnMut(&mut R) -> &'a P,
140 R: Rng;
141}
142
143/// The "stretch" move orignally used by the emcee sampler
144///
145/// Symmetric affine invariant move as described in [Goodman & Weare (2010)](https://msp.org/camcos/2010/5-1/p04.xhtml).
146pub struct Stretch {
147 scale: f64,
148}
149
150impl Stretch {
151 /// Construct a "stretch" move using the given `scale` parameter
152 pub fn new(scale: f64) -> Self {
153 Self { scale }
154 }
155}
156
157impl Default for Stretch {
158 fn default() -> Self {
159 Self::new(2.)
160 }
161}
162
163impl<P> Move<P> for Stretch
164where
165 P: Params,
166{
167 fn propose<'a, O, R>(&self, this: &'a P, mut other: O, rng: &mut R) -> (P, f64)
168 where
169 O: FnMut(&mut R) -> &'a P,
170 R: Rng,
171 {
172 let other = other(rng);
173
174 let z = ((self.scale - 1.) * gen_unit(rng) + 1.).powi(2) / self.scale;
175
176 let new_state = P::collect(
177 this.values()
178 .zip(other.values())
179 .map(|(this, other)| (this - other).mul_add(z, *other)),
180 );
181
182 let factor = (new_state.dimension() - 1) as f64 * z.ln();
183
184 (new_state, factor)
185 }
186}
187
188/// Move using differential evolution based on two other walkers
189///
190/// Using a normal distribution to scale the proposal as described in [Nelson et al. (2013)](https://iopscience.iop.org/article/10.1088/0067-0049/210/1/11).
191pub struct DifferentialEvolution {
192 gamma: Normal<f64>,
193}
194
195impl DifferentialEvolution {
196 /// Construct a differential evolution move using a normal distribution `gamma`
197 ///
198 /// A reasonable default for `gamma_mean` is `2.38 / (2 * N).sqrt()` where `N` is the dimension of the state space.
199 ///
200 /// A reasonable default for `gamma_std_dev` is `1.0e-5`.
201 pub fn new(gamma_mean: f64, gamma_std_dev: f64) -> Self {
202 Self {
203 gamma: Normal::new(gamma_mean, gamma_std_dev).unwrap(),
204 }
205 }
206}
207
208impl<P> Move<P> for DifferentialEvolution
209where
210 P: Params,
211{
212 fn propose<'a, O, R>(&self, this: &'a P, mut other: O, rng: &mut R) -> (P, f64)
213 where
214 O: FnMut(&mut R) -> &'a P,
215 R: Rng,
216 {
217 let first_other = other(rng);
218 let mut second_other = other(rng);
219
220 while ptr::eq(first_other, second_other) {
221 second_other = other(rng);
222 }
223
224 let gamma = self.gamma.sample(rng);
225
226 let new_state = P::collect(
227 this.values()
228 .zip(first_other.values())
229 .zip(second_other.values())
230 .map(|((this, first_other), second_other)| {
231 (first_other - second_other).mul_add(gamma, *this)
232 }),
233 );
234
235 (new_state, 0.)
236 }
237}
238
239/// A Metropolis step with a Gaussian proposal function
240///
241/// For each step, a direction is choosen randomly and
242/// the displacement is sampled from a centered normal distribution.
243pub struct RandomGaussian {
244 displ: Normal<f64>,
245}
246
247impl RandomGaussian {
248 /// Construct a move using the given standard deviation of the displacement `displ`
249 pub fn new(displ: f64) -> Self {
250 Self {
251 displ: Normal::new(0., displ).unwrap(),
252 }
253 }
254}
255
256impl<P> Move<P> for RandomGaussian
257where
258 P: Params,
259{
260 fn propose<'a, O, R>(&self, this: &'a P, _other: O, rng: &mut R) -> (P, f64)
261 where
262 O: FnMut(&mut R) -> &'a P,
263 R: Rng,
264 {
265 let dir = rng.random_range(0..this.dimension());
266
267 let new_state = P::collect(this.values().enumerate().map(|(idx, value)| {
268 if idx == dir {
269 value + self.displ.sample(rng)
270 } else {
271 *value
272 }
273 }));
274
275 (new_state, 0.)
276 }
277}
278
279/// Combines multiple moves into a single mixture
280///
281/// Mixtures are constructed from tuples of `(Move, Weight)` pairs.
282///
283/// For each step, a single move is selected to determine the next proposal.
284/// The probability of selecting a given move is determined by its relative weight.
285///
286/// ```
287/// # use hammer_and_sample::{sample, MinChainLen, Mixture, Model, RandomGaussian, Serial, Stretch};
288/// # use rand::SeedableRng;
289/// # use rand_pcg::Pcg64Mcg;
290/// #
291/// # struct Dummy;
292/// #
293/// # impl Model for Dummy {
294/// # type Params = [f64; 1];
295/// #
296/// # fn log_prob(&self, state: &Self::Params) -> f64 {
297/// # f64::NEG_INFINITY
298/// # }
299/// # }
300/// #
301/// # let model = Dummy;
302/// #
303/// # let walkers = (0..100).map(|idx| {
304/// # let mut rng = Pcg64Mcg::seed_from_u64(idx);
305/// #
306/// # ([0.], rng)
307/// # });
308/// #
309/// let move_ = Mixture::from((
310/// (Stretch::default(), 2),
311/// (RandomGaussian::new(1.0e-3), 1),
312/// ));
313///
314/// let (chain, accepted) = sample(&model, &move_, walkers, MinChainLen(100_000), Serial);
315/// ```
316pub struct Mixture<W, M>(WeightedAliasIndex<W>, M)
317where
318 W: AliasableWeight;
319
320macro_rules! impl_mixture {
321 ( $( $types:ident @ $weights:ident ),+ ) => {
322 impl<W, $( $types ),+> From<( $( ( $types, W ) ),+ )> for Mixture<W, ( $( $types ),+ )>
323 where
324 W: AliasableWeight
325 {
326 #[allow(non_snake_case)]
327 fn from(( $( ( $types, $weights ) ),+ ): ( $( ( $types, W ) ),+ )) -> Self {
328 let index = WeightedAliasIndex::new(vec![$( $weights ),+]).unwrap();
329
330 Self(index, ( $( $types ),+ ))
331 }
332 }
333
334 impl<W, $( $types ),+, P> Move<P> for Mixture<W, ( $( $types ),+ )>
335 where
336 W: AliasableWeight,
337 P: Params,
338 $( $types: Move<P> ),+
339 {
340 #[allow(non_snake_case)]
341 fn propose<'a, O, R>(&self, this: &'a P, other: O, rng: &mut R) -> (P, f64)
342 where
343 O: FnMut(&mut R) -> &'a P,
344 R: Rng,
345 {
346 let Self(index, ( $( $types ),+ )) = self;
347
348 let chosen_index = index.sample(rng);
349
350 let mut index = 0;
351
352 $(
353
354 #[allow(unused_assignments)]
355 if chosen_index == index {
356 return $types.propose(this, other, rng)
357 } else {
358 index += 1;
359 }
360
361 )+
362
363 unreachable!()
364 }
365 }
366 };
367}
368
369impl_mixture!(A @ a, B @ b);
370impl_mixture!(A @ a, B @ b, C @ c);
371impl_mixture!(A @ a, B @ b, C @ c, D @ d);
372impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e);
373impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f);
374impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g);
375impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g, H @ h);
376impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g, H @ h, I @ i);
377impl_mixture!(A @ a, B @ b, C @ c, D @ d, E @ e, F @ f, G @ g, H @ h, I @ i, J @ j);
378
379/// Models are defined by the type of their parameters and their probability functions
380pub trait Model: Send + Sync {
381 /// Type used to store the model parameters, e.g. `[f64; N]` or `Vec<f64>`
382 type Params: Params;
383
384 /// The logarithm of the probability determined by the model given the parameters stored in `state`, up to an addititive constant
385 ///
386 /// The sampler will only ever consider differences of these values, i.e. any addititive constant that does _not_ depend on `state` can be omitted when computing them.
387 fn log_prob(&self, state: &Self::Params) -> f64;
388}
389
390/// Runs the sampler on the given [`model`][Model] using the chosen [`move`][Move], [`schedule`][Schedule] and [`execution`][Execution] strategy
391///
392/// A reasonable default for the `move` is [`Stretch`].
393///
394/// A reasonable default for the `schedule` is [`MinChainLen`].
395///
396/// A reasonable default for the `execution` is [`Serial`].
397///
398/// The `walkers` iterator is used to initialise the ensemble of walkers by defining their initial parameter values and providing appropriately seeded PRNG instances.
399///
400/// The number of walkers must be non-zero, even and at least twice the number of parameters.
401///
402/// A vector of samples and the number of accepted moves are returned.
403pub fn sample<MD, MV, W, R, S, E>(
404 model: &MD,
405 move_: &MV,
406 walkers: W,
407 mut schedule: S,
408 execution: E,
409) -> (Vec<MD::Params>, usize)
410where
411 MD: Model,
412 MV: Move<MD::Params> + Send + Sync,
413 W: Iterator<Item = (MD::Params, R)>,
414 R: Rng + Send + Sync,
415 S: Schedule<MD::Params>,
416 E: Execution,
417{
418 let mut walkers = walkers
419 .map(|(state, rng)| Walker::new(model, state, rng))
420 .collect::<Vec<_>>();
421
422 assert!(!walkers.is_empty() && walkers.len() % 2 == 0);
423 assert!(walkers.len() >= 2 * walkers[0].state.dimension());
424
425 let mut chain =
426 Vec::with_capacity(walkers.len() * schedule.iterations(walkers.len()).unwrap_or(0));
427
428 let half = walkers.len() / 2;
429 let (lower_half, upper_half) = walkers.split_at_mut(half);
430
431 let random_index = Uniform::new(0, half).unwrap();
432
433 let update_walker = move |walker: &mut Walker<MD, R>, other_walkers: &[Walker<MD, R>]| {
434 walker.move_(model, move_, |rng| &other_walkers[random_index.sample(rng)])
435 };
436
437 while schedule.next_step(&chain).is_continue() {
438 execution.extend_chain(&mut chain, lower_half, |walker| {
439 update_walker(walker, upper_half)
440 });
441
442 execution.extend_chain(&mut chain, upper_half, |walker| {
443 update_walker(walker, lower_half)
444 });
445 }
446
447 let accepted = walkers.iter().map(|walker| walker.accepted).sum();
448
449 (chain, accepted)
450}
451
452struct Walker<MD, R>
453where
454 MD: Model,
455{
456 state: MD::Params,
457 log_prob: f64,
458 rng: R,
459 accepted: usize,
460}
461
462impl<MD, R> Walker<MD, R>
463where
464 MD: Model,
465 R: Rng,
466{
467 fn new(model: &MD, state: MD::Params, rng: R) -> Self {
468 let log_prob = model.log_prob(&state);
469
470 Self {
471 state,
472 log_prob,
473 rng,
474 accepted: 0,
475 }
476 }
477
478 fn move_<'a, MV, O>(&'a mut self, model: &MD, move_: &MV, mut other: O) -> MD::Params
479 where
480 MV: Move<MD::Params>,
481 O: FnMut(&mut R) -> &'a Self,
482 {
483 let (mut new_state, factor) =
484 move_.propose(&self.state, |rng| &other(rng).state, &mut self.rng);
485
486 let new_log_prob = model.log_prob(&new_state);
487
488 let log_prob_diff = factor + new_log_prob - self.log_prob;
489
490 if log_prob_diff > gen_unit(&mut self.rng).ln() {
491 self.state.clone_from(&new_state);
492 self.log_prob = new_log_prob;
493 self.accepted += 1;
494 } else {
495 new_state.clone_from(&self.state);
496 }
497
498 new_state
499 }
500}
501
502fn gen_unit<R>(rng: &mut R) -> f64
503where
504 R: Rng,
505{
506 StandardUniform.sample(rng)
507}
508
509/// Estimate the integrated auto-correlation time
510///
511/// Returns `None` if the chain length is considered insufficient for a reliable estimate.
512///
513/// `min_win_size` defines the factor between the estimate and the window size up to which the auto-correlation is computed. (Default value: 5)
514///
515/// `min_chain_len` defines the factor between the estimate and the chain length above which the estimate is considered reliable. (Default value: 50)
516pub fn auto_corr_time<C>(
517 chain: C,
518 min_win_size: Option<usize>,
519 min_chain_len: Option<usize>,
520) -> Option<f64>
521where
522 C: ExactSizeIterator<Item = f64> + Clone,
523{
524 let min_win_size = min_win_size.unwrap_or(5) as f64;
525 let min_chain_len = min_chain_len.unwrap_or(50) as f64;
526
527 let mean = chain.clone().sum::<f64>() / chain.len() as f64;
528
529 let variance = chain
530 .clone()
531 .map(|sample| (sample - mean).powi(2))
532 .sum::<f64>()
533 / chain.len() as f64;
534
535 let mut estimate = 1.;
536
537 for lag in 1..chain.len() {
538 let auto_corr = chain
539 .clone()
540 .skip(lag)
541 .zip(chain.clone())
542 .map(|(lhs, rhs)| (lhs - mean) * (rhs - mean))
543 .sum::<f64>()
544 / chain.len() as f64
545 / variance;
546
547 estimate += 2. * auto_corr;
548
549 if lag as f64 >= min_win_size * estimate {
550 break;
551 }
552 }
553
554 if chain.len() as f64 >= min_chain_len * estimate {
555 Some(estimate)
556 } else {
557 None
558 }
559}
560
561/// Determines how many iterations of the sampler are executed
562///
563/// Enables running the sampler until some condition based on
564/// the samples collected so far is fulfilled, for example using
565/// the [auto-correlation time][auto_corr_time].
566///
567/// The [`MinChainLen`] implementor provides a reasonable default.
568///
569/// It can also be used for progress reporting:
570///
571/// ```
572/// use std::ops::ControlFlow;
573///
574/// use hammer_and_sample::{Params, Schedule};
575///
576/// struct FixedIterationsWithProgress {
577/// done: usize,
578/// todo: usize,
579/// }
580///
581/// impl<P> Schedule<P> for FixedIterationsWithProgress
582/// where
583/// P: Params
584/// {
585/// fn next_step(&mut self, _chain: &[P]) -> ControlFlow<()> {
586/// if self.done == self.todo {
587/// eprintln!("100%");
588///
589/// ControlFlow::Break(())
590/// } else {
591/// self.done += 1;
592///
593/// if self.done % (self.todo / 100) == 0 {
594/// eprintln!("{}% ", self.done / (self.todo / 100));
595/// }
596///
597/// ControlFlow::Continue(())
598/// }
599/// }
600///
601/// fn iterations(&self, _walkers: usize) -> Option<usize> {
602/// Some(self.todo)
603/// }
604/// }
605/// ```
606pub trait Schedule<P>
607where
608 P: Params,
609{
610 /// The next step in the schedule given the current `chain`, either [continue][ControlFlow::Continue] or [break][ControlFlow::Break]
611 fn next_step(&mut self, chain: &[P]) -> ControlFlow<()>;
612
613 /// If possible, compute a lower bound for the number of iterations given the number of `walkers`
614 fn iterations(&self, _walkers: usize) -> Option<usize> {
615 None
616 }
617}
618
619/// Runs the sampler until the given chain length is reached
620pub struct MinChainLen(pub usize);
621
622impl<P> Schedule<P> for MinChainLen
623where
624 P: Params,
625{
626 fn next_step(&mut self, chain: &[P]) -> ControlFlow<()> {
627 if self.0 <= chain.len() {
628 ControlFlow::Break(())
629 } else {
630 ControlFlow::Continue(())
631 }
632 }
633
634 fn iterations(&self, walkers: usize) -> Option<usize> {
635 Some(self.0 / walkers)
636 }
637}
638
639/// Runs the inner `schedule` after calling the given `callback`
640///
641/// ```
642/// # use hammer_and_sample::{sample, MinChainLen, Model, Schedule, Serial, Stretch, WithProgress};
643/// # use rand::SeedableRng;
644/// # use rand_pcg::Pcg64Mcg;
645/// #
646/// # struct Dummy;
647/// #
648/// # impl Model for Dummy {
649/// # type Params = [f64; 1];
650/// #
651/// # fn log_prob(&self, state: &Self::Params) -> f64 {
652/// # f64::NEG_INFINITY
653/// # }
654/// # }
655/// #
656/// # let model = Dummy;
657/// #
658/// # let walkers = (0..100).map(|idx| {
659/// # let mut rng = Pcg64Mcg::seed_from_u64(idx);
660/// #
661/// # ([0.], rng)
662/// # });
663/// #
664/// let schedule = WithProgress {
665/// schedule: MinChainLen(100_000),
666/// callback: |chain: &[_]| eprintln!("{} %", 100 * chain.len() / 100_000),
667/// };
668///
669/// let (chain, accepted) = sample(&model, &Stretch::default(), walkers, schedule, Serial);
670/// ```
671pub struct WithProgress<S, C> {
672 /// The inner schedule which determines the number of iterations
673 pub schedule: S,
674 /// The callback which is executed after each iteration
675 pub callback: C,
676}
677
678impl<P, S, C> Schedule<P> for WithProgress<S, C>
679where
680 P: Params,
681 S: Schedule<P>,
682 C: FnMut(&[P]),
683{
684 fn next_step(&mut self, chain: &[P]) -> ControlFlow<()> {
685 (self.callback)(chain);
686
687 self.schedule.next_step(chain)
688 }
689
690 fn iterations(&self, walkers: usize) -> Option<usize> {
691 self.schedule.iterations(walkers)
692 }
693}
694
695/// Execution strategy for `update`ing an ensemble of `walkers` to extend the given `chain`
696pub trait Execution {
697 /// Must call `update` exactly once for all elements of `walkers` and store the results in `chain`
698 fn extend_chain<P, W, U>(&self, chain: &mut Vec<P>, walkers: &mut [W], update: U)
699 where
700 P: Send + Sync,
701 W: Send + Sync,
702 U: Fn(&mut W) -> P + Send + Sync;
703}
704
705/// Serial execution strategy which updates walkers using a single thread
706pub struct Serial;
707
708impl Execution for Serial {
709 fn extend_chain<P, W, U>(&self, chain: &mut Vec<P>, walkers: &mut [W], update: U)
710 where
711 P: Send + Sync,
712 W: Send + Sync,
713 U: Fn(&mut W) -> P + Send + Sync,
714 {
715 chain.extend(walkers.iter_mut().map(update));
716 }
717}
718
719#[cfg(feature = "rayon")]
720/// Parallel execution strategy which updates walkers using Rayon's thread pool
721pub struct Parallel;
722
723#[cfg(feature = "rayon")]
724impl Execution for Parallel {
725 fn extend_chain<P, W, U>(&self, chain: &mut Vec<P>, walkers: &mut [W], update: U)
726 where
727 P: Send + Sync,
728 W: Send + Sync,
729 U: Fn(&mut W) -> P + Send + Sync,
730 {
731 chain.par_extend(walkers.par_iter_mut().map(update));
732 }
733}