Skip to main content

distr_combinators/
lib.rs

1//! Combinators for [`rand::distr::Distribution`] objects
2//!
3//! This module provides tools to compose and transform probability distributions
4//! in a functional style. It is designed to be generic and independent of any
5//! specific data structures.
6
7use rand::Rng;
8use rand_distr::Distribution;
9use rand::distr::{Iter, Uniform};
10use std::marker::PhantomData;
11
12/// A convenience to analyze the distribution of generated values
13pub mod histogram;
14pub use histogram::Histogram;
15
16/// A distribution that samples from `d` but rejects values that do not satisfy predicate `p`.
17#[derive(Clone)]
18pub struct Filtered<T, D : Distribution<T> + Clone, P : Fn(&T) -> bool> { pub d: D, pub p: P, pub pd: PhantomData<T> }
19impl <T, D : Distribution<T> + Clone, P : Fn(&T) -> bool> Distribution<T> for Filtered<T, D, P> {
20  fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
21    loop {
22      let s = self.d.sample(rng);
23      if (self.p)(&s) { return s }
24    }
25  }
26}
27
28/// A distribution that maps values sampled from `d` using function `f`.
29#[derive(Clone)]
30pub struct Mapped<T, S, D : Distribution<T> + Clone, F : Fn(T) -> S + Clone> { pub d: D, pub f: F, pub pd: PhantomData<(T, S)> }
31impl <T, S, D : Distribution<T> + Clone, F : Fn(T) -> S + Clone> Distribution<S> for Mapped<T, S, D, F> {
32  fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> S {
33    (self.f)(self.d.sample(rng))
34  }
35}
36
37/// A distribution that samples from `d`, maps using `pf`, and retries if `pf` returns `None`.
38#[derive(Clone)]
39pub struct Collected<T, S, D : Distribution<T> + Clone, P : Fn(T) -> Option<S> + Clone> { pub d: D, pub pf: P, pub pd: PhantomData<(T, S)> }
40impl <T, S, D : Distribution<T> + Clone, P : Fn(T) -> Option<S> + Clone> Distribution<S> for Collected<T, S, D, P> {
41  fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> S {
42    loop {
43      let t = self.d.sample(rng);
44      match (self.pf)(t) {
45        None => {}
46        Some(s) => { return s }
47      }
48    }
49  }
50}
51
52/// A distribution that combines two independent distributions `dx` and `dy` using a function `f`.
53#[derive(Clone)]
54pub struct Product2<X, DX : Distribution<X> + Clone, Y, DY : Distribution<Y> + Clone, Z, F : Fn(X, Y) -> Z + Clone> { pub dx: DX, pub dy: DY, pub f: F,
55  pub pd: PhantomData<(X, Y, Z)> }
56impl <X, DX : Distribution<X> + Clone, Y, DY : Distribution<Y> + Clone, Z, F : Fn(X, Y) -> Z + Clone> Distribution<Z> for Product2<X, DX, Y, DY, Z, F> {
57  fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Z {
58    (self.f)(self.dx.sample(rng), self.dy.sample(rng))
59  }
60}
61
62/// A distribution that chooses between `dx` and `dy` based on a boolean distribution `db`.
63#[derive(Clone)]
64pub struct Choice2<X, DX : Distribution<X> + Clone, Y, DY : Distribution<Y> + Clone, DB : Distribution<bool> + Clone> { pub dx: DX, pub dy: DY, pub db: DB,
65  pub pd: PhantomData<(X, Y)> }
66impl <X, DX : Distribution<X> + Clone, Y, DY : Distribution<Y> + Clone, DB : Distribution<bool> + Clone> Distribution<Result<X, Y>> for Choice2<X, DX, Y, DY, DB> {
67  fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<X, Y> {
68    if self.db.sample(rng) { Ok(self.dx.sample(rng)) }
69    else { Err(self.dy.sample(rng)) }
70  }
71}
72
73/// A distribution where the choice of the second distribution depends on the value sampled from the first.
74#[derive(Clone)]
75pub struct Dependent2<X, DX : Distribution<X> + Clone, Y, DY : Distribution<Y> + Clone, FDY : Fn(X) -> DY + Clone> { pub dx: DX, pub fdy: FDY,
76  pub pd: PhantomData<(X, Y)> }
77impl <X, DX : Distribution<X> + Clone, Y, DY : Distribution<Y> + Clone, FDY : Fn(X) -> DY + Clone> Distribution<Y> for Dependent2<X, DX, Y, DY, FDY> {
78  fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Y {
79    (self.fdy)(self.dx.sample(rng)).sample(rng)
80  }
81}
82
83/// A stateful distribution that accumulates samples from `dx` into state `z` until `fa` returns a value.
84#[derive(Clone)]
85pub struct Concentrated<X, DX : Distribution<X> + Clone, A : Clone, Y, FA : Fn(&mut A, X) -> Option<Y>> { pub dx: DX, pub z: A, pub fa: FA,
86  pub pd: PhantomData<(X, Y)> }
87impl <X, DX : Distribution<X> + Clone, A : Clone, Y, FA : Fn(&mut A, X) -> Option<Y>> Distribution<Y> for Concentrated<X, DX, A, Y, FA> {
88  fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Y {
89    let mut a = self.z.clone();
90    loop {
91      match (self.fa)(&mut a, self.dx.sample(rng)) {
92        None => {}
93        Some(y) => { return y }
94      }
95    }
96  }
97}
98
99/// A distribution that expands a single sample from `dx` into a sequence of values.
100#[derive(Clone)]
101pub struct Diluted<X, DX : Distribution<X> + Clone, A : Clone, Y, FA : Fn(X) -> A, FAY : Fn(&mut A) -> Option<Y>> { pub dx: DX, pub fa: FA, pub fay: FAY,
102  pub pd: PhantomData<(X, A, Y)> }
103impl <X, DX : Distribution<X> + Clone, A : Clone, Y, FA : Fn(X) -> A, FAY : Fn(&mut A) -> Option<Y>> Distribution<Y> for Diluted<X, DX, A, Y, FA, FAY> {
104  fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Y {
105    let mut a = (self.fa)(self.dx.sample(rng));
106    (self.fay)(&mut a).expect("fay returns at least once per fa call")
107  }
108
109  fn sample_iter<R>(self, _rng: R) -> Iter<Self, R, Y> where R : Rng, Self : Sized {
110    panic!("This function returning a concrete object makes it impossible to override the iterator behavior")
111  }
112}
113
114/// A constant distribution that always returns `element`.
115#[derive(Clone)]
116pub struct Degenerate<T : Clone> { pub element: T }
117impl <T : Clone> Distribution<T> for Degenerate<T> {
118  fn sample<R: Rng + ?Sized>(&self, _rng: &mut R) -> T {
119    self.element.clone()
120  }
121}
122
123/// A categorical distribution that selects from `elements` based on an index distribution `ed`.
124#[derive(Clone)]
125pub struct Categorical<T : Clone, ElemD : Distribution<usize> + Clone> { pub elements: Vec<T>, pub ed: ElemD }
126impl <T : Clone, ElemD : Distribution<usize> + Clone> Distribution<T> for Categorical<T, ElemD> {
127  fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
128    self.elements[self.ed.sample(rng)].clone()
129  }
130}
131
132/// Creates a categorical distribution where elements are chosen with probability proportional to their weights.
133pub fn ratios<T : Clone>(ep: impl IntoIterator<Item=(T, usize)>) -> Categorical<T, Mapped<usize, usize, Uniform<usize>, impl Fn(usize) -> usize + Clone>> {
134  let mut elements = vec![];
135  let mut cdf = vec![];
136  let mut sum = 0;
137  for (e, r) in ep.into_iter() {
138    elements.push(e);
139    cdf.push(sum);
140    sum += r;
141  }
142  let us = Uniform::try_from(0..sum).unwrap();
143  Categorical {
144    elements,
145    // it's much cheaper to draw many samples at once, but the current Distribution API is broken
146    ed: Mapped{ d: us, f: move |x| { match cdf.binary_search(&x) {
147      Ok(i) => { i }
148      Err(i) => { i - 1 }
149    }}, pd: PhantomData::default() }
150  }
151}
152
153/// A distribution that generates a vector of items with length sampled from `lengthd` and items from `itemd`.
154#[derive(Clone)]
155pub struct Repeated<T, LengthD : Distribution<usize>, ItemD : Distribution<T>> { pub lengthd: LengthD, pub itemd: ItemD, pub pd: PhantomData<T> }
156impl <T, LengthD : Distribution<usize>, ItemD : Distribution<T>> Distribution<Vec<T>> for Repeated<T, LengthD, ItemD> {
157  fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<T> {
158    let l = self.lengthd.sample(rng);
159    Vec::from_iter(std::iter::repeat_with(|| self.itemd.sample(rng)).take(l))
160  }
161}
162
163/// A distribution that generates a vector of bytes by sampling from `mbd` until `None` is returned.
164#[derive(Clone)]
165pub struct Sentinel<MByteD : Distribution<Option<u8>> + Clone> { pub mbd: MByteD }
166impl <MByteD : Distribution<Option<u8>> + Clone> Distribution<Vec<u8>> for Sentinel<MByteD> {
167  fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<u8> {
168    let mut v = vec![];
169    while let Some(e) = self.mbd.sample(rng) {
170      v.push(e)
171    }
172    v
173  }
174}
175
176#[cfg(test)]
177mod tests {
178  use rand::rngs::StdRng;
179  use rand::SeedableRng;
180  use rand::distr::Uniform;
181  use crate::*;
182
183  #[test]
184  fn monte_carlo_pi() {
185    #[cfg(not(miri))]
186    const SAMPLES: usize = 100000;
187    #[cfg(miri)]
188    const SAMPLES: usize = 100;
189
190    let rng = StdRng::from_seed([0; 32]);
191    let sx = Uniform::new(0.0, 1.0).unwrap();
192    let sy = Uniform::new(0.0, 1.0).unwrap();
193    let sxy = Product2 { dx: sx, dy: sy, f: |x, y| (x, y), pd: PhantomData::default() };
194    let spi = Concentrated { dx: sxy, z: (0, 0), fa: |i_o, (x, y)| {
195      if x*x + y*y < 1.0 { i_o.0 += 1 } else { i_o.1 += 1 }
196      if i_o.0 + i_o.1 > SAMPLES { Some(4f64*(i_o.0 as f64/(i_o.0 + i_o.1) as f64)) } else { None }
197    }, pd: Default::default() };
198
199    spi.sample_iter(rng).take(10).for_each(|api| {
200      let err_bar = 3.5f64 / (SAMPLES as f64).sqrt();
201      assert!(std::f64::consts::PI-err_bar <= api && std::f64::consts::PI+err_bar >= api)
202    });
203  }
204
205  #[test]
206  fn categorical_samples() {
207    #[cfg(not(miri))]
208    const SAMPLES: usize = 1000;
209    #[cfg(miri)]
210    const SAMPLES: usize = 141;
211
212    let rng = StdRng::from_seed([0; 32]);
213    let expected = [('b', 2usize), ('a', 10), ('c', 29), ('d', 100)];
214    let cd = ratios(expected.into_iter());
215    let hist = Histogram::from_iter(cd.sample_iter(rng).take(SAMPLES*(10+2+29+100)));
216    let achieved: Vec<(char, usize)> = hist.iter().map(|(k, c)|
217      (*k, ((c as f64)/(SAMPLES as f64)).round() as usize)).collect();
218    assert_eq!(&expected[..], &achieved[..]);
219  }
220}