ferray_random/distributions/
gamma.rs1use ferray_core::{Array, FerrayError, IxDyn};
7
8use crate::bitgen::BitGenerator;
9use crate::distributions::normal::standard_normal_single;
10use crate::generator::{Generator, generate_vec, shape_size, vec_to_array_f64};
11use crate::shape::IntoShape;
12
13pub(crate) fn standard_gamma_single<B: BitGenerator>(bg: &mut B, alpha: f64) -> f64 {
18 if alpha < 1.0 {
19 if alpha <= 0.0 {
21 return 0.0;
22 }
23 loop {
24 let u = bg.next_f64();
25 if u > f64::EPSILON {
26 let x = standard_gamma_ge1(bg, alpha + 1.0);
27 return x * u.powf(1.0 / alpha);
28 }
29 }
30 } else {
31 standard_gamma_ge1(bg, alpha)
32 }
33}
34
35fn standard_gamma_ge1<B: BitGenerator>(bg: &mut B, alpha: f64) -> f64 {
37 let d = alpha - 1.0 / 3.0;
38 let c = 1.0 / (9.0 * d).sqrt();
39
40 loop {
41 let x = standard_normal_single(bg);
42 let v_base = 1.0 + c * x;
43 if v_base <= 0.0 {
44 continue;
45 }
46 let v = v_base * v_base * v_base;
47 let u = bg.next_f64();
48 if u < 1.0 - 0.0331 * (x * x) * (x * x) {
50 return d * v;
51 }
52 if u.ln() < 0.5 * x * x + d * (1.0 - v + v.ln()) {
53 return d * v;
54 }
55 }
56}
57
58impl<B: BitGenerator> Generator<B> {
59 pub fn standard_gamma(
64 &mut self,
65 alpha: f64,
66 size: impl IntoShape,
67 ) -> Result<Array<f64, IxDyn>, FerrayError> {
68 if alpha <= 0.0 {
69 return Err(FerrayError::invalid_value(format!(
70 "alpha must be positive, got {alpha}"
71 )));
72 }
73 let shape_vec = size.into_shape()?;
74 let n = shape_size(&shape_vec);
75 let data = generate_vec(self, n, |bg| standard_gamma_single(bg, alpha));
76 vec_to_array_f64(data, &shape_vec)
77 }
78
79 pub fn gamma(
87 &mut self,
88 alpha: f64,
89 scale: f64,
90 size: impl IntoShape,
91 ) -> Result<Array<f64, IxDyn>, FerrayError> {
92 if alpha <= 0.0 {
93 return Err(FerrayError::invalid_value(format!(
94 "alpha must be positive, got {alpha}"
95 )));
96 }
97 if scale <= 0.0 {
98 return Err(FerrayError::invalid_value(format!(
99 "scale must be positive, got {scale}"
100 )));
101 }
102 let shape_vec = size.into_shape()?;
103 let n = shape_size(&shape_vec);
104 let data = generate_vec(self, n, |bg| scale * standard_gamma_single(bg, alpha));
105 vec_to_array_f64(data, &shape_vec)
106 }
107
108 pub fn beta(
115 &mut self,
116 a: f64,
117 b: f64,
118 size: impl IntoShape,
119 ) -> Result<Array<f64, IxDyn>, FerrayError> {
120 if a <= 0.0 {
121 return Err(FerrayError::invalid_value(format!(
122 "a must be positive, got {a}"
123 )));
124 }
125 if b <= 0.0 {
126 return Err(FerrayError::invalid_value(format!(
127 "b must be positive, got {b}"
128 )));
129 }
130 let shape_vec = size.into_shape()?;
131 let n = shape_size(&shape_vec);
132 let data = generate_vec(self, n, |bg| {
133 let x = standard_gamma_single(bg, a);
134 let y = standard_gamma_single(bg, b);
135 if x + y == 0.0 {
136 0.5 } else {
138 x / (x + y)
139 }
140 });
141 vec_to_array_f64(data, &shape_vec)
142 }
143
144 pub fn chisquare(
151 &mut self,
152 df: f64,
153 size: impl IntoShape,
154 ) -> Result<Array<f64, IxDyn>, FerrayError> {
155 if df <= 0.0 {
156 return Err(FerrayError::invalid_value(format!(
157 "df must be positive, got {df}"
158 )));
159 }
160 let shape_vec = size.into_shape()?;
161 let n = shape_size(&shape_vec);
162 let data = generate_vec(self, n, |bg| 2.0 * standard_gamma_single(bg, df / 2.0));
163 vec_to_array_f64(data, &shape_vec)
164 }
165
166 pub fn f(
173 &mut self,
174 dfnum: f64,
175 dfden: f64,
176 size: impl IntoShape,
177 ) -> Result<Array<f64, IxDyn>, FerrayError> {
178 if dfnum <= 0.0 {
179 return Err(FerrayError::invalid_value(format!(
180 "dfnum must be positive, got {dfnum}"
181 )));
182 }
183 if dfden <= 0.0 {
184 return Err(FerrayError::invalid_value(format!(
185 "dfden must be positive, got {dfden}"
186 )));
187 }
188 let shape_vec = size.into_shape()?;
189 let n = shape_size(&shape_vec);
190 let data = generate_vec(self, n, |bg| {
191 let x1 = standard_gamma_single(bg, dfnum / 2.0);
192 let x2 = standard_gamma_single(bg, dfden / 2.0);
193 if x2 == 0.0 {
194 f64::INFINITY
195 } else {
196 (x1 / dfnum) / (x2 / dfden)
197 }
198 });
199 vec_to_array_f64(data, &shape_vec)
200 }
201
202 pub fn student_t(
209 &mut self,
210 df: f64,
211 size: impl IntoShape,
212 ) -> Result<Array<f64, IxDyn>, FerrayError> {
213 if df <= 0.0 {
214 return Err(FerrayError::invalid_value(format!(
215 "df must be positive, got {df}"
216 )));
217 }
218 let shape_vec = size.into_shape()?;
219 let n = shape_size(&shape_vec);
220 let data = generate_vec(self, n, |bg| {
221 let z = standard_normal_single(bg);
222 let chi2 = 2.0 * standard_gamma_single(bg, df / 2.0);
223 z / (chi2 / df).sqrt()
224 });
225 vec_to_array_f64(data, &shape_vec)
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use crate::default_rng_seeded;
232
233 #[test]
234 fn gamma_positive() {
235 let mut rng = default_rng_seeded(42);
236 let arr = rng.gamma(2.0, 1.0, 10_000).unwrap();
237 let slice = arr.as_slice().unwrap();
238 for &v in slice {
239 assert!(v > 0.0);
240 }
241 }
242
243 #[test]
244 fn gamma_mean_variance() {
245 let mut rng = default_rng_seeded(42);
246 let n = 100_000;
247 let shape = 3.0;
248 let scale = 2.0;
249 let arr = rng.gamma(shape, scale, n).unwrap();
250 let slice = arr.as_slice().unwrap();
251 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
252 let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
253 let expected_mean = shape * scale;
255 let expected_var = shape * scale * scale;
256 let se = (expected_var / n as f64).sqrt();
257 assert!(
258 (mean - expected_mean).abs() < 3.0 * se,
259 "gamma mean {mean} too far from {expected_mean}"
260 );
261 assert!(
262 (var - expected_var).abs() / expected_var < 0.05,
263 "gamma variance {var} too far from {expected_var}"
264 );
265 }
266
267 #[test]
268 fn gamma_small_shape() {
269 let mut rng = default_rng_seeded(42);
270 let arr = rng.gamma(0.5, 1.0, 10_000).unwrap();
271 let slice = arr.as_slice().unwrap();
272 for &v in slice {
273 assert!(v > 0.0);
274 }
275 }
276
277 #[test]
278 fn beta_in_range() {
279 let mut rng = default_rng_seeded(42);
280 let arr = rng.beta(2.0, 5.0, 10_000).unwrap();
281 let slice = arr.as_slice().unwrap();
282 for &v in slice {
283 assert!(v > 0.0 && v < 1.0, "beta value {v} out of (0,1)");
284 }
285 }
286
287 #[test]
288 fn beta_mean() {
289 let mut rng = default_rng_seeded(42);
290 let n = 100_000;
291 let a = 2.0;
292 let b = 5.0;
293 let arr = rng.beta(a, b, n).unwrap();
294 let slice = arr.as_slice().unwrap();
295 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
296 let expected_mean = a / (a + b);
298 let expected_var = (a * b) / ((a + b).powi(2) * (a + b + 1.0));
299 let se = (expected_var / n as f64).sqrt();
300 assert!(
301 (mean - expected_mean).abs() < 3.0 * se,
302 "beta mean {mean} too far from {expected_mean}"
303 );
304 }
305
306 #[test]
307 fn chisquare_positive() {
308 let mut rng = default_rng_seeded(42);
309 let arr = rng.chisquare(5.0, 10_000).unwrap();
310 let slice = arr.as_slice().unwrap();
311 for &v in slice {
312 assert!(v > 0.0);
313 }
314 }
315
316 #[test]
317 fn chisquare_mean() {
318 let mut rng = default_rng_seeded(42);
319 let n = 100_000;
320 let df = 10.0;
321 let arr = rng.chisquare(df, n).unwrap();
322 let slice = arr.as_slice().unwrap();
323 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
324 let expected_var = 2.0 * df;
326 let se = (expected_var / n as f64).sqrt();
327 assert!(
328 (mean - df).abs() < 3.0 * se,
329 "chisquare mean {mean} too far from {df}"
330 );
331 }
332
333 #[test]
334 fn f_positive() {
335 let mut rng = default_rng_seeded(42);
336 let arr = rng.f(5.0, 10.0, 10_000).unwrap();
337 let slice = arr.as_slice().unwrap();
338 for &v in slice {
339 assert!(v > 0.0);
340 }
341 }
342
343 #[test]
344 fn student_t_symmetric() {
345 let mut rng = default_rng_seeded(42);
346 let n = 100_000;
347 let df = 10.0;
348 let arr = rng.student_t(df, n).unwrap();
349 let slice = arr.as_slice().unwrap();
350 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
351 assert!(mean.abs() < 0.05, "student_t mean {mean} too far from 0");
353 }
354
355 #[test]
356 fn standard_gamma_mean() {
357 let mut rng = default_rng_seeded(42);
358 let n = 100_000;
359 let shape = 5.0;
360 let arr = rng.standard_gamma(shape, n).unwrap();
361 let slice = arr.as_slice().unwrap();
362 let mean: f64 = slice.iter().sum::<f64>() / n as f64;
363 let se = (shape / n as f64).sqrt();
364 assert!(
365 (mean - shape).abs() < 3.0 * se,
366 "standard_gamma mean {mean} too far from {shape}"
367 );
368 }
369
370 #[test]
371 fn gamma_bad_params() {
372 let mut rng = default_rng_seeded(42);
373 assert!(rng.gamma(0.0, 1.0, 100).is_err());
374 assert!(rng.gamma(1.0, 0.0, 100).is_err());
375 assert!(rng.gamma(-1.0, 1.0, 100).is_err());
376 }
377}