use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use petgraph::Direction;
use petgraph::graph::{DiGraph, NodeIndex};
use crate::types::{RankingConfig, Tag};
pub struct PageRanker {
config: RankingConfig,
}
impl PageRanker {
pub fn new(config: RankingConfig) -> Self {
Self { config }
}
pub fn compute_ranks(
&self,
tags_by_file: &HashMap<String, Vec<Tag>>,
chat_fnames: &[String],
) -> HashMap<String, f64> {
let defines = self.build_defines_index(tags_by_file);
let (graph, node_map, index_map) = self.build_graph(tags_by_file, &defines);
if graph.node_count() == 0 {
return HashMap::new();
}
let chat_rel_fnames: HashSet<String> = chat_fnames
.iter()
.map(|f| self.extract_rel_fname(f))
.collect();
let personalization = self.build_personalization(&node_map, &chat_rel_fnames);
let ranks = self.pagerank(&graph, &personalization, &index_map);
let mut result = HashMap::new();
for (node_idx, rank) in ranks {
if let Some(rel_fname) = index_map.get(&node_idx) {
result.insert(rel_fname.clone(), rank);
}
}
result
}
fn build_defines_index(
&self,
tags_by_file: &HashMap<String, Vec<Tag>>,
) -> HashMap<Arc<str>, HashSet<String>> {
let mut defines: HashMap<Arc<str>, HashSet<String>> = HashMap::new();
for (fname, tags) in tags_by_file {
let rel_fname = self.extract_rel_fname(fname);
for tag in tags {
if tag.is_def() {
defines
.entry(Arc::clone(&tag.name))
.or_insert_with(HashSet::new)
.insert(rel_fname.clone());
}
}
}
defines
}
fn build_graph(
&self,
tags_by_file: &HashMap<String, Vec<Tag>>,
defines: &HashMap<Arc<str>, HashSet<String>>,
) -> (
DiGraph<(), ()>,
HashMap<String, NodeIndex>,
HashMap<NodeIndex, String>,
) {
let mut graph = DiGraph::new();
let mut node_map: HashMap<String, NodeIndex> = HashMap::new();
let mut index_map: HashMap<NodeIndex, String> = HashMap::new();
for fname in tags_by_file.keys() {
let rel_fname = self.extract_rel_fname(fname);
if !node_map.contains_key(&rel_fname) {
let idx = graph.add_node(());
node_map.insert(rel_fname.clone(), idx);
index_map.insert(idx, rel_fname);
}
}
for (fname, tags) in tags_by_file {
let ref_fname = self.extract_rel_fname(fname);
let ref_node = match node_map.get(&ref_fname) {
Some(n) => *n,
None => continue,
};
for tag in tags {
if tag.is_ref() {
if let Some(def_fnames) = defines.get(&tag.name) {
for def_fname in def_fnames {
if def_fname != &ref_fname {
if let Some(&def_node) = node_map.get(def_fname) {
graph.add_edge(ref_node, def_node, ());
}
}
}
}
}
}
}
(graph, node_map, index_map)
}
fn build_personalization(
&self,
node_map: &HashMap<String, NodeIndex>,
chat_fnames: &HashSet<String>,
) -> HashMap<NodeIndex, f64> {
let mut personalization = HashMap::new();
for (rel_fname, &node_idx) in node_map {
let weight = self.personalization_weight(rel_fname, chat_fnames);
personalization.insert(node_idx, weight);
}
personalization
}
fn personalization_weight(&self, rel_fname: &str, chat_fnames: &HashSet<String>) -> f64 {
let depth = rel_fname.matches('/').count();
let is_vendor = self
.config
.vendor_patterns
.iter()
.any(|pattern| rel_fname.contains(pattern.as_str()));
let base_weight = if is_vendor {
self.config.depth_weight_vendor
} else if depth <= self.config.depth_threshold_shallow {
self.config.depth_weight_root
} else if depth <= self.config.depth_threshold_moderate {
self.config.depth_weight_moderate
} else {
self.config.depth_weight_deep
};
if chat_fnames.contains(rel_fname) {
base_weight * self.config.pagerank_chat_multiplier
} else {
base_weight
}
}
fn pagerank(
&self,
graph: &DiGraph<(), ()>,
personalization: &HashMap<NodeIndex, f64>,
_index_map: &HashMap<NodeIndex, String>,
) -> HashMap<NodeIndex, f64> {
let alpha = self.config.pagerank_alpha;
let epsilon = 1e-8;
let max_iterations = 100;
let n = graph.node_count();
if n == 0 {
return HashMap::new();
}
let total_personalization: f64 = personalization.values().sum();
let normalized_personalization: HashMap<NodeIndex, f64> = personalization
.iter()
.map(|(&idx, &weight)| (idx, weight / total_personalization))
.collect();
let init_rank = 1.0 / n as f64;
let mut ranks: HashMap<NodeIndex, f64> =
graph.node_indices().map(|idx| (idx, init_rank)).collect();
let mut new_ranks = ranks.clone();
for _iteration in 0..max_iterations {
let mut dangling_sum = 0.0;
for node in graph.node_indices() {
let out_degree = graph.neighbors_directed(node, Direction::Outgoing).count();
if out_degree == 0 {
dangling_sum += ranks[&node];
}
}
for node in graph.node_indices() {
let mut incoming_sum = 0.0;
for predecessor in graph.neighbors_directed(node, Direction::Incoming) {
let pred_rank = ranks[&predecessor];
let out_degree = graph
.neighbors_directed(predecessor, Direction::Outgoing)
.count();
if out_degree > 0 {
incoming_sum += pred_rank / out_degree as f64;
}
}
let personalization_value = normalized_personalization
.get(&node)
.copied()
.unwrap_or(1.0 / n as f64);
new_ranks.insert(
node,
(1.0 - alpha) * personalization_value
+ alpha * incoming_sum
+ alpha * dangling_sum * personalization_value, );
}
let max_change = ranks
.iter()
.map(|(node, &old_rank)| (new_ranks[node] - old_rank).abs())
.fold(0.0_f64, f64::max);
if max_change < epsilon {
break;
}
std::mem::swap(&mut ranks, &mut new_ranks);
}
ranks
}
fn extract_rel_fname(&self, abs_fname: &str) -> String {
abs_fname.strip_prefix("/").unwrap_or(abs_fname).to_string()
}
pub fn compute_function_ranks(
&self,
call_graph: &crate::callgraph::CallGraph,
) -> HashMap<crate::callgraph::FunctionId, f64> {
use petgraph::visit::EdgeRef;
let inner = call_graph.inner();
let n = inner.node_count();
if n == 0 {
return HashMap::new();
}
let alpha = self.config.pagerank_alpha;
let epsilon = 1e-8;
let max_iterations = 100;
let mut personalization: HashMap<petgraph::graph::NodeIndex, f64> = HashMap::new();
for node_idx in inner.node_indices() {
if let Some(func) = inner.node_weight(node_idx) {
let weight = self.personalization_weight(func.file.as_ref(), &HashSet::new());
personalization.insert(node_idx, weight);
}
}
let total: f64 = personalization.values().sum();
if total > 0.0 {
for v in personalization.values_mut() {
*v /= total;
}
}
let init_rank = 1.0 / n as f64;
let mut ranks: HashMap<petgraph::graph::NodeIndex, f64> =
inner.node_indices().map(|idx| (idx, init_rank)).collect();
let mut new_ranks = ranks.clone();
for _iteration in 0..max_iterations {
let mut dangling_sum = 0.0;
for node in inner.node_indices() {
let out_degree = inner.edges(node).count();
if out_degree == 0 {
dangling_sum += ranks[&node];
}
}
for node in inner.node_indices() {
let mut incoming_sum = 0.0;
for edge in inner.edges_directed(node, petgraph::Direction::Incoming) {
let caller = edge.source();
let caller_rank = ranks[&caller];
let out_degree = inner.edges(caller).count();
if out_degree > 0 {
let confidence = edge.weight().confidence;
incoming_sum += (caller_rank * confidence) / out_degree as f64;
}
}
let p_value = personalization
.get(&node)
.copied()
.unwrap_or(1.0 / n as f64);
new_ranks.insert(
node,
(1.0 - alpha) * p_value + alpha * incoming_sum + alpha * dangling_sum * p_value,
);
}
let max_change = ranks
.iter()
.map(|(node, &old_rank)| (new_ranks[node] - old_rank).abs())
.fold(0.0_f64, f64::max);
if max_change < epsilon {
break;
}
std::mem::swap(&mut ranks, &mut new_ranks);
}
let mut result = HashMap::new();
for (node_idx, rank) in ranks {
if let Some(func_id) = inner.node_weight(node_idx) {
result.insert(func_id.clone(), rank);
}
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::TagKind;
fn make_tag(rel_fname: &str, name: &str, kind: TagKind) -> Tag {
Tag {
rel_fname: Arc::from(rel_fname),
fname: Arc::from(format!("/{}", rel_fname)),
line: 1,
name: Arc::from(name),
kind,
node_type: Arc::from("function"),
parent_name: None,
parent_line: None,
signature: None,
fields: None,
metadata: None,
}
}
#[test]
fn test_simple_pagerank() {
let config = RankingConfig::default();
let ranker = PageRanker::new(config);
let mut tags_by_file = HashMap::new();
tags_by_file.insert(
"/a.rs".to_string(),
vec![make_tag("a.rs", "foo", TagKind::Def)],
);
tags_by_file.insert(
"/b.rs".to_string(),
vec![make_tag("b.rs", "foo", TagKind::Ref)],
);
tags_by_file.insert(
"/c.rs".to_string(),
vec![make_tag("c.rs", "foo", TagKind::Ref)],
);
let chat_fnames = vec![];
let ranks = ranker.compute_ranks(&tags_by_file, &chat_fnames);
assert!(ranks["a.rs"] > ranks["b.rs"]);
assert!(ranks["a.rs"] > ranks["c.rs"]);
}
#[test]
fn test_depth_aware_personalization() {
let config = RankingConfig::default();
let ranker = PageRanker::new(config.clone());
let chat_fnames = HashSet::new();
let weight_root = ranker.personalization_weight("main.rs", &chat_fnames);
assert_eq!(weight_root, config.depth_weight_root);
let weight_shallow = ranker.personalization_weight("src/lib.rs", &chat_fnames);
assert_eq!(weight_shallow, config.depth_weight_root);
let weight_deep = ranker.personalization_weight("src/a/b/c/d/e.rs", &chat_fnames);
assert_eq!(weight_deep, config.depth_weight_deep);
let weight_vendor = ranker.personalization_weight("vendor/lib.rs", &chat_fnames);
assert_eq!(weight_vendor, config.depth_weight_vendor);
}
#[test]
fn test_chat_file_boost() {
let config = RankingConfig::default();
let ranker = PageRanker::new(config.clone());
let mut chat_fnames = HashSet::new();
chat_fnames.insert("main.rs".to_string());
let weight_chat = ranker.personalization_weight("main.rs", &chat_fnames);
assert_eq!(
weight_chat,
config.depth_weight_root * config.pagerank_chat_multiplier
);
let weight_normal = ranker.personalization_weight("other.rs", &chat_fnames);
assert_eq!(weight_normal, config.depth_weight_root);
}
#[test]
fn test_vendor_patterns() {
let config = RankingConfig::default();
let ranker = PageRanker::new(config.clone());
let chat_fnames = HashSet::new();
assert_eq!(
ranker.personalization_weight("node_modules/lib.js", &chat_fnames),
config.depth_weight_vendor
);
assert_eq!(
ranker.personalization_weight("src/vendor/lib.rs", &chat_fnames),
config.depth_weight_vendor
);
assert_eq!(
ranker.personalization_weight("third_party/lib.c", &chat_fnames),
config.depth_weight_vendor
);
}
#[test]
fn test_empty_graph() {
let config = RankingConfig::default();
let ranker = PageRanker::new(config);
let tags_by_file = HashMap::new();
let chat_fnames = vec![];
let ranks = ranker.compute_ranks(&tags_by_file, &chat_fnames);
assert!(ranks.is_empty());
}
#[test]
fn test_pagerank_convergence() {
let config = RankingConfig::default();
let ranker = PageRanker::new(config);
let mut tags_by_file = HashMap::new();
tags_by_file.insert(
"/a.rs".to_string(),
vec![make_tag("a.rs", "func_b", TagKind::Ref)],
);
tags_by_file.insert(
"/b.rs".to_string(),
vec![
make_tag("b.rs", "func_b", TagKind::Def),
make_tag("b.rs", "func_c", TagKind::Ref),
],
);
tags_by_file.insert(
"/c.rs".to_string(),
vec![make_tag("c.rs", "func_c", TagKind::Def)],
);
let chat_fnames = vec![];
let ranks = ranker.compute_ranks(&tags_by_file, &chat_fnames);
let total: f64 = ranks.values().sum();
assert!(
(total - 1.0).abs() < 0.01,
"Total rank should be close to 1.0, got {}",
total
);
assert!(
ranks["c.rs"] >= ranks["b.rs"],
"c.rs rank {} should be >= b.rs rank {}",
ranks["c.rs"],
ranks["b.rs"]
);
assert!(
ranks["b.rs"] >= ranks["a.rs"],
"b.rs rank {} should be >= a.rs rank {}",
ranks["b.rs"],
ranks["a.rs"]
);
}
#[test]
fn test_function_level_pagerank() {
use crate::callgraph::{CallEdge, CallGraph, FunctionId};
let config = RankingConfig::default();
let ranker = PageRanker::new(config);
let mut graph = CallGraph::new();
let main = FunctionId::new("test.rs", "main", 1);
let helper = FunctionId::new("test.rs", "helper", 10);
let util = FunctionId::new("test.rs", "util", 20);
graph.add_call(
main.clone(),
helper.clone(),
CallEdge::new(0.9, "same_file", 5),
);
graph.add_call(
main.clone(),
util.clone(),
CallEdge::new(0.9, "same_file", 6),
);
graph.add_call(
helper.clone(),
util.clone(),
CallEdge::new(0.9, "same_file", 15),
);
let ranks = ranker.compute_function_ranks(&graph);
assert!(
ranks[&util] >= ranks[&helper],
"util rank {} should be >= helper rank {}",
ranks[&util],
ranks[&helper]
);
assert!(
ranks[&util] >= ranks[&main],
"util rank {} should be >= main rank {}",
ranks[&util],
ranks[&main]
);
}
}