use serde::{Deserialize, Serialize};
use crate::error::{BodhError, Result, validate_finite, validate_non_negative, validate_positive};
#[inline]
#[must_use = "returns the retention value without side effects"]
pub fn ebbinghaus_forgetting(time: f64, stability: f64) -> Result<f64> {
validate_non_negative(time, "time")?;
validate_positive(stability, "stability")?;
Ok((-time / stability).exp())
}
#[inline]
#[must_use = "returns the retention value without side effects"]
pub fn ebbinghaus_forgetting_full(strength: f64, time: f64, stability: f64) -> Result<f64> {
validate_finite(strength, "strength")?;
let retention = ebbinghaus_forgetting(time, stability)?;
Ok(strength * retention)
}
#[inline]
#[must_use = "returns the interval without side effects"]
pub fn spaced_repetition_interval(repetition: u32, ease: f64) -> Result<f64> {
validate_positive(ease, "ease")?;
if repetition == 0 {
return Err(BodhError::InvalidParameter(
"repetition must be at least 1".into(),
));
}
let base = 1.0; Ok(base * ease.powi(repetition as i32 - 1))
}
#[inline]
#[must_use = "returns the performance time without side effects"]
pub fn power_law_of_practice(trial: u32, initial_time: f64, learning_rate: f64) -> Result<f64> {
if trial == 0 {
return Err(BodhError::InvalidParameter(
"trial must be at least 1".into(),
));
}
validate_positive(initial_time, "initial_time")?;
validate_positive(learning_rate, "learning_rate")?;
Ok(initial_time * (trial as f64).powf(-learning_rate))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum ReinforcementSchedule {
FixedRatio,
VariableRatio,
FixedInterval,
VariableInterval,
Continuous,
}
impl ReinforcementSchedule {
#[inline]
#[must_use]
pub fn extinction_resistance(self) -> f64 {
match self {
Self::Continuous => 0.2,
Self::FixedRatio => 0.5,
Self::FixedInterval => 0.4,
Self::VariableRatio => 0.9,
Self::VariableInterval => 0.8,
}
}
}
#[inline]
#[must_use = "returns the change in associative strength without side effects"]
pub fn rescorla_wagner(alpha: f64, beta: f64, lambda: f64, current_v: f64) -> Result<f64> {
validate_finite(alpha, "alpha")?;
validate_finite(beta, "beta")?;
validate_finite(lambda, "lambda")?;
validate_finite(current_v, "current_v")?;
Ok(alpha * beta * (lambda - current_v))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ebbinghaus_at_t0() {
let r = ebbinghaus_forgetting(0.0, 1.0).unwrap();
assert!((r - 1.0).abs() < 1e-10);
}
#[test]
fn test_ebbinghaus_decay() {
let r = ebbinghaus_forgetting(1.0, 1.0).unwrap();
assert!((r - (-1.0_f64).exp()).abs() < 1e-10);
}
#[test]
fn test_ebbinghaus_full_with_strength() {
let r = ebbinghaus_forgetting_full(0.8, 0.0, 1.0).unwrap();
assert!((r - 0.8).abs() < 1e-10);
}
#[test]
fn test_ebbinghaus_monotonic_decay() {
let r1 = ebbinghaus_forgetting(1.0, 2.0).unwrap();
let r2 = ebbinghaus_forgetting(2.0, 2.0).unwrap();
let r3 = ebbinghaus_forgetting(5.0, 2.0).unwrap();
assert!(r1 > r2);
assert!(r2 > r3);
}
#[test]
fn test_spaced_repetition_first() {
let interval = spaced_repetition_interval(1, 2.5).unwrap();
assert!((interval - 1.0).abs() < 1e-10); }
#[test]
fn test_spaced_repetition_grows() {
let i1 = spaced_repetition_interval(1, 2.5).unwrap();
let i2 = spaced_repetition_interval(2, 2.5).unwrap();
let i3 = spaced_repetition_interval(3, 2.5).unwrap();
assert!(i2 > i1);
assert!(i3 > i2);
}
#[test]
fn test_power_law_of_practice() {
let t = power_law_of_practice(1, 10.0, 0.3).unwrap();
assert!((t - 10.0).abs() < 1e-10);
}
#[test]
fn test_power_law_improvement() {
let t1 = power_law_of_practice(1, 10.0, 0.3).unwrap();
let t10 = power_law_of_practice(10, 10.0, 0.3).unwrap();
assert!(t10 < t1); }
#[test]
fn test_reinforcement_schedule_extinction() {
let vr = ReinforcementSchedule::VariableRatio;
let cont = ReinforcementSchedule::Continuous;
assert!(vr.extinction_resistance() > cont.extinction_resistance());
}
#[test]
fn test_rescorla_wagner_acquisition() {
let dv = rescorla_wagner(0.5, 0.5, 1.0, 0.0).unwrap();
assert!((dv - 0.25).abs() < 1e-10);
}
#[test]
fn test_rescorla_wagner_asymptote() {
let dv = rescorla_wagner(0.5, 0.5, 1.0, 1.0).unwrap();
assert!((dv - 0.0).abs() < 1e-10);
}
#[test]
fn test_reinforcement_schedule_serde_roundtrip() {
let sched = ReinforcementSchedule::VariableRatio;
let json = serde_json::to_string(&sched).unwrap();
let back: ReinforcementSchedule = serde_json::from_str(&json).unwrap();
assert_eq!(sched, back);
}
}