use std::collections::{HashMap, HashSet, VecDeque};
use scirs2_core::random::{Rng, RngExt};
use crate::error::{GraphError, Result};
pub type AdjList = HashMap<usize, Vec<(usize, f64)>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SirState {
Susceptible,
Infected,
Recovered,
}
#[derive(Debug, Clone)]
pub struct SimulationResult {
pub activated: HashSet<usize>,
pub time_series: Vec<(usize, usize, usize)>,
pub spread: usize,
}
#[derive(Debug, Clone)]
pub struct IndependentCascade {
pub adjacency: AdjList,
pub num_nodes: usize,
}
impl IndependentCascade {
pub fn new(adjacency: AdjList, num_nodes: usize) -> Self {
IndependentCascade {
adjacency,
num_nodes,
}
}
pub fn from_edges(edges: &[(usize, usize, f64)], num_nodes: usize) -> Self {
let mut adjacency: AdjList = HashMap::new();
for &(src, tgt, prob) in edges {
adjacency.entry(src).or_default().push((tgt, prob));
}
IndependentCascade::new(adjacency, num_nodes)
}
pub fn simulate(&self, seeds: &[usize]) -> Result<SimulationResult> {
simulate_ic(&self.adjacency, seeds)
}
pub fn expected_spread(&self, seeds: &[usize], num_simulations: usize) -> Result<f64> {
expected_spread_ic(&self.adjacency, seeds, num_simulations)
}
}
#[derive(Debug, Clone)]
pub struct LinearThreshold {
pub adjacency: AdjList,
pub num_nodes: usize,
pub thresholds: Option<Vec<f64>>,
}
impl LinearThreshold {
pub fn new(adjacency: AdjList, num_nodes: usize) -> Self {
LinearThreshold {
adjacency,
num_nodes,
thresholds: None,
}
}
pub fn with_thresholds(adjacency: AdjList, thresholds: Vec<f64>) -> Result<Self> {
let num_nodes = thresholds.len();
for (i, &t) in thresholds.iter().enumerate() {
if !(0.0..=1.0).contains(&t) {
return Err(GraphError::InvalidParameter {
param: format!("thresholds[{i}]"),
value: t.to_string(),
expected: "value in [0, 1]".to_string(),
context: "LinearThreshold::with_thresholds".to_string(),
});
}
}
Ok(LinearThreshold {
adjacency,
num_nodes,
thresholds: Some(thresholds),
})
}
pub fn from_edges(edges: &[(usize, usize, f64)], num_nodes: usize) -> Self {
let mut adjacency: AdjList = HashMap::new();
for &(src, tgt, w) in edges {
adjacency.entry(src).or_default().push((tgt, w));
}
LinearThreshold::new(adjacency, num_nodes)
}
pub fn simulate(&self, seeds: &[usize]) -> Result<SimulationResult> {
simulate_lt(&self.adjacency, self.num_nodes, seeds, self.thresholds.as_deref())
}
pub fn expected_spread(&self, seeds: &[usize], num_simulations: usize) -> Result<f64> {
expected_spread_lt(
&self.adjacency,
self.num_nodes,
seeds,
self.thresholds.as_deref(),
num_simulations,
)
}
}
#[derive(Debug, Clone)]
pub struct SIRModel {
pub adjacency: AdjList,
pub beta: f64,
pub gamma: f64,
pub num_nodes: usize,
}
impl SIRModel {
pub fn new(adjacency: AdjList, num_nodes: usize, beta: f64, gamma: f64) -> Result<Self> {
if !(0.0..=1.0).contains(&beta) {
return Err(GraphError::InvalidParameter {
param: "beta".to_string(),
value: beta.to_string(),
expected: "[0, 1]".to_string(),
context: "SIRModel::new".to_string(),
});
}
if !(0.0..=1.0).contains(&gamma) {
return Err(GraphError::InvalidParameter {
param: "gamma".to_string(),
value: gamma.to_string(),
expected: "[0, 1]".to_string(),
context: "SIRModel::new".to_string(),
});
}
Ok(SIRModel {
adjacency,
beta,
gamma,
num_nodes,
})
}
pub fn from_edges(edges: &[(usize, usize)], num_nodes: usize, beta: f64, gamma: f64) -> Result<Self> {
let mut adjacency: AdjList = HashMap::new();
for &(src, tgt) in edges {
adjacency.entry(src).or_default().push((tgt, 1.0));
adjacency.entry(tgt).or_default().push((src, 1.0));
}
SIRModel::new(adjacency, num_nodes, beta, gamma)
}
pub fn simulate(&self, initial_infected: &[usize]) -> Result<SimulationResult> {
simulate_sir(&self.adjacency, self.num_nodes, initial_infected, self.beta, self.gamma)
}
}
#[derive(Debug, Clone)]
pub struct SISModel {
pub adjacency: AdjList,
pub beta: f64,
pub gamma: f64,
pub num_nodes: usize,
pub max_steps: usize,
}
impl SISModel {
pub fn new(
adjacency: AdjList,
num_nodes: usize,
beta: f64,
gamma: f64,
max_steps: usize,
) -> Result<Self> {
if !(0.0..=1.0).contains(&beta) {
return Err(GraphError::InvalidParameter {
param: "beta".to_string(),
value: beta.to_string(),
expected: "[0, 1]".to_string(),
context: "SISModel::new".to_string(),
});
}
if !(0.0..=1.0).contains(&gamma) {
return Err(GraphError::InvalidParameter {
param: "gamma".to_string(),
value: gamma.to_string(),
expected: "[0, 1]".to_string(),
context: "SISModel::new".to_string(),
});
}
Ok(SISModel {
adjacency,
beta,
gamma,
num_nodes,
max_steps,
})
}
pub fn from_edges(
edges: &[(usize, usize)],
num_nodes: usize,
beta: f64,
gamma: f64,
max_steps: usize,
) -> Result<Self> {
let mut adjacency: AdjList = HashMap::new();
for &(src, tgt) in edges {
adjacency.entry(src).or_default().push((tgt, 1.0));
adjacency.entry(tgt).or_default().push((src, 1.0));
}
SISModel::new(adjacency, num_nodes, beta, gamma, max_steps)
}
pub fn simulate(&self, initial_infected: &[usize]) -> Result<SimulationResult> {
simulate_sis(
&self.adjacency,
self.num_nodes,
initial_infected,
self.beta,
self.gamma,
self.max_steps,
)
}
}
pub fn simulate_ic(adjacency: &AdjList, seeds: &[usize]) -> Result<SimulationResult> {
let mut rng = scirs2_core::random::rng();
let mut active: HashSet<usize> = seeds.iter().cloned().collect();
let mut queue: VecDeque<usize> = seeds.iter().cloned().collect();
while let Some(node) = queue.pop_front() {
if let Some(neighbors) = adjacency.get(&node) {
for &(nbr, prob) in neighbors {
if !active.contains(&nbr) && rng.random::<f64>() < prob {
active.insert(nbr);
queue.push_back(nbr);
}
}
}
}
let spread = active.len();
Ok(SimulationResult {
activated: active,
time_series: Vec::new(),
spread,
})
}
pub fn expected_spread(adjacency: &AdjList, seeds: &[usize], num_simulations: usize) -> Result<f64> {
expected_spread_ic(adjacency, seeds, num_simulations)
}
fn expected_spread_ic(
adjacency: &AdjList,
seeds: &[usize],
num_simulations: usize,
) -> Result<f64> {
if num_simulations == 0 {
return Err(GraphError::InvalidParameter {
param: "num_simulations".to_string(),
value: "0".to_string(),
expected: ">= 1".to_string(),
context: "expected_spread_ic".to_string(),
});
}
let mut total = 0.0_f64;
for _ in 0..num_simulations {
let result = simulate_ic(adjacency, seeds)?;
total += result.spread as f64;
}
Ok(total / num_simulations as f64)
}
pub fn simulate_lt(
adjacency: &AdjList,
num_nodes: usize,
seeds: &[usize],
fixed_thresholds: Option<&[f64]>,
) -> Result<SimulationResult> {
let mut reverse: HashMap<usize, Vec<(usize, f64)>> = HashMap::new();
for (&src, nbrs) in adjacency {
for &(tgt, w) in nbrs {
reverse.entry(tgt).or_default().push((src, w));
}
}
let mut rng = scirs2_core::random::rng();
let thresholds: Vec<f64> = match fixed_thresholds {
Some(t) => {
if t.len() < num_nodes {
return Err(GraphError::InvalidParameter {
param: "fixed_thresholds".to_string(),
value: format!("len={}", t.len()),
expected: format!(">= num_nodes={num_nodes}"),
context: "simulate_lt".to_string(),
});
}
t.to_vec()
}
None => (0..num_nodes).map(|_| rng.random::<f64>()).collect(),
};
let mut active: HashSet<usize> = seeds.iter().cloned().collect();
let mut changed = true;
while changed {
changed = false;
let candidates: Vec<usize> = reverse
.keys()
.filter(|&&node| !active.contains(&node))
.cloned()
.collect();
for node in candidates {
let weight_sum: f64 = reverse
.get(&node)
.map(|in_nbrs| {
in_nbrs
.iter()
.filter(|(src, _)| active.contains(src))
.map(|(_, w)| w)
.sum()
})
.unwrap_or(0.0);
let threshold = if node < thresholds.len() {
thresholds[node]
} else {
1.0
};
if weight_sum >= threshold {
active.insert(node);
changed = true;
}
}
}
let spread = active.len();
Ok(SimulationResult {
activated: active,
time_series: Vec::new(),
spread,
})
}
fn expected_spread_lt(
adjacency: &AdjList,
num_nodes: usize,
seeds: &[usize],
fixed_thresholds: Option<&[f64]>,
num_simulations: usize,
) -> Result<f64> {
if num_simulations == 0 {
return Err(GraphError::InvalidParameter {
param: "num_simulations".to_string(),
value: "0".to_string(),
expected: ">= 1".to_string(),
context: "expected_spread_lt".to_string(),
});
}
let mut total = 0.0_f64;
for _ in 0..num_simulations {
let result = simulate_lt(adjacency, num_nodes, seeds, fixed_thresholds)?;
total += result.spread as f64;
}
Ok(total / num_simulations as f64)
}
pub fn simulate_sir(
adjacency: &AdjList,
num_nodes: usize,
initial_infected: &[usize],
beta: f64,
gamma: f64,
) -> Result<SimulationResult> {
if !(0.0..=1.0).contains(&beta) || !(0.0..=1.0).contains(&gamma) {
return Err(GraphError::InvalidParameter {
param: "beta/gamma".to_string(),
value: format!("beta={beta}, gamma={gamma}"),
expected: "both in [0, 1]".to_string(),
context: "simulate_sir".to_string(),
});
}
let mut rng = scirs2_core::random::rng();
let mut states: Vec<SirState> = vec![SirState::Susceptible; num_nodes];
for &node in initial_infected {
if node < num_nodes {
states[node] = SirState::Infected;
}
}
let mut time_series: Vec<(usize, usize, usize)> = Vec::new();
let mut ever_infected: HashSet<usize> = initial_infected.iter().cloned().collect();
loop {
let n_infected = states.iter().filter(|&&s| s == SirState::Infected).count();
let n_recovered = states.iter().filter(|&&s| s == SirState::Recovered).count();
let n_susceptible = num_nodes - n_infected - n_recovered;
time_series.push((n_susceptible, n_infected, n_recovered));
if n_infected == 0 {
break;
}
let mut next_states = states.clone();
for node in 0..num_nodes {
if states[node] == SirState::Infected {
if let Some(neighbors) = adjacency.get(&node) {
for &(nbr, _) in neighbors {
if nbr < num_nodes
&& states[nbr] == SirState::Susceptible
&& rng.random::<f64>() < beta
{
next_states[nbr] = SirState::Infected;
ever_infected.insert(nbr);
}
}
}
}
}
for node in 0..num_nodes {
if states[node] == SirState::Infected && rng.random::<f64>() < gamma {
next_states[node] = SirState::Recovered;
}
}
states = next_states;
}
Ok(SimulationResult {
activated: ever_infected,
time_series,
spread: states.iter().filter(|&&s| s == SirState::Recovered).count()
+ states.iter().filter(|&&s| s == SirState::Infected).count(),
})
}
pub fn simulate_sis(
adjacency: &AdjList,
num_nodes: usize,
initial_infected: &[usize],
beta: f64,
gamma: f64,
max_steps: usize,
) -> Result<SimulationResult> {
if !(0.0..=1.0).contains(&beta) || !(0.0..=1.0).contains(&gamma) {
return Err(GraphError::InvalidParameter {
param: "beta/gamma".to_string(),
value: format!("beta={beta}, gamma={gamma}"),
expected: "both in [0, 1]".to_string(),
context: "simulate_sis".to_string(),
});
}
let mut rng = scirs2_core::random::rng();
let mut infected: HashSet<usize> = initial_infected.iter().cloned().collect();
let mut ever_infected = infected.clone();
let mut time_series: Vec<(usize, usize, usize)> = Vec::new();
for _step in 0..max_steps {
let n_infected = infected.len();
time_series.push((num_nodes - n_infected, n_infected, 0));
if n_infected == 0 {
break;
}
let mut new_infections: HashSet<usize> = HashSet::new();
let mut new_recoveries: HashSet<usize> = HashSet::new();
for &node in &infected {
if let Some(neighbors) = adjacency.get(&node) {
for &(nbr, _) in neighbors {
if nbr < num_nodes
&& !infected.contains(&nbr)
&& rng.random::<f64>() < beta
{
new_infections.insert(nbr);
ever_infected.insert(nbr);
}
}
}
if rng.random::<f64>() < gamma {
new_recoveries.insert(node);
}
}
for node in new_recoveries {
infected.remove(&node);
}
for node in new_infections {
infected.insert(node);
}
}
let spread = ever_infected.len();
Ok(SimulationResult {
activated: ever_infected,
time_series,
spread,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn star_adjacency(n: usize, prob: f64) -> AdjList {
let mut adj: AdjList = HashMap::new();
for i in 1..n {
adj.entry(0).or_default().push((i, prob));
}
adj
}
#[test]
fn test_simulate_ic_full_spread() {
let adj = star_adjacency(6, 1.0);
let result = simulate_ic(&adj, &[0]).expect("ic simulation");
assert_eq!(result.spread, 6);
}
#[test]
fn test_simulate_ic_no_spread() {
let adj = star_adjacency(6, 0.0);
let result = simulate_ic(&adj, &[0]).expect("ic simulation");
assert_eq!(result.spread, 1);
}
#[test]
fn test_simulate_lt_deterministic_threshold() {
let mut adj: AdjList = HashMap::new();
for i in 1..4_usize {
adj.entry(0).or_default().push((i, 1.0));
}
let thresholds = vec![0.5_f64; 4];
let result =
simulate_lt(&adj, 4, &[0], Some(&thresholds)).expect("lt simulation");
assert!(result.spread >= 1);
}
#[test]
fn test_simulate_sir_terminates() {
let mut adj: AdjList = HashMap::new();
for i in 0..4_usize {
adj.entry(i).or_default().push((i + 1, 1.0));
adj.entry(i + 1).or_default().push((i, 1.0));
}
let result = simulate_sir(&adj, 5, &[0], 0.8, 0.5).expect("sir");
assert!(result.spread >= 1);
assert!(!result.time_series.is_empty());
}
#[test]
fn test_simulate_sis_terminates() {
let mut adj: AdjList = HashMap::new();
for i in 0..4_usize {
adj.entry(i).or_default().push((i + 1, 1.0));
adj.entry(i + 1).or_default().push((i, 1.0));
}
let result = simulate_sis(&adj, 5, &[0], 0.5, 0.9, 1000).expect("sis");
assert!(result.spread >= 1);
}
#[test]
fn test_expected_spread_ic() {
let adj = star_adjacency(5, 1.0);
let spread = expected_spread(&adj, &[0], 50).expect("expected spread");
assert!((spread - 5.0).abs() < 0.01);
}
#[test]
fn test_sir_bad_params() {
let adj: AdjList = HashMap::new();
let err = simulate_sir(&adj, 1, &[], 2.0, 0.5);
assert!(err.is_err());
}
#[test]
fn test_ic_struct() {
let edges = vec![(0_usize, 1_usize, 1.0_f64), (0, 2, 1.0), (0, 3, 1.0)];
let ic = IndependentCascade::from_edges(&edges, 4);
let res = ic.simulate(&[0]).expect("simulate");
assert_eq!(res.spread, 4);
}
#[test]
fn test_lt_bad_threshold() {
let adj: AdjList = HashMap::new();
let err = LinearThreshold::with_thresholds(adj, vec![0.5, 1.5]);
assert!(err.is_err());
}
}