use crate::monitor::{DriftConfig, MonitoredWindow};
use crate::{
policy_fill_generic, select_mab_explain, select_mab_monitored_explain_with_summaries,
split_control_budget, worst_first_pick_k, ContextualCell, ControlConfig, CoverageConfig,
DriftMetric, LatencyGuardrailConfig, MonitoredMabConfig, Outcome, OutcomeIdx, PipelineOrder,
Summary, TriageSession, TriageSessionConfig, Window, WorstFirstConfig,
};
use std::collections::BTreeMap;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct RouterConfig {
pub mab: MonitoredMabConfig,
pub drift: DriftConfig,
pub window_cap: usize,
pub enable_monitoring: bool,
pub baseline_cap: usize,
pub recent_cap: usize,
pub triage_cfg: Option<TriageSessionConfig>,
pub triage_wf: WorstFirstConfig,
pub triage_fraction: f64,
pub coverage: CoverageConfig,
pub guardrail: LatencyGuardrailConfig,
pub pipeline_order: PipelineOrder,
pub novelty_enabled: bool,
pub control: ControlConfig,
}
impl Default for RouterConfig {
fn default() -> Self {
Self {
mab: MonitoredMabConfig::default(),
drift: DriftConfig {
metric: DriftMetric::Hellinger,
tol: 1e-9,
min_baseline: 20,
min_recent: 10,
},
window_cap: 100,
enable_monitoring: false,
baseline_cap: 500,
recent_cap: 50,
triage_cfg: None,
triage_wf: WorstFirstConfig {
exploration_c: 1.0,
hard_weight: 3.0,
soft_weight: 1.0,
},
triage_fraction: 0.5,
coverage: CoverageConfig::default(),
guardrail: LatencyGuardrailConfig::default(),
pipeline_order: PipelineOrder::NoveltyFirst,
novelty_enabled: true,
control: ControlConfig::default(),
}
}
}
impl RouterConfig {
pub fn with_monitoring(mut self, baseline_cap: usize, recent_cap: usize) -> Self {
self.enable_monitoring = true;
self.baseline_cap = baseline_cap;
self.recent_cap = recent_cap;
self
}
pub fn with_triage(mut self) -> Self {
self.triage_cfg = Some(TriageSessionConfig::default());
self
}
pub fn with_triage_cfg(mut self, cfg: TriageSessionConfig) -> Self {
self.triage_cfg = Some(cfg);
self
}
pub fn window_cap(mut self, cap: usize) -> Self {
self.window_cap = cap;
self
}
pub fn with_coverage(mut self, min_fraction: f64, min_floor: u64) -> Self {
self.coverage = CoverageConfig {
enabled: true,
min_fraction,
min_calls_floor: min_floor,
};
self
}
pub fn with_guardrail(mut self, max_mean_ms: f64) -> Self {
self.guardrail = LatencyGuardrailConfig {
max_mean_ms: Some(max_mean_ms),
require_measured: false,
allow_fewer: true,
};
self
}
pub fn with_control(mut self, control_k: usize) -> Self {
self.control = ControlConfig::with_k(control_k);
self
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum RouterMode {
Normal,
Triage {
alarmed_arms: Vec<String>,
},
}
impl RouterMode {
pub fn is_triage(&self) -> bool {
matches!(self, RouterMode::Triage { .. })
}
pub fn alarmed_arms(&self) -> &[String] {
match self {
RouterMode::Triage { alarmed_arms } => alarmed_arms,
RouterMode::Normal => &[],
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct RouterDecision {
pub chosen: Vec<String>,
pub mode: RouterMode,
pub prechosen: Vec<String>,
pub control_picks: Vec<String>,
pub mab_eligible: Vec<String>,
pub triage_cells: Vec<ContextualCell>,
}
impl RouterDecision {
fn empty(mode: RouterMode) -> Self {
Self {
chosen: Vec::new(),
mode,
prechosen: Vec::new(),
control_picks: Vec::new(),
mab_eligible: Vec::new(),
triage_cells: Vec::new(),
}
}
pub fn primary(&self) -> Option<&str> {
self.chosen.first().map(|s| s.as_str())
}
}
pub struct Router {
arms: Vec<String>,
windows: BTreeMap<String, Window>,
monitored: Option<BTreeMap<String, MonitoredWindow>>,
triage: Option<TriageSession>,
cfg: RouterConfig,
total_observations: u64,
}
impl Router {
pub fn new(arms: Vec<String>, cfg: RouterConfig) -> Result<Self, logp::Error> {
let windows: BTreeMap<String, Window> = arms
.iter()
.map(|a| (a.clone(), Window::new(cfg.window_cap.max(1))))
.collect();
let monitored = if cfg.enable_monitoring {
Some(
arms.iter()
.map(|a| {
(
a.clone(),
MonitoredWindow::new(cfg.baseline_cap.max(1), cfg.recent_cap.max(1)),
)
})
.collect(),
)
} else {
None
};
let triage = if let Some(ref tcfg) = cfg.triage_cfg {
Some(TriageSession::new(&arms, tcfg.clone())?)
} else {
None
};
Ok(Self {
arms,
windows,
monitored,
triage,
cfg,
total_observations: 0,
})
}
pub fn add_arm(&mut self, arm: String) -> Result<(), logp::Error> {
if self.arms.contains(&arm) {
return Ok(());
}
if let Some(ref mut t) = self.triage {
let tcfg = self.cfg.triage_cfg.clone().unwrap_or_default();
let mut all_arms = self.arms.clone();
all_arms.push(arm.clone());
*t = TriageSession::new(&all_arms, tcfg)?;
}
self.windows
.insert(arm.clone(), Window::new(self.cfg.window_cap.max(1)));
if let Some(ref mut m) = self.monitored {
m.insert(
arm.clone(),
MonitoredWindow::new(self.cfg.baseline_cap.max(1), self.cfg.recent_cap.max(1)),
);
}
self.arms.push(arm);
Ok(())
}
pub fn remove_arm(&mut self, arm: &str) -> Result<(), logp::Error> {
self.arms.retain(|a| a != arm);
self.windows.remove(arm);
if let Some(ref mut m) = self.monitored {
m.remove(arm);
}
if self.triage.is_some() {
let tcfg = self.cfg.triage_cfg.clone().unwrap_or_default();
match TriageSession::new(&self.arms, tcfg) {
Ok(t) => self.triage = Some(t),
Err(e) => {
self.triage = None;
return Err(e);
}
}
}
Ok(())
}
pub fn select(&self, k: usize, seed: u64) -> RouterDecision {
if k == 0 || self.arms.is_empty() {
return RouterDecision::empty(self.mode());
}
let mode = self.mode();
let alarmed = mode.alarmed_arms().to_vec();
let mut chosen: Vec<String> = Vec::new();
let mut triage_cells: Vec<ContextualCell> = Vec::new();
let (control_picks, remaining_k) =
split_control_budget(seed ^ 0xC0E1_1A11, &self.arms, k, self.cfg.control);
for a in &control_picks {
chosen.push(a.clone());
}
let remaining_arms: Vec<String> = self
.arms
.iter()
.filter(|a| !chosen.contains(*a))
.cloned()
.collect();
let triage_k = if !alarmed.is_empty() {
let frac = self.cfg.triage_fraction.clamp(0.0, 1.0);
((remaining_k as f64 * frac).ceil() as usize)
.max(1)
.min(alarmed.len())
.min(remaining_k.saturating_sub(1).max(1))
} else {
0
};
if triage_k > 0 {
let triage_arms: Vec<String> = alarmed
.iter()
.filter(|a| !chosen.contains(*a))
.cloned()
.collect();
let picks = worst_first_pick_k(
seed ^ 0x5452_4947,
&triage_arms,
triage_k,
self.cfg.triage_wf,
|arm| self.windows.get(arm).map(|w| w.len() as u64).unwrap_or(0),
|arm| {
let s = self
.windows
.get(arm)
.map(|w| w.summary())
.unwrap_or_default();
(s.calls, s.hard_junk_rate(), s.soft_junk_rate())
},
);
for (arm, _) in picks {
if chosen.len() < k {
chosen.push(arm);
}
}
if let Some(triage) = &self.triage {
let bins = triage.tracker().active_bins();
if !bins.is_empty() {
triage_cells = triage
.top_alarmed_cells(&bins, triage_k)
.into_iter()
.map(|(cell, _explore)| cell)
.collect();
}
}
}
let mab_k = k.saturating_sub(chosen.len());
let mab_arms: Vec<String> = remaining_arms
.iter()
.filter(|a| !chosen.contains(*a))
.cloned()
.collect();
if mab_k == 0 || mab_arms.is_empty() {
return RouterDecision {
chosen,
mode,
prechosen: Vec::new(),
control_picks,
mab_eligible: Vec::new(),
triage_cells,
};
}
let obs_snap: BTreeMap<String, (u64, u64)> = mab_arms
.iter()
.map(|a| {
let w = self.windows.get(a.as_str());
let calls = w.map(|w| w.len() as u64).unwrap_or(0);
let elapsed = w.map(|w| w.summary().elapsed_ms_sum).unwrap_or(0);
(a.clone(), (calls, elapsed))
})
.collect();
let sum_snap: BTreeMap<String, Summary> = mab_arms
.iter()
.map(|a| {
let s = self
.windows
.get(a.as_str())
.map(|w| w.summary())
.unwrap_or_default();
(a.clone(), s)
})
.collect();
let plan = policy_fill_generic(
seed ^ 0x504C_414E,
&mab_arms,
mab_k,
self.cfg.novelty_enabled,
self.cfg.coverage,
self.cfg.guardrail,
self.cfg.pipeline_order,
|arm| obs_snap.get(arm).copied().unwrap_or((0, 0)),
|eligible, need| {
self.select_mab_round(eligible, need, &sum_snap)
},
);
let prechosen = plan.plan.prechosen.clone();
let mab_eligible = plan.plan.eligible.clone();
for arm in &plan.chosen {
if chosen.len() < k {
chosen.push(arm.clone());
}
}
RouterDecision {
chosen,
mode,
prechosen,
control_picks,
mab_eligible,
triage_cells,
}
}
pub fn observe(&mut self, arm: &str, outcome: Outcome) -> bool {
self.observe_with_context(arm, outcome, &[])
}
pub fn observe_with_context(&mut self, arm: &str, outcome: Outcome, context: &[f64]) -> bool {
let known = self.windows.contains_key(arm);
if !known {
return false;
}
if let Some(w) = self.windows.get_mut(arm) {
w.push(outcome);
}
if let Some(ref mut m) = self.monitored {
if let Some(mw) = m.get_mut(arm) {
mw.push(outcome);
}
}
if let Some(ref mut t) = self.triage {
let idx = OutcomeIdx::from_outcome(outcome.ok, outcome.junk, outcome.hard_junk);
t.observe(arm, idx, context);
}
self.total_observations += 1;
true
}
pub fn set_last_quality_score(&mut self, arm: &str, score: f64) {
if let Some(w) = self.windows.get_mut(arm) {
w.set_last_quality_score(score);
}
if let Some(ref mut m) = self.monitored {
if let Some(mw) = m.get_mut(arm) {
mw.set_last_quality_score(score);
}
}
}
pub fn set_last_junk_level(&mut self, arm: &str, junk: bool, hard_junk: bool) {
if let Some(w) = self.windows.get_mut(arm) {
w.set_last_junk_level(junk, hard_junk);
}
if let Some(ref mut m) = self.monitored {
if let Some(mw) = m.get_mut(arm) {
mw.set_last_junk_level(junk, hard_junk);
}
}
}
pub fn acknowledge_change(&mut self, arm: &str) {
if let Some(ref mut t) = self.triage {
t.reset_arm(arm);
}
if let Some(ref mut m) = self.monitored {
if let Some(mw) = m.get_mut(arm) {
mw.acknowledge_change();
}
}
}
pub fn acknowledge_all_changes(&mut self) {
let alarmed = self.mode().alarmed_arms().to_vec();
for arm in &alarmed {
self.acknowledge_change(arm);
}
}
pub fn mode(&self) -> RouterMode {
let alarmed = self
.triage
.as_ref()
.map(|t| t.alarmed_arms())
.unwrap_or_default();
if alarmed.is_empty() {
RouterMode::Normal
} else {
RouterMode::Triage {
alarmed_arms: alarmed,
}
}
}
pub fn arms(&self) -> &[String] {
&self.arms
}
pub fn total_observations(&self) -> u64 {
self.total_observations
}
pub fn summary(&self, arm: &str) -> Summary {
self.windows
.get(arm)
.map(|w| w.summary())
.unwrap_or_default()
}
pub fn summaries(&self) -> BTreeMap<String, Summary> {
self.windows
.iter()
.map(|(arm, w)| (arm.clone(), w.summary()))
.collect()
}
pub fn mean_quality_score(&self, arm: &str) -> Option<f64> {
self.windows.get(arm)?.mean_quality_score()
}
pub fn window_len(&self, arm: &str) -> usize {
self.windows.get(arm).map(|w| w.len()).unwrap_or(0)
}
pub fn window(&self, arm: &str) -> Option<&Window> {
self.windows.get(arm)
}
pub fn monitored_window(&self, arm: &str) -> Option<&MonitoredWindow> {
self.monitored.as_ref()?.get(arm)
}
pub fn triage_session(&self) -> Option<&TriageSession> {
self.triage.as_ref()
}
fn select_mab_round(
&self,
eligible: &[String],
need: usize,
sum_snap: &BTreeMap<String, Summary>,
) -> Vec<String> {
if eligible.is_empty() || need == 0 {
return Vec::new();
}
let monitored = self.monitored.as_ref();
let mut remaining: Vec<String> = eligible.to_vec();
let mut picks: Vec<String> = Vec::new();
for _round in 0..need {
if remaining.is_empty() {
break;
}
let d = if let Some(mon) = monitored {
select_mab_monitored_explain_with_summaries(
&remaining,
sum_snap,
mon,
self.cfg.drift,
self.cfg.mab.clone(),
)
} else {
select_mab_explain(&remaining, sum_snap, self.cfg.mab.base.clone())
};
let pick = d.selection.chosen.clone();
remaining.retain(|a| a != &pick);
picks.push(pick);
}
picks
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct RouterSnapshot {
pub arms: Vec<String>,
pub windows: BTreeMap<String, Window>,
pub monitored: Option<BTreeMap<String, MonitoredWindow>>,
pub cfg: RouterConfig,
pub total_observations: u64,
}
impl Router {
pub fn snapshot(&self) -> RouterSnapshot {
RouterSnapshot {
arms: self.arms.clone(),
windows: self.windows.clone(),
monitored: self.monitored.clone(),
cfg: self.cfg.clone(),
total_observations: self.total_observations,
}
}
pub fn from_snapshot(snap: RouterSnapshot) -> Result<Self, logp::Error> {
let triage = if let Some(ref tcfg) = snap.cfg.triage_cfg {
Some(TriageSession::new(&snap.arms, tcfg.clone())?)
} else {
None
};
Ok(Self {
arms: snap.arms,
windows: snap.windows,
monitored: snap.monitored,
triage,
cfg: snap.cfg,
total_observations: snap.total_observations,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Outcome;
fn arms(n: usize) -> Vec<String> {
(0..n).map(|i| format!("arm{i}")).collect()
}
fn clean() -> Outcome {
Outcome {
ok: true,
junk: false,
hard_junk: false,
cost_units: 1,
elapsed_ms: 50,
quality_score: None,
}
}
fn bad() -> Outcome {
Outcome {
ok: false,
junk: true,
hard_junk: true,
cost_units: 1,
elapsed_ms: 50,
quality_score: None,
}
}
#[test]
fn router_select_returns_member_of_arms() {
let r = Router::new(arms(3), RouterConfig::default()).unwrap();
let d = r.select(1, 42);
assert_eq!(d.chosen.len(), 1);
assert!(r.arms().contains(&d.chosen[0]));
}
#[test]
fn router_select_multi_pick_unique() {
let r = Router::new(arms(5), RouterConfig::default()).unwrap();
let d = r.select(3, 7);
assert!(d.chosen.len() <= 3);
let mut s = d.chosen.clone();
s.sort();
s.dedup();
assert_eq!(s.len(), d.chosen.len(), "picks must be unique");
}
#[test]
fn router_observe_increments_total() {
let mut r = Router::new(arms(2), RouterConfig::default()).unwrap();
assert_eq!(r.total_observations(), 0);
r.observe("arm0", clean());
r.observe("arm1", clean());
assert_eq!(r.total_observations(), 2);
}
#[test]
fn router_select_never_returns_more_than_k() {
let r = Router::new(arms(2), RouterConfig::default()).unwrap();
let d = r.select(5, 0); assert!(d.chosen.len() <= 2);
}
#[test]
fn router_explores_unseen_arm_before_exploitation() {
let mut r = Router::new(arms(2), RouterConfig::default()).unwrap();
for _ in 0..50 {
r.observe("arm0", clean());
}
let d = r.select(1, 0);
assert_eq!(d.chosen[0], "arm1", "unseen arm should be explored first");
}
#[test]
fn router_prefers_better_arm_after_enough_data() {
let mut r = Router::new(arms(2), RouterConfig::default()).unwrap();
for _ in 0..50 {
r.observe("arm0", clean());
}
for _ in 0..50 {
r.observe(
"arm1",
Outcome {
ok: true,
junk: true,
hard_junk: false,
..clean()
},
);
}
let d = r.select(1, 0);
assert_eq!(d.chosen[0], "arm0", "arm0 has lower junk rate");
}
#[test]
fn router_triage_detects_hard_failure_arm() {
let tcfg = TriageSessionConfig {
min_n: 10,
threshold: 3.0,
..TriageSessionConfig::default()
};
let cfg = RouterConfig::default().with_triage_cfg(tcfg);
let mut r = Router::new(vec!["good".to_string(), "bad".to_string()], cfg).unwrap();
for _ in 0..20 {
r.observe("good", clean());
r.observe("bad", clean());
}
for _ in 0..30 {
r.observe("bad", bad());
}
assert!(
r.mode().is_triage(),
"should alarm after sustained hard failures"
);
assert!(
r.mode().alarmed_arms().contains(&"bad".to_string()),
"'bad' arm should be alarmed"
);
}
#[test]
fn router_acknowledge_change_resets_triage() {
let tcfg = TriageSessionConfig {
min_n: 5,
threshold: 2.0,
..TriageSessionConfig::default()
};
let cfg = RouterConfig::default().with_triage_cfg(tcfg);
let mut r = Router::new(vec!["a".to_string()], cfg).unwrap();
for _ in 0..10 {
r.observe("a", clean());
}
for _ in 0..20 {
r.observe("a", bad());
}
assert!(r.mode().is_triage());
r.acknowledge_change("a");
assert!(
!r.mode().is_triage(),
"mode should return to Normal after acknowledge"
);
}
#[test]
fn router_monitoring_windows_exist_when_enabled() {
let cfg = RouterConfig::default().with_monitoring(200, 50);
let r = Router::new(arms(3), cfg).unwrap();
for a in r.arms() {
assert!(
r.monitored_window(a).is_some(),
"monitored window should exist for {a}"
);
}
}
#[test]
fn router_acknowledge_promotes_recent_to_baseline() {
let cfg = RouterConfig::default().with_monitoring(200, 50);
let mut r = Router::new(vec!["a".to_string()], cfg).unwrap();
for _ in 0..20 {
r.observe("a", clean());
}
let before_recent = r.monitored_window("a").unwrap().recent_len();
assert!(before_recent > 0);
r.acknowledge_change("a");
let after_recent = r.monitored_window("a").unwrap().recent_len();
assert_eq!(
after_recent, 0,
"recent window should be cleared after acknowledge"
);
}
#[test]
fn router_add_arm_is_explored_next() {
let mut r = Router::new(arms(2), RouterConfig::default()).unwrap();
for _ in 0..50 {
r.observe("arm0", clean());
r.observe("arm1", clean());
}
r.add_arm("arm2".to_string()).unwrap();
let d = r.select(1, 0);
assert_eq!(
d.chosen[0], "arm2",
"newly added arm should be explored first"
);
}
#[test]
fn router_remove_arm_not_selected() {
let mut r = Router::new(arms(3), RouterConfig::default()).unwrap();
r.remove_arm("arm1").unwrap();
for _ in 0..100 {
let d = r.select(1, 0);
assert_ne!(d.chosen[0], "arm1", "removed arm must not be selected");
}
}
#[test]
fn router_large_k_covers_all_arms_with_multi_pick() {
let n = 30;
let cfg = RouterConfig::default();
let mut r = Router::new(arms(n), cfg).unwrap();
let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
for round in 0..15 {
let d = r.select(3, round as u64);
for a in &d.chosen {
seen.insert(a.clone());
}
for a in &d.chosen {
r.observe(a, clean());
}
}
assert_eq!(
seen.len(),
n,
"all {n} arms should be explored within 15 rounds (k=3)"
);
}
#[test]
fn router_large_k_with_coverage_prevents_starvation() {
let n = 20;
let cfg = RouterConfig::default().with_coverage(0.02, 1);
let mut r = Router::new(arms(n), cfg).unwrap();
for i in 0..200 {
let d = r.select(1, i as u64);
if let Some(arm) = d.primary() {
r.observe(arm, clean());
}
}
for a in r.arms() {
let s = r.summary(a);
assert!(
s.calls > 0,
"arm {a} should have at least 1 observation with coverage enabled"
);
}
}
#[test]
fn router_control_picks_are_subset_of_chosen() {
let cfg = RouterConfig::default().with_control(1);
let r = Router::new(arms(5), cfg).unwrap();
let d = r.select(3, 42);
for p in &d.control_picks {
assert!(d.chosen.contains(p), "control pick {p} must be in chosen");
}
}
#[test]
fn router_select_is_deterministic() {
let mut r = Router::new(arms(4), RouterConfig::default()).unwrap();
for _ in 0..20 {
r.observe("arm0", clean());
r.observe("arm1", bad());
}
let d1 = r.select(2, 99);
let d2 = r.select(2, 99);
assert_eq!(d1.chosen, d2.chosen, "same seed → same picks");
}
}