dynamic_weighted_index/
weight.rs1use crate::numeric::FloatingPointParts;
2use rand::Rng;
3use std::fmt::Debug;
4use std::ops::{Add, AddAssign, Sub, SubAssign};
5
6pub trait Weight:
7 Copy
8 + Debug
9 + Sized
10 + SubAssign
11 + PartialEq
12 + PartialOrd
13 + Default
14 + Add<Output = Self>
15 + Sub<Output = Self>
16 + AddAssign
17 + rand::distributions::uniform::SampleUniform
18{
19 const NUM_RANGES: u32;
20 const ZERO: Self;
21
22 fn compute_range_index(&self) -> u32;
23 fn max_weight_of_range_index(index: u32) -> Self;
24 fn sample_from_ratio<R: Rng + ?Sized>(rng: &mut R, numerator: Self, denominator: Self) -> bool;
25}
26
27macro_rules! weight_impl_int {
28 ($t : ty, $test_name : ident) => {
29 impl Weight for $t {
30 const NUM_RANGES: u32 = <$t>::BITS;
31 const ZERO: Self = 0;
32
33 fn compute_range_index(&self) -> u32 {
34 debug_assert!(*self > 0);
35 Self::NUM_RANGES - 1 - self.leading_zeros()
36 }
37
38 fn max_weight_of_range_index(index: u32) -> Self {
39 debug_assert!(index < Self::NUM_RANGES);
40 ((2 as $t) << index).overflowing_sub(1).0
41 }
42
43 fn sample_from_ratio<R: Rng + ?Sized>(
44 rng: &mut R,
45 numerator: Self,
46 denominator: Self,
47 ) -> bool {
48 debug_assert!(numerator <= denominator);
49 rng.gen_range(0..denominator) < numerator
50 }
51 }
52
53 #[cfg(test)]
54 mod $test_name {
55 use super::Weight;
56 use pcg_rand::Pcg64;
57 use rand::{Rng, SeedableRng};
58
59 #[test]
60 fn compute_range_index_rand() {
61 let mut rng = Pcg64::seed_from_u64(0x1234);
62 for _ in 0..1000 {
63 let num: $t = rng.gen();
64
65 if num == 0 {
66 continue;
67 }
68
69 let range = num.compute_range_index();
70
71 if num > 1 {
72 assert_eq!(range.saturating_sub(1), (num / 2).compute_range_index());
73 }
74
75 if num <= <$t>::MAX / 2 {
76 assert_eq!(range + 1, (num * 2).compute_range_index());
77 }
78 }
79 }
80
81 #[test]
82 fn compute_range_index_max() {
83 for i in 0..<$t>::NUM_RANGES {
84 let max_weight = <$t>::max_weight_of_range_index(i);
85 assert_eq!(i, max_weight.compute_range_index());
86 }
87 }
88
89 #[test]
90 fn compute_range_index_max2() {
91 for i in 0..(<$t>::NUM_RANGES - 1) {
92 let max_weight = <$t>::max_weight_of_range_index(i);
93 assert_eq!(i + 1, max_weight.compute_range_index() + 1);
94 }
95 }
96
97 #[test]
98 fn num_ranges() {
99 assert_eq!(<$t>::MAX.compute_range_index() + 1, <$t>::NUM_RANGES);
100 }
101 }
102 };
103}
104
105weight_impl_int!(u8, test_u8);
106weight_impl_int!(u16, test_u16);
107weight_impl_int!(u32, test_u32);
108weight_impl_int!(u64, test_u64);
109weight_impl_int!(u128, test_u128);
110weight_impl_int!(usize, test_usize);
111
112macro_rules! weight_impl_float {
113 ($t : ty, $m : ty, $test_name : ident) => {
114 impl Weight for $t {
115 const NUM_RANGES: u32 = (<$t>::MAX_EXP - <$t>::MIN_EXP + 2) as u32;
116 const ZERO: Self = 0.0;
117
118 fn compute_range_index(&self) -> u32 {
119 self.get_exponent() as u32
128 }
129
130 fn max_weight_of_range_index(index: u32) -> Self {
131 Self::compose(0, 1 + index as $m, false)
132 }
133
134 fn sample_from_ratio<R: Rng + ?Sized>(
135 rng: &mut R,
136 numerator: Self,
137 denominator: Self,
138 ) -> bool {
139 rng.gen_bool((numerator / denominator) as f64)
140 }
141 }
142
143 #[cfg(test)]
144 mod $test_name {
145 use super::Weight;
146 use pcg_rand::Pcg64;
147 use rand::{Rng, SeedableRng};
148
149 #[test]
150 fn num_ranges() {
151 assert_eq!(<$t>::MAX.compute_range_index() + 1, <$t>::NUM_RANGES);
152 }
153
154 #[test]
155 fn compute_range_index_rand() {
156 let mut rng = Pcg64::seed_from_u64(0x1234);
157 for _ in 0..1000 {
158 let num = <$t>::from_bits(rng.gen::<$m>() >> 1);
159 if !num.is_finite() || !num.is_normal() {
160 continue;
161 }
162
163 let range = num.compute_range_index();
164
165 if range != 0 {
166 assert_eq!(range.saturating_sub(1), (num / 2.0).compute_range_index());
167 }
168
169 if range + 1 != <$t>::NUM_RANGES {
170 assert_eq!(
171 range + 1,
172 (num * 2.0).compute_range_index(),
173 "num: {}, 2*num: {}",
174 num,
175 num * 2.0
176 );
177 }
178 }
179 }
180 }
181 };
182}
183
184weight_impl_float!(f32, u32, test_f32);
185weight_impl_float!(f64, u64, test_f64);