use std::collections::{HashMap, HashSet};
use crate::core::community::CommunityRecord;
use crate::core::symbol_graph::SymbolGraph;
pub const BONUS_MAX: f32 = 0.15;
const CENTRALITY_WEIGHT: f32 = 0.10;
const CENTROID_WEIGHT: f32 = 0.05;
pub struct GraphScorer {
centrality: HashMap<String, f32>,
centroids: HashSet<String>,
community_map: HashMap<String, u64>,
}
impl GraphScorer {
pub fn build(graph: &SymbolGraph, communities: &[CommunityRecord]) -> Self {
let degrees = graph.degrees();
let max_degree = degrees.values().copied().max().unwrap_or(0) as f32;
let centrality: HashMap<String, f32> = if max_degree <= 0.0 {
degrees.keys().map(|s| (s.clone(), 0.0_f32)).collect()
} else {
degrees
.into_iter()
.map(|(sym, d)| (sym, d as f32 / max_degree))
.collect()
};
let mut centroids: HashSet<String> = HashSet::with_capacity(communities.len());
let mut community_map: HashMap<String, u64> = HashMap::new();
for rec in communities {
if !rec.centroid_symbol.is_empty() {
centroids.insert(rec.centroid_symbol.clone());
}
let cid = rec.id as u64;
for member in &rec.members {
community_map.insert(member.clone(), cid);
}
}
Self {
centrality,
centroids,
community_map,
}
}
pub fn bonus(&self, symbol: &str) -> f32 {
let c = self.centrality.get(symbol).copied().unwrap_or(0.0);
let centroid = if self.centroids.contains(symbol) {
1.0_f32
} else {
0.0_f32
};
let raw = CENTRALITY_WEIGHT * c + CENTROID_WEIGHT * centroid;
raw.clamp(0.0, BONUS_MAX)
}
pub fn same_community(&self, a: &str, b: &str) -> bool {
match (self.community_map.get(a), self.community_map.get(b)) {
(Some(x), Some(y)) => x == y,
_ => false,
}
}
pub fn community_of(&self, symbol: &str) -> Option<u64> {
self.community_map.get(symbol).copied()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::chunker::ChunkType;
use crate::core::symbol_graph::{ChunkTuple, SymbolGraph};
fn mk_chunk(
id: &str,
file: &str,
name: &str,
calls: &[&str],
chunk_type: ChunkType,
) -> ChunkTuple {
(
id.to_string(),
file.to_string(),
Some(name.to_string()),
calls.iter().map(|s| s.to_string()).collect(),
Vec::new(),
chunk_type,
)
}
fn synthetic_record(id: usize, centroid: &str, members: &[&str]) -> CommunityRecord {
CommunityRecord {
id,
members: members.iter().map(|s| s.to_string()).collect(),
member_count: members.len(),
modularity_contribution: 0.0,
centroid_symbol: centroid.to_string(),
dominant_files: Vec::new(),
}
}
#[test]
fn high_degree_node_scores_higher() {
let chunks = vec![
mk_chunk("a.rs:1:5", "a.rs", "hub", &[], ChunkType::Function),
mk_chunk("a.rs:7:9", "a.rs", "caller1", &["hub"], ChunkType::Function),
mk_chunk(
"a.rs:11:13",
"a.rs",
"caller2",
&["hub"],
ChunkType::Function,
),
mk_chunk(
"a.rs:15:17",
"a.rs",
"caller3",
&["hub"],
ChunkType::Function,
),
mk_chunk(
"a.rs:19:21",
"a.rs",
"caller4",
&["hub"],
ChunkType::Function,
),
mk_chunk(
"a.rs:23:25",
"a.rs",
"caller5",
&["hub"],
ChunkType::Function,
),
mk_chunk("a.rs:27:29", "a.rs", "leaf", &[], ChunkType::Function),
mk_chunk(
"a.rs:31:33",
"a.rs",
"isolated",
&["leaf"],
ChunkType::Function,
),
];
let graph = SymbolGraph::build_from_chunks(&chunks);
let scorer = GraphScorer::build(&graph, &[]);
let hub_bonus = scorer.bonus("hub");
let leaf_bonus = scorer.bonus("leaf");
assert!(
hub_bonus > leaf_bonus,
"hub bonus {hub_bonus} should exceed leaf bonus {leaf_bonus}"
);
assert!(hub_bonus <= BONUS_MAX);
}
#[test]
fn centroid_gets_extra_bonus() {
let chunks = vec![
mk_chunk("a.rs:1:5", "a.rs", "a", &["b"], ChunkType::Function),
mk_chunk("a.rs:7:9", "a.rs", "b", &["a"], ChunkType::Function),
];
let graph = SymbolGraph::build_from_chunks(&chunks);
let communities = vec![synthetic_record(0, "a", &["a", "b"])];
let scorer = GraphScorer::build(&graph, &communities);
let a_bonus = scorer.bonus("a");
let b_bonus = scorer.bonus("b");
assert!(
a_bonus > b_bonus,
"centroid bonus {a_bonus} should exceed non-centroid {b_bonus}"
);
}
#[test]
fn same_community_detection() {
let graph = SymbolGraph::new();
let communities = vec![
synthetic_record(0, "alpha", &["alpha", "beta", "gamma"]),
synthetic_record(1, "delta", &["delta", "epsilon"]),
];
let scorer = GraphScorer::build(&graph, &communities);
assert!(scorer.same_community("alpha", "beta"));
assert!(scorer.same_community("beta", "gamma"));
assert!(!scorer.same_community("alpha", "delta"));
assert!(!scorer.same_community("alpha", "unknown"));
}
#[test]
fn bonus_within_bounds() {
let chunks = vec![
mk_chunk("a.rs:1:5", "a.rs", "hub", &["leaf"], ChunkType::Function),
mk_chunk("a.rs:7:9", "a.rs", "leaf", &[], ChunkType::Function),
];
let graph = SymbolGraph::build_from_chunks(&chunks);
let communities = vec![synthetic_record(0, "hub", &["hub", "leaf"])];
let scorer = GraphScorer::build(&graph, &communities);
let b = scorer.bonus("hub");
assert!(b <= BONUS_MAX, "bonus {b} exceeded cap {BONUS_MAX}");
assert!(b >= 0.0);
}
}