Skip to main content

khive_runtime/
retrieval.rs

1//! Retrieval operations: local embedding generation and hybrid search with RRF fusion.
2//!
3//! See ADR-012 — Retrieval Architecture.
4
5use std::collections::{HashMap, HashSet};
6
7use uuid::Uuid;
8
9use crate::error::{RuntimeError, RuntimeResult};
10use crate::runtime::{parse_embedding_model_alias, sanitize_key, KhiveRuntime, NamespaceToken};
11use khive_score::{rrf_score, DeterministicScore};
12use khive_storage::types::{
13    PageRequest, TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
14    VectorSearchRequest,
15};
16use khive_storage::EntityFilter;
17use khive_types::SubstrateKind;
18
19/// A unified search result combining vector and text signals.
20#[derive(Clone, Debug)]
21pub struct SearchHit {
22    pub entity_id: Uuid,
23    pub score: DeterministicScore,
24    pub source: SearchSource,
25    pub title: Option<String>,
26    pub snippet: Option<String>,
27}
28
29/// Which retrieval path(s) contributed to a hit.
30#[derive(Clone, Copy, Debug, PartialEq, Eq)]
31pub enum SearchSource {
32    Vector,
33    Text,
34    Both,
35}
36
37/// RRF constant. Controls how strongly top ranks dominate.
38///
39/// The original paper uses k=60 for large-scale document retrieval. For a knowledge
40/// graph with tens to thousands of entities, k=60 over-compresses scores into a
41/// narrow band (rank 1 ≈ 0.016, rank 10 ≈ 0.014, spread ≈ 0.002). k=10 produces
42/// rank 1 ≈ 0.091, rank 10 ≈ 0.050, spread ≈ 0.041 — 20× better discrimination,
43/// making dedup-before-create reliable at graph sizes of 50–2700 entities.
44const RRF_K: usize = 10;
45
46/// Candidates pulled per path before fusion. Higher = better recall, more work.
47const CANDIDATE_MULTIPLIER: u32 = 4;
48
49impl KhiveRuntime {
50    /// Generate an embedding vector for `text` using the configured default model.
51    ///
52    /// First call lazily loads model weights (cold start cost). Subsequent calls reuse them.
53    /// Returns `Unconfigured("embedding_model")` if no model is configured.
54    pub async fn embed(&self, text: &str) -> RuntimeResult<Vec<f32>> {
55        let model_name = self.default_embedder_name();
56        if model_name.is_empty() {
57            return Err(RuntimeError::Unconfigured("embedding_model".into()));
58        }
59        self.embed_with_model(model_name, text).await
60    }
61
62    /// Generate an embedding vector for `text` using the named model.
63    ///
64    /// Accepts both built-in lattice model names/aliases and custom provider
65    /// names registered via [`KhiveRuntime::register_embedder`]. For lattice
66    /// models the resolved `EmbeddingModel` enum is forwarded to `embed_one`
67    /// so the service can select the correct model variant. For custom
68    /// providers, `embed_one` is called with `EmbeddingModel::default()`
69    /// because custom services are expected to ignore the enum argument (they
70    /// own a single model implicitly).
71    ///
72    /// Returns `UnknownModel` if `model_name` is not in the embedder registry.
73    pub async fn embed_with_model(&self, model_name: &str, text: &str) -> RuntimeResult<Vec<f32>> {
74        // Try to resolve as a lattice alias. If that succeeds, use the enum to
75        // inform the service which model to run. If not, fall through to the
76        // custom-provider path — custom services ignore the EmbeddingModel arg.
77        let model = parse_embedding_model_alias(model_name);
78        let service = self.embedder(model_name).await?;
79        let emb_model = model.unwrap_or_default();
80        Ok(service.embed_one(text, emb_model).await?)
81    }
82
83    /// Generate embeddings for multiple texts in one call using the configured default model.
84    ///
85    /// Delegates to the cached `EmbeddingService::embed`, so repeated texts within
86    /// and across calls benefit from the runtime-level LRU cache.
87    ///
88    /// Returns an empty vec for empty input without hitting the embedding service.
89    /// Returns `Unconfigured("embedding_model")` if no model is configured.
90    pub async fn embed_batch(&self, texts: &[String]) -> RuntimeResult<Vec<Vec<f32>>> {
91        if texts.is_empty() {
92            return Ok(vec![]);
93        }
94        let model_name = self.default_embedder_name();
95        if model_name.is_empty() {
96            return Err(RuntimeError::Unconfigured("embedding_model".into()));
97        }
98        self.embed_batch_with_model(model_name, texts).await
99    }
100
101    /// Generate embeddings for multiple texts using the named model.
102    ///
103    /// Accepts lattice model names/aliases and custom provider names.
104    /// Returns `UnknownModel` if `model_name` is not in the embedder registry.
105    pub async fn embed_batch_with_model(
106        &self,
107        model_name: &str,
108        texts: &[String],
109    ) -> RuntimeResult<Vec<Vec<f32>>> {
110        if texts.is_empty() {
111            return Ok(vec![]);
112        }
113        let model = parse_embedding_model_alias(model_name);
114        let service = self.embedder(model_name).await?;
115        let emb_model = model.unwrap_or_default();
116        Ok(service.embed(texts, emb_model).await?)
117    }
118
119    /// Search vectors using either a caller-provided embedding or query text.
120    ///
121    /// Existing callers pass `query_embedding: Some(vec)` to avoid re-embedding.
122    /// Text callers pass `query_embedding: None, query_text: Some(...)` and the
123    /// runtime embeds internally.
124    pub async fn vector_search(
125        &self,
126        token: &NamespaceToken,
127        query_embedding: Option<Vec<f32>>,
128        query_text: Option<&str>,
129        top_k: u32,
130        kind: Option<SubstrateKind>,
131    ) -> RuntimeResult<Vec<VectorSearchHit>> {
132        let embedding = match query_embedding {
133            Some(vec) => vec,
134            None => {
135                let text = query_text.ok_or_else(|| {
136                    RuntimeError::InvalidInput(
137                        "vector search requires query_embedding or query_text".into(),
138                    )
139                })?;
140                if text.trim().is_empty() {
141                    return Err(RuntimeError::InvalidInput(
142                        "query_text must not be empty".into(),
143                    ));
144                }
145                self.embed(text).await?
146            }
147        };
148
149        let ns = token.namespace().as_str().to_owned();
150        Ok(self
151            .vectors(token)?
152            .search(VectorSearchRequest {
153                query_vectors: vec![embedding],
154                top_k,
155                namespace: Some(ns),
156                kind,
157                embedding_model: None,
158                filter: None,
159                backend_hints: None,
160            })
161            .await?)
162    }
163
164    /// Hybrid search: text (FTS5) + vector retrieval fused via Reciprocal Rank Fusion.
165    ///
166    /// - Always performs text search over `query_text`.
167    /// - If `query_vector` is `Some`, also performs vector search and fuses both lists.
168    /// - If `None`, returns text-only results — no vector store needed.
169    /// - If `entity_kind` is `Some`, the alive-set query filters to that kind.
170    ///   The text/vector candidate pools are unfiltered up front; the kind
171    ///   filter applies at the alive-check stage where we already fetch each
172    ///   candidate to confirm it isn't soft-deleted.
173    ///
174    /// `limit` caps the final returned list; internally pulls `limit * 4` candidates per path.
175    /// The fused candidate set is kept untruncated until after the alive + kind filter so
176    /// that right-kind hits ranked below `limit` in the raw fusion still surface when
177    /// higher-ranked candidates are wrong-kind or soft-deleted.
178    #[allow(clippy::too_many_arguments)]
179    pub async fn hybrid_search(
180        &self,
181        token: &NamespaceToken,
182        query_text: &str,
183        query_vector: Option<Vec<f32>>,
184        limit: u32,
185        entity_kind: Option<&str>,
186        entity_type: Option<&str>,
187    ) -> RuntimeResult<Vec<SearchHit>> {
188        let candidates = limit.saturating_mul(CANDIDATE_MULTIPLIER).max(limit);
189
190        let ns = token.namespace().as_str().to_owned();
191        let text_hits = self
192            .text(token)?
193            .search(TextSearchRequest {
194                query: query_text.to_string(),
195                mode: TextQueryMode::Plain,
196                filter: Some(TextFilter {
197                    namespaces: vec![ns.clone()],
198                    ..TextFilter::default()
199                }),
200                top_k: candidates,
201                snippet_chars: 200,
202            })
203            .await?;
204
205        let vector_hits = if query_vector.is_some() || self.config().embedding_model.is_some() {
206            self.vector_search(
207                token,
208                query_vector,
209                Some(query_text),
210                candidates,
211                Some(SubstrateKind::Entity),
212            )
213            .await?
214        } else {
215            Vec::new()
216        };
217
218        // Fuse without truncating: keep the full candidate pool through the
219        // alive/kind filter so right-kind hits below rank `limit` aren't lost.
220        let mut fused = rrf_fuse(text_hits, vector_hits, candidates as usize, query_text);
221
222        // Filter to alive entities (and optionally to a specific kind). A single
223        // query fetches all alive IDs that match the kind constraint from the
224        // fused set; any ID absent has been soft-deleted or doesn't match.
225        if !fused.is_empty() {
226            let candidate_ids: Vec<Uuid> = fused.iter().map(|h| h.entity_id).collect();
227            let alive_page = self
228                .entities(token)?
229                .query_entities(
230                    token.namespace().as_str(),
231                    EntityFilter {
232                        ids: candidate_ids,
233                        kinds: entity_kind.map(|k| vec![k.to_string()]).unwrap_or_default(),
234                        entity_types: entity_type.map(|t| vec![t.to_string()]).unwrap_or_default(),
235                        ..EntityFilter::default()
236                    },
237                    PageRequest {
238                        offset: 0,
239                        limit: fused.len() as u32,
240                    },
241                )
242                .await?;
243            // Keep entity metadata to enrich hits that had no FTS5 title/snippet.
244            let mut entity_meta: HashMap<Uuid, (String, Option<String>)> = HashMap::new();
245            let mut alive: HashSet<Uuid> = HashSet::new();
246            for e in alive_page.items {
247                alive.insert(e.id);
248                entity_meta.insert(e.id, (e.name, e.description));
249            }
250
251            fused.retain(|h| alive.contains(&h.entity_id));
252
253            // Enrich vector-only hits (title/snippet == None) from entity record.
254            for hit in &mut fused {
255                if let Some((name, description)) = entity_meta.get(&hit.entity_id) {
256                    if hit.title.is_none() {
257                        hit.title = Some(name.clone());
258                    }
259                    if hit.snippet.is_none() {
260                        hit.snippet = description.clone();
261                    }
262                }
263            }
264        }
265
266        fused.truncate(limit as usize);
267        Ok(fused)
268    }
269
270    /// Exact KNN over the full namespace's vector store.
271    ///
272    /// sqlite-vec uses brute-force cosine — results are exact, not approximate.
273    /// Cost is O(N · D) per query. For small-to-medium namespaces (~hundreds of
274    /// thousands of vectors) this is well within latency budgets.
275    pub async fn knn(
276        &self,
277        token: &NamespaceToken,
278        query_vector: Vec<f32>,
279        top_k: u32,
280    ) -> RuntimeResult<Vec<VectorSearchHit>> {
281        let ns = token.namespace().as_str().to_owned();
282        Ok(self
283            .vectors(token)?
284            .search(VectorSearchRequest {
285                query_vectors: vec![query_vector],
286                top_k,
287                namespace: Some(ns),
288                kind: Some(SubstrateKind::Entity),
289                embedding_model: None,
290                filter: None,
291                backend_hints: None,
292            })
293            .await?)
294    }
295
296    /// Exact KNN restricted to a candidate set.
297    ///
298    /// Useful for reranking the top-N results from `hybrid_search` (or any other
299    /// retrieval path) with exact cosine similarity against a query vector.
300    /// Returns hits sorted by similarity (highest first), truncated to `top_k`.
301    pub async fn rerank(
302        &self,
303        token: &NamespaceToken,
304        query_vector: &[f32],
305        candidate_ids: &[Uuid],
306        top_k: u32,
307    ) -> RuntimeResult<Vec<VectorSearchHit>> {
308        let candidate_set: HashSet<Uuid> = candidate_ids.iter().copied().collect();
309        let ns = token.namespace().as_str().to_owned();
310        let all_hits = self
311            .vectors(token)?
312            .search(VectorSearchRequest {
313                query_vectors: vec![query_vector.to_vec()],
314                top_k: candidate_ids.len() as u32,
315                namespace: Some(ns),
316                kind: Some(SubstrateKind::Entity),
317                embedding_model: None,
318                filter: None,
319                backend_hints: None,
320            })
321            .await?;
322        let mut hits: Vec<VectorSearchHit> = all_hits
323            .into_iter()
324            .filter(|h| candidate_set.contains(&h.subject_id))
325            .collect();
326        hits.sort_by(|a, b| b.score.cmp(&a.score));
327        hits.truncate(top_k as usize);
328        Ok(hits)
329    }
330
331    /// Backfill vector and FTS index entries for entities and notes that are missing them.
332    ///
333    /// Intended to run once at startup as a background task (ADR-043 §8, steps 2–4).
334    /// Queries the SQL substrate for entity descriptions and note contents that have no
335    /// corresponding entry in the vector store for any registered embedding model, then
336    /// embeds and inserts them. FTS entries missing for notes are also repopulated.
337    ///
338    /// The operation is best-effort: individual embed/insert failures are logged and
339    /// skipped rather than aborting the whole backfill. If no embedding models are
340    /// registered, returns immediately with 0.
341    ///
342    /// Returns the total number of records backfilled across all models.
343    pub async fn backfill_missing_embeddings(&self, token: &NamespaceToken) -> RuntimeResult<u64> {
344        use khive_storage::types::{SqlRow, SqlStatement, SqlValue, TextDocument};
345
346        let model_names = self.registered_embedding_model_names();
347        if model_names.is_empty() {
348            tracing::debug!(
349                "backfill_missing_embeddings: no embedding models registered, skipping"
350            );
351            return Ok(0);
352        }
353
354        let ns = token.namespace().as_str().to_string();
355        let mut total_backfilled = 0u64;
356
357        for model_name in &model_names {
358            // Derive the vec table name from the model name (must match vec_model_key logic).
359            let vec_table = format!("vec_{}", sanitize_key(model_name));
360
361            // --- Entities: embed description where no vector entry exists ---
362            // Loop until a batch returns fewer than PAGE_SIZE rows. Because the query uses
363            // NOT IN (SELECT subject_id FROM vec_table ...), each successfully inserted row is
364            // excluded from subsequent pages — no OFFSET needed.
365            const PAGE_SIZE: usize = 500;
366            let mut entity_total = 0usize;
367            loop {
368                let entity_sql = SqlStatement {
369                    sql: format!(
370                        "SELECT id, name, description FROM entities \
371                         WHERE namespace = ?1 AND deleted_at IS NULL \
372                         AND id NOT IN (\
373                             SELECT subject_id FROM {vec_table} \
374                             WHERE namespace = ?1 AND embedding_model = ?2 \
375                         ) LIMIT {PAGE_SIZE}"
376                    ),
377                    params: vec![
378                        SqlValue::Text(ns.clone()),
379                        SqlValue::Text(model_name.clone()),
380                    ],
381                    label: Some("backfill_entities".into()),
382                };
383
384                let entity_rows: Vec<SqlRow> = {
385                    let sql = self.sql();
386                    match sql.reader().await {
387                        Ok(mut reader) => reader.query_all(entity_sql).await.unwrap_or_default(),
388                        Err(_) => vec![],
389                    }
390                };
391
392                let batch_len = entity_rows.len();
393                entity_total += batch_len;
394
395                for row in &entity_rows {
396                    let id_str = row.columns.first().and_then(|c| {
397                        if let SqlValue::Text(s) = &c.value {
398                            Some(s.clone())
399                        } else {
400                            None
401                        }
402                    });
403                    let description = row.columns.get(2).and_then(|c| {
404                        if let SqlValue::Text(s) = &c.value {
405                            Some(s.clone())
406                        } else if let SqlValue::Null = &c.value {
407                            None
408                        } else {
409                            None
410                        }
411                    });
412
413                    let (Some(id_str), Some(desc)) = (id_str, description) else {
414                        continue;
415                    };
416                    let Ok(id) = id_str.parse::<Uuid>() else {
417                        continue;
418                    };
419                    if desc.trim().is_empty() {
420                        continue;
421                    }
422
423                    match self.embed_with_model(model_name, &desc).await {
424                        Ok(vector) => {
425                            if let Ok(vs) = self.vectors_for_model(token, model_name) {
426                                match vs
427                                    .insert(
428                                        id,
429                                        SubstrateKind::Entity,
430                                        &ns,
431                                        "entity.description",
432                                        vec![vector],
433                                    )
434                                    .await
435                                {
436                                    Ok(()) => {
437                                        total_backfilled += 1;
438                                    }
439                                    Err(e) => {
440                                        tracing::warn!(
441                                            id = %id, model = %model_name,
442                                            error = %e,
443                                            "backfill_missing_embeddings: entity vector insert failed"
444                                        );
445                                    }
446                                }
447                            }
448                        }
449                        Err(e) => {
450                            tracing::warn!(
451                                id = %id, model = %model_name,
452                                error = %e,
453                                "backfill_missing_embeddings: entity embed failed"
454                            );
455                        }
456                    }
457                }
458
459                if batch_len < PAGE_SIZE {
460                    break;
461                }
462            }
463
464            // --- Notes: embed content where no vector entry exists ---
465            let text_store = self.text_for_notes(token).ok();
466            let mut note_total = 0usize;
467            loop {
468                let note_sql = SqlStatement {
469                    sql: format!(
470                        "SELECT id, content FROM notes \
471                         WHERE namespace = ?1 AND deleted_at IS NULL \
472                         AND id NOT IN (\
473                             SELECT subject_id FROM {vec_table} \
474                             WHERE namespace = ?1 AND embedding_model = ?2 \
475                         ) LIMIT {PAGE_SIZE}"
476                    ),
477                    params: vec![
478                        SqlValue::Text(ns.clone()),
479                        SqlValue::Text(model_name.clone()),
480                    ],
481                    label: Some("backfill_notes".into()),
482                };
483
484                let note_rows: Vec<SqlRow> = {
485                    let sql = self.sql();
486                    match sql.reader().await {
487                        Ok(mut reader) => reader.query_all(note_sql).await.unwrap_or_default(),
488                        Err(_) => vec![],
489                    }
490                };
491
492                let batch_len = note_rows.len();
493                note_total += batch_len;
494
495                for row in &note_rows {
496                    let id_str = row.columns.first().and_then(|c| {
497                        if let SqlValue::Text(s) = &c.value {
498                            Some(s.clone())
499                        } else {
500                            None
501                        }
502                    });
503                    let content = row.columns.get(1).and_then(|c| {
504                        if let SqlValue::Text(s) = &c.value {
505                            Some(s.clone())
506                        } else {
507                            None
508                        }
509                    });
510
511                    let (Some(id_str), Some(content)) = (id_str, content) else {
512                        continue;
513                    };
514                    let Ok(id) = id_str.parse::<Uuid>() else {
515                        continue;
516                    };
517                    if content.trim().is_empty() {
518                        continue;
519                    }
520
521                    // Repopulate FTS entry if missing (best-effort, first model only to avoid N duplicates).
522                    if model_names.first().map(|n| n.as_str()) == Some(model_name.as_str()) {
523                        if let Some(ref ts) = text_store {
524                            let _ = ts
525                                .upsert_document(TextDocument {
526                                    subject_id: id,
527                                    namespace: ns.clone(),
528                                    kind: SubstrateKind::Note,
529                                    title: None,
530                                    body: content.clone(),
531                                    tags: vec![],
532                                    metadata: None,
533                                    updated_at: chrono::Utc::now(),
534                                })
535                                .await;
536                        }
537                    }
538
539                    match self.embed_with_model(model_name, &content).await {
540                        Ok(vector) => {
541                            if let Ok(vs) = self.vectors_for_model(token, model_name) {
542                                match vs
543                                    .insert(
544                                        id,
545                                        SubstrateKind::Note,
546                                        &ns,
547                                        "note.content",
548                                        vec![vector],
549                                    )
550                                    .await
551                                {
552                                    Ok(()) => {
553                                        total_backfilled += 1;
554                                    }
555                                    Err(e) => {
556                                        tracing::warn!(
557                                            id = %id, model = %model_name,
558                                            error = %e,
559                                            "backfill_missing_embeddings: note vector insert failed"
560                                        );
561                                    }
562                                }
563                            }
564                        }
565                        Err(e) => {
566                            tracing::warn!(
567                                id = %id, model = %model_name,
568                                error = %e,
569                                "backfill_missing_embeddings: note embed failed"
570                            );
571                        }
572                    }
573                }
574
575                if batch_len < PAGE_SIZE {
576                    break;
577                }
578            }
579
580            tracing::info!(
581                model = %model_name,
582                namespace = %ns,
583                entities = entity_total,
584                notes = note_total,
585                "backfill_missing_embeddings: model pass complete"
586            );
587        }
588
589        tracing::info!(
590            namespace = %ns,
591            total_backfilled = total_backfilled,
592            "backfill_missing_embeddings: finished"
593        );
594
595        Ok(total_backfilled)
596    }
597
598    /// Sweep orphaned vector entries for all registered embedding models.
599    ///
600    /// A vector entry is orphaned when its `subject_id` no longer exists as a
601    /// live row in the entity or note tables (i.e. either the row is absent or
602    /// has `deleted_at IS NOT NULL`). Orphaned entries accumulate after
603    /// hard-deletes because the vector store and SQL substrate are decoupled
604    /// (ADR-044 §rationale).
605    ///
606    /// Iterates over every registered embedding model and calls
607    /// [`VectorStore::orphan_sweep`] for the token's namespace. Models whose
608    /// backend returns [`StorageError::Unsupported`] are skipped without error —
609    /// this preserves forward-compat when a newly registered model does not yet
610    /// implement sweep.
611    ///
612    /// Returns the total number of vector rows deleted across all models.
613    pub async fn sweep_orphan_vectors(
614        &self,
615        token: &NamespaceToken,
616        max_delete_per_model: u32,
617        dry_run: bool,
618    ) -> RuntimeResult<u64> {
619        use khive_storage::types::OrphanSweepConfig;
620        use khive_storage::StorageError;
621
622        let model_names = self.registered_embedding_model_names();
623        if model_names.is_empty() {
624            tracing::debug!("sweep_orphan_vectors: no embedding models registered, skipping");
625            return Ok(0);
626        }
627
628        let ns = token.namespace().as_str().to_string();
629        let mut total_deleted = 0u64;
630
631        for model_name in &model_names {
632            let store = match self.vectors_for_model(token, model_name) {
633                Ok(s) => s,
634                Err(e) => {
635                    tracing::warn!(
636                        model = %model_name,
637                        error = %e,
638                        "sweep_orphan_vectors: failed to get vector store, skipping model"
639                    );
640                    continue;
641                }
642            };
643
644            let caps = store.capabilities();
645            if !caps.supports_orphan_sweep {
646                tracing::debug!(
647                    model = %model_name,
648                    "sweep_orphan_vectors: backend does not support orphan sweep, skipping"
649                );
650                continue;
651            }
652
653            let config = OrphanSweepConfig {
654                subject_id_allowlist: None,
655                namespaces: vec![ns.clone()],
656                substrate_kinds: vec![],
657                max_delete: max_delete_per_model,
658                dry_run,
659            };
660
661            match store.orphan_sweep(&config).await {
662                Ok(result) => {
663                    tracing::info!(
664                        model = %model_name,
665                        namespace = %ns,
666                        scanned = result.scanned,
667                        deleted = result.deleted,
668                        would_delete = result.would_delete,
669                        dry_run = dry_run,
670                        "sweep_orphan_vectors: sweep complete"
671                    );
672                    total_deleted += result.deleted;
673                }
674                Err(StorageError::Unsupported { .. }) => {
675                    tracing::debug!(
676                        model = %model_name,
677                        "sweep_orphan_vectors: backend returned Unsupported, skipping"
678                    );
679                }
680                Err(e) => {
681                    tracing::warn!(
682                        model = %model_name,
683                        error = %e,
684                        "sweep_orphan_vectors: sweep failed, continuing with other models"
685                    );
686                }
687            }
688        }
689
690        tracing::info!(
691            namespace = %ns,
692            total_deleted = total_deleted,
693            dry_run = dry_run,
694            "sweep_orphan_vectors: finished"
695        );
696
697        Ok(total_deleted)
698    }
699}
700
701/// Score bonus applied when an entity's title is an exact case-insensitive match for
702/// the query. Dominates RRF scores (~0.09–0.18 range with k=10) so that an exact
703/// name match always ranks above any partial or semantic match.
704const EXACT_MATCH_BOOST: f64 = 0.5;
705
706/// Fuse text + vector hits with Reciprocal Rank Fusion (k=10).
707///
708/// Entity search stays local because it uses k=10 plus exact-match boosting.
709/// Hits in both lists get RRF scores summed. If `query_text` exactly matches
710/// (case-insensitive) an entity's title from the text hits, a bonus of
711/// `EXACT_MATCH_BOOST` is added to ensure exact-name matches dominate.
712/// Sort by fused score, take top-`limit`.
713fn rrf_fuse(
714    text_hits: Vec<TextSearchHit>,
715    vector_hits: Vec<VectorSearchHit>,
716    limit: usize,
717    query_text: &str,
718) -> Vec<SearchHit> {
719    #[derive(Default)]
720    struct Bucket {
721        score: DeterministicScore,
722        source: Option<SearchSource>,
723        title: Option<String>,
724        snippet: Option<String>,
725    }
726
727    let mut buckets: HashMap<Uuid, Bucket> = HashMap::new();
728
729    let query_lower = query_text.to_lowercase();
730    for (i, hit) in text_hits.into_iter().enumerate() {
731        let rank = i + 1; // RRF is 1-indexed
732        let entry = buckets.entry(hit.subject_id).or_default();
733        entry.score = entry.score + rrf_score(rank, RRF_K);
734        entry.source = Some(match entry.source {
735            Some(SearchSource::Vector) => SearchSource::Both,
736            _ => SearchSource::Text,
737        });
738        if entry.title.is_none() {
739            // Apply exact-match boost before storing the title so we only check once.
740            if let Some(ref title) = hit.title {
741                if title.to_lowercase() == query_lower {
742                    entry.score = entry.score + DeterministicScore::from_f64(EXACT_MATCH_BOOST);
743                }
744            }
745            entry.title = hit.title;
746        }
747        if entry.snippet.is_none() {
748            entry.snippet = hit.snippet;
749        }
750    }
751
752    for (i, hit) in vector_hits.into_iter().enumerate() {
753        let rank = i + 1;
754        let entry = buckets.entry(hit.subject_id).or_default();
755        entry.score = entry.score + rrf_score(rank, RRF_K);
756        entry.source = Some(match entry.source {
757            Some(SearchSource::Text) => SearchSource::Both,
758            _ => SearchSource::Vector,
759        });
760    }
761
762    let mut hits: Vec<SearchHit> = buckets
763        .into_iter()
764        .map(|(id, b)| SearchHit {
765            entity_id: id,
766            score: b.score,
767            source: b.source.expect("each bucket gets a source"),
768            title: b.title,
769            snippet: b.snippet,
770        })
771        .collect();
772
773    hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
774    hits.truncate(limit);
775    hits
776}
777
778#[cfg(test)]
779mod tests {
780    use super::*;
781    use crate::runtime::{KhiveRuntime, NamespaceToken, RuntimeConfig};
782    use khive_storage::types::{TextSearchHit, VectorSearchHit};
783    use khive_types::namespace::Namespace;
784    use lattice_embed::EmbeddingModel;
785
786    fn text_hit(id: Uuid, rank: u32, title: &str) -> TextSearchHit {
787        TextSearchHit {
788            subject_id: id,
789            score: DeterministicScore::from_f64(1.0),
790            rank,
791            title: Some(title.to_string()),
792            snippet: Some("...".to_string()),
793        }
794    }
795
796    fn vector_hit(id: Uuid, rank: u32) -> VectorSearchHit {
797        VectorSearchHit {
798            subject_id: id,
799            score: DeterministicScore::from_f64(0.9),
800            rank,
801        }
802    }
803
804    #[test]
805    fn rrf_fuse_text_only() {
806        let a = Uuid::new_v4();
807        let b = Uuid::new_v4();
808        let text = vec![text_hit(a, 1, "A"), text_hit(b, 2, "B")];
809        let hits = rrf_fuse(text, vec![], 10, "query");
810        assert_eq!(hits.len(), 2);
811        assert_eq!(hits[0].entity_id, a);
812        assert_eq!(hits[0].source, SearchSource::Text);
813        assert_eq!(hits[0].title.as_deref(), Some("A"));
814    }
815
816    #[test]
817    fn rrf_fuse_vector_only() {
818        let a = Uuid::new_v4();
819        let hits = rrf_fuse(vec![], vec![vector_hit(a, 1)], 10, "query");
820        assert_eq!(hits.len(), 1);
821        assert_eq!(hits[0].source, SearchSource::Vector);
822        assert!(hits[0].title.is_none());
823    }
824
825    #[test]
826    fn rrf_fuse_marks_both_when_in_both_lists() {
827        let id = Uuid::new_v4();
828        let text = vec![text_hit(id, 1, "A")];
829        let vec = vec![vector_hit(id, 1)];
830        let hits = rrf_fuse(text, vec, 10, "query");
831        assert_eq!(hits.len(), 1);
832        assert_eq!(hits[0].source, SearchSource::Both);
833    }
834
835    #[test]
836    fn rrf_fuse_respects_limit() {
837        let hits: Vec<TextSearchHit> = (0..20)
838            .map(|i| text_hit(Uuid::new_v4(), i + 1, "x"))
839            .collect();
840        let fused = rrf_fuse(hits, vec![], 5, "query");
841        assert_eq!(fused.len(), 5);
842    }
843
844    #[test]
845    fn rrf_fuse_orders_higher_score_first() {
846        // Same UUID in both lists at rank 1 → score 2/(10+1). Different UUIDs → 1/(10+1) each.
847        let a = Uuid::new_v4();
848        let b = Uuid::new_v4();
849        let text = vec![text_hit(a, 1, "A")];
850        let vec = vec![vector_hit(a, 1), vector_hit(b, 2)];
851        let hits = rrf_fuse(text, vec, 10, "query");
852        assert_eq!(hits[0].entity_id, a);
853        assert_eq!(hits[0].source, SearchSource::Both);
854        assert!(hits[0].score > hits[1].score);
855    }
856
857    #[test]
858    fn rrf_fuse_k10_score_spread_exceeds_threshold() {
859        // With k=10: rank 1 → 1/11 ≈ 0.0909, rank 10 → 1/20 = 0.0500.
860        // Spread ≈ 0.041, well above the 0.03 minimum required for reliable dedup.
861        let ids: Vec<Uuid> = (0..10).map(|_| Uuid::new_v4()).collect();
862        let text: Vec<TextSearchHit> = ids
863            .iter()
864            .enumerate()
865            .map(|(i, &id)| text_hit(id, (i + 1) as u32, "x"))
866            .collect();
867        let hits = rrf_fuse(text, vec![], 10, "query");
868        assert_eq!(hits.len(), 10);
869        let top_score = hits[0].score.to_f64();
870        let bottom_score = hits[9].score.to_f64();
871        let spread = top_score - bottom_score;
872        assert!(
873            spread >= 0.03,
874            "score spread {spread:.4} between rank 1 and rank 10 must be ≥ 0.03 (was {spread:.4})"
875        );
876    }
877
878    #[test]
879    fn rrf_fuse_exact_match_boost_elevates_score() {
880        // An entity whose title exactly matches the query should receive a score
881        // significantly above a non-matching entity ranked first by text search.
882        let exact_id = Uuid::new_v4();
883        let other_id = Uuid::new_v4();
884        // other_id ranks 1 in text, exact_id ranks 2 — but exact_id matches query.
885        let text = vec![
886            text_hit(other_id, 1, "something else"),
887            text_hit(exact_id, 2, "FlashAttention"),
888        ];
889        let hits = rrf_fuse(text, vec![], 10, "flashattention");
890        assert_eq!(hits.len(), 2);
891        assert_eq!(
892            hits[0].entity_id, exact_id,
893            "exact match must rank first despite being rank-2 in raw text search"
894        );
895    }
896
897    // ---- embed_batch tests ----
898
899    #[test]
900    fn embed_batch_unconfigured_on_memory_runtime() {
901        // KhiveRuntime::memory() has no embedding model — embed_batch returns Unconfigured.
902        let rt = KhiveRuntime::memory().unwrap();
903        let result = tokio::runtime::Runtime::new()
904            .unwrap()
905            .block_on(rt.embed_batch(&[]));
906        // Empty slice short-circuits before hitting the model check.
907        assert!(result.is_ok());
908        assert!(result.unwrap().is_empty());
909    }
910
911    #[test]
912    fn embed_batch_empty_input_returns_empty_vec() {
913        // No model needed — empty slice is handled before the embedder is touched.
914        let rt = KhiveRuntime::memory().unwrap();
915        let result = tokio::runtime::Runtime::new()
916            .unwrap()
917            .block_on(rt.embed_batch(&[]));
918        assert_eq!(result.unwrap(), Vec::<Vec<f32>>::new());
919    }
920
921    #[test]
922    fn embed_batch_no_model_non_empty_returns_unconfigured() {
923        let rt = KhiveRuntime::memory().unwrap();
924        let texts = vec!["hello".to_string()];
925        let result = tokio::runtime::Runtime::new()
926            .unwrap()
927            .block_on(rt.embed_batch(&texts));
928        match result {
929            Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
930            Err(other) => panic!("expected Unconfigured, got {:?}", other),
931            Ok(_) => panic!("expected Err, got Ok"),
932        }
933    }
934
935    #[test]
936    #[ignore = "loads ~80 MB model; run with --include-ignored"]
937    fn embed_batch_count_matches_input() {
938        let config = RuntimeConfig {
939            db_path: None,
940            default_namespace: Namespace::parse("test").unwrap(),
941            embedding_model: Some(EmbeddingModel::AllMiniLmL6V2),
942            packs: vec!["kg".to_string()],
943            ..RuntimeConfig::default()
944        };
945        let rt = KhiveRuntime::new(config).unwrap();
946        let texts: Vec<String> = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
947        let result = tokio::runtime::Runtime::new()
948            .unwrap()
949            .block_on(rt.embed_batch(&texts));
950        let embeddings = result.unwrap();
951        assert_eq!(embeddings.len(), texts.len());
952    }
953
954    #[test]
955    fn vector_search_requires_embedding_or_text() {
956        let rt = KhiveRuntime::memory().unwrap();
957        let tok = NamespaceToken::local();
958        let result = tokio::runtime::Runtime::new()
959            .unwrap()
960            .block_on(rt.vector_search(&tok, None, None, 10, Some(SubstrateKind::Entity)));
961        match result {
962            Err(crate::RuntimeError::InvalidInput(msg)) => {
963                assert!(msg.contains("query_embedding or query_text"), "msg: {msg}");
964            }
965            other => panic!("expected InvalidInput, got {other:?}"),
966        }
967    }
968
969    #[test]
970    fn vector_search_text_without_model_returns_unconfigured() {
971        let rt = KhiveRuntime::memory().unwrap();
972        let tok = NamespaceToken::local();
973        let result = tokio::runtime::Runtime::new()
974            .unwrap()
975            .block_on(rt.vector_search(
976                &tok,
977                None,
978                Some("attention"),
979                10,
980                Some(SubstrateKind::Entity),
981            ));
982        match result {
983            Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
984            other => panic!("expected Unconfigured, got {other:?}"),
985        }
986    }
987
988    #[test]
989    #[ignore = "loads ~80 MB model; run with --include-ignored"]
990    fn embed_batch_vectors_have_expected_dimensions() {
991        let model = EmbeddingModel::AllMiniLmL6V2;
992        let config = RuntimeConfig {
993            db_path: None,
994            default_namespace: Namespace::parse("test").unwrap(),
995            embedding_model: Some(model),
996            packs: vec!["kg".to_string()],
997            ..RuntimeConfig::default()
998        };
999        let rt = KhiveRuntime::new(config).unwrap();
1000        let texts = vec!["hello world".to_string()];
1001        let result = tokio::runtime::Runtime::new()
1002            .unwrap()
1003            .block_on(rt.embed_batch(&texts));
1004        let embeddings = result.unwrap();
1005        assert_eq!(embeddings[0].len(), model.dimensions());
1006    }
1007
1008    // ---- hybrid_search enrichment (issue #147 / #160) ----
1009
1010    #[tokio::test]
1011    async fn hybrid_search_entity_hit_has_title() {
1012        let rt = KhiveRuntime::memory().unwrap();
1013        let tok = NamespaceToken::local();
1014        rt.create_entity(
1015            &tok,
1016            "concept",
1017            None,
1018            "FlashAttention",
1019            Some("IO-aware exact attention using tiling"),
1020            None,
1021            vec![],
1022        )
1023        .await
1024        .unwrap();
1025
1026        let hits = rt
1027            .hybrid_search(&tok, "FlashAttention", None, 10, None, None)
1028            .await
1029            .unwrap();
1030
1031        assert!(!hits.is_empty(), "should find the entity");
1032        let hit = &hits[0];
1033        assert!(hit.title.is_some(), "title must be populated");
1034        assert!(
1035            hit.title.as_deref().unwrap().contains("FlashAttention"),
1036            "title must contain entity name"
1037        );
1038    }
1039}