#[derive(Debug, Clone)]
pub struct RewardNormalizer {
baseline: f64,
alpha: f64,
initialized: bool,
}
impl RewardNormalizer {
pub fn new(alpha: f64) -> Self {
assert!(
alpha > 0.0 && alpha <= 1.0,
"RewardNormalizer alpha must be in (0, 1], got {alpha}"
);
Self {
baseline: 0.0,
alpha,
initialized: false,
}
}
pub fn with_span(span: usize) -> Self {
assert!(span > 0, "RewardNormalizer span must be > 0, got {span}");
let alpha = 2.0 / (span as f64 + 1.0);
Self::new(alpha)
}
pub fn normalize(&mut self, metric_value: f64) -> f64 {
if !self.initialized {
self.baseline = metric_value;
self.initialized = true;
return 0.5;
}
self.baseline = self.alpha * metric_value + (1.0 - self.alpha) * self.baseline;
if self.baseline == 0.0 {
return 0.5;
}
(1.0 - metric_value / self.baseline).clamp(0.0, 1.0)
}
pub fn baseline(&self) -> f64 {
self.baseline
}
pub fn reset(&mut self) {
self.baseline = 0.0;
self.initialized = false;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn normalizer_first_value_returns_half() {
let mut norm = RewardNormalizer::new(0.1);
let reward = norm.normalize(10.0);
assert!(
(reward - 0.5).abs() < 1e-12,
"first value should return 0.5, got {reward}"
);
}
#[test]
fn better_performance_higher_reward() {
let mut norm = RewardNormalizer::new(0.01);
norm.normalize(10.0);
let reward = norm.normalize(1.0);
assert!(
reward > 0.5,
"lower metric should produce reward > 0.5, got {reward}"
);
}
#[test]
fn worse_performance_lower_reward() {
let mut norm = RewardNormalizer::new(0.01);
norm.normalize(10.0);
let reward = norm.normalize(20.0);
assert!(
reward < 0.5,
"higher metric should produce reward < 0.5, got {reward}"
);
}
#[test]
fn reward_always_clamped() {
let mut norm = RewardNormalizer::new(0.1);
norm.normalize(1.0);
let reward_bad = norm.normalize(1000.0);
assert!(
(0.0..=1.0).contains(&reward_bad),
"reward should be in [0, 1], got {reward_bad}"
);
let reward_good = norm.normalize(0.001);
assert!(
(0.0..=1.0).contains(&reward_good),
"reward should be in [0, 1], got {reward_good}"
);
let reward_neg = norm.normalize(-5.0);
assert!(
(0.0..=1.0).contains(&reward_neg),
"reward should be in [0, 1] even for negative metric, got {reward_neg}"
);
}
#[test]
fn baseline_tracks_values() {
let mut norm = RewardNormalizer::new(0.5); norm.normalize(10.0);
assert!(
(norm.baseline() - 10.0).abs() < 1e-12,
"initial baseline should be 10.0, got {}",
norm.baseline()
);
for _ in 0..20 {
norm.normalize(2.0);
}
assert!(
norm.baseline() < 5.0,
"baseline should have moved toward 2.0, got {}",
norm.baseline()
);
for _ in 0..20 {
norm.normalize(100.0);
}
assert!(
norm.baseline() > 50.0,
"baseline should have moved toward 100.0, got {}",
norm.baseline()
);
}
#[test]
fn reset_clears_state() {
let mut norm = RewardNormalizer::new(0.1);
norm.normalize(10.0);
norm.normalize(5.0);
assert!(
norm.baseline() > 0.0,
"baseline should be non-zero before reset"
);
norm.reset();
assert!(
(norm.baseline() - 0.0).abs() < 1e-12,
"baseline should be 0 after reset, got {}",
norm.baseline()
);
let reward = norm.normalize(42.0);
assert!(
(reward - 0.5).abs() < 1e-12,
"after reset, first value should return 0.5, got {reward}"
);
}
#[test]
fn zero_baseline_returns_half() {
let mut norm = RewardNormalizer::new(1.0); norm.normalize(0.0);
let reward = norm.normalize(5.0);
let mut norm2 = RewardNormalizer::new(0.5);
norm2.normalize(0.0);
let reward2 = norm2.normalize(0.0);
assert!(
(reward2 - 0.5).abs() < 1e-12,
"zero baseline should return 0.5, got {reward2}"
);
let mut norm3 = RewardNormalizer::new(0.1);
let reward3 = norm3.normalize(0.0);
assert!(
(reward3 - 0.5).abs() < 1e-12,
"first value of 0.0 should return 0.5, got {reward3}"
);
assert!(reward.is_finite(), "reward should be finite, got {reward}");
}
#[test]
fn with_span_alpha_computation() {
let norm = RewardNormalizer::with_span(19); assert!(
(norm.alpha - 0.1).abs() < 1e-12,
"span=19 should give alpha=0.1, got {}",
norm.alpha
);
let norm2 = RewardNormalizer::with_span(1); assert!(
(norm2.alpha - 1.0).abs() < 1e-12,
"span=1 should give alpha=1.0, got {}",
norm2.alpha
);
}
#[test]
#[should_panic(expected = "alpha must be in (0, 1]")]
fn new_panics_on_zero_alpha() {
RewardNormalizer::new(0.0);
}
#[test]
#[should_panic(expected = "alpha must be in (0, 1]")]
fn new_panics_on_alpha_over_one() {
RewardNormalizer::new(1.5);
}
#[test]
#[should_panic(expected = "span must be > 0")]
fn with_span_panics_on_zero() {
RewardNormalizer::with_span(0);
}
}