Skip to main content

ferray_random/
generator.rs

1// ferray-random: Generator struct — the main user-facing RNG API
2//
3// Wraps a BitGenerator and provides distribution sampling methods.
4// Takes &mut self — stateful, NOT Sync.
5
6use ferray_core::{Array, FerrayError, IxDyn};
7
8use crate::bitgen::{BitGenerator, Xoshiro256StarStar};
9
10/// The main random number generator, wrapping a pluggable [`BitGenerator`].
11///
12/// `Generator` takes `&mut self` for all sampling methods — it is stateful
13/// and NOT `Sync`. Thread-safety is handled by spawning independent generators
14/// via [`spawn`](Generator::spawn) or using the parallel generation API.
15///
16/// # Example
17/// ```
18/// use ferray_random::{default_rng_seeded, Generator};
19///
20/// let mut rng = default_rng_seeded(42);
21/// let values = rng.random(10).unwrap();
22/// assert_eq!(values.shape(), &[10]);
23/// ```
24pub struct Generator<B: BitGenerator = Xoshiro256StarStar> {
25    /// The underlying bit generator.
26    pub(crate) bg: B,
27    /// The seed used to create this generator (for spawn).
28    pub(crate) seed: u64,
29}
30
31impl<B: BitGenerator> Generator<B> {
32    /// Create a new `Generator` wrapping the given `BitGenerator`.
33    pub const fn new(bg: B) -> Self {
34        Self { bg, seed: 0 }
35    }
36
37    /// Create a new `Generator` with a known seed (stored for spawn).
38    pub(crate) const fn new_with_seed(bg: B, seed: u64) -> Self {
39        Self { bg, seed }
40    }
41
42    /// Access the underlying `BitGenerator` mutably.
43    #[inline]
44    pub const fn bit_generator(&mut self) -> &mut B {
45        &mut self.bg
46    }
47
48    /// Generate the next random `u64`.
49    #[inline]
50    pub fn next_u64(&mut self) -> u64 {
51        self.bg.next_u64()
52    }
53
54    /// Generate the next random `f64` in [0, 1).
55    #[inline]
56    pub fn next_f64(&mut self) -> f64 {
57        self.bg.next_f64()
58    }
59
60    /// Generate the next random `f32` in [0, 1).
61    #[inline]
62    pub fn next_f32(&mut self) -> f32 {
63        self.bg.next_f32()
64    }
65
66    /// Generate a `u64` in [0, bound).
67    #[inline]
68    pub fn next_u64_bounded(&mut self, bound: u64) -> u64 {
69        self.bg.next_u64_bounded(bound)
70    }
71
72    /// Serialize the underlying [`BitGenerator`]'s full internal state
73    /// to a byte vector — pair with [`set_state_bytes`](Self::set_state_bytes)
74    /// to restore. Used to checkpoint reproducible experiments (#453).
75    ///
76    /// The format is the LE-byte serialization of the bit generator's
77    /// state words; it is stable per-generator-type but **not**
78    /// portable across different `BitGenerator` implementations
79    /// (Pcg64 state cannot be loaded into Xoshiro256**).
80    ///
81    /// # Errors
82    /// `FerrayError::InvalidValue` if the underlying generator does
83    /// not implement state serialization.
84    pub fn state_bytes(&self) -> Result<Vec<u8>, FerrayError> {
85        self.bg.state_bytes()
86    }
87
88    /// Restore the underlying [`BitGenerator`]'s state from previously
89    /// captured bytes.
90    ///
91    /// # Errors
92    /// `FerrayError::InvalidValue` if the byte length is wrong for
93    /// this generator type or the embedded state is invalid (e.g.
94    /// all-zero state for Xoshiro256**, even `inc` for Pcg64).
95    pub fn set_state_bytes(&mut self, bytes: &[u8]) -> Result<(), FerrayError> {
96        self.bg.set_state_bytes(bytes)
97    }
98
99    /// Generate `n` random bytes as a `Vec<u8>`.
100    ///
101    /// Equivalent to `numpy.random.Generator.bytes(n)`. Each byte is
102    /// drawn from the underlying bit generator's `u64` stream and
103    /// little-endian-decomposed; calling `bytes(n)` advances the bit
104    /// generator by `ceil(n / 8)` `u64` draws (#446).
105    pub fn bytes(&mut self, n: usize) -> Vec<u8> {
106        let mut out = Vec::with_capacity(n);
107        let full_words = n / 8;
108        for _ in 0..full_words {
109            out.extend_from_slice(&self.bg.next_u64().to_le_bytes());
110        }
111        let remainder = n % 8;
112        if remainder > 0 {
113            let bytes = self.bg.next_u64().to_le_bytes();
114            out.extend_from_slice(&bytes[..remainder]);
115        }
116        out
117    }
118}
119
120/// Create a `Generator` with the default `BitGenerator` (Xoshiro256**)
121/// seeded from a non-deterministic source (using the system time as a
122/// simple entropy source).
123///
124/// # Example
125/// ```
126/// let mut rng = ferray_random::default_rng();
127/// let val = rng.next_f64();
128/// assert!((0.0..1.0).contains(&val));
129/// ```
130#[must_use]
131pub fn default_rng() -> Generator<Xoshiro256StarStar> {
132    // Use OS entropy via getrandom for proper seeding.
133    // Falls back to time-based entropy if getrandom fails.
134    let seed = {
135        let mut buf = [0u8; 8];
136        if getrandom::fill(&mut buf).is_ok() {
137            u64::from_ne_bytes(buf)
138        } else {
139            // Fallback: time + stack address
140            use std::time::SystemTime;
141            let dur = SystemTime::now()
142                .duration_since(SystemTime::UNIX_EPOCH)
143                .unwrap_or_default();
144            let nanos = dur.as_nanos();
145            let mut s = nanos as u64;
146            s ^= (nanos >> 64) as u64;
147            let stack_var: u8 = 0;
148            s ^= &raw const stack_var as u64;
149            s
150        }
151    };
152    default_rng_seeded(seed)
153}
154
155/// Create a `Generator` with the default `BitGenerator` (Xoshiro256**)
156/// from a specific seed, ensuring deterministic output.
157///
158/// # Example
159/// ```
160/// let mut rng1 = ferray_random::default_rng_seeded(42);
161/// let mut rng2 = ferray_random::default_rng_seeded(42);
162/// assert_eq!(rng1.next_u64(), rng2.next_u64());
163/// ```
164#[must_use]
165pub fn default_rng_seeded(seed: u64) -> Generator<Xoshiro256StarStar> {
166    let bg = Xoshiro256StarStar::seed_from_u64(seed);
167    Generator::new_with_seed(bg, seed)
168}
169
170/// Spawn `n` independent child generators from this generator.
171///
172/// Uses `jump()` if available (Xoshiro256**), otherwise uses
173/// `stream()` (Philox), otherwise falls back to seeding from
174/// the parent generator's output.
175///
176/// # Errors
177/// Returns `FerrayError::InvalidValue` if `n` is zero.
178pub fn spawn_generators<B: BitGenerator + Clone>(
179    parent: &mut Generator<B>,
180    n: usize,
181) -> Result<Vec<Generator<B>>, FerrayError> {
182    if n == 0 {
183        return Err(FerrayError::invalid_value("spawn count must be > 0"));
184    }
185
186    let mut children = Vec::with_capacity(n);
187
188    // Try jump-based spawning first
189    let mut test_bg = parent.bg.clone();
190    if test_bg.jump().is_some() {
191        // Jump-based: each child starts at a 2^128 offset
192        let mut current = parent.bg.clone();
193        for _ in 0..n {
194            children.push(Generator::new(current.clone()));
195            current.jump();
196        }
197        // Advance parent past all children
198        parent.bg = current;
199        return Ok(children);
200    }
201
202    // Try stream-based spawning
203    if let Some(first) = B::stream(parent.seed, 0) {
204        drop(first);
205        for i in 0..n {
206            if let Some(bg) = B::stream(parent.seed, i as u64) {
207                children.push(Generator::new(bg));
208            }
209        }
210        if children.len() == n {
211            return Ok(children);
212        }
213        children.clear();
214    }
215
216    // Fallback: seed from parent output (less ideal but works for PCG64)
217    for _ in 0..n {
218        let child_seed = parent.bg.next_u64();
219        let bg = B::seed_from_u64(child_seed);
220        children.push(Generator::new(bg));
221    }
222    Ok(children)
223}
224
225// Helper: generate a Vec<f64> of given total size using a closure.
226pub(crate) fn generate_vec<B: BitGenerator>(
227    rng: &mut Generator<B>,
228    size: usize,
229    mut f: impl FnMut(&mut B) -> f64,
230) -> Vec<f64> {
231    let mut data = Vec::with_capacity(size);
232    for _ in 0..size {
233        data.push(f(&mut rng.bg));
234    }
235    data
236}
237
238// Helper: generate a Vec<f32> of given total size using a closure.
239pub(crate) fn generate_vec_f32<B: BitGenerator>(
240    rng: &mut Generator<B>,
241    size: usize,
242    mut f: impl FnMut(&mut B) -> f32,
243) -> Vec<f32> {
244    let mut data = Vec::with_capacity(size);
245    for _ in 0..size {
246        data.push(f(&mut rng.bg));
247    }
248    data
249}
250
251// Helper: generate a Vec<i64> of given total size using a closure.
252pub(crate) fn generate_vec_i64<B: BitGenerator>(
253    rng: &mut Generator<B>,
254    size: usize,
255    mut f: impl FnMut(&mut B) -> i64,
256) -> Vec<i64> {
257    let mut data = Vec::with_capacity(size);
258    for _ in 0..size {
259        data.push(f(&mut rng.bg));
260    }
261    data
262}
263
264/// Total element count for a shape, returning 0 for an empty shape.
265#[inline]
266pub(crate) fn shape_size(shape: &[usize]) -> usize {
267    if shape.is_empty() {
268        0
269    } else {
270        shape.iter().product()
271    }
272}
273
274/// Wrap a `Vec<f64>` into an `Array<f64, IxDyn>` with the given shape.
275pub(crate) fn vec_to_array_f64(
276    data: Vec<f64>,
277    shape: &[usize],
278) -> Result<Array<f64, IxDyn>, FerrayError> {
279    Array::<f64, IxDyn>::from_vec(IxDyn::new(shape), data)
280}
281
282/// Wrap a `Vec<f32>` into an `Array<f32, IxDyn>` with the given shape.
283pub(crate) fn vec_to_array_f32(
284    data: Vec<f32>,
285    shape: &[usize],
286) -> Result<Array<f32, IxDyn>, FerrayError> {
287    Array::<f32, IxDyn>::from_vec(IxDyn::new(shape), data)
288}
289
290/// Wrap a `Vec<i64>` into an `Array<i64, IxDyn>` with the given shape.
291pub(crate) fn vec_to_array_i64(
292    data: Vec<i64>,
293    shape: &[usize],
294) -> Result<Array<i64, IxDyn>, FerrayError> {
295    Array::<i64, IxDyn>::from_vec(IxDyn::new(shape), data)
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301
302    #[test]
303    fn state_bytes_roundtrip_via_generator() {
304        // #453: Generator::state_bytes / set_state_bytes round-trip
305        // — capture state, draw a chunk, restore, draw the same chunk
306        // again and verify byte-equality.
307        let mut a = default_rng_seeded(2026);
308        // Burn a few values so we are not at the seed boundary.
309        for _ in 0..11 {
310            a.next_u64();
311        }
312        let snap = a.state_bytes().unwrap();
313        let from_a: Vec<u64> = (0..32).map(|_| a.next_u64()).collect();
314
315        let mut b = default_rng_seeded(0); // wrong seed on purpose
316        b.set_state_bytes(&snap).unwrap();
317        let from_b: Vec<u64> = (0..32).map(|_| b.next_u64()).collect();
318        assert_eq!(from_a, from_b);
319    }
320
321    #[test]
322    fn set_state_bytes_rejects_wrong_size() {
323        let mut a = default_rng_seeded(0);
324        assert!(a.set_state_bytes(&[0u8; 4]).is_err());
325    }
326
327    #[test]
328    fn default_rng_seeded_deterministic() {
329        let mut rng1 = default_rng_seeded(42);
330        let mut rng2 = default_rng_seeded(42);
331        for _ in 0..100 {
332            assert_eq!(rng1.next_u64(), rng2.next_u64());
333        }
334    }
335
336    #[test]
337    fn default_rng_works() {
338        let mut rng = default_rng();
339        let v = rng.next_f64();
340        assert!((0.0..1.0).contains(&v));
341    }
342
343    #[test]
344    fn spawn_xoshiro() {
345        let mut parent = default_rng_seeded(42);
346        let children = spawn_generators(&mut parent, 4).unwrap();
347        assert_eq!(children.len(), 4);
348    }
349
350    #[test]
351    fn spawn_zero_is_error() {
352        let mut parent = default_rng_seeded(42);
353        assert!(spawn_generators(&mut parent, 0).is_err());
354    }
355
356    // ----- bytes() coverage (#446) -----
357
358    #[test]
359    fn bytes_length_zero() {
360        let mut rng = default_rng_seeded(42);
361        assert!(rng.bytes(0).is_empty());
362    }
363
364    #[test]
365    fn bytes_length_full_word() {
366        let mut rng = default_rng_seeded(42);
367        let b = rng.bytes(8);
368        assert_eq!(b.len(), 8);
369    }
370
371    #[test]
372    fn bytes_length_partial_word() {
373        let mut rng = default_rng_seeded(42);
374        let b = rng.bytes(13);
375        assert_eq!(b.len(), 13);
376    }
377
378    #[test]
379    fn bytes_deterministic_for_same_seed() {
380        let mut rng1 = default_rng_seeded(42);
381        let mut rng2 = default_rng_seeded(42);
382        assert_eq!(rng1.bytes(64), rng2.bytes(64));
383    }
384}