use crate::config::Config;
use crate::index::Index;
use crate::rank::cmp_score_desc;
use crate::text::match_tokens;
use std::collections::HashMap;
const K1: f32 = 1.2;
const B: f32 = 0.75;
const MIN_TERM_OVERLAP: usize = 2;
#[derive(Clone, Debug)]
pub struct Lex {
pub id: String,
pub score: f32,
}
pub fn scores(prompt: &str, idx: &Index) -> Vec<Lex> {
let mut q: Vec<String> = match_tokens(prompt);
q.sort();
q.dedup();
if q.is_empty() || idx.skills.is_empty() {
return Vec::new();
}
let docs: Vec<HashMap<String, u32>> = idx
.skills
.iter()
.map(|e| {
let mut tf: HashMap<String, u32> = HashMap::new();
for t in match_tokens(&e.description) {
*tf.entry(t).or_insert(0) += 1;
}
tf
})
.collect();
let lens: Vec<f32> = docs
.iter()
.map(|d| d.values().sum::<u32>() as f32)
.collect();
let n = docs.len() as f32;
let avgdl = (lens.iter().sum::<f32>() / n).max(1.0);
let idf: HashMap<&str, f32> = q
.iter()
.map(|t| {
let df = docs.iter().filter(|d| d.contains_key(t)).count() as f32;
let idf = (1.0 + (n - df + 0.5) / (df + 0.5)).ln();
(t.as_str(), idf)
})
.collect();
let mut out: Vec<Lex> = idx
.skills
.iter()
.enumerate()
.map(|(i, e)| {
let dl = lens[i];
let mut score = 0.0f32;
for t in &q {
let f = *docs[i].get(t).unwrap_or(&0) as f32;
if f == 0.0 {
continue;
}
let denom = f + K1 * (1.0 - B + B * dl / avgdl);
score += idf[t.as_str()] * (f * (K1 + 1.0)) / denom;
}
Lex {
id: e.id.clone(),
score,
}
})
.collect();
out.sort_by(|a, b| cmp_score_desc(a.score, b.score));
out
}
pub fn dominant(prompt: &str, idx: &Index, cfg: &Config) -> Option<Lex> {
if cfg.lexical_min <= 0.0 {
return None;
}
let ranked = scores(prompt, idx);
let top = ranked.first()?;
if top.score < cfg.lexical_min {
return None;
}
let second = ranked.get(1).map(|l| l.score)?;
if top.score - second < cfg.lexical_margin {
return None;
}
let mut q: Vec<String> = match_tokens(prompt);
q.sort();
q.dedup();
let win_terms: std::collections::HashSet<String> = idx
.skills
.iter()
.find(|e| e.id == top.id)
.map(|e| match_tokens(&e.description).into_iter().collect())
.unwrap_or_default();
let overlap = q.iter().filter(|t| win_terms.contains(*t)).count();
if overlap < MIN_TERM_OVERLAP {
return None;
}
Some(top.clone())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::index::Entry;
fn entry(id: &str, description: &str) -> Entry {
Entry {
id: id.to_string(),
name: id.to_string(),
description: description.to_string(),
path: String::new(),
keywords: Vec::new(),
trigger_phrases: Vec::new(),
body_head: String::new(),
hash: String::new(),
embedding: Vec::new(),
}
}
fn index_of(entries: Vec<Entry>) -> Index {
Index {
model: "test".into(),
dim: 0,
skills: entries,
}
}
fn cfg(min: f32, margin: f32) -> Config {
Config {
lexical_min: min,
lexical_margin: margin,
..Default::default()
}
}
#[test]
fn ranks_description_vocabulary_match_first() {
let idx = index_of(vec![
entry(
"xlsx",
"create and edit spreadsheets, compute formulas, build charts",
),
entry("pdf", "merge split and extract text from pdf documents"),
entry(
"docx",
"create and edit word documents with headings and tables",
),
]);
let ranked = scores(
"turn this sales spreadsheet into a chart with formulas",
&idx,
);
assert_eq!(ranked[0].id, "xlsx");
assert!(ranked[0].score > ranked[1].score);
}
#[test]
fn matches_across_plural_inflection() {
let idx = index_of(vec![
entry("xlsx", "edit a spreadsheet, charts and formulas"),
entry("pdf", "merge split and extract text from pdf documents"),
]);
let ranked = scores("compute formulas across my spreadsheets", &idx);
assert_eq!(ranked[0].id, "xlsx");
assert!(ranked[0].score > 0.0);
}
#[test]
fn dominant_requires_absolute_floor() {
let idx = index_of(vec![
entry("xlsx", "edit a spreadsheet, charts and formulas"),
entry("pdf", "pdf documents"),
]);
assert!(dominant("edit my spreadsheet", &idx, &cfg(100.0, 0.5)).is_none());
let win = dominant("edit my spreadsheet", &idx, &cfg(0.5, 0.5)).unwrap();
assert_eq!(win.id, "xlsx");
}
#[test]
fn dominant_requires_margin_over_runner_up() {
let idx = index_of(vec![
entry("a", "process the report data"),
entry("b", "process the report data"),
]);
assert!(dominant("process the report", &idx, &cfg(0.1, 0.5)).is_none());
}
#[test]
fn dominant_rejects_single_term_match() {
let idx = index_of(vec![
entry(
"brand-guidelines",
"apply anthropic brand colors to artifacts",
),
entry("pdf", "merge and split pdf documents"),
]);
assert!(dominant(
"who founded anthropic and in what year",
&idx,
&cfg(0.1, 0.1)
)
.is_none());
let idx2 = index_of(vec![
entry("xlsx", "edit a spreadsheet, charts and formulas"),
entry("pdf", "pdf documents"),
]);
let win = dominant("edit the spreadsheet formulas", &idx2, &cfg(0.1, 0.1)).unwrap();
assert_eq!(win.id, "xlsx");
}
#[test]
fn single_skill_library_is_never_dominant() {
let idx = index_of(vec![entry(
"xlsx",
"edit a spreadsheet, charts and formulas",
)]);
assert!(dominant("edit the spreadsheet formulas", &idx, &cfg(0.1, 0.1)).is_none());
}
#[test]
fn dominant_off_when_min_non_positive() {
let idx = index_of(vec![entry("xlsx", "spreadsheets charts")]);
assert!(dominant("spreadsheet", &idx, &cfg(0.0, 0.5)).is_none());
}
#[test]
fn empty_prompt_or_index_is_none() {
let idx = index_of(vec![entry("xlsx", "spreadsheets")]);
assert!(scores("", &idx).is_empty());
assert!(scores("the an of to", &idx).is_empty()); assert!(scores("spreadsheet", &index_of(vec![])).is_empty());
}
}