use std::collections::BTreeMap;
use super::knob::{CalibrationKnob, KnobValue};
use super::loop_runner::{ProposedPatch, Proposer, StepOutcome, StepReport};
#[derive(Debug, Clone, Default)]
struct KnobState {
last_direction: i32,
last_improved: bool,
}
pub struct GreedyKnobProposer {
state: BTreeMap<String, KnobState>,
next_knob: usize,
fails_per_knob: BTreeMap<String, usize>,
}
impl Default for GreedyKnobProposer {
fn default() -> Self {
Self::new()
}
}
impl GreedyKnobProposer {
pub fn new() -> Self {
Self {
state: BTreeMap::new(),
next_knob: 0,
fails_per_knob: BTreeMap::new(),
}
}
const MAX_FAILS: usize = 2;
}
impl Proposer for GreedyKnobProposer {
fn propose(
&mut self,
knobs: &[CalibrationKnob],
_current_loss: (f64, f64),
history: &[StepReport],
) -> Option<ProposedPatch> {
if knobs.is_empty() {
return None;
}
if let Some(last) = history.last() {
if let Some(patch) = &last.proposed_patch {
let path = knobs[patch.knob_index].path.clone();
let improved = matches!(last.outcome, StepOutcome::Improved);
let entry = self.state.entry(path.clone()).or_default();
entry.last_improved = improved;
if !improved {
*self.fails_per_knob.entry(path).or_default() += 1;
}
}
}
for offset in 0..knobs.len() {
let idx = (self.next_knob + offset) % knobs.len();
let path = &knobs[idx].path;
let fails = self.fails_per_knob.get(path).copied().unwrap_or(0);
if fails >= Self::MAX_FAILS {
continue;
}
let entry = self.state.entry(path.clone()).or_default();
let dir = match (entry.last_direction, entry.last_improved) {
(0, _) => 1,
(d, true) => d,
(d, false) => -d,
};
entry.last_direction = dir;
entry.last_improved = false;
let cur = knobs[idx].current.as_f64();
let step = knobs[idx].max_step * dir as f64;
let proposed_f = cur + step;
let proposed = match knobs[idx].current {
KnobValue::F64(_) => KnobValue::F64(proposed_f),
KnobValue::Usize(_) => KnobValue::Usize(proposed_f.round().max(0.0) as usize),
};
self.next_knob = (idx + 1) % knobs.len();
return Some(ProposedPatch {
knob_index: idx,
proposed_value: proposed,
rationale: format!(
"greedy: knob `{}` direction {dir} step {step}",
knobs[idx].path
),
});
}
None
}
}
pub struct RoundRobinProposer {
pub next_knob: usize,
}
impl RoundRobinProposer {
pub fn new() -> Self {
Self { next_knob: 0 }
}
}
impl Default for RoundRobinProposer {
fn default() -> Self {
Self::new()
}
}
impl Proposer for RoundRobinProposer {
fn propose(
&mut self,
knobs: &[CalibrationKnob],
_current_loss: (f64, f64),
_history: &[StepReport],
) -> Option<ProposedPatch> {
if knobs.is_empty() {
return None;
}
let idx = self.next_knob % knobs.len();
self.next_knob = (self.next_knob + 1) % knobs.len();
let cur = knobs[idx].current.as_f64();
let proposed = match knobs[idx].current {
KnobValue::F64(_) => KnobValue::F64(cur + knobs[idx].max_step),
KnobValue::Usize(_) => {
KnobValue::Usize((cur + knobs[idx].max_step).round().max(0.0) as usize)
}
};
Some(ProposedPatch {
knob_index: idx,
proposed_value: proposed,
rationale: format!("round-robin step on `{}`", knobs[idx].path),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::calibration::loop_runner::{ProposedPatch, StepReport};
use std::collections::BTreeMap;
fn knobs() -> Vec<CalibrationKnob> {
vec![
CalibrationKnob::new_f64("k.a", 0.05, 0.0, 1.0, 0.01),
CalibrationKnob::new_f64("k.b", 0.10, 0.0, 1.0, 0.02),
]
}
fn step_with_outcome(
iter: usize,
knob_index: usize,
proposed_value: KnobValue,
outcome: StepOutcome,
) -> StepReport {
StepReport {
iter,
loss_before_mean: 1.0,
loss_before_std: 0.0,
proposed_patch: Some(ProposedPatch {
knob_index,
proposed_value,
rationale: "test".into(),
}),
loss_after_mean: Some(1.0),
loss_after_std: Some(0.0),
knob_values: BTreeMap::new(),
outcome,
}
}
#[test]
fn round_robin_proposer_cycles_through_knobs() {
let mut prop = RoundRobinProposer::new();
let ks = knobs();
let p1 = prop.propose(&ks, (0.0, 0.0), &[]).unwrap();
assert_eq!(p1.knob_index, 0);
let p2 = prop.propose(&ks, (0.0, 0.0), &[]).unwrap();
assert_eq!(p2.knob_index, 1);
let p3 = prop.propose(&ks, (0.0, 0.0), &[]).unwrap();
assert_eq!(p3.knob_index, 0);
}
#[test]
fn round_robin_proposer_steps_by_max_step() {
let mut prop = RoundRobinProposer::new();
let ks = knobs();
let p = prop.propose(&ks, (0.0, 0.0), &[]).unwrap();
assert!(
(p.proposed_value.as_f64() - 0.06).abs() < 1e-12,
"expected 0.06, got {}",
p.proposed_value
);
}
#[test]
fn round_robin_proposer_empty_knobs_returns_none() {
let mut prop = RoundRobinProposer::new();
assert!(prop.propose(&[], (0.0, 0.0), &[]).is_none());
}
#[test]
fn greedy_proposer_first_call_picks_positive_direction() {
let mut prop = GreedyKnobProposer::new();
let ks = knobs();
let p = prop.propose(&ks, (0.0, 0.0), &[]).unwrap();
assert_eq!(p.knob_index, 0);
assert!((p.proposed_value.as_f64() - 0.06).abs() < 1e-12);
}
#[test]
fn greedy_proposer_continues_same_direction_after_improvement() {
let mut prop = GreedyKnobProposer::new();
let ks = knobs();
let _p1 = prop.propose(&ks, (0.0, 0.0), &[]).unwrap();
let h1 = vec![step_with_outcome(
0,
0,
KnobValue::F64(0.06),
StepOutcome::Improved,
)];
let p2 = prop.propose(&ks, (0.0, 0.0), &h1).unwrap();
assert_eq!(p2.knob_index, 1);
let h2 = vec![
step_with_outcome(0, 0, KnobValue::F64(0.06), StepOutcome::Improved),
step_with_outcome(1, 1, KnobValue::F64(0.12), StepOutcome::Improved),
];
let p3 = prop.propose(&ks, (0.0, 0.0), &h2).unwrap();
assert_eq!(p3.knob_index, 0);
assert!((p3.proposed_value.as_f64() - 0.06).abs() < 1e-12);
}
#[test]
fn greedy_proposer_flips_direction_after_failure() {
let mut prop = GreedyKnobProposer::new();
let ks = knobs();
let _p1 = prop.propose(&ks, (0.0, 0.0), &[]).unwrap();
let h1 = vec![step_with_outcome(
0,
0,
KnobValue::F64(0.06),
StepOutcome::Reverted,
)];
let _p2 = prop.propose(&ks, (0.0, 0.0), &h1).unwrap();
let h2 = vec![
step_with_outcome(0, 0, KnobValue::F64(0.06), StepOutcome::Reverted),
step_with_outcome(1, 1, KnobValue::F64(0.12), StepOutcome::Improved),
];
let p3 = prop.propose(&ks, (0.0, 0.0), &h2).unwrap();
assert_eq!(p3.knob_index, 0);
assert!(
(p3.proposed_value.as_f64() - 0.04).abs() < 1e-12,
"expected 0.04 (flipped direction), got {}",
p3.proposed_value
);
}
#[test]
fn greedy_proposer_exhausts_after_both_directions_fail() {
let mut prop = GreedyKnobProposer::new();
let ks = vec![CalibrationKnob::new_f64("k", 0.05, 0.0, 1.0, 0.01)];
let _p1 = prop.propose(&ks, (0.0, 0.0), &[]).unwrap();
let h1 = vec![step_with_outcome(
0,
0,
KnobValue::F64(0.06),
StepOutcome::Reverted,
)];
let _p2 = prop.propose(&ks, (0.0, 0.0), &h1).unwrap();
let h2 = vec![
step_with_outcome(0, 0, KnobValue::F64(0.06), StepOutcome::Reverted),
step_with_outcome(1, 0, KnobValue::F64(0.04), StepOutcome::Reverted),
];
let p3 = prop.propose(&ks, (0.0, 0.0), &h2);
assert!(
p3.is_none(),
"proposer should exhaust after both directions fail; got {:?}",
p3.map(|p| p.proposed_value)
);
}
}