1use rand::Rng;
8use rand_distr::{Beta, Distribution};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum SignalKind {
15 Explicit,
17 Implicit,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq)]
23pub struct SignalInput {
24 pub kind: SignalKind,
26 pub value: f64,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq)]
32pub struct AggregationConfig {
33 pub explicit_weight: f64,
35 pub implicit_weight: f64,
37}
38
39impl Default for AggregationConfig {
40 fn default() -> Self {
41 Self {
42 explicit_weight: 5.0,
43 implicit_weight: 1.0,
44 }
45 }
46}
47
48impl SignalInput {
49 pub fn weight(&self, config: &AggregationConfig) -> f64 {
51 match self.kind {
52 SignalKind::Explicit => config.explicit_weight,
53 SignalKind::Implicit => config.implicit_weight,
54 }
55 }
56}
57
58pub fn aggregate(signals: &[SignalInput], config: &AggregationConfig) -> f64 {
63 if signals.is_empty() {
64 return 0.5;
65 }
66 let mut numerator = 0.0;
67 let mut denominator = 0.0;
68 for s in signals {
69 let w = s.weight(config);
70 numerator += w * s.value.clamp(0.0, 1.0);
71 denominator += w;
72 }
73 (numerator / denominator).clamp(0.0, 1.0)
74}
75
76fn wins_losses(scores: &[f64]) -> (u32, u32) {
78 let wins: u32 = scores.iter().filter(|&&s| s >= 0.5).count() as u32;
79 let losses = scores.len() as u32 - wins;
80 (wins, losses)
81}
82
83#[derive(Debug, Clone, Copy, PartialEq)]
85pub struct PromotionConfig {
86 pub min_sessions_per_arm: usize,
88 pub promote_threshold: f64,
90 pub mc_samples: u32,
92}
93
94impl Default for PromotionConfig {
95 fn default() -> Self {
96 Self {
97 min_sessions_per_arm: 20,
98 promote_threshold: 0.95,
99 mc_samples: 10_000,
100 }
101 }
102}
103
104#[derive(Debug, Clone, Copy, PartialEq)]
106pub enum Decision {
107 NeedMoreData {
109 sessions_each: usize,
111 required: usize,
113 },
114 Hold {
116 posterior: f64,
118 },
119 Promote {
121 posterior: f64,
123 },
124}
125
126pub fn promotion_decision<R: Rng>(
128 champion_scores: &[f64],
129 challenger_scores: &[f64],
130 config: &PromotionConfig,
131 rng: &mut R,
132) -> Decision {
133 let champ_n = champion_scores.len();
134 let chall_n = challenger_scores.len();
135 if champ_n < config.min_sessions_per_arm || chall_n < config.min_sessions_per_arm {
136 return Decision::NeedMoreData {
137 sessions_each: champ_n.min(chall_n),
138 required: config.min_sessions_per_arm,
139 };
140 }
141 let posterior =
142 posterior_probability(champion_scores, challenger_scores, config.mc_samples, rng);
143 if posterior >= config.promote_threshold {
144 Decision::Promote { posterior }
145 } else {
146 Decision::Hold { posterior }
147 }
148}
149
150pub fn posterior_probability<R: Rng>(
155 champion_scores: &[f64],
156 challenger_scores: &[f64],
157 samples: u32,
158 rng: &mut R,
159) -> f64 {
160 let (cw, cl) = wins_losses(champion_scores);
161 let (hw, hl) = wins_losses(challenger_scores);
162 let champ = Beta::new(1.0 + cw as f64, 1.0 + cl as f64).expect("valid Beta params");
163 let chall = Beta::new(1.0 + hw as f64, 1.0 + hl as f64).expect("valid Beta params");
164 let mut hits: u32 = 0;
165 for _ in 0..samples {
166 let a: f64 = champ.sample(rng);
167 let b: f64 = chall.sample(rng);
168 if b > a {
169 hits += 1;
170 }
171 }
172 hits as f64 / samples as f64
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 #[test]
180 fn default_weights_are_five_to_one() {
181 let cfg = AggregationConfig::default();
182 assert_eq!(cfg.explicit_weight, 5.0);
183 assert_eq!(cfg.implicit_weight, 1.0);
184 }
185
186 #[test]
187 fn explicit_signal_weighs_five_times_implicit() {
188 let cfg = AggregationConfig::default();
189 let e = SignalInput {
190 kind: SignalKind::Explicit,
191 value: 1.0,
192 };
193 let i = SignalInput {
194 kind: SignalKind::Implicit,
195 value: 1.0,
196 };
197 assert_eq!(e.weight(&cfg) / i.weight(&cfg), 5.0);
198 }
199
200 #[test]
201 fn aggregate_empty_returns_neutral_half() {
202 assert_eq!(aggregate(&[], &AggregationConfig::default()), 0.5);
203 }
204
205 #[test]
206 fn aggregate_single_explicit_1_is_1() {
207 let signals = [SignalInput {
208 kind: SignalKind::Explicit,
209 value: 1.0,
210 }];
211 assert_eq!(aggregate(&signals, &AggregationConfig::default()), 1.0);
212 }
213
214 #[test]
215 fn aggregate_single_implicit_0_is_0() {
216 let signals = [SignalInput {
217 kind: SignalKind::Implicit,
218 value: 0.0,
219 }];
220 assert_eq!(aggregate(&signals, &AggregationConfig::default()), 0.0);
221 }
222
223 #[test]
224 fn aggregate_clips_out_of_range_values() {
225 let signals = [SignalInput {
226 kind: SignalKind::Implicit,
227 value: 2.0,
228 }];
229 assert_eq!(aggregate(&signals, &AggregationConfig::default()), 1.0);
230 }
231
232 #[test]
233 fn aggregate_weighted_mean_matches_hand_calculation() {
234 let signals = [
237 SignalInput {
238 kind: SignalKind::Explicit,
239 value: 0.0,
240 },
241 SignalInput {
242 kind: SignalKind::Implicit,
243 value: 1.0,
244 },
245 SignalInput {
246 kind: SignalKind::Implicit,
247 value: 1.0,
248 },
249 ];
250 let got = aggregate(&signals, &AggregationConfig::default());
251 assert!((got - 2.0 / 7.0).abs() < 1e-9, "got {got}");
252 }
253
254 use rand::SeedableRng;
255 use rand_chacha::ChaCha8Rng;
256
257 fn seeded_rng() -> ChaCha8Rng {
258 ChaCha8Rng::seed_from_u64(42)
259 }
260
261 #[test]
262 fn posterior_obvious_challenger_win_exceeds_threshold() {
263 let champion: Vec<f64> = (0..20).map(|i| if i < 5 { 1.0 } else { 0.0 }).collect();
264 let challenger: Vec<f64> = (0..20).map(|i| if i < 18 { 1.0 } else { 0.0 }).collect();
265 let p = posterior_probability(&champion, &challenger, 10_000, &mut seeded_rng());
266 assert!(p > 0.95, "expected P(chall > champ) > 0.95, got {p}");
267 }
268
269 #[test]
270 fn posterior_obvious_champion_win_stays_below_threshold() {
271 let champion: Vec<f64> = (0..20).map(|i| if i < 18 { 1.0 } else { 0.0 }).collect();
272 let challenger: Vec<f64> = (0..20).map(|i| if i < 5 { 1.0 } else { 0.0 }).collect();
273 let p = posterior_probability(&champion, &challenger, 10_000, &mut seeded_rng());
274 assert!(p < 0.05, "expected P(chall > champ) < 0.05, got {p}");
275 }
276
277 #[test]
278 fn posterior_tied_evidence_stays_near_half() {
279 let champion: Vec<f64> = (0..40).map(|i| if i < 20 { 1.0 } else { 0.0 }).collect();
280 let challenger: Vec<f64> = (0..40).map(|i| if i < 20 { 1.0 } else { 0.0 }).collect();
281 let p = posterior_probability(&champion, &challenger, 10_000, &mut seeded_rng());
282 assert!((p - 0.5).abs() < 0.05, "expected P near 0.5, got {p}");
283 }
284
285 #[test]
286 fn posterior_is_deterministic_under_same_seed() {
287 let champion: Vec<f64> = (0..10).map(|_| 0.6).collect();
288 let challenger: Vec<f64> = (0..10).map(|_| 0.7).collect();
289 let p1 = posterior_probability(&champion, &challenger, 5_000, &mut seeded_rng());
290 let p2 = posterior_probability(&champion, &challenger, 5_000, &mut seeded_rng());
291 assert_eq!(p1, p2);
292 }
293
294 #[test]
295 fn decision_needs_more_data_when_either_arm_is_thin() {
296 let champion: Vec<f64> = vec![1.0; 5];
297 let challenger: Vec<f64> = vec![0.0; 20];
298 let cfg = PromotionConfig::default();
299 let d = promotion_decision(&champion, &challenger, &cfg, &mut seeded_rng());
300 assert!(matches!(d, Decision::NeedMoreData { .. }));
301 }
302
303 #[test]
304 fn decision_promotes_obvious_winner() {
305 let champion: Vec<f64> = (0..25).map(|i| if i < 5 { 1.0 } else { 0.0 }).collect();
306 let challenger: Vec<f64> = (0..25).map(|i| if i < 23 { 1.0 } else { 0.0 }).collect();
307 let cfg = PromotionConfig::default();
308 let d = promotion_decision(&champion, &challenger, &cfg, &mut seeded_rng());
309 match d {
310 Decision::Promote { posterior } => {
311 assert!(
312 posterior >= cfg.promote_threshold,
313 "posterior {posterior} below threshold {}",
314 cfg.promote_threshold,
315 );
316 }
317 other => panic!("expected Promote, got {other:?}"),
318 }
319 }
320
321 #[test]
322 fn decision_holds_when_evidence_is_tied() {
323 let champion: Vec<f64> = (0..30).map(|i| if i < 15 { 1.0 } else { 0.0 }).collect();
324 let challenger: Vec<f64> = (0..30).map(|i| if i < 15 { 1.0 } else { 0.0 }).collect();
325 let cfg = PromotionConfig::default();
326 let d = promotion_decision(&champion, &challenger, &cfg, &mut seeded_rng());
327 assert!(matches!(d, Decision::Hold { .. }));
328 }
329
330 #[test]
331 fn decision_finishes_in_reasonable_time_for_realistic_input() {
332 let champion: Vec<f64> = (0..100).map(|i| if i < 60 { 1.0 } else { 0.0 }).collect();
333 let challenger: Vec<f64> = (0..100).map(|i| if i < 70 { 1.0 } else { 0.0 }).collect();
334 let cfg = PromotionConfig::default();
335 let mut r = seeded_rng();
336
337 let start = std::time::Instant::now();
338 let _ = promotion_decision(&champion, &challenger, &cfg, &mut r);
339 let elapsed = start.elapsed();
340
341 assert!(
344 elapsed.as_millis() < 50,
345 "promotion_decision took {elapsed:?}; expected < 50ms",
346 );
347 }
348
349 #[test]
350 fn aggregate_single_explicit_dominates_many_implicit() {
351 let signals = [
352 SignalInput {
353 kind: SignalKind::Explicit,
354 value: 0.0,
355 },
356 SignalInput {
357 kind: SignalKind::Implicit,
358 value: 1.0,
359 },
360 SignalInput {
361 kind: SignalKind::Implicit,
362 value: 1.0,
363 },
364 SignalInput {
365 kind: SignalKind::Implicit,
366 value: 1.0,
367 },
368 ];
369 let got = aggregate(&signals, &AggregationConfig::default());
370 assert!(
372 got < 0.5,
373 "explicit 0.0 should pull aggregate below 0.5, got {got}",
374 );
375 }
376}
377
378#[cfg(test)]
379mod proptests {
380 use super::*;
381 use proptest::prelude::*;
382 use rand::SeedableRng;
383 use rand_chacha::ChaCha8Rng;
384
385 fn arb_scores(max_n: usize) -> impl Strategy<Value = Vec<f64>> {
386 prop::collection::vec(prop_oneof![Just(0.0_f64), Just(1.0_f64)], 0..max_n)
387 }
388
389 fn arb_signal() -> impl Strategy<Value = SignalInput> {
390 (prop::bool::ANY, -10.0_f64..10.0_f64).prop_map(|(is_explicit, v)| SignalInput {
391 kind: if is_explicit {
392 SignalKind::Explicit
393 } else {
394 SignalKind::Implicit
395 },
396 value: v,
397 })
398 }
399
400 proptest! {
401 #[test]
403 fn aggregate_is_in_unit_interval(
404 signals in prop::collection::vec(arb_signal(), 0..50),
405 ) {
406 let out = aggregate(&signals, &AggregationConfig::default());
407 prop_assert!((0.0..=1.0).contains(&out), "got {out}");
408 }
409
410 #[test]
412 fn decision_never_promotes_below_threshold(
413 champion in arb_scores(60),
414 challenger in arb_scores(60),
415 ) {
416 let cfg = PromotionConfig::default();
417 let mut r = ChaCha8Rng::seed_from_u64(1);
418 let d = promotion_decision(&champion, &challenger, &cfg, &mut r);
419 if let Decision::Promote { posterior } = d {
420 prop_assert!(posterior >= cfg.promote_threshold);
421 }
422 }
423
424 #[test]
426 fn posterior_is_in_unit_interval(
427 champion in arb_scores(50),
428 challenger in arb_scores(50),
429 ) {
430 let mut r = ChaCha8Rng::seed_from_u64(7);
431 let p = posterior_probability(&champion, &challenger, 1_000, &mut r);
432 prop_assert!((0.0..=1.0).contains(&p), "got {p}");
433 }
434 }
435}