Skip to main content

khive_runtime/
fusion.rs

1//! Fusion strategies for combining ranked result lists.
2
3use std::collections::{HashMap, HashSet};
4
5use serde::{Deserialize, Serialize};
6use uuid::Uuid;
7
8use khive_score::{rrf_score, DeterministicScore};
9use khive_storage::types::{
10    PageRequest, TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
11};
12use khive_storage::EntityFilter;
13use khive_types::SubstrateKind;
14
15use crate::error::RuntimeResult;
16use crate::retrieval::{SearchHit, SearchSource};
17use crate::runtime::KhiveRuntime;
18
19const CANDIDATE_MULTIPLIER: u32 = 4;
20
21/// Strategy for fusing ranked result lists from multiple retrieval sources.
22#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
23#[serde(rename_all = "snake_case")]
24pub enum FusionStrategy {
25    /// Reciprocal Rank Fusion. Uses only ranks; robust to different score scales.
26    Rrf { k: usize },
27    /// Weighted linear combination. Min-max normalizes each source to [0,1] first.
28    /// Weights are normalized to sum to 1.0; negatives clamped to 0; all-zero falls back to equal.
29    Weighted { weights: Vec<f64> },
30    /// Take all hits; keep the max score per entity_id.
31    Union,
32    /// Drop text hits; return vector hits only.
33    VectorOnly,
34}
35
36impl Default for FusionStrategy {
37    fn default() -> Self {
38        Self::Rrf { k: 60 }
39    }
40}
41
42/// Fuse text and vector hits using the given strategy, returning at most `limit` results.
43pub fn fuse_with_strategy(
44    text_hits: Vec<TextSearchHit>,
45    vector_hits: Vec<VectorSearchHit>,
46    strategy: &FusionStrategy,
47    limit: usize,
48) -> Vec<SearchHit> {
49    match strategy {
50        FusionStrategy::Rrf { k } => rrf_fuse_k(text_hits, vector_hits, *k, limit),
51        FusionStrategy::Weighted { weights } => {
52            weighted_fuse(text_hits, vector_hits, weights, limit)
53        }
54        FusionStrategy::Union => union_fuse(text_hits, vector_hits, limit),
55        FusionStrategy::VectorOnly => vector_only(vector_hits, limit),
56    }
57}
58
59impl KhiveRuntime {
60    /// Hybrid search with a caller-supplied fusion strategy.
61    pub async fn hybrid_search_with_strategy(
62        &self,
63        namespace: Option<&str>,
64        query_text: &str,
65        query_vector: Option<Vec<f32>>,
66        strategy: FusionStrategy,
67        limit: u32,
68    ) -> RuntimeResult<Vec<SearchHit>> {
69        let candidates = limit.saturating_mul(CANDIDATE_MULTIPLIER).max(limit);
70
71        let ns = self.ns(namespace).to_string();
72        let text_hits = self
73            .text(namespace)?
74            .search(TextSearchRequest {
75                query: query_text.to_string(),
76                mode: TextQueryMode::Plain,
77                filter: Some(TextFilter {
78                    namespaces: vec![ns.clone()],
79                    ..TextFilter::default()
80                }),
81                top_k: candidates,
82                snippet_chars: 200,
83            })
84            .await?;
85
86        let vector_hits = if query_vector.is_some() || self.config().embedding_model.is_some() {
87            self.vector_search(
88                namespace,
89                query_vector,
90                Some(query_text),
91                candidates,
92                Some(SubstrateKind::Entity),
93            )
94            .await?
95        } else {
96            Vec::new()
97        };
98
99        let mut fused = fuse_with_strategy(text_hits, vector_hits, &strategy, limit as usize);
100
101        // Filter out soft-deleted entities. A single query fetches all alive IDs from the
102        // fused set; any ID absent from the result has been soft-deleted (deleted_at IS NOT NULL).
103        if !fused.is_empty() {
104            let candidate_ids: Vec<Uuid> = fused.iter().map(|h| h.entity_id).collect();
105            let alive_page = self
106                .entities(namespace)?
107                .query_entities(
108                    self.ns(namespace),
109                    EntityFilter {
110                        ids: candidate_ids,
111                        ..EntityFilter::default()
112                    },
113                    PageRequest {
114                        offset: 0,
115                        limit: fused.len() as u32,
116                    },
117                )
118                .await?;
119            let alive: HashSet<Uuid> = alive_page.items.into_iter().map(|e| e.id).collect();
120            fused.retain(|h| alive.contains(&h.entity_id));
121        }
122
123        Ok(fused)
124    }
125}
126
127fn rrf_fuse_k(
128    text_hits: Vec<TextSearchHit>,
129    vector_hits: Vec<VectorSearchHit>,
130    k: usize,
131    limit: usize,
132) -> Vec<SearchHit> {
133    #[derive(Default)]
134    struct Bucket {
135        score: DeterministicScore,
136        source: Option<SearchSource>,
137        title: Option<String>,
138        snippet: Option<String>,
139    }
140
141    let mut buckets: HashMap<Uuid, Bucket> = HashMap::new();
142
143    for (i, hit) in text_hits.into_iter().enumerate() {
144        let entry = buckets.entry(hit.subject_id).or_default();
145        entry.score = entry.score + rrf_score(i + 1, k);
146        entry.source = Some(match entry.source {
147            Some(SearchSource::Vector) => SearchSource::Both,
148            _ => SearchSource::Text,
149        });
150        if entry.title.is_none() {
151            entry.title = hit.title;
152        }
153        if entry.snippet.is_none() {
154            entry.snippet = hit.snippet;
155        }
156    }
157
158    for (i, hit) in vector_hits.into_iter().enumerate() {
159        let entry = buckets.entry(hit.subject_id).or_default();
160        entry.score = entry.score + rrf_score(i + 1, k);
161        entry.source = Some(match entry.source {
162            Some(SearchSource::Text) => SearchSource::Both,
163            _ => SearchSource::Vector,
164        });
165    }
166
167    let mut hits: Vec<SearchHit> = buckets
168        .into_iter()
169        .map(|(id, b)| SearchHit {
170            entity_id: id,
171            score: b.score,
172            source: b.source.expect("each bucket gets a source"),
173            title: b.title,
174            snippet: b.snippet,
175        })
176        .collect();
177
178    hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
179    hits.truncate(limit);
180    hits
181}
182
183fn weighted_fuse(
184    text_hits: Vec<TextSearchHit>,
185    vector_hits: Vec<VectorSearchHit>,
186    weights: &[f64],
187    limit: usize,
188) -> Vec<SearchHit> {
189    // Normalize: clamp negatives to 0, fall back to equal if all zero.
190    let w0 = weights.first().copied().unwrap_or(0.0).max(0.0);
191    let w1 = weights.get(1).copied().unwrap_or(0.0).max(0.0);
192    let total = w0 + w1;
193    let (nw0, nw1) = if total <= 0.0 {
194        (0.5, 0.5)
195    } else {
196        (w0 / total, w1 / total)
197    };
198
199    // Collect metadata from text hits before consuming them for scores.
200    let mut meta: HashMap<Uuid, (Option<String>, Option<String>)> = HashMap::new();
201    let text_scores: Vec<(Uuid, f64)> = text_hits
202        .into_iter()
203        .map(|h| {
204            meta.entry(h.subject_id)
205                .or_insert_with(|| (h.title, h.snippet));
206            (h.subject_id, h.score.to_f64())
207        })
208        .collect();
209
210    let vector_scores: Vec<(Uuid, f64)> = vector_hits
211        .into_iter()
212        .map(|h| (h.subject_id, h.score.to_f64()))
213        .collect();
214
215    // Per-source min-max normalize to [0, 1].
216    let text_norm = min_max_normalize(&text_scores);
217    let vector_norm = min_max_normalize(&vector_scores);
218
219    let mut combined: HashMap<Uuid, f64> = HashMap::new();
220    for (id, s) in &text_norm {
221        *combined.entry(*id).or_insert(0.0) += s * nw0;
222    }
223    for (id, s) in &vector_norm {
224        *combined.entry(*id).or_insert(0.0) += s * nw1;
225    }
226
227    let mut hits: Vec<SearchHit> = combined
228        .into_iter()
229        .map(|(id, score)| {
230            let (title, snippet) = meta.get(&id).cloned().unwrap_or_default();
231            let source = match (
232                text_norm.iter().any(|(i, _)| *i == id),
233                vector_norm.iter().any(|(i, _)| *i == id),
234            ) {
235                (true, true) => SearchSource::Both,
236                (true, false) => SearchSource::Text,
237                _ => SearchSource::Vector,
238            };
239            SearchHit {
240                entity_id: id,
241                score: DeterministicScore::from_f64(score),
242                source,
243                title,
244                snippet,
245            }
246        })
247        .collect();
248
249    hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
250    hits.truncate(limit);
251    hits
252}
253
254fn min_max_normalize(scores: &[(Uuid, f64)]) -> Vec<(Uuid, f64)> {
255    if scores.is_empty() {
256        return Vec::new();
257    }
258    let min = scores.iter().map(|(_, s)| *s).fold(f64::INFINITY, f64::min);
259    let max = scores
260        .iter()
261        .map(|(_, s)| *s)
262        .fold(f64::NEG_INFINITY, f64::max);
263    let span = max - min;
264    if span <= f64::EPSILON {
265        return scores.iter().map(|(id, _)| (*id, 1.0)).collect();
266    }
267    scores
268        .iter()
269        .map(|(id, s)| (*id, (s - min) / span))
270        .collect()
271}
272
273fn union_fuse(
274    text_hits: Vec<TextSearchHit>,
275    vector_hits: Vec<VectorSearchHit>,
276    limit: usize,
277) -> Vec<SearchHit> {
278    struct Bucket {
279        score: DeterministicScore,
280        source: SearchSource,
281        title: Option<String>,
282        snippet: Option<String>,
283    }
284
285    let mut buckets: HashMap<Uuid, Bucket> = HashMap::new();
286
287    for hit in text_hits {
288        let entry = buckets.entry(hit.subject_id).or_insert_with(|| Bucket {
289            score: DeterministicScore::ZERO,
290            source: SearchSource::Text,
291            title: None,
292            snippet: None,
293        });
294        if hit.score > entry.score {
295            entry.score = hit.score;
296        }
297        if entry.title.is_none() {
298            entry.title = hit.title;
299        }
300        if entry.snippet.is_none() {
301            entry.snippet = hit.snippet;
302        }
303        if entry.source == SearchSource::Vector {
304            entry.source = SearchSource::Both;
305        }
306    }
307
308    for hit in vector_hits {
309        let entry = buckets.entry(hit.subject_id).or_insert_with(|| Bucket {
310            score: DeterministicScore::ZERO,
311            source: SearchSource::Vector,
312            title: None,
313            snippet: None,
314        });
315        if hit.score > entry.score {
316            entry.score = hit.score;
317        }
318        if entry.source == SearchSource::Text {
319            entry.source = SearchSource::Both;
320        }
321    }
322
323    let mut hits: Vec<SearchHit> = buckets
324        .into_iter()
325        .map(|(id, b)| SearchHit {
326            entity_id: id,
327            score: b.score,
328            source: b.source,
329            title: b.title,
330            snippet: b.snippet,
331        })
332        .collect();
333
334    hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
335    hits.truncate(limit);
336    hits
337}
338
339fn vector_only(vector_hits: Vec<VectorSearchHit>, limit: usize) -> Vec<SearchHit> {
340    let mut hits: Vec<SearchHit> = vector_hits
341        .into_iter()
342        .map(|h| SearchHit {
343            entity_id: h.subject_id,
344            score: h.score,
345            source: SearchSource::Vector,
346            title: None,
347            snippet: None,
348        })
349        .collect();
350    hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
351    hits.truncate(limit);
352    hits
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use khive_storage::types::{TextSearchHit, VectorSearchHit};
359
360    fn text_hit(id: Uuid, score: f64, title: &str) -> TextSearchHit {
361        TextSearchHit {
362            subject_id: id,
363            score: DeterministicScore::from_f64(score),
364            rank: 1,
365            title: Some(title.to_string()),
366            snippet: Some("...".to_string()),
367        }
368    }
369
370    fn vector_hit(id: Uuid, score: f64) -> VectorSearchHit {
371        VectorSearchHit {
372            subject_id: id,
373            score: DeterministicScore::from_f64(score),
374            rank: 1,
375        }
376    }
377
378    // 1. RRF with custom k produces different ordering than k=60
379    #[test]
380    fn rrf_custom_k_differs_from_k60() {
381        let a = Uuid::new_v4();
382        let b = Uuid::new_v4();
383        // With k=1, top rank contributes 1/(1+1)=0.5 vs rank-2 1/(1+2)=0.333 — bigger gap
384        // With k=60, top rank contributes 1/61 vs 1/62 — much smaller gap
385        // Use a case where combining one source forces a=rank1, b=rank2 in text, reversed in vector
386        // k=1: a from text rank1 + vector rank2 = 1/2 + 1/3 = 5/6
387        //       b from text rank2 + vector rank1 = 1/3 + 1/2 = 5/6 (tie, broken by UUID)
388        // k=60: same math, but: 1/61 + 1/62 ≈ 0.0326 each — same tie
389        // Instead verify k=1 produces larger absolute score differences for rank differences
390        let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
391        let hits_k1 = fuse_with_strategy(text.clone(), vec![], &FusionStrategy::Rrf { k: 1 }, 10);
392        let hits_k60 = fuse_with_strategy(text, vec![], &FusionStrategy::Rrf { k: 60 }, 10);
393        // Both should have a first (rank 1 always wins in single-source)
394        assert_eq!(hits_k1[0].entity_id, a);
395        assert_eq!(hits_k60[0].entity_id, a);
396        // k=1 produces higher raw score for rank 1 than k=60
397        assert!(hits_k1[0].score > hits_k60[0].score);
398    }
399
400    // 2. Weighted [0.7, 0.3] gives different ordering than [0.3, 0.7]
401    #[test]
402    fn weighted_ordering_depends_on_weights() {
403        let a = Uuid::new_v4();
404        let b = Uuid::new_v4();
405        // a scores high in text, b scores high in vector
406        let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
407        let vec_hits = vec![vector_hit(b, 0.9), vector_hit(a, 0.1)];
408
409        let heavy_text = fuse_with_strategy(
410            text.clone(),
411            vec_hits.clone(),
412            &FusionStrategy::Weighted {
413                weights: vec![0.7, 0.3],
414            },
415            10,
416        );
417        let heavy_vec = fuse_with_strategy(
418            text,
419            vec_hits,
420            &FusionStrategy::Weighted {
421                weights: vec![0.3, 0.7],
422            },
423            10,
424        );
425
426        assert_eq!(heavy_text[0].entity_id, a);
427        assert_eq!(heavy_vec[0].entity_id, b);
428    }
429
430    // 3. Weighted [7.0, 3.0] = Weighted [0.7, 0.3] (normalization)
431    #[test]
432    fn weighted_scale_invariant() {
433        let a = Uuid::new_v4();
434        let b = Uuid::new_v4();
435        let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
436        let vec_hits = vec![vector_hit(b, 0.9), vector_hit(a, 0.1)];
437
438        let w1 = fuse_with_strategy(
439            text.clone(),
440            vec_hits.clone(),
441            &FusionStrategy::Weighted {
442                weights: vec![0.7, 0.3],
443            },
444            10,
445        );
446        let w2 = fuse_with_strategy(
447            text,
448            vec_hits,
449            &FusionStrategy::Weighted {
450                weights: vec![7.0, 3.0],
451            },
452            10,
453        );
454
455        assert_eq!(w1[0].entity_id, w2[0].entity_id);
456        assert_eq!(w1[1].entity_id, w2[1].entity_id);
457        let diff = (w1[0].score.to_f64() - w2[0].score.to_f64()).abs();
458        assert!(diff < 1e-9, "scores differ by {diff}");
459    }
460
461    // 4. Weighted [0.0, 0.0] falls back to equal weights
462    #[test]
463    fn weighted_zero_weights_equal_fallback() {
464        let a = Uuid::new_v4();
465        let b = Uuid::new_v4();
466        // Both sources agree: a > b
467        let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
468        let vec_hits = vec![vector_hit(a, 0.9), vector_hit(b, 0.1)];
469
470        let hits = fuse_with_strategy(
471            text,
472            vec_hits,
473            &FusionStrategy::Weighted {
474                weights: vec![0.0, 0.0],
475            },
476            10,
477        );
478        assert_eq!(hits[0].entity_id, a);
479    }
480
481    // 5. Weighted with negative weight clamps to 0
482    #[test]
483    fn weighted_negative_weight_clamped() {
484        let a = Uuid::new_v4();
485        let text = vec![text_hit(a, 0.9, "a")];
486        // Negative vector weight → only text contributes
487        let hits = fuse_with_strategy(
488            text,
489            vec![],
490            &FusionStrategy::Weighted {
491                weights: vec![1.0, -0.5],
492            },
493            10,
494        );
495        assert_eq!(hits.len(), 1);
496        assert_eq!(hits[0].entity_id, a);
497    }
498
499    // 6. Union returns max score per entity when same id appears in both lists
500    #[test]
501    fn union_max_score_per_entity() {
502        let a = Uuid::new_v4();
503        let text = vec![text_hit(a, 0.3, "a")];
504        let vec_hits = vec![vector_hit(a, 0.9)];
505
506        let hits = fuse_with_strategy(text, vec_hits, &FusionStrategy::Union, 10);
507        assert_eq!(hits.len(), 1);
508        assert!((hits[0].score.to_f64() - 0.9).abs() < 1e-6);
509        assert_eq!(hits[0].source, SearchSource::Both);
510    }
511
512    // 7. VectorOnly returns vector hits only (text hits dropped)
513    #[test]
514    fn vector_only_drops_text() {
515        let a = Uuid::new_v4();
516        let b = Uuid::new_v4();
517        let text = vec![text_hit(b, 0.9, "b")];
518        let vec_hits = vec![vector_hit(a, 0.8)];
519
520        let hits = fuse_with_strategy(text, vec_hits, &FusionStrategy::VectorOnly, 10);
521        assert_eq!(hits.len(), 1);
522        assert_eq!(hits[0].entity_id, a);
523        assert_eq!(hits[0].source, SearchSource::Vector);
524        assert!(hits[0].title.is_none());
525    }
526
527    // 8. Default strategy is Rrf{k:60}
528    #[test]
529    fn default_strategy_is_rrf_k60() {
530        assert_eq!(FusionStrategy::default(), FusionStrategy::Rrf { k: 60 });
531    }
532
533    // 9. Roundtrip serde preserves variant
534    #[test]
535    fn serde_roundtrip() {
536        let cases = vec![
537            FusionStrategy::Rrf { k: 60 },
538            FusionStrategy::Rrf { k: 20 },
539            FusionStrategy::Weighted {
540                weights: vec![0.7, 0.3],
541            },
542            FusionStrategy::Union,
543            FusionStrategy::VectorOnly,
544        ];
545        for strategy in cases {
546            let json = serde_json::to_string(&strategy).expect("serialize");
547            let back: FusionStrategy = serde_json::from_str(&json).expect("deserialize");
548            assert_eq!(strategy, back, "roundtrip failed for {json}");
549        }
550    }
551}