use std::collections::{hash_map::DefaultHasher, HashMap};
use std::hash::{Hash, Hasher};
use std::num::NonZeroUsize;
use std::sync::{Arc, Mutex};
use anyhow::{Context, Result};
use lru::LruCache;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use crate::core::bm25::Bm25Index;
use crate::core::chunker::{chunk_ast, ChunkType, RawChunk};
use crate::core::classifier::{QueryClassifier, QueryIntent};
use crate::core::embed::Embedder;
use crate::core::entity::{EdgeKind, EntityType, RawEntity};
use crate::core::search::rrf::{rrf_fuse, RRF_K};
use crate::core::store::VectorStore;
use crate::core::symbol_graph::{ChunkTuple, SymbolGraph};
const QUERY_CACHE_CAPACITY: usize = 256;
const HNSW_OVERSAMPLE: usize = 4;
const DEFAULT_EMBEDDING_CACHE_CAP: usize = 1_000;
fn embedding_cache_cap() -> usize {
std::env::var("TRUSTY_EMBEDDING_CACHE")
.ok()
.and_then(|v| v.parse().ok())
.filter(|&n: &usize| n > 0)
.unwrap_or(DEFAULT_EMBEDDING_CACHE_CAP)
}
const DEFAULT_MAX_CHUNKS_PER_INDEX: usize = 200_000;
fn max_chunks_per_index() -> usize {
std::env::var("TRUSTY_MAX_CHUNKS")
.ok()
.and_then(|v| v.parse().ok())
.filter(|&n: &usize| n > 0)
.unwrap_or(DEFAULT_MAX_CHUNKS_PER_INDEX)
}
const DEFAULT_EMBED_BATCH_SIZE: usize = 128;
const EMBED_BATCH_MIN: usize = 32;
const EMBED_BATCH_MAX: usize = 2048;
fn embed_batch_size() -> usize {
std::env::var("TRUSTY_MAX_BATCH_SIZE")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.filter(|&n| n > 0)
.map(|n| n.clamp(EMBED_BATCH_MIN, EMBED_BATCH_MAX))
.unwrap_or(DEFAULT_EMBED_BATCH_SIZE)
}
#[allow(dead_code)]
const KG_EXPAND_SCORE_FACTOR: f32 = 0.7;
const KG_EXPAND_HOPS: usize = 1;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CodeChunk {
pub id: String,
pub file: String,
#[serde(default)]
pub language: Option<String>,
pub start_line: usize,
pub end_line: usize,
pub content: String,
pub function_name: Option<String>,
pub score: f32,
pub compact_snippet: Option<String>,
pub match_reason: String,
#[serde(default)]
pub chunk_type: ChunkType,
#[serde(default)]
pub calls: Vec<String>,
#[serde(default)]
pub inherits_from: Vec<String>,
#[serde(default)]
pub chunk_depth: u8,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub index_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchQuery {
pub text: String,
#[serde(default = "default_top_k")]
pub top_k: usize,
#[serde(default = "default_true")]
pub expand_graph: bool,
#[serde(default = "default_true")]
pub compact: bool,
}
fn default_top_k() -> usize {
10
}
fn default_true() -> bool {
true
}
fn hash_query(query: &str) -> u64 {
let mut h = DefaultHasher::new();
query.hash(&mut h);
h.finish()
}
fn build_compact_snippet(content: &str) -> String {
let lines: Vec<&str> = content.lines().collect();
if lines.len() <= 7 {
return content.to_string();
}
lines[..7].join("\n")
}
fn raw_to_code_chunk(
raw: &RawChunk,
score: f32,
match_reason: &str,
compact_snippet: Option<String>,
) -> CodeChunk {
let chunk_depth: u8 = raw.chunk_depth.min(u8::MAX as usize) as u8;
CodeChunk {
id: raw.id.clone(),
file: raw.file.clone(),
language: raw.language.clone(),
start_line: raw.start_line,
end_line: raw.end_line,
content: raw.content.clone(),
function_name: raw.function_name.clone(),
score,
compact_snippet,
match_reason: match_reason.to_string(),
chunk_type: raw.chunk_type.clone(),
calls: raw.calls.clone(),
inherits_from: raw.inherits_from.clone(),
chunk_depth,
index_id: None,
}
}
fn populate_virtual_terms(chunks: &mut [RawChunk], entities: &[RawEntity]) {
for chunk in chunks.iter_mut() {
let mut seen: std::collections::HashSet<&str> = std::collections::HashSet::new();
let mut terms: Vec<String> = Vec::new();
for ent in entities {
if ent.line >= chunk.start_line
&& ent.line <= chunk.end_line
&& seen.insert(ent.text.as_str())
{
terms.push(ent.text.clone());
}
}
chunk.virtual_terms = terms;
}
}
fn compute_match_reason(in_v: bool, in_b: bool, in_kg: bool) -> &'static str {
match (in_v, in_b, in_kg) {
(true, true, _) => "hybrid",
(true, false, _) => "vector",
(false, true, _) => "bm25",
(false, false, true) => "hybrid+kg",
(false, false, false) => "fallback",
}
}
#[derive(Default)]
pub struct ParsedBatch {
pub chunks: Vec<RawChunk>,
pub embeddings: Vec<Option<Vec<f32>>>,
pub entities_by_file: Vec<(String, Vec<RawEntity>)>,
pub parse_ms: u64,
pub embed_ms: u64,
pub vector_count: usize,
}
#[derive(Debug, Default, Clone, Copy)]
pub struct CommitTimings {
pub chunks: usize,
pub bm25_ms: u64,
pub vector_upsert_ms: u64,
pub kg_ms: u64,
}
pub struct CodeIndexer {
pub index_id: String,
pub root_path: std::path::PathBuf,
embedder: Option<Arc<dyn Embedder>>,
store: Option<Arc<dyn VectorStore>>,
chunks: Arc<RwLock<HashMap<String, RawChunk>>>,
entities: Arc<RwLock<HashMap<String, Vec<RawEntity>>>>,
chunk_embeddings: Arc<RwLock<LruCache<String, Vec<f32>>>>,
bm25: Arc<RwLock<Bm25Index>>,
query_cache: Arc<Mutex<LruCache<u64, Vec<f32>>>>,
symbol_graph: Arc<RwLock<Arc<SymbolGraph>>>,
ner: crate::core::ner::NerExtractor,
}
impl CodeIndexer {
pub fn new(index_id: impl Into<String>, root_path: impl Into<std::path::PathBuf>) -> Self {
let cap =
NonZeroUsize::new(QUERY_CACHE_CAPACITY).expect("QUERY_CACHE_CAPACITY must be non-zero");
let emb_cap = NonZeroUsize::new(embedding_cache_cap())
.expect("embedding_cache_cap must be non-zero (env var filtered)");
Self {
index_id: index_id.into(),
root_path: root_path.into(),
embedder: None,
store: None,
chunks: Arc::new(RwLock::new(HashMap::new())),
entities: Arc::new(RwLock::new(HashMap::new())),
chunk_embeddings: Arc::new(RwLock::new(LruCache::new(emb_cap))),
bm25: Arc::new(RwLock::new(Bm25Index::new())),
query_cache: Arc::new(Mutex::new(LruCache::new(cap))),
symbol_graph: Arc::new(RwLock::new(Arc::new(SymbolGraph::new()))),
ner: crate::core::ner::NerExtractor::try_load(),
}
}
pub async fn symbol_graph(&self) -> Arc<SymbolGraph> {
Arc::clone(&*self.symbol_graph.read().await)
}
async fn rebuild_symbol_graph(&self) {
let chunks = self.chunks.read().await;
let tuples: Vec<ChunkTuple> = chunks
.values()
.map(|c| {
(
c.id.clone(),
c.file.clone(),
c.function_name.clone(),
c.calls.clone(),
c.inherits_from.clone(),
c.chunk_type.clone(),
)
})
.collect();
drop(chunks);
let new_graph = Arc::new(SymbolGraph::build_from_chunks(&tuples));
*self.symbol_graph.write().await = new_graph;
}
pub fn with_components(
mut self,
embedder: Arc<dyn Embedder>,
store: Arc<dyn VectorStore>,
) -> Self {
self.embedder = Some(embedder);
self.store = Some(store);
self
}
pub fn get_embedding(&self, chunk_id: &str) -> Option<Vec<f32>> {
self.chunk_embeddings
.try_read()
.ok()
.and_then(|g| g.peek(chunk_id).cloned())
}
pub async fn find_chunk_id(&self, file_suffix: &str, function: Option<&str>) -> Option<String> {
let chunks = self.chunks.read().await;
let matching: Vec<&RawChunk> = chunks
.values()
.filter(|c| c.file.ends_with(file_suffix))
.filter(|c| match function {
Some(f) => c.function_name.as_deref() == Some(f),
None => true,
})
.collect();
matching
.into_iter()
.min_by_key(|c| c.start_line)
.map(|c| c.id.clone())
}
pub async fn similar_by_embedding(
&self,
embedding: &[f32],
top_k: usize,
exclude_id: Option<&str>,
) -> Result<Vec<CodeChunk>> {
let want = top_k.saturating_add(1).max(top_k);
let hits = self.vector_search(embedding, want).await?;
let chunks = self.chunks.read().await;
let mut out = Vec::with_capacity(top_k);
for (id, score) in hits {
if Some(id.as_str()) == exclude_id {
continue;
}
let Some(raw) = chunks.get(&id) else { continue };
let snippet = Some(build_compact_snippet(&raw.content));
out.push(raw_to_code_chunk(raw, score, "vector", snippet));
if out.len() >= top_k {
break;
}
}
Ok(out)
}
pub async fn all_chunks(&self) -> Vec<CodeChunk> {
let chunks = self.chunks.read().await;
chunks
.values()
.map(|raw| raw_to_code_chunk(raw, 0.0, "all", None))
.collect()
}
pub async fn enumerate_chunks(&self, offset: usize, limit: usize) -> (usize, Vec<CodeChunk>) {
let chunks = self.chunks.read().await;
let total = chunks.len();
if limit == 0 || offset >= total {
return (total, Vec::new());
}
let mut ordered: Vec<&RawChunk> = chunks.values().collect();
ordered.sort_by(|a, b| {
a.file
.cmp(&b.file)
.then(a.start_line.cmp(&b.start_line))
.then(a.end_line.cmp(&b.end_line))
});
let end = (offset + limit).min(total);
let page: Vec<CodeChunk> = ordered[offset..end]
.iter()
.map(|raw| raw_to_code_chunk(raw, 0.0, "enumerate", None))
.collect();
(total, page)
}
pub fn chunk_count(&self) -> usize {
self.chunks.try_read().map(|g| g.len()).unwrap_or(0)
}
fn bm25_doc_text(chunk: &RawChunk) -> String {
if chunk.virtual_terms.is_empty() {
chunk.content.clone()
} else {
let mut s = String::with_capacity(
chunk.content.len()
+ chunk
.virtual_terms
.iter()
.map(|t| t.len() + 1)
.sum::<usize>(),
);
s.push_str(&chunk.content);
for t in &chunk.virtual_terms {
s.push(' ');
s.push_str(t);
}
s
}
}
pub async fn add_chunk(&self, chunk: RawChunk) -> Result<()> {
let id = chunk.id.clone();
{
let chunks = self.chunks.read().await;
let cap = max_chunks_per_index();
if !chunks.contains_key(&id) && chunks.len() >= cap {
tracing::warn!(
"index '{}' chunk cap ({}) reached — skipping chunk {}",
self.index_id,
cap,
id
);
return Ok(());
}
}
if let (Some(embedder), Some(store)) = (&self.embedder, &self.store) {
let vec = embedder
.embed(&chunk.content)
.await
.context("embed chunk content")?;
store
.upsert(&id, vec.clone())
.await
.context("upsert chunk vector")?;
self.chunk_embeddings.write().await.put(id.clone(), vec);
}
let bm25_text = Self::bm25_doc_text(&chunk);
self.bm25.write().await.upsert_document(&id, &bm25_text);
self.chunks.write().await.insert(id, chunk);
self.rebuild_symbol_graph().await;
Ok(())
}
pub async fn index_file(&self, file_path: &str, content: &str) -> Result<()> {
let (mut chunks, entities) = chunk_ast(file_path, content);
populate_virtual_terms(&mut chunks, &entities);
let chunk_contents: Vec<String> = chunks.iter().map(|c| c.content.clone()).collect();
for chunk in chunks {
self.add_chunk(chunk).await?;
}
let all_entities = self
.enrich_with_nlp_entities(file_path, content, &chunk_contents, entities)
.await;
self.entities
.write()
.await
.insert(file_path.to_string(), all_entities);
self.rebuild_symbol_graph().await;
Ok(())
}
async fn enrich_with_nlp_entities(
&self,
file_path: &str,
content: &str,
chunk_contents: &[String],
base_entities: Vec<RawEntity>,
) -> Vec<RawEntity> {
let doc_text = crate::core::ner::extract_doc_comments(content);
let ner_entities = self.ner.extract(&doc_text, file_path);
if !ner_entities.is_empty() {
tracing::debug!(
"ner: {} NaturalLanguagePhrase entities for {}",
ner_entities.len(),
file_path
);
}
let mut all_entities = base_entities;
all_entities.extend(ner_entities);
if let Some(embedder) = &self.embedder {
let refs: Vec<&str> = chunk_contents.iter().map(|s| s.as_str()).collect();
let cluster_entities = crate::core::concept_cluster::cluster_concepts_from_contents(
&refs,
embedder.as_ref(),
file_path,
)
.await;
if !cluster_entities.is_empty() {
tracing::debug!(
"concept_cluster: {} ConceptCluster entities for {}",
cluster_entities.len(),
file_path
);
all_entities.extend(cluster_entities);
}
}
all_entities
}
pub async fn index_files_batch(&self, files: &[(String, String)]) -> Result<usize> {
self.index_files_batch_inner(files, false).await
}
pub async fn index_files_batch_no_rebuild(&self, files: &[(String, String)]) -> Result<usize> {
self.index_files_batch_inner(files, true).await
}
pub async fn rebuild_symbol_graph_now(&self) {
self.rebuild_symbol_graph().await;
}
async fn index_files_batch_inner(
&self,
files: &[(String, String)],
defer_graph_rebuild: bool,
) -> Result<usize> {
if files.is_empty() {
return Ok(0);
}
let parsed = self.parse_and_embed_files(files.to_vec()).await?;
let timings = self
.commit_parsed_batch(parsed, defer_graph_rebuild)
.await?;
Ok(timings.chunks)
}
pub async fn parse_and_embed_files(&self, files: Vec<(String, String)>) -> Result<ParsedBatch> {
if files.is_empty() {
return Ok(ParsedBatch::default());
}
let parse_start = std::time::Instant::now();
let parsed = Self::parse_files_parallel(files).await?;
let mut all_chunks: Vec<RawChunk> = Vec::new();
let mut entities_by_file: Vec<(String, Vec<RawEntity>)> = Vec::with_capacity(parsed.len());
for (path, chunks, entities) in parsed {
all_chunks.extend(chunks);
entities_by_file.push((path, entities));
}
let parse_ms = parse_start.elapsed().as_millis() as u64;
let embed_start = std::time::Instant::now();
let embeddings = self.embed_chunks_in_batches(&all_chunks).await?;
let embed_ms = embed_start.elapsed().as_millis() as u64;
let vector_count = embeddings.iter().filter(|e| e.is_some()).count();
Ok(ParsedBatch {
chunks: all_chunks,
embeddings,
entities_by_file,
parse_ms,
embed_ms,
vector_count,
})
}
async fn parse_files_parallel(
files: Vec<(String, String)>,
) -> Result<Vec<(String, Vec<RawChunk>, Vec<RawEntity>)>> {
use rayon::prelude::*;
tokio::task::spawn_blocking(move || {
files
.par_iter()
.map(|(path, content)| {
let (mut chunks, entities) = chunk_ast(path, content);
populate_virtual_terms(&mut chunks, &entities);
(path.clone(), chunks, entities)
})
.collect()
})
.await
.context("batch parse task panicked")
}
async fn embed_chunks_in_batches(&self, chunks: &[RawChunk]) -> Result<Vec<Option<Vec<f32>>>> {
let mut embeddings: Vec<Option<Vec<f32>>> = vec![None; chunks.len()];
let (Some(embedder), Some(_store)) = (&self.embedder, &self.store) else {
return Ok(embeddings);
};
let chunk_total = chunks.len();
let batch_size = embed_batch_size();
for batch_start in (0..chunk_total).step_by(batch_size) {
let batch_end = (batch_start + batch_size).min(chunk_total);
let batch_texts: Vec<&str> = chunks[batch_start..batch_end]
.iter()
.map(|c| c.content.as_str())
.collect();
let batch_vecs = embedder
.embed_batch(&batch_texts)
.await
.context("batch embed_batch failed")?;
if batch_vecs.len() != batch_texts.len() {
anyhow::bail!(
"embed_batch returned {} vectors, expected {}",
batch_vecs.len(),
batch_texts.len()
);
}
for (offset, vec) in batch_vecs.into_iter().enumerate() {
embeddings[batch_start + offset] = Some(vec);
}
}
Ok(embeddings)
}
pub async fn commit_parsed_batch(
&self,
parsed: ParsedBatch,
defer_graph_rebuild: bool,
) -> Result<CommitTimings> {
let ParsedBatch {
chunks: mut all_chunks,
embeddings,
entities_by_file,
parse_ms: _,
embed_ms: _,
vector_count: _,
} = parsed;
let chunk_total = all_chunks.len();
if chunk_total == 0 {
self.commit_entities(entities_by_file).await;
return Ok(CommitTimings::default());
}
let vec_start = std::time::Instant::now();
self.commit_vectors_batch(&all_chunks, &embeddings).await?;
let vector_upsert_ms = vec_start.elapsed().as_millis() as u64;
let bm25_start = std::time::Instant::now();
self.commit_bm25_batch(&all_chunks).await;
let bm25_ms = bm25_start.elapsed().as_millis() as u64;
self.commit_embeddings_cache(&all_chunks, embeddings).await;
self.commit_corpus(&mut all_chunks).await;
self.commit_entities(entities_by_file).await;
let kg_ms = if defer_graph_rebuild {
0
} else {
let kg_start = std::time::Instant::now();
self.rebuild_symbol_graph().await;
kg_start.elapsed().as_millis() as u64
};
Ok(CommitTimings {
chunks: chunk_total,
bm25_ms,
vector_upsert_ms,
kg_ms,
})
}
async fn commit_vectors_batch(
&self,
chunks: &[RawChunk],
embeddings: &[Option<Vec<f32>>],
) -> Result<()> {
let Some(store) = &self.store else {
return Ok(());
};
let items: Vec<(String, Vec<f32>)> = chunks
.iter()
.zip(embeddings.iter())
.filter_map(|(chunk, vec_opt)| vec_opt.as_ref().map(|v| (chunk.id.clone(), v.clone())))
.collect();
if items.is_empty() {
return Ok(());
}
store
.upsert_batch(&items)
.await
.context("batch upsert chunk vectors")
}
async fn commit_bm25_batch(&self, chunks: &[RawChunk]) {
let mut bm25 = self.bm25.write().await;
for chunk in chunks {
let text = Self::bm25_doc_text(chunk);
bm25.upsert_document(&chunk.id, &text);
}
}
async fn commit_embeddings_cache(
&self,
chunks: &[RawChunk],
embeddings: Vec<Option<Vec<f32>>>,
) {
if self.embedder.is_none() {
return;
}
let mut emb_cache = self.chunk_embeddings.write().await;
for (chunk, vec_opt) in chunks.iter().zip(embeddings) {
if let Some(vec) = vec_opt {
emb_cache.put(chunk.id.clone(), vec);
}
}
}
async fn commit_corpus(&self, chunks: &mut Vec<RawChunk>) {
let cap = max_chunks_per_index();
let mut corpus = self.chunks.write().await;
let mut dropped = 0usize;
for chunk in chunks.drain(..) {
if !corpus.contains_key(&chunk.id) && corpus.len() >= cap {
dropped += 1;
continue;
}
corpus.insert(chunk.id.clone(), chunk);
}
if dropped > 0 {
tracing::warn!(
"index '{}' chunk cap ({}) reached — dropped {} new chunks in batch",
self.index_id,
cap,
dropped
);
}
}
async fn commit_entities(&self, entities_by_file: Vec<(String, Vec<RawEntity>)>) {
let mut emap = self.entities.write().await;
for (path, ents) in entities_by_file {
emap.insert(path, ents);
}
}
pub async fn entities_for(&self, file_path: &str) -> Option<Vec<RawEntity>> {
self.entities.read().await.get(file_path).cloned()
}
async fn entity_exact_match(&self, query: &str) -> Option<String> {
let needle = query.trim();
if needle.is_empty() || needle.contains(' ') {
return None;
}
let entities = self.entities.read().await;
let chunks = self.chunks.read().await;
for (file, ents) in entities.iter() {
for ent in ents {
if !matches!(
ent.entity_type,
EntityType::NamedType | EntityType::ModulePath
) {
continue;
}
if ent.text.eq_ignore_ascii_case(needle) {
if let Some(c) = chunks
.values()
.filter(|c| c.file == *file)
.find(|c| ent.line >= c.start_line && ent.line <= c.end_line)
{
return Some(c.id.clone());
}
}
}
}
None
}
pub async fn remove_file(&self, file_path: &str) -> Result<usize> {
let ids: Vec<String> = {
let chunks = self.chunks.read().await;
chunks
.values()
.filter(|c| c.file == file_path)
.map(|c| c.id.clone())
.collect()
};
let removed = ids.len();
self.remove_chunks_from_stores(&ids).await;
self.entities.write().await.remove(file_path);
self.rebuild_symbol_graph().await;
Ok(removed)
}
async fn remove_chunks_from_stores(&self, ids: &[String]) {
if let Some(store) = &self.store {
for id in ids {
store.remove(id).await.ok();
}
}
{
let mut chunks = self.chunks.write().await;
for id in ids {
chunks.remove(id);
}
}
{
let mut emb = self.chunk_embeddings.write().await;
for id in ids {
emb.pop(id);
}
}
{
let mut bm25 = self.bm25.write().await;
for id in ids {
bm25.remove_document(id);
}
}
}
pub async fn remove_chunk(&self, chunk_id: &str) -> Result<()> {
if let Some(store) = &self.store {
store.remove(chunk_id).await.ok();
}
self.chunks.write().await.remove(chunk_id);
self.chunk_embeddings.write().await.pop(chunk_id);
self.bm25.write().await.remove_document(chunk_id);
self.rebuild_symbol_graph().await;
Ok(())
}
async fn embed_query(&self, query: &str) -> Result<Option<Vec<f32>>> {
let Some(embedder) = self.embedder.clone() else {
return Ok(None);
};
let key = hash_query(query);
if let Some(v) = self
.query_cache
.lock()
.expect("query_cache mutex poisoned")
.get(&key)
{
return Ok(Some(v.clone()));
}
let vec = embedder.embed(query).await.context("embed query")?;
self.query_cache
.lock()
.expect("query_cache mutex poisoned")
.put(key, vec.clone());
Ok(Some(vec))
}
async fn bm25_search(&self, query: &str, want: usize) -> Result<Vec<(String, f32)>> {
let bm25 = self.bm25.read().await;
if bm25.is_empty() {
return Ok(Vec::new());
}
Ok(bm25.score_query_all(query, want))
}
async fn vector_search(&self, embedding: &[f32], want: usize) -> Result<Vec<(String, f32)>> {
let Some(store) = &self.store else {
return Ok(Vec::new());
};
let hits = store.search(embedding, want).await?;
Ok(hits.into_iter().map(|h| (h.chunk_id, h.score)).collect())
}
fn edge_kinds_for_intent(intent: QueryIntent) -> Vec<EdgeKind> {
match intent {
QueryIntent::Definition => {
vec![EdgeKind::Implements, EdgeKind::Aliases, EdgeKind::UsesType]
}
QueryIntent::Usage => vec![
EdgeKind::CallsFunction,
EdgeKind::CalledByFunction,
EdgeKind::TestedBy,
EdgeKind::CoOccursInTest,
],
QueryIntent::Conceptual => {
vec![EdgeKind::ReferencesConcept, EdgeKind::Documents]
}
QueryIntent::BugDebt => vec![
EdgeKind::RaisesError,
EdgeKind::ErrorDescribes,
EdgeKind::Configures,
],
QueryIntent::Unknown => vec![EdgeKind::CallsFunction, EdgeKind::CalledByFunction],
}
}
async fn kg_expand(&self, seeds: &[(String, f32)], intent: QueryIntent) -> Vec<(String, f32)> {
let graph = self.symbol_graph().await;
if graph.node_count() == 0 || seeds.is_empty() {
return Vec::new();
}
let edge_kinds = Self::edge_kinds_for_intent(intent);
let seed_ids: std::collections::HashSet<&String> = seeds.iter().map(|(id, _)| id).collect();
let mut best: HashMap<String, f32> = HashMap::new();
for (seed_id, seed_score) in seeds {
let Some(symbol) = graph.symbol_for_chunk(seed_id) else {
continue;
};
for (_, neighbour_id, edge_kind) in
graph.neighbors_by_edge(symbol, &edge_kinds, KG_EXPAND_HOPS)
{
if seed_ids.contains(&neighbour_id) {
continue;
}
let derived = seed_score * edge_kind.score_multiplier();
best.entry(neighbour_id)
.and_modify(|s| {
if derived > *s {
*s = derived;
}
})
.or_insert(derived);
}
}
let mut out: Vec<(String, f32)> = best.into_iter().collect();
out.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
out
}
pub async fn search(&self, query: &SearchQuery) -> Result<Vec<CodeChunk>> {
let intent = QueryClassifier::classify(&query.text);
let (alpha, beta, use_kg_first) = intent.weights();
tracing::debug!(
"search index={} query={:?} intent={:?} alpha={} beta={}",
self.index_id,
query.text,
intent,
alpha,
beta
);
let embedding = self.embed_query(&query.text).await?;
let want = query.top_k.saturating_mul(HNSW_OVERSAMPLE).max(query.top_k);
let bm25_fut = self.bm25_search(&query.text, want);
let hnsw_results = match &embedding {
Some(v) => self.vector_search(v, want).await?,
None => Vec::new(),
};
let mut bm25_results = bm25_fut.await?;
self.inject_entity_exact_match(&intent, &query.text, beta, &mut bm25_results)
.await;
let fused_raw = rrf_fuse(
&hnsw_results,
&bm25_results,
alpha,
beta,
RRF_K,
query.top_k,
);
let fused = self.apply_mmr_rerank(fused_raw, query.top_k).await;
let (all, kg_ids) = self
.expand_with_kg(fused, &intent, use_kg_first, query.expand_graph)
.await;
let result = self
.materialize_search_results(all, &hnsw_results, &bm25_results, &kg_ids, query)
.await;
Ok(result)
}
async fn inject_entity_exact_match(
&self,
intent: &QueryIntent,
query_text: &str,
beta: f32,
bm25_results: &mut Vec<(String, f32)>,
) {
if !matches!(intent, QueryIntent::Definition | QueryIntent::Unknown) {
return;
}
let Some(hit) = self.entity_exact_match(query_text).await else {
return;
};
let injected_score = beta * 1.5;
bm25_results.retain(|(id, _)| id != &hit);
bm25_results.insert(0, (hit, injected_score));
}
async fn apply_mmr_rerank(
&self,
fused_raw: Vec<(String, f32)>,
top_k: usize,
) -> Vec<(String, f32)> {
let emb_map = self.chunk_embeddings.read().await;
if emb_map.is_empty() {
return fused_raw;
}
let snapshot: HashMap<String, Vec<f32>> = fused_raw
.iter()
.filter_map(|(id, _)| emb_map.peek(id).map(|v| (id.clone(), v.clone())))
.collect();
drop(emb_map);
crate::core::mmr::mmr_rerank(
fused_raw,
&snapshot,
crate::core::mmr::DEFAULT_LAMBDA,
top_k,
)
}
async fn expand_with_kg(
&self,
fused: Vec<(String, f32)>,
intent: &QueryIntent,
use_kg_first: bool,
expand_graph: bool,
) -> (Vec<(String, f32)>, std::collections::HashSet<String>) {
let mut all = fused.clone();
if !(use_kg_first && expand_graph) {
return (all, std::collections::HashSet::new());
}
let expanded = self.kg_expand(&fused, intent.clone()).await;
let kg_ids: std::collections::HashSet<String> =
expanded.iter().map(|(id, _)| id.clone()).collect();
all.extend(expanded);
(all, kg_ids)
}
async fn materialize_search_results(
&self,
all: Vec<(String, f32)>,
hnsw_results: &[(String, f32)],
bm25_results: &[(String, f32)],
kg_ids: &std::collections::HashSet<String>,
query: &SearchQuery,
) -> Vec<CodeChunk> {
let in_hnsw: std::collections::HashSet<&String> =
hnsw_results.iter().map(|(id, _)| id).collect();
let in_bm25: std::collections::HashSet<&String> =
bm25_results.iter().map(|(id, _)| id).collect();
let chunks = self.chunks.read().await;
let mut out = Vec::with_capacity(all.len().min(query.top_k));
for (id, score) in all.into_iter().take(query.top_k) {
let Some(raw) = chunks.get(&id) else {
tracing::trace!("fused id {id} not in corpus — likely race; skipping");
continue;
};
let match_reason = compute_match_reason(
in_hnsw.contains(&id),
in_bm25.contains(&id),
kg_ids.contains(&id),
);
let snippet = if query.compact {
Some(build_compact_snippet(&raw.content))
} else {
None
};
out.push(raw_to_code_chunk(raw, score, match_reason, snippet));
}
out
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::embed::MockEmbedder;
use crate::core::store::UsearchStore;
fn raw(id: &str, file: &str, content: &str) -> RawChunk {
RawChunk {
id: id.to_string(),
file: file.to_string(),
start_line: 1,
end_line: 1 + content.lines().count(),
content: content.to_string(),
function_name: None,
language: Some("rust".to_string()),
chunk_type: crate::core::chunker::ChunkType::Code,
calls: Vec::new(),
inherits_from: Vec::new(),
chunk_depth: 0,
parent_chunk_id: None,
child_chunk_ids: Vec::new(),
nlp_keywords: Vec::new(),
nlp_code_refs: Vec::new(),
virtual_terms: Vec::new(),
}
}
fn make_indexer() -> CodeIndexer {
let dim = 32;
let embedder: Arc<dyn Embedder> = Arc::new(MockEmbedder::new(dim));
let store: Arc<dyn VectorStore> = Arc::new(UsearchStore::new(dim).expect("usearch new"));
CodeIndexer::new("test", "/tmp/test").with_components(embedder, store)
}
#[tokio::test]
async fn test_search_integration_returns_relevant_chunk_first() {
let idx = make_indexer();
idx.add_chunk(raw(
"src/auth.rs:1:5",
"src/auth.rs",
"fn authenticate(user: &str, password: &str) -> bool { true }",
))
.await
.unwrap();
idx.add_chunk(raw(
"src/render.rs:1:3",
"src/render.rs",
"fn render_ui_components() { /* svelte */ }",
))
.await
.unwrap();
idx.add_chunk(raw(
"src/db.rs:1:4",
"src/db.rs",
"struct Database { conn: String }",
))
.await
.unwrap();
let q = SearchQuery {
text: "fn authenticate".to_string(),
top_k: 3,
expand_graph: false,
compact: true,
};
let results = idx.search(&q).await.expect("search");
assert!(!results.is_empty(), "search should return at least one hit");
assert_eq!(
results[0].id,
"src/auth.rs:1:5",
"auth chunk must rank first; got {:?}",
results.iter().map(|r| &r.id).collect::<Vec<_>>()
);
assert!(
results[0].compact_snippet.is_some(),
"compact_snippet should be populated when compact=true"
);
assert!(
results[0].match_reason == "hybrid" || results[0].match_reason == "bm25",
"expected hybrid or bm25 match_reason, got {}",
results[0].match_reason
);
}
#[tokio::test]
async fn test_query_cache_skips_embedder_on_repeat() {
let idx = make_indexer();
let q = "find user authentication logic";
let v1 = idx.embed_query(q).await.unwrap().unwrap();
let key = hash_query(q);
let cached = {
let mut g = idx.query_cache.lock().unwrap();
g.get(&key).cloned()
};
assert_eq!(cached.as_ref(), Some(&v1), "cache must be populated");
let v2 = idx.embed_query(q).await.unwrap().unwrap();
assert_eq!(v1, v2, "second call must return identical vector via cache");
}
#[tokio::test]
async fn test_search_with_no_embedder_falls_back_to_bm25() {
let idx = CodeIndexer::new("bm25-only", "/tmp/test");
idx.add_chunk(raw("f.rs:1:1", "f.rs", "fn authenticate() {}"))
.await
.unwrap();
idx.add_chunk(raw("g.rs:1:1", "g.rs", "fn unrelated() {}"))
.await
.unwrap();
let q = SearchQuery {
text: "authenticate".to_string(),
top_k: 5,
expand_graph: false,
compact: false,
};
let r = idx.search(&q).await.unwrap();
assert_eq!(r[0].id, "f.rs:1:1");
assert_eq!(r[0].match_reason, "bm25");
}
#[tokio::test]
async fn test_remove_chunk_removes_from_results() {
let idx = make_indexer();
idx.add_chunk(raw("a:1:1", "a.rs", "fn authenticate() {}"))
.await
.unwrap();
idx.add_chunk(raw("b:1:1", "b.rs", "fn other_thing() {}"))
.await
.unwrap();
idx.remove_chunk("a:1:1").await.unwrap();
let q = SearchQuery {
text: "authenticate".to_string(),
top_k: 5,
expand_graph: false,
compact: false,
};
let r = idx.search(&q).await.unwrap();
assert!(!r.iter().any(|c| c.id == "a:1:1"));
}
#[tokio::test]
async fn test_kg_expansion_marks_neighbours_with_hybrid_kg() {
let idx = CodeIndexer::new("kg-test", "/tmp/test");
idx.add_chunk(RawChunk {
id: "h:1".to_string(),
file: "h.rs".to_string(),
start_line: 1,
end_line: 3,
content: "fn login_handler() { /* dispatch to verifier */ }".to_string(),
function_name: Some("login_handler".to_string()),
language: Some("rust".to_string()),
chunk_type: crate::core::chunker::ChunkType::Function,
calls: vec!["authenticate".to_string()],
inherits_from: Vec::new(),
chunk_depth: 0,
parent_chunk_id: None,
child_chunk_ids: Vec::new(),
nlp_keywords: Vec::new(),
nlp_code_refs: Vec::new(),
virtual_terms: Vec::new(),
})
.await
.unwrap();
idx.add_chunk(RawChunk {
id: "a:1".to_string(),
file: "a.rs".to_string(),
start_line: 1,
end_line: 1,
content: "fn authenticate() {}".to_string(),
function_name: Some("authenticate".to_string()),
language: Some("rust".to_string()),
chunk_type: crate::core::chunker::ChunkType::Function,
calls: Vec::new(),
inherits_from: Vec::new(),
chunk_depth: 0,
parent_chunk_id: None,
child_chunk_ids: Vec::new(),
nlp_keywords: Vec::new(),
nlp_code_refs: Vec::new(),
virtual_terms: Vec::new(),
})
.await
.unwrap();
let q = SearchQuery {
text: "callers of authenticate".to_string(),
top_k: 10,
expand_graph: true,
compact: false,
};
let results = idx.search(&q).await.unwrap();
let login = results
.iter()
.find(|c| c.id == "h:1")
.expect("login_handler should surface via KG expansion");
assert_eq!(
login.match_reason, "hybrid+kg",
"KG-expanded chunks must carry hybrid+kg marker, got {}",
login.match_reason
);
let trigger = results
.iter()
.find(|c| c.id == "a:1")
.expect("authenticate must appear directly");
let expected = trigger.score * KG_EXPAND_SCORE_FACTOR;
assert!(
(login.score - expected).abs() < 1e-5,
"expected KG score = 0.7 * {} = {}, got {}",
trigger.score,
expected,
login.score
);
}
#[tokio::test]
async fn test_kg_expansion_disabled_by_expand_graph_false() {
let idx = make_indexer();
idx.add_chunk(RawChunk {
id: "h:1".to_string(),
file: "h.rs".to_string(),
start_line: 1,
end_line: 1,
content: "fn caller() { target(); }".to_string(),
function_name: Some("caller".to_string()),
language: Some("rust".to_string()),
chunk_type: crate::core::chunker::ChunkType::Function,
calls: vec!["target".to_string()],
inherits_from: Vec::new(),
chunk_depth: 0,
parent_chunk_id: None,
child_chunk_ids: Vec::new(),
nlp_keywords: Vec::new(),
nlp_code_refs: Vec::new(),
virtual_terms: Vec::new(),
})
.await
.unwrap();
idx.add_chunk(RawChunk {
id: "t:1".to_string(),
file: "t.rs".to_string(),
start_line: 1,
end_line: 1,
content: "fn target() {}".to_string(),
function_name: Some("target".to_string()),
language: Some("rust".to_string()),
chunk_type: crate::core::chunker::ChunkType::Function,
calls: Vec::new(),
inherits_from: Vec::new(),
chunk_depth: 0,
parent_chunk_id: None,
child_chunk_ids: Vec::new(),
nlp_keywords: Vec::new(),
nlp_code_refs: Vec::new(),
virtual_terms: Vec::new(),
})
.await
.unwrap();
let q = SearchQuery {
text: "callers of target".to_string(),
top_k: 10,
expand_graph: false,
compact: false,
};
let results = idx.search(&q).await.unwrap();
assert!(
!results.iter().any(|c| c.match_reason.contains("kg")),
"expand_graph=false must suppress KG expansion, got {results:#?}"
);
}
#[tokio::test]
async fn test_symbol_graph_rebuilds_after_indexing() {
let idx = make_indexer();
assert_eq!(idx.symbol_graph().await.node_count(), 0);
idx.index_file("a.rs", "fn alpha() { beta(); }\nfn beta() {}\n")
.await
.unwrap();
let g = idx.symbol_graph().await;
assert!(g.node_count() >= 2, "graph should hold alpha + beta");
assert!(
!g.callees_of("alpha", 1).is_empty(),
"alpha should have a callee edge to beta"
);
}
#[tokio::test]
async fn test_entity_exact_match_finds_chunk() {
let idx = make_indexer();
idx.index_file("e.rs", "pub struct MyType { x: u32 }\nfn f() {}\n")
.await
.unwrap();
let hit = idx.entity_exact_match("MyType").await;
assert!(hit.is_some(), "expected entity_exact_match to find MyType");
let hit_id = hit.unwrap();
let chunks = idx.chunks.read().await;
assert!(
chunks
.get(&hit_id)
.map(|c| c.file == "e.rs")
.unwrap_or(false),
"matched chunk should live in e.rs",
);
}
#[tokio::test]
async fn test_entity_exact_match_struct_ranks_first() {
let idx = CodeIndexer::new("ent-rank-1", "/tmp/test");
idx.index_file(
"src/types.rs",
"pub struct FooBar { pub x: u32 }\n\nfn unrelated() { let _ = 1; }\n",
)
.await
.unwrap();
idx.index_file("src/other.rs", "fn other_thing() {}\n")
.await
.unwrap();
let q = SearchQuery {
text: "FooBar".to_string(),
top_k: 5,
expand_graph: false,
compact: false,
};
let results = idx.search(&q).await.expect("search");
assert!(!results.is_empty(), "search must return at least one hit");
assert_eq!(
results[0].file,
"src/types.rs",
"FooBar's defining file must rank first; got {:?}",
results.iter().map(|r| &r.file).collect::<Vec<_>>(),
);
assert!(
results[0].content.contains("FooBar"),
"rank-1 chunk must contain the FooBar definition; got {:?}",
results[0].content,
);
}
#[tokio::test]
async fn test_entity_exact_match_skips_non_symbol_entities() {
let idx = make_indexer();
idx.index_file("lit.rs", "fn f() { let _ = \"this is a long literal\"; }\n")
.await
.unwrap();
assert!(
idx.entity_exact_match("literal").await.is_none(),
"non-symbol entity types must not satisfy entity_exact_match"
);
}
#[tokio::test]
async fn test_entity_exact_match_skips_multiword_query() {
let idx = make_indexer();
idx.index_file("e.rs", "use std::sync::Arc;\nfn f() {}\n")
.await
.unwrap();
assert!(idx.entity_exact_match("Arc thing").await.is_none());
}
#[tokio::test]
async fn test_virtual_terms_populated_from_entities() {
let idx = make_indexer();
idx.index_file(
"v.rs",
"use std::sync::Arc;\nfn f() { let _x: Arc<String> = Arc::new(String::new()); }\n",
)
.await
.unwrap();
let chunks = idx.chunks.read().await;
let f_chunk = chunks
.values()
.find(|c| c.function_name.as_deref() == Some("f"))
.expect("f chunk");
assert!(
f_chunk.virtual_terms.iter().any(|t| t == "Arc"),
"expected 'Arc' in virtual_terms, got {:?}",
f_chunk.virtual_terms
);
}
#[tokio::test]
async fn test_get_embedding_returns_some_after_indexing() {
let idx = make_indexer();
idx.add_chunk(raw("a:1:1", "a.rs", "fn alpha() {}"))
.await
.unwrap();
let emb = idx.get_embedding("a:1:1");
assert!(emb.is_some(), "expected embedding cached after add_chunk");
assert!(idx.get_embedding("nope").is_none());
}
#[tokio::test]
async fn test_similar_by_embedding_excludes_seed() {
let idx = make_indexer();
idx.add_chunk(raw("a:1:1", "a.rs", "fn alpha() {}"))
.await
.unwrap();
idx.add_chunk(raw("b:1:1", "b.rs", "fn beta() {}"))
.await
.unwrap();
let emb = idx.get_embedding("a:1:1").unwrap();
let results = idx
.similar_by_embedding(&emb, 5, Some("a:1:1"))
.await
.unwrap();
assert!(results.iter().all(|c| c.id != "a:1:1"));
assert!(results.iter().all(|c| c.match_reason == "vector"));
}
#[tokio::test]
async fn test_index_files_batch_indexes_all_chunks_once() {
let idx = make_indexer();
let files = vec![
(
"src/a.rs".to_string(),
"fn alpha() { beta(); }\nfn beta() {}\n".to_string(),
),
(
"src/b.rs".to_string(),
"fn gamma() {}\nfn delta() { gamma(); }\n".to_string(),
),
];
let added = idx.index_files_batch(&files).await.unwrap();
assert!(added >= 4, "expected at least 4 chunks, got {added}");
let g = idx.symbol_graph().await;
assert!(g.node_count() >= 4);
let q = SearchQuery {
text: "fn alpha".to_string(),
top_k: 5,
expand_graph: false,
compact: false,
};
let r = idx.search(&q).await.unwrap();
assert!(r.iter().any(|c| c.file == "src/a.rs"));
}
#[tokio::test]
async fn test_index_files_batch_empty_input_is_noop() {
let idx = make_indexer();
let added = idx.index_files_batch(&[]).await.unwrap();
assert_eq!(added, 0);
assert_eq!(idx.chunk_count(), 0);
}
#[tokio::test]
async fn test_index_files_batch_bm25_only_mode() {
let idx = CodeIndexer::new("bm25-batch", "/tmp/test");
let files = vec![(
"x.rs".to_string(),
"fn authenticate() {}\nfn other() {}\n".to_string(),
)];
let added = idx.index_files_batch(&files).await.unwrap();
assert!(added >= 2);
let r = idx
.search(&SearchQuery {
text: "authenticate".to_string(),
top_k: 5,
expand_graph: false,
compact: false,
})
.await
.unwrap();
assert!(r.iter().any(|c| c.content.contains("authenticate")));
}
#[test]
fn test_intent_routing_definitions() {
use crate::core::classifier::QueryIntent;
let (a, b, kg) = QueryIntent::Definition.weights();
assert!((a - 0.3).abs() < 1e-6 && (b - 0.7).abs() < 1e-6 && !kg);
let (a, b, kg) = QueryIntent::Usage.weights();
assert!((a - 0.5).abs() < 1e-6 && (b - 0.5).abs() < 1e-6 && kg);
}
#[tokio::test]
async fn test_enumerate_chunks_paginates_stable_order() {
let idx = make_indexer();
fn raw_lines(id: &str, file: &str, start: usize, end: usize, content: &str) -> RawChunk {
let mut r = raw(id, file, content);
r.start_line = start;
r.end_line = end;
r
}
idx.add_chunk(raw_lines("b.rs:10:20", "b.rs", 10, 20, "fn b_two() {}"))
.await
.unwrap();
idx.add_chunk(raw_lines("a.rs:1:5", "a.rs", 1, 5, "fn a_one() {}"))
.await
.unwrap();
idx.add_chunk(raw_lines("b.rs:1:5", "b.rs", 1, 5, "fn b_one() {}"))
.await
.unwrap();
idx.add_chunk(raw_lines("a.rs:30:40", "a.rs", 30, 40, "fn a_two() {}"))
.await
.unwrap();
let (total_all, all) = idx.enumerate_chunks(0, 100).await;
assert_eq!(total_all, 4);
let ids: Vec<_> = all.iter().map(|c| c.id.as_str()).collect();
assert_eq!(
ids,
vec!["a.rs:1:5", "a.rs:30:40", "b.rs:1:5", "b.rs:10:20"]
);
let (total_p1, page1) = idx.enumerate_chunks(0, 2).await;
let (total_p2, page2) = idx.enumerate_chunks(2, 2).await;
assert_eq!(total_p1, 4);
assert_eq!(total_p2, 4);
assert_eq!(page1.len(), 2);
assert_eq!(page2.len(), 2);
let combined: Vec<_> = page1
.iter()
.chain(page2.iter())
.map(|c| c.id.as_str())
.collect();
assert_eq!(combined, ids);
let (total_end, end) = idx.enumerate_chunks(10, 5).await;
assert_eq!(total_end, 4);
assert!(end.is_empty());
let (total_z, z) = idx.enumerate_chunks(0, 0).await;
assert_eq!(total_z, 4);
assert!(z.is_empty());
}
}