use rand::distributions::uniform::SampleUniform;
use crate::{int_in_range, integers::Uint};
pub fn select_from_weighted<T: Clone, W: Uint + SampleUniform>(
randomness: [u8; 32],
list: &[(T, W)],
) -> Result<T, String> {
if list.is_empty() {
return Err(String::from("List must not be empty"));
}
let mut total_weight = W::ZERO;
for (_, weight) in list {
if *weight == W::ZERO {
return Err(String::from("All element weights should be >= 1"));
}
total_weight = total_weight
.checked_add(*weight)
.ok_or_else(|| String::from("Total weight is greater than maximum value of u32"))?;
}
debug_assert!(
total_weight > W::ZERO,
"we know we have a non-empty list of non-zero elements"
);
let r = int_in_range::<W>(randomness, W::ONE, total_weight);
let mut weight_sum = W::ZERO;
for element in list {
weight_sum += element.1;
if r <= weight_sum {
return Ok(element.0.clone());
}
}
panic!("No element selected")
}
#[cfg(test)]
mod tests {
use crate::RANDOMNESS1;
use super::*;
#[test]
fn select_from_weighted_works() {
let elements: Vec<(char, u32)> = vec![('a', 1), ('b', 5), ('c', 4)];
let picked = select_from_weighted(RANDOMNESS1, &elements).unwrap();
assert_eq!(picked, 'c');
#[derive(PartialEq, Debug, Clone)]
struct Color(String);
let elements = vec![
(Color("red".into()), 12u32),
(Color("blue".to_string()), 15u32),
(Color("green".to_string()), 8u32),
(Color("orange".to_string()), 21u32),
(Color("pink".to_string()), 11u32),
];
let picked = select_from_weighted(RANDOMNESS1, &elements).unwrap();
assert_eq!(picked, Color("orange".to_string()));
let elements = vec![
(Color("red".into()), 12u128),
(Color("blue".to_string()), 15u128),
(Color("green".to_string()), 8u128),
(Color("orange".to_string()), 21u128),
(Color("pink".to_string()), 11u128),
];
let picked = select_from_weighted(RANDOMNESS1, &elements).unwrap();
assert_eq!(picked, Color("blue".to_string()));
let selection = &elements[0..3];
let picked = select_from_weighted(RANDOMNESS1, selection).unwrap();
assert_eq!(picked, Color("red".to_string()));
}
#[test]
fn select_from_weighted_fails_on_empty_list() {
let elements: Vec<(i32, u32)> = vec![];
let err = select_from_weighted(RANDOMNESS1, &elements).unwrap_err();
assert_eq!(err, "List must not be empty");
}
#[test]
fn select_from_weighted_fails_on_element_weight_less_than_1() {
let elements: Vec<(i32, u32)> = vec![(1, 5), (2, 4), (-3, 0)];
let err = select_from_weighted(RANDOMNESS1, &elements).unwrap_err();
assert_eq!(err, "All element weights should be >= 1");
}
#[test]
fn select_from_weighted_fails_with_total_weight_too_high() {
let elements: Vec<(i32, u128)> = vec![(1, u128::MAX), (2, 1)];
let err = select_from_weighted(RANDOMNESS1, &elements).unwrap_err();
assert_eq!(err, "Total weight is greater than maximum value of u32");
}
#[test]
fn select_from_weighted_distribution_is_uniform() {
use crate::sub_randomness::sub_randomness;
use std::collections::HashMap;
const TEST_SAMPLE_SIZE: usize = 1_000_000;
const ACCURACY: f32 = 0.01;
let elements: Vec<(String, u32)> = vec![
(String::from("a"), 100),
(String::from("b"), 200),
(String::from("c"), 30),
(String::from("d"), 70),
(String::from("e"), 600),
];
let total_weight = elements.iter().map(|element| element.1).sum::<u32>();
println!("total weight: {}", total_weight);
let mut result = vec![];
for subrand in sub_randomness(RANDOMNESS1).take(TEST_SAMPLE_SIZE) {
result.push(select_from_weighted(subrand, &elements).unwrap());
}
let mut histogram = HashMap::new();
for element in result {
let count = histogram.entry(element).or_insert(0);
*count += 1;
}
for (bin, count) in histogram {
let probability = elements.iter().find(|e| e.0 == bin).map(|e| e.1).unwrap() as f32
/ total_weight as f32;
let estimated_count_for_uniform_distribution = TEST_SAMPLE_SIZE as f32 * probability;
let estimation_min: i32 =
(estimated_count_for_uniform_distribution * (1_f32 - ACCURACY)) as i32;
let estimation_max: i32 =
(estimated_count_for_uniform_distribution * (1_f32 + ACCURACY)) as i32;
println!(
"estimation {}, max: {}, min: {}",
estimated_count_for_uniform_distribution, estimation_max, estimation_min
);
println!("{}: {}", bin, count);
assert!(count >= estimation_min && count <= estimation_max);
}
}
}