dynamic_weighted_index/
weight.rs

1use crate::numeric::FloatingPointParts;
2use rand::Rng;
3use std::fmt::Debug;
4use std::ops::{Add, AddAssign, Sub, SubAssign};
5
6pub trait Weight:
7    Copy
8    + Debug
9    + Sized
10    + SubAssign
11    + PartialEq
12    + PartialOrd
13    + Default
14    + Add<Output = Self>
15    + Sub<Output = Self>
16    + AddAssign
17    + rand::distributions::uniform::SampleUniform
18{
19    const NUM_RANGES: u32;
20    const ZERO: Self;
21
22    fn compute_range_index(&self) -> u32;
23    fn max_weight_of_range_index(index: u32) -> Self;
24    fn sample_from_ratio<R: Rng + ?Sized>(rng: &mut R, numerator: Self, denominator: Self) -> bool;
25}
26
27macro_rules! weight_impl_int {
28    ($t : ty, $test_name : ident) => {
29        impl Weight for $t {
30            const NUM_RANGES: u32 = <$t>::BITS;
31            const ZERO: Self = 0;
32
33            fn compute_range_index(&self) -> u32 {
34                debug_assert!(*self > 0);
35                Self::NUM_RANGES - 1 - self.leading_zeros()
36            }
37
38            fn max_weight_of_range_index(index: u32) -> Self {
39                debug_assert!(index < Self::NUM_RANGES);
40                ((2 as $t) << index).overflowing_sub(1).0
41            }
42
43            fn sample_from_ratio<R: Rng + ?Sized>(
44                rng: &mut R,
45                numerator: Self,
46                denominator: Self,
47            ) -> bool {
48                debug_assert!(numerator <= denominator);
49                rng.gen_range(0..denominator) < numerator
50            }
51        }
52
53        #[cfg(test)]
54        mod $test_name {
55            use super::Weight;
56            use pcg_rand::Pcg64;
57            use rand::{Rng, SeedableRng};
58
59            #[test]
60            fn compute_range_index_rand() {
61                let mut rng = Pcg64::seed_from_u64(0x1234);
62                for _ in 0..1000 {
63                    let num: $t = rng.gen();
64
65                    if num == 0 {
66                        continue;
67                    }
68
69                    let range = num.compute_range_index();
70
71                    if num > 1 {
72                        assert_eq!(range.saturating_sub(1), (num / 2).compute_range_index());
73                    }
74
75                    if num <= <$t>::MAX / 2 {
76                        assert_eq!(range + 1, (num * 2).compute_range_index());
77                    }
78                }
79            }
80
81            #[test]
82            fn compute_range_index_max() {
83                for i in 0..<$t>::NUM_RANGES {
84                    let max_weight = <$t>::max_weight_of_range_index(i);
85                    assert_eq!(i, max_weight.compute_range_index());
86                }
87            }
88
89            #[test]
90            fn compute_range_index_max2() {
91                for i in 0..(<$t>::NUM_RANGES - 1) {
92                    let max_weight = <$t>::max_weight_of_range_index(i);
93                    assert_eq!(i + 1, max_weight.compute_range_index() + 1);
94                }
95            }
96
97            #[test]
98            fn num_ranges() {
99                assert_eq!(<$t>::MAX.compute_range_index() + 1, <$t>::NUM_RANGES);
100            }
101        }
102    };
103}
104
105weight_impl_int!(u8, test_u8);
106weight_impl_int!(u16, test_u16);
107weight_impl_int!(u32, test_u32);
108weight_impl_int!(u64, test_u64);
109weight_impl_int!(u128, test_u128);
110weight_impl_int!(usize, test_usize);
111
112macro_rules! weight_impl_float {
113    ($t : ty, $m : ty, $test_name : ident) => {
114        impl Weight for $t {
115            const NUM_RANGES: u32 = (<$t>::MAX_EXP - <$t>::MIN_EXP + 2) as u32;
116            const ZERO: Self = 0.0;
117
118            fn compute_range_index(&self) -> u32 {
119                /*
120                the portable way:
121                    let log = weight.log2().floor() as i32;
122                    let result = log - f64::MIN_EXP;
123                    assert!(result >= 0);
124                    result as u32
125                */
126
127                self.get_exponent() as u32
128            }
129
130            fn max_weight_of_range_index(index: u32) -> Self {
131                Self::compose(0, 1 + index as $m, false)
132            }
133
134            fn sample_from_ratio<R: Rng + ?Sized>(
135                rng: &mut R,
136                numerator: Self,
137                denominator: Self,
138            ) -> bool {
139                rng.gen_bool((numerator / denominator) as f64)
140            }
141        }
142
143        #[cfg(test)]
144        mod $test_name {
145            use super::Weight;
146            use pcg_rand::Pcg64;
147            use rand::{Rng, SeedableRng};
148
149            #[test]
150            fn num_ranges() {
151                assert_eq!(<$t>::MAX.compute_range_index() + 1, <$t>::NUM_RANGES);
152            }
153
154            #[test]
155            fn compute_range_index_rand() {
156                let mut rng = Pcg64::seed_from_u64(0x1234);
157                for _ in 0..1000 {
158                    let num = <$t>::from_bits(rng.gen::<$m>() >> 1);
159                    if !num.is_finite() || !num.is_normal() {
160                        continue;
161                    }
162
163                    let range = num.compute_range_index();
164
165                    if range != 0 {
166                        assert_eq!(range.saturating_sub(1), (num / 2.0).compute_range_index());
167                    }
168
169                    if range + 1 != <$t>::NUM_RANGES {
170                        assert_eq!(
171                            range + 1,
172                            (num * 2.0).compute_range_index(),
173                            "num: {}, 2*num: {}",
174                            num,
175                            num * 2.0
176                        );
177                    }
178                }
179            }
180        }
181    };
182}
183
184weight_impl_float!(f32, u32, test_f32);
185weight_impl_float!(f64, u64, test_f64);