Skip to main content

oxihuman_core/
ab_test_config.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3#![allow(dead_code)]
4
5//! A/B testing configuration.
6
7use std::collections::HashMap;
8
9/// An A/B test variant definition.
10#[derive(Debug, Clone)]
11pub struct AbVariant {
12    pub name: String,
13    pub weight: f32,
14}
15
16/// A/B test configuration.
17#[derive(Debug, Default)]
18pub struct AbTestConfig {
19    tests: HashMap<String, Vec<AbVariant>>,
20}
21
22impl AbTestConfig {
23    pub fn new() -> Self {
24        Self::default()
25    }
26
27    pub fn add_test(&mut self, test_name: &str, variants: Vec<AbVariant>) {
28        self.tests.insert(test_name.to_string(), variants);
29    }
30
31    pub fn variant_count(&self, test_name: &str) -> usize {
32        self.tests.get(test_name).map(|v| v.len()).unwrap_or(0)
33    }
34
35    pub fn total_weight(&self, test_name: &str) -> f32 {
36        self.tests
37            .get(test_name)
38            .map(|v| v.iter().map(|vr| vr.weight).sum())
39            .unwrap_or(0.0)
40    }
41
42    pub fn select_variant(&self, test_name: &str, seed: f32) -> Option<&str> {
43        let variants = self.tests.get(test_name)?;
44        let total = self.total_weight(test_name);
45        if total <= 0.0 {
46            return None;
47        }
48        let target = seed.rem_euclid(total);
49        let mut cumulative = 0.0;
50        for v in variants {
51            cumulative += v.weight;
52            if target < cumulative {
53                return Some(&v.name);
54            }
55        }
56        variants.last().map(|v| v.name.as_str())
57    }
58
59    pub fn test_count(&self) -> usize {
60        self.tests.len()
61    }
62
63    pub fn test_names(&self) -> Vec<String> {
64        let mut names: Vec<String> = self.tests.keys().cloned().collect();
65        names.sort();
66        names
67    }
68}
69
70pub fn new_ab_test_config() -> AbTestConfig {
71    AbTestConfig::new()
72}
73
74pub fn ab_add_test(cfg: &mut AbTestConfig, name: &str, variants: Vec<AbVariant>) {
75    cfg.add_test(name, variants);
76}
77
78pub fn ab_select_variant<'a>(cfg: &'a AbTestConfig, test: &str, seed: f32) -> Option<&'a str> {
79    cfg.select_variant(test, seed)
80}
81
82pub fn ab_variant_count(cfg: &AbTestConfig, test: &str) -> usize {
83    cfg.variant_count(test)
84}
85
86pub fn ab_total_weight(cfg: &AbTestConfig, test: &str) -> f32 {
87    cfg.total_weight(test)
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    fn make_variants() -> Vec<AbVariant> {
95        vec![
96            AbVariant {
97                name: "control".to_string(),
98                weight: 50.0,
99            },
100            AbVariant {
101                name: "treatment".to_string(),
102                weight: 50.0,
103            },
104        ]
105    }
106
107    #[test]
108    fn test_add_and_count() {
109        let mut cfg = new_ab_test_config();
110        ab_add_test(&mut cfg, "button_color", make_variants());
111        assert_eq!(ab_variant_count(&cfg, "button_color"), 2);
112    }
113
114    #[test]
115    fn test_total_weight() {
116        let mut cfg = new_ab_test_config();
117        ab_add_test(&mut cfg, "test1", make_variants());
118        assert!((ab_total_weight(&cfg, "test1") - 100.0).abs() < 1e-5);
119    }
120
121    #[test]
122    fn test_select_control() {
123        /* seed=25 lands in first bucket (0..50) */
124        let mut cfg = new_ab_test_config();
125        ab_add_test(&mut cfg, "t", make_variants());
126        let v = ab_select_variant(&cfg, "t", 25.0);
127        assert_eq!(v, Some("control"));
128    }
129
130    #[test]
131    fn test_select_treatment() {
132        /* seed=75 lands in second bucket (50..100) */
133        let mut cfg = new_ab_test_config();
134        ab_add_test(&mut cfg, "t", make_variants());
135        let v = ab_select_variant(&cfg, "t", 75.0);
136        assert_eq!(v, Some("treatment"));
137    }
138
139    #[test]
140    fn test_unknown_test_returns_none() {
141        let cfg = new_ab_test_config();
142        assert_eq!(ab_select_variant(&cfg, "missing", 0.5), None);
143    }
144
145    #[test]
146    fn test_test_count() {
147        let mut cfg = new_ab_test_config();
148        ab_add_test(&mut cfg, "a", make_variants());
149        ab_add_test(&mut cfg, "b", make_variants());
150        assert_eq!(cfg.test_count(), 2);
151    }
152
153    #[test]
154    fn test_test_names_sorted() {
155        let mut cfg = new_ab_test_config();
156        ab_add_test(&mut cfg, "z_test", make_variants());
157        ab_add_test(&mut cfg, "a_test", make_variants());
158        assert_eq!(cfg.test_names()[0], "a_test");
159    }
160
161    #[test]
162    fn test_zero_weight_returns_none() {
163        let mut cfg = new_ab_test_config();
164        cfg.add_test(
165            "empty_w",
166            vec![AbVariant {
167                name: "v".to_string(),
168                weight: 0.0,
169            }],
170        );
171        assert_eq!(ab_select_variant(&cfg, "empty_w", 0.5), None);
172    }
173
174    #[test]
175    fn test_uneven_weights() {
176        /* 90/10 split: seed=95 should land in treatment */
177        let mut cfg = new_ab_test_config();
178        cfg.add_test(
179            "skewed",
180            vec![
181                AbVariant {
182                    name: "ctrl".to_string(),
183                    weight: 90.0,
184                },
185                AbVariant {
186                    name: "treat".to_string(),
187                    weight: 10.0,
188                },
189            ],
190        );
191        assert_eq!(ab_select_variant(&cfg, "skewed", 95.0), Some("treat"));
192    }
193
194    #[test]
195    fn test_seed_wraps_via_rem_euclid() {
196        /* seed=150 mod 100 = 50, should still land correctly */
197        let mut cfg = new_ab_test_config();
198        ab_add_test(&mut cfg, "t", make_variants());
199        let v = ab_select_variant(&cfg, "t", 150.0);
200        assert!(v.is_some());
201    }
202}