random_constructible/
lib.rs

1// ---------------- [ File: random-constructible/src/lib.rs ]
2#![cfg_attr(feature = "specialization", feature(min_specialization,specialization))]
3
4#![allow(unused_imports)]
5
6#[macro_use] mod imports; use imports::*;
7
8x!{rand_construct}
9x!{rand_construct_enum}
10x!{rand_construct_env}
11x!{prim_traits}
12x!{sample}
13x!{impl_for_optiont}
14
15#[cfg(test)]
16mod tests {
17    use super::*;
18    use rand::rngs::StdRng;
19    use rand::SeedableRng;
20    use std::collections::HashMap;
21    use std::sync::Arc;
22
23    // Define a test enum and manually implement RandConstructEnum
24    #[derive(Default,Clone, Debug, PartialEq, Eq, Hash)]
25    enum ManualTestEnum {
26        #[default]
27        VariantX,
28        VariantY,
29        VariantZ,
30    }
31
32    impl RandConstructEnumWithEnv for ManualTestEnum {}
33
34    impl RandConstructEnum for ManualTestEnum {
35        fn all_variants() -> Vec<Self> {
36            vec![Self::VariantX, Self::VariantY, Self::VariantZ]
37        }
38
39        fn default_weight(&self) -> f64 {
40            match self {
41                Self::VariantX => 2.0,
42                Self::VariantY => 3.0,
43                Self::VariantZ => 5.0,
44            }
45        }
46
47        fn create_default_probability_map() -> Arc<HashMap<Self, f64>> {
48            DefaultProvider::probability_map()
49        }
50    }
51
52    // Implement the default provider using the macro
53    struct DefaultProvider;
54
55    rand_construct_env!(DefaultProvider => ManualTestEnum {
56        VariantX => 2.0,
57        VariantY => 3.0,
58        VariantZ => 5.0,
59    });
60
61    // Implement a custom probability provider using the macro
62    struct CustomProvider;
63
64    rand_construct_env!(CustomProvider => ManualTestEnum {
65        VariantX => 1.0,
66        VariantY => 1.0,
67        VariantZ => 8.0,
68    });
69
70    #[test]
71    fn test_manual_all_variants() {
72        let variants = ManualTestEnum::all_variants();
73        assert_eq!(variants.len(), 3);
74        assert!(variants.contains(&ManualTestEnum::VariantX));
75        assert!(variants.contains(&ManualTestEnum::VariantY));
76        assert!(variants.contains(&ManualTestEnum::VariantZ));
77    }
78
79    #[test]
80    fn test_manual_default_weight() {
81        assert_eq!(ManualTestEnum::VariantX.default_weight(), 2.0);
82        assert_eq!(ManualTestEnum::VariantY.default_weight(), 3.0);
83        assert_eq!(ManualTestEnum::VariantZ.default_weight(), 5.0);
84    }
85
86    #[test]
87    fn test_manual_random() {
88        let mut rng = StdRng::seed_from_u64(42);
89        let mut counts = HashMap::new();
90
91        for _ in 0..10000 {
92            let variant = ManualTestEnum::random_with_rng(&mut rng);
93            *counts.entry(variant).or_insert(0) += 1;
94        }
95
96        let total = counts.values().sum::<usize>() as f64;
97        let prob_x = *counts.get(&ManualTestEnum::VariantX).unwrap_or(&0) as f64 / total;
98        let prob_y = *counts.get(&ManualTestEnum::VariantY).unwrap_or(&0) as f64 / total;
99        let prob_z = *counts.get(&ManualTestEnum::VariantZ).unwrap_or(&0) as f64 / total;
100
101        // Expected probabilities: X: 0.2, Y: 0.3, Z: 0.5
102        assert!((prob_x - 0.2).abs() < 0.05);
103        assert!((prob_y - 0.3).abs() < 0.05);
104        assert!((prob_z - 0.5).abs() < 0.05);
105    }
106
107    #[test]
108    fn test_manual_uniform() {
109        let mut counts = HashMap::new();
110
111        for _ in 0..10000 {
112            let variant = ManualTestEnum::uniform();
113            *counts.entry(variant).or_insert(0) += 1;
114        }
115
116        let total = counts.values().sum::<usize>() as f64;
117        for &count in counts.values() {
118            let prob = count as f64 / total;
119            assert!((prob - (1.0 / 3.0)).abs() < 0.05);
120        }
121    }
122
123    #[test]
124    fn test_manual_random_with_probabilities() {
125        let mut rng = StdRng::seed_from_u64(42);
126        let probs = CustomProvider::probability_map();
127
128        let mut counts = HashMap::new();
129
130        for _ in 0..10000 {
131            let variant = sample_variants_with_probabilities(&mut rng, &probs);
132            *counts.entry(variant).or_insert(0) += 1;
133        }
134
135        // Expected probabilities: X: 0.1, Y: 0.1, Z: 0.8
136        let total = counts.values().sum::<usize>() as f64;
137        let prob_x = *counts.get(&ManualTestEnum::VariantX).unwrap_or(&0) as f64 / total;
138        let prob_y = *counts.get(&ManualTestEnum::VariantY).unwrap_or(&0) as f64 / total;
139        let prob_z = *counts.get(&ManualTestEnum::VariantZ).unwrap_or(&0) as f64 / total;
140
141        assert!((prob_x - 0.1).abs() < 0.02);
142        assert!((prob_y - 0.1).abs() < 0.02);
143        assert!((prob_z - 0.8).abs() < 0.05);
144    }
145
146    #[test]
147    fn test_manual_sample_from_provider() {
148        let mut rng = StdRng::seed_from_u64(42);
149        let mut counts = HashMap::new();
150
151        for _ in 0..10000 {
152            let variant = ManualTestEnum::sample_from_provider::<CustomProvider, _>(&mut rng);
153            *counts.entry(variant).or_insert(0) += 1;
154        }
155
156        // Expected probabilities: X: 0.1, Y: 0.1, Z: 0.8
157        let total = counts.values().sum::<usize>() as f64;
158        let prob_x = *counts.get(&ManualTestEnum::VariantX).unwrap_or(&0) as f64 / total;
159        let prob_y = *counts.get(&ManualTestEnum::VariantY).unwrap_or(&0) as f64 / total;
160        let prob_z = *counts.get(&ManualTestEnum::VariantZ).unwrap_or(&0) as f64 / total;
161
162        assert!((prob_x - 0.1).abs() < 0.02);
163        assert!((prob_y - 0.1).abs() < 0.02);
164        assert!((prob_z - 0.8).abs() < 0.05);
165    }
166}