1pub mod hnsw;
13mod rrf;
14
15pub use hnsw::{HnswConfig, HnswIndex, HnswResult};
16pub use rrf::{RrfConfig, reciprocal_rank_fusion, weighted_rrf};
17
18use crate::embedding::{Embedder, cosine_similarity};
19use crate::error::Result;
20use crate::storage::{SqliteStorage, Storage};
21
22pub const DEFAULT_SIMILARITY_THRESHOLD: f32 = 0.3;
24
25pub const DEFAULT_TOP_K: usize = 10;
27
28#[derive(Debug, Clone)]
30pub struct SearchResult {
31 pub chunk_id: i64,
33 pub buffer_id: i64,
35 pub index: usize,
37 pub score: f64,
39 pub semantic_score: Option<f32>,
41 pub bm25_score: Option<f64>,
43 pub content_preview: Option<String>,
45}
46
47#[derive(Debug, Clone)]
49pub struct SearchConfig {
50 pub top_k: usize,
52 pub similarity_threshold: f32,
54 pub rrf_k: u32,
56 pub use_semantic: bool,
58 pub use_bm25: bool,
60}
61
62impl Default for SearchConfig {
63 fn default() -> Self {
64 Self {
65 top_k: DEFAULT_TOP_K,
66 similarity_threshold: DEFAULT_SIMILARITY_THRESHOLD,
67 rrf_k: 60,
68 use_semantic: true,
69 use_bm25: true,
70 }
71 }
72}
73
74pub const DEFAULT_PREVIEW_LEN: usize = 150;
76
77impl SearchResult {
78 fn from_chunk_id(
82 storage: &SqliteStorage,
83 chunk_id: i64,
84 score: f64,
85 semantic_score: Option<f32>,
86 bm25_score: Option<f64>,
87 ) -> Option<Self> {
88 storage
89 .get_chunk(chunk_id)
90 .ok()
91 .flatten()
92 .map(|chunk| Self {
93 chunk_id,
94 buffer_id: chunk.buffer_id,
95 index: chunk.index,
96 score,
97 semantic_score,
98 bm25_score,
99 content_preview: None,
100 })
101 }
102}
103
104pub fn populate_previews(
116 storage: &SqliteStorage,
117 results: &mut [SearchResult],
118 preview_len: usize,
119) -> Result<()> {
120 for result in results.iter_mut() {
121 if let Some(chunk) = storage.get_chunk(result.chunk_id)? {
122 let content = &chunk.content;
123 let preview = if content.len() <= preview_len {
124 content.clone()
125 } else {
126 let end = crate::io::find_char_boundary(content, preview_len);
128 let mut preview = content[..end].to_string();
129 if end < content.len() {
130 preview.push_str("...");
131 }
132 preview
133 };
134 result.content_preview = Some(preview);
135 }
136 }
137 Ok(())
138}
139
140impl SearchConfig {
141 #[must_use]
143 pub fn new() -> Self {
144 Self::default()
145 }
146
147 #[must_use]
149 pub const fn with_top_k(mut self, top_k: usize) -> Self {
150 self.top_k = top_k;
151 self
152 }
153
154 #[must_use]
156 pub const fn with_threshold(mut self, threshold: f32) -> Self {
157 self.similarity_threshold = threshold;
158 self
159 }
160
161 #[must_use]
163 pub const fn with_rrf_k(mut self, k: u32) -> Self {
164 self.rrf_k = k;
165 self
166 }
167
168 #[must_use]
170 pub const fn with_semantic(mut self, enabled: bool) -> Self {
171 self.use_semantic = enabled;
172 self
173 }
174
175 #[must_use]
177 pub const fn with_bm25(mut self, enabled: bool) -> Self {
178 self.use_bm25 = enabled;
179 self
180 }
181}
182
183pub fn hybrid_search(
196 storage: &SqliteStorage,
197 embedder: &dyn Embedder,
198 query: &str,
199 config: &SearchConfig,
200) -> Result<Vec<SearchResult>> {
201 let mut semantic_results: Vec<(i64, f32)> = Vec::new();
202 let mut bm25_results: Vec<(i64, f64)> = Vec::new();
203
204 if config.use_semantic {
206 semantic_results = semantic_search(storage, embedder, query, config)?;
207 }
208
209 if config.use_bm25 {
211 bm25_results = storage.search_fts(query, config.top_k * 2)?;
212 }
213
214 if !config.use_semantic {
216 return Ok(bm25_results
217 .into_iter()
218 .take(config.top_k)
219 .filter_map(|(chunk_id, score)| {
220 SearchResult::from_chunk_id(storage, chunk_id, score, None, Some(score))
221 })
222 .collect());
223 }
224
225 if !config.use_bm25 {
226 return Ok(semantic_results
227 .into_iter()
228 .take(config.top_k)
229 .filter_map(|(chunk_id, score)| {
230 SearchResult::from_chunk_id(storage, chunk_id, f64::from(score), Some(score), None)
231 })
232 .collect());
233 }
234
235 let rrf_config = RrfConfig::new(config.rrf_k);
237
238 let semantic_ranked: Vec<i64> = semantic_results.iter().map(|(id, _)| *id).collect();
240 let bm25_ranked: Vec<i64> = bm25_results.iter().map(|(id, _)| *id).collect();
241
242 let fused = reciprocal_rank_fusion(&[&semantic_ranked, &bm25_ranked], &rrf_config);
243
244 let semantic_map: std::collections::HashMap<i64, f32> = semantic_results.into_iter().collect();
246 let bm25_map: std::collections::HashMap<i64, f64> = bm25_results.into_iter().collect();
247
248 let results: Vec<SearchResult> = fused
249 .into_iter()
250 .take(config.top_k)
251 .filter_map(|(chunk_id, rrf_score)| {
252 SearchResult::from_chunk_id(
253 storage,
254 chunk_id,
255 rrf_score,
256 semantic_map.get(&chunk_id).copied(),
257 bm25_map.get(&chunk_id).copied(),
258 )
259 })
260 .collect();
261
262 Ok(results)
263}
264
265fn semantic_search(
269 storage: &SqliteStorage,
270 embedder: &dyn Embedder,
271 query: &str,
272 config: &SearchConfig,
273) -> Result<Vec<(i64, f32)>> {
274 let query_embedding = embedder.embed(query)?;
276
277 let all_embeddings = storage.get_all_embeddings()?;
279
280 if all_embeddings.is_empty() {
281 return Ok(Vec::new());
282 }
283
284 let mut similarities: Vec<(i64, f32)> = all_embeddings
286 .iter()
287 .map(|(chunk_id, embedding)| {
288 let sim = cosine_similarity(&query_embedding, embedding);
289 (*chunk_id, sim)
290 })
291 .filter(|(_, sim)| *sim >= config.similarity_threshold)
292 .collect();
293
294 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
296
297 similarities.truncate(config.top_k * 2);
299
300 Ok(similarities)
301}
302
303pub fn search_semantic(
317 storage: &SqliteStorage,
318 embedder: &dyn Embedder,
319 query: &str,
320 top_k: usize,
321 threshold: f32,
322) -> Result<Vec<SearchResult>> {
323 let config = SearchConfig::new()
324 .with_top_k(top_k)
325 .with_threshold(threshold)
326 .with_semantic(true)
327 .with_bm25(false);
328
329 hybrid_search(storage, embedder, query, &config)
330}
331
332pub fn search_bm25(
344 storage: &SqliteStorage,
345 query: &str,
346 top_k: usize,
347) -> Result<Vec<SearchResult>> {
348 let results = storage.search_fts(query, top_k)?;
349
350 Ok(results
351 .into_iter()
352 .filter_map(|(chunk_id, score)| {
353 SearchResult::from_chunk_id(storage, chunk_id, score, None, Some(score))
354 })
355 .collect())
356}
357
358pub fn embed_buffer_chunks(
374 storage: &mut SqliteStorage,
375 embedder: &dyn Embedder,
376 buffer_id: i64,
377) -> Result<usize> {
378 let chunks = storage.get_chunks(buffer_id)?;
379
380 if chunks.is_empty() {
381 return Ok(0);
382 }
383
384 let texts: Vec<&str> = chunks.iter().map(|c| c.content.as_str()).collect();
386
387 let embeddings = embedder.embed_batch(&texts)?;
389
390 let batch: Vec<(i64, Vec<f32>)> = chunks
392 .iter()
393 .zip(embeddings)
394 .filter_map(|(chunk, embedding)| chunk.id.map(|id| (id, embedding)))
395 .collect();
396
397 let count = batch.len();
398
399 storage.store_embeddings_batch(&batch, Some(embedder.model_name()))?;
401
402 Ok(count)
403}
404
405pub fn buffer_fully_embedded(storage: &SqliteStorage, buffer_id: i64) -> Result<bool> {
411 let chunk_count = storage.chunk_count(buffer_id)?;
412 if chunk_count == 0 {
413 return Ok(true);
414 }
415
416 let chunks = storage.get_chunks(buffer_id)?;
418 let mut embedded_count = 0;
419
420 for chunk in &chunks {
421 if let Some(id) = chunk.id
422 && storage.has_embedding(id)?
423 {
424 embedded_count += 1;
425 }
426 }
427
428 Ok(embedded_count == chunk_count)
429}
430
431pub fn check_model_mismatch(
439 storage: &SqliteStorage,
440 buffer_id: i64,
441 current_model: &str,
442) -> Result<Option<String>> {
443 let models = storage.get_embedding_models(buffer_id)?;
444
445 if models.is_empty() {
447 return Ok(None);
448 }
449
450 for model in models {
452 if model != current_model {
453 return Ok(Some(model));
454 }
455 }
456
457 Ok(None)
458}
459
460#[derive(Debug, Clone)]
462pub struct EmbeddingModelInfo {
463 pub models: Vec<(Option<String>, i64)>,
465 pub total_embeddings: i64,
467 pub has_mixed_models: bool,
469}
470
471pub fn get_embedding_model_info(
477 storage: &SqliteStorage,
478 buffer_id: i64,
479) -> Result<EmbeddingModelInfo> {
480 let models = storage.get_embedding_model_counts(buffer_id)?;
481 let total_embeddings: i64 = models.iter().map(|(_, count)| count).sum();
482 let distinct_models: std::collections::HashSet<_> =
483 models.iter().map(|(name, _)| name.as_deref()).collect();
484 let has_mixed_models = distinct_models.len() > 1;
485
486 Ok(EmbeddingModelInfo {
487 models,
488 total_embeddings,
489 has_mixed_models,
490 })
491}
492
493#[derive(Debug, Clone)]
495pub struct IncrementalEmbedResult {
496 pub embedded_count: usize,
498 pub skipped_count: usize,
500 pub replaced_count: usize,
502 pub total_chunks: usize,
504 pub model_name: String,
506}
507
508impl IncrementalEmbedResult {
509 #[must_use]
511 pub const fn had_changes(&self) -> bool {
512 self.embedded_count > 0 || self.replaced_count > 0
513 }
514
515 #[must_use]
517 #[allow(clippy::cast_precision_loss)] pub fn completion_percentage(&self) -> f64 {
519 if self.total_chunks == 0 {
520 100.0
521 } else {
522 let completed = self.embedded_count + self.skipped_count + self.replaced_count;
523 (completed as f64 / self.total_chunks as f64) * 100.0
524 }
525 }
526}
527
528pub fn embed_buffer_chunks_incremental(
552 storage: &mut SqliteStorage,
553 embedder: &dyn Embedder,
554 buffer_id: i64,
555 force_reembed: bool,
556) -> Result<IncrementalEmbedResult> {
557 let current_model = embedder.model_name();
558 let stats = storage.get_embedding_stats(buffer_id)?;
559 let total_chunks = stats.total_chunks;
560
561 let model_to_check = if force_reembed {
563 Some(current_model)
564 } else {
565 None
566 };
567
568 let chunk_ids_to_embed = storage.get_chunks_needing_embedding(buffer_id, model_to_check)?;
569
570 if chunk_ids_to_embed.is_empty() {
571 return Ok(IncrementalEmbedResult {
572 embedded_count: 0,
573 skipped_count: total_chunks,
574 replaced_count: 0,
575 total_chunks,
576 model_name: current_model.to_string(),
577 });
578 }
579
580 let all_chunks = storage.get_chunks(buffer_id)?;
582 let chunks_to_embed: Vec<_> = all_chunks
583 .iter()
584 .filter(|c| c.id.is_some_and(|id| chunk_ids_to_embed.contains(&id)))
585 .collect();
586
587 let mut replaced_count = 0;
589 for chunk in &chunks_to_embed {
590 if let Some(id) = chunk.id
591 && storage.has_embedding(id)?
592 {
593 replaced_count += 1;
594 }
595 }
596
597 let texts: Vec<&str> = chunks_to_embed.iter().map(|c| c.content.as_str()).collect();
599 let embeddings = embedder.embed_batch(&texts)?;
600
601 let batch: Vec<(i64, Vec<f32>)> = chunks_to_embed
603 .iter()
604 .zip(embeddings)
605 .filter_map(|(chunk, embedding)| chunk.id.map(|id| (id, embedding)))
606 .collect();
607
608 let embedded_count = batch.len();
609 storage.store_embeddings_batch(&batch, Some(current_model))?;
610
611 let new_embeddings = embedded_count - replaced_count;
612 let skipped_count = total_chunks - embedded_count;
613
614 Ok(IncrementalEmbedResult {
615 embedded_count: new_embeddings,
616 skipped_count,
617 replaced_count,
618 total_chunks,
619 model_name: current_model.to_string(),
620 })
621}
622
623#[cfg(test)]
624mod tests {
625 use super::*;
626 use crate::core::{Buffer, Chunk};
627 use crate::embedding::{DEFAULT_DIMENSIONS, FallbackEmbedder};
628 use crate::storage::Storage;
629
630 fn setup_storage() -> SqliteStorage {
631 let mut storage = SqliteStorage::in_memory().unwrap();
632 storage.init().unwrap();
633 storage
634 }
635
636 fn setup_storage_with_chunks() -> SqliteStorage {
637 let mut storage = setup_storage();
638
639 let buffer = Buffer::from_named(
641 "test.txt".to_string(),
642 "Test content for searching".to_string(),
643 );
644 let buffer_id = storage.add_buffer(&buffer).unwrap();
645
646 let chunks = vec![
648 Chunk::new(
649 buffer_id,
650 "The quick brown fox jumps over the lazy dog".to_string(),
651 0..44,
652 0,
653 ),
654 Chunk::new(
655 buffer_id,
656 "Machine learning is a subset of artificial intelligence".to_string(),
657 44..100,
658 1,
659 ),
660 Chunk::new(
661 buffer_id,
662 "Rust is a systems programming language".to_string(),
663 100..139,
664 2,
665 ),
666 ];
667
668 storage.add_chunks(buffer_id, &chunks).unwrap();
669
670 storage
671 }
672
673 #[test]
674 fn test_search_config_default() {
675 let config = SearchConfig::default();
676 assert_eq!(config.top_k, DEFAULT_TOP_K);
677 assert!((config.similarity_threshold - DEFAULT_SIMILARITY_THRESHOLD).abs() < f32::EPSILON);
678 assert_eq!(config.rrf_k, 60);
679 assert!(config.use_semantic);
680 assert!(config.use_bm25);
681 }
682
683 #[test]
684 fn test_search_config_builder() {
685 let config = SearchConfig::new()
686 .with_top_k(20)
687 .with_threshold(0.5)
688 .with_rrf_k(30)
689 .with_semantic(false)
690 .with_bm25(true);
691
692 assert_eq!(config.top_k, 20);
693 assert!((config.similarity_threshold - 0.5).abs() < f32::EPSILON);
694 assert_eq!(config.rrf_k, 30);
695 assert!(!config.use_semantic);
696 assert!(config.use_bm25);
697 }
698
699 #[test]
700 fn test_search_bm25() {
701 let storage = setup_storage_with_chunks();
702
703 let results = search_bm25(&storage, "fox", 10).unwrap();
705 assert!(!results.is_empty());
706 assert!(results[0].bm25_score.is_some());
707 assert!(results[0].semantic_score.is_none());
708 }
709
710 #[test]
711 fn test_search_bm25_no_results() {
712 let storage = setup_storage_with_chunks();
713
714 let results = search_bm25(&storage, "xyz123nonexistent", 10).unwrap();
716 assert!(results.is_empty());
717 }
718
719 #[test]
720 fn test_embed_buffer_chunks() {
721 let mut storage = setup_storage_with_chunks();
722 let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
723
724 let count = embed_buffer_chunks(&mut storage, &embedder, 1).unwrap();
726 assert_eq!(count, 3); }
728
729 #[test]
730 fn test_embed_buffer_chunks_empty() {
731 let mut storage = setup_storage();
732 let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
733
734 let buffer = Buffer::from_named("empty.txt".to_string(), String::new());
736 let buffer_id = storage.add_buffer(&buffer).unwrap();
737
738 let count = embed_buffer_chunks(&mut storage, &embedder, buffer_id).unwrap();
739 assert_eq!(count, 0);
740 }
741
742 #[test]
743 fn test_buffer_fully_embedded_empty() {
744 let mut storage = setup_storage();
745
746 let buffer = Buffer::from_named("empty.txt".to_string(), String::new());
748 let buffer_id = storage.add_buffer(&buffer).unwrap();
749
750 let result = buffer_fully_embedded(&storage, buffer_id).unwrap();
752 assert!(result);
753 }
754
755 #[test]
756 fn test_buffer_fully_embedded_with_embeddings() {
757 let mut storage = setup_storage_with_chunks();
758 let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
759
760 let result = buffer_fully_embedded(&storage, 1).unwrap();
762 assert!(!result);
763
764 embed_buffer_chunks(&mut storage, &embedder, 1).unwrap();
766
767 let result = buffer_fully_embedded(&storage, 1).unwrap();
769 assert!(result);
770 }
771
772 #[test]
773 fn test_hybrid_search_bm25_only() {
774 let storage = setup_storage_with_chunks();
775 let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
776
777 let config = SearchConfig::new().with_semantic(false).with_bm25(true);
778
779 let results = hybrid_search(&storage, &embedder, "programming", &config).unwrap();
780 assert!(!results.is_empty());
782 assert!(results[0].bm25_score.is_some());
783 assert!(results[0].semantic_score.is_none());
784 }
785
786 #[test]
787 fn test_hybrid_search_semantic_only() {
788 let mut storage = setup_storage_with_chunks();
789 let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
790
791 embed_buffer_chunks(&mut storage, &embedder, 1).unwrap();
793
794 let config = SearchConfig::new()
795 .with_semantic(true)
796 .with_bm25(false)
797 .with_threshold(0.0); let results = hybrid_search(&storage, &embedder, "programming language", &config).unwrap();
800 assert!(!results.is_empty());
801 assert!(results[0].semantic_score.is_some());
802 assert!(results[0].bm25_score.is_none());
803 }
804
805 #[test]
806 fn test_hybrid_search_both() {
807 let mut storage = setup_storage_with_chunks();
808 let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
809
810 embed_buffer_chunks(&mut storage, &embedder, 1).unwrap();
812
813 let config = SearchConfig::new()
814 .with_semantic(true)
815 .with_bm25(true)
816 .with_threshold(0.0); let results = hybrid_search(&storage, &embedder, "programming", &config).unwrap();
819 assert!(!results.is_empty());
820 }
821
822 #[test]
823 fn test_search_semantic() {
824 let mut storage = setup_storage_with_chunks();
825 let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
826
827 embed_buffer_chunks(&mut storage, &embedder, 1).unwrap();
829
830 let results = search_semantic(&storage, &embedder, "test query", 10, 0.0).unwrap();
831 for result in &results {
833 assert!(result.semantic_score.is_some());
834 assert!(result.bm25_score.is_none());
835 }
836 }
837
838 #[test]
839 fn test_search_semantic_empty_embeddings() {
840 let storage = setup_storage_with_chunks();
841 let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
842
843 let results = search_semantic(&storage, &embedder, "test query", 10, 0.5).unwrap();
845 assert!(results.is_empty());
846 }
847
848 #[test]
849 fn test_incremental_embed_new_chunks() {
850 let mut storage = setup_storage_with_chunks();
851 let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
852
853 let result = embed_buffer_chunks_incremental(&mut storage, &embedder, 1, false).unwrap();
855 assert_eq!(result.embedded_count, 3);
856 assert_eq!(result.skipped_count, 0);
857 assert_eq!(result.replaced_count, 0);
858 assert_eq!(result.total_chunks, 3);
859 assert!(result.had_changes());
860
861 let result2 = embed_buffer_chunks_incremental(&mut storage, &embedder, 1, false).unwrap();
863 assert_eq!(result2.embedded_count, 0);
864 assert_eq!(result2.skipped_count, 3);
865 assert_eq!(result2.replaced_count, 0);
866 assert!(!result2.had_changes());
867 }
868
869 #[test]
870 fn test_incremental_embed_force_reembed() {
871 let mut storage = setup_storage_with_chunks();
872 let embedder = FallbackEmbedder::new(DEFAULT_DIMENSIONS);
873
874 embed_buffer_chunks_incremental(&mut storage, &embedder, 1, false).unwrap();
876
877 let result = embed_buffer_chunks_incremental(&mut storage, &embedder, 1, true).unwrap();
879 assert_eq!(result.skipped_count, 3);
882 assert!(!result.had_changes());
883 }
884
885 #[test]
886 fn test_incremental_embed_result_completion() {
887 let result = IncrementalEmbedResult {
888 embedded_count: 2,
889 skipped_count: 3,
890 replaced_count: 0,
891 total_chunks: 5,
892 model_name: "test".to_string(),
893 };
894 assert!(result.had_changes());
895 assert!((result.completion_percentage() - 100.0).abs() < f64::EPSILON);
896 }
897}