oxify_vector/
hybrid.rs

1//! Hybrid Search: Vector + Keyword (BM25)
2//!
3//! Combines semantic vector search with lexical keyword search for improved results.
4//!
5//! ## Algorithm
6//!
7//! - **Vector Search:** Semantic similarity using embeddings
8//! - **Keyword Search:** BM25 scoring for lexical matching
9//! - **Fusion:** Reciprocal Rank Fusion (RRF) or weighted combination
10//!
11//! ## Example
12//!
13//! ```rust
14//! use oxify_vector::hybrid::{HybridIndex, HybridConfig};
15//! use std::collections::HashMap;
16//!
17//! # fn example() -> anyhow::Result<()> {
18//! // Create documents with text and embeddings
19//! let mut embeddings = HashMap::new();
20//! embeddings.insert("doc1".to_string(), vec![0.1, 0.2, 0.3]);
21//! embeddings.insert("doc2".to_string(), vec![0.2, 0.3, 0.4]);
22//!
23//! let mut texts = HashMap::new();
24//! texts.insert("doc1".to_string(), "rust programming language".to_string());
25//! texts.insert("doc2".to_string(), "python machine learning".to_string());
26//!
27//! // Build hybrid index
28//! let config = HybridConfig::default();
29//! let mut index = HybridIndex::new(config);
30//! index.build(&embeddings, &texts)?;
31//!
32//! // Hybrid search
33//! let query_vector = vec![0.15, 0.25, 0.35];
34//! let query_text = "rust programming";
35//! let results = index.search(&query_vector, query_text, 2)?;
36//! # Ok(())
37//! # }
38//! ```
39
40use crate::search::VectorSearchIndex;
41use crate::types::{DistanceMetric, SearchConfig, SearchResult};
42use anyhow::{anyhow, Result};
43use serde::{Deserialize, Serialize};
44use std::collections::HashMap;
45use tracing::{debug, info};
46
47/// BM25 parameters
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct Bm25Config {
50    /// Term frequency saturation parameter (default: 1.2)
51    pub k1: f32,
52    /// Document length normalization parameter (default: 0.75)
53    pub b: f32,
54}
55
56impl Default for Bm25Config {
57    fn default() -> Self {
58        Self { k1: 1.2, b: 0.75 }
59    }
60}
61
62/// Hybrid search configuration
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct HybridConfig {
65    /// Vector search weight (0.0 to 1.0)
66    pub alpha: f32,
67    /// Distance metric for vector search
68    pub metric: DistanceMetric,
69    /// BM25 parameters
70    pub bm25: Bm25Config,
71    /// RRF constant (typically 60)
72    pub rrf_k: f32,
73    /// Normalize vectors
74    pub normalize: bool,
75}
76
77impl Default for HybridConfig {
78    fn default() -> Self {
79        Self {
80            alpha: 0.5,
81            metric: DistanceMetric::Cosine,
82            bm25: Bm25Config::default(),
83            rrf_k: 60.0,
84            normalize: true,
85        }
86    }
87}
88
89impl HybridConfig {
90    /// Create config favoring vector search
91    pub fn vector_heavy() -> Self {
92        Self {
93            alpha: 0.7,
94            ..Default::default()
95        }
96    }
97
98    /// Create config favoring keyword search
99    pub fn keyword_heavy() -> Self {
100        Self {
101            alpha: 0.3,
102            ..Default::default()
103        }
104    }
105}
106
107/// BM25 index for keyword search
108struct Bm25Index {
109    config: Bm25Config,
110    /// Document texts (entity_id -> text)
111    documents: HashMap<String, String>,
112    /// Inverted index (term -> entity_ids with term frequency)
113    inverted_index: HashMap<String, HashMap<String, usize>>,
114    /// Document lengths (entity_id -> word count)
115    doc_lengths: HashMap<String, usize>,
116    /// Average document length
117    avg_doc_length: f32,
118    /// Total number of documents
119    num_docs: usize,
120}
121
122impl Bm25Index {
123    fn new(config: Bm25Config) -> Self {
124        Self {
125            config,
126            documents: HashMap::new(),
127            inverted_index: HashMap::new(),
128            doc_lengths: HashMap::new(),
129            avg_doc_length: 0.0,
130            num_docs: 0,
131        }
132    }
133
134    fn build(&mut self, texts: &HashMap<String, String>) {
135        self.documents = texts.clone();
136        self.num_docs = texts.len();
137
138        // Tokenize and build inverted index
139        let mut total_length = 0;
140
141        for (entity_id, text) in texts {
142            let tokens = self.tokenize(text);
143            let doc_len = tokens.len();
144            self.doc_lengths.insert(entity_id.clone(), doc_len);
145            total_length += doc_len;
146
147            // Count term frequencies
148            let mut term_counts: HashMap<String, usize> = HashMap::new();
149            for token in tokens {
150                *term_counts.entry(token).or_insert(0) += 1;
151            }
152
153            // Update inverted index
154            for (term, count) in term_counts {
155                self.inverted_index
156                    .entry(term)
157                    .or_default()
158                    .insert(entity_id.clone(), count);
159            }
160        }
161
162        self.avg_doc_length = if self.num_docs > 0 {
163            total_length as f32 / self.num_docs as f32
164        } else {
165            0.0
166        };
167    }
168
169    fn tokenize(&self, text: &str) -> Vec<String> {
170        text.to_lowercase()
171            .split(|c: char| !c.is_alphanumeric())
172            .filter(|s| !s.is_empty() && s.len() > 1)
173            .map(|s| s.to_string())
174            .collect()
175    }
176
177    fn search(&self, query: &str, k: usize) -> Vec<(String, f32)> {
178        let query_tokens = self.tokenize(query);
179        let mut scores: HashMap<String, f32> = HashMap::new();
180
181        for token in &query_tokens {
182            if let Some(postings) = self.inverted_index.get(token) {
183                // Calculate IDF
184                let df = postings.len() as f32;
185                let idf = ((self.num_docs as f32 - df + 0.5) / (df + 0.5) + 1.0).ln();
186
187                for (entity_id, &tf) in postings {
188                    let doc_len = *self.doc_lengths.get(entity_id).unwrap_or(&1) as f32;
189                    let tf_f = tf as f32;
190
191                    // BM25 formula
192                    let numerator = tf_f * (self.config.k1 + 1.0);
193                    let denominator = tf_f
194                        + self.config.k1
195                            * (1.0 - self.config.b
196                                + self.config.b * (doc_len / self.avg_doc_length));
197
198                    let score = idf * (numerator / denominator);
199                    *scores.entry(entity_id.clone()).or_insert(0.0) += score;
200                }
201            }
202        }
203
204        // Sort by score descending
205        let mut results: Vec<(String, f32)> = scores.into_iter().collect();
206        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
207        results.truncate(k);
208        results
209    }
210}
211
212/// Hybrid search index combining vector and keyword search
213pub struct HybridIndex {
214    config: HybridConfig,
215    vector_index: VectorSearchIndex,
216    bm25_index: Bm25Index,
217    entity_ids: Vec<String>,
218    is_built: bool,
219}
220
221impl HybridIndex {
222    /// Create a new hybrid search index
223    pub fn new(config: HybridConfig) -> Self {
224        info!(
225            "Initialized hybrid index: alpha={}, metric={:?}",
226            config.alpha, config.metric
227        );
228
229        let vector_config = SearchConfig {
230            metric: config.metric,
231            parallel: true,
232            normalize: config.normalize,
233        };
234
235        Self {
236            vector_index: VectorSearchIndex::new(vector_config),
237            bm25_index: Bm25Index::new(config.bm25.clone()),
238            config,
239            entity_ids: Vec::new(),
240            is_built: false,
241        }
242    }
243
244    /// Build hybrid index from embeddings and texts
245    pub fn build(
246        &mut self,
247        embeddings: &HashMap<String, Vec<f32>>,
248        texts: &HashMap<String, String>,
249    ) -> Result<()> {
250        if embeddings.is_empty() {
251            return Err(anyhow!("Cannot build index from empty embeddings"));
252        }
253
254        // Verify all embeddings have corresponding texts
255        for entity_id in embeddings.keys() {
256            if !texts.contains_key(entity_id) {
257                return Err(anyhow!(
258                    "Missing text for entity '{}'. All embeddings must have corresponding texts.",
259                    entity_id
260                ));
261            }
262        }
263
264        info!("Building hybrid index for {} entities", embeddings.len());
265
266        self.entity_ids = embeddings.keys().cloned().collect();
267
268        // Build vector index
269        self.vector_index.build(embeddings)?;
270
271        // Build BM25 index
272        self.bm25_index.build(texts);
273
274        self.is_built = true;
275        info!("Hybrid index built successfully");
276        Ok(())
277    }
278
279    /// Hybrid search combining vector and keyword results
280    ///
281    /// Uses Reciprocal Rank Fusion (RRF) to combine results.
282    pub fn search(
283        &self,
284        query_vector: &[f32],
285        query_text: &str,
286        k: usize,
287    ) -> Result<Vec<HybridSearchResult>> {
288        if !self.is_built {
289            return Err(anyhow!("Index not built. Call build() first"));
290        }
291
292        debug!(
293            "Hybrid search: k={}, alpha={}, query_text='{}'",
294            k, self.config.alpha, query_text
295        );
296
297        // Get more candidates for fusion
298        let expanded_k = (k * 3).min(self.entity_ids.len());
299
300        // Vector search
301        let vector_results = self.vector_index.search(query_vector, expanded_k)?;
302
303        // BM25 search
304        let bm25_results = self.bm25_index.search(query_text, expanded_k);
305
306        // Combine using RRF
307        let results = self.reciprocal_rank_fusion(&vector_results, &bm25_results, k);
308
309        debug!("Hybrid search returned {} results", results.len());
310        Ok(results)
311    }
312
313    /// Combine results using Reciprocal Rank Fusion
314    fn reciprocal_rank_fusion(
315        &self,
316        vector_results: &[SearchResult],
317        bm25_results: &[(String, f32)],
318        k: usize,
319    ) -> Vec<HybridSearchResult> {
320        let mut rrf_scores: HashMap<String, f32> = HashMap::new();
321        let mut vector_scores: HashMap<String, f32> = HashMap::new();
322        let mut bm25_scores: HashMap<String, f32> = HashMap::new();
323
324        // Calculate RRF scores from vector results
325        for (rank, result) in vector_results.iter().enumerate() {
326            let rrf_score = 1.0 / (self.config.rrf_k + rank as f32 + 1.0);
327            *rrf_scores.entry(result.entity_id.clone()).or_insert(0.0) +=
328                self.config.alpha * rrf_score;
329            vector_scores.insert(result.entity_id.clone(), result.score);
330        }
331
332        // Calculate RRF scores from BM25 results
333        for (rank, (entity_id, score)) in bm25_results.iter().enumerate() {
334            let rrf_score = 1.0 / (self.config.rrf_k + rank as f32 + 1.0);
335            *rrf_scores.entry(entity_id.clone()).or_insert(0.0) +=
336                (1.0 - self.config.alpha) * rrf_score;
337            bm25_scores.insert(entity_id.clone(), *score);
338        }
339
340        // Sort by combined RRF score
341        let mut results: Vec<(String, f32)> = rrf_scores.into_iter().collect();
342        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
343
344        // Convert to HybridSearchResult
345        results
346            .into_iter()
347            .take(k)
348            .enumerate()
349            .map(|(rank, (entity_id, combined_score))| HybridSearchResult {
350                entity_id: entity_id.clone(),
351                combined_score,
352                vector_score: vector_scores.get(&entity_id).copied(),
353                bm25_score: bm25_scores.get(&entity_id).copied(),
354                rank: rank + 1,
355            })
356            .collect()
357    }
358
359    /// Search using weighted linear combination instead of RRF
360    pub fn weighted_search(
361        &self,
362        query_vector: &[f32],
363        query_text: &str,
364        k: usize,
365    ) -> Result<Vec<HybridSearchResult>> {
366        if !self.is_built {
367            return Err(anyhow!("Index not built. Call build() first"));
368        }
369
370        let expanded_k = (k * 3).min(self.entity_ids.len());
371
372        // Vector search
373        let vector_results = self.vector_index.search(query_vector, expanded_k)?;
374
375        // BM25 search
376        let bm25_results = self.bm25_index.search(query_text, expanded_k);
377
378        // Normalize and combine scores
379        let mut combined_scores: HashMap<String, (Option<f32>, Option<f32>)> = HashMap::new();
380
381        // Normalize vector scores (already in 0-1 for cosine)
382        let max_vector_score = vector_results.first().map(|r| r.score).unwrap_or(1.0);
383        for result in &vector_results {
384            let norm_score = if max_vector_score > 0.0 {
385                result.score / max_vector_score
386            } else {
387                0.0
388            };
389            combined_scores.insert(result.entity_id.clone(), (Some(norm_score), None));
390        }
391
392        // Normalize BM25 scores
393        let max_bm25_score = bm25_results.first().map(|(_, s)| *s).unwrap_or(1.0);
394        for (entity_id, score) in &bm25_results {
395            let norm_score = if max_bm25_score > 0.0 {
396                score / max_bm25_score
397            } else {
398                0.0
399            };
400            combined_scores
401                .entry(entity_id.clone())
402                .and_modify(|(_, b)| *b = Some(norm_score))
403                .or_insert((None, Some(norm_score)));
404        }
405
406        // Calculate weighted combination
407        let mut results: Vec<HybridSearchResult> = combined_scores
408            .into_iter()
409            .map(|(entity_id, (v_score, b_score))| {
410                let v = v_score.unwrap_or(0.0);
411                let b = b_score.unwrap_or(0.0);
412                let combined = self.config.alpha * v + (1.0 - self.config.alpha) * b;
413
414                HybridSearchResult {
415                    entity_id,
416                    combined_score: combined,
417                    vector_score: v_score,
418                    bm25_score: b_score,
419                    rank: 0, // Will be set below
420                }
421            })
422            .collect();
423
424        // Sort by combined score
425        results.sort_by(|a, b| {
426            b.combined_score
427                .partial_cmp(&a.combined_score)
428                .unwrap_or(std::cmp::Ordering::Equal)
429        });
430
431        // Set ranks
432        for (i, result) in results.iter_mut().enumerate() {
433            result.rank = i + 1;
434        }
435
436        results.truncate(k);
437        Ok(results)
438    }
439
440    /// Vector-only search (alpha = 1.0)
441    pub fn vector_search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
442        self.vector_index.search(query, k)
443    }
444
445    /// Keyword-only search (alpha = 0.0)
446    pub fn keyword_search(&self, query: &str, k: usize) -> Result<Vec<HybridSearchResult>> {
447        if !self.is_built {
448            return Err(anyhow!("Index not built. Call build() first"));
449        }
450
451        let results = self.bm25_index.search(query, k);
452
453        Ok(results
454            .into_iter()
455            .enumerate()
456            .map(|(rank, (entity_id, score))| HybridSearchResult {
457                entity_id,
458                combined_score: score,
459                vector_score: None,
460                bm25_score: Some(score),
461                rank: rank + 1,
462            })
463            .collect())
464    }
465
466    /// Get index statistics
467    pub fn get_stats(&self) -> HybridStats {
468        HybridStats {
469            num_documents: self.entity_ids.len(),
470            vocabulary_size: self.bm25_index.inverted_index.len(),
471            avg_doc_length: self.bm25_index.avg_doc_length,
472            alpha: self.config.alpha,
473            is_built: self.is_built,
474        }
475    }
476
477    /// Set alpha parameter (vector vs keyword weight)
478    pub fn set_alpha(&mut self, alpha: f32) {
479        self.config.alpha = alpha.clamp(0.0, 1.0);
480    }
481}
482
483/// Hybrid search result
484#[derive(Debug, Clone, Serialize, Deserialize)]
485pub struct HybridSearchResult {
486    /// Entity ID
487    pub entity_id: String,
488    /// Combined score from RRF or weighted combination
489    pub combined_score: f32,
490    /// Vector similarity score (if available)
491    pub vector_score: Option<f32>,
492    /// BM25 score (if available)
493    pub bm25_score: Option<f32>,
494    /// Rank in results (1-indexed)
495    pub rank: usize,
496}
497
498/// Hybrid index statistics
499#[derive(Debug, Clone, Serialize, Deserialize)]
500pub struct HybridStats {
501    /// Number of documents in the index
502    pub num_documents: usize,
503    /// Number of unique terms in vocabulary
504    pub vocabulary_size: usize,
505    /// Average document length (in tokens)
506    pub avg_doc_length: f32,
507    /// Current alpha setting
508    pub alpha: f32,
509    /// Whether index is built
510    pub is_built: bool,
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516    use std::collections::HashSet;
517
518    fn create_test_data() -> (HashMap<String, Vec<f32>>, HashMap<String, String>) {
519        let mut embeddings = HashMap::new();
520        let mut texts = HashMap::new();
521
522        // Tech document - similar vectors
523        embeddings.insert("doc1".to_string(), vec![0.9, 0.1, 0.0]);
524        texts.insert(
525            "doc1".to_string(),
526            "rust programming language systems programming".to_string(),
527        );
528
529        embeddings.insert("doc2".to_string(), vec![0.8, 0.2, 0.0]);
530        texts.insert(
531            "doc2".to_string(),
532            "rust cargo package manager dependencies".to_string(),
533        );
534
535        // ML document
536        embeddings.insert("doc3".to_string(), vec![0.1, 0.9, 0.0]);
537        texts.insert(
538            "doc3".to_string(),
539            "python machine learning deep learning neural networks".to_string(),
540        );
541
542        embeddings.insert("doc4".to_string(), vec![0.0, 0.8, 0.2]);
543        texts.insert(
544            "doc4".to_string(),
545            "python data science pandas numpy analysis".to_string(),
546        );
547
548        // Mixed
549        embeddings.insert("doc5".to_string(), vec![0.5, 0.5, 0.0]);
550        texts.insert(
551            "doc5".to_string(),
552            "rust machine learning inference performance".to_string(),
553        );
554
555        (embeddings, texts)
556    }
557
558    #[test]
559    fn test_hybrid_config_default() {
560        let config = HybridConfig::default();
561        assert_eq!(config.alpha, 0.5);
562        assert_eq!(config.rrf_k, 60.0);
563    }
564
565    #[test]
566    fn test_hybrid_build() {
567        let (embeddings, texts) = create_test_data();
568        let mut index = HybridIndex::new(HybridConfig::default());
569
570        assert!(index.build(&embeddings, &texts).is_ok());
571        assert!(index.is_built);
572
573        let stats = index.get_stats();
574        assert_eq!(stats.num_documents, 5);
575        assert!(stats.vocabulary_size > 0);
576    }
577
578    #[test]
579    fn test_hybrid_search() {
580        let (embeddings, texts) = create_test_data();
581        let mut index = HybridIndex::new(HybridConfig::default());
582        index.build(&embeddings, &texts).unwrap();
583
584        // Search for rust programming
585        let query_vector = vec![0.85, 0.15, 0.0];
586        let query_text = "rust programming";
587        let results = index.search(&query_vector, query_text, 3).unwrap();
588
589        assert_eq!(results.len(), 3);
590        // doc1 or doc2 should be top results (both vector and keyword match)
591        assert!(results[0].entity_id == "doc1" || results[0].entity_id == "doc2");
592    }
593
594    #[test]
595    fn test_weighted_search() {
596        let (embeddings, texts) = create_test_data();
597        let mut index = HybridIndex::new(HybridConfig::default());
598        index.build(&embeddings, &texts).unwrap();
599
600        let query_vector = vec![0.85, 0.15, 0.0];
601        let query_text = "rust programming";
602        let results = index.weighted_search(&query_vector, query_text, 3).unwrap();
603
604        assert_eq!(results.len(), 3);
605        // Results should have both vector and BM25 scores
606        assert!(results[0].vector_score.is_some() || results[0].bm25_score.is_some());
607    }
608
609    #[test]
610    fn test_vector_only_search() {
611        let (embeddings, texts) = create_test_data();
612        let mut index = HybridIndex::new(HybridConfig::default());
613        index.build(&embeddings, &texts).unwrap();
614
615        let query_vector = vec![0.85, 0.15, 0.0];
616        let results = index.vector_search(&query_vector, 3).unwrap();
617
618        assert_eq!(results.len(), 3);
619    }
620
621    #[test]
622    fn test_keyword_only_search() {
623        let (embeddings, texts) = create_test_data();
624        let mut index = HybridIndex::new(HybridConfig::default());
625        index.build(&embeddings, &texts).unwrap();
626
627        let results = index.keyword_search("python machine learning", 3).unwrap();
628
629        assert_eq!(results.len(), 3);
630        // doc3 should be top (best keyword match)
631        assert!(results[0].entity_id == "doc3" || results[0].entity_id == "doc5");
632    }
633
634    #[test]
635    fn test_alpha_adjustment() {
636        let (embeddings, texts) = create_test_data();
637        let mut index = HybridIndex::new(HybridConfig::default());
638        index.build(&embeddings, &texts).unwrap();
639
640        index.set_alpha(0.8);
641        let stats = index.get_stats();
642        assert_eq!(stats.alpha, 0.8);
643
644        // Clamp to valid range
645        index.set_alpha(1.5);
646        let stats = index.get_stats();
647        assert_eq!(stats.alpha, 1.0);
648    }
649
650    #[test]
651    fn test_bm25_scoring() {
652        let (embeddings, texts) = create_test_data();
653        let mut index = HybridIndex::new(HybridConfig::default());
654        index.build(&embeddings, &texts).unwrap();
655
656        // Search for specific term
657        let results = index.keyword_search("rust", 5).unwrap();
658
659        // doc1, doc2, doc5 contain "rust"
660        let rust_docs: HashSet<&str> = results.iter().map(|r| r.entity_id.as_str()).collect();
661        assert!(rust_docs.contains("doc1"));
662        assert!(rust_docs.contains("doc2"));
663        assert!(rust_docs.contains("doc5"));
664    }
665
666    #[test]
667    fn test_empty_query() {
668        let (embeddings, texts) = create_test_data();
669        let mut index = HybridIndex::new(HybridConfig::default());
670        index.build(&embeddings, &texts).unwrap();
671
672        // Empty keyword query
673        let results = index.keyword_search("", 3).unwrap();
674        assert_eq!(results.len(), 0);
675    }
676
677    #[test]
678    fn test_missing_text_error() {
679        let mut embeddings = HashMap::new();
680        embeddings.insert("doc1".to_string(), vec![0.9, 0.1, 0.0]);
681
682        let texts: HashMap<String, String> = HashMap::new(); // Empty
683
684        let mut index = HybridIndex::new(HybridConfig::default());
685        assert!(index.build(&embeddings, &texts).is_err());
686    }
687}