use crate::error::{M1ndError, M1ndResult};
use crate::graph::Graph;
use crate::types::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use std::time::Instant;
pub const DEFAULT_ITERATIONS: u32 = 50;
pub const MAX_ITERATIONS: u32 = 500;
pub const DEFAULT_RECOVERY_RATE: f32 = 0.0;
pub const MAX_INFECTION_RATE: f32 = 0.95;
pub const MIN_INFECTION_RATE: f32 = 0.001;
pub const DEFAULT_TOP_K: usize = 20;
pub const BURNOUT_THRESHOLD: f32 = 0.8;
pub const BURNOUT_MIN_ITERATIONS: u32 = 10;
pub const EXTINCTION_PLATEAU_ITERATIONS: u32 = 5;
pub const RECOVERED_FIREWALL_FACTOR: f32 = 0.5;
pub const COUPLING_FACTORS: &[(&str, f32)] = &[
("imports", 0.8),
("calls", 0.7),
("inherits", 0.6),
("references", 0.4),
("contains", 0.3),
("related_to", 0.2),
];
pub const DEFAULT_COUPLING_FACTOR: f32 = 0.1;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
pub enum Compartment {
Susceptible = 0,
Infected = 1,
Recovered = 2,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum EpidemicDirection {
Forward,
Backward,
Both,
}
#[derive(Clone, Debug)]
pub struct EpidemicConfig {
pub infection_rate: Option<f32>,
pub recovery_rate: f32,
pub iterations: u32,
pub direction: EpidemicDirection,
pub top_k: usize,
pub burnout_threshold: f32,
pub promotion_threshold: f32,
}
#[derive(Clone, Debug, Serialize)]
pub struct EpidemicPrediction {
pub node_id: String,
pub label: String,
pub node_type: String,
pub infection_probability: f32,
pub generation: u32,
pub contributing_infected: Vec<String>,
pub transmission_path: Vec<String>,
pub edge_weight_to_nearest: f32,
}
#[derive(Clone, Debug, Serialize)]
pub struct EpidemicSummary {
pub total_susceptible: u32,
pub total_infected: u32,
pub total_recovered: u32,
pub peak_infection_iteration: u32,
pub r0_estimate: f32,
pub epidemic_extinct: bool,
}
#[derive(Clone, Debug, Serialize)]
pub struct UnreachableComponent {
pub representative_node: String,
pub node_count: u32,
}
#[derive(Clone, Debug, Serialize)]
pub struct EpidemicResult {
pub predictions: Vec<EpidemicPrediction>,
pub summary: EpidemicSummary,
pub unreachable_components: Vec<UnreachableComponent>,
pub warnings: Vec<String>,
pub unresolved_nodes: Vec<String>,
pub elapsed_ms: f64,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct EpidemicPersistentState {
pub infected_nodes: HashMap<String, f64>,
pub recovered_nodes: HashMap<String, f64>,
pub cumulative_infections: HashMap<String, u32>,
pub last_run_timestamp: Option<f64>,
}
fn sir_coupling_factor(relation: &str) -> f32 {
for &(rel, factor) in COUPLING_FACTORS {
if relation == rel {
return factor;
}
}
DEFAULT_COUPLING_FACTOR
}
#[inline]
fn sir_clamp(v: f32, lo: f32, hi: f32) -> f32 {
if v < lo {
lo
} else if v > hi {
hi
} else {
v
}
}
fn sir_node_type_str(nt: &NodeType) -> &'static str {
match nt {
NodeType::File => "file",
NodeType::Directory => "directory",
NodeType::Function => "function",
NodeType::Class => "class",
NodeType::Struct => "struct",
NodeType::Enum => "enum",
NodeType::Type => "type",
NodeType::Module => "module",
NodeType::Reference => "reference",
NodeType::Concept => "concept",
NodeType::Material => "material",
NodeType::Process => "process",
NodeType::Product => "product",
NodeType::Supplier => "supplier",
NodeType::Regulatory => "regulatory",
NodeType::System => "system",
NodeType::Cost => "cost",
NodeType::Custom(_) => "custom",
}
}
fn sir_build_node_to_ext(graph: &Graph) -> Vec<String> {
let n = graph.num_nodes() as usize;
let mut node_to_ext = vec![String::new(); n];
for (interned, &nid) in &graph.id_to_node {
let idx = nid.as_usize();
if idx < n {
node_to_ext[idx] = graph.strings.resolve(*interned).to_string();
}
}
node_to_ext
}
fn sir_reconstruct_path(target: usize, parent: &[u32], node_to_ext: &[String]) -> Vec<String> {
let mut path = Vec::new();
let mut cur = target;
let max_steps = parent.len();
for _ in 0..max_steps {
path.push(node_to_ext[cur].clone());
if parent[cur] == cur as u32 || parent[cur] == u32::MAX {
break;
}
cur = parent[cur] as usize;
}
path.reverse();
path
}
pub struct EpidemicEngine;
impl Default for EpidemicEngine {
fn default() -> Self {
Self::new()
}
}
impl EpidemicEngine {
pub fn new() -> Self {
Self
}
pub fn simulate(
&self,
graph: &Graph,
infected_ids: &[NodeId],
recovered_ids: &[NodeId],
config: &EpidemicConfig,
) -> M1ndResult<EpidemicResult> {
let start = Instant::now();
let n = graph.num_nodes() as usize;
let mut warnings: Vec<String> = Vec::new();
let uniform_rate = config.infection_rate.map(|r| {
let clamped = sir_clamp(r, MIN_INFECTION_RATE, MAX_INFECTION_RATE);
if (clamped - r).abs() > f32::EPSILON {
warnings.push(format!(
"infection_rate clamped from {:.2} to {:.2}",
r, clamped
));
}
clamped
});
let iterations = config.iterations.min(MAX_ITERATIONS);
if infected_ids.is_empty() {
return Err(M1ndError::NoValidInfectedNodes);
}
if n == 0 {
return Ok(EpidemicResult {
predictions: Vec::new(),
summary: EpidemicSummary {
total_susceptible: 0,
total_infected: 0,
total_recovered: 0,
peak_infection_iteration: 0,
r0_estimate: 0.0,
epidemic_extinct: true,
},
unreachable_components: Vec::new(),
warnings,
unresolved_nodes: Vec::new(),
elapsed_ms: start.elapsed().as_secs_f64() * 1000.0,
});
}
let node_to_ext = sir_build_node_to_ext(graph);
let mut compartment = vec![Compartment::Susceptible; n];
let mut probability = vec![0.0f32; n];
let mut generation = vec![u32::MAX; n];
let mut parent = vec![u32::MAX; n];
let mut contributing: Vec<Vec<usize>> = vec![Vec::new(); n];
let mut edge_weight_nearest = vec![0.0f32; n];
for &nid in infected_ids {
let idx = nid.as_usize();
if idx < n {
compartment[idx] = Compartment::Infected;
probability[idx] = 1.0;
generation[idx] = 0;
parent[idx] = idx as u32;
contributing[idx] = vec![idx];
}
}
for &nid in recovered_ids {
let idx = nid.as_usize();
if idx < n {
compartment[idx] = Compartment::Recovered;
probability[idx] = 0.0;
generation[idx] = 0;
parent[idx] = idx as u32;
}
}
let initial_infected_count =
infected_ids.iter().filter(|nid| nid.as_usize() < n).count() as u32;
let mut peak_infected: u32 = initial_infected_count;
let mut peak_iteration: u32 = 0;
let mut total_new_infections_sum: u32 = 0;
let mut consecutive_zero_new = 0u32;
let mut epidemic_extinct = false;
let mut new_probability = vec![0.0f32; n];
for t in 0..iterations {
new_probability.copy_from_slice(&probability);
let mut new_infections_this_iter = 0u32;
for src in 0..n {
if compartment[src] != Compartment::Infected {
continue;
}
let src_prob = probability[src];
if src_prob <= 0.0 {
continue;
}
match config.direction {
EpidemicDirection::Forward | EpidemicDirection::Both => {
let range = graph.csr.out_range(NodeId::new(src as u32));
for edge_pos in range {
let tgt = graph.csr.targets[edge_pos].as_usize();
if tgt >= n {
continue;
}
if compartment[tgt] == Compartment::Recovered {
continue;
}
if compartment[tgt] == Compartment::Infected {
continue;
}
let p_transmit = self.sir_compute_edge_transmission(
graph,
edge_pos,
uniform_rate,
src,
&compartment,
);
let p_new = p_transmit * src_prob;
let old_p = new_probability[tgt];
new_probability[tgt] = 1.0 - (1.0 - old_p) * (1.0 - p_new);
let src_gen = generation[src];
if src_gen != u32::MAX && generation[tgt] > src_gen + 1 {
generation[tgt] = src_gen + 1;
parent[tgt] = src as u32;
edge_weight_nearest[tgt] =
graph.csr.read_weight(EdgeIdx::new(edge_pos as u32)).get();
contributing[tgt] = contributing[src].clone();
}
}
}
EpidemicDirection::Backward => { }
}
match config.direction {
EpidemicDirection::Backward | EpidemicDirection::Both => {
let range = graph.csr.in_range(NodeId::new(src as u32));
for rev_pos in range {
let tgt = graph.csr.rev_sources[rev_pos].as_usize();
if tgt >= n {
continue;
}
if compartment[tgt] == Compartment::Recovered {
continue;
}
if compartment[tgt] == Compartment::Infected {
continue;
}
let fwd_edge_idx = graph.csr.rev_edge_idx[rev_pos].as_usize();
let p_transmit = self.sir_compute_edge_transmission(
graph,
fwd_edge_idx,
uniform_rate,
src,
&compartment,
);
let p_new = p_transmit * src_prob;
let old_p = new_probability[tgt];
new_probability[tgt] = 1.0 - (1.0 - old_p) * (1.0 - p_new);
let src_gen = generation[src];
if src_gen != u32::MAX && generation[tgt] > src_gen + 1 {
generation[tgt] = src_gen + 1;
parent[tgt] = src as u32;
edge_weight_nearest[tgt] = graph
.csr
.read_weight(EdgeIdx::new(fwd_edge_idx as u32))
.get();
contributing[tgt] = contributing[src].clone();
}
}
}
EpidemicDirection::Forward => { }
}
}
for i in 0..n {
if compartment[i] == Compartment::Susceptible
&& new_probability[i] > 0.0
&& probability[i] == 0.0
{
new_infections_this_iter += 1;
}
}
let mut current_infected_count = 0u32;
for i in 0..n {
if compartment[i] == Compartment::Susceptible
&& new_probability[i] > config.promotion_threshold
{
compartment[i] = Compartment::Infected;
}
if compartment[i] == Compartment::Infected {
current_infected_count += 1;
}
}
probability.copy_from_slice(&new_probability);
total_new_infections_sum += new_infections_this_iter;
if current_infected_count > peak_infected {
peak_infected = current_infected_count;
peak_iteration = t + 1;
}
if t < BURNOUT_MIN_ITERATIONS {
let infected_pct = current_infected_count as f32 / n.max(1) as f32;
if infected_pct > config.burnout_threshold {
return Err(M1ndError::EpidemicBurnout {
infected_pct: infected_pct * 100.0,
iteration: t + 1,
});
}
}
if new_infections_this_iter == 0 {
consecutive_zero_new += 1;
if consecutive_zero_new >= EXTINCTION_PLATEAU_ITERATIONS {
epidemic_extinct = true;
break;
}
} else {
consecutive_zero_new = 0;
}
}
let mut prediction_indices: Vec<usize> = (0..n)
.filter(|&i| {
let was_original_infected = infected_ids.iter().any(|nid| nid.as_usize() == i);
let was_original_recovered = recovered_ids.iter().any(|nid| nid.as_usize() == i);
!was_original_infected && !was_original_recovered && probability[i] > 0.0
})
.collect();
prediction_indices.sort_by(|&a, &b| {
probability[b]
.partial_cmp(&probability[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
prediction_indices.truncate(config.top_k);
let predictions: Vec<EpidemicPrediction> = prediction_indices
.iter()
.map(|&i| {
let ext_id = &node_to_ext[i];
let label = graph.strings.resolve(graph.nodes.label[i]).to_string();
let nt = sir_node_type_str(&graph.nodes.node_type[i]);
let path = sir_reconstruct_path(i, &parent, &node_to_ext);
let contrib: Vec<String> = contributing[i]
.iter()
.filter_map(|&c| {
if c < n {
Some(node_to_ext[c].clone())
} else {
None
}
})
.collect();
EpidemicPrediction {
node_id: ext_id.clone(),
label,
node_type: nt.to_string(),
infection_probability: probability[i],
generation: generation[i],
contributing_infected: contrib,
transmission_path: path,
edge_weight_to_nearest: edge_weight_nearest[i],
}
})
.collect();
let total_infected = compartment
.iter()
.filter(|&&c| c == Compartment::Infected)
.count() as u32;
let total_recovered = compartment
.iter()
.filter(|&&c| c == Compartment::Recovered)
.count() as u32;
let total_susceptible = n as u32 - total_infected - total_recovered;
let r0_estimate = if initial_infected_count > 0 {
total_new_infections_sum as f32 / initial_infected_count as f32
} else {
0.0
};
let reachable: Vec<bool> = (0..n)
.map(|i| probability[i] > 0.0 || compartment[i] != Compartment::Susceptible)
.collect();
let unreachable_components = self.find_unreachable_components(graph, &reachable);
let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
Ok(EpidemicResult {
predictions,
summary: EpidemicSummary {
total_susceptible,
total_infected,
total_recovered,
peak_infection_iteration: peak_iteration,
r0_estimate,
epidemic_extinct,
},
unreachable_components,
warnings,
unresolved_nodes: Vec::new(),
elapsed_ms,
})
}
fn sir_compute_edge_transmission(
&self,
graph: &Graph,
edge_idx: usize,
uniform_rate: Option<f32>,
src: usize,
compartment: &[Compartment],
) -> f32 {
let mut p = if let Some(rate) = uniform_rate {
rate
} else {
self.edge_transmission_probability(graph, edge_idx, None)
};
if compartment[src] == Compartment::Recovered {
p *= RECOVERED_FIREWALL_FACTOR;
}
let src_out_degree = graph.csr.out_range(NodeId::new(src as u32)).len() as f32;
if src_out_degree > 1.0 {
p /= src_out_degree.sqrt();
}
p
}
fn edge_transmission_probability(
&self,
graph: &Graph,
edge_idx: usize,
uniform_rate: Option<f32>,
) -> f32 {
if let Some(rate) = uniform_rate {
return rate;
}
let weight = graph.csr.read_weight(EdgeIdx::new(edge_idx as u32)).get();
let relation_interned = graph.csr.relations[edge_idx];
let relation_str = graph.strings.resolve(relation_interned);
let coupling = sir_coupling_factor(relation_str);
sir_clamp(weight * coupling, 0.0, MAX_INFECTION_RATE)
}
fn find_unreachable_components(
&self,
graph: &Graph,
reachable: &[bool],
) -> Vec<UnreachableComponent> {
let n = graph.num_nodes() as usize;
if n == 0 {
return Vec::new();
}
let node_to_ext = sir_build_node_to_ext(graph);
let mut visited = vec![false; n];
let mut components = Vec::new();
for start in 0..n {
if reachable[start] || visited[start] {
continue;
}
let mut queue = std::collections::VecDeque::new();
queue.push_back(start);
visited[start] = true;
let mut count = 0u32;
while let Some(cur) = queue.pop_front() {
count += 1;
let range = graph.csr.out_range(NodeId::new(cur as u32));
for edge_pos in range {
let tgt = graph.csr.targets[edge_pos].as_usize();
if tgt < n && !visited[tgt] && !reachable[tgt] {
visited[tgt] = true;
queue.push_back(tgt);
}
}
let rev_range = graph.csr.in_range(NodeId::new(cur as u32));
for rev_pos in rev_range {
let src = graph.csr.rev_sources[rev_pos].as_usize();
if src < n && !visited[src] && !reachable[src] {
visited[src] = true;
queue.push_back(src);
}
}
}
if count > 0 {
components.push(UnreachableComponent {
representative_node: node_to_ext[start].clone(),
node_count: count,
});
}
}
components
}
}
pub fn save_epidemic_state(state: &EpidemicPersistentState, path: &Path) -> M1ndResult<()> {
let tmp_path = path.with_extension("json.tmp");
let json = serde_json::to_string_pretty(state)?;
std::fs::write(&tmp_path, json.as_bytes())?;
std::fs::rename(&tmp_path, path)?;
Ok(())
}
pub fn load_epidemic_state(path: &Path) -> M1ndResult<EpidemicPersistentState> {
if !path.exists() {
return Ok(EpidemicPersistentState::default());
}
let data = std::fs::read_to_string(path)?;
let state: EpidemicPersistentState = serde_json::from_str(&data)?;
Ok(state)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::Graph;
use crate::types::*;
fn default_config() -> EpidemicConfig {
EpidemicConfig {
infection_rate: Some(0.2), recovery_rate: DEFAULT_RECOVERY_RATE,
iterations: DEFAULT_ITERATIONS,
direction: EpidemicDirection::Forward,
top_k: DEFAULT_TOP_K,
burnout_threshold: 1.1, promotion_threshold: 0.0,
}
}
fn chain_graph() -> Graph {
let mut g = Graph::new();
for i in 0..4u32 {
g.add_node(
&format!("n{}", i),
&format!("module_{}", i),
NodeType::Module,
&[],
1.0,
0.5,
)
.unwrap();
}
for i in 0..3u32 {
g.add_edge(
NodeId::new(i),
NodeId::new(i + 1),
"related_to", FiniteF32::new(1.0),
EdgeDirection::Forward,
false,
FiniteF32::new(0.8),
)
.unwrap();
}
g.finalize().unwrap();
g
}
#[test]
fn propagation_spreads_along_chain() {
let g = chain_graph();
let engine = EpidemicEngine::new();
let config = default_config();
let infected = vec![NodeId::new(0)];
let result = engine.simulate(&g, &infected, &[], &config).unwrap();
assert!(
!result.predictions.is_empty(),
"expect predictions for downstream nodes"
);
let has_module1 = result
.predictions
.iter()
.any(|p| p.label.contains("module_1"));
assert!(has_module1, "module_1 should be predicted as at-risk");
}
#[test]
fn recovered_node_blocks_as_firewall() {
let g = chain_graph();
let engine = EpidemicEngine::new();
let config = default_config();
let infected = vec![NodeId::new(0)];
let recovered = vec![NodeId::new(1)];
let result = engine.simulate(&g, &infected, &recovered, &config).unwrap();
let has_module1 = result.predictions.iter().any(|p| p.label == "module_1");
assert!(
!has_module1,
"recovered node should not appear in predictions"
);
}
#[test]
fn burnout_fires_on_dense_fully_connected_graph() {
let mut g = Graph::new();
for i in 0..5u32 {
g.add_node(
&format!("n{}", i),
&format!("mod_{}", i),
NodeType::Module,
&[],
1.0,
0.5,
)
.unwrap();
}
for i in 0..5u32 {
for j in 0..5u32 {
if i != j {
let _ = g.add_edge(
NodeId::new(i),
NodeId::new(j),
"imports",
FiniteF32::new(1.0),
EdgeDirection::Forward,
false,
FiniteF32::new(1.0),
);
}
}
}
g.finalize().unwrap();
let engine = EpidemicEngine::new();
let mut config = default_config();
config.infection_rate = Some(0.95); config.burnout_threshold = 0.5; config.promotion_threshold = 0.0;
let infected = vec![NodeId::new(0)];
let result = engine.simulate(&g, &infected, &[], &config);
match result {
Err(crate::error::M1ndError::EpidemicBurnout { .. }) => {
}
Ok(_) => {
}
Err(e) => panic!("unexpected error: {:?}", e),
}
}
#[test]
fn high_promotion_threshold_reduces_predictions() {
let g = chain_graph();
let engine = EpidemicEngine::new();
let mut config_low = default_config();
config_low.promotion_threshold = 0.0;
let mut config_high = default_config();
config_high.promotion_threshold = 0.99;
let infected = vec![NodeId::new(0)];
let result_low = engine.simulate(&g, &infected, &[], &config_low).unwrap();
let result_high = engine.simulate(&g, &infected, &[], &config_high).unwrap();
assert!(
result_high.predictions.len() <= result_low.predictions.len(),
"high promotion_threshold should yield <= predictions than low threshold"
);
}
#[test]
fn empty_infected_returns_error() {
let g = chain_graph();
let engine = EpidemicEngine::new();
let config = default_config();
let result = engine.simulate(&g, &[], &[], &config);
assert!(matches!(
result,
Err(crate::error::M1ndError::NoValidInfectedNodes)
));
}
#[test]
fn all_predictions_have_positive_probability() {
let g = chain_graph();
let engine = EpidemicEngine::new();
let config = default_config();
let infected = vec![NodeId::new(0)];
let result = engine.simulate(&g, &infected, &[], &config).unwrap();
for pred in &result.predictions {
assert!(
pred.infection_probability > 0.0,
"prediction {} has zero probability",
pred.label
);
}
}
#[test]
fn bidirectional_direction_reaches_both_ends() {
let g = chain_graph();
let engine = EpidemicEngine::new();
let mut config = default_config();
config.direction = EpidemicDirection::Both;
config.infection_rate = Some(0.9);
let infected = vec![NodeId::new(2)];
let result = engine.simulate(&g, &infected, &[], &config).unwrap();
let has_downstream = result.predictions.iter().any(|p| p.label == "module_3");
let has_upstream = result.predictions.iter().any(|p| p.label == "module_1");
assert!(has_downstream, "module_3 (downstream) should be at risk");
assert!(
has_upstream,
"module_1 (upstream) should be at risk in bidirectional mode"
);
}
#[test]
fn zero_node_graph_with_infected_returns_empty_result() {
let mut g = Graph::new();
g.finalize().unwrap();
let engine = EpidemicEngine::new();
let config = default_config();
let infected = vec![NodeId::new(0)];
let result = engine.simulate(&g, &infected, &[], &config).unwrap();
assert_eq!(result.predictions.len(), 0);
assert_eq!(result.summary.total_infected, 0);
assert!(result.summary.epidemic_extinct);
}
}