nois/
select_from_weighted.rs1use rand::distributions::uniform::SampleUniform;
2
3use crate::{int_in_range, integers::Uint};
4
5pub fn select_from_weighted<T: Clone, W: Uint + SampleUniform>(
34 randomness: [u8; 32],
35 list: &[(T, W)],
36) -> Result<T, String> {
37 if list.is_empty() {
38 return Err(String::from("List must not be empty"));
39 }
40
41 let mut total_weight = W::ZERO;
42 for (_, weight) in list {
43 if *weight == W::ZERO {
44 return Err(String::from("All element weights should be >= 1"));
45 }
46 total_weight = total_weight
47 .checked_add(*weight)
48 .ok_or_else(|| String::from("Total weight is greater than maximum value of u32"))?;
49 }
50
51 debug_assert!(
52 total_weight > W::ZERO,
53 "we know we have a non-empty list of non-zero elements"
54 );
55
56 let r = int_in_range::<W>(randomness, W::ONE, total_weight);
57 let mut weight_sum = W::ZERO;
58 for element in list {
59 weight_sum += element.1;
60 if r <= weight_sum {
61 return Ok(element.0.clone());
62 }
63 }
64 panic!("No element selected")
66}
67
68#[cfg(test)]
69mod tests {
70 use crate::RANDOMNESS1;
71
72 use super::*;
73
74 #[test]
75 fn select_from_weighted_works() {
76 let elements: Vec<(char, u32)> = vec![('a', 1), ('b', 5), ('c', 4)];
77 let picked = select_from_weighted(RANDOMNESS1, &elements).unwrap();
78 assert_eq!(picked, 'c');
79
80 #[derive(PartialEq, Debug, Clone)]
82 struct Color(String);
83 let elements = vec![
84 (Color("red".into()), 12u32),
85 (Color("blue".to_string()), 15u32),
86 (Color("green".to_string()), 8u32),
87 (Color("orange".to_string()), 21u32),
88 (Color("pink".to_string()), 11u32),
89 ];
90 let picked = select_from_weighted(RANDOMNESS1, &elements).unwrap();
91 assert_eq!(picked, Color("orange".to_string()));
92
93 let elements = vec![
95 (Color("red".into()), 12u128),
96 (Color("blue".to_string()), 15u128),
97 (Color("green".to_string()), 8u128),
98 (Color("orange".to_string()), 21u128),
99 (Color("pink".to_string()), 11u128),
100 ];
101 let picked = select_from_weighted(RANDOMNESS1, &elements).unwrap();
102 assert_eq!(picked, Color("blue".to_string()));
103
104 let selection = &elements[0..3];
106 let picked = select_from_weighted(RANDOMNESS1, selection).unwrap();
107 assert_eq!(picked, Color("red".to_string()));
108 }
109
110 #[test]
111 fn select_from_weighted_fails_on_empty_list() {
112 let elements: Vec<(i32, u32)> = vec![];
114
115 let err = select_from_weighted(RANDOMNESS1, &elements).unwrap_err();
116
117 assert_eq!(err, "List must not be empty");
119 }
120
121 #[test]
122 fn select_from_weighted_fails_on_element_weight_less_than_1() {
123 let elements: Vec<(i32, u32)> = vec![(1, 5), (2, 4), (-3, 0)];
124
125 let err = select_from_weighted(RANDOMNESS1, &elements).unwrap_err();
126
127 assert_eq!(err, "All element weights should be >= 1");
129 }
130
131 #[test]
132 fn select_from_weighted_fails_with_total_weight_too_high() {
133 let elements: Vec<(i32, u128)> = vec![(1, u128::MAX), (2, 1)];
134
135 let err = select_from_weighted(RANDOMNESS1, &elements).unwrap_err();
136
137 assert_eq!(err, "Total weight is greater than maximum value of u32");
139 }
140
141 #[test]
142 fn select_from_weighted_distribution_is_uniform() {
143 use crate::sub_randomness::sub_randomness;
146 use std::collections::HashMap;
147
148 const TEST_SAMPLE_SIZE: usize = 1_000_000;
149 const ACCURACY: f32 = 0.01;
150 let elements: Vec<(String, u32)> = vec![
153 (String::from("a"), 100),
154 (String::from("b"), 200),
155 (String::from("c"), 30),
156 (String::from("d"), 70),
157 (String::from("e"), 600),
158 ];
159 let total_weight = elements.iter().map(|element| element.1).sum::<u32>();
160 println!("total weight: {}", total_weight);
161
162 let mut result = vec![];
163
164 for subrand in sub_randomness(RANDOMNESS1).take(TEST_SAMPLE_SIZE) {
165 result.push(select_from_weighted(subrand, &elements).unwrap());
166 }
167
168 let mut histogram = HashMap::new();
169
170 for element in result {
171 let count = histogram.entry(element).or_insert(0);
172 *count += 1;
173 }
174
175 for (bin, count) in histogram {
177 let probability = elements.iter().find(|e| e.0 == bin).map(|e| e.1).unwrap() as f32
178 / total_weight as f32;
179 let estimated_count_for_uniform_distribution = TEST_SAMPLE_SIZE as f32 * probability;
180 let estimation_min: i32 =
181 (estimated_count_for_uniform_distribution * (1_f32 - ACCURACY)) as i32;
182 let estimation_max: i32 =
183 (estimated_count_for_uniform_distribution * (1_f32 + ACCURACY)) as i32;
184 println!(
185 "estimation {}, max: {}, min: {}",
186 estimated_count_for_uniform_distribution, estimation_max, estimation_min
187 );
188 println!("{}: {}", bin, count);
189 assert!(count >= estimation_min && count <= estimation_max);
190 }
191 }
192}