Skip to main content

khive_runtime/
fusion.rs

1//! Fusion strategies for combining ranked result lists.
2
3use std::collections::{hash_map::Entry, HashMap, HashSet};
4
5use uuid::Uuid;
6
7use khive_score::DeterministicScore;
8use khive_storage::types::{
9    PageRequest, TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
10};
11use khive_storage::EntityFilter;
12use khive_types::SubstrateKind;
13
14use crate::error::RuntimeResult;
15use crate::retrieval::{SearchHit, SearchSource};
16use crate::runtime::{KhiveRuntime, NamespaceToken};
17
18pub use khive_fusion::FusionStrategy;
19
20const CANDIDATE_MULTIPLIER: u32 = 4;
21
22/// Fuse text and vector hits using the given strategy, returning at most `limit` results.
23pub fn fuse_with_strategy(
24    text_hits: Vec<TextSearchHit>,
25    vector_hits: Vec<VectorSearchHit>,
26    strategy: &FusionStrategy,
27    limit: usize,
28) -> RuntimeResult<Vec<SearchHit>> {
29    match strategy {
30        FusionStrategy::VectorOnly => fuse_sources(Vec::new(), vector_hits, strategy, limit),
31        FusionStrategy::KeywordOnly => fuse_sources(text_hits, Vec::new(), strategy, limit),
32        FusionStrategy::Rrf { .. } | FusionStrategy::Weighted { .. } | FusionStrategy::Union => {
33            fuse_sources(text_hits, vector_hits, strategy, limit)
34        }
35        FusionStrategy::Custom { ref name, .. } => {
36            Err(khive_fusion::FuseError::CustomRequiresRuntime(name.clone()).into())
37        }
38    }
39}
40
41/// RRF convenience wrapper used by operations.rs (k=60 note search path).
42pub(crate) fn rrf_fuse_k(
43    text_hits: Vec<TextSearchHit>,
44    vector_hits: Vec<VectorSearchHit>,
45    k: usize,
46    limit: usize,
47) -> RuntimeResult<Vec<SearchHit>> {
48    fuse_with_strategy(text_hits, vector_hits, &FusionStrategy::Rrf { k }, limit)
49}
50
51fn fuse_sources(
52    text_hits: Vec<TextSearchHit>,
53    vector_hits: Vec<VectorSearchHit>,
54    strategy: &FusionStrategy,
55    limit: usize,
56) -> RuntimeResult<Vec<SearchHit>> {
57    let mut metadata: HashMap<Uuid, SearchHit> =
58        HashMap::with_capacity(text_hits.len() + vector_hits.len());
59
60    let text_source: Vec<(Uuid, DeterministicScore)> = text_hits
61        .into_iter()
62        .map(|h| {
63            let hit = SearchHit {
64                entity_id: h.subject_id,
65                score: h.score,
66                source: SearchSource::Text,
67                title: h.title,
68                snippet: h.snippet,
69            };
70            let id = hit.entity_id;
71            let score = hit.score;
72            merge_metadata(&mut metadata, hit);
73            (id, score)
74        })
75        .collect();
76
77    let vector_source: Vec<(Uuid, DeterministicScore)> = vector_hits
78        .into_iter()
79        .map(|h| {
80            let hit = SearchHit {
81                entity_id: h.subject_id,
82                score: h.score,
83                source: SearchSource::Vector,
84                title: None,
85                snippet: None,
86            };
87            let id = hit.entity_id;
88            let score = hit.score;
89            merge_metadata(&mut metadata, hit);
90            (id, score)
91        })
92        .collect();
93
94    let sources: Vec<Vec<(Uuid, DeterministicScore)>> = vec![text_source, vector_source]
95        .into_iter()
96        .filter(|s| !s.is_empty())
97        .collect();
98
99    Ok(khive_fusion::fuse(sources, strategy, limit)?
100        .into_iter()
101        .filter_map(|(id, score)| {
102            let mut hit = metadata.remove(&id)?;
103            hit.score = score;
104            Some(hit)
105        })
106        .collect())
107}
108
109fn merge_metadata(metadata: &mut HashMap<Uuid, SearchHit>, hit: SearchHit) {
110    match metadata.entry(hit.entity_id) {
111        Entry::Occupied(mut entry) => {
112            let existing = entry.get_mut();
113            existing.source = merge_sources(existing.source, hit.source);
114            if existing.title.is_none() {
115                existing.title = hit.title;
116            }
117            if existing.snippet.is_none() {
118                existing.snippet = hit.snippet;
119            }
120        }
121        Entry::Vacant(entry) => {
122            entry.insert(hit);
123        }
124    }
125}
126
127fn merge_sources(left: SearchSource, right: SearchSource) -> SearchSource {
128    match (left, right) {
129        (SearchSource::Both, _) | (_, SearchSource::Both) => SearchSource::Both,
130        (SearchSource::Text, SearchSource::Vector) | (SearchSource::Vector, SearchSource::Text) => {
131            SearchSource::Both
132        }
133        (SearchSource::Text, SearchSource::Text) => SearchSource::Text,
134        (SearchSource::Vector, SearchSource::Vector) => SearchSource::Vector,
135    }
136}
137
138impl KhiveRuntime {
139    /// Hybrid search with a caller-supplied fusion strategy.
140    pub async fn hybrid_search_with_strategy(
141        &self,
142        token: &NamespaceToken,
143        query_text: &str,
144        query_vector: Option<Vec<f32>>,
145        strategy: FusionStrategy,
146        limit: u32,
147    ) -> RuntimeResult<Vec<SearchHit>> {
148        let candidates = limit.saturating_mul(CANDIDATE_MULTIPLIER).max(limit);
149
150        let ns = token.namespace().as_str().to_owned();
151        let text_hits = self
152            .text(token)?
153            .search(TextSearchRequest {
154                query: query_text.to_string(),
155                mode: TextQueryMode::Plain,
156                filter: Some(TextFilter {
157                    namespaces: vec![ns.clone()],
158                    ..TextFilter::default()
159                }),
160                top_k: candidates,
161                snippet_chars: 200,
162            })
163            .await?;
164
165        let vector_hits = if query_vector.is_some() || self.config().embedding_model.is_some() {
166            self.vector_search(
167                token,
168                query_vector,
169                Some(query_text),
170                candidates,
171                Some(SubstrateKind::Entity),
172            )
173            .await?
174        } else {
175            Vec::new()
176        };
177
178        let mut fused = fuse_with_strategy(text_hits, vector_hits, &strategy, limit as usize)?;
179
180        // Filter out soft-deleted entities. A single query fetches all alive IDs from the
181        // fused set; any ID absent from the result has been soft-deleted (deleted_at IS NOT NULL).
182        if !fused.is_empty() {
183            let candidate_ids: Vec<Uuid> = fused.iter().map(|h| h.entity_id).collect();
184            let alive_page = self
185                .entities(token)?
186                .query_entities(
187                    token.namespace().as_str(),
188                    EntityFilter {
189                        ids: candidate_ids,
190                        ..EntityFilter::default()
191                    },
192                    PageRequest {
193                        offset: 0,
194                        limit: fused.len() as u32,
195                    },
196                )
197                .await?;
198            let alive: HashSet<Uuid> = alive_page.items.into_iter().map(|e| e.id).collect();
199            fused.retain(|h| alive.contains(&h.entity_id));
200        }
201
202        Ok(fused)
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use khive_storage::types::{TextSearchHit, VectorSearchHit};
210
211    fn text_hit(id: Uuid, score: f64, title: &str) -> TextSearchHit {
212        TextSearchHit {
213            subject_id: id,
214            score: DeterministicScore::from_f64(score),
215            rank: 1,
216            title: Some(title.to_string()),
217            snippet: Some("...".to_string()),
218        }
219    }
220
221    fn vector_hit(id: Uuid, score: f64) -> VectorSearchHit {
222        VectorSearchHit {
223            subject_id: id,
224            score: DeterministicScore::from_f64(score),
225            rank: 1,
226        }
227    }
228
229    // 1. RRF with custom k produces different ordering than k=60
230    #[test]
231    fn rrf_custom_k_differs_from_k60() {
232        let a = Uuid::new_v4();
233        let b = Uuid::new_v4();
234        // With k=1, top rank contributes 1/(1+1)=0.5 vs rank-2 1/(1+2)=0.333 — bigger gap
235        // With k=60, top rank contributes 1/61 vs 1/62 — much smaller gap
236        // Use a case where combining one source forces a=rank1, b=rank2 in text, reversed in vector
237        // k=1: a from text rank1 + vector rank2 = 1/2 + 1/3 = 5/6
238        //       b from text rank2 + vector rank1 = 1/3 + 1/2 = 5/6 (tie, broken by UUID)
239        // k=60: same math, but: 1/61 + 1/62 ≈ 0.0326 each — same tie
240        // Instead verify k=1 produces larger absolute score differences for rank differences
241        let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
242        let hits_k1 =
243            fuse_with_strategy(text.clone(), vec![], &FusionStrategy::Rrf { k: 1 }, 10).unwrap();
244        let hits_k60 =
245            fuse_with_strategy(text, vec![], &FusionStrategy::Rrf { k: 60 }, 10).unwrap();
246        // Both should have a first (rank 1 always wins in single-source)
247        assert_eq!(hits_k1[0].entity_id, a);
248        assert_eq!(hits_k60[0].entity_id, a);
249        // k=1 produces higher raw score for rank 1 than k=60
250        assert!(hits_k1[0].score > hits_k60[0].score);
251    }
252
253    // 2. Weighted [0.7, 0.3] gives different ordering than [0.3, 0.7]
254    #[test]
255    fn weighted_ordering_depends_on_weights() {
256        let a = Uuid::new_v4();
257        let b = Uuid::new_v4();
258        // a scores high in text, b scores high in vector
259        let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
260        let vec_hits = vec![vector_hit(b, 0.9), vector_hit(a, 0.1)];
261
262        let heavy_text = fuse_with_strategy(
263            text.clone(),
264            vec_hits.clone(),
265            &FusionStrategy::Weighted {
266                weights: vec![0.7, 0.3],
267            },
268            10,
269        )
270        .unwrap();
271        let heavy_vec = fuse_with_strategy(
272            text,
273            vec_hits,
274            &FusionStrategy::Weighted {
275                weights: vec![0.3, 0.7],
276            },
277            10,
278        )
279        .unwrap();
280
281        assert_eq!(heavy_text[0].entity_id, a);
282        assert_eq!(heavy_vec[0].entity_id, b);
283    }
284
285    // 3. Weighted [7.0, 3.0] = Weighted [0.7, 0.3] (normalization)
286    #[test]
287    fn weighted_scale_invariant() {
288        let a = Uuid::new_v4();
289        let b = Uuid::new_v4();
290        let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
291        let vec_hits = vec![vector_hit(b, 0.9), vector_hit(a, 0.1)];
292
293        let w1 = fuse_with_strategy(
294            text.clone(),
295            vec_hits.clone(),
296            &FusionStrategy::Weighted {
297                weights: vec![0.7, 0.3],
298            },
299            10,
300        )
301        .unwrap();
302        let w2 = fuse_with_strategy(
303            text,
304            vec_hits,
305            &FusionStrategy::Weighted {
306                weights: vec![7.0, 3.0],
307            },
308            10,
309        )
310        .unwrap();
311
312        assert_eq!(w1[0].entity_id, w2[0].entity_id);
313        assert_eq!(w1[1].entity_id, w2[1].entity_id);
314        let diff = (w1[0].score.to_f64() - w2[0].score.to_f64()).abs();
315        assert!(diff < 1e-9, "scores differ by {diff}");
316    }
317
318    // 4. Weighted [0.0, 0.0] falls back to equal weights
319    #[test]
320    fn weighted_zero_weights_equal_fallback() {
321        let a = Uuid::new_v4();
322        let b = Uuid::new_v4();
323        // Both sources agree: a > b
324        let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
325        let vec_hits = vec![vector_hit(a, 0.9), vector_hit(b, 0.1)];
326
327        let hits = fuse_with_strategy(
328            text,
329            vec_hits,
330            &FusionStrategy::Weighted {
331                weights: vec![0.0, 0.0],
332            },
333            10,
334        )
335        .unwrap();
336        assert_eq!(hits[0].entity_id, a);
337    }
338
339    // 5. Weighted with negative weight clamps to 0
340    #[test]
341    fn weighted_negative_weight_clamped() {
342        let a = Uuid::new_v4();
343        let text = vec![text_hit(a, 0.9, "a")];
344        // Negative vector weight → only text contributes
345        let hits = fuse_with_strategy(
346            text,
347            vec![],
348            &FusionStrategy::Weighted {
349                weights: vec![1.0, -0.5],
350            },
351            10,
352        )
353        .unwrap();
354        assert_eq!(hits.len(), 1);
355        assert_eq!(hits[0].entity_id, a);
356    }
357
358    // 6. Union returns max score per entity when same id appears in both lists
359    #[test]
360    fn union_max_score_per_entity() {
361        let a = Uuid::new_v4();
362        let text = vec![text_hit(a, 0.3, "a")];
363        let vec_hits = vec![vector_hit(a, 0.9)];
364
365        let hits = fuse_with_strategy(text, vec_hits, &FusionStrategy::Union, 10).unwrap();
366        assert_eq!(hits.len(), 1);
367        assert!((hits[0].score.to_f64() - 0.9).abs() < 1e-6);
368        assert_eq!(hits[0].source, SearchSource::Both);
369    }
370
371    // 7. VectorOnly returns vector hits only (text hits dropped)
372    #[test]
373    fn vector_only_drops_text() {
374        let a = Uuid::new_v4();
375        let b = Uuid::new_v4();
376        let text = vec![text_hit(b, 0.9, "b")];
377        let vec_hits = vec![vector_hit(a, 0.8)];
378
379        let hits = fuse_with_strategy(text, vec_hits, &FusionStrategy::VectorOnly, 10).unwrap();
380        assert_eq!(hits.len(), 1);
381        assert_eq!(hits[0].entity_id, a);
382        assert_eq!(hits[0].source, SearchSource::Vector);
383        assert!(hits[0].title.is_none());
384    }
385
386    // 8. Default strategy is Rrf{k:60}
387    #[test]
388    fn default_strategy_is_rrf_k60() {
389        assert_eq!(FusionStrategy::default(), FusionStrategy::Rrf { k: 60 });
390    }
391
392    // 9. Roundtrip serde preserves variant
393    #[test]
394    fn serde_roundtrip() {
395        let cases = vec![
396            FusionStrategy::Rrf { k: 60 },
397            FusionStrategy::Rrf { k: 20 },
398            FusionStrategy::Weighted {
399                weights: vec![0.7, 0.3],
400            },
401            FusionStrategy::Union,
402            FusionStrategy::VectorOnly,
403        ];
404        for strategy in cases {
405            let json = serde_json::to_string(&strategy).expect("serialize");
406            let back: FusionStrategy = serde_json::from_str(&json).expect("deserialize");
407            assert_eq!(strategy, back, "roundtrip failed for {json}");
408        }
409    }
410}