1use crate::search::VectorSearchIndex;
41use crate::types::{DistanceMetric, SearchConfig, SearchResult};
42use anyhow::{anyhow, Result};
43use serde::{Deserialize, Serialize};
44use std::collections::HashMap;
45use tracing::{debug, info};
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct Bm25Config {
50 pub k1: f32,
52 pub b: f32,
54}
55
56impl Default for Bm25Config {
57 fn default() -> Self {
58 Self { k1: 1.2, b: 0.75 }
59 }
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct HybridConfig {
65 pub alpha: f32,
67 pub metric: DistanceMetric,
69 pub bm25: Bm25Config,
71 pub rrf_k: f32,
73 pub normalize: bool,
75}
76
77impl Default for HybridConfig {
78 fn default() -> Self {
79 Self {
80 alpha: 0.5,
81 metric: DistanceMetric::Cosine,
82 bm25: Bm25Config::default(),
83 rrf_k: 60.0,
84 normalize: true,
85 }
86 }
87}
88
89impl HybridConfig {
90 pub fn vector_heavy() -> Self {
92 Self {
93 alpha: 0.7,
94 ..Default::default()
95 }
96 }
97
98 pub fn keyword_heavy() -> Self {
100 Self {
101 alpha: 0.3,
102 ..Default::default()
103 }
104 }
105}
106
107struct Bm25Index {
109 config: Bm25Config,
110 documents: HashMap<String, String>,
112 inverted_index: HashMap<String, HashMap<String, usize>>,
114 doc_lengths: HashMap<String, usize>,
116 avg_doc_length: f32,
118 num_docs: usize,
120}
121
122impl Bm25Index {
123 fn new(config: Bm25Config) -> Self {
124 Self {
125 config,
126 documents: HashMap::new(),
127 inverted_index: HashMap::new(),
128 doc_lengths: HashMap::new(),
129 avg_doc_length: 0.0,
130 num_docs: 0,
131 }
132 }
133
134 fn build(&mut self, texts: &HashMap<String, String>) {
135 self.documents = texts.clone();
136 self.num_docs = texts.len();
137
138 let mut total_length = 0;
140
141 for (entity_id, text) in texts {
142 let tokens = self.tokenize(text);
143 let doc_len = tokens.len();
144 self.doc_lengths.insert(entity_id.clone(), doc_len);
145 total_length += doc_len;
146
147 let mut term_counts: HashMap<String, usize> = HashMap::new();
149 for token in tokens {
150 *term_counts.entry(token).or_insert(0) += 1;
151 }
152
153 for (term, count) in term_counts {
155 self.inverted_index
156 .entry(term)
157 .or_default()
158 .insert(entity_id.clone(), count);
159 }
160 }
161
162 self.avg_doc_length = if self.num_docs > 0 {
163 total_length as f32 / self.num_docs as f32
164 } else {
165 0.0
166 };
167 }
168
169 fn tokenize(&self, text: &str) -> Vec<String> {
170 text.to_lowercase()
171 .split(|c: char| !c.is_alphanumeric())
172 .filter(|s| !s.is_empty() && s.len() > 1)
173 .map(|s| s.to_string())
174 .collect()
175 }
176
177 fn search(&self, query: &str, k: usize) -> Vec<(String, f32)> {
178 let query_tokens = self.tokenize(query);
179 let mut scores: HashMap<String, f32> = HashMap::new();
180
181 for token in &query_tokens {
182 if let Some(postings) = self.inverted_index.get(token) {
183 let df = postings.len() as f32;
185 let idf = ((self.num_docs as f32 - df + 0.5) / (df + 0.5) + 1.0).ln();
186
187 for (entity_id, &tf) in postings {
188 let doc_len = *self.doc_lengths.get(entity_id).unwrap_or(&1) as f32;
189 let tf_f = tf as f32;
190
191 let numerator = tf_f * (self.config.k1 + 1.0);
193 let denominator = tf_f
194 + self.config.k1
195 * (1.0 - self.config.b
196 + self.config.b * (doc_len / self.avg_doc_length));
197
198 let score = idf * (numerator / denominator);
199 *scores.entry(entity_id.clone()).or_insert(0.0) += score;
200 }
201 }
202 }
203
204 let mut results: Vec<(String, f32)> = scores.into_iter().collect();
206 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
207 results.truncate(k);
208 results
209 }
210}
211
212pub struct HybridIndex {
214 config: HybridConfig,
215 vector_index: VectorSearchIndex,
216 bm25_index: Bm25Index,
217 entity_ids: Vec<String>,
218 is_built: bool,
219}
220
221impl HybridIndex {
222 pub fn new(config: HybridConfig) -> Self {
224 info!(
225 "Initialized hybrid index: alpha={}, metric={:?}",
226 config.alpha, config.metric
227 );
228
229 let vector_config = SearchConfig {
230 metric: config.metric,
231 parallel: true,
232 normalize: config.normalize,
233 };
234
235 Self {
236 vector_index: VectorSearchIndex::new(vector_config),
237 bm25_index: Bm25Index::new(config.bm25.clone()),
238 config,
239 entity_ids: Vec::new(),
240 is_built: false,
241 }
242 }
243
244 pub fn build(
246 &mut self,
247 embeddings: &HashMap<String, Vec<f32>>,
248 texts: &HashMap<String, String>,
249 ) -> Result<()> {
250 if embeddings.is_empty() {
251 return Err(anyhow!("Cannot build index from empty embeddings"));
252 }
253
254 for entity_id in embeddings.keys() {
256 if !texts.contains_key(entity_id) {
257 return Err(anyhow!(
258 "Missing text for entity '{}'. All embeddings must have corresponding texts.",
259 entity_id
260 ));
261 }
262 }
263
264 info!("Building hybrid index for {} entities", embeddings.len());
265
266 self.entity_ids = embeddings.keys().cloned().collect();
267
268 self.vector_index.build(embeddings)?;
270
271 self.bm25_index.build(texts);
273
274 self.is_built = true;
275 info!("Hybrid index built successfully");
276 Ok(())
277 }
278
279 pub fn search(
283 &self,
284 query_vector: &[f32],
285 query_text: &str,
286 k: usize,
287 ) -> Result<Vec<HybridSearchResult>> {
288 if !self.is_built {
289 return Err(anyhow!("Index not built. Call build() first"));
290 }
291
292 debug!(
293 "Hybrid search: k={}, alpha={}, query_text='{}'",
294 k, self.config.alpha, query_text
295 );
296
297 let expanded_k = (k * 3).min(self.entity_ids.len());
299
300 let vector_results = self.vector_index.search(query_vector, expanded_k)?;
302
303 let bm25_results = self.bm25_index.search(query_text, expanded_k);
305
306 let results = self.reciprocal_rank_fusion(&vector_results, &bm25_results, k);
308
309 debug!("Hybrid search returned {} results", results.len());
310 Ok(results)
311 }
312
313 fn reciprocal_rank_fusion(
315 &self,
316 vector_results: &[SearchResult],
317 bm25_results: &[(String, f32)],
318 k: usize,
319 ) -> Vec<HybridSearchResult> {
320 let mut rrf_scores: HashMap<String, f32> = HashMap::new();
321 let mut vector_scores: HashMap<String, f32> = HashMap::new();
322 let mut bm25_scores: HashMap<String, f32> = HashMap::new();
323
324 for (rank, result) in vector_results.iter().enumerate() {
326 let rrf_score = 1.0 / (self.config.rrf_k + rank as f32 + 1.0);
327 *rrf_scores.entry(result.entity_id.clone()).or_insert(0.0) +=
328 self.config.alpha * rrf_score;
329 vector_scores.insert(result.entity_id.clone(), result.score);
330 }
331
332 for (rank, (entity_id, score)) in bm25_results.iter().enumerate() {
334 let rrf_score = 1.0 / (self.config.rrf_k + rank as f32 + 1.0);
335 *rrf_scores.entry(entity_id.clone()).or_insert(0.0) +=
336 (1.0 - self.config.alpha) * rrf_score;
337 bm25_scores.insert(entity_id.clone(), *score);
338 }
339
340 let mut results: Vec<(String, f32)> = rrf_scores.into_iter().collect();
342 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
343
344 results
346 .into_iter()
347 .take(k)
348 .enumerate()
349 .map(|(rank, (entity_id, combined_score))| HybridSearchResult {
350 entity_id: entity_id.clone(),
351 combined_score,
352 vector_score: vector_scores.get(&entity_id).copied(),
353 bm25_score: bm25_scores.get(&entity_id).copied(),
354 rank: rank + 1,
355 })
356 .collect()
357 }
358
359 pub fn weighted_search(
361 &self,
362 query_vector: &[f32],
363 query_text: &str,
364 k: usize,
365 ) -> Result<Vec<HybridSearchResult>> {
366 if !self.is_built {
367 return Err(anyhow!("Index not built. Call build() first"));
368 }
369
370 let expanded_k = (k * 3).min(self.entity_ids.len());
371
372 let vector_results = self.vector_index.search(query_vector, expanded_k)?;
374
375 let bm25_results = self.bm25_index.search(query_text, expanded_k);
377
378 let mut combined_scores: HashMap<String, (Option<f32>, Option<f32>)> = HashMap::new();
380
381 let max_vector_score = vector_results.first().map(|r| r.score).unwrap_or(1.0);
383 for result in &vector_results {
384 let norm_score = if max_vector_score > 0.0 {
385 result.score / max_vector_score
386 } else {
387 0.0
388 };
389 combined_scores.insert(result.entity_id.clone(), (Some(norm_score), None));
390 }
391
392 let max_bm25_score = bm25_results.first().map(|(_, s)| *s).unwrap_or(1.0);
394 for (entity_id, score) in &bm25_results {
395 let norm_score = if max_bm25_score > 0.0 {
396 score / max_bm25_score
397 } else {
398 0.0
399 };
400 combined_scores
401 .entry(entity_id.clone())
402 .and_modify(|(_, b)| *b = Some(norm_score))
403 .or_insert((None, Some(norm_score)));
404 }
405
406 let mut results: Vec<HybridSearchResult> = combined_scores
408 .into_iter()
409 .map(|(entity_id, (v_score, b_score))| {
410 let v = v_score.unwrap_or(0.0);
411 let b = b_score.unwrap_or(0.0);
412 let combined = self.config.alpha * v + (1.0 - self.config.alpha) * b;
413
414 HybridSearchResult {
415 entity_id,
416 combined_score: combined,
417 vector_score: v_score,
418 bm25_score: b_score,
419 rank: 0, }
421 })
422 .collect();
423
424 results.sort_by(|a, b| {
426 b.combined_score
427 .partial_cmp(&a.combined_score)
428 .unwrap_or(std::cmp::Ordering::Equal)
429 });
430
431 for (i, result) in results.iter_mut().enumerate() {
433 result.rank = i + 1;
434 }
435
436 results.truncate(k);
437 Ok(results)
438 }
439
440 pub fn vector_search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
442 self.vector_index.search(query, k)
443 }
444
445 pub fn keyword_search(&self, query: &str, k: usize) -> Result<Vec<HybridSearchResult>> {
447 if !self.is_built {
448 return Err(anyhow!("Index not built. Call build() first"));
449 }
450
451 let results = self.bm25_index.search(query, k);
452
453 Ok(results
454 .into_iter()
455 .enumerate()
456 .map(|(rank, (entity_id, score))| HybridSearchResult {
457 entity_id,
458 combined_score: score,
459 vector_score: None,
460 bm25_score: Some(score),
461 rank: rank + 1,
462 })
463 .collect())
464 }
465
466 pub fn get_stats(&self) -> HybridStats {
468 HybridStats {
469 num_documents: self.entity_ids.len(),
470 vocabulary_size: self.bm25_index.inverted_index.len(),
471 avg_doc_length: self.bm25_index.avg_doc_length,
472 alpha: self.config.alpha,
473 is_built: self.is_built,
474 }
475 }
476
477 pub fn set_alpha(&mut self, alpha: f32) {
479 self.config.alpha = alpha.clamp(0.0, 1.0);
480 }
481}
482
483#[derive(Debug, Clone, Serialize, Deserialize)]
485pub struct HybridSearchResult {
486 pub entity_id: String,
488 pub combined_score: f32,
490 pub vector_score: Option<f32>,
492 pub bm25_score: Option<f32>,
494 pub rank: usize,
496}
497
498#[derive(Debug, Clone, Serialize, Deserialize)]
500pub struct HybridStats {
501 pub num_documents: usize,
503 pub vocabulary_size: usize,
505 pub avg_doc_length: f32,
507 pub alpha: f32,
509 pub is_built: bool,
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516 use std::collections::HashSet;
517
518 fn create_test_data() -> (HashMap<String, Vec<f32>>, HashMap<String, String>) {
519 let mut embeddings = HashMap::new();
520 let mut texts = HashMap::new();
521
522 embeddings.insert("doc1".to_string(), vec![0.9, 0.1, 0.0]);
524 texts.insert(
525 "doc1".to_string(),
526 "rust programming language systems programming".to_string(),
527 );
528
529 embeddings.insert("doc2".to_string(), vec![0.8, 0.2, 0.0]);
530 texts.insert(
531 "doc2".to_string(),
532 "rust cargo package manager dependencies".to_string(),
533 );
534
535 embeddings.insert("doc3".to_string(), vec![0.1, 0.9, 0.0]);
537 texts.insert(
538 "doc3".to_string(),
539 "python machine learning deep learning neural networks".to_string(),
540 );
541
542 embeddings.insert("doc4".to_string(), vec![0.0, 0.8, 0.2]);
543 texts.insert(
544 "doc4".to_string(),
545 "python data science pandas numpy analysis".to_string(),
546 );
547
548 embeddings.insert("doc5".to_string(), vec![0.5, 0.5, 0.0]);
550 texts.insert(
551 "doc5".to_string(),
552 "rust machine learning inference performance".to_string(),
553 );
554
555 (embeddings, texts)
556 }
557
558 #[test]
559 fn test_hybrid_config_default() {
560 let config = HybridConfig::default();
561 assert_eq!(config.alpha, 0.5);
562 assert_eq!(config.rrf_k, 60.0);
563 }
564
565 #[test]
566 fn test_hybrid_build() {
567 let (embeddings, texts) = create_test_data();
568 let mut index = HybridIndex::new(HybridConfig::default());
569
570 assert!(index.build(&embeddings, &texts).is_ok());
571 assert!(index.is_built);
572
573 let stats = index.get_stats();
574 assert_eq!(stats.num_documents, 5);
575 assert!(stats.vocabulary_size > 0);
576 }
577
578 #[test]
579 fn test_hybrid_search() {
580 let (embeddings, texts) = create_test_data();
581 let mut index = HybridIndex::new(HybridConfig::default());
582 index.build(&embeddings, &texts).unwrap();
583
584 let query_vector = vec![0.85, 0.15, 0.0];
586 let query_text = "rust programming";
587 let results = index.search(&query_vector, query_text, 3).unwrap();
588
589 assert_eq!(results.len(), 3);
590 assert!(results[0].entity_id == "doc1" || results[0].entity_id == "doc2");
592 }
593
594 #[test]
595 fn test_weighted_search() {
596 let (embeddings, texts) = create_test_data();
597 let mut index = HybridIndex::new(HybridConfig::default());
598 index.build(&embeddings, &texts).unwrap();
599
600 let query_vector = vec![0.85, 0.15, 0.0];
601 let query_text = "rust programming";
602 let results = index.weighted_search(&query_vector, query_text, 3).unwrap();
603
604 assert_eq!(results.len(), 3);
605 assert!(results[0].vector_score.is_some() || results[0].bm25_score.is_some());
607 }
608
609 #[test]
610 fn test_vector_only_search() {
611 let (embeddings, texts) = create_test_data();
612 let mut index = HybridIndex::new(HybridConfig::default());
613 index.build(&embeddings, &texts).unwrap();
614
615 let query_vector = vec![0.85, 0.15, 0.0];
616 let results = index.vector_search(&query_vector, 3).unwrap();
617
618 assert_eq!(results.len(), 3);
619 }
620
621 #[test]
622 fn test_keyword_only_search() {
623 let (embeddings, texts) = create_test_data();
624 let mut index = HybridIndex::new(HybridConfig::default());
625 index.build(&embeddings, &texts).unwrap();
626
627 let results = index.keyword_search("python machine learning", 3).unwrap();
628
629 assert_eq!(results.len(), 3);
630 assert!(results[0].entity_id == "doc3" || results[0].entity_id == "doc5");
632 }
633
634 #[test]
635 fn test_alpha_adjustment() {
636 let (embeddings, texts) = create_test_data();
637 let mut index = HybridIndex::new(HybridConfig::default());
638 index.build(&embeddings, &texts).unwrap();
639
640 index.set_alpha(0.8);
641 let stats = index.get_stats();
642 assert_eq!(stats.alpha, 0.8);
643
644 index.set_alpha(1.5);
646 let stats = index.get_stats();
647 assert_eq!(stats.alpha, 1.0);
648 }
649
650 #[test]
651 fn test_bm25_scoring() {
652 let (embeddings, texts) = create_test_data();
653 let mut index = HybridIndex::new(HybridConfig::default());
654 index.build(&embeddings, &texts).unwrap();
655
656 let results = index.keyword_search("rust", 5).unwrap();
658
659 let rust_docs: HashSet<&str> = results.iter().map(|r| r.entity_id.as_str()).collect();
661 assert!(rust_docs.contains("doc1"));
662 assert!(rust_docs.contains("doc2"));
663 assert!(rust_docs.contains("doc5"));
664 }
665
666 #[test]
667 fn test_empty_query() {
668 let (embeddings, texts) = create_test_data();
669 let mut index = HybridIndex::new(HybridConfig::default());
670 index.build(&embeddings, &texts).unwrap();
671
672 let results = index.keyword_search("", 3).unwrap();
674 assert_eq!(results.len(), 0);
675 }
676
677 #[test]
678 fn test_missing_text_error() {
679 let mut embeddings = HashMap::new();
680 embeddings.insert("doc1".to_string(), vec![0.9, 0.1, 0.0]);
681
682 let texts: HashMap<String, String> = HashMap::new(); let mut index = HybridIndex::new(HybridConfig::default());
685 assert!(index.build(&embeddings, &texts).is_err());
686 }
687}