use crate::algorithms::properties::is_simple::{SimpleMode, is_simple_with_mode};
use crate::core::rng::SplitMix64;
use crate::core::{Graph, IgraphError, IgraphResult};
#[derive(Debug, Clone, PartialEq)]
pub struct Sir {
pub times: Vec<f64>,
pub no_s: Vec<usize>,
pub no_i: Vec<usize>,
pub no_r: Vec<usize>,
}
struct PsumTree {
n: usize,
bit: Vec<f64>,
values: Vec<f64>,
total: f64,
}
impl PsumTree {
fn new(n: usize) -> Self {
Self {
n,
bit: vec![0.0; n + 1],
values: vec![0.0; n],
total: 0.0,
}
}
fn get(&self, i: usize) -> f64 {
self.values[i]
}
fn total(&self) -> f64 {
self.total
}
fn set(&mut self, i: usize, v: f64) {
let delta = v - self.values[i];
self.values[i] = v;
self.total += delta;
let mut k = i + 1;
while k <= self.n {
self.bit[k] += delta;
k += k & k.wrapping_neg();
}
}
fn reset(&mut self) {
for b in &mut self.bit {
*b = 0.0;
}
for v in &mut self.values {
*v = 0.0;
}
self.total = 0.0;
}
fn search(&self, target: f64) -> usize {
if self.n == 0 {
return 0;
}
let mut idx: usize = 0;
let mut remaining = target;
let mut step = 1usize;
while step.saturating_mul(2) <= self.n {
step *= 2;
}
while step > 0 {
let next = idx + step;
if next <= self.n && self.bit[next] <= remaining {
idx = next;
remaining -= self.bit[next];
}
step >>= 1;
}
idx.min(self.n - 1)
}
}
const S_S: u8 = 0;
const S_I: u8 = 1;
const S_R: u8 = 2;
pub fn sir(
graph: &Graph,
beta: f64,
gamma: f64,
no_sim: usize,
seed: u64,
) -> IgraphResult<Vec<Sir>> {
let n = graph.vcount() as usize;
if n == 0 {
return Err(IgraphError::InvalidArgument(
"Cannot run SIR model on empty graph.".to_string(),
));
}
if beta < 0.0 {
return Err(IgraphError::InvalidArgument(format!(
"The infection rate beta must be non-negative (got {beta})."
)));
}
if gamma <= 0.0 {
return Err(IgraphError::InvalidArgument(format!(
"The recovery rate gamma must be positive (got {gamma})."
)));
}
if no_sim == 0 {
return Err(IgraphError::InvalidArgument(
"Number of SIR simulations must be positive.".to_string(),
));
}
if !is_simple_with_mode(graph, SimpleMode::DirectedAsUndirected)? {
return Err(IgraphError::InvalidArgument(
"SIR model only works with simple graphs.".to_string(),
));
}
let adj = build_undirected_adj(graph)?;
let mut rng = SplitMix64::new(seed);
let mut tree = PsumTree::new(n);
let mut status = vec![S_S; n];
let mut result = Vec::with_capacity(no_sim);
for _ in 0..no_sim {
result.push(run_one(
&adj,
beta,
gamma,
n,
&mut rng,
&mut tree,
&mut status,
));
}
Ok(result)
}
fn build_undirected_adj(graph: &Graph) -> IgraphResult<Vec<Vec<usize>>> {
let n = graph.vcount() as usize;
let m = graph.ecount();
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
for eid in 0..m {
let eid_u32 =
u32::try_from(eid).map_err(|_| IgraphError::Internal("ecount exceeds u32::MAX"))?;
let (src, tgt) = graph.edge(eid_u32)?;
adj[src as usize].push(tgt as usize);
adj[tgt as usize].push(src as usize);
}
Ok(adj)
}
fn run_one(
adj: &[Vec<usize>],
beta: f64,
gamma: f64,
n: usize,
rng: &mut SplitMix64,
tree: &mut PsumTree,
status: &mut [u8],
) -> Sir {
let infected = rng.gen_index(n);
for s in status.iter_mut() {
*s = S_S;
}
status[infected] = S_I;
let mut ns = n - 1;
let mut ni = 1usize;
let mut nr = 0usize;
let mut times = vec![0.0_f64];
let mut no_s = vec![ns];
let mut no_i = vec![ni];
let mut no_r = vec![nr];
tree.reset();
tree.set(infected, gamma);
for &nei in &adj[infected] {
tree.set(nei, beta);
}
while ni > 0 {
let psum = tree.total();
let tt = -(1.0 - rng.gen_unit()).ln() / psum;
let r = rng.gen_unit() * psum;
let vchange = tree.search(r);
if status[vchange] == S_I {
status[vchange] = S_R;
ni -= 1;
nr += 1;
tree.set(vchange, 0.0);
for &nei in &adj[vchange] {
if status[nei] == S_S {
let mut rate = tree.get(nei) - beta;
if rate < 0.0 {
rate = 0.0;
}
tree.set(nei, rate);
}
}
} else {
status[vchange] = S_I;
ns -= 1;
ni += 1;
tree.set(vchange, gamma);
for &nei in &adj[vchange] {
if status[nei] == S_S {
let rate = tree.get(nei) + beta;
tree.set(nei, rate);
}
}
}
let last = *times.last().unwrap_or(&0.0);
times.push(tt + last);
no_s.push(ns);
no_i.push(ni);
no_r.push(nr);
}
Sir {
times,
no_s,
no_i,
no_r,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn ring(n: u32) -> Graph {
let mut g = Graph::with_vertices(n);
for i in 0..n {
g.add_edge(i, (i + 1) % n).unwrap();
}
g
}
fn complete(n: u32) -> Graph {
let mut g = Graph::with_vertices(n);
for i in 0..n {
for j in (i + 1)..n {
g.add_edge(i, j).unwrap();
}
}
g
}
#[test]
fn empty_graph_errors() {
let g = Graph::with_vertices(0);
assert!(sir(&g, 1.0, 1.0, 1, 0).is_err());
}
#[test]
fn parameter_errors() {
let g = ring(5);
assert!(sir(&g, -0.1, 1.0, 1, 0).is_err()); assert!(sir(&g, 1.0, 0.0, 1, 0).is_err()); assert!(sir(&g, 1.0, -1.0, 1, 0).is_err()); assert!(sir(&g, 1.0, 1.0, 0, 0).is_err()); }
#[test]
fn non_simple_graph_errors() {
let mut g = Graph::with_vertices(3);
g.add_edge(0, 1).unwrap();
g.add_edge(0, 1).unwrap(); assert!(sir(&g, 1.0, 1.0, 1, 0).is_err());
let mut g2 = Graph::with_vertices(3);
g2.add_edge(0, 0).unwrap(); g2.add_edge(1, 2).unwrap();
assert!(sir(&g2, 1.0, 1.0, 1, 0).is_err());
}
#[test]
fn produces_requested_number_of_runs() {
let g = ring(10);
let runs = sir(&g, 2.0, 1.0, 7, 0xABCD).unwrap();
assert_eq!(runs.len(), 7);
}
#[test]
fn initial_state_is_consistent() {
let g = complete(6);
let runs = sir(&g, 1.0, 1.0, 5, 42).unwrap();
for run in &runs {
#[allow(clippy::float_cmp)]
{
assert_eq!(run.times[0], 0.0);
}
assert_eq!(run.no_i[0], 1);
assert_eq!(run.no_s[0], 5);
assert_eq!(run.no_r[0], 0);
}
}
#[test]
fn population_conserved_and_terminates() {
let g = complete(8);
let runs = sir(&g, 3.0, 1.0, 10, 0x1234_5678).unwrap();
for run in &runs {
let len = run.times.len();
assert_eq!(run.no_s.len(), len);
assert_eq!(run.no_i.len(), len);
assert_eq!(run.no_r.len(), len);
for k in 0..len {
assert_eq!(run.no_s[k] + run.no_i[k] + run.no_r[k], 8);
}
assert_eq!(*run.no_i.last().unwrap(), 0);
for k in 1..len {
assert!(run.no_s[k] <= run.no_s[k - 1]);
assert!(run.no_r[k] >= run.no_r[k - 1]);
}
}
}
#[test]
fn times_strictly_increasing() {
let g = complete(7);
let runs = sir(&g, 2.0, 1.0, 4, 0x9999).unwrap();
for run in &runs {
for k in 1..run.times.len() {
assert!(run.times[k] > run.times[k - 1]);
}
}
}
#[test]
fn deterministic_with_seed() {
let g = complete(6);
let a = sir(&g, 1.5, 0.7, 5, 0xDEAD_BEEF).unwrap();
let b = sir(&g, 1.5, 0.7, 5, 0xDEAD_BEEF).unwrap();
assert_eq!(a, b);
}
#[test]
fn different_seeds_differ() {
let g = complete(20);
let a = sir(&g, 2.0, 0.5, 1, 1).unwrap();
let b = sir(&g, 2.0, 0.5, 1, 2).unwrap();
assert!(a != b);
}
#[test]
fn zero_beta_recovers_immediately() {
let g = complete(5);
let runs = sir(&g, 0.0, 1.0, 6, 0x2468).unwrap();
for run in &runs {
assert_eq!(run.times.len(), 2);
assert_eq!(run.no_r.last().copied(), Some(1));
assert_eq!(run.no_s.last().copied(), Some(4));
}
}
#[test]
fn directed_graph_ignores_direction() {
let mut g = Graph::new(5, true).unwrap();
for i in 0..5u32 {
g.add_edge(i, (i + 1) % 5).unwrap();
}
let runs = sir(&g, 2.0, 1.0, 3, 0x55).unwrap();
assert_eq!(runs.len(), 3);
for run in &runs {
assert_eq!(*run.no_i.last().unwrap(), 0);
}
}
}