use crate::config::Config;
use crate::index::Index;
use crate::text::{match_tokens, norm_token, tokenize};
use std::collections::{BTreeSet, HashSet};
#[derive(Clone, Debug)]
pub struct Hit {
pub id: String,
pub name: String,
pub cosine: f32,
pub context: f32,
pub file: f32,
pub project: f32,
pub keyword: f32,
pub phrase: f32,
pub score: f32,
}
impl Hit {
pub fn stage1_score(&self) -> f32 {
self.cosine + self.context + self.file + self.project + self.keyword + self.phrase
}
pub fn breakdown(&self) -> [(&'static str, f32); 6] {
[
("cos", self.cosine),
("ctx", self.context),
("file", self.file),
("project", self.project),
("kw", self.keyword),
("ph", self.phrase),
]
}
}
pub fn cosine(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let (mut dot, mut na, mut nb) = (0f32, 0f32, 0f32);
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
na += x * x;
nb += y * y;
}
if na == 0.0 || nb == 0.0 {
return 0.0;
}
let c = dot / (na.sqrt() * nb.sqrt());
if c.is_finite() {
c
} else {
0.0
}
}
pub fn cmp_score_desc(a: f32, b: f32) -> std::cmp::Ordering {
match (a.is_nan(), b.is_nan()) {
(false, false) => b.partial_cmp(&a).unwrap_or(std::cmp::Ordering::Equal),
(true, true) => std::cmp::Ordering::Equal,
(true, false) => std::cmp::Ordering::Greater, (false, true) => std::cmp::Ordering::Less, }
}
pub fn keyword_score(prompt: &str, keywords: &[String], boost: f32) -> f32 {
let toks: HashSet<String> = tokenize(prompt).iter().map(|t| norm_token(t)).collect();
let hits = keywords
.iter()
.filter(|k| toks.contains(&norm_token(k)))
.count();
hits as f32 * boost
}
pub fn phrase_score(prompt: &str, phrases: &[String], boost: f32) -> f32 {
if phrases.is_empty() {
return 0.0;
}
let toks: HashSet<String> = match_tokens(prompt).into_iter().collect();
let hits = phrases
.iter()
.filter(|p| {
let mut pt = p.split_whitespace().peekable();
pt.peek().is_some() && pt.all(|t| toks.contains(&norm_token(t)))
})
.count();
hits as f32 * boost
}
pub const PROJECT_GATE_SLACK: f32 = 0.06;
pub fn context_weight(prompt_top: f32, cfg: &Config) -> f32 {
if cfg.context_weight <= 0.0 || cfg.context_depth == 0 {
return 0.0;
}
let (lo, hi) = (cfg.vague_lo, cfg.vague_hi);
let vagueness = if hi <= lo {
if prompt_top >= hi {
0.0
} else {
1.0
}
} else {
((hi - prompt_top) / (hi - lo)).clamp(0.0, 1.0)
};
cfg.context_weight * vagueness
}
pub fn rank_all(query: &[f32], prompt: &str, index: &Index, cfg: &Config) -> Vec<Hit> {
rank_all_ctx(
query,
None,
&BTreeSet::new(),
&BTreeSet::new(),
prompt,
index,
cfg,
)
}
pub fn rank_all_ctx(
query: &[f32],
context: Option<&[f32]>,
file_ids: &BTreeSet<String>,
project_ids: &BTreeSet<String>,
prompt: &str,
index: &Index,
cfg: &Config,
) -> Vec<Hit> {
let prompt_cos: Vec<f32> = index
.skills
.iter()
.map(|e| cosine(query, &e.embedding))
.collect();
let prompt_top = prompt_cos.iter().copied().fold(0.0_f32, f32::max);
let lambda = match context {
Some(_) => context_weight(prompt_top, cfg),
None => 0.0,
};
let ctx_cos: Vec<f32> = match (lambda > 0.0, context) {
(true, Some(c)) => index
.skills
.iter()
.map(|e| cosine(c, &e.embedding))
.collect(),
_ => Vec::new(),
};
let ctx_mean = if ctx_cos.is_empty() {
0.0
} else {
ctx_cos.iter().sum::<f32>() / ctx_cos.len() as f32
};
let mut hits: Vec<Hit> = index
.skills
.iter()
.enumerate()
.map(|(i, e)| {
let cosine = prompt_cos[i];
let context = ctx_cos
.get(i)
.map(|&c| lambda * (c - ctx_mean).max(0.0))
.unwrap_or(0.0);
let file = if cfg.file_boost > 0.0 && file_ids.contains(&e.id) {
cfg.file_boost
} else {
0.0
};
let project = if cfg.project_boost > 0.0
&& cosine >= cfg.min_similarity - PROJECT_GATE_SLACK
&& project_ids.contains(&e.id)
{
cfg.project_boost
} else {
0.0
};
let keyword = keyword_score(prompt, &e.keywords, cfg.keyword_boost);
let phrase = phrase_score(prompt, &e.trigger_phrases, cfg.phrase_boost);
let mut hit = Hit {
id: e.id.clone(),
name: e.name.clone(),
cosine,
context,
file,
project,
keyword,
phrase,
score: 0.0,
};
hit.score = hit.stage1_score();
hit
})
.collect();
hits.sort_by(|a, b| cmp_score_desc(a.score, b.score));
hits
}
pub fn select(hits: Vec<Hit>, cfg: &Config) -> Vec<Hit> {
hits.into_iter()
.filter(|h| h.score >= cfg.min_similarity)
.take(cfg.max_skills)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::index::{Entry, Index};
fn no_files() -> BTreeSet<String> {
BTreeSet::new()
}
fn ctx_cfg() -> Config {
Config {
context_depth: 1,
context_weight: 0.3,
vague_lo: 0.55,
vague_hi: 0.65,
file_boost: 0.0, project_boost: 0.0, ..Default::default()
}
}
fn idx2() -> Index {
let entry = |id: &str, emb: Vec<f32>| Entry {
id: id.to_string(),
name: id.to_string(),
description: String::new(),
path: String::new(),
keywords: Vec::new(),
trigger_phrases: Vec::new(),
body_head: String::new(),
hash: String::new(),
embedding: emb,
};
Index {
model: "m".into(),
dim: 2,
skills: vec![entry("a", vec![1.0, 0.0]), entry("b", vec![0.0, 1.0])],
}
}
#[test]
fn context_weight_scales_with_vagueness() {
let cfg = ctx_cfg(); assert!((context_weight(0.50, &cfg) - 0.30).abs() < 1e-6); assert_eq!(context_weight(0.65, &cfg), 0.0); assert!((context_weight(0.60, &cfg) - 0.15).abs() < 1e-6); }
#[test]
fn context_weight_zero_when_disabled() {
let off_weight = Config {
context_depth: 1,
context_weight: 0.0,
..Default::default()
};
let off_depth = Config {
context_depth: 0,
context_weight: 0.3,
..Default::default()
};
assert_eq!(context_weight(0.10, &off_weight), 0.0);
assert_eq!(context_weight(0.10, &off_depth), 0.0);
}
#[test]
fn context_none_matches_plain_rank() {
let q = [0.5, 0.5];
let hits = rank_all_ctx(&q, None, &no_files(), &no_files(), "", &idx2(), &ctx_cfg());
for h in &hits {
assert_eq!(h.context, 0.0);
assert!((h.score - h.cosine).abs() < 1e-6);
}
}
#[test]
fn vague_prompt_lets_context_break_a_tie() {
let cfg = Config {
vague_lo: 0.80,
vague_hi: 0.90,
..ctx_cfg()
};
let q = [0.5, 0.5]; let ctx = [1.0, 0.0]; let hits = rank_all_ctx(&q, Some(&ctx), &no_files(), &no_files(), "", &idx2(), &cfg);
assert_eq!(hits[0].id, "a"); assert!(hits[0].context > 0.0);
let b = hits.iter().find(|h| h.id == "b").unwrap();
assert_eq!(b.context, 0.0);
}
#[test]
fn confident_prompt_suppresses_context() {
let q = [1.0, 0.0];
let ctx = [0.0, 1.0];
let hits = rank_all_ctx(
&q,
Some(&ctx),
&no_files(),
&no_files(),
"",
&idx2(),
&ctx_cfg(),
);
assert!(hits.iter().all(|h| h.context == 0.0));
assert_eq!(hits[0].id, "a");
}
#[test]
fn file_boost_lifts_named_skill_ungated() {
let cfg = Config {
file_boost: 0.2,
..ctx_cfg()
};
let q = [1.0, 0.0]; let files: BTreeSet<String> = ["b".to_string()].into_iter().collect();
let hits = rank_all_ctx(&q, None, &files, &no_files(), "", &idx2(), &cfg);
let b = hits.iter().find(|h| h.id == "b").unwrap();
assert!((b.file - 0.2).abs() < 1e-6); let a = hits.iter().find(|h| h.id == "a").unwrap();
assert_eq!(a.file, 0.0); }
#[test]
fn file_boost_off_when_zero() {
let q = [1.0, 0.0];
let files: BTreeSet<String> = ["b".to_string()].into_iter().collect();
let hits = rank_all_ctx(&q, None, &files, &no_files(), "", &idx2(), &ctx_cfg());
assert!(hits.iter().all(|h| h.file == 0.0));
}
#[test]
fn project_boost_gated_on_cosine_floor() {
let cfg = Config {
project_boost: 0.2,
..ctx_cfg()
};
let proj: BTreeSet<String> = ["b".to_string()].into_iter().collect();
let hits = rank_all_ctx(&[0.0, 1.0], None, &no_files(), &proj, "", &idx2(), &cfg);
let b = hits.iter().find(|h| h.id == "b").unwrap();
assert!((b.project - 0.2).abs() < 1e-6);
let hits = rank_all_ctx(&[1.0, 0.0], None, &no_files(), &proj, "", &idx2(), &cfg);
let b = hits.iter().find(|h| h.id == "b").unwrap();
assert_eq!(b.project, 0.0);
}
#[test]
fn project_boost_lifts_near_floor_skill_over_the_line() {
let cfg = Config {
project_boost: 0.2,
min_similarity: 0.30,
..ctx_cfg()
};
let proj: BTreeSet<String> = ["b".to_string()].into_iter().collect();
let q = [0.9578, 0.2873];
let hits = rank_all_ctx(&q, None, &no_files(), &proj, "", &idx2(), &cfg);
let b = hits.iter().find(|h| h.id == "b").unwrap();
assert!(b.cosine < cfg.min_similarity, "cosine {}", b.cosine);
assert!(b.cosine >= cfg.min_similarity - PROJECT_GATE_SLACK);
assert!((b.project - 0.2).abs() < 1e-6);
assert!(b.score >= cfg.min_similarity); }
#[test]
fn project_boost_off_when_zero() {
let proj: BTreeSet<String> = ["b".to_string()].into_iter().collect();
let hits = rank_all_ctx(
&[0.0, 1.0],
None,
&no_files(),
&proj,
"",
&idx2(),
&ctx_cfg(),
);
assert!(hits.iter().all(|h| h.project == 0.0));
}
#[test]
fn corrupt_infinite_embedding_cannot_claim_rank() {
let entry = |id: &str, emb: Vec<f32>| crate::index::Entry {
id: id.to_string(),
name: id.to_string(),
description: String::new(),
path: String::new(),
keywords: Vec::new(),
trigger_phrases: Vec::new(),
body_head: String::new(),
hash: String::new(),
embedding: emb,
};
let idx = Index {
model: "m".into(),
dim: 2,
skills: vec![
entry("corrupt", vec![f32::INFINITY, 0.0]),
entry("real", vec![1.0, 0.0]),
],
};
let hits = rank_all(&[1.0, 0.0], "", &idx, &Config::default());
assert_eq!(hits[0].id, "real");
assert_eq!(hits.iter().find(|h| h.id == "corrupt").unwrap().score, 0.0);
assert!(hits.iter().all(|h| h.score.is_finite()));
}
#[test]
fn cosine_bounds() {
let a = [1.0, 0.0, 0.0];
let b = [1.0, 0.0, 0.0];
let c = [0.0, 1.0, 0.0];
assert!((cosine(&a, &b) - 1.0).abs() < 1e-6);
assert!(cosine(&a, &c).abs() < 1e-6);
}
#[test]
fn cosine_rejects_dimension_mismatch() {
let a = [1.0, 0.0, 0.0];
let b = [1.0, 0.0];
assert_eq!(cosine(&a, &b), 0.0);
}
#[test]
fn cmp_score_desc_sorts_nan_last_either_side() {
let mut v = [f32::NAN, 0.5, 2.0, -1.0];
v.sort_by(|a, b| cmp_score_desc(*a, *b));
assert_eq!(&v[..3], &[2.0, 0.5, -1.0]);
assert!(v[3].is_nan());
}
#[test]
fn cmp_score_desc_regular_values_descend() {
let mut v = vec![1.0, 3.0, 2.0];
v.sort_by(|a, b| cmp_score_desc(*a, *b));
assert_eq!(v, [3.0, 2.0, 1.0]);
}
#[test]
fn keyword_boost_counts_matches() {
let kw = vec!["uv".to_string(), "setup".to_string()];
assert!((keyword_score("set up with uv", &kw, 0.1) - 0.1).abs() < 1e-6); assert!((keyword_score("uv setup now", &kw, 0.1) - 0.2).abs() < 1e-6); }
#[test]
fn keyword_boost_matches_across_plural_inflection() {
let kw = vec!["chart".to_string(), "dependencies".to_string()];
assert!((keyword_score("make some charts", &kw, 0.1) - 0.1).abs() < 1e-6);
assert!((keyword_score("add a dependency", &kw, 0.1) - 0.1).abs() < 1e-6);
}
#[test]
fn phrase_matches_across_plural_inflection() {
let ph = vec!["merge pdf files".to_string()];
assert!((phrase_score("merge these pdf file chunks", &ph, 0.2) - 0.2).abs() < 1e-6);
assert!((phrase_score("merging is off topic here", &ph, 0.2) - 0.0).abs() < 1e-6);
}
#[test]
fn phrase_fires_only_when_all_tokens_present() {
let ph = vec!["screen reader support".to_string()];
assert!(
(phrase_score("does my form have screen reader support today", &ph, 0.2) - 0.2).abs()
< 1e-6
);
assert!((phrase_score("support for a screen reader", &ph, 0.2) - 0.2).abs() < 1e-6);
}
#[test]
fn phrase_does_not_fire_on_partial_overlap() {
let ph = vec!["screen reader support".to_string()];
assert_eq!(
phrase_score("split this screen into two panes", &ph, 0.2),
0.0
);
assert_eq!(
phrase_score(
"implement a debounce function in vanilla javascript",
&ph,
0.2
),
0.0
);
}
#[test]
fn phrase_score_sums_distinct_phrases() {
let ph = vec![
"convert markdown pdf".to_string(),
"merge two pdf files".to_string(),
];
assert!(
(phrase_score(
"convert this markdown to pdf and merge two pdf files",
&ph,
0.2
) - 0.4)
.abs()
< 1e-6
);
}
}