use std::collections::{HashMap, HashSet};
use anyhow::{Context, Result};
use crate::core::classifier::{QueryClassifier, QueryIntent};
use crate::core::entity::EdgeKind;
use crate::core::git::{normalize_path, resolve_branch_files};
use crate::core::search::rrf::{rrf_fuse, RRF_K};
use super::{
build_compact_snippet, compute_match_reason, file_type_score_multiplier, hash_query,
raw_to_code_chunk, CodeChunk, CodeIndexer, SearchQuery, HNSW_OVERSAMPLE, KG_EXPAND_HOPS,
};
pub(crate) const BRANCH_BOOST_MIN: f32 = 1.0;
pub(crate) const BRANCH_BOOST_MAX: f32 = 3.0;
pub(crate) fn resolve_branch_set(
query: &SearchQuery,
root_path: &std::path::Path,
) -> (Option<HashSet<String>>, f32) {
let boost = query.branch_boost.clamp(BRANCH_BOOST_MIN, BRANCH_BOOST_MAX);
let files: Option<Vec<String>> = match &query.branch_files {
Some(v) if !v.is_empty() => Some(v.clone()),
_ => match &query.branch {
Some(name) => resolve_branch_files(root_path, name),
None => None,
},
};
let set = files.and_then(|v| {
let s: HashSet<String> = v.iter().map(|p| normalize_path(p).to_owned()).collect();
if s.is_empty() {
None
} else {
Some(s)
}
});
if (boost - 1.0).abs() < f32::EPSILON {
(None, boost)
} else {
(set, boost)
}
}
impl CodeIndexer {
pub fn get_embedding(&self, chunk_id: &str) -> Option<Vec<f32>> {
self.chunk_embeddings
.try_read()
.ok()
.and_then(|g| g.peek(chunk_id).cloned())
}
pub async fn embed_text(&self, text: &str) -> Result<Option<Vec<f32>>> {
let Some(embedder) = self.embedder.clone() else {
return Ok(None);
};
let vec = embedder.embed(text).await.context("embed text")?;
Ok(Some(vec))
}
pub(super) async fn embed_query(&self, query: &str) -> Result<Option<Vec<f32>>> {
let Some(embedder) = self.embedder.clone() else {
return Ok(None);
};
let key = hash_query(query);
if let Some(v) = self
.query_cache
.lock()
.expect("query_cache mutex poisoned")
.get(&key)
{
return Ok(Some(v.clone()));
}
let vec = embedder.embed(query).await.context("embed query")?;
self.query_cache
.lock()
.expect("query_cache mutex poisoned")
.put(key, vec.clone());
Ok(Some(vec))
}
async fn bm25_search(&self, query: &str, want: usize) -> Result<Vec<(String, f32)>> {
let bm25 = self.bm25.read().await;
if bm25.is_empty() {
return Ok(Vec::new());
}
Ok(bm25.score_query_all(query, want))
}
pub(super) async fn vector_search(
&self,
embedding: &[f32],
want: usize,
) -> Result<Vec<(String, f32)>> {
let Some(store) = &self.store else {
return Ok(Vec::new());
};
let hits = store.search(embedding, want).await?;
Ok(hits.into_iter().map(|h| (h.chunk_id, h.score)).collect())
}
fn edge_kinds_for_intent(intent: QueryIntent) -> Vec<EdgeKind> {
match intent {
QueryIntent::Definition => {
vec![EdgeKind::Implements, EdgeKind::Aliases, EdgeKind::UsesType]
}
QueryIntent::Usage => vec![
EdgeKind::CallsFunction,
EdgeKind::CalledByFunction,
EdgeKind::TestedBy,
EdgeKind::CoOccursInTest,
],
QueryIntent::Conceptual => {
vec![EdgeKind::ReferencesConcept, EdgeKind::Documents]
}
QueryIntent::BugDebt => vec![
EdgeKind::RaisesError,
EdgeKind::ErrorDescribes,
EdgeKind::Configures,
],
QueryIntent::Unknown => vec![EdgeKind::CallsFunction, EdgeKind::CalledByFunction],
}
}
async fn kg_expand(&self, seeds: &[(String, f32)], intent: QueryIntent) -> Vec<(String, f32)> {
let graph = self.symbol_graph().await;
if graph.node_count() == 0 || seeds.is_empty() {
return Vec::new();
}
let edge_kinds = Self::edge_kinds_for_intent(intent);
let seed_ids: std::collections::HashSet<&String> = seeds.iter().map(|(id, _)| id).collect();
let mut best: HashMap<String, f32> = HashMap::new();
for (seed_id, seed_score) in seeds {
let Some(symbol) = graph.symbol_for_chunk(seed_id) else {
continue;
};
for (_, neighbour_id, edge_kind) in
graph.neighbors_by_edge(symbol, &edge_kinds, KG_EXPAND_HOPS)
{
if seed_ids.contains(&neighbour_id) {
continue;
}
let derived = seed_score * edge_kind.score_multiplier();
best.entry(neighbour_id)
.and_modify(|s| {
if derived > *s {
*s = derived;
}
})
.or_insert(derived);
}
}
let mut out: Vec<(String, f32)> = best.into_iter().collect();
out.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
out
}
pub async fn search(&self, query: &SearchQuery) -> Result<Vec<CodeChunk>> {
let intent = QueryClassifier::classify_with_domain(&query.text, &self.domain_terms);
let (alpha, beta, use_kg_first) = intent.weights();
tracing::debug!(
"search index={} query={:?} intent={:?} alpha={} beta={}",
self.index_id,
query.text,
intent,
alpha,
beta
);
let embedding = self.embed_query(&query.text).await?;
let want = query.top_k.saturating_mul(HNSW_OVERSAMPLE).max(query.top_k);
let bm25_fut = self.bm25_search(&query.text, want);
let hnsw_results = match &embedding {
Some(v) => self.vector_search(v, want).await?,
None => Vec::new(),
};
let mut bm25_results = bm25_fut.await?;
self.inject_entity_exact_match(&intent, &query.text, beta, &mut bm25_results)
.await;
let fused_raw = rrf_fuse(
&hnsw_results,
&bm25_results,
alpha,
beta,
RRF_K,
query.top_k,
);
let fused = self.apply_mmr_rerank(fused_raw, query.top_k).await;
let (all, kg_ids) = self
.expand_with_kg(fused, &intent, use_kg_first, query.expand_graph)
.await;
let (branch_set, branch_boost) = resolve_branch_set(query, &self.root_path);
let all = self
.apply_score_adjustments(all, &intent, branch_set.as_ref(), branch_boost)
.await;
let result = self
.materialize_search_results(
all,
&hnsw_results,
&bm25_results,
&kg_ids,
branch_set.as_ref(),
query,
)
.await;
Ok(result)
}
async fn apply_score_adjustments(
&self,
candidates: Vec<(String, f32)>,
intent: &QueryIntent,
branch_files: Option<&HashSet<String>>,
branch_boost: f32,
) -> Vec<(String, f32)> {
let demote_docs = matches!(intent, QueryIntent::Definition);
let chunks = self.chunks.read().await;
let mut adjusted: Vec<(String, f32)> = candidates
.into_iter()
.map(|(id, score)| {
let mut multiplier = 1.0_f32;
let raw = chunks.get(&id);
if demote_docs {
if let Some(r) = raw {
multiplier *= file_type_score_multiplier(&r.file);
}
}
if let (Some(set), Some(r)) = (branch_files, raw) {
if set.contains(normalize_path(&r.file)) {
multiplier *= branch_boost;
}
}
(id, score * multiplier)
})
.collect();
adjusted.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
adjusted
}
async fn inject_entity_exact_match(
&self,
intent: &QueryIntent,
query_text: &str,
beta: f32,
bm25_results: &mut Vec<(String, f32)>,
) {
if !matches!(intent, QueryIntent::Definition | QueryIntent::Unknown) {
return;
}
let Some(hit) = self.entity_exact_match(query_text).await else {
return;
};
let injected_score = beta * 1.5;
bm25_results.retain(|(id, _)| id != &hit);
bm25_results.insert(0, (hit, injected_score));
}
async fn apply_mmr_rerank(
&self,
fused_raw: Vec<(String, f32)>,
top_k: usize,
) -> Vec<(String, f32)> {
let emb_map = self.chunk_embeddings.read().await;
if emb_map.is_empty() {
return fused_raw;
}
let snapshot: HashMap<String, Vec<f32>> = fused_raw
.iter()
.filter_map(|(id, _)| emb_map.peek(id).map(|v| (id.clone(), v.clone())))
.collect();
drop(emb_map);
crate::core::mmr::mmr_rerank(
fused_raw,
&snapshot,
crate::core::mmr::DEFAULT_LAMBDA,
top_k,
)
}
async fn expand_with_kg(
&self,
fused: Vec<(String, f32)>,
intent: &QueryIntent,
use_kg_first: bool,
expand_graph: bool,
) -> (Vec<(String, f32)>, std::collections::HashSet<String>) {
let mut all = fused.clone();
if !(use_kg_first && expand_graph) {
return (all, std::collections::HashSet::new());
}
let expanded = self.kg_expand(&fused, intent.clone()).await;
let kg_ids: std::collections::HashSet<String> =
expanded.iter().map(|(id, _)| id.clone()).collect();
all.extend(expanded);
(all, kg_ids)
}
async fn materialize_search_results(
&self,
all: Vec<(String, f32)>,
hnsw_results: &[(String, f32)],
bm25_results: &[(String, f32)],
kg_ids: &HashSet<String>,
branch_files: Option<&HashSet<String>>,
query: &SearchQuery,
) -> Vec<CodeChunk> {
let in_hnsw: HashSet<&String> = hnsw_results.iter().map(|(id, _)| id).collect();
let in_bm25: HashSet<&String> = bm25_results.iter().map(|(id, _)| id).collect();
let chunks = self.chunks.read().await;
let mut out = Vec::with_capacity(all.len().min(query.top_k));
for (id, score) in all.into_iter().take(query.top_k) {
let Some(raw) = chunks.get(&id) else {
tracing::trace!("fused id {id} not in corpus — likely race; skipping");
continue;
};
let match_reason = compute_match_reason(
in_hnsw.contains(&id),
in_bm25.contains(&id),
kg_ids.contains(&id),
);
let snippet = if query.compact {
Some(build_compact_snippet(&raw.content))
} else {
None
};
let mut chunk = raw_to_code_chunk(raw, score, match_reason, snippet);
if let Some(set) = branch_files {
chunk.on_branch = set.contains(normalize_path(&raw.file));
}
out.push(chunk);
}
out
}
}