nois/
select_from_weighted.rs

1use rand::distributions::uniform::SampleUniform;
2
3use crate::{int_in_range, integers::Uint};
4
5/// Selects one element from a given weighted list.
6///
7/// In contrast to [`pick`] this does not move the selected element from the input list
8/// but requires elements to be `Clone`able. This is because only one element is needed.
9/// It could be implemented differently though.
10///
11/// The list must not be empty. Each element must have a non-zeo weight.
12/// The total weight must not exceed the u128 range.
13///
14/// ## Examples
15///
16/// Pick 1 hat out of 3 hats with different rarity:
17///
18/// ```
19/// use nois::{randomness_from_str, select_from_weighted};
20///
21/// let randomness = randomness_from_str("9e8e26615f51552aa3b18b6f0bcf0dae5afbe30321e8d7ea7fa51ebeb1d8fe62").unwrap();
22///
23/// let list = vec![
24///     ("green hat", 40u32),
25///     ("viking helmet", 55u32),
26///     ("rare golden crown", 5u32)
27/// ];
28///
29/// let selected = select_from_weighted(randomness, &list).unwrap();
30///
31/// assert_eq!(selected, "viking helmet");
32/// ```
33pub fn select_from_weighted<T: Clone, W: Uint + SampleUniform>(
34    randomness: [u8; 32],
35    list: &[(T, W)],
36) -> Result<T, String> {
37    if list.is_empty() {
38        return Err(String::from("List must not be empty"));
39    }
40
41    let mut total_weight = W::ZERO;
42    for (_, weight) in list {
43        if *weight == W::ZERO {
44            return Err(String::from("All element weights should be >= 1"));
45        }
46        total_weight = total_weight
47            .checked_add(*weight)
48            .ok_or_else(|| String::from("Total weight is greater than maximum value of u32"))?;
49    }
50
51    debug_assert!(
52        total_weight > W::ZERO,
53        "we know we have a non-empty list of non-zero elements"
54    );
55
56    let r = int_in_range::<W>(randomness, W::ONE, total_weight);
57    let mut weight_sum = W::ZERO;
58    for element in list {
59        weight_sum += element.1;
60        if r <= weight_sum {
61            return Ok(element.0.clone());
62        }
63    }
64    // This point should never be reached
65    panic!("No element selected")
66}
67
68#[cfg(test)]
69mod tests {
70    use crate::RANDOMNESS1;
71
72    use super::*;
73
74    #[test]
75    fn select_from_weighted_works() {
76        let elements: Vec<(char, u32)> = vec![('a', 1), ('b', 5), ('c', 4)];
77        let picked = select_from_weighted(RANDOMNESS1, &elements).unwrap();
78        assert_eq!(picked, 'c');
79
80        // Element type is Clone but not Copy
81        #[derive(PartialEq, Debug, Clone)]
82        struct Color(String);
83        let elements = vec![
84            (Color("red".into()), 12u32),
85            (Color("blue".to_string()), 15u32),
86            (Color("green".to_string()), 8u32),
87            (Color("orange".to_string()), 21u32),
88            (Color("pink".to_string()), 11u32),
89        ];
90        let picked = select_from_weighted(RANDOMNESS1, &elements).unwrap();
91        assert_eq!(picked, Color("orange".to_string()));
92
93        // Test for u128
94        let elements = vec![
95            (Color("red".into()), 12u128),
96            (Color("blue".to_string()), 15u128),
97            (Color("green".to_string()), 8u128),
98            (Color("orange".to_string()), 21u128),
99            (Color("pink".to_string()), 11u128),
100        ];
101        let picked = select_from_weighted(RANDOMNESS1, &elements).unwrap();
102        assert_eq!(picked, Color("blue".to_string()));
103
104        // Pick from slice
105        let selection = &elements[0..3];
106        let picked = select_from_weighted(RANDOMNESS1, selection).unwrap();
107        assert_eq!(picked, Color("red".to_string()));
108    }
109
110    #[test]
111    fn select_from_weighted_fails_on_empty_list() {
112        //This will check that the list is empty
113        let elements: Vec<(i32, u32)> = vec![];
114
115        let err = select_from_weighted(RANDOMNESS1, &elements).unwrap_err();
116
117        // Check that the selected element has the expected weight
118        assert_eq!(err, "List must not be empty");
119    }
120
121    #[test]
122    fn select_from_weighted_fails_on_element_weight_less_than_1() {
123        let elements: Vec<(i32, u32)> = vec![(1, 5), (2, 4), (-3, 0)];
124
125        let err = select_from_weighted(RANDOMNESS1, &elements).unwrap_err();
126
127        // Check that the selected element has the expected weight
128        assert_eq!(err, "All element weights should be >= 1");
129    }
130
131    #[test]
132    fn select_from_weighted_fails_with_total_weight_too_high() {
133        let elements: Vec<(i32, u128)> = vec![(1, u128::MAX), (2, 1)];
134
135        let err = select_from_weighted(RANDOMNESS1, &elements).unwrap_err();
136
137        // Check that the selected element has the expected weight
138        assert_eq!(err, "Total weight is greater than maximum value of u32");
139    }
140
141    #[test]
142    fn select_from_weighted_distribution_is_uniform() {
143        /// This test will generate a huge amount  of subrandomness
144        /// then checks that the distribution is expected within a range of 1%
145        use crate::sub_randomness::sub_randomness;
146        use std::collections::HashMap;
147
148        const TEST_SAMPLE_SIZE: usize = 1_000_000;
149        const ACCURACY: f32 = 0.01;
150        // This test needs the sum of the weights to be equal to 1.
151        // Although the function should work as expected for weights that do not equal 1
152        let elements: Vec<(String, u32)> = vec![
153            (String::from("a"), 100),
154            (String::from("b"), 200),
155            (String::from("c"), 30),
156            (String::from("d"), 70),
157            (String::from("e"), 600),
158        ];
159        let total_weight = elements.iter().map(|element| element.1).sum::<u32>();
160        println!("total weight: {}", total_weight);
161
162        let mut result = vec![];
163
164        for subrand in sub_randomness(RANDOMNESS1).take(TEST_SAMPLE_SIZE) {
165            result.push(select_from_weighted(subrand, &elements).unwrap());
166        }
167
168        let mut histogram = HashMap::new();
169
170        for element in result {
171            let count = histogram.entry(element).or_insert(0);
172            *count += 1;
173        }
174
175        // This will assert on all the elements of the data 1 by 1 and check if their occurence is within the 1% expected range
176        for (bin, count) in histogram {
177            let probability = elements.iter().find(|e| e.0 == bin).map(|e| e.1).unwrap() as f32
178                / total_weight as f32;
179            let estimated_count_for_uniform_distribution = TEST_SAMPLE_SIZE as f32 * probability;
180            let estimation_min: i32 =
181                (estimated_count_for_uniform_distribution * (1_f32 - ACCURACY)) as i32;
182            let estimation_max: i32 =
183                (estimated_count_for_uniform_distribution * (1_f32 + ACCURACY)) as i32;
184            println!(
185                "estimation {}, max: {}, min: {}",
186                estimated_count_for_uniform_distribution, estimation_max, estimation_min
187            );
188            println!("{}: {}", bin, count);
189            assert!(count >= estimation_min && count <= estimation_max);
190        }
191    }
192}