1use rand::Rng;
8use rand_distr::Distribution;
9use rand::distr::{Iter, Uniform};
10use std::marker::PhantomData;
11
12pub mod histogram;
14pub use histogram::Histogram;
15
16#[derive(Clone)]
18pub struct Filtered<T, D : Distribution<T> + Clone, P : Fn(&T) -> bool> { pub d: D, pub p: P, pub pd: PhantomData<T> }
19impl <T, D : Distribution<T> + Clone, P : Fn(&T) -> bool> Distribution<T> for Filtered<T, D, P> {
20 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
21 loop {
22 let s = self.d.sample(rng);
23 if (self.p)(&s) { return s }
24 }
25 }
26}
27
28#[derive(Clone)]
30pub struct Mapped<T, S, D : Distribution<T> + Clone, F : Fn(T) -> S + Clone> { pub d: D, pub f: F, pub pd: PhantomData<(T, S)> }
31impl <T, S, D : Distribution<T> + Clone, F : Fn(T) -> S + Clone> Distribution<S> for Mapped<T, S, D, F> {
32 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> S {
33 (self.f)(self.d.sample(rng))
34 }
35}
36
37#[derive(Clone)]
39pub struct Collected<T, S, D : Distribution<T> + Clone, P : Fn(T) -> Option<S> + Clone> { pub d: D, pub pf: P, pub pd: PhantomData<(T, S)> }
40impl <T, S, D : Distribution<T> + Clone, P : Fn(T) -> Option<S> + Clone> Distribution<S> for Collected<T, S, D, P> {
41 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> S {
42 loop {
43 let t = self.d.sample(rng);
44 match (self.pf)(t) {
45 None => {}
46 Some(s) => { return s }
47 }
48 }
49 }
50}
51
52#[derive(Clone)]
54pub struct Product2<X, DX : Distribution<X> + Clone, Y, DY : Distribution<Y> + Clone, Z, F : Fn(X, Y) -> Z + Clone> { pub dx: DX, pub dy: DY, pub f: F,
55 pub pd: PhantomData<(X, Y, Z)> }
56impl <X, DX : Distribution<X> + Clone, Y, DY : Distribution<Y> + Clone, Z, F : Fn(X, Y) -> Z + Clone> Distribution<Z> for Product2<X, DX, Y, DY, Z, F> {
57 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Z {
58 (self.f)(self.dx.sample(rng), self.dy.sample(rng))
59 }
60}
61
62#[derive(Clone)]
64pub struct Choice2<X, DX : Distribution<X> + Clone, Y, DY : Distribution<Y> + Clone, DB : Distribution<bool> + Clone> { pub dx: DX, pub dy: DY, pub db: DB,
65 pub pd: PhantomData<(X, Y)> }
66impl <X, DX : Distribution<X> + Clone, Y, DY : Distribution<Y> + Clone, DB : Distribution<bool> + Clone> Distribution<Result<X, Y>> for Choice2<X, DX, Y, DY, DB> {
67 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<X, Y> {
68 if self.db.sample(rng) { Ok(self.dx.sample(rng)) }
69 else { Err(self.dy.sample(rng)) }
70 }
71}
72
73#[derive(Clone)]
75pub struct Dependent2<X, DX : Distribution<X> + Clone, Y, DY : Distribution<Y> + Clone, FDY : Fn(X) -> DY + Clone> { pub dx: DX, pub fdy: FDY,
76 pub pd: PhantomData<(X, Y)> }
77impl <X, DX : Distribution<X> + Clone, Y, DY : Distribution<Y> + Clone, FDY : Fn(X) -> DY + Clone> Distribution<Y> for Dependent2<X, DX, Y, DY, FDY> {
78 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Y {
79 (self.fdy)(self.dx.sample(rng)).sample(rng)
80 }
81}
82
83#[derive(Clone)]
85pub struct Concentrated<X, DX : Distribution<X> + Clone, A : Clone, Y, FA : Fn(&mut A, X) -> Option<Y>> { pub dx: DX, pub z: A, pub fa: FA,
86 pub pd: PhantomData<(X, Y)> }
87impl <X, DX : Distribution<X> + Clone, A : Clone, Y, FA : Fn(&mut A, X) -> Option<Y>> Distribution<Y> for Concentrated<X, DX, A, Y, FA> {
88 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Y {
89 let mut a = self.z.clone();
90 loop {
91 match (self.fa)(&mut a, self.dx.sample(rng)) {
92 None => {}
93 Some(y) => { return y }
94 }
95 }
96 }
97}
98
99#[derive(Clone)]
101pub struct Diluted<X, DX : Distribution<X> + Clone, A : Clone, Y, FA : Fn(X) -> A, FAY : Fn(&mut A) -> Option<Y>> { pub dx: DX, pub fa: FA, pub fay: FAY,
102 pub pd: PhantomData<(X, A, Y)> }
103impl <X, DX : Distribution<X> + Clone, A : Clone, Y, FA : Fn(X) -> A, FAY : Fn(&mut A) -> Option<Y>> Distribution<Y> for Diluted<X, DX, A, Y, FA, FAY> {
104 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Y {
105 let mut a = (self.fa)(self.dx.sample(rng));
106 (self.fay)(&mut a).expect("fay returns at least once per fa call")
107 }
108
109 fn sample_iter<R>(self, _rng: R) -> Iter<Self, R, Y> where R : Rng, Self : Sized {
110 panic!("This function returning a concrete object makes it impossible to override the iterator behavior")
111 }
112}
113
114#[derive(Clone)]
116pub struct Degenerate<T : Clone> { pub element: T }
117impl <T : Clone> Distribution<T> for Degenerate<T> {
118 fn sample<R: Rng + ?Sized>(&self, _rng: &mut R) -> T {
119 self.element.clone()
120 }
121}
122
123#[derive(Clone)]
125pub struct Categorical<T : Clone, ElemD : Distribution<usize> + Clone> { pub elements: Vec<T>, pub ed: ElemD }
126impl <T : Clone, ElemD : Distribution<usize> + Clone> Distribution<T> for Categorical<T, ElemD> {
127 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
128 self.elements[self.ed.sample(rng)].clone()
129 }
130}
131
132pub fn ratios<T : Clone>(ep: impl IntoIterator<Item=(T, usize)>) -> Categorical<T, Mapped<usize, usize, Uniform<usize>, impl Fn(usize) -> usize + Clone>> {
134 let mut elements = vec![];
135 let mut cdf = vec![];
136 let mut sum = 0;
137 for (e, r) in ep.into_iter() {
138 elements.push(e);
139 cdf.push(sum);
140 sum += r;
141 }
142 let us = Uniform::try_from(0..sum).unwrap();
143 Categorical {
144 elements,
145 ed: Mapped{ d: us, f: move |x| { match cdf.binary_search(&x) {
147 Ok(i) => { i }
148 Err(i) => { i - 1 }
149 }}, pd: PhantomData::default() }
150 }
151}
152
153#[derive(Clone)]
155pub struct Repeated<T, LengthD : Distribution<usize>, ItemD : Distribution<T>> { pub lengthd: LengthD, pub itemd: ItemD, pub pd: PhantomData<T> }
156impl <T, LengthD : Distribution<usize>, ItemD : Distribution<T>> Distribution<Vec<T>> for Repeated<T, LengthD, ItemD> {
157 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<T> {
158 let l = self.lengthd.sample(rng);
159 Vec::from_iter(std::iter::repeat_with(|| self.itemd.sample(rng)).take(l))
160 }
161}
162
163#[derive(Clone)]
165pub struct Sentinel<MByteD : Distribution<Option<u8>> + Clone> { pub mbd: MByteD }
166impl <MByteD : Distribution<Option<u8>> + Clone> Distribution<Vec<u8>> for Sentinel<MByteD> {
167 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<u8> {
168 let mut v = vec![];
169 while let Some(e) = self.mbd.sample(rng) {
170 v.push(e)
171 }
172 v
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use rand::rngs::StdRng;
179 use rand::SeedableRng;
180 use rand::distr::Uniform;
181 use crate::*;
182
183 #[test]
184 fn monte_carlo_pi() {
185 #[cfg(not(miri))]
186 const SAMPLES: usize = 100000;
187 #[cfg(miri)]
188 const SAMPLES: usize = 100;
189
190 let rng = StdRng::from_seed([0; 32]);
191 let sx = Uniform::new(0.0, 1.0).unwrap();
192 let sy = Uniform::new(0.0, 1.0).unwrap();
193 let sxy = Product2 { dx: sx, dy: sy, f: |x, y| (x, y), pd: PhantomData::default() };
194 let spi = Concentrated { dx: sxy, z: (0, 0), fa: |i_o, (x, y)| {
195 if x*x + y*y < 1.0 { i_o.0 += 1 } else { i_o.1 += 1 }
196 if i_o.0 + i_o.1 > SAMPLES { Some(4f64*(i_o.0 as f64/(i_o.0 + i_o.1) as f64)) } else { None }
197 }, pd: Default::default() };
198
199 spi.sample_iter(rng).take(10).for_each(|api| {
200 let err_bar = 3.5f64 / (SAMPLES as f64).sqrt();
201 assert!(std::f64::consts::PI-err_bar <= api && std::f64::consts::PI+err_bar >= api)
202 });
203 }
204
205 #[test]
206 fn categorical_samples() {
207 #[cfg(not(miri))]
208 const SAMPLES: usize = 1000;
209 #[cfg(miri)]
210 const SAMPLES: usize = 141;
211
212 let rng = StdRng::from_seed([0; 32]);
213 let expected = [('b', 2usize), ('a', 10), ('c', 29), ('d', 100)];
214 let cd = ratios(expected.into_iter());
215 let hist = Histogram::from_iter(cd.sample_iter(rng).take(SAMPLES*(10+2+29+100)));
216 let achieved: Vec<(char, usize)> = hist.iter().map(|(k, c)|
217 (*k, ((c as f64)/(SAMPLES as f64)).round() as usize)).collect();
218 assert_eq!(&expected[..], &achieved[..]);
219 }
220}