use rand::prelude::*;
pub fn choose_multiple_weighted<R>(rng: &mut R, amount: usize, weights: &[f64]) -> Vec<usize>
where
R: Rng + ?Sized,
{
if amount == 0 {
return vec![];
} else {
assert!(!weights.is_empty());
}
let weights: Vec<_> = (0..weights.len())
.scan(0.0, |sum, idx| {
debug_assert!(weights[idx] >= 0.0);
*sum += weights[idx];
Some(*sum)
})
.collect();
let total_weight = *weights.last().expect("Internal Error");
if total_weight <= f64::EPSILON * weights.len() as f64 {
let mut results = vec![];
while results.len() < amount {
let num_samples = amount - results.len();
if num_samples >= weights.len() {
results.extend(0..weights.len());
} else {
results.extend(rand::seq::index::sample(rng, weights.len(), num_samples));
}
}
results.shuffle(rng);
return results;
}
assert!(total_weight.is_finite());
let arm_spacing = total_weight / (amount as f64);
let arm_offset = rng.gen::<f64>() * arm_spacing;
let mut samples = Vec::with_capacity(amount);
let mut idx = 0;
for arm in 0..amount {
let arm = (arm as f64) * arm_spacing + arm_offset;
while idx < weights.len() && weights[idx] < arm {
idx += 1;
}
samples.push(idx);
}
samples.shuffle(rng);
samples
}
#[cfg(test)]
mod tests {
use super::choose_multiple_weighted as sus;
fn assert_data_eq(a: &mut [usize], b: &mut [usize]) {
a.sort();
b.sort();
assert_eq!(a, b);
}
#[test]
fn no_data() {
let mut rng = rand::thread_rng();
assert_data_eq(&mut sus(&mut rng, 0, &[]), &mut []);
assert_data_eq(&mut sus(&mut rng, 0, &[1.0, 2.0, 3.0]), &mut []);
}
#[test]
#[should_panic]
fn no_data_panic() {
let mut rng = rand::thread_rng();
sus(&mut rng, 100, &[]);
}
#[test]
fn not_enough_data() {
let mut rng = rand::thread_rng();
assert_data_eq(&mut sus(&mut rng, 2, &[1.0]), &mut [0, 0]);
}
#[test]
fn zero_data() {
let mut rng = rand::thread_rng();
assert_data_eq(&mut sus(&mut rng, 1, &[0.0]), &mut [0]);
assert_data_eq(
&mut sus(&mut rng, 10, &[0.0; 10]),
&mut [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
);
assert_data_eq(&mut sus(&mut rng, 6, &[0.0; 3]), &mut [0, 0, 1, 1, 2, 2]);
sus(&mut rng, 7, &[0.0; 3]);
}
#[test]
fn round_robin() {
let mut rng = rand::thread_rng();
assert_data_eq(&mut sus(&mut rng, 3, &[1.0; 3]), &mut [0, 1, 2]);
assert_data_eq(&mut sus(&mut rng, 6, &[1.0; 3]), &mut [0, 1, 2, 0, 1, 2]);
}
#[test]
fn it_works() {
let mut rng = rand::thread_rng();
assert_data_eq(&mut sus(&mut rng, 2, &[1.0, 0.0, 1.0]), &mut [0, 2]);
assert_data_eq(&mut sus(&mut rng, 3, &[2.0, 0.0, 1.0]), &mut [0, 0, 2]);
assert_data_eq(&mut sus(&mut rng, 3, &[1.0, 0.0, 0.5]), &mut [0, 0, 2]);
assert_data_eq(
&mut sus(&mut rng, 6, &[1.0, 2.0, 3.0]),
&mut [0, 1, 1, 2, 2, 2],
);
}
#[test]
fn sample_one() {
let mut rng = rand::thread_rng();
let mut data = [0.0; 10000];
data[1234] = 0.0000001;
assert_data_eq(&mut sus(&mut rng, 1, &data), &mut [1234]);
}
#[test]
fn random_data() {
let mut rng = rand::thread_rng();
assert!(sus(&mut rng, 1, &[1.0; 10000]) != sus(&mut rng, 1, &[1.0; 10000]));
assert!(sus(&mut rng, 40, &[1.0; 2000]) != sus(&mut rng, 40, &[1.0; 2000]));
}
#[test]
fn random_order() {
let mut rng = rand::thread_rng();
let mut a = sus(&mut rng, 2000, &[1.0; 2000]);
let mut b = sus(&mut rng, 2000, &[1.0; 2000]);
assert!(a != b);
assert_data_eq(&mut a, &mut b);
}
#[test]
fn random_order_repeats() {
let mut rng = rand::thread_rng();
for _ in 0..100 {
let mut a = sus(&mut rng, 100, &[1.0; 2]);
let mut b = sus(&mut rng, 100, &[1.0; 2]);
assert!(a != b);
assert_data_eq(&mut a, &mut b);
}
}
#[test]
#[ignore]
fn benchmark() {
use rand::Rng;
let mut rng = rand::thread_rng();
let amount = 1000;
let num_weights = 1_000_000;
let weights: Vec<f64> = (0..num_weights).map(|_| rng.gen()).collect();
println!("Running SUS(amount: {amount}, num_weights: {num_weights}) ...",);
std::thread::yield_now();
let start_time = std::time::Instant::now();
std::hint::black_box(sus(&mut rng, amount, &weights));
let elapsed_time = start_time.elapsed();
println!("Elapsed time: {elapsed_time:?}");
}
}