use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
#[derive(Debug, Clone)]
pub struct DecayPredictor {
alpha_prior: f32,
beta_prior: f32,
alpha: f32,
beta: f32,
pub global_multiplier: f32,
observation_count: u64,
rng: SmallRng,
}
impl DecayPredictor {
pub fn new() -> Self {
Self::with_prior(1.0, 100.0)
}
pub fn with_prior(alpha: f32, beta: f32) -> Self {
Self {
alpha_prior: alpha,
beta_prior: beta,
alpha,
beta,
global_multiplier: 1.0,
observation_count: 0,
rng: SmallRng::from_entropy(),
}
}
pub fn predict(&self) -> f32 {
(self.alpha / self.beta) * self.global_multiplier
}
pub fn sample(&mut self) -> f32 {
let mean = self.predict();
let std_dev = self.uncertainty().sqrt();
let noise: f32 = self.rng.gen_range(-1.0..1.0);
(mean + std_dev * noise).max(0.0001)
}
pub fn observe(&mut self, actual_decay: f32) {
self.observation_count += 1;
if actual_decay > 0.0 {
let n = self.observation_count as f32;
let old_mean = self.alpha / self.beta;
let new_mean = old_mean + (actual_decay - old_mean) / n;
self.alpha = self.alpha_prior + n;
self.beta = self.alpha / new_mean.max(0.0001);
}
}
pub fn reset(&mut self, blend_factor: f32) {
self.alpha = self.alpha * blend_factor + self.alpha_prior * (1.0 - blend_factor);
self.beta = self.beta * blend_factor + self.beta_prior * (1.0 - blend_factor);
}
pub fn uncertainty(&self) -> f32 {
self.alpha / (self.beta * self.beta)
}
pub fn observations(&self) -> u64 {
self.observation_count
}
}
impl Default for DecayPredictor {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_predictor_convergence() {
let mut predictor = DecayPredictor::new();
for _ in 0..100 {
predictor.observe(0.01);
}
let prediction = predictor.predict();
assert!((prediction - 0.01).abs() < 0.005);
}
#[test]
fn test_uncertainty_decreases() {
let mut predictor = DecayPredictor::new();
let initial_uncertainty = predictor.uncertainty();
for _ in 0..10 {
predictor.observe(0.01);
}
assert!(predictor.uncertainty() < initial_uncertainty);
}
#[test]
fn test_thompson_sampling() {
let mut predictor = DecayPredictor::new();
let sample = predictor.sample();
assert!(sample > 0.0);
}
}