use std::cmp::Reverse;
use std::collections::{BinaryHeap, HashMap};
use crate::language::ChunkType;
use crate::math::cosine_similarity;
use crate::store::helpers::{SearchFilter, SearchResult};
use super::config::ScoringConfig;
use super::name_match::NameMatcher;
use super::note_boost::NoteBoostIndex;
pub(crate) fn chunk_importance(name: &str, file_path: &str) -> f32 {
let cfg = &ScoringConfig::DEFAULT;
if crate::is_test_chunk(name, file_path) {
return cfg.importance_test;
}
if name.starts_with('_') && !name.starts_with("__") {
return cfg.importance_private;
}
1.0
}
pub(crate) fn apply_parent_boost(results: &mut [SearchResult]) {
if results.len() < 3 {
return; }
let mut parent_counts: HashMap<String, usize> = HashMap::new();
for r in results.iter() {
if let Some(ref ptn) = r.chunk.parent_type_name {
*parent_counts.entry(ptn.clone()).or_insert(0) += 1;
}
}
if !parent_counts.values().any(|&c| c >= 2) {
return;
}
let cfg = &ScoringConfig::DEFAULT;
let max_children = (cfg.parent_boost_cap - 1.0) / cfg.parent_boost_per_child;
let mut boosted = false;
for r in results.iter_mut() {
let is_container = matches!(
r.chunk.chunk_type,
ChunkType::Class | ChunkType::Struct | ChunkType::Interface
);
if !is_container {
continue;
}
if let Some(&count) = parent_counts.get(&r.chunk.name) {
if count >= 2 {
let boost =
1.0 + cfg.parent_boost_per_child * (count as f32 - 1.0).min(max_children);
tracing::debug!(
name = %r.chunk.name,
child_count = count,
boost = %boost,
"parent_boost: boosting container"
);
r.score *= boost;
boosted = true;
}
}
}
if boosted {
results.sort_by(|a, b| b.score.total_cmp(&a.score));
}
}
pub(crate) struct BoundedScoreHeap {
heap: BinaryHeap<Reverse<(OrderedFloat, String)>>,
capacity: usize,
}
#[derive(Clone, Copy, PartialEq)]
struct OrderedFloat(f32);
impl Eq for OrderedFloat {}
impl PartialOrd for OrderedFloat {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for OrderedFloat {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0.total_cmp(&other.0)
}
}
impl BoundedScoreHeap {
pub fn new(capacity: usize) -> Self {
Self {
heap: BinaryHeap::with_capacity(capacity + 1),
capacity,
}
}
pub fn push(&mut self, id: String, score: f32) {
if !score.is_finite() {
tracing::warn!("BoundedScoreHeap: ignoring non-finite score");
return;
}
if self.heap.len() < self.capacity {
self.heap.push(Reverse((OrderedFloat(score), id)));
return;
}
if let Some(Reverse((OrderedFloat(min_score), _))) = self.heap.peek() {
if score > *min_score {
self.heap.pop();
self.heap.push(Reverse((OrderedFloat(score), id)));
}
}
}
pub fn into_sorted_vec(self) -> Vec<(String, f32)> {
let mut results: Vec<_> = self
.heap
.into_iter()
.map(|Reverse((OrderedFloat(score), id))| (id, score))
.collect();
results.sort_by(|a, b| b.1.total_cmp(&a.1));
results
}
}
pub(crate) struct ScoringContext<'a> {
pub query: &'a [f32],
pub filter: &'a SearchFilter,
pub name_matcher: Option<&'a NameMatcher>,
pub glob_matcher: Option<&'a globset::GlobMatcher>,
pub note_index: &'a NoteBoostIndex<'a>,
pub threshold: f32,
}
pub(crate) fn score_candidate(
embedding: &[f32],
name: Option<&str>,
file_part: &str,
ctx: &ScoringContext<'_>,
) -> Option<f32> {
let embedding_score = cosine_similarity(ctx.query, embedding)?;
let base_score = if let Some(matcher) = ctx.name_matcher {
let n = name.unwrap_or("");
let name_score = matcher.score(n);
(1.0 - ctx.filter.name_boost) * embedding_score + ctx.filter.name_boost * name_score
} else {
embedding_score
};
if let Some(matcher) = ctx.glob_matcher {
if !matcher.is_match(file_part) {
return None;
}
}
let chunk_name = name.unwrap_or("");
let mut score = base_score.max(0.0) * ctx.note_index.boost(file_part, chunk_name);
if ctx.filter.enable_demotion {
score *= chunk_importance(chunk_name, file_part);
}
if score >= ctx.threshold {
Some(score)
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::store::helpers::{ChunkSummary, NoteSummary, SearchFilter};
#[test]
fn test_bounded_heap_equal_scores() {
let mut heap = BoundedScoreHeap::new(2);
heap.push("a".to_string(), 0.5);
heap.push("b".to_string(), 0.5);
heap.push("c".to_string(), 0.5);
let results = heap.into_sorted_vec();
assert_eq!(results.len(), 2);
assert!(results.iter().any(|(id, _)| id == "a"));
assert!(results.iter().any(|(id, _)| id == "b"));
}
#[test]
fn test_bounded_heap_evicts_lowest() {
let mut heap = BoundedScoreHeap::new(2);
heap.push("low".to_string(), 0.1);
heap.push("mid".to_string(), 0.5);
heap.push("high".to_string(), 0.9);
let results = heap.into_sorted_vec();
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, "high");
assert_eq!(results[1].0, "mid");
}
#[test]
fn test_bounded_heap_ignores_non_finite() {
let mut heap = BoundedScoreHeap::new(5);
heap.push("nan".to_string(), f32::NAN);
heap.push("inf".to_string(), f32::INFINITY);
heap.push("neginf".to_string(), f32::NEG_INFINITY);
heap.push("ok".to_string(), 0.5);
let results = heap.into_sorted_vec();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, "ok");
}
#[test]
fn test_bounded_heap_empty() {
let heap = BoundedScoreHeap::new(5);
let results = heap.into_sorted_vec();
assert!(results.is_empty());
}
fn make_result(
name: &str,
chunk_type: ChunkType,
parent_type_name: Option<&str>,
score: f32,
) -> SearchResult {
SearchResult {
chunk: ChunkSummary {
id: name.to_string(),
file: std::path::PathBuf::from("test.ts"),
language: crate::parser::Language::TypeScript,
chunk_type,
name: name.to_string(),
signature: String::new(),
content: String::new(),
doc: None,
line_start: 1,
line_end: 10,
parent_id: None,
parent_type_name: parent_type_name.map(|s| s.to_string()),
content_hash: String::new(),
window_idx: None,
},
score,
}
}
#[test]
fn test_parent_boost_circuit_breaker() {
let mut results = vec![
make_result(
"recordFailure",
ChunkType::Method,
Some("CircuitBreaker"),
0.88,
),
make_result(
"retryWithBackoff",
ChunkType::Method,
Some("CircuitBreaker"),
0.86,
),
make_result(
"shouldAllow",
ChunkType::Method,
Some("CircuitBreaker"),
0.85,
),
make_result("CircuitBreaker", ChunkType::Class, None, 0.82),
];
apply_parent_boost(&mut results);
assert_eq!(results[0].chunk.name, "CircuitBreaker");
assert!(results[0].score > 0.90);
}
#[test]
fn test_parent_boost_no_effect_on_standalone_functions() {
let mut results = vec![
make_result("_insertionSortSmall", ChunkType::Function, None, 0.88),
make_result("insertionSort", ChunkType::Function, None, 0.85),
make_result("mergeSort", ChunkType::Function, None, 0.80),
];
let scores_before: Vec<f32> = results.iter().map(|r| r.score).collect();
apply_parent_boost(&mut results);
let scores_after: Vec<f32> = results.iter().map(|r| r.score).collect();
assert_eq!(scores_before, scores_after);
}
#[test]
fn test_parent_boost_needs_minimum_two_children() {
let mut results = vec![
make_result(
"recordFailure",
ChunkType::Method,
Some("CircuitBreaker"),
0.88,
),
make_result("CircuitBreaker", ChunkType::Class, None, 0.82),
make_result("unrelatedFn", ChunkType::Function, None, 0.80),
];
apply_parent_boost(&mut results);
assert_eq!(results[0].chunk.name, "recordFailure");
assert_eq!(results[1].chunk.name, "CircuitBreaker");
}
#[test]
fn test_parent_boost_caps_at_1_15() {
let mut results = vec![
make_result("m1", ChunkType::Method, Some("BigClass"), 0.88),
make_result("m2", ChunkType::Method, Some("BigClass"), 0.87),
make_result("m3", ChunkType::Method, Some("BigClass"), 0.86),
make_result("m4", ChunkType::Method, Some("BigClass"), 0.85),
make_result("m5", ChunkType::Method, Some("BigClass"), 0.84),
make_result("BigClass", ChunkType::Class, None, 0.78),
];
apply_parent_boost(&mut results);
let class_score = results
.iter()
.find(|r| r.chunk.name == "BigClass")
.unwrap()
.score;
assert!(
(class_score - 0.897).abs() < 0.001,
"Expected ~0.897, got {class_score}"
);
}
#[test]
fn test_parent_boost_too_few_results() {
let mut results = vec![
make_result("foo", ChunkType::Method, Some("Bar"), 0.88),
make_result("Bar", ChunkType::Class, None, 0.82),
];
let score_before = results[1].score;
apply_parent_boost(&mut results);
assert_eq!(results[1].score, score_before);
}
#[test]
fn test_chunk_importance_normal() {
assert_eq!(chunk_importance("parse_config", "src/lib.rs"), 1.0);
}
#[test]
fn test_chunk_importance_test_prefix() {
assert_eq!(chunk_importance("test_parse_config", "src/lib.rs"), 0.70);
}
#[test]
fn test_chunk_importance_test_upper() {
assert_eq!(
chunk_importance("TestParseConfig", "src/lib.go"),
ScoringConfig::DEFAULT.importance_test
);
}
#[test]
fn test_chunk_importance_underscore() {
assert_eq!(
chunk_importance("_helper", "src/lib.rs"),
ScoringConfig::DEFAULT.importance_private
);
}
#[test]
fn test_chunk_importance_dunder_not_demoted() {
assert_eq!(chunk_importance("__init__", "src/lib.py"), 1.0);
}
#[test]
fn test_chunk_importance_test_file() {
assert_eq!(
chunk_importance("helper_fn", "src/foo_test.rs"),
ScoringConfig::DEFAULT.importance_test
);
}
#[test]
fn test_chunk_importance_test_dir_demoted() {
assert_eq!(
chunk_importance("real_fn", "tests/fixtures/eval.rs"),
ScoringConfig::DEFAULT.importance_test
);
}
#[test]
fn test_chunk_importance_test_name_beats_path() {
assert_eq!(
chunk_importance("test_foo", "src/lib.rs"),
ScoringConfig::DEFAULT.importance_test
);
}
fn test_embedding(seed: f32) -> Vec<f32> {
let mut v = vec![seed; crate::EMBEDDING_DIM];
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut v {
*x /= norm;
}
}
v
}
fn make_note(sentiment: f32, mentions: &[&str]) -> NoteSummary {
NoteSummary {
id: "note:test".to_string(),
text: "test note".to_string(),
sentiment,
mentions: mentions.iter().map(|s| s.to_string()).collect(),
}
}
#[test]
fn test_score_candidate_basic() {
let emb = test_embedding(1.0);
let query = test_embedding(1.0);
let filter = SearchFilter::default();
let note_index = NoteBoostIndex::new(&[]);
let ctx = ScoringContext {
query: &query,
filter: &filter,
name_matcher: None,
glob_matcher: None,
note_index: ¬e_index,
threshold: 0.0,
};
let score = score_candidate(&emb, None, "src/lib.rs", &ctx);
assert!(score.is_some());
assert!(
score.unwrap() > 0.9,
"Self-similarity should be ~1.0, got {}",
score.unwrap()
);
}
#[test]
fn test_score_candidate_below_threshold() {
let emb = test_embedding(1.0);
let query = test_embedding(-1.0);
let filter = SearchFilter::default();
let note_index = NoteBoostIndex::new(&[]);
let ctx = ScoringContext {
query: &query,
filter: &filter,
name_matcher: None,
glob_matcher: None,
note_index: ¬e_index,
threshold: 0.5,
};
let score = score_candidate(&emb, None, "src/lib.rs", &ctx);
assert!(
score.is_none(),
"Opposite vectors should be below 0.5 threshold"
);
}
#[test]
fn test_score_candidate_glob_filters() {
let emb = test_embedding(1.0);
let query = test_embedding(1.0);
let filter = SearchFilter::default();
let note_index = NoteBoostIndex::new(&[]);
let glob = globset::Glob::new("src/**/*.rs").unwrap().compile_matcher();
let ctx = ScoringContext {
query: &query,
filter: &filter,
name_matcher: None,
glob_matcher: Some(&glob),
note_index: ¬e_index,
threshold: 0.0,
};
let score = score_candidate(&emb, None, "src/lib.rs", &ctx);
assert!(score.is_some());
let score = score_candidate(&emb, None, "tests/foo.py", &ctx);
assert!(score.is_none());
}
#[test]
fn test_score_candidate_name_boost() {
let emb = test_embedding(1.0);
let query = test_embedding(1.0);
let filter_no_boost = SearchFilter::default();
let filter_with_boost = SearchFilter {
name_boost: 0.3,
query_text: "parseConfig".to_string(),
..Default::default()
};
let note_index = NoteBoostIndex::new(&[]);
let matcher = NameMatcher::new("parseConfig");
let ctx_no = ScoringContext {
query: &query,
filter: &filter_no_boost,
name_matcher: None,
glob_matcher: None,
note_index: ¬e_index,
threshold: 0.0,
};
let score_no = score_candidate(&emb, Some("parseConfig"), "src/a.rs", &ctx_no).unwrap();
let ctx_yes = ScoringContext {
query: &query,
filter: &filter_with_boost,
name_matcher: Some(&matcher),
glob_matcher: None,
note_index: ¬e_index,
threshold: 0.0,
};
let score_yes = score_candidate(&emb, Some("parseConfig"), "src/a.rs", &ctx_yes).unwrap();
assert!(score_yes > 0.0);
assert!(score_no > 0.0);
}
#[test]
fn test_score_candidate_demotion() {
let emb = test_embedding(1.0);
let query = test_embedding(1.0);
let note_index = NoteBoostIndex::new(&[]);
let filter_no_demote = SearchFilter {
enable_demotion: false,
..Default::default()
};
let filter_demote = SearchFilter {
enable_demotion: true,
..Default::default()
};
let ctx_demote = ScoringContext {
query: &query,
filter: &filter_demote,
name_matcher: None,
glob_matcher: None,
note_index: ¬e_index,
threshold: 0.0,
};
let score_normal =
score_candidate(&emb, Some("real_fn"), "src/lib.rs", &ctx_demote).unwrap();
let score_test =
score_candidate(&emb, Some("test_foo"), "src/lib.rs", &ctx_demote).unwrap();
let ctx_no_demote = ScoringContext {
query: &query,
filter: &filter_no_demote,
name_matcher: None,
glob_matcher: None,
note_index: ¬e_index,
threshold: 0.0,
};
let score_no_demote =
score_candidate(&emb, Some("test_foo"), "src/lib.rs", &ctx_no_demote).unwrap();
assert!(score_test < score_normal, "test_ should be demoted");
assert!(
(score_no_demote - score_normal).abs() < 0.001,
"No demotion without flag"
);
}
#[test]
fn test_score_candidate_note_boost() {
let emb = test_embedding(1.0);
let query = test_embedding(1.0);
let filter = SearchFilter::default();
let notes = vec![make_note(1.0, &["lib.rs"])];
let note_index_boosted = NoteBoostIndex::new(¬es);
let note_index_empty = NoteBoostIndex::new(&[]);
let ctx_boosted = ScoringContext {
query: &query,
filter: &filter,
name_matcher: None,
glob_matcher: None,
note_index: ¬e_index_boosted,
threshold: 0.0,
};
let score_boosted =
score_candidate(&emb, Some("my_fn"), "src/lib.rs", &ctx_boosted).unwrap();
let ctx_plain = ScoringContext {
query: &query,
filter: &filter,
name_matcher: None,
glob_matcher: None,
note_index: ¬e_index_empty,
threshold: 0.0,
};
let score_plain = score_candidate(&emb, Some("my_fn"), "src/lib.rs", &ctx_plain).unwrap();
assert!(
score_boosted > score_plain,
"Positive note should boost score"
);
}
#[test]
fn heap_all_nan_scores() {
let mut heap = BoundedScoreHeap::new(5);
heap.push("a".to_string(), f32::NAN);
heap.push("b".to_string(), f32::NAN);
heap.push("c".to_string(), f32::NAN);
let results = heap.into_sorted_vec();
assert!(
results.is_empty(),
"All NaN scores should produce empty results, got {} items",
results.len()
);
}
#[test]
fn heap_mixed_valid_and_nan() {
let mut heap = BoundedScoreHeap::new(10);
heap.push("nan1".to_string(), f32::NAN);
heap.push("ok1".to_string(), 0.7);
heap.push("inf".to_string(), f32::INFINITY);
heap.push("ok2".to_string(), 0.9);
heap.push("nan2".to_string(), f32::NAN);
heap.push("neginf".to_string(), f32::NEG_INFINITY);
heap.push("ok3".to_string(), 0.5);
let results = heap.into_sorted_vec();
assert_eq!(results.len(), 3, "Only finite scores should be kept");
for (id, score) in &results {
assert!(
score.is_finite(),
"Result '{id}' has non-finite score {score}"
);
}
assert_eq!(results[0].0, "ok2");
assert_eq!(results[1].0, "ok1");
assert_eq!(results[2].0, "ok3");
}
#[test]
fn heap_negative_scores() {
let mut heap = BoundedScoreHeap::new(5);
heap.push("a".to_string(), -0.1);
heap.push("b".to_string(), -0.5);
heap.push("c".to_string(), -0.3);
let results = heap.into_sorted_vec();
assert_eq!(results.len(), 3, "All negative scores should be kept");
assert_eq!(results[0].0, "a", "Least negative should be first");
assert_eq!(results[1].0, "c");
assert_eq!(results[2].0, "b", "Most negative should be last");
}
#[test]
fn heap_capacity_zero() {
let mut heap = BoundedScoreHeap::new(0);
heap.push("a".to_string(), 0.9);
heap.push("b".to_string(), 0.8);
let results = heap.into_sorted_vec();
assert!(
results.is_empty(),
"Capacity-0 heap should always be empty, got {} items",
results.len()
);
}
#[test]
fn score_candidate_nan_embedding_filtered() {
let query = test_embedding(1.0);
let mut nan_emb = vec![f32::NAN; crate::EMBEDDING_DIM];
nan_emb[0] = 0.5;
nan_emb[1] = 0.3;
let filter = SearchFilter::default();
let note_index = NoteBoostIndex::new(&[]);
let ctx = ScoringContext {
query: &query,
filter: &filter,
name_matcher: None,
glob_matcher: None,
note_index: ¬e_index,
threshold: 0.0,
};
let result = score_candidate(&nan_emb, Some("nan_fn"), "src/lib.rs", &ctx);
assert!(
result.is_none(),
"NaN embedding should be filtered out (return None), got {:?}",
result
);
}
#[test]
fn score_candidate_nan_query_filtered() {
let nan_query = vec![f32::NAN; crate::EMBEDDING_DIM];
let normal_emb = test_embedding(1.0);
let filter = SearchFilter::default();
let note_index = NoteBoostIndex::new(&[]);
let ctx = ScoringContext {
query: &nan_query,
filter: &filter,
name_matcher: None,
glob_matcher: None,
note_index: ¬e_index,
threshold: 0.0,
};
let result = score_candidate(&normal_emb, Some("my_fn"), "src/lib.rs", &ctx);
assert!(
result.is_none(),
"NaN query should be filtered out (return None), got {:?}",
result
);
}
#[test]
fn score_candidate_nan_both_filtered() {
let nan_query = vec![f32::NAN; crate::EMBEDDING_DIM];
let nan_emb = vec![f32::NAN; crate::EMBEDDING_DIM];
let filter = SearchFilter::default();
let note_index = NoteBoostIndex::new(&[]);
let ctx = ScoringContext {
query: &nan_query,
filter: &filter,
name_matcher: None,
glob_matcher: None,
note_index: ¬e_index,
threshold: 0.0,
};
let result = score_candidate(&nan_emb, Some("fn"), "src/lib.rs", &ctx);
assert!(
result.is_none(),
"All-NaN inputs should be filtered out, got {:?}",
result
);
}
#[test]
fn score_candidate_zero_embedding() {
let zero_query = vec![0.0f32; crate::EMBEDDING_DIM];
let normal_emb = test_embedding(1.0);
let filter = SearchFilter {
query_text: "test".into(),
..Default::default()
};
let notes: Vec<NoteSummary> = vec![];
let note_index = NoteBoostIndex::new(¬es);
let ctx = ScoringContext {
query: &zero_query,
filter: &filter,
name_matcher: None,
glob_matcher: None,
note_index: ¬e_index,
threshold: 0.0,
};
let result = score_candidate(&normal_emb, None, "src/lib.rs", &ctx);
match result {
None => {}
Some(v) => assert!(
v.is_finite(),
"score_candidate with zero query must return finite score, got {v}"
),
}
}
}