#[derive(Debug, Clone)]
pub struct OnlineTemperatureScaling {
temperature: f64,
lr: f64,
n_updates: u64,
}
impl OnlineTemperatureScaling {
pub fn new(lr: f64) -> Self {
assert!(lr > 0.0, "learning rate must be > 0, got {lr}");
Self {
temperature: 1.0,
lr,
n_updates: 0,
}
}
pub fn calibrate(&self, logits: &[f64]) -> Vec<f64> {
let inv_t = 1.0 / self.temperature;
let scaled: Vec<f64> = logits.iter().map(|&z| z * inv_t).collect();
stable_softmax(&scaled)
}
pub fn update(&mut self, logits: &[f64], true_class: usize) {
debug_assert!(
true_class < logits.len(),
"true_class {} out of range for {} logits",
true_class,
logits.len(),
);
let proba = self.calibrate(logits);
let z_c = logits[true_class];
let weighted_mean: f64 = logits.iter().zip(proba.iter()).map(|(&z, &p)| z * p).sum();
let t_sq = self.temperature * self.temperature;
let grad = (z_c - weighted_mean) / t_sq;
self.temperature -= self.lr * grad;
self.temperature = self.temperature.clamp(0.01, 100.0);
self.n_updates += 1;
}
#[inline]
pub fn temperature(&self) -> f64 {
self.temperature
}
#[inline]
pub fn n_updates(&self) -> u64 {
self.n_updates
}
pub fn reset(&mut self) {
self.temperature = 1.0;
self.n_updates = 0;
}
}
impl Default for OnlineTemperatureScaling {
fn default() -> Self {
Self::new(0.01)
}
}
fn stable_softmax(logits: &[f64]) -> Vec<f64> {
let max_logit = logits.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = logits.iter().map(|&z| (z - max_logit).exp()).collect();
let sum: f64 = exps.iter().sum();
exps.iter().map(|&e| e / sum).collect()
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-10;
#[test]
fn initial_temperature_is_one() {
let ts = OnlineTemperatureScaling::new(0.01);
assert!(
(ts.temperature() - 1.0).abs() < EPS,
"initial temperature should be 1.0, got {}",
ts.temperature()
);
assert_eq!(ts.n_updates(), 0);
}
#[test]
fn calibrate_with_t_equals_one_is_softmax() {
let ts = OnlineTemperatureScaling::new(0.01);
let logits = vec![2.0, 1.0, 0.5];
let calibrated = ts.calibrate(&logits);
let expected = stable_softmax(&logits);
for (c, e) in calibrated.iter().zip(expected.iter()) {
assert!(
(c - e).abs() < EPS,
"calibrate with T=1 should equal softmax: {c} vs {e}"
);
}
}
#[test]
fn calibrate_sums_to_one() {
let ts = OnlineTemperatureScaling::new(0.01);
let logits = vec![3.0, 1.0, -2.0, 0.5];
let calibrated = ts.calibrate(&logits);
let sum: f64 = calibrated.iter().sum();
assert!(
(sum - 1.0).abs() < EPS,
"calibrated probabilities should sum to 1.0, got {sum}"
);
}
#[test]
fn higher_temperature_flattens_distribution() {
let mut ts = OnlineTemperatureScaling::new(0.01);
let logits = vec![5.0, 1.0, 0.0];
let sharp = ts.calibrate(&logits);
ts.temperature = 5.0;
let flat = ts.calibrate(&logits);
let max_sharp = sharp.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let max_flat = flat.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
assert!(
max_sharp > max_flat,
"higher T should flatten distribution: max_sharp={max_sharp} > max_flat={max_flat}"
);
}
#[test]
fn lower_temperature_sharpens_distribution() {
let mut ts = OnlineTemperatureScaling::new(0.01);
let logits = vec![2.0, 1.0, 0.5];
let normal = ts.calibrate(&logits);
ts.temperature = 0.1;
let sharp = ts.calibrate(&logits);
let max_normal = normal.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let max_sharp = sharp.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
assert!(
max_sharp > max_normal,
"lower T should sharpen distribution: max_sharp={max_sharp} > max_normal={max_normal}"
);
}
#[test]
fn update_adjusts_temperature() {
let mut ts = OnlineTemperatureScaling::new(0.1);
let t_before = ts.temperature();
let logits = vec![5.0, 1.0, 0.0];
ts.update(&logits, 2);
assert_ne!(
ts.temperature(),
t_before,
"temperature should change after update"
);
assert_eq!(ts.n_updates(), 1);
}
#[test]
fn overconfident_model_increases_temperature() {
let mut ts = OnlineTemperatureScaling::new(0.05);
for _ in 0..100 {
ts.update(&[10.0, 0.0, 0.0], 1); ts.update(&[10.0, 0.0, 0.0], 2); }
assert!(
ts.temperature() > 1.0,
"overconfident wrong predictions should increase T, got {}",
ts.temperature()
);
}
#[test]
fn temperature_stays_positive_after_many_updates() {
let mut ts = OnlineTemperatureScaling::new(0.1);
for i in 0..1000 {
let true_class = i % 3;
ts.update(&[1.0, 2.0, 3.0], true_class);
}
assert!(
ts.temperature() > 0.0,
"temperature must stay positive, got {}",
ts.temperature()
);
assert!(
ts.temperature().is_finite(),
"temperature must be finite, got {}",
ts.temperature()
);
}
#[test]
fn reset_restores_default_state() {
let mut ts = OnlineTemperatureScaling::new(0.01);
ts.update(&[1.0, 0.0], 0);
ts.update(&[0.0, 1.0], 1);
assert!(ts.n_updates() > 0);
ts.reset();
assert!(
(ts.temperature() - 1.0).abs() < EPS,
"temperature should be 1.0 after reset, got {}",
ts.temperature()
);
assert_eq!(ts.n_updates(), 0, "n_updates should be 0 after reset");
}
#[test]
fn default_uses_lr_0_01() {
let ts = OnlineTemperatureScaling::default();
assert!(
(ts.temperature() - 1.0).abs() < EPS,
"default temperature should be 1.0"
);
}
#[test]
#[should_panic(expected = "learning rate must be > 0")]
fn panics_on_zero_lr() {
let _ = OnlineTemperatureScaling::new(0.0);
}
#[test]
#[should_panic(expected = "learning rate must be > 0")]
fn panics_on_negative_lr() {
let _ = OnlineTemperatureScaling::new(-0.01);
}
#[test]
fn extreme_logits_dont_cause_nan() {
let mut ts = OnlineTemperatureScaling::new(0.01);
let logits = vec![1000.0, -1000.0, 0.0];
let calibrated = ts.calibrate(&logits);
assert!(
calibrated.iter().all(|p| p.is_finite()),
"calibrate should be finite for extreme logits"
);
ts.update(&logits, 0);
assert!(
ts.temperature().is_finite(),
"temperature should be finite after extreme update"
);
}
}