ferray_random/distributions/
exponential.rs1use ferray_core::{Array, FerrayError, IxDyn};
4
5use crate::bitgen::BitGenerator;
6use crate::generator::{
7 Generator, generate_vec, generate_vec_f32, shape_size, vec_to_array_f32, vec_to_array_f64,
8};
9use crate::shape::IntoShape;
10
11pub(crate) fn standard_exponential_single<B: BitGenerator>(bg: &mut B) -> f64 {
13 loop {
14 let u = bg.next_f64();
15 if u > f64::EPSILON {
16 return -u.ln();
17 }
18 }
19}
20
21pub(crate) fn standard_exponential_single_f32<B: BitGenerator>(bg: &mut B) -> f32 {
25 standard_exponential_single(bg) as f32
26}
27
28impl<B: BitGenerator> Generator<B> {
29 pub fn standard_exponential(
36 &mut self,
37 size: impl IntoShape,
38 ) -> Result<Array<f64, IxDyn>, FerrayError> {
39 let shape = size.into_shape()?;
40 let n = shape_size(&shape);
41 let data = generate_vec(self, n, standard_exponential_single);
42 vec_to_array_f64(data, &shape)
43 }
44
45 pub fn exponential(
52 &mut self,
53 scale: f64,
54 size: impl IntoShape,
55 ) -> Result<Array<f64, IxDyn>, FerrayError> {
56 if scale <= 0.0 {
57 return Err(FerrayError::invalid_value(format!(
58 "scale must be positive, got {scale}"
59 )));
60 }
61 let shape = size.into_shape()?;
62 let n = shape_size(&shape);
63 let data = generate_vec(self, n, |bg| scale * standard_exponential_single(bg));
64 vec_to_array_f64(data, &shape)
65 }
66
67 pub fn standard_exponential_into(
75 &mut self,
76 out: &mut Array<f64, IxDyn>,
77 ) -> Result<(), FerrayError> {
78 let slice = out.as_slice_mut().ok_or_else(|| {
79 FerrayError::invalid_value("standard_exponential_into requires a contiguous out buffer")
80 })?;
81 for v in slice.iter_mut() {
82 *v = standard_exponential_single(&mut self.bg);
83 }
84 Ok(())
85 }
86
87 pub fn exponential_array(
97 &mut self,
98 scale: &Array<f64, IxDyn>,
99 ) -> Result<Array<f64, IxDyn>, FerrayError> {
100 let shape = scale.shape().to_vec();
101 let total: usize = shape.iter().product();
102 let mut out: Vec<f64> = Vec::with_capacity(total);
103 for &s in scale.iter() {
104 if s <= 0.0 {
105 return Err(FerrayError::invalid_value(format!(
106 "scale must be positive, got {s}"
107 )));
108 }
109 out.push(s * standard_exponential_single(&mut self.bg));
110 }
111 Array::<f64, IxDyn>::from_vec(IxDyn::new(&shape), out)
112 }
113
114 pub fn standard_exponential_f32(
121 &mut self,
122 size: impl IntoShape,
123 ) -> Result<Array<f32, IxDyn>, FerrayError> {
124 let shape = size.into_shape()?;
125 let n = shape_size(&shape);
126 let data = generate_vec_f32(self, n, standard_exponential_single_f32);
127 vec_to_array_f32(data, &shape)
128 }
129
130 pub fn exponential_f32(
136 &mut self,
137 scale: f32,
138 size: impl IntoShape,
139 ) -> Result<Array<f32, IxDyn>, FerrayError> {
140 if scale <= 0.0 {
141 return Err(FerrayError::invalid_value(format!(
142 "scale must be positive, got {scale}"
143 )));
144 }
145 let shape = size.into_shape()?;
146 let n = shape_size(&shape);
147 let data = generate_vec_f32(self, n, |bg| scale * standard_exponential_single_f32(bg));
148 vec_to_array_f32(data, &shape)
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use crate::default_rng_seeded;
155
156 #[test]
159 fn standard_exponential_into_matches_allocating_version() {
160 use ferray_core::{Array, IxDyn};
161 let mut a = default_rng_seeded(42);
162 let mut b = default_rng_seeded(42);
163 let allocated = a.standard_exponential([7, 3]).unwrap();
164 let mut buf = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[7, 3]), vec![0.0; 21]).unwrap();
165 b.standard_exponential_into(&mut buf).unwrap();
166 assert_eq!(allocated.as_slice().unwrap(), buf.as_slice().unwrap());
167 }
168
169 #[test]
170 fn standard_exponential_into_values_positive() {
171 use ferray_core::{Array, IxDyn};
172 let mut rng = default_rng_seeded(123);
173 let mut buf = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[1000]), vec![-1.0; 1000]).unwrap();
174 rng.standard_exponential_into(&mut buf).unwrap();
175 for &v in buf.as_slice().unwrap() {
176 assert!(v > 0.0);
177 }
178 }
179
180 #[test]
183 fn exponential_array_shape_matches_scale() {
184 use ferray_core::{Array, IxDyn};
185 let mut rng = default_rng_seeded(42);
186 let scale =
187 Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 0.5, 4.0, 1.5, 3.0])
188 .unwrap();
189 let out = rng.exponential_array(&scale).unwrap();
190 assert_eq!(out.shape(), &[2, 3]);
191 for &v in out.as_slice().unwrap() {
192 assert!(v >= 0.0);
193 }
194 }
195
196 #[test]
197 fn exponential_array_per_element_mean() {
198 use ferray_core::{Array, IxDyn};
199 let mut rng = default_rng_seeded(7);
202 let scales = [0.5_f64, 2.0, 5.0];
203 let scale = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[3]), scales.to_vec()).unwrap();
204 let n_trials = 20_000;
205 let mut sums = [0.0_f64; 3];
206 for _ in 0..n_trials {
207 let out = rng.exponential_array(&scale).unwrap();
208 let s = out.as_slice().unwrap();
209 for j in 0..3 {
210 sums[j] += s[j];
211 }
212 }
213 for j in 0..3 {
214 let m = sums[j] / n_trials as f64;
215 let se = scales[j] / (n_trials as f64).sqrt();
217 assert!(
218 (m - scales[j]).abs() < 4.0 * se,
219 "element {j}: mean {m} too far from {}",
220 scales[j]
221 );
222 }
223 }
224
225 #[test]
226 fn exponential_array_bad_scale_errors() {
227 use ferray_core::{Array, IxDyn};
228 let mut rng = default_rng_seeded(0);
229 let scale = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![1.0, -1.0]).unwrap();
230 assert!(rng.exponential_array(&scale).is_err());
231 }
232
233 #[test]
234 fn standard_exponential_positive() {
235 let mut rng = default_rng_seeded(42);
236 let arr = rng.standard_exponential(10_000).unwrap();
237 let slice = arr.as_slice().unwrap();
238 for &v in slice {
239 assert!(
240 v > 0.0,
241 "standard_exponential produced non-positive value: {v}"
242 );
243 }
244 }
245
246 #[test]
247 fn standard_exponential_mean_variance() {
248 let mut rng = default_rng_seeded(42);
249 let n = 100_000;
250 let arr = rng.standard_exponential(n).unwrap();
251 let slice = arr.as_slice().unwrap();
252 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
253 let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
254 let se = (1.0 / n as f64).sqrt();
256 assert!((mean - 1.0).abs() < 3.0 * se, "mean {mean} too far from 1");
257 assert!((var - 1.0).abs() < 0.05, "variance {var} too far from 1");
258 }
259
260 #[test]
261 fn exponential_mean() {
262 let mut rng = default_rng_seeded(42);
263 let n = 100_000;
264 let scale = 3.0;
265 let arr = rng.exponential(scale, n).unwrap();
266 let slice = arr.as_slice().unwrap();
267 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
268 let se = (scale * scale / n as f64).sqrt();
269 assert!(
270 (mean - scale).abs() < 3.0 * se,
271 "mean {mean} too far from {scale}"
272 );
273 }
274
275 #[test]
276 fn exponential_bad_scale() {
277 let mut rng = default_rng_seeded(42);
278 assert!(rng.exponential(0.0, 100).is_err());
279 assert!(rng.exponential(-1.0, 100).is_err());
280 }
281
282 #[test]
283 fn exponential_deterministic() {
284 let mut rng1 = default_rng_seeded(42);
285 let mut rng2 = default_rng_seeded(42);
286 let a = rng1.exponential(2.0, 100).unwrap();
287 let b = rng2.exponential(2.0, 100).unwrap();
288 assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
289 }
290
291 #[test]
292 fn exponential_mean_and_variance() {
293 let mut rng = default_rng_seeded(42);
294 let n = 100_000;
295 let scale = 3.0;
296 let arr = rng.exponential(scale, n).unwrap();
297 let s = arr.as_slice().unwrap();
298 let mean: f64 = s.iter().sum::<f64>() / n as f64;
299 let var: f64 = s.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
300 assert!(
302 (mean - scale).abs() < 0.1,
303 "exponential mean {mean} too far from {scale}"
304 );
305 assert!(
306 (var - scale * scale).abs() < 1.0,
307 "exponential variance {var} too far from {}",
308 scale * scale
309 );
310 }
311
312 #[test]
313 fn standard_exponential_mean() {
314 let mut rng = default_rng_seeded(42);
315 let n = 100_000;
316 let arr = rng.standard_exponential(n).unwrap();
317 let s = arr.as_slice().unwrap();
318 let mean: f64 = s.iter().sum::<f64>() / n as f64;
319 assert!(
320 (mean - 1.0).abs() < 0.02,
321 "standard_exponential mean {mean} too far from 1.0"
322 );
323 assert!(s.iter().all(|&x| x >= 0.0), "negative exponential value");
325 }
326
327 #[test]
332 fn standard_exponential_f32_positive() {
333 let mut rng = default_rng_seeded(42);
334 let arr = rng.standard_exponential_f32(10_000).unwrap();
335 for &v in arr.as_slice().unwrap() {
336 assert!(
337 v > 0.0,
338 "standard_exponential_f32 produced non-positive: {v}"
339 );
340 }
341 }
342
343 #[test]
344 fn standard_exponential_f32_mean() {
345 let mut rng = default_rng_seeded(42);
346 let n = 100_000usize;
347 let arr = rng.standard_exponential_f32(n).unwrap();
348 let slice = arr.as_slice().unwrap();
349 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
350 assert!(
351 (mean - 1.0).abs() < 0.02,
352 "f32 exp mean {mean} too far from 1"
353 );
354 }
355
356 #[test]
357 fn exponential_f32_mean() {
358 let mut rng = default_rng_seeded(42);
359 let n = 100_000usize;
360 let scale = 3.0f32;
361 let arr = rng.exponential_f32(scale, n).unwrap();
362 let slice = arr.as_slice().unwrap();
363 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
364 assert!(
365 (mean - scale as f64).abs() < 0.1,
366 "exponential_f32 mean {mean} too far from {scale}"
367 );
368 }
369
370 #[test]
371 fn exponential_f32_bad_scale() {
372 let mut rng = default_rng_seeded(42);
373 assert!(rng.exponential_f32(0.0, 100).is_err());
374 assert!(rng.exponential_f32(-1.0, 100).is_err());
375 }
376
377 #[test]
378 fn exponential_f32_deterministic() {
379 let mut rng1 = default_rng_seeded(42);
380 let mut rng2 = default_rng_seeded(42);
381 let a = rng1.exponential_f32(2.0, 100).unwrap();
382 let b = rng2.exponential_f32(2.0, 100).unwrap();
383 assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
384 }
385}