random_constructible/
lib.rs

1#![allow(unused_imports)]
2
3#[macro_use] mod imports; use imports::*;
4
5x!{rand_construct}
6x!{rand_construct_enum}
7x!{rand_construct_env}
8x!{prim_traits}
9x!{sample}
10x!{impl_for_optiont}
11
12#[cfg(test)]
13mod tests {
14    use super::*;
15    use rand::rngs::StdRng;
16    use rand::SeedableRng;
17    use std::collections::HashMap;
18    use std::sync::Arc;
19
20    // Define a test enum and manually implement RandConstructEnum
21    #[derive(Default,Clone, Debug, PartialEq, Eq, Hash)]
22    enum ManualTestEnum {
23        #[default]
24        VariantX,
25        VariantY,
26        VariantZ,
27    }
28
29    impl RandConstructEnumWithEnv for ManualTestEnum {}
30
31    impl RandConstructEnum for ManualTestEnum {
32        fn all_variants() -> Vec<Self> {
33            vec![Self::VariantX, Self::VariantY, Self::VariantZ]
34        }
35
36        fn default_weight(&self) -> f64 {
37            match self {
38                Self::VariantX => 2.0,
39                Self::VariantY => 3.0,
40                Self::VariantZ => 5.0,
41            }
42        }
43
44        fn create_default_probability_map() -> Arc<HashMap<Self, f64>> {
45            DefaultProvider::probability_map()
46        }
47    }
48
49    // Implement the default provider using the macro
50    struct DefaultProvider;
51
52    rand_construct_env!(DefaultProvider => ManualTestEnum {
53        VariantX => 2.0,
54        VariantY => 3.0,
55        VariantZ => 5.0,
56    });
57
58    // Implement a custom probability provider using the macro
59    struct CustomProvider;
60
61    rand_construct_env!(CustomProvider => ManualTestEnum {
62        VariantX => 1.0,
63        VariantY => 1.0,
64        VariantZ => 8.0,
65    });
66
67    #[test]
68    fn test_manual_all_variants() {
69        let variants = ManualTestEnum::all_variants();
70        assert_eq!(variants.len(), 3);
71        assert!(variants.contains(&ManualTestEnum::VariantX));
72        assert!(variants.contains(&ManualTestEnum::VariantY));
73        assert!(variants.contains(&ManualTestEnum::VariantZ));
74    }
75
76    #[test]
77    fn test_manual_default_weight() {
78        assert_eq!(ManualTestEnum::VariantX.default_weight(), 2.0);
79        assert_eq!(ManualTestEnum::VariantY.default_weight(), 3.0);
80        assert_eq!(ManualTestEnum::VariantZ.default_weight(), 5.0);
81    }
82
83    #[test]
84    fn test_manual_random() {
85        let mut rng = StdRng::seed_from_u64(42);
86        let mut counts = HashMap::new();
87
88        for _ in 0..10000 {
89            let variant = ManualTestEnum::random_with_rng(&mut rng);
90            *counts.entry(variant).or_insert(0) += 1;
91        }
92
93        let total = counts.values().sum::<usize>() as f64;
94        let prob_x = *counts.get(&ManualTestEnum::VariantX).unwrap_or(&0) as f64 / total;
95        let prob_y = *counts.get(&ManualTestEnum::VariantY).unwrap_or(&0) as f64 / total;
96        let prob_z = *counts.get(&ManualTestEnum::VariantZ).unwrap_or(&0) as f64 / total;
97
98        // Expected probabilities: X: 0.2, Y: 0.3, Z: 0.5
99        assert!((prob_x - 0.2).abs() < 0.05);
100        assert!((prob_y - 0.3).abs() < 0.05);
101        assert!((prob_z - 0.5).abs() < 0.05);
102    }
103
104    #[test]
105    fn test_manual_uniform() {
106        let mut counts = HashMap::new();
107
108        for _ in 0..10000 {
109            let variant = ManualTestEnum::uniform();
110            *counts.entry(variant).or_insert(0) += 1;
111        }
112
113        let total = counts.values().sum::<usize>() as f64;
114        for &count in counts.values() {
115            let prob = count as f64 / total;
116            assert!((prob - (1.0 / 3.0)).abs() < 0.05);
117        }
118    }
119
120    #[test]
121    fn test_manual_random_with_probabilities() {
122        let mut rng = StdRng::seed_from_u64(42);
123        let probs = CustomProvider::probability_map();
124
125        let mut counts = HashMap::new();
126
127        for _ in 0..10000 {
128            let variant = sample_variants_with_probabilities(&mut rng, &probs);
129            *counts.entry(variant).or_insert(0) += 1;
130        }
131
132        // Expected probabilities: X: 0.1, Y: 0.1, Z: 0.8
133        let total = counts.values().sum::<usize>() as f64;
134        let prob_x = *counts.get(&ManualTestEnum::VariantX).unwrap_or(&0) as f64 / total;
135        let prob_y = *counts.get(&ManualTestEnum::VariantY).unwrap_or(&0) as f64 / total;
136        let prob_z = *counts.get(&ManualTestEnum::VariantZ).unwrap_or(&0) as f64 / total;
137
138        assert!((prob_x - 0.1).abs() < 0.02);
139        assert!((prob_y - 0.1).abs() < 0.02);
140        assert!((prob_z - 0.8).abs() < 0.05);
141    }
142
143    #[test]
144    fn test_manual_sample_from_provider() {
145        let mut rng = StdRng::seed_from_u64(42);
146        let mut counts = HashMap::new();
147
148        for _ in 0..10000 {
149            let variant = ManualTestEnum::sample_from_provider::<CustomProvider, _>(&mut rng);
150            *counts.entry(variant).or_insert(0) += 1;
151        }
152
153        // Expected probabilities: X: 0.1, Y: 0.1, Z: 0.8
154        let total = counts.values().sum::<usize>() as f64;
155        let prob_x = *counts.get(&ManualTestEnum::VariantX).unwrap_or(&0) as f64 / total;
156        let prob_y = *counts.get(&ManualTestEnum::VariantY).unwrap_or(&0) as f64 / total;
157        let prob_z = *counts.get(&ManualTestEnum::VariantZ).unwrap_or(&0) as f64 / total;
158
159        assert!((prob_x - 0.1).abs() < 0.02);
160        assert!((prob_y - 0.1).abs() < 0.02);
161        assert!((prob_z - 0.8).abs() < 0.05);
162    }
163}