pub mod hnsw;
mod rrf;
pub use hnsw::{HnswConfig, HnswIndex, HnswResult};
pub use rrf::{RrfConfig, reciprocal_rank_fusion, weighted_rrf};
use crate::embedding::{Embedder, cosine_similarity};
use crate::error::Result;
use crate::storage::{SqliteStorage, Storage};
pub const DEFAULT_SIMILARITY_THRESHOLD: f32 = 0.3;
pub const DEFAULT_TOP_K: usize = 10;
#[derive(Debug, Clone)]
pub struct SearchResult {
pub chunk_id: i64,
pub buffer_id: i64,
pub index: usize,
pub score: f64,
pub semantic_score: Option<f32>,
pub bm25_score: Option<f64>,
pub content_preview: Option<String>,
}
#[derive(Debug, Clone)]
pub struct SearchConfig {
pub top_k: usize,
pub similarity_threshold: f32,
pub rrf_k: u32,
pub use_semantic: bool,
pub use_bm25: bool,
}
impl Default for SearchConfig {
fn default() -> Self {
Self {
top_k: DEFAULT_TOP_K,
similarity_threshold: DEFAULT_SIMILARITY_THRESHOLD,
rrf_k: 60,
use_semantic: true,
use_bm25: true,
}
}
}
pub const DEFAULT_PREVIEW_LEN: usize = 150;
impl SearchResult {
fn from_chunk_id(
storage: &SqliteStorage,
chunk_id: i64,
score: f64,
semantic_score: Option<f32>,
bm25_score: Option<f64>,
) -> Option<Self> {
storage
.get_chunk(chunk_id)
.ok()
.flatten()
.map(|chunk| Self {
chunk_id,
buffer_id: chunk.buffer_id,
index: chunk.index,
score,
semantic_score,
bm25_score,
content_preview: None,
})
}
}
pub fn populate_previews(
storage: &SqliteStorage,
results: &mut [SearchResult],
preview_len: usize,
) -> Result<()> {
for result in results.iter_mut() {
if let Some(chunk) = storage.get_chunk(result.chunk_id)? {
let content = &chunk.content;
let preview = if content.len() <= preview_len {
content.clone()
} else {
let end = crate::io::find_char_boundary(content, preview_len);
let mut preview = content[..end].to_string();
if end < content.len() {
preview.push_str("...");
}
preview
};
result.content_preview = Some(preview);
}
}
Ok(())
}
impl SearchConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub const fn with_top_k(mut self, top_k: usize) -> Self {
self.top_k = top_k;
self
}
#[must_use]
pub const fn with_threshold(mut self, threshold: f32) -> Self {
self.similarity_threshold = threshold;
self
}
#[must_use]
pub const fn with_rrf_k(mut self, k: u32) -> Self {
self.rrf_k = k;
self
}
#[must_use]
pub const fn with_semantic(mut self, enabled: bool) -> Self {
self.use_semantic = enabled;
self
}
#[must_use]
pub const fn with_bm25(mut self, enabled: bool) -> Self {
self.use_bm25 = enabled;
self
}
}
pub fn hybrid_search(
storage: &SqliteStorage,
embedder: &dyn Embedder,
query: &str,
config: &SearchConfig,
) -> Result<Vec<SearchResult>> {
let mut semantic_results: Vec<(i64, f32)> = Vec::new();
let mut bm25_results: Vec<(i64, f64)> = Vec::new();
if config.use_semantic {
semantic_results = semantic_search(storage, embedder, query, config)?;
}
if config.use_bm25 {
bm25_results = storage.search_fts(query, config.top_k * 2)?;
}
if !config.use_semantic {
return Ok(bm25_results
.into_iter()
.take(config.top_k)
.filter_map(|(chunk_id, score)| {
SearchResult::from_chunk_id(storage, chunk_id, score, None, Some(score))
})
.collect());
}
if !config.use_bm25 {
return Ok(semantic_results
.into_iter()
.take(config.top_k)
.filter_map(|(chunk_id, score)| {
SearchResult::from_chunk_id(storage, chunk_id, f64::from(score), Some(score), None)
})
.collect());
}
let rrf_config = RrfConfig::new(config.rrf_k);
let semantic_ranked: Vec<i64> = semantic_results.iter().map(|(id, _)| *id).collect();
let bm25_ranked: Vec<i64> = bm25_results.iter().map(|(id, _)| *id).collect();
let fused = reciprocal_rank_fusion(&[&semantic_ranked, &bm25_ranked], &rrf_config);
let semantic_map: std::collections::HashMap<i64, f32> = semantic_results.into_iter().collect();
let bm25_map: std::collections::HashMap<i64, f64> = bm25_results.into_iter().collect();
let results: Vec<SearchResult> = fused
.into_iter()
.take(config.top_k)
.filter_map(|(chunk_id, rrf_score)| {
SearchResult::from_chunk_id(
storage,
chunk_id,
rrf_score,
semantic_map.get(&chunk_id).copied(),
bm25_map.get(&chunk_id).copied(),
)
})
.collect();
Ok(results)
}
fn semantic_search(
storage: &SqliteStorage,
embedder: &dyn Embedder,
query: &str,
config: &SearchConfig,
) -> Result<Vec<(i64, f32)>> {
let query_embedding = embedder.embed(query)?;
let all_embeddings = storage.get_all_embeddings()?;
if all_embeddings.is_empty() {
return Ok(Vec::new());
}
let mut similarities: Vec<(i64, f32)> = all_embeddings
.iter()
.map(|(chunk_id, embedding)| {
let sim = cosine_similarity(&query_embedding, embedding);
(*chunk_id, sim)
})
.filter(|(_, sim)| *sim >= config.similarity_threshold)
.collect();
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
similarities.truncate(config.top_k * 2);
Ok(similarities)
}
pub fn search_semantic(
storage: &SqliteStorage,
embedder: &dyn Embedder,
query: &str,
top_k: usize,
threshold: f32,
) -> Result<Vec<SearchResult>> {
let config = SearchConfig::new()
.with_top_k(top_k)
.with_threshold(threshold)
.with_semantic(true)
.with_bm25(false);
hybrid_search(storage, embedder, query, &config)
}
pub fn search_bm25(
storage: &SqliteStorage,
query: &str,
top_k: usize,
) -> Result<Vec<SearchResult>> {
let results = storage.search_fts(query, top_k)?;
Ok(results
.into_iter()
.filter_map(|(chunk_id, score)| {
SearchResult::from_chunk_id(storage, chunk_id, score, None, Some(score))
})
.collect())
}
pub fn embed_buffer_chunks(
storage: &mut SqliteStorage,
embedder: &dyn Embedder,
buffer_id: i64,
) -> Result<usize> {
let chunks = storage.get_chunks(buffer_id)?;
if chunks.is_empty() {
return Ok(0);
}
let texts: Vec<&str> = chunks.iter().map(|c| c.content.as_str()).collect();
let embeddings = embedder.embed_batch(&texts)?;
let batch: Vec<(i64, Vec<f32>)> = chunks
.iter()
.zip(embeddings)
.filter_map(|(chunk, embedding)| chunk.id.map(|id| (id, embedding)))
.collect();
let count = batch.len();
storage.store_embeddings_batch(&batch, Some(embedder.model_name()))?;
Ok(count)
}
pub fn buffer_fully_embedded(storage: &SqliteStorage, buffer_id: i64) -> Result<bool> {
let chunk_count = storage.chunk_count(buffer_id)?;
if chunk_count == 0 {
return Ok(true);
}
let chunks = storage.get_chunks(buffer_id)?;
let mut embedded_count = 0;
for chunk in &chunks {
if let Some(id) = chunk.id
&& storage.has_embedding(id)?
{
embedded_count += 1;
}
}
Ok(embedded_count == chunk_count)
}
pub fn check_model_mismatch(
storage: &SqliteStorage,
buffer_id: i64,
current_model: &str,
) -> Result<Option<String>> {
let models = storage.get_embedding_models(buffer_id)?;
if models.is_empty() {
return Ok(None);
}
for model in models {
if model != current_model {
return Ok(Some(model));
}
}
Ok(None)
}
#[derive(Debug, Clone)]
pub struct EmbeddingModelInfo {
pub models: Vec<(Option<String>, i64)>,
pub total_embeddings: i64,
pub has_mixed_models: bool,
}
pub fn get_embedding_model_info(
storage: &SqliteStorage,
buffer_id: i64,
) -> Result<EmbeddingModelInfo> {
let models = storage.get_embedding_model_counts(buffer_id)?;
let total_embeddings: i64 = models.iter().map(|(_, count)| count).sum();
let distinct_models: std::collections::HashSet<_> =
models.iter().map(|(name, _)| name.as_deref()).collect();
let has_mixed_models = distinct_models.len() > 1;
Ok(EmbeddingModelInfo {
models,
total_embeddings,
has_mixed_models,
})
}
#[derive(Debug, Clone)]
pub struct IncrementalEmbedResult {
pub embedded_count: usize,
pub skipped_count: usize,
pub replaced_count: usize,
pub total_chunks: usize,
pub model_name: String,
}
impl IncrementalEmbedResult {
#[must_use]
pub const fn had_changes(&self) -> bool {
self.embedded_count > 0 || self.replaced_count > 0
}
#[must_use]
#[allow(clippy::cast_precision_loss)] pub fn completion_percentage(&self) -> f64 {
if self.total_chunks == 0 {
100.0
} else {
let completed = self.embedded_count + self.skipped_count + self.replaced_count;
(completed as f64 / self.total_chunks as f64) * 100.0
}
}
}
pub fn embed_buffer_chunks_incremental(
storage: &mut SqliteStorage,
embedder: &dyn Embedder,
buffer_id: i64,
force_reembed: bool,
) -> Result<IncrementalEmbedResult> {
let current_model = embedder.model_name();
let stats = storage.get_embedding_stats(buffer_id)?;
let total_chunks = stats.total_chunks;
let model_to_check = if force_reembed {
Some(current_model)
} else {
None
};
let chunk_ids_to_embed = storage.get_chunks_needing_embedding(buffer_id, model_to_check)?;
if chunk_ids_to_embed.is_empty() {
return Ok(IncrementalEmbedResult {
embedded_count: 0,
skipped_count: total_chunks,
replaced_count: 0,
total_chunks,
model_name: current_model.to_string(),
});
}
let all_chunks = storage.get_chunks(buffer_id)?;
let chunks_to_embed: Vec<_> = all_chunks
.iter()
.filter(|c| c.id.is_some_and(|id| chunk_ids_to_embed.contains(&id)))
.collect();
let mut replaced_count = 0;
for chunk in &chunks_to_embed {
if let Some(id) = chunk.id
&& storage.has_embedding(id)?
{
replaced_count += 1;
}
}
let texts: Vec<&str> = chunks_to_embed.iter().map(|c| c.content.as_str()).collect();
let embeddings = embedder.embed_batch(&texts)?;
let batch: Vec<(i64, Vec<f32>)> = chunks_to_embed
.iter()
.zip(embeddings)
.filter_map(|(chunk, embedding)| chunk.id.map(|id| (id, embedding)))
.collect();
let embedded_count = batch.len();
storage.store_embeddings_batch(&batch, Some(current_model))?;
let new_embeddings = embedded_count - replaced_count;
let skipped_count = total_chunks - embedded_count;
Ok(IncrementalEmbedResult {
embedded_count: new_embeddings,
skipped_count,
replaced_count,
total_chunks,
model_name: current_model.to_string(),
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{Buffer, Chunk};
use crate::embedding::{DEFAULT_DIMENSIONS, FallbackEmbedder};
use crate::storage::Storage;
fn setup_storage() -> SqliteStorage {
let mut storage = SqliteStorage::in_memory().unwrap();
storage.init().unwrap();
storage
}
fn setup_storage_with_chunks() -> SqliteStorage {
let mut storage = setup_storage();
let buffer = Buffer::from_named(
"test.txt".to_string(),
"Test content for searching".to_string(),
);
let buffer_id = storage.add_buffer(&buffer).unwrap();
let chunks = vec![
Chunk::new(
buffer_id,
"The quick brown fox jumps over the lazy dog".to_string(),
0..44,
0,
),
Chunk::new(
buffer_id,
"Machine learning is a subset of artificial intelligence".to_string(),
44..100,
1,
),
Chunk::new(
buffer_id,
"Rust is a systems programming language".to_string(),
100..139,
2,
),
];
storage.add_chunks(buffer_id, &chunks).unwrap();
storage
}
#[test]
fn test_search_config_default() {
let config = SearchConfig::default();
assert_eq!(config.top_k, DEFAULT_TOP_K);
assert!((config.similarity_threshold - DEFAULT_SIMILARITY_THRESHOLD).abs() < f32::EPSILON);
assert_eq!(config.rrf_k, 60);
assert!(config.use_semantic);
assert!(config.use_bm25);
}
#[test]
fn test_search_config_builder() {
let config = SearchConfig::new()
.with_top_k(20)
.with_threshold(0.5)
.with_rrf_k(30)
.with_semantic(false)
.with_bm25(true);
assert_eq!(config.top_k, 20);
assert!((config.similarity_threshold - 0.5).abs() < f32::EPSILON);
assert_eq!(config.rrf_k, 30);
assert!(!config.use_semantic);
assert!(config.use_bm25);
}
#[test]
fn test_search_bm25() {
let storage = setup_storage_with_chunks();
let results = search_bm25(&storage, "fox", 10).unwrap();
assert!(!results.is_empty());
assert!(results[0].bm25_score.is_some());
assert!(results[0].semantic_score.is_none());
}
#[test]
fn test_search_bm25_no_results() {
let storage = setup_storage_with_chunks();
let results = search_bm25(&storage, "xyz123nonexistent", 10).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_embed_buffer_chunks() {
let mut storage = setup_storage_with_chunks();
let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
let count = embed_buffer_chunks(&mut storage, &embedder, 1).unwrap();
assert_eq!(count, 3); }
#[test]
fn test_embed_buffer_chunks_empty() {
let mut storage = setup_storage();
let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
let buffer = Buffer::from_named("empty.txt".to_string(), String::new());
let buffer_id = storage.add_buffer(&buffer).unwrap();
let count = embed_buffer_chunks(&mut storage, &embedder, buffer_id).unwrap();
assert_eq!(count, 0);
}
#[test]
fn test_buffer_fully_embedded_empty() {
let mut storage = setup_storage();
let buffer = Buffer::from_named("empty.txt".to_string(), String::new());
let buffer_id = storage.add_buffer(&buffer).unwrap();
let result = buffer_fully_embedded(&storage, buffer_id).unwrap();
assert!(result);
}
#[test]
fn test_buffer_fully_embedded_with_embeddings() {
let mut storage = setup_storage_with_chunks();
let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
let result = buffer_fully_embedded(&storage, 1).unwrap();
assert!(!result);
embed_buffer_chunks(&mut storage, &embedder, 1).unwrap();
let result = buffer_fully_embedded(&storage, 1).unwrap();
assert!(result);
}
#[test]
fn test_hybrid_search_bm25_only() {
let storage = setup_storage_with_chunks();
let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
let config = SearchConfig::new().with_semantic(false).with_bm25(true);
let results = hybrid_search(&storage, &embedder, "programming", &config).unwrap();
assert!(!results.is_empty());
assert!(results[0].bm25_score.is_some());
assert!(results[0].semantic_score.is_none());
}
#[test]
fn test_hybrid_search_semantic_only() {
let mut storage = setup_storage_with_chunks();
let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
embed_buffer_chunks(&mut storage, &embedder, 1).unwrap();
let config = SearchConfig::new()
.with_semantic(true)
.with_bm25(false)
.with_threshold(0.0);
let results = hybrid_search(&storage, &embedder, "programming language", &config).unwrap();
assert!(!results.is_empty());
assert!(results[0].semantic_score.is_some());
assert!(results[0].bm25_score.is_none());
}
#[test]
fn test_hybrid_search_both() {
let mut storage = setup_storage_with_chunks();
let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
embed_buffer_chunks(&mut storage, &embedder, 1).unwrap();
let config = SearchConfig::new()
.with_semantic(true)
.with_bm25(true)
.with_threshold(0.0);
let results = hybrid_search(&storage, &embedder, "programming", &config).unwrap();
assert!(!results.is_empty());
}
#[test]
fn test_search_semantic() {
let mut storage = setup_storage_with_chunks();
let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
embed_buffer_chunks(&mut storage, &embedder, 1).unwrap();
let results = search_semantic(&storage, &embedder, "test query", 10, 0.0).unwrap();
for result in &results {
assert!(result.semantic_score.is_some());
assert!(result.bm25_score.is_none());
}
}
#[test]
fn test_search_semantic_empty_embeddings() {
let storage = setup_storage_with_chunks();
let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
let results = search_semantic(&storage, &embedder, "test query", 10, 0.5).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_incremental_embed_new_chunks() {
let mut storage = setup_storage_with_chunks();
let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
let result = embed_buffer_chunks_incremental(&mut storage, &embedder, 1, false).unwrap();
assert_eq!(result.embedded_count, 3);
assert_eq!(result.skipped_count, 0);
assert_eq!(result.replaced_count, 0);
assert_eq!(result.total_chunks, 3);
assert!(result.had_changes());
let result2 = embed_buffer_chunks_incremental(&mut storage, &embedder, 1, false).unwrap();
assert_eq!(result2.embedded_count, 0);
assert_eq!(result2.skipped_count, 3);
assert_eq!(result2.replaced_count, 0);
assert!(!result2.had_changes());
}
#[test]
fn test_incremental_embed_force_reembed() {
let mut storage = setup_storage_with_chunks();
let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
embed_buffer_chunks_incremental(&mut storage, &embedder, 1, false).unwrap();
let result = embed_buffer_chunks_incremental(&mut storage, &embedder, 1, true).unwrap();
assert_eq!(result.skipped_count, 3);
assert!(!result.had_changes());
}
#[test]
fn test_incremental_embed_result_completion() {
let result = IncrementalEmbedResult {
embedded_count: 2,
skipped_count: 3,
replaced_count: 0,
total_chunks: 5,
model_name: "test".to_string(),
};
assert!(result.had_changes());
assert!((result.completion_percentage() - 100.0).abs() < f64::EPSILON);
}
}