1use std::collections::HashMap;
20
21use manifoldb_core::EntityId;
22
23use super::VectorMatch;
24
25#[derive(Debug, Clone, Copy)]
30pub struct HybridConfig {
31 pub dense_weight: f32,
33 pub sparse_weight: f32,
35 pub normalize: bool,
37}
38
39impl HybridConfig {
40 #[must_use]
55 pub const fn new(dense_weight: f32, sparse_weight: f32) -> Self {
56 Self { dense_weight, sparse_weight, normalize: true }
57 }
58
59 #[must_use]
61 pub const fn dense_only() -> Self {
62 Self::new(1.0, 0.0)
63 }
64
65 #[must_use]
67 pub const fn sparse_only() -> Self {
68 Self::new(0.0, 1.0)
69 }
70
71 #[must_use]
73 pub const fn equal() -> Self {
74 Self::new(0.5, 0.5)
75 }
76
77 #[must_use]
79 pub const fn without_normalization(mut self) -> Self {
80 self.normalize = false;
81 self
82 }
83
84 #[inline]
94 #[must_use]
95 pub fn combine_distances(&self, dense_score: f32, sparse_score: f32) -> f32 {
96 self.dense_weight * dense_score + self.sparse_weight * sparse_score
97 }
98
99 #[inline]
106 #[must_use]
107 pub fn combine_similarities(&self, dense_sim: f32, sparse_sim: f32) -> f32 {
108 let combined_sim = self.dense_weight * dense_sim + self.sparse_weight * sparse_sim;
109 1.0 - combined_sim
110 }
111}
112
113impl Default for HybridConfig {
114 fn default() -> Self {
115 Self::equal()
116 }
117}
118
119#[derive(Debug, Clone, Copy)]
121pub struct HybridMatch {
122 pub entity_id: EntityId,
124 pub combined_distance: f32,
126 pub dense_distance: Option<f32>,
128 pub sparse_distance: Option<f32>,
130}
131
132impl HybridMatch {
133 #[must_use]
135 pub const fn new(
136 entity_id: EntityId,
137 combined_distance: f32,
138 dense_distance: Option<f32>,
139 sparse_distance: Option<f32>,
140 ) -> Self {
141 Self { entity_id, combined_distance, dense_distance, sparse_distance }
142 }
143
144 #[must_use]
146 pub const fn dense_only(entity_id: EntityId, distance: f32) -> Self {
147 Self::new(entity_id, distance, Some(distance), None)
148 }
149
150 #[must_use]
152 pub const fn sparse_only(entity_id: EntityId, distance: f32) -> Self {
153 Self::new(entity_id, distance, None, Some(distance))
154 }
155}
156
157impl From<HybridMatch> for VectorMatch {
158 fn from(m: HybridMatch) -> Self {
159 VectorMatch::new(m.entity_id, m.combined_distance)
160 }
161}
162
163pub fn merge_results(
183 dense_results: &[VectorMatch],
184 sparse_results: &[VectorMatch],
185 config: &HybridConfig,
186 k: usize,
187) -> Vec<HybridMatch> {
188 let mut dense_scores: HashMap<EntityId, f32> = HashMap::new();
190 let mut sparse_scores: HashMap<EntityId, f32> = HashMap::new();
191
192 for m in dense_results {
193 dense_scores.insert(m.entity_id, m.distance);
194 }
195
196 for m in sparse_results {
197 sparse_scores.insert(m.entity_id, m.distance);
198 }
199
200 let (dense_min, dense_max) = if config.normalize && !dense_results.is_empty() {
202 let min = dense_results.iter().map(|m| m.distance).fold(f32::INFINITY, f32::min);
203 let max = dense_results.iter().map(|m| m.distance).fold(f32::NEG_INFINITY, f32::max);
204 (min, max)
205 } else {
206 (0.0, 1.0)
207 };
208
209 let (sparse_min, sparse_max) = if config.normalize && !sparse_results.is_empty() {
210 let min = sparse_results.iter().map(|m| m.distance).fold(f32::INFINITY, f32::min);
211 let max = sparse_results.iter().map(|m| m.distance).fold(f32::NEG_INFINITY, f32::max);
212 (min, max)
213 } else {
214 (0.0, 1.0)
215 };
216
217 let all_entities: Vec<EntityId> = dense_scores
219 .keys()
220 .chain(sparse_scores.keys())
221 .copied()
222 .collect::<std::collections::HashSet<_>>()
223 .into_iter()
224 .collect();
225
226 let mut results: Vec<HybridMatch> = all_entities
228 .into_iter()
229 .map(|entity_id| {
230 let dense_dist = dense_scores.get(&entity_id).copied();
231 let sparse_dist = sparse_scores.get(&entity_id).copied();
232
233 let norm_dense = dense_dist.map(|d| {
235 if dense_max - dense_min > 0.0 {
236 (d - dense_min) / (dense_max - dense_min)
237 } else {
238 0.0
239 }
240 });
241
242 let norm_sparse = sparse_dist.map(|d| {
243 if sparse_max - sparse_min > 0.0 {
244 (d - sparse_min) / (sparse_max - sparse_min)
245 } else {
246 0.0
247 }
248 });
249
250 let combined = match (norm_dense, norm_sparse) {
253 (Some(d), Some(s)) => config.combine_distances(d, s),
254 (Some(d), None) => config.combine_distances(d, 1.0),
255 (None, Some(s)) => config.combine_distances(1.0, s),
256 (None, None) => 1.0, };
258
259 HybridMatch::new(entity_id, combined, dense_dist, sparse_dist)
260 })
261 .collect();
262
263 results.sort_by(|a, b| {
265 a.combined_distance.partial_cmp(&b.combined_distance).unwrap_or(std::cmp::Ordering::Equal)
266 });
267
268 results.truncate(k);
270 results
271}
272
273pub fn reciprocal_rank_fusion(
285 dense_results: &[VectorMatch],
286 sparse_results: &[VectorMatch],
287 k_param: u32,
288 top_k: usize,
289) -> Vec<HybridMatch> {
290 let mut rrf_scores: HashMap<EntityId, f32> = HashMap::new();
291 let mut dense_distances: HashMap<EntityId, f32> = HashMap::new();
292 let mut sparse_distances: HashMap<EntityId, f32> = HashMap::new();
293
294 for (rank, m) in dense_results.iter().enumerate() {
296 let score = 1.0 / (k_param as f32 + rank as f32 + 1.0);
297 *rrf_scores.entry(m.entity_id).or_insert(0.0) += score;
298 dense_distances.insert(m.entity_id, m.distance);
299 }
300
301 for (rank, m) in sparse_results.iter().enumerate() {
303 let score = 1.0 / (k_param as f32 + rank as f32 + 1.0);
304 *rrf_scores.entry(m.entity_id).or_insert(0.0) += score;
305 sparse_distances.insert(m.entity_id, m.distance);
306 }
307
308 let max_score = rrf_scores.values().fold(0.0f32, |a, &b| a.max(b));
310
311 let mut results: Vec<HybridMatch> = rrf_scores
312 .into_iter()
313 .map(|(entity_id, score)| {
314 let combined_distance = if max_score > 0.0 { 1.0 - (score / max_score) } else { 1.0 };
316
317 HybridMatch::new(
318 entity_id,
319 combined_distance,
320 dense_distances.get(&entity_id).copied(),
321 sparse_distances.get(&entity_id).copied(),
322 )
323 })
324 .collect();
325
326 results.sort_by(|a, b| {
328 a.combined_distance.partial_cmp(&b.combined_distance).unwrap_or(std::cmp::Ordering::Equal)
329 });
330
331 results.truncate(top_k);
332 results
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 const EPSILON: f32 = 1e-5;
340
341 fn assert_near(a: f32, b: f32, epsilon: f32) {
342 assert!(
343 (a - b).abs() < epsilon,
344 "assertion failed: {} !~ {} (diff: {})",
345 a,
346 b,
347 (a - b).abs()
348 );
349 }
350
351 #[test]
352 fn hybrid_config_weights() {
353 let config = HybridConfig::new(0.7, 0.3);
354 assert_near(config.dense_weight, 0.7, EPSILON);
355 assert_near(config.sparse_weight, 0.3, EPSILON);
356 }
357
358 #[test]
359 fn hybrid_config_presets() {
360 let dense_only = HybridConfig::dense_only();
361 assert_near(dense_only.dense_weight, 1.0, EPSILON);
362 assert_near(dense_only.sparse_weight, 0.0, EPSILON);
363
364 let sparse_only = HybridConfig::sparse_only();
365 assert_near(sparse_only.dense_weight, 0.0, EPSILON);
366 assert_near(sparse_only.sparse_weight, 1.0, EPSILON);
367
368 let equal = HybridConfig::equal();
369 assert_near(equal.dense_weight, 0.5, EPSILON);
370 assert_near(equal.sparse_weight, 0.5, EPSILON);
371 }
372
373 #[test]
374 fn combine_distances() {
375 let config = HybridConfig::new(0.7, 0.3);
376 let combined = config.combine_distances(0.1, 0.2);
377 assert_near(combined, 0.13, EPSILON);
379 }
380
381 #[test]
382 fn combine_similarities() {
383 let config = HybridConfig::new(0.7, 0.3);
384 let combined = config.combine_similarities(0.9, 0.8);
385 assert_near(combined, 0.13, EPSILON);
388 }
389
390 #[test]
391 fn merge_results_both_present() {
392 let dense =
393 vec![VectorMatch::new(EntityId::new(1), 0.1), VectorMatch::new(EntityId::new(2), 0.2)];
394 let sparse =
395 vec![VectorMatch::new(EntityId::new(1), 0.3), VectorMatch::new(EntityId::new(3), 0.1)];
396
397 let config = HybridConfig::equal().without_normalization();
398 let results = merge_results(&dense, &sparse, &config, 10);
399
400 assert_eq!(results.len(), 3);
401
402 let e1 = results.iter().find(|m| m.entity_id == EntityId::new(1)).unwrap();
404 assert!(e1.dense_distance.is_some());
405 assert!(e1.sparse_distance.is_some());
406 }
407
408 #[test]
409 fn merge_results_respects_k() {
410 let dense = vec![
411 VectorMatch::new(EntityId::new(1), 0.1),
412 VectorMatch::new(EntityId::new(2), 0.2),
413 VectorMatch::new(EntityId::new(3), 0.3),
414 ];
415 let sparse =
416 vec![VectorMatch::new(EntityId::new(4), 0.1), VectorMatch::new(EntityId::new(5), 0.2)];
417
418 let config = HybridConfig::equal();
419 let results = merge_results(&dense, &sparse, &config, 3);
420
421 assert_eq!(results.len(), 3);
422 }
423
424 #[test]
425 fn reciprocal_rank_fusion_basic() {
426 let dense = vec![
427 VectorMatch::new(EntityId::new(1), 0.1),
428 VectorMatch::new(EntityId::new(2), 0.2),
429 VectorMatch::new(EntityId::new(3), 0.3),
430 ];
431 let sparse = vec![
432 VectorMatch::new(EntityId::new(2), 0.1),
433 VectorMatch::new(EntityId::new(1), 0.2),
434 VectorMatch::new(EntityId::new(4), 0.3),
435 ];
436
437 let results = reciprocal_rank_fusion(&dense, &sparse, 60, 10);
438
439 assert!(results.len() >= 2);
442
443 let top_two: Vec<EntityId> = results.iter().take(2).map(|m| m.entity_id).collect();
445 assert!(top_two.contains(&EntityId::new(1)));
446 assert!(top_two.contains(&EntityId::new(2)));
447 }
448
449 #[test]
450 fn reciprocal_rank_fusion_respects_top_k() {
451 let dense: Vec<VectorMatch> =
452 (1..=10).map(|i| VectorMatch::new(EntityId::new(i), i as f32 * 0.1)).collect();
453 let sparse: Vec<VectorMatch> =
454 (5..=15).map(|i| VectorMatch::new(EntityId::new(i), i as f32 * 0.1)).collect();
455
456 let results = reciprocal_rank_fusion(&dense, &sparse, 60, 5);
457
458 assert_eq!(results.len(), 5);
459 }
460
461 #[test]
462 fn hybrid_match_conversions() {
463 let hybrid = HybridMatch::dense_only(EntityId::new(1), 0.5);
464 assert!(hybrid.dense_distance.is_some());
465 assert!(hybrid.sparse_distance.is_none());
466
467 let hybrid = HybridMatch::sparse_only(EntityId::new(2), 0.3);
468 assert!(hybrid.dense_distance.is_none());
469 assert!(hybrid.sparse_distance.is_some());
470
471 let vector_match: VectorMatch = hybrid.into();
472 assert_eq!(vector_match.entity_id, EntityId::new(2));
473 }
474}