mix_distribution/
lib.rs

1//! Mixuture distributions.
2
3use std::{fmt, marker::PhantomData, ops::AddAssign};
4
5use rand::Rng;
6use rand_distr::{
7    uniform::{SampleBorrow, SampleUniform},
8    weighted::{WeightedError, WeightedIndex},
9    Distribution,
10};
11
12/// Mixture distributions.
13///
14/// # Examples
15///
16/// ```rust
17/// use rand_distr::{Distribution, Normal, Uniform};
18/// use mix_distribution::Mix;
19///
20/// let mut rng = rand::thread_rng();
21///
22/// // Mixture of two distributions
23/// let mix = {
24///     let dists = vec![
25///         Normal::new(0.0, 1.0).unwrap(),
26///         Normal::new(1.0, 2.0).unwrap(),
27///     ];
28///     let weights = &[2, 1];
29///     Mix::new(dists, weights).unwrap()
30/// };
31/// mix.sample(&mut rng);
32///
33/// // Mixture of three distributions
34/// let mix = {
35///     let dists = vec![
36///         Uniform::new_inclusive(0.0, 2.0),
37///         Uniform::new_inclusive(1.0, 3.0),
38///         Uniform::new_inclusive(2.0, 4.0),
39///     ];
40///     let weights = &[2, 1, 3];
41///     Mix::new(dists, weights).unwrap()
42/// };
43/// mix.sample(&mut rng);
44///
45/// // From iterator over (distribution, weight) pairs
46/// let mix = Mix::with_zip(vec![
47///     (Uniform::new_inclusive(0, 2), 2),
48///     (Uniform::new_inclusive(1, 3), 1),
49/// ])
50/// .unwrap();
51/// mix.sample(&mut rng);
52/// ```
53pub struct Mix<T, U, X>
54where
55    T: Distribution<U>,
56    X: SampleUniform + PartialOrd,
57{
58    distributions: Vec<T>,
59    weights: WeightedIndex<X>,
60    _marker: PhantomData<U>,
61}
62
63impl<T, U, X> Mix<T, U, X>
64where
65    T: Distribution<U>,
66    X: SampleUniform + PartialOrd,
67{
68    /// Creates a new `Mix`.
69    /// `dists` and `weights` must have the same length.
70    ///
71    /// Propagates errors from `rand_dist::weighted::WeightedIndex::new()`.
72    pub fn new<I, J>(dists: I, weights: J) -> Result<Self, WeightedError>
73    where
74        I: IntoIterator<Item = T>,
75        J: IntoIterator,
76        J::Item: SampleBorrow<X>,
77        X: for<'a> AddAssign<&'a X> + Clone + Default,
78    {
79        Ok(Self {
80            distributions: dists.into_iter().collect(),
81            weights: WeightedIndex::new(weights)?,
82            _marker: PhantomData,
83        })
84    }
85
86    /// Creats a new `Mix` with the given iterator over (distribution, weight) pairs.
87    ///
88    /// Propagates errors from `rand_dist::weighted::WeightedIndex::new()`.
89    pub fn with_zip<W>(
90        dists_weights: impl IntoIterator<Item = (T, W)>,
91    ) -> Result<Self, WeightedError>
92    where
93        W: SampleBorrow<X>,
94        X: for<'a> AddAssign<&'a X> + Clone + Default,
95    {
96        let (distributions, weights): (Vec<_>, Vec<_>) = dists_weights.into_iter().unzip();
97        Ok(Self {
98            distributions,
99            weights: WeightedIndex::new(weights)?,
100            _marker: PhantomData,
101        })
102    }
103}
104
105impl<T, U, X> Distribution<U> for Mix<T, U, X>
106where
107    T: Distribution<U>,
108    X: SampleUniform + PartialOrd,
109{
110    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> U {
111        let idx = self.weights.sample(rng);
112        self.distributions[idx].sample(rng)
113    }
114}
115
116impl<T, U, X> Clone for Mix<T, U, X>
117where
118    T: Distribution<U> + Clone,
119    X: SampleUniform + PartialOrd + Clone,
120    X::Sampler: Clone,
121{
122    fn clone(&self) -> Self {
123        Self {
124            distributions: self.distributions.clone(),
125            weights: self.weights.clone(),
126            _marker: PhantomData,
127        }
128    }
129}
130
131impl<T, U, X> fmt::Debug for Mix<T, U, X>
132where
133    T: Distribution<U> + fmt::Debug,
134    X: SampleUniform + PartialOrd + fmt::Debug,
135    X::Sampler: fmt::Debug,
136{
137    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138        f.debug_struct("Mix")
139            .field("distributions", &self.distributions)
140            .field("weights", &self.weights)
141            .finish()
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148    use rand_distr::{Normal, Uniform};
149
150    #[test]
151    #[ignore]
152    fn test_mix_plot() {
153        let mut rng = rand::thread_rng();
154
155        let mix = {
156            let dists = vec![
157                Normal::new(0.0, 1.0).unwrap(),
158                Normal::new(5.0, 2.0).unwrap(),
159            ];
160            let weights = &[2, 1];
161            Mix::new(dists, weights).unwrap()
162        };
163
164        for _ in 0..30000 {
165            println!("{} # mix", mix.sample(&mut rng));
166        }
167
168        // # cargo test test_mix_plot -- --ignored --nocapture | python3 plot.py
169        //
170        // from sys import stdin
171        //
172        // import numpy as np
173        // from numpy.random import normal
174        // import matplotlib.pyplot as plt
175        //
176        // BINS = 128
177        // ALPHA = 0.5
178        //
179        // actual = np.array([float(l.split()[0])
180        //                      for l in stdin.readlines() if "# mix" in l])
181        // plt.hist(actual, bins=BINS, alpha=ALPHA, label="Actual")
182        //
183        // expected = np.concatenate(
184        //     (normal(0.0, 1.0, 20000), normal(5.0, 2.0, 10000)), axis=0)
185        // plt.hist(expected, bins=BINS, alpha=ALPHA, label="Expected")
186        //
187        // plt.legend()
188        // plt.grid()
189        //
190        // plt.show()
191    }
192
193    #[test]
194    fn test_mix_2() {
195        let mut rng = rand::thread_rng();
196
197        let mix = {
198            let dists = vec![Uniform::new_inclusive(0, 0), Uniform::new_inclusive(1, 1)];
199            let weights = &[2, 1];
200            Mix::new(dists, weights).unwrap()
201        };
202
203        let data = mix.sample_iter(&mut rng).take(300).collect::<Vec<_>>();
204
205        let zeros = data.iter().filter(|&&x| x == 0).count();
206        let ones = data.iter().filter(|&&x| x == 1).count();
207
208        assert_eq!(zeros + ones, 300);
209
210        assert_eq!((zeros as f64 / 100.0).round() as i32, 2);
211        assert_eq!((ones as f64 / 100.0).round() as i32, 1);
212    }
213
214    #[test]
215    fn test_mix_3() {
216        let mut rng = rand::thread_rng();
217
218        let mix = {
219            let dists = vec![
220                Uniform::new_inclusive(0, 0),
221                Uniform::new_inclusive(1, 1),
222                Uniform::new_inclusive(2, 2),
223            ];
224            let weights = &[3, 2, 1];
225            Mix::new(dists, weights).unwrap()
226        };
227
228        let data = mix.sample_iter(&mut rng).take(600).collect::<Vec<_>>();
229
230        let zeros = data.iter().filter(|&&x| x == 0).count();
231        let ones = data.iter().filter(|&&x| x == 1).count();
232        let twos = data.iter().filter(|&&x| x == 2).count();
233
234        assert_eq!(zeros + ones + twos, 600);
235
236        assert_eq!((zeros as f64 / 100.0).round() as i32, 3);
237        assert_eq!((ones as f64 / 100.0).round() as i32, 2);
238        assert_eq!((twos as f64 / 100.0).round() as i32, 1);
239    }
240
241    #[test]
242    fn test_weight_f64() {
243        let mut rng = rand::thread_rng();
244
245        let mix = {
246            let dists = vec![Uniform::new_inclusive(0, 0), Uniform::new_inclusive(1, 1)];
247            let weights = &[0.4, 0.6];
248            Mix::new(dists, weights).unwrap()
249        };
250
251        let data = mix.sample_iter(&mut rng).take(1000).collect::<Vec<_>>();
252
253        let zeros = data.iter().filter(|&&x| x == 0).count();
254        let ones = data.iter().filter(|&&x| x == 1).count();
255
256        assert_eq!(zeros + ones, 1000);
257
258        assert_eq!((zeros as f64 / 100.0).round() as i32, 4);
259        assert_eq!((ones as f64 / 100.0).round() as i32, 6);
260    }
261
262    #[test]
263    fn test_zip() {
264        let mut rng = rand::thread_rng();
265
266        let mix = Mix::with_zip(vec![
267            (Uniform::new_inclusive(0, 0), 2),
268            (Uniform::new_inclusive(1, 1), 1),
269        ])
270        .unwrap();
271
272        let data = mix.sample_iter(&mut rng).take(300).collect::<Vec<_>>();
273
274        let zeros = data.iter().filter(|&&x| x == 0).count();
275        let ones = data.iter().filter(|&&x| x == 1).count();
276
277        assert_eq!(zeros + ones, 300);
278
279        assert_eq!((zeros as f64 / 100.0).round() as i32, 2);
280        assert_eq!((ones as f64 / 100.0).round() as i32, 1);
281    }
282
283    #[test]
284    fn error_invalid_weights() {
285        let dists = vec![Uniform::new_inclusive(0, 0), Uniform::new_inclusive(1, 1)];
286
287        let weights = &[2, 1][0..0];
288        assert_eq!(
289            Mix::new(dists.clone(), weights).unwrap_err(),
290            WeightedError::NoItem,
291        );
292
293        let weights = &[2, -1];
294        assert_eq!(
295            Mix::new(dists.clone(), weights).unwrap_err(),
296            WeightedError::InvalidWeight,
297        );
298
299        let weights = &[0, 0];
300        assert_eq!(
301            Mix::new(dists, weights).unwrap_err(),
302            WeightedError::AllWeightsZero,
303        );
304    }
305}