use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::types::{RankedTag, RankingConfig, Tag, TagKind};
pub struct BoostCalculator {
config: RankingConfig,
}
impl BoostCalculator {
pub fn new(config: RankingConfig) -> Self {
Self { config }
}
pub fn apply_boosts(
&self,
tags_by_file: &HashMap<String, Vec<Tag>>,
file_ranks: &HashMap<String, f64>,
symbol_ranks: Option<&HashMap<(Arc<str>, Arc<str>), f64>>,
chat_fnames: &HashSet<String>,
mentioned_fnames: &HashSet<String>,
mentioned_idents: &HashSet<String>,
temporal_boost_files: &HashSet<String>,
git_weights: Option<&HashMap<String, f64>>,
caller_weights: Option<&HashMap<String, f64>>,
focus_expansion_weights: Option<&HashMap<(Arc<str>, Arc<str>), f64>>,
) -> Vec<RankedTag> {
let mut result = Vec::new();
let chat_rel_fnames: HashSet<String> =
chat_fnames.iter().map(|f| extract_rel_fname(f)).collect();
for (fname, tags) in tags_by_file {
let rel_fname = extract_rel_fname(fname);
let file_rank = file_ranks.get(&rel_fname).copied().unwrap_or(0.0);
let git_weight = git_weights
.and_then(|w| w.get(&rel_fname))
.copied()
.unwrap_or(1.0);
let raw_caller_weight = caller_weights
.and_then(|w| w.get(&rel_fname))
.copied()
.unwrap_or(1.0);
let effective_boost = self.config.boost_caller_weight * (1.0 - self.config.hub_damping);
let caller_weight = (1.0 + (raw_caller_weight - 1.0) * effective_boost).max(0.01);
for tag in tags.iter().filter(|t| t.kind == TagKind::Def) {
let base_rank = symbol_ranks
.and_then(|sr| {
let key = (Arc::clone(&tag.rel_fname), Arc::clone(&tag.name));
sr.get(&key).copied()
})
.unwrap_or(file_rank);
let mut boost = 1.0;
if mentioned_idents.contains(tag.name.as_ref()) {
boost *= self.config.boost_mentioned_ident;
}
if mentioned_fnames.contains(&rel_fname) {
boost *= self.config.boost_mentioned_file;
}
if chat_rel_fnames.contains(&rel_fname) {
boost *= self.config.boost_chat_file;
}
if temporal_boost_files.contains(&rel_fname) {
boost *= self.config.boost_temporal_coupling;
}
if let Some(expansion_weights) = focus_expansion_weights {
let key = (Arc::clone(&tag.rel_fname), Arc::clone(&tag.name));
if let Some(&expansion_weight) = expansion_weights.get(&key) {
boost *= self.config.boost_focus_expansion * expansion_weight;
}
}
let final_rank = base_rank * boost * git_weight * caller_weight;
result.push(RankedTag::new(final_rank, tag.clone()));
}
}
result.sort();
result
}
}
fn extract_rel_fname(abs_fname: &str) -> String {
abs_fname.strip_prefix('/').unwrap_or(abs_fname).to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::TagKind;
fn make_tag(rel_fname: &str, name: &str, kind: TagKind) -> Tag {
Tag {
rel_fname: Arc::from(rel_fname),
fname: Arc::from(format!("/{}", rel_fname)),
line: 1,
name: Arc::from(name),
kind,
node_type: Arc::from("function"),
parent_name: None,
parent_line: None,
signature: None,
fields: None,
metadata: None,
}
}
#[test]
fn test_no_boosts() {
let config = RankingConfig::default();
let calculator = BoostCalculator::new(config);
let mut tags_by_file = HashMap::new();
tags_by_file.insert(
"/a.rs".to_string(),
vec![make_tag("a.rs", "foo", TagKind::Def)],
);
let mut file_ranks = HashMap::new();
file_ranks.insert("a.rs".to_string(), 0.5);
let result = calculator.apply_boosts(
&tags_by_file,
&file_ranks,
None,
&HashSet::new(),
&HashSet::new(),
&HashSet::new(),
&HashSet::new(),
None,
None,
None,
);
assert_eq!(result.len(), 1);
assert_eq!(result[0].rank, 0.5); assert_eq!(result[0].tag.name.as_ref(), "foo");
}
#[test]
fn test_mentioned_ident_boost() {
let config = RankingConfig::default();
let calculator = BoostCalculator::new(config.clone());
let mut tags_by_file = HashMap::new();
tags_by_file.insert(
"/a.rs".to_string(),
vec![make_tag("a.rs", "foo", TagKind::Def)],
);
let mut file_ranks = HashMap::new();
file_ranks.insert("a.rs".to_string(), 0.1);
let mut mentioned_idents = HashSet::new();
mentioned_idents.insert("foo".to_string());
let result = calculator.apply_boosts(
&tags_by_file,
&file_ranks,
None,
&HashSet::new(),
&HashSet::new(),
&mentioned_idents,
&HashSet::new(),
None,
None,
None,
);
assert_eq!(result.len(), 1);
assert_eq!(result[0].rank, 0.1 * config.boost_mentioned_ident); }
#[test]
fn test_mentioned_file_boost() {
let config = RankingConfig::default();
let calculator = BoostCalculator::new(config.clone());
let mut tags_by_file = HashMap::new();
tags_by_file.insert(
"/a.rs".to_string(),
vec![make_tag("a.rs", "foo", TagKind::Def)],
);
let mut file_ranks = HashMap::new();
file_ranks.insert("a.rs".to_string(), 0.2);
let mut mentioned_fnames = HashSet::new();
mentioned_fnames.insert("a.rs".to_string());
let result = calculator.apply_boosts(
&tags_by_file,
&file_ranks,
None,
&HashSet::new(),
&mentioned_fnames,
&HashSet::new(),
&HashSet::new(),
None,
None,
None,
);
assert_eq!(result.len(), 1);
assert_eq!(result[0].rank, 0.2 * config.boost_mentioned_file); }
#[test]
fn test_chat_file_boost() {
let config = RankingConfig::default();
let calculator = BoostCalculator::new(config.clone());
let mut tags_by_file = HashMap::new();
tags_by_file.insert(
"/a.rs".to_string(),
vec![make_tag("a.rs", "foo", TagKind::Def)],
);
let mut file_ranks = HashMap::new();
file_ranks.insert("a.rs".to_string(), 0.05);
let mut chat_fnames = HashSet::new();
chat_fnames.insert("/a.rs".to_string());
let result = calculator.apply_boosts(
&tags_by_file,
&file_ranks,
None,
&chat_fnames,
&HashSet::new(),
&HashSet::new(),
&HashSet::new(),
None,
None,
None,
);
assert_eq!(result.len(), 1);
assert_eq!(result[0].rank, 0.05 * config.boost_chat_file); }
#[test]
fn test_temporal_coupling_boost() {
let config = RankingConfig::default();
let calculator = BoostCalculator::new(config.clone());
let mut tags_by_file = HashMap::new();
tags_by_file.insert(
"/a.rs".to_string(),
vec![make_tag("a.rs", "foo", TagKind::Def)],
);
let mut file_ranks = HashMap::new();
file_ranks.insert("a.rs".to_string(), 0.5);
let mut temporal_boost_files = HashSet::new();
temporal_boost_files.insert("a.rs".to_string());
let result = calculator.apply_boosts(
&tags_by_file,
&file_ranks,
None,
&HashSet::new(),
&HashSet::new(),
&HashSet::new(),
&temporal_boost_files,
None,
None,
None,
);
assert_eq!(result.len(), 1);
assert_eq!(result[0].rank, 0.5 * config.boost_temporal_coupling); }
#[test]
fn test_multiple_boosts_multiply() {
let config = RankingConfig::default();
let calculator = BoostCalculator::new(config.clone());
let mut tags_by_file = HashMap::new();
tags_by_file.insert(
"/a.rs".to_string(),
vec![make_tag("a.rs", "foo", TagKind::Def)],
);
let mut file_ranks = HashMap::new();
file_ranks.insert("a.rs".to_string(), 0.01);
let mut mentioned_idents = HashSet::new();
mentioned_idents.insert("foo".to_string());
let mut chat_fnames = HashSet::new();
chat_fnames.insert("/a.rs".to_string());
let result = calculator.apply_boosts(
&tags_by_file,
&file_ranks,
None,
&chat_fnames,
&HashSet::new(),
&mentioned_idents,
&HashSet::new(),
None,
None,
None,
);
assert_eq!(result.len(), 1);
assert_eq!(
result[0].rank,
0.01 * config.boost_mentioned_ident * config.boost_chat_file
);
}
#[test]
fn test_git_weight_multiplier() {
let config = RankingConfig::default();
let calculator = BoostCalculator::new(config);
let mut tags_by_file = HashMap::new();
tags_by_file.insert(
"/a.rs".to_string(),
vec![make_tag("a.rs", "foo", TagKind::Def)],
);
let mut file_ranks = HashMap::new();
file_ranks.insert("a.rs".to_string(), 1.0);
let mut git_weights = HashMap::new();
git_weights.insert("a.rs".to_string(), 2.0);
let result = calculator.apply_boosts(
&tags_by_file,
&file_ranks,
None,
&HashSet::new(),
&HashSet::new(),
&HashSet::new(),
&HashSet::new(),
Some(&git_weights),
None,
None,
);
assert_eq!(result.len(), 1);
assert_eq!(result[0].rank, 1.0 * 2.0); }
#[test]
fn test_caller_weight_multiplier() {
let config = RankingConfig::default();
let calculator = BoostCalculator::new(config);
let mut tags_by_file = HashMap::new();
tags_by_file.insert(
"/a.rs".to_string(),
vec![make_tag("a.rs", "foo", TagKind::Def)],
);
let mut file_ranks = HashMap::new();
file_ranks.insert("a.rs".to_string(), 1.0);
let mut caller_weights = HashMap::new();
caller_weights.insert("a.rs".to_string(), 1.5);
let result = calculator.apply_boosts(
&tags_by_file,
&file_ranks,
None,
&HashSet::new(),
&HashSet::new(),
&HashSet::new(),
&HashSet::new(),
None,
Some(&caller_weights),
None,
);
assert_eq!(result.len(), 1);
assert_eq!(result[0].rank, 1.0 * 2.0); }
#[test]
fn test_focus_expansion_weight() {
let config = RankingConfig::default();
let calculator = BoostCalculator::new(config.clone());
let mut tags_by_file = HashMap::new();
tags_by_file.insert(
"/a.rs".to_string(),
vec![make_tag("a.rs", "foo", TagKind::Def)],
);
let mut file_ranks = HashMap::new();
file_ranks.insert("a.rs".to_string(), 0.1);
let mut focus_expansion_weights = HashMap::new();
focus_expansion_weights.insert((Arc::from("a.rs"), Arc::from("foo")), 0.8);
let result = calculator.apply_boosts(
&tags_by_file,
&file_ranks,
None,
&HashSet::new(),
&HashSet::new(),
&HashSet::new(),
&HashSet::new(),
None,
None,
Some(&focus_expansion_weights),
);
assert_eq!(result.len(), 1);
assert_eq!(result[0].rank, 0.1 * config.boost_focus_expansion * 0.8);
}
#[test]
fn test_symbol_rank_overrides_file_rank() {
let config = RankingConfig::default();
let calculator = BoostCalculator::new(config);
let mut tags_by_file = HashMap::new();
tags_by_file.insert(
"/a.rs".to_string(),
vec![make_tag("a.rs", "foo", TagKind::Def)],
);
let mut file_ranks = HashMap::new();
file_ranks.insert("a.rs".to_string(), 0.1);
let mut symbol_ranks = HashMap::new();
symbol_ranks.insert((Arc::from("a.rs"), Arc::from("foo")), 0.9);
let result = calculator.apply_boosts(
&tags_by_file,
&file_ranks,
Some(&symbol_ranks),
&HashSet::new(),
&HashSet::new(),
&HashSet::new(),
&HashSet::new(),
None,
None,
None,
);
assert_eq!(result.len(), 1);
assert_eq!(result[0].rank, 0.9); }
#[test]
fn test_only_definitions_included() {
let config = RankingConfig::default();
let calculator = BoostCalculator::new(config);
let mut tags_by_file = HashMap::new();
tags_by_file.insert(
"/a.rs".to_string(),
vec![
make_tag("a.rs", "foo", TagKind::Def),
make_tag("a.rs", "bar", TagKind::Ref),
],
);
let mut file_ranks = HashMap::new();
file_ranks.insert("a.rs".to_string(), 1.0);
let result = calculator.apply_boosts(
&tags_by_file,
&file_ranks,
None,
&HashSet::new(),
&HashSet::new(),
&HashSet::new(),
&HashSet::new(),
None,
None,
None,
);
assert_eq!(result.len(), 1); assert_eq!(result[0].tag.name.as_ref(), "foo");
}
#[test]
fn test_sorting_descending() {
let config = RankingConfig::default();
let calculator = BoostCalculator::new(config);
let mut tags_by_file = HashMap::new();
tags_by_file.insert(
"/a.rs".to_string(),
vec![
make_tag("a.rs", "low", TagKind::Def),
make_tag("a.rs", "medium", TagKind::Def),
make_tag("a.rs", "high", TagKind::Def),
],
);
let mut file_ranks = HashMap::new();
file_ranks.insert("a.rs".to_string(), 1.0);
let mut symbol_ranks = HashMap::new();
symbol_ranks.insert((Arc::from("a.rs"), Arc::from("low")), 0.1);
symbol_ranks.insert((Arc::from("a.rs"), Arc::from("medium")), 0.5);
symbol_ranks.insert((Arc::from("a.rs"), Arc::from("high")), 0.9);
let result = calculator.apply_boosts(
&tags_by_file,
&file_ranks,
Some(&symbol_ranks),
&HashSet::new(),
&HashSet::new(),
&HashSet::new(),
&HashSet::new(),
None,
None,
None,
);
assert_eq!(result.len(), 3);
assert_eq!(result[0].tag.name.as_ref(), "high");
assert_eq!(result[1].tag.name.as_ref(), "medium");
assert_eq!(result[2].tag.name.as_ref(), "low");
}
#[test]
fn test_combined_all_weights() {
let config = RankingConfig::default();
let calculator = BoostCalculator::new(config.clone());
let mut tags_by_file = HashMap::new();
tags_by_file.insert(
"/a.rs".to_string(),
vec![make_tag("a.rs", "foo", TagKind::Def)],
);
let mut file_ranks = HashMap::new();
file_ranks.insert("a.rs".to_string(), 0.1);
let mut mentioned_idents = HashSet::new();
mentioned_idents.insert("foo".to_string());
let mut mentioned_fnames = HashSet::new();
mentioned_fnames.insert("a.rs".to_string());
let mut chat_fnames = HashSet::new();
chat_fnames.insert("/a.rs".to_string());
let mut temporal_boost_files = HashSet::new();
temporal_boost_files.insert("a.rs".to_string());
let mut git_weights = HashMap::new();
git_weights.insert("a.rs".to_string(), 2.0);
let mut caller_weights = HashMap::new();
caller_weights.insert("a.rs".to_string(), 1.5);
let mut focus_expansion_weights = HashMap::new();
focus_expansion_weights.insert((Arc::from("a.rs"), Arc::from("foo")), 0.5);
let result = calculator.apply_boosts(
&tags_by_file,
&file_ranks,
None,
&chat_fnames,
&mentioned_fnames,
&mentioned_idents,
&temporal_boost_files,
Some(&git_weights),
Some(&caller_weights),
Some(&focus_expansion_weights),
);
assert_eq!(result.len(), 1);
let raw_caller = 1.5;
let scaled_caller = 1.0 + (raw_caller - 1.0) * config.boost_caller_weight;
let expected = 0.1
* config.boost_mentioned_ident
* config.boost_mentioned_file
* config.boost_chat_file
* config.boost_temporal_coupling
* (config.boost_focus_expansion * 0.5)
* 2.0
* scaled_caller;
assert!((result[0].rank - expected).abs() < 1e-6);
}
#[test]
fn test_extract_rel_fname() {
assert_eq!(extract_rel_fname("/a.rs"), "a.rs");
assert_eq!(extract_rel_fname("/src/lib.rs"), "src/lib.rs");
assert_eq!(extract_rel_fname("no_slash.rs"), "no_slash.rs");
}
}