use serde::{Deserialize, Serialize};
use crate::error::{Result, validate_finite, validate_non_negative, validate_positive};
#[inline]
#[must_use = "returns the population change without side effects"]
pub fn logistic_growth(n: f64, r: f64, k: f64) -> Result<f64> {
validate_non_negative(n, "n")?;
validate_finite(r, "r")?;
validate_positive(k, "k")?;
Ok(r * n * (1.0 - n / k))
}
#[inline]
#[must_use = "returns the new population without side effects"]
pub fn logistic_growth_step(n: f64, r: f64, k: f64, dt: f64) -> Result<f64> {
let dn = logistic_growth(n, r, k)?;
validate_positive(dt, "dt")?;
Ok(n + dn * dt)
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[non_exhaustive]
pub struct SirState {
pub s: f64,
pub i: f64,
pub r: f64,
}
impl SirState {
#[inline]
#[must_use]
pub fn new(s: f64, i: f64, r: f64) -> Self {
Self { s, i, r }
}
}
#[must_use = "returns the new SIR state without side effects"]
pub fn sir_step(s: f64, i: f64, r: f64, beta: f64, gamma: f64, dt: f64) -> Result<SirState> {
validate_non_negative(s, "s")?;
validate_non_negative(i, "i")?;
validate_non_negative(r, "r")?;
validate_positive(beta, "beta")?;
validate_positive(gamma, "gamma")?;
validate_positive(dt, "dt")?;
let ds = -beta * s * i;
let di = beta * s * i - gamma * i;
let dr = gamma * i;
Ok(SirState {
s: (s + ds * dt).max(0.0),
i: (i + di * dt).max(0.0),
r: (r + dr * dt).max(0.0),
})
}
#[inline]
#[must_use = "returns R0 without side effects"]
pub fn r_naught(beta: f64, gamma: f64) -> Result<f64> {
validate_positive(beta, "beta")?;
validate_positive(gamma, "gamma")?;
Ok(beta / gamma)
}
#[inline]
#[must_use = "returns the herd immunity threshold without side effects"]
pub fn herd_immunity_threshold(r0: f64) -> Result<f64> {
validate_positive(r0, "r0")?;
Ok(1.0 - 1.0 / r0)
}
#[must_use = "returns the SIR trajectory without side effects"]
pub fn sir_trajectory(
s0: f64,
i0: f64,
r0: f64,
beta: f64,
gamma: f64,
dt: f64,
steps: usize,
) -> Result<Vec<SirState>> {
let mut trajectory = Vec::with_capacity(steps + 1);
trajectory.push(SirState {
s: s0,
i: i0,
r: r0,
});
let mut state = SirState {
s: s0,
i: i0,
r: r0,
};
for _ in 0..steps {
state = sir_step(state.s, state.i, state.r, beta, gamma, dt)?;
trajectory.push(state);
}
Ok(trajectory)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_logistic_growth_at_zero() {
let dn = logistic_growth(0.0, 0.5, 100.0).unwrap();
assert!((dn - 0.0).abs() < 1e-10);
}
#[test]
fn test_logistic_growth_at_capacity() {
let dn = logistic_growth(100.0, 0.5, 100.0).unwrap();
assert!((dn - 0.0).abs() < 1e-10);
}
#[test]
fn test_logistic_growth_max_at_half_k() {
let dn_half = logistic_growth(50.0, 1.0, 100.0).unwrap();
let dn_quarter = logistic_growth(25.0, 1.0, 100.0).unwrap();
let dn_three_quarter = logistic_growth(75.0, 1.0, 100.0).unwrap();
assert!(dn_half > dn_quarter);
assert!(dn_half > dn_three_quarter);
}
#[test]
fn test_r_naught() {
let r0 = r_naught(0.5, 0.2).unwrap();
assert!((r0 - 2.5).abs() < 1e-10);
}
#[test]
fn test_herd_immunity_r0_3() {
let h = herd_immunity_threshold(3.0).unwrap();
assert!((h - 2.0 / 3.0).abs() < 1e-10);
}
#[test]
fn test_herd_immunity_r0_2_5() {
let h = herd_immunity_threshold(2.5).unwrap();
assert!((h - 0.6).abs() < 1e-10);
}
#[test]
fn test_sir_step_conservation() {
let state = sir_step(0.99, 0.01, 0.0, 0.5, 0.1, 0.01).unwrap();
let total = state.s + state.i + state.r;
assert!((total - 1.0).abs() < 0.01);
}
#[test]
fn test_sir_declining_epidemic() {
let state = sir_step(0.99, 0.01, 0.0, 0.1, 0.5, 0.1).unwrap();
assert!(state.i < 0.01);
}
#[test]
fn test_sir_state_serde_roundtrip() {
let state = SirState {
s: 0.9,
i: 0.05,
r: 0.05,
};
let json = serde_json::to_string(&state).unwrap();
let back: SirState = serde_json::from_str(&json).unwrap();
assert!((state.s - back.s).abs() < 1e-10);
}
#[test]
fn test_logistic_growth_step() {
let n = logistic_growth_step(50.0, 1.0, 100.0, 0.1).unwrap();
assert!((n - 52.5).abs() < 1e-10);
}
#[test]
fn test_logistic_growth_step_error() {
assert!(logistic_growth_step(50.0, 1.0, 100.0, 0.0).is_err()); assert!(logistic_growth_step(50.0, 1.0, 100.0, -1.0).is_err()); }
#[test]
fn test_sir_trajectory_length() {
let traj = sir_trajectory(0.99, 0.01, 0.0, 0.5, 0.1, 0.01, 100).unwrap();
assert_eq!(traj.len(), 101); }
#[test]
fn test_sir_trajectory_conservation() {
let traj = sir_trajectory(0.99, 0.01, 0.0, 0.5, 0.1, 0.01, 50).unwrap();
for state in &traj {
let total = state.s + state.i + state.r;
assert!((total - 1.0).abs() < 0.05);
}
}
#[test]
fn test_herd_immunity_r0_below_1() {
let h = herd_immunity_threshold(0.5).unwrap();
assert!(h < 0.0);
}
#[test]
fn test_sir_state_new() {
let s = SirState::new(0.9, 0.05, 0.05);
assert!((s.s - 0.9).abs() < 1e-10);
assert!((s.i - 0.05).abs() < 1e-10);
assert!((s.r - 0.05).abs() < 1e-10);
}
#[test]
fn test_r_naught_error() {
assert!(r_naught(0.0, 0.5).is_err()); assert!(r_naught(0.5, 0.0).is_err()); }
}