random_constructible/
lib.rs1#![cfg_attr(feature = "specialization", feature(min_specialization,specialization))]
3
4#![allow(unused_imports)]
5
6#[macro_use] mod imports; use imports::*;
7
8x!{rand_construct}
9x!{rand_construct_enum}
10x!{rand_construct_env}
11x!{prim_traits}
12x!{sample}
13x!{impl_for_optiont}
14
15#[cfg(test)]
16mod tests {
17 use super::*;
18 use rand::rngs::StdRng;
19 use rand::SeedableRng;
20 use std::collections::HashMap;
21 use std::sync::Arc;
22
23 #[derive(Default,Clone, Debug, PartialEq, Eq, Hash)]
25 enum ManualTestEnum {
26 #[default]
27 VariantX,
28 VariantY,
29 VariantZ,
30 }
31
32 impl RandConstructEnumWithEnv for ManualTestEnum {}
33
34 impl RandConstructEnum for ManualTestEnum {
35 fn all_variants() -> Vec<Self> {
36 vec![Self::VariantX, Self::VariantY, Self::VariantZ]
37 }
38
39 fn default_weight(&self) -> f64 {
40 match self {
41 Self::VariantX => 2.0,
42 Self::VariantY => 3.0,
43 Self::VariantZ => 5.0,
44 }
45 }
46
47 fn create_default_probability_map() -> Arc<HashMap<Self, f64>> {
48 DefaultProvider::probability_map()
49 }
50 }
51
52 struct DefaultProvider;
54
55 rand_construct_env!(DefaultProvider => ManualTestEnum {
56 VariantX => 2.0,
57 VariantY => 3.0,
58 VariantZ => 5.0,
59 });
60
61 struct CustomProvider;
63
64 rand_construct_env!(CustomProvider => ManualTestEnum {
65 VariantX => 1.0,
66 VariantY => 1.0,
67 VariantZ => 8.0,
68 });
69
70 #[test]
71 fn test_manual_all_variants() {
72 let variants = ManualTestEnum::all_variants();
73 assert_eq!(variants.len(), 3);
74 assert!(variants.contains(&ManualTestEnum::VariantX));
75 assert!(variants.contains(&ManualTestEnum::VariantY));
76 assert!(variants.contains(&ManualTestEnum::VariantZ));
77 }
78
79 #[test]
80 fn test_manual_default_weight() {
81 assert_eq!(ManualTestEnum::VariantX.default_weight(), 2.0);
82 assert_eq!(ManualTestEnum::VariantY.default_weight(), 3.0);
83 assert_eq!(ManualTestEnum::VariantZ.default_weight(), 5.0);
84 }
85
86 #[test]
87 fn test_manual_random() {
88 let mut rng = StdRng::seed_from_u64(42);
89 let mut counts = HashMap::new();
90
91 for _ in 0..10000 {
92 let variant = ManualTestEnum::random_with_rng(&mut rng);
93 *counts.entry(variant).or_insert(0) += 1;
94 }
95
96 let total = counts.values().sum::<usize>() as f64;
97 let prob_x = *counts.get(&ManualTestEnum::VariantX).unwrap_or(&0) as f64 / total;
98 let prob_y = *counts.get(&ManualTestEnum::VariantY).unwrap_or(&0) as f64 / total;
99 let prob_z = *counts.get(&ManualTestEnum::VariantZ).unwrap_or(&0) as f64 / total;
100
101 assert!((prob_x - 0.2).abs() < 0.05);
103 assert!((prob_y - 0.3).abs() < 0.05);
104 assert!((prob_z - 0.5).abs() < 0.05);
105 }
106
107 #[test]
108 fn test_manual_uniform() {
109 let mut counts = HashMap::new();
110
111 for _ in 0..10000 {
112 let variant = ManualTestEnum::uniform();
113 *counts.entry(variant).or_insert(0) += 1;
114 }
115
116 let total = counts.values().sum::<usize>() as f64;
117 for &count in counts.values() {
118 let prob = count as f64 / total;
119 assert!((prob - (1.0 / 3.0)).abs() < 0.05);
120 }
121 }
122
123 #[test]
124 fn test_manual_random_with_probabilities() {
125 let mut rng = StdRng::seed_from_u64(42);
126 let probs = CustomProvider::probability_map();
127
128 let mut counts = HashMap::new();
129
130 for _ in 0..10000 {
131 let variant = sample_variants_with_probabilities(&mut rng, &probs);
132 *counts.entry(variant).or_insert(0) += 1;
133 }
134
135 let total = counts.values().sum::<usize>() as f64;
137 let prob_x = *counts.get(&ManualTestEnum::VariantX).unwrap_or(&0) as f64 / total;
138 let prob_y = *counts.get(&ManualTestEnum::VariantY).unwrap_or(&0) as f64 / total;
139 let prob_z = *counts.get(&ManualTestEnum::VariantZ).unwrap_or(&0) as f64 / total;
140
141 assert!((prob_x - 0.1).abs() < 0.02);
142 assert!((prob_y - 0.1).abs() < 0.02);
143 assert!((prob_z - 0.8).abs() < 0.05);
144 }
145
146 #[test]
147 fn test_manual_sample_from_provider() {
148 let mut rng = StdRng::seed_from_u64(42);
149 let mut counts = HashMap::new();
150
151 for _ in 0..10000 {
152 let variant = ManualTestEnum::sample_from_provider::<CustomProvider, _>(&mut rng);
153 *counts.entry(variant).or_insert(0) += 1;
154 }
155
156 let total = counts.values().sum::<usize>() as f64;
158 let prob_x = *counts.get(&ManualTestEnum::VariantX).unwrap_or(&0) as f64 / total;
159 let prob_y = *counts.get(&ManualTestEnum::VariantY).unwrap_or(&0) as f64 / total;
160 let prob_z = *counts.get(&ManualTestEnum::VariantZ).unwrap_or(&0) as f64 / total;
161
162 assert!((prob_x - 0.1).abs() < 0.02);
163 assert!((prob_y - 0.1).abs() < 0.02);
164 assert!((prob_z - 0.8).abs() < 0.05);
165 }
166}