use crate::index::graph::DependencyGraph;
use std::collections::{HashMap, HashSet};
pub fn compute_pagerank(
graph: &DependencyGraph,
damping: f64,
max_iterations: usize,
) -> HashMap<String, f64> {
let mut nodes: HashSet<String> = HashSet::new();
for (src, targets) in &graph.edges {
nodes.insert(src.clone());
for e in targets {
nodes.insert(e.target.clone());
}
}
for (dst, sources) in &graph.reverse_edges {
nodes.insert(dst.clone());
for e in sources {
nodes.insert(e.target.clone());
}
}
let n = nodes.len();
if n == 0 {
return HashMap::new();
}
let nodes: Vec<String> = {
let mut v: Vec<String> = nodes.into_iter().collect();
v.sort(); v
};
let out_degree: HashMap<&str, usize> = nodes
.iter()
.map(|node| {
let degree = graph.edges.get(node.as_str()).map(|s| s.len()).unwrap_or(0);
(node.as_str(), degree)
})
.collect();
let initial_rank = 1.0 / n as f64;
let mut rank: HashMap<&str, f64> = nodes
.iter()
.map(|node| (node.as_str(), initial_rank))
.collect();
let teleport = (1.0 - damping) / n as f64;
let convergence_threshold = 1e-6_f64;
for _ in 0..max_iterations {
let dangling_sum: f64 = nodes
.iter()
.filter(|node| out_degree[node.as_str()] == 0)
.map(|node| rank[node.as_str()])
.sum();
let dangling_contrib = damping * dangling_sum / n as f64;
let mut new_rank: HashMap<&str, f64> = HashMap::with_capacity(n);
for node in &nodes {
let inbound: f64 = graph
.reverse_edges
.get(node.as_str())
.map(|importers| {
importers
.iter()
.filter_map(|e| {
let importer = e.target.as_str();
let deg = out_degree.get(importer).copied().unwrap_or(0);
if deg == 0 {
None
} else {
Some(rank[importer] / deg as f64)
}
})
.sum()
})
.unwrap_or(0.0);
new_rank.insert(
node.as_str(),
teleport + dangling_contrib + damping * inbound,
);
}
let max_delta = nodes
.iter()
.map(|node| (new_rank[node.as_str()] - rank[node.as_str()]).abs())
.fold(0.0_f64, f64::max);
rank = new_rank;
if max_delta < convergence_threshold {
break;
}
}
let max_rank = rank.values().copied().fold(0.0_f64, f64::max);
nodes
.iter()
.map(|node| {
let normalised = if max_rank > 0.0 {
rank[node.as_str()] / max_rank
} else {
0.0
};
(node.clone(), normalised)
})
.collect()
}
pub fn build_symbol_cross_refs(
term_frequencies: &HashMap<String, HashMap<String, u32>>,
) -> HashMap<String, HashSet<String>> {
let mut cross_refs: HashMap<String, HashSet<String>> = HashMap::new();
for (file_path, terms) in term_frequencies {
for term in terms.keys() {
cross_refs
.entry(term.clone())
.or_default()
.insert(file_path.clone());
}
}
cross_refs
}
pub fn symbol_importance(
symbol: &crate::parser::language::Symbol,
file_pagerank: f64,
cross_refs: &HashMap<String, HashSet<String>>,
file_path: &str,
) -> f64 {
use crate::parser::language::Visibility;
let weight = match symbol.visibility {
Visibility::Public => {
let name_lower = symbol.name.to_lowercase();
let referenced_elsewhere = cross_refs
.get(&name_lower)
.map(|files| files.iter().any(|f| f != file_path))
.unwrap_or(false);
if referenced_elsewhere {
1.0
} else {
0.7
}
}
Visibility::Private => 0.3,
};
file_pagerank * weight
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::language::{Symbol, SymbolKind, Visibility};
use crate::schema::EdgeType;
const DAMPING: f64 = 0.85;
const MAX_ITER: usize = 100;
fn make_graph(edges: &[(&str, &str)]) -> DependencyGraph {
let mut g = DependencyGraph::new();
for &(from, to) in edges {
g.add_edge(from, to, EdgeType::Import);
}
g
}
fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() < tol
}
#[test]
fn test_pagerank_empty_graph() {
let graph = DependencyGraph::new();
let scores = compute_pagerank(&graph, DAMPING, MAX_ITER);
assert!(scores.is_empty(), "empty graph should produce empty scores");
}
#[test]
fn test_pagerank_single_edge() {
let graph = make_graph(&[("a.rs", "b.rs")]);
let scores = compute_pagerank(&graph, DAMPING, MAX_ITER);
assert!(scores.contains_key("a.rs"));
assert!(scores.contains_key("b.rs"));
for &v in scores.values() {
assert!((0.0..=1.0 + 1e-9).contains(&v), "score out of range: {v}");
}
let b_score = scores["b.rs"];
let a_score = scores["a.rs"];
assert!(
approx_eq(b_score, 1.0, 1e-6),
"b.rs should be normalised to 1.0, got {b_score}"
);
assert!(
a_score < b_score,
"importer a.rs ({a_score}) should rank lower than imported b.rs ({b_score})"
);
}
#[test]
fn test_pagerank_linear_chain() {
let graph = make_graph(&[("a.rs", "b.rs"), ("b.rs", "c.rs")]);
let scores = compute_pagerank(&graph, DAMPING, MAX_ITER);
assert_eq!(scores.len(), 3);
let a = scores["a.rs"];
let b = scores["b.rs"];
let c = scores["c.rs"];
assert!(
approx_eq(c, 1.0, 1e-6),
"c.rs should be top-ranked (1.0), got {c}"
);
assert!(b > a, "b.rs ({b}) should rank higher than a.rs ({a})");
}
#[test]
fn test_pagerank_star_pattern() {
let graph = make_graph(&[
("a.rs", "hub.rs"),
("b.rs", "hub.rs"),
("c.rs", "hub.rs"),
("d.rs", "hub.rs"),
]);
let scores = compute_pagerank(&graph, DAMPING, MAX_ITER);
assert_eq!(scores.len(), 5);
let hub = scores["hub.rs"];
assert!(
approx_eq(hub, 1.0, 1e-6),
"hub.rs should be normalised to 1.0, got {hub}"
);
let leaves = ["a.rs", "b.rs", "c.rs", "d.rs"];
let leaf_score = scores[leaves[0]];
for &l in &leaves[1..] {
assert!(
approx_eq(scores[l], leaf_score, 1e-6),
"leaf {l} score {} differs from first leaf {leaf_score}",
scores[l]
);
}
assert!(
hub > leaf_score,
"hub ({hub}) should score higher than leaves ({leaf_score})"
);
}
#[test]
fn test_pagerank_cycle() {
let graph = make_graph(&[("a.rs", "b.rs"), ("b.rs", "c.rs"), ("c.rs", "a.rs")]);
let scores = compute_pagerank(&graph, DAMPING, MAX_ITER);
assert_eq!(scores.len(), 3);
let a = scores["a.rs"];
let b = scores["b.rs"];
let c = scores["c.rs"];
assert!(
approx_eq(a, b, 1e-6) && approx_eq(b, c, 1e-6),
"cycle nodes should have equal rank: a={a}, b={b}, c={c}"
);
assert!(
approx_eq(a, 1.0, 1e-6),
"normalised cycle rank should be 1.0, got {a}"
);
}
#[test]
fn test_pagerank_disconnected() {
let graph = make_graph(&[("a.rs", "b.rs"), ("c.rs", "d.rs")]);
let scores = compute_pagerank(&graph, DAMPING, MAX_ITER);
assert_eq!(scores.len(), 4);
for &v in scores.values() {
assert!((0.0..=1.0 + 1e-9).contains(&v), "score out of range: {v}");
}
let max_score = scores.values().copied().fold(0.0_f64, f64::max);
assert!(
approx_eq(max_score, 1.0, 1e-6),
"normalised max should be 1.0, got {max_score}"
);
assert!(
approx_eq(scores["b.rs"], scores["d.rs"], 1e-6),
"symmetric leaves b={} d={} should be equal",
scores["b.rs"],
scores["d.rs"]
);
assert!(
approx_eq(scores["a.rs"], scores["c.rs"], 1e-6),
"symmetric importers a={} c={} should be equal",
scores["a.rs"],
scores["c.rs"]
);
}
#[test]
fn test_pagerank_convergence() {
let mut graph = DependencyGraph::new();
for i in 0..49_usize {
graph.add_edge(
&format!("file_{i:02}.rs"),
&format!("file_{:02}.rs", i + 1),
EdgeType::Import,
);
}
let scores = compute_pagerank(&graph, DAMPING, MAX_ITER);
assert_eq!(scores.len(), 50);
for (path, &score) in &scores {
assert!(
(0.0..=1.0 + 1e-9).contains(&score),
"{path} has out-of-range score {score}"
);
}
let top = scores["file_49.rs"];
assert!(
approx_eq(top, 1.0, 1e-6),
"file_49.rs should be normalised to 1.0, got {top}"
);
}
#[test]
fn test_pagerank_normalized() {
let topologies: Vec<Vec<(&str, &str)>> = vec![
vec![("x.rs", "y.rs")],
vec![("x.rs", "y.rs"), ("y.rs", "z.rs")],
vec![
("a.rs", "common.rs"),
("b.rs", "common.rs"),
("c.rs", "common.rs"),
],
vec![
("a.rs", "b.rs"),
("b.rs", "c.rs"),
("c.rs", "a.rs"),
("d.rs", "a.rs"),
],
];
for edges in &topologies {
let graph = make_graph(edges);
let scores = compute_pagerank(&graph, DAMPING, MAX_ITER);
let max_score = scores.values().copied().fold(0.0_f64, f64::max);
assert!(
approx_eq(max_score, 1.0, 1e-6),
"max score should be 1.0 for topology {edges:?}, got {max_score}"
);
for &v in scores.values() {
assert!(
(0.0..=1.0 + 1e-9).contains(&v),
"score {v} out of [0.0, 1.0] range"
);
}
}
}
fn make_symbol(name: &str, visibility: Visibility) -> Symbol {
Symbol {
name: name.to_string(),
kind: SymbolKind::Function,
visibility,
signature: format!("pub fn {}()", name),
body: "{}".to_string(),
start_line: 1,
end_line: 1,
}
}
fn term_freq(pairs: &[(&str, u32)]) -> HashMap<String, u32> {
pairs.iter().map(|&(k, v)| (k.to_string(), v)).collect()
}
#[test]
fn test_build_symbol_cross_refs() {
let mut tf: HashMap<String, HashMap<String, u32>> = HashMap::new();
tf.insert(
"a.rs".to_string(),
term_freq(&[("connect", 3), ("query", 1)]),
);
tf.insert(
"b.rs".to_string(),
term_freq(&[("connect", 1), ("render", 2)]),
);
tf.insert("c.rs".to_string(), term_freq(&[("render", 1)]));
let refs = build_symbol_cross_refs(&tf);
let connect_files = refs
.get("connect")
.expect("connect should be in cross_refs");
assert!(
connect_files.contains("a.rs"),
"connect should reference a.rs"
);
assert!(
connect_files.contains("b.rs"),
"connect should reference b.rs"
);
assert!(
!connect_files.contains("c.rs"),
"connect should not reference c.rs"
);
let render_files = refs.get("render").expect("render should be in cross_refs");
assert!(
render_files.contains("b.rs"),
"render should reference b.rs"
);
assert!(
render_files.contains("c.rs"),
"render should reference c.rs"
);
assert!(
!render_files.contains("a.rs"),
"render should not reference a.rs"
);
let query_files = refs.get("query").expect("query should be in cross_refs");
assert_eq!(query_files.len(), 1);
assert!(query_files.contains("a.rs"));
}
#[test]
fn test_symbol_importance_public_referenced() {
let mut tf: HashMap<String, HashMap<String, u32>> = HashMap::new();
tf.insert("a.rs".to_string(), term_freq(&[("connect", 3)]));
tf.insert("b.rs".to_string(), term_freq(&[("connect", 1)]));
let refs = build_symbol_cross_refs(&tf);
let sym = make_symbol("connect", Visibility::Public);
let importance = symbol_importance(&sym, 0.8, &refs, "a.rs");
assert!(
(importance - 0.8).abs() < 1e-9,
"public+referenced: expected 0.8 * 1.0 = 0.8, got {importance}"
);
}
#[test]
fn test_symbol_importance_public_unreferenced() {
let mut tf: HashMap<String, HashMap<String, u32>> = HashMap::new();
tf.insert("a.rs".to_string(), term_freq(&[("unique", 1)]));
let refs = build_symbol_cross_refs(&tf);
let sym = make_symbol("unique_fn", Visibility::Public);
let importance = symbol_importance(&sym, 0.8, &refs, "a.rs");
assert!(
(importance - 0.56).abs() < 1e-9,
"public+unreferenced: expected 0.8 * 0.7 = 0.56, got {importance}"
);
}
#[test]
fn test_symbol_importance_private() {
let refs: HashMap<String, HashSet<String>> = HashMap::new();
let sym = make_symbol("internal_helper", Visibility::Private);
let importance = symbol_importance(&sym, 0.8, &refs, "a.rs");
assert!(
(importance - 0.24).abs() < 1e-9,
"private: expected 0.8 * 0.3 = 0.24, got {importance}"
);
}
}