use irithyll_core::learner::StreamingLearner;
pub const COHORT_K: usize = 3;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CohortWeight {
#[default]
Uniform,
CiLower,
}
pub struct CohortMember {
pub model: Box<dyn StreamingLearner>,
pub metric: f64,
pub ci_lo: f64,
pub factory_name: String,
pub samples_in_cohort: u64,
}
impl std::fmt::Debug for CohortMember {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CohortMember")
.field("factory_name", &self.factory_name)
.field("metric", &self.metric)
.field("ci_lo", &self.ci_lo)
.field("samples_in_cohort", &self.samples_in_cohort)
.finish_non_exhaustive()
}
}
#[derive(Debug, Clone)]
pub struct CohortMemberSnapshot {
pub factory_name: String,
pub metric: f64,
pub ci_lo: f64,
pub samples_in_cohort: u64,
}
pub struct ChampionCohort {
members: Vec<CohortMember>,
weight_policy: CohortWeight,
pub challenges: u64,
pub demotions: u64,
}
impl std::fmt::Debug for ChampionCohort {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChampionCohort")
.field("n_members", &self.members.len())
.field("weight_policy", &self.weight_policy)
.field("challenges", &self.challenges)
.field("demotions", &self.demotions)
.finish()
}
}
impl ChampionCohort {
pub fn new(weight_policy: CohortWeight) -> Self {
Self {
members: Vec::with_capacity(COHORT_K),
weight_policy,
challenges: 0,
demotions: 0,
}
}
pub fn with_uniform_weights() -> Self {
Self::new(CohortWeight::Uniform)
}
pub fn len(&self) -> usize {
self.members.len()
}
pub fn is_empty(&self) -> bool {
self.members.is_empty()
}
pub fn is_full(&self) -> bool {
self.members.len() >= COHORT_K
}
pub fn predict(&self, features: &[f64]) -> f64 {
if self.members.is_empty() {
return f64::NAN;
}
match self.weight_policy {
CohortWeight::Uniform => {
let sum: f64 = self.members.iter().map(|m| m.model.predict(features)).sum();
sum / self.members.len() as f64
}
CohortWeight::CiLower => {
let min_ci = self
.members
.iter()
.map(|m| m.ci_lo)
.fold(f64::INFINITY, f64::min);
let eps = 1e-12;
let weights: Vec<f64> = self
.members
.iter()
.map(|m| m.ci_lo - min_ci + eps)
.collect();
let w_sum: f64 = weights.iter().sum();
if w_sum < 1e-30 {
let sum: f64 = self.members.iter().map(|m| m.model.predict(features)).sum();
return sum / self.members.len() as f64;
}
self.members
.iter()
.zip(weights.iter())
.map(|(m, &w)| m.model.predict(features) * w / w_sum)
.sum()
}
}
}
pub fn snapshots(&self) -> Vec<CohortMemberSnapshot> {
self.members
.iter()
.map(|m| CohortMemberSnapshot {
factory_name: m.factory_name.clone(),
metric: m.metric,
ci_lo: m.ci_lo,
samples_in_cohort: m.samples_in_cohort,
})
.collect()
}
pub fn best_metric(&self) -> f64 {
self.members
.first()
.map(|m| m.metric)
.unwrap_or(f64::INFINITY)
}
pub fn worst_metric(&self) -> f64 {
self.members
.last()
.map(|m| m.metric)
.unwrap_or(f64::INFINITY)
}
pub fn try_enter(
&mut self,
candidate: Box<dyn StreamingLearner>,
candidate_metric: f64,
candidate_ci_lo: f64,
factory_name: String,
) -> bool {
self.challenges += 1;
if !self.is_full() {
self.members.push(CohortMember {
model: candidate,
metric: candidate_metric,
ci_lo: candidate_ci_lo,
factory_name,
samples_in_cohort: 0,
});
self.sort_members();
return true;
}
let beats: usize = self
.members
.iter()
.filter(|m| candidate_metric < m.metric)
.count();
let quorum = COHORT_K / 2 + 1;
if beats < quorum {
return false;
}
let worst_idx = self.members.len() - 1;
self.members.remove(worst_idx); self.members.push(CohortMember {
model: candidate,
metric: candidate_metric,
ci_lo: candidate_ci_lo,
factory_name,
samples_in_cohort: 0,
});
self.sort_members();
self.demotions += 1;
true
}
pub fn update_member_metric(&mut self, idx: usize, new_metric: f64) {
if let Some(m) = self.members.get_mut(idx) {
m.metric = new_metric;
m.samples_in_cohort += 1;
}
self.sort_members();
}
pub fn train_all(&mut self, features: &[f64], target: f64, weight: f64) {
for m in &mut self.members {
m.model.train_one(features, target, weight);
m.samples_in_cohort += 1;
}
}
pub fn reset(&mut self) {
self.members.clear();
self.challenges = 0;
self.demotions = 0;
}
fn sort_members(&mut self) {
self.members.sort_by(|a, b| {
a.metric
.partial_cmp(&b.metric)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
}
#[cfg(test)]
mod tests {
use super::*;
struct ConstantPredictor {
value: f64,
n: u64,
}
impl ConstantPredictor {
fn boxed(value: f64) -> Box<dyn StreamingLearner> {
Box::new(Self { value, n: 0 })
}
}
impl StreamingLearner for ConstantPredictor {
fn train_one(&mut self, _: &[f64], _: f64, _: f64) {
self.n += 1;
}
fn predict(&self, _: &[f64]) -> f64 {
self.value
}
fn n_samples_seen(&self) -> u64 {
self.n
}
fn reset(&mut self) {
self.n = 0;
}
}
#[test]
fn cohort_demotes_lowest_when_new_candidate_enters() {
let mut cohort = ChampionCohort::with_uniform_weights();
cohort.try_enter(ConstantPredictor::boxed(1.0), 0.5, 0.0, "A".into());
cohort.try_enter(ConstantPredictor::boxed(1.0), 0.3, 0.0, "B".into());
cohort.try_enter(ConstantPredictor::boxed(1.0), 0.8, 0.0, "C".into());
assert_eq!(cohort.len(), 3, "cohort should be full");
assert!(
(cohort.worst_metric() - 0.8).abs() < 1e-12,
"worst should be 0.8, got {}",
cohort.worst_metric()
);
let entered = cohort.try_enter(ConstantPredictor::boxed(1.0), 0.1, 0.0, "D".into());
assert!(
entered,
"candidate with metric 0.1 should beat all 3 and enter"
);
assert_eq!(cohort.len(), 3, "cohort size must remain COHORT_K=3");
assert_eq!(cohort.demotions, 1, "exactly one demotion should occur");
assert!(
(cohort.worst_metric() - 0.5).abs() < 1e-12,
"after demotion worst should be 0.5, got {}",
cohort.worst_metric()
);
assert!(
(cohort.best_metric() - 0.1).abs() < 1e-12,
"after entry best should be 0.1, got {}",
cohort.best_metric()
);
}
#[test]
fn cohort_quorum_vote_demotes_majority() {
let mut cohort = ChampionCohort::with_uniform_weights();
cohort.try_enter(ConstantPredictor::boxed(1.0), 0.2, 0.0, "A".into());
cohort.try_enter(ConstantPredictor::boxed(1.0), 0.4, 0.0, "B".into());
cohort.try_enter(ConstantPredictor::boxed(1.0), 0.6, 0.0, "C".into());
let entered_quorum = cohort.try_enter(ConstantPredictor::boxed(2.0), 0.35, 0.0, "D".into());
assert!(
entered_quorum,
"candidate 0.35 beats 2 of 3 members — quorum (≥2) is met, must enter"
);
assert_eq!(cohort.demotions, 1, "one demotion for quorum entry");
let snaps = cohort.snapshots();
let metrics: Vec<f64> = snaps.iter().map(|s| s.metric).collect();
assert!(
metrics.iter().any(|&m| (m - 0.35).abs() < 1e-12),
"0.35 must be in the cohort: {:?}",
metrics
);
assert!(
!metrics.iter().any(|&m| (m - 0.6).abs() < 1e-12),
"0.6 must have been evicted: {:?}",
metrics
);
let entered_below_quorum =
cohort.try_enter(ConstantPredictor::boxed(3.0), 0.45, 0.0, "E".into());
assert!(
!entered_below_quorum,
"candidate 0.45 beats 0 of 3 members — below quorum, must be rejected"
);
assert_eq!(
cohort.demotions, 1,
"no additional demotion when candidate is rejected"
);
assert_eq!(cohort.len(), 3, "cohort must remain at COHORT_K=3");
}
#[test]
fn ewma_does_not_reset_at_round_boundary() {
let mut cohort = ChampionCohort::with_uniform_weights();
cohort.try_enter(ConstantPredictor::boxed(0.5), 0.3, 0.0, "A".into());
let features = [1.0_f64];
for _ in 0..50 {
cohort.train_all(&features, 1.0, 1.0);
}
let after_round1 = cohort.members[0].samples_in_cohort;
for _ in 0..50 {
cohort.train_all(&features, 1.0, 1.0);
}
let after_round2 = cohort.members[0].samples_in_cohort;
assert_eq!(
after_round1, 50,
"samples_in_cohort after round 1 should be 50, got {}",
after_round1
);
assert_eq!(
after_round2, 100,
"samples_in_cohort after round 2 should be 100 (no reset), got {}",
after_round2
);
}
}