1use std::collections::{HashMap, HashSet};
4
5use serde::{Deserialize, Serialize};
6use uuid::Uuid;
7
8use khive_score::{rrf_score, DeterministicScore};
9use khive_storage::types::{
10 PageRequest, TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
11};
12use khive_storage::EntityFilter;
13use khive_types::SubstrateKind;
14
15use crate::error::RuntimeResult;
16use crate::retrieval::{SearchHit, SearchSource};
17use crate::runtime::KhiveRuntime;
18
19const CANDIDATE_MULTIPLIER: u32 = 4;
20
21#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
23#[serde(rename_all = "snake_case")]
24pub enum FusionStrategy {
25 Rrf { k: usize },
27 Weighted { weights: Vec<f64> },
30 Union,
32 VectorOnly,
34}
35
36impl Default for FusionStrategy {
37 fn default() -> Self {
38 Self::Rrf { k: 60 }
39 }
40}
41
42pub fn fuse_with_strategy(
44 text_hits: Vec<TextSearchHit>,
45 vector_hits: Vec<VectorSearchHit>,
46 strategy: &FusionStrategy,
47 limit: usize,
48) -> Vec<SearchHit> {
49 match strategy {
50 FusionStrategy::Rrf { k } => rrf_fuse_k(text_hits, vector_hits, *k, limit),
51 FusionStrategy::Weighted { weights } => {
52 weighted_fuse(text_hits, vector_hits, weights, limit)
53 }
54 FusionStrategy::Union => union_fuse(text_hits, vector_hits, limit),
55 FusionStrategy::VectorOnly => vector_only(vector_hits, limit),
56 }
57}
58
59impl KhiveRuntime {
60 pub async fn hybrid_search_with_strategy(
62 &self,
63 namespace: Option<&str>,
64 query_text: &str,
65 query_vector: Option<Vec<f32>>,
66 strategy: FusionStrategy,
67 limit: u32,
68 ) -> RuntimeResult<Vec<SearchHit>> {
69 let candidates = limit.saturating_mul(CANDIDATE_MULTIPLIER).max(limit);
70
71 let ns = self.ns(namespace).to_string();
72 let text_hits = self
73 .text(namespace)?
74 .search(TextSearchRequest {
75 query: query_text.to_string(),
76 mode: TextQueryMode::Plain,
77 filter: Some(TextFilter {
78 namespaces: vec![ns.clone()],
79 ..TextFilter::default()
80 }),
81 top_k: candidates,
82 snippet_chars: 200,
83 })
84 .await?;
85
86 let vector_hits = if query_vector.is_some() || self.config().embedding_model.is_some() {
87 self.vector_search(
88 namespace,
89 query_vector,
90 Some(query_text),
91 candidates,
92 Some(SubstrateKind::Entity),
93 )
94 .await?
95 } else {
96 Vec::new()
97 };
98
99 let mut fused = fuse_with_strategy(text_hits, vector_hits, &strategy, limit as usize);
100
101 if !fused.is_empty() {
104 let candidate_ids: Vec<Uuid> = fused.iter().map(|h| h.entity_id).collect();
105 let alive_page = self
106 .entities(namespace)?
107 .query_entities(
108 self.ns(namespace),
109 EntityFilter {
110 ids: candidate_ids,
111 ..EntityFilter::default()
112 },
113 PageRequest {
114 offset: 0,
115 limit: fused.len() as u32,
116 },
117 )
118 .await?;
119 let alive: HashSet<Uuid> = alive_page.items.into_iter().map(|e| e.id).collect();
120 fused.retain(|h| alive.contains(&h.entity_id));
121 }
122
123 Ok(fused)
124 }
125}
126
127fn rrf_fuse_k(
128 text_hits: Vec<TextSearchHit>,
129 vector_hits: Vec<VectorSearchHit>,
130 k: usize,
131 limit: usize,
132) -> Vec<SearchHit> {
133 #[derive(Default)]
134 struct Bucket {
135 score: DeterministicScore,
136 source: Option<SearchSource>,
137 title: Option<String>,
138 snippet: Option<String>,
139 }
140
141 let mut buckets: HashMap<Uuid, Bucket> = HashMap::new();
142
143 for (i, hit) in text_hits.into_iter().enumerate() {
144 let entry = buckets.entry(hit.subject_id).or_default();
145 entry.score = entry.score + rrf_score(i + 1, k);
146 entry.source = Some(match entry.source {
147 Some(SearchSource::Vector) => SearchSource::Both,
148 _ => SearchSource::Text,
149 });
150 if entry.title.is_none() {
151 entry.title = hit.title;
152 }
153 if entry.snippet.is_none() {
154 entry.snippet = hit.snippet;
155 }
156 }
157
158 for (i, hit) in vector_hits.into_iter().enumerate() {
159 let entry = buckets.entry(hit.subject_id).or_default();
160 entry.score = entry.score + rrf_score(i + 1, k);
161 entry.source = Some(match entry.source {
162 Some(SearchSource::Text) => SearchSource::Both,
163 _ => SearchSource::Vector,
164 });
165 }
166
167 let mut hits: Vec<SearchHit> = buckets
168 .into_iter()
169 .map(|(id, b)| SearchHit {
170 entity_id: id,
171 score: b.score,
172 source: b.source.expect("each bucket gets a source"),
173 title: b.title,
174 snippet: b.snippet,
175 })
176 .collect();
177
178 hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
179 hits.truncate(limit);
180 hits
181}
182
183fn weighted_fuse(
184 text_hits: Vec<TextSearchHit>,
185 vector_hits: Vec<VectorSearchHit>,
186 weights: &[f64],
187 limit: usize,
188) -> Vec<SearchHit> {
189 let w0 = weights.first().copied().unwrap_or(0.0).max(0.0);
191 let w1 = weights.get(1).copied().unwrap_or(0.0).max(0.0);
192 let total = w0 + w1;
193 let (nw0, nw1) = if total <= 0.0 {
194 (0.5, 0.5)
195 } else {
196 (w0 / total, w1 / total)
197 };
198
199 let mut meta: HashMap<Uuid, (Option<String>, Option<String>)> = HashMap::new();
201 let text_scores: Vec<(Uuid, f64)> = text_hits
202 .into_iter()
203 .map(|h| {
204 meta.entry(h.subject_id)
205 .or_insert_with(|| (h.title, h.snippet));
206 (h.subject_id, h.score.to_f64())
207 })
208 .collect();
209
210 let vector_scores: Vec<(Uuid, f64)> = vector_hits
211 .into_iter()
212 .map(|h| (h.subject_id, h.score.to_f64()))
213 .collect();
214
215 let text_norm = min_max_normalize(&text_scores);
217 let vector_norm = min_max_normalize(&vector_scores);
218
219 let mut combined: HashMap<Uuid, f64> = HashMap::new();
220 for (id, s) in &text_norm {
221 *combined.entry(*id).or_insert(0.0) += s * nw0;
222 }
223 for (id, s) in &vector_norm {
224 *combined.entry(*id).or_insert(0.0) += s * nw1;
225 }
226
227 let mut hits: Vec<SearchHit> = combined
228 .into_iter()
229 .map(|(id, score)| {
230 let (title, snippet) = meta.get(&id).cloned().unwrap_or_default();
231 let source = match (
232 text_norm.iter().any(|(i, _)| *i == id),
233 vector_norm.iter().any(|(i, _)| *i == id),
234 ) {
235 (true, true) => SearchSource::Both,
236 (true, false) => SearchSource::Text,
237 _ => SearchSource::Vector,
238 };
239 SearchHit {
240 entity_id: id,
241 score: DeterministicScore::from_f64(score),
242 source,
243 title,
244 snippet,
245 }
246 })
247 .collect();
248
249 hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
250 hits.truncate(limit);
251 hits
252}
253
254fn min_max_normalize(scores: &[(Uuid, f64)]) -> Vec<(Uuid, f64)> {
255 if scores.is_empty() {
256 return Vec::new();
257 }
258 let min = scores.iter().map(|(_, s)| *s).fold(f64::INFINITY, f64::min);
259 let max = scores
260 .iter()
261 .map(|(_, s)| *s)
262 .fold(f64::NEG_INFINITY, f64::max);
263 let span = max - min;
264 if span <= f64::EPSILON {
265 return scores.iter().map(|(id, _)| (*id, 1.0)).collect();
266 }
267 scores
268 .iter()
269 .map(|(id, s)| (*id, (s - min) / span))
270 .collect()
271}
272
273fn union_fuse(
274 text_hits: Vec<TextSearchHit>,
275 vector_hits: Vec<VectorSearchHit>,
276 limit: usize,
277) -> Vec<SearchHit> {
278 struct Bucket {
279 score: DeterministicScore,
280 source: SearchSource,
281 title: Option<String>,
282 snippet: Option<String>,
283 }
284
285 let mut buckets: HashMap<Uuid, Bucket> = HashMap::new();
286
287 for hit in text_hits {
288 let entry = buckets.entry(hit.subject_id).or_insert_with(|| Bucket {
289 score: DeterministicScore::ZERO,
290 source: SearchSource::Text,
291 title: None,
292 snippet: None,
293 });
294 if hit.score > entry.score {
295 entry.score = hit.score;
296 }
297 if entry.title.is_none() {
298 entry.title = hit.title;
299 }
300 if entry.snippet.is_none() {
301 entry.snippet = hit.snippet;
302 }
303 if entry.source == SearchSource::Vector {
304 entry.source = SearchSource::Both;
305 }
306 }
307
308 for hit in vector_hits {
309 let entry = buckets.entry(hit.subject_id).or_insert_with(|| Bucket {
310 score: DeterministicScore::ZERO,
311 source: SearchSource::Vector,
312 title: None,
313 snippet: None,
314 });
315 if hit.score > entry.score {
316 entry.score = hit.score;
317 }
318 if entry.source == SearchSource::Text {
319 entry.source = SearchSource::Both;
320 }
321 }
322
323 let mut hits: Vec<SearchHit> = buckets
324 .into_iter()
325 .map(|(id, b)| SearchHit {
326 entity_id: id,
327 score: b.score,
328 source: b.source,
329 title: b.title,
330 snippet: b.snippet,
331 })
332 .collect();
333
334 hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
335 hits.truncate(limit);
336 hits
337}
338
339fn vector_only(vector_hits: Vec<VectorSearchHit>, limit: usize) -> Vec<SearchHit> {
340 let mut hits: Vec<SearchHit> = vector_hits
341 .into_iter()
342 .map(|h| SearchHit {
343 entity_id: h.subject_id,
344 score: h.score,
345 source: SearchSource::Vector,
346 title: None,
347 snippet: None,
348 })
349 .collect();
350 hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
351 hits.truncate(limit);
352 hits
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use khive_storage::types::{TextSearchHit, VectorSearchHit};
359
360 fn text_hit(id: Uuid, score: f64, title: &str) -> TextSearchHit {
361 TextSearchHit {
362 subject_id: id,
363 score: DeterministicScore::from_f64(score),
364 rank: 1,
365 title: Some(title.to_string()),
366 snippet: Some("...".to_string()),
367 }
368 }
369
370 fn vector_hit(id: Uuid, score: f64) -> VectorSearchHit {
371 VectorSearchHit {
372 subject_id: id,
373 score: DeterministicScore::from_f64(score),
374 rank: 1,
375 }
376 }
377
378 #[test]
380 fn rrf_custom_k_differs_from_k60() {
381 let a = Uuid::new_v4();
382 let b = Uuid::new_v4();
383 let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
391 let hits_k1 = fuse_with_strategy(text.clone(), vec![], &FusionStrategy::Rrf { k: 1 }, 10);
392 let hits_k60 = fuse_with_strategy(text, vec![], &FusionStrategy::Rrf { k: 60 }, 10);
393 assert_eq!(hits_k1[0].entity_id, a);
395 assert_eq!(hits_k60[0].entity_id, a);
396 assert!(hits_k1[0].score > hits_k60[0].score);
398 }
399
400 #[test]
402 fn weighted_ordering_depends_on_weights() {
403 let a = Uuid::new_v4();
404 let b = Uuid::new_v4();
405 let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
407 let vec_hits = vec![vector_hit(b, 0.9), vector_hit(a, 0.1)];
408
409 let heavy_text = fuse_with_strategy(
410 text.clone(),
411 vec_hits.clone(),
412 &FusionStrategy::Weighted {
413 weights: vec![0.7, 0.3],
414 },
415 10,
416 );
417 let heavy_vec = fuse_with_strategy(
418 text,
419 vec_hits,
420 &FusionStrategy::Weighted {
421 weights: vec![0.3, 0.7],
422 },
423 10,
424 );
425
426 assert_eq!(heavy_text[0].entity_id, a);
427 assert_eq!(heavy_vec[0].entity_id, b);
428 }
429
430 #[test]
432 fn weighted_scale_invariant() {
433 let a = Uuid::new_v4();
434 let b = Uuid::new_v4();
435 let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
436 let vec_hits = vec![vector_hit(b, 0.9), vector_hit(a, 0.1)];
437
438 let w1 = fuse_with_strategy(
439 text.clone(),
440 vec_hits.clone(),
441 &FusionStrategy::Weighted {
442 weights: vec![0.7, 0.3],
443 },
444 10,
445 );
446 let w2 = fuse_with_strategy(
447 text,
448 vec_hits,
449 &FusionStrategy::Weighted {
450 weights: vec![7.0, 3.0],
451 },
452 10,
453 );
454
455 assert_eq!(w1[0].entity_id, w2[0].entity_id);
456 assert_eq!(w1[1].entity_id, w2[1].entity_id);
457 let diff = (w1[0].score.to_f64() - w2[0].score.to_f64()).abs();
458 assert!(diff < 1e-9, "scores differ by {diff}");
459 }
460
461 #[test]
463 fn weighted_zero_weights_equal_fallback() {
464 let a = Uuid::new_v4();
465 let b = Uuid::new_v4();
466 let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
468 let vec_hits = vec![vector_hit(a, 0.9), vector_hit(b, 0.1)];
469
470 let hits = fuse_with_strategy(
471 text,
472 vec_hits,
473 &FusionStrategy::Weighted {
474 weights: vec![0.0, 0.0],
475 },
476 10,
477 );
478 assert_eq!(hits[0].entity_id, a);
479 }
480
481 #[test]
483 fn weighted_negative_weight_clamped() {
484 let a = Uuid::new_v4();
485 let text = vec![text_hit(a, 0.9, "a")];
486 let hits = fuse_with_strategy(
488 text,
489 vec![],
490 &FusionStrategy::Weighted {
491 weights: vec![1.0, -0.5],
492 },
493 10,
494 );
495 assert_eq!(hits.len(), 1);
496 assert_eq!(hits[0].entity_id, a);
497 }
498
499 #[test]
501 fn union_max_score_per_entity() {
502 let a = Uuid::new_v4();
503 let text = vec![text_hit(a, 0.3, "a")];
504 let vec_hits = vec![vector_hit(a, 0.9)];
505
506 let hits = fuse_with_strategy(text, vec_hits, &FusionStrategy::Union, 10);
507 assert_eq!(hits.len(), 1);
508 assert!((hits[0].score.to_f64() - 0.9).abs() < 1e-6);
509 assert_eq!(hits[0].source, SearchSource::Both);
510 }
511
512 #[test]
514 fn vector_only_drops_text() {
515 let a = Uuid::new_v4();
516 let b = Uuid::new_v4();
517 let text = vec![text_hit(b, 0.9, "b")];
518 let vec_hits = vec![vector_hit(a, 0.8)];
519
520 let hits = fuse_with_strategy(text, vec_hits, &FusionStrategy::VectorOnly, 10);
521 assert_eq!(hits.len(), 1);
522 assert_eq!(hits[0].entity_id, a);
523 assert_eq!(hits[0].source, SearchSource::Vector);
524 assert!(hits[0].title.is_none());
525 }
526
527 #[test]
529 fn default_strategy_is_rrf_k60() {
530 assert_eq!(FusionStrategy::default(), FusionStrategy::Rrf { k: 60 });
531 }
532
533 #[test]
535 fn serde_roundtrip() {
536 let cases = vec![
537 FusionStrategy::Rrf { k: 60 },
538 FusionStrategy::Rrf { k: 20 },
539 FusionStrategy::Weighted {
540 weights: vec![0.7, 0.3],
541 },
542 FusionStrategy::Union,
543 FusionStrategy::VectorOnly,
544 ];
545 for strategy in cases {
546 let json = serde_json::to_string(&strategy).expect("serialize");
547 let back: FusionStrategy = serde_json::from_str(&json).expect("deserialize");
548 assert_eq!(strategy, back, "roundtrip failed for {json}");
549 }
550 }
551}