manifoldb_vector/ops/
hybrid.rs

1//! Hybrid dense+sparse vector search.
2//!
3//! This module provides support for combining dense and sparse vector similarity
4//! scores using weighted combinations. This is useful for hybrid retrieval systems
5//! that combine semantic (dense) and lexical (sparse) similarity.
6//!
7//! # Example
8//!
9//! ```ignore
10//! use manifoldb_vector::ops::{HybridScore, HybridSearch};
11//!
12//! // Create a hybrid search with 0.7 weight for dense and 0.3 for sparse
13//! let hybrid = HybridSearch::new(0.7, 0.3);
14//!
15//! // Combine scores (lower is better for distance-based scores)
16//! let combined = hybrid.combine_scores(0.1, 0.2);
17//! ```
18
19use std::collections::HashMap;
20
21use manifoldb_core::EntityId;
22
23use super::VectorMatch;
24
25/// Configuration for hybrid dense+sparse vector search.
26///
27/// Combines dense (semantic) and sparse (lexical) similarity scores
28/// using a weighted combination.
29#[derive(Debug, Clone, Copy)]
30pub struct HybridConfig {
31    /// Weight for dense vector scores (0.0 to 1.0).
32    pub dense_weight: f32,
33    /// Weight for sparse vector scores (0.0 to 1.0).
34    pub sparse_weight: f32,
35    /// Whether to normalize scores before combining (recommended).
36    pub normalize: bool,
37}
38
39impl HybridConfig {
40    /// Create a new hybrid configuration with the given weights.
41    ///
42    /// Weights should sum to 1.0 for proper score interpretation,
43    /// but this is not enforced.
44    ///
45    /// # Example
46    ///
47    /// ```
48    /// use manifoldb_vector::ops::hybrid::HybridConfig;
49    ///
50    /// // 70% dense, 30% sparse
51    /// let config = HybridConfig::new(0.7, 0.3);
52    /// assert!((config.dense_weight - 0.7).abs() < 1e-6);
53    /// ```
54    #[must_use]
55    pub const fn new(dense_weight: f32, sparse_weight: f32) -> Self {
56        Self { dense_weight, sparse_weight, normalize: true }
57    }
58
59    /// Create a dense-only configuration.
60    #[must_use]
61    pub const fn dense_only() -> Self {
62        Self::new(1.0, 0.0)
63    }
64
65    /// Create a sparse-only configuration.
66    #[must_use]
67    pub const fn sparse_only() -> Self {
68        Self::new(0.0, 1.0)
69    }
70
71    /// Create an equal weighting configuration.
72    #[must_use]
73    pub const fn equal() -> Self {
74        Self::new(0.5, 0.5)
75    }
76
77    /// Disable score normalization.
78    #[must_use]
79    pub const fn without_normalization(mut self) -> Self {
80        self.normalize = false;
81        self
82    }
83
84    /// Combine dense and sparse distance scores.
85    ///
86    /// For distance-based metrics (where lower is better), this computes
87    /// a weighted sum: `dense_weight * dense_score + sparse_weight * sparse_score`
88    ///
89    /// # Arguments
90    ///
91    /// * `dense_score` - Distance score from dense vector search (lower is better)
92    /// * `sparse_score` - Distance score from sparse vector search (lower is better)
93    #[inline]
94    #[must_use]
95    pub fn combine_distances(&self, dense_score: f32, sparse_score: f32) -> f32 {
96        self.dense_weight * dense_score + self.sparse_weight * sparse_score
97    }
98
99    /// Combine dense and sparse similarity scores.
100    ///
101    /// For similarity-based metrics (where higher is better), this computes
102    /// a weighted sum: `dense_weight * dense_sim + sparse_weight * sparse_sim`
103    ///
104    /// The result is then converted to a distance (1 - similarity).
105    #[inline]
106    #[must_use]
107    pub fn combine_similarities(&self, dense_sim: f32, sparse_sim: f32) -> f32 {
108        let combined_sim = self.dense_weight * dense_sim + self.sparse_weight * sparse_sim;
109        1.0 - combined_sim
110    }
111}
112
113impl Default for HybridConfig {
114    fn default() -> Self {
115        Self::equal()
116    }
117}
118
119/// Result of a hybrid search operation.
120#[derive(Debug, Clone, Copy)]
121pub struct HybridMatch {
122    /// The entity ID of the matching vector.
123    pub entity_id: EntityId,
124    /// The combined distance score.
125    pub combined_distance: f32,
126    /// The dense vector distance (if available).
127    pub dense_distance: Option<f32>,
128    /// The sparse vector distance (if available).
129    pub sparse_distance: Option<f32>,
130}
131
132impl HybridMatch {
133    /// Create a new hybrid match with both dense and sparse scores.
134    #[must_use]
135    pub const fn new(
136        entity_id: EntityId,
137        combined_distance: f32,
138        dense_distance: Option<f32>,
139        sparse_distance: Option<f32>,
140    ) -> Self {
141        Self { entity_id, combined_distance, dense_distance, sparse_distance }
142    }
143
144    /// Create a dense-only match.
145    #[must_use]
146    pub const fn dense_only(entity_id: EntityId, distance: f32) -> Self {
147        Self::new(entity_id, distance, Some(distance), None)
148    }
149
150    /// Create a sparse-only match.
151    #[must_use]
152    pub const fn sparse_only(entity_id: EntityId, distance: f32) -> Self {
153        Self::new(entity_id, distance, None, Some(distance))
154    }
155}
156
157impl From<HybridMatch> for VectorMatch {
158    fn from(m: HybridMatch) -> Self {
159        VectorMatch::new(m.entity_id, m.combined_distance)
160    }
161}
162
163/// Merge and re-rank results from dense and sparse searches.
164///
165/// This function takes results from separate dense and sparse searches
166/// and combines them using the provided hybrid configuration.
167///
168/// # Algorithm
169///
170/// 1. Normalize scores if configured (min-max normalization)
171/// 2. For entities with both dense and sparse scores: compute weighted combination
172/// 3. For entities with only one score: use that score with its weight (other score = 1.0)
173/// 4. Sort by combined score
174/// 5. Return top K results
175///
176/// # Arguments
177///
178/// * `dense_results` - Results from dense vector search
179/// * `sparse_results` - Results from sparse vector search
180/// * `config` - Hybrid search configuration
181/// * `k` - Maximum number of results to return
182pub fn merge_results(
183    dense_results: &[VectorMatch],
184    sparse_results: &[VectorMatch],
185    config: &HybridConfig,
186    k: usize,
187) -> Vec<HybridMatch> {
188    // Collect scores by entity ID
189    let mut dense_scores: HashMap<EntityId, f32> = HashMap::new();
190    let mut sparse_scores: HashMap<EntityId, f32> = HashMap::new();
191
192    for m in dense_results {
193        dense_scores.insert(m.entity_id, m.distance);
194    }
195
196    for m in sparse_results {
197        sparse_scores.insert(m.entity_id, m.distance);
198    }
199
200    // Normalize scores if configured
201    let (dense_min, dense_max) = if config.normalize && !dense_results.is_empty() {
202        let min = dense_results.iter().map(|m| m.distance).fold(f32::INFINITY, f32::min);
203        let max = dense_results.iter().map(|m| m.distance).fold(f32::NEG_INFINITY, f32::max);
204        (min, max)
205    } else {
206        (0.0, 1.0)
207    };
208
209    let (sparse_min, sparse_max) = if config.normalize && !sparse_results.is_empty() {
210        let min = sparse_results.iter().map(|m| m.distance).fold(f32::INFINITY, f32::min);
211        let max = sparse_results.iter().map(|m| m.distance).fold(f32::NEG_INFINITY, f32::max);
212        (min, max)
213    } else {
214        (0.0, 1.0)
215    };
216
217    // Collect all entity IDs
218    let all_entities: Vec<EntityId> = dense_scores
219        .keys()
220        .chain(sparse_scores.keys())
221        .copied()
222        .collect::<std::collections::HashSet<_>>()
223        .into_iter()
224        .collect();
225
226    // Compute combined scores
227    let mut results: Vec<HybridMatch> = all_entities
228        .into_iter()
229        .map(|entity_id| {
230            let dense_dist = dense_scores.get(&entity_id).copied();
231            let sparse_dist = sparse_scores.get(&entity_id).copied();
232
233            // Normalize scores to [0, 1] range
234            let norm_dense = dense_dist.map(|d| {
235                if dense_max - dense_min > 0.0 {
236                    (d - dense_min) / (dense_max - dense_min)
237                } else {
238                    0.0
239                }
240            });
241
242            let norm_sparse = sparse_dist.map(|d| {
243                if sparse_max - sparse_min > 0.0 {
244                    (d - sparse_min) / (sparse_max - sparse_min)
245                } else {
246                    0.0
247                }
248            });
249
250            // Compute combined distance
251            // If one score is missing, use 1.0 (worst normalized distance)
252            let combined = match (norm_dense, norm_sparse) {
253                (Some(d), Some(s)) => config.combine_distances(d, s),
254                (Some(d), None) => config.combine_distances(d, 1.0),
255                (None, Some(s)) => config.combine_distances(1.0, s),
256                (None, None) => 1.0, // Should not happen
257            };
258
259            HybridMatch::new(entity_id, combined, dense_dist, sparse_dist)
260        })
261        .collect();
262
263    // Sort by combined distance (lower is better)
264    results.sort_by(|a, b| {
265        a.combined_distance.partial_cmp(&b.combined_distance).unwrap_or(std::cmp::Ordering::Equal)
266    });
267
268    // Return top K
269    results.truncate(k);
270    results
271}
272
273/// Reciprocal Rank Fusion (RRF) for combining ranked lists.
274///
275/// RRF is a simple but effective method for combining multiple ranked lists.
276/// Each item's score is computed as the sum of 1/(k + rank) across all lists.
277///
278/// # Arguments
279///
280/// * `dense_results` - Results from dense vector search (in rank order)
281/// * `sparse_results` - Results from sparse vector search (in rank order)
282/// * `k_param` - RRF parameter (typically 60)
283/// * `top_k` - Maximum number of results to return
284pub fn reciprocal_rank_fusion(
285    dense_results: &[VectorMatch],
286    sparse_results: &[VectorMatch],
287    k_param: u32,
288    top_k: usize,
289) -> Vec<HybridMatch> {
290    let mut rrf_scores: HashMap<EntityId, f32> = HashMap::new();
291    let mut dense_distances: HashMap<EntityId, f32> = HashMap::new();
292    let mut sparse_distances: HashMap<EntityId, f32> = HashMap::new();
293
294    // Add dense results
295    for (rank, m) in dense_results.iter().enumerate() {
296        let score = 1.0 / (k_param as f32 + rank as f32 + 1.0);
297        *rrf_scores.entry(m.entity_id).or_insert(0.0) += score;
298        dense_distances.insert(m.entity_id, m.distance);
299    }
300
301    // Add sparse results
302    for (rank, m) in sparse_results.iter().enumerate() {
303        let score = 1.0 / (k_param as f32 + rank as f32 + 1.0);
304        *rrf_scores.entry(m.entity_id).or_insert(0.0) += score;
305        sparse_distances.insert(m.entity_id, m.distance);
306    }
307
308    // Convert RRF scores to distances (higher RRF score = lower distance)
309    let max_score = rrf_scores.values().fold(0.0f32, |a, &b| a.max(b));
310
311    let mut results: Vec<HybridMatch> = rrf_scores
312        .into_iter()
313        .map(|(entity_id, score)| {
314            // Convert score to distance: higher score = lower distance
315            let combined_distance = if max_score > 0.0 { 1.0 - (score / max_score) } else { 1.0 };
316
317            HybridMatch::new(
318                entity_id,
319                combined_distance,
320                dense_distances.get(&entity_id).copied(),
321                sparse_distances.get(&entity_id).copied(),
322            )
323        })
324        .collect();
325
326    // Sort by combined distance (lower is better)
327    results.sort_by(|a, b| {
328        a.combined_distance.partial_cmp(&b.combined_distance).unwrap_or(std::cmp::Ordering::Equal)
329    });
330
331    results.truncate(top_k);
332    results
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    const EPSILON: f32 = 1e-5;
340
341    fn assert_near(a: f32, b: f32, epsilon: f32) {
342        assert!(
343            (a - b).abs() < epsilon,
344            "assertion failed: {} !~ {} (diff: {})",
345            a,
346            b,
347            (a - b).abs()
348        );
349    }
350
351    #[test]
352    fn hybrid_config_weights() {
353        let config = HybridConfig::new(0.7, 0.3);
354        assert_near(config.dense_weight, 0.7, EPSILON);
355        assert_near(config.sparse_weight, 0.3, EPSILON);
356    }
357
358    #[test]
359    fn hybrid_config_presets() {
360        let dense_only = HybridConfig::dense_only();
361        assert_near(dense_only.dense_weight, 1.0, EPSILON);
362        assert_near(dense_only.sparse_weight, 0.0, EPSILON);
363
364        let sparse_only = HybridConfig::sparse_only();
365        assert_near(sparse_only.dense_weight, 0.0, EPSILON);
366        assert_near(sparse_only.sparse_weight, 1.0, EPSILON);
367
368        let equal = HybridConfig::equal();
369        assert_near(equal.dense_weight, 0.5, EPSILON);
370        assert_near(equal.sparse_weight, 0.5, EPSILON);
371    }
372
373    #[test]
374    fn combine_distances() {
375        let config = HybridConfig::new(0.7, 0.3);
376        let combined = config.combine_distances(0.1, 0.2);
377        // 0.7 * 0.1 + 0.3 * 0.2 = 0.07 + 0.06 = 0.13
378        assert_near(combined, 0.13, EPSILON);
379    }
380
381    #[test]
382    fn combine_similarities() {
383        let config = HybridConfig::new(0.7, 0.3);
384        let combined = config.combine_similarities(0.9, 0.8);
385        // Combined sim = 0.7 * 0.9 + 0.3 * 0.8 = 0.63 + 0.24 = 0.87
386        // Distance = 1 - 0.87 = 0.13
387        assert_near(combined, 0.13, EPSILON);
388    }
389
390    #[test]
391    fn merge_results_both_present() {
392        let dense =
393            vec![VectorMatch::new(EntityId::new(1), 0.1), VectorMatch::new(EntityId::new(2), 0.2)];
394        let sparse =
395            vec![VectorMatch::new(EntityId::new(1), 0.3), VectorMatch::new(EntityId::new(3), 0.1)];
396
397        let config = HybridConfig::equal().without_normalization();
398        let results = merge_results(&dense, &sparse, &config, 10);
399
400        assert_eq!(results.len(), 3);
401
402        // Entity 1 should have both scores
403        let e1 = results.iter().find(|m| m.entity_id == EntityId::new(1)).unwrap();
404        assert!(e1.dense_distance.is_some());
405        assert!(e1.sparse_distance.is_some());
406    }
407
408    #[test]
409    fn merge_results_respects_k() {
410        let dense = vec![
411            VectorMatch::new(EntityId::new(1), 0.1),
412            VectorMatch::new(EntityId::new(2), 0.2),
413            VectorMatch::new(EntityId::new(3), 0.3),
414        ];
415        let sparse =
416            vec![VectorMatch::new(EntityId::new(4), 0.1), VectorMatch::new(EntityId::new(5), 0.2)];
417
418        let config = HybridConfig::equal();
419        let results = merge_results(&dense, &sparse, &config, 3);
420
421        assert_eq!(results.len(), 3);
422    }
423
424    #[test]
425    fn reciprocal_rank_fusion_basic() {
426        let dense = vec![
427            VectorMatch::new(EntityId::new(1), 0.1),
428            VectorMatch::new(EntityId::new(2), 0.2),
429            VectorMatch::new(EntityId::new(3), 0.3),
430        ];
431        let sparse = vec![
432            VectorMatch::new(EntityId::new(2), 0.1),
433            VectorMatch::new(EntityId::new(1), 0.2),
434            VectorMatch::new(EntityId::new(4), 0.3),
435        ];
436
437        let results = reciprocal_rank_fusion(&dense, &sparse, 60, 10);
438
439        // Entity 1 and 2 should have higher RRF scores (lower distances)
440        // as they appear in both lists
441        assert!(results.len() >= 2);
442
443        // First two results should be entities 1 and 2 (in some order)
444        let top_two: Vec<EntityId> = results.iter().take(2).map(|m| m.entity_id).collect();
445        assert!(top_two.contains(&EntityId::new(1)));
446        assert!(top_two.contains(&EntityId::new(2)));
447    }
448
449    #[test]
450    fn reciprocal_rank_fusion_respects_top_k() {
451        let dense: Vec<VectorMatch> =
452            (1..=10).map(|i| VectorMatch::new(EntityId::new(i), i as f32 * 0.1)).collect();
453        let sparse: Vec<VectorMatch> =
454            (5..=15).map(|i| VectorMatch::new(EntityId::new(i), i as f32 * 0.1)).collect();
455
456        let results = reciprocal_rank_fusion(&dense, &sparse, 60, 5);
457
458        assert_eq!(results.len(), 5);
459    }
460
461    #[test]
462    fn hybrid_match_conversions() {
463        let hybrid = HybridMatch::dense_only(EntityId::new(1), 0.5);
464        assert!(hybrid.dense_distance.is_some());
465        assert!(hybrid.sparse_distance.is_none());
466
467        let hybrid = HybridMatch::sparse_only(EntityId::new(2), 0.3);
468        assert!(hybrid.dense_distance.is_none());
469        assert!(hybrid.sparse_distance.is_some());
470
471        let vector_match: VectorMatch = hybrid.into();
472        assert_eq!(vector_match.entity_id, EntityId::new(2));
473    }
474}