Skip to main content

ferray_random/distributions/
exponential.rs

1// ferray-random: Exponential distribution sampling — standard_exponential, exponential
2
3use ferray_core::{Array, FerrayError, IxDyn};
4
5use crate::bitgen::BitGenerator;
6use crate::generator::{
7    Generator, generate_vec, generate_vec_f32, shape_size, vec_to_array_f32, vec_to_array_f64,
8};
9use crate::shape::IntoShape;
10
11/// Generate a single standard exponential variate (rate=1) via inverse CDF.
12pub(crate) fn standard_exponential_single<B: BitGenerator>(bg: &mut B) -> f64 {
13    loop {
14        let u = bg.next_f64();
15        if u > f64::EPSILON {
16            return -u.ln();
17        }
18    }
19}
20
21/// Generate a single standard exponential f32 variate.
22///
23/// Performed in f64 then cast to avoid f32 log precision loss near 0.
24pub(crate) fn standard_exponential_single_f32<B: BitGenerator>(bg: &mut B) -> f32 {
25    standard_exponential_single(bg) as f32
26}
27
28impl<B: BitGenerator> Generator<B> {
29    /// Generate an array of standard exponential (rate=1, scale=1) variates.
30    ///
31    /// Uses the inverse CDF method: -ln(U) where U ~ Uniform(0,1).
32    ///
33    /// # Errors
34    /// Returns `FerrayError::InvalidValue` if `shape` is invalid.
35    pub fn standard_exponential(
36        &mut self,
37        size: impl IntoShape,
38    ) -> Result<Array<f64, IxDyn>, FerrayError> {
39        let shape = size.into_shape()?;
40        let n = shape_size(&shape);
41        let data = generate_vec(self, n, standard_exponential_single);
42        vec_to_array_f64(data, &shape)
43    }
44
45    /// Generate an array of exponential variates with the given scale.
46    ///
47    /// The exponential distribution has PDF: f(x) = (1/scale) * exp(-x/scale).
48    ///
49    /// # Errors
50    /// Returns `FerrayError::InvalidValue` if `scale <= 0` or `shape` is invalid.
51    pub fn exponential(
52        &mut self,
53        scale: f64,
54        size: impl IntoShape,
55    ) -> Result<Array<f64, IxDyn>, FerrayError> {
56        if scale <= 0.0 {
57            return Err(FerrayError::invalid_value(format!(
58                "scale must be positive, got {scale}"
59            )));
60        }
61        let shape = size.into_shape()?;
62        let n = shape_size(&shape);
63        let data = generate_vec(self, n, |bg| scale * standard_exponential_single(bg));
64        vec_to_array_f64(data, &shape)
65    }
66
67    /// Generate an array of standard exponential `f32` variates.
68    ///
69    /// The f32 analogue of [`standard_exponential`](Self::standard_exponential).
70    ///
71    /// # Errors
72    /// Returns `FerrayError::InvalidValue` if `shape` is invalid.
73    pub fn standard_exponential_f32(
74        &mut self,
75        size: impl IntoShape,
76    ) -> Result<Array<f32, IxDyn>, FerrayError> {
77        let shape = size.into_shape()?;
78        let n = shape_size(&shape);
79        let data = generate_vec_f32(self, n, standard_exponential_single_f32);
80        vec_to_array_f32(data, &shape)
81    }
82
83    /// Generate an array of `f32` exponential variates with the given scale.
84    /// The f32 analogue of [`exponential`](Self::exponential).
85    ///
86    /// # Errors
87    /// Returns `FerrayError::InvalidValue` if `scale <= 0` or `shape` is invalid.
88    pub fn exponential_f32(
89        &mut self,
90        scale: f32,
91        size: impl IntoShape,
92    ) -> Result<Array<f32, IxDyn>, FerrayError> {
93        if scale <= 0.0 {
94            return Err(FerrayError::invalid_value(format!(
95                "scale must be positive, got {scale}"
96            )));
97        }
98        let shape = size.into_shape()?;
99        let n = shape_size(&shape);
100        let data = generate_vec_f32(self, n, |bg| scale * standard_exponential_single_f32(bg));
101        vec_to_array_f32(data, &shape)
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use crate::default_rng_seeded;
108
109    #[test]
110    fn standard_exponential_positive() {
111        let mut rng = default_rng_seeded(42);
112        let arr = rng.standard_exponential(10_000).unwrap();
113        let slice = arr.as_slice().unwrap();
114        for &v in slice {
115            assert!(
116                v > 0.0,
117                "standard_exponential produced non-positive value: {v}"
118            );
119        }
120    }
121
122    #[test]
123    fn standard_exponential_mean_variance() {
124        let mut rng = default_rng_seeded(42);
125        let n = 100_000;
126        let arr = rng.standard_exponential(n).unwrap();
127        let slice = arr.as_slice().unwrap();
128        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
129        let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
130        // Exp(1): mean=1, var=1
131        let se = (1.0 / n as f64).sqrt();
132        assert!((mean - 1.0).abs() < 3.0 * se, "mean {mean} too far from 1");
133        assert!((var - 1.0).abs() < 0.05, "variance {var} too far from 1");
134    }
135
136    #[test]
137    fn exponential_mean() {
138        let mut rng = default_rng_seeded(42);
139        let n = 100_000;
140        let scale = 3.0;
141        let arr = rng.exponential(scale, n).unwrap();
142        let slice = arr.as_slice().unwrap();
143        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
144        let se = (scale * scale / n as f64).sqrt();
145        assert!(
146            (mean - scale).abs() < 3.0 * se,
147            "mean {mean} too far from {scale}"
148        );
149    }
150
151    #[test]
152    fn exponential_bad_scale() {
153        let mut rng = default_rng_seeded(42);
154        assert!(rng.exponential(0.0, 100).is_err());
155        assert!(rng.exponential(-1.0, 100).is_err());
156    }
157
158    #[test]
159    fn exponential_deterministic() {
160        let mut rng1 = default_rng_seeded(42);
161        let mut rng2 = default_rng_seeded(42);
162        let a = rng1.exponential(2.0, 100).unwrap();
163        let b = rng2.exponential(2.0, 100).unwrap();
164        assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
165    }
166
167    #[test]
168    fn exponential_mean_and_variance() {
169        let mut rng = default_rng_seeded(42);
170        let n = 100_000;
171        let scale = 3.0;
172        let arr = rng.exponential(scale, n).unwrap();
173        let s = arr.as_slice().unwrap();
174        let mean: f64 = s.iter().sum::<f64>() / n as f64;
175        let var: f64 = s.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
176        // Exponential(scale): mean=scale, var=scale^2
177        assert!(
178            (mean - scale).abs() < 0.1,
179            "exponential mean {mean} too far from {scale}"
180        );
181        assert!(
182            (var - scale * scale).abs() < 1.0,
183            "exponential variance {var} too far from {}",
184            scale * scale
185        );
186    }
187
188    #[test]
189    fn standard_exponential_mean() {
190        let mut rng = default_rng_seeded(42);
191        let n = 100_000;
192        let arr = rng.standard_exponential(n).unwrap();
193        let s = arr.as_slice().unwrap();
194        let mean: f64 = s.iter().sum::<f64>() / n as f64;
195        assert!(
196            (mean - 1.0).abs() < 0.02,
197            "standard_exponential mean {mean} too far from 1.0"
198        );
199        // All values should be non-negative
200        assert!(s.iter().all(|&x| x >= 0.0), "negative exponential value");
201    }
202
203    // ---------------------------------------------------------------
204    // f32 variants (issue #441)
205    // ---------------------------------------------------------------
206
207    #[test]
208    fn standard_exponential_f32_positive() {
209        let mut rng = default_rng_seeded(42);
210        let arr = rng.standard_exponential_f32(10_000).unwrap();
211        for &v in arr.as_slice().unwrap() {
212            assert!(
213                v > 0.0,
214                "standard_exponential_f32 produced non-positive: {v}"
215            );
216        }
217    }
218
219    #[test]
220    fn standard_exponential_f32_mean() {
221        let mut rng = default_rng_seeded(42);
222        let n = 100_000usize;
223        let arr = rng.standard_exponential_f32(n).unwrap();
224        let slice = arr.as_slice().unwrap();
225        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
226        assert!(
227            (mean - 1.0).abs() < 0.02,
228            "f32 exp mean {mean} too far from 1"
229        );
230    }
231
232    #[test]
233    fn exponential_f32_mean() {
234        let mut rng = default_rng_seeded(42);
235        let n = 100_000usize;
236        let scale = 3.0f32;
237        let arr = rng.exponential_f32(scale, n).unwrap();
238        let slice = arr.as_slice().unwrap();
239        let mean: f64 = slice.iter().map(|&x| x as f64).sum::<f64>() / n as f64;
240        assert!(
241            (mean - scale as f64).abs() < 0.1,
242            "exponential_f32 mean {mean} too far from {scale}"
243        );
244    }
245
246    #[test]
247    fn exponential_f32_bad_scale() {
248        let mut rng = default_rng_seeded(42);
249        assert!(rng.exponential_f32(0.0, 100).is_err());
250        assert!(rng.exponential_f32(-1.0, 100).is_err());
251    }
252
253    #[test]
254    fn exponential_f32_deterministic() {
255        let mut rng1 = default_rng_seeded(42);
256        let mut rng2 = default_rng_seeded(42);
257        let a = rng1.exponential_f32(2.0, 100).unwrap();
258        let b = rng2.exponential_f32(2.0, 100).unwrap();
259        assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
260    }
261}