1use std::collections::{HashMap, HashSet};
6
7use uuid::Uuid;
8
9use crate::error::{RuntimeError, RuntimeResult};
10use crate::runtime::KhiveRuntime;
11use khive_score::{rrf_score, DeterministicScore};
12use khive_storage::types::{
13 PageRequest, TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
14 VectorSearchRequest,
15};
16use khive_storage::EntityFilter;
17use khive_types::SubstrateKind;
18
19#[derive(Clone, Debug)]
21pub struct SearchHit {
22 pub entity_id: Uuid,
23 pub score: DeterministicScore,
24 pub source: SearchSource,
25 pub title: Option<String>,
26 pub snippet: Option<String>,
27}
28
29#[derive(Clone, Copy, Debug, PartialEq, Eq)]
31pub enum SearchSource {
32 Vector,
33 Text,
34 Both,
35}
36
37const RRF_K: usize = 60;
39
40const CANDIDATE_MULTIPLIER: u32 = 4;
42
43impl KhiveRuntime {
44 pub async fn embed(&self, text: &str) -> RuntimeResult<Vec<f32>> {
49 let service = self.embedder().await?;
50 let model = self
51 .config()
52 .embedding_model
53 .expect("embedder() returns Unconfigured when model is None");
54 Ok(service.embed_one(text, model).await?)
55 }
56
57 pub async fn embed_batch(&self, texts: &[String]) -> RuntimeResult<Vec<Vec<f32>>> {
65 if texts.is_empty() {
66 return Ok(vec![]);
67 }
68 let service = self.embedder().await?;
69 let model = self
70 .config()
71 .embedding_model
72 .expect("embedder() returns Unconfigured when model is None");
73 Ok(service.embed(texts, model).await?)
74 }
75
76 pub async fn vector_search(
82 &self,
83 namespace: Option<&str>,
84 query_embedding: Option<Vec<f32>>,
85 query_text: Option<&str>,
86 top_k: u32,
87 kind: Option<SubstrateKind>,
88 ) -> RuntimeResult<Vec<VectorSearchHit>> {
89 let embedding = match query_embedding {
90 Some(vec) => vec,
91 None => {
92 let text = query_text.ok_or_else(|| {
93 RuntimeError::InvalidInput(
94 "vector search requires query_embedding or query_text".into(),
95 )
96 })?;
97 if text.trim().is_empty() {
98 return Err(RuntimeError::InvalidInput(
99 "query_text must not be empty".into(),
100 ));
101 }
102 self.embed(text).await?
103 }
104 };
105
106 let ns = self.ns(namespace).to_string();
107 Ok(self
108 .vectors(namespace)?
109 .search(VectorSearchRequest {
110 query_embedding: embedding,
111 top_k,
112 namespace: Some(ns),
113 kind,
114 })
115 .await?)
116 }
117
118 pub async fn hybrid_search(
133 &self,
134 namespace: Option<&str>,
135 query_text: &str,
136 query_vector: Option<Vec<f32>>,
137 limit: u32,
138 entity_kind: Option<&str>,
139 ) -> RuntimeResult<Vec<SearchHit>> {
140 let candidates = limit.saturating_mul(CANDIDATE_MULTIPLIER).max(limit);
141
142 let ns = self.ns(namespace).to_string();
143 let text_hits = self
144 .text(namespace)?
145 .search(TextSearchRequest {
146 query: query_text.to_string(),
147 mode: TextQueryMode::Plain,
148 filter: Some(TextFilter {
149 namespaces: vec![ns.clone()],
150 ..TextFilter::default()
151 }),
152 top_k: candidates,
153 snippet_chars: 200,
154 })
155 .await?;
156
157 let vector_hits = if query_vector.is_some() || self.config().embedding_model.is_some() {
158 self.vector_search(
159 namespace,
160 query_vector,
161 Some(query_text),
162 candidates,
163 Some(SubstrateKind::Entity),
164 )
165 .await?
166 } else {
167 Vec::new()
168 };
169
170 let mut fused = rrf_fuse(text_hits, vector_hits, candidates as usize);
173
174 if !fused.is_empty() {
178 let candidate_ids: Vec<Uuid> = fused.iter().map(|h| h.entity_id).collect();
179 let alive_page = self
180 .entities(namespace)?
181 .query_entities(
182 self.ns(namespace),
183 EntityFilter {
184 ids: candidate_ids,
185 kinds: entity_kind.map(|k| vec![k.to_string()]).unwrap_or_default(),
186 ..EntityFilter::default()
187 },
188 PageRequest {
189 offset: 0,
190 limit: fused.len() as u32,
191 },
192 )
193 .await?;
194 let mut entity_meta: HashMap<Uuid, (String, Option<String>)> = HashMap::new();
196 let mut alive: HashSet<Uuid> = HashSet::new();
197 for e in alive_page.items {
198 alive.insert(e.id);
199 entity_meta.insert(e.id, (e.name, e.description));
200 }
201
202 fused.retain(|h| alive.contains(&h.entity_id));
203
204 for hit in &mut fused {
206 if let Some((name, description)) = entity_meta.get(&hit.entity_id) {
207 if hit.title.is_none() {
208 hit.title = Some(name.clone());
209 }
210 if hit.snippet.is_none() {
211 hit.snippet = description.clone();
212 }
213 }
214 }
215 }
216
217 fused.truncate(limit as usize);
218 Ok(fused)
219 }
220
221 pub async fn knn(
227 &self,
228 namespace: Option<&str>,
229 query_vector: Vec<f32>,
230 top_k: u32,
231 ) -> RuntimeResult<Vec<VectorSearchHit>> {
232 let ns = self.ns(namespace).to_string();
233 Ok(self
234 .vectors(namespace)?
235 .search(VectorSearchRequest {
236 query_embedding: query_vector,
237 top_k,
238 namespace: Some(ns),
239 kind: Some(SubstrateKind::Entity),
240 })
241 .await?)
242 }
243
244 pub async fn rerank(
250 &self,
251 namespace: Option<&str>,
252 query_vector: &[f32],
253 candidate_ids: &[Uuid],
254 top_k: u32,
255 ) -> RuntimeResult<Vec<VectorSearchHit>> {
256 let candidate_set: HashSet<Uuid> = candidate_ids.iter().copied().collect();
257 let ns = self.ns(namespace).to_string();
258 let all_hits = self
259 .vectors(namespace)?
260 .search(VectorSearchRequest {
261 query_embedding: query_vector.to_vec(),
262 top_k: candidate_ids.len() as u32,
263 namespace: Some(ns),
264 kind: Some(SubstrateKind::Entity),
265 })
266 .await?;
267 let mut hits: Vec<VectorSearchHit> = all_hits
268 .into_iter()
269 .filter(|h| candidate_set.contains(&h.subject_id))
270 .collect();
271 hits.sort_by(|a, b| b.score.cmp(&a.score));
272 hits.truncate(top_k as usize);
273 Ok(hits)
274 }
275}
276
277fn rrf_fuse(
281 text_hits: Vec<TextSearchHit>,
282 vector_hits: Vec<VectorSearchHit>,
283 limit: usize,
284) -> Vec<SearchHit> {
285 #[derive(Default)]
286 struct Bucket {
287 score: DeterministicScore,
288 source: Option<SearchSource>,
289 title: Option<String>,
290 snippet: Option<String>,
291 }
292
293 let mut buckets: HashMap<Uuid, Bucket> = HashMap::new();
294
295 for (i, hit) in text_hits.into_iter().enumerate() {
296 let rank = i + 1; let entry = buckets.entry(hit.subject_id).or_default();
298 entry.score = entry.score + rrf_score(rank, RRF_K);
299 entry.source = Some(match entry.source {
300 Some(SearchSource::Vector) => SearchSource::Both,
301 _ => SearchSource::Text,
302 });
303 if entry.title.is_none() {
304 entry.title = hit.title;
305 }
306 if entry.snippet.is_none() {
307 entry.snippet = hit.snippet;
308 }
309 }
310
311 for (i, hit) in vector_hits.into_iter().enumerate() {
312 let rank = i + 1;
313 let entry = buckets.entry(hit.subject_id).or_default();
314 entry.score = entry.score + rrf_score(rank, RRF_K);
315 entry.source = Some(match entry.source {
316 Some(SearchSource::Text) => SearchSource::Both,
317 _ => SearchSource::Vector,
318 });
319 }
320
321 let mut hits: Vec<SearchHit> = buckets
322 .into_iter()
323 .map(|(id, b)| SearchHit {
324 entity_id: id,
325 score: b.score,
326 source: b.source.expect("each bucket gets a source"),
327 title: b.title,
328 snippet: b.snippet,
329 })
330 .collect();
331
332 hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
333 hits.truncate(limit);
334 hits
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 use crate::runtime::{KhiveRuntime, RuntimeConfig};
341 use khive_storage::types::{TextSearchHit, VectorSearchHit};
342 use lattice_embed::EmbeddingModel;
343
344 fn text_hit(id: Uuid, rank: u32, title: &str) -> TextSearchHit {
345 TextSearchHit {
346 subject_id: id,
347 score: DeterministicScore::from_f64(1.0),
348 rank,
349 title: Some(title.to_string()),
350 snippet: Some("...".to_string()),
351 }
352 }
353
354 fn vector_hit(id: Uuid, rank: u32) -> VectorSearchHit {
355 VectorSearchHit {
356 subject_id: id,
357 score: DeterministicScore::from_f64(0.9),
358 rank,
359 }
360 }
361
362 #[test]
363 fn rrf_fuse_text_only() {
364 let a = Uuid::new_v4();
365 let b = Uuid::new_v4();
366 let text = vec![text_hit(a, 1, "A"), text_hit(b, 2, "B")];
367 let hits = rrf_fuse(text, vec![], 10);
368 assert_eq!(hits.len(), 2);
369 assert_eq!(hits[0].entity_id, a);
370 assert_eq!(hits[0].source, SearchSource::Text);
371 assert_eq!(hits[0].title.as_deref(), Some("A"));
372 }
373
374 #[test]
375 fn rrf_fuse_vector_only() {
376 let a = Uuid::new_v4();
377 let hits = rrf_fuse(vec![], vec![vector_hit(a, 1)], 10);
378 assert_eq!(hits.len(), 1);
379 assert_eq!(hits[0].source, SearchSource::Vector);
380 assert!(hits[0].title.is_none());
381 }
382
383 #[test]
384 fn rrf_fuse_marks_both_when_in_both_lists() {
385 let id = Uuid::new_v4();
386 let text = vec![text_hit(id, 1, "A")];
387 let vec = vec![vector_hit(id, 1)];
388 let hits = rrf_fuse(text, vec, 10);
389 assert_eq!(hits.len(), 1);
390 assert_eq!(hits[0].source, SearchSource::Both);
391 }
392
393 #[test]
394 fn rrf_fuse_respects_limit() {
395 let hits: Vec<TextSearchHit> = (0..20)
396 .map(|i| text_hit(Uuid::new_v4(), i + 1, "x"))
397 .collect();
398 let fused = rrf_fuse(hits, vec![], 5);
399 assert_eq!(fused.len(), 5);
400 }
401
402 #[test]
403 fn rrf_fuse_orders_higher_score_first() {
404 let a = Uuid::new_v4();
406 let b = Uuid::new_v4();
407 let text = vec![text_hit(a, 1, "A")];
408 let vec = vec![vector_hit(a, 1), vector_hit(b, 2)];
409 let hits = rrf_fuse(text, vec, 10);
410 assert_eq!(hits[0].entity_id, a);
411 assert_eq!(hits[0].source, SearchSource::Both);
412 assert!(hits[0].score > hits[1].score);
413 }
414
415 #[test]
418 fn embed_batch_unconfigured_on_memory_runtime() {
419 let rt = KhiveRuntime::memory().unwrap();
421 let result = tokio::runtime::Runtime::new()
422 .unwrap()
423 .block_on(rt.embed_batch(&[]));
424 assert!(result.is_ok());
426 assert!(result.unwrap().is_empty());
427 }
428
429 #[test]
430 fn embed_batch_empty_input_returns_empty_vec() {
431 let rt = KhiveRuntime::memory().unwrap();
433 let result = tokio::runtime::Runtime::new()
434 .unwrap()
435 .block_on(rt.embed_batch(&[]));
436 assert_eq!(result.unwrap(), Vec::<Vec<f32>>::new());
437 }
438
439 #[test]
440 fn embed_batch_no_model_non_empty_returns_unconfigured() {
441 let rt = KhiveRuntime::memory().unwrap();
442 let texts = vec!["hello".to_string()];
443 let result = tokio::runtime::Runtime::new()
444 .unwrap()
445 .block_on(rt.embed_batch(&texts));
446 match result {
447 Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
448 Err(other) => panic!("expected Unconfigured, got {:?}", other),
449 Ok(_) => panic!("expected Err, got Ok"),
450 }
451 }
452
453 #[test]
454 #[ignore = "loads ~80 MB model; run with --include-ignored"]
455 fn embed_batch_count_matches_input() {
456 let config = RuntimeConfig {
457 db_path: None,
458 default_namespace: "test".to_string(),
459 embedding_model: Some(EmbeddingModel::AllMiniLmL6V2),
460 packs: vec!["kg".to_string()],
461 ..RuntimeConfig::default()
462 };
463 let rt = KhiveRuntime::new(config).unwrap();
464 let texts: Vec<String> = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
465 let result = tokio::runtime::Runtime::new()
466 .unwrap()
467 .block_on(rt.embed_batch(&texts));
468 let embeddings = result.unwrap();
469 assert_eq!(embeddings.len(), texts.len());
470 }
471
472 #[test]
473 fn vector_search_requires_embedding_or_text() {
474 let rt = KhiveRuntime::memory().unwrap();
475 let result = tokio::runtime::Runtime::new()
476 .unwrap()
477 .block_on(rt.vector_search(None, None, None, 10, Some(SubstrateKind::Entity)));
478 match result {
479 Err(crate::RuntimeError::InvalidInput(msg)) => {
480 assert!(msg.contains("query_embedding or query_text"), "msg: {msg}");
481 }
482 other => panic!("expected InvalidInput, got {other:?}"),
483 }
484 }
485
486 #[test]
487 fn vector_search_text_without_model_returns_unconfigured() {
488 let rt = KhiveRuntime::memory().unwrap();
489 let result = tokio::runtime::Runtime::new()
490 .unwrap()
491 .block_on(rt.vector_search(
492 None,
493 None,
494 Some("attention"),
495 10,
496 Some(SubstrateKind::Entity),
497 ));
498 match result {
499 Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
500 other => panic!("expected Unconfigured, got {other:?}"),
501 }
502 }
503
504 #[test]
505 #[ignore = "loads ~80 MB model; run with --include-ignored"]
506 fn embed_batch_vectors_have_expected_dimensions() {
507 let model = EmbeddingModel::AllMiniLmL6V2;
508 let config = RuntimeConfig {
509 db_path: None,
510 default_namespace: "test".to_string(),
511 embedding_model: Some(model),
512 packs: vec!["kg".to_string()],
513 ..RuntimeConfig::default()
514 };
515 let rt = KhiveRuntime::new(config).unwrap();
516 let texts = vec!["hello world".to_string()];
517 let result = tokio::runtime::Runtime::new()
518 .unwrap()
519 .block_on(rt.embed_batch(&texts));
520 let embeddings = result.unwrap();
521 assert_eq!(embeddings[0].len(), model.dimensions());
522 }
523
524 #[tokio::test]
527 async fn hybrid_search_entity_hit_has_title() {
528 let rt = KhiveRuntime::memory().unwrap();
529 rt.create_entity(
530 None,
531 "concept",
532 "FlashAttention",
533 Some("IO-aware exact attention using tiling"),
534 None,
535 vec![],
536 )
537 .await
538 .unwrap();
539
540 let hits = rt
541 .hybrid_search(None, "FlashAttention", None, 10, None)
542 .await
543 .unwrap();
544
545 assert!(!hits.is_empty(), "should find the entity");
546 let hit = &hits[0];
547 assert!(hit.title.is_some(), "title must be populated");
548 assert!(
549 hit.title.as_deref().unwrap().contains("FlashAttention"),
550 "title must contain entity name"
551 );
552 }
553}