oxihuman_core/
ab_test_config.rs1#![allow(dead_code)]
4
5use std::collections::HashMap;
8
9#[derive(Debug, Clone)]
11pub struct AbVariant {
12 pub name: String,
13 pub weight: f32,
14}
15
16#[derive(Debug, Default)]
18pub struct AbTestConfig {
19 tests: HashMap<String, Vec<AbVariant>>,
20}
21
22impl AbTestConfig {
23 pub fn new() -> Self {
24 Self::default()
25 }
26
27 pub fn add_test(&mut self, test_name: &str, variants: Vec<AbVariant>) {
28 self.tests.insert(test_name.to_string(), variants);
29 }
30
31 pub fn variant_count(&self, test_name: &str) -> usize {
32 self.tests.get(test_name).map(|v| v.len()).unwrap_or(0)
33 }
34
35 pub fn total_weight(&self, test_name: &str) -> f32 {
36 self.tests
37 .get(test_name)
38 .map(|v| v.iter().map(|vr| vr.weight).sum())
39 .unwrap_or(0.0)
40 }
41
42 pub fn select_variant(&self, test_name: &str, seed: f32) -> Option<&str> {
43 let variants = self.tests.get(test_name)?;
44 let total = self.total_weight(test_name);
45 if total <= 0.0 {
46 return None;
47 }
48 let target = seed.rem_euclid(total);
49 let mut cumulative = 0.0;
50 for v in variants {
51 cumulative += v.weight;
52 if target < cumulative {
53 return Some(&v.name);
54 }
55 }
56 variants.last().map(|v| v.name.as_str())
57 }
58
59 pub fn test_count(&self) -> usize {
60 self.tests.len()
61 }
62
63 pub fn test_names(&self) -> Vec<String> {
64 let mut names: Vec<String> = self.tests.keys().cloned().collect();
65 names.sort();
66 names
67 }
68}
69
70pub fn new_ab_test_config() -> AbTestConfig {
71 AbTestConfig::new()
72}
73
74pub fn ab_add_test(cfg: &mut AbTestConfig, name: &str, variants: Vec<AbVariant>) {
75 cfg.add_test(name, variants);
76}
77
78pub fn ab_select_variant<'a>(cfg: &'a AbTestConfig, test: &str, seed: f32) -> Option<&'a str> {
79 cfg.select_variant(test, seed)
80}
81
82pub fn ab_variant_count(cfg: &AbTestConfig, test: &str) -> usize {
83 cfg.variant_count(test)
84}
85
86pub fn ab_total_weight(cfg: &AbTestConfig, test: &str) -> f32 {
87 cfg.total_weight(test)
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 fn make_variants() -> Vec<AbVariant> {
95 vec![
96 AbVariant {
97 name: "control".to_string(),
98 weight: 50.0,
99 },
100 AbVariant {
101 name: "treatment".to_string(),
102 weight: 50.0,
103 },
104 ]
105 }
106
107 #[test]
108 fn test_add_and_count() {
109 let mut cfg = new_ab_test_config();
110 ab_add_test(&mut cfg, "button_color", make_variants());
111 assert_eq!(ab_variant_count(&cfg, "button_color"), 2);
112 }
113
114 #[test]
115 fn test_total_weight() {
116 let mut cfg = new_ab_test_config();
117 ab_add_test(&mut cfg, "test1", make_variants());
118 assert!((ab_total_weight(&cfg, "test1") - 100.0).abs() < 1e-5);
119 }
120
121 #[test]
122 fn test_select_control() {
123 let mut cfg = new_ab_test_config();
125 ab_add_test(&mut cfg, "t", make_variants());
126 let v = ab_select_variant(&cfg, "t", 25.0);
127 assert_eq!(v, Some("control"));
128 }
129
130 #[test]
131 fn test_select_treatment() {
132 let mut cfg = new_ab_test_config();
134 ab_add_test(&mut cfg, "t", make_variants());
135 let v = ab_select_variant(&cfg, "t", 75.0);
136 assert_eq!(v, Some("treatment"));
137 }
138
139 #[test]
140 fn test_unknown_test_returns_none() {
141 let cfg = new_ab_test_config();
142 assert_eq!(ab_select_variant(&cfg, "missing", 0.5), None);
143 }
144
145 #[test]
146 fn test_test_count() {
147 let mut cfg = new_ab_test_config();
148 ab_add_test(&mut cfg, "a", make_variants());
149 ab_add_test(&mut cfg, "b", make_variants());
150 assert_eq!(cfg.test_count(), 2);
151 }
152
153 #[test]
154 fn test_test_names_sorted() {
155 let mut cfg = new_ab_test_config();
156 ab_add_test(&mut cfg, "z_test", make_variants());
157 ab_add_test(&mut cfg, "a_test", make_variants());
158 assert_eq!(cfg.test_names()[0], "a_test");
159 }
160
161 #[test]
162 fn test_zero_weight_returns_none() {
163 let mut cfg = new_ab_test_config();
164 cfg.add_test(
165 "empty_w",
166 vec![AbVariant {
167 name: "v".to_string(),
168 weight: 0.0,
169 }],
170 );
171 assert_eq!(ab_select_variant(&cfg, "empty_w", 0.5), None);
172 }
173
174 #[test]
175 fn test_uneven_weights() {
176 let mut cfg = new_ab_test_config();
178 cfg.add_test(
179 "skewed",
180 vec![
181 AbVariant {
182 name: "ctrl".to_string(),
183 weight: 90.0,
184 },
185 AbVariant {
186 name: "treat".to_string(),
187 weight: 10.0,
188 },
189 ],
190 );
191 assert_eq!(ab_select_variant(&cfg, "skewed", 95.0), Some("treat"));
192 }
193
194 #[test]
195 fn test_seed_wraps_via_rem_euclid() {
196 let mut cfg = new_ab_test_config();
198 ab_add_test(&mut cfg, "t", make_variants());
199 let v = ab_select_variant(&cfg, "t", 150.0);
200 assert!(v.is_some());
201 }
202}