use crate::epidemic::{EpidemicConfig, EpidemicDirection, EpidemicEngine, EpidemicResult};
use crate::error::{M1ndError, M1ndResult};
use crate::flow::{FlowConfig, FlowEngine, FlowSimulationResult};
use crate::graph::Graph;
use crate::types::{FiniteF32, NodeId};
use serde::Serialize;
use std::collections::HashSet;
#[derive(Clone, Debug)]
pub enum TaintType {
UserInput,
SensitiveData,
Custom { boundary_patterns: Vec<String> },
}
#[derive(Clone, Debug)]
pub struct TaintConfig {
pub max_depth: u32,
pub num_particles: u32,
pub epidemic_iterations: u32,
pub min_probability: f32,
pub taint_type: TaintType,
}
impl Default for TaintConfig {
fn default() -> Self {
Self {
max_depth: 15,
num_particles: 4,
epidemic_iterations: 50,
min_probability: 0.01,
taint_type: TaintType::UserInput,
}
}
}
#[derive(Clone, Debug, Serialize)]
pub struct BoundaryCheck {
pub node_id: String,
pub label: String,
pub boundary_type: String,
pub taint_reached: bool,
pub infection_probability: f32,
}
#[derive(Clone, Debug, Serialize)]
pub struct TaintLeak {
pub entry_node: String,
pub exit_node: String,
pub probability: f32,
pub path: Vec<String>,
}
#[derive(Clone, Debug, Serialize)]
pub struct TaintResult {
pub boundary_hits: Vec<BoundaryCheck>,
pub boundary_misses: Vec<BoundaryCheck>,
pub leaks: Vec<TaintLeak>,
pub flow_result: FlowSimulationResult,
pub epidemic_result: EpidemicResult,
pub risk_score: f32,
pub summary: TaintSummary,
}
#[derive(Clone, Debug, Serialize)]
pub struct TaintSummary {
pub entry_points: usize,
pub total_nodes_reached: usize,
pub boundary_hits: usize,
pub boundary_misses: usize,
pub leaks_found: usize,
pub max_infection_probability: f32,
pub elapsed_ms: f64,
}
fn default_boundary_patterns(taint_type: &TaintType) -> Vec<&'static str> {
match taint_type {
TaintType::UserInput => vec![
"validate",
"sanitize",
"escape",
"encode",
"clean",
"filter",
"check_input",
"verify",
"parse_input",
"whitelist",
"allowlist",
"blocklist",
"blacklist",
"csrf",
"xss",
"sql_injection",
"injection",
],
TaintType::SensitiveData => vec![
"auth",
"authenticate",
"authorize",
"verify_token",
"check_permission",
"check_role",
"is_admin",
"encrypt",
"decrypt",
"hash",
"hmac",
"sign",
"mask",
"redact",
"anonymize",
"obfuscate",
"access_control",
"permission",
"credential",
],
TaintType::Custom { boundary_patterns } => {
let _ = boundary_patterns;
vec![] }
}
}
pub struct TaintEngine;
impl TaintEngine {
pub fn analyze(
graph: &Graph,
entry_node_ids: &[NodeId],
config: &TaintConfig,
) -> M1ndResult<TaintResult> {
let start = std::time::Instant::now();
if entry_node_ids.is_empty() {
return Err(M1ndError::NoEntryPoints);
}
let flow_config = FlowConfig {
max_depth: config.max_depth.min(255) as u8,
..FlowConfig::with_defaults()
};
let flow_engine = FlowEngine::new();
let flow_result =
flow_engine.simulate(graph, entry_node_ids, config.num_particles, &flow_config)?;
let epidemic_config = EpidemicConfig {
iterations: config.epidemic_iterations,
infection_rate: None,
recovery_rate: 0.0,
top_k: 200,
direction: EpidemicDirection::Forward,
burnout_threshold: 1.0,
promotion_threshold: 0.5,
};
let epidemic_engine = EpidemicEngine::new();
let epidemic_result = epidemic_engine.simulate(
graph,
entry_node_ids,
&[], &epidemic_config,
)?;
let boundary_patterns = default_boundary_patterns(&config.taint_type);
let custom_patterns: Vec<String> = if let TaintType::Custom {
boundary_patterns: cp,
} = &config.taint_type
{
cp.clone()
} else {
vec![]
};
let n = graph.num_nodes() as usize;
let mut node_to_ext: Vec<String> = vec![String::new(); n];
for (interned, node_id) in &graph.id_to_node {
let idx = node_id.as_usize();
if idx < n {
node_to_ext[idx] = graph.strings.resolve(*interned).to_string();
}
}
let infected_nodes: HashSet<String> = epidemic_result
.predictions
.iter()
.filter(|p| p.infection_probability >= config.min_probability)
.map(|p| p.node_id.clone())
.collect();
let mut boundary_hits = Vec::new();
let mut boundary_misses = Vec::new();
#[allow(clippy::needless_range_loop)]
for i in 0..n {
let label = graph.strings.resolve(graph.nodes.label[i]).to_lowercase();
let ext_id = &node_to_ext[i];
let boundary_type = detect_boundary_type(&label, &boundary_patterns, &custom_patterns);
if let Some(btype) = boundary_type {
let taint_reached = infected_nodes.contains(ext_id);
let prob = epidemic_result
.predictions
.iter()
.find(|p| p.node_id == *ext_id)
.map(|p| p.infection_probability)
.unwrap_or(0.0);
let check = BoundaryCheck {
node_id: ext_id.to_string(),
label: graph.strings.resolve(graph.nodes.label[i]).to_string(),
boundary_type: btype,
taint_reached,
infection_probability: prob,
};
if taint_reached {
boundary_hits.push(check);
} else {
boundary_misses.push(check);
}
}
}
let boundary_ext_ids: HashSet<&str> = boundary_hits
.iter()
.chain(boundary_misses.iter())
.map(|b| b.node_id.as_str())
.collect();
let mut leaks = Vec::new();
for pred in &epidemic_result.predictions {
if pred.infection_probability < config.min_probability {
continue;
}
if boundary_ext_ids.contains(pred.node_id.as_str()) {
continue;
}
let has_boundary_in_path = pred
.transmission_path
.iter()
.any(|node| boundary_ext_ids.contains(node.as_str()));
if !has_boundary_in_path && !pred.transmission_path.is_empty() {
let entry = pred.transmission_path.first().cloned().unwrap_or_default();
leaks.push(TaintLeak {
entry_node: entry,
exit_node: pred.node_id.clone(),
probability: pred.infection_probability,
path: pred.transmission_path.clone(),
});
}
}
leaks.sort_by(|a, b| {
b.probability
.partial_cmp(&a.probability)
.unwrap_or(std::cmp::Ordering::Equal)
});
let risk_score =
compute_risk_score(&boundary_hits, &boundary_misses, &leaks, &epidemic_result);
let elapsed_ms = start.elapsed().as_secs_f64() * 1000.0;
let summary = TaintSummary {
entry_points: entry_node_ids.len(),
total_nodes_reached: epidemic_result.predictions.len() + entry_node_ids.len(),
boundary_hits: boundary_hits.len(),
boundary_misses: boundary_misses.len(),
leaks_found: leaks.len(),
max_infection_probability: epidemic_result
.predictions
.first()
.map(|p| p.infection_probability)
.unwrap_or(0.0),
elapsed_ms,
};
Ok(TaintResult {
boundary_hits,
boundary_misses,
leaks,
flow_result,
epidemic_result,
risk_score,
summary,
})
}
}
fn detect_boundary_type(
label_lower: &str,
default_patterns: &[&str],
custom_patterns: &[String],
) -> Option<String> {
for pattern in default_patterns {
if label_lower.contains(pattern) {
return Some(pattern.to_string());
}
}
for pattern in custom_patterns {
if label_lower.contains(&pattern.to_lowercase()) {
return Some(pattern.clone());
}
}
None
}
fn compute_risk_score(
hits: &[BoundaryCheck],
misses: &[BoundaryCheck],
leaks: &[TaintLeak],
epidemic: &EpidemicResult,
) -> f32 {
let total_boundaries = hits.len() + misses.len();
let miss_ratio = if total_boundaries > 0 {
misses.len() as f32 / total_boundaries as f32
} else {
0.5 };
let leak_factor = if leaks.is_empty() {
0.0
} else {
(1.0 + leaks.len() as f32).ln() / (1.0 + 100.0_f32).ln() };
let spread_factor = if epidemic.summary.total_infected > 0 {
let total = epidemic.summary.total_susceptible
+ epidemic.summary.total_infected
+ epidemic.summary.total_recovered;
if total > 0 {
epidemic.summary.total_infected as f32 / total as f32
} else {
0.0
}
} else {
0.0
};
let score = 0.4 * miss_ratio + 0.35 * leak_factor + 0.25 * spread_factor;
score.clamp(0.0, 1.0)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::*;
use crate::types::{EdgeDirection, FiniteF32, NodeId, NodeType};
fn build_no_boundary_graph() -> Graph {
let mut g = Graph::new();
g.add_node(
"entry",
"handle_request",
NodeType::Function,
&["handler"],
0.0,
0.5,
)
.unwrap();
g.add_node(
"proc",
"process_data",
NodeType::Function,
&["data"],
0.0,
0.3,
)
.unwrap();
g.add_node(
"out",
"send_response",
NodeType::Function,
&["output"],
0.0,
0.2,
)
.unwrap();
g.add_edge(
NodeId::new(0),
NodeId::new(1),
"calls",
FiniteF32::new(0.8),
EdgeDirection::Forward,
false,
FiniteF32::new(0.5),
)
.unwrap();
g.add_edge(
NodeId::new(1),
NodeId::new(2),
"calls",
FiniteF32::new(0.7),
EdgeDirection::Forward,
false,
FiniteF32::new(0.4),
)
.unwrap();
g.finalize().unwrap();
g
}
fn build_with_boundary_graph() -> Graph {
let mut g = Graph::new();
g.add_node(
"entry",
"handle_request",
NodeType::Function,
&["handler"],
0.0,
0.5,
)
.unwrap();
g.add_node(
"val",
"validate_input",
NodeType::Function,
&["security"],
0.0,
0.4,
)
.unwrap();
g.add_node(
"proc",
"process_data",
NodeType::Function,
&["data"],
0.0,
0.3,
)
.unwrap();
g.add_node(
"out",
"send_response",
NodeType::Function,
&["output"],
0.0,
0.2,
)
.unwrap();
g.add_edge(
NodeId::new(0),
NodeId::new(1),
"calls",
FiniteF32::new(0.9),
EdgeDirection::Forward,
false,
FiniteF32::new(0.5),
)
.unwrap();
g.add_edge(
NodeId::new(1),
NodeId::new(2),
"calls",
FiniteF32::new(0.8),
EdgeDirection::Forward,
false,
FiniteF32::new(0.4),
)
.unwrap();
g.add_edge(
NodeId::new(2),
NodeId::new(3),
"calls",
FiniteF32::new(0.7),
EdgeDirection::Forward,
false,
FiniteF32::new(0.4),
)
.unwrap();
g.finalize().unwrap();
g
}
#[test]
fn detect_boundary_validates() {
let patterns = vec!["validate", "sanitize", "auth"];
let custom: Vec<String> = vec![];
assert!(detect_boundary_type("validate_input", &patterns, &custom).is_some());
assert!(detect_boundary_type("process_data", &patterns, &custom).is_none());
assert!(detect_boundary_type("sanitize_html", &patterns, &custom).is_some());
}
#[test]
fn risk_score_no_boundaries_is_medium() {
let score = compute_risk_score(
&[],
&[],
&[],
&EpidemicResult {
predictions: vec![],
summary: crate::epidemic::EpidemicSummary {
total_susceptible: 10,
total_infected: 0,
total_recovered: 0,
peak_infection_iteration: 0,
r0_estimate: 0.0,
epidemic_extinct: true,
},
unreachable_components: vec![],
warnings: vec![],
unresolved_nodes: vec![],
elapsed_ms: 1.0,
},
);
assert!(
(score - 0.2).abs() < 0.1,
"No boundaries = medium risk, got {score}"
);
}
#[test]
fn taint_no_boundary_graph_analysis() {
let g = build_no_boundary_graph();
let config = TaintConfig::default();
let result = TaintEngine::analyze(&g, &[NodeId::new(0)], &config).unwrap();
assert_eq!(result.summary.entry_points, 1);
assert!(result.summary.total_nodes_reached > 0);
}
#[test]
fn taint_with_boundary_detects_validation() {
let g = build_with_boundary_graph();
let config = TaintConfig::default();
let result = TaintEngine::analyze(&g, &[NodeId::new(0)], &config).unwrap();
let all_boundaries: Vec<_> = result
.boundary_hits
.iter()
.chain(result.boundary_misses.iter())
.collect();
assert!(
all_boundaries.iter().any(|b| b.label.contains("validate")),
"Should detect validate_input as boundary, found: {:?}",
all_boundaries
);
}
#[test]
fn taint_empty_entry_returns_error() {
let g = build_no_boundary_graph();
let config = TaintConfig::default();
let result = TaintEngine::analyze(&g, &[], &config);
assert!(result.is_err());
}
#[test]
fn risk_score_bounded_zero_to_one() {
let leak = TaintLeak {
entry_node: "a".into(),
exit_node: "b".into(),
probability: 0.9,
path: vec!["a".into(), "b".into()],
};
let miss = BoundaryCheck {
node_id: "x".into(),
label: "validate".into(),
boundary_type: "validate".into(),
taint_reached: false,
infection_probability: 0.0,
};
let score = compute_risk_score(
&[],
&[miss],
&[leak],
&EpidemicResult {
predictions: vec![],
summary: crate::epidemic::EpidemicSummary {
total_susceptible: 5,
total_infected: 5,
total_recovered: 0,
peak_infection_iteration: 10,
r0_estimate: 2.0,
epidemic_extinct: false,
},
unreachable_components: vec![],
warnings: vec![],
unresolved_nodes: vec![],
elapsed_ms: 1.0,
},
);
assert!(
score >= 0.0 && score <= 1.0,
"Risk score out of range: {score}"
);
assert!(
score > 0.5,
"All misses + leaks + high spread should be high risk, got {score}"
);
}
}