use std::fmt;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum KnobValue {
F64(f64),
Usize(usize),
}
impl KnobValue {
pub fn to_yaml_string(&self) -> String {
match self {
Self::F64(v) => format!("{v}"),
Self::Usize(v) => v.to_string(),
}
}
pub fn as_f64(&self) -> f64 {
match self {
Self::F64(v) => *v,
Self::Usize(v) => *v as f64,
}
}
}
impl fmt::Display for KnobValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.to_yaml_string())
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum KnobBounds {
F64Range { min: f64, max: f64 },
UsizeRange { min: usize, max: usize },
}
impl KnobBounds {
pub fn width(&self) -> f64 {
match self {
Self::F64Range { min, max } => max - min,
Self::UsizeRange { min, max } => (max - min) as f64,
}
}
pub fn matches_value(&self, value: &KnobValue) -> bool {
matches!(
(self, value),
(Self::F64Range { .. }, KnobValue::F64(_))
| (Self::UsizeRange { .. }, KnobValue::Usize(_))
)
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum KnobClipResult {
InRange,
ClippedLow,
ClippedHigh,
TypeMismatch,
}
#[derive(Debug, Clone)]
pub struct CalibrationKnob {
pub path: String,
pub current: KnobValue,
pub bounds: KnobBounds,
pub max_step: f64,
}
impl CalibrationKnob {
pub fn new_f64(
path: impl Into<String>,
current: f64,
min: f64,
max: f64,
max_step: f64,
) -> Self {
Self {
path: path.into(),
current: KnobValue::F64(current),
bounds: KnobBounds::F64Range { min, max },
max_step,
}
}
pub fn new_usize(
path: impl Into<String>,
current: usize,
min: usize,
max: usize,
max_step: f64,
) -> Self {
Self {
path: path.into(),
current: KnobValue::Usize(current),
bounds: KnobBounds::UsizeRange { min, max },
max_step,
}
}
pub fn clip(&self, proposed: KnobValue) -> (KnobValue, KnobClipResult) {
if !self.bounds.matches_value(&proposed) {
return (self.current, KnobClipResult::TypeMismatch);
}
let cur_f = self.current.as_f64();
let prop_f = proposed.as_f64();
let delta = prop_f - cur_f;
let stepped_f = if delta.abs() > self.max_step {
cur_f + delta.signum() * self.max_step
} else {
prop_f
};
let (clipped_f, bound_result) = match self.bounds {
KnobBounds::F64Range { min, max } => {
if stepped_f < min {
(min, KnobClipResult::ClippedLow)
} else if stepped_f > max {
(max, KnobClipResult::ClippedHigh)
} else {
(stepped_f, KnobClipResult::InRange)
}
}
KnobBounds::UsizeRange { min, max } => {
let i = stepped_f.round().max(0.0) as usize;
if i < min {
(min as f64, KnobClipResult::ClippedLow)
} else if i > max {
(max as f64, KnobClipResult::ClippedHigh)
} else {
(i as f64, KnobClipResult::InRange)
}
}
};
let final_v = match proposed {
KnobValue::F64(_) => KnobValue::F64(clipped_f),
KnobValue::Usize(_) => KnobValue::Usize(clipped_f as usize),
};
(final_v, bound_result)
}
pub fn apply(&mut self, proposed: KnobValue) -> KnobClipResult {
let (final_v, result) = self.clip(proposed);
self.current = final_v;
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn f64_clip_in_range_is_noop() {
let mut k = CalibrationKnob::new_f64("fraud.fraud_rate", 0.02, 0.0, 0.1, 0.01);
let r = k.apply(KnobValue::F64(0.025));
assert_eq!(r, KnobClipResult::InRange);
assert_eq!(k.current, KnobValue::F64(0.025));
}
#[test]
fn f64_clip_low_clamps_to_min() {
let mut k = CalibrationKnob::new_f64("fraud.fraud_rate", 0.02, 0.01, 0.1, 1.0);
let r = k.apply(KnobValue::F64(0.001));
assert_eq!(r, KnobClipResult::ClippedLow);
assert_eq!(k.current, KnobValue::F64(0.01));
}
#[test]
fn f64_clip_high_clamps_to_max() {
let mut k = CalibrationKnob::new_f64("fraud.fraud_rate", 0.02, 0.0, 0.05, 1.0);
let r = k.apply(KnobValue::F64(0.1));
assert_eq!(r, KnobClipResult::ClippedHigh);
assert_eq!(k.current, KnobValue::F64(0.05));
}
#[test]
fn f64_step_size_clamps_large_jumps() {
let mut k = CalibrationKnob::new_f64("fraud.fraud_rate", 0.02, 0.0, 0.1, 0.01);
let r = k.apply(KnobValue::F64(0.05));
assert_eq!(r, KnobClipResult::InRange);
assert!(
(k.current.as_f64() - 0.03).abs() < 1e-12,
"value should land at 0.02 + 0.01 step = 0.03, got {}",
k.current
);
}
#[test]
fn f64_step_size_clamps_negative_jumps() {
let mut k = CalibrationKnob::new_f64("rate", 0.10, 0.0, 1.0, 0.01);
let r = k.apply(KnobValue::F64(0.02));
assert_eq!(r, KnobClipResult::InRange);
assert!((k.current.as_f64() - 0.09).abs() < 1e-12);
}
#[test]
fn usize_clip_rounds_and_clamps() {
let mut k = CalibrationKnob::new_usize("pool.target_size", 12, 5, 20, 4.0);
let r = k.apply(KnobValue::Usize(17));
assert_eq!(r, KnobClipResult::InRange);
assert_eq!(k.current, KnobValue::Usize(16));
}
#[test]
fn usize_clip_low_clamps_to_min() {
let mut k = CalibrationKnob::new_usize("pool.target_size", 12, 5, 20, 100.0);
let r = k.apply(KnobValue::Usize(2));
assert_eq!(r, KnobClipResult::ClippedLow);
assert_eq!(k.current, KnobValue::Usize(5));
}
#[test]
fn type_mismatch_is_noop() {
let mut k = CalibrationKnob::new_f64("fraud.fraud_rate", 0.02, 0.0, 0.1, 0.01);
let r = k.apply(KnobValue::Usize(12));
assert_eq!(r, KnobClipResult::TypeMismatch);
assert_eq!(k.current, KnobValue::F64(0.02));
}
#[test]
fn yaml_string_round_trips() {
assert_eq!(KnobValue::F64(0.025).to_yaml_string(), "0.025");
assert_eq!(KnobValue::Usize(12).to_yaml_string(), "12");
}
#[test]
fn bounds_width_is_max_minus_min() {
assert!((KnobBounds::F64Range { min: 0.0, max: 0.1 }.width() - 0.1).abs() < 1e-12);
assert_eq!(KnobBounds::UsizeRange { min: 5, max: 20 }.width(), 15.0);
}
}