1use std::collections::{hash_map::Entry, HashMap, HashSet};
4
5use uuid::Uuid;
6
7use khive_score::DeterministicScore;
8use khive_storage::types::{
9 PageRequest, TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
10};
11use khive_storage::EntityFilter;
12use khive_types::SubstrateKind;
13
14use crate::error::RuntimeResult;
15use crate::retrieval::{SearchHit, SearchSource};
16use crate::runtime::{KhiveRuntime, NamespaceToken};
17
18pub use khive_fusion::FusionStrategy;
19
20const CANDIDATE_MULTIPLIER: u32 = 4;
21
22pub fn fuse_with_strategy(
24 text_hits: Vec<TextSearchHit>,
25 vector_hits: Vec<VectorSearchHit>,
26 strategy: &FusionStrategy,
27 limit: usize,
28) -> Vec<SearchHit> {
29 match strategy {
30 FusionStrategy::VectorOnly => fuse_sources(Vec::new(), vector_hits, strategy, limit),
31 FusionStrategy::KeywordOnly => fuse_sources(text_hits, Vec::new(), strategy, limit),
32 FusionStrategy::Rrf { .. } | FusionStrategy::Weighted { .. } | FusionStrategy::Union => {
33 fuse_sources(text_hits, vector_hits, strategy, limit)
34 }
35 }
36}
37
38pub(crate) fn rrf_fuse_k(
40 text_hits: Vec<TextSearchHit>,
41 vector_hits: Vec<VectorSearchHit>,
42 k: usize,
43 limit: usize,
44) -> Vec<SearchHit> {
45 fuse_with_strategy(text_hits, vector_hits, &FusionStrategy::Rrf { k }, limit)
46}
47
48fn fuse_sources(
49 text_hits: Vec<TextSearchHit>,
50 vector_hits: Vec<VectorSearchHit>,
51 strategy: &FusionStrategy,
52 limit: usize,
53) -> Vec<SearchHit> {
54 let mut metadata: HashMap<Uuid, SearchHit> =
55 HashMap::with_capacity(text_hits.len() + vector_hits.len());
56
57 let text_source: Vec<(Uuid, DeterministicScore)> = text_hits
58 .into_iter()
59 .map(|h| {
60 let hit = SearchHit {
61 entity_id: h.subject_id,
62 score: h.score,
63 source: SearchSource::Text,
64 title: h.title,
65 snippet: h.snippet,
66 };
67 let id = hit.entity_id;
68 let score = hit.score;
69 merge_metadata(&mut metadata, hit);
70 (id, score)
71 })
72 .collect();
73
74 let vector_source: Vec<(Uuid, DeterministicScore)> = vector_hits
75 .into_iter()
76 .map(|h| {
77 let hit = SearchHit {
78 entity_id: h.subject_id,
79 score: h.score,
80 source: SearchSource::Vector,
81 title: None,
82 snippet: None,
83 };
84 let id = hit.entity_id;
85 let score = hit.score;
86 merge_metadata(&mut metadata, hit);
87 (id, score)
88 })
89 .collect();
90
91 khive_fusion::fuse(vec![text_source, vector_source], strategy, limit)
92 .into_iter()
93 .filter_map(|(id, score)| {
94 let mut hit = metadata.remove(&id)?;
95 hit.score = score;
96 Some(hit)
97 })
98 .collect()
99}
100
101fn merge_metadata(metadata: &mut HashMap<Uuid, SearchHit>, hit: SearchHit) {
102 match metadata.entry(hit.entity_id) {
103 Entry::Occupied(mut entry) => {
104 let existing = entry.get_mut();
105 existing.source = merge_sources(existing.source, hit.source);
106 if existing.title.is_none() {
107 existing.title = hit.title;
108 }
109 if existing.snippet.is_none() {
110 existing.snippet = hit.snippet;
111 }
112 }
113 Entry::Vacant(entry) => {
114 entry.insert(hit);
115 }
116 }
117}
118
119fn merge_sources(left: SearchSource, right: SearchSource) -> SearchSource {
120 match (left, right) {
121 (SearchSource::Both, _) | (_, SearchSource::Both) => SearchSource::Both,
122 (SearchSource::Text, SearchSource::Vector) | (SearchSource::Vector, SearchSource::Text) => {
123 SearchSource::Both
124 }
125 (SearchSource::Text, SearchSource::Text) => SearchSource::Text,
126 (SearchSource::Vector, SearchSource::Vector) => SearchSource::Vector,
127 }
128}
129
130impl KhiveRuntime {
131 pub async fn hybrid_search_with_strategy(
133 &self,
134 token: &NamespaceToken,
135 query_text: &str,
136 query_vector: Option<Vec<f32>>,
137 strategy: FusionStrategy,
138 limit: u32,
139 ) -> RuntimeResult<Vec<SearchHit>> {
140 let candidates = limit.saturating_mul(CANDIDATE_MULTIPLIER).max(limit);
141
142 let ns = token.namespace().as_str().to_owned();
143 let text_hits = self
144 .text(token)?
145 .search(TextSearchRequest {
146 query: query_text.to_string(),
147 mode: TextQueryMode::Plain,
148 filter: Some(TextFilter {
149 namespaces: vec![ns.clone()],
150 ..TextFilter::default()
151 }),
152 top_k: candidates,
153 snippet_chars: 200,
154 })
155 .await?;
156
157 let vector_hits = if query_vector.is_some() || self.config().embedding_model.is_some() {
158 self.vector_search(
159 token,
160 query_vector,
161 Some(query_text),
162 candidates,
163 Some(SubstrateKind::Entity),
164 )
165 .await?
166 } else {
167 Vec::new()
168 };
169
170 let mut fused = fuse_with_strategy(text_hits, vector_hits, &strategy, limit as usize);
171
172 if !fused.is_empty() {
175 let candidate_ids: Vec<Uuid> = fused.iter().map(|h| h.entity_id).collect();
176 let alive_page = self
177 .entities(token)?
178 .query_entities(
179 token.namespace().as_str(),
180 EntityFilter {
181 ids: candidate_ids,
182 ..EntityFilter::default()
183 },
184 PageRequest {
185 offset: 0,
186 limit: fused.len() as u32,
187 },
188 )
189 .await?;
190 let alive: HashSet<Uuid> = alive_page.items.into_iter().map(|e| e.id).collect();
191 fused.retain(|h| alive.contains(&h.entity_id));
192 }
193
194 Ok(fused)
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use khive_storage::types::{TextSearchHit, VectorSearchHit};
202
203 fn text_hit(id: Uuid, score: f64, title: &str) -> TextSearchHit {
204 TextSearchHit {
205 subject_id: id,
206 score: DeterministicScore::from_f64(score),
207 rank: 1,
208 title: Some(title.to_string()),
209 snippet: Some("...".to_string()),
210 }
211 }
212
213 fn vector_hit(id: Uuid, score: f64) -> VectorSearchHit {
214 VectorSearchHit {
215 subject_id: id,
216 score: DeterministicScore::from_f64(score),
217 rank: 1,
218 }
219 }
220
221 #[test]
223 fn rrf_custom_k_differs_from_k60() {
224 let a = Uuid::new_v4();
225 let b = Uuid::new_v4();
226 let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
234 let hits_k1 = fuse_with_strategy(text.clone(), vec![], &FusionStrategy::Rrf { k: 1 }, 10);
235 let hits_k60 = fuse_with_strategy(text, vec![], &FusionStrategy::Rrf { k: 60 }, 10);
236 assert_eq!(hits_k1[0].entity_id, a);
238 assert_eq!(hits_k60[0].entity_id, a);
239 assert!(hits_k1[0].score > hits_k60[0].score);
241 }
242
243 #[test]
245 fn weighted_ordering_depends_on_weights() {
246 let a = Uuid::new_v4();
247 let b = Uuid::new_v4();
248 let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
250 let vec_hits = vec![vector_hit(b, 0.9), vector_hit(a, 0.1)];
251
252 let heavy_text = fuse_with_strategy(
253 text.clone(),
254 vec_hits.clone(),
255 &FusionStrategy::Weighted {
256 weights: vec![0.7, 0.3],
257 },
258 10,
259 );
260 let heavy_vec = fuse_with_strategy(
261 text,
262 vec_hits,
263 &FusionStrategy::Weighted {
264 weights: vec![0.3, 0.7],
265 },
266 10,
267 );
268
269 assert_eq!(heavy_text[0].entity_id, a);
270 assert_eq!(heavy_vec[0].entity_id, b);
271 }
272
273 #[test]
275 fn weighted_scale_invariant() {
276 let a = Uuid::new_v4();
277 let b = Uuid::new_v4();
278 let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
279 let vec_hits = vec![vector_hit(b, 0.9), vector_hit(a, 0.1)];
280
281 let w1 = fuse_with_strategy(
282 text.clone(),
283 vec_hits.clone(),
284 &FusionStrategy::Weighted {
285 weights: vec![0.7, 0.3],
286 },
287 10,
288 );
289 let w2 = fuse_with_strategy(
290 text,
291 vec_hits,
292 &FusionStrategy::Weighted {
293 weights: vec![7.0, 3.0],
294 },
295 10,
296 );
297
298 assert_eq!(w1[0].entity_id, w2[0].entity_id);
299 assert_eq!(w1[1].entity_id, w2[1].entity_id);
300 let diff = (w1[0].score.to_f64() - w2[0].score.to_f64()).abs();
301 assert!(diff < 1e-9, "scores differ by {diff}");
302 }
303
304 #[test]
306 fn weighted_zero_weights_equal_fallback() {
307 let a = Uuid::new_v4();
308 let b = Uuid::new_v4();
309 let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
311 let vec_hits = vec![vector_hit(a, 0.9), vector_hit(b, 0.1)];
312
313 let hits = fuse_with_strategy(
314 text,
315 vec_hits,
316 &FusionStrategy::Weighted {
317 weights: vec![0.0, 0.0],
318 },
319 10,
320 );
321 assert_eq!(hits[0].entity_id, a);
322 }
323
324 #[test]
326 fn weighted_negative_weight_clamped() {
327 let a = Uuid::new_v4();
328 let text = vec![text_hit(a, 0.9, "a")];
329 let hits = fuse_with_strategy(
331 text,
332 vec![],
333 &FusionStrategy::Weighted {
334 weights: vec![1.0, -0.5],
335 },
336 10,
337 );
338 assert_eq!(hits.len(), 1);
339 assert_eq!(hits[0].entity_id, a);
340 }
341
342 #[test]
344 fn union_max_score_per_entity() {
345 let a = Uuid::new_v4();
346 let text = vec![text_hit(a, 0.3, "a")];
347 let vec_hits = vec![vector_hit(a, 0.9)];
348
349 let hits = fuse_with_strategy(text, vec_hits, &FusionStrategy::Union, 10);
350 assert_eq!(hits.len(), 1);
351 assert!((hits[0].score.to_f64() - 0.9).abs() < 1e-6);
352 assert_eq!(hits[0].source, SearchSource::Both);
353 }
354
355 #[test]
357 fn vector_only_drops_text() {
358 let a = Uuid::new_v4();
359 let b = Uuid::new_v4();
360 let text = vec![text_hit(b, 0.9, "b")];
361 let vec_hits = vec![vector_hit(a, 0.8)];
362
363 let hits = fuse_with_strategy(text, vec_hits, &FusionStrategy::VectorOnly, 10);
364 assert_eq!(hits.len(), 1);
365 assert_eq!(hits[0].entity_id, a);
366 assert_eq!(hits[0].source, SearchSource::Vector);
367 assert!(hits[0].title.is_none());
368 }
369
370 #[test]
372 fn default_strategy_is_rrf_k60() {
373 assert_eq!(FusionStrategy::default(), FusionStrategy::Rrf { k: 60 });
374 }
375
376 #[test]
378 fn serde_roundtrip() {
379 let cases = vec![
380 FusionStrategy::Rrf { k: 60 },
381 FusionStrategy::Rrf { k: 20 },
382 FusionStrategy::Weighted {
383 weights: vec![0.7, 0.3],
384 },
385 FusionStrategy::Union,
386 FusionStrategy::VectorOnly,
387 ];
388 for strategy in cases {
389 let json = serde_json::to_string(&strategy).expect("serialize");
390 let back: FusionStrategy = serde_json::from_str(&json).expect("deserialize");
391 assert_eq!(strategy, back, "roundtrip failed for {json}");
392 }
393 }
394}