use crate::search::hnsw::{HNSWIndex, HNSWParams};
use crate::search::quantization::int8_hnsw::{Int8HnswIndex, Int8HnswParams};
use crate::search::query::{MAX_EMBEDDING_DIMENSION, MIN_EMBEDDING_DIMENSION};
use crate::search::ranking::{HybridScorer, Score};
use crate::search::vector::VectorIndex;
use lru::LruCache;
use serde::{Deserialize, Serialize};
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use std::num::NonZeroUsize;
pub const DEFAULT_EMBEDDING_DIMENSION: usize = 768;
pub const MAX_NODES: usize = 1_000_000;
pub const INDEXING_BATCH_NODE_CAP: usize = 50_000;
pub const INDEXING_BATCH_BYTE_CAP: usize = 512 * 1024 * 1024;
pub const WORK_HOISTER_MAX_ENTRIES: usize = 4_096;
pub const WORK_HOISTER_MAX_BYTES: usize = 8 * 1024 * 1024;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PruningDecision {
Keep,
GeneratedCode(String),
LowInformation(String),
}
#[derive(Debug, Clone)]
pub struct ContentPruner {
generated_suffixes: Vec<String>,
generated_path_patterns: Vec<String>,
min_content_bytes: usize,
}
impl Default for ContentPruner {
fn default() -> Self {
Self::new()
}
}
impl ContentPruner {
pub fn new() -> Self {
Self {
generated_suffixes: vec![
".min.js".into(),
".min.css".into(),
".pb.go".into(),
".generated.rs".into(),
".bundle.js".into(),
".chunk.js".into(),
"_pb2.py".into(),
".d.ts".into(),
],
generated_path_patterns: vec![
"node_modules".into(),
"/vendor/".into(),
"/generated/".into(),
"/autogenerated/".into(),
"\\vendor\\".into(),
],
min_content_bytes: 16,
}
}
pub fn evaluate(&self, file_path: &str, content: &str, symbol_name: &str) -> PruningDecision {
let fp_lower = file_path.to_ascii_lowercase();
for suffix in &self.generated_suffixes {
if fp_lower.ends_with(suffix) {
return PruningDecision::GeneratedCode(format!(
"file ends with generated suffix '{}'",
suffix
));
}
}
for pattern in &self.generated_path_patterns {
if fp_lower.contains(pattern) {
return PruningDecision::GeneratedCode(format!(
"file path contains generated pattern '{}'",
pattern
));
}
}
if content.len() < self.min_content_bytes && symbol_name.len() < 3 {
return PruningDecision::LowInformation(format!(
"content {} bytes < min {}, symbol '{}' < 3 chars",
content.len(),
self.min_content_bytes,
symbol_name
));
}
PruningDecision::Keep
}
pub fn is_generated_path(&self, file_path: &str) -> bool {
let fp_lower = file_path.to_ascii_lowercase();
for suffix in &self.generated_suffixes {
if fp_lower.ends_with(suffix) {
return true;
}
}
for pattern in &self.generated_path_patterns {
if fp_lower.contains(pattern) {
return true;
}
}
false
}
}
#[derive(Debug)]
pub struct IndexingAdmissionGate {
node_cap: usize,
byte_cap: usize,
nodes_admitted: usize,
bytes_admitted: usize,
nodes_shed: usize,
}
impl IndexingAdmissionGate {
pub fn new() -> Self {
Self {
node_cap: INDEXING_BATCH_NODE_CAP,
byte_cap: INDEXING_BATCH_BYTE_CAP,
nodes_admitted: 0,
bytes_admitted: 0,
nodes_shed: 0,
}
}
pub fn with_caps(node_cap: usize, byte_cap: usize) -> Self {
Self {
node_cap,
byte_cap,
nodes_admitted: 0,
bytes_admitted: 0,
nodes_shed: 0,
}
}
pub fn try_admit(&mut self, content_bytes: usize) -> bool {
if self.nodes_admitted >= self.node_cap {
self.nodes_shed += 1;
return false;
}
if self.bytes_admitted + content_bytes > self.byte_cap {
self.nodes_shed += 1;
return false;
}
self.nodes_admitted += 1;
self.bytes_admitted += content_bytes;
true
}
pub fn nodes_admitted(&self) -> usize {
self.nodes_admitted
}
pub fn nodes_shed(&self) -> usize {
self.nodes_shed
}
pub fn bytes_admitted(&self) -> usize {
self.bytes_admitted
}
pub fn reset(&mut self) {
self.nodes_admitted = 0;
self.bytes_admitted = 0;
self.nodes_shed = 0;
}
}
impl Default for IndexingAdmissionGate {
fn default() -> Self {
Self::new()
}
}
struct HoistedWork {
embedding: Vec<f32>,
neural_embedding: Option<Vec<f32>>,
byte_size: usize,
}
pub struct WorkHoister {
cache: LruCache<blake3::Hash, HoistedWork>,
tracked_bytes: usize,
max_bytes: usize,
}
impl WorkHoister {
pub fn new() -> Self {
Self {
cache: LruCache::new(NonZeroUsize::new(WORK_HOISTER_MAX_ENTRIES).unwrap()),
tracked_bytes: 0,
max_bytes: WORK_HOISTER_MAX_BYTES,
}
}
pub fn with_bounds(max_entries: usize, max_bytes: usize) -> Self {
Self {
cache: LruCache::new(NonZeroUsize::new(max_entries.max(1)).unwrap()),
tracked_bytes: 0,
max_bytes,
}
}
pub fn lookup(&mut self, content: &str) -> Option<(Vec<f32>, Option<Vec<f32>>)> {
let hash = blake3::hash(content.as_bytes());
self.cache.get(&hash).map(|entry| {
(entry.embedding.clone(), entry.neural_embedding.clone())
})
}
pub fn store(&mut self, content: &str, embedding: Vec<f32>, neural_embedding: Option<Vec<f32>>) {
let hash = blake3::hash(content.as_bytes());
let neural_byte_size = neural_embedding.as_ref().map_or(0, |v| v.len() * std::mem::size_of::<f32>());
let byte_size = 32 + embedding.len() * std::mem::size_of::<f32>() + neural_byte_size;
let entry = HoistedWork {
embedding,
neural_embedding,
byte_size,
};
if let Some(existing) = self.cache.put(hash, entry) {
self.tracked_bytes = self.tracked_bytes.saturating_sub(existing.byte_size);
}
while self.tracked_bytes + byte_size > self.max_bytes && !self.cache.is_empty() {
if let Some((evicted_hash, evicted)) = self.cache.pop_lru() {
self.tracked_bytes = self.tracked_bytes.saturating_sub(evicted.byte_size);
if evicted_hash == hash {
return;
}
}
}
if self.cache.contains(&hash) {
self.tracked_bytes += byte_size;
}
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
pub fn bytes_used(&self) -> usize {
self.tracked_bytes
}
pub fn clear(&mut self) {
self.cache.clear();
self.tracked_bytes = 0;
}
}
impl Default for WorkHoister {
fn default() -> Self {
Self::new()
}
}
pub enum VectorIndexImpl {
BruteForce(VectorIndex),
HNSW(Box<HNSWIndex>),
HNSWQuantized(Box<Int8HnswIndex>),
}
impl VectorIndexImpl {
#[must_use]
pub fn len(&self) -> usize {
match self {
Self::BruteForce(idx) => idx.len(),
Self::HNSW(idx) => idx.len(),
Self::HNSWQuantized(idx) => idx.len(),
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
match self {
Self::BruteForce(idx) => idx.is_empty(),
Self::HNSW(idx) => idx.is_empty(),
Self::HNSWQuantized(idx) => idx.is_empty(),
}
}
#[must_use]
pub fn dimension(&self) -> usize {
match self {
Self::BruteForce(idx) => idx.dimension(),
Self::HNSW(idx) => idx.dimension(),
Self::HNSWQuantized(idx) => idx.dimension(),
}
}
pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(String, f32)> {
match self {
Self::BruteForce(idx) => idx.search(query, top_k),
Self::HNSW(idx) => idx.search(query, top_k),
Self::HNSWQuantized(idx) => idx.search(query, top_k),
}
}
pub fn insert(&mut self, node_id: String, vector: Vec<f32>) -> Result<(), VectorIndexError> {
match self {
Self::BruteForce(idx) => idx
.insert(node_id, vector)
.map_err(|e| VectorIndexError::InsertionFailed(e.to_string())),
Self::HNSW(idx) => idx
.insert(node_id, vector)
.map_err(|e| VectorIndexError::InsertionFailed(e.to_string())),
Self::HNSWQuantized(idx) => idx
.insert(node_id, vector)
.map_err(|e| VectorIndexError::InsertionFailed(e.to_string())),
}
}
pub fn clear(&mut self) {
match self {
Self::BruteForce(idx) => idx.clear(),
Self::HNSW(idx) => idx.clear(),
Self::HNSWQuantized(idx) => idx.clear(),
}
}
pub fn remove(&mut self, node_id: &str) -> bool {
match self {
Self::BruteForce(idx) => idx.remove(node_id),
Self::HNSW(idx) => idx.remove(node_id),
Self::HNSWQuantized(idx) => idx.remove(node_id),
}
}
#[must_use]
pub fn is_hnsw_enabled(&self) -> bool {
matches!(self, Self::HNSW(_) | Self::HNSWQuantized(_))
}
#[must_use]
pub fn estimated_memory_bytes(&self) -> usize {
match self {
Self::BruteForce(idx) => (*idx).estimated_memory_bytes(),
Self::HNSW(idx) => (*idx).estimated_memory_bytes(),
Self::HNSWQuantized(idx) => (*idx).estimated_memory_bytes(),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum VectorIndexError {
#[error("Insertion failed: {0}")]
InsertionFailed(String),
#[error("Index operation failed: {0}")]
IndexOperationFailed(String),
}
#[derive(Debug, Clone)]
pub struct NodeInfo {
pub node_id: String,
pub file_path: String,
pub symbol_name: String,
pub language: String,
pub content: String,
pub byte_range: (usize, usize),
pub tfidf_embedding: Vec<f32>,
pub neural_embedding: Option<Vec<f32>>,
pub complexity: u32,
pub signature: Option<String>,
pub pre_tokenized: Option<Vec<String>>,
}
#[derive(Deserialize)]
struct NodeInfoRepr {
node_id: String,
file_path: String,
symbol_name: String,
language: String,
content: String,
byte_range: (usize, usize),
#[serde(default)]
tfidf_embedding: Vec<f32>,
#[serde(default)]
neural_embedding: Option<Vec<f32>>,
#[serde(default, alias = "embedding")]
legacy_embedding: Option<Vec<f32>>,
#[serde(default)]
complexity: u32,
#[serde(default)]
signature: Option<String>,
#[serde(default)]
pre_tokenized: Option<Vec<String>>,
}
impl<'de> Deserialize<'de> for NodeInfo {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let repr = NodeInfoRepr::deserialize(deserializer)?;
let tfidf_embedding = if !repr.tfidf_embedding.is_empty() {
repr.tfidf_embedding
} else if let Some(legacy) = repr.legacy_embedding {
if !legacy.is_empty() {
legacy
} else {
Vec::new()
}
} else {
Vec::new()
};
Ok(Self {
node_id: repr.node_id,
file_path: repr.file_path,
symbol_name: repr.symbol_name,
language: repr.language,
content: repr.content,
byte_range: repr.byte_range,
tfidf_embedding,
neural_embedding: repr.neural_embedding,
complexity: repr.complexity,
signature: repr.signature,
pre_tokenized: repr.pre_tokenized,
})
}
}
impl Serialize for NodeInfo {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
#[derive(Serialize)]
struct NodeInfoNew<'a> {
node_id: &'a str,
file_path: &'a str,
symbol_name: &'a str,
language: &'a str,
content: &'a str,
byte_range: (usize, usize),
tfidf_embedding: &'a [f32],
neural_embedding: &'a Option<Vec<f32>>,
complexity: u32,
signature: &'a Option<String>,
pre_tokenized: &'a Option<Vec<String>>,
}
NodeInfoNew {
node_id: &self.node_id,
file_path: &self.file_path,
symbol_name: &self.symbol_name,
language: &self.language,
content: &self.content,
byte_range: self.byte_range,
tfidf_embedding: &self.tfidf_embedding,
neural_embedding: &self.neural_embedding,
complexity: self.complexity,
signature: &self.signature,
pre_tokenized: &self.pre_tokenized,
}
.serialize(serializer)
}
}
struct TextQueryPreprocessed {
query_lower: String,
query_tokens: HashSet<String>,
}
impl TextQueryPreprocessed {
fn from_query(query: &str) -> Self {
let query_lower = query.to_ascii_lowercase();
let query_tokens: HashSet<_> = query
.split(|c: char| !c.is_alphanumeric())
.map(|s| s.to_ascii_lowercase())
.filter(|s| s.len() >= 2)
.collect();
Self {
query_lower,
query_tokens,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchQuery {
pub query: String,
pub top_k: usize,
pub token_budget: Option<usize>,
pub semantic: bool,
pub expand_context: bool,
pub query_embedding: Option<Vec<f32>>,
pub threshold: Option<f32>,
pub query_type: Option<crate::search::ranking::QueryType>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub rank: usize,
pub node_id: String,
pub file_path: String,
pub symbol_name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub symbol_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub signature: Option<String>,
pub complexity: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub caller_count: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dependency_count: Option<usize>,
pub language: String,
pub score: Score,
pub context: Option<String>,
pub byte_range: (usize, usize),
}
#[derive(Debug, Clone)]
pub struct CompactTokenIndex {
token_rows: HashMap<String, HashSet<u32>>,
}
impl CompactTokenIndex {
pub fn nodes_for_token(&self, token: &str) -> &HashSet<u32> {
static EMPTY: std::sync::OnceLock<HashSet<u32>> = std::sync::OnceLock::new();
self.token_rows
.get(token)
.unwrap_or_else(|| EMPTY.get_or_init(HashSet::new))
}
pub fn token_count(&self) -> usize {
self.token_rows.len()
}
}
#[derive(Debug, Clone)]
pub struct CompactNodeMetadata {
row_map: Vec<(String, u32)>,
complexity_by_row: Vec<u32>,
token_index: CompactTokenIndex,
}
impl CompactNodeMetadata {
pub fn row_index(&self, node_id: &str) -> Option<u32> {
self.row_map
.iter()
.find(|(id, _)| id == node_id)
.map(|(_, row)| *row)
}
pub fn complexity_by_row(&self, row: u32) -> Option<u32> {
self.complexity_by_row.get(row as usize).copied()
}
pub fn token_index(&self) -> &CompactTokenIndex {
&self.token_index
}
pub fn node_count(&self) -> usize {
self.row_map.len()
}
}
pub struct SearchEngine {
nodes: Vec<NodeInfo>,
scorer: HybridScorer,
vector_index: VectorIndexImpl,
complexity_cache: HashMap<String, u32>,
text_index: HashMap<String, HashSet<String>>,
node_id_to_idx: HashMap<String, usize>,
node_tokens: HashMap<String, HashSet<String>>,
search_cache: LruCache<String, Vec<SearchResult>>,
search_cache_bytes: usize,
}
pub const SEARCH_CACHE_MAX_ENTRIES: usize = 256;
pub const SEARCH_CACHE_MAX_BYTES: usize = 16 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct StagedRetrievalConfig {
pub enabled: bool,
pub coarse_multiplier: usize,
}
impl Default for StagedRetrievalConfig {
fn default() -> Self {
Self {
enabled: false,
coarse_multiplier: 5,
}
}
}
impl StagedRetrievalConfig {
pub fn enabled_with_multiplier(coarse_multiplier: usize) -> Self {
Self {
enabled: true,
coarse_multiplier: coarse_multiplier.max(1),
}
}
pub fn disabled() -> Self {
Self {
enabled: false,
..Default::default()
}
}
}
#[derive(Debug, Clone, Default)]
pub struct StagedRetrievalMetrics {
pub coarse_candidates: usize,
pub exact_scored: usize,
pub results_returned: usize,
pub staged_used: bool,
}
#[derive(Debug, Clone)]
pub struct Int8QualityThresholds {
pub ndcg10_max_drop: f64,
pub p50_max_increase: f64,
pub p99_max_increase: f64,
}
impl Default for Int8QualityThresholds {
fn default() -> Self {
Self {
ndcg10_max_drop: 0.01,
p50_max_increase: 0.05,
p99_max_increase: 0.10,
}
}
}
#[derive(Debug, Clone)]
pub struct Int8QualityReport {
pub baseline_ndcg10: f64,
pub int8_ndcg10: f64,
pub ndcg10_drop: f64,
pub ndcg10_passed: bool,
pub baseline_p50_ns: u64,
pub int8_p50_ns: u64,
pub p50_increase: f64,
pub p50_passed: bool,
pub baseline_p99_ns: u64,
pub int8_p99_ns: u64,
pub p99_increase: f64,
pub p99_passed: bool,
pub overall_passed: bool,
}
#[derive(Debug, Clone)]
pub struct Int8QualityGate {
thresholds: Int8QualityThresholds,
}
impl Int8QualityGate {
pub fn new(thresholds: Int8QualityThresholds) -> Self {
Self { thresholds }
}
pub fn with_default_thresholds() -> Self {
Self::new(Int8QualityThresholds::default())
}
pub fn thresholds(&self) -> &Int8QualityThresholds {
&self.thresholds
}
pub fn ndcg_at_10(returned_ids: &[String], relevant_ids: &HashSet<String>) -> f64 {
let k = 10.min(returned_ids.len());
let mut dcg: f64 = 0.0;
for (i, id) in returned_ids.iter().take(k).enumerate() {
if relevant_ids.contains(id) {
let rank = (i + 1) as f64;
dcg += 1.0 / (rank + 1.0).log2();
}
}
let ideal_k = 10.min(relevant_ids.len());
let mut idcg: f64 = 0.0;
for i in 0..ideal_k {
let rank = (i + 1) as f64;
idcg += 1.0 / (rank + 1.0).log2();
}
if idcg == 0.0 {
0.0
} else {
dcg / idcg
}
}
pub fn latency_percentile(sorted_samples: &[u64], percentile: f64) -> u64 {
if sorted_samples.is_empty() {
return 0;
}
let idx = ((sorted_samples.len() as f64) * percentile) as usize;
let idx = idx.min(sorted_samples.len() - 1);
sorted_samples[idx]
}
pub fn evaluate(
&self,
baseline_ndcg10: f64,
int8_ndcg10: f64,
baseline_latencies_ns: &[u64],
int8_latencies_ns: &[u64],
) -> Int8QualityReport {
let ndcg10_drop = baseline_ndcg10 - int8_ndcg10;
let ndcg10_passed = ndcg10_drop <= self.thresholds.ndcg10_max_drop;
let mut baseline_sorted = baseline_latencies_ns.to_vec();
baseline_sorted.sort();
let mut int8_sorted = int8_latencies_ns.to_vec();
int8_sorted.sort();
let baseline_p50 = Self::latency_percentile(&baseline_sorted, 0.50);
let int8_p50 = Self::latency_percentile(&int8_sorted, 0.50);
let p50_increase = if baseline_p50 == 0 {
0.0
} else {
(int8_p50 as f64 - baseline_p50 as f64) / baseline_p50 as f64
};
let p50_passed = p50_increase <= self.thresholds.p50_max_increase;
let baseline_p99 = Self::latency_percentile(&baseline_sorted, 0.99);
let int8_p99 = Self::latency_percentile(&int8_sorted, 0.99);
let p99_increase = if baseline_p99 == 0 {
0.0
} else {
(int8_p99 as f64 - baseline_p99 as f64) / baseline_p99 as f64
};
let p99_passed = p99_increase <= self.thresholds.p99_max_increase;
let overall_passed = ndcg10_passed && p50_passed && p99_passed;
Int8QualityReport {
baseline_ndcg10,
int8_ndcg10,
ndcg10_drop,
ndcg10_passed,
baseline_p50_ns: baseline_p50,
int8_p50_ns: int8_p50,
p50_increase,
p50_passed,
baseline_p99_ns: baseline_p99,
int8_p99_ns: int8_p99,
p99_increase,
p99_passed,
overall_passed,
}
}
}
impl Default for Int8QualityGate {
fn default() -> Self {
Self::with_default_thresholds()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Int8PromotionDecision {
Promote,
Block(String),
}
impl Int8QualityReport {
pub fn promotion_decision(&self) -> Int8PromotionDecision {
if self.overall_passed {
Int8PromotionDecision::Promote
} else {
let mut reasons = Vec::new();
if !self.ndcg10_passed {
reasons.push(format!(
"NDCG@10 drop {:.4} > {:.4}",
self.ndcg10_drop, 0.01
));
}
if !self.p50_passed {
reasons.push(format!(
"p50 latency increase {:.2}% > {:.2}%",
self.p50_increase * 100.0,
5.0
));
}
if !self.p99_passed {
reasons.push(format!(
"p99 latency increase {:.2}% > {:.2}%",
self.p99_increase * 100.0,
10.0
));
}
Int8PromotionDecision::Block(reasons.join("; "))
}
}
}
impl SearchEngine {
#[must_use]
pub fn new() -> Self {
Self {
nodes: Vec::new(),
scorer: HybridScorer::new(),
vector_index: VectorIndexImpl::BruteForce(VectorIndex::new(
DEFAULT_EMBEDDING_DIMENSION,
)),
complexity_cache: HashMap::new(),
text_index: HashMap::new(),
node_id_to_idx: HashMap::new(),
node_tokens: HashMap::new(),
search_cache: LruCache::new(NonZeroUsize::new(SEARCH_CACHE_MAX_ENTRIES).unwrap()),
search_cache_bytes: 0,
}
}
#[must_use]
pub fn with_dimension(dimension: usize) -> Self {
if !(MIN_EMBEDDING_DIMENSION..=MAX_EMBEDDING_DIMENSION).contains(&dimension) {
panic!(
"Invalid embedding dimension: {} (must be between {} and {})",
dimension, MIN_EMBEDDING_DIMENSION, MAX_EMBEDDING_DIMENSION
);
}
Self {
nodes: Vec::new(),
scorer: HybridScorer::new(),
vector_index: VectorIndexImpl::BruteForce(VectorIndex::new(dimension)),
complexity_cache: HashMap::new(),
text_index: HashMap::new(),
node_id_to_idx: HashMap::new(),
node_tokens: HashMap::new(),
search_cache: LruCache::new(NonZeroUsize::new(SEARCH_CACHE_MAX_ENTRIES).unwrap()),
search_cache_bytes: 0,
}
}
pub fn index_nodes(&mut self, mut nodes: Vec<NodeInfo>) {
if nodes.len() > MAX_NODES {
panic!(
"Cannot index more than {} nodes (provided: {})",
MAX_NODES,
nodes.len()
);
}
self.complexity_cache.clear();
self.text_index.clear();
self.search_cache.clear();
self.search_cache_bytes = 0;
self.node_id_to_idx.clear();
self.node_tokens.clear();
self.vector_index.clear();
for (idx, node) in nodes.iter().enumerate() {
self.node_id_to_idx.insert(node.node_id.clone(), idx);
self.complexity_cache
.insert(node.node_id.clone(), node.complexity);
let mut tokens = HashSet::new();
if let Some(pre_tok) = &node.pre_tokenized {
for token in pre_tok {
self.text_index
.entry(token.clone())
.or_default()
.insert(node.node_id.clone());
tokens.insert(token.clone());
}
} else {
for token in node.content.split(|c: char| !c.is_alphanumeric()) {
let normalized_token: String = token.to_ascii_lowercase();
if normalized_token.len() >= 2 {
self.text_index
.entry(normalized_token.clone())
.or_default()
.insert(node.node_id.clone());
tokens.insert(normalized_token);
}
}
}
self.node_tokens.insert(node.node_id.clone(), tokens);
}
for node in nodes.iter_mut() {
if !node.tfidf_embedding.is_empty() {
if let Err(e) = self
.vector_index
.insert(node.node_id.clone(), node.tfidf_embedding.clone())
{
tracing::warn!(
"Failed to insert TF-IDF embedding for node {}: {:?}",
node.node_id,
e
);
}
}
}
self.nodes = nodes;
for node in &mut self.nodes {
node.signature = Self::extract_signature_from_content(&node.content);
}
for node in &mut self.nodes {
node.content.clear();
}
}
pub fn extract_signature_from_content(content: &str) -> Option<String> {
content
.lines()
.skip(1) .map(|l| l.trim())
.find(|l| !l.is_empty() && !l.starts_with("// [No source") && !l.starts_with("// ["))
.map(|l| l.to_string())
}
pub fn incremental_reindex(&mut self, delta: TextIndexDelta) {
self.search_cache.clear();
self.search_cache_bytes = 0;
for node_id in &delta.removed_node_ids {
self.remove_node_from_index(node_id);
}
for node in delta.updated_nodes {
self.add_node_to_index(node);
}
if self.nodes.len() > MAX_NODES {
panic!(
"Cannot index more than {} nodes (current: {})",
MAX_NODES,
self.nodes.len()
);
}
}
fn remove_node_from_index(&mut self, node_id: &str) {
let Some(removed_idx) = self.node_id_to_idx.remove(node_id) else {
return; };
if let Some(tokens) = self.node_tokens.remove(node_id) {
for token in tokens {
if let Entry::Occupied(mut entry) = self.text_index.entry(token) {
entry.get_mut().remove(node_id);
if entry.get().is_empty() {
entry.remove();
}
}
}
}
self.complexity_cache.remove(node_id);
self.vector_index.remove(node_id);
if removed_idx < self.nodes.len() {
self.nodes.swap_remove(removed_idx);
if removed_idx < self.nodes.len() {
let swapped_id = self.nodes[removed_idx].node_id.clone();
self.node_id_to_idx.insert(swapped_id, removed_idx);
}
}
}
fn add_node_to_index(&mut self, mut node: NodeInfo) {
if self.node_id_to_idx.contains_key(&node.node_id) {
self.remove_node_from_index(&node.node_id);
}
let node_id = node.node_id.clone();
let new_idx = self.nodes.len();
let mut tokens = HashSet::new();
if let Some(pre_tok) = &node.pre_tokenized {
for token in pre_tok {
self.text_index
.entry(token.clone())
.or_default()
.insert(node_id.clone());
tokens.insert(token.clone());
}
} else {
for token in node.content.split(|c: char| !c.is_alphanumeric()) {
let normalized_token: String = token.to_ascii_lowercase();
if normalized_token.len() >= 2 {
self.text_index
.entry(normalized_token.clone())
.or_default()
.insert(node_id.clone());
tokens.insert(normalized_token);
}
}
}
self.node_tokens.insert(node_id.clone(), tokens);
self.node_id_to_idx.insert(node_id.clone(), new_idx);
self.complexity_cache
.insert(node_id.clone(), node.complexity);
if !node.tfidf_embedding.is_empty() {
if let Err(e) = self
.vector_index
.insert(node_id.clone(), node.tfidf_embedding.clone())
{
tracing::warn!(
"Failed to insert TF-IDF embedding for node {}: {:?}",
node_id,
e
);
}
}
node.signature = Self::extract_signature_from_content(&node.content);
node.content.clear();
self.nodes.push(node);
}
#[must_use]
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn collect_embeddings(&self) -> Vec<(String, Vec<f32>)> {
self.nodes
.iter()
.filter(|n| !n.tfidf_embedding.is_empty())
.map(|n| (n.node_id.clone(), n.tfidf_embedding.clone()))
.collect()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn node_index(&self, node_id: &str) -> Option<usize> {
self.node_id_to_idx.get(node_id).copied()
}
pub fn live_node_count(&self) -> usize {
self.node_id_to_idx.len()
}
pub fn contains_node(&self, node_id: &str) -> bool {
self.node_id_to_idx.contains_key(node_id)
}
pub fn live_node_ids(&self) -> Vec<String> {
self.node_id_to_idx.keys().cloned().collect()
}
pub fn node_complexity(&self, node_id: &str) -> Option<u32> {
self.complexity_cache.get(node_id).copied()
}
pub fn node_tokens(&self, node_id: &str) -> Option<&HashSet<String>> {
self.node_tokens.get(node_id)
}
pub fn token_lookup(&self, token: &str) -> Option<&HashSet<String>> {
self.text_index.get(token)
}
pub fn search_cache_len(&self) -> usize {
self.search_cache.len()
}
pub fn search_cache_bytes(&self) -> usize {
self.search_cache_bytes
}
pub fn compact_metadata(&self) -> CompactNodeMetadata {
let mut row_map: Vec<(String, u32)> = Vec::with_capacity(self.nodes.len());
let mut complexity_by_row: Vec<u32> = Vec::with_capacity(self.nodes.len());
for (idx, node) in self.nodes.iter().enumerate() {
row_map.push((node.node_id.clone(), idx as u32));
complexity_by_row.push(node.complexity);
}
let mut token_rows: HashMap<String, HashSet<u32>> = HashMap::new();
for (token, node_ids) in &self.text_index {
let mut rows = HashSet::new();
for node_id in node_ids {
if let Some(&idx) = self.node_id_to_idx.get(node_id) {
rows.insert(idx as u32);
}
}
token_rows.insert(token.clone(), rows);
}
CompactNodeMetadata {
row_map,
complexity_by_row,
token_index: CompactTokenIndex { token_rows },
}
}
pub fn validate_coherence(&self) -> Result<(), String> {
if self.node_id_to_idx.len() != self.nodes.len() {
return Err(format!(
"node_id_to_idx len ({}) != nodes len ({})",
self.node_id_to_idx.len(),
self.nodes.len()
));
}
for (id, &idx) in &self.node_id_to_idx {
if idx >= self.nodes.len() {
return Err(format!(
"node_id_to_idx[{}] = {} >= nodes.len() = {}",
id,
idx,
self.nodes.len()
));
}
if self.nodes[idx].node_id != *id {
return Err(format!(
"nodes[{}].node_id = '{}' != node_id_to_idx key '{}'",
idx, self.nodes[idx].node_id, id
));
}
}
for node in &self.nodes {
match self.complexity_cache.get(&node.node_id) {
Some(c) if *c == node.complexity => {}
Some(c) => {
return Err(format!(
"complexity_cache[{}] = {} != node.complexity = {}",
node.node_id, c, node.complexity
))
}
None => {
return Err(format!(
"complexity_cache missing entry for {}",
node.node_id
))
}
}
}
for node in &self.nodes {
if !self.node_tokens.contains_key(&node.node_id) {
return Err(format!("node_tokens missing entry for {}", node.node_id));
}
}
for (token, node_ids) in &self.text_index {
for id in node_ids {
if !self.node_id_to_idx.contains_key(id) {
return Err(format!(
"text_index token '{}' references non-live node '{}'",
token, id
));
}
}
}
Ok(())
}
pub fn search(&mut self, query: SearchQuery) -> Result<Vec<SearchResult>, Error> {
if self.nodes.is_empty() {
return Ok(Vec::new());
}
let cache_key = format!(
"{}:{}:{:?}:{}",
query.query, query.top_k, query.threshold, query.semantic
);
if let Some(cached) = self.search_cache.get(&cache_key) {
return Ok(cached.clone());
}
let mut results = Vec::new();
let vector_results: std::collections::HashMap<String, f32> = if query.semantic {
let embedding = if let Some(emb) = query.query_embedding {
Some(emb)
} else {
self.nodes
.iter()
.find_map(|n| {
if n.tfidf_embedding.is_empty() {
None
} else {
Some(&n.tfidf_embedding)
}
})
.cloned()
};
if let Some(emb) = embedding {
self.vector_index
.search(&emb, query.top_k)
.into_iter()
.collect()
} else {
std::collections::HashMap::new()
}
} else {
std::collections::HashMap::new()
};
let text_query = TextQueryPreprocessed::from_query(&query.query);
let candidates = if text_query.query_tokens.is_empty() {
self.nodes.iter().collect::<Vec<_>>()
} else {
let mut candidate_ids: HashSet<&str> = HashSet::new();
for token in &text_query.query_tokens {
if let Some(node_ids) = self.text_index.get(token) {
for node_id in node_ids {
candidate_ids.insert(node_id.as_str());
}
}
}
if candidate_ids.is_empty() && !query.semantic {
return Ok(Vec::new());
}
if candidate_ids.is_empty() {
self.nodes.iter().collect()
} else {
if vector_results.is_empty() {
self.nodes
.iter()
.filter(|node| candidate_ids.contains(node.node_id.as_str()))
.collect()
} else {
self.nodes
.iter()
.filter(|node| {
candidate_ids.contains(node.node_id.as_str())
|| vector_results.contains_key(&node.node_id)
})
.collect()
}
}
};
for node in candidates {
let text_score = self.calculate_text_score_optimized(
&text_query,
&node.node_id,
&node.symbol_name,
&node.file_path,
);
let tfidf_score = if query.semantic {
*vector_results.get(&node.node_id).unwrap_or(&0.0)
} else {
0.0
};
if text_score == 0.0 && !query.semantic && tfidf_score == 0.0 {
continue;
}
let structural_score = (node.complexity as f32 / 100.0).min(1.0);
let neural_score = 0.0;
let score = if let Some(qt) = query.query_type {
match qt {
crate::search::ranking::QueryType::Text => {
self.scorer
.with_weights_hybrid(0.2, 0.05, 0.05, 0.7)
.score_hybrid(tfidf_score, neural_score, structural_score, text_score)
}
crate::search::ranking::QueryType::Semantic => {
self.scorer
.with_weights_hybrid(0.7, 0.1, 0.1, 0.1)
.score_hybrid(tfidf_score, neural_score, structural_score, text_score)
}
crate::search::ranking::QueryType::Structural => {
self.scorer
.with_weights_hybrid(0.3, 0.0, 0.5, 0.2)
.score_hybrid(tfidf_score, neural_score, structural_score, text_score)
}
}
} else {
self.scorer
.score_hybrid(tfidf_score, neural_score, structural_score, text_score)
};
if score.overall > 0.0 {
if let Some(threshold) = query.threshold {
if score.overall < threshold {
continue;
}
}
let signature = node.signature.clone();
results.push(SearchResult {
rank: 0, node_id: node.node_id.clone(),
file_path: node.file_path.clone(),
symbol_name: node.symbol_name.clone(),
symbol_type: None, signature,
complexity: node.complexity,
caller_count: None, dependency_count: None, language: node.language.clone(),
score,
context: None,
byte_range: node.byte_range,
});
}
}
results.sort_by(|a, b| {
b.score
.overall
.partial_cmp(&a.score.overall)
.unwrap_or(std::cmp::Ordering::Equal)
});
let top_k = results.into_iter().take(query.top_k).collect::<Vec<_>>();
let mut final_results = top_k;
for (i, result) in final_results.iter_mut().enumerate() {
result.rank = i + 1;
}
{
let results_bytes = Self::estimate_search_results_bytes(&final_results);
if results_bytes < SEARCH_CACHE_MAX_BYTES {
if let Some(existing) = self.search_cache.get(&cache_key) {
self.search_cache_bytes = self
.search_cache_bytes
.saturating_sub(Self::estimate_search_results_bytes(existing));
}
while self.search_cache_bytes + results_bytes > SEARCH_CACHE_MAX_BYTES
&& !self.search_cache.is_empty()
{
if let Some((_, evicted)) = self.search_cache.pop_lru() {
self.search_cache_bytes = self
.search_cache_bytes
.saturating_sub(Self::estimate_search_results_bytes(&evicted));
}
}
self.search_cache_bytes += results_bytes;
self.search_cache.put(cache_key, final_results.clone());
}
}
Ok(final_results)
}
pub fn search_staged(
&mut self,
query: SearchQuery,
config: &StagedRetrievalConfig,
) -> Result<(Vec<SearchResult>, StagedRetrievalMetrics), Error> {
if !config.enabled {
let results = self.search(query)?;
let count = results.len();
return Ok((
results,
StagedRetrievalMetrics {
coarse_candidates: 0,
exact_scored: count,
results_returned: count,
staged_used: false,
},
));
}
if self.nodes.is_empty() {
return Ok((
Vec::new(),
StagedRetrievalMetrics {
staged_used: true,
..Default::default()
},
));
}
let cache_key = format!(
"staged:{}:{}:{:?}:{}:{:?}:{:?}",
query.query, query.top_k, query.threshold, query.semantic, config.coarse_multiplier, query.query_type
);
if let Some(cached) = self.search_cache.get(&cache_key) {
let count = cached.len();
return Ok((
cached.clone(),
StagedRetrievalMetrics {
coarse_candidates: 0,
exact_scored: count,
results_returned: count,
staged_used: true,
},
));
}
let mut metrics = StagedRetrievalMetrics {
staged_used: true,
..Default::default()
};
let coarse_top_k = query.top_k.saturating_mul(config.coarse_multiplier);
let text_query = TextQueryPreprocessed::from_query(&query.query);
let mut coarse_candidate_ids: HashSet<String> = HashSet::new();
for token in &text_query.query_tokens {
if let Some(node_ids) = self.text_index.get(token) {
for id in node_ids {
coarse_candidate_ids.insert(id.clone());
}
}
}
let vector_results: HashMap<String, f32> = if query.semantic {
if let Some(ref emb) = query.query_embedding {
let vec_hits = self.vector_index.search(emb, coarse_top_k);
for (id, _) in &vec_hits {
coarse_candidate_ids.insert(id.clone());
}
vec_hits.into_iter().collect()
} else {
HashMap::new()
}
} else {
HashMap::new()
};
metrics.coarse_candidates = coarse_candidate_ids.len();
if coarse_candidate_ids.is_empty() {
return Ok((Vec::new(), metrics));
}
let mut results = Vec::new();
for node in &self.nodes {
if !coarse_candidate_ids.contains(&node.node_id) {
continue;
}
let text_score = self.calculate_text_score_optimized(
&text_query,
&node.node_id,
&node.symbol_name,
&node.file_path,
);
let tfidf_score = if query.semantic {
*vector_results.get(&node.node_id).unwrap_or(&0.0)
} else {
0.0
};
if text_score == 0.0 && !query.semantic && tfidf_score == 0.0 {
continue;
}
let structural_score = (node.complexity as f32 / 100.0).min(1.0);
let neural_score = 0.0;
let score = if let Some(qt) = query.query_type {
match qt {
crate::search::ranking::QueryType::Text => self
.scorer
.with_weights_hybrid(0.2, 0.05, 0.05, 0.7)
.score_hybrid(tfidf_score, neural_score, structural_score, text_score),
crate::search::ranking::QueryType::Semantic => self
.scorer
.with_weights_hybrid(0.7, 0.1, 0.1, 0.1)
.score_hybrid(tfidf_score, neural_score, structural_score, text_score),
crate::search::ranking::QueryType::Structural => self
.scorer
.with_weights_hybrid(0.3, 0.0, 0.5, 0.2)
.score_hybrid(tfidf_score, neural_score, structural_score, text_score),
}
} else {
self.scorer
.score_hybrid(tfidf_score, neural_score, structural_score, text_score)
};
if score.overall > 0.0 {
if let Some(threshold) = query.threshold {
if score.overall < threshold {
continue;
}
}
let signature = node.signature.clone();
results.push(SearchResult {
rank: 0,
node_id: node.node_id.clone(),
file_path: node.file_path.clone(),
symbol_name: node.symbol_name.clone(),
symbol_type: None,
signature,
complexity: node.complexity,
caller_count: None,
dependency_count: None,
language: node.language.clone(),
score,
context: None,
byte_range: node.byte_range,
});
}
}
metrics.exact_scored = results.len();
results.sort_by(|a, b| {
b.score
.overall
.partial_cmp(&a.score.overall)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut final_results: Vec<SearchResult> = results.into_iter().take(query.top_k).collect();
for (i, result) in final_results.iter_mut().enumerate() {
result.rank = i + 1;
}
metrics.results_returned = final_results.len();
{
let results_bytes = Self::estimate_search_results_bytes(&final_results);
if results_bytes < SEARCH_CACHE_MAX_BYTES {
if let Some(existing) = self.search_cache.get(&cache_key) {
self.search_cache_bytes = self
.search_cache_bytes
.saturating_sub(Self::estimate_search_results_bytes(existing));
}
while self.search_cache_bytes + results_bytes > SEARCH_CACHE_MAX_BYTES
&& !self.search_cache.is_empty()
{
if let Some((_, evicted)) = self.search_cache.pop_lru() {
self.search_cache_bytes = self
.search_cache_bytes
.saturating_sub(Self::estimate_search_results_bytes(&evicted));
}
}
self.search_cache_bytes += results_bytes;
self.search_cache.put(cache_key, final_results.clone());
}
}
Ok((final_results, metrics))
}
fn calculate_text_score_optimized(
&self,
precomputed: &TextQueryPreprocessed,
node_id: &str,
symbol_name: &str,
file_path: &str,
) -> f32 {
let symbol_boost = if symbol_name
.to_ascii_lowercase()
.contains(&precomputed.query_lower)
{
0.5
} else {
0.0
};
let test_penalty = if file_path.to_ascii_lowercase().contains("test")
|| symbol_name.to_ascii_lowercase().contains("test")
{
0.3
} else {
0.0
};
let base_score = if precomputed.query_tokens.is_empty() {
0.0
} else if let Some(node_tokens) = self.node_tokens.get(node_id) {
let matching = precomputed.query_tokens.intersection(node_tokens).count();
matching as f32 / precomputed.query_tokens.len() as f32
} else {
0.0
};
((base_score + symbol_boost) - test_penalty).clamp(0.0, 1.0)
}
pub fn semantic_search(
&self,
query_embedding: &[f32],
top_k: usize,
) -> Result<Vec<SemanticEntry>, Error> {
if self.vector_index.is_empty() {
return Ok(Vec::new());
}
if query_embedding.len() != self.vector_index.dimension() {
return Err(Error::QueryFailed(format!(
"Embedding dimension mismatch: expected {}, got {}",
self.vector_index.dimension(),
query_embedding.len()
)));
}
let results = self.vector_index.search(query_embedding, top_k);
let entries = results
.into_iter()
.map(|(node_id, score)| {
let entry_type = self
.node_id_to_idx
.get(&node_id)
.and_then(|&idx| self.nodes.get(idx))
.map(|_| EntryType::Function)
.unwrap_or(EntryType::Function);
SemanticEntry {
node_id,
relevance: score,
entry_type,
}
})
.collect();
Ok(entries)
}
#[must_use]
pub fn vector_index(&self) -> &VectorIndexImpl {
&self.vector_index
}
pub fn vector_index_mut(&mut self) -> &mut VectorIndexImpl {
&mut self.vector_index
}
pub fn enable_hnsw(&mut self, params: Option<HNSWParams>) {
let dimension = self.vector_index.dimension();
let params = params.unwrap_or_default();
self.vector_index =
VectorIndexImpl::HNSW(Box::new(HNSWIndex::with_params(dimension, params)));
}
#[must_use]
pub fn is_hnsw_enabled(&self) -> bool {
matches!(
self.vector_index,
VectorIndexImpl::HNSW(_) | VectorIndexImpl::HNSWQuantized(_)
)
}
pub fn disable_hnsw(&mut self) {
let dimension = self.vector_index.dimension();
self.vector_index = VectorIndexImpl::BruteForce(VectorIndex::new(dimension));
}
#[must_use]
pub fn with_hnsw(dimension: usize, params: HNSWParams) -> Self {
let mut engine = Self::with_dimension(dimension);
engine.enable_hnsw(Some(params));
engine
}
pub fn enable_int8_hnsw(&mut self, params: Option<Int8HnswParams>) {
let dimension = self.vector_index.dimension();
let params = params.unwrap_or_default();
self.vector_index =
VectorIndexImpl::HNSWQuantized(Box::new(Int8HnswIndex::with_params(dimension, params)));
}
#[must_use]
pub fn is_quantized(&self) -> bool {
matches!(self.vector_index, VectorIndexImpl::HNSWQuantized(_))
}
#[must_use]
pub fn estimated_memory_bytes(&self) -> usize {
let nodes_size = self.nodes.len() * std::mem::size_of::<NodeInfo>();
let cache_size = self.complexity_cache.len()
* (std::mem::size_of::<String>() + std::mem::size_of::<u32>());
let text_index_size = self
.text_index
.values()
.map(|set| set.len() * std::mem::size_of::<String>())
.sum::<usize>();
nodes_size + cache_size + text_index_size + self.vector_index.estimated_memory_bytes()
}
fn estimate_search_results_bytes(results: &[SearchResult]) -> usize {
results
.iter()
.map(|r| {
r.node_id.len()
+ r.file_path.len()
+ r.symbol_name.len()
+ r.symbol_type.as_ref().map_or(0, |s| s.len())
+ r.signature.as_ref().map_or(0, |s| s.len())
+ r.language.len()
+ r.context.as_ref().map_or(0, |c| c.len())
+ 128 })
.sum()
}
}
#[derive(Debug, Default)]
pub struct TextIndexDelta {
pub removed_node_ids: Vec<String>,
pub updated_nodes: Vec<NodeInfo>,
}
impl Default for SearchEngine {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum EntryType {
Function,
Method,
Class,
Module,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SemanticEntry {
pub node_id: String,
pub relevance: f32,
pub entry_type: EntryType,
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Query failed: {0}")]
QueryFailed(String),
#[error("Index is empty")]
EmptyIndex,
#[error("Dimension mismatch: expected {expected}, got {got}")]
DimensionMismatch {
expected: usize,
got: usize,
},
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_nodes() -> Vec<NodeInfo> {
vec![
NodeInfo {
node_id: "func1".to_string(),
file_path: "test.rs".to_string(),
symbol_name: "func1".to_string(),
language: "rust".to_string(),
content: "fn func1() { println!(\"hello\"); }".to_string(),
byte_range: (0, 40),
tfidf_embedding: vec![1.0, 0.0, 0.0],
neural_embedding: None,
complexity: 2,
signature: None,
pre_tokenized: None,
},
NodeInfo {
node_id: "func2".to_string(),
file_path: "test.rs".to_string(),
symbol_name: "func2".to_string(),
language: "rust".to_string(),
content: "fn func2() { println!(\"world\"); }".to_string(),
byte_range: (42, 82),
tfidf_embedding: vec![0.0, 1.0, 0.0],
neural_embedding: None,
complexity: 2,
signature: None,
pre_tokenized: None,
},
]
}
#[test]
fn test_search_engine_creation() {
let engine = SearchEngine::new();
assert_eq!(engine.node_count(), 0);
assert!(engine.is_empty());
}
#[test]
fn test_index_nodes() {
let mut engine = SearchEngine::new();
let nodes = create_test_nodes();
engine.index_nodes(nodes);
assert_eq!(engine.node_count(), 2);
assert!(!engine.is_empty());
}
#[test]
fn test_search_empty_index() {
let mut engine = SearchEngine::new();
let query = SearchQuery {
query: "test".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_semantic_search_empty_index() {
let engine = SearchEngine::new();
let results = engine.semantic_search(&[0.1, 0.2, 0.3], 10).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_search_with_results() {
let mut engine = SearchEngine::new();
let nodes = create_test_nodes();
engine.index_nodes(nodes);
let query = SearchQuery {
query: "func1".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].node_id, "func1");
}
#[test]
fn test_semantic_search() {
let mut engine = SearchEngine::with_dimension(3);
let nodes = create_test_nodes();
engine.index_nodes(nodes);
let results = engine.semantic_search(&[1.0, 0.0, 0.0], 1).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].node_id, "func1");
}
#[test]
fn test_dimension_validation() {
let engine = SearchEngine::with_dimension(128);
assert_eq!(engine.vector_index().dimension(), 128);
}
#[test]
fn test_dimension_mismatch_error() {
let mut engine = SearchEngine::with_dimension(3);
let nodes = create_test_nodes();
engine.index_nodes(nodes);
let result = engine.semantic_search(&[0.1, 0.2], 10);
assert!(result.is_err());
}
#[test]
fn test_hnsw_enable() {
let mut engine = SearchEngine::with_dimension(128);
engine.enable_hnsw(None);
assert!(engine.vector_index().is_hnsw_enabled());
}
#[test]
fn test_top_k_limit() {
let mut engine = SearchEngine::new();
let nodes = create_test_nodes();
engine.index_nodes(nodes);
let query = SearchQuery {
query: "fn".to_string(),
top_k: 1,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn test_relevance_threshold() {
let mut engine = SearchEngine::new();
let nodes = create_test_nodes();
engine.index_nodes(nodes);
let query = SearchQuery {
query: "nonexistent".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: Some(0.5),
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_node_id_to_idx_populated() {
let mut engine = SearchEngine::new();
let nodes = create_test_nodes();
engine.index_nodes(nodes);
assert_eq!(engine.node_id_to_idx.len(), 2);
assert_eq!(engine.node_id_to_idx.get("func1"), Some(&0));
assert_eq!(engine.node_id_to_idx.get("func2"), Some(&1));
}
#[test]
fn test_node_id_to_idx_o1_lookup_in_semantic_search() {
let mut engine = SearchEngine::with_dimension(3);
let nodes = create_test_nodes();
engine.index_nodes(nodes);
let results = engine.semantic_search(&[1.0, 0.0, 0.0], 10).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].node_id, "func1");
assert_eq!(results[0].entry_type, EntryType::Function);
for entry in &results {
assert_eq!(entry.entry_type, EntryType::Function);
}
}
#[test]
fn test_node_id_to_idx_cleared_on_reindex() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
assert_eq!(engine.node_id_to_idx.len(), 2);
engine.index_nodes(vec![NodeInfo {
node_id: "new_func".to_string(),
file_path: "new.rs".to_string(),
symbol_name: "new_func".to_string(),
language: "rust".to_string(),
content: "fn new_func() {}".to_string(),
byte_range: (0, 18),
tfidf_embedding: vec![],
neural_embedding: None,
complexity: 1,
signature: None,
pre_tokenized: None,
}]);
assert_eq!(engine.node_id_to_idx.len(), 1);
assert_eq!(engine.node_id_to_idx.get("new_func"), Some(&0));
assert_eq!(engine.node_id_to_idx.get("func1"), None);
}
#[test]
fn test_content_cleared_after_indexing() {
let mut engine = SearchEngine::new();
let nodes = create_test_nodes();
engine.index_nodes(nodes);
for node in &engine.nodes {
assert!(
node.content.is_empty(),
"Node {} content should be cleared after indexing, but got: {:?}",
node.node_id,
node.content
);
}
let query = SearchQuery {
query: "func1".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(
!results.is_empty(),
"Search should still find results via inverted index after content cleared"
);
assert_eq!(results[0].node_id, "func1");
assert!(
!engine.text_index.is_empty(),
"text_index should be populated"
);
assert!(
engine.text_index.contains_key("func1"),
"text_index should contain 'func1' token"
);
assert!(
engine.text_index.contains_key("func2"),
"text_index should contain 'func2' token"
);
}
#[test]
fn test_node_tokens_populated() {
let mut engine = SearchEngine::new();
let nodes = create_test_nodes();
engine.index_nodes(nodes);
assert_eq!(engine.node_tokens.len(), 2);
assert!(engine.node_tokens.contains_key("func1"));
assert!(engine.node_tokens.contains_key("func2"));
let func1_tokens = engine.node_tokens.get("func1").unwrap();
assert!(
func1_tokens.contains("func1"),
"func1 tokens should contain 'func1', got: {:?}",
func1_tokens
);
let func2_tokens = engine.node_tokens.get("func2").unwrap();
assert!(
func2_tokens.contains("func2"),
"func2 tokens should contain 'func2', got: {:?}",
func2_tokens
);
}
#[test]
fn test_node_tokens_cleared_on_reindex() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
assert_eq!(engine.node_tokens.len(), 2);
engine.index_nodes(vec![NodeInfo {
node_id: "new_func".to_string(),
file_path: "test.rs".to_string(),
symbol_name: "new_func".to_string(),
language: "rust".to_string(),
content: "fn new_func() {}".to_string(),
byte_range: (0, 18),
tfidf_embedding: vec![],
neural_embedding: None,
complexity: 1,
signature: None,
pre_tokenized: None,
}]);
assert_eq!(engine.node_tokens.len(), 1);
assert!(engine.node_tokens.contains_key("new_func"));
assert!(!engine.node_tokens.contains_key("func1"));
}
#[test]
fn test_node_tokens_used_in_scoring() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
for node in &engine.nodes {
assert!(node.content.is_empty());
}
let query = SearchQuery {
query: "println hello".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(
!results.is_empty(),
"Search should find results using cached node_tokens even after content is cleared"
);
assert_eq!(results[0].node_id, "func1");
}
#[test]
fn test_incremental_reindex_add_nodes() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
assert_eq!(engine.node_count(), 2);
let delta = TextIndexDelta {
removed_node_ids: vec![],
updated_nodes: vec![NodeInfo {
node_id: "func3".to_string(),
file_path: "test.rs".to_string(),
symbol_name: "func3".to_string(),
language: "rust".to_string(),
content: "fn func3() { db_query(); }".to_string(),
byte_range: (100, 130),
tfidf_embedding: vec![0.0, 0.0, 1.0],
neural_embedding: None,
complexity: 3,
signature: None,
pre_tokenized: None,
}],
};
engine.incremental_reindex(delta);
assert_eq!(engine.node_count(), 3);
assert_eq!(engine.node_id_to_idx.len(), 3);
assert_eq!(engine.node_tokens.len(), 3);
assert_eq!(engine.complexity_cache.len(), 3);
let query = SearchQuery {
query: "func3".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].node_id, "func3");
assert!(engine.text_index.contains_key("func3"));
assert!(engine.text_index.contains_key("query"));
}
#[test]
fn test_incremental_reindex_remove_nodes() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
assert_eq!(engine.node_count(), 2);
let delta = TextIndexDelta {
removed_node_ids: vec!["func1".to_string()],
updated_nodes: vec![],
};
engine.incremental_reindex(delta);
assert_eq!(engine.node_count(), 1);
assert_eq!(engine.node_id_to_idx.len(), 1);
assert!(!engine.node_id_to_idx.contains_key("func1"));
assert!(engine.node_id_to_idx.contains_key("func2"));
if let Some(ids) = engine.text_index.get("func1") {
assert!(
!ids.contains("func1"),
"func1 should be removed from text_index"
);
}
assert!(!engine.node_tokens.contains_key("func1"));
assert!(engine.node_tokens.contains_key("func2"));
let query = SearchQuery {
query: "func1".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(
results.is_empty(),
"func1 should not be found after removal"
);
}
#[test]
fn test_incremental_reindex_update_existing_node() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
let delta = TextIndexDelta {
removed_node_ids: vec![],
updated_nodes: vec![NodeInfo {
node_id: "func1".to_string(),
file_path: "updated.rs".to_string(),
symbol_name: "func1_renamed".to_string(),
language: "rust".to_string(),
content: "fn func1_renamed() { new_logic(); }".to_string(),
byte_range: (0, 35),
tfidf_embedding: vec![0.5, 0.5, 0.0],
neural_embedding: None,
complexity: 5,
signature: None,
pre_tokenized: None,
}],
};
engine.incremental_reindex(delta);
assert_eq!(engine.node_count(), 2);
assert_eq!(engine.complexity_cache.get("func1"), Some(&5));
assert!(engine.node_tokens.get("func1").unwrap().contains("logic"));
assert!(engine.text_index.contains_key("logic"));
let query = SearchQuery {
query: "new_logic".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].node_id, "func1");
}
#[test]
fn test_incremental_reindex_combined_add_remove() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
let delta = TextIndexDelta {
removed_node_ids: vec!["func1".to_string()],
updated_nodes: vec![
NodeInfo {
node_id: "func3".to_string(),
file_path: "new.rs".to_string(),
symbol_name: "func3".to_string(),
language: "rust".to_string(),
content: "fn func3() {}".to_string(),
byte_range: (0, 14),
tfidf_embedding: vec![],
neural_embedding: None,
complexity: 1,
signature: None,
pre_tokenized: None,
},
NodeInfo {
node_id: "func4".to_string(),
file_path: "new.rs".to_string(),
symbol_name: "func4".to_string(),
language: "rust".to_string(),
content: "fn func4() { helper(); }".to_string(),
byte_range: (15, 40),
tfidf_embedding: vec![],
neural_embedding: None,
complexity: 2,
signature: None,
pre_tokenized: None,
},
],
};
engine.incremental_reindex(delta);
assert_eq!(engine.node_count(), 3);
assert_eq!(engine.node_id_to_idx.len(), 3);
assert!(!engine.node_id_to_idx.contains_key("func1"));
assert!(engine.node_id_to_idx.contains_key("func2"));
assert!(engine.node_id_to_idx.contains_key("func3"));
assert!(engine.node_id_to_idx.contains_key("func4"));
let query = SearchQuery {
query: "func2".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].node_id, "func2");
}
#[test]
fn test_incremental_reindex_empty_delta() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
let delta = TextIndexDelta {
removed_node_ids: vec![],
updated_nodes: vec![],
};
engine.incremental_reindex(delta);
assert_eq!(engine.node_count(), 2);
assert_eq!(engine.node_id_to_idx.len(), 2);
}
#[test]
fn test_incremental_reindex_removes_empty_token_sets() {
let mut engine = SearchEngine::new();
engine.index_nodes(vec![
NodeInfo {
node_id: "unique1".to_string(),
file_path: "test.rs".to_string(),
symbol_name: "unique1".to_string(),
language: "rust".to_string(),
content: "fn unique1() { zebra(); }".to_string(),
byte_range: (0, 25),
tfidf_embedding: vec![],
neural_embedding: None,
complexity: 1,
signature: None,
pre_tokenized: None,
},
NodeInfo {
node_id: "unique2".to_string(),
file_path: "test.rs".to_string(),
symbol_name: "unique2".to_string(),
language: "rust".to_string(),
content: "fn unique2() { apple(); }".to_string(),
byte_range: (26, 52),
tfidf_embedding: vec![],
neural_embedding: None,
complexity: 1,
signature: None,
pre_tokenized: None,
},
]);
assert!(engine.text_index.contains_key("zebra"));
let delta = TextIndexDelta {
removed_node_ids: vec!["unique1".to_string()],
updated_nodes: vec![],
};
engine.incremental_reindex(delta);
assert!(
!engine.text_index.contains_key("zebra"),
"Token with no remaining nodes should be removed from text_index"
);
assert!(engine.text_index.contains_key("apple"));
}
#[test]
fn test_incremental_reindex_correctness_vs_full_rebuild() {
let mut engine_inc = SearchEngine::new();
let mut engine_full = SearchEngine::new();
let initial = create_test_nodes();
engine_inc.index_nodes(initial.clone());
engine_full.index_nodes(initial);
let delta = TextIndexDelta {
removed_node_ids: vec!["func1".to_string()],
updated_nodes: vec![NodeInfo {
node_id: "func3".to_string(),
file_path: "new.rs".to_string(),
symbol_name: "func3".to_string(),
language: "rust".to_string(),
content: "fn func3() { compute(); }".to_string(),
byte_range: (0, 25),
tfidf_embedding: vec![1.0, 1.0, 0.0],
neural_embedding: None,
complexity: 4,
signature: None,
pre_tokenized: None,
}],
};
engine_inc.incremental_reindex(delta);
engine_full.index_nodes(vec![
NodeInfo {
node_id: "func2".to_string(),
file_path: "test.rs".to_string(),
symbol_name: "func2".to_string(),
language: "rust".to_string(),
content: "fn func2() { println!(\"world\"); }".to_string(),
byte_range: (42, 82),
tfidf_embedding: vec![0.0, 1.0, 0.0],
neural_embedding: None,
complexity: 2,
signature: None,
pre_tokenized: None,
},
NodeInfo {
node_id: "func3".to_string(),
file_path: "new.rs".to_string(),
symbol_name: "func3".to_string(),
language: "rust".to_string(),
content: "fn func3() { compute(); }".to_string(),
byte_range: (0, 25),
tfidf_embedding: vec![1.0, 1.0, 0.0],
neural_embedding: None,
complexity: 4,
signature: None,
pre_tokenized: None,
},
]);
assert_eq!(engine_inc.node_count(), engine_full.node_count());
let inc_ids: std::collections::BTreeSet<_> =
engine_inc.nodes.iter().map(|n| n.node_id.clone()).collect();
let full_ids: std::collections::BTreeSet<_> = engine_full
.nodes
.iter()
.map(|n| n.node_id.clone())
.collect();
assert_eq!(inc_ids, full_ids);
let query = SearchQuery {
query: "func2".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let inc_results = engine_inc.search(query.clone()).unwrap();
let full_results = engine_full.search(query).unwrap();
assert_eq!(inc_results.len(), full_results.len());
if !inc_results.is_empty() {
assert_eq!(inc_results[0].node_id, full_results[0].node_id);
}
let inc_sem = engine_inc.semantic_search(&[1.0, 1.0, 0.0], 10).unwrap();
let full_sem = engine_full.semantic_search(&[1.0, 1.0, 0.0], 10).unwrap();
assert_eq!(inc_sem.len(), full_sem.len());
if !inc_sem.is_empty() {
assert_eq!(inc_sem[0].node_id, full_sem[0].node_id);
}
}
#[test]
fn test_incremental_reindex_semantic_search_after_update() {
let mut engine = SearchEngine::with_dimension(3);
engine.index_nodes(create_test_nodes());
let delta = TextIndexDelta {
removed_node_ids: vec![],
updated_nodes: vec![NodeInfo {
node_id: "func3".to_string(),
file_path: "test.rs".to_string(),
symbol_name: "func3".to_string(),
language: "rust".to_string(),
content: "fn func3() {}".to_string(),
byte_range: (0, 14),
tfidf_embedding: vec![0.1, 0.1, 0.9],
neural_embedding: None,
complexity: 1,
signature: None,
pre_tokenized: None,
}],
};
engine.incremental_reindex(delta);
let results = engine.semantic_search(&[0.1, 0.1, 0.9], 1).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].node_id, "func3");
}
#[test]
fn test_incremental_reindex_node_id_to_idx_consistency() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
engine.incremental_reindex(TextIndexDelta {
removed_node_ids: vec![],
updated_nodes: vec![NodeInfo {
node_id: "func3".to_string(),
file_path: "test.rs".to_string(),
symbol_name: "func3".to_string(),
language: "rust".to_string(),
content: "fn func3() {}".to_string(),
byte_range: (0, 14),
tfidf_embedding: vec![],
neural_embedding: None,
complexity: 1,
signature: None,
pre_tokenized: None,
}],
});
engine.incremental_reindex(TextIndexDelta {
removed_node_ids: vec!["func1".to_string()],
updated_nodes: vec![],
});
assert_eq!(engine.node_id_to_idx.len(), engine.nodes.len());
for (idx, node) in engine.nodes.iter().enumerate() {
assert_eq!(
engine.node_id_to_idx.get(&node.node_id),
Some(&idx),
"node_id_to_idx mismatch for node {}",
node.node_id
);
}
}
#[test]
fn test_incremental_reindex_removes_nonexistent_node() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
let delta = TextIndexDelta {
removed_node_ids: vec!["nonexistent".to_string()],
updated_nodes: vec![],
};
engine.incremental_reindex(delta);
assert_eq!(engine.node_count(), 2);
assert_eq!(engine.node_id_to_idx.len(), 2);
}
#[test]
fn test_incremental_reindex_content_cleared() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
engine.incremental_reindex(TextIndexDelta {
removed_node_ids: vec![],
updated_nodes: vec![NodeInfo {
node_id: "func3".to_string(),
file_path: "test.rs".to_string(),
symbol_name: "func3".to_string(),
language: "rust".to_string(),
content: "fn func3() { important_content(); }".to_string(),
byte_range: (0, 40),
tfidf_embedding: vec![],
neural_embedding: None,
complexity: 3,
signature: None,
pre_tokenized: None,
}],
});
for node in &engine.nodes {
assert!(
node.content.is_empty(),
"Node {} content should be cleared, got: {:?}",
node.node_id,
node.content
);
}
assert!(engine
.node_tokens
.get("func3")
.unwrap()
.contains("important"));
}
#[test]
fn test_pre_tokenized_produces_identical_search_results() {
let content = "fn calculate_total(price: f64, tax: f64) -> f64 { price + tax }";
let search_tokens: Vec<String> = content
.split(|c: char| !c.is_alphanumeric())
.map(|s| s.to_ascii_lowercase())
.filter(|s| s.len() >= 2)
.collect();
let mut engine_pre = SearchEngine::new();
engine_pre.index_nodes(vec![NodeInfo {
node_id: "calc_total".to_string(),
file_path: "math.rs".to_string(),
symbol_name: "calculate_total".to_string(),
language: "rust".to_string(),
content: content.to_string(),
byte_range: (0, content.len()),
tfidf_embedding: vec![],
neural_embedding: None,
complexity: 3,
signature: None,
pre_tokenized: Some(search_tokens),
}]);
let mut engine_fallback = SearchEngine::new();
engine_fallback.index_nodes(vec![NodeInfo {
node_id: "calc_total".to_string(),
file_path: "math.rs".to_string(),
symbol_name: "calculate_total".to_string(),
language: "rust".to_string(),
content: content.to_string(),
byte_range: (0, content.len()),
tfidf_embedding: vec![],
neural_embedding: None,
complexity: 3,
signature: None,
pre_tokenized: None,
}]);
assert_eq!(
engine_pre.text_index, engine_fallback.text_index,
"Pre-tokenized and fallback should produce identical text_index"
);
assert_eq!(
engine_pre.node_tokens, engine_fallback.node_tokens,
"Pre-tokenized and fallback should produce identical node_tokens"
);
let query = SearchQuery {
query: "calculate".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results_pre = engine_pre.search(query.clone()).unwrap();
let results_fallback = engine_fallback.search(query).unwrap();
assert_eq!(results_pre.len(), results_fallback.len());
assert!(!results_pre.is_empty());
assert_eq!(results_pre[0].node_id, results_fallback[0].node_id);
}
#[test]
fn test_pre_tokenized_none_falls_back_to_content() {
let mut engine = SearchEngine::new();
engine.index_nodes(vec![NodeInfo {
node_id: "backward_compat".to_string(),
file_path: "compat.rs".to_string(),
symbol_name: "legacy_func".to_string(),
language: "rust".to_string(),
content: "fn legacy_func() { return 42; }".to_string(),
byte_range: (0, 30),
tfidf_embedding: vec![],
neural_embedding: None,
complexity: 1,
signature: None,
pre_tokenized: None,
}]);
assert!(engine.text_index.contains_key("legacy"));
assert!(engine.text_index.contains_key("func"));
assert!(engine.node_tokens.contains_key("backward_compat"));
let query = SearchQuery {
query: "legacy".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].node_id, "backward_compat");
}
#[test]
fn test_pre_tokenized_and_content_produce_same_inverted_index() {
let content = "pub async fn handle_http_request(req: Request) -> Response { ... }";
let tokens: Vec<String> = content
.split(|c: char| !c.is_alphanumeric())
.map(|s| s.to_ascii_lowercase())
.filter(|s| s.len() >= 2)
.collect();
assert!(tokens.contains(&"handle".to_string()));
assert!(tokens.contains(&"http".to_string()));
assert!(tokens.contains(&"request".to_string()));
assert!(tokens.contains(&"response".to_string()));
let mut engine_a = SearchEngine::new();
engine_a.index_nodes(vec![NodeInfo {
node_id: "handler".to_string(),
file_path: "server.rs".to_string(),
symbol_name: "handle_http_request".to_string(),
language: "rust".to_string(),
content: content.to_string(),
byte_range: (0, content.len()),
tfidf_embedding: vec![],
neural_embedding: None,
complexity: 5,
signature: None,
pre_tokenized: Some(tokens),
}]);
let mut engine_b = SearchEngine::new();
engine_b.index_nodes(vec![NodeInfo {
node_id: "handler".to_string(),
file_path: "server.rs".to_string(),
symbol_name: "handle_http_request".to_string(),
language: "rust".to_string(),
content: content.to_string(),
byte_range: (0, content.len()),
tfidf_embedding: vec![],
neural_embedding: None,
complexity: 5,
signature: None,
pre_tokenized: None,
}]);
for token in &["handle", "http", "request", "response", "pub", "async"] {
assert_eq!(
engine_a.text_index.get(*token),
engine_b.text_index.get(*token),
"Mismatch for token '{}': pre_tokenized={:?}, content={:?}",
token,
engine_a.text_index.get(*token),
engine_b.text_index.get(*token)
);
}
}
#[test]
fn test_pre_tokenized_incremental_reindex() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
let new_content = "fn compute_metrics(data: &[f64]) -> Metrics { ... }";
let tokens: Vec<String> = new_content
.split(|c: char| !c.is_alphanumeric())
.map(|s| s.to_ascii_lowercase())
.filter(|s| s.len() >= 2)
.collect();
let delta = TextIndexDelta {
removed_node_ids: vec![],
updated_nodes: vec![NodeInfo {
node_id: "metrics".to_string(),
file_path: "metrics.rs".to_string(),
symbol_name: "compute_metrics".to_string(),
language: "rust".to_string(),
content: new_content.to_string(),
byte_range: (0, new_content.len()),
tfidf_embedding: vec![],
neural_embedding: None,
complexity: 4,
signature: None,
pre_tokenized: Some(tokens),
}],
};
engine.incremental_reindex(delta);
assert_eq!(engine.node_count(), 3);
assert!(engine.text_index.contains_key("compute"));
assert!(engine.text_index.contains_key("metrics"));
assert!(engine.text_index.contains_key("data"));
let query = SearchQuery {
query: "compute metrics".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].node_id, "metrics");
}
#[test]
fn test_search_cache_hard_capped() {
let mut engine = SearchEngine::new();
let nodes: Vec<NodeInfo> = (0..50)
.map(|i| NodeInfo {
node_id: format!("node_{}", i),
file_path: format!("file_{}.rs", i),
symbol_name: format!("symbol_{}", i),
language: "rust".to_string(),
content: format!("fn symbol_{}() {{}}", i),
byte_range: (0, 16),
tfidf_embedding: vec![0.0; 768],
neural_embedding: None,
complexity: 1,
signature: None,
pre_tokenized: None,
})
.collect();
engine.index_nodes(nodes);
for i in 0..300 {
let query = SearchQuery {
query: format!("query_{}", i),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let _ = engine.search(query);
}
assert!(
engine.search_cache.len() <= SEARCH_CACHE_MAX_ENTRIES,
"search cache entries ({}) should not exceed max ({})",
engine.search_cache.len(),
SEARCH_CACHE_MAX_ENTRIES
);
assert!(
engine.search_cache_bytes <= SEARCH_CACHE_MAX_BYTES,
"search cache bytes ({}) should not exceed max ({})",
engine.search_cache_bytes,
SEARCH_CACHE_MAX_BYTES
);
let query = SearchQuery {
query: "symbol_0".to_string(),
top_k: 5,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(!results.is_empty(), "search should still return results");
}
#[test]
fn test_admission_gate_admits_within_bounds() {
let mut gate = IndexingAdmissionGate::with_caps(100, 1024);
assert!(gate.try_admit(10));
assert!(gate.try_admit(10));
assert_eq!(gate.nodes_admitted(), 2);
assert_eq!(gate.nodes_shed(), 0);
}
#[test]
fn test_admission_gate_sheds_over_node_cap() {
let mut gate = IndexingAdmissionGate::with_caps(3, 1_000_000);
assert!(gate.try_admit(10));
assert!(gate.try_admit(10));
assert!(gate.try_admit(10));
assert!(!gate.try_admit(10));
assert!(!gate.try_admit(10));
assert_eq!(gate.nodes_admitted(), 3);
assert_eq!(gate.nodes_shed(), 2);
}
#[test]
fn test_admission_gate_sheds_over_byte_cap() {
let mut gate = IndexingAdmissionGate::with_caps(100, 50);
assert!(gate.try_admit(20));
assert!(gate.try_admit(20));
assert!(!gate.try_admit(20));
assert_eq!(gate.nodes_admitted(), 2);
assert_eq!(gate.bytes_admitted(), 40);
assert_eq!(gate.nodes_shed(), 1);
}
#[test]
fn test_admission_gate_resets() {
let mut gate = IndexingAdmissionGate::with_caps(2, 100);
assert!(gate.try_admit(10));
assert!(gate.try_admit(10));
assert!(!gate.try_admit(10));
gate.reset();
assert!(gate.try_admit(10));
assert_eq!(gate.nodes_admitted(), 1);
assert_eq!(gate.nodes_shed(), 0);
}
#[test]
fn test_admission_gate_default_caps() {
let gate = IndexingAdmissionGate::new();
assert_eq!(gate.nodes_admitted(), 0);
assert_eq!(gate.nodes_shed(), 0);
assert_eq!(gate.bytes_admitted(), 0);
}
#[test]
fn test_admission_gate_oversized_single_node() {
let mut gate = IndexingAdmissionGate::with_caps(100, 50);
assert!(!gate.try_admit(100));
assert_eq!(gate.nodes_shed(), 1);
assert_eq!(gate.nodes_admitted(), 0);
}
#[test]
fn test_admission_gate_bursty_workload() {
let mut gate = IndexingAdmissionGate::with_caps(10, 10_000);
let mut admitted = 0;
let mut shed = 0;
for _ in 0..50 {
if gate.try_admit(100) {
admitted += 1;
} else {
shed += 1;
}
}
assert_eq!(admitted, 10);
assert_eq!(shed, 40);
assert_eq!(gate.nodes_admitted(), 10);
assert_eq!(gate.nodes_shed(), 40);
}
#[test]
fn test_pruner_keeps_user_authored_code() {
let pruner = ContentPruner::new();
let decision = pruner.evaluate("src/main.rs", "fn main() { println!(\"hello\"); }", "main");
assert_eq!(decision, PruningDecision::Keep);
}
#[test]
fn test_pruner_prunes_minified_js() {
let pruner = ContentPruner::new();
let decision = pruner.evaluate(
"static/app.min.js",
"var a=1,b=2;function c(){return a+b}",
"c",
);
assert!(matches!(decision, PruningDecision::GeneratedCode(_)));
}
#[test]
fn test_pruner_prunes_generated_protobuf() {
let pruner = ContentPruner::new();
let decision = pruner.evaluate(
"proto/user.pb.go",
"func (m *User) GetName() string { return m.Name }",
"GetName",
);
assert!(matches!(decision, PruningDecision::GeneratedCode(_)));
}
#[test]
fn test_pruner_prunes_generated_rust() {
let pruner = ContentPruner::new();
let decision = pruner.evaluate(
"src/types.generated.rs",
"pub fn generated_fn() -> i32 { 42 }",
"generated_fn",
);
assert!(matches!(decision, PruningDecision::GeneratedCode(_)));
}
#[test]
fn test_pruner_prunes_bundle_js() {
let pruner = ContentPruner::new();
let decision = pruner.evaluate(
"dist/app.bundle.js",
"module.exports=function(n){return n+1}",
"anonymous",
);
assert!(matches!(decision, PruningDecision::GeneratedCode(_)));
}
#[test]
fn test_pruner_prunes_node_modules() {
let pruner = ContentPruner::new();
let decision = pruner.evaluate(
"node_modules/lodash/index.js",
"function debounce(fn, ms) { /* ... */ }",
"debounce",
);
assert!(matches!(decision, PruningDecision::GeneratedCode(_)));
}
#[test]
fn test_pruner_prunes_low_information() {
let pruner = ContentPruner::new();
let decision = pruner.evaluate("src/x.rs", "fn x() {}", "x");
assert!(matches!(decision, PruningDecision::LowInformation(_)));
}
#[test]
fn test_pruner_keeps_short_content_with_meaningful_name() {
let pruner = ContentPruner::new();
let decision = pruner.evaluate("src/lib.rs", "fn compute() {}", "compute");
assert_eq!(decision, PruningDecision::Keep);
}
#[test]
fn test_pruner_is_generated_path() {
let pruner = ContentPruner::new();
assert!(pruner.is_generated_path("static/app.min.js"));
assert!(pruner.is_generated_path("proto/user.pb.go"));
assert!(pruner.is_generated_path("src/types.generated.rs"));
assert!(pruner.is_generated_path("node_modules/react/index.js"));
assert!(!pruner.is_generated_path("src/main.rs"));
assert!(!pruner.is_generated_path("lib/parser.py"));
}
#[test]
fn test_pruner_decision_is_observable() {
let pruner = ContentPruner::new();
let kept = pruner.evaluate("src/main.rs", "fn main() { /* ... */ }", "main");
let generated = pruner.evaluate("src/types.generated.rs", "fn gen() {}", "gen");
let low_info = pruner.evaluate("src/x.rs", "fn x() {}", "x");
assert_eq!(kept, PruningDecision::Keep);
match generated {
PruningDecision::GeneratedCode(reason) => {
assert!(
reason.contains("generated"),
"reason should mention generated: {}",
reason
);
}
other => panic!("expected GeneratedCode, got {:?}", other),
}
match low_info {
PruningDecision::LowInformation(reason) => {
assert!(
reason.contains("bytes"),
"reason should mention bytes: {}",
reason
);
}
other => panic!("expected LowInformation, got {:?}", other),
}
}
#[test]
fn test_pruner_does_not_remove_high_signal_files() {
let pruner = ContentPruner::new();
let cases = vec![
(
"src/lib.rs",
"pub fn connect(db: &Database) -> Result<Connection> { /* ... */ }",
"connect",
),
(
"src/api/handlers.rs",
"async fn handle_request(req: Request) -> Response { /* ... */ }",
"handle_request",
),
(
"src/models/user.rs",
"struct User { name: String, email: String, created_at: DateTime }",
"User",
),
(
"app/controllers/application_controller.rb",
"def index; @items = Item.all; end",
"index",
),
];
for (path, content, symbol) in cases {
let decision = pruner.evaluate(path, content, symbol);
assert_eq!(
decision,
PruningDecision::Keep,
"high-signal file {} should be kept, got {:?}",
path,
decision
);
}
}
#[test]
fn test_work_hoister_stores_and_retrieves() {
let mut hoister = WorkHoister::with_bounds(100, 1_000_000);
let content = "fn compute(x: i32) -> i32 { x + 1 }";
let embedding = vec![0.1, 0.2, 0.3];
hoister.store(content, embedding.clone(), None);
let retrieved = hoister.lookup(content);
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().0, embedding);
}
#[test]
fn test_work_hoister_miss_for_unseen_content() {
let mut hoister = WorkHoister::with_bounds(100, 1_000_000);
assert!(hoister.lookup("unseen content").is_none());
}
#[test]
fn test_work_hoister_reuses_identical_content() {
let mut hoister = WorkHoister::with_bounds(100, 1_000_000);
let content = "fn identical() {}";
let embedding = vec![0.5, 0.5, 0.5];
hoister.store(content, embedding.clone(), None);
let retrieved = hoister.lookup(content);
assert_eq!(retrieved.unwrap().0, embedding);
}
#[test]
fn test_work_hoister_distinguishes_different_content() {
let mut hoister = WorkHoister::with_bounds(100, 1_000_000);
let content_a = "fn alpha() {}";
let content_b = "fn beta() {}";
let embedding_a = vec![1.0, 0.0, 0.0];
let embedding_b = vec![0.0, 1.0, 0.0];
hoister.store(content_a, embedding_a.clone(), None);
hoister.store(content_b, embedding_b.clone(), None);
assert_eq!(hoister.lookup(content_a).unwrap().0, embedding_a);
assert_eq!(hoister.lookup(content_b).unwrap().0, embedding_b);
}
#[test]
fn test_work_hoister_evicts_on_entry_cap() {
let mut hoister = WorkHoister::with_bounds(3, 1_000_000);
hoister.store("content_1", vec![1.0], None);
hoister.store("content_2", vec![2.0], None);
hoister.store("content_3", vec![3.0], None);
assert_eq!(hoister.len(), 3);
hoister.store("content_4", vec![4.0], None);
assert_eq!(hoister.len(), 3);
assert!(hoister.lookup("content_1").is_none());
assert!(hoister.lookup("content_4").is_some());
}
#[test]
fn test_work_hoister_evicts_on_byte_cap() {
let mut hoister = WorkHoister::with_bounds(100, 50);
hoister.store("short", vec![1.0, 2.0, 3.0], None);
assert_eq!(hoister.len(), 1);
hoister.store("another", vec![4.0, 5.0, 6.0], None);
assert!(hoister.lookup("short").is_none());
assert!(hoister.lookup("another").is_some());
}
#[test]
fn test_work_hoister_clear() {
let mut hoister = WorkHoister::with_bounds(100, 1_000_000);
hoister.store("content", vec![1.0], None);
assert!(!hoister.is_empty());
hoister.clear();
assert!(hoister.is_empty());
assert_eq!(hoister.bytes_used(), 0);
}
#[test]
fn test_work_hoister_preserves_search_results() {
let mut hoister = WorkHoister::with_bounds(100, 1_000_000);
let content = "fn compute(data: &[f64]) -> f64 { data.iter().sum() }";
let embedding = vec![0.1, 0.2, 0.3, 0.4, 0.5];
hoister.store(content, embedding.clone(), None);
let retrieved = hoister.lookup(content).unwrap().0;
assert_eq!(retrieved, embedding);
let mut engine_a = SearchEngine::with_dimension(5);
let mut engine_b = SearchEngine::with_dimension(5);
engine_a.index_nodes(vec![NodeInfo {
node_id: "compute".to_string(),
file_path: "math.rs".to_string(),
symbol_name: "compute".to_string(),
language: "rust".to_string(),
content: content.to_string(),
byte_range: (0, content.len()),
tfidf_embedding: embedding,
neural_embedding: None,
complexity: 3,
signature: None,
pre_tokenized: None,
}]);
engine_b.index_nodes(vec![NodeInfo {
node_id: "compute".to_string(),
file_path: "math.rs".to_string(),
symbol_name: "compute".to_string(),
language: "rust".to_string(),
content: content.to_string(),
byte_range: (0, content.len()),
tfidf_embedding: retrieved,
neural_embedding: None,
complexity: 3,
signature: None,
pre_tokenized: None,
}]);
let results_a = engine_a
.semantic_search(&[0.1, 0.2, 0.3, 0.4, 0.5], 5)
.unwrap();
let results_b = engine_b
.semantic_search(&[0.1, 0.2, 0.3, 0.4, 0.5], 5)
.unwrap();
assert_eq!(results_a.len(), results_b.len());
if !results_a.is_empty() {
assert_eq!(results_a[0].node_id, results_b[0].node_id);
assert!((results_a[0].relevance - results_b[0].relevance).abs() < 1e-6);
}
}
#[test]
fn test_work_hoister_duplicate_work_suppressed() {
let mut hoister = WorkHoister::with_bounds(100, 1_000_000);
let content = "fn expensive() { /* lots of work */ }";
let embedding = vec![0.42; 768];
hoister.store(content, embedding.clone(), None);
let lookup_result = hoister.lookup(content);
assert!(lookup_result.is_some());
assert_eq!(lookup_result.unwrap().0, embedding);
assert_eq!(hoister.len(), 1);
}
#[test]
fn test_work_hoister_updates_existing_embedding() {
let mut hoister = WorkHoister::with_bounds(100, 1_000_000);
let content = "fn update_me() {}";
let first_embedding = vec![1.0, 2.0, 3.0];
let second_embedding = vec![4.0, 5.0];
hoister.store(content, first_embedding.clone(), None);
let initial_bytes = hoister.bytes_used();
hoister.store(content, second_embedding.clone(), None);
assert_eq!(hoister.lookup(content).unwrap().0, second_embedding);
assert_eq!(hoister.len(), 1);
assert_eq!(
hoister.bytes_used(),
initial_bytes - first_embedding.len() * std::mem::size_of::<f32>()
+ second_embedding.len() * std::mem::size_of::<f32>()
);
}
fn make_node_info(tfidf: Vec<f32>) -> NodeInfo {
NodeInfo {
node_id: "test_node".into(),
file_path: "test.rs".into(),
symbol_name: "test_fn".into(),
language: "rust".into(),
content: "fn test_fn() {}".into(),
byte_range: (0, 16),
tfidf_embedding: tfidf,
neural_embedding: None,
complexity: 1,
signature: None,
pre_tokenized: None,
}
}
#[test]
fn test_legacy_payload_remains_readable() {
let legacy_json = r#"{
"node_id": "legacy_node",
"file_path": "legacy.rs",
"symbol_name": "legacy_fn",
"language": "rust",
"content": "fn legacy_fn() {}",
"byte_range": [0, 16],
"embedding": [0.1, 0.2, 0.3, 0.4],
"neural_embedding": null,
"complexity": 5,
"signature": null,
"pre_tokenized": null
}"#;
let node: NodeInfo = serde_json::from_str(legacy_json)
.expect("Legacy payload must deserialize during compatibility window");
assert_eq!(node.node_id, "legacy_node");
assert_eq!(node.tfidf_embedding, vec![0.1, 0.2, 0.3, 0.4]);
}
#[test]
fn test_new_payload_serializes_only_new_shape() {
let node = make_node_info(vec![1.0, 2.0, 3.0]);
let json = serde_json::to_string(&node).expect("Serialization must succeed");
assert!(
json.contains("\"tfidf_embedding\""),
"Serialized output must contain tfidf_embedding field"
);
let has_legacy_embedding =
json.contains("\"embedding\":") && !json.contains("\"tfidf_embedding\":");
assert!(
!has_legacy_embedding,
"Serialized output must not contain legacy 'embedding' field. JSON: {}",
json
);
}
#[test]
fn test_compat_prefers_tfidf_embedding_over_legacy() {
let dual_json = r#"{
"node_id": "dual_node",
"file_path": "dual.rs",
"symbol_name": "dual_fn",
"language": "rust",
"content": "fn dual_fn() {}",
"byte_range": [0, 14],
"tfidf_embedding": [0.5, 0.6, 0.7],
"embedding": [0.1, 0.2, 0.3],
"neural_embedding": null,
"complexity": 3,
"signature": null,
"pre_tokenized": null
}"#;
let node: NodeInfo =
serde_json::from_str(dual_json).expect("Dual-shape payload must deserialize");
assert_eq!(
node.tfidf_embedding,
vec![0.5, 0.6, 0.7],
"Must prefer tfidf_embedding when both fields are present"
);
}
#[test]
fn test_compat_fallback_promotes_legacy_when_needed() {
let legacy_only_json = r#"{
"node_id": "fallback_node",
"file_path": "fallback.rs",
"symbol_name": "fallback_fn",
"language": "rust",
"content": "fn fallback_fn() {}",
"byte_range": [0, 18],
"embedding": [0.9, 0.8, 0.7],
"neural_embedding": null,
"complexity": 2,
"signature": null,
"pre_tokenized": null
}"#;
let node: NodeInfo =
serde_json::from_str(legacy_only_json).expect("Legacy-only payload must deserialize");
assert_eq!(
node.tfidf_embedding,
vec![0.9, 0.8, 0.7],
"Must promote legacy embedding when tfidf_embedding is absent"
);
let empty_new_json = r#"{
"node_id": "empty_new_node",
"file_path": "empty.rs",
"symbol_name": "empty_fn",
"language": "rust",
"content": "fn empty_fn() {}",
"byte_range": [0, 14],
"tfidf_embedding": [],
"embedding": [0.4, 0.5, 0.6],
"neural_embedding": null,
"complexity": 1,
"signature": null,
"pre_tokenized": null
}"#;
let node2: NodeInfo = serde_json::from_str(empty_new_json)
.expect("Empty-new + legacy payload must deserialize");
assert_eq!(
node2.tfidf_embedding,
vec![0.4, 0.5, 0.6],
"Must promote legacy embedding when tfidf_embedding is empty"
);
}
#[test]
fn test_empty_embeddings_degrade_safely() {
let empty_json = r#"{
"node_id": "empty_node",
"file_path": "empty.rs",
"symbol_name": "empty_fn",
"language": "rust",
"content": "fn empty_fn() {}",
"byte_range": [0, 14],
"neural_embedding": null,
"complexity": 0,
"signature": null,
"pre_tokenized": null
}"#;
let node: NodeInfo = serde_json::from_str(empty_json)
.expect("Payload with no embeddings must deserialize successfully");
assert!(
node.tfidf_embedding.is_empty(),
"Must degrade to empty tfidf_embedding, got {:?}",
node.tfidf_embedding
);
assert_eq!(node.node_id, "empty_node");
}
#[test]
fn test_search_semantics_unchanged_after_dedup() {
let node = NodeInfo {
node_id: "dedup_node".into(),
file_path: "dedup.rs".into(),
symbol_name: "dedup_fn".into(),
language: "rust".into(),
content: "fn dedup_fn() { compute_value(); }".into(),
byte_range: (0, 32),
tfidf_embedding: vec![1.0, 0.0, 0.0],
neural_embedding: None,
complexity: 3,
signature: None,
pre_tokenized: None,
};
let mut engine = SearchEngine::with_dimension(3);
engine.index_nodes(vec![node]);
let query = SearchQuery {
query: "dedup_fn".into(),
top_k: 10,
token_budget: None,
semantic: true,
expand_context: false,
query_embedding: Some(vec![0.9, 0.1, 0.0]),
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(
!results.is_empty(),
"Search must return results for indexed node"
);
assert_eq!(
results[0].node_id, "dedup_node",
"Search must find the correct node"
);
let node_v2 = NodeInfo {
node_id: "dedup_node_v2".into(),
file_path: "dedup.rs".into(),
symbol_name: "dedup_fn_v2".into(),
language: "rust".into(),
content: "fn dedup_fn_v2() { compute_other(); }".into(),
byte_range: (0, 36),
tfidf_embedding: vec![0.0, 1.0, 0.0],
neural_embedding: None,
complexity: 4,
signature: None,
pre_tokenized: None,
};
let serialized = serde_json::to_string(&node_v2).unwrap();
let deserialized: NodeInfo = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized.tfidf_embedding, node_v2.tfidf_embedding);
let mut engine2 = SearchEngine::with_dimension(3);
engine2.index_nodes(vec![deserialized]);
let query2 = SearchQuery {
query: "dedup_fn_v2".into(),
top_k: 10,
token_budget: None,
semantic: true,
expand_context: false,
query_embedding: Some(vec![0.0, 0.9, 0.1]),
threshold: None,
query_type: None,
};
let results2 = engine2.search(query2).unwrap();
assert!(
!results2.is_empty(),
"Search must return results after round-trip serialization"
);
assert_eq!(
results2[0].node_id, "dedup_node_v2",
"Search must find the correct node after round-trip"
);
}
#[test]
fn test_post_index_content_clearing_preserved() {
let nodes = vec![
NodeInfo {
node_id: "clear_node_1".into(),
file_path: "clear.rs".into(),
symbol_name: "clear_fn_1".into(),
language: "rust".into(),
content: "fn clear_fn_1() { /* some content that should be cleared */ }".into(),
byte_range: (0, 60),
tfidf_embedding: vec![1.0, 0.0, 0.0],
neural_embedding: None,
complexity: 2,
signature: Some("fn clear_fn_1()".into()),
pre_tokenized: None,
},
NodeInfo {
node_id: "clear_node_2".into(),
file_path: "clear.rs".into(),
symbol_name: "clear_fn_2".into(),
language: "rust".into(),
content: "fn clear_fn_2() { /* more content to be cleared */ }".into(),
byte_range: (60, 110),
tfidf_embedding: vec![0.0, 1.0, 0.0],
neural_embedding: None,
complexity: 3,
signature: Some("fn clear_fn_2()".into()),
pre_tokenized: None,
},
];
let mut engine = SearchEngine::with_dimension(3);
engine.index_nodes(nodes);
for node in &engine.nodes {
assert!(
node.content.is_empty(),
"Node {} content should be cleared after indexing, but got: {:?}",
node.node_id,
node.content
);
}
let query = SearchQuery {
query: "clear_fn".into(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(
!results.is_empty(),
"Search should still find results via inverted index after content cleared"
);
}
}