1use std::ops::{Shl, Shr, Range, RangeInclusive};
2use num_traits::{
3 cast::{AsPrimitive, FromPrimitive, NumCast},
4 PrimInt, Unsigned, ops::overflowing::{OverflowingMul, OverflowingAdd}, Bounded, One, WrappingSub,
5};
6use rand::{CryptoRng, RngCore};
7
8use crate::{magnitude::HasMagnitude, widening::Widening};
9
10pub trait RandomNumberGenerator: RngCore + CryptoRng {
11 fn random_data(&mut self, size: usize) -> Vec<u8> {
13 let mut data = vec![0; size];
14 self.fill_bytes(&mut data);
15 data
16 }
17
18 fn fill_random_data(&mut self, data: &mut [u8]) {
19 self.fill_bytes(data);
20 }
21}
22
23pub fn rng_random_data(rng: &mut impl RandomNumberGenerator, size: usize) -> Vec<u8> {
25 let mut data = vec![0; size];
26 rng.fill_random_data(&mut data);
27 data
28}
29
30pub fn rng_fill_random_data(rng: &mut impl RandomNumberGenerator, data: &mut [u8]) {
32 rng.fill_random_data(data);
33}
34
35pub fn rng_next_with_upper_bound<T>(rng: &mut impl RandomNumberGenerator, upper_bound: T) -> T
47 where
48 T: PrimInt
49 + Unsigned
50 + NumCast
51 + FromPrimitive
52 + AsPrimitive<u128>
53 + OverflowingMul
54 + Shl<usize, Output = T>
55 + Shr<usize, Output = T>
56 + WrappingSub
57 + OverflowingAdd
58 + Widening
59{
60 assert!(upper_bound != T::zero());
61
62 let bitmask: u64 = T::max_value().to_u64().unwrap();
67 let mut random: T = NumCast::from(rng.next_u64() & bitmask).unwrap();
68 let mut m = random.wide_mul(upper_bound);
70 if m.0 < upper_bound {
71 let t = (T::zero().wrapping_sub(&upper_bound)) % upper_bound;
72 while m.0 < t {
73 random = NumCast::from(rng.next_u64() & bitmask).unwrap();
74 m = random.wide_mul(upper_bound);
75 }
76 }
77 m.1
78}
79
80pub fn rng_next_in_range<T>(rng: &mut impl RandomNumberGenerator, range: &Range<T>) -> T
93 where T: PrimInt
94 + FromPrimitive
95 + AsPrimitive<u128>
96 + OverflowingMul
97 + Shl<usize, Output = T>
98 + Shr<usize, Output = T>
99 + HasMagnitude
100 + OverflowingAdd
101{
102 let lower_bound = range.start;
103 let upper_bound = range.end;
104
105 assert!(lower_bound < upper_bound);
106
107 let delta = (upper_bound - lower_bound).to_magnitude();
108
109 if delta == T::Magnitude::max_value() {
110 return T::from_u64(rng.next_u64()).unwrap();
111 }
112
113 let random = rng_next_with_upper_bound(rng, delta);
114 lower_bound + T::from_magnitude(random)
115}
116
117pub fn rng_next_in_closed_range<T>(rng: &mut impl RandomNumberGenerator, range: &RangeInclusive<T>) -> T
118 where T: PrimInt
119 + FromPrimitive
120 + AsPrimitive<u128>
121 + OverflowingMul
122 + Shl<usize, Output = T>
123 + Shr<usize, Output = T>
124 + HasMagnitude
125{
126 let lower_bound = *range.start();
127 let upper_bound = *range.end();
128
129 assert!(lower_bound <= upper_bound);
130
131 let delta = (upper_bound - lower_bound).to_magnitude();
132
133 if delta == T::Magnitude::max_value() {
134 return T::from_u64(rng.next_u64()).unwrap();
135 }
136
137 let random = rng_next_with_upper_bound(rng, delta + T::Magnitude::one());
138 lower_bound + T::from_magnitude(random)
139}
140
141pub fn rng_random_array<const N: usize>(rng: &mut impl RandomNumberGenerator) -> [u8; N] {
142 let mut data = [0u8; N];
143 rng.fill_random_data(&mut data);
144 data
145}
146
147pub fn rng_random_bool(rng: &mut impl RandomNumberGenerator) -> bool {
148 rng.next_u32() % 2 == 0
149}
150
151pub fn rng_random_u32(rng: &mut impl RandomNumberGenerator) -> u32 {
152 rng.next_u32()
153}
154
155#[cfg(test)]
156mod tests {
157 use crate::{make_fake_random_number_generator, rng_next_in_closed_range};
158
159 #[test]
160 fn test_fake_numbers() {
161 let mut rng = make_fake_random_number_generator();
162 let array = (0..100).map(|_| rng_next_in_closed_range(&mut rng, &(-50..=50))).collect::<Vec<_>>();
163 assert_eq!(format!("{:?}", array), "[-43, -6, 43, -34, -34, 17, -9, 24, 17, -29, -32, -44, 12, -15, -46, 20, 50, -31, -50, 36, -28, -23, 6, -27, -31, -45, -27, 26, 31, -23, 24, 19, -32, 43, -18, -17, 6, -13, -1, -27, 4, -48, -4, -44, -6, 17, -15, 22, 15, 20, -25, -35, -33, -27, -17, -44, -27, 15, -14, -38, -29, -12, 8, 43, 49, -42, -11, -1, -42, -26, -25, 22, -13, 14, 42, -29, -38, 17, 2, 5, 5, -31, 27, -3, 39, -12, 42, 46, -17, -25, -46, -19, 16, 2, -45, 41, 12, -22, 43, -11]");
164 }
165}