1use std::collections::{HashMap, HashSet};
4
5use serde::{Deserialize, Serialize};
6use uuid::Uuid;
7
8use khive_score::{rrf_score, DeterministicScore};
9use khive_storage::types::{
10 PageRequest, TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
11 VectorSearchRequest,
12};
13use khive_storage::EntityFilter;
14use khive_types::SubstrateKind;
15
16use crate::error::RuntimeResult;
17use crate::retrieval::{SearchHit, SearchSource};
18use crate::runtime::KhiveRuntime;
19
20const CANDIDATE_MULTIPLIER: u32 = 4;
21
22#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
24#[serde(rename_all = "snake_case")]
25pub enum FusionStrategy {
26 Rrf { k: usize },
28 Weighted { weights: Vec<f64> },
31 Union,
33 VectorOnly,
35}
36
37impl Default for FusionStrategy {
38 fn default() -> Self {
39 Self::Rrf { k: 60 }
40 }
41}
42
43pub fn fuse_with_strategy(
45 text_hits: Vec<TextSearchHit>,
46 vector_hits: Vec<VectorSearchHit>,
47 strategy: &FusionStrategy,
48 limit: usize,
49) -> Vec<SearchHit> {
50 match strategy {
51 FusionStrategy::Rrf { k } => rrf_fuse_k(text_hits, vector_hits, *k, limit),
52 FusionStrategy::Weighted { weights } => {
53 weighted_fuse(text_hits, vector_hits, weights, limit)
54 }
55 FusionStrategy::Union => union_fuse(text_hits, vector_hits, limit),
56 FusionStrategy::VectorOnly => vector_only(vector_hits, limit),
57 }
58}
59
60impl KhiveRuntime {
61 pub async fn hybrid_search_with_strategy(
63 &self,
64 namespace: Option<&str>,
65 query_text: &str,
66 query_vector: Option<Vec<f32>>,
67 strategy: FusionStrategy,
68 limit: u32,
69 ) -> RuntimeResult<Vec<SearchHit>> {
70 let candidates = limit.saturating_mul(CANDIDATE_MULTIPLIER).max(limit);
71
72 let ns = self.ns(namespace).to_string();
73 let text_hits = self
74 .text(namespace)?
75 .search(TextSearchRequest {
76 query: query_text.to_string(),
77 mode: TextQueryMode::Plain,
78 filter: Some(TextFilter {
79 namespaces: vec![ns.clone()],
80 ..TextFilter::default()
81 }),
82 top_k: candidates,
83 snippet_chars: 200,
84 })
85 .await?;
86
87 let vector_hits = if let Some(vec) = query_vector {
88 self.vectors(namespace)?
89 .search(VectorSearchRequest {
90 query_embedding: vec,
91 top_k: candidates,
92 namespace: Some(ns.clone()),
93 kind: Some(SubstrateKind::Entity),
94 })
95 .await?
96 } else {
97 Vec::new()
98 };
99
100 let mut fused = fuse_with_strategy(text_hits, vector_hits, &strategy, limit as usize);
101
102 if !fused.is_empty() {
105 let candidate_ids: Vec<Uuid> = fused.iter().map(|h| h.entity_id).collect();
106 let alive_page = self
107 .entities(namespace)?
108 .query_entities(
109 self.ns(namespace),
110 EntityFilter {
111 ids: candidate_ids,
112 ..EntityFilter::default()
113 },
114 PageRequest {
115 offset: 0,
116 limit: fused.len() as u32,
117 },
118 )
119 .await?;
120 let alive: HashSet<Uuid> = alive_page.items.into_iter().map(|e| e.id).collect();
121 fused.retain(|h| alive.contains(&h.entity_id));
122 }
123
124 Ok(fused)
125 }
126}
127
128fn rrf_fuse_k(
129 text_hits: Vec<TextSearchHit>,
130 vector_hits: Vec<VectorSearchHit>,
131 k: usize,
132 limit: usize,
133) -> Vec<SearchHit> {
134 #[derive(Default)]
135 struct Bucket {
136 score: DeterministicScore,
137 source: Option<SearchSource>,
138 title: Option<String>,
139 snippet: Option<String>,
140 }
141
142 let mut buckets: HashMap<Uuid, Bucket> = HashMap::new();
143
144 for (i, hit) in text_hits.into_iter().enumerate() {
145 let entry = buckets.entry(hit.subject_id).or_default();
146 entry.score = entry.score + rrf_score(i + 1, k);
147 entry.source = Some(match entry.source {
148 Some(SearchSource::Vector) => SearchSource::Both,
149 _ => SearchSource::Text,
150 });
151 if entry.title.is_none() {
152 entry.title = hit.title;
153 }
154 if entry.snippet.is_none() {
155 entry.snippet = hit.snippet;
156 }
157 }
158
159 for (i, hit) in vector_hits.into_iter().enumerate() {
160 let entry = buckets.entry(hit.subject_id).or_default();
161 entry.score = entry.score + rrf_score(i + 1, k);
162 entry.source = Some(match entry.source {
163 Some(SearchSource::Text) => SearchSource::Both,
164 _ => SearchSource::Vector,
165 });
166 }
167
168 let mut hits: Vec<SearchHit> = buckets
169 .into_iter()
170 .map(|(id, b)| SearchHit {
171 entity_id: id,
172 score: b.score,
173 source: b.source.expect("each bucket gets a source"),
174 title: b.title,
175 snippet: b.snippet,
176 })
177 .collect();
178
179 hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
180 hits.truncate(limit);
181 hits
182}
183
184fn weighted_fuse(
185 text_hits: Vec<TextSearchHit>,
186 vector_hits: Vec<VectorSearchHit>,
187 weights: &[f64],
188 limit: usize,
189) -> Vec<SearchHit> {
190 let w0 = weights.first().copied().unwrap_or(0.0).max(0.0);
192 let w1 = weights.get(1).copied().unwrap_or(0.0).max(0.0);
193 let total = w0 + w1;
194 let (nw0, nw1) = if total <= 0.0 {
195 (0.5, 0.5)
196 } else {
197 (w0 / total, w1 / total)
198 };
199
200 let mut meta: HashMap<Uuid, (Option<String>, Option<String>)> = HashMap::new();
202 let text_scores: Vec<(Uuid, f64)> = text_hits
203 .into_iter()
204 .map(|h| {
205 meta.entry(h.subject_id)
206 .or_insert_with(|| (h.title, h.snippet));
207 (h.subject_id, h.score.to_f64())
208 })
209 .collect();
210
211 let vector_scores: Vec<(Uuid, f64)> = vector_hits
212 .into_iter()
213 .map(|h| (h.subject_id, h.score.to_f64()))
214 .collect();
215
216 let text_norm = min_max_normalize(&text_scores);
218 let vector_norm = min_max_normalize(&vector_scores);
219
220 let mut combined: HashMap<Uuid, f64> = HashMap::new();
221 for (id, s) in &text_norm {
222 *combined.entry(*id).or_insert(0.0) += s * nw0;
223 }
224 for (id, s) in &vector_norm {
225 *combined.entry(*id).or_insert(0.0) += s * nw1;
226 }
227
228 let mut hits: Vec<SearchHit> = combined
229 .into_iter()
230 .map(|(id, score)| {
231 let (title, snippet) = meta.get(&id).cloned().unwrap_or_default();
232 let source = match (
233 text_norm.iter().any(|(i, _)| *i == id),
234 vector_norm.iter().any(|(i, _)| *i == id),
235 ) {
236 (true, true) => SearchSource::Both,
237 (true, false) => SearchSource::Text,
238 _ => SearchSource::Vector,
239 };
240 SearchHit {
241 entity_id: id,
242 score: DeterministicScore::from_f64(score),
243 source,
244 title,
245 snippet,
246 }
247 })
248 .collect();
249
250 hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
251 hits.truncate(limit);
252 hits
253}
254
255fn min_max_normalize(scores: &[(Uuid, f64)]) -> Vec<(Uuid, f64)> {
256 if scores.is_empty() {
257 return Vec::new();
258 }
259 let min = scores.iter().map(|(_, s)| *s).fold(f64::INFINITY, f64::min);
260 let max = scores
261 .iter()
262 .map(|(_, s)| *s)
263 .fold(f64::NEG_INFINITY, f64::max);
264 let span = max - min;
265 if span <= f64::EPSILON {
266 return scores.iter().map(|(id, _)| (*id, 1.0)).collect();
267 }
268 scores
269 .iter()
270 .map(|(id, s)| (*id, (s - min) / span))
271 .collect()
272}
273
274fn union_fuse(
275 text_hits: Vec<TextSearchHit>,
276 vector_hits: Vec<VectorSearchHit>,
277 limit: usize,
278) -> Vec<SearchHit> {
279 struct Bucket {
280 score: DeterministicScore,
281 source: SearchSource,
282 title: Option<String>,
283 snippet: Option<String>,
284 }
285
286 let mut buckets: HashMap<Uuid, Bucket> = HashMap::new();
287
288 for hit in text_hits {
289 let entry = buckets.entry(hit.subject_id).or_insert_with(|| Bucket {
290 score: DeterministicScore::ZERO,
291 source: SearchSource::Text,
292 title: None,
293 snippet: None,
294 });
295 if hit.score > entry.score {
296 entry.score = hit.score;
297 }
298 if entry.title.is_none() {
299 entry.title = hit.title;
300 }
301 if entry.snippet.is_none() {
302 entry.snippet = hit.snippet;
303 }
304 if entry.source == SearchSource::Vector {
305 entry.source = SearchSource::Both;
306 }
307 }
308
309 for hit in vector_hits {
310 let entry = buckets.entry(hit.subject_id).or_insert_with(|| Bucket {
311 score: DeterministicScore::ZERO,
312 source: SearchSource::Vector,
313 title: None,
314 snippet: None,
315 });
316 if hit.score > entry.score {
317 entry.score = hit.score;
318 }
319 if entry.source == SearchSource::Text {
320 entry.source = SearchSource::Both;
321 }
322 }
323
324 let mut hits: Vec<SearchHit> = buckets
325 .into_iter()
326 .map(|(id, b)| SearchHit {
327 entity_id: id,
328 score: b.score,
329 source: b.source,
330 title: b.title,
331 snippet: b.snippet,
332 })
333 .collect();
334
335 hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
336 hits.truncate(limit);
337 hits
338}
339
340fn vector_only(vector_hits: Vec<VectorSearchHit>, limit: usize) -> Vec<SearchHit> {
341 let mut hits: Vec<SearchHit> = vector_hits
342 .into_iter()
343 .map(|h| SearchHit {
344 entity_id: h.subject_id,
345 score: h.score,
346 source: SearchSource::Vector,
347 title: None,
348 snippet: None,
349 })
350 .collect();
351 hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
352 hits.truncate(limit);
353 hits
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359 use khive_storage::types::{TextSearchHit, VectorSearchHit};
360
361 fn text_hit(id: Uuid, score: f64, title: &str) -> TextSearchHit {
362 TextSearchHit {
363 subject_id: id,
364 score: DeterministicScore::from_f64(score),
365 rank: 1,
366 title: Some(title.to_string()),
367 snippet: Some("...".to_string()),
368 }
369 }
370
371 fn vector_hit(id: Uuid, score: f64) -> VectorSearchHit {
372 VectorSearchHit {
373 subject_id: id,
374 score: DeterministicScore::from_f64(score),
375 rank: 1,
376 }
377 }
378
379 #[test]
381 fn rrf_custom_k_differs_from_k60() {
382 let a = Uuid::new_v4();
383 let b = Uuid::new_v4();
384 let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
392 let hits_k1 = fuse_with_strategy(text.clone(), vec![], &FusionStrategy::Rrf { k: 1 }, 10);
393 let hits_k60 = fuse_with_strategy(text, vec![], &FusionStrategy::Rrf { k: 60 }, 10);
394 assert_eq!(hits_k1[0].entity_id, a);
396 assert_eq!(hits_k60[0].entity_id, a);
397 assert!(hits_k1[0].score > hits_k60[0].score);
399 }
400
401 #[test]
403 fn weighted_ordering_depends_on_weights() {
404 let a = Uuid::new_v4();
405 let b = Uuid::new_v4();
406 let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
408 let vec_hits = vec![vector_hit(b, 0.9), vector_hit(a, 0.1)];
409
410 let heavy_text = fuse_with_strategy(
411 text.clone(),
412 vec_hits.clone(),
413 &FusionStrategy::Weighted {
414 weights: vec![0.7, 0.3],
415 },
416 10,
417 );
418 let heavy_vec = fuse_with_strategy(
419 text,
420 vec_hits,
421 &FusionStrategy::Weighted {
422 weights: vec![0.3, 0.7],
423 },
424 10,
425 );
426
427 assert_eq!(heavy_text[0].entity_id, a);
428 assert_eq!(heavy_vec[0].entity_id, b);
429 }
430
431 #[test]
433 fn weighted_scale_invariant() {
434 let a = Uuid::new_v4();
435 let b = Uuid::new_v4();
436 let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
437 let vec_hits = vec![vector_hit(b, 0.9), vector_hit(a, 0.1)];
438
439 let w1 = fuse_with_strategy(
440 text.clone(),
441 vec_hits.clone(),
442 &FusionStrategy::Weighted {
443 weights: vec![0.7, 0.3],
444 },
445 10,
446 );
447 let w2 = fuse_with_strategy(
448 text,
449 vec_hits,
450 &FusionStrategy::Weighted {
451 weights: vec![7.0, 3.0],
452 },
453 10,
454 );
455
456 assert_eq!(w1[0].entity_id, w2[0].entity_id);
457 assert_eq!(w1[1].entity_id, w2[1].entity_id);
458 let diff = (w1[0].score.to_f64() - w2[0].score.to_f64()).abs();
459 assert!(diff < 1e-9, "scores differ by {diff}");
460 }
461
462 #[test]
464 fn weighted_zero_weights_equal_fallback() {
465 let a = Uuid::new_v4();
466 let b = Uuid::new_v4();
467 let text = vec![text_hit(a, 0.9, "a"), text_hit(b, 0.1, "b")];
469 let vec_hits = vec![vector_hit(a, 0.9), vector_hit(b, 0.1)];
470
471 let hits = fuse_with_strategy(
472 text,
473 vec_hits,
474 &FusionStrategy::Weighted {
475 weights: vec![0.0, 0.0],
476 },
477 10,
478 );
479 assert_eq!(hits[0].entity_id, a);
480 }
481
482 #[test]
484 fn weighted_negative_weight_clamped() {
485 let a = Uuid::new_v4();
486 let text = vec![text_hit(a, 0.9, "a")];
487 let hits = fuse_with_strategy(
489 text,
490 vec![],
491 &FusionStrategy::Weighted {
492 weights: vec![1.0, -0.5],
493 },
494 10,
495 );
496 assert_eq!(hits.len(), 1);
497 assert_eq!(hits[0].entity_id, a);
498 }
499
500 #[test]
502 fn union_max_score_per_entity() {
503 let a = Uuid::new_v4();
504 let text = vec![text_hit(a, 0.3, "a")];
505 let vec_hits = vec![vector_hit(a, 0.9)];
506
507 let hits = fuse_with_strategy(text, vec_hits, &FusionStrategy::Union, 10);
508 assert_eq!(hits.len(), 1);
509 assert!((hits[0].score.to_f64() - 0.9).abs() < 1e-6);
510 assert_eq!(hits[0].source, SearchSource::Both);
511 }
512
513 #[test]
515 fn vector_only_drops_text() {
516 let a = Uuid::new_v4();
517 let b = Uuid::new_v4();
518 let text = vec![text_hit(b, 0.9, "b")];
519 let vec_hits = vec![vector_hit(a, 0.8)];
520
521 let hits = fuse_with_strategy(text, vec_hits, &FusionStrategy::VectorOnly, 10);
522 assert_eq!(hits.len(), 1);
523 assert_eq!(hits[0].entity_id, a);
524 assert_eq!(hits[0].source, SearchSource::Vector);
525 assert!(hits[0].title.is_none());
526 }
527
528 #[test]
530 fn default_strategy_is_rrf_k60() {
531 assert_eq!(FusionStrategy::default(), FusionStrategy::Rrf { k: 60 });
532 }
533
534 #[test]
536 fn serde_roundtrip() {
537 let cases = vec![
538 FusionStrategy::Rrf { k: 60 },
539 FusionStrategy::Rrf { k: 20 },
540 FusionStrategy::Weighted {
541 weights: vec![0.7, 0.3],
542 },
543 FusionStrategy::Union,
544 FusionStrategy::VectorOnly,
545 ];
546 for strategy in cases {
547 let json = serde_json::to_string(&strategy).expect("serialize");
548 let back: FusionStrategy = serde_json::from_str(&json).expect("deserialize");
549 assert_eq!(strategy, back, "roundtrip failed for {json}");
550 }
551 }
552}