1use std::{fmt, marker::PhantomData, ops::AddAssign};
4
5use rand::Rng;
6use rand_distr::{
7 uniform::{SampleBorrow, SampleUniform},
8 weighted::{WeightedError, WeightedIndex},
9 Distribution,
10};
11
12pub struct Mix<T, U, X>
54where
55 T: Distribution<U>,
56 X: SampleUniform + PartialOrd,
57{
58 distributions: Vec<T>,
59 weights: WeightedIndex<X>,
60 _marker: PhantomData<U>,
61}
62
63impl<T, U, X> Mix<T, U, X>
64where
65 T: Distribution<U>,
66 X: SampleUniform + PartialOrd,
67{
68 pub fn new<I, J>(dists: I, weights: J) -> Result<Self, WeightedError>
73 where
74 I: IntoIterator<Item = T>,
75 J: IntoIterator,
76 J::Item: SampleBorrow<X>,
77 X: for<'a> AddAssign<&'a X> + Clone + Default,
78 {
79 Ok(Self {
80 distributions: dists.into_iter().collect(),
81 weights: WeightedIndex::new(weights)?,
82 _marker: PhantomData,
83 })
84 }
85
86 pub fn with_zip<W>(
90 dists_weights: impl IntoIterator<Item = (T, W)>,
91 ) -> Result<Self, WeightedError>
92 where
93 W: SampleBorrow<X>,
94 X: for<'a> AddAssign<&'a X> + Clone + Default,
95 {
96 let (distributions, weights): (Vec<_>, Vec<_>) = dists_weights.into_iter().unzip();
97 Ok(Self {
98 distributions,
99 weights: WeightedIndex::new(weights)?,
100 _marker: PhantomData,
101 })
102 }
103}
104
105impl<T, U, X> Distribution<U> for Mix<T, U, X>
106where
107 T: Distribution<U>,
108 X: SampleUniform + PartialOrd,
109{
110 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> U {
111 let idx = self.weights.sample(rng);
112 self.distributions[idx].sample(rng)
113 }
114}
115
116impl<T, U, X> Clone for Mix<T, U, X>
117where
118 T: Distribution<U> + Clone,
119 X: SampleUniform + PartialOrd + Clone,
120 X::Sampler: Clone,
121{
122 fn clone(&self) -> Self {
123 Self {
124 distributions: self.distributions.clone(),
125 weights: self.weights.clone(),
126 _marker: PhantomData,
127 }
128 }
129}
130
131impl<T, U, X> fmt::Debug for Mix<T, U, X>
132where
133 T: Distribution<U> + fmt::Debug,
134 X: SampleUniform + PartialOrd + fmt::Debug,
135 X::Sampler: fmt::Debug,
136{
137 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138 f.debug_struct("Mix")
139 .field("distributions", &self.distributions)
140 .field("weights", &self.weights)
141 .finish()
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148 use rand_distr::{Normal, Uniform};
149
150 #[test]
151 #[ignore]
152 fn test_mix_plot() {
153 let mut rng = rand::thread_rng();
154
155 let mix = {
156 let dists = vec![
157 Normal::new(0.0, 1.0).unwrap(),
158 Normal::new(5.0, 2.0).unwrap(),
159 ];
160 let weights = &[2, 1];
161 Mix::new(dists, weights).unwrap()
162 };
163
164 for _ in 0..30000 {
165 println!("{} # mix", mix.sample(&mut rng));
166 }
167
168 }
192
193 #[test]
194 fn test_mix_2() {
195 let mut rng = rand::thread_rng();
196
197 let mix = {
198 let dists = vec![Uniform::new_inclusive(0, 0), Uniform::new_inclusive(1, 1)];
199 let weights = &[2, 1];
200 Mix::new(dists, weights).unwrap()
201 };
202
203 let data = mix.sample_iter(&mut rng).take(300).collect::<Vec<_>>();
204
205 let zeros = data.iter().filter(|&&x| x == 0).count();
206 let ones = data.iter().filter(|&&x| x == 1).count();
207
208 assert_eq!(zeros + ones, 300);
209
210 assert_eq!((zeros as f64 / 100.0).round() as i32, 2);
211 assert_eq!((ones as f64 / 100.0).round() as i32, 1);
212 }
213
214 #[test]
215 fn test_mix_3() {
216 let mut rng = rand::thread_rng();
217
218 let mix = {
219 let dists = vec![
220 Uniform::new_inclusive(0, 0),
221 Uniform::new_inclusive(1, 1),
222 Uniform::new_inclusive(2, 2),
223 ];
224 let weights = &[3, 2, 1];
225 Mix::new(dists, weights).unwrap()
226 };
227
228 let data = mix.sample_iter(&mut rng).take(600).collect::<Vec<_>>();
229
230 let zeros = data.iter().filter(|&&x| x == 0).count();
231 let ones = data.iter().filter(|&&x| x == 1).count();
232 let twos = data.iter().filter(|&&x| x == 2).count();
233
234 assert_eq!(zeros + ones + twos, 600);
235
236 assert_eq!((zeros as f64 / 100.0).round() as i32, 3);
237 assert_eq!((ones as f64 / 100.0).round() as i32, 2);
238 assert_eq!((twos as f64 / 100.0).round() as i32, 1);
239 }
240
241 #[test]
242 fn test_weight_f64() {
243 let mut rng = rand::thread_rng();
244
245 let mix = {
246 let dists = vec![Uniform::new_inclusive(0, 0), Uniform::new_inclusive(1, 1)];
247 let weights = &[0.4, 0.6];
248 Mix::new(dists, weights).unwrap()
249 };
250
251 let data = mix.sample_iter(&mut rng).take(1000).collect::<Vec<_>>();
252
253 let zeros = data.iter().filter(|&&x| x == 0).count();
254 let ones = data.iter().filter(|&&x| x == 1).count();
255
256 assert_eq!(zeros + ones, 1000);
257
258 assert_eq!((zeros as f64 / 100.0).round() as i32, 4);
259 assert_eq!((ones as f64 / 100.0).round() as i32, 6);
260 }
261
262 #[test]
263 fn test_zip() {
264 let mut rng = rand::thread_rng();
265
266 let mix = Mix::with_zip(vec![
267 (Uniform::new_inclusive(0, 0), 2),
268 (Uniform::new_inclusive(1, 1), 1),
269 ])
270 .unwrap();
271
272 let data = mix.sample_iter(&mut rng).take(300).collect::<Vec<_>>();
273
274 let zeros = data.iter().filter(|&&x| x == 0).count();
275 let ones = data.iter().filter(|&&x| x == 1).count();
276
277 assert_eq!(zeros + ones, 300);
278
279 assert_eq!((zeros as f64 / 100.0).round() as i32, 2);
280 assert_eq!((ones as f64 / 100.0).round() as i32, 1);
281 }
282
283 #[test]
284 fn error_invalid_weights() {
285 let dists = vec![Uniform::new_inclusive(0, 0), Uniform::new_inclusive(1, 1)];
286
287 let weights = &[2, 1][0..0];
288 assert_eq!(
289 Mix::new(dists.clone(), weights).unwrap_err(),
290 WeightedError::NoItem,
291 );
292
293 let weights = &[2, -1];
294 assert_eq!(
295 Mix::new(dists.clone(), weights).unwrap_err(),
296 WeightedError::InvalidWeight,
297 );
298
299 let weights = &[0, 0];
300 assert_eq!(
301 Mix::new(dists, weights).unwrap_err(),
302 WeightedError::AllWeightsZero,
303 );
304 }
305}