use std::collections::HashMap;
use crate::bm25::Bm25Index;
use crate::encoder::{SemanticIndex, StaticEncoder};
use crate::graph::DependencyGraph;
use crate::model::{Chunk, MatchLine, SearchResult};
use crate::ranking::{
apply_query_boost, boost_multi_chunk_files, rerank_topk, rerank_topk_for_bm25_code,
resolve_alpha,
};
use crate::tokens::tokenize;
const RRF_K: f64 = 60.0;
const MIN_SCORE_RATIO: f64 = 0.12;
fn top_scored(raw: &[f32], k: usize) -> Vec<(usize, f64)> {
let mut indexed: Vec<(usize, f64)> = raw
.iter()
.enumerate()
.filter(|(_, &s)| s > 0.0)
.map(|(i, &s)| (i, s as f64))
.collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.truncate(k);
indexed
}
fn rrf_scores(scores: &HashMap<usize, f64>) -> HashMap<usize, f64> {
if scores.is_empty() {
return HashMap::new();
}
let mut ranked: Vec<(usize, f64)> = scores.iter().map(|(&k, &v)| (k, v)).collect();
ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
ranked
.iter()
.enumerate()
.map(|(rank, &(idx, _))| (idx, 1.0 / (RRF_K + rank as f64 + 1.0)))
.collect()
}
fn selector_to_mask(selector: Option<&[usize]>, size: usize) -> Option<Vec<bool>> {
let indices = selector?;
let mut mask = vec![false; size];
for &idx in indices {
if idx < size {
mask[idx] = true;
}
}
Some(mask)
}
fn find_match_lines(chunk: &Chunk, query: &str) -> Vec<MatchLine> {
let query_lower = query.to_lowercase();
let keywords: Vec<&str> = query_lower
.split_whitespace()
.filter(|w| w.len() >= 2)
.collect();
if keywords.is_empty() {
return Vec::new();
}
let mut matches = Vec::new();
for (i, line) in chunk.content.lines().enumerate() {
let line_lower = line.to_lowercase();
if keywords.iter().any(|kw| line_lower.contains(kw)) {
matches.push(MatchLine {
line: chunk.start_line + i,
content: line.trim().to_string(),
});
}
}
matches
}
fn boost_sibling_chunks(scores: &mut HashMap<usize, f64>, chunks: &[Chunk], query: &str) {
let keywords: Vec<String> = query
.to_lowercase()
.split_whitespace()
.filter(|w| w.len() >= 3)
.map(String::from)
.collect();
if keywords.is_empty() {
return;
}
let mut file_has_match: HashMap<&str, f64> = HashMap::new();
let mut match_counts: Vec<usize> = Vec::with_capacity(chunks.len());
for (idx, chunk) in chunks.iter().enumerate() {
let content_lower = chunk.content.to_lowercase();
let count = keywords
.iter()
.filter(|kw| content_lower.contains(kw.as_str()))
.count();
if count > 0 {
let score = scores.get(&idx).copied().unwrap_or(0.001);
let fp = chunk.file_path.as_str();
let entry = file_has_match.entry(fp).or_insert(0.0);
if score > *entry {
*entry = score;
}
}
match_counts.push(count);
}
for (idx, chunk) in chunks.iter().enumerate() {
let match_count = match_counts[idx];
if match_count == 0 {
continue;
}
if let Some(existing) = scores.get_mut(&idx) {
*existing += *existing * 0.3 * match_count as f64;
} else if let Some(&file_score) = file_has_match.get(chunk.file_path.as_str()) {
scores.insert(idx, file_score * (0.8 + 0.2 * match_count as f64));
}
}
}
fn filter_low_scores(results: Vec<SearchResult>) -> Vec<SearchResult> {
if results.len() <= 1 {
return results;
}
let top_score = results[0].score;
if top_score <= 0.0 {
return Vec::new();
}
let min = top_score * MIN_SCORE_RATIO;
results.into_iter().filter(|r| r.score >= min).collect()
}
fn is_code_chunk(chunk: &Chunk) -> bool {
!matches!(
chunk.language.as_deref(),
Some("markdown") | Some("text") | None
)
}
fn boost_from_graph(
scores: &mut HashMap<usize, f64>,
chunks: &[Chunk],
graph: &DependencyGraph,
file_mapping: &HashMap<String, Vec<usize>>,
) {
if scores.is_empty() {
return;
}
let max_score = scores.values().cloned().fold(f64::NEG_INFINITY, f64::max);
if max_score <= 0.0 {
return;
}
let mut top_files: Vec<(&str, f64)> = Vec::new();
for (&idx, &score) in scores.iter() {
let fp = chunks[idx].file_path.as_str();
if score >= max_score * 0.5 {
top_files.push((fp, score));
}
}
let boost = max_score * 0.3;
for (top_fp, _) in &top_files {
let dependents = graph.dependents(top_fp);
if let Some(node) = graph.deps(top_fp) {
for dep in &node.depends_on {
if let Some(indices) = file_mapping.get(dep.as_str()) {
for &idx in indices {
scores.entry(idx).or_insert(boost * 0.5);
}
}
}
}
for dep_fp in dependents {
if let Some(indices) = file_mapping.get(dep_fp) {
for &idx in indices {
scores.entry(idx).or_insert(boost * 0.3);
}
}
}
}
}
pub fn search_bm25(
query: &str,
bm25_index: &Bm25Index,
chunks: &[Chunk],
top_k: usize,
selector: Option<&[usize]>,
) -> Vec<SearchResult> {
let tokens = tokenize(query);
if tokens.is_empty() {
return Vec::new();
}
let mask = selector_to_mask(selector, chunks.len());
let scores = bm25_index.get_scores(&tokens, mask.as_deref());
let candidate_count = top_k.saturating_mul(8).max(top_k);
let mut indexed = top_scored(&scores, candidate_count);
let has_code_candidate = indexed.iter().any(|&(idx, _)| is_code_chunk(&chunks[idx]));
if !has_code_candidate {
indexed.truncate(top_k);
let results: Vec<SearchResult> = indexed
.into_iter()
.map(|(idx, score)| {
let match_lines = find_match_lines(&chunks[idx], query);
SearchResult {
chunk: chunks[idx].clone(),
score,
match_lines,
}
})
.collect();
return filter_low_scores(results);
}
let mut candidate_scores: HashMap<usize, f64> = indexed.into_iter().collect();
apply_query_boost(&mut candidate_scores, query, chunks);
let ranked = rerank_topk_for_bm25_code(&candidate_scores, chunks, top_k);
let results: Vec<SearchResult> = ranked
.into_iter()
.map(|(idx, score)| {
let match_lines = find_match_lines(&chunks[idx], query);
SearchResult {
chunk: chunks[idx].clone(),
score,
match_lines,
}
})
.collect();
filter_low_scores(results)
}
pub(crate) struct HybridSearchContext<'a> {
pub(crate) encoder: &'a StaticEncoder,
pub(crate) semantic_index: &'a SemanticIndex,
pub(crate) bm25_index: &'a Bm25Index,
pub(crate) chunks: &'a [Chunk],
pub(crate) graph: Option<&'a DependencyGraph>,
pub(crate) file_mapping: &'a HashMap<String, Vec<usize>>,
}
pub(crate) fn search_hybrid(
query: &str,
context: HybridSearchContext<'_>,
top_k: usize,
alpha: Option<f64>,
selector: Option<&[usize]>,
) -> Vec<SearchResult> {
let alpha_weight = resolve_alpha(query, alpha);
let candidate_count = top_k * 5;
let chunks = context.chunks;
let query_embedding = match context.encoder.encode_single(query) {
Ok(e) => e,
Err(err) => {
log::warn!("semantic query encoding failed; falling back to BM25-only search: {err:#}");
return search_bm25(query, context.bm25_index, chunks, top_k, selector);
}
};
let semantic_results =
context
.semantic_index
.query(&query_embedding, candidate_count, selector);
let semantic_scores: HashMap<usize, f64> = semantic_results
.iter()
.map(|&(idx, dist)| (idx, (1.0 - dist) as f64))
.collect();
let tokens = tokenize(query);
let bm25_scores: HashMap<usize, f64> = if !tokens.is_empty() {
let mask = selector_to_mask(selector, chunks.len());
let raw_scores = context.bm25_index.get_scores(&tokens, mask.as_deref());
top_scored(&raw_scores, candidate_count)
.into_iter()
.collect()
} else {
HashMap::new()
};
let norm_semantic = rrf_scores(&semantic_scores);
let norm_bm25 = rrf_scores(&bm25_scores);
let all_indices: std::collections::HashSet<usize> = norm_semantic
.keys()
.chain(norm_bm25.keys())
.cloned()
.collect();
let mut combined: HashMap<usize, f64> = HashMap::new();
for idx in all_indices {
let sem = norm_semantic.get(&idx).copied().unwrap_or(0.0);
let bm = norm_bm25.get(&idx).copied().unwrap_or(0.0);
combined.insert(idx, alpha_weight * sem + (1.0 - alpha_weight) * bm);
}
boost_multi_chunk_files(&mut combined, chunks);
apply_query_boost(&mut combined, query, chunks);
boost_sibling_chunks(&mut combined, chunks, query);
if let Some(graph) = context.graph {
boost_from_graph(&mut combined, chunks, graph, context.file_mapping);
}
let ranked = rerank_topk(&combined, chunks, top_k, alpha_weight < 1.0);
let results: Vec<SearchResult> = ranked
.into_iter()
.map(|(idx, score)| {
let match_lines = find_match_lines(&chunks[idx], query);
SearchResult {
chunk: chunks[idx].clone(),
score,
match_lines,
}
})
.collect();
filter_low_scores(results)
}