1use std::collections::{HashMap, HashSet};
6
7use uuid::Uuid;
8
9use crate::error::{RuntimeError, RuntimeResult};
10use crate::runtime::{parse_embedding_model_alias, sanitize_key, KhiveRuntime, NamespaceToken};
11use khive_score::{rrf_score, DeterministicScore};
12use khive_storage::types::{
13 PageRequest, TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
14 VectorSearchRequest,
15};
16use khive_storage::EntityFilter;
17use khive_types::SubstrateKind;
18
19#[derive(Clone, Debug)]
21pub struct SearchHit {
22 pub entity_id: Uuid,
23 pub score: DeterministicScore,
24 pub source: SearchSource,
25 pub title: Option<String>,
26 pub snippet: Option<String>,
27}
28
29#[derive(Clone, Copy, Debug, PartialEq, Eq)]
31pub enum SearchSource {
32 Vector,
33 Text,
34 Both,
35}
36
37const RRF_K: usize = 10;
45
46const CANDIDATE_MULTIPLIER: u32 = 4;
48
49impl KhiveRuntime {
50 pub async fn embed(&self, text: &str) -> RuntimeResult<Vec<f32>> {
55 let model_name = self.default_embedder_name();
56 if model_name.is_empty() {
57 return Err(RuntimeError::Unconfigured("embedding_model".into()));
58 }
59 self.embed_with_model(model_name, text).await
60 }
61
62 pub async fn embed_with_model(&self, model_name: &str, text: &str) -> RuntimeResult<Vec<f32>> {
74 let model = parse_embedding_model_alias(model_name);
78 let service = self.embedder(model_name).await?;
79 let emb_model = model.unwrap_or_default();
80 Ok(service.embed_one(text, emb_model).await?)
81 }
82
83 pub async fn embed_batch(&self, texts: &[String]) -> RuntimeResult<Vec<Vec<f32>>> {
91 if texts.is_empty() {
92 return Ok(vec![]);
93 }
94 let model_name = self.default_embedder_name();
95 if model_name.is_empty() {
96 return Err(RuntimeError::Unconfigured("embedding_model".into()));
97 }
98 self.embed_batch_with_model(model_name, texts).await
99 }
100
101 pub async fn embed_batch_with_model(
106 &self,
107 model_name: &str,
108 texts: &[String],
109 ) -> RuntimeResult<Vec<Vec<f32>>> {
110 if texts.is_empty() {
111 return Ok(vec![]);
112 }
113 let model = parse_embedding_model_alias(model_name);
114 let service = self.embedder(model_name).await?;
115 let emb_model = model.unwrap_or_default();
116 Ok(service.embed(texts, emb_model).await?)
117 }
118
119 pub async fn vector_search(
125 &self,
126 token: &NamespaceToken,
127 query_embedding: Option<Vec<f32>>,
128 query_text: Option<&str>,
129 top_k: u32,
130 kind: Option<SubstrateKind>,
131 ) -> RuntimeResult<Vec<VectorSearchHit>> {
132 let embedding = match query_embedding {
133 Some(vec) => vec,
134 None => {
135 let text = query_text.ok_or_else(|| {
136 RuntimeError::InvalidInput(
137 "vector search requires query_embedding or query_text".into(),
138 )
139 })?;
140 if text.trim().is_empty() {
141 return Err(RuntimeError::InvalidInput(
142 "query_text must not be empty".into(),
143 ));
144 }
145 self.embed(text).await?
146 }
147 };
148
149 let ns = token.namespace().as_str().to_owned();
150 Ok(self
151 .vectors(token)?
152 .search(VectorSearchRequest {
153 query_vectors: vec![embedding],
154 top_k,
155 namespace: Some(ns),
156 kind,
157 embedding_model: None,
158 filter: None,
159 backend_hints: None,
160 })
161 .await?)
162 }
163
164 #[allow(clippy::too_many_arguments)]
179 pub async fn hybrid_search(
180 &self,
181 token: &NamespaceToken,
182 query_text: &str,
183 query_vector: Option<Vec<f32>>,
184 limit: u32,
185 entity_kind: Option<&str>,
186 entity_type: Option<&str>,
187 ) -> RuntimeResult<Vec<SearchHit>> {
188 let candidates = limit.saturating_mul(CANDIDATE_MULTIPLIER).max(limit);
189
190 let ns = token.namespace().as_str().to_owned();
191 let text_hits = self
192 .text(token)?
193 .search(TextSearchRequest {
194 query: query_text.to_string(),
195 mode: TextQueryMode::Plain,
196 filter: Some(TextFilter {
197 namespaces: vec![ns.clone()],
198 ..TextFilter::default()
199 }),
200 top_k: candidates,
201 snippet_chars: 200,
202 })
203 .await?;
204
205 let vector_hits = if query_vector.is_some() || self.config().embedding_model.is_some() {
206 self.vector_search(
207 token,
208 query_vector,
209 Some(query_text),
210 candidates,
211 Some(SubstrateKind::Entity),
212 )
213 .await?
214 } else {
215 Vec::new()
216 };
217
218 let mut fused = rrf_fuse(text_hits, vector_hits, candidates as usize, query_text);
221
222 if !fused.is_empty() {
226 let candidate_ids: Vec<Uuid> = fused.iter().map(|h| h.entity_id).collect();
227 let alive_page = self
228 .entities(token)?
229 .query_entities(
230 token.namespace().as_str(),
231 EntityFilter {
232 ids: candidate_ids,
233 kinds: entity_kind.map(|k| vec![k.to_string()]).unwrap_or_default(),
234 entity_types: entity_type.map(|t| vec![t.to_string()]).unwrap_or_default(),
235 ..EntityFilter::default()
236 },
237 PageRequest {
238 offset: 0,
239 limit: fused.len() as u32,
240 },
241 )
242 .await?;
243 let mut entity_meta: HashMap<Uuid, (String, Option<String>)> = HashMap::new();
245 let mut alive: HashSet<Uuid> = HashSet::new();
246 for e in alive_page.items {
247 alive.insert(e.id);
248 entity_meta.insert(e.id, (e.name, e.description));
249 }
250
251 fused.retain(|h| alive.contains(&h.entity_id));
252
253 for hit in &mut fused {
255 if let Some((name, description)) = entity_meta.get(&hit.entity_id) {
256 if hit.title.is_none() {
257 hit.title = Some(name.clone());
258 }
259 if hit.snippet.is_none() {
260 hit.snippet = description.clone();
261 }
262 }
263 }
264 }
265
266 fused.truncate(limit as usize);
267 Ok(fused)
268 }
269
270 pub async fn knn(
276 &self,
277 token: &NamespaceToken,
278 query_vector: Vec<f32>,
279 top_k: u32,
280 ) -> RuntimeResult<Vec<VectorSearchHit>> {
281 let ns = token.namespace().as_str().to_owned();
282 Ok(self
283 .vectors(token)?
284 .search(VectorSearchRequest {
285 query_vectors: vec![query_vector],
286 top_k,
287 namespace: Some(ns),
288 kind: Some(SubstrateKind::Entity),
289 embedding_model: None,
290 filter: None,
291 backend_hints: None,
292 })
293 .await?)
294 }
295
296 pub async fn rerank(
302 &self,
303 token: &NamespaceToken,
304 query_vector: &[f32],
305 candidate_ids: &[Uuid],
306 top_k: u32,
307 ) -> RuntimeResult<Vec<VectorSearchHit>> {
308 let candidate_set: HashSet<Uuid> = candidate_ids.iter().copied().collect();
309 let ns = token.namespace().as_str().to_owned();
310 let all_hits = self
311 .vectors(token)?
312 .search(VectorSearchRequest {
313 query_vectors: vec![query_vector.to_vec()],
314 top_k: candidate_ids.len() as u32,
315 namespace: Some(ns),
316 kind: Some(SubstrateKind::Entity),
317 embedding_model: None,
318 filter: None,
319 backend_hints: None,
320 })
321 .await?;
322 let mut hits: Vec<VectorSearchHit> = all_hits
323 .into_iter()
324 .filter(|h| candidate_set.contains(&h.subject_id))
325 .collect();
326 hits.sort_by(|a, b| b.score.cmp(&a.score));
327 hits.truncate(top_k as usize);
328 Ok(hits)
329 }
330
331 pub async fn backfill_missing_embeddings(&self, token: &NamespaceToken) -> RuntimeResult<u64> {
344 use khive_storage::types::{SqlRow, SqlStatement, SqlValue, TextDocument};
345
346 let model_names = self.registered_embedding_model_names();
347 if model_names.is_empty() {
348 tracing::debug!(
349 "backfill_missing_embeddings: no embedding models registered, skipping"
350 );
351 return Ok(0);
352 }
353
354 let ns = token.namespace().as_str().to_string();
355 let mut total_backfilled = 0u64;
356
357 for model_name in &model_names {
358 let vec_table = format!("vec_{}", sanitize_key(model_name));
360
361 const PAGE_SIZE: usize = 500;
366 let mut entity_total = 0usize;
367 loop {
368 let entity_sql = SqlStatement {
369 sql: format!(
370 "SELECT id, name, description FROM entities \
371 WHERE namespace = ?1 AND deleted_at IS NULL \
372 AND id NOT IN (\
373 SELECT subject_id FROM {vec_table} \
374 WHERE namespace = ?1 AND embedding_model = ?2 \
375 ) LIMIT {PAGE_SIZE}"
376 ),
377 params: vec![
378 SqlValue::Text(ns.clone()),
379 SqlValue::Text(model_name.clone()),
380 ],
381 label: Some("backfill_entities".into()),
382 };
383
384 let entity_rows: Vec<SqlRow> = {
385 let sql = self.sql();
386 match sql.reader().await {
387 Ok(mut reader) => reader.query_all(entity_sql).await.unwrap_or_default(),
388 Err(_) => vec![],
389 }
390 };
391
392 let batch_len = entity_rows.len();
393 entity_total += batch_len;
394
395 for row in &entity_rows {
396 let id_str = row.columns.first().and_then(|c| {
397 if let SqlValue::Text(s) = &c.value {
398 Some(s.clone())
399 } else {
400 None
401 }
402 });
403 let description = row.columns.get(2).and_then(|c| {
404 if let SqlValue::Text(s) = &c.value {
405 Some(s.clone())
406 } else if let SqlValue::Null = &c.value {
407 None
408 } else {
409 None
410 }
411 });
412
413 let (Some(id_str), Some(desc)) = (id_str, description) else {
414 continue;
415 };
416 let Ok(id) = id_str.parse::<Uuid>() else {
417 continue;
418 };
419 if desc.trim().is_empty() {
420 continue;
421 }
422
423 match self.embed_with_model(model_name, &desc).await {
424 Ok(vector) => {
425 if let Ok(vs) = self.vectors_for_model(token, model_name) {
426 match vs
427 .insert(
428 id,
429 SubstrateKind::Entity,
430 &ns,
431 "entity.description",
432 vec![vector],
433 )
434 .await
435 {
436 Ok(()) => {
437 total_backfilled += 1;
438 }
439 Err(e) => {
440 tracing::warn!(
441 id = %id, model = %model_name,
442 error = %e,
443 "backfill_missing_embeddings: entity vector insert failed"
444 );
445 }
446 }
447 }
448 }
449 Err(e) => {
450 tracing::warn!(
451 id = %id, model = %model_name,
452 error = %e,
453 "backfill_missing_embeddings: entity embed failed"
454 );
455 }
456 }
457 }
458
459 if batch_len < PAGE_SIZE {
460 break;
461 }
462 }
463
464 let text_store = self.text_for_notes(token).ok();
466 let mut note_total = 0usize;
467 loop {
468 let note_sql = SqlStatement {
469 sql: format!(
470 "SELECT id, content FROM notes \
471 WHERE namespace = ?1 AND deleted_at IS NULL \
472 AND id NOT IN (\
473 SELECT subject_id FROM {vec_table} \
474 WHERE namespace = ?1 AND embedding_model = ?2 \
475 ) LIMIT {PAGE_SIZE}"
476 ),
477 params: vec![
478 SqlValue::Text(ns.clone()),
479 SqlValue::Text(model_name.clone()),
480 ],
481 label: Some("backfill_notes".into()),
482 };
483
484 let note_rows: Vec<SqlRow> = {
485 let sql = self.sql();
486 match sql.reader().await {
487 Ok(mut reader) => reader.query_all(note_sql).await.unwrap_or_default(),
488 Err(_) => vec![],
489 }
490 };
491
492 let batch_len = note_rows.len();
493 note_total += batch_len;
494
495 for row in ¬e_rows {
496 let id_str = row.columns.first().and_then(|c| {
497 if let SqlValue::Text(s) = &c.value {
498 Some(s.clone())
499 } else {
500 None
501 }
502 });
503 let content = row.columns.get(1).and_then(|c| {
504 if let SqlValue::Text(s) = &c.value {
505 Some(s.clone())
506 } else {
507 None
508 }
509 });
510
511 let (Some(id_str), Some(content)) = (id_str, content) else {
512 continue;
513 };
514 let Ok(id) = id_str.parse::<Uuid>() else {
515 continue;
516 };
517 if content.trim().is_empty() {
518 continue;
519 }
520
521 if model_names.first().map(|n| n.as_str()) == Some(model_name.as_str()) {
523 if let Some(ref ts) = text_store {
524 let _ = ts
525 .upsert_document(TextDocument {
526 subject_id: id,
527 namespace: ns.clone(),
528 kind: SubstrateKind::Note,
529 title: None,
530 body: content.clone(),
531 tags: vec![],
532 metadata: None,
533 updated_at: chrono::Utc::now(),
534 })
535 .await;
536 }
537 }
538
539 match self.embed_with_model(model_name, &content).await {
540 Ok(vector) => {
541 if let Ok(vs) = self.vectors_for_model(token, model_name) {
542 match vs
543 .insert(
544 id,
545 SubstrateKind::Note,
546 &ns,
547 "note.content",
548 vec![vector],
549 )
550 .await
551 {
552 Ok(()) => {
553 total_backfilled += 1;
554 }
555 Err(e) => {
556 tracing::warn!(
557 id = %id, model = %model_name,
558 error = %e,
559 "backfill_missing_embeddings: note vector insert failed"
560 );
561 }
562 }
563 }
564 }
565 Err(e) => {
566 tracing::warn!(
567 id = %id, model = %model_name,
568 error = %e,
569 "backfill_missing_embeddings: note embed failed"
570 );
571 }
572 }
573 }
574
575 if batch_len < PAGE_SIZE {
576 break;
577 }
578 }
579
580 tracing::info!(
581 model = %model_name,
582 namespace = %ns,
583 entities = entity_total,
584 notes = note_total,
585 "backfill_missing_embeddings: model pass complete"
586 );
587 }
588
589 tracing::info!(
590 namespace = %ns,
591 total_backfilled = total_backfilled,
592 "backfill_missing_embeddings: finished"
593 );
594
595 Ok(total_backfilled)
596 }
597
598 pub async fn sweep_orphan_vectors(
614 &self,
615 token: &NamespaceToken,
616 max_delete_per_model: u32,
617 dry_run: bool,
618 ) -> RuntimeResult<u64> {
619 use khive_storage::types::OrphanSweepConfig;
620 use khive_storage::StorageError;
621
622 let model_names = self.registered_embedding_model_names();
623 if model_names.is_empty() {
624 tracing::debug!("sweep_orphan_vectors: no embedding models registered, skipping");
625 return Ok(0);
626 }
627
628 let ns = token.namespace().as_str().to_string();
629 let mut total_deleted = 0u64;
630
631 for model_name in &model_names {
632 let store = match self.vectors_for_model(token, model_name) {
633 Ok(s) => s,
634 Err(e) => {
635 tracing::warn!(
636 model = %model_name,
637 error = %e,
638 "sweep_orphan_vectors: failed to get vector store, skipping model"
639 );
640 continue;
641 }
642 };
643
644 let caps = store.capabilities();
645 if !caps.supports_orphan_sweep {
646 tracing::debug!(
647 model = %model_name,
648 "sweep_orphan_vectors: backend does not support orphan sweep, skipping"
649 );
650 continue;
651 }
652
653 let config = OrphanSweepConfig {
654 subject_id_allowlist: None,
655 namespaces: vec![ns.clone()],
656 substrate_kinds: vec![],
657 max_delete: max_delete_per_model,
658 dry_run,
659 };
660
661 match store.orphan_sweep(&config).await {
662 Ok(result) => {
663 tracing::info!(
664 model = %model_name,
665 namespace = %ns,
666 scanned = result.scanned,
667 deleted = result.deleted,
668 would_delete = result.would_delete,
669 dry_run = dry_run,
670 "sweep_orphan_vectors: sweep complete"
671 );
672 total_deleted += result.deleted;
673 }
674 Err(StorageError::Unsupported { .. }) => {
675 tracing::debug!(
676 model = %model_name,
677 "sweep_orphan_vectors: backend returned Unsupported, skipping"
678 );
679 }
680 Err(e) => {
681 tracing::warn!(
682 model = %model_name,
683 error = %e,
684 "sweep_orphan_vectors: sweep failed, continuing with other models"
685 );
686 }
687 }
688 }
689
690 tracing::info!(
691 namespace = %ns,
692 total_deleted = total_deleted,
693 dry_run = dry_run,
694 "sweep_orphan_vectors: finished"
695 );
696
697 Ok(total_deleted)
698 }
699}
700
701const EXACT_MATCH_BOOST: f64 = 0.5;
705
706fn rrf_fuse(
714 text_hits: Vec<TextSearchHit>,
715 vector_hits: Vec<VectorSearchHit>,
716 limit: usize,
717 query_text: &str,
718) -> Vec<SearchHit> {
719 #[derive(Default)]
720 struct Bucket {
721 score: DeterministicScore,
722 source: Option<SearchSource>,
723 title: Option<String>,
724 snippet: Option<String>,
725 }
726
727 let mut buckets: HashMap<Uuid, Bucket> = HashMap::new();
728
729 let query_lower = query_text.to_lowercase();
730 for (i, hit) in text_hits.into_iter().enumerate() {
731 let rank = i + 1; let entry = buckets.entry(hit.subject_id).or_default();
733 entry.score = entry.score + rrf_score(rank, RRF_K);
734 entry.source = Some(match entry.source {
735 Some(SearchSource::Vector) => SearchSource::Both,
736 _ => SearchSource::Text,
737 });
738 if entry.title.is_none() {
739 if let Some(ref title) = hit.title {
741 if title.to_lowercase() == query_lower {
742 entry.score = entry.score + DeterministicScore::from_f64(EXACT_MATCH_BOOST);
743 }
744 }
745 entry.title = hit.title;
746 }
747 if entry.snippet.is_none() {
748 entry.snippet = hit.snippet;
749 }
750 }
751
752 for (i, hit) in vector_hits.into_iter().enumerate() {
753 let rank = i + 1;
754 let entry = buckets.entry(hit.subject_id).or_default();
755 entry.score = entry.score + rrf_score(rank, RRF_K);
756 entry.source = Some(match entry.source {
757 Some(SearchSource::Text) => SearchSource::Both,
758 _ => SearchSource::Vector,
759 });
760 }
761
762 let mut hits: Vec<SearchHit> = buckets
763 .into_iter()
764 .map(|(id, b)| SearchHit {
765 entity_id: id,
766 score: b.score,
767 source: b.source.expect("each bucket gets a source"),
768 title: b.title,
769 snippet: b.snippet,
770 })
771 .collect();
772
773 hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
774 hits.truncate(limit);
775 hits
776}
777
778#[cfg(test)]
779mod tests {
780 use super::*;
781 use crate::runtime::{KhiveRuntime, NamespaceToken, RuntimeConfig};
782 use khive_storage::types::{TextSearchHit, VectorSearchHit};
783 use khive_types::namespace::Namespace;
784 use lattice_embed::EmbeddingModel;
785
786 fn text_hit(id: Uuid, rank: u32, title: &str) -> TextSearchHit {
787 TextSearchHit {
788 subject_id: id,
789 score: DeterministicScore::from_f64(1.0),
790 rank,
791 title: Some(title.to_string()),
792 snippet: Some("...".to_string()),
793 }
794 }
795
796 fn vector_hit(id: Uuid, rank: u32) -> VectorSearchHit {
797 VectorSearchHit {
798 subject_id: id,
799 score: DeterministicScore::from_f64(0.9),
800 rank,
801 }
802 }
803
804 #[test]
805 fn rrf_fuse_text_only() {
806 let a = Uuid::new_v4();
807 let b = Uuid::new_v4();
808 let text = vec![text_hit(a, 1, "A"), text_hit(b, 2, "B")];
809 let hits = rrf_fuse(text, vec![], 10, "query");
810 assert_eq!(hits.len(), 2);
811 assert_eq!(hits[0].entity_id, a);
812 assert_eq!(hits[0].source, SearchSource::Text);
813 assert_eq!(hits[0].title.as_deref(), Some("A"));
814 }
815
816 #[test]
817 fn rrf_fuse_vector_only() {
818 let a = Uuid::new_v4();
819 let hits = rrf_fuse(vec![], vec![vector_hit(a, 1)], 10, "query");
820 assert_eq!(hits.len(), 1);
821 assert_eq!(hits[0].source, SearchSource::Vector);
822 assert!(hits[0].title.is_none());
823 }
824
825 #[test]
826 fn rrf_fuse_marks_both_when_in_both_lists() {
827 let id = Uuid::new_v4();
828 let text = vec![text_hit(id, 1, "A")];
829 let vec = vec![vector_hit(id, 1)];
830 let hits = rrf_fuse(text, vec, 10, "query");
831 assert_eq!(hits.len(), 1);
832 assert_eq!(hits[0].source, SearchSource::Both);
833 }
834
835 #[test]
836 fn rrf_fuse_respects_limit() {
837 let hits: Vec<TextSearchHit> = (0..20)
838 .map(|i| text_hit(Uuid::new_v4(), i + 1, "x"))
839 .collect();
840 let fused = rrf_fuse(hits, vec![], 5, "query");
841 assert_eq!(fused.len(), 5);
842 }
843
844 #[test]
845 fn rrf_fuse_orders_higher_score_first() {
846 let a = Uuid::new_v4();
848 let b = Uuid::new_v4();
849 let text = vec![text_hit(a, 1, "A")];
850 let vec = vec![vector_hit(a, 1), vector_hit(b, 2)];
851 let hits = rrf_fuse(text, vec, 10, "query");
852 assert_eq!(hits[0].entity_id, a);
853 assert_eq!(hits[0].source, SearchSource::Both);
854 assert!(hits[0].score > hits[1].score);
855 }
856
857 #[test]
858 fn rrf_fuse_k10_score_spread_exceeds_threshold() {
859 let ids: Vec<Uuid> = (0..10).map(|_| Uuid::new_v4()).collect();
862 let text: Vec<TextSearchHit> = ids
863 .iter()
864 .enumerate()
865 .map(|(i, &id)| text_hit(id, (i + 1) as u32, "x"))
866 .collect();
867 let hits = rrf_fuse(text, vec![], 10, "query");
868 assert_eq!(hits.len(), 10);
869 let top_score = hits[0].score.to_f64();
870 let bottom_score = hits[9].score.to_f64();
871 let spread = top_score - bottom_score;
872 assert!(
873 spread >= 0.03,
874 "score spread {spread:.4} between rank 1 and rank 10 must be ≥ 0.03 (was {spread:.4})"
875 );
876 }
877
878 #[test]
879 fn rrf_fuse_exact_match_boost_elevates_score() {
880 let exact_id = Uuid::new_v4();
883 let other_id = Uuid::new_v4();
884 let text = vec![
886 text_hit(other_id, 1, "something else"),
887 text_hit(exact_id, 2, "FlashAttention"),
888 ];
889 let hits = rrf_fuse(text, vec![], 10, "flashattention");
890 assert_eq!(hits.len(), 2);
891 assert_eq!(
892 hits[0].entity_id, exact_id,
893 "exact match must rank first despite being rank-2 in raw text search"
894 );
895 }
896
897 #[test]
900 fn embed_batch_unconfigured_on_memory_runtime() {
901 let rt = KhiveRuntime::memory().unwrap();
903 let result = tokio::runtime::Runtime::new()
904 .unwrap()
905 .block_on(rt.embed_batch(&[]));
906 assert!(result.is_ok());
908 assert!(result.unwrap().is_empty());
909 }
910
911 #[test]
912 fn embed_batch_empty_input_returns_empty_vec() {
913 let rt = KhiveRuntime::memory().unwrap();
915 let result = tokio::runtime::Runtime::new()
916 .unwrap()
917 .block_on(rt.embed_batch(&[]));
918 assert_eq!(result.unwrap(), Vec::<Vec<f32>>::new());
919 }
920
921 #[test]
922 fn embed_batch_no_model_non_empty_returns_unconfigured() {
923 let rt = KhiveRuntime::memory().unwrap();
924 let texts = vec!["hello".to_string()];
925 let result = tokio::runtime::Runtime::new()
926 .unwrap()
927 .block_on(rt.embed_batch(&texts));
928 match result {
929 Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
930 Err(other) => panic!("expected Unconfigured, got {:?}", other),
931 Ok(_) => panic!("expected Err, got Ok"),
932 }
933 }
934
935 #[test]
936 #[ignore = "loads ~80 MB model; run with --include-ignored"]
937 fn embed_batch_count_matches_input() {
938 let config = RuntimeConfig {
939 db_path: None,
940 default_namespace: Namespace::parse("test").unwrap(),
941 embedding_model: Some(EmbeddingModel::AllMiniLmL6V2),
942 packs: vec!["kg".to_string()],
943 ..RuntimeConfig::default()
944 };
945 let rt = KhiveRuntime::new(config).unwrap();
946 let texts: Vec<String> = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
947 let result = tokio::runtime::Runtime::new()
948 .unwrap()
949 .block_on(rt.embed_batch(&texts));
950 let embeddings = result.unwrap();
951 assert_eq!(embeddings.len(), texts.len());
952 }
953
954 #[test]
955 fn vector_search_requires_embedding_or_text() {
956 let rt = KhiveRuntime::memory().unwrap();
957 let tok = NamespaceToken::local();
958 let result = tokio::runtime::Runtime::new()
959 .unwrap()
960 .block_on(rt.vector_search(&tok, None, None, 10, Some(SubstrateKind::Entity)));
961 match result {
962 Err(crate::RuntimeError::InvalidInput(msg)) => {
963 assert!(msg.contains("query_embedding or query_text"), "msg: {msg}");
964 }
965 other => panic!("expected InvalidInput, got {other:?}"),
966 }
967 }
968
969 #[test]
970 fn vector_search_text_without_model_returns_unconfigured() {
971 let rt = KhiveRuntime::memory().unwrap();
972 let tok = NamespaceToken::local();
973 let result = tokio::runtime::Runtime::new()
974 .unwrap()
975 .block_on(rt.vector_search(
976 &tok,
977 None,
978 Some("attention"),
979 10,
980 Some(SubstrateKind::Entity),
981 ));
982 match result {
983 Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
984 other => panic!("expected Unconfigured, got {other:?}"),
985 }
986 }
987
988 #[test]
989 #[ignore = "loads ~80 MB model; run with --include-ignored"]
990 fn embed_batch_vectors_have_expected_dimensions() {
991 let model = EmbeddingModel::AllMiniLmL6V2;
992 let config = RuntimeConfig {
993 db_path: None,
994 default_namespace: Namespace::parse("test").unwrap(),
995 embedding_model: Some(model),
996 packs: vec!["kg".to_string()],
997 ..RuntimeConfig::default()
998 };
999 let rt = KhiveRuntime::new(config).unwrap();
1000 let texts = vec!["hello world".to_string()];
1001 let result = tokio::runtime::Runtime::new()
1002 .unwrap()
1003 .block_on(rt.embed_batch(&texts));
1004 let embeddings = result.unwrap();
1005 assert_eq!(embeddings[0].len(), model.dimensions());
1006 }
1007
1008 #[tokio::test]
1011 async fn hybrid_search_entity_hit_has_title() {
1012 let rt = KhiveRuntime::memory().unwrap();
1013 let tok = NamespaceToken::local();
1014 rt.create_entity(
1015 &tok,
1016 "concept",
1017 None,
1018 "FlashAttention",
1019 Some("IO-aware exact attention using tiling"),
1020 None,
1021 vec![],
1022 )
1023 .await
1024 .unwrap();
1025
1026 let hits = rt
1027 .hybrid_search(&tok, "FlashAttention", None, 10, None, None)
1028 .await
1029 .unwrap();
1030
1031 assert!(!hits.is_empty(), "should find the entity");
1032 let hit = &hits[0];
1033 assert!(hit.title.is_some(), "title must be populated");
1034 assert!(
1035 hit.title.as_deref().unwrap().contains("FlashAttention"),
1036 "title must contain entity name"
1037 );
1038 }
1039}