1use std::io;
7
8const RAND_BUFF_SIZE: usize = 8;
10
11#[derive(Debug)]
13pub struct RandomState {
14 buffer: [u32; RAND_BUFF_SIZE],
15 buf_cnt: usize,
16}
17
18impl Default for RandomState {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl RandomState {
25 pub fn new() -> Self {
26 Self {
27 buffer: [0; RAND_BUFF_SIZE],
28 buf_cnt: 0,
29 }
30 }
31
32 pub fn get_srandom(&mut self) -> u32 {
34 if self.buf_cnt == 0 {
35 let mut bytes = [0u8; RAND_BUFF_SIZE * 4];
36 if fill_random_bytes(&mut bytes).is_ok() {
37 for (i, chunk) in bytes.chunks_exact(4).enumerate() {
38 self.buffer[i] = u32::from_ne_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
39 }
40 }
41 self.buf_cnt = RAND_BUFF_SIZE;
42 }
43 self.buf_cnt -= 1;
44 self.buffer[self.buf_cnt]
45 }
46}
47
48#[cfg(target_os = "macos")]
50pub fn fill_random_bytes(buf: &mut [u8]) -> io::Result<()> {
51 unsafe {
52 libc::arc4random_buf(buf.as_mut_ptr() as *mut libc::c_void, buf.len());
53 }
54 Ok(())
55}
56
57#[cfg(target_os = "linux")]
58pub fn fill_random_bytes(buf: &mut [u8]) -> io::Result<()> {
59 let mut filled = 0;
60
61 while filled < buf.len() {
62 let ret = unsafe {
63 libc::getrandom(
64 buf[filled..].as_mut_ptr() as *mut libc::c_void,
65 buf.len() - filled,
66 0,
67 )
68 };
69
70 if ret < 0 {
71 let err = io::Error::last_os_error();
72 if err.kind() == io::ErrorKind::Interrupted {
73 continue;
74 }
75 return Err(err);
76 }
77
78 filled += ret as usize;
79 }
80
81 Ok(())
82}
83
84#[cfg(not(any(target_os = "macos", target_os = "linux")))]
85pub fn fill_random_bytes(buf: &mut [u8]) -> io::Result<()> {
86 use std::fs::File;
87 use std::io::Read;
88
89 let mut file = File::open("/dev/urandom")?;
90 file.read_exact(buf)?;
91 Ok(())
92}
93
94pub fn get_random_u32() -> u32 {
96 let mut buf = [0u8; 4];
97 let _ = fill_random_bytes(&mut buf);
98 u32::from_ne_bytes(buf)
99}
100
101pub fn get_random_u64() -> u64 {
103 let mut buf = [0u8; 8];
104 let _ = fill_random_bytes(&mut buf);
105 u64::from_ne_bytes(buf)
106}
107
108pub fn get_bounded_random(max: u32) -> u32 {
111 if max == 0 {
112 return 0;
113 }
114
115 if max == u32::MAX {
116 return get_random_u32();
117 }
118
119 let mut x = get_random_u32();
120 let mut m = (x as u64) * (max as u64);
121 let mut l = m as u32;
122
123 if l < max {
124 let threshold = (-(max as i64) as u64 % max as u64) as u32;
125 while l < threshold {
126 x = get_random_u32();
127 m = (x as u64) * (max as u64);
128 l = m as u32;
129 }
130 }
131
132 (m >> 32) as u32
133}
134
135pub fn get_bounded_random_buffer(buffer: &mut [u32], max: u32) {
137 for item in buffer.iter_mut() {
138 *item = get_bounded_random(max);
139 }
140}
141
142pub fn zrand_int(upper: Option<i64>, lower: Option<i64>, inclusive: bool) -> Result<i64, String> {
145 let lower = lower.unwrap_or(0);
146 let upper = upper.unwrap_or(u32::MAX as i64);
147
148 if lower < 0 || lower > u32::MAX as i64 {
149 return Err(format!(
150 "Lower bound ({}) out of range: 0-4294967295",
151 lower
152 ));
153 }
154
155 if upper < lower {
156 return Err(format!(
157 "Upper bound ({}) must be greater than Lower Bound ({})",
158 upper, lower
159 ));
160 }
161
162 if upper < 0 || upper > u32::MAX as i64 {
163 return Err(format!(
164 "Upper bound ({}) out of range: 0-4294967295",
165 upper
166 ));
167 }
168
169 let incl = if inclusive { 1 } else { 0 };
170 let diff = (upper - lower + incl) as u32;
171
172 if diff == 0 {
173 return Ok(upper);
174 }
175
176 let r = get_bounded_random(diff);
177 Ok(r as i64 + lower)
178}
179
180pub fn zrand_float() -> f64 {
183 random_real()
184}
185
186pub fn random_real() -> f64 {
188 let x = get_random_u64();
189 (x >> 11) as f64 * (1.0 / (1u64 << 53) as f64)
190}
191
192pub fn random_real_exclusive_zero() -> f64 {
194 let x = get_random_u64();
195 ((x >> 11) as f64 + 0.5) * (1.0 / (1u64 << 53) as f64)
196}
197
198pub fn random_real_inclusive() -> f64 {
200 let x = get_random_u64();
201 (x >> 11) as f64 * (1.0 / ((1u64 << 53) - 1) as f64)
202}
203
204pub fn random_range(min: i64, max: i64) -> i64 {
206 if min >= max {
207 return min;
208 }
209
210 let range = (max - min + 1) as u64;
211
212 if range <= u32::MAX as u64 {
213 min + get_bounded_random(range as u32) as i64
214 } else {
215 let r = get_random_u64() % range;
216 min + r as i64
217 }
218}
219
220pub fn shuffle<T>(slice: &mut [T]) {
222 let n = slice.len();
223 if n <= 1 {
224 return;
225 }
226
227 for i in (1..n).rev() {
228 let j = get_bounded_random((i + 1) as u32) as usize;
229 slice.swap(i, j);
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn test_random_state() {
239 let mut state = RandomState::new();
240 let r1 = state.get_srandom();
241 let r2 = state.get_srandom();
242 let r3 = state.get_srandom();
243 assert!(r1 != r2 || r2 != r3);
244 }
245
246 #[test]
247 fn test_get_random_u32() {
248 let r1 = get_random_u32();
249 let r2 = get_random_u32();
250 let r3 = get_random_u32();
251 assert!(r1 != r2 || r2 != r3);
252 }
253
254 #[test]
255 fn test_get_random_u64() {
256 let r1 = get_random_u64();
257 let r2 = get_random_u64();
258 assert_ne!(r1, r2);
259 }
260
261 #[test]
262 fn test_bounded_random() {
263 for _ in 0..100 {
264 let r = get_bounded_random(10);
265 assert!(r < 10);
266 }
267 }
268
269 #[test]
270 fn test_bounded_random_one() {
271 for _ in 0..10 {
272 let r = get_bounded_random(1);
273 assert_eq!(r, 0);
274 }
275 }
276
277 #[test]
278 fn test_zrand_int() {
279 let r = zrand_int(Some(100), Some(50), false).unwrap();
280 assert!(r >= 50 && r < 100);
281
282 let r = zrand_int(Some(100), Some(50), true).unwrap();
283 assert!(r >= 50 && r <= 100);
284 }
285
286 #[test]
287 fn test_zrand_int_no_args() {
288 let r = zrand_int(None, None, false).unwrap();
289 assert!(r >= 0);
290 }
291
292 #[test]
293 fn test_zrand_int_errors() {
294 assert!(zrand_int(Some(50), Some(100), false).is_err());
295 assert!(zrand_int(Some(-1), None, false).is_err());
296 }
297
298 #[test]
299 fn test_zrand_float() {
300 for _ in 0..100 {
301 let r = zrand_float();
302 assert!(r >= 0.0 && r < 1.0);
303 }
304 }
305
306 #[test]
307 fn test_random_real() {
308 for _ in 0..100 {
309 let r = random_real();
310 assert!(r >= 0.0 && r < 1.0);
311 }
312 }
313
314 #[test]
315 fn test_random_range() {
316 for _ in 0..100 {
317 let r = random_range(10, 20);
318 assert!(r >= 10 && r <= 20);
319 }
320 }
321
322 #[test]
323 fn test_shuffle() {
324 let mut arr = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
325 let original = arr.clone();
326 shuffle(&mut arr);
327 arr.sort();
328 assert_eq!(arr, original.iter().copied().collect::<Vec<_>>());
329 }
330
331 #[test]
332 fn test_fill_random_bytes() {
333 let mut buf = [0u8; 32];
334 fill_random_bytes(&mut buf).unwrap();
335 assert!(!buf.iter().all(|&b| b == 0));
336 }
337}