Skip to main content

khive_runtime/
fusion.rs

1//! Fusion strategies for combining ranked result lists.
2
3use std::collections::{HashMap, HashSet};
4
5use serde::{Deserialize, Serialize};
6use uuid::Uuid;
7
8use khive_score::{rrf_score, DeterministicScore};
9use khive_storage::types::{
10    PageRequest, TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
11    VectorSearchRequest,
12};
13use khive_storage::EntityFilter;
14use khive_types::SubstrateKind;
15
16use crate::error::RuntimeResult;
17use crate::retrieval::{SearchHit, SearchSource};
18use crate::runtime::KhiveRuntime;
19
20const CANDIDATE_MULTIPLIER: u32 = 4;
21
22/// Strategy for fusing ranked result lists from multiple retrieval sources.
23#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
24#[serde(rename_all = "snake_case")]
25pub enum FusionStrategy {
26    /// Reciprocal Rank Fusion. Uses only ranks; robust to different score scales.
27    Rrf { k: usize },
28    /// Weighted linear combination. Min-max normalizes each source to [0,1] first.
29    /// Weights are normalized to sum to 1.0; negatives clamped to 0; all-zero falls back to equal.
30    Weighted { weights: Vec<f64> },
31    /// Take all hits; keep the max score per entity_id.
32    Union,
33    /// Drop text hits; return vector hits only.
34    VectorOnly,
35}
36
37impl Default for FusionStrategy {
38    fn default() -> Self {
39        Self::Rrf { k: 60 }
40    }
41}
42
43/// Fuse text and vector hits using the given strategy, returning at most `limit` results.
44pub fn fuse_with_strategy(
45    text_hits: Vec<TextSearchHit>,
46    vector_hits: Vec<VectorSearchHit>,
47    strategy: &FusionStrategy,
48    limit: usize,
49) -> Vec<SearchHit> {
50    match strategy {
51        FusionStrategy::Rrf { k } => rrf_fuse_k(text_hits, vector_hits, *k, limit),
52        FusionStrategy::Weighted { weights } => {
53            weighted_fuse(text_hits, vector_hits, weights, limit)
54        }
55        FusionStrategy::Union => union_fuse(text_hits, vector_hits, limit),
56        FusionStrategy::VectorOnly => vector_only(vector_hits, limit),
57    }
58}
59
60impl KhiveRuntime {
61    /// Hybrid search with a caller-supplied fusion strategy.
62    pub async fn hybrid_search_with_strategy(
63        &self,
64        namespace: Option<&str>,
65        query_text: &str,
66        query_vector: Option<Vec<f32>>,
67        strategy: FusionStrategy,
68        limit: u32,
69    ) -> RuntimeResult<Vec<SearchHit>> {
70        let candidates = limit.saturating_mul(CANDIDATE_MULTIPLIER).max(limit);
71
72        let ns = self.ns(namespace).to_string();
73        let text_hits = self
74            .text(namespace)?
75            .search(TextSearchRequest {
76                query: query_text.to_string(),
77                mode: TextQueryMode::Plain,
78                filter: Some(TextFilter {
79                    namespaces: vec![ns.clone()],
80                    ..TextFilter::default()
81                }),
82                top_k: candidates,
83                snippet_chars: 200,
84            })
85            .await?;
86
87        let vector_hits = if let Some(vec) = query_vector {
88            self.vectors(namespace)?
89                .search(VectorSearchRequest {
90                    query_embedding: vec,
91                    top_k: candidates,
92                    namespace: Some(ns.clone()),
93                    kind: Some(SubstrateKind::Entity),
94                })
95                .await?
96        } else {
97            Vec::new()
98        };
99
100        let mut fused = fuse_with_strategy(text_hits, vector_hits, &strategy, limit as usize);
101
102        // Filter out soft-deleted entities. A single query fetches all alive IDs from the
103        // fused set; any ID absent from the result has been soft-deleted (deleted_at IS NOT NULL).
104        if !fused.is_empty() {
105            let candidate_ids: Vec<Uuid> = fused.iter().map(|h| h.entity_id).collect();
106            let alive_page = self
107                .entities(namespace)?
108                .query_entities(
109                    self.ns(namespace),
110                    EntityFilter {
111                        ids: candidate_ids,
112                        ..EntityFilter::default()
113                    },
114                    PageRequest {
115                        offset: 0,
116                        limit: fused.len() as u32,
117                    },
118                )
119                .await?;
120            let alive: HashSet<Uuid> = alive_page.items.into_iter().map(|e| e.id).collect();
121            fused.retain(|h| alive.contains(&h.entity_id));
122        }
123
124        Ok(fused)
125    }
126}
127
128fn rrf_fuse_k(
129    text_hits: Vec<TextSearchHit>,
130    vector_hits: Vec<VectorSearchHit>,
131    k: usize,
132    limit: usize,
133) -> Vec<SearchHit> {
134    #[derive(Default)]
135    struct Bucket {
136        score: DeterministicScore,
137        source: Option<SearchSource>,
138        title: Option<String>,
139        snippet: Option<String>,
140    }
141
142    let mut buckets: HashMap<Uuid, Bucket> = HashMap::new();
143
144    for (i, hit) in text_hits.into_iter().enumerate() {
145        let entry = buckets.entry(hit.subject_id).or_default();
146        entry.score = entry.score + rrf_score(i + 1, k);
147        entry.source = Some(match entry.source {
148            Some(SearchSource::Vector) => SearchSource::Both,
149            _ => SearchSource::Text,
150        });
151        if entry.title.is_none() {
152            entry.title = hit.title;
153        }
154        if entry.snippet.is_none() {
155            entry.snippet = hit.snippet;
156        }
157    }
158
159    for (i, hit) in vector_hits.into_iter().enumerate() {
160        let entry = buckets.entry(hit.subject_id).or_default();
161        entry.score = entry.score + rrf_score(i + 1, k);
162        entry.source = Some(match entry.source {
163            Some(SearchSource::Text) => SearchSource::Both,
164            _ => SearchSource::Vector,
165        });
166    }
167
168    let mut hits: Vec<SearchHit> = buckets
169        .into_iter()
170        .map(|(id, b)| SearchHit {
171            entity_id: id,
172            score: b.score,
173            source: b.source.expect("each bucket gets a source"),
174            title: b.title,
175            snippet: b.snippet,
176        })
177        .collect();
178
179    hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
180    hits.truncate(limit);
181    hits
182}
183
184fn weighted_fuse(
185    text_hits: Vec<TextSearchHit>,
186    vector_hits: Vec<VectorSearchHit>,
187    weights: &[f64],
188    limit: usize,
189) -> Vec<SearchHit> {
190    // Normalize: clamp negatives to 0, fall back to equal if all zero.
191    let w0 = weights.first().copied().unwrap_or(0.0).max(0.0);
192    let w1 = weights.get(1).copied().unwrap_or(0.0).max(0.0);
193    let total = w0 + w1;
194    let (nw0, nw1) = if total <= 0.0 {
195        (0.5, 0.5)
196    } else {
197        (w0 / total, w1 / total)
198    };
199
200    // Collect metadata from text hits before consuming them for scores.
201    let mut meta: HashMap<Uuid, (Option<String>, Option<String>)> = HashMap::new();
202    let text_scores: Vec<(Uuid, f64)> = text_hits
203        .into_iter()
204        .map(|h| {
205            meta.entry(h.subject_id)
206                .or_insert_with(|| (h.title, h.snippet));
207            (h.subject_id, h.score.to_f64())
208        })
209        .collect();
210
211    let vector_scores: Vec<(Uuid, f64)> = vector_hits
212        .into_iter()
213        .map(|h| (h.subject_id, h.score.to_f64()))
214        .collect();
215
216    // Per-source min-max normalize to [0, 1].
217    let text_norm = min_max_normalize(&text_scores);
218    let vector_norm = min_max_normalize(&vector_scores);
219
220    let mut combined: HashMap<Uuid, f64> = HashMap::new();
221    for (id, s) in &text_norm {
222        *combined.entry(*id).or_insert(0.0) += s * nw0;
223    }
224    for (id, s) in &vector_norm {
225        *combined.entry(*id).or_insert(0.0) += s * nw1;
226    }
227
228    let mut hits: Vec<SearchHit> = combined
229        .into_iter()
230        .map(|(id, score)| {
231            let (title, snippet) = meta.get(&id).cloned().unwrap_or_default();
232            let source = match (
233                text_norm.iter().any(|(i, _)| *i == id),
234                vector_norm.iter().any(|(i, _)| *i == id),
235            ) {
236                (true, true) => SearchSource::Both,
237                (true, false) => SearchSource::Text,
238                _ => SearchSource::Vector,
239            };
240            SearchHit {
241                entity_id: id,
242                score: DeterministicScore::from_f64(score),
243                source,
244                title,
245                snippet,
246            }
247        })
248        .collect();
249
250    hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
251    hits.truncate(limit);
252    hits
253}
254
255fn min_max_normalize(scores: &[(Uuid, f64)]) -> Vec<(Uuid, f64)> {
256    if scores.is_empty() {
257        return Vec::new();
258    }
259    let min = scores.iter().map(|(_, s)| *s).fold(f64::INFINITY, f64::min);
260    let max = scores
261        .iter()
262        .map(|(_, s)| *s)
263        .fold(f64::NEG_INFINITY, f64::max);
264    let span = max - min;
265    if span <= f64::EPSILON {
266        return scores.iter().map(|(id, _)| (*id, 1.0)).collect();
267    }
268    scores
269        .iter()
270        .map(|(id, s)| (*id, (s - min) / span))
271        .collect()
272}
273
274fn union_fuse(
275    text_hits: Vec<TextSearchHit>,
276    vector_hits: Vec<VectorSearchHit>,
277    limit: usize,
278) -> Vec<SearchHit> {
279    struct Bucket {
280        score: DeterministicScore,
281        source: SearchSource,
282        title: Option<String>,
283        snippet: Option<String>,
284    }
285
286    let mut buckets: HashMap<Uuid, Bucket> = HashMap::new();
287
288    for hit in text_hits {
289        let entry = buckets.entry(hit.subject_id).or_insert_with(|| Bucket {
290            score: DeterministicScore::ZERO,
291            source: SearchSource::Text,
292            title: None,
293            snippet: None,
294        });
295        if hit.score > entry.score {
296            entry.score = hit.score;
297        }
298        if entry.title.is_none() {
299            entry.title = hit.title;
300        }
301        if entry.snippet.is_none() {
302            entry.snippet = hit.snippet;
303        }
304        if entry.source == SearchSource::Vector {
305            entry.source = SearchSource::Both;
306        }
307    }
308
309    for hit in vector_hits {
310        let entry = buckets.entry(hit.subject_id).or_insert_with(|| Bucket {
311            score: DeterministicScore::ZERO,
312            source: SearchSource::Vector,
313            title: None,
314            snippet: None,
315        });
316        if hit.score > entry.score {
317            entry.score = hit.score;
318        }
319        if entry.source == SearchSource::Text {
320            entry.source = SearchSource::Both;
321        }
322    }
323
324    let mut hits: Vec<SearchHit> = buckets
325        .into_iter()
326        .map(|(id, b)| SearchHit {
327            entity_id: id,
328            score: b.score,
329            source: b.source,
330            title: b.title,
331            snippet: b.snippet,
332        })
333        .collect();
334
335    hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
336    hits.truncate(limit);
337    hits
338}
339
340fn vector_only(vector_hits: Vec<VectorSearchHit>, limit: usize) -> Vec<SearchHit> {
341    let mut hits: Vec<SearchHit> = vector_hits
342        .into_iter()
343        .map(|h| SearchHit {
344            entity_id: h.subject_id,
345            score: h.score,
346            source: SearchSource::Vector,
347            title: None,
348            snippet: None,
349        })
350        .collect();
351    hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
352    hits.truncate(limit);
353    hits
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359    use khive_storage::types::{TextSearchHit, VectorSearchHit};
360
361    fn text_hit(id: Uuid, score: f64, title: &str) -> TextSearchHit {
362        TextSearchHit {
363            subject_id: id,
364            score: DeterministicScore::from_f64(score),
365            rank: 1,
366            title: Some(title.to_string()),
367            snippet: Some("...".to_string()),
368        }
369    }
370
371    fn vector_hit(id: Uuid, score: f64) -> VectorSearchHit {
372        VectorSearchHit {
373            subject_id: id,
374            score: DeterministicScore::from_f64(score),
375            rank: 1,
376        }
377    }
378
379    // 1. RRF with custom k produces different ordering than k=60
380    #[test]
381    fn rrf_custom_k_differs_from_k60() {
382        let a = Uuid::new_v4();
383        let b = Uuid::new_v4();
384        // With k=1, top rank contributes 1/(1+1)=0.5 vs rank-2 1/(1+2)=0.333 — bigger gap
385        // With k=60, top rank contributes 1/61 vs 1/62 — much smaller gap
386        // Use a case where combining one source forces a=rank1, b=rank2 in text, reversed in vector
387        // k=1: a from text rank1 + vector rank2 = 1/2 + 1/3 = 5/6
388        //       b from text rank2 + vector rank1 = 1/3 + 1/2 = 5/6 (tie, broken by UUID)
389        // k=60: same math, but: 1/61 + 1/62 ≈ 0.0326 each — same tie
390        // Instead verify k=1 produces larger absolute score differences for rank differences
391        let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
392        let hits_k1 = fuse_with_strategy(text.clone(), vec![], &FusionStrategy::Rrf { k: 1 }, 10);
393        let hits_k60 = fuse_with_strategy(text, vec![], &FusionStrategy::Rrf { k: 60 }, 10);
394        // Both should have a first (rank 1 always wins in single-source)
395        assert_eq!(hits_k1[0].entity_id, a);
396        assert_eq!(hits_k60[0].entity_id, a);
397        // k=1 produces higher raw score for rank 1 than k=60
398        assert!(hits_k1[0].score > hits_k60[0].score);
399    }
400
401    // 2. Weighted [0.7, 0.3] gives different ordering than [0.3, 0.7]
402    #[test]
403    fn weighted_ordering_depends_on_weights() {
404        let a = Uuid::new_v4();
405        let b = Uuid::new_v4();
406        // a scores high in text, b scores high in vector
407        let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
408        let vec_hits = vec![vector_hit(b, 0.9), vector_hit(a, 0.1)];
409
410        let heavy_text = fuse_with_strategy(
411            text.clone(),
412            vec_hits.clone(),
413            &FusionStrategy::Weighted {
414                weights: vec![0.7, 0.3],
415            },
416            10,
417        );
418        let heavy_vec = fuse_with_strategy(
419            text,
420            vec_hits,
421            &FusionStrategy::Weighted {
422                weights: vec![0.3, 0.7],
423            },
424            10,
425        );
426
427        assert_eq!(heavy_text[0].entity_id, a);
428        assert_eq!(heavy_vec[0].entity_id, b);
429    }
430
431    // 3. Weighted [7.0, 3.0] = Weighted [0.7, 0.3] (normalization)
432    #[test]
433    fn weighted_scale_invariant() {
434        let a = Uuid::new_v4();
435        let b = Uuid::new_v4();
436        let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
437        let vec_hits = vec![vector_hit(b, 0.9), vector_hit(a, 0.1)];
438
439        let w1 = fuse_with_strategy(
440            text.clone(),
441            vec_hits.clone(),
442            &FusionStrategy::Weighted {
443                weights: vec![0.7, 0.3],
444            },
445            10,
446        );
447        let w2 = fuse_with_strategy(
448            text,
449            vec_hits,
450            &FusionStrategy::Weighted {
451                weights: vec![7.0, 3.0],
452            },
453            10,
454        );
455
456        assert_eq!(w1[0].entity_id, w2[0].entity_id);
457        assert_eq!(w1[1].entity_id, w2[1].entity_id);
458        let diff = (w1[0].score.to_f64() - w2[0].score.to_f64()).abs();
459        assert!(diff < 1e-9, "scores differ by {diff}");
460    }
461
462    // 4. Weighted [0.0, 0.0] falls back to equal weights
463    #[test]
464    fn weighted_zero_weights_equal_fallback() {
465        let a = Uuid::new_v4();
466        let b = Uuid::new_v4();
467        // Both sources agree: a > b
468        let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
469        let vec_hits = vec![vector_hit(a, 0.9), vector_hit(b, 0.1)];
470
471        let hits = fuse_with_strategy(
472            text,
473            vec_hits,
474            &FusionStrategy::Weighted {
475                weights: vec![0.0, 0.0],
476            },
477            10,
478        );
479        assert_eq!(hits[0].entity_id, a);
480    }
481
482    // 5. Weighted with negative weight clamps to 0
483    #[test]
484    fn weighted_negative_weight_clamped() {
485        let a = Uuid::new_v4();
486        let text = vec![text_hit(a, 0.9, "a")];
487        // Negative vector weight → only text contributes
488        let hits = fuse_with_strategy(
489            text,
490            vec![],
491            &FusionStrategy::Weighted {
492                weights: vec![1.0, -0.5],
493            },
494            10,
495        );
496        assert_eq!(hits.len(), 1);
497        assert_eq!(hits[0].entity_id, a);
498    }
499
500    // 6. Union returns max score per entity when same id appears in both lists
501    #[test]
502    fn union_max_score_per_entity() {
503        let a = Uuid::new_v4();
504        let text = vec![text_hit(a, 0.3, "a")];
505        let vec_hits = vec![vector_hit(a, 0.9)];
506
507        let hits = fuse_with_strategy(text, vec_hits, &FusionStrategy::Union, 10);
508        assert_eq!(hits.len(), 1);
509        assert!((hits[0].score.to_f64() - 0.9).abs() < 1e-6);
510        assert_eq!(hits[0].source, SearchSource::Both);
511    }
512
513    // 7. VectorOnly returns vector hits only (text hits dropped)
514    #[test]
515    fn vector_only_drops_text() {
516        let a = Uuid::new_v4();
517        let b = Uuid::new_v4();
518        let text = vec![text_hit(b, 0.9, "b")];
519        let vec_hits = vec![vector_hit(a, 0.8)];
520
521        let hits = fuse_with_strategy(text, vec_hits, &FusionStrategy::VectorOnly, 10);
522        assert_eq!(hits.len(), 1);
523        assert_eq!(hits[0].entity_id, a);
524        assert_eq!(hits[0].source, SearchSource::Vector);
525        assert!(hits[0].title.is_none());
526    }
527
528    // 8. Default strategy is Rrf{k:60}
529    #[test]
530    fn default_strategy_is_rrf_k60() {
531        assert_eq!(FusionStrategy::default(), FusionStrategy::Rrf { k: 60 });
532    }
533
534    // 9. Roundtrip serde preserves variant
535    #[test]
536    fn serde_roundtrip() {
537        let cases = vec![
538            FusionStrategy::Rrf { k: 60 },
539            FusionStrategy::Rrf { k: 20 },
540            FusionStrategy::Weighted {
541                weights: vec![0.7, 0.3],
542            },
543            FusionStrategy::Union,
544            FusionStrategy::VectorOnly,
545        ];
546        for strategy in cases {
547            let json = serde_json::to_string(&strategy).expect("serialize");
548            let back: FusionStrategy = serde_json::from_str(&json).expect("deserialize");
549            assert_eq!(strategy, back, "roundtrip failed for {json}");
550        }
551    }
552}