use std::collections::{hash_map::DefaultHasher, HashMap};
use std::hash::{Hash, Hasher};
use std::num::NonZeroUsize;
use std::sync::atomic::AtomicBool;
use std::sync::{Arc, Mutex};
use lru::LruCache;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use crate::core::bm25::Bm25Index;
use crate::core::chunker::{ChunkType, RawChunk};
use crate::core::embed::Embedder;
use crate::core::entity::RawEntity;
use crate::core::store::VectorStore;
use crate::core::symbol_graph::SymbolGraph;
mod files;
mod ingest;
mod persist;
mod search;
#[cfg(test)]
mod tests;
const QUERY_CACHE_CAPACITY: usize = 256;
pub(crate) const HNSW_OVERSAMPLE: usize = 4;
const DEFAULT_EMBEDDING_CACHE_CAP: usize = 1_000;
fn embedding_cache_cap() -> usize {
std::env::var("TRUSTY_EMBEDDING_CACHE")
.ok()
.and_then(|v| v.parse().ok())
.filter(|&n: &usize| n > 0)
.unwrap_or(DEFAULT_EMBEDDING_CACHE_CAP)
}
const DEFAULT_MAX_CHUNKS_PER_INDEX: usize = 200_000;
pub(crate) fn max_chunks_per_index() -> usize {
std::env::var("TRUSTY_MAX_CHUNKS")
.ok()
.and_then(|v| v.parse().ok())
.filter(|&n: &usize| n > 0)
.unwrap_or(DEFAULT_MAX_CHUNKS_PER_INDEX)
}
const DEFAULT_EMBED_BATCH_SIZE: usize = 64;
const EMBED_BATCH_MIN: usize = 32;
const EMBED_BATCH_MAX: usize = 512;
pub(crate) fn embed_batch_size() -> usize {
std::env::var("TRUSTY_MAX_BATCH_SIZE")
.ok()
.and_then(|v| v.parse::<usize>().ok())
.filter(|&n| n > 0)
.map(|n| n.clamp(EMBED_BATCH_MIN, EMBED_BATCH_MAX))
.unwrap_or(DEFAULT_EMBED_BATCH_SIZE)
}
#[allow(dead_code)]
pub(crate) const KG_EXPAND_SCORE_FACTOR: f32 = 0.7;
pub(crate) const KG_EXPAND_HOPS: usize = 1;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CodeChunk {
pub id: String,
pub file: String,
#[serde(default)]
pub language: Option<String>,
pub start_line: usize,
pub end_line: usize,
pub content: String,
pub function_name: Option<String>,
pub score: f32,
pub compact_snippet: Option<String>,
pub match_reason: String,
#[serde(default)]
pub chunk_type: ChunkType,
#[serde(default)]
pub calls: Vec<String>,
#[serde(default)]
pub inherits_from: Vec<String>,
#[serde(default)]
pub chunk_depth: u8,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub index_id: Option<String>,
#[serde(default)]
pub on_branch: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchQuery {
pub text: String,
#[serde(default = "default_top_k")]
pub top_k: usize,
#[serde(default = "default_true")]
pub expand_graph: bool,
#[serde(default = "default_true")]
pub compact: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub branch_files: Option<Vec<String>>,
#[serde(default = "SearchQuery::default_branch_boost")]
pub branch_boost: f32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub branch: Option<String>,
}
impl SearchQuery {
pub fn default_branch_boost() -> f32 {
1.5_f32
}
}
impl Default for SearchQuery {
fn default() -> Self {
Self {
text: String::new(),
top_k: default_top_k(),
expand_graph: true,
compact: true,
branch_files: None,
branch_boost: SearchQuery::default_branch_boost(),
branch: None,
}
}
}
fn default_top_k() -> usize {
10
}
fn default_true() -> bool {
true
}
pub(crate) fn hash_query(query: &str) -> u64 {
let mut h = DefaultHasher::new();
query.hash(&mut h);
h.finish()
}
pub(crate) fn build_compact_snippet(content: &str) -> String {
let lines: Vec<&str> = content.lines().collect();
if lines.len() <= 7 {
return content.to_string();
}
lines[..7].join("\n")
}
pub(crate) fn raw_to_code_chunk(
raw: &RawChunk,
score: f32,
match_reason: &str,
compact_snippet: Option<String>,
) -> CodeChunk {
let chunk_depth: u8 = raw.chunk_depth.min(u8::MAX as usize) as u8;
CodeChunk {
id: raw.id.clone(),
file: raw.file.clone(),
language: raw.language.clone(),
start_line: raw.start_line,
end_line: raw.end_line,
content: raw.content.clone(),
function_name: raw.function_name.clone(),
score,
compact_snippet,
match_reason: match_reason.to_string(),
chunk_type: raw.chunk_type.clone(),
calls: raw.calls.clone(),
inherits_from: raw.inherits_from.clone(),
chunk_depth,
index_id: None,
on_branch: false,
}
}
pub(crate) fn populate_virtual_terms(chunks: &mut [RawChunk], entities: &[RawEntity]) {
for chunk in chunks.iter_mut() {
let mut seen: std::collections::HashSet<&str> = std::collections::HashSet::new();
let mut terms: Vec<String> = Vec::new();
for ent in entities {
if ent.line >= chunk.start_line
&& ent.line <= chunk.end_line
&& seen.insert(ent.text.as_str())
{
terms.push(ent.text.clone());
}
}
chunk.virtual_terms = terms;
}
}
pub(crate) fn file_type_score_multiplier(path: &str) -> f32 {
const DOC_EXTENSIONS: &[&str] = &[".md", ".txt", ".toml", ".yaml", ".yml", ".json"];
let lower = path.to_ascii_lowercase();
if DOC_EXTENSIONS.iter().any(|ext| lower.ends_with(ext)) {
0.5
} else {
1.0
}
}
pub(crate) fn compute_match_reason(in_v: bool, in_b: bool, in_kg: bool) -> &'static str {
match (in_v, in_b, in_kg) {
(true, true, _) => "hybrid",
(true, false, _) => "vector",
(false, true, _) => "bm25",
(false, false, true) => "hybrid+kg",
(false, false, false) => "fallback",
}
}
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct ChunkSnapshot {
pub(crate) version: u32,
pub(crate) chunks: Vec<RawChunk>,
pub(crate) entities: Vec<(String, Vec<RawEntity>)>,
}
#[derive(Default)]
pub struct ParsedBatch {
pub chunks: Vec<RawChunk>,
pub embeddings: Vec<Option<Vec<f32>>>,
pub entities_by_file: Vec<(String, Vec<RawEntity>)>,
pub parse_ms: u64,
pub embed_ms: u64,
pub vector_count: usize,
}
#[derive(Debug, Default, Clone, Copy)]
pub struct CommitTimings {
pub chunks: usize,
pub bm25_ms: u64,
pub vector_upsert_ms: u64,
pub kg_ms: u64,
}
pub struct CodeIndexer {
pub index_id: String,
pub root_path: std::path::PathBuf,
pub(super) embedder: Option<Arc<dyn Embedder>>,
pub(super) store: Option<Arc<dyn VectorStore>>,
pub(super) chunks: Arc<RwLock<HashMap<String, RawChunk>>>,
pub(super) entities: Arc<RwLock<HashMap<String, Vec<RawEntity>>>>,
pub(super) chunk_embeddings: Arc<RwLock<LruCache<String, Vec<f32>>>>,
pub(super) bm25: Arc<RwLock<Bm25Index>>,
pub(super) query_cache: Arc<Mutex<LruCache<u64, Vec<f32>>>>,
pub(super) symbol_graph: Arc<RwLock<Arc<SymbolGraph>>>,
pub(super) ner: crate::core::ner::NerExtractor,
pub(super) persist_state: Arc<PersistState>,
pub(super) domain_terms: Vec<String>,
}
#[derive(Debug, Default)]
pub(crate) struct PersistState {
pub(crate) in_flight: AtomicBool,
pub(crate) dirty: AtomicBool,
}
impl CodeIndexer {
pub fn new(index_id: impl Into<String>, root_path: impl Into<std::path::PathBuf>) -> Self {
let cap =
NonZeroUsize::new(QUERY_CACHE_CAPACITY).expect("QUERY_CACHE_CAPACITY must be non-zero");
let emb_cap = NonZeroUsize::new(embedding_cache_cap())
.expect("embedding_cache_cap must be non-zero (env var filtered)");
Self {
index_id: index_id.into(),
root_path: root_path.into(),
embedder: None,
store: None,
chunks: Arc::new(RwLock::new(HashMap::new())),
entities: Arc::new(RwLock::new(HashMap::new())),
chunk_embeddings: Arc::new(RwLock::new(LruCache::new(emb_cap))),
bm25: Arc::new(RwLock::new(Bm25Index::new())),
query_cache: Arc::new(Mutex::new(LruCache::new(cap))),
symbol_graph: Arc::new(RwLock::new(Arc::new(SymbolGraph::new()))),
ner: crate::core::ner::NerExtractor::try_load(),
persist_state: Arc::new(PersistState::default()),
domain_terms: Vec::new(),
}
}
pub fn with_domain_terms(mut self, terms: Vec<String>) -> Self {
self.domain_terms = terms;
self
}
pub fn set_domain_terms(&mut self, terms: Vec<String>) {
self.domain_terms = terms;
}
pub async fn snapshot_symbol_graph(&self) -> Arc<SymbolGraph> {
Arc::clone(&*self.symbol_graph.read().await)
}
pub fn with_components(
mut self,
embedder: Arc<dyn Embedder>,
store: Arc<dyn VectorStore>,
) -> Self {
self.embedder = Some(embedder);
self.store = Some(store);
self
}
}