ferray_random/
generator.rs1use ferray_core::{Array, FerrayError, IxDyn};
7
8use crate::bitgen::{BitGenerator, Xoshiro256StarStar};
9
10pub struct Generator<B: BitGenerator = Xoshiro256StarStar> {
25 pub(crate) bg: B,
27 pub(crate) seed: u64,
29}
30
31impl<B: BitGenerator> Generator<B> {
32 pub const fn new(bg: B) -> Self {
34 Self { bg, seed: 0 }
35 }
36
37 pub(crate) const fn new_with_seed(bg: B, seed: u64) -> Self {
39 Self { bg, seed }
40 }
41
42 #[inline]
44 pub const fn bit_generator(&mut self) -> &mut B {
45 &mut self.bg
46 }
47
48 #[inline]
50 pub fn next_u64(&mut self) -> u64 {
51 self.bg.next_u64()
52 }
53
54 #[inline]
56 pub fn next_f64(&mut self) -> f64 {
57 self.bg.next_f64()
58 }
59
60 #[inline]
62 pub fn next_f32(&mut self) -> f32 {
63 self.bg.next_f32()
64 }
65
66 #[inline]
68 pub fn next_u64_bounded(&mut self, bound: u64) -> u64 {
69 self.bg.next_u64_bounded(bound)
70 }
71
72 pub fn state_bytes(&self) -> Result<Vec<u8>, FerrayError> {
85 self.bg.state_bytes()
86 }
87
88 pub fn set_state_bytes(&mut self, bytes: &[u8]) -> Result<(), FerrayError> {
96 self.bg.set_state_bytes(bytes)
97 }
98
99 pub fn bytes(&mut self, n: usize) -> Vec<u8> {
106 let mut out = Vec::with_capacity(n);
107 let full_words = n / 8;
108 for _ in 0..full_words {
109 out.extend_from_slice(&self.bg.next_u64().to_le_bytes());
110 }
111 let remainder = n % 8;
112 if remainder > 0 {
113 let bytes = self.bg.next_u64().to_le_bytes();
114 out.extend_from_slice(&bytes[..remainder]);
115 }
116 out
117 }
118}
119
120#[must_use]
131pub fn default_rng() -> Generator<Xoshiro256StarStar> {
132 let seed = {
135 let mut buf = [0u8; 8];
136 if getrandom::fill(&mut buf).is_ok() {
137 u64::from_ne_bytes(buf)
138 } else {
139 use std::time::SystemTime;
141 let dur = SystemTime::now()
142 .duration_since(SystemTime::UNIX_EPOCH)
143 .unwrap_or_default();
144 let nanos = dur.as_nanos();
145 let mut s = nanos as u64;
146 s ^= (nanos >> 64) as u64;
147 let stack_var: u8 = 0;
148 s ^= &raw const stack_var as u64;
149 s
150 }
151 };
152 default_rng_seeded(seed)
153}
154
155#[must_use]
165pub fn default_rng_seeded(seed: u64) -> Generator<Xoshiro256StarStar> {
166 let bg = Xoshiro256StarStar::seed_from_u64(seed);
167 Generator::new_with_seed(bg, seed)
168}
169
170pub fn spawn_generators<B: BitGenerator + Clone>(
179 parent: &mut Generator<B>,
180 n: usize,
181) -> Result<Vec<Generator<B>>, FerrayError> {
182 if n == 0 {
183 return Err(FerrayError::invalid_value("spawn count must be > 0"));
184 }
185
186 let mut children = Vec::with_capacity(n);
187
188 let mut test_bg = parent.bg.clone();
190 if test_bg.jump().is_some() {
191 let mut current = parent.bg.clone();
193 for _ in 0..n {
194 children.push(Generator::new(current.clone()));
195 current.jump();
196 }
197 parent.bg = current;
199 return Ok(children);
200 }
201
202 if let Some(first) = B::stream(parent.seed, 0) {
204 drop(first);
205 for i in 0..n {
206 if let Some(bg) = B::stream(parent.seed, i as u64) {
207 children.push(Generator::new(bg));
208 }
209 }
210 if children.len() == n {
211 return Ok(children);
212 }
213 children.clear();
214 }
215
216 for _ in 0..n {
218 let child_seed = parent.bg.next_u64();
219 let bg = B::seed_from_u64(child_seed);
220 children.push(Generator::new(bg));
221 }
222 Ok(children)
223}
224
225pub(crate) fn generate_vec<B: BitGenerator>(
227 rng: &mut Generator<B>,
228 size: usize,
229 mut f: impl FnMut(&mut B) -> f64,
230) -> Vec<f64> {
231 let mut data = Vec::with_capacity(size);
232 for _ in 0..size {
233 data.push(f(&mut rng.bg));
234 }
235 data
236}
237
238pub(crate) fn generate_vec_f32<B: BitGenerator>(
240 rng: &mut Generator<B>,
241 size: usize,
242 mut f: impl FnMut(&mut B) -> f32,
243) -> Vec<f32> {
244 let mut data = Vec::with_capacity(size);
245 for _ in 0..size {
246 data.push(f(&mut rng.bg));
247 }
248 data
249}
250
251pub(crate) fn generate_vec_i64<B: BitGenerator>(
253 rng: &mut Generator<B>,
254 size: usize,
255 mut f: impl FnMut(&mut B) -> i64,
256) -> Vec<i64> {
257 let mut data = Vec::with_capacity(size);
258 for _ in 0..size {
259 data.push(f(&mut rng.bg));
260 }
261 data
262}
263
264#[inline]
266pub(crate) fn shape_size(shape: &[usize]) -> usize {
267 if shape.is_empty() {
268 0
269 } else {
270 shape.iter().product()
271 }
272}
273
274pub(crate) fn vec_to_array_f64(
276 data: Vec<f64>,
277 shape: &[usize],
278) -> Result<Array<f64, IxDyn>, FerrayError> {
279 Array::<f64, IxDyn>::from_vec(IxDyn::new(shape), data)
280}
281
282pub(crate) fn vec_to_array_f32(
284 data: Vec<f32>,
285 shape: &[usize],
286) -> Result<Array<f32, IxDyn>, FerrayError> {
287 Array::<f32, IxDyn>::from_vec(IxDyn::new(shape), data)
288}
289
290pub(crate) fn vec_to_array_i64(
292 data: Vec<i64>,
293 shape: &[usize],
294) -> Result<Array<i64, IxDyn>, FerrayError> {
295 Array::<i64, IxDyn>::from_vec(IxDyn::new(shape), data)
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301
302 #[test]
303 fn state_bytes_roundtrip_via_generator() {
304 let mut a = default_rng_seeded(2026);
308 for _ in 0..11 {
310 a.next_u64();
311 }
312 let snap = a.state_bytes().unwrap();
313 let from_a: Vec<u64> = (0..32).map(|_| a.next_u64()).collect();
314
315 let mut b = default_rng_seeded(0); b.set_state_bytes(&snap).unwrap();
317 let from_b: Vec<u64> = (0..32).map(|_| b.next_u64()).collect();
318 assert_eq!(from_a, from_b);
319 }
320
321 #[test]
322 fn set_state_bytes_rejects_wrong_size() {
323 let mut a = default_rng_seeded(0);
324 assert!(a.set_state_bytes(&[0u8; 4]).is_err());
325 }
326
327 #[test]
328 fn default_rng_seeded_deterministic() {
329 let mut rng1 = default_rng_seeded(42);
330 let mut rng2 = default_rng_seeded(42);
331 for _ in 0..100 {
332 assert_eq!(rng1.next_u64(), rng2.next_u64());
333 }
334 }
335
336 #[test]
337 fn default_rng_works() {
338 let mut rng = default_rng();
339 let v = rng.next_f64();
340 assert!((0.0..1.0).contains(&v));
341 }
342
343 #[test]
344 fn spawn_xoshiro() {
345 let mut parent = default_rng_seeded(42);
346 let children = spawn_generators(&mut parent, 4).unwrap();
347 assert_eq!(children.len(), 4);
348 }
349
350 #[test]
351 fn spawn_zero_is_error() {
352 let mut parent = default_rng_seeded(42);
353 assert!(spawn_generators(&mut parent, 0).is_err());
354 }
355
356 #[test]
359 fn bytes_length_zero() {
360 let mut rng = default_rng_seeded(42);
361 assert!(rng.bytes(0).is_empty());
362 }
363
364 #[test]
365 fn bytes_length_full_word() {
366 let mut rng = default_rng_seeded(42);
367 let b = rng.bytes(8);
368 assert_eq!(b.len(), 8);
369 }
370
371 #[test]
372 fn bytes_length_partial_word() {
373 let mut rng = default_rng_seeded(42);
374 let b = rng.bytes(13);
375 assert_eq!(b.len(), 13);
376 }
377
378 #[test]
379 fn bytes_deterministic_for_same_seed() {
380 let mut rng1 = default_rng_seeded(42);
381 let mut rng2 = default_rng_seeded(42);
382 assert_eq!(rng1.bytes(64), rng2.bytes(64));
383 }
384}