Skip to main content

khive_runtime/
fusion.rs

1//! Fusion strategies for combining ranked result lists.
2
3use std::collections::{hash_map::Entry, HashMap, HashSet};
4
5use uuid::Uuid;
6
7use khive_score::DeterministicScore;
8use khive_storage::types::{
9    PageRequest, TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
10};
11use khive_storage::EntityFilter;
12use khive_types::SubstrateKind;
13
14use crate::error::RuntimeResult;
15use crate::retrieval::{SearchHit, SearchSource};
16use crate::runtime::{KhiveRuntime, NamespaceToken};
17
18pub use khive_fusion::FusionStrategy;
19
20const CANDIDATE_MULTIPLIER: u32 = 4;
21
22/// Fuse text and vector hits using the given strategy, returning at most `limit` results.
23pub fn fuse_with_strategy(
24    text_hits: Vec<TextSearchHit>,
25    vector_hits: Vec<VectorSearchHit>,
26    strategy: &FusionStrategy,
27    limit: usize,
28) -> Vec<SearchHit> {
29    match strategy {
30        FusionStrategy::VectorOnly => fuse_sources(Vec::new(), vector_hits, strategy, limit),
31        FusionStrategy::KeywordOnly => fuse_sources(text_hits, Vec::new(), strategy, limit),
32        FusionStrategy::Rrf { .. } | FusionStrategy::Weighted { .. } | FusionStrategy::Union => {
33            fuse_sources(text_hits, vector_hits, strategy, limit)
34        }
35    }
36}
37
38/// RRF convenience wrapper used by operations.rs (k=60 note search path).
39pub(crate) fn rrf_fuse_k(
40    text_hits: Vec<TextSearchHit>,
41    vector_hits: Vec<VectorSearchHit>,
42    k: usize,
43    limit: usize,
44) -> Vec<SearchHit> {
45    fuse_with_strategy(text_hits, vector_hits, &FusionStrategy::Rrf { k }, limit)
46}
47
48fn fuse_sources(
49    text_hits: Vec<TextSearchHit>,
50    vector_hits: Vec<VectorSearchHit>,
51    strategy: &FusionStrategy,
52    limit: usize,
53) -> Vec<SearchHit> {
54    let mut metadata: HashMap<Uuid, SearchHit> =
55        HashMap::with_capacity(text_hits.len() + vector_hits.len());
56
57    let text_source: Vec<(Uuid, DeterministicScore)> = text_hits
58        .into_iter()
59        .map(|h| {
60            let hit = SearchHit {
61                entity_id: h.subject_id,
62                score: h.score,
63                source: SearchSource::Text,
64                title: h.title,
65                snippet: h.snippet,
66            };
67            let id = hit.entity_id;
68            let score = hit.score;
69            merge_metadata(&mut metadata, hit);
70            (id, score)
71        })
72        .collect();
73
74    let vector_source: Vec<(Uuid, DeterministicScore)> = vector_hits
75        .into_iter()
76        .map(|h| {
77            let hit = SearchHit {
78                entity_id: h.subject_id,
79                score: h.score,
80                source: SearchSource::Vector,
81                title: None,
82                snippet: None,
83            };
84            let id = hit.entity_id;
85            let score = hit.score;
86            merge_metadata(&mut metadata, hit);
87            (id, score)
88        })
89        .collect();
90
91    khive_fusion::fuse(vec![text_source, vector_source], strategy, limit)
92        .into_iter()
93        .filter_map(|(id, score)| {
94            let mut hit = metadata.remove(&id)?;
95            hit.score = score;
96            Some(hit)
97        })
98        .collect()
99}
100
101fn merge_metadata(metadata: &mut HashMap<Uuid, SearchHit>, hit: SearchHit) {
102    match metadata.entry(hit.entity_id) {
103        Entry::Occupied(mut entry) => {
104            let existing = entry.get_mut();
105            existing.source = merge_sources(existing.source, hit.source);
106            if existing.title.is_none() {
107                existing.title = hit.title;
108            }
109            if existing.snippet.is_none() {
110                existing.snippet = hit.snippet;
111            }
112        }
113        Entry::Vacant(entry) => {
114            entry.insert(hit);
115        }
116    }
117}
118
119fn merge_sources(left: SearchSource, right: SearchSource) -> SearchSource {
120    match (left, right) {
121        (SearchSource::Both, _) | (_, SearchSource::Both) => SearchSource::Both,
122        (SearchSource::Text, SearchSource::Vector) | (SearchSource::Vector, SearchSource::Text) => {
123            SearchSource::Both
124        }
125        (SearchSource::Text, SearchSource::Text) => SearchSource::Text,
126        (SearchSource::Vector, SearchSource::Vector) => SearchSource::Vector,
127    }
128}
129
130impl KhiveRuntime {
131    /// Hybrid search with a caller-supplied fusion strategy.
132    pub async fn hybrid_search_with_strategy(
133        &self,
134        token: &NamespaceToken,
135        query_text: &str,
136        query_vector: Option<Vec<f32>>,
137        strategy: FusionStrategy,
138        limit: u32,
139    ) -> RuntimeResult<Vec<SearchHit>> {
140        let candidates = limit.saturating_mul(CANDIDATE_MULTIPLIER).max(limit);
141
142        let ns = token.namespace().as_str().to_owned();
143        let text_hits = self
144            .text(token)?
145            .search(TextSearchRequest {
146                query: query_text.to_string(),
147                mode: TextQueryMode::Plain,
148                filter: Some(TextFilter {
149                    namespaces: vec![ns.clone()],
150                    ..TextFilter::default()
151                }),
152                top_k: candidates,
153                snippet_chars: 200,
154            })
155            .await?;
156
157        let vector_hits = if query_vector.is_some() || self.config().embedding_model.is_some() {
158            self.vector_search(
159                token,
160                query_vector,
161                Some(query_text),
162                candidates,
163                Some(SubstrateKind::Entity),
164            )
165            .await?
166        } else {
167            Vec::new()
168        };
169
170        let mut fused = fuse_with_strategy(text_hits, vector_hits, &strategy, limit as usize);
171
172        // Filter out soft-deleted entities. A single query fetches all alive IDs from the
173        // fused set; any ID absent from the result has been soft-deleted (deleted_at IS NOT NULL).
174        if !fused.is_empty() {
175            let candidate_ids: Vec<Uuid> = fused.iter().map(|h| h.entity_id).collect();
176            let alive_page = self
177                .entities(token)?
178                .query_entities(
179                    token.namespace().as_str(),
180                    EntityFilter {
181                        ids: candidate_ids,
182                        ..EntityFilter::default()
183                    },
184                    PageRequest {
185                        offset: 0,
186                        limit: fused.len() as u32,
187                    },
188                )
189                .await?;
190            let alive: HashSet<Uuid> = alive_page.items.into_iter().map(|e| e.id).collect();
191            fused.retain(|h| alive.contains(&h.entity_id));
192        }
193
194        Ok(fused)
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use khive_storage::types::{TextSearchHit, VectorSearchHit};
202
203    fn text_hit(id: Uuid, score: f64, title: &str) -> TextSearchHit {
204        TextSearchHit {
205            subject_id: id,
206            score: DeterministicScore::from_f64(score),
207            rank: 1,
208            title: Some(title.to_string()),
209            snippet: Some("...".to_string()),
210        }
211    }
212
213    fn vector_hit(id: Uuid, score: f64) -> VectorSearchHit {
214        VectorSearchHit {
215            subject_id: id,
216            score: DeterministicScore::from_f64(score),
217            rank: 1,
218        }
219    }
220
221    // 1. RRF with custom k produces different ordering than k=60
222    #[test]
223    fn rrf_custom_k_differs_from_k60() {
224        let a = Uuid::new_v4();
225        let b = Uuid::new_v4();
226        // With k=1, top rank contributes 1/(1+1)=0.5 vs rank-2 1/(1+2)=0.333 — bigger gap
227        // With k=60, top rank contributes 1/61 vs 1/62 — much smaller gap
228        // Use a case where combining one source forces a=rank1, b=rank2 in text, reversed in vector
229        // k=1: a from text rank1 + vector rank2 = 1/2 + 1/3 = 5/6
230        //       b from text rank2 + vector rank1 = 1/3 + 1/2 = 5/6 (tie, broken by UUID)
231        // k=60: same math, but: 1/61 + 1/62 ≈ 0.0326 each — same tie
232        // Instead verify k=1 produces larger absolute score differences for rank differences
233        let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
234        let hits_k1 = fuse_with_strategy(text.clone(), vec![], &FusionStrategy::Rrf { k: 1 }, 10);
235        let hits_k60 = fuse_with_strategy(text, vec![], &FusionStrategy::Rrf { k: 60 }, 10);
236        // Both should have a first (rank 1 always wins in single-source)
237        assert_eq!(hits_k1[0].entity_id, a);
238        assert_eq!(hits_k60[0].entity_id, a);
239        // k=1 produces higher raw score for rank 1 than k=60
240        assert!(hits_k1[0].score > hits_k60[0].score);
241    }
242
243    // 2. Weighted [0.7, 0.3] gives different ordering than [0.3, 0.7]
244    #[test]
245    fn weighted_ordering_depends_on_weights() {
246        let a = Uuid::new_v4();
247        let b = Uuid::new_v4();
248        // a scores high in text, b scores high in vector
249        let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
250        let vec_hits = vec![vector_hit(b, 0.9), vector_hit(a, 0.1)];
251
252        let heavy_text = fuse_with_strategy(
253            text.clone(),
254            vec_hits.clone(),
255            &FusionStrategy::Weighted {
256                weights: vec![0.7, 0.3],
257            },
258            10,
259        );
260        let heavy_vec = fuse_with_strategy(
261            text,
262            vec_hits,
263            &FusionStrategy::Weighted {
264                weights: vec![0.3, 0.7],
265            },
266            10,
267        );
268
269        assert_eq!(heavy_text[0].entity_id, a);
270        assert_eq!(heavy_vec[0].entity_id, b);
271    }
272
273    // 3. Weighted [7.0, 3.0] = Weighted [0.7, 0.3] (normalization)
274    #[test]
275    fn weighted_scale_invariant() {
276        let a = Uuid::new_v4();
277        let b = Uuid::new_v4();
278        let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
279        let vec_hits = vec![vector_hit(b, 0.9), vector_hit(a, 0.1)];
280
281        let w1 = fuse_with_strategy(
282            text.clone(),
283            vec_hits.clone(),
284            &FusionStrategy::Weighted {
285                weights: vec![0.7, 0.3],
286            },
287            10,
288        );
289        let w2 = fuse_with_strategy(
290            text,
291            vec_hits,
292            &FusionStrategy::Weighted {
293                weights: vec![7.0, 3.0],
294            },
295            10,
296        );
297
298        assert_eq!(w1[0].entity_id, w2[0].entity_id);
299        assert_eq!(w1[1].entity_id, w2[1].entity_id);
300        let diff = (w1[0].score.to_f64() - w2[0].score.to_f64()).abs();
301        assert!(diff < 1e-9, "scores differ by {diff}");
302    }
303
304    // 4. Weighted [0.0, 0.0] falls back to equal weights
305    #[test]
306    fn weighted_zero_weights_equal_fallback() {
307        let a = Uuid::new_v4();
308        let b = Uuid::new_v4();
309        // Both sources agree: a > b
310        let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
311        let vec_hits = vec![vector_hit(a, 0.9), vector_hit(b, 0.1)];
312
313        let hits = fuse_with_strategy(
314            text,
315            vec_hits,
316            &FusionStrategy::Weighted {
317                weights: vec![0.0, 0.0],
318            },
319            10,
320        );
321        assert_eq!(hits[0].entity_id, a);
322    }
323
324    // 5. Weighted with negative weight clamps to 0
325    #[test]
326    fn weighted_negative_weight_clamped() {
327        let a = Uuid::new_v4();
328        let text = vec![text_hit(a, 0.9, "a")];
329        // Negative vector weight → only text contributes
330        let hits = fuse_with_strategy(
331            text,
332            vec![],
333            &FusionStrategy::Weighted {
334                weights: vec![1.0, -0.5],
335            },
336            10,
337        );
338        assert_eq!(hits.len(), 1);
339        assert_eq!(hits[0].entity_id, a);
340    }
341
342    // 6. Union returns max score per entity when same id appears in both lists
343    #[test]
344    fn union_max_score_per_entity() {
345        let a = Uuid::new_v4();
346        let text = vec![text_hit(a, 0.3, "a")];
347        let vec_hits = vec![vector_hit(a, 0.9)];
348
349        let hits = fuse_with_strategy(text, vec_hits, &FusionStrategy::Union, 10);
350        assert_eq!(hits.len(), 1);
351        assert!((hits[0].score.to_f64() - 0.9).abs() < 1e-6);
352        assert_eq!(hits[0].source, SearchSource::Both);
353    }
354
355    // 7. VectorOnly returns vector hits only (text hits dropped)
356    #[test]
357    fn vector_only_drops_text() {
358        let a = Uuid::new_v4();
359        let b = Uuid::new_v4();
360        let text = vec![text_hit(b, 0.9, "b")];
361        let vec_hits = vec![vector_hit(a, 0.8)];
362
363        let hits = fuse_with_strategy(text, vec_hits, &FusionStrategy::VectorOnly, 10);
364        assert_eq!(hits.len(), 1);
365        assert_eq!(hits[0].entity_id, a);
366        assert_eq!(hits[0].source, SearchSource::Vector);
367        assert!(hits[0].title.is_none());
368    }
369
370    // 8. Default strategy is Rrf{k:60}
371    #[test]
372    fn default_strategy_is_rrf_k60() {
373        assert_eq!(FusionStrategy::default(), FusionStrategy::Rrf { k: 60 });
374    }
375
376    // 9. Roundtrip serde preserves variant
377    #[test]
378    fn serde_roundtrip() {
379        let cases = vec![
380            FusionStrategy::Rrf { k: 60 },
381            FusionStrategy::Rrf { k: 20 },
382            FusionStrategy::Weighted {
383                weights: vec![0.7, 0.3],
384            },
385            FusionStrategy::Union,
386            FusionStrategy::VectorOnly,
387        ];
388        for strategy in cases {
389            let json = serde_json::to_string(&strategy).expect("serialize");
390            let back: FusionStrategy = serde_json::from_str(&json).expect("deserialize");
391            assert_eq!(strategy, back, "roundtrip failed for {json}");
392        }
393    }
394}