Skip to main content

evolve_core/
promotion.rs

1//! Signal aggregation + Bayesian champion-vs-challenger promotion math.
2//!
3//! Pure functions, no I/O. Callers (CLI, adapters) translate
4//! `evolve_storage::signals::Signal` rows into [`SignalInput`] before calling
5//! into this module.
6
7use rand::Rng;
8use rand_distr::{Beta, Distribution};
9
10/// Whether a signal was contributed explicitly by the user
11/// (`evolve good`/`bad`/`thumbs`) or inferred implicitly by an adapter
12/// from the session log.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum SignalKind {
15    /// User explicitly graded the session.
16    Explicit,
17    /// Inferred from adapter session log.
18    Implicit,
19}
20
21/// One normalized fitness signal feeding into aggregation.
22#[derive(Debug, Clone, Copy, PartialEq)]
23pub struct SignalInput {
24    /// Source category -- controls weighting.
25    pub kind: SignalKind,
26    /// Score in `[0.0, 1.0]`. Out-of-range values are clamped before weighting.
27    pub value: f64,
28}
29
30/// Per-kind weights used by the aggregator.
31#[derive(Debug, Clone, Copy, PartialEq)]
32pub struct AggregationConfig {
33    /// Weight applied to explicit user signals. Default 5.0.
34    pub explicit_weight: f64,
35    /// Weight applied to adapter-inferred implicit signals. Default 1.0.
36    pub implicit_weight: f64,
37}
38
39impl Default for AggregationConfig {
40    fn default() -> Self {
41        Self {
42            explicit_weight: 5.0,
43            implicit_weight: 1.0,
44        }
45    }
46}
47
48impl SignalInput {
49    /// Weight this signal carries under the given aggregation config.
50    pub fn weight(&self, config: &AggregationConfig) -> f64 {
51        match self.kind {
52            SignalKind::Explicit => config.explicit_weight,
53            SignalKind::Implicit => config.implicit_weight,
54        }
55    }
56}
57
58/// Collapse a session's signals into a single fitness score in `[0.0, 1.0]`.
59///
60/// Uses the weighted arithmetic mean. Values are clamped to `[0.0, 1.0]`
61/// before weighting. Empty input returns `0.5` (neutral prior).
62pub fn aggregate(signals: &[SignalInput], config: &AggregationConfig) -> f64 {
63    if signals.is_empty() {
64        return 0.5;
65    }
66    let mut numerator = 0.0;
67    let mut denominator = 0.0;
68    for s in signals {
69        let w = s.weight(config);
70        numerator += w * s.value.clamp(0.0, 1.0);
71        denominator += w;
72    }
73    (numerator / denominator).clamp(0.0, 1.0)
74}
75
76/// Count scores >= 0.5 as wins, the rest as losses.
77fn wins_losses(scores: &[f64]) -> (u32, u32) {
78    let wins: u32 = scores.iter().filter(|&&s| s >= 0.5).count() as u32;
79    let losses = scores.len() as u32 - wins;
80    (wins, losses)
81}
82
83/// Configuration for [`promotion_decision`].
84#[derive(Debug, Clone, Copy, PartialEq)]
85pub struct PromotionConfig {
86    /// Minimum sessions required in each arm before any decision is made.
87    pub min_sessions_per_arm: usize,
88    /// Posterior threshold above which the challenger is promoted.
89    pub promote_threshold: f64,
90    /// Monte Carlo sample count for [`posterior_probability`].
91    pub mc_samples: u32,
92}
93
94impl Default for PromotionConfig {
95    fn default() -> Self {
96        Self {
97            min_sessions_per_arm: 20,
98            promote_threshold: 0.95,
99            mc_samples: 10_000,
100        }
101    }
102}
103
104/// Outcome of a promotion evaluation.
105#[derive(Debug, Clone, Copy, PartialEq)]
106pub enum Decision {
107    /// At least one arm has too few sessions to decide yet.
108    NeedMoreData {
109        /// Sessions in the thinner arm.
110        sessions_each: usize,
111        /// Minimum required per arm.
112        required: usize,
113    },
114    /// Enough data, but posterior below threshold. Keep running.
115    Hold {
116        /// Current estimated `P(challenger > champion)`.
117        posterior: f64,
118    },
119    /// Promote: posterior crossed threshold.
120    Promote {
121        /// Current estimated `P(challenger > champion)`.
122        posterior: f64,
123    },
124}
125
126/// Evaluate whether the challenger should be promoted, held, or needs more data.
127pub fn promotion_decision<R: Rng>(
128    champion_scores: &[f64],
129    challenger_scores: &[f64],
130    config: &PromotionConfig,
131    rng: &mut R,
132) -> Decision {
133    let champ_n = champion_scores.len();
134    let chall_n = challenger_scores.len();
135    if champ_n < config.min_sessions_per_arm || chall_n < config.min_sessions_per_arm {
136        return Decision::NeedMoreData {
137            sessions_each: champ_n.min(chall_n),
138            required: config.min_sessions_per_arm,
139        };
140    }
141    let posterior =
142        posterior_probability(champion_scores, challenger_scores, config.mc_samples, rng);
143    if posterior >= config.promote_threshold {
144        Decision::Promote { posterior }
145    } else {
146        Decision::Hold { posterior }
147    }
148}
149
150/// Monte Carlo estimate of `P(challenger > champion)` under beta-binomial
151/// posteriors: `Beta(1 + wins, 1 + losses)` per arm (uniform Jeffreys-like prior).
152///
153/// Each session score is binarized at 0.5 before counting.
154pub fn posterior_probability<R: Rng>(
155    champion_scores: &[f64],
156    challenger_scores: &[f64],
157    samples: u32,
158    rng: &mut R,
159) -> f64 {
160    let (cw, cl) = wins_losses(champion_scores);
161    let (hw, hl) = wins_losses(challenger_scores);
162    let champ = Beta::new(1.0 + cw as f64, 1.0 + cl as f64).expect("valid Beta params");
163    let chall = Beta::new(1.0 + hw as f64, 1.0 + hl as f64).expect("valid Beta params");
164    let mut hits: u32 = 0;
165    for _ in 0..samples {
166        let a: f64 = champ.sample(rng);
167        let b: f64 = chall.sample(rng);
168        if b > a {
169            hits += 1;
170        }
171    }
172    hits as f64 / samples as f64
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    #[test]
180    fn default_weights_are_five_to_one() {
181        let cfg = AggregationConfig::default();
182        assert_eq!(cfg.explicit_weight, 5.0);
183        assert_eq!(cfg.implicit_weight, 1.0);
184    }
185
186    #[test]
187    fn explicit_signal_weighs_five_times_implicit() {
188        let cfg = AggregationConfig::default();
189        let e = SignalInput {
190            kind: SignalKind::Explicit,
191            value: 1.0,
192        };
193        let i = SignalInput {
194            kind: SignalKind::Implicit,
195            value: 1.0,
196        };
197        assert_eq!(e.weight(&cfg) / i.weight(&cfg), 5.0);
198    }
199
200    #[test]
201    fn aggregate_empty_returns_neutral_half() {
202        assert_eq!(aggregate(&[], &AggregationConfig::default()), 0.5);
203    }
204
205    #[test]
206    fn aggregate_single_explicit_1_is_1() {
207        let signals = [SignalInput {
208            kind: SignalKind::Explicit,
209            value: 1.0,
210        }];
211        assert_eq!(aggregate(&signals, &AggregationConfig::default()), 1.0);
212    }
213
214    #[test]
215    fn aggregate_single_implicit_0_is_0() {
216        let signals = [SignalInput {
217            kind: SignalKind::Implicit,
218            value: 0.0,
219        }];
220        assert_eq!(aggregate(&signals, &AggregationConfig::default()), 0.0);
221    }
222
223    #[test]
224    fn aggregate_clips_out_of_range_values() {
225        let signals = [SignalInput {
226            kind: SignalKind::Implicit,
227            value: 2.0,
228        }];
229        assert_eq!(aggregate(&signals, &AggregationConfig::default()), 1.0);
230    }
231
232    #[test]
233    fn aggregate_weighted_mean_matches_hand_calculation() {
234        // 1 explicit at 0.0 (weight 5) + 2 implicit at 1.0 (weight 1 each)
235        // weighted mean = (5*0 + 1*1 + 1*1) / (5 + 1 + 1) = 2/7
236        let signals = [
237            SignalInput {
238                kind: SignalKind::Explicit,
239                value: 0.0,
240            },
241            SignalInput {
242                kind: SignalKind::Implicit,
243                value: 1.0,
244            },
245            SignalInput {
246                kind: SignalKind::Implicit,
247                value: 1.0,
248            },
249        ];
250        let got = aggregate(&signals, &AggregationConfig::default());
251        assert!((got - 2.0 / 7.0).abs() < 1e-9, "got {got}");
252    }
253
254    use rand::SeedableRng;
255    use rand_chacha::ChaCha8Rng;
256
257    fn seeded_rng() -> ChaCha8Rng {
258        ChaCha8Rng::seed_from_u64(42)
259    }
260
261    #[test]
262    fn posterior_obvious_challenger_win_exceeds_threshold() {
263        let champion: Vec<f64> = (0..20).map(|i| if i < 5 { 1.0 } else { 0.0 }).collect();
264        let challenger: Vec<f64> = (0..20).map(|i| if i < 18 { 1.0 } else { 0.0 }).collect();
265        let p = posterior_probability(&champion, &challenger, 10_000, &mut seeded_rng());
266        assert!(p > 0.95, "expected P(chall > champ) > 0.95, got {p}");
267    }
268
269    #[test]
270    fn posterior_obvious_champion_win_stays_below_threshold() {
271        let champion: Vec<f64> = (0..20).map(|i| if i < 18 { 1.0 } else { 0.0 }).collect();
272        let challenger: Vec<f64> = (0..20).map(|i| if i < 5 { 1.0 } else { 0.0 }).collect();
273        let p = posterior_probability(&champion, &challenger, 10_000, &mut seeded_rng());
274        assert!(p < 0.05, "expected P(chall > champ) < 0.05, got {p}");
275    }
276
277    #[test]
278    fn posterior_tied_evidence_stays_near_half() {
279        let champion: Vec<f64> = (0..40).map(|i| if i < 20 { 1.0 } else { 0.0 }).collect();
280        let challenger: Vec<f64> = (0..40).map(|i| if i < 20 { 1.0 } else { 0.0 }).collect();
281        let p = posterior_probability(&champion, &challenger, 10_000, &mut seeded_rng());
282        assert!((p - 0.5).abs() < 0.05, "expected P near 0.5, got {p}");
283    }
284
285    #[test]
286    fn posterior_is_deterministic_under_same_seed() {
287        let champion: Vec<f64> = (0..10).map(|_| 0.6).collect();
288        let challenger: Vec<f64> = (0..10).map(|_| 0.7).collect();
289        let p1 = posterior_probability(&champion, &challenger, 5_000, &mut seeded_rng());
290        let p2 = posterior_probability(&champion, &challenger, 5_000, &mut seeded_rng());
291        assert_eq!(p1, p2);
292    }
293
294    #[test]
295    fn decision_needs_more_data_when_either_arm_is_thin() {
296        let champion: Vec<f64> = vec![1.0; 5];
297        let challenger: Vec<f64> = vec![0.0; 20];
298        let cfg = PromotionConfig::default();
299        let d = promotion_decision(&champion, &challenger, &cfg, &mut seeded_rng());
300        assert!(matches!(d, Decision::NeedMoreData { .. }));
301    }
302
303    #[test]
304    fn decision_promotes_obvious_winner() {
305        let champion: Vec<f64> = (0..25).map(|i| if i < 5 { 1.0 } else { 0.0 }).collect();
306        let challenger: Vec<f64> = (0..25).map(|i| if i < 23 { 1.0 } else { 0.0 }).collect();
307        let cfg = PromotionConfig::default();
308        let d = promotion_decision(&champion, &challenger, &cfg, &mut seeded_rng());
309        match d {
310            Decision::Promote { posterior } => {
311                assert!(
312                    posterior >= cfg.promote_threshold,
313                    "posterior {posterior} below threshold {}",
314                    cfg.promote_threshold,
315                );
316            }
317            other => panic!("expected Promote, got {other:?}"),
318        }
319    }
320
321    #[test]
322    fn decision_holds_when_evidence_is_tied() {
323        let champion: Vec<f64> = (0..30).map(|i| if i < 15 { 1.0 } else { 0.0 }).collect();
324        let challenger: Vec<f64> = (0..30).map(|i| if i < 15 { 1.0 } else { 0.0 }).collect();
325        let cfg = PromotionConfig::default();
326        let d = promotion_decision(&champion, &challenger, &cfg, &mut seeded_rng());
327        assert!(matches!(d, Decision::Hold { .. }));
328    }
329
330    #[test]
331    fn decision_finishes_in_reasonable_time_for_realistic_input() {
332        let champion: Vec<f64> = (0..100).map(|i| if i < 60 { 1.0 } else { 0.0 }).collect();
333        let challenger: Vec<f64> = (0..100).map(|i| if i < 70 { 1.0 } else { 0.0 }).collect();
334        let cfg = PromotionConfig::default();
335        let mut r = seeded_rng();
336
337        let start = std::time::Instant::now();
338        let _ = promotion_decision(&champion, &challenger, &cfg, &mut r);
339        let elapsed = start.elapsed();
340
341        // Generous cap: budget is ~1ms in release, debug builds on slow CI
342        // can swell this 5-10x. Failing here is a real red flag, not a flake.
343        assert!(
344            elapsed.as_millis() < 50,
345            "promotion_decision took {elapsed:?}; expected < 50ms",
346        );
347    }
348
349    #[test]
350    fn aggregate_single_explicit_dominates_many_implicit() {
351        let signals = [
352            SignalInput {
353                kind: SignalKind::Explicit,
354                value: 0.0,
355            },
356            SignalInput {
357                kind: SignalKind::Implicit,
358                value: 1.0,
359            },
360            SignalInput {
361                kind: SignalKind::Implicit,
362                value: 1.0,
363            },
364            SignalInput {
365                kind: SignalKind::Implicit,
366                value: 1.0,
367            },
368        ];
369        let got = aggregate(&signals, &AggregationConfig::default());
370        // (0*5 + 1*1 + 1*1 + 1*1) / (5+1+1+1) = 3/8 = 0.375, below 0.5 threshold
371        assert!(
372            got < 0.5,
373            "explicit 0.0 should pull aggregate below 0.5, got {got}",
374        );
375    }
376}
377
378#[cfg(test)]
379mod proptests {
380    use super::*;
381    use proptest::prelude::*;
382    use rand::SeedableRng;
383    use rand_chacha::ChaCha8Rng;
384
385    fn arb_scores(max_n: usize) -> impl Strategy<Value = Vec<f64>> {
386        prop::collection::vec(prop_oneof![Just(0.0_f64), Just(1.0_f64)], 0..max_n)
387    }
388
389    fn arb_signal() -> impl Strategy<Value = SignalInput> {
390        (prop::bool::ANY, -10.0_f64..10.0_f64).prop_map(|(is_explicit, v)| SignalInput {
391            kind: if is_explicit {
392                SignalKind::Explicit
393            } else {
394                SignalKind::Implicit
395            },
396            value: v,
397        })
398    }
399
400    proptest! {
401        /// `aggregate` always returns a value in `[0.0, 1.0]` for any inputs.
402        #[test]
403        fn aggregate_is_in_unit_interval(
404            signals in prop::collection::vec(arb_signal(), 0..50),
405        ) {
406            let out = aggregate(&signals, &AggregationConfig::default());
407            prop_assert!((0.0..=1.0).contains(&out), "got {out}");
408        }
409
410        /// `promotion_decision` never returns `Promote` with a posterior below threshold.
411        #[test]
412        fn decision_never_promotes_below_threshold(
413            champion in arb_scores(60),
414            challenger in arb_scores(60),
415        ) {
416            let cfg = PromotionConfig::default();
417            let mut r = ChaCha8Rng::seed_from_u64(1);
418            let d = promotion_decision(&champion, &challenger, &cfg, &mut r);
419            if let Decision::Promote { posterior } = d {
420                prop_assert!(posterior >= cfg.promote_threshold);
421            }
422        }
423
424        /// Posterior probability is always in `[0.0, 1.0]`.
425        #[test]
426        fn posterior_is_in_unit_interval(
427            champion in arb_scores(50),
428            challenger in arb_scores(50),
429        ) {
430            let mut r = ChaCha8Rng::seed_from_u64(7);
431            let p = posterior_probability(&champion, &challenger, 1_000, &mut r);
432            prop_assert!((0.0..=1.0).contains(&p), "got {p}");
433        }
434    }
435}