1use std::collections::{HashMap, HashSet};
6
7use uuid::Uuid;
8
9use crate::error::{RuntimeError, RuntimeResult};
10use crate::runtime::{KhiveRuntime, NamespaceToken};
11use khive_score::{rrf_score, DeterministicScore};
12use khive_storage::types::{
13 PageRequest, TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
14 VectorSearchRequest,
15};
16use khive_storage::EntityFilter;
17use khive_types::SubstrateKind;
18
19#[derive(Clone, Debug)]
21pub struct SearchHit {
22 pub entity_id: Uuid,
23 pub score: DeterministicScore,
24 pub source: SearchSource,
25 pub title: Option<String>,
26 pub snippet: Option<String>,
27}
28
29#[derive(Clone, Copy, Debug, PartialEq, Eq)]
31pub enum SearchSource {
32 Vector,
33 Text,
34 Both,
35}
36
37const RRF_K: usize = 60;
39
40const CANDIDATE_MULTIPLIER: u32 = 4;
42
43impl KhiveRuntime {
44 pub async fn embed(&self, text: &str) -> RuntimeResult<Vec<f32>> {
49 let service = self.embedder().await?;
50 let model = self
51 .config()
52 .embedding_model
53 .expect("embedder() returns Unconfigured when model is None");
54 Ok(service.embed_one(text, model).await?)
55 }
56
57 pub async fn embed_batch(&self, texts: &[String]) -> RuntimeResult<Vec<Vec<f32>>> {
65 if texts.is_empty() {
66 return Ok(vec![]);
67 }
68 let service = self.embedder().await?;
69 let model = self
70 .config()
71 .embedding_model
72 .expect("embedder() returns Unconfigured when model is None");
73 Ok(service.embed(texts, model).await?)
74 }
75
76 pub async fn vector_search(
82 &self,
83 token: &NamespaceToken,
84 query_embedding: Option<Vec<f32>>,
85 query_text: Option<&str>,
86 top_k: u32,
87 kind: Option<SubstrateKind>,
88 ) -> RuntimeResult<Vec<VectorSearchHit>> {
89 let embedding = match query_embedding {
90 Some(vec) => vec,
91 None => {
92 let text = query_text.ok_or_else(|| {
93 RuntimeError::InvalidInput(
94 "vector search requires query_embedding or query_text".into(),
95 )
96 })?;
97 if text.trim().is_empty() {
98 return Err(RuntimeError::InvalidInput(
99 "query_text must not be empty".into(),
100 ));
101 }
102 self.embed(text).await?
103 }
104 };
105
106 let ns = token.namespace().as_str().to_owned();
107 Ok(self
108 .vectors(token)?
109 .search(VectorSearchRequest {
110 query_vectors: vec![embedding],
111 top_k,
112 namespace: Some(ns),
113 kind,
114 filter: None,
115 backend_hints: None,
116 })
117 .await?)
118 }
119
120 #[allow(clippy::too_many_arguments)]
135 pub async fn hybrid_search(
136 &self,
137 token: &NamespaceToken,
138 query_text: &str,
139 query_vector: Option<Vec<f32>>,
140 limit: u32,
141 entity_kind: Option<&str>,
142 entity_type: Option<&str>,
143 ) -> RuntimeResult<Vec<SearchHit>> {
144 let candidates = limit.saturating_mul(CANDIDATE_MULTIPLIER).max(limit);
145
146 let ns = token.namespace().as_str().to_owned();
147 let text_hits = self
148 .text(token)?
149 .search(TextSearchRequest {
150 query: query_text.to_string(),
151 mode: TextQueryMode::Plain,
152 filter: Some(TextFilter {
153 namespaces: vec![ns.clone()],
154 ..TextFilter::default()
155 }),
156 top_k: candidates,
157 snippet_chars: 200,
158 })
159 .await?;
160
161 let vector_hits = if query_vector.is_some() || self.config().embedding_model.is_some() {
162 self.vector_search(
163 token,
164 query_vector,
165 Some(query_text),
166 candidates,
167 Some(SubstrateKind::Entity),
168 )
169 .await?
170 } else {
171 Vec::new()
172 };
173
174 let mut fused = rrf_fuse(text_hits, vector_hits, candidates as usize);
177
178 if !fused.is_empty() {
182 let candidate_ids: Vec<Uuid> = fused.iter().map(|h| h.entity_id).collect();
183 let alive_page = self
184 .entities(token)?
185 .query_entities(
186 token.namespace().as_str(),
187 EntityFilter {
188 ids: candidate_ids,
189 kinds: entity_kind.map(|k| vec![k.to_string()]).unwrap_or_default(),
190 entity_types: entity_type.map(|t| vec![t.to_string()]).unwrap_or_default(),
191 ..EntityFilter::default()
192 },
193 PageRequest {
194 offset: 0,
195 limit: fused.len() as u32,
196 },
197 )
198 .await?;
199 let mut entity_meta: HashMap<Uuid, (String, Option<String>)> = HashMap::new();
201 let mut alive: HashSet<Uuid> = HashSet::new();
202 for e in alive_page.items {
203 alive.insert(e.id);
204 entity_meta.insert(e.id, (e.name, e.description));
205 }
206
207 fused.retain(|h| alive.contains(&h.entity_id));
208
209 for hit in &mut fused {
211 if let Some((name, description)) = entity_meta.get(&hit.entity_id) {
212 if hit.title.is_none() {
213 hit.title = Some(name.clone());
214 }
215 if hit.snippet.is_none() {
216 hit.snippet = description.clone();
217 }
218 }
219 }
220 }
221
222 fused.truncate(limit as usize);
223 Ok(fused)
224 }
225
226 pub async fn knn(
232 &self,
233 token: &NamespaceToken,
234 query_vector: Vec<f32>,
235 top_k: u32,
236 ) -> RuntimeResult<Vec<VectorSearchHit>> {
237 let ns = token.namespace().as_str().to_owned();
238 Ok(self
239 .vectors(token)?
240 .search(VectorSearchRequest {
241 query_vectors: vec![query_vector],
242 top_k,
243 namespace: Some(ns),
244 kind: Some(SubstrateKind::Entity),
245 filter: None,
246 backend_hints: None,
247 })
248 .await?)
249 }
250
251 pub async fn rerank(
257 &self,
258 token: &NamespaceToken,
259 query_vector: &[f32],
260 candidate_ids: &[Uuid],
261 top_k: u32,
262 ) -> RuntimeResult<Vec<VectorSearchHit>> {
263 let candidate_set: HashSet<Uuid> = candidate_ids.iter().copied().collect();
264 let ns = token.namespace().as_str().to_owned();
265 let all_hits = self
266 .vectors(token)?
267 .search(VectorSearchRequest {
268 query_vectors: vec![query_vector.to_vec()],
269 top_k: candidate_ids.len() as u32,
270 namespace: Some(ns),
271 kind: Some(SubstrateKind::Entity),
272 filter: None,
273 backend_hints: None,
274 })
275 .await?;
276 let mut hits: Vec<VectorSearchHit> = all_hits
277 .into_iter()
278 .filter(|h| candidate_set.contains(&h.subject_id))
279 .collect();
280 hits.sort_by(|a, b| b.score.cmp(&a.score));
281 hits.truncate(top_k as usize);
282 Ok(hits)
283 }
284}
285
286fn rrf_fuse(
290 text_hits: Vec<TextSearchHit>,
291 vector_hits: Vec<VectorSearchHit>,
292 limit: usize,
293) -> Vec<SearchHit> {
294 #[derive(Default)]
295 struct Bucket {
296 score: DeterministicScore,
297 source: Option<SearchSource>,
298 title: Option<String>,
299 snippet: Option<String>,
300 }
301
302 let mut buckets: HashMap<Uuid, Bucket> = HashMap::new();
303
304 for (i, hit) in text_hits.into_iter().enumerate() {
305 let rank = i + 1; let entry = buckets.entry(hit.subject_id).or_default();
307 entry.score = entry.score + rrf_score(rank, RRF_K);
308 entry.source = Some(match entry.source {
309 Some(SearchSource::Vector) => SearchSource::Both,
310 _ => SearchSource::Text,
311 });
312 if entry.title.is_none() {
313 entry.title = hit.title;
314 }
315 if entry.snippet.is_none() {
316 entry.snippet = hit.snippet;
317 }
318 }
319
320 for (i, hit) in vector_hits.into_iter().enumerate() {
321 let rank = i + 1;
322 let entry = buckets.entry(hit.subject_id).or_default();
323 entry.score = entry.score + rrf_score(rank, RRF_K);
324 entry.source = Some(match entry.source {
325 Some(SearchSource::Text) => SearchSource::Both,
326 _ => SearchSource::Vector,
327 });
328 }
329
330 let mut hits: Vec<SearchHit> = buckets
331 .into_iter()
332 .map(|(id, b)| SearchHit {
333 entity_id: id,
334 score: b.score,
335 source: b.source.expect("each bucket gets a source"),
336 title: b.title,
337 snippet: b.snippet,
338 })
339 .collect();
340
341 hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
342 hits.truncate(limit);
343 hits
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use crate::runtime::{KhiveRuntime, NamespaceToken, RuntimeConfig};
350 use khive_storage::types::{TextSearchHit, VectorSearchHit};
351 use khive_types::namespace::Namespace;
352 use lattice_embed::EmbeddingModel;
353
354 fn text_hit(id: Uuid, rank: u32, title: &str) -> TextSearchHit {
355 TextSearchHit {
356 subject_id: id,
357 score: DeterministicScore::from_f64(1.0),
358 rank,
359 title: Some(title.to_string()),
360 snippet: Some("...".to_string()),
361 }
362 }
363
364 fn vector_hit(id: Uuid, rank: u32) -> VectorSearchHit {
365 VectorSearchHit {
366 subject_id: id,
367 score: DeterministicScore::from_f64(0.9),
368 rank,
369 }
370 }
371
372 #[test]
373 fn rrf_fuse_text_only() {
374 let a = Uuid::new_v4();
375 let b = Uuid::new_v4();
376 let text = vec![text_hit(a, 1, "A"), text_hit(b, 2, "B")];
377 let hits = rrf_fuse(text, vec![], 10);
378 assert_eq!(hits.len(), 2);
379 assert_eq!(hits[0].entity_id, a);
380 assert_eq!(hits[0].source, SearchSource::Text);
381 assert_eq!(hits[0].title.as_deref(), Some("A"));
382 }
383
384 #[test]
385 fn rrf_fuse_vector_only() {
386 let a = Uuid::new_v4();
387 let hits = rrf_fuse(vec![], vec![vector_hit(a, 1)], 10);
388 assert_eq!(hits.len(), 1);
389 assert_eq!(hits[0].source, SearchSource::Vector);
390 assert!(hits[0].title.is_none());
391 }
392
393 #[test]
394 fn rrf_fuse_marks_both_when_in_both_lists() {
395 let id = Uuid::new_v4();
396 let text = vec![text_hit(id, 1, "A")];
397 let vec = vec![vector_hit(id, 1)];
398 let hits = rrf_fuse(text, vec, 10);
399 assert_eq!(hits.len(), 1);
400 assert_eq!(hits[0].source, SearchSource::Both);
401 }
402
403 #[test]
404 fn rrf_fuse_respects_limit() {
405 let hits: Vec<TextSearchHit> = (0..20)
406 .map(|i| text_hit(Uuid::new_v4(), i + 1, "x"))
407 .collect();
408 let fused = rrf_fuse(hits, vec![], 5);
409 assert_eq!(fused.len(), 5);
410 }
411
412 #[test]
413 fn rrf_fuse_orders_higher_score_first() {
414 let a = Uuid::new_v4();
416 let b = Uuid::new_v4();
417 let text = vec![text_hit(a, 1, "A")];
418 let vec = vec![vector_hit(a, 1), vector_hit(b, 2)];
419 let hits = rrf_fuse(text, vec, 10);
420 assert_eq!(hits[0].entity_id, a);
421 assert_eq!(hits[0].source, SearchSource::Both);
422 assert!(hits[0].score > hits[1].score);
423 }
424
425 #[test]
428 fn embed_batch_unconfigured_on_memory_runtime() {
429 let rt = KhiveRuntime::memory().unwrap();
431 let result = tokio::runtime::Runtime::new()
432 .unwrap()
433 .block_on(rt.embed_batch(&[]));
434 assert!(result.is_ok());
436 assert!(result.unwrap().is_empty());
437 }
438
439 #[test]
440 fn embed_batch_empty_input_returns_empty_vec() {
441 let rt = KhiveRuntime::memory().unwrap();
443 let result = tokio::runtime::Runtime::new()
444 .unwrap()
445 .block_on(rt.embed_batch(&[]));
446 assert_eq!(result.unwrap(), Vec::<Vec<f32>>::new());
447 }
448
449 #[test]
450 fn embed_batch_no_model_non_empty_returns_unconfigured() {
451 let rt = KhiveRuntime::memory().unwrap();
452 let texts = vec!["hello".to_string()];
453 let result = tokio::runtime::Runtime::new()
454 .unwrap()
455 .block_on(rt.embed_batch(&texts));
456 match result {
457 Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
458 Err(other) => panic!("expected Unconfigured, got {:?}", other),
459 Ok(_) => panic!("expected Err, got Ok"),
460 }
461 }
462
463 #[test]
464 #[ignore = "loads ~80 MB model; run with --include-ignored"]
465 fn embed_batch_count_matches_input() {
466 let config = RuntimeConfig {
467 db_path: None,
468 default_namespace: Namespace::parse("test").unwrap(),
469 embedding_model: Some(EmbeddingModel::AllMiniLmL6V2),
470 packs: vec!["kg".to_string()],
471 ..RuntimeConfig::default()
472 };
473 let rt = KhiveRuntime::new(config).unwrap();
474 let texts: Vec<String> = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
475 let result = tokio::runtime::Runtime::new()
476 .unwrap()
477 .block_on(rt.embed_batch(&texts));
478 let embeddings = result.unwrap();
479 assert_eq!(embeddings.len(), texts.len());
480 }
481
482 #[test]
483 fn vector_search_requires_embedding_or_text() {
484 let rt = KhiveRuntime::memory().unwrap();
485 let tok = NamespaceToken::local();
486 let result = tokio::runtime::Runtime::new()
487 .unwrap()
488 .block_on(rt.vector_search(&tok, None, None, 10, Some(SubstrateKind::Entity)));
489 match result {
490 Err(crate::RuntimeError::InvalidInput(msg)) => {
491 assert!(msg.contains("query_embedding or query_text"), "msg: {msg}");
492 }
493 other => panic!("expected InvalidInput, got {other:?}"),
494 }
495 }
496
497 #[test]
498 fn vector_search_text_without_model_returns_unconfigured() {
499 let rt = KhiveRuntime::memory().unwrap();
500 let tok = NamespaceToken::local();
501 let result = tokio::runtime::Runtime::new()
502 .unwrap()
503 .block_on(rt.vector_search(
504 &tok,
505 None,
506 Some("attention"),
507 10,
508 Some(SubstrateKind::Entity),
509 ));
510 match result {
511 Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
512 other => panic!("expected Unconfigured, got {other:?}"),
513 }
514 }
515
516 #[test]
517 #[ignore = "loads ~80 MB model; run with --include-ignored"]
518 fn embed_batch_vectors_have_expected_dimensions() {
519 let model = EmbeddingModel::AllMiniLmL6V2;
520 let config = RuntimeConfig {
521 db_path: None,
522 default_namespace: Namespace::parse("test").unwrap(),
523 embedding_model: Some(model),
524 packs: vec!["kg".to_string()],
525 ..RuntimeConfig::default()
526 };
527 let rt = KhiveRuntime::new(config).unwrap();
528 let texts = vec!["hello world".to_string()];
529 let result = tokio::runtime::Runtime::new()
530 .unwrap()
531 .block_on(rt.embed_batch(&texts));
532 let embeddings = result.unwrap();
533 assert_eq!(embeddings[0].len(), model.dimensions());
534 }
535
536 #[tokio::test]
539 async fn hybrid_search_entity_hit_has_title() {
540 let rt = KhiveRuntime::memory().unwrap();
541 let tok = NamespaceToken::local();
542 rt.create_entity(
543 &tok,
544 "concept",
545 None,
546 "FlashAttention",
547 Some("IO-aware exact attention using tiling"),
548 None,
549 vec![],
550 )
551 .await
552 .unwrap();
553
554 let hits = rt
555 .hybrid_search(&tok, "FlashAttention", None, 10, None, None)
556 .await
557 .unwrap();
558
559 assert!(!hits.is_empty(), "should find the entity");
560 let hit = &hits[0];
561 assert!(hit.title.is_some(), "title must be populated");
562 assert!(
563 hit.title.as_deref().unwrap().contains("FlashAttention"),
564 "title must contain entity name"
565 );
566 }
567}