1use std::collections::{HashMap, HashSet};
4
5use uuid::Uuid;
6
7use crate::config::{parse_embedding_model_alias, sanitize_key};
8use crate::curation::note_fts_document;
9use crate::error::{RuntimeError, RuntimeResult};
10use crate::runtime::{KhiveRuntime, NamespaceToken};
11use khive_score::{rrf_score, DeterministicScore};
12use khive_storage::types::{
13 PageRequest, TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
14 VectorSearchRequest,
15};
16use khive_storage::EntityFilter;
17use khive_types::SubstrateKind;
18
19#[derive(Clone, Debug)]
21pub struct SearchHit {
22 pub entity_id: Uuid,
23 pub score: DeterministicScore,
24 pub source: SearchSource,
25 pub title: Option<String>,
26 pub snippet: Option<String>,
27}
28
29#[derive(Clone, Copy, Debug, PartialEq, Eq)]
31pub enum SearchSource {
32 Vector,
33 Text,
34 Both,
35}
36
37const RRF_K: usize = 10;
45
46const CANDIDATE_MULTIPLIER: u32 = 4;
48
49impl KhiveRuntime {
50 pub async fn embed(&self, text: &str) -> RuntimeResult<Vec<f32>> {
55 let model_name = self.default_embedder_name();
56 if model_name.is_empty() {
57 return Err(RuntimeError::Unconfigured("embedding_model".into()));
58 }
59 self.embed_with_model(model_name, text).await
60 }
61
62 pub async fn embed_with_model(&self, model_name: &str, text: &str) -> RuntimeResult<Vec<f32>> {
78 let model = parse_embedding_model_alias(model_name);
82 let service = self.embedder(model_name).await?;
83 let emb_model = model.unwrap_or_default();
84 Ok(service.embed_one(text, emb_model).await?)
85 }
86
87 pub async fn embed_document_with_model(
106 &self,
107 model_name: &str,
108 text: &str,
109 ) -> RuntimeResult<Vec<f32>> {
110 let model = parse_embedding_model_alias(model_name);
111 let service = self.embedder(model_name).await?;
112 let emb_model = model.unwrap_or_default();
113 service
114 .embed_passage(&[text.to_string()], emb_model)
115 .await?
116 .into_iter()
117 .next()
118 .ok_or_else(|| RuntimeError::Internal("embed_passage returned empty vec".into()))
119 }
120
121 pub async fn embed_query_with_model(
134 &self,
135 model_name: &str,
136 text: &str,
137 ) -> RuntimeResult<Vec<f32>> {
138 let model = parse_embedding_model_alias(model_name);
139 let service = self.embedder(model_name).await?;
140 let emb_model = model.unwrap_or_default();
141 service
142 .embed_query(&[text.to_string()], emb_model)
143 .await?
144 .into_iter()
145 .next()
146 .ok_or_else(|| RuntimeError::Internal("embed_query returned empty vec".into()))
147 }
148
149 pub async fn embed_document(&self, text: &str) -> RuntimeResult<Vec<f32>> {
156 let model_name = self.default_embedder_name();
157 if model_name.is_empty() {
158 return Err(RuntimeError::Unconfigured("embedding_model".into()));
159 }
160 self.embed_document_with_model(model_name, text).await
161 }
162
163 pub async fn embed_query(&self, text: &str) -> RuntimeResult<Vec<f32>> {
170 let model_name = self.default_embedder_name();
171 if model_name.is_empty() {
172 return Err(RuntimeError::Unconfigured("embedding_model".into()));
173 }
174 self.embed_query_with_model(model_name, text).await
175 }
176
177 pub async fn embed_batch(&self, texts: &[String]) -> RuntimeResult<Vec<Vec<f32>>> {
185 if texts.is_empty() {
186 return Ok(vec![]);
187 }
188 let model_name = self.default_embedder_name();
189 if model_name.is_empty() {
190 return Err(RuntimeError::Unconfigured("embedding_model".into()));
191 }
192 self.embed_batch_with_model(model_name, texts).await
193 }
194
195 pub async fn embed_batch_with_model(
200 &self,
201 model_name: &str,
202 texts: &[String],
203 ) -> RuntimeResult<Vec<Vec<f32>>> {
204 if texts.is_empty() {
205 return Ok(vec![]);
206 }
207 let model = parse_embedding_model_alias(model_name);
208 let service = self.embedder(model_name).await?;
209 let emb_model = model.unwrap_or_default();
210 Ok(service.embed(texts, emb_model).await?)
211 }
212
213 pub async fn embed_document_batch_with_model(
223 &self,
224 model_name: &str,
225 texts: &[String],
226 ) -> RuntimeResult<Vec<Vec<f32>>> {
227 if texts.is_empty() {
228 return Ok(vec![]);
229 }
230 let model = parse_embedding_model_alias(model_name);
231 let service = self.embedder(model_name).await?;
232 let emb_model = model.unwrap_or_default();
233 Ok(service.embed_passage(texts, emb_model).await?)
234 }
235
236 pub async fn embed_document_batch(&self, texts: &[String]) -> RuntimeResult<Vec<Vec<f32>>> {
243 if texts.is_empty() {
244 return Ok(vec![]);
245 }
246 let model_name = self.default_embedder_name();
247 if model_name.is_empty() {
248 return Err(RuntimeError::Unconfigured("embedding_model".into()));
249 }
250 self.embed_document_batch_with_model(model_name, texts)
251 .await
252 }
253
254 pub async fn embed_query_batch_with_model(
261 &self,
262 model_name: &str,
263 texts: &[String],
264 ) -> RuntimeResult<Vec<Vec<f32>>> {
265 if texts.is_empty() {
266 return Ok(vec![]);
267 }
268 let model = parse_embedding_model_alias(model_name);
269 let service = self.embedder(model_name).await?;
270 let emb_model = model.unwrap_or_default();
271 Ok(service.embed_query(texts, emb_model).await?)
272 }
273
274 pub async fn vector_search(
280 &self,
281 token: &NamespaceToken,
282 query_embedding: Option<Vec<f32>>,
283 query_text: Option<&str>,
284 top_k: u32,
285 kind: Option<SubstrateKind>,
286 ) -> RuntimeResult<Vec<VectorSearchHit>> {
287 let embedding = match query_embedding {
288 Some(vec) => vec,
289 None => {
290 let text = query_text.ok_or_else(|| {
291 RuntimeError::InvalidInput(
292 "vector search requires query_embedding or query_text".into(),
293 )
294 })?;
295 if text.trim().is_empty() {
296 return Err(RuntimeError::InvalidInput(
297 "query_text must not be empty".into(),
298 ));
299 }
300 self.embed_query(text).await?
301 }
302 };
303
304 let ns = token.namespace().as_str().to_owned();
305 Ok(self
306 .vectors(token)?
307 .search(VectorSearchRequest {
308 query_vectors: vec![embedding],
309 top_k,
310 namespace: Some(ns),
311 kind,
312 embedding_model: None,
313 filter: None,
314 backend_hints: None,
315 })
316 .await?)
317 }
318
319 #[allow(clippy::too_many_arguments)]
334 pub async fn hybrid_search(
335 &self,
336 token: &NamespaceToken,
337 query_text: &str,
338 query_vector: Option<Vec<f32>>,
339 limit: u32,
340 entity_kind: Option<&str>,
341 entity_type: Option<&str>,
342 ) -> RuntimeResult<Vec<SearchHit>> {
343 let candidates = limit.saturating_mul(CANDIDATE_MULTIPLIER).max(limit);
344
345 let ns = token.namespace().as_str().to_owned();
346 let text_hits = self
347 .text(token)?
348 .search(TextSearchRequest {
349 query: query_text.to_string(),
350 mode: TextQueryMode::Plain,
351 filter: Some(TextFilter {
352 namespaces: vec![ns.clone()],
353 ..TextFilter::default()
354 }),
355 top_k: candidates,
356 snippet_chars: 200,
357 })
358 .await?;
359
360 let vector_hits = if query_vector.is_some() || self.config().embedding_model.is_some() {
361 self.vector_search(
362 token,
363 query_vector,
364 Some(query_text),
365 candidates,
366 Some(SubstrateKind::Entity),
367 )
368 .await?
369 } else {
370 Vec::new()
371 };
372
373 let mut fused = rrf_fuse(text_hits, vector_hits, candidates as usize, query_text);
376
377 if !fused.is_empty() {
381 let candidate_ids: Vec<Uuid> = fused.iter().map(|h| h.entity_id).collect();
382 let alive_page = self
383 .entities(token)?
384 .query_entities(
385 token.namespace().as_str(),
386 EntityFilter {
387 ids: candidate_ids,
388 kinds: entity_kind.map(|k| vec![k.to_string()]).unwrap_or_default(),
389 entity_types: entity_type.map(|t| vec![t.to_string()]).unwrap_or_default(),
390 ..EntityFilter::default()
391 },
392 PageRequest {
393 offset: 0,
394 limit: fused.len() as u32,
395 },
396 )
397 .await?;
398 let mut entity_meta: HashMap<Uuid, (String, Option<String>)> = HashMap::new();
400 let mut alive: HashSet<Uuid> = HashSet::new();
401 for e in alive_page.items {
402 alive.insert(e.id);
403 entity_meta.insert(e.id, (e.name, e.description));
404 }
405
406 fused.retain(|h| alive.contains(&h.entity_id));
407
408 for hit in &mut fused {
410 if let Some((name, description)) = entity_meta.get(&hit.entity_id) {
411 if hit.title.is_none() {
412 hit.title = Some(name.clone());
413 }
414 if hit.snippet.is_none() {
415 hit.snippet = description.clone();
416 }
417 }
418 }
419 }
420
421 fused.truncate(limit as usize);
422 Ok(fused)
423 }
424
425 pub async fn knn(
431 &self,
432 token: &NamespaceToken,
433 query_vector: Vec<f32>,
434 top_k: u32,
435 ) -> RuntimeResult<Vec<VectorSearchHit>> {
436 let ns = token.namespace().as_str().to_owned();
437 Ok(self
438 .vectors(token)?
439 .search(VectorSearchRequest {
440 query_vectors: vec![query_vector],
441 top_k,
442 namespace: Some(ns),
443 kind: Some(SubstrateKind::Entity),
444 embedding_model: None,
445 filter: None,
446 backend_hints: None,
447 })
448 .await?)
449 }
450
451 pub async fn rerank(
457 &self,
458 token: &NamespaceToken,
459 query_vector: &[f32],
460 candidate_ids: &[Uuid],
461 top_k: u32,
462 ) -> RuntimeResult<Vec<VectorSearchHit>> {
463 let candidate_set: HashSet<Uuid> = candidate_ids.iter().copied().collect();
464 let ns = token.namespace().as_str().to_owned();
465 let all_hits = self
466 .vectors(token)?
467 .search(VectorSearchRequest {
468 query_vectors: vec![query_vector.to_vec()],
469 top_k: candidate_ids.len() as u32,
470 namespace: Some(ns),
471 kind: Some(SubstrateKind::Entity),
472 embedding_model: None,
473 filter: None,
474 backend_hints: None,
475 })
476 .await?;
477 let mut hits: Vec<VectorSearchHit> = all_hits
478 .into_iter()
479 .filter(|h| candidate_set.contains(&h.subject_id))
480 .collect();
481 hits.sort_by(|a, b| b.score.cmp(&a.score));
482 hits.truncate(top_k as usize);
483 Ok(hits)
484 }
485
486 pub async fn backfill_missing_embeddings(&self, token: &NamespaceToken) -> RuntimeResult<u64> {
499 use khive_storage::types::{SqlRow, SqlStatement, SqlValue};
500
501 let model_names = self.registered_embedding_model_names();
502 if model_names.is_empty() {
503 tracing::debug!(
504 "backfill_missing_embeddings: no embedding models registered, skipping"
505 );
506 return Ok(0);
507 }
508
509 let ns = token.namespace().as_str().to_string();
510 let mut total_backfilled = 0u64;
511
512 for model_name in &model_names {
513 let vec_table = format!("vec_{}", sanitize_key(model_name));
515
516 const PAGE_SIZE: usize = 500;
521 let mut entity_total = 0usize;
522 loop {
523 let entity_sql = SqlStatement {
524 sql: format!(
525 "SELECT id, name, description FROM entities \
526 WHERE namespace = ?1 AND deleted_at IS NULL \
527 AND id NOT IN (\
528 SELECT subject_id FROM {vec_table} \
529 WHERE namespace = ?1 AND embedding_model = ?2 \
530 ) LIMIT {PAGE_SIZE}"
531 ),
532 params: vec![
533 SqlValue::Text(ns.clone()),
534 SqlValue::Text(model_name.clone()),
535 ],
536 label: Some("backfill_entities".into()),
537 };
538
539 let entity_rows: Vec<SqlRow> = {
540 let sql = self.sql();
541 match sql.reader().await {
542 Ok(mut reader) => reader.query_all(entity_sql).await.unwrap_or_default(),
543 Err(_) => vec![],
544 }
545 };
546
547 let batch_len = entity_rows.len();
548 entity_total += batch_len;
549
550 for row in &entity_rows {
551 let id_str = row.columns.first().and_then(|c| {
552 if let SqlValue::Text(s) = &c.value {
553 Some(s.clone())
554 } else {
555 None
556 }
557 });
558 let description = row.columns.get(2).and_then(|c| {
559 if let SqlValue::Text(s) = &c.value {
560 Some(s.clone())
561 } else if let SqlValue::Null = &c.value {
562 None
563 } else {
564 None
565 }
566 });
567
568 let (Some(id_str), Some(desc)) = (id_str, description) else {
569 continue;
570 };
571 let Ok(id) = id_str.parse::<Uuid>() else {
572 continue;
573 };
574 if desc.trim().is_empty() {
575 continue;
576 }
577
578 match self.embed_document_with_model(model_name, &desc).await {
579 Ok(vector) => {
580 if let Ok(vs) = self.vectors_for_model(token, model_name) {
581 match vs
582 .insert(
583 id,
584 SubstrateKind::Entity,
585 &ns,
586 "entity.description",
587 vec![vector],
588 )
589 .await
590 {
591 Ok(()) => {
592 total_backfilled += 1;
593 }
594 Err(e) => {
595 tracing::warn!(
596 id = %id, model = %model_name,
597 error = %e,
598 "backfill_missing_embeddings: entity vector insert failed"
599 );
600 }
601 }
602 }
603 }
604 Err(e) => {
605 tracing::warn!(
606 id = %id, model = %model_name,
607 error = %e,
608 "backfill_missing_embeddings: entity embed failed"
609 );
610 }
611 }
612 }
613
614 if batch_len < PAGE_SIZE {
615 break;
616 }
617 }
618
619 let text_store = self.text_for_notes(token).ok();
621 let note_store = self.notes(token).ok();
622 let mut note_total = 0usize;
623 loop {
624 let note_sql = SqlStatement {
628 sql: format!(
629 "SELECT id FROM notes \
630 WHERE namespace = ?1 AND deleted_at IS NULL \
631 AND id NOT IN (\
632 SELECT subject_id FROM {vec_table} \
633 WHERE namespace = ?1 AND embedding_model = ?2 \
634 ) LIMIT {PAGE_SIZE}"
635 ),
636 params: vec![
637 SqlValue::Text(ns.clone()),
638 SqlValue::Text(model_name.clone()),
639 ],
640 label: Some("backfill_notes".into()),
641 };
642
643 let note_rows: Vec<SqlRow> = {
644 let sql = self.sql();
645 match sql.reader().await {
646 Ok(mut reader) => reader.query_all(note_sql).await.unwrap_or_default(),
647 Err(_) => vec![],
648 }
649 };
650
651 let batch_len = note_rows.len();
652 note_total += batch_len;
653
654 for row in ¬e_rows {
655 let id_str = row.columns.first().and_then(|c| {
656 if let SqlValue::Text(s) = &c.value {
657 Some(s.clone())
658 } else {
659 None
660 }
661 });
662
663 let Some(id_str) = id_str else {
664 continue;
665 };
666 let Ok(id) = id_str.parse::<Uuid>() else {
667 continue;
668 };
669
670 let note = match ¬e_store {
674 Some(store) => match store.get_note(id).await {
675 Ok(Some(n)) => n,
676 _ => continue,
677 },
678 None => continue,
679 };
680
681 if note.content.trim().is_empty() {
682 continue;
683 }
684
685 if model_names.first().map(|n| n.as_str()) == Some(model_name.as_str()) {
688 if let Some(ref ts) = text_store {
689 let _ = ts.upsert_document(note_fts_document(¬e)).await;
690 }
691 }
692
693 let content = note.content.clone();
694 match self.embed_document_with_model(model_name, &content).await {
695 Ok(vector) => {
696 if let Ok(vs) = self.vectors_for_model(token, model_name) {
697 match vs
698 .insert(
699 id,
700 SubstrateKind::Note,
701 &ns,
702 "note.content",
703 vec![vector],
704 )
705 .await
706 {
707 Ok(()) => {
708 total_backfilled += 1;
709 }
710 Err(e) => {
711 tracing::warn!(
712 id = %id, model = %model_name,
713 error = %e,
714 "backfill_missing_embeddings: note vector insert failed"
715 );
716 }
717 }
718 }
719 }
720 Err(e) => {
721 tracing::warn!(
722 id = %id, model = %model_name,
723 error = %e,
724 "backfill_missing_embeddings: note embed failed"
725 );
726 }
727 }
728 }
729
730 if batch_len < PAGE_SIZE {
731 break;
732 }
733 }
734
735 tracing::info!(
736 model = %model_name,
737 namespace = %ns,
738 entities = entity_total,
739 notes = note_total,
740 "backfill_missing_embeddings: model pass complete"
741 );
742 }
743
744 tracing::info!(
745 namespace = %ns,
746 total_backfilled = total_backfilled,
747 "backfill_missing_embeddings: finished"
748 );
749
750 Ok(total_backfilled)
751 }
752
753 pub async fn sweep_orphan_vectors(
768 &self,
769 token: &NamespaceToken,
770 max_delete_per_model: u32,
771 dry_run: bool,
772 ) -> RuntimeResult<u64> {
773 use khive_storage::types::OrphanSweepConfig;
774 use khive_storage::StorageError;
775
776 let model_names = self.registered_embedding_model_names();
777 if model_names.is_empty() {
778 tracing::debug!("sweep_orphan_vectors: no embedding models registered, skipping");
779 return Ok(0);
780 }
781
782 let ns = token.namespace().as_str().to_string();
783 let mut total_deleted = 0u64;
784
785 for model_name in &model_names {
786 let store = match self.vectors_for_model(token, model_name) {
787 Ok(s) => s,
788 Err(e) => {
789 tracing::warn!(
790 model = %model_name,
791 error = %e,
792 "sweep_orphan_vectors: failed to get vector store, skipping model"
793 );
794 continue;
795 }
796 };
797
798 let caps = store.capabilities();
799 if !caps.supports_orphan_sweep {
800 tracing::debug!(
801 model = %model_name,
802 "sweep_orphan_vectors: backend does not support orphan sweep, skipping"
803 );
804 continue;
805 }
806
807 let config = OrphanSweepConfig {
808 subject_id_allowlist: None,
809 namespaces: vec![ns.clone()],
810 substrate_kinds: vec![],
811 max_delete: max_delete_per_model,
812 dry_run,
813 };
814
815 match store.orphan_sweep(&config).await {
816 Ok(result) => {
817 tracing::info!(
818 model = %model_name,
819 namespace = %ns,
820 scanned = result.scanned,
821 deleted = result.deleted,
822 would_delete = result.would_delete,
823 dry_run = dry_run,
824 "sweep_orphan_vectors: sweep complete"
825 );
826 total_deleted += result.deleted;
827 }
828 Err(StorageError::Unsupported { .. }) => {
829 tracing::debug!(
830 model = %model_name,
831 "sweep_orphan_vectors: backend returned Unsupported, skipping"
832 );
833 }
834 Err(e) => {
835 tracing::warn!(
836 model = %model_name,
837 error = %e,
838 "sweep_orphan_vectors: sweep failed, continuing with other models"
839 );
840 }
841 }
842 }
843
844 tracing::info!(
845 namespace = %ns,
846 total_deleted = total_deleted,
847 dry_run = dry_run,
848 "sweep_orphan_vectors: finished"
849 );
850
851 Ok(total_deleted)
852 }
853}
854
855const EXACT_MATCH_BOOST: f64 = 0.5;
859
860fn rrf_fuse(
868 text_hits: Vec<TextSearchHit>,
869 vector_hits: Vec<VectorSearchHit>,
870 limit: usize,
871 query_text: &str,
872) -> Vec<SearchHit> {
873 #[derive(Default)]
874 struct Bucket {
875 score: DeterministicScore,
876 source: Option<SearchSource>,
877 title: Option<String>,
878 snippet: Option<String>,
879 }
880
881 let mut buckets: HashMap<Uuid, Bucket> = HashMap::new();
882
883 let query_lower = query_text.to_lowercase();
884 for (i, hit) in text_hits.into_iter().enumerate() {
885 let rank = i + 1; let entry = buckets.entry(hit.subject_id).or_default();
887 entry.score = entry.score + rrf_score(rank, RRF_K);
888 entry.source = Some(match entry.source {
889 Some(SearchSource::Vector) => SearchSource::Both,
890 _ => SearchSource::Text,
891 });
892 if entry.title.is_none() {
893 if let Some(ref title) = hit.title {
895 if title.to_lowercase() == query_lower {
896 entry.score = entry.score + DeterministicScore::from_f64(EXACT_MATCH_BOOST);
897 }
898 }
899 entry.title = hit.title;
900 }
901 if entry.snippet.is_none() {
902 entry.snippet = hit.snippet;
903 }
904 }
905
906 for (i, hit) in vector_hits.into_iter().enumerate() {
907 let rank = i + 1;
908 let entry = buckets.entry(hit.subject_id).or_default();
909 entry.score = entry.score + rrf_score(rank, RRF_K);
910 entry.source = Some(match entry.source {
911 Some(SearchSource::Text) => SearchSource::Both,
912 _ => SearchSource::Vector,
913 });
914 }
915
916 let mut hits: Vec<SearchHit> = buckets
917 .into_iter()
918 .map(|(id, b)| SearchHit {
919 entity_id: id,
920 score: b.score,
921 source: b.source.expect("each bucket gets a source"),
922 title: b.title,
923 snippet: b.snippet,
924 })
925 .collect();
926
927 hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
928 hits.truncate(limit);
929 hits
930}
931
932#[cfg(test)]
933mod tests {
934 use super::*;
935 use crate::runtime::{KhiveRuntime, NamespaceToken, RuntimeConfig};
936 use khive_storage::types::{TextSearchHit, VectorSearchHit};
937 use khive_types::namespace::Namespace;
938 use lattice_embed::EmbeddingModel;
939
940 fn text_hit(id: Uuid, rank: u32, title: &str) -> TextSearchHit {
941 TextSearchHit {
942 subject_id: id,
943 score: DeterministicScore::from_f64(1.0),
944 rank,
945 title: Some(title.to_string()),
946 snippet: Some("...".to_string()),
947 }
948 }
949
950 fn vector_hit(id: Uuid, rank: u32) -> VectorSearchHit {
951 VectorSearchHit {
952 subject_id: id,
953 score: DeterministicScore::from_f64(0.9),
954 rank,
955 }
956 }
957
958 #[test]
959 fn rrf_fuse_text_only() {
960 let a = Uuid::new_v4();
961 let b = Uuid::new_v4();
962 let text = vec![text_hit(a, 1, "A"), text_hit(b, 2, "B")];
963 let hits = rrf_fuse(text, vec![], 10, "query");
964 assert_eq!(hits.len(), 2);
965 assert_eq!(hits[0].entity_id, a);
966 assert_eq!(hits[0].source, SearchSource::Text);
967 assert_eq!(hits[0].title.as_deref(), Some("A"));
968 }
969
970 #[test]
971 fn rrf_fuse_vector_only() {
972 let a = Uuid::new_v4();
973 let hits = rrf_fuse(vec![], vec![vector_hit(a, 1)], 10, "query");
974 assert_eq!(hits.len(), 1);
975 assert_eq!(hits[0].source, SearchSource::Vector);
976 assert!(hits[0].title.is_none());
977 }
978
979 #[test]
980 fn rrf_fuse_marks_both_when_in_both_lists() {
981 let id = Uuid::new_v4();
982 let text = vec![text_hit(id, 1, "A")];
983 let vec = vec![vector_hit(id, 1)];
984 let hits = rrf_fuse(text, vec, 10, "query");
985 assert_eq!(hits.len(), 1);
986 assert_eq!(hits[0].source, SearchSource::Both);
987 }
988
989 #[test]
990 fn rrf_fuse_respects_limit() {
991 let hits: Vec<TextSearchHit> = (0..20)
992 .map(|i| text_hit(Uuid::new_v4(), i + 1, "x"))
993 .collect();
994 let fused = rrf_fuse(hits, vec![], 5, "query");
995 assert_eq!(fused.len(), 5);
996 }
997
998 #[test]
999 fn rrf_fuse_orders_higher_score_first() {
1000 let a = Uuid::new_v4();
1002 let b = Uuid::new_v4();
1003 let text = vec![text_hit(a, 1, "A")];
1004 let vec = vec![vector_hit(a, 1), vector_hit(b, 2)];
1005 let hits = rrf_fuse(text, vec, 10, "query");
1006 assert_eq!(hits[0].entity_id, a);
1007 assert_eq!(hits[0].source, SearchSource::Both);
1008 assert!(hits[0].score > hits[1].score);
1009 }
1010
1011 #[test]
1012 fn rrf_fuse_k10_score_spread_exceeds_threshold() {
1013 let ids: Vec<Uuid> = (0..10).map(|_| Uuid::new_v4()).collect();
1016 let text: Vec<TextSearchHit> = ids
1017 .iter()
1018 .enumerate()
1019 .map(|(i, &id)| text_hit(id, (i + 1) as u32, "x"))
1020 .collect();
1021 let hits = rrf_fuse(text, vec![], 10, "query");
1022 assert_eq!(hits.len(), 10);
1023 let top_score = hits[0].score.to_f64();
1024 let bottom_score = hits[9].score.to_f64();
1025 let spread = top_score - bottom_score;
1026 assert!(
1027 spread >= 0.03,
1028 "score spread {spread:.4} between rank 1 and rank 10 must be ≥ 0.03 (was {spread:.4})"
1029 );
1030 }
1031
1032 #[test]
1033 fn rrf_fuse_exact_match_boost_elevates_score() {
1034 let exact_id = Uuid::new_v4();
1037 let other_id = Uuid::new_v4();
1038 let text = vec![
1040 text_hit(other_id, 1, "something else"),
1041 text_hit(exact_id, 2, "FlashAttention"),
1042 ];
1043 let hits = rrf_fuse(text, vec![], 10, "flashattention");
1044 assert_eq!(hits.len(), 2);
1045 assert_eq!(
1046 hits[0].entity_id, exact_id,
1047 "exact match must rank first despite being rank-2 in raw text search"
1048 );
1049 }
1050
1051 #[test]
1054 fn embed_batch_unconfigured_on_memory_runtime() {
1055 let rt = KhiveRuntime::memory().unwrap();
1057 let result = tokio::runtime::Runtime::new()
1058 .unwrap()
1059 .block_on(rt.embed_batch(&[]));
1060 assert!(result.is_ok());
1062 assert!(result.unwrap().is_empty());
1063 }
1064
1065 #[test]
1066 fn embed_batch_empty_input_returns_empty_vec() {
1067 let rt = KhiveRuntime::memory().unwrap();
1069 let result = tokio::runtime::Runtime::new()
1070 .unwrap()
1071 .block_on(rt.embed_batch(&[]));
1072 assert_eq!(result.unwrap(), Vec::<Vec<f32>>::new());
1073 }
1074
1075 #[test]
1076 fn embed_batch_no_model_non_empty_returns_unconfigured() {
1077 let rt = KhiveRuntime::memory().unwrap();
1078 let texts = vec!["hello".to_string()];
1079 let result = tokio::runtime::Runtime::new()
1080 .unwrap()
1081 .block_on(rt.embed_batch(&texts));
1082 match result {
1083 Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
1084 Err(other) => panic!("expected Unconfigured, got {:?}", other),
1085 Ok(_) => panic!("expected Err, got Ok"),
1086 }
1087 }
1088
1089 #[test]
1090 #[ignore = "loads ~80 MB model; run with --include-ignored"]
1091 fn embed_batch_count_matches_input() {
1092 let config = RuntimeConfig {
1093 db_path: None,
1094 default_namespace: Namespace::parse("test").unwrap(),
1095 embedding_model: Some(EmbeddingModel::AllMiniLmL6V2),
1096 packs: vec!["kg".to_string()],
1097 ..RuntimeConfig::default()
1098 };
1099 let rt = KhiveRuntime::new(config).unwrap();
1100 let texts: Vec<String> = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
1101 let result = tokio::runtime::Runtime::new()
1102 .unwrap()
1103 .block_on(rt.embed_batch(&texts));
1104 let embeddings = result.unwrap();
1105 assert_eq!(embeddings.len(), texts.len());
1106 }
1107
1108 #[test]
1109 fn vector_search_requires_embedding_or_text() {
1110 let rt = KhiveRuntime::memory().unwrap();
1111 let tok = NamespaceToken::local();
1112 let result = tokio::runtime::Runtime::new()
1113 .unwrap()
1114 .block_on(rt.vector_search(&tok, None, None, 10, Some(SubstrateKind::Entity)));
1115 match result {
1116 Err(crate::RuntimeError::InvalidInput(msg)) => {
1117 assert!(msg.contains("query_embedding or query_text"), "msg: {msg}");
1118 }
1119 other => panic!("expected InvalidInput, got {other:?}"),
1120 }
1121 }
1122
1123 #[test]
1124 fn vector_search_text_without_model_returns_unconfigured() {
1125 let rt = KhiveRuntime::memory().unwrap();
1126 let tok = NamespaceToken::local();
1127 let result = tokio::runtime::Runtime::new()
1128 .unwrap()
1129 .block_on(rt.vector_search(
1130 &tok,
1131 None,
1132 Some("attention"),
1133 10,
1134 Some(SubstrateKind::Entity),
1135 ));
1136 match result {
1137 Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
1138 other => panic!("expected Unconfigured, got {other:?}"),
1139 }
1140 }
1141
1142 #[test]
1143 #[ignore = "loads ~80 MB model; run with --include-ignored"]
1144 fn embed_batch_vectors_have_expected_dimensions() {
1145 let model = EmbeddingModel::AllMiniLmL6V2;
1146 let config = RuntimeConfig {
1147 db_path: None,
1148 default_namespace: Namespace::parse("test").unwrap(),
1149 embedding_model: Some(model),
1150 packs: vec!["kg".to_string()],
1151 ..RuntimeConfig::default()
1152 };
1153 let rt = KhiveRuntime::new(config).unwrap();
1154 let texts = vec!["hello world".to_string()];
1155 let result = tokio::runtime::Runtime::new()
1156 .unwrap()
1157 .block_on(rt.embed_batch(&texts));
1158 let embeddings = result.unwrap();
1159 assert_eq!(embeddings[0].len(), model.dimensions());
1160 }
1161
1162 #[tokio::test]
1165 async fn hybrid_search_entity_hit_has_title() {
1166 let rt = KhiveRuntime::memory().unwrap();
1167 let tok = NamespaceToken::local();
1168 rt.create_entity(
1169 &tok,
1170 "concept",
1171 None,
1172 "FlashAttention",
1173 Some("IO-aware exact attention using tiling"),
1174 None,
1175 vec![],
1176 )
1177 .await
1178 .unwrap();
1179
1180 let hits = rt
1181 .hybrid_search(&tok, "FlashAttention", None, 10, None, None)
1182 .await
1183 .unwrap();
1184
1185 assert!(!hits.is_empty(), "should find the entity");
1186 let hit = &hits[0];
1187 assert!(hit.title.is_some(), "title must be populated");
1188 assert!(
1189 hit.title.as_deref().unwrap().contains("FlashAttention"),
1190 "title must contain entity name"
1191 );
1192 }
1193
1194 #[test]
1197 #[ignore = "loads ~80 MB model; run with --include-ignored"]
1198 fn minilm_document_and_query_embed_are_identical_no_prefix_model() {
1199 let model = EmbeddingModel::AllMiniLmL6V2;
1203 let config = RuntimeConfig {
1204 db_path: None,
1205 default_namespace: Namespace::parse("test").unwrap(),
1206 embedding_model: Some(model),
1207 packs: vec!["kg".to_string()],
1208 ..RuntimeConfig::default()
1209 };
1210 let rt = KhiveRuntime::new(config).unwrap();
1211 let text = "attention is all you need".to_string();
1212 let rt_ref = &rt;
1213 let (doc_emb, query_emb) = tokio::runtime::Runtime::new().unwrap().block_on(async {
1214 let d = rt_ref
1215 .embed_document_with_model(&model.to_string(), &text)
1216 .await
1217 .unwrap();
1218 let q = rt_ref
1219 .embed_query_with_model(&model.to_string(), &text)
1220 .await
1221 .unwrap();
1222 (d, q)
1223 });
1224 assert_eq!(
1225 doc_emb, query_emb,
1226 "MiniLM has no instruction prefix: document and query embeds must be identical"
1227 );
1228 }
1229
1230 #[test]
1231 #[ignore = "loads multilingual-e5-small (~90 MB); run with --include-ignored"]
1232 fn e5_document_and_query_embed_differ_instruction_tuned_model() {
1233 let model = EmbeddingModel::MultilingualE5Small;
1238 let config = RuntimeConfig {
1239 db_path: None,
1240 default_namespace: Namespace::parse("test").unwrap(),
1241 embedding_model: Some(model),
1242 packs: vec!["kg".to_string()],
1243 ..RuntimeConfig::default()
1244 };
1245 let rt = KhiveRuntime::new(config).unwrap();
1246 let text = "attention is all you need".to_string();
1247 let rt_ref = &rt;
1248 let (doc_emb, query_emb) = tokio::runtime::Runtime::new().unwrap().block_on(async {
1249 let d = rt_ref
1250 .embed_document_with_model(&model.to_string(), &text)
1251 .await
1252 .unwrap();
1253 let q = rt_ref
1254 .embed_query_with_model(&model.to_string(), &text)
1255 .await
1256 .unwrap();
1257 (d, q)
1258 });
1259 assert_ne!(
1260 doc_emb, query_emb,
1261 "multilingual-e5-small uses asymmetric prefixes: document ('passage: ') \
1262 and query ('query: ') embeds of the same text must differ"
1263 );
1264 }
1265}