use crate::vectordb::{
cache::EmbeddingCache,
db::VectorDB,
embedding::{EmbeddingModel, EmbeddingModelType},
hnsw::{HNSWConfig, HNSWIndex},
search::{
bm25::{build_bm25_index, search_bm25_top_k},
vector::search_with_limit,
},
snippet_extractor::SnippetExtractor,
};
use std::collections::HashMap;
use std::fs;
use tempfile::tempdir;
use anyhow;
#[derive(Clone)]
struct MockEmbeddingModel {
embeddings: HashMap<String, Vec<f32>>,
dimension: usize,
}
impl MockEmbeddingModel {
fn new(dimension: usize) -> Self {
MockEmbeddingModel { embeddings: HashMap::new(), dimension }
}
fn add_embedding(&mut self, text: &str, embedding: Vec<f32>) {
assert_eq!(embedding.len(), self.dimension, "Mock embedding dimension mismatch");
self.embeddings.insert(text.to_string(), embedding);
}
}
impl crate::vectordb::provider::EmbeddingProvider for MockEmbeddingModel {
fn embed(&self, text: &str) -> anyhow::Result<Vec<f32>> {
self.embeddings.get(text)
.cloned()
.ok_or_else(|| anyhow::anyhow!("Mock embedding not found for query: {}", text))
}
fn embed_batch(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
texts.iter().map(|text| self.embed(text)).collect()
}
fn dimension(&self) -> usize {
self.dimension
}
}
fn setup_mock_db_with_content(
content_map: HashMap<String, String>,
embeddings: HashMap<String, Vec<f32>>, dimension: usize,
use_hnsw: bool
) -> (VectorDB, tempfile::TempDir, String) {
let temp_dir = tempdir().expect("Failed to create temp dir for mock db");
let db_path = temp_dir.path().join("mock_db.json").to_string_lossy().to_string();
let cache_path = temp_dir.path().join("cache.json").to_string_lossy().to_string();
let cache = EmbeddingCache::new(cache_path).unwrap();
let mut final_embeddings = HashMap::new();
for (file_name, content) in &content_map {
let file_path = temp_dir.path().join(file_name);
fs::write(&file_path, content).expect("Failed to write mock content file");
let path_str = file_path.to_string_lossy().into_owned();
if let Some(embedding) = embeddings.get(file_name) { final_embeddings.insert(path_str, embedding.clone());
} else {
println!("Warning: No embedding provided for file: {}", file_name);
}
}
let hnsw_index_opt = if use_hnsw && !final_embeddings.is_empty() {
let mut hnsw_index = HNSWIndex::new(HNSWConfig::new(dimension));
let mut sorted_paths: Vec<String> = final_embeddings.keys().cloned().collect();
sorted_paths.sort();
for path in &sorted_paths {
if let Some(embedding) = final_embeddings.get(path) {
hnsw_index.insert(embedding.clone()).unwrap();
}
}
Some(hnsw_index)
} else {
None
};
let db = VectorDB::new_test(
db_path.clone(),
final_embeddings, cache,
hnsw_index_opt,
EmbeddingModelType::Onnx,
);
(db, temp_dir, db_path)
}
#[test]
fn test_vector_search_empty_query() {
let dim = 4;
let content_map: HashMap<String, String> = HashMap::new();
let embeddings: HashMap<String, Vec<f32>> = HashMap::new();
let (db, _temp_dir, _) = setup_mock_db_with_content(content_map, embeddings, dim, true);
let mut model = EmbeddingModel::new_mock(Box::new(MockEmbeddingModel::new(dim)));
let mut snippet_extractor = SnippetExtractor::new();
let results = search_with_limit(&db, &mut model, &mut snippet_extractor, "", 10).unwrap();
assert!(results.is_empty(), "Empty query should return empty results");
}
#[test]
fn test_vector_search_hnsw_path() {
let dim = 4;
let mut mock_provider = MockEmbeddingModel::new(dim);
mock_provider.add_embedding("query1", vec![1.0, 0.0, 0.0, 0.0]);
mock_provider.add_embedding("query2", vec![0.0, 0.0, 1.0, 0.0]);
let mut content_map = HashMap::new();
content_map.insert("file1.txt".to_string(), "content1".to_string());
content_map.insert("file2.txt".to_string(), "content2".to_string());
content_map.insert("file3.txt".to_string(), "content3".to_string());
content_map.insert("file4.txt".to_string(), "content4".to_string());
let mut embeddings = HashMap::new();
embeddings.insert("file1.txt".to_string(), vec![0.9, 0.1, 0.0, 0.0]); embeddings.insert("file2.txt".to_string(), vec![0.0, 0.0, 0.8, 0.2]); embeddings.insert("file3.txt".to_string(), vec![0.6, 0.1, 0.3, 0.0]); embeddings.insert("file4.txt".to_string(), vec![0.1, 0.2, 0.1, 0.6]);
let (db, _temp_dir, _) = setup_mock_db_with_content(content_map, embeddings, dim, true); let mut model = EmbeddingModel::new_mock(Box::new(mock_provider));
let mut snippet_extractor = SnippetExtractor::new();
let results1 = search_with_limit(&db, &mut model, &mut snippet_extractor, "query1", 10).unwrap();
assert!(!results1.is_empty(), "Should find results for query1");
assert!(results1[0].file_path.ends_with("file1.txt"), "Top result for query1 should be file1.txt");
assert!(results1[0].similarity > 0.8, "Similarity for file1.txt should be high"); if results1.len() > 1 {
assert!(results1[1].file_path.ends_with("file3.txt"), "Second result for query1 should be file3.txt");
assert!(results1[1].similarity < results1[0].similarity, "Result 2 sim should be < Result 1 sim");
}
assert!(!results1.iter().any(|r| r.file_path.ends_with("file4.txt") && r.similarity > 0.3), "file4.txt should have low similarity");
let results2 = search_with_limit(&db, &mut model, &mut snippet_extractor, "query2", 10).unwrap();
assert!(!results2.is_empty(), "Should find results for query2");
assert!(results2[0].file_path.ends_with("file2.txt"), "Top result for query2 should be file2.txt");
assert!(results2[0].similarity > 0.7, "Similarity for file2.txt should be high");
}
#[test]
fn test_vector_search_brute_force_path() {
let dim = 4;
let mut mock_provider = MockEmbeddingModel::new(dim);
mock_provider.add_embedding("query1", vec![1.0, 0.0, 0.0, 0.0]);
let mut content_map = HashMap::new();
content_map.insert("bf_file1.txt".to_string(), "content bf1".to_string());
content_map.insert("bf_file2.txt".to_string(), "content bf2".to_string());
let mut embeddings = HashMap::new();
embeddings.insert("bf_file1.txt".to_string(), vec![0.9, 0.1, 0.0, 0.0]); embeddings.insert("bf_file2.txt".to_string(), vec![0.1, 0.1, 0.9, 0.0]);
let (db, _temp_dir, _) = setup_mock_db_with_content(content_map, embeddings, dim, false);
assert!(db.hnsw_index.is_none(), "HNSW index should be None for brute force test");
let mut model = EmbeddingModel::new_mock(Box::new(mock_provider));
let mut snippet_extractor = SnippetExtractor::new();
let results = search_with_limit(&db, &mut model, &mut snippet_extractor, "query1", 10).unwrap();
assert!(!results.is_empty(), "Brute force should find results");
assert!(results[0].file_path.ends_with("bf_file1.txt"), "Top result should be bf_file1.txt");
assert!(results[0].similarity > 0.8, "Similarity for bf_file1.txt should be high");
if results.len() > 1 {
assert!(results.iter().any(|r| r.file_path.ends_with("bf_file2.txt")), "bf_file2.txt should be present if not filtered");
assert!(results.iter().find(|r| r.file_path.ends_with("bf_file2.txt")).unwrap().similarity < 0.3, "bf_file2.txt should have low similarity");
}
}
#[test]
fn test_vector_search_max_results_limit() {
let dim = 2;
let mut mock_provider = MockEmbeddingModel::new(dim);
mock_provider.add_embedding("query", vec![1.0, 0.0]);
let mut content_map = HashMap::new();
let mut embeddings = HashMap::new();
for i in 0..5 {
let filename = format!("limit_file_{}.txt", i);
content_map.insert(filename.clone(), format!("content {}", i));
let sim = 0.9 - (i as f32 * 0.1);
embeddings.insert(filename, vec![sim, (1.0 - sim*sim).sqrt()]);
}
let (db, _temp_dir, _) = setup_mock_db_with_content(content_map, embeddings, dim, true);
let mut model = EmbeddingModel::new_mock(Box::new(mock_provider));
let mut snippet_extractor = SnippetExtractor::new();
let limit = 3;
let results = search_with_limit(&db, &mut model, &mut snippet_extractor, "query", limit).unwrap();
assert_eq!(results.len(), limit, "Number of results should be equal to the limit");
assert!(results[0].similarity > results[1].similarity);
assert!(results[1].similarity > results[2].similarity);
}
#[test]
fn test_vector_search_similarity_threshold() {
let dim = 2;
let mut mock_provider = MockEmbeddingModel::new(dim);
mock_provider.add_embedding("query", vec![1.0, 0.0]);
let temp_dir = tempdir().unwrap();
let db_path = temp_dir.path().join("thresh_db.json").to_string_lossy().to_string();
let cache_path = temp_dir.path().join("thresh_cache.json").to_string_lossy().to_string();
let cache = EmbeddingCache::new(cache_path).unwrap();
let mut embeddings: HashMap<String, Vec<f32>> = HashMap::new();
embeddings.insert("high_sim.txt".to_string(), vec![0.9, 0.435]);
embeddings.insert("low_sim.txt".to_string(), vec![0.1, 0.995]);
embeddings.insert("medium_sim.txt".to_string(), vec![0.5, 0.866]);
let db = VectorDB::new_test(
db_path,
embeddings,
cache,
None,
EmbeddingModelType::Onnx
);
let mut model = EmbeddingModel::new_mock(Box::new(mock_provider));
let mut snippet_extractor = SnippetExtractor::new();
let results = search_with_limit(&db, &mut model, &mut snippet_extractor, "query", 10).unwrap();
assert!(results.iter().any(|r| r.file_path == "high_sim.txt"), "High similarity file should be present");
assert!(results.iter().any(|r| r.file_path == "medium_sim.txt"), "Medium similarity file should be present");
assert!(!results.iter().any(|r| r.file_path == "low_sim.txt"), "Low similarity file should be filtered out by threshold");
assert_eq!(results.len(), 2, "Only results above threshold should remain");
}
#[test]
fn test_bm25_index_building() {
let mut content_map = HashMap::new();
content_map.insert("doc1.txt".to_string(), "the quick brown fox".to_string());
content_map.insert("doc2.txt".to_string(), "jumps over the lazy fox".to_string());
content_map.insert("doc3.txt".to_string(), "the lazy dog".to_string());
let empty_embeddings: HashMap<String, Vec<f32>> = HashMap::new();
let (_db, _temp_dir, db_path) = setup_mock_db_with_content(content_map.clone(), empty_embeddings.clone(), 4, false);
let temp_path = _temp_dir.path();
let adjusted_embeddings = content_map.keys().map(|fname| {
let full_path = temp_path.join(fname).to_string_lossy().into_owned();
(full_path, vec![0.0; 4])
}).collect();
let cache_path = temp_path.join("bm25_cache.json").to_string_lossy().to_string();
let cache = EmbeddingCache::new(cache_path).unwrap();
let db_for_bm25 = VectorDB::new_test(
db_path,
adjusted_embeddings,
cache,
None,
EmbeddingModelType::Onnx,
);
let bm25_result = build_bm25_index(&db_for_bm25);
assert!(bm25_result.is_ok(), "BM25 index build failed: {:?}", bm25_result.err());
let bm25_index = bm25_result.unwrap();
assert_eq!(bm25_index.total_docs, 3);
assert_eq!(bm25_index.doc_data.len(), 3);
assert!(bm25_index.avg_doc_length > 0.0);
assert!(bm25_index.idf.contains_key("fox"));
assert!(bm25_index.idf.contains_key("lazy"));
assert!(bm25_index.idf.contains_key("the"));
assert!(bm25_index.idf["quick"] > bm25_index.idf["fox"]);
assert!(bm25_index.idf["dog"] > bm25_index.idf["lazy"]);
assert!(bm25_index.idf["lazy"] > bm25_index.idf["the"]);
}
#[test]
fn test_bm25_search() {
let mut content_map = HashMap::new();
content_map.insert("bm_doc1.txt".to_string(), "search algorithms are fun".to_string());
content_map.insert("bm_doc2.txt".to_string(), "fun search index test".to_string());
content_map.insert("bm_doc3.txt".to_string(), "another test document".to_string());
let empty_embeddings: HashMap<String, Vec<f32>> = HashMap::new();
let (_db, _temp_dir, db_path) = setup_mock_db_with_content(content_map.clone(), empty_embeddings.clone(), 4, false);
let temp_path = _temp_dir.path();
let adjusted_embeddings = content_map.keys().map(|fname| {
(temp_path.join(fname).to_string_lossy().into_owned(), vec![0.0; 4])
}).collect();
let cache_path = temp_path.join("bm25_s_cache.json").to_string_lossy().to_string();
let cache = EmbeddingCache::new(cache_path).unwrap();
let db_for_bm25 = VectorDB::new_test(
db_path, adjusted_embeddings, cache, None, EmbeddingModelType::Onnx
);
let bm25_index = build_bm25_index(&db_for_bm25).unwrap();
let results1 = search_bm25_top_k("fun search", &bm25_index, 10).unwrap();
assert_eq!(results1.len(), 2);
assert!(results1.iter().any(|(p, _)| p.ends_with("bm_doc1.txt")), "bm_doc1 should be present for 'fun search'");
assert!(results1.iter().any(|(p, _)| p.ends_with("bm_doc2.txt")), "bm_doc2 should be present for 'fun search'");
assert!(results1[0].1 >= results1[1].1, "Scores should be non-increasing for 'fun search'");
let results2 = search_bm25_top_k("algorithms", &bm25_index, 10).unwrap();
assert_eq!(results2.len(), 1);
assert!(results2[0].0.ends_with("bm_doc1.txt"));
let results3 = search_bm25_top_k("test", &bm25_index, 10).unwrap();
assert_eq!(results3.len(), 2);
assert!(results3.iter().any(|(p, _)| p.ends_with("bm_doc2.txt")), "bm_doc2 should be present for 'test'");
assert!(results3.iter().any(|(p, _)| p.ends_with("bm_doc3.txt")), "bm_doc3 should be present for 'test'");
assert!(results3[0].1 >= results3[1].1, "Scores should be non-increasing for 'test'");
let results4 = search_bm25_top_k("", &bm25_index, 10).unwrap();
assert!(results4.is_empty());
let results5 = search_bm25_top_k("nonexistent term", &bm25_index, 10).unwrap();
assert!(results5.is_empty());
}