ferray_random/
generator.rs1use ferray_core::{Array, FerrayError, Ix1};
7
8use crate::bitgen::{BitGenerator, Xoshiro256StarStar};
9
10pub struct Generator<B: BitGenerator = Xoshiro256StarStar> {
25 pub(crate) bg: B,
27 pub(crate) seed: u64,
29}
30
31impl<B: BitGenerator> Generator<B> {
32 pub fn new(bg: B) -> Self {
34 Self { bg, seed: 0 }
35 }
36
37 pub(crate) fn new_with_seed(bg: B, seed: u64) -> Self {
39 Self { bg, seed }
40 }
41
42 #[inline]
44 pub fn bit_generator(&mut self) -> &mut B {
45 &mut self.bg
46 }
47
48 #[inline]
50 pub fn next_u64(&mut self) -> u64 {
51 self.bg.next_u64()
52 }
53
54 #[inline]
56 pub fn next_f64(&mut self) -> f64 {
57 self.bg.next_f64()
58 }
59
60 #[inline]
62 pub fn next_u64_bounded(&mut self, bound: u64) -> u64 {
63 self.bg.next_u64_bounded(bound)
64 }
65}
66
67pub fn default_rng() -> Generator<Xoshiro256StarStar> {
78 let seed = {
80 use std::time::SystemTime;
81 let dur = SystemTime::now()
82 .duration_since(SystemTime::UNIX_EPOCH)
83 .unwrap_or_default();
84 let nanos = dur.as_nanos();
85 let mut s = nanos as u64;
87 s ^= (nanos >> 64) as u64;
88 let stack_var: u8 = 0;
90 let addr = &stack_var as *const u8 as u64;
91 s ^= addr;
92 s
93 };
94 default_rng_seeded(seed)
95}
96
97pub fn default_rng_seeded(seed: u64) -> Generator<Xoshiro256StarStar> {
107 let bg = Xoshiro256StarStar::seed_from_u64(seed);
108 Generator::new_with_seed(bg, seed)
109}
110
111pub fn spawn_generators<B: BitGenerator + Clone>(
120 parent: &mut Generator<B>,
121 n: usize,
122) -> Result<Vec<Generator<B>>, FerrayError> {
123 if n == 0 {
124 return Err(FerrayError::invalid_value("spawn count must be > 0"));
125 }
126
127 let mut children = Vec::with_capacity(n);
128
129 let mut test_bg = parent.bg.clone();
131 if test_bg.jump().is_some() {
132 let mut current = parent.bg.clone();
134 for _ in 0..n {
135 children.push(Generator::new(current.clone()));
136 current.jump();
137 }
138 parent.bg = current;
140 return Ok(children);
141 }
142
143 if let Some(first) = B::stream(parent.seed, 0) {
145 drop(first);
146 for i in 0..n {
147 if let Some(bg) = B::stream(parent.seed, i as u64) {
148 children.push(Generator::new(bg));
149 }
150 }
151 if children.len() == n {
152 return Ok(children);
153 }
154 children.clear();
155 }
156
157 for _ in 0..n {
159 let child_seed = parent.bg.next_u64();
160 let bg = B::seed_from_u64(child_seed);
161 children.push(Generator::new(bg));
162 }
163 Ok(children)
164}
165
166pub(crate) fn generate_vec<B: BitGenerator>(
168 rng: &mut Generator<B>,
169 size: usize,
170 mut f: impl FnMut(&mut B) -> f64,
171) -> Vec<f64> {
172 let mut data = Vec::with_capacity(size);
173 for _ in 0..size {
174 data.push(f(&mut rng.bg));
175 }
176 data
177}
178
179pub(crate) fn generate_vec_i64<B: BitGenerator>(
181 rng: &mut Generator<B>,
182 size: usize,
183 mut f: impl FnMut(&mut B) -> i64,
184) -> Vec<i64> {
185 let mut data = Vec::with_capacity(size);
186 for _ in 0..size {
187 data.push(f(&mut rng.bg));
188 }
189 data
190}
191
192pub(crate) fn vec_to_array1(data: Vec<f64>) -> Result<Array<f64, Ix1>, FerrayError> {
194 let n = data.len();
195 Array::<f64, Ix1>::from_vec(Ix1::new([n]), data)
196}
197
198pub(crate) fn vec_to_array1_i64(data: Vec<i64>) -> Result<Array<i64, Ix1>, FerrayError> {
200 let n = data.len();
201 Array::<i64, Ix1>::from_vec(Ix1::new([n]), data)
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 #[test]
209 fn default_rng_seeded_deterministic() {
210 let mut rng1 = default_rng_seeded(42);
211 let mut rng2 = default_rng_seeded(42);
212 for _ in 0..100 {
213 assert_eq!(rng1.next_u64(), rng2.next_u64());
214 }
215 }
216
217 #[test]
218 fn default_rng_works() {
219 let mut rng = default_rng();
220 let v = rng.next_f64();
221 assert!((0.0..1.0).contains(&v));
222 }
223
224 #[test]
225 fn spawn_xoshiro() {
226 let mut parent = default_rng_seeded(42);
227 let children = spawn_generators(&mut parent, 4).unwrap();
228 assert_eq!(children.len(), 4);
229 }
230
231 #[test]
232 fn spawn_zero_is_error() {
233 let mut parent = default_rng_seeded(42);
234 assert!(spawn_generators(&mut parent, 0).is_err());
235 }
236}