1use rand::prelude::*;
2use rand::Rng;
3use std::cell::RefCell;
4use std::rc::Rc;
5
6pub trait NextRandValue
8where
9 Self: Sized, {
10 fn next_i64(&self) -> (i64, Self);
12
13 fn next_u64(&self) -> (u64, Self) {
15 let (i, r) = self.next_i64();
16 (if i < 0 { -(i + 1) as u64 } else { i as u64 }, r)
17 }
18
19 fn next_i32(&self) -> (i32, Self);
21
22 fn next_u32(&self) -> (u32, Self) {
24 let (i, r) = self.next_i32();
25 (if i < 0 { -(i + 1) as u32 } else { i as u32 }, r)
26 }
27
28 fn next_i16(&self) -> (i16, Self);
30
31 fn next_u16(&self) -> (u16, Self) {
33 let (i, r) = self.next_i16();
34 (if i < 0 { -(i + 1) as u16 } else { i as u16 }, r)
35 }
36
37 fn next_i8(&self) -> (i8, Self);
39
40 fn next_u8(&self) -> (u8, Self) {
42 let (i, r) = self.next_i8();
43 (if i < 0 { -(i + 1) as u8 } else { i as u8 }, r)
44 }
45
46 fn next_f64(&self) -> (f64, Self) {
53 let (i, r) = self.next_i64();
54 let normalized = if i < 0 {
56 (-i as f64) / (i64::MAX as f64 + 1.0)
57 } else {
58 i as f64 / (i64::MAX as f64 + 1.0)
59 };
60 (normalized, r)
61 }
62
63 fn next_f32(&self) -> (f32, Self) {
70 let (i, r) = self.next_i32();
71 let normalized = if i < 0 {
73 (-i as f32) / (i32::MAX as f32 + 1.0)
74 } else {
75 i as f32 / (i32::MAX as f32 + 1.0)
76 };
77 (normalized, r)
78 }
79
80 fn next_bool(&self) -> (bool, Self) {
82 let (i, r) = self.next_i32();
83 ((i % 2) != 0, r)
84 }
85}
86
87pub trait RandGen<T: NextRandValue>
89where
90 Self: Sized, {
91 fn rnd_gen(rng: T) -> (Self, T);
93}
94
95impl<T: NextRandValue> RandGen<T> for i64 {
96 fn rnd_gen(rng: T) -> (Self, T) {
97 rng.next_i64()
98 }
99}
100
101impl<T: NextRandValue> RandGen<T> for u32 {
102 fn rnd_gen(rng: T) -> (Self, T) {
103 rng.next_u32()
104 }
105}
106
107impl<T: NextRandValue> RandGen<T> for i32 {
108 fn rnd_gen(rng: T) -> (Self, T) {
109 rng.next_i32()
110 }
111}
112
113impl<T: NextRandValue> RandGen<T> for i16 {
114 fn rnd_gen(rng: T) -> (Self, T) {
115 rng.next_i16()
116 }
117}
118
119impl<T: NextRandValue> RandGen<T> for f32 {
120 fn rnd_gen(rng: T) -> (Self, T) {
121 rng.next_f32()
122 }
123}
124
125impl<T: NextRandValue> RandGen<T> for bool {
126 fn rnd_gen(rng: T) -> (Self, T) {
127 rng.next_bool()
128 }
129}
130
131#[derive(Clone, Debug, PartialEq)]
133pub struct RNG {
134 rng: Rc<RefCell<StdRng>>,
135}
136
137impl Default for RNG {
138 fn default() -> Self {
139 Self::new()
140 }
141}
142
143impl NextRandValue for RNG {
144 fn next_i64(&self) -> (i64, Self) {
145 let n = { self.rng.borrow_mut().random::<i64>() };
146 (
147 n,
148 Self {
149 rng: Rc::clone(&self.rng),
150 },
151 )
152 }
153
154 fn next_i32(&self) -> (i32, Self) {
155 let n = { self.rng.borrow_mut().random::<i32>() };
156 (
157 n,
158 Self {
159 rng: Rc::clone(&self.rng),
160 },
161 )
162 }
163
164 fn next_i16(&self) -> (i16, Self) {
165 let n = { self.rng.borrow_mut().random::<i16>() };
166 (
167 n,
168 Self {
169 rng: Rc::clone(&self.rng),
170 },
171 )
172 }
173
174 fn next_i8(&self) -> (i8, Self) {
175 let n = { self.rng.borrow_mut().random::<i8>() };
176 (
177 n,
178 Self {
179 rng: Rc::clone(&self.rng),
180 },
181 )
182 }
183}
184
185impl RNG {
186 pub fn new() -> Self {
188 Self {
189 rng: Rc::new(RefCell::new(StdRng::seed_from_u64(0))),
190 }
191 }
192
193 pub fn with_seed(mut self, seed: u64) -> Self {
195 self.rng = Rc::new(RefCell::new(StdRng::seed_from_u64(seed)));
196 self
197 }
198
199 pub fn i32_f32(&self) -> ((i32, f32), Self) {
201 let (i, r1) = self.next_i32();
202 let (d, r2) = r1.next_f32();
203 ((i, d), r2)
204 }
205
206 pub fn f32_i32(&self) -> ((f32, i32), Self) {
208 let ((i, d), r) = self.i32_f32();
209 ((d, i), r)
210 }
211
212 pub fn f32_3(&self) -> ((f32, f32, f32), Self) {
214 let (d1, r1) = self.next_f32();
215 let (d2, r2) = r1.next_f32();
216 let (d3, r3) = r2.next_f32();
217 ((d1, d2, d3), r3)
218 }
219
220 pub fn i32s(&self, count: u32) -> (Vec<i32>, Self) {
226 if count >= 50_000 {
228 return self.i32s_parallel(count);
229 }
230
231 self.i32s_direct(count)
233 }
234
235 pub fn i32s_direct(&self, count: u32) -> (Vec<i32>, Self) {
237 let mut result = Vec::with_capacity(count as usize);
238
239 {
240 let mut rng_inner = self.rng.borrow_mut();
241 for _ in 0..count {
242 result.push(rng_inner.random::<i32>());
243 }
244 }
245
246 (result, self.clone())
247 }
248
249 pub fn i32s_parallel(&self, count: u32) -> (Vec<i32>, Self) {
251 use rand::prelude::*;
252 use rand::SeedableRng;
253 use std::sync::{Arc, Mutex};
254 use std::thread;
255
256 let num_threads = num_cpus::get().min(8); let chunk_size = count / num_threads as u32;
258 let remainder = count % num_threads as u32;
259
260 let result = Arc::new(Mutex::new(Vec::with_capacity(count as usize)));
261
262 let mut handles = vec![];
263
264 for i in 0..num_threads {
265 let result_clone = Arc::clone(&result);
266 let mut thread_count = chunk_size;
267
268 if i == num_threads - 1 {
270 thread_count += remainder;
271 }
272
273 let handle = thread::spawn(move || {
274 let mut rng = StdRng::seed_from_u64(i as u64); let mut local_result = Vec::with_capacity(thread_count as usize);
276
277 for _ in 0..thread_count {
278 local_result.push(rng.random::<i32>());
279 }
280
281 let mut result = result_clone.lock().unwrap();
282 result.extend(local_result);
283 });
284
285 handles.push(handle);
286 }
287
288 for handle in handles {
289 handle.join().unwrap();
290 }
291
292 (Arc::try_unwrap(result).unwrap().into_inner().unwrap(), self.clone())
293 }
294
295 pub fn unit<A>(a: A) -> Box<dyn FnMut(RNG) -> (A, RNG)>
297 where
298 A: Clone + 'static, {
299 Box::new(move |rng: RNG| (a.clone(), rng))
300 }
301
302 pub fn sequence<A, F>(fs: Vec<F>) -> Box<dyn FnMut(RNG) -> (Vec<A>, RNG)>
305 where
306 A: Clone + 'static,
307 F: FnMut(RNG) -> (A, RNG) + 'static, {
308 let cap = fs.len();
309 let unit: Box<dyn FnMut(RNG) -> (Vec<A>, RNG)> = Box::new(move |rng: RNG| (Vec::<A>::with_capacity(cap), rng));
310 let result = fs.into_iter().fold(unit, |acc, e| {
311 let map_result: Box<dyn FnMut(RNG) -> (Vec<A>, RNG)> = Self::map2(acc, e, |mut a, b| {
312 a.push(b);
313 a
314 });
315 map_result
316 });
317 result
318 }
319
320 pub fn int_value() -> Box<dyn FnMut(RNG) -> (i32, RNG)> {
322 Box::new(move |rng| rng.next_i32())
323 }
324
325 pub fn double_value() -> Box<dyn FnMut(RNG) -> (f32, RNG)> {
327 Box::new(move |rng| rng.next_f32())
328 }
329
330 pub fn map<A, B, F1, F2>(mut s: F1, mut f: F2) -> Box<dyn FnMut(RNG) -> (B, RNG)>
332 where
333 F1: FnMut(RNG) -> (A, RNG) + 'static,
334 F2: FnMut(A) -> B + 'static, {
335 Box::new(move |rng| {
336 let (a, rng2) = s(rng);
337 (f(a), rng2)
338 })
339 }
340
341 pub fn map2<F1, F2, F3, A, B, C>(mut ra: F1, mut rb: F2, mut f: F3) -> Box<dyn FnMut(RNG) -> (C, RNG)>
343 where
344 F1: FnMut(RNG) -> (A, RNG) + 'static,
345 F2: FnMut(RNG) -> (B, RNG) + 'static,
346 F3: FnMut(A, B) -> C + 'static, {
347 Box::new(move |rng| {
348 let (a, r1) = ra(rng);
349 let (b, r2) = rb(r1);
350 (f(a, b), r2)
351 })
352 }
353
354 pub fn both<F1, F2, A, B>(ra: F1, rb: F2) -> Box<dyn FnMut(RNG) -> ((A, B), RNG)>
356 where
357 F1: FnMut(RNG) -> (A, RNG) + 'static,
358 F2: FnMut(RNG) -> (B, RNG) + 'static, {
359 Self::map2(ra, rb, |a, b| (a, b))
360 }
361
362 pub fn rand_int_double() -> Box<dyn FnMut(RNG) -> ((i32, f32), RNG)> {
364 Self::both(Self::int_value(), Self::double_value())
365 }
366
367 pub fn rand_double_int() -> Box<dyn FnMut(RNG) -> ((f32, i32), RNG)> {
369 Self::both(Self::double_value(), Self::int_value())
370 }
371
372 pub fn flat_map<A, B, F, GF, BF>(mut f: F, mut g: GF) -> Box<dyn FnMut(RNG) -> (B, RNG)>
374 where
375 F: FnMut(RNG) -> (A, RNG) + 'static,
376 BF: FnMut(RNG) -> (B, RNG) + 'static,
377 GF: FnMut(A) -> BF + 'static, {
378 Box::new(move |rng| {
379 let (a, r1) = f(rng);
380 (g(a))(r1)
381 })
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use crate::rng::{NextRandValue, RandGen, RNG};
388 use std::env;
389
390 fn init() {
391 env::set_var("RUST_LOG", "info");
392 let _ = env_logger::builder().is_test(true).try_init();
393 }
394
395 fn new_rng() -> RNG {
396 RNG::new()
397 }
398
399 #[test]
400 fn next_i32() {
401 init();
402 let rng = RNG::new();
403 let (v1, r1) = i32::rnd_gen(rng);
404 log::info!("{:?}", v1);
405 let (v2, _) = u32::rnd_gen(r1);
406 log::info!("{:?}", v2);
407 }
408
409 #[test]
410 fn test_next_i64() {
411 init();
412 let rng = new_rng();
413 let (value, new_rng) = rng.next_i64();
414 assert!(value >= i64::MIN && value <= i64::MAX);
415
416 let (value1, _) = rng.next_i32();
419 let (value2, _) = new_rng.next_i32();
420 assert_ne!(value1, value2);
422 }
423
424 #[test]
425 fn test_next_u64() {
426 init();
427 let rng = new_rng();
428 let (value, _) = rng.next_u64();
429 assert!(value <= u64::MAX);
430 }
431
432 #[test]
433 fn test_next_i32() {
434 init();
435 let rng = new_rng();
436 let (value, _) = rng.next_i32();
437 assert!(value >= i32::MIN && value <= i32::MAX);
438 }
439
440 #[test]
441 fn test_next_u32() {
442 init();
443 let rng = new_rng();
444 let (value, _) = rng.next_u32();
445 assert!(value <= u32::MAX);
446 }
447
448 #[test]
449 fn test_next_i16() {
450 init();
451 let rng = new_rng();
452 let (value, _) = rng.next_i16();
453 assert!(value >= i16::MIN && value <= i16::MAX);
454 }
455
456 #[test]
457 fn test_next_u16() {
458 init();
459 let rng = new_rng();
460 let (value, _) = rng.next_u16();
461 assert!(value <= u16::MAX);
462 }
463
464 #[test]
465 fn test_next_i8() {
466 init();
467 let rng = new_rng();
468 let (value, _) = rng.next_i8();
469 assert!(value >= i8::MIN && value <= i8::MAX);
470 }
471
472 #[test]
473 fn test_next_u8() {
474 init();
475 let rng = new_rng();
476 let (value, _) = rng.next_u8();
477 assert!(value <= u8::MAX);
478 }
479
480 #[test]
481 fn test_next_f64() {
482 init();
483 let rng = new_rng();
484 let (value, _) = rng.next_f64();
485 assert!(value >= 0.0 && value < 1.0);
486 }
487
488 #[test]
489 fn test_next_f32() {
490 init();
491 let rng = new_rng();
492 let (value, _) = rng.next_f32();
493 assert!(value >= 0.0 && value < 1.0);
494 }
495
496 #[test]
497 fn test_next_bool() {
498 init();
499 let rng = new_rng();
500 let (value, _) = rng.next_bool();
501 assert!(value == true || value == false);
502 }
503
504 #[test]
505 fn test_with_seed() {
506 init();
507 let rng1 = new_rng().with_seed(42);
508 let rng2 = new_rng().with_seed(42);
509
510 let (v1, _) = rng1.next_i32();
512 let (v2, _) = rng2.next_i32();
513 assert_eq!(v1, v2);
514 }
515
516 #[test]
517 fn test_i32_f32() {
518 init();
519 let rng = new_rng();
520 let ((i, f), _) = rng.i32_f32();
521 assert!(i >= i32::MIN && i <= i32::MAX);
522 assert!(f >= 0.0 && f < 1.0);
523 }
524
525 #[test]
526 fn test_f32_i32() {
527 init();
528 let rng = new_rng();
529 let ((f, i), _) = rng.f32_i32();
530 assert!(f >= 0.0 && f < 1.0);
531 assert!(i >= i32::MIN && i <= i32::MAX);
532 }
533
534 #[test]
535 fn test_f32_3() {
536 init();
537 let rng = new_rng();
538 let ((f1, f2, f3), _) = rng.f32_3();
539 assert!(f1 >= 0.0 && f1 < 1.0);
540 assert!(f2 >= 0.0 && f2 < 1.0);
541 assert!(f3 >= 0.0 && f3 < 1.0);
542 }
543
544 #[test]
545 fn test_i32s() {
546 init();
547 let rng = new_rng();
548 let count = 100;
549 let (values, _) = rng.i32s(count);
550 assert_eq!(values.len(), count as usize);
551
552 for value in values {
554 assert!(value >= i32::MIN && value <= i32::MAX);
555 }
556 }
557
558 #[test]
559 fn test_i32s_direct() {
560 init();
561 let rng = new_rng();
562 let count = 100;
563 let (values, _) = rng.i32s_direct(count);
564 assert_eq!(values.len(), count as usize);
565 }
566
567 #[test]
568 fn test_i32s_parallel() {
569 init();
570 let rng = new_rng();
571 let count = 50_000; let (values, _) = rng.i32s_parallel(count);
573 assert_eq!(values.len(), count as usize);
574 }
575
576 #[test]
577 fn test_unit() {
578 init();
579 let rng = new_rng();
580 let mut unit_fn = RNG::unit(42);
581 let (value, _) = unit_fn(rng);
582 assert_eq!(value, 42);
583 }
584
585 #[test]
586 fn test_sequence() {
587 init();
588 let rng = new_rng();
589 let mut fns = vec![RNG::unit(1), RNG::unit(2), RNG::unit(3)];
590 let mut sequence_fn = RNG::sequence(fns);
591 let (values, _) = sequence_fn(rng);
592 assert_eq!(values, vec![1, 2, 3]);
593 }
594
595 #[test]
596 fn test_int_value() {
597 init();
598 let rng = new_rng();
599 let mut int_fn = RNG::int_value();
600 let (value, _) = int_fn(rng);
601 assert!(value >= i32::MIN && value <= i32::MAX);
602 }
603
604 #[test]
605 fn test_double_value() {
606 init();
607 let rng = new_rng();
608 let mut double_fn = RNG::double_value();
609 let (value, _) = double_fn(rng);
610 assert!(value >= 0.0 && value < 1.0);
611 }
612
613 #[test]
614 fn test_map() {
615 init();
616 let rng = new_rng();
617
618 let mut int_fn = RNG::int_value();
620 let (original, rng2) = int_fn(rng);
621
622 let mapped_value = original / 2;
624
625 let mut map_fn = RNG::map(RNG::unit(original), |x| x / 2);
627 let (value, _) = map_fn(rng2);
628
629 assert_eq!(value, mapped_value);
631 }
632
633 #[test]
634 fn test_map2() {
635 init();
636 let rng = new_rng();
637 let mut int_fn = RNG::int_value();
638 let mut double_fn = RNG::double_value();
639 let mut map2_fn = RNG::map2(int_fn, double_fn, |i, d| (i as f32 + d));
640 let (value, _) = map2_fn(rng);
641 assert!(!value.is_nan());
642 }
643
644 #[test]
645 fn test_both() {
646 init();
647 let rng = new_rng();
648 let mut int_fn = RNG::int_value();
649 let mut double_fn = RNG::double_value();
650 let mut both_fn = RNG::both(int_fn, double_fn);
651 let ((i, d), _) = both_fn(rng);
652 assert!(i >= i32::MIN && i <= i32::MAX);
653 assert!(d >= 0.0 && d < 1.0);
654 }
655
656 #[test]
657 fn test_rand_int_double() {
658 init();
659 let rng = new_rng();
660 let mut fn_id = RNG::rand_int_double();
661 let ((i, d), _) = fn_id(rng);
662 assert!(i >= i32::MIN && i <= i32::MAX);
663 assert!(d >= 0.0 && d < 1.0);
664 }
665
666 #[test]
667 fn test_rand_double_int() {
668 init();
669 let rng = new_rng();
670 let mut fn_di = RNG::rand_double_int();
671 let ((d, i), _) = fn_di(rng);
672 assert!(d >= 0.0 && d < 1.0);
673 assert!(i >= i32::MIN && i <= i32::MAX);
674 }
675
676 #[test]
677 fn test_flat_map() {
678 init();
679 let rng = new_rng();
680 let mut int_fn = RNG::int_value();
681 let mut flat_map_fn = RNG::flat_map(int_fn, |i| RNG::unit(i * 2));
682 let (value, _) = flat_map_fn(rng);
683 assert!(value % 2 == 0); }
685}