1use std::collections::{HashMap, HashSet};
4
5use uuid::Uuid;
6
7use crate::error::{RuntimeError, RuntimeResult};
8use crate::runtime::{parse_embedding_model_alias, sanitize_key, KhiveRuntime, NamespaceToken};
9use khive_score::{rrf_score, DeterministicScore};
10use khive_storage::types::{
11 PageRequest, TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
12 VectorSearchRequest,
13};
14use khive_storage::EntityFilter;
15use khive_types::SubstrateKind;
16
17#[derive(Clone, Debug)]
19pub struct SearchHit {
20 pub entity_id: Uuid,
21 pub score: DeterministicScore,
22 pub source: SearchSource,
23 pub title: Option<String>,
24 pub snippet: Option<String>,
25}
26
27#[derive(Clone, Copy, Debug, PartialEq, Eq)]
29pub enum SearchSource {
30 Vector,
31 Text,
32 Both,
33}
34
35const RRF_K: usize = 10;
43
44const CANDIDATE_MULTIPLIER: u32 = 4;
46
47impl KhiveRuntime {
48 pub async fn embed(&self, text: &str) -> RuntimeResult<Vec<f32>> {
53 let model_name = self.default_embedder_name();
54 if model_name.is_empty() {
55 return Err(RuntimeError::Unconfigured("embedding_model".into()));
56 }
57 self.embed_with_model(model_name, text).await
58 }
59
60 pub async fn embed_with_model(&self, model_name: &str, text: &str) -> RuntimeResult<Vec<f32>> {
72 let model = parse_embedding_model_alias(model_name);
76 let service = self.embedder(model_name).await?;
77 let emb_model = model.unwrap_or_default();
78 Ok(service.embed_one(text, emb_model).await?)
79 }
80
81 pub async fn embed_batch(&self, texts: &[String]) -> RuntimeResult<Vec<Vec<f32>>> {
89 if texts.is_empty() {
90 return Ok(vec![]);
91 }
92 let model_name = self.default_embedder_name();
93 if model_name.is_empty() {
94 return Err(RuntimeError::Unconfigured("embedding_model".into()));
95 }
96 self.embed_batch_with_model(model_name, texts).await
97 }
98
99 pub async fn embed_batch_with_model(
104 &self,
105 model_name: &str,
106 texts: &[String],
107 ) -> RuntimeResult<Vec<Vec<f32>>> {
108 if texts.is_empty() {
109 return Ok(vec![]);
110 }
111 let model = parse_embedding_model_alias(model_name);
112 let service = self.embedder(model_name).await?;
113 let emb_model = model.unwrap_or_default();
114 Ok(service.embed(texts, emb_model).await?)
115 }
116
117 pub async fn vector_search(
123 &self,
124 token: &NamespaceToken,
125 query_embedding: Option<Vec<f32>>,
126 query_text: Option<&str>,
127 top_k: u32,
128 kind: Option<SubstrateKind>,
129 ) -> RuntimeResult<Vec<VectorSearchHit>> {
130 let embedding = match query_embedding {
131 Some(vec) => vec,
132 None => {
133 let text = query_text.ok_or_else(|| {
134 RuntimeError::InvalidInput(
135 "vector search requires query_embedding or query_text".into(),
136 )
137 })?;
138 if text.trim().is_empty() {
139 return Err(RuntimeError::InvalidInput(
140 "query_text must not be empty".into(),
141 ));
142 }
143 self.embed(text).await?
144 }
145 };
146
147 let ns = token.namespace().as_str().to_owned();
148 Ok(self
149 .vectors(token)?
150 .search(VectorSearchRequest {
151 query_vectors: vec![embedding],
152 top_k,
153 namespace: Some(ns),
154 kind,
155 embedding_model: None,
156 filter: None,
157 backend_hints: None,
158 })
159 .await?)
160 }
161
162 #[allow(clippy::too_many_arguments)]
177 pub async fn hybrid_search(
178 &self,
179 token: &NamespaceToken,
180 query_text: &str,
181 query_vector: Option<Vec<f32>>,
182 limit: u32,
183 entity_kind: Option<&str>,
184 entity_type: Option<&str>,
185 ) -> RuntimeResult<Vec<SearchHit>> {
186 let candidates = limit.saturating_mul(CANDIDATE_MULTIPLIER).max(limit);
187
188 let ns = token.namespace().as_str().to_owned();
189 let text_hits = self
190 .text(token)?
191 .search(TextSearchRequest {
192 query: query_text.to_string(),
193 mode: TextQueryMode::Plain,
194 filter: Some(TextFilter {
195 namespaces: vec![ns.clone()],
196 ..TextFilter::default()
197 }),
198 top_k: candidates,
199 snippet_chars: 200,
200 })
201 .await?;
202
203 let vector_hits = if query_vector.is_some() || self.config().embedding_model.is_some() {
204 self.vector_search(
205 token,
206 query_vector,
207 Some(query_text),
208 candidates,
209 Some(SubstrateKind::Entity),
210 )
211 .await?
212 } else {
213 Vec::new()
214 };
215
216 let mut fused = rrf_fuse(text_hits, vector_hits, candidates as usize, query_text);
219
220 if !fused.is_empty() {
224 let candidate_ids: Vec<Uuid> = fused.iter().map(|h| h.entity_id).collect();
225 let alive_page = self
226 .entities(token)?
227 .query_entities(
228 token.namespace().as_str(),
229 EntityFilter {
230 ids: candidate_ids,
231 kinds: entity_kind.map(|k| vec![k.to_string()]).unwrap_or_default(),
232 entity_types: entity_type.map(|t| vec![t.to_string()]).unwrap_or_default(),
233 ..EntityFilter::default()
234 },
235 PageRequest {
236 offset: 0,
237 limit: fused.len() as u32,
238 },
239 )
240 .await?;
241 let mut entity_meta: HashMap<Uuid, (String, Option<String>)> = HashMap::new();
243 let mut alive: HashSet<Uuid> = HashSet::new();
244 for e in alive_page.items {
245 alive.insert(e.id);
246 entity_meta.insert(e.id, (e.name, e.description));
247 }
248
249 fused.retain(|h| alive.contains(&h.entity_id));
250
251 for hit in &mut fused {
253 if let Some((name, description)) = entity_meta.get(&hit.entity_id) {
254 if hit.title.is_none() {
255 hit.title = Some(name.clone());
256 }
257 if hit.snippet.is_none() {
258 hit.snippet = description.clone();
259 }
260 }
261 }
262 }
263
264 fused.truncate(limit as usize);
265 Ok(fused)
266 }
267
268 pub async fn knn(
274 &self,
275 token: &NamespaceToken,
276 query_vector: Vec<f32>,
277 top_k: u32,
278 ) -> RuntimeResult<Vec<VectorSearchHit>> {
279 let ns = token.namespace().as_str().to_owned();
280 Ok(self
281 .vectors(token)?
282 .search(VectorSearchRequest {
283 query_vectors: vec![query_vector],
284 top_k,
285 namespace: Some(ns),
286 kind: Some(SubstrateKind::Entity),
287 embedding_model: None,
288 filter: None,
289 backend_hints: None,
290 })
291 .await?)
292 }
293
294 pub async fn rerank(
300 &self,
301 token: &NamespaceToken,
302 query_vector: &[f32],
303 candidate_ids: &[Uuid],
304 top_k: u32,
305 ) -> RuntimeResult<Vec<VectorSearchHit>> {
306 let candidate_set: HashSet<Uuid> = candidate_ids.iter().copied().collect();
307 let ns = token.namespace().as_str().to_owned();
308 let all_hits = self
309 .vectors(token)?
310 .search(VectorSearchRequest {
311 query_vectors: vec![query_vector.to_vec()],
312 top_k: candidate_ids.len() as u32,
313 namespace: Some(ns),
314 kind: Some(SubstrateKind::Entity),
315 embedding_model: None,
316 filter: None,
317 backend_hints: None,
318 })
319 .await?;
320 let mut hits: Vec<VectorSearchHit> = all_hits
321 .into_iter()
322 .filter(|h| candidate_set.contains(&h.subject_id))
323 .collect();
324 hits.sort_by(|a, b| b.score.cmp(&a.score));
325 hits.truncate(top_k as usize);
326 Ok(hits)
327 }
328
329 pub async fn backfill_missing_embeddings(&self, token: &NamespaceToken) -> RuntimeResult<u64> {
342 use khive_storage::types::{SqlRow, SqlStatement, SqlValue, TextDocument};
343
344 let model_names = self.registered_embedding_model_names();
345 if model_names.is_empty() {
346 tracing::debug!(
347 "backfill_missing_embeddings: no embedding models registered, skipping"
348 );
349 return Ok(0);
350 }
351
352 let ns = token.namespace().as_str().to_string();
353 let mut total_backfilled = 0u64;
354
355 for model_name in &model_names {
356 let vec_table = format!("vec_{}", sanitize_key(model_name));
358
359 const PAGE_SIZE: usize = 500;
364 let mut entity_total = 0usize;
365 loop {
366 let entity_sql = SqlStatement {
367 sql: format!(
368 "SELECT id, name, description FROM entities \
369 WHERE namespace = ?1 AND deleted_at IS NULL \
370 AND id NOT IN (\
371 SELECT subject_id FROM {vec_table} \
372 WHERE namespace = ?1 AND embedding_model = ?2 \
373 ) LIMIT {PAGE_SIZE}"
374 ),
375 params: vec![
376 SqlValue::Text(ns.clone()),
377 SqlValue::Text(model_name.clone()),
378 ],
379 label: Some("backfill_entities".into()),
380 };
381
382 let entity_rows: Vec<SqlRow> = {
383 let sql = self.sql();
384 match sql.reader().await {
385 Ok(mut reader) => reader.query_all(entity_sql).await.unwrap_or_default(),
386 Err(_) => vec![],
387 }
388 };
389
390 let batch_len = entity_rows.len();
391 entity_total += batch_len;
392
393 for row in &entity_rows {
394 let id_str = row.columns.first().and_then(|c| {
395 if let SqlValue::Text(s) = &c.value {
396 Some(s.clone())
397 } else {
398 None
399 }
400 });
401 let description = row.columns.get(2).and_then(|c| {
402 if let SqlValue::Text(s) = &c.value {
403 Some(s.clone())
404 } else if let SqlValue::Null = &c.value {
405 None
406 } else {
407 None
408 }
409 });
410
411 let (Some(id_str), Some(desc)) = (id_str, description) else {
412 continue;
413 };
414 let Ok(id) = id_str.parse::<Uuid>() else {
415 continue;
416 };
417 if desc.trim().is_empty() {
418 continue;
419 }
420
421 match self.embed_with_model(model_name, &desc).await {
422 Ok(vector) => {
423 if let Ok(vs) = self.vectors_for_model(token, model_name) {
424 match vs
425 .insert(
426 id,
427 SubstrateKind::Entity,
428 &ns,
429 "entity.description",
430 vec![vector],
431 )
432 .await
433 {
434 Ok(()) => {
435 total_backfilled += 1;
436 }
437 Err(e) => {
438 tracing::warn!(
439 id = %id, model = %model_name,
440 error = %e,
441 "backfill_missing_embeddings: entity vector insert failed"
442 );
443 }
444 }
445 }
446 }
447 Err(e) => {
448 tracing::warn!(
449 id = %id, model = %model_name,
450 error = %e,
451 "backfill_missing_embeddings: entity embed failed"
452 );
453 }
454 }
455 }
456
457 if batch_len < PAGE_SIZE {
458 break;
459 }
460 }
461
462 let text_store = self.text_for_notes(token).ok();
464 let mut note_total = 0usize;
465 loop {
466 let note_sql = SqlStatement {
467 sql: format!(
468 "SELECT id, content FROM notes \
469 WHERE namespace = ?1 AND deleted_at IS NULL \
470 AND id NOT IN (\
471 SELECT subject_id FROM {vec_table} \
472 WHERE namespace = ?1 AND embedding_model = ?2 \
473 ) LIMIT {PAGE_SIZE}"
474 ),
475 params: vec![
476 SqlValue::Text(ns.clone()),
477 SqlValue::Text(model_name.clone()),
478 ],
479 label: Some("backfill_notes".into()),
480 };
481
482 let note_rows: Vec<SqlRow> = {
483 let sql = self.sql();
484 match sql.reader().await {
485 Ok(mut reader) => reader.query_all(note_sql).await.unwrap_or_default(),
486 Err(_) => vec![],
487 }
488 };
489
490 let batch_len = note_rows.len();
491 note_total += batch_len;
492
493 for row in ¬e_rows {
494 let id_str = row.columns.first().and_then(|c| {
495 if let SqlValue::Text(s) = &c.value {
496 Some(s.clone())
497 } else {
498 None
499 }
500 });
501 let content = row.columns.get(1).and_then(|c| {
502 if let SqlValue::Text(s) = &c.value {
503 Some(s.clone())
504 } else {
505 None
506 }
507 });
508
509 let (Some(id_str), Some(content)) = (id_str, content) else {
510 continue;
511 };
512 let Ok(id) = id_str.parse::<Uuid>() else {
513 continue;
514 };
515 if content.trim().is_empty() {
516 continue;
517 }
518
519 if model_names.first().map(|n| n.as_str()) == Some(model_name.as_str()) {
521 if let Some(ref ts) = text_store {
522 let _ = ts
523 .upsert_document(TextDocument {
524 subject_id: id,
525 namespace: ns.clone(),
526 kind: SubstrateKind::Note,
527 title: None,
528 body: content.clone(),
529 tags: vec![],
530 metadata: None,
531 updated_at: chrono::Utc::now(),
532 })
533 .await;
534 }
535 }
536
537 match self.embed_with_model(model_name, &content).await {
538 Ok(vector) => {
539 if let Ok(vs) = self.vectors_for_model(token, model_name) {
540 match vs
541 .insert(
542 id,
543 SubstrateKind::Note,
544 &ns,
545 "note.content",
546 vec![vector],
547 )
548 .await
549 {
550 Ok(()) => {
551 total_backfilled += 1;
552 }
553 Err(e) => {
554 tracing::warn!(
555 id = %id, model = %model_name,
556 error = %e,
557 "backfill_missing_embeddings: note vector insert failed"
558 );
559 }
560 }
561 }
562 }
563 Err(e) => {
564 tracing::warn!(
565 id = %id, model = %model_name,
566 error = %e,
567 "backfill_missing_embeddings: note embed failed"
568 );
569 }
570 }
571 }
572
573 if batch_len < PAGE_SIZE {
574 break;
575 }
576 }
577
578 tracing::info!(
579 model = %model_name,
580 namespace = %ns,
581 entities = entity_total,
582 notes = note_total,
583 "backfill_missing_embeddings: model pass complete"
584 );
585 }
586
587 tracing::info!(
588 namespace = %ns,
589 total_backfilled = total_backfilled,
590 "backfill_missing_embeddings: finished"
591 );
592
593 Ok(total_backfilled)
594 }
595
596 pub async fn sweep_orphan_vectors(
611 &self,
612 token: &NamespaceToken,
613 max_delete_per_model: u32,
614 dry_run: bool,
615 ) -> RuntimeResult<u64> {
616 use khive_storage::types::OrphanSweepConfig;
617 use khive_storage::StorageError;
618
619 let model_names = self.registered_embedding_model_names();
620 if model_names.is_empty() {
621 tracing::debug!("sweep_orphan_vectors: no embedding models registered, skipping");
622 return Ok(0);
623 }
624
625 let ns = token.namespace().as_str().to_string();
626 let mut total_deleted = 0u64;
627
628 for model_name in &model_names {
629 let store = match self.vectors_for_model(token, model_name) {
630 Ok(s) => s,
631 Err(e) => {
632 tracing::warn!(
633 model = %model_name,
634 error = %e,
635 "sweep_orphan_vectors: failed to get vector store, skipping model"
636 );
637 continue;
638 }
639 };
640
641 let caps = store.capabilities();
642 if !caps.supports_orphan_sweep {
643 tracing::debug!(
644 model = %model_name,
645 "sweep_orphan_vectors: backend does not support orphan sweep, skipping"
646 );
647 continue;
648 }
649
650 let config = OrphanSweepConfig {
651 subject_id_allowlist: None,
652 namespaces: vec![ns.clone()],
653 substrate_kinds: vec![],
654 max_delete: max_delete_per_model,
655 dry_run,
656 };
657
658 match store.orphan_sweep(&config).await {
659 Ok(result) => {
660 tracing::info!(
661 model = %model_name,
662 namespace = %ns,
663 scanned = result.scanned,
664 deleted = result.deleted,
665 would_delete = result.would_delete,
666 dry_run = dry_run,
667 "sweep_orphan_vectors: sweep complete"
668 );
669 total_deleted += result.deleted;
670 }
671 Err(StorageError::Unsupported { .. }) => {
672 tracing::debug!(
673 model = %model_name,
674 "sweep_orphan_vectors: backend returned Unsupported, skipping"
675 );
676 }
677 Err(e) => {
678 tracing::warn!(
679 model = %model_name,
680 error = %e,
681 "sweep_orphan_vectors: sweep failed, continuing with other models"
682 );
683 }
684 }
685 }
686
687 tracing::info!(
688 namespace = %ns,
689 total_deleted = total_deleted,
690 dry_run = dry_run,
691 "sweep_orphan_vectors: finished"
692 );
693
694 Ok(total_deleted)
695 }
696}
697
698const EXACT_MATCH_BOOST: f64 = 0.5;
702
703fn rrf_fuse(
711 text_hits: Vec<TextSearchHit>,
712 vector_hits: Vec<VectorSearchHit>,
713 limit: usize,
714 query_text: &str,
715) -> Vec<SearchHit> {
716 #[derive(Default)]
717 struct Bucket {
718 score: DeterministicScore,
719 source: Option<SearchSource>,
720 title: Option<String>,
721 snippet: Option<String>,
722 }
723
724 let mut buckets: HashMap<Uuid, Bucket> = HashMap::new();
725
726 let query_lower = query_text.to_lowercase();
727 for (i, hit) in text_hits.into_iter().enumerate() {
728 let rank = i + 1; let entry = buckets.entry(hit.subject_id).or_default();
730 entry.score = entry.score + rrf_score(rank, RRF_K);
731 entry.source = Some(match entry.source {
732 Some(SearchSource::Vector) => SearchSource::Both,
733 _ => SearchSource::Text,
734 });
735 if entry.title.is_none() {
736 if let Some(ref title) = hit.title {
738 if title.to_lowercase() == query_lower {
739 entry.score = entry.score + DeterministicScore::from_f64(EXACT_MATCH_BOOST);
740 }
741 }
742 entry.title = hit.title;
743 }
744 if entry.snippet.is_none() {
745 entry.snippet = hit.snippet;
746 }
747 }
748
749 for (i, hit) in vector_hits.into_iter().enumerate() {
750 let rank = i + 1;
751 let entry = buckets.entry(hit.subject_id).or_default();
752 entry.score = entry.score + rrf_score(rank, RRF_K);
753 entry.source = Some(match entry.source {
754 Some(SearchSource::Text) => SearchSource::Both,
755 _ => SearchSource::Vector,
756 });
757 }
758
759 let mut hits: Vec<SearchHit> = buckets
760 .into_iter()
761 .map(|(id, b)| SearchHit {
762 entity_id: id,
763 score: b.score,
764 source: b.source.expect("each bucket gets a source"),
765 title: b.title,
766 snippet: b.snippet,
767 })
768 .collect();
769
770 hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
771 hits.truncate(limit);
772 hits
773}
774
775#[cfg(test)]
776mod tests {
777 use super::*;
778 use crate::runtime::{KhiveRuntime, NamespaceToken, RuntimeConfig};
779 use khive_storage::types::{TextSearchHit, VectorSearchHit};
780 use khive_types::namespace::Namespace;
781 use lattice_embed::EmbeddingModel;
782
783 fn text_hit(id: Uuid, rank: u32, title: &str) -> TextSearchHit {
784 TextSearchHit {
785 subject_id: id,
786 score: DeterministicScore::from_f64(1.0),
787 rank,
788 title: Some(title.to_string()),
789 snippet: Some("...".to_string()),
790 }
791 }
792
793 fn vector_hit(id: Uuid, rank: u32) -> VectorSearchHit {
794 VectorSearchHit {
795 subject_id: id,
796 score: DeterministicScore::from_f64(0.9),
797 rank,
798 }
799 }
800
801 #[test]
802 fn rrf_fuse_text_only() {
803 let a = Uuid::new_v4();
804 let b = Uuid::new_v4();
805 let text = vec![text_hit(a, 1, "A"), text_hit(b, 2, "B")];
806 let hits = rrf_fuse(text, vec![], 10, "query");
807 assert_eq!(hits.len(), 2);
808 assert_eq!(hits[0].entity_id, a);
809 assert_eq!(hits[0].source, SearchSource::Text);
810 assert_eq!(hits[0].title.as_deref(), Some("A"));
811 }
812
813 #[test]
814 fn rrf_fuse_vector_only() {
815 let a = Uuid::new_v4();
816 let hits = rrf_fuse(vec![], vec![vector_hit(a, 1)], 10, "query");
817 assert_eq!(hits.len(), 1);
818 assert_eq!(hits[0].source, SearchSource::Vector);
819 assert!(hits[0].title.is_none());
820 }
821
822 #[test]
823 fn rrf_fuse_marks_both_when_in_both_lists() {
824 let id = Uuid::new_v4();
825 let text = vec![text_hit(id, 1, "A")];
826 let vec = vec![vector_hit(id, 1)];
827 let hits = rrf_fuse(text, vec, 10, "query");
828 assert_eq!(hits.len(), 1);
829 assert_eq!(hits[0].source, SearchSource::Both);
830 }
831
832 #[test]
833 fn rrf_fuse_respects_limit() {
834 let hits: Vec<TextSearchHit> = (0..20)
835 .map(|i| text_hit(Uuid::new_v4(), i + 1, "x"))
836 .collect();
837 let fused = rrf_fuse(hits, vec![], 5, "query");
838 assert_eq!(fused.len(), 5);
839 }
840
841 #[test]
842 fn rrf_fuse_orders_higher_score_first() {
843 let a = Uuid::new_v4();
845 let b = Uuid::new_v4();
846 let text = vec![text_hit(a, 1, "A")];
847 let vec = vec![vector_hit(a, 1), vector_hit(b, 2)];
848 let hits = rrf_fuse(text, vec, 10, "query");
849 assert_eq!(hits[0].entity_id, a);
850 assert_eq!(hits[0].source, SearchSource::Both);
851 assert!(hits[0].score > hits[1].score);
852 }
853
854 #[test]
855 fn rrf_fuse_k10_score_spread_exceeds_threshold() {
856 let ids: Vec<Uuid> = (0..10).map(|_| Uuid::new_v4()).collect();
859 let text: Vec<TextSearchHit> = ids
860 .iter()
861 .enumerate()
862 .map(|(i, &id)| text_hit(id, (i + 1) as u32, "x"))
863 .collect();
864 let hits = rrf_fuse(text, vec![], 10, "query");
865 assert_eq!(hits.len(), 10);
866 let top_score = hits[0].score.to_f64();
867 let bottom_score = hits[9].score.to_f64();
868 let spread = top_score - bottom_score;
869 assert!(
870 spread >= 0.03,
871 "score spread {spread:.4} between rank 1 and rank 10 must be ≥ 0.03 (was {spread:.4})"
872 );
873 }
874
875 #[test]
876 fn rrf_fuse_exact_match_boost_elevates_score() {
877 let exact_id = Uuid::new_v4();
880 let other_id = Uuid::new_v4();
881 let text = vec![
883 text_hit(other_id, 1, "something else"),
884 text_hit(exact_id, 2, "FlashAttention"),
885 ];
886 let hits = rrf_fuse(text, vec![], 10, "flashattention");
887 assert_eq!(hits.len(), 2);
888 assert_eq!(
889 hits[0].entity_id, exact_id,
890 "exact match must rank first despite being rank-2 in raw text search"
891 );
892 }
893
894 #[test]
897 fn embed_batch_unconfigured_on_memory_runtime() {
898 let rt = KhiveRuntime::memory().unwrap();
900 let result = tokio::runtime::Runtime::new()
901 .unwrap()
902 .block_on(rt.embed_batch(&[]));
903 assert!(result.is_ok());
905 assert!(result.unwrap().is_empty());
906 }
907
908 #[test]
909 fn embed_batch_empty_input_returns_empty_vec() {
910 let rt = KhiveRuntime::memory().unwrap();
912 let result = tokio::runtime::Runtime::new()
913 .unwrap()
914 .block_on(rt.embed_batch(&[]));
915 assert_eq!(result.unwrap(), Vec::<Vec<f32>>::new());
916 }
917
918 #[test]
919 fn embed_batch_no_model_non_empty_returns_unconfigured() {
920 let rt = KhiveRuntime::memory().unwrap();
921 let texts = vec!["hello".to_string()];
922 let result = tokio::runtime::Runtime::new()
923 .unwrap()
924 .block_on(rt.embed_batch(&texts));
925 match result {
926 Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
927 Err(other) => panic!("expected Unconfigured, got {:?}", other),
928 Ok(_) => panic!("expected Err, got Ok"),
929 }
930 }
931
932 #[test]
933 #[ignore = "loads ~80 MB model; run with --include-ignored"]
934 fn embed_batch_count_matches_input() {
935 let config = RuntimeConfig {
936 db_path: None,
937 default_namespace: Namespace::parse("test").unwrap(),
938 embedding_model: Some(EmbeddingModel::AllMiniLmL6V2),
939 packs: vec!["kg".to_string()],
940 ..RuntimeConfig::default()
941 };
942 let rt = KhiveRuntime::new(config).unwrap();
943 let texts: Vec<String> = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
944 let result = tokio::runtime::Runtime::new()
945 .unwrap()
946 .block_on(rt.embed_batch(&texts));
947 let embeddings = result.unwrap();
948 assert_eq!(embeddings.len(), texts.len());
949 }
950
951 #[test]
952 fn vector_search_requires_embedding_or_text() {
953 let rt = KhiveRuntime::memory().unwrap();
954 let tok = NamespaceToken::local();
955 let result = tokio::runtime::Runtime::new()
956 .unwrap()
957 .block_on(rt.vector_search(&tok, None, None, 10, Some(SubstrateKind::Entity)));
958 match result {
959 Err(crate::RuntimeError::InvalidInput(msg)) => {
960 assert!(msg.contains("query_embedding or query_text"), "msg: {msg}");
961 }
962 other => panic!("expected InvalidInput, got {other:?}"),
963 }
964 }
965
966 #[test]
967 fn vector_search_text_without_model_returns_unconfigured() {
968 let rt = KhiveRuntime::memory().unwrap();
969 let tok = NamespaceToken::local();
970 let result = tokio::runtime::Runtime::new()
971 .unwrap()
972 .block_on(rt.vector_search(
973 &tok,
974 None,
975 Some("attention"),
976 10,
977 Some(SubstrateKind::Entity),
978 ));
979 match result {
980 Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
981 other => panic!("expected Unconfigured, got {other:?}"),
982 }
983 }
984
985 #[test]
986 #[ignore = "loads ~80 MB model; run with --include-ignored"]
987 fn embed_batch_vectors_have_expected_dimensions() {
988 let model = EmbeddingModel::AllMiniLmL6V2;
989 let config = RuntimeConfig {
990 db_path: None,
991 default_namespace: Namespace::parse("test").unwrap(),
992 embedding_model: Some(model),
993 packs: vec!["kg".to_string()],
994 ..RuntimeConfig::default()
995 };
996 let rt = KhiveRuntime::new(config).unwrap();
997 let texts = vec!["hello world".to_string()];
998 let result = tokio::runtime::Runtime::new()
999 .unwrap()
1000 .block_on(rt.embed_batch(&texts));
1001 let embeddings = result.unwrap();
1002 assert_eq!(embeddings[0].len(), model.dimensions());
1003 }
1004
1005 #[tokio::test]
1008 async fn hybrid_search_entity_hit_has_title() {
1009 let rt = KhiveRuntime::memory().unwrap();
1010 let tok = NamespaceToken::local();
1011 rt.create_entity(
1012 &tok,
1013 "concept",
1014 None,
1015 "FlashAttention",
1016 Some("IO-aware exact attention using tiling"),
1017 None,
1018 vec![],
1019 )
1020 .await
1021 .unwrap();
1022
1023 let hits = rt
1024 .hybrid_search(&tok, "FlashAttention", None, 10, None, None)
1025 .await
1026 .unwrap();
1027
1028 assert!(!hits.is_empty(), "should find the entity");
1029 let hit = &hits[0];
1030 assert!(hit.title.is_some(), "title must be populated");
1031 assert!(
1032 hit.title.as_deref().unwrap().contains("FlashAttention"),
1033 "title must contain entity name"
1034 );
1035 }
1036}