ferray_random/distributions/
exponential.rs1use ferray_core::{Array, FerrayError, IxDyn};
4
5use crate::bitgen::BitGenerator;
6use crate::generator::{
7 Generator, generate_vec, generate_vec_f32, shape_size, vec_to_array_f32, vec_to_array_f64,
8};
9use crate::shape::IntoShape;
10
11pub(crate) fn standard_exponential_single<B: BitGenerator>(bg: &mut B) -> f64 {
13 loop {
14 let u = bg.next_f64();
15 if u > f64::EPSILON {
16 return -u.ln();
17 }
18 }
19}
20
21pub(crate) fn standard_exponential_single_f32<B: BitGenerator>(bg: &mut B) -> f32 {
25 standard_exponential_single(bg) as f32
26}
27
28impl<B: BitGenerator> Generator<B> {
29 pub fn standard_exponential(
36 &mut self,
37 size: impl IntoShape,
38 ) -> Result<Array<f64, IxDyn>, FerrayError> {
39 let shape = size.into_shape()?;
40 let n = shape_size(&shape);
41 let data = generate_vec(self, n, standard_exponential_single);
42 vec_to_array_f64(data, &shape)
43 }
44
45 pub fn exponential(
52 &mut self,
53 scale: f64,
54 size: impl IntoShape,
55 ) -> Result<Array<f64, IxDyn>, FerrayError> {
56 if scale <= 0.0 {
57 return Err(FerrayError::invalid_value(format!(
58 "scale must be positive, got {scale}"
59 )));
60 }
61 let shape = size.into_shape()?;
62 let n = shape_size(&shape);
63 let data = generate_vec(self, n, |bg| scale * standard_exponential_single(bg));
64 vec_to_array_f64(data, &shape)
65 }
66
67 pub fn standard_exponential_f32(
74 &mut self,
75 size: impl IntoShape,
76 ) -> Result<Array<f32, IxDyn>, FerrayError> {
77 let shape = size.into_shape()?;
78 let n = shape_size(&shape);
79 let data = generate_vec_f32(self, n, standard_exponential_single_f32);
80 vec_to_array_f32(data, &shape)
81 }
82
83 pub fn exponential_f32(
89 &mut self,
90 scale: f32,
91 size: impl IntoShape,
92 ) -> Result<Array<f32, IxDyn>, FerrayError> {
93 if scale <= 0.0 {
94 return Err(FerrayError::invalid_value(format!(
95 "scale must be positive, got {scale}"
96 )));
97 }
98 let shape = size.into_shape()?;
99 let n = shape_size(&shape);
100 let data = generate_vec_f32(self, n, |bg| scale * standard_exponential_single_f32(bg));
101 vec_to_array_f32(data, &shape)
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use crate::default_rng_seeded;
108
109 #[test]
110 fn standard_exponential_positive() {
111 let mut rng = default_rng_seeded(42);
112 let arr = rng.standard_exponential(10_000).unwrap();
113 let slice = arr.as_slice().unwrap();
114 for &v in slice {
115 assert!(
116 v > 0.0,
117 "standard_exponential produced non-positive value: {v}"
118 );
119 }
120 }
121
122 #[test]
123 fn standard_exponential_mean_variance() {
124 let mut rng = default_rng_seeded(42);
125 let n = 100_000;
126 let arr = rng.standard_exponential(n).unwrap();
127 let slice = arr.as_slice().unwrap();
128 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
129 let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
130 let se = (1.0 / n as f64).sqrt();
132 assert!((mean - 1.0).abs() < 3.0 * se, "mean {mean} too far from 1");
133 assert!((var - 1.0).abs() < 0.05, "variance {var} too far from 1");
134 }
135
136 #[test]
137 fn exponential_mean() {
138 let mut rng = default_rng_seeded(42);
139 let n = 100_000;
140 let scale = 3.0;
141 let arr = rng.exponential(scale, n).unwrap();
142 let slice = arr.as_slice().unwrap();
143 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
144 let se = (scale * scale / n as f64).sqrt();
145 assert!(
146 (mean - scale).abs() < 3.0 * se,
147 "mean {mean} too far from {scale}"
148 );
149 }
150
151 #[test]
152 fn exponential_bad_scale() {
153 let mut rng = default_rng_seeded(42);
154 assert!(rng.exponential(0.0, 100).is_err());
155 assert!(rng.exponential(-1.0, 100).is_err());
156 }
157
158 #[test]
159 fn exponential_deterministic() {
160 let mut rng1 = default_rng_seeded(42);
161 let mut rng2 = default_rng_seeded(42);
162 let a = rng1.exponential(2.0, 100).unwrap();
163 let b = rng2.exponential(2.0, 100).unwrap();
164 assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
165 }
166
167 #[test]
168 fn exponential_mean_and_variance() {
169 let mut rng = default_rng_seeded(42);
170 let n = 100_000;
171 let scale = 3.0;
172 let arr = rng.exponential(scale, n).unwrap();
173 let s = arr.as_slice().unwrap();
174 let mean: f64 = s.iter().sum::<f64>() / n as f64;
175 let var: f64 = s.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
176 assert!(
178 (mean - scale).abs() < 0.1,
179 "exponential mean {mean} too far from {scale}"
180 );
181 assert!(
182 (var - scale * scale).abs() < 1.0,
183 "exponential variance {var} too far from {}",
184 scale * scale
185 );
186 }
187
188 #[test]
189 fn standard_exponential_mean() {
190 let mut rng = default_rng_seeded(42);
191 let n = 100_000;
192 let arr = rng.standard_exponential(n).unwrap();
193 let s = arr.as_slice().unwrap();
194 let mean: f64 = s.iter().sum::<f64>() / n as f64;
195 assert!(
196 (mean - 1.0).abs() < 0.02,
197 "standard_exponential mean {mean} too far from 1.0"
198 );
199 assert!(s.iter().all(|&x| x >= 0.0), "negative exponential value");
201 }
202
203 #[test]
208 fn standard_exponential_f32_positive() {
209 let mut rng = default_rng_seeded(42);
210 let arr = rng.standard_exponential_f32(10_000).unwrap();
211 for &v in arr.as_slice().unwrap() {
212 assert!(
213 v > 0.0,
214 "standard_exponential_f32 produced non-positive: {v}"
215 );
216 }
217 }
218
219 #[test]
220 fn standard_exponential_f32_mean() {
221 let mut rng = default_rng_seeded(42);
222 let n = 100_000usize;
223 let arr = rng.standard_exponential_f32(n).unwrap();
224 let slice = arr.as_slice().unwrap();
225 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
226 assert!(
227 (mean - 1.0).abs() < 0.02,
228 "f32 exp mean {mean} too far from 1"
229 );
230 }
231
232 #[test]
233 fn exponential_f32_mean() {
234 let mut rng = default_rng_seeded(42);
235 let n = 100_000usize;
236 let scale = 3.0f32;
237 let arr = rng.exponential_f32(scale, n).unwrap();
238 let slice = arr.as_slice().unwrap();
239 let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
240 assert!(
241 (mean - scale as f64).abs() < 0.1,
242 "exponential_f32 mean {mean} too far from {scale}"
243 );
244 }
245
246 #[test]
247 fn exponential_f32_bad_scale() {
248 let mut rng = default_rng_seeded(42);
249 assert!(rng.exponential_f32(0.0, 100).is_err());
250 assert!(rng.exponential_f32(-1.0, 100).is_err());
251 }
252
253 #[test]
254 fn exponential_f32_deterministic() {
255 let mut rng1 = default_rng_seeded(42);
256 let mut rng2 = default_rng_seeded(42);
257 let a = rng1.exponential_f32(2.0, 100).unwrap();
258 let b = rng2.exponential_f32(2.0, 100).unwrap();
259 assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
260 }
261}