#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BudgetStatus {
Active,
Exhausted,
}
#[derive(Debug, Clone)]
pub struct ArmBudget {
pub complexity: usize,
pub samples: u64,
pub status: BudgetStatus,
}
impl ArmBudget {
pub fn new(complexity: usize) -> Self {
Self {
complexity,
samples: 0,
status: BudgetStatus::Active,
}
}
}
#[derive(Debug, Default, Clone)]
pub struct BudgetLedger {
entries: Vec<ArmBudget>,
}
impl BudgetLedger {
pub fn new() -> Self {
Self::default()
}
pub fn add_arm(&mut self, complexity: usize) -> usize {
let idx = self.entries.len();
self.entries.push(ArmBudget::new(complexity));
idx
}
pub fn record_sample(&mut self, idx: usize) {
if let Some(entry) = self.entries.get_mut(idx) {
entry.samples += 1;
}
}
pub fn recompute(&mut self) {
let total_samples: u64 = self.entries.iter().map(|e| e.samples).sum();
let total_complexity: u64 = self.entries.iter().map(|e| e.complexity as u64).sum();
if total_samples == 0 || total_complexity == 0 {
for e in &mut self.entries {
e.status = BudgetStatus::Active;
}
return;
}
for entry in &mut self.entries {
let lhs = entry.samples * total_complexity;
let rhs = (entry.complexity as u64) * total_samples;
entry.status = if lhs > rhs {
BudgetStatus::Exhausted
} else {
BudgetStatus::Active
};
}
}
pub fn status(&self, idx: usize) -> BudgetStatus {
self.entries
.get(idx)
.map(|e| e.status)
.unwrap_or(BudgetStatus::Active)
}
pub fn arm(&self, idx: usize) -> Option<&ArmBudget> {
self.entries.get(idx)
}
pub fn n_arms(&self) -> usize {
self.entries.len()
}
pub fn reset(&mut self) {
self.entries.clear();
}
pub fn total_cost(&self) -> u64 {
self.entries
.iter()
.map(|e| e.samples * e.complexity as u64)
.sum()
}
pub fn adjusted_metric(&self, idx: usize, base_metric: f64, scale: f64) -> f64 {
let total_samples: u64 = self.entries.iter().map(|e| e.samples).sum();
let total_complexity: u64 = self.entries.iter().map(|e| e.complexity as u64).sum();
if total_samples == 0 || total_complexity == 0 {
return base_metric;
}
let Some(entry) = self.entries.get(idx) else {
return base_metric;
};
let sample_frac = entry.samples as f64 / total_samples as f64;
let complexity_frac = entry.complexity as f64 / total_complexity as f64;
let penalty = sample_frac - complexity_frac;
base_metric + scale * penalty
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn budget_exhausts_arm_when_overrun() {
let mut ledger = BudgetLedger::new();
let simple = ledger.add_arm(10); let complex = ledger.add_arm(90);
for _ in 0..90 {
ledger.record_sample(simple);
}
for _ in 0..10 {
ledger.record_sample(complex);
}
ledger.recompute();
assert_eq!(
ledger.status(simple),
BudgetStatus::Exhausted,
"simple arm got 90% of samples but only 10% of complexity share — must be Exhausted"
);
assert_eq!(
ledger.status(complex),
BudgetStatus::Active,
"complex arm got 10% of samples and 90% of complexity share — must remain Active"
);
}
#[test]
fn budget_normalizes_across_arms_with_different_complexity() {
let mut ledger = BudgetLedger::new();
let a = ledger.add_arm(100);
let b = ledger.add_arm(200);
let c = ledger.add_arm(300);
for _ in 0..100 {
ledger.record_sample(a);
}
for _ in 0..200 {
ledger.record_sample(b);
}
for _ in 0..300 {
ledger.record_sample(c);
}
ledger.recompute();
assert_eq!(
ledger.status(a),
BudgetStatus::Active,
"arm a: 100 samples, complexity 100 — at fair share, must be Active"
);
assert_eq!(
ledger.status(b),
BudgetStatus::Active,
"arm b: 200 samples, complexity 200 — at fair share, must be Active"
);
assert_eq!(
ledger.status(c),
BudgetStatus::Active,
"arm c: 300 samples, complexity 300 — at fair share, must be Active"
);
for _ in 0..500 {
ledger.record_sample(a);
}
ledger.recompute();
assert_eq!(
ledger.status(a),
BudgetStatus::Exhausted,
"arm a overrun: now 600 samples vs 100 complexity — must be Exhausted"
);
}
#[test]
fn budget_total_cost_reflects_actual_usage() {
let mut ledger = BudgetLedger::new();
let idx = ledger.add_arm(50);
for _ in 0..20 {
ledger.record_sample(idx);
}
assert_eq!(
ledger.total_cost(),
1000,
"total cost should be 20 × 50 = 1000"
);
}
}