use std::collections::BTreeMap;
use crate::{
CandidateDebug, Decision, DecisionNote, DecisionPolicy, MabSelectionDecision, Selection,
};
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct StickyConfig {
pub min_dwell: u64,
pub min_switch_margin: f64,
}
impl Default for StickyConfig {
fn default() -> Self {
Self {
min_dwell: 0,
min_switch_margin: 0.0,
}
}
}
fn f64_or0(x: f64) -> f64 {
if x.is_finite() {
x
} else {
0.0
}
}
pub fn mab_scalar_score(c: &CandidateDebug) -> f64 {
f64_or0(c.score)
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct StickyMab {
pub cfg: StickyConfig,
previous: Option<String>,
dwell: u64,
}
impl StickyMab {
pub fn new(cfg: StickyConfig) -> Self {
Self {
cfg,
previous: None,
dwell: 0,
}
}
pub fn previous(&self) -> Option<&str> {
self.previous.as_deref()
}
pub fn dwell(&self) -> u64 {
self.dwell
}
pub fn reset(&mut self) {
self.previous = None;
self.dwell = 0;
}
fn scores_by_arm(sel: &Selection) -> BTreeMap<&str, f64> {
let mut out = BTreeMap::new();
for c in &sel.candidates {
out.insert(c.name.as_str(), mab_scalar_score(c));
}
out
}
fn apply_inner(
&mut self,
mut sel: Selection,
explore_first: bool,
) -> (Selection, Vec<DecisionNote>) {
if explore_first {
self.previous = Some(sel.chosen.clone());
self.dwell = 1;
return (sel, Vec::new());
}
let candidate = sel.chosen.clone();
let Some(prev) = self.previous.clone() else {
self.previous = Some(candidate);
self.dwell = 1;
return (sel, Vec::new());
};
let scores = Self::scores_by_arm(&sel);
let Some(prev_score) = scores.get(prev.as_str()).copied() else {
self.previous = Some(candidate);
self.dwell = 1;
return (sel, Vec::new());
};
if candidate == prev {
self.dwell = self.dwell.saturating_add(1);
return (sel, Vec::new());
}
if self.cfg.min_dwell > 0 && self.dwell < self.cfg.min_dwell {
let note = DecisionNote::StickyKeptPreviousDwell {
previous: prev.clone(),
candidate,
dwell: self.dwell,
min_dwell: self.cfg.min_dwell,
};
sel.chosen = prev;
self.dwell = self.dwell.saturating_add(1);
return (sel, vec![note]);
}
let cand_score = scores
.get(candidate.as_str())
.copied()
.unwrap_or(f64::NEG_INFINITY);
let margin = cand_score - prev_score;
let min_margin = f64_or0(self.cfg.min_switch_margin);
if min_margin > 0.0 && !(margin.is_finite() && margin >= min_margin) {
let note = DecisionNote::StickyKeptPreviousMargin {
previous: prev.clone(),
candidate,
previous_score: prev_score,
candidate_score: cand_score,
margin,
min_margin,
};
sel.chosen = prev;
self.dwell = self.dwell.saturating_add(1);
return (sel, vec![note]);
}
let note = DecisionNote::StickySwitched {
previous: prev,
candidate: candidate.clone(),
previous_score: prev_score,
candidate_score: cand_score,
margin,
};
self.previous = Some(candidate);
self.dwell = 1;
(sel, vec![note])
}
pub fn apply(&mut self, sel: Selection) -> Selection {
let explore_first = sel.candidates.len() == 1 && sel.candidates[0].summary.calls == 0;
let (sel, _notes) = self.apply_inner(sel, explore_first);
sel
}
pub fn apply_mab(&mut self, decision: MabSelectionDecision) -> Selection {
let (sel, _notes) = self.apply_inner(decision.selection, decision.explore_first);
sel
}
pub fn apply_mab_decide(&mut self, decision: MabSelectionDecision) -> Decision {
let constraints = DecisionNote::Constraints {
eligible_arms: decision.eligible_arms.clone(),
fallback_used: decision.constraints_fallback_used,
};
let explore_first = decision.explore_first;
let (sel, sticky_notes) = self.apply_inner(decision.selection, explore_first);
let mut notes = vec![constraints];
if explore_first {
notes.push(DecisionNote::ExploreFirst);
} else {
notes.push(DecisionNote::DeterministicChoice);
}
notes.extend(sticky_notes);
Decision {
policy: DecisionPolicy::Mab,
chosen: sel.chosen,
probs: None,
notes,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{CandidateDebug, MabConfig, Summary};
fn mk_candidate(name: &str, score: f64) -> CandidateDebug {
CandidateDebug {
name: name.to_string(),
summary: Summary {
calls: 10,
..Summary::default()
},
ucb: 0.0,
objective_values: vec![],
score,
drift_score: None,
catkl_score: None,
cusum_score: None,
ok_half_width: None,
junk_half_width: None,
hard_junk_half_width: None,
}
}
fn mk_sel(previous: &str, candidate: &str) -> Selection {
let (a_score, b_score) = if previous == "a" {
(1.0, 2.0)
} else {
(2.0, 1.0)
};
Selection {
chosen: candidate.to_string(),
frontier: vec!["a".to_string(), "b".to_string()],
candidates: vec![mk_candidate("a", a_score), mk_candidate("b", b_score)],
config: MabConfig::default(),
}
}
#[test]
fn sticky_never_returns_arm_not_in_candidates() {
let mut sticky = StickyMab::new(StickyConfig {
min_dwell: 100,
min_switch_margin: 100.0,
});
let _ = sticky.apply(mk_sel("a", "a"));
let sel = Selection {
chosen: "x".to_string(),
frontier: vec!["x".to_string()],
candidates: vec![mk_candidate("x", 0.0)],
config: MabConfig::default(),
};
let out = sticky.apply(sel);
assert_eq!(out.chosen, "x");
}
}