use crate::path::path_components;
use std::collections::{HashMap, HashSet};
pub fn compute_personalization(
total_files: usize,
chat_rel_fnames: &HashSet<String>,
rel_fnames: &[String],
mentioned_fnames: &HashSet<String>,
mentioned_idents: &HashSet<String>,
anchor_contributions: &HashMap<String, f64>,
) -> HashMap<String, f64> {
if total_files == 0 {
return HashMap::new();
}
let personalize = 100.0 / total_files as f64;
let mut result: HashMap<String, f64> = HashMap::new();
for rel_fname in rel_fnames {
let mut current_pers = 0.0;
if chat_rel_fnames.contains(rel_fname) {
current_pers += personalize;
}
if mentioned_fnames.contains(rel_fname) {
current_pers = current_pers.max(personalize);
}
let mut path_matched = false;
for component in path_components(rel_fname) {
if mentioned_idents.contains(component) {
path_matched = true;
break;
}
}
if !path_matched
&& let Some(stem) = std::path::Path::new(rel_fname)
.file_stem()
.and_then(|s| s.to_str())
&& mentioned_idents.contains(stem)
{
path_matched = true;
}
if path_matched {
current_pers += personalize;
}
if let Some(&contrib) = anchor_contributions.get(rel_fname) {
current_pers += contrib;
}
if current_pers > 0.0 {
result.insert(rel_fname.clone(), current_pers);
}
}
for (anchor, &contrib) in anchor_contributions {
if !result.contains_key(anchor) {
result.insert(anchor.clone(), contrib);
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
fn no_anchors() -> HashMap<String, f64> {
HashMap::new()
}
#[test]
fn personalization_empty() {
let result = compute_personalization(
0,
&HashSet::new(),
&[],
&HashSet::new(),
&HashSet::new(),
&no_anchors(),
);
assert!(result.is_empty());
}
#[test]
fn personalization_chat_files() {
let mut chat = HashSet::new();
chat.insert("main.rs".to_string());
let files = vec!["main.rs".to_string(), "lib.rs".to_string()];
let result = compute_personalization(
2,
&chat,
&files,
&HashSet::new(),
&HashSet::new(),
&no_anchors(),
);
assert!(result.contains_key("main.rs"));
assert!(!result.contains_key("lib.rs"));
assert!((result["main.rs"] - 50.0).abs() < 0.001); }
#[test]
fn personalization_mentioned_files() {
let mut mentioned = HashSet::new();
mentioned.insert("lib.rs".to_string());
let files = vec!["main.rs".to_string(), "lib.rs".to_string()];
let result = compute_personalization(
2,
&HashSet::new(),
&files,
&mentioned,
&HashSet::new(),
&no_anchors(),
);
assert!(result.contains_key("lib.rs"));
assert!((result["lib.rs"] - 50.0).abs() < 0.001);
}
#[test]
fn personalization_mentioned_idents() {
let mut idents = HashSet::new();
idents.insert("utils".to_string());
let files = vec!["src/utils/mod.rs".to_string()];
let result = compute_personalization(
1,
&HashSet::new(),
&files,
&HashSet::new(),
&idents,
&no_anchors(),
);
assert!(result.contains_key("src/utils/mod.rs"));
}
#[test]
fn personalization_no_double_count() {
let mut chat = HashSet::new();
chat.insert("main.rs".to_string());
let mut mentioned = HashSet::new();
mentioned.insert("main.rs".to_string());
let files = vec!["main.rs".to_string()];
let result =
compute_personalization(1, &chat, &files, &mentioned, &HashSet::new(), &no_anchors());
assert!((result["main.rs"] - 100.0).abs() < 0.001);
}
#[test]
fn personalization_anchor_contributions() {
let anchors = HashMap::from([("entry.rs".to_string(), 500.0)]);
let files = vec!["entry.rs".to_string(), "lib.rs".to_string()];
let result = compute_personalization(
2,
&HashSet::new(),
&files,
&HashSet::new(),
&HashSet::new(),
&anchors,
);
assert!(result.contains_key("entry.rs"));
assert!((result["entry.rs"] - 500.0).abs() < 0.001);
assert!(!result.contains_key("lib.rs"));
}
#[test]
fn personalization_anchor_ambiguous_divided_weight() {
let anchors = HashMap::from([
("app_a/tasks.py".to_string(), 125.0),
("app_b/tasks.py".to_string(), 125.0),
]);
let files = vec![
"app_a/tasks.py".to_string(),
"app_b/tasks.py".to_string(),
"lib.rs".to_string(),
"util.rs".to_string(),
];
let result = compute_personalization(
4,
&HashSet::new(),
&files,
&HashSet::new(),
&HashSet::new(),
&anchors,
);
assert!((result["app_a/tasks.py"] - 125.0).abs() < 0.001);
assert!((result["app_b/tasks.py"] - 125.0).abs() < 0.001);
assert!(!result.contains_key("lib.rs"));
}
#[test]
fn personalization_anchor_not_in_rel_fnames() {
let anchors = HashMap::from([("external.rs".to_string(), 1000.0)]);
let files = vec!["main.rs".to_string()];
let result = compute_personalization(
1,
&HashSet::new(),
&files,
&HashSet::new(),
&HashSet::new(),
&anchors,
);
assert!(result.contains_key("external.rs"));
}
}