use ahash::{AHashMap, AHashSet};
#[cfg(test)]
use crate::GraphEdge;
use crate::progress::ProgressCallback;
use crate::{GraphEntity, errors::SqliteGraphError, graph::SqliteGraph};
use super::reachability::{can_reach, reachable_from, reverse_reachable_from};
#[derive(Debug, Clone)]
pub struct TaintResult {
pub sources: AHashSet<i64>,
pub sinks_reached: AHashSet<i64>,
pub tainted_nodes: AHashSet<i64>,
pub source_sink_paths: Vec<(i64, i64)>,
pub size: usize,
}
impl TaintResult {
pub fn new() -> Self {
Self {
sources: AHashSet::new(),
sinks_reached: AHashSet::new(),
tainted_nodes: AHashSet::new(),
source_sink_paths: Vec::new(),
size: 0,
}
}
pub fn is_tainted(&self, node: i64) -> bool {
self.tainted_nodes.contains(&node)
}
pub fn has_vulnerability(&self) -> bool {
!self.source_sink_paths.is_empty()
}
pub fn sorted_tainted_nodes(&self) -> Vec<i64> {
let mut nodes: Vec<i64> = self.tainted_nodes.iter().copied().collect();
nodes.sort();
nodes
}
pub fn sorted_vulnerabilities(&self) -> Vec<(i64, i64)> {
let mut paths = self.source_sink_paths.clone();
paths.sort();
paths
}
}
impl Default for TaintResult {
fn default() -> Self {
Self::new()
}
}
pub trait SourceCallback {
fn is_source(&self, node: i64, entity: &GraphEntity) -> bool;
}
pub trait SinkCallback {
fn is_sink(&self, node: i64, entity: &GraphEntity) -> bool;
}
pub struct MetadataSourceDetector;
impl SourceCallback for MetadataSourceDetector {
fn is_source(&self, _node: i64, entity: &GraphEntity) -> bool {
if let Some(kind) = entity.data.get("kind").and_then(|k| k.as_str()) {
if matches!(kind, "source" | "untrusted" | "user_input") {
return true;
}
}
if let Some(taint) = entity.data.get("taint").and_then(|t| t.as_str()) {
if taint == "source" {
return true;
}
}
false
}
}
pub struct MetadataSinkDetector;
impl SinkCallback for MetadataSinkDetector {
fn is_sink(&self, _node: i64, entity: &GraphEntity) -> bool {
if let Some(kind) = entity.data.get("kind").and_then(|k| k.as_str()) {
if matches!(
kind,
"sink" | "sql_query" | "html_output" | "command" | "file_operation"
) {
return true;
}
}
if let Some(operation) = entity.data.get("operation").and_then(|o| o.as_str()) {
if matches!(operation, "execute" | "query" | "render" | "write") {
return true;
}
}
false
}
}
pub fn propagate_taint_forward(
graph: &SqliteGraph,
sources: &[i64],
sinks: &[i64],
) -> Result<TaintResult, SqliteGraphError> {
let mut tainted_nodes: AHashSet<i64> = AHashSet::new();
let sources_set: AHashSet<i64> = sources.iter().copied().collect();
let sinks_set: AHashSet<i64> = sinks.iter().copied().collect();
for &source in sources {
let reachable = reachable_from(graph, source)?;
tainted_nodes.extend(reachable);
}
let sinks_reached: AHashSet<i64> = sinks_set.intersection(&tainted_nodes).copied().collect();
let mut source_sink_paths = Vec::new();
for &source in sources {
for &sink in &sinks_reached {
if can_reach(graph, source, sink)? {
source_sink_paths.push((source, sink));
}
}
}
let size = tainted_nodes.len();
Ok(TaintResult {
sources: sources_set,
sinks_reached,
tainted_nodes,
source_sink_paths,
size,
})
}
pub fn propagate_taint_forward_with_progress<F>(
graph: &SqliteGraph,
sources: &[i64],
sinks: &[i64],
progress: &F,
) -> Result<TaintResult, SqliteGraphError>
where
F: ProgressCallback,
{
let mut tainted_nodes: AHashSet<i64> = AHashSet::new();
let sources_set: AHashSet<i64> = sources.iter().copied().collect();
let sinks_set: AHashSet<i64> = sinks.iter().copied().collect();
let total = sources.len();
for (idx, &source) in sources.iter().enumerate() {
let reachable = reachable_from(graph, source)?;
tainted_nodes.extend(reachable);
progress.on_progress(
idx + 1,
Some(total),
&format!(
"Taint propagation: {}/{} sources processed, {} tainted nodes",
idx + 1,
total,
tainted_nodes.len()
),
);
}
let sinks_reached: AHashSet<i64> = sinks_set.intersection(&tainted_nodes).copied().collect();
let mut source_sink_paths = Vec::new();
for &source in sources {
for &sink in &sinks_reached {
if can_reach(graph, source, sink)? {
source_sink_paths.push((source, sink));
}
}
}
progress.on_complete();
let size = tainted_nodes.len();
Ok(TaintResult {
sources: sources_set,
sinks_reached,
tainted_nodes,
source_sink_paths,
size,
})
}
pub fn propagate_taint_backward(
graph: &SqliteGraph,
sink: i64,
sources: &[i64],
) -> Result<TaintResult, SqliteGraphError> {
let sources_set: AHashSet<i64> = sources.iter().copied().collect();
let ancestors = reverse_reachable_from(graph, sink)?;
let affecting_sources: AHashSet<i64> = sources_set.intersection(&ancestors).copied().collect();
let source_sink_paths: Vec<(i64, i64)> = affecting_sources
.iter()
.map(|&source| (source, sink))
.collect();
let mut sinks_reached = AHashSet::new();
sinks_reached.insert(sink);
let size = ancestors.len();
Ok(TaintResult {
sources: affecting_sources,
sinks_reached,
tainted_nodes: ancestors,
source_sink_paths,
size,
})
}
pub fn propagate_taint_backward_with_progress<F>(
graph: &SqliteGraph,
sink: i64,
sources: &[i64],
progress: &F,
) -> Result<TaintResult, SqliteGraphError>
where
F: ProgressCallback,
{
let sources_set: AHashSet<i64> = sources.iter().copied().collect();
let ancestors = reverse_reachable_from(graph, sink)?;
let affecting_sources: AHashSet<i64> = sources_set.intersection(&ancestors).copied().collect();
progress.on_progress(
1,
Some(1),
&format!(
"Backward taint propagation: from sink {}, {} sources found",
sink,
affecting_sources.len()
),
);
let source_sink_paths: Vec<(i64, i64)> = affecting_sources
.iter()
.map(|&source| (source, sink))
.collect();
let mut sinks_reached = AHashSet::new();
sinks_reached.insert(sink);
progress.on_complete();
let size = ancestors.len();
Ok(TaintResult {
sources: affecting_sources,
sinks_reached,
tainted_nodes: ancestors,
source_sink_paths,
size,
})
}
pub fn sink_reachability_analysis(
graph: &SqliteGraph,
sources: &[i64],
sinks: &[i64],
) -> Result<AHashMap<i64, Vec<i64>>, SqliteGraphError> {
let mut result: AHashMap<i64, Vec<i64>> = AHashMap::new();
for &sink in sinks {
let taint_result = propagate_taint_backward(graph, sink, sources)?;
if !taint_result.sources.is_empty() {
let affecting_sources: Vec<i64> = taint_result.sources.iter().copied().collect();
result.insert(sink, affecting_sources);
}
}
Ok(result)
}
pub fn sink_reachability_analysis_with_progress<F>(
graph: &SqliteGraph,
sources: &[i64],
sinks: &[i64],
progress: &F,
) -> Result<AHashMap<i64, Vec<i64>>, SqliteGraphError>
where
F: ProgressCallback,
{
let mut result: AHashMap<i64, Vec<i64>> = AHashMap::new();
let total = sinks.len();
for (idx, &sink) in sinks.iter().enumerate() {
let taint_result = propagate_taint_backward(graph, sink, sources)?;
if !taint_result.sources.is_empty() {
let affecting_sources: Vec<i64> = taint_result.sources.iter().copied().collect();
result.insert(sink, affecting_sources);
}
progress.on_progress(
idx + 1,
Some(total),
&format!(
"Sink reachability: {}/{} sinks analyzed, {} vulnerabilities found",
idx + 1,
total,
result.len()
),
);
}
progress.on_complete();
Ok(result)
}
pub fn discover_sources_and_sinks(
graph: &SqliteGraph,
source_detector: &impl SourceCallback,
sink_detector: &impl SinkCallback,
) -> Result<(Vec<i64>, Vec<i64>), SqliteGraphError> {
let mut sources = Vec::new();
let mut sinks = Vec::new();
let all_ids = graph.all_entity_ids()?;
for node_id in all_ids {
let entity = graph.get_entity(node_id)?;
if source_detector.is_source(node_id, &entity) {
sources.push(node_id);
}
if sink_detector.is_sink(node_id, &entity) {
sinks.push(node_id);
}
}
Ok((sources, sinks))
}
pub fn discover_sources_and_sinks_default(
graph: &SqliteGraph,
) -> Result<(Vec<i64>, Vec<i64>), SqliteGraphError> {
discover_sources_and_sinks(graph, &MetadataSourceDetector, &MetadataSinkDetector)
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use crate::GraphEntity;
fn create_test_entity(id: i64, kind: &str, data: serde_json::Value) -> GraphEntity {
GraphEntity {
id,
kind: kind.to_string(),
name: format!("node_{}", id),
file_path: None,
data,
}
}
#[test]
fn test_metadata_source_detector_kind_source() {
let detector = MetadataSourceDetector;
let entity = create_test_entity(1, "variable", json!({"kind": "source"}));
assert!(detector.is_source(1, &entity));
}
#[test]
fn test_metadata_source_detector_kind_untrusted() {
let detector = MetadataSourceDetector;
let entity = create_test_entity(1, "variable", json!({"kind": "untrusted"}));
assert!(detector.is_source(1, &entity));
}
#[test]
fn test_metadata_source_detector_kind_user_input() {
let detector = MetadataSourceDetector;
let entity = create_test_entity(1, "variable", json!({"kind": "user_input"}));
assert!(detector.is_source(1, &entity));
}
#[test]
fn test_metadata_source_detector_taint_field() {
let detector = MetadataSourceDetector;
let entity = create_test_entity(1, "variable", json!({"taint": "source"}));
assert!(detector.is_source(1, &entity));
}
#[test]
fn test_metadata_source_detector_not_a_source() {
let detector = MetadataSourceDetector;
let entity = create_test_entity(1, "variable", json!({"kind": "sanitized"}));
assert!(!detector.is_source(1, &entity));
}
#[test]
fn test_metadata_sink_detector_kind_sink() {
let detector = MetadataSinkDetector;
let entity = create_test_entity(1, "operation", json!({"kind": "sink"}));
assert!(detector.is_sink(1, &entity));
}
#[test]
fn test_metadata_sink_detector_kind_sql_query() {
let detector = MetadataSinkDetector;
let entity = create_test_entity(1, "operation", json!({"kind": "sql_query"}));
assert!(detector.is_sink(1, &entity));
}
#[test]
fn test_metadata_sink_detector_kind_html_output() {
let detector = MetadataSinkDetector;
let entity = create_test_entity(1, "operation", json!({"kind": "html_output"}));
assert!(detector.is_sink(1, &entity));
}
#[test]
fn test_metadata_sink_detector_kind_command() {
let detector = MetadataSinkDetector;
let entity = create_test_entity(1, "operation", json!({"kind": "command"}));
assert!(detector.is_sink(1, &entity));
}
#[test]
fn test_metadata_sink_detector_operation_execute() {
let detector = MetadataSinkDetector;
let entity = create_test_entity(1, "operation", json!({"operation": "execute"}));
assert!(detector.is_sink(1, &entity));
}
#[test]
fn test_metadata_sink_detector_operation_query() {
let detector = MetadataSinkDetector;
let entity = create_test_entity(1, "operation", json!({"operation": "query"}));
assert!(detector.is_sink(1, &entity));
}
#[test]
fn test_metadata_sink_detector_operation_render() {
let detector = MetadataSinkDetector;
let entity = create_test_entity(1, "operation", json!({"operation": "render"}));
assert!(detector.is_sink(1, &entity));
}
#[test]
fn test_metadata_sink_detector_not_a_sink() {
let detector = MetadataSinkDetector;
let entity = create_test_entity(1, "operation", json!({"operation": "validate"}));
assert!(!detector.is_sink(1, &entity));
}
#[test]
fn test_taint_result_new() {
let result = TaintResult::new();
assert!(result.sources.is_empty());
assert!(result.sinks_reached.is_empty());
assert!(result.tainted_nodes.is_empty());
assert!(result.source_sink_paths.is_empty());
assert_eq!(result.size, 0);
}
#[test]
fn test_taint_result_default() {
let result = TaintResult::default();
assert!(result.sources.is_empty());
assert!(result.sinks_reached.is_empty());
assert!(result.tainted_nodes.is_empty());
assert!(result.source_sink_paths.is_empty());
assert_eq!(result.size, 0);
}
#[test]
fn test_taint_result_is_tainted() {
let mut result = TaintResult::new();
result.tainted_nodes.insert(1);
result.tainted_nodes.insert(5);
result.tainted_nodes.insert(10);
assert!(result.is_tainted(1));
assert!(result.is_tainted(5));
assert!(result.is_tainted(10));
assert!(!result.is_tainted(99));
}
#[test]
fn test_taint_result_has_vulnerability() {
let mut result = TaintResult::new();
assert!(!result.has_vulnerability());
result.source_sink_paths.push((1, 5));
assert!(result.has_vulnerability());
}
#[test]
fn test_taint_result_sorted_tainted_nodes() {
let mut result = TaintResult::new();
result.tainted_nodes.insert(10);
result.tainted_nodes.insert(1);
result.tainted_nodes.insert(5);
result.tainted_nodes.insert(3);
let sorted = result.sorted_tainted_nodes();
assert_eq!(sorted, vec![1, 3, 5, 10]);
}
#[test]
fn test_taint_result_sorted_vulnerabilities() {
let mut result = TaintResult::new();
result.source_sink_paths = vec![(5, 10), (1, 3), (3, 5), (1, 10)];
let sorted = result.sorted_vulnerabilities();
assert_eq!(sorted, vec![(1, 3), (1, 10), (3, 5), (5, 10)]);
}
fn create_linear_flow_graph() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().unwrap();
graph
.insert_entity(&GraphEntity {
id: 1,
kind: "variable".to_string(),
name: "source".to_string(),
file_path: None,
data: json!({"kind": "source"}),
})
.unwrap();
for i in 2..=4 {
graph
.insert_entity(&GraphEntity {
id: i,
kind: "variable".to_string(),
name: format!("node_{}", i),
file_path: None,
data: json!({}),
})
.unwrap();
}
graph
.insert_entity(&GraphEntity {
id: 5,
kind: "operation".to_string(),
name: "sink".to_string(),
file_path: None,
data: json!({"kind": "sink"}),
})
.unwrap();
for i in 1..5 {
graph
.insert_edge(&GraphEdge {
id: 0,
from_id: i,
to_id: i + 1,
edge_type: "data_flow".to_string(),
data: json!({}),
})
.unwrap();
}
graph
}
fn create_vulnerable_flow_graph() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().unwrap();
graph
.insert_entity(&GraphEntity {
id: 1,
kind: "variable".to_string(),
name: "user_input".to_string(),
file_path: None,
data: json!({"kind": "source"}),
})
.unwrap();
graph
.insert_entity(&GraphEntity {
id: 2,
kind: "variable".to_string(),
name: "intermediate".to_string(),
file_path: None,
data: json!({}),
})
.unwrap();
graph
.insert_entity(&GraphEntity {
id: 3,
kind: "operation".to_string(),
name: "sql_execute".to_string(),
file_path: None,
data: json!({"kind": "sql_query", "operation": "execute"}),
})
.unwrap();
graph
.insert_edge(&GraphEdge {
id: 0,
from_id: 1,
to_id: 2,
edge_type: "data_flow".to_string(),
data: json!({}),
})
.unwrap();
graph
.insert_edge(&GraphEdge {
id: 0,
from_id: 2,
to_id: 3,
edge_type: "data_flow".to_string(),
data: json!({}),
})
.unwrap();
graph
}
fn create_safe_flow_graph() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().unwrap();
graph
.insert_entity(&GraphEntity {
id: 1,
kind: "variable".to_string(),
name: "user_input".to_string(),
file_path: None,
data: json!({"kind": "source"}),
})
.unwrap();
graph
.insert_entity(&GraphEntity {
id: 2,
kind: "operation".to_string(),
name: "sanitize".to_string(),
file_path: None,
data: json!({"operation": "sanitize"}), })
.unwrap();
graph
.insert_entity(&GraphEntity {
id: 3,
kind: "operation".to_string(),
name: "sql_execute".to_string(),
file_path: None,
data: json!({"kind": "sql_query"}),
})
.unwrap();
graph
.insert_edge(&GraphEdge {
id: 0,
from_id: 1,
to_id: 2,
edge_type: "data_flow".to_string(),
data: json!({}),
})
.unwrap();
graph
}
fn create_multi_source_sink_graph() -> SqliteGraph {
let graph = SqliteGraph::open_in_memory().unwrap();
graph
.insert_entity(&GraphEntity {
id: 1,
kind: "variable".to_string(),
name: "source1".to_string(),
file_path: None,
data: json!({"kind": "source"}),
})
.unwrap();
graph
.insert_entity(&GraphEntity {
id: 2,
kind: "variable".to_string(),
name: "source2".to_string(),
file_path: None,
data: json!({"kind": "untrusted"}),
})
.unwrap();
graph
.insert_entity(&GraphEntity {
id: 3,
kind: "variable".to_string(),
name: "intermediate".to_string(),
file_path: None,
data: json!({}),
})
.unwrap();
graph
.insert_entity(&GraphEntity {
id: 4,
kind: "operation".to_string(),
name: "sink".to_string(),
file_path: None,
data: json!({"kind": "sink"}),
})
.unwrap();
graph
.insert_edge(&GraphEdge {
id: 0,
from_id: 1,
to_id: 3,
edge_type: "data_flow".to_string(),
data: json!({}),
})
.unwrap();
graph
.insert_edge(&GraphEdge {
id: 0,
from_id: 2,
to_id: 3,
edge_type: "data_flow".to_string(),
data: json!({}),
})
.unwrap();
graph
.insert_edge(&GraphEdge {
id: 0,
from_id: 3,
to_id: 4,
edge_type: "data_flow".to_string(),
data: json!({}),
})
.unwrap();
graph
}
#[test]
fn test_propagate_taint_forward_vulnerable() {
let graph = create_vulnerable_flow_graph();
let sources = vec![1];
let sinks = vec![3];
let result = propagate_taint_forward(&graph, &sources, &sinks).unwrap();
assert!(result.has_vulnerability());
assert_eq!(result.sinks_reached.len(), 1);
assert!(result.sinks_reached.contains(&3));
assert_eq!(result.source_sink_paths.len(), 1);
assert_eq!(result.source_sink_paths[0], (1, 3));
assert!(result.is_tainted(1));
assert!(result.is_tainted(2));
assert!(result.is_tainted(3));
}
#[test]
fn test_propagate_taint_forward_safe() {
let graph = create_safe_flow_graph();
let sources = vec![1];
let sinks = vec![3];
let result = propagate_taint_forward(&graph, &sources, &sinks).unwrap();
assert!(!result.has_vulnerability());
assert_eq!(result.sinks_reached.len(), 0);
assert_eq!(result.source_sink_paths.len(), 0);
assert!(result.is_tainted(1));
assert!(result.is_tainted(2));
assert!(!result.is_tainted(3)); }
#[test]
fn test_propagate_taint_forward_multi_source() {
let graph = create_multi_source_sink_graph();
let sources = vec![1, 2];
let sinks = vec![4];
let result = propagate_taint_forward(&graph, &sources, &sinks).unwrap();
assert!(result.has_vulnerability());
assert_eq!(result.sinks_reached.len(), 1);
assert!(result.sinks_reached.contains(&4));
assert_eq!(result.source_sink_paths.len(), 2);
assert_eq!(result.source_sink_paths, vec![(1, 4), (2, 4)]);
}
#[test]
fn test_propagate_taint_forward_multi_sink() {
let graph = create_linear_flow_graph();
let sources = vec![1];
let sinks = vec![3, 5];
let result = propagate_taint_forward(&graph, &sources, &sinks).unwrap();
assert!(result.has_vulnerability());
assert_eq!(result.sinks_reached.len(), 2);
assert_eq!(result.source_sink_paths.len(), 2);
let mut paths = result.source_sink_paths.clone();
paths.sort();
assert_eq!(paths, vec![(1, 3), (1, 5)]);
}
#[test]
fn test_propagate_taint_forward_empty_sources() {
let graph = create_vulnerable_flow_graph();
let sources = vec![];
let sinks = vec![3];
let result = propagate_taint_forward(&graph, &sources, &sinks).unwrap();
assert!(!result.has_vulnerability());
assert_eq!(result.tainted_nodes.len(), 0);
assert_eq!(result.size, 0);
}
#[test]
fn test_propagate_taint_forward_empty_sinks() {
let graph = create_vulnerable_flow_graph();
let sources = vec![1];
let sinks = vec![];
let result = propagate_taint_forward(&graph, &sources, &sinks).unwrap();
assert!(!result.has_vulnerability());
assert!(result.is_tainted(1)); assert!(result.is_tainted(2));
assert!(result.is_tainted(3));
}
#[test]
fn test_propagate_taint_backward_vulnerable() {
let graph = create_vulnerable_flow_graph();
let sources = vec![1];
let sink = 3;
let result = propagate_taint_backward(&graph, sink, &sources).unwrap();
assert_eq!(result.sources.len(), 1);
assert!(result.sources.contains(&1));
assert!(result.sinks_reached.contains(&3));
assert_eq!(result.source_sink_paths.len(), 1);
assert_eq!(result.source_sink_paths[0], (1, 3));
assert!(result.is_tainted(1));
assert!(result.is_tainted(2));
assert!(result.is_tainted(3));
}
#[test]
fn test_propagate_taint_backward_safe() {
let graph = create_safe_flow_graph();
let sources = vec![1];
let sink = 3;
let result = propagate_taint_backward(&graph, sink, &sources).unwrap();
assert_eq!(result.sources.len(), 0);
assert!(!result.has_vulnerability());
}
#[test]
fn test_propagate_taint_backward_multi_source() {
let graph = create_multi_source_sink_graph();
let sources = vec![1, 2];
let sink = 4;
let result = propagate_taint_backward(&graph, sink, &sources).unwrap();
assert_eq!(result.sources.len(), 2);
assert!(result.sources.contains(&1));
assert!(result.sources.contains(&2));
assert_eq!(result.source_sink_paths.len(), 2);
}
#[test]
fn test_propagate_taint_backward_self() {
let graph = create_vulnerable_flow_graph();
let sources = vec![1];
let sink = 1;
let result = propagate_taint_backward(&graph, sink, &sources).unwrap();
assert_eq!(result.sources.len(), 1);
assert!(result.sources.contains(&1));
}
#[test]
fn test_sink_reachability_vulnerability_found() {
let graph = create_vulnerable_flow_graph();
let sources = vec![1];
let sinks = vec![3];
let vulnerabilities = sink_reachability_analysis(&graph, &sources, &sinks).unwrap();
assert_eq!(vulnerabilities.len(), 1);
assert!(vulnerabilities.contains_key(&3));
let affecting_sources = vulnerabilities.get(&3).unwrap();
assert_eq!(affecting_sources.len(), 1);
assert!(affecting_sources.contains(&1));
}
#[test]
fn test_sink_reachability_no_vulnerability() {
let graph = create_safe_flow_graph();
let sources = vec![1];
let sinks = vec![3];
let vulnerabilities = sink_reachability_analysis(&graph, &sources, &sinks).unwrap();
assert_eq!(vulnerabilities.len(), 0);
}
#[test]
fn test_sink_reachability_multi_vulnerabilities() {
let graph = create_multi_source_sink_graph();
let sources = vec![1, 2];
let sinks = vec![4];
let vulnerabilities = sink_reachability_analysis(&graph, &sources, &sinks).unwrap();
assert_eq!(vulnerabilities.len(), 1);
let affecting_sources = vulnerabilities.get(&4).unwrap();
assert_eq!(affecting_sources.len(), 2);
assert!(affecting_sources.contains(&1));
assert!(affecting_sources.contains(&2));
}
#[test]
fn test_discover_sources_and_sinks_metadata() {
let graph = create_vulnerable_flow_graph();
let (sources, sinks) = discover_sources_and_sinks_default(&graph).unwrap();
assert_eq!(sources, vec![1]); assert_eq!(sinks, vec![3]); }
#[test]
fn test_discover_sources_and_sinks_custom() {
struct EvenSourceDetector;
struct OddSinkDetector;
impl SourceCallback for EvenSourceDetector {
fn is_source(&self, node: i64, _entity: &GraphEntity) -> bool {
node % 2 == 0
}
}
impl SinkCallback for OddSinkDetector {
fn is_sink(&self, node: i64, _entity: &GraphEntity) -> bool {
node % 2 == 1
}
}
let graph = create_linear_flow_graph();
let (sources, sinks) =
discover_sources_and_sinks(&graph, &EvenSourceDetector, &OddSinkDetector).unwrap();
assert_eq!(sources.len(), 2);
assert!(sources.contains(&2));
assert!(sources.contains(&4));
assert_eq!(sinks.len(), 3);
assert!(sinks.contains(&1));
assert!(sinks.contains(&3));
assert!(sinks.contains(&5));
}
#[test]
fn test_discover_empty_graph() {
let graph = SqliteGraph::open_in_memory().unwrap();
let (sources, sinks) = discover_sources_and_sinks_default(&graph).unwrap();
assert_eq!(sources.len(), 0);
assert_eq!(sinks.len(), 0);
}
#[test]
fn test_propagate_taint_forward_with_progress_matches() {
use crate::progress::NoProgress;
let graph = create_vulnerable_flow_graph();
let sources = vec![1];
let sinks = vec![3];
let base_result = propagate_taint_forward(&graph, &sources, &sinks).unwrap();
let progress_result =
propagate_taint_forward_with_progress(&graph, &sources, &sinks, &NoProgress).unwrap();
assert_eq!(base_result.sources, progress_result.sources);
assert_eq!(base_result.sinks_reached, progress_result.sinks_reached);
assert_eq!(base_result.tainted_nodes, progress_result.tainted_nodes);
assert_eq!(
base_result.source_sink_paths,
progress_result.source_sink_paths
);
assert_eq!(base_result.size, progress_result.size);
}
#[test]
fn test_propagate_taint_backward_with_progress_matches() {
use crate::progress::NoProgress;
let graph = create_vulnerable_flow_graph();
let sources = vec![1];
let sink = 3;
let base_result = propagate_taint_backward(&graph, sink, &sources).unwrap();
let progress_result =
propagate_taint_backward_with_progress(&graph, sink, &sources, &NoProgress).unwrap();
assert_eq!(base_result.sources, progress_result.sources);
assert_eq!(base_result.sinks_reached, progress_result.sinks_reached);
assert_eq!(base_result.tainted_nodes, progress_result.tainted_nodes);
assert_eq!(
base_result.source_sink_paths,
progress_result.source_sink_paths
);
assert_eq!(base_result.size, progress_result.size);
}
#[test]
fn test_sink_reachability_with_progress_matches() {
use crate::progress::NoProgress;
let graph = create_multi_source_sink_graph();
let sources = vec![1, 2];
let sinks = vec![4];
let base_result = sink_reachability_analysis(&graph, &sources, &sinks).unwrap();
let progress_result =
sink_reachability_analysis_with_progress(&graph, &sources, &sinks, &NoProgress)
.unwrap();
assert_eq!(base_result.len(), progress_result.len());
for (sink, sources) in base_result {
assert!(progress_result.contains_key(&sink));
let mut expected: Vec<i64> = sources.clone();
let mut actual: Vec<i64> = progress_result.get(&sink).unwrap().clone();
expected.sort();
actual.sort();
assert_eq!(expected, actual, "Sources for sink {} don't match", sink);
}
}
}