1use std::ops::{Range, RangeInclusive, Shl, Shr};
2
3use num_traits::{
4 Bounded, One, PrimInt, Unsigned, WrappingSub,
5 cast::{AsPrimitive, FromPrimitive, NumCast},
6 ops::overflowing::{OverflowingAdd, OverflowingMul},
7};
8use rand::{CryptoRng, RngCore};
9
10use crate::{magnitude::HasMagnitude, widening::Widening};
11
12pub trait RandomNumberGenerator: RngCore + CryptoRng {
13 fn random_data(&mut self, size: usize) -> Vec<u8> {
15 let mut data = vec![0; size];
16 self.fill_bytes(&mut data);
17 data
18 }
19
20 fn fill_random_data(&mut self, data: &mut [u8]) { self.fill_bytes(data); }
21}
22
23pub fn rng_random_data(
25 rng: &mut impl RandomNumberGenerator,
26 size: usize,
27) -> Vec<u8> {
28 let mut data = vec![0; size];
29 rng.fill_random_data(&mut data);
30 data
31}
32
33pub fn rng_fill_random_data(
35 rng: &mut impl RandomNumberGenerator,
36 data: &mut [u8],
37) {
38 rng.fill_random_data(data);
39}
40
41pub fn rng_next_with_upper_bound<T>(
53 rng: &mut impl RandomNumberGenerator,
54 upper_bound: T,
55) -> T
56where
57 T: PrimInt
58 + Unsigned
59 + NumCast
60 + FromPrimitive
61 + AsPrimitive<u128>
62 + OverflowingMul
63 + Shl<usize, Output = T>
64 + Shr<usize, Output = T>
65 + WrappingSub
66 + OverflowingAdd
67 + Widening,
68{
69 assert!(upper_bound != T::zero());
70
71 let bitmask: u64 = T::max_value().to_u64().unwrap();
76 let mut random: T = NumCast::from(rng.next_u64() & bitmask).unwrap();
77 let mut m = random.wide_mul(upper_bound);
79 if m.0 < upper_bound {
80 let t = (T::zero().wrapping_sub(&upper_bound)) % upper_bound;
81 while m.0 < t {
82 random = NumCast::from(rng.next_u64() & bitmask).unwrap();
83 m = random.wide_mul(upper_bound);
84 }
85 }
86 m.1
87}
88
89pub fn rng_next_in_range<T>(
102 rng: &mut impl RandomNumberGenerator,
103 range: &Range<T>,
104) -> T
105where
106 T: PrimInt
107 + FromPrimitive
108 + AsPrimitive<u128>
109 + OverflowingMul
110 + Shl<usize, Output = T>
111 + Shr<usize, Output = T>
112 + HasMagnitude
113 + OverflowingAdd,
114{
115 let lower_bound = range.start;
116 let upper_bound = range.end;
117
118 assert!(lower_bound < upper_bound);
119
120 let delta = (upper_bound - lower_bound).to_magnitude();
121
122 if delta == T::Magnitude::max_value() {
123 return T::from_u64(rng.next_u64()).unwrap();
124 }
125
126 let random = rng_next_with_upper_bound(rng, delta);
127 lower_bound + T::from_magnitude(random)
128}
129
130pub fn rng_next_in_closed_range<T>(
131 rng: &mut impl RandomNumberGenerator,
132 range: &RangeInclusive<T>,
133) -> T
134where
135 T: PrimInt
136 + FromPrimitive
137 + AsPrimitive<u128>
138 + OverflowingMul
139 + Shl<usize, Output = T>
140 + Shr<usize, Output = T>
141 + HasMagnitude,
142{
143 let lower_bound = *range.start();
144 let upper_bound = *range.end();
145
146 assert!(lower_bound <= upper_bound);
147
148 let delta = (upper_bound - lower_bound).to_magnitude();
149
150 if delta == T::Magnitude::max_value() {
151 return T::from_u64(rng.next_u64()).unwrap();
152 }
153
154 let random = rng_next_with_upper_bound(rng, delta + T::Magnitude::one());
155 lower_bound + T::from_magnitude(random)
156}
157
158pub fn rng_random_array<const N: usize>(
159 rng: &mut impl RandomNumberGenerator,
160) -> [u8; N] {
161 let mut data = [0u8; N];
162 rng.fill_random_data(&mut data);
163 data
164}
165
166pub fn rng_random_bool(rng: &mut impl RandomNumberGenerator) -> bool {
167 rng.next_u32().is_multiple_of(2)
168}
169
170pub fn rng_random_u32(rng: &mut impl RandomNumberGenerator) -> u32 {
171 rng.next_u32()
172}
173
174#[cfg(test)]
175mod tests {
176 use crate::{make_fake_random_number_generator, rng_next_in_closed_range};
177
178 #[test]
179 fn test_fake_numbers() {
180 let mut rng = make_fake_random_number_generator();
181 let array = (0..100)
182 .map(|_| rng_next_in_closed_range(&mut rng, &(-50..=50)))
183 .collect::<Vec<_>>();
184 assert_eq!(
185 format!("{:?}", array),
186 "[-43, -6, 43, -34, -34, 17, -9, 24, 17, -29, -32, -44, 12, -15, -46, 20, 50, -31, -50, 36, -28, -23, 6, -27, -31, -45, -27, 26, 31, -23, 24, 19, -32, 43, -18, -17, 6, -13, -1, -27, 4, -48, -4, -44, -6, 17, -15, 22, 15, 20, -25, -35, -33, -27, -17, -44, -27, 15, -14, -38, -29, -12, 8, 43, 49, -42, -11, -1, -42, -26, -25, 22, -13, 14, 42, -29, -38, 17, 2, 5, 5, -31, 27, -3, 39, -12, 42, 46, -17, -25, -46, -19, 16, 2, -45, 41, 12, -22, 43, -11]"
187 );
188 }
189}