random_constructible/
lib.rs1#![allow(unused_imports)]
2
3#[macro_use] mod imports; use imports::*;
4
5x!{rand_construct}
6x!{rand_construct_enum}
7x!{rand_construct_env}
8x!{prim_traits}
9x!{sample}
10x!{impl_for_optiont}
11
12#[cfg(test)]
13mod tests {
14 use super::*;
15 use rand::rngs::StdRng;
16 use rand::SeedableRng;
17 use std::collections::HashMap;
18 use std::sync::Arc;
19
20 #[derive(Default,Clone, Debug, PartialEq, Eq, Hash)]
22 enum ManualTestEnum {
23 #[default]
24 VariantX,
25 VariantY,
26 VariantZ,
27 }
28
29 impl RandConstructEnumWithEnv for ManualTestEnum {}
30
31 impl RandConstructEnum for ManualTestEnum {
32 fn all_variants() -> Vec<Self> {
33 vec![Self::VariantX, Self::VariantY, Self::VariantZ]
34 }
35
36 fn default_weight(&self) -> f64 {
37 match self {
38 Self::VariantX => 2.0,
39 Self::VariantY => 3.0,
40 Self::VariantZ => 5.0,
41 }
42 }
43
44 fn create_default_probability_map() -> Arc<HashMap<Self, f64>> {
45 DefaultProvider::probability_map()
46 }
47 }
48
49 struct DefaultProvider;
51
52 rand_construct_env!(DefaultProvider => ManualTestEnum {
53 VariantX => 2.0,
54 VariantY => 3.0,
55 VariantZ => 5.0,
56 });
57
58 struct CustomProvider;
60
61 rand_construct_env!(CustomProvider => ManualTestEnum {
62 VariantX => 1.0,
63 VariantY => 1.0,
64 VariantZ => 8.0,
65 });
66
67 #[test]
68 fn test_manual_all_variants() {
69 let variants = ManualTestEnum::all_variants();
70 assert_eq!(variants.len(), 3);
71 assert!(variants.contains(&ManualTestEnum::VariantX));
72 assert!(variants.contains(&ManualTestEnum::VariantY));
73 assert!(variants.contains(&ManualTestEnum::VariantZ));
74 }
75
76 #[test]
77 fn test_manual_default_weight() {
78 assert_eq!(ManualTestEnum::VariantX.default_weight(), 2.0);
79 assert_eq!(ManualTestEnum::VariantY.default_weight(), 3.0);
80 assert_eq!(ManualTestEnum::VariantZ.default_weight(), 5.0);
81 }
82
83 #[test]
84 fn test_manual_random() {
85 let mut rng = StdRng::seed_from_u64(42);
86 let mut counts = HashMap::new();
87
88 for _ in 0..10000 {
89 let variant = ManualTestEnum::random_with_rng(&mut rng);
90 *counts.entry(variant).or_insert(0) += 1;
91 }
92
93 let total = counts.values().sum::<usize>() as f64;
94 let prob_x = *counts.get(&ManualTestEnum::VariantX).unwrap_or(&0) as f64 / total;
95 let prob_y = *counts.get(&ManualTestEnum::VariantY).unwrap_or(&0) as f64 / total;
96 let prob_z = *counts.get(&ManualTestEnum::VariantZ).unwrap_or(&0) as f64 / total;
97
98 assert!((prob_x - 0.2).abs() < 0.05);
100 assert!((prob_y - 0.3).abs() < 0.05);
101 assert!((prob_z - 0.5).abs() < 0.05);
102 }
103
104 #[test]
105 fn test_manual_uniform() {
106 let mut counts = HashMap::new();
107
108 for _ in 0..10000 {
109 let variant = ManualTestEnum::uniform();
110 *counts.entry(variant).or_insert(0) += 1;
111 }
112
113 let total = counts.values().sum::<usize>() as f64;
114 for &count in counts.values() {
115 let prob = count as f64 / total;
116 assert!((prob - (1.0 / 3.0)).abs() < 0.05);
117 }
118 }
119
120 #[test]
121 fn test_manual_random_with_probabilities() {
122 let mut rng = StdRng::seed_from_u64(42);
123 let probs = CustomProvider::probability_map();
124
125 let mut counts = HashMap::new();
126
127 for _ in 0..10000 {
128 let variant = sample_variants_with_probabilities(&mut rng, &probs);
129 *counts.entry(variant).or_insert(0) += 1;
130 }
131
132 let total = counts.values().sum::<usize>() as f64;
134 let prob_x = *counts.get(&ManualTestEnum::VariantX).unwrap_or(&0) as f64 / total;
135 let prob_y = *counts.get(&ManualTestEnum::VariantY).unwrap_or(&0) as f64 / total;
136 let prob_z = *counts.get(&ManualTestEnum::VariantZ).unwrap_or(&0) as f64 / total;
137
138 assert!((prob_x - 0.1).abs() < 0.02);
139 assert!((prob_y - 0.1).abs() < 0.02);
140 assert!((prob_z - 0.8).abs() < 0.05);
141 }
142
143 #[test]
144 fn test_manual_sample_from_provider() {
145 let mut rng = StdRng::seed_from_u64(42);
146 let mut counts = HashMap::new();
147
148 for _ in 0..10000 {
149 let variant = ManualTestEnum::sample_from_provider::<CustomProvider, _>(&mut rng);
150 *counts.entry(variant).or_insert(0) += 1;
151 }
152
153 let total = counts.values().sum::<usize>() as f64;
155 let prob_x = *counts.get(&ManualTestEnum::VariantX).unwrap_or(&0) as f64 / total;
156 let prob_y = *counts.get(&ManualTestEnum::VariantY).unwrap_or(&0) as f64 / total;
157 let prob_z = *counts.get(&ManualTestEnum::VariantZ).unwrap_or(&0) as f64 / total;
158
159 assert!((prob_x - 0.1).abs() < 0.02);
160 assert!((prob_y - 0.1).abs() < 0.02);
161 assert!((prob_z - 0.8).abs() < 0.05);
162 }
163}