use std::collections::{HashSet, VecDeque};
use petgraph::Direction;
use petgraph::stable_graph::NodeIndex;
use petgraph::visit::EdgeRef;
use crate::graph::{CodeGraph, edge::EdgeKind, node::GraphNode};
#[derive(Debug, Clone, PartialEq, serde::Serialize)]
pub struct FlowPath {
pub hops: Vec<String>,
pub depth: usize,
}
#[derive(Debug, Clone, PartialEq, serde::Serialize)]
pub struct FlowResult {
pub paths: Vec<FlowPath>,
pub shared_dependency: Option<String>,
}
pub fn trace_flow(
graph: &CodeGraph,
entry: &str,
target: &str,
max_paths: usize,
max_depth: usize,
) -> FlowResult {
let entry_indices = match graph.symbol_index.get(entry) {
Some(v) if !v.is_empty() => v.clone(),
_ => {
return FlowResult {
paths: Vec::new(),
shared_dependency: None,
};
}
};
let target_indices: HashSet<NodeIndex> = match graph.symbol_index.get(target) {
Some(v) if !v.is_empty() => v.iter().cloned().collect(),
_ => {
return FlowResult {
paths: Vec::new(),
shared_dependency: None,
};
}
};
let entry_idx = entry_indices[0];
let mut found_paths: Vec<FlowPath> = Vec::new();
let mut path: Vec<NodeIndex> = vec![entry_idx];
let mut visited: HashSet<NodeIndex> = HashSet::new();
visited.insert(entry_idx);
dfs_trace(
graph,
&target_indices,
max_paths,
max_depth,
&mut path,
&mut visited,
&mut found_paths,
);
let shared_dependency = if found_paths.is_empty() {
find_shared_dependency(graph, entry_idx, &target_indices, max_depth)
} else {
None
};
FlowResult {
paths: found_paths,
shared_dependency,
}
}
fn dfs_trace(
graph: &CodeGraph,
target_indices: &HashSet<NodeIndex>,
max_paths: usize,
max_depth: usize,
path: &mut Vec<NodeIndex>,
visited: &mut HashSet<NodeIndex>,
found_paths: &mut Vec<FlowPath>,
) {
let current = *path.last().unwrap();
if target_indices.contains(¤t) && path.len() > 1 {
let hops: Vec<String> = path
.iter()
.map(|&idx| node_symbol_name(graph, idx))
.collect();
let depth = hops.len() - 1;
found_paths.push(FlowPath { hops, depth });
return; }
if path.len() > max_depth {
return;
}
for edge_ref in graph.graph.edges_directed(current, Direction::Outgoing) {
if found_paths.len() >= max_paths {
return;
}
if !matches!(
edge_ref.weight(),
EdgeKind::Calls | EdgeKind::ResolvedImport { .. }
) {
continue;
}
let neighbor = edge_ref.target();
if visited.contains(&neighbor) {
continue;
}
match &graph.graph[neighbor] {
GraphNode::Symbol(_) | GraphNode::File(_) => {}
_ => continue,
}
path.push(neighbor);
visited.insert(neighbor);
dfs_trace(
graph,
target_indices,
max_paths,
max_depth,
path,
visited,
found_paths,
);
path.pop();
visited.remove(&neighbor);
}
}
fn find_shared_dependency(
graph: &CodeGraph,
entry_idx: NodeIndex,
target_indices: &HashSet<NodeIndex>,
max_depth: usize,
) -> Option<String> {
let from_entry = reachable_set(graph, entry_idx, max_depth);
let mut shared_name: Option<String> = None;
let mut best_depth = usize::MAX;
for &target_idx in target_indices {
let from_target = reachable_set(graph, target_idx, max_depth);
for (node, depth) in &from_entry {
if from_target.contains_key(node) && *depth < best_depth {
best_depth = *depth;
shared_name = Some(node_symbol_name(graph, *node));
}
}
}
shared_name
}
fn reachable_set(
graph: &CodeGraph,
start: NodeIndex,
max_depth: usize,
) -> std::collections::HashMap<NodeIndex, usize> {
let mut visited: std::collections::HashMap<NodeIndex, usize> = std::collections::HashMap::new();
let mut queue: VecDeque<(NodeIndex, usize)> = VecDeque::new();
queue.push_back((start, 0));
visited.insert(start, 0);
while let Some((current, depth)) = queue.pop_front() {
if depth >= max_depth {
continue;
}
for edge_ref in graph.graph.edges_directed(current, Direction::Outgoing) {
if !matches!(
edge_ref.weight(),
EdgeKind::Calls | EdgeKind::ResolvedImport { .. }
) {
continue;
}
let neighbor = edge_ref.target();
if let std::collections::hash_map::Entry::Vacant(e) = visited.entry(neighbor) {
e.insert(depth + 1);
queue.push_back((neighbor, depth + 1));
}
}
}
visited
}
fn node_symbol_name(graph: &CodeGraph, idx: NodeIndex) -> String {
match &graph.graph[idx] {
GraphNode::Symbol(info) => info.name.clone(),
GraphNode::File(fi) => fi
.path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("?")
.to_string(),
_ => "?".to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
use crate::graph::node::{SymbolInfo, SymbolKind};
fn root() -> PathBuf {
PathBuf::from("/proj")
}
fn graph_linear_chain() -> (crate::graph::CodeGraph, PathBuf) {
let r = root();
let mut g = crate::graph::CodeGraph::new();
let f = g.add_file(r.join("src/main.rs"), "rust");
let a = g.add_symbol(
f,
SymbolInfo {
name: "A".into(),
kind: SymbolKind::Function,
line: 1,
..Default::default()
},
);
let b = g.add_symbol(
f,
SymbolInfo {
name: "B".into(),
kind: SymbolKind::Function,
line: 5,
..Default::default()
},
);
let c = g.add_symbol(
f,
SymbolInfo {
name: "C".into(),
kind: SymbolKind::Function,
line: 10,
..Default::default()
},
);
g.add_calls_edge(a, b);
g.add_calls_edge(b, c);
(g, r)
}
#[test]
fn test_trace_flow_basic() {
let (graph, r) = graph_linear_chain();
let _ = r;
let result = trace_flow(&graph, "A", "C", 3, 20);
assert_eq!(result.paths.len(), 1, "expected exactly 1 path A->B->C");
assert_eq!(result.paths[0].hops, vec!["A", "B", "C"]);
assert_eq!(result.paths[0].depth, 2);
}
#[test]
fn test_trace_flow_multiple_paths() {
let r = root();
let mut g = crate::graph::CodeGraph::new();
let f = g.add_file(r.join("src/main.rs"), "rust");
let a = g.add_symbol(
f,
SymbolInfo {
name: "A".into(),
kind: SymbolKind::Function,
line: 1,
..Default::default()
},
);
let b = g.add_symbol(
f,
SymbolInfo {
name: "B".into(),
kind: SymbolKind::Function,
line: 5,
..Default::default()
},
);
let d = g.add_symbol(
f,
SymbolInfo {
name: "D".into(),
kind: SymbolKind::Function,
line: 8,
..Default::default()
},
);
let c = g.add_symbol(
f,
SymbolInfo {
name: "C".into(),
kind: SymbolKind::Function,
line: 10,
..Default::default()
},
);
g.add_calls_edge(a, b);
g.add_calls_edge(b, c);
g.add_calls_edge(a, d);
g.add_calls_edge(d, c);
let result = trace_flow(&g, "A", "C", 3, 20);
assert_eq!(result.paths.len(), 2, "expected 2 paths (A->B->C, A->D->C)");
}
#[test]
fn test_trace_flow_max_paths() {
let r = root();
let mut g = crate::graph::CodeGraph::new();
let f = g.add_file(r.join("src/main.rs"), "rust");
let a = g.add_symbol(
f,
SymbolInfo {
name: "A".into(),
kind: SymbolKind::Function,
line: 1,
..Default::default()
},
);
let c = g.add_symbol(
f,
SymbolInfo {
name: "C".into(),
kind: SymbolKind::Function,
line: 10,
..Default::default()
},
);
for i in 0..4 {
let b = g.add_symbol(
f,
SymbolInfo {
name: format!("B{i}"),
kind: SymbolKind::Function,
line: 20 + i,
..Default::default()
},
);
g.add_calls_edge(a, b);
g.add_calls_edge(b, c);
}
let result = trace_flow(&g, "A", "C", 3, 20);
assert_eq!(
result.paths.len(),
3,
"max_paths=3 should return at most 3 paths"
);
}
#[test]
fn test_trace_flow_cycle_safety() {
let r = root();
let mut g = crate::graph::CodeGraph::new();
let f = g.add_file(r.join("src/main.rs"), "rust");
let a = g.add_symbol(
f,
SymbolInfo {
name: "A".into(),
kind: SymbolKind::Function,
line: 1,
..Default::default()
},
);
let b = g.add_symbol(
f,
SymbolInfo {
name: "B".into(),
kind: SymbolKind::Function,
line: 5,
..Default::default()
},
);
let c = g.add_symbol(
f,
SymbolInfo {
name: "C".into(),
kind: SymbolKind::Function,
line: 10,
..Default::default()
},
);
g.add_calls_edge(a, b);
g.add_calls_edge(b, c);
g.add_calls_edge(c, a);
let result = trace_flow(&g, "A", "C", 3, 20);
assert_eq!(
result.paths.len(),
1,
"cycle graph: should find path A->B->C"
);
assert_eq!(result.paths[0].hops, vec!["A", "B", "C"]);
}
#[test]
fn test_trace_flow_depth_cap() {
let r = root();
let mut g = crate::graph::CodeGraph::new();
let f = g.add_file(r.join("src/main.rs"), "rust");
let mut prev = g.add_symbol(
f,
SymbolInfo {
name: "S0".into(),
kind: SymbolKind::Function,
line: 1,
..Default::default()
},
);
for i in 1..=25 {
let next = g.add_symbol(
f,
SymbolInfo {
name: format!("S{i}"),
kind: SymbolKind::Function,
line: i + 1,
..Default::default()
},
);
g.add_calls_edge(prev, next);
prev = next;
}
let result = trace_flow(&g, "S0", "S25", 3, 20);
assert!(
result.paths.is_empty(),
"depth-25 chain with max_depth=20 should return no paths"
);
}
#[test]
fn test_trace_flow_no_path() {
let (graph, _) = graph_linear_chain();
let result = trace_flow(&graph, "A", "NonExistent", 3, 20);
assert!(
result.paths.is_empty(),
"no path to unknown symbol expected"
);
}
#[test]
fn test_trace_flow_follows_calls_and_imports() {
let r = root();
let mut g = crate::graph::CodeGraph::new();
let fa = g.add_file(r.join("src/a.rs"), "rust");
let fb = g.add_file(r.join("src/b.rs"), "rust");
let a = g.add_symbol(
fa,
SymbolInfo {
name: "funcA".into(),
kind: SymbolKind::Function,
line: 1,
..Default::default()
},
);
let b = g.add_symbol(
fa,
SymbolInfo {
name: "funcB".into(),
kind: SymbolKind::Function,
line: 5,
..Default::default()
},
);
g.add_calls_edge(a, b);
g.add_resolved_import(fa, fb, "./b");
let result = trace_flow(&g, "funcA", "funcB", 3, 20);
assert!(
!result.paths.is_empty(),
"Calls edge from funcA to funcB should be traceable"
);
assert_eq!(result.paths[0].hops, vec!["funcA", "funcB"]);
}
}