1use std::collections::{hash_map::Entry, HashMap, HashSet};
4
5use uuid::Uuid;
6
7use khive_score::DeterministicScore;
8use khive_storage::types::{
9 PageRequest, TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
10};
11use khive_storage::EntityFilter;
12use khive_types::SubstrateKind;
13
14use crate::error::RuntimeResult;
15use crate::retrieval::{SearchHit, SearchSource};
16use crate::runtime::{KhiveRuntime, NamespaceToken};
17
18pub use khive_fusion::FusionStrategy;
19
20const CANDIDATE_MULTIPLIER: u32 = 4;
21
22pub fn fuse_with_strategy(
24 text_hits: Vec<TextSearchHit>,
25 vector_hits: Vec<VectorSearchHit>,
26 strategy: &FusionStrategy,
27 limit: usize,
28) -> RuntimeResult<Vec<SearchHit>> {
29 match strategy {
30 FusionStrategy::VectorOnly => fuse_sources(Vec::new(), vector_hits, strategy, limit),
31 FusionStrategy::KeywordOnly => fuse_sources(text_hits, Vec::new(), strategy, limit),
32 FusionStrategy::Rrf { .. } | FusionStrategy::Weighted { .. } | FusionStrategy::Union => {
33 fuse_sources(text_hits, vector_hits, strategy, limit)
34 }
35 FusionStrategy::Custom { ref name, .. } => {
36 Err(khive_fusion::FuseError::CustomRequiresRuntime(name.clone()).into())
37 }
38 }
39}
40
41pub(crate) fn rrf_fuse_k(
43 text_hits: Vec<TextSearchHit>,
44 vector_hits: Vec<VectorSearchHit>,
45 k: usize,
46 limit: usize,
47) -> RuntimeResult<Vec<SearchHit>> {
48 fuse_with_strategy(text_hits, vector_hits, &FusionStrategy::Rrf { k }, limit)
49}
50
51fn fuse_sources(
52 text_hits: Vec<TextSearchHit>,
53 vector_hits: Vec<VectorSearchHit>,
54 strategy: &FusionStrategy,
55 limit: usize,
56) -> RuntimeResult<Vec<SearchHit>> {
57 let mut metadata: HashMap<Uuid, SearchHit> =
58 HashMap::with_capacity(text_hits.len() + vector_hits.len());
59
60 let text_source: Vec<(Uuid, DeterministicScore)> = text_hits
61 .into_iter()
62 .map(|h| {
63 let hit = SearchHit {
64 entity_id: h.subject_id,
65 score: h.score,
66 source: SearchSource::Text,
67 title: h.title,
68 snippet: h.snippet,
69 };
70 let id = hit.entity_id;
71 let score = hit.score;
72 merge_metadata(&mut metadata, hit);
73 (id, score)
74 })
75 .collect();
76
77 let vector_source: Vec<(Uuid, DeterministicScore)> = vector_hits
78 .into_iter()
79 .map(|h| {
80 let hit = SearchHit {
81 entity_id: h.subject_id,
82 score: h.score,
83 source: SearchSource::Vector,
84 title: None,
85 snippet: None,
86 };
87 let id = hit.entity_id;
88 let score = hit.score;
89 merge_metadata(&mut metadata, hit);
90 (id, score)
91 })
92 .collect();
93
94 let sources: Vec<Vec<(Uuid, DeterministicScore)>> = vec![text_source, vector_source]
95 .into_iter()
96 .filter(|s| !s.is_empty())
97 .collect();
98
99 Ok(khive_fusion::fuse(sources, strategy, limit)?
100 .into_iter()
101 .filter_map(|(id, score)| {
102 let mut hit = metadata.remove(&id)?;
103 hit.score = score;
104 Some(hit)
105 })
106 .collect())
107}
108
109fn merge_metadata(metadata: &mut HashMap<Uuid, SearchHit>, hit: SearchHit) {
110 match metadata.entry(hit.entity_id) {
111 Entry::Occupied(mut entry) => {
112 let existing = entry.get_mut();
113 existing.source = merge_sources(existing.source, hit.source);
114 if existing.title.is_none() {
115 existing.title = hit.title;
116 }
117 if existing.snippet.is_none() {
118 existing.snippet = hit.snippet;
119 }
120 }
121 Entry::Vacant(entry) => {
122 entry.insert(hit);
123 }
124 }
125}
126
127fn merge_sources(left: SearchSource, right: SearchSource) -> SearchSource {
128 match (left, right) {
129 (SearchSource::Both, _) | (_, SearchSource::Both) => SearchSource::Both,
130 (SearchSource::Text, SearchSource::Vector) | (SearchSource::Vector, SearchSource::Text) => {
131 SearchSource::Both
132 }
133 (SearchSource::Text, SearchSource::Text) => SearchSource::Text,
134 (SearchSource::Vector, SearchSource::Vector) => SearchSource::Vector,
135 }
136}
137
138impl KhiveRuntime {
139 pub async fn hybrid_search_with_strategy(
141 &self,
142 token: &NamespaceToken,
143 query_text: &str,
144 query_vector: Option<Vec<f32>>,
145 strategy: FusionStrategy,
146 limit: u32,
147 ) -> RuntimeResult<Vec<SearchHit>> {
148 let candidates = limit.saturating_mul(CANDIDATE_MULTIPLIER).max(limit);
149
150 let ns = token.namespace().as_str().to_owned();
151 let text_hits = self
152 .text(token)?
153 .search(TextSearchRequest {
154 query: query_text.to_string(),
155 mode: TextQueryMode::Plain,
156 filter: Some(TextFilter {
157 namespaces: vec![ns.clone()],
158 ..TextFilter::default()
159 }),
160 top_k: candidates,
161 snippet_chars: 200,
162 })
163 .await?;
164
165 let vector_hits = if query_vector.is_some() || self.config().embedding_model.is_some() {
166 self.vector_search(
167 token,
168 query_vector,
169 Some(query_text),
170 candidates,
171 Some(SubstrateKind::Entity),
172 )
173 .await?
174 } else {
175 Vec::new()
176 };
177
178 let mut fused = fuse_with_strategy(text_hits, vector_hits, &strategy, limit as usize)?;
179
180 if !fused.is_empty() {
183 let candidate_ids: Vec<Uuid> = fused.iter().map(|h| h.entity_id).collect();
184 let alive_page = self
185 .entities(token)?
186 .query_entities(
187 token.namespace().as_str(),
188 EntityFilter {
189 ids: candidate_ids,
190 ..EntityFilter::default()
191 },
192 PageRequest {
193 offset: 0,
194 limit: fused.len() as u32,
195 },
196 )
197 .await?;
198 let alive: HashSet<Uuid> = alive_page.items.into_iter().map(|e| e.id).collect();
199 fused.retain(|h| alive.contains(&h.entity_id));
200 }
201
202 Ok(fused)
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use khive_storage::types::{TextSearchHit, VectorSearchHit};
210
211 fn text_hit(id: Uuid, score: f64, title: &str) -> TextSearchHit {
212 TextSearchHit {
213 subject_id: id,
214 score: DeterministicScore::from_f64(score),
215 rank: 1,
216 title: Some(title.to_string()),
217 snippet: Some("...".to_string()),
218 }
219 }
220
221 fn vector_hit(id: Uuid, score: f64) -> VectorSearchHit {
222 VectorSearchHit {
223 subject_id: id,
224 score: DeterministicScore::from_f64(score),
225 rank: 1,
226 }
227 }
228
229 #[test]
231 fn rrf_custom_k_differs_from_k60() {
232 let a = Uuid::new_v4();
233 let b = Uuid::new_v4();
234 let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
242 let hits_k1 =
243 fuse_with_strategy(text.clone(), vec![], &FusionStrategy::Rrf { k: 1 }, 10).unwrap();
244 let hits_k60 =
245 fuse_with_strategy(text, vec![], &FusionStrategy::Rrf { k: 60 }, 10).unwrap();
246 assert_eq!(hits_k1[0].entity_id, a);
248 assert_eq!(hits_k60[0].entity_id, a);
249 assert!(hits_k1[0].score > hits_k60[0].score);
251 }
252
253 #[test]
255 fn weighted_ordering_depends_on_weights() {
256 let a = Uuid::new_v4();
257 let b = Uuid::new_v4();
258 let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
260 let vec_hits = vec![vector_hit(b, 0.9), vector_hit(a, 0.1)];
261
262 let heavy_text = fuse_with_strategy(
263 text.clone(),
264 vec_hits.clone(),
265 &FusionStrategy::Weighted {
266 weights: vec![0.7, 0.3],
267 },
268 10,
269 )
270 .unwrap();
271 let heavy_vec = fuse_with_strategy(
272 text,
273 vec_hits,
274 &FusionStrategy::Weighted {
275 weights: vec![0.3, 0.7],
276 },
277 10,
278 )
279 .unwrap();
280
281 assert_eq!(heavy_text[0].entity_id, a);
282 assert_eq!(heavy_vec[0].entity_id, b);
283 }
284
285 #[test]
287 fn weighted_scale_invariant() {
288 let a = Uuid::new_v4();
289 let b = Uuid::new_v4();
290 let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
291 let vec_hits = vec![vector_hit(b, 0.9), vector_hit(a, 0.1)];
292
293 let w1 = fuse_with_strategy(
294 text.clone(),
295 vec_hits.clone(),
296 &FusionStrategy::Weighted {
297 weights: vec![0.7, 0.3],
298 },
299 10,
300 )
301 .unwrap();
302 let w2 = fuse_with_strategy(
303 text,
304 vec_hits,
305 &FusionStrategy::Weighted {
306 weights: vec![7.0, 3.0],
307 },
308 10,
309 )
310 .unwrap();
311
312 assert_eq!(w1[0].entity_id, w2[0].entity_id);
313 assert_eq!(w1[1].entity_id, w2[1].entity_id);
314 let diff = (w1[0].score.to_f64() - w2[0].score.to_f64()).abs();
315 assert!(diff < 1e-9, "scores differ by {diff}");
316 }
317
318 #[test]
320 fn weighted_zero_weights_equal_fallback() {
321 let a = Uuid::new_v4();
322 let b = Uuid::new_v4();
323 let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
325 let vec_hits = vec![vector_hit(a, 0.9), vector_hit(b, 0.1)];
326
327 let hits = fuse_with_strategy(
328 text,
329 vec_hits,
330 &FusionStrategy::Weighted {
331 weights: vec![0.0, 0.0],
332 },
333 10,
334 )
335 .unwrap();
336 assert_eq!(hits[0].entity_id, a);
337 }
338
339 #[test]
341 fn weighted_negative_weight_clamped() {
342 let a = Uuid::new_v4();
343 let text = vec![text_hit(a, 0.9, "a")];
344 let hits = fuse_with_strategy(
346 text,
347 vec![],
348 &FusionStrategy::Weighted {
349 weights: vec![1.0, -0.5],
350 },
351 10,
352 )
353 .unwrap();
354 assert_eq!(hits.len(), 1);
355 assert_eq!(hits[0].entity_id, a);
356 }
357
358 #[test]
360 fn union_max_score_per_entity() {
361 let a = Uuid::new_v4();
362 let text = vec![text_hit(a, 0.3, "a")];
363 let vec_hits = vec![vector_hit(a, 0.9)];
364
365 let hits = fuse_with_strategy(text, vec_hits, &FusionStrategy::Union, 10).unwrap();
366 assert_eq!(hits.len(), 1);
367 assert!((hits[0].score.to_f64() - 0.9).abs() < 1e-6);
368 assert_eq!(hits[0].source, SearchSource::Both);
369 }
370
371 #[test]
373 fn vector_only_drops_text() {
374 let a = Uuid::new_v4();
375 let b = Uuid::new_v4();
376 let text = vec![text_hit(b, 0.9, "b")];
377 let vec_hits = vec![vector_hit(a, 0.8)];
378
379 let hits = fuse_with_strategy(text, vec_hits, &FusionStrategy::VectorOnly, 10).unwrap();
380 assert_eq!(hits.len(), 1);
381 assert_eq!(hits[0].entity_id, a);
382 assert_eq!(hits[0].source, SearchSource::Vector);
383 assert!(hits[0].title.is_none());
384 }
385
386 #[test]
388 fn default_strategy_is_rrf_k60() {
389 assert_eq!(FusionStrategy::default(), FusionStrategy::Rrf { k: 60 });
390 }
391
392 #[test]
394 fn serde_roundtrip() {
395 let cases = vec![
396 FusionStrategy::Rrf { k: 60 },
397 FusionStrategy::Rrf { k: 20 },
398 FusionStrategy::Weighted {
399 weights: vec![0.7, 0.3],
400 },
401 FusionStrategy::Union,
402 FusionStrategy::VectorOnly,
403 ];
404 for strategy in cases {
405 let json = serde_json::to_string(&strategy).expect("serialize");
406 let back: FusionStrategy = serde_json::from_str(&json).expect("deserialize");
407 assert_eq!(strategy, back, "roundtrip failed for {json}");
408 }
409 }
410}