#![forbid(unsafe_code)]
#![warn(missing_docs)]
use pare::{Direction, ParetoFrontier};
use std::collections::{BTreeMap, VecDeque};
const TIEBREAK_EPS: f64 = 1e-12;
mod decision;
pub use decision::{Decision, DecisionNote, DecisionPolicy};
mod policy;
#[cfg(feature = "stochastic")]
pub use policy::BanditPolicy;
mod alloc;
pub use alloc::softmax_map;
mod utils;
pub use utils::suggested_window_cap;
mod control;
pub use control::{pick_control_arms, split_control_budget, ControlConfig};
mod router;
pub use router::{Router, RouterConfig, RouterDecision, RouterMode, RouterSnapshot};
mod guardrail;
pub use guardrail::LatencyGuardrailConfig;
pub mod monitor;
mod coverage;
pub use coverage::{coverage_pick_under_sampled, coverage_pick_under_sampled_idx, CoverageConfig};
#[cfg(feature = "stochastic")]
mod exp3ix;
#[cfg(feature = "stochastic")]
pub use exp3ix::{Exp3Ix, Exp3IxConfig, Exp3IxState};
#[cfg(feature = "stochastic")]
mod thompson;
#[cfg(feature = "stochastic")]
pub use thompson::{BetaStats, ThompsonConfig, ThompsonSampling, ThompsonState};
#[cfg(feature = "boltzmann")]
mod boltzmann;
#[cfg(feature = "boltzmann")]
pub use boltzmann::{BoltzmannConfig, BoltzmannPolicy};
#[cfg(feature = "contextual")]
mod contextual;
#[cfg(feature = "contextual")]
pub use contextual::{LinUcb, LinUcbArmState, LinUcbConfig, LinUcbScore, LinUcbState};
mod sticky;
pub use sticky::{StickyConfig, StickyMab};
mod stable_hash;
pub use stable_hash::stable_hash64;
pub(crate) use stable_hash::stable_hash64_u64;
mod novelty;
pub use novelty::novelty_pick_unseen;
pub(crate) use novelty::pick_random_subset;
mod prior;
pub use prior::apply_prior_counts_to_summary;
mod worst_first;
pub use worst_first::{
context_bin, contextual_worst_first_pick_k, contextual_worst_first_pick_one,
worst_first_pick_k, worst_first_pick_one, ContextBinConfig, ContextualCell,
ContextualCoverageTracker, WorstFirstConfig,
};
mod harness;
pub use harness::{
guardrail_filter_observed, guardrail_filter_observed_elapsed, policy_fill_generic,
policy_fill_k_observed_guardrail_first_with_coverage, policy_fill_k_observed_with_coverage,
policy_plan_generic, select_k_without_replacement_by, PipelineOrder, PolicyFill, PolicyPlan,
};
#[cfg(feature = "contextual")]
pub use harness::{policy_fill_k_contextual, ContextualPolicyFill};
mod triage;
pub use triage::{OutcomeIdx, TriageSession, TriageSessionConfig};
#[cfg(feature = "stochastic")]
pub use monitor::{calibrate_cusum_threshold, simulate_cusum_null_max_scores};
pub use monitor::{
calibrate_threshold_from_max_scores, DriftConfig, DriftMetric, MonitoredWindow, RateBoundMode,
ThresholdCalibration, UncertaintyConfig,
};
#[derive(Debug, Clone, Copy, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize))]
#[non_exhaustive]
pub struct Outcome {
pub ok: bool,
pub junk: bool,
pub hard_junk: bool,
pub cost_units: u64,
pub elapsed_ms: u64,
#[cfg_attr(feature = "serde", serde(skip_serializing_if = "Option::is_none"))]
pub quality_score: Option<f64>,
}
impl Outcome {
pub fn new(ok: bool, junk: bool, hard_junk: bool, cost_units: u64, elapsed_ms: u64) -> Self {
Self {
ok,
junk: junk || hard_junk,
hard_junk,
cost_units,
elapsed_ms,
quality_score: None,
}
}
pub fn success(cost_units: u64, elapsed_ms: u64) -> Self {
Self {
ok: true,
junk: false,
hard_junk: false,
cost_units,
elapsed_ms,
quality_score: None,
}
}
pub fn failure(cost_units: u64, elapsed_ms: u64) -> Self {
Self {
ok: false,
junk: true,
hard_junk: true,
cost_units,
elapsed_ms,
quality_score: None,
}
}
pub fn degraded(cost_units: u64, elapsed_ms: u64) -> Self {
Self {
ok: true,
junk: true,
hard_junk: false,
cost_units,
elapsed_ms,
quality_score: None,
}
}
pub fn with_quality(
ok: bool,
junk: bool,
hard_junk: bool,
cost_units: u64,
elapsed_ms: u64,
quality_score: f64,
) -> Self {
Self {
ok,
junk: junk || hard_junk,
hard_junk,
cost_units,
elapsed_ms,
quality_score: Some(quality_score.clamp(0.0, 1.0)),
}
}
}
#[cfg(feature = "serde")]
impl<'de> serde::Deserialize<'de> for Outcome {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(serde::Deserialize)]
struct Raw {
ok: bool,
junk: bool,
hard_junk: bool,
cost_units: u64,
elapsed_ms: u64,
#[serde(default)]
quality_score: Option<f64>,
}
let raw = Raw::deserialize(deserializer)?;
Ok(Outcome {
ok: raw.ok,
junk: raw.junk || raw.hard_junk,
hard_junk: raw.hard_junk,
cost_units: raw.cost_units,
elapsed_ms: raw.elapsed_ms,
quality_score: raw.quality_score.map(|s| s.clamp(0.0, 1.0)),
})
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Window {
cap: usize,
buf: VecDeque<Outcome>,
}
impl Window {
pub fn new(cap: usize) -> Self {
Self {
cap: cap.max(1),
buf: VecDeque::new(),
}
}
pub fn cap(&self) -> usize {
self.cap
}
pub fn len(&self) -> usize {
self.buf.len()
}
pub fn is_empty(&self) -> bool {
self.buf.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = &Outcome> + '_ {
self.buf.iter()
}
pub fn push(&mut self, o: Outcome) {
if self.buf.len() == self.cap {
self.buf.pop_front();
}
self.buf.push_back(o);
}
pub fn set_last_junk_level(&mut self, junk: bool, hard_junk: bool) {
if let Some(last) = self.buf.back_mut() {
last.junk = junk;
last.hard_junk = hard_junk && junk;
}
}
pub fn set_last_quality_score(&mut self, score: f64) {
if let Some(last) = self.buf.back_mut() {
last.quality_score = Some(score.clamp(0.0, 1.0));
}
}
pub fn mean_quality_score(&self) -> Option<f64> {
let mut sum = 0.0_f64;
let mut count = 0u64;
for o in &self.buf {
if let Some(q) = o.quality_score {
sum += q;
count += 1;
}
}
if count == 0 {
None
} else {
Some(sum / count as f64)
}
}
pub fn summary(&self) -> Summary {
let n = self.buf.len() as u64;
if n == 0 {
return Summary::default();
}
let mut ok = 0u64;
let mut junk = 0u64;
let mut hard_junk = 0u64;
let mut cost_units = 0u64;
let mut elapsed_ms_sum = 0u64;
for o in &self.buf {
ok += o.ok as u64;
junk += o.junk as u64;
hard_junk += o.hard_junk as u64;
cost_units = cost_units.saturating_add(o.cost_units);
elapsed_ms_sum = elapsed_ms_sum.saturating_add(o.elapsed_ms);
}
let mut quality_sum = 0.0_f64;
let mut quality_count = 0u64;
for o in &self.buf {
if let Some(q) = o.quality_score {
quality_sum += q;
quality_count += 1;
}
}
let mean_quality_score = if quality_count > 0 {
Some(quality_sum / quality_count as f64)
} else {
None
};
Summary {
calls: n,
ok,
junk,
hard_junk,
cost_units,
elapsed_ms_sum,
mean_quality_score,
}
}
}
#[derive(Debug, Clone, Copy, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Summary {
pub calls: u64,
pub ok: u64,
pub junk: u64,
pub hard_junk: u64,
pub cost_units: u64,
pub elapsed_ms_sum: u64,
#[cfg_attr(
feature = "serde",
serde(default, skip_serializing_if = "Option::is_none")
)]
pub mean_quality_score: Option<f64>,
}
impl Summary {
pub fn ok_rate(&self) -> f64 {
if self.calls == 0 {
0.0
} else {
(self.ok as f64) / (self.calls as f64)
}
}
pub fn junk_rate(&self) -> f64 {
if self.calls == 0 {
0.0
} else {
(self.junk as f64) / (self.calls as f64)
}
}
pub fn mean_cost_units(&self) -> f64 {
if self.calls == 0 {
0.0
} else {
(self.cost_units as f64) / (self.calls as f64)
}
}
pub fn mean_elapsed_ms(&self) -> f64 {
if self.calls == 0 {
0.0
} else {
(self.elapsed_ms_sum as f64) / (self.calls as f64)
}
}
pub fn hard_junk_rate(&self) -> f64 {
if self.calls == 0 {
0.0
} else {
(self.hard_junk as f64) / (self.calls as f64)
}
}
pub fn soft_junk_rate(&self) -> f64 {
if self.calls == 0 {
0.0
} else {
let soft = self.junk.saturating_sub(self.hard_junk);
(soft as f64) / (self.calls as f64)
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum Extract {
OkRateUcb,
MeanCost,
MeanLatency,
HardJunkRate,
SoftJunkRate,
MeanQuality,
Custom,
}
impl Extract {
pub fn apply(self, s: &Summary, ucb: f64) -> f64 {
match self {
Self::OkRateUcb => s.ok_rate() + ucb,
Self::MeanCost => s.mean_cost_units(),
Self::MeanLatency => s.mean_elapsed_ms(),
Self::HardJunkRate => s.hard_junk_rate(),
Self::SoftJunkRate => s.soft_junk_rate(),
Self::MeanQuality => s.mean_quality_score.unwrap_or(0.0),
Self::Custom => 0.0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Objective {
pub extract: Extract,
pub direction: Direction,
pub weight: f64,
#[cfg_attr(
feature = "serde",
serde(default, skip_serializing_if = "Option::is_none")
)]
pub value: Option<f64>,
}
impl Objective {
pub fn maximize(extract: Extract, weight: f64) -> Self {
Self {
extract,
direction: Direction::Maximize,
weight,
value: None,
}
}
pub fn minimize(extract: Extract, weight: f64) -> Self {
Self {
extract,
direction: Direction::Minimize,
weight,
value: None,
}
}
pub fn custom(direction: Direction, weight: f64, value: f64) -> Self {
Self {
extract: Extract::Custom,
direction,
weight,
value: Some(value),
}
}
pub fn with_value(mut self, v: f64) -> Self {
self.value = Some(v);
self
}
pub fn resolve(&self, s: &Summary, ucb: f64) -> f64 {
self.value.unwrap_or_else(|| self.extract.apply(s, ucb))
}
pub fn pareto_value(&self, s: &Summary, ucb: f64) -> f64 {
let v = self.resolve(s, ucb);
match self.direction {
Direction::Maximize => v,
Direction::Minimize => -v,
}
}
pub fn scalar_contribution(&self, s: &Summary, ucb: f64) -> f64 {
let v = self.resolve(s, ucb);
match self.direction {
Direction::Maximize => self.weight * v,
Direction::Minimize => -(self.weight * v),
}
}
}
pub fn default_objectives() -> Vec<Objective> {
vec![
Objective::maximize(Extract::OkRateUcb, 1.0),
Objective::minimize(Extract::MeanCost, 0.0),
Objective::minimize(Extract::MeanLatency, 0.0),
Objective::minimize(Extract::HardJunkRate, 0.0),
Objective::minimize(Extract::SoftJunkRate, 0.0),
Objective::maximize(Extract::MeanQuality, 0.0),
]
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct MabConfig {
pub exploration_c: f64,
pub objectives: Vec<Objective>,
pub max_junk_rate: Option<f64>,
pub max_hard_junk_rate: Option<f64>,
pub max_mean_cost_units: Option<f64>,
}
impl Default for MabConfig {
fn default() -> Self {
Self {
exploration_c: 0.7,
objectives: default_objectives(),
max_junk_rate: None,
max_hard_junk_rate: None,
max_mean_cost_units: None,
}
}
}
impl MabConfig {
pub fn set_weight(&mut self, extract: Extract, weight: f64) {
if let Some(obj) = self.objectives.iter_mut().find(|o| o.extract == extract) {
obj.weight = weight;
}
}
pub fn with_cost_weight(mut self, w: f64) -> Self {
self.set_weight(Extract::MeanCost, w);
self
}
pub fn with_latency_weight(mut self, w: f64) -> Self {
self.set_weight(Extract::MeanLatency, w);
self
}
pub fn with_junk_weight(mut self, w: f64) -> Self {
self.set_weight(Extract::SoftJunkRate, w);
self
}
pub fn with_hard_junk_weight(mut self, w: f64) -> Self {
self.set_weight(Extract::HardJunkRate, w);
self
}
pub fn with_quality_weight(mut self, w: f64) -> Self {
self.set_weight(Extract::MeanQuality, w);
self
}
pub fn with_objectives(mut self, objectives: Vec<Objective>) -> Self {
self.objectives = objectives;
self
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct MonitoredMabConfig {
pub base: MabConfig,
pub max_drift: Option<f64>,
pub drift_metric: DriftMetric,
pub drift_weight: f64,
pub uncertainty: UncertaintyConfig,
pub max_catkl: Option<f64>,
pub catkl_alpha: f64,
pub catkl_min_baseline: u64,
pub catkl_min_recent: u64,
pub catkl_weight: f64,
pub max_cusum: Option<f64>,
pub cusum_alpha: f64,
pub cusum_min_baseline: u64,
pub cusum_min_recent: u64,
pub cusum_alt_p: Option<[f64; 4]>,
pub cusum_weight: f64,
}
impl Default for MonitoredMabConfig {
fn default() -> Self {
Self {
base: MabConfig::default(),
max_drift: None,
drift_metric: DriftMetric::default(),
drift_weight: 0.0,
uncertainty: UncertaintyConfig::default(),
max_catkl: None,
catkl_alpha: 1e-3,
catkl_min_baseline: 40,
catkl_min_recent: 20,
catkl_weight: 0.0,
max_cusum: None,
cusum_alpha: 1e-3,
cusum_min_baseline: 40,
cusum_min_recent: 20,
cusum_alt_p: None,
cusum_weight: 0.0,
}
}
}
impl From<MabConfig> for MonitoredMabConfig {
fn from(base: MabConfig) -> Self {
Self {
base,
..Self::default()
}
}
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ObjectiveValue {
pub extract: Extract,
pub value: f64,
pub pareto_value: f64,
pub scalar_contribution: f64,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct CandidateDebug {
pub name: String,
pub summary: Summary,
pub ucb: f64,
pub objective_values: Vec<ObjectiveValue>,
pub score: f64,
#[cfg_attr(
feature = "serde",
serde(default, skip_serializing_if = "Option::is_none")
)]
pub drift_score: Option<f64>,
#[cfg_attr(
feature = "serde",
serde(default, skip_serializing_if = "Option::is_none")
)]
pub catkl_score: Option<f64>,
#[cfg_attr(
feature = "serde",
serde(default, skip_serializing_if = "Option::is_none")
)]
pub cusum_score: Option<f64>,
#[cfg_attr(
feature = "serde",
serde(default, skip_serializing_if = "Option::is_none")
)]
pub ok_half_width: Option<f64>,
#[cfg_attr(
feature = "serde",
serde(default, skip_serializing_if = "Option::is_none")
)]
pub junk_half_width: Option<f64>,
#[cfg_attr(
feature = "serde",
serde(default, skip_serializing_if = "Option::is_none")
)]
pub hard_junk_half_width: Option<f64>,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Selection {
pub chosen: String,
pub frontier: Vec<String>,
pub candidates: Vec<CandidateDebug>,
pub config: MabConfig,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct MabSelectionDecision {
pub selection: Selection,
pub eligible_arms: Vec<String>,
pub constraints_fallback_used: bool,
pub explore_first: bool,
#[cfg_attr(
feature = "serde",
serde(default, skip_serializing_if = "Option::is_none")
)]
pub drift_guard: Option<DriftGuardDecision>,
#[cfg_attr(
feature = "serde",
serde(default, skip_serializing_if = "Option::is_none")
)]
pub catkl_guard: Option<CatKlGuardDecision>,
#[cfg_attr(
feature = "serde",
serde(default, skip_serializing_if = "Option::is_none")
)]
pub cusum_guard: Option<CusumGuardDecision>,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct DriftGuardDecision {
pub eligible_arms: Vec<String>,
pub fallback_used: bool,
pub metric: DriftMetric,
pub max_drift: f64,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct CatKlGuardDecision {
pub eligible_arms: Vec<String>,
pub fallback_used: bool,
pub max_catkl: f64,
pub alpha: f64,
pub min_baseline: u64,
pub min_recent: u64,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct CusumGuardDecision {
pub eligible_arms: Vec<String>,
pub fallback_used: bool,
pub max_cusum: f64,
pub alpha: f64,
pub min_baseline: u64,
pub min_recent: u64,
pub alt_p: [f64; 4],
}
fn apply_base_constraints(
arms_in_order: &[String],
summaries: &BTreeMap<String, Summary>,
cfg: &MabConfig,
) -> (Vec<String>, bool) {
let mut eligible: Vec<String> = Vec::new();
for a in arms_in_order {
let s = summaries.get(a).copied().unwrap_or_default();
let ok = cfg
.max_junk_rate
.map(|thr| s.junk_rate() <= thr)
.unwrap_or(true)
&& cfg
.max_hard_junk_rate
.map(|thr| s.hard_junk_rate() <= thr)
.unwrap_or(true)
&& cfg
.max_mean_cost_units
.map(|thr| s.mean_cost_units() <= thr)
.unwrap_or(true);
if ok {
eligible.push(a.clone());
}
}
let constraints_fallback_used = eligible.is_empty();
let eligible_arms: Vec<String> = if constraints_fallback_used {
arms_in_order.to_vec()
} else {
eligible
};
(eligible_arms, constraints_fallback_used)
}
fn explore_first_decision(
chosen: String,
eligible_arms: Vec<String>,
constraints_fallback_used: bool,
cfg: MabConfig,
) -> MabSelectionDecision {
let zero_summary = Summary::default();
let obj_values: Vec<ObjectiveValue> = cfg
.objectives
.iter()
.map(|obj| ObjectiveValue {
extract: obj.extract,
value: 0.0,
pareto_value: 0.0,
scalar_contribution: 0.0,
})
.collect();
let sel = Selection {
chosen: chosen.clone(),
frontier: vec![chosen.clone()],
candidates: vec![CandidateDebug {
name: chosen,
summary: zero_summary,
ucb: 0.0,
objective_values: obj_values,
score: 0.0,
drift_score: None,
catkl_score: None,
cusum_score: None,
ok_half_width: None,
junk_half_width: None,
hard_junk_half_width: None,
}],
config: cfg,
};
MabSelectionDecision {
selection: sel,
eligible_arms,
constraints_fallback_used,
explore_first: true,
drift_guard: None,
catkl_guard: None,
cusum_guard: None,
}
}
fn choose_from_frontier(
candidates: &[CandidateDebug],
frontier_names_in_order: &[String],
fallback_first: Option<&String>,
) -> (String, Vec<String>) {
let dims = candidates
.first()
.map(|c| c.objective_values.len())
.unwrap_or(0);
let mut frontier = ParetoFrontier::new(vec![Direction::Maximize; dims]);
for (i, c) in candidates.iter().enumerate() {
let pt: Vec<f64> = c.objective_values.iter().map(|o| o.pareto_value).collect();
frontier.push(pt, i);
}
let mut frontier_indices: Vec<usize> = if frontier.is_empty() {
(0..candidates.len()).collect()
} else {
frontier.points().iter().map(|p| p.data).collect()
};
frontier_indices.sort_unstable();
let frontier_names: Vec<String> = frontier_indices
.iter()
.filter_map(|&i| frontier_names_in_order.get(i).cloned())
.collect();
let mut best_name = frontier_names
.first()
.cloned()
.unwrap_or_else(|| fallback_first.cloned().unwrap_or_default());
let mut best_score = f64::NEG_INFINITY;
for &idx in &frontier_indices {
let Some(c) = candidates.get(idx) else {
continue;
};
if c.score > best_score
|| ((c.score - best_score).abs() <= TIEBREAK_EPS && c.name < best_name)
{
best_score = c.score;
best_name = c.name.clone();
}
}
(best_name, frontier_names)
}
pub fn select_mab(
arms_in_order: &[String],
summaries: &BTreeMap<String, Summary>,
cfg: MabConfig,
) -> Selection {
select_mab_explain(arms_in_order, summaries, cfg).selection
}
pub fn select_mab_explain(
arms_in_order: &[String],
summaries: &BTreeMap<String, Summary>,
cfg: MabConfig,
) -> MabSelectionDecision {
let (eligible_arms, constraints_fallback_used) =
apply_base_constraints(arms_in_order, summaries, &cfg);
let arms_in_order: &[String] = &eligible_arms;
let explore_choice: Option<String> = arms_in_order
.iter()
.find(|a| summaries.get(*a).copied().unwrap_or_default().calls == 0)
.cloned();
if let Some(chosen) = explore_choice {
return explore_first_decision(chosen, eligible_arms, constraints_fallback_used, cfg);
}
let total_calls: f64 = arms_in_order
.iter()
.map(|a| summaries.get(a).copied().unwrap_or_default().calls as f64)
.sum::<f64>()
.max(1.0);
let mut frontier_names_in_order: Vec<String> = Vec::new();
let mut candidates = Vec::new();
for a in arms_in_order {
let s = summaries.get(a).copied().unwrap_or_default();
let n = (s.calls as f64).max(1.0);
let ucb = cfg.exploration_c * ((total_calls.ln() / n).sqrt());
let obj_values: Vec<ObjectiveValue> = cfg
.objectives
.iter()
.map(|obj| {
let value = obj.resolve(&s, ucb);
ObjectiveValue {
extract: obj.extract,
value,
pareto_value: obj.pareto_value(&s, ucb),
scalar_contribution: obj.scalar_contribution(&s, ucb),
}
})
.collect();
let score: f64 = obj_values.iter().map(|o| o.scalar_contribution).sum();
candidates.push(CandidateDebug {
name: a.clone(),
summary: s,
ucb,
objective_values: obj_values,
score,
drift_score: None,
catkl_score: None,
cusum_score: None,
ok_half_width: None,
junk_half_width: None,
hard_junk_half_width: None,
});
frontier_names_in_order.push(a.clone());
}
let (best_name, frontier_names) =
choose_from_frontier(&candidates, &frontier_names_in_order, arms_in_order.first());
let sel = Selection {
chosen: best_name,
frontier: frontier_names,
candidates,
config: cfg,
};
MabSelectionDecision {
selection: sel,
eligible_arms,
constraints_fallback_used,
explore_first: false,
drift_guard: None,
catkl_guard: None,
cusum_guard: None,
}
}
pub fn select_mab_monitored_explain(
arms_in_order: &[String],
monitored: &BTreeMap<String, MonitoredWindow>,
drift_cfg: DriftConfig,
cfg: MonitoredMabConfig,
) -> MabSelectionDecision {
let summaries: BTreeMap<String, Summary> = monitored
.iter()
.map(|(k, w)| (k.clone(), w.recent_summary()))
.collect();
select_mab_monitored_explain_with_summaries(
arms_in_order,
&summaries,
monitored,
drift_cfg,
cfg,
)
}
pub fn select_mab_monitored_explain_with_summaries(
arms_in_order: &[String],
summaries: &BTreeMap<String, Summary>,
monitored: &BTreeMap<String, MonitoredWindow>,
drift_cfg: DriftConfig,
cfg: MonitoredMabConfig,
) -> MabSelectionDecision {
let base = &cfg.base;
let (eligible_arms, constraints_fallback_used) =
apply_base_constraints(arms_in_order, summaries, base);
let arms_in_order: &[String] = &eligible_arms;
let explore_choice: Option<String> = arms_in_order
.iter()
.find(|a| summaries.get(*a).copied().unwrap_or_default().calls == 0)
.cloned();
if let Some(chosen) = explore_choice {
return explore_first_decision(
chosen,
eligible_arms,
constraints_fallback_used,
base.clone(),
);
}
let max_drift = cfg
.max_drift
.and_then(|x| (x.is_finite() && x >= 0.0).then_some(x));
let mut eligible_after_drift = arms_in_order.to_vec();
let mut drift_guard: Option<DriftGuardDecision> = None;
if let Some(thr) = max_drift {
let mut kept: Vec<String> = Vec::new();
for a in arms_in_order {
let Some(w) = monitored.get(a) else {
kept.push(a.clone());
continue;
};
let d = monitor::drift_between_windows(
w.baseline(),
w.recent(),
DriftConfig {
metric: cfg.drift_metric,
..drift_cfg
},
);
let violates = d.as_ref().map(|x| x.score > thr).unwrap_or(false);
if !violates {
kept.push(a.clone());
}
}
let fallback_used = kept.is_empty();
let eligible_arms = if fallback_used {
arms_in_order.to_vec()
} else {
kept
};
drift_guard = Some(DriftGuardDecision {
eligible_arms: eligible_arms.clone(),
fallback_used,
metric: cfg.drift_metric,
max_drift: thr,
});
eligible_after_drift = eligible_arms;
}
let max_catkl = cfg
.max_catkl
.and_then(|x| (x.is_finite() && x >= 0.0).then_some(x));
let catkl_alpha = if cfg.catkl_alpha.is_finite() && cfg.catkl_alpha > 0.0 {
cfg.catkl_alpha
} else {
1e-3
};
let mut eligible_after_catkl = eligible_after_drift.clone();
let mut catkl_guard: Option<CatKlGuardDecision> = None;
if let Some(thr) = max_catkl {
let mut kept: Vec<String> = Vec::new();
for a in &eligible_after_drift {
let Some(w) = monitored.get(a) else {
kept.push(a.clone());
continue;
};
let s = monitor::catkl_score_between_windows(
w.baseline(),
w.recent(),
catkl_alpha,
drift_cfg.tol,
cfg.catkl_min_baseline,
cfg.catkl_min_recent,
);
let violates = s.map(|x| x > thr).unwrap_or(false);
if !violates {
kept.push(a.clone());
}
}
let fallback_used = kept.is_empty();
let eligible_arms = if fallback_used {
eligible_after_drift.clone()
} else {
kept
};
catkl_guard = Some(CatKlGuardDecision {
eligible_arms: eligible_arms.clone(),
fallback_used,
max_catkl: thr,
alpha: catkl_alpha,
min_baseline: cfg.catkl_min_baseline,
min_recent: cfg.catkl_min_recent,
});
eligible_after_catkl = eligible_arms;
}
let max_cusum = cfg
.max_cusum
.and_then(|x| (x.is_finite() && x >= 0.0).then_some(x));
let cusum_alpha = if cfg.cusum_alpha.is_finite() && cfg.cusum_alpha > 0.0 {
cfg.cusum_alpha
} else {
1e-3
};
let cusum_alt_p = cfg.cusum_alt_p.unwrap_or([0.05, 0.05, 0.45, 0.45]);
let mut eligible_after_cusum = eligible_after_catkl.clone();
let mut cusum_guard: Option<CusumGuardDecision> = None;
if let Some(thr) = max_cusum {
let mut kept: Vec<String> = Vec::new();
for a in &eligible_after_catkl {
let Some(w) = monitored.get(a) else {
kept.push(a.clone());
continue;
};
let s = monitor::cusum_score_between_windows(
w.baseline(),
w.recent(),
cusum_alpha,
drift_cfg.tol,
cfg.cusum_min_baseline,
cfg.cusum_min_recent,
Some(cusum_alt_p),
);
let violates = s.map(|x| x > thr).unwrap_or(false);
if !violates {
kept.push(a.clone());
}
}
let fallback_used = kept.is_empty();
let eligible_arms = if fallback_used {
eligible_after_catkl.clone()
} else {
kept
};
cusum_guard = Some(CusumGuardDecision {
eligible_arms: eligible_arms.clone(),
fallback_used,
max_cusum: thr,
alpha: cusum_alpha,
min_baseline: cfg.cusum_min_baseline,
min_recent: cfg.cusum_min_recent,
alt_p: cusum_alt_p,
});
eligible_after_cusum = eligible_arms;
}
let total_calls: f64 = eligible_after_cusum
.iter()
.map(|a| summaries.get(a).copied().unwrap_or_default().calls as f64)
.sum::<f64>()
.max(1.0);
let mut frontier_names_in_order: Vec<String> = Vec::new();
let mut candidates: Vec<CandidateDebug> = Vec::new();
let monitoring_objectives = [
Objective::minimize(Extract::MeanCost, cfg.drift_weight.max(0.0)), Objective::minimize(Extract::MeanCost, cfg.catkl_weight.max(0.0)),
Objective::minimize(Extract::MeanCost, cfg.cusum_weight.max(0.0)),
];
for a in &eligible_after_cusum {
let s = summaries.get(a).copied().unwrap_or_default();
let n = (s.calls as f64).max(1.0);
let z = cfg.uncertainty.z;
let soft = s.junk.saturating_sub(s.hard_junk);
let (ok_rate_used, ok_half) =
monitor::apply_rate_bound(s.ok, s.calls, z, cfg.uncertainty.ok_mode);
let (hard_used, hard_half) =
monitor::apply_rate_bound(s.hard_junk, s.calls, z, cfg.uncertainty.hard_junk_mode);
let (soft_used, soft_half) =
monitor::apply_rate_bound(soft, s.calls, z, cfg.uncertainty.junk_mode);
let drift_score = monitored.get(a).and_then(|w| {
monitor::drift_between_windows(
w.baseline(),
w.recent(),
DriftConfig {
metric: cfg.drift_metric,
..drift_cfg
},
)
.map(|x| x.score)
});
let catkl_score = monitored.get(a).and_then(|w| {
monitor::catkl_score_between_windows(
w.baseline(),
w.recent(),
catkl_alpha,
drift_cfg.tol,
cfg.catkl_min_baseline,
cfg.catkl_min_recent,
)
});
let cusum_score = monitored.get(a).and_then(|w| {
monitor::cusum_score_between_windows(
w.baseline(),
w.recent(),
cusum_alpha,
drift_cfg.tol,
cfg.cusum_min_baseline,
cfg.cusum_min_recent,
Some(cusum_alt_p),
)
});
let ucb = base.exploration_c * ((total_calls.ln() / n).sqrt());
let mut obj_values: Vec<ObjectiveValue> = base
.objectives
.iter()
.map(|obj| {
let value = match obj.extract {
Extract::OkRateUcb => ok_rate_used + ucb,
Extract::HardJunkRate => hard_used,
Extract::SoftJunkRate => soft_used,
_ => obj.resolve(&s, ucb),
};
let pv = match obj.direction {
Direction::Maximize => value,
Direction::Minimize => -value,
};
let sc = match obj.direction {
Direction::Maximize => obj.weight * value,
Direction::Minimize => -(obj.weight * value),
};
ObjectiveValue {
extract: obj.extract,
value,
pareto_value: pv,
scalar_contribution: sc,
}
})
.collect();
let mon_values = [
drift_score.unwrap_or(0.0),
catkl_score.unwrap_or(0.0),
cusum_score.unwrap_or(0.0),
];
for (mon_obj, &mon_val) in monitoring_objectives.iter().zip(mon_values.iter()) {
obj_values.push(ObjectiveValue {
extract: mon_obj.extract,
value: mon_val,
pareto_value: -mon_val, scalar_contribution: -(mon_obj.weight * mon_val),
});
}
let score: f64 = obj_values.iter().map(|o| o.scalar_contribution).sum();
candidates.push(CandidateDebug {
name: a.clone(),
summary: s,
ucb,
objective_values: obj_values,
score,
drift_score,
catkl_score,
cusum_score,
ok_half_width: Some(ok_half),
junk_half_width: Some(soft_half),
hard_junk_half_width: Some(hard_half),
});
frontier_names_in_order.push(a.clone());
}
let (best_name, frontier_names) = choose_from_frontier(
&candidates,
&frontier_names_in_order,
eligible_after_cusum.first(),
);
let sel = Selection {
chosen: best_name,
frontier: frontier_names,
candidates,
config: base.clone(),
};
MabSelectionDecision {
selection: sel,
eligible_arms,
constraints_fallback_used,
explore_first: false,
drift_guard,
catkl_guard,
cusum_guard,
}
}
pub fn select_mab_decide(
arms_in_order: &[String],
summaries: &BTreeMap<String, Summary>,
cfg: MabConfig,
) -> Decision {
let d = select_mab_explain(arms_in_order, summaries, cfg);
let mut notes = vec![DecisionNote::Constraints {
eligible_arms: d.eligible_arms.clone(),
fallback_used: d.constraints_fallback_used,
}];
if d.explore_first {
notes.push(DecisionNote::ExploreFirst);
} else {
notes.push(DecisionNote::DeterministicChoice);
}
Decision {
policy: DecisionPolicy::Mab,
chosen: d.selection.chosen.clone(),
probs: None,
notes,
}
}
pub fn select_mab_monitored_decide(
arms_in_order: &[String],
monitored: &BTreeMap<String, MonitoredWindow>,
drift_cfg: DriftConfig,
cfg: MonitoredMabConfig,
) -> Decision {
let d = select_mab_monitored_explain(arms_in_order, monitored, drift_cfg, cfg);
let mut notes = vec![DecisionNote::Constraints {
eligible_arms: d.eligible_arms.clone(),
fallback_used: d.constraints_fallback_used,
}];
if let Some(ref dg) = d.drift_guard {
notes.push(DecisionNote::DriftGuard {
eligible_arms: dg.eligible_arms.clone(),
fallback_used: dg.fallback_used,
metric: dg.metric,
max_drift: dg.max_drift,
});
}
if let Some(ref cg) = d.catkl_guard {
notes.push(DecisionNote::CatKlGuard {
eligible_arms: cg.eligible_arms.clone(),
fallback_used: cg.fallback_used,
max_catkl: cg.max_catkl,
alpha: cg.alpha,
min_baseline: cg.min_baseline,
min_recent: cg.min_recent,
});
}
if let Some(ref ug) = d.cusum_guard {
notes.push(DecisionNote::CusumGuard {
eligible_arms: ug.eligible_arms.clone(),
fallback_used: ug.fallback_used,
max_cusum: ug.max_cusum,
alpha: ug.alpha,
min_baseline: ug.min_baseline,
min_recent: ug.min_recent,
alt_p: ug.alt_p,
});
}
if d.explore_first {
notes.push(DecisionNote::ExploreFirst);
} else {
notes.push(DecisionNote::DeterministicChoice);
}
let chosen_row = d
.selection
.candidates
.iter()
.find(|c| c.name == d.selection.chosen);
if let Some(c) = chosen_row {
notes.push(DecisionNote::Diagnostics {
drift_score: c.drift_score,
catkl_score: c.catkl_score,
cusum_score: c.cusum_score,
ok_half_width: c.ok_half_width,
junk_half_width: c.junk_half_width,
hard_junk_half_width: c.hard_junk_half_width,
mean_quality_score: c.summary.mean_quality_score,
});
}
Decision {
policy: DecisionPolicy::Mab,
chosen: d.selection.chosen.clone(),
probs: None,
notes,
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
fn mk_test_candidate(name: &str, score: f64) -> CandidateDebug {
CandidateDebug {
name: name.to_string(),
summary: Summary::default(),
ucb: 0.0,
objective_values: vec![],
score,
drift_score: None,
catkl_score: None,
cusum_score: None,
ok_half_width: None,
junk_half_width: None,
hard_junk_half_width: None,
}
}
fn mk_test_candidate_with_calls(name: &str, calls: u64, score: f64) -> CandidateDebug {
CandidateDebug {
name: name.to_string(),
summary: Summary {
calls,
..Summary::default()
},
ucb: 0.0,
objective_values: vec![],
score,
drift_score: None,
catkl_score: None,
cusum_score: None,
ok_half_width: None,
junk_half_width: None,
hard_junk_half_width: None,
}
}
fn s(
calls: u64,
ok: u64,
junk: u64,
hard_junk: u64,
cost_units: u64,
elapsed_ms_sum: u64,
) -> Summary {
Summary {
calls,
ok,
junk,
hard_junk,
cost_units,
elapsed_ms_sum,
mean_quality_score: None,
}
}
#[test]
fn select_mab_is_deterministic_and_prefers_lower_junk_all_else_equal() {
let arms = vec!["a".to_string(), "b".to_string()];
let mut m = BTreeMap::new();
m.insert("a".to_string(), s(10, 9, 5, 0, 10, 1000));
m.insert("b".to_string(), s(10, 9, 0, 0, 10, 1000));
let sel1 = select_mab(&arms, &m, MabConfig::default());
let sel2 = select_mab(&arms, &m, MabConfig::default());
assert_eq!(sel1.chosen, "b");
assert_eq!(sel1.chosen, sel2.chosen);
}
#[test]
fn constraints_filter_arms_but_never_return_empty() {
let arms = vec!["a".to_string(), "b".to_string()];
let mut m = BTreeMap::new();
m.insert("a".to_string(), s(10, 9, 9, 0, 10, 1000));
m.insert("b".to_string(), s(10, 9, 9, 0, 10, 1000));
let cfg = MabConfig {
max_junk_rate: Some(0.1),
..MabConfig::default()
};
let sel = select_mab(&arms, &m, cfg);
assert!(!sel.chosen.is_empty());
assert!(sel.frontier.iter().any(|x| x == &sel.chosen));
}
#[test]
fn constraints_can_exclude_high_hard_junk_arm() {
let arms = vec!["a".to_string(), "b".to_string()];
let mut m = BTreeMap::new();
m.insert("a".to_string(), s(10, 9, 1, 1, 10, 1000));
m.insert("b".to_string(), s(10, 9, 1, 0, 10, 1000));
let cfg = MabConfig {
max_hard_junk_rate: Some(0.05),
..MabConfig::default()
};
let sel = select_mab(&arms, &m, cfg);
assert_eq!(sel.chosen, "b");
}
proptest! {
#[test]
fn select_mab_never_panics_and_returns_member_of_arms(
calls_a in 0u64..50,
calls_b in 0u64..50,
ok_a in 0u64..50,
ok_b in 0u64..50,
junk_a in 0u64..50,
junk_b in 0u64..50,
hard_a in 0u64..50,
hard_b in 0u64..50,
cost_a in 0u64..500,
cost_b in 0u64..500,
lat_a in 0u64..50_000,
lat_b in 0u64..50_000,
) {
let arms = vec!["a".to_string(), "b".to_string()];
let mut m = BTreeMap::new();
let sa = s(
calls_a,
ok_a.min(calls_a),
junk_a.min(calls_a),
hard_a.min(junk_a.min(calls_a)),
cost_a,
lat_a,
);
let sb = s(
calls_b,
ok_b.min(calls_b),
junk_b.min(calls_b),
hard_b.min(junk_b.min(calls_b)),
cost_b,
lat_b,
);
m.insert("a".to_string(), sa);
m.insert("b".to_string(), sb);
let cfg = MabConfig {
exploration_c: 0.7,
..MabConfig::default()
};
let sel = select_mab(&arms, &m, cfg.clone());
prop_assert!(sel.chosen == "a" || sel.chosen == "b");
prop_assert!(sel.frontier.iter().any(|x| x == &sel.chosen));
let sel2 = select_mab(&arms, &m, cfg.clone());
prop_assert_eq!(sel.chosen, sel2.chosen);
}
#[test]
fn select_mab_ignores_summaries_for_unknown_arms(
calls_a in 1u64..50,
calls_b in 1u64..50,
ok_a in 0u64..50,
ok_b in 0u64..50,
junk_a in 0u64..50,
junk_b in 0u64..50,
hard_a in 0u64..50,
hard_b in 0u64..50,
cost_a in 0u64..500,
cost_b in 0u64..500,
lat_a in 0u64..50_000,
lat_b in 0u64..50_000,
extra_calls in 0u64..50,
extra_ok in 0u64..50,
extra_junk in 0u64..50,
extra_hard in 0u64..50,
extra_cost in 0u64..500,
extra_lat in 0u64..50_000,
) {
let arms = vec!["a".to_string(), "b".to_string()];
let mut m = BTreeMap::new();
let sa = s(
calls_a,
ok_a.min(calls_a),
junk_a.min(calls_a),
hard_a.min(junk_a.min(calls_a)),
cost_a,
lat_a,
);
let sb = s(
calls_b,
ok_b.min(calls_b),
junk_b.min(calls_b),
hard_b.min(junk_b.min(calls_b)),
cost_b,
lat_b,
);
m.insert("a".to_string(), sa);
m.insert("b".to_string(), sb);
let cfg = MabConfig::default();
let sel1 = select_mab(&arms, &m, cfg.clone());
let sx = s(
extra_calls,
extra_ok.min(extra_calls),
extra_junk.min(extra_calls),
extra_hard.min(extra_junk.min(extra_calls)),
extra_cost,
extra_lat,
);
m.insert("zzz-extra".to_string(), sx);
let sel2 = select_mab(&arms, &m, cfg);
prop_assert_eq!(sel1.chosen, sel2.chosen);
}
#[test]
fn select_mab_explores_first_zero_call_arm(
calls_a in 0u64..10,
calls_b in 0u64..10,
calls_c in 0u64..10,
) {
let arms = vec!["a".to_string(), "b".to_string(), "c".to_string()];
let mut m = BTreeMap::new();
m.insert("a".to_string(), s(calls_a, 0, 0, 0, 0, 0));
m.insert("b".to_string(), s(calls_b, 0, 0, 0, 0, 0));
m.insert("c".to_string(), s(calls_c, 0, 0, 0, 0, 0));
let expected = if calls_a == 0 {
"a"
} else if calls_b == 0 {
"b"
} else if calls_c == 0 {
"c"
} else {
return Ok(());
};
let sel = select_mab(&arms, &m, MabConfig::default());
prop_assert_eq!(sel.chosen, expected);
}
}
#[test]
fn sticky_mab_respects_min_dwell() {
let arms = vec!["a".to_string(), "b".to_string()];
let cfg = MabConfig::default();
let mut sticky = StickyMab::new(StickyConfig {
min_dwell: 3,
min_switch_margin: 0.0,
});
let mut m1 = BTreeMap::new();
m1.insert("a".to_string(), s(10, 10, 0, 0, 0, 0));
m1.insert("b".to_string(), s(10, 5, 0, 0, 0, 0));
let e1 = sticky.apply_mab(select_mab_explain(&arms, &m1, cfg.clone()));
assert_eq!(e1.chosen, "a");
assert_eq!(sticky.dwell(), 1);
let mut m2 = BTreeMap::new();
m2.insert("a".to_string(), s(10, 5, 0, 0, 0, 0));
m2.insert("b".to_string(), s(10, 10, 0, 0, 0, 0));
let e2 = sticky.apply_mab(select_mab_explain(&arms, &m2, cfg.clone()));
assert_eq!(e2.chosen, "a");
let e3 = sticky.apply_mab(select_mab_explain(&arms, &m2, cfg.clone()));
assert_eq!(e3.chosen, "a");
let e4 = sticky.apply_mab(select_mab_explain(&arms, &m2, cfg));
assert_eq!(e4.chosen, "b");
assert_eq!(sticky.dwell(), 1);
}
#[test]
fn sticky_mab_respects_min_switch_margin() {
let cfg = MabConfig::default();
let mut sticky = StickyMab::new(StickyConfig {
min_dwell: 0,
min_switch_margin: 0.5,
});
let mk = |chosen: &str, a_score: f64, b_score: f64| -> MabSelectionDecision {
MabSelectionDecision {
selection: Selection {
chosen: chosen.to_string(),
frontier: vec!["a".to_string(), "b".to_string()],
candidates: vec![
mk_test_candidate_with_calls("a", 10, a_score),
mk_test_candidate_with_calls("b", 10, b_score),
],
config: cfg.clone(),
},
eligible_arms: vec!["a".to_string(), "b".to_string()],
constraints_fallback_used: false,
explore_first: false,
drift_guard: None,
catkl_guard: None,
cusum_guard: None,
}
};
let e1 = sticky.apply_mab(mk("a", 1.0, 1.0));
assert_eq!(e1.chosen, "a");
assert_eq!(sticky.previous(), Some("a"));
let e2 = sticky.apply_mab(mk("b", 1.0, 1.4));
assert_eq!(e2.chosen, "a");
let e3 = sticky.apply_mab(mk("b", 1.0, 1.7));
assert_eq!(e3.chosen, "b");
assert_eq!(sticky.previous(), Some("b"));
}
#[test]
fn sticky_mab_follows_base_choice_if_previous_missing_from_candidates() {
let cfg = MabConfig::default();
let mut sticky = StickyMab::new(StickyConfig {
min_dwell: 10,
min_switch_margin: 100.0,
});
sticky.apply_mab(MabSelectionDecision {
selection: Selection {
chosen: "old".to_string(),
frontier: vec!["old".to_string()],
candidates: vec![mk_test_candidate("old", 0.0)],
config: cfg.clone(),
},
eligible_arms: vec!["old".to_string()],
constraints_fallback_used: false,
explore_first: true,
drift_guard: None,
catkl_guard: None,
cusum_guard: None,
});
assert_eq!(sticky.previous(), Some("old"));
let base = Selection {
chosen: "a".to_string(),
frontier: vec!["a".to_string()],
candidates: vec![mk_test_candidate_with_calls("a", 10, 0.0)],
config: cfg,
};
let e = sticky.apply_mab(MabSelectionDecision {
selection: base,
eligible_arms: vec!["a".to_string()],
constraints_fallback_used: false,
explore_first: false,
drift_guard: None,
catkl_guard: None,
cusum_guard: None,
});
assert_eq!(e.chosen, "a");
assert_eq!(sticky.previous(), Some("a"));
}
#[test]
fn select_mab_chosen_satisfies_constraints_when_eligible_exists() {
let arms = vec!["a".to_string(), "b".to_string()];
let mut m = BTreeMap::new();
m.insert("a".to_string(), s(100, 90, 80, 0, 10, 1000));
m.insert("b".to_string(), s(100, 90, 0, 0, 10, 1000));
let cfg = MabConfig {
max_junk_rate: Some(0.1),
..MabConfig::default()
};
let sel = select_mab(&arms, &m, cfg);
assert_eq!(sel.chosen, "b");
let s = m.get(&sel.chosen).copied().unwrap_or_default();
assert!(s.junk_rate() <= 0.1);
}
}