use crate::monitor::CusumCatBank;
use crate::{ContextBinConfig, ContextualCell, ContextualCoverageTracker, WorstFirstConfig};
use std::collections::BTreeMap;
pub struct OutcomeIdx;
impl OutcomeIdx {
pub const OK: usize = 0;
pub const SOFT_JUNK: usize = 1;
pub const HARD_JUNK: usize = 2;
pub const FAIL: usize = 3;
pub fn from_outcome(ok: bool, junk: bool, hard_junk: bool) -> usize {
if !ok {
Self::FAIL
} else if hard_junk {
Self::HARD_JUNK
} else if junk {
Self::SOFT_JUNK
} else {
Self::OK
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct TriageSessionConfig {
pub p0: [f64; 4],
pub alts: Vec<[f64; 4]>,
pub cusum_alpha: f64,
pub min_n: u64,
pub threshold: f64,
pub tol: f64,
pub bin_cfg: ContextBinConfig,
pub wf_cfg: WorstFirstConfig,
pub seed: u64,
}
impl Default for TriageSessionConfig {
fn default() -> Self {
Self {
p0: [0.85, 0.05, 0.05, 0.05],
alts: vec![
[0.40, 0.10, 0.40, 0.10], [0.40, 0.10, 0.10, 0.40], ],
cusum_alpha: 1e-3,
min_n: 20,
threshold: 5.0,
tol: 1e-6,
bin_cfg: ContextBinConfig::default(),
wf_cfg: WorstFirstConfig {
exploration_c: 1.0,
hard_weight: 3.0,
soft_weight: 1.0,
},
seed: 0xCA_FE_BA_BE,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct ArmTriageState {
pub n: u64,
pub score_max: f64,
pub alarmed: bool,
}
#[derive(Debug, Clone)]
pub struct TriageSession {
banks: BTreeMap<String, (CusumCatBank, bool)>, tracker: ContextualCoverageTracker,
bin_cfg: ContextBinConfig,
wf_cfg: WorstFirstConfig,
seed: u64,
}
impl TriageSession {
pub fn new(arms: &[String], cfg: TriageSessionConfig) -> Result<Self, logp::Error> {
let p0: Vec<f64> = cfg.p0.to_vec();
let alts: Vec<Vec<f64>> = cfg.alts.iter().map(|a| a.to_vec()).collect();
let mut banks = BTreeMap::new();
for arm in arms {
let bank = CusumCatBank::new(
&p0,
&alts,
cfg.cusum_alpha,
cfg.min_n,
cfg.threshold,
cfg.tol,
)?;
banks.insert(arm.clone(), (bank, false));
}
Ok(Self {
banks,
tracker: ContextualCoverageTracker::new(),
bin_cfg: cfg.bin_cfg,
wf_cfg: cfg.wf_cfg,
seed: cfg.seed,
})
}
pub fn observe(
&mut self,
arm: &str,
outcome_idx: usize,
context: &[f64],
) -> Option<crate::monitor::CusumCatBankUpdate> {
let bin = crate::context_bin(context, self.bin_cfg);
let update = if let Some((bank, alarmed)) = self.banks.get_mut(arm) {
let upd = bank.update(outcome_idx);
if upd.alarmed {
*alarmed = true;
}
Some(upd)
} else {
None
};
let hard_junk = outcome_idx == OutcomeIdx::HARD_JUNK;
let soft_junk = outcome_idx == OutcomeIdx::SOFT_JUNK;
self.tracker.record(arm, bin, hard_junk, soft_junk);
update
}
pub fn alarmed_arms(&self) -> Vec<String> {
self.banks
.iter()
.filter(|(_, (_, alarmed))| *alarmed)
.map(|(arm, _)| arm.clone())
.collect()
}
pub fn any_alarmed(&self) -> bool {
self.banks.values().any(|(_, alarmed)| *alarmed)
}
pub fn arm_state(&self, arm: &str) -> Option<ArmTriageState> {
self.banks.get(arm).map(|(bank, alarmed)| ArmTriageState {
n: bank.n(),
score_max: bank.score_max(),
alarmed: *alarmed,
})
}
pub fn top_cells(
&self,
arms: &[String],
active_bins: &[u64],
k: usize,
) -> Vec<(ContextualCell, bool)> {
self.tracker
.pick_k(self.seed, arms, active_bins, k, self.wf_cfg)
}
pub fn top_alarmed_cells(&self, active_bins: &[u64], k: usize) -> Vec<(ContextualCell, bool)> {
let arms = self.alarmed_arms();
if arms.is_empty() {
return Vec::new();
}
self.tracker
.pick_k(self.seed, &arms, active_bins, k, self.wf_cfg)
}
pub fn reset_arm(&mut self, arm: &str) {
if let Some((bank, alarmed)) = self.banks.get_mut(arm) {
bank.reset();
*alarmed = false;
}
}
pub fn tracker(&self) -> &ContextualCoverageTracker {
&self.tracker
}
}
#[cfg(test)]
mod tests {
use super::*;
fn two_arms() -> Vec<String> {
vec!["arm_a".to_string(), "arm_b".to_string()]
}
#[test]
fn new_session_no_alarms() {
let session = TriageSession::new(&two_arms(), TriageSessionConfig::default()).unwrap();
assert!(session.alarmed_arms().is_empty());
assert!(!session.any_alarmed());
}
#[test]
fn observe_returns_none_for_unknown_arm() {
let mut session = TriageSession::new(&two_arms(), TriageSessionConfig::default()).unwrap();
let upd = session.observe("unknown", OutcomeIdx::OK, &[0.5]);
assert!(upd.is_none());
}
#[test]
fn outcome_idx_mapping() {
assert_eq!(OutcomeIdx::from_outcome(true, false, false), OutcomeIdx::OK);
assert_eq!(
OutcomeIdx::from_outcome(true, true, false),
OutcomeIdx::SOFT_JUNK
);
assert_eq!(
OutcomeIdx::from_outcome(true, true, true),
OutcomeIdx::HARD_JUNK
);
assert_eq!(
OutcomeIdx::from_outcome(false, false, false),
OutcomeIdx::FAIL
);
}
#[test]
fn hard_junk_flood_triggers_alarm() {
let cfg = TriageSessionConfig {
min_n: 10,
threshold: 3.0,
..TriageSessionConfig::default()
};
let mut session = TriageSession::new(&two_arms(), cfg).unwrap();
for _ in 0..30 {
session.observe("arm_a", OutcomeIdx::OK, &[0.2, 0.3]);
}
for _ in 0..30 {
session.observe("arm_b", OutcomeIdx::HARD_JUNK, &[0.7, 0.8]);
}
let alarmed = session.alarmed_arms();
assert!(alarmed.contains(&"arm_b".to_string()), "arm_b should alarm");
assert!(
!alarmed.contains(&"arm_a".to_string()),
"arm_a should not alarm"
);
}
#[test]
fn top_alarmed_cells_targets_bad_arm() {
let cfg = TriageSessionConfig {
min_n: 10,
threshold: 3.0,
..TriageSessionConfig::default()
};
let mut session = TriageSession::new(&two_arms(), cfg).unwrap();
for _ in 0..30 {
session.observe("arm_a", OutcomeIdx::OK, &[0.2, 0.3]);
}
for _ in 0..30 {
session.observe("arm_b", OutcomeIdx::HARD_JUNK, &[0.7, 0.8]);
}
let bins = session.tracker().active_bins();
let picks = session.top_alarmed_cells(&bins, 2);
assert!(!picks.is_empty(), "should have triage picks");
assert_eq!(
picks[0].0.arm, "arm_b",
"top cell should be the alarmed arm"
);
}
#[test]
fn reset_arm_clears_alarm() {
let cfg = TriageSessionConfig {
min_n: 5,
threshold: 2.0,
..TriageSessionConfig::default()
};
let mut session = TriageSession::new(&two_arms(), cfg).unwrap();
for _ in 0..20 {
session.observe("arm_b", OutcomeIdx::HARD_JUNK, &[0.5]);
}
assert!(session.any_alarmed());
session.reset_arm("arm_b");
assert!(!session.any_alarmed(), "alarm should clear after reset");
let state = session.arm_state("arm_b").unwrap();
assert_eq!(state.n, 0, "CUSUM n should reset to 0");
assert_eq!(state.score_max, 0.0, "CUSUM score should reset to 0");
}
#[test]
fn top_cells_empty_when_no_bins() {
let session = TriageSession::new(&two_arms(), TriageSessionConfig::default()).unwrap();
let arms = two_arms();
let picks = session.top_cells(&arms, &[], 5);
assert!(picks.is_empty());
}
#[test]
fn cell_tracker_populated_after_observe() {
let mut session = TriageSession::new(&two_arms(), TriageSessionConfig::default()).unwrap();
session.observe("arm_a", OutcomeIdx::OK, &[0.1, 0.2]);
assert_eq!(session.tracker().total_calls(), 1);
}
}