Skip to main content

ferray_random/distributions/
uniform.rs

1// ferray-random: Uniform distribution sampling — random, uniform, integers
2
3use ferray_core::{Array, FerrayError, Ix1};
4
5use crate::bitgen::BitGenerator;
6use crate::generator::{
7    Generator, generate_vec, generate_vec_i64, vec_to_array1, vec_to_array1_i64,
8};
9
10impl<B: BitGenerator> Generator<B> {
11    /// Generate an array of uniformly distributed `f64` values in [0, 1).
12    ///
13    /// Equivalent to NumPy's `Generator.random(size)`.
14    ///
15    /// # Arguments
16    /// * `size` - Number of values to generate.
17    ///
18    /// # Errors
19    /// Returns `FerrayError::InvalidValue` if `size` is zero.
20    ///
21    /// # Example
22    /// ```
23    /// let mut rng = ferray_random::default_rng_seeded(42);
24    /// let arr = rng.random(10).unwrap();
25    /// assert_eq!(arr.shape(), &[10]);
26    /// ```
27    pub fn random(&mut self, size: usize) -> Result<Array<f64, Ix1>, FerrayError> {
28        if size == 0 {
29            return Err(FerrayError::invalid_value("size must be > 0"));
30        }
31        let data = generate_vec(self, size, |bg| bg.next_f64());
32        vec_to_array1(data)
33    }
34
35    /// Generate an array of uniformly distributed `f64` values in [low, high).
36    ///
37    /// Equivalent to NumPy's `Generator.uniform(low, high, size)`.
38    ///
39    /// # Arguments
40    /// * `low` - Lower bound (inclusive).
41    /// * `high` - Upper bound (exclusive).
42    /// * `size` - Number of values to generate.
43    ///
44    /// # Errors
45    /// Returns `FerrayError::InvalidValue` if `low >= high` or `size` is zero.
46    pub fn uniform(
47        &mut self,
48        low: f64,
49        high: f64,
50        size: usize,
51    ) -> Result<Array<f64, Ix1>, FerrayError> {
52        if size == 0 {
53            return Err(FerrayError::invalid_value("size must be > 0"));
54        }
55        if low >= high {
56            return Err(FerrayError::invalid_value(format!(
57                "low ({low}) must be less than high ({high})"
58            )));
59        }
60        let range = high - low;
61        let data = generate_vec(self, size, |bg| low + bg.next_f64() * range);
62        vec_to_array1(data)
63    }
64
65    /// Generate an array of uniformly distributed random integers in [low, high).
66    ///
67    /// Equivalent to NumPy's `Generator.integers(low, high, size)`.
68    ///
69    /// # Arguments
70    /// * `low` - Lower bound (inclusive).
71    /// * `high` - Upper bound (exclusive).
72    /// * `size` - Number of values to generate.
73    ///
74    /// # Errors
75    /// Returns `FerrayError::InvalidValue` if `low >= high` or `size` is zero.
76    pub fn integers(
77        &mut self,
78        low: i64,
79        high: i64,
80        size: usize,
81    ) -> Result<Array<i64, Ix1>, FerrayError> {
82        if size == 0 {
83            return Err(FerrayError::invalid_value("size must be > 0"));
84        }
85        if low >= high {
86            return Err(FerrayError::invalid_value(format!(
87                "low ({low}) must be less than high ({high})"
88            )));
89        }
90        let range = (high - low) as u64;
91        let data = generate_vec_i64(self, size, |bg| low + bg.next_u64_bounded(range) as i64);
92        vec_to_array1_i64(data)
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use crate::default_rng_seeded;
99
100    #[test]
101    fn random_in_range() {
102        let mut rng = default_rng_seeded(42);
103        let arr = rng.random(10_000).unwrap();
104        let slice = arr.as_slice().unwrap();
105        for &v in slice {
106            assert!((0.0..1.0).contains(&v));
107        }
108    }
109
110    #[test]
111    fn random_deterministic() {
112        let mut rng1 = default_rng_seeded(42);
113        let mut rng2 = default_rng_seeded(42);
114        let a = rng1.random(100).unwrap();
115        let b = rng2.random(100).unwrap();
116        assert_eq!(a.as_slice().unwrap(), b.as_slice().unwrap());
117    }
118
119    #[test]
120    fn uniform_in_range() {
121        let mut rng = default_rng_seeded(42);
122        let arr = rng.uniform(5.0, 10.0, 10_000).unwrap();
123        let slice = arr.as_slice().unwrap();
124        for &v in slice {
125            assert!(v >= 5.0 && v < 10.0, "value {v} out of range");
126        }
127    }
128
129    #[test]
130    fn uniform_bad_range() {
131        let mut rng = default_rng_seeded(42);
132        assert!(rng.uniform(10.0, 5.0, 100).is_err());
133        assert!(rng.uniform(5.0, 5.0, 100).is_err());
134    }
135
136    #[test]
137    fn integers_in_range() {
138        let mut rng = default_rng_seeded(42);
139        let arr = rng.integers(0, 10, 10_000).unwrap();
140        let slice = arr.as_slice().unwrap();
141        for &v in slice {
142            assert!((0..10).contains(&v), "value {v} out of range");
143        }
144    }
145
146    #[test]
147    fn integers_negative_range() {
148        let mut rng = default_rng_seeded(42);
149        let arr = rng.integers(-5, 5, 1000).unwrap();
150        let slice = arr.as_slice().unwrap();
151        for &v in slice {
152            assert!((-5..5).contains(&v), "value {v} out of range");
153        }
154    }
155
156    #[test]
157    fn integers_bad_range() {
158        let mut rng = default_rng_seeded(42);
159        assert!(rng.integers(10, 5, 100).is_err());
160    }
161
162    #[test]
163    fn uniform_mean_variance() {
164        let mut rng = default_rng_seeded(42);
165        let n = 100_000;
166        let arr = rng.uniform(2.0, 8.0, n).unwrap();
167        let slice = arr.as_slice().unwrap();
168        let mean: f64 = slice.iter().sum::<f64>() / n as f64;
169        let var: f64 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / n as f64;
170        // Uniform(a,b): mean = (a+b)/2 = 5.0, var = (b-a)^2/12 = 3.0
171        let expected_mean = 5.0;
172        let expected_var = 3.0;
173        let se_mean = (expected_var / n as f64).sqrt();
174        assert!(
175            (mean - expected_mean).abs() < 3.0 * se_mean,
176            "mean {mean} too far from {expected_mean}"
177        );
178        // Variance check: use chi-squared-like tolerance
179        assert!(
180            (var - expected_var).abs() < 0.1,
181            "variance {var} too far from {expected_var}"
182        );
183    }
184}