use crate::search::hnsw::{HNSWIndex, HNSWParams};
use crate::search::quantization::int8_hnsw::{Int8HnswIndex, Int8HnswParams};
use crate::search::query::{MAX_EMBEDDING_DIMENSION, MIN_EMBEDDING_DIMENSION};
use crate::search::ranking::{HybridScorer, Score};
use crate::search::vector::VectorIndex;
use lru::LruCache;
use serde::{Deserialize, Serialize};
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use std::num::NonZeroUsize;
pub const DEFAULT_EMBEDDING_DIMENSION: usize = 768;
pub const MAX_NODES: usize = 1_000_000;
pub enum VectorIndexImpl {
BruteForce(VectorIndex),
HNSW(Box<HNSWIndex>),
HNSWQuantized(Box<Int8HnswIndex>),
}
impl VectorIndexImpl {
#[must_use]
pub fn len(&self) -> usize {
match self {
Self::BruteForce(idx) => idx.len(),
Self::HNSW(idx) => idx.len(),
Self::HNSWQuantized(idx) => idx.len(),
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
match self {
Self::BruteForce(idx) => idx.is_empty(),
Self::HNSW(idx) => idx.is_empty(),
Self::HNSWQuantized(idx) => idx.is_empty(),
}
}
#[must_use]
pub fn dimension(&self) -> usize {
match self {
Self::BruteForce(idx) => idx.dimension(),
Self::HNSW(idx) => idx.dimension(),
Self::HNSWQuantized(idx) => idx.dimension(),
}
}
pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(String, f32)> {
match self {
Self::BruteForce(idx) => idx.search(query, top_k),
Self::HNSW(idx) => idx.search(query, top_k),
Self::HNSWQuantized(idx) => idx.search(query, top_k),
}
}
pub fn insert(&mut self, node_id: String, vector: Vec<f32>) -> Result<(), VectorIndexError> {
match self {
Self::BruteForce(idx) => idx
.insert(node_id, vector)
.map_err(|e| VectorIndexError::InsertionFailed(e.to_string())),
Self::HNSW(idx) => idx
.insert(node_id, vector)
.map_err(|e| VectorIndexError::InsertionFailed(e.to_string())),
Self::HNSWQuantized(idx) => idx
.insert(node_id, vector)
.map_err(|e| VectorIndexError::InsertionFailed(e.to_string())),
}
}
pub fn clear(&mut self) {
match self {
Self::BruteForce(idx) => idx.clear(),
Self::HNSW(idx) => idx.clear(),
Self::HNSWQuantized(idx) => idx.clear(),
}
}
pub fn remove(&mut self, node_id: &str) -> bool {
match self {
Self::BruteForce(idx) => idx.remove(node_id),
Self::HNSW(idx) => idx.remove(node_id),
Self::HNSWQuantized(idx) => idx.remove(node_id),
}
}
#[must_use]
pub fn is_hnsw_enabled(&self) -> bool {
matches!(self, Self::HNSW(_) | Self::HNSWQuantized(_))
}
#[must_use]
pub fn estimated_memory_bytes(&self) -> usize {
match self {
Self::BruteForce(idx) => (*idx).estimated_memory_bytes(),
Self::HNSW(idx) => (*idx).estimated_memory_bytes(),
Self::HNSWQuantized(idx) => (*idx).estimated_memory_bytes(),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum VectorIndexError {
#[error("Insertion failed: {0}")]
InsertionFailed(String),
#[error("Index operation failed: {0}")]
IndexOperationFailed(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeInfo {
pub node_id: String,
pub file_path: String,
pub symbol_name: String,
pub language: String,
pub content: String,
pub byte_range: (usize, usize),
pub embedding: Option<Vec<f32>>,
pub complexity: u32,
pub signature: Option<String>,
}
struct TextQueryPreprocessed {
query_lower: String,
query_tokens: HashSet<String>,
}
impl TextQueryPreprocessed {
fn from_query(query: &str) -> Self {
let query_lower = query.to_ascii_lowercase();
let query_tokens: HashSet<_> = query
.split(|c: char| !c.is_alphanumeric())
.map(|s| s.to_ascii_lowercase())
.filter(|s| s.len() >= 2)
.collect();
Self {
query_lower,
query_tokens,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchQuery {
pub query: String,
pub top_k: usize,
pub token_budget: Option<usize>,
pub semantic: bool,
pub expand_context: bool,
pub query_embedding: Option<Vec<f32>>,
pub threshold: Option<f32>,
pub query_type: Option<crate::search::ranking::QueryType>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub rank: usize,
pub node_id: String,
pub file_path: String,
pub symbol_name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub symbol_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub signature: Option<String>,
pub complexity: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub caller_count: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dependency_count: Option<usize>,
pub language: String,
pub score: Score,
pub context: Option<String>,
pub byte_range: (usize, usize),
}
pub struct SearchEngine {
nodes: Vec<NodeInfo>,
scorer: HybridScorer,
vector_index: VectorIndexImpl,
complexity_cache: HashMap<String, u32>,
text_index: HashMap<String, HashSet<String>>,
node_id_to_idx: HashMap<String, usize>,
node_tokens: HashMap<String, HashSet<String>>,
search_cache: LruCache<String, Vec<SearchResult>>,
}
impl SearchEngine {
#[must_use]
pub fn new() -> Self {
Self {
nodes: Vec::new(),
scorer: HybridScorer::new(),
vector_index: VectorIndexImpl::BruteForce(VectorIndex::new(
DEFAULT_EMBEDDING_DIMENSION,
)),
complexity_cache: HashMap::new(),
text_index: HashMap::new(),
node_id_to_idx: HashMap::new(),
node_tokens: HashMap::new(),
search_cache: LruCache::new(NonZeroUsize::new(1000).unwrap()),
}
}
#[must_use]
pub fn with_dimension(dimension: usize) -> Self {
if !(MIN_EMBEDDING_DIMENSION..=MAX_EMBEDDING_DIMENSION).contains(&dimension) {
panic!(
"Invalid embedding dimension: {} (must be between {} and {})",
dimension, MIN_EMBEDDING_DIMENSION, MAX_EMBEDDING_DIMENSION
);
}
Self {
nodes: Vec::new(),
scorer: HybridScorer::new(),
vector_index: VectorIndexImpl::BruteForce(VectorIndex::new(dimension)),
complexity_cache: HashMap::new(),
text_index: HashMap::new(),
node_id_to_idx: HashMap::new(),
node_tokens: HashMap::new(),
search_cache: LruCache::new(NonZeroUsize::new(1000).unwrap()),
}
}
pub fn index_nodes(&mut self, nodes: Vec<NodeInfo>) {
if nodes.len() > MAX_NODES {
panic!(
"Cannot index more than {} nodes (provided: {})",
MAX_NODES,
nodes.len()
);
}
self.complexity_cache.clear();
self.text_index.clear();
self.search_cache.clear();
self.node_id_to_idx.clear();
self.node_tokens.clear();
self.vector_index.clear();
for (idx, node) in nodes.iter().enumerate() {
self.node_id_to_idx.insert(node.node_id.clone(), idx);
self.complexity_cache
.insert(node.node_id.clone(), node.complexity);
let mut tokens = HashSet::new();
for token in node.content.split(|c: char| !c.is_alphanumeric()) {
let normalized_token: String = token.to_ascii_lowercase();
if normalized_token.len() >= 2 {
self.text_index
.entry(normalized_token.clone())
.or_default()
.insert(node.node_id.clone());
tokens.insert(normalized_token);
}
}
self.node_tokens.insert(node.node_id.clone(), tokens);
}
for node in &nodes {
if let Some(embedding) = &node.embedding {
if let Err(e) = self
.vector_index
.insert(node.node_id.clone(), embedding.clone())
{
tracing::warn!(
"Failed to insert embedding for node {}: {:?}",
node.node_id,
e
);
}
}
}
self.nodes = nodes;
for node in &mut self.nodes {
node.signature = Self::extract_signature_from_content(&node.content);
}
for node in &mut self.nodes {
node.content.clear();
}
}
pub fn extract_signature_from_content(content: &str) -> Option<String> {
content
.lines()
.skip(1) .map(|l| l.trim())
.find(|l| !l.is_empty() && !l.starts_with("// [No source") && !l.starts_with("// ["))
.map(|l| l.to_string())
}
pub fn incremental_reindex(&mut self, delta: TextIndexDelta) {
self.search_cache.clear();
for node_id in &delta.removed_node_ids {
self.remove_node_from_index(node_id);
}
for node in delta.updated_nodes {
self.add_node_to_index(node);
}
if self.nodes.len() > MAX_NODES {
panic!(
"Cannot index more than {} nodes (current: {})",
MAX_NODES,
self.nodes.len()
);
}
}
fn remove_node_from_index(&mut self, node_id: &str) {
let Some(removed_idx) = self.node_id_to_idx.remove(node_id) else {
return; };
if let Some(tokens) = self.node_tokens.remove(node_id) {
for token in tokens {
if let Entry::Occupied(mut entry) = self.text_index.entry(token) {
entry.get_mut().remove(node_id);
if entry.get().is_empty() {
entry.remove();
}
}
}
}
self.complexity_cache.remove(node_id);
self.vector_index.remove(node_id);
if removed_idx < self.nodes.len() {
self.nodes.swap_remove(removed_idx);
if removed_idx < self.nodes.len() {
let swapped_id = self.nodes[removed_idx].node_id.clone();
self.node_id_to_idx.insert(swapped_id, removed_idx);
}
}
}
fn add_node_to_index(&mut self, mut node: NodeInfo) {
if self.node_id_to_idx.contains_key(&node.node_id) {
self.remove_node_from_index(&node.node_id);
}
let node_id = node.node_id.clone();
let new_idx = self.nodes.len();
let mut tokens = HashSet::new();
for token in node.content.split(|c: char| !c.is_alphanumeric()) {
let normalized_token: String = token.to_ascii_lowercase();
if normalized_token.len() >= 2 {
self.text_index
.entry(normalized_token.clone())
.or_default()
.insert(node_id.clone());
tokens.insert(normalized_token);
}
}
self.node_tokens.insert(node_id.clone(), tokens);
self.node_id_to_idx.insert(node_id.clone(), new_idx);
self.complexity_cache
.insert(node_id.clone(), node.complexity);
if let Some(embedding) = &node.embedding {
if let Err(e) = self.vector_index.insert(node_id.clone(), embedding.clone()) {
tracing::warn!("Failed to insert embedding for node {}: {:?}", node_id, e);
}
}
node.signature = Self::extract_signature_from_content(&node.content);
node.content.clear();
self.nodes.push(node);
}
#[must_use]
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn search(&mut self, query: SearchQuery) -> Result<Vec<SearchResult>, Error> {
if self.nodes.is_empty() {
return Ok(Vec::new());
}
let cache_key = format!(
"{}:{}:{:?}:{}",
query.query, query.top_k, query.threshold, query.semantic
);
if let Some(cached) = self.search_cache.get(&cache_key) {
return Ok(cached.clone());
}
let mut results = Vec::new();
let vector_results: std::collections::HashMap<String, f32> = if query.semantic {
let embedding = if let Some(emb) = query.query_embedding {
Some(emb)
} else {
self.nodes
.iter()
.find_map(|n| n.embedding.as_ref())
.cloned()
};
if let Some(emb) = embedding {
self.vector_index
.search(&emb, query.top_k)
.into_iter()
.collect()
} else {
std::collections::HashMap::new()
}
} else {
std::collections::HashMap::new()
};
let text_query = TextQueryPreprocessed::from_query(&query.query);
let candidates = if text_query.query_tokens.is_empty() {
self.nodes.iter().collect::<Vec<_>>()
} else {
let mut candidate_ids: HashSet<&str> = HashSet::new();
for token in &text_query.query_tokens {
if let Some(node_ids) = self.text_index.get(token) {
for node_id in node_ids {
candidate_ids.insert(node_id.as_str());
}
}
}
if candidate_ids.is_empty() && !query.semantic {
return Ok(Vec::new());
}
if candidate_ids.is_empty() {
self.nodes.iter().collect()
} else {
if vector_results.is_empty() {
self.nodes
.iter()
.filter(|node| candidate_ids.contains(node.node_id.as_str()))
.collect()
} else {
self.nodes
.iter()
.filter(|node| {
candidate_ids.contains(node.node_id.as_str())
|| vector_results.contains_key(&node.node_id)
})
.collect()
}
}
};
for node in candidates {
let text_score = self.calculate_text_score_optimized(
&text_query,
&node.node_id,
&node.symbol_name,
&node.file_path,
);
let semantic_score = if query.semantic {
*vector_results.get(&node.node_id).unwrap_or(&0.0)
} else {
0.0
};
if text_score == 0.0 && !query.semantic && semantic_score == 0.0 {
continue;
}
let structural_score = (node.complexity as f32 / 100.0).min(1.0);
let score = if let Some(qt) = query.query_type {
match qt {
crate::search::ranking::QueryType::Text => {
self.scorer.with_weights(0.2, 0.05, 0.75).score(
semantic_score,
structural_score,
text_score,
)
}
crate::search::ranking::QueryType::Semantic => {
self.scorer.with_weights(0.7, 0.1, 0.2).score(
semantic_score,
structural_score,
text_score,
)
}
crate::search::ranking::QueryType::Structural => {
self.scorer.with_weights(0.3, 0.5, 0.2).score(
semantic_score,
structural_score,
text_score,
)
}
}
} else {
self.scorer
.score(semantic_score, structural_score, text_score)
};
if score.overall > 0.0 {
if let Some(threshold) = query.threshold {
if score.overall < threshold {
continue;
}
}
let signature = node.signature.clone();
results.push(SearchResult {
rank: 0, node_id: node.node_id.clone(),
file_path: node.file_path.clone(),
symbol_name: node.symbol_name.clone(),
symbol_type: None, signature,
complexity: node.complexity,
caller_count: None, dependency_count: None, language: node.language.clone(),
score,
context: None,
byte_range: node.byte_range,
});
}
}
results.sort_by(|a, b| {
b.score
.overall
.partial_cmp(&a.score.overall)
.unwrap_or(std::cmp::Ordering::Equal)
});
let top_k = results.into_iter().take(query.top_k).collect::<Vec<_>>();
let mut final_results = top_k;
for (i, result) in final_results.iter_mut().enumerate() {
result.rank = i + 1;
}
self.search_cache.put(cache_key, final_results.clone());
Ok(final_results)
}
fn calculate_text_score_optimized(
&self,
precomputed: &TextQueryPreprocessed,
node_id: &str,
symbol_name: &str,
file_path: &str,
) -> f32 {
let symbol_boost = if symbol_name
.to_ascii_lowercase()
.contains(&precomputed.query_lower)
{
0.5
} else {
0.0
};
let test_penalty = if file_path.to_ascii_lowercase().contains("test")
|| symbol_name.to_ascii_lowercase().contains("test")
{
0.3
} else {
0.0
};
let base_score = if precomputed.query_tokens.is_empty() {
0.0
} else if let Some(node_tokens) = self.node_tokens.get(node_id) {
let matching = precomputed.query_tokens.intersection(node_tokens).count();
matching as f32 / precomputed.query_tokens.len() as f32
} else {
0.0
};
((base_score + symbol_boost) - test_penalty).clamp(0.0, 1.0)
}
pub fn semantic_search(
&self,
query_embedding: &[f32],
top_k: usize,
) -> Result<Vec<SemanticEntry>, Error> {
if self.vector_index.is_empty() {
return Ok(Vec::new());
}
if query_embedding.len() != self.vector_index.dimension() {
return Err(Error::QueryFailed(format!(
"Embedding dimension mismatch: expected {}, got {}",
self.vector_index.dimension(),
query_embedding.len()
)));
}
let results = self.vector_index.search(query_embedding, top_k);
let entries = results
.into_iter()
.map(|(node_id, score)| {
let entry_type = self
.node_id_to_idx
.get(&node_id)
.and_then(|&idx| self.nodes.get(idx))
.map(|_| EntryType::Function)
.unwrap_or(EntryType::Function);
SemanticEntry {
node_id,
relevance: score,
entry_type,
}
})
.collect();
Ok(entries)
}
#[must_use]
pub fn vector_index(&self) -> &VectorIndexImpl {
&self.vector_index
}
pub fn vector_index_mut(&mut self) -> &mut VectorIndexImpl {
&mut self.vector_index
}
pub fn enable_hnsw(&mut self, params: Option<HNSWParams>) {
let dimension = self.vector_index.dimension();
let params = params.unwrap_or_default();
self.vector_index =
VectorIndexImpl::HNSW(Box::new(HNSWIndex::with_params(dimension, params)));
}
#[must_use]
pub fn is_hnsw_enabled(&self) -> bool {
matches!(
self.vector_index,
VectorIndexImpl::HNSW(_) | VectorIndexImpl::HNSWQuantized(_)
)
}
pub fn disable_hnsw(&mut self) {
let dimension = self.vector_index.dimension();
self.vector_index = VectorIndexImpl::BruteForce(VectorIndex::new(dimension));
}
#[must_use]
pub fn with_hnsw(dimension: usize, params: HNSWParams) -> Self {
let mut engine = Self::with_dimension(dimension);
engine.enable_hnsw(Some(params));
engine
}
pub fn enable_int8_hnsw(&mut self, params: Option<Int8HnswParams>) {
let dimension = self.vector_index.dimension();
let params = params.unwrap_or_default();
self.vector_index =
VectorIndexImpl::HNSWQuantized(Box::new(Int8HnswIndex::with_params(dimension, params)));
}
#[must_use]
pub fn is_quantized(&self) -> bool {
matches!(self.vector_index, VectorIndexImpl::HNSWQuantized(_))
}
#[must_use]
pub fn estimated_memory_bytes(&self) -> usize {
let nodes_size = self.nodes.len() * std::mem::size_of::<NodeInfo>();
let cache_size = self.complexity_cache.len()
* (std::mem::size_of::<String>() + std::mem::size_of::<u32>());
let text_index_size = self
.text_index
.values()
.map(|set| set.len() * std::mem::size_of::<String>())
.sum::<usize>();
nodes_size + cache_size + text_index_size + self.vector_index.estimated_memory_bytes()
}
}
#[derive(Debug, Default)]
pub struct TextIndexDelta {
pub removed_node_ids: Vec<String>,
pub updated_nodes: Vec<NodeInfo>,
}
impl Default for SearchEngine {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum EntryType {
Function,
Method,
Class,
Module,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SemanticEntry {
pub node_id: String,
pub relevance: f32,
pub entry_type: EntryType,
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Query failed: {0}")]
QueryFailed(String),
#[error("Index is empty")]
EmptyIndex,
#[error("Dimension mismatch: expected {expected}, got {got}")]
DimensionMismatch {
expected: usize,
got: usize,
},
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_nodes() -> Vec<NodeInfo> {
vec![
NodeInfo {
node_id: "func1".to_string(),
file_path: "test.rs".to_string(),
symbol_name: "func1".to_string(),
language: "rust".to_string(),
content: "fn func1() { println!(\"hello\"); }".to_string(),
byte_range: (0, 40),
embedding: Some(vec![1.0, 0.0, 0.0]),
complexity: 2,
signature: None,
},
NodeInfo {
node_id: "func2".to_string(),
file_path: "test.rs".to_string(),
symbol_name: "func2".to_string(),
language: "rust".to_string(),
content: "fn func2() { println!(\"world\"); }".to_string(),
byte_range: (42, 82),
embedding: Some(vec![0.0, 1.0, 0.0]),
complexity: 2,
signature: None,
},
]
}
#[test]
fn test_search_engine_creation() {
let engine = SearchEngine::new();
assert_eq!(engine.node_count(), 0);
assert!(engine.is_empty());
}
#[test]
fn test_index_nodes() {
let mut engine = SearchEngine::new();
let nodes = create_test_nodes();
engine.index_nodes(nodes);
assert_eq!(engine.node_count(), 2);
assert!(!engine.is_empty());
}
#[test]
fn test_search_empty_index() {
let mut engine = SearchEngine::new();
let query = SearchQuery {
query: "test".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_semantic_search_empty_index() {
let engine = SearchEngine::new();
let results = engine.semantic_search(&[0.1, 0.2, 0.3], 10).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_search_with_results() {
let mut engine = SearchEngine::new();
let nodes = create_test_nodes();
engine.index_nodes(nodes);
let query = SearchQuery {
query: "func1".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].node_id, "func1");
}
#[test]
fn test_semantic_search() {
let mut engine = SearchEngine::with_dimension(3);
let nodes = create_test_nodes();
engine.index_nodes(nodes);
let results = engine.semantic_search(&[1.0, 0.0, 0.0], 1).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].node_id, "func1");
}
#[test]
fn test_dimension_validation() {
let engine = SearchEngine::with_dimension(128);
assert_eq!(engine.vector_index().dimension(), 128);
}
#[test]
fn test_dimension_mismatch_error() {
let mut engine = SearchEngine::with_dimension(3);
let nodes = create_test_nodes();
engine.index_nodes(nodes);
let result = engine.semantic_search(&[0.1, 0.2], 10);
assert!(result.is_err());
}
#[test]
fn test_hnsw_enable() {
let mut engine = SearchEngine::with_dimension(128);
engine.enable_hnsw(None);
assert!(engine.vector_index().is_hnsw_enabled());
}
#[test]
fn test_top_k_limit() {
let mut engine = SearchEngine::new();
let nodes = create_test_nodes();
engine.index_nodes(nodes);
let query = SearchQuery {
query: "fn".to_string(),
top_k: 1,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn test_relevance_threshold() {
let mut engine = SearchEngine::new();
let nodes = create_test_nodes();
engine.index_nodes(nodes);
let query = SearchQuery {
query: "nonexistent".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: Some(0.5),
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_node_id_to_idx_populated() {
let mut engine = SearchEngine::new();
let nodes = create_test_nodes();
engine.index_nodes(nodes);
assert_eq!(engine.node_id_to_idx.len(), 2);
assert_eq!(engine.node_id_to_idx.get("func1"), Some(&0));
assert_eq!(engine.node_id_to_idx.get("func2"), Some(&1));
}
#[test]
fn test_node_id_to_idx_o1_lookup_in_semantic_search() {
let mut engine = SearchEngine::with_dimension(3);
let nodes = create_test_nodes();
engine.index_nodes(nodes);
let results = engine.semantic_search(&[1.0, 0.0, 0.0], 10).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].node_id, "func1");
assert_eq!(results[0].entry_type, EntryType::Function);
for entry in &results {
assert_eq!(entry.entry_type, EntryType::Function);
}
}
#[test]
fn test_node_id_to_idx_cleared_on_reindex() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
assert_eq!(engine.node_id_to_idx.len(), 2);
engine.index_nodes(vec![NodeInfo {
node_id: "new_func".to_string(),
file_path: "new.rs".to_string(),
symbol_name: "new_func".to_string(),
language: "rust".to_string(),
content: "fn new_func() {}".to_string(),
byte_range: (0, 18),
embedding: None,
complexity: 1,
signature: None,
}]);
assert_eq!(engine.node_id_to_idx.len(), 1);
assert_eq!(engine.node_id_to_idx.get("new_func"), Some(&0));
assert_eq!(engine.node_id_to_idx.get("func1"), None);
}
#[test]
fn test_content_cleared_after_indexing() {
let mut engine = SearchEngine::new();
let nodes = create_test_nodes();
engine.index_nodes(nodes);
for node in &engine.nodes {
assert!(
node.content.is_empty(),
"Node {} content should be cleared after indexing, but got: {:?}",
node.node_id,
node.content
);
}
let query = SearchQuery {
query: "func1".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(
!results.is_empty(),
"Search should still find results via inverted index after content cleared"
);
assert_eq!(results[0].node_id, "func1");
assert!(
!engine.text_index.is_empty(),
"text_index should be populated"
);
assert!(
engine.text_index.contains_key("func1"),
"text_index should contain 'func1' token"
);
assert!(
engine.text_index.contains_key("func2"),
"text_index should contain 'func2' token"
);
}
#[test]
fn test_node_tokens_populated() {
let mut engine = SearchEngine::new();
let nodes = create_test_nodes();
engine.index_nodes(nodes);
assert_eq!(engine.node_tokens.len(), 2);
assert!(engine.node_tokens.contains_key("func1"));
assert!(engine.node_tokens.contains_key("func2"));
let func1_tokens = engine.node_tokens.get("func1").unwrap();
assert!(
func1_tokens.contains("func1"),
"func1 tokens should contain 'func1', got: {:?}",
func1_tokens
);
let func2_tokens = engine.node_tokens.get("func2").unwrap();
assert!(
func2_tokens.contains("func2"),
"func2 tokens should contain 'func2', got: {:?}",
func2_tokens
);
}
#[test]
fn test_node_tokens_cleared_on_reindex() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
assert_eq!(engine.node_tokens.len(), 2);
engine.index_nodes(vec![NodeInfo {
node_id: "new_func".to_string(),
file_path: "test.rs".to_string(),
symbol_name: "new_func".to_string(),
language: "rust".to_string(),
content: "fn new_func() {}".to_string(),
byte_range: (0, 18),
embedding: None,
complexity: 1,
signature: None,
}]);
assert_eq!(engine.node_tokens.len(), 1);
assert!(engine.node_tokens.contains_key("new_func"));
assert!(!engine.node_tokens.contains_key("func1"));
}
#[test]
fn test_node_tokens_used_in_scoring() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
for node in &engine.nodes {
assert!(node.content.is_empty());
}
let query = SearchQuery {
query: "println hello".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(
!results.is_empty(),
"Search should find results using cached node_tokens even after content is cleared"
);
assert_eq!(results[0].node_id, "func1");
}
#[test]
fn test_incremental_reindex_add_nodes() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
assert_eq!(engine.node_count(), 2);
let delta = TextIndexDelta {
removed_node_ids: vec![],
updated_nodes: vec![NodeInfo {
node_id: "func3".to_string(),
file_path: "test.rs".to_string(),
symbol_name: "func3".to_string(),
language: "rust".to_string(),
content: "fn func3() { db_query(); }".to_string(),
byte_range: (100, 130),
embedding: Some(vec![0.0, 0.0, 1.0]),
complexity: 3,
signature: None,
}],
};
engine.incremental_reindex(delta);
assert_eq!(engine.node_count(), 3);
assert_eq!(engine.node_id_to_idx.len(), 3);
assert_eq!(engine.node_tokens.len(), 3);
assert_eq!(engine.complexity_cache.len(), 3);
let query = SearchQuery {
query: "func3".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].node_id, "func3");
assert!(engine.text_index.contains_key("func3"));
assert!(engine.text_index.contains_key("query"));
}
#[test]
fn test_incremental_reindex_remove_nodes() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
assert_eq!(engine.node_count(), 2);
let delta = TextIndexDelta {
removed_node_ids: vec!["func1".to_string()],
updated_nodes: vec![],
};
engine.incremental_reindex(delta);
assert_eq!(engine.node_count(), 1);
assert_eq!(engine.node_id_to_idx.len(), 1);
assert!(!engine.node_id_to_idx.contains_key("func1"));
assert!(engine.node_id_to_idx.contains_key("func2"));
if let Some(ids) = engine.text_index.get("func1") {
assert!(
!ids.contains("func1"),
"func1 should be removed from text_index"
);
}
assert!(!engine.node_tokens.contains_key("func1"));
assert!(engine.node_tokens.contains_key("func2"));
let query = SearchQuery {
query: "func1".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(
results.is_empty(),
"func1 should not be found after removal"
);
}
#[test]
fn test_incremental_reindex_update_existing_node() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
let delta = TextIndexDelta {
removed_node_ids: vec![],
updated_nodes: vec![NodeInfo {
node_id: "func1".to_string(),
file_path: "updated.rs".to_string(),
symbol_name: "func1_renamed".to_string(),
language: "rust".to_string(),
content: "fn func1_renamed() { new_logic(); }".to_string(),
byte_range: (0, 35),
embedding: Some(vec![0.5, 0.5, 0.0]),
complexity: 5,
signature: None,
}],
};
engine.incremental_reindex(delta);
assert_eq!(engine.node_count(), 2);
assert_eq!(engine.complexity_cache.get("func1"), Some(&5));
assert!(engine.node_tokens.get("func1").unwrap().contains("logic"));
assert!(engine.text_index.contains_key("logic"));
let query = SearchQuery {
query: "new_logic".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].node_id, "func1");
}
#[test]
fn test_incremental_reindex_combined_add_remove() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
let delta = TextIndexDelta {
removed_node_ids: vec!["func1".to_string()],
updated_nodes: vec![
NodeInfo {
node_id: "func3".to_string(),
file_path: "new.rs".to_string(),
symbol_name: "func3".to_string(),
language: "rust".to_string(),
content: "fn func3() {}".to_string(),
byte_range: (0, 14),
embedding: None,
complexity: 1,
signature: None,
},
NodeInfo {
node_id: "func4".to_string(),
file_path: "new.rs".to_string(),
symbol_name: "func4".to_string(),
language: "rust".to_string(),
content: "fn func4() { helper(); }".to_string(),
byte_range: (15, 40),
embedding: None,
complexity: 2,
signature: None,
},
],
};
engine.incremental_reindex(delta);
assert_eq!(engine.node_count(), 3);
assert_eq!(engine.node_id_to_idx.len(), 3);
assert!(!engine.node_id_to_idx.contains_key("func1"));
assert!(engine.node_id_to_idx.contains_key("func2"));
assert!(engine.node_id_to_idx.contains_key("func3"));
assert!(engine.node_id_to_idx.contains_key("func4"));
let query = SearchQuery {
query: "func2".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let results = engine.search(query).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].node_id, "func2");
}
#[test]
fn test_incremental_reindex_empty_delta() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
let delta = TextIndexDelta {
removed_node_ids: vec![],
updated_nodes: vec![],
};
engine.incremental_reindex(delta);
assert_eq!(engine.node_count(), 2);
assert_eq!(engine.node_id_to_idx.len(), 2);
}
#[test]
fn test_incremental_reindex_removes_empty_token_sets() {
let mut engine = SearchEngine::new();
engine.index_nodes(vec![
NodeInfo {
node_id: "unique1".to_string(),
file_path: "test.rs".to_string(),
symbol_name: "unique1".to_string(),
language: "rust".to_string(),
content: "fn unique1() { zebra(); }".to_string(),
byte_range: (0, 25),
embedding: None,
complexity: 1,
signature: None,
},
NodeInfo {
node_id: "unique2".to_string(),
file_path: "test.rs".to_string(),
symbol_name: "unique2".to_string(),
language: "rust".to_string(),
content: "fn unique2() { apple(); }".to_string(),
byte_range: (26, 52),
embedding: None,
complexity: 1,
signature: None,
},
]);
assert!(engine.text_index.contains_key("zebra"));
let delta = TextIndexDelta {
removed_node_ids: vec!["unique1".to_string()],
updated_nodes: vec![],
};
engine.incremental_reindex(delta);
assert!(
!engine.text_index.contains_key("zebra"),
"Token with no remaining nodes should be removed from text_index"
);
assert!(engine.text_index.contains_key("apple"));
}
#[test]
fn test_incremental_reindex_correctness_vs_full_rebuild() {
let mut engine_inc = SearchEngine::new();
let mut engine_full = SearchEngine::new();
let initial = create_test_nodes();
engine_inc.index_nodes(initial.clone());
engine_full.index_nodes(initial);
let delta = TextIndexDelta {
removed_node_ids: vec!["func1".to_string()],
updated_nodes: vec![NodeInfo {
node_id: "func3".to_string(),
file_path: "new.rs".to_string(),
symbol_name: "func3".to_string(),
language: "rust".to_string(),
content: "fn func3() { compute(); }".to_string(),
byte_range: (0, 25),
embedding: Some(vec![1.0, 1.0, 0.0]),
complexity: 4,
signature: None,
}],
};
engine_inc.incremental_reindex(delta);
engine_full.index_nodes(vec![
NodeInfo {
node_id: "func2".to_string(),
file_path: "test.rs".to_string(),
symbol_name: "func2".to_string(),
language: "rust".to_string(),
content: "fn func2() { println!(\"world\"); }".to_string(),
byte_range: (42, 82),
embedding: Some(vec![0.0, 1.0, 0.0]),
complexity: 2,
signature: None,
},
NodeInfo {
node_id: "func3".to_string(),
file_path: "new.rs".to_string(),
symbol_name: "func3".to_string(),
language: "rust".to_string(),
content: "fn func3() { compute(); }".to_string(),
byte_range: (0, 25),
embedding: Some(vec![1.0, 1.0, 0.0]),
complexity: 4,
signature: None,
},
]);
assert_eq!(engine_inc.node_count(), engine_full.node_count());
let inc_ids: std::collections::BTreeSet<_> =
engine_inc.nodes.iter().map(|n| n.node_id.clone()).collect();
let full_ids: std::collections::BTreeSet<_> = engine_full
.nodes
.iter()
.map(|n| n.node_id.clone())
.collect();
assert_eq!(inc_ids, full_ids);
let query = SearchQuery {
query: "func2".to_string(),
top_k: 10,
token_budget: None,
semantic: false,
expand_context: false,
query_embedding: None,
threshold: None,
query_type: None,
};
let inc_results = engine_inc.search(query.clone()).unwrap();
let full_results = engine_full.search(query).unwrap();
assert_eq!(inc_results.len(), full_results.len());
if !inc_results.is_empty() {
assert_eq!(inc_results[0].node_id, full_results[0].node_id);
}
let inc_sem = engine_inc.semantic_search(&[1.0, 1.0, 0.0], 10).unwrap();
let full_sem = engine_full.semantic_search(&[1.0, 1.0, 0.0], 10).unwrap();
assert_eq!(inc_sem.len(), full_sem.len());
if !inc_sem.is_empty() {
assert_eq!(inc_sem[0].node_id, full_sem[0].node_id);
}
}
#[test]
fn test_incremental_reindex_semantic_search_after_update() {
let mut engine = SearchEngine::with_dimension(3);
engine.index_nodes(create_test_nodes());
let delta = TextIndexDelta {
removed_node_ids: vec![],
updated_nodes: vec![NodeInfo {
node_id: "func3".to_string(),
file_path: "test.rs".to_string(),
symbol_name: "func3".to_string(),
language: "rust".to_string(),
content: "fn func3() {}".to_string(),
byte_range: (0, 14),
embedding: Some(vec![0.1, 0.1, 0.9]),
complexity: 1,
signature: None,
}],
};
engine.incremental_reindex(delta);
let results = engine.semantic_search(&[0.1, 0.1, 0.9], 1).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].node_id, "func3");
}
#[test]
fn test_incremental_reindex_node_id_to_idx_consistency() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
engine.incremental_reindex(TextIndexDelta {
removed_node_ids: vec![],
updated_nodes: vec![NodeInfo {
node_id: "func3".to_string(),
file_path: "test.rs".to_string(),
symbol_name: "func3".to_string(),
language: "rust".to_string(),
content: "fn func3() {}".to_string(),
byte_range: (0, 14),
embedding: None,
complexity: 1,
signature: None,
}],
});
engine.incremental_reindex(TextIndexDelta {
removed_node_ids: vec!["func1".to_string()],
updated_nodes: vec![],
});
assert_eq!(engine.node_id_to_idx.len(), engine.nodes.len());
for (idx, node) in engine.nodes.iter().enumerate() {
assert_eq!(
engine.node_id_to_idx.get(&node.node_id),
Some(&idx),
"node_id_to_idx mismatch for node {}",
node.node_id
);
}
}
#[test]
fn test_incremental_reindex_removes_nonexistent_node() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
let delta = TextIndexDelta {
removed_node_ids: vec!["nonexistent".to_string()],
updated_nodes: vec![],
};
engine.incremental_reindex(delta);
assert_eq!(engine.node_count(), 2);
assert_eq!(engine.node_id_to_idx.len(), 2);
}
#[test]
fn test_incremental_reindex_content_cleared() {
let mut engine = SearchEngine::new();
engine.index_nodes(create_test_nodes());
engine.incremental_reindex(TextIndexDelta {
removed_node_ids: vec![],
updated_nodes: vec![NodeInfo {
node_id: "func3".to_string(),
file_path: "test.rs".to_string(),
symbol_name: "func3".to_string(),
language: "rust".to_string(),
content: "fn func3() { important_content(); }".to_string(),
byte_range: (0, 40),
embedding: None,
complexity: 3,
signature: None,
}],
});
for node in &engine.nodes {
assert!(
node.content.is_empty(),
"Node {} content should be cleared, got: {:?}",
node.node_id,
node.content
);
}
assert!(engine
.node_tokens
.get("func3")
.unwrap()
.contains("important"));
}
}