use crate::{stable_hash64, stable_hash64_u64};
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct WorstFirstConfig {
pub exploration_c: f64,
pub hard_weight: f64,
pub soft_weight: f64,
}
pub fn worst_first_pick_one<FObs, FSum>(
seed: u64,
remaining: &[String],
cfg: WorstFirstConfig,
mut observed_calls: FObs,
mut summary: FSum,
) -> Option<(String, bool)>
where
FObs: FnMut(&str) -> u64,
FSum: FnMut(&str) -> (u64, f64, f64), {
if remaining.is_empty() {
return None;
}
let mut unseen: Vec<String> = remaining
.iter()
.filter(|b| observed_calls(b.as_str()) == 0)
.cloned()
.collect();
if !unseen.is_empty() {
unseen.sort_by_key(|b| stable_hash64(seed ^ 0x574F_5253, b)); return Some((unseen[0].clone(), true));
}
let mut total_calls_f: f64 = 0.0;
let mut scored: Vec<(f64, String, u64, f64, f64)> = Vec::new();
for b in remaining {
let (calls_u64, hard, soft) = summary(b.as_str());
let calls = (calls_u64 as f64).max(1.0);
total_calls_f += calls;
scored.push((0.0, b.clone(), calls_u64, hard, soft));
}
let total_calls_f = total_calls_f.max(1.0);
for row in &mut scored {
let calls = (row.2 as f64).max(1.0);
let exploration = cfg.exploration_c * ((total_calls_f.ln() / calls).sqrt());
let score = cfg.hard_weight * row.3 + cfg.soft_weight * row.4 + exploration;
row.0 = score;
}
scored.sort_by(|a, b| {
b.0.total_cmp(&a.0)
.then_with(|| {
stable_hash64(seed ^ 0x574F_5253, &a.1)
.cmp(&stable_hash64(seed ^ 0x574F_5253, &b.1))
})
.then_with(|| a.1.cmp(&b.1))
});
let pick = scored.first().map(|r| r.1.clone())?;
Some((pick, false))
}
#[must_use]
pub fn worst_first_pick_k<FObs, FSum>(
seed: u64,
arms: &[String],
k: usize,
cfg: WorstFirstConfig,
mut observed_calls: FObs,
mut summary: FSum,
) -> Vec<(String, bool)>
where
FObs: FnMut(&str) -> u64,
FSum: FnMut(&str) -> (u64, f64, f64),
{
if k == 0 || arms.is_empty() {
return Vec::new();
}
let mut chosen: Vec<(String, bool)> = Vec::new();
let mut remaining: Vec<String> = arms.to_vec();
while chosen.len() < k && !remaining.is_empty() {
let (pick, explore_first) = match worst_first_pick_one(
seed ^ ((chosen.len() as u64) + 1),
&remaining,
cfg,
|b| observed_calls(b),
|b| summary(b),
) {
None => break,
Some(x) => x,
};
remaining.retain(|b| b != &pick);
chosen.push((pick, explore_first));
}
chosen
}
#[derive(Debug, Clone, Copy)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ContextBinConfig {
pub levels: u8,
pub seed: u64,
}
impl Default for ContextBinConfig {
fn default() -> Self {
Self {
levels: 4,
seed: 0xC0B1_C0B1,
}
}
}
pub fn context_bin(context: &[f64], config: ContextBinConfig) -> u64 {
let levels = (config.levels.max(1)) as u64;
let mut key = String::with_capacity(context.len() * 4);
for (i, &v) in context.iter().enumerate() {
let clamped = if v.is_finite() {
v.clamp(0.0, 1.0)
} else {
0.0
};
let bucket = ((clamped * levels as f64).floor() as u64).min(levels - 1);
if i > 0 {
key.push(':');
}
key.push_str(&bucket.to_string());
}
stable_hash64(config.seed, &key)
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ContextualCell {
pub arm: String,
pub context_bin: u64,
}
impl ContextualCell {
pub fn new(arm: impl Into<String>, context_bin: u64) -> Self {
Self {
arm: arm.into(),
context_bin,
}
}
}
#[must_use]
pub fn contextual_worst_first_pick_one<FCalls, FSummary>(
seed: u64,
arms: &[String],
active_bins: &[u64],
cfg: WorstFirstConfig,
mut cell_calls: FCalls,
mut cell_summary: FSummary,
) -> Option<(ContextualCell, bool)>
where
FCalls: FnMut(&str, u64) -> u64,
FSummary: FnMut(&str, u64) -> (u64, f64, f64),
{
if arms.is_empty() || active_bins.is_empty() {
return None;
}
const TIE_SEED: u64 = 0xCE11_C0B1;
let mut unseen: Vec<ContextualCell> = Vec::new();
for arm in arms {
for &bin in active_bins {
if cell_calls(arm.as_str(), bin) == 0 {
unseen.push(ContextualCell::new(arm.clone(), bin));
}
}
}
if !unseen.is_empty() {
unseen.sort_by_key(|c| {
let h1 = stable_hash64(seed ^ TIE_SEED, &c.arm);
let h2 = stable_hash64_u64(seed ^ TIE_SEED ^ 1, c.context_bin);
h1 ^ h2.rotate_left(32)
});
return Some((unseen[0].clone(), true));
}
let mut total_calls: f64 = 0.0;
let mut scored: Vec<(f64, ContextualCell, u64, f64, f64)> = Vec::new();
for arm in arms {
for &bin in active_bins {
let (calls_u64, hard, soft) = cell_summary(arm.as_str(), bin);
let calls = (calls_u64 as f64).max(1.0);
total_calls += calls;
scored.push((
0.0,
ContextualCell::new(arm.clone(), bin),
calls_u64,
hard,
soft,
));
}
}
let total_calls = total_calls.max(1.0);
for row in &mut scored {
let calls = (row.2 as f64).max(1.0);
let exploration = cfg.exploration_c * ((total_calls.ln() / calls).sqrt());
row.0 = cfg.hard_weight * row.3 + cfg.soft_weight * row.4 + exploration;
}
scored.sort_by(|a, b| {
b.0.total_cmp(&a.0).then_with(|| {
let ha = stable_hash64(seed ^ TIE_SEED, &a.1.arm)
^ stable_hash64_u64(seed ^ TIE_SEED ^ 2, a.1.context_bin);
let hb = stable_hash64(seed ^ TIE_SEED, &b.1.arm)
^ stable_hash64_u64(seed ^ TIE_SEED ^ 2, b.1.context_bin);
ha.cmp(&hb)
})
});
scored.into_iter().next().map(|r| (r.1, false))
}
#[must_use]
pub fn contextual_worst_first_pick_k<FCalls, FSummary>(
seed: u64,
arms: &[String],
active_bins: &[u64],
k: usize,
cfg: WorstFirstConfig,
mut cell_calls: FCalls,
mut cell_summary: FSummary,
) -> Vec<(ContextualCell, bool)>
where
FCalls: FnMut(&str, u64) -> u64,
FSummary: FnMut(&str, u64) -> (u64, f64, f64),
{
if k == 0 || arms.is_empty() || active_bins.is_empty() {
return Vec::new();
}
let mut chosen: Vec<(ContextualCell, bool)> = Vec::new();
let mut remaining_bins: Vec<u64> = active_bins.to_vec();
let remaining_arms: Vec<String> = arms.to_vec();
while chosen.len() < k {
if remaining_arms.is_empty() || remaining_bins.is_empty() {
break;
}
let pick_seed = seed ^ ((chosen.len() as u64) + 1).wrapping_mul(0x9E37_79B9);
match contextual_worst_first_pick_one(
pick_seed,
&remaining_arms,
&remaining_bins,
cfg,
|a, b| cell_calls(a, b),
|a, b| cell_summary(a, b),
) {
None => break,
Some((cell, explore)) => {
remaining_bins.retain(|&b| b != cell.context_bin);
chosen.push((cell, explore));
}
}
}
chosen
}
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct CellStats {
pub calls: u64,
pub hard_junk: u64,
pub soft_junk: u64,
}
impl CellStats {
pub fn hard_junk_rate(&self) -> f64 {
if self.calls == 0 {
return 0.0;
}
self.hard_junk as f64 / self.calls as f64
}
pub fn soft_junk_rate(&self) -> f64 {
if self.calls == 0 {
return 0.0;
}
self.soft_junk as f64 / self.calls as f64
}
}
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ContextualCoverageTracker {
cells: std::collections::BTreeMap<(String, u64), CellStats>,
}
impl ContextualCoverageTracker {
pub fn new() -> Self {
Self::default()
}
pub fn record(&mut self, arm: &str, bin: u64, hard_junk: bool, soft_junk: bool) {
let stats = self.cells.entry((arm.to_string(), bin)).or_default();
stats.calls += 1;
if hard_junk {
stats.hard_junk += 1;
}
if soft_junk {
stats.soft_junk += 1;
}
}
pub fn cell_calls(&self, arm: &str, bin: u64) -> u64 {
self.cells
.get(&(arm.to_string(), bin))
.map(|s| s.calls)
.unwrap_or(0)
}
pub fn cell_summary(&self, arm: &str, bin: u64) -> (u64, f64, f64) {
match self.cells.get(&(arm.to_string(), bin)) {
None => (0, 0.0, 0.0),
Some(s) => (s.calls, s.hard_junk_rate(), s.soft_junk_rate()),
}
}
pub fn active_bins(&self) -> Vec<u64> {
let mut bins: std::collections::BTreeSet<u64> = std::collections::BTreeSet::new();
for &(_, bin) in self.cells.keys() {
bins.insert(bin);
}
bins.into_iter().collect()
}
pub fn get(&self, arm: &str, bin: u64) -> Option<&CellStats> {
self.cells.get(&(arm.to_string(), bin))
}
pub fn total_calls(&self) -> u64 {
self.cells.values().map(|s| s.calls).sum()
}
pub fn pick_one(
&self,
seed: u64,
arms: &[String],
active_bins: &[u64],
cfg: WorstFirstConfig,
) -> Option<(ContextualCell, bool)> {
contextual_worst_first_pick_one(
seed,
arms,
active_bins,
cfg,
|a, b| self.cell_calls(a, b),
|a, b| self.cell_summary(a, b),
)
}
pub fn pick_k(
&self,
seed: u64,
arms: &[String],
active_bins: &[u64],
k: usize,
cfg: WorstFirstConfig,
) -> Vec<(ContextualCell, bool)> {
contextual_worst_first_pick_k(
seed,
arms,
active_bins,
k,
cfg,
|a, b| self.cell_calls(a, b),
|a, b| self.cell_summary(a, b),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn arms() -> Vec<String> {
vec!["a".into(), "b".into(), "c".into()]
}
fn default_cfg() -> WorstFirstConfig {
WorstFirstConfig {
exploration_c: 1.0,
hard_weight: 2.0,
soft_weight: 1.0,
}
}
#[test]
fn picks_unseen_first() {
let (pick, explore) =
worst_first_pick_one(42, &arms(), default_cfg(), |_| 0, |_| (0, 0.0, 0.0)).unwrap();
assert!(explore, "should flag explore_first for unseen arms");
assert!(arms().contains(&pick));
}
#[test]
fn prefers_highest_badness() {
let (pick, explore) = worst_first_pick_one(
42,
&arms(),
default_cfg(),
|_| 10, |name| match name {
"a" => (10, 0.1, 0.1),
"b" => (10, 0.9, 0.9), "c" => (10, 0.5, 0.5),
_ => (10, 0.0, 0.0),
},
)
.unwrap();
assert!(!explore);
assert_eq!(pick, "b", "should pick the arm with highest badness");
}
#[test]
fn pick_k_returns_at_most_k() {
let result = worst_first_pick_k(42, &arms(), 2, default_cfg(), |_| 0, |_| (0, 0.0, 0.0));
assert_eq!(result.len(), 2);
assert_ne!(result[0].0, result[1].0);
}
#[test]
fn pick_k_empty_arms() {
let result = worst_first_pick_k(42, &[], 5, default_cfg(), |_| 0, |_| (0, 0.0, 0.0));
assert!(result.is_empty());
}
#[test]
fn deterministic_given_seed() {
let a = worst_first_pick_k(99, &arms(), 3, default_cfg(), |_| 0, |_| (0, 0.0, 0.0));
let b = worst_first_pick_k(99, &arms(), 3, default_cfg(), |_| 0, |_| (0, 0.0, 0.0));
assert_eq!(a, b);
}
#[test]
fn context_bin_is_deterministic() {
let cfg = ContextBinConfig::default();
assert_eq!(
context_bin(&[0.1, 0.5, 0.9], cfg),
context_bin(&[0.1, 0.5, 0.9], cfg)
);
}
#[test]
fn context_bin_differs_for_different_contexts() {
let cfg = ContextBinConfig::default();
let b1 = context_bin(&[0.1, 0.9], cfg);
let b2 = context_bin(&[0.9, 0.1], cfg);
assert_ne!(b1, b2);
}
#[test]
fn context_bin_single_level_all_same() {
let cfg = ContextBinConfig { levels: 1, seed: 0 };
assert_eq!(context_bin(&[0.0], cfg), context_bin(&[1.0], cfg));
}
#[test]
fn context_bin_clamps_nonfinite() {
let cfg = ContextBinConfig::default();
let nan_bin = context_bin(&[f64::NAN, 0.5], cfg);
let zero_bin = context_bin(&[0.0, 0.5], cfg);
assert_eq!(nan_bin, zero_bin, "NaN clamped to 0 → same bin as 0.0");
}
fn bins_ab() -> Vec<u64> {
vec![1, 2]
}
#[test]
fn contextual_picks_unseen_cell_first() {
let (cell, explore) = contextual_worst_first_pick_one(
42,
&arms(),
&bins_ab(),
default_cfg(),
|_, _| 0,
|_, _| (0, 0.0, 0.0),
)
.unwrap();
assert!(explore);
assert!(arms().contains(&cell.arm));
assert!(bins_ab().contains(&cell.context_bin));
}
#[test]
fn contextual_prefers_worst_cell() {
let arms2 = vec!["x".to_string(), "y".to_string()];
let bins = vec![10u64, 20u64];
let (cell, explore) = contextual_worst_first_pick_one(
42,
&arms2,
&bins,
default_cfg(),
|_, _| 10, |arm, bin| match (arm, bin) {
("x", 10) => (10, 0.1, 0.1),
("x", 20) => (10, 0.2, 0.2),
("y", 10) => (10, 0.1, 0.1),
("y", 20) => (10, 0.9, 0.9), _ => (10, 0.0, 0.0),
},
)
.unwrap();
assert!(!explore);
assert_eq!(cell.arm, "y");
assert_eq!(cell.context_bin, 20);
}
#[test]
fn contextual_empty_bins_returns_none() {
assert!(contextual_worst_first_pick_one(
42,
&arms(),
&[],
default_cfg(),
|_, _| 0,
|_, _| (0, 0.0, 0.0),
)
.is_none());
}
#[test]
fn contextual_pick_k_returns_unique_bins() {
let bins = vec![1u64, 2, 3];
let result = contextual_worst_first_pick_k(
42,
&arms(),
&bins,
3,
default_cfg(),
|_, _| 0,
|_, _| (0, 0.0, 0.0),
);
assert_eq!(result.len(), 3);
let picked_bins: Vec<u64> = result.iter().map(|(c, _)| c.context_bin).collect();
let mut dedup = picked_bins.clone();
dedup.sort();
dedup.dedup();
assert_eq!(dedup.len(), picked_bins.len(), "bins should be unique");
}
#[test]
fn contextual_pick_k_deterministic() {
let bins = vec![1u64, 2, 3];
let a = contextual_worst_first_pick_k(
7,
&arms(),
&bins,
3,
default_cfg(),
|_, _| 0,
|_, _| (0, 0.0, 0.0),
);
let b = contextual_worst_first_pick_k(
7,
&arms(),
&bins,
3,
default_cfg(),
|_, _| 0,
|_, _| (0, 0.0, 0.0),
);
assert_eq!(a, b);
}
}