use serde::{Deserialize, Serialize};
pub trait DecayCurve {
#[must_use]
fn decay_factor(&self, elapsed: f64) -> f64;
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize)]
pub struct ExponentialDecay {
half_life: f64,
#[serde(skip)]
lambda: f64,
}
impl<'de> serde::Deserialize<'de> for ExponentialDecay {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
#[derive(serde::Deserialize)]
struct Raw {
half_life: f64,
}
let raw = Raw::deserialize(deserializer)?;
Ok(Self::new(raw.half_life))
}
}
impl ExponentialDecay {
#[inline]
#[must_use]
pub fn new(half_life: f64) -> Self {
let hl = if half_life.is_finite() && half_life > 0.0 {
half_life
} else {
1.0
};
Self {
half_life: hl,
lambda: core::f64::consts::LN_2 / hl,
}
}
#[inline]
#[must_use]
pub fn half_life(&self) -> f64 {
self.half_life
}
#[inline]
#[must_use]
pub fn lambda(&self) -> f64 {
self.lambda
}
}
impl DecayCurve for ExponentialDecay {
#[inline]
fn decay_factor(&self, elapsed: f64) -> f64 {
if elapsed <= 0.0 {
return 1.0;
}
(-self.lambda * elapsed).exp()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct LogisticCurve {
pub midpoint: f64,
pub steepness: f64,
}
impl LogisticCurve {
#[inline]
#[must_use]
pub const fn new(midpoint: f64, steepness: f64) -> Self {
Self {
midpoint,
steepness,
}
}
#[inline]
#[must_use]
pub fn evaluate(&self, x: f64) -> f64 {
1.0 / (1.0 + (-self.steepness * (x - self.midpoint)).exp())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn exponential_no_elapsed() {
let d = ExponentialDecay::new(10.0);
assert!((d.decay_factor(0.0) - 1.0).abs() < 1e-10);
}
#[test]
fn exponential_half_life() {
let d = ExponentialDecay::new(10.0);
assert!((d.decay_factor(10.0) - 0.5).abs() < 1e-10);
}
#[test]
fn exponential_two_half_lives() {
let d = ExponentialDecay::new(10.0);
assert!((d.decay_factor(20.0) - 0.25).abs() < 1e-10);
}
#[test]
fn exponential_negative_elapsed() {
let d = ExponentialDecay::new(10.0);
assert!((d.decay_factor(-5.0) - 1.0).abs() < 1e-10);
}
#[test]
fn exponential_invalid_half_life_clamped() {
let d = ExponentialDecay::new(0.0);
assert_eq!(d.half_life(), 1.0);
let d = ExponentialDecay::new(-5.0);
assert_eq!(d.half_life(), 1.0);
let d = ExponentialDecay::new(f64::NAN);
assert_eq!(d.half_life(), 1.0);
}
#[test]
fn exponential_lambda() {
let d = ExponentialDecay::new(10.0);
assert!((d.lambda() - core::f64::consts::LN_2 / 10.0).abs() < 1e-15);
}
#[test]
fn exponential_serde_roundtrip() {
let d = ExponentialDecay::new(300.0);
let json = serde_json::to_string(&d).unwrap();
let back: ExponentialDecay = serde_json::from_str(&json).unwrap();
assert_eq!(back.half_life(), 300.0);
}
#[test]
fn logistic_midpoint() {
let c = LogisticCurve::new(0.0, 4.0);
assert!((c.evaluate(0.0) - 0.5).abs() < 1e-10);
}
#[test]
fn logistic_far_positive() {
let c = LogisticCurve::new(0.0, 4.0);
assert!(c.evaluate(5.0) > 0.99);
}
#[test]
fn logistic_far_negative() {
let c = LogisticCurve::new(0.0, 4.0);
assert!(c.evaluate(-5.0) < 0.01);
}
#[test]
fn logistic_shifted_midpoint() {
let c = LogisticCurve::new(5.0, 2.0);
assert!((c.evaluate(5.0) - 0.5).abs() < 1e-10);
}
#[test]
fn logistic_steepness_effect() {
let shallow = LogisticCurve::new(0.0, 1.0);
let steep = LogisticCurve::new(0.0, 10.0);
assert!(steep.evaluate(1.0) > shallow.evaluate(1.0));
}
#[test]
fn logistic_serde_roundtrip() {
let c = LogisticCurve::new(3.0, 2.5);
let json = serde_json::to_string(&c).unwrap();
let back: LogisticCurve = serde_json::from_str(&json).unwrap();
assert_eq!(back, c);
}
}