1#[cfg(feature = "distributions")]
4use crate::distributions::{Distribution, Uniform};
5use crate::traits::{Rng, RngExt};
6
7#[cfg(feature = "distributions")]
9pub fn gen_range_u32<R: Rng>(rng: &mut R, low: u32, high: u32) -> u32 {
10 let dist = Uniform::new_u32(low, high);
11 dist.sample(rng)
12}
13
14#[cfg(not(feature = "distributions"))]
16pub fn gen_range_u32<R: Rng>(rng: &mut R, low: u32, high: u32) -> u32 {
17 if low >= high {
18 panic!("gen_range_u32: low must be less than high");
19 }
20 let range = high - low;
21 let random_value = rng.next_u32() % range;
22 low + random_value
23}
24
25#[cfg(feature = "distributions")]
27pub fn gen_range_u64<R: Rng>(rng: &mut R, low: u64, high: u64) -> u64 {
28 let dist = Uniform::new_u64(low, high);
29 dist.sample(rng)
30}
31
32#[cfg(not(feature = "distributions"))]
34pub fn gen_range_u64<R: Rng>(rng: &mut R, low: u64, high: u64) -> u64 {
35 if low >= high {
36 panic!("gen_range_u64: low must be less than high");
37 }
38 let range = high - low;
39 let random_value = rng.next_u64() % range;
40 low + random_value
41}
42
43#[cfg(feature = "distributions")]
45pub fn gen_range_f32<R: Rng>(rng: &mut R, low: f32, high: f32) -> f32 {
46 let dist = Uniform::new_f32(low, high);
47 dist.sample(rng)
48}
49
50#[cfg(not(feature = "distributions"))]
52pub fn gen_range_f32<R: Rng>(rng: &mut R, low: f32, high: f32) -> f32 {
53 if low >= high {
54 panic!("gen_range_f32: low must be less than high");
55 }
56 let range = high - low;
57 let random_value = rng.gen_f32() * range;
58 low + random_value
59}
60
61#[cfg(feature = "distributions")]
63pub fn gen_range_f64<R: Rng>(rng: &mut R, low: f64, high: f64) -> f64 {
64 let dist = Uniform::new_f64(low, high);
65 dist.sample(rng)
66}
67
68#[cfg(not(feature = "distributions"))]
70pub fn gen_range_f64<R: Rng>(rng: &mut R, low: f64, high: f64) -> f64 {
71 if low >= high {
72 panic!("gen_range_f64: low must be less than high");
73 }
74 let range = high - low;
75 let random_value = rng.gen_f64() * range;
76 low + random_value
77}
78
79#[cfg(not(feature = "distributions"))]
81pub fn gen_range<R: Rng>(rng: &mut R, low: u64, high: u64) -> u64 {
82 if low >= high {
83 panic!("gen_range: low must be less than high");
84 }
85 let range = high - low;
86 let random_value = rng.next_u64() % range;
87 low + random_value
88}
89
90pub fn fill_bytes<R: Rng>(rng: &mut R, buf: &mut [u8]) {
92 rng.fill_bytes(buf);
93}
94
95pub fn gen_f32<R: Rng>(rng: &mut R) -> f32 {
97 rng.gen_f32()
98}
99
100pub fn gen_f64<R: Rng>(rng: &mut R) -> f64 {
102 rng.gen_f64()
103}
104
105pub fn shuffle<R: Rng, T>(rng: &mut R, slice: &mut [T]) {
107 for i in (1..slice.len()).rev() {
108 let j = gen_range_u32(rng, 0, (i + 1) as u32) as usize;
109 slice.swap(i, j);
110 }
111}
112
113#[cfg(feature = "std")]
115pub fn sample<R: Rng, T: Clone>(rng: &mut R, slice: &[T], n: usize) -> Vec<T> {
116 if n >= slice.len() {
117 return slice.to_vec();
118 }
119
120 let mut indices: Vec<usize> = (0..slice.len()).collect();
121 shuffle(rng, &mut indices);
122 indices[..n].iter().map(|&i| slice[i].clone()).collect()
123}
124
125pub fn choose<'a, R: Rng, T>(rng: &mut R, slice: &'a [T]) -> Option<&'a T> {
127 if slice.is_empty() {
128 return None;
129 }
130 #[cfg(feature = "distributions")]
131 {
132 let index: usize = gen_range_u64(rng, 0, slice.len() as u64) as usize;
133 Some(&slice[index])
134 }
135 #[cfg(not(feature = "distributions"))]
136 {
137 let len = slice.len() as u64;
138 let random_value = rng.next_u64() % len;
139 let index = random_value as usize;
140 Some(&slice[index])
141 }
142}
143
144pub fn weighted_choose<'a, R: Rng, T>(
146 rng: &mut R,
147 items: &'a [T],
148 weights: &[f64],
149) -> Option<&'a T> {
150 if items.is_empty() || items.len() != weights.len() {
151 return None;
152 }
153
154 let total_weight: f64 = weights.iter().sum();
155 if total_weight <= 0.0 {
156 return None;
157 }
158
159 let mut r = gen_f64(rng) * total_weight;
160 for (item, &weight) in items.iter().zip(weights.iter()) {
161 r -= weight;
162 if r <= 0.0 {
163 return Some(item);
164 }
165 }
166
167 items.last()
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use crate::fast::xoshiro256::Xoshiro256Plus;
175
176 #[test]
177 fn test_gen_range_u64() {
178 let mut rng = Xoshiro256Plus::new(42);
179 for _ in 0..100 {
180 let x = gen_range_u64(&mut rng, 10, 20);
181 assert!(x >= 10 && x < 20);
182 }
183 }
184
185 #[cfg(feature = "std")]
186 #[test]
187 fn test_shuffle() {
188 let mut rng = Xoshiro256Plus::new(42);
189 let mut vec = vec![1, 2, 3, 4, 5];
190 let original = vec.clone();
191 shuffle(&mut rng, &mut vec);
192
193 assert_eq!(vec.len(), original.len());
195 assert!(vec.iter().all(|&x| original.contains(&x)));
196 }
197
198 #[cfg(feature = "std")]
199 #[test]
200 fn test_sample() {
201 let mut rng = Xoshiro256Plus::new(42);
202 let vec = vec![1, 2, 3, 4, 5];
203 let sampled = sample(&mut rng, &vec, 3);
204
205 assert_eq!(sampled.len(), 3);
206 assert!(sampled.iter().all(|&x| vec.contains(&x)));
207 }
208
209 #[cfg(feature = "std")]
210 #[test]
211 fn test_choose() {
212 let mut rng = Xoshiro256Plus::new(42);
213 let vec = vec![1, 2, 3, 4, 5];
214 let chosen = choose(&mut rng, &vec);
215
216 assert!(chosen.is_some());
217 assert!(vec.contains(chosen.unwrap()));
218 }
219
220 #[cfg(feature = "std")]
221 #[test]
222 fn test_weighted_choose() {
223 let mut rng = Xoshiro256Plus::new(42);
224 let items = vec!["a", "b", "c"];
225 let weights = vec![0.5, 0.3, 0.2];
226 let chosen = weighted_choose(&mut rng, &items, &weights);
227
228 assert!(chosen.is_some());
229 assert!(items.contains(chosen.unwrap()));
230 }
231}