Skip to main content

engine/
hybrid.rs

1//! Hybrid search combining vector similarity and full-text search
2//!
3//! Provides a unified search experience by combining:
4//! - Vector similarity scores (cosine, euclidean, dot product)
5//! - Full-text BM25 scores
6//!
7//! The final score is computed as:
8//! `score = vector_weight * vector_score + (1 - vector_weight) * text_score`
9
10use std::collections::HashMap;
11
12use crate::fulltext::FullTextResult;
13
14/// A vector search result row: (id, score, optional metadata, optional vector).
15type VectorResultRow = (String, f32, Option<serde_json::Value>, Option<Vec<f32>>);
16
17/// Configuration for hybrid search
18#[derive(Debug, Clone)]
19pub struct HybridConfig {
20    /// Weight for vector search (0.0 to 1.0)
21    pub vector_weight: f32,
22    /// Whether to require matches in both indices
23    pub require_both: bool,
24}
25
26impl Default for HybridConfig {
27    fn default() -> Self {
28        Self {
29            vector_weight: 0.5,
30            require_both: false,
31        }
32    }
33}
34
35/// Raw score from a single search type
36#[derive(Debug, Clone)]
37struct RawScore {
38    /// Original score before normalization
39    score: f32,
40    /// Additional data (metadata, vector)
41    metadata: Option<serde_json::Value>,
42    vector: Option<Vec<f32>>,
43}
44
45/// Result of hybrid search
46#[derive(Debug, Clone)]
47pub struct HybridResult {
48    /// Document/vector ID
49    pub id: String,
50    /// Combined score (weighted average of normalized scores)
51    pub combined_score: f32,
52    /// Normalized vector similarity score (0-1)
53    pub vector_score: f32,
54    /// Normalized text search score (0-1)
55    pub text_score: f32,
56    /// Optional metadata
57    pub metadata: Option<serde_json::Value>,
58    /// Optional vector values
59    pub vector: Option<Vec<f32>>,
60}
61
62/// Hybrid search engine that combines vector and text search
63pub struct HybridSearcher {
64    config: HybridConfig,
65}
66
67impl HybridSearcher {
68    pub fn new(config: HybridConfig) -> Self {
69        Self { config }
70    }
71
72    pub fn with_vector_weight(mut self, weight: f32) -> Self {
73        self.config.vector_weight = weight.clamp(0.0, 1.0);
74        self
75    }
76
77    /// Combine vector search results with full-text search results
78    ///
79    /// # Arguments
80    /// * `vector_results` - Results from vector similarity search (id, score, metadata, vector)
81    /// * `text_results` - Results from full-text BM25 search
82    /// * `top_k` - Number of results to return
83    ///
84    /// # Returns
85    /// Combined and re-ranked results
86    pub fn search(
87        &self,
88        vector_results: Vec<VectorResultRow>,
89        text_results: Vec<FullTextResult>,
90        top_k: usize,
91    ) -> Vec<HybridResult> {
92        // Collect all unique IDs and their raw scores
93        let mut vector_scores: HashMap<String, RawScore> = HashMap::new();
94        let mut text_scores: HashMap<String, f32> = HashMap::new();
95
96        // Track min/max for normalization
97        let mut vector_min = f32::MAX;
98        let mut vector_max = f32::MIN;
99        let mut text_min = f32::MAX;
100        let mut text_max = f32::MIN;
101
102        // Collect vector scores
103        for (id, score, metadata, vector) in vector_results {
104            vector_min = vector_min.min(score);
105            vector_max = vector_max.max(score);
106            vector_scores.insert(
107                id,
108                RawScore {
109                    score,
110                    metadata,
111                    vector,
112                },
113            );
114        }
115
116        // Collect text scores
117        for result in text_results {
118            text_min = text_min.min(result.score);
119            text_max = text_max.max(result.score);
120            text_scores.insert(result.doc_id, result.score);
121        }
122
123        // Get all unique IDs
124        let mut all_ids: Vec<String> = vector_scores
125            .keys()
126            .chain(text_scores.keys())
127            .cloned()
128            .collect();
129        all_ids.sort();
130        all_ids.dedup();
131
132        // Compute combined scores
133        let mut results: Vec<HybridResult> = Vec::new();
134
135        for id in all_ids {
136            let vector_raw = vector_scores.get(&id);
137            let text_raw = text_scores.get(&id);
138
139            // Skip if require_both and missing one
140            if self.config.require_both && (vector_raw.is_none() || text_raw.is_none()) {
141                continue;
142            }
143
144            // Normalize scores to 0-1 range
145            let vector_normalized = if let Some(raw) = vector_raw {
146                normalize_score(raw.score, vector_min, vector_max)
147            } else {
148                0.0
149            };
150
151            let text_normalized = if let Some(&score) = text_raw {
152                normalize_score(score, text_min, text_max)
153            } else {
154                0.0
155            };
156
157            // Compute weighted combination
158            let combined = self.config.vector_weight * vector_normalized
159                + (1.0 - self.config.vector_weight) * text_normalized;
160
161            // Get metadata and vector from vector results if available
162            let (metadata, vector) = if let Some(raw) = vector_raw {
163                (raw.metadata.clone(), raw.vector.clone())
164            } else {
165                (None, None)
166            };
167
168            results.push(HybridResult {
169                id,
170                combined_score: combined,
171                vector_score: vector_normalized,
172                text_score: text_normalized,
173                metadata,
174                vector,
175            });
176        }
177
178        // Sort by combined score (descending)
179        results.sort_by(|a, b| {
180            b.combined_score
181                .partial_cmp(&a.combined_score)
182                .unwrap_or(std::cmp::Ordering::Equal)
183        });
184
185        // Return top-k
186        results.truncate(top_k);
187        results
188    }
189}
190
191impl Default for HybridSearcher {
192    fn default() -> Self {
193        Self::new(HybridConfig::default())
194    }
195}
196
197/// Normalize a score to 0-1 range using min-max normalization
198fn normalize_score(score: f32, min: f32, max: f32) -> f32 {
199    if (max - min).abs() < f32::EPSILON {
200        // All scores are the same, return 1.0
201        1.0
202    } else {
203        (score - min) / (max - min)
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn test_hybrid_search_basic() {
213        let searcher = HybridSearcher::default();
214
215        let vector_results = vec![
216            ("doc1".to_string(), 0.9, None, None),
217            ("doc2".to_string(), 0.7, None, None),
218            ("doc3".to_string(), 0.5, None, None),
219        ];
220
221        let text_results = vec![
222            FullTextResult {
223                doc_id: "doc1".to_string(),
224                score: 3.0,
225            },
226            FullTextResult {
227                doc_id: "doc2".to_string(),
228                score: 4.0,
229            },
230            FullTextResult {
231                doc_id: "doc4".to_string(),
232                score: 2.0,
233            },
234        ];
235
236        let results = searcher.search(vector_results, text_results, 10);
237
238        // All 4 documents should be in results
239        assert_eq!(results.len(), 4);
240
241        // Check that doc1 and doc2 have both scores >= 0
242        // (normalized scores, min becomes 0.0)
243        let doc1 = results.iter().find(|r| r.id == "doc1").unwrap();
244        assert!(doc1.vector_score > 0.0);
245        assert!(doc1.text_score >= 0.0);
246        assert!(doc1.combined_score > 0.0);
247
248        let doc2 = results.iter().find(|r| r.id == "doc2").unwrap();
249        assert!(doc2.vector_score > 0.0);
250        assert!(doc2.text_score > 0.0); // doc2 has highest text score, should be 1.0
251        assert!(doc2.combined_score > 0.0);
252
253        // doc2 should have the highest text score (normalized to 1.0)
254        assert_eq!(doc2.text_score, 1.0);
255    }
256
257    #[test]
258    fn test_hybrid_search_vector_only() {
259        let searcher = HybridSearcher::new(HybridConfig {
260            vector_weight: 1.0,
261            require_both: false,
262        });
263
264        let vector_results = vec![
265            ("doc1".to_string(), 0.9, None, None),
266            ("doc2".to_string(), 0.5, None, None),
267        ];
268
269        let text_results = vec![FullTextResult {
270            doc_id: "doc1".to_string(),
271            score: 1.0,
272        }];
273
274        let results = searcher.search(vector_results, text_results, 10);
275
276        // doc1 should be first (highest vector score)
277        assert_eq!(results[0].id, "doc1");
278        assert_eq!(results[0].combined_score, results[0].vector_score);
279    }
280
281    #[test]
282    fn test_hybrid_search_text_only() {
283        let searcher = HybridSearcher::new(HybridConfig {
284            vector_weight: 0.0,
285            require_both: false,
286        });
287
288        let vector_results = vec![
289            ("doc1".to_string(), 0.9, None, None),
290            ("doc2".to_string(), 0.5, None, None),
291        ];
292
293        let text_results = vec![
294            FullTextResult {
295                doc_id: "doc1".to_string(),
296                score: 1.0,
297            },
298            FullTextResult {
299                doc_id: "doc2".to_string(),
300                score: 3.0,
301            },
302        ];
303
304        let results = searcher.search(vector_results, text_results, 10);
305
306        // doc2 should be first (highest text score)
307        assert_eq!(results[0].id, "doc2");
308        assert_eq!(results[0].combined_score, results[0].text_score);
309    }
310
311    #[test]
312    fn test_hybrid_search_require_both() {
313        let searcher = HybridSearcher::new(HybridConfig {
314            vector_weight: 0.5,
315            require_both: true,
316        });
317
318        let vector_results = vec![
319            ("doc1".to_string(), 0.9, None, None),
320            ("doc2".to_string(), 0.7, None, None),
321        ];
322
323        let text_results = vec![FullTextResult {
324            doc_id: "doc1".to_string(),
325            score: 2.0,
326        }];
327
328        let results = searcher.search(vector_results, text_results, 10);
329
330        // Only doc1 should be in results (only one with both scores)
331        assert_eq!(results.len(), 1);
332        assert_eq!(results[0].id, "doc1");
333    }
334
335    #[test]
336    fn test_hybrid_search_top_k() {
337        let searcher = HybridSearcher::default();
338
339        let vector_results = vec![
340            ("doc1".to_string(), 0.9, None, None),
341            ("doc2".to_string(), 0.8, None, None),
342            ("doc3".to_string(), 0.7, None, None),
343            ("doc4".to_string(), 0.6, None, None),
344            ("doc5".to_string(), 0.5, None, None),
345        ];
346
347        let text_results = vec![];
348
349        let results = searcher.search(vector_results, text_results, 3);
350
351        assert_eq!(results.len(), 3);
352    }
353
354    #[test]
355    fn test_hybrid_search_with_metadata() {
356        let searcher = HybridSearcher::default();
357
358        let metadata = serde_json::json!({"title": "Test Document"});
359        let vector = vec![1.0, 0.0, 0.0];
360
361        let vector_results = vec![(
362            "doc1".to_string(),
363            0.9,
364            Some(metadata.clone()),
365            Some(vector.clone()),
366        )];
367
368        let text_results = vec![FullTextResult {
369            doc_id: "doc1".to_string(),
370            score: 2.0,
371        }];
372
373        let results = searcher.search(vector_results, text_results, 10);
374
375        assert_eq!(results.len(), 1);
376        assert_eq!(results[0].metadata, Some(metadata));
377        assert_eq!(results[0].vector, Some(vector));
378    }
379
380    #[test]
381    fn test_normalize_score() {
382        // Normal case
383        assert_eq!(normalize_score(5.0, 0.0, 10.0), 0.5);
384        assert_eq!(normalize_score(0.0, 0.0, 10.0), 0.0);
385        assert_eq!(normalize_score(10.0, 0.0, 10.0), 1.0);
386
387        // All same scores
388        assert_eq!(normalize_score(5.0, 5.0, 5.0), 1.0);
389    }
390
391    #[test]
392    fn test_hybrid_searcher_builder() {
393        let searcher = HybridSearcher::default().with_vector_weight(0.7);
394
395        assert_eq!(searcher.config.vector_weight, 0.7);
396    }
397
398    #[test]
399    fn test_vector_weight_clamping() {
400        let searcher1 = HybridSearcher::default().with_vector_weight(1.5);
401        assert_eq!(searcher1.config.vector_weight, 1.0);
402
403        let searcher2 = HybridSearcher::default().with_vector_weight(-0.5);
404        assert_eq!(searcher2.config.vector_weight, 0.0);
405    }
406}