Skip to main content

khive_runtime/
retrieval.rs

1//! Retrieval operations: local embedding generation and hybrid search with RRF fusion.
2
3use std::collections::{HashMap, HashSet};
4
5use uuid::Uuid;
6
7use crate::error::{RuntimeError, RuntimeResult};
8use crate::runtime::{parse_embedding_model_alias, sanitize_key, KhiveRuntime, NamespaceToken};
9use khive_score::{rrf_score, DeterministicScore};
10use khive_storage::types::{
11    PageRequest, TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
12    VectorSearchRequest,
13};
14use khive_storage::EntityFilter;
15use khive_types::SubstrateKind;
16
17/// A unified search result combining vector and text signals.
18#[derive(Clone, Debug)]
19pub struct SearchHit {
20    pub entity_id: Uuid,
21    pub score: DeterministicScore,
22    pub source: SearchSource,
23    pub title: Option<String>,
24    pub snippet: Option<String>,
25}
26
27/// Which retrieval path(s) contributed to a hit.
28#[derive(Clone, Copy, Debug, PartialEq, Eq)]
29pub enum SearchSource {
30    Vector,
31    Text,
32    Both,
33}
34
35/// RRF constant. Controls how strongly top ranks dominate.
36///
37/// The original paper uses k=60 for large-scale document retrieval. For a knowledge
38/// graph with tens to thousands of entities, k=60 over-compresses scores into a
39/// narrow band (rank 1 ≈ 0.016, rank 10 ≈ 0.014, spread ≈ 0.002). k=10 produces
40/// rank 1 ≈ 0.091, rank 10 ≈ 0.050, spread ≈ 0.041 — 20× better discrimination,
41/// making dedup-before-create reliable at graph sizes of 50–2700 entities.
42const RRF_K: usize = 10;
43
44/// Candidates pulled per path before fusion. Higher = better recall, more work.
45const CANDIDATE_MULTIPLIER: u32 = 4;
46
47impl KhiveRuntime {
48    /// Generate an embedding vector for `text` using the configured default model.
49    ///
50    /// First call lazily loads model weights (cold start cost). Subsequent calls reuse them.
51    /// Returns `Unconfigured("embedding_model")` if no model is configured.
52    pub async fn embed(&self, text: &str) -> RuntimeResult<Vec<f32>> {
53        let model_name = self.default_embedder_name();
54        if model_name.is_empty() {
55            return Err(RuntimeError::Unconfigured("embedding_model".into()));
56        }
57        self.embed_with_model(model_name, text).await
58    }
59
60    /// Generate an embedding vector for `text` using the named model.
61    ///
62    /// Accepts both built-in lattice model names/aliases and custom provider
63    /// names registered via [`KhiveRuntime::register_embedder`]. For lattice
64    /// models the resolved `EmbeddingModel` enum is forwarded to `embed_one`
65    /// so the service can select the correct model variant. For custom
66    /// providers, `embed_one` is called with `EmbeddingModel::default()`
67    /// because custom services are expected to ignore the enum argument (they
68    /// own a single model implicitly).
69    ///
70    /// Returns `UnknownModel` if `model_name` is not in the embedder registry.
71    pub async fn embed_with_model(&self, model_name: &str, text: &str) -> RuntimeResult<Vec<f32>> {
72        // Try to resolve as a lattice alias. If that succeeds, use the enum to
73        // inform the service which model to run. If not, fall through to the
74        // custom-provider path — custom services ignore the EmbeddingModel arg.
75        let model = parse_embedding_model_alias(model_name);
76        let service = self.embedder(model_name).await?;
77        let emb_model = model.unwrap_or_default();
78        Ok(service.embed_one(text, emb_model).await?)
79    }
80
81    /// Generate embeddings for multiple texts in one call using the configured default model.
82    ///
83    /// Delegates to the cached `EmbeddingService::embed`, so repeated texts within
84    /// and across calls benefit from the runtime-level LRU cache.
85    ///
86    /// Returns an empty vec for empty input without hitting the embedding service.
87    /// Returns `Unconfigured("embedding_model")` if no model is configured.
88    pub async fn embed_batch(&self, texts: &[String]) -> RuntimeResult<Vec<Vec<f32>>> {
89        if texts.is_empty() {
90            return Ok(vec![]);
91        }
92        let model_name = self.default_embedder_name();
93        if model_name.is_empty() {
94            return Err(RuntimeError::Unconfigured("embedding_model".into()));
95        }
96        self.embed_batch_with_model(model_name, texts).await
97    }
98
99    /// Generate embeddings for multiple texts using the named model.
100    ///
101    /// Accepts lattice model names/aliases and custom provider names.
102    /// Returns `UnknownModel` if `model_name` is not in the embedder registry.
103    pub async fn embed_batch_with_model(
104        &self,
105        model_name: &str,
106        texts: &[String],
107    ) -> RuntimeResult<Vec<Vec<f32>>> {
108        if texts.is_empty() {
109            return Ok(vec![]);
110        }
111        let model = parse_embedding_model_alias(model_name);
112        let service = self.embedder(model_name).await?;
113        let emb_model = model.unwrap_or_default();
114        Ok(service.embed(texts, emb_model).await?)
115    }
116
117    /// Search vectors using either a caller-provided embedding or query text.
118    ///
119    /// Existing callers pass `query_embedding: Some(vec)` to avoid re-embedding.
120    /// Text callers pass `query_embedding: None, query_text: Some(...)` and the
121    /// runtime embeds internally.
122    pub async fn vector_search(
123        &self,
124        token: &NamespaceToken,
125        query_embedding: Option<Vec<f32>>,
126        query_text: Option<&str>,
127        top_k: u32,
128        kind: Option<SubstrateKind>,
129    ) -> RuntimeResult<Vec<VectorSearchHit>> {
130        let embedding = match query_embedding {
131            Some(vec) => vec,
132            None => {
133                let text = query_text.ok_or_else(|| {
134                    RuntimeError::InvalidInput(
135                        "vector search requires query_embedding or query_text".into(),
136                    )
137                })?;
138                if text.trim().is_empty() {
139                    return Err(RuntimeError::InvalidInput(
140                        "query_text must not be empty".into(),
141                    ));
142                }
143                self.embed(text).await?
144            }
145        };
146
147        let ns = token.namespace().as_str().to_owned();
148        Ok(self
149            .vectors(token)?
150            .search(VectorSearchRequest {
151                query_vectors: vec![embedding],
152                top_k,
153                namespace: Some(ns),
154                kind,
155                embedding_model: None,
156                filter: None,
157                backend_hints: None,
158            })
159            .await?)
160    }
161
162    /// Hybrid search: text (FTS5) + vector retrieval fused via Reciprocal Rank Fusion.
163    ///
164    /// - Always performs text search over `query_text`.
165    /// - If `query_vector` is `Some`, also performs vector search and fuses both lists.
166    /// - If `None`, returns text-only results — no vector store needed.
167    /// - If `entity_kind` is `Some`, the alive-set query filters to that kind.
168    ///   The text/vector candidate pools are unfiltered up front; the kind
169    ///   filter applies at the alive-check stage where we already fetch each
170    ///   candidate to confirm it isn't soft-deleted.
171    ///
172    /// `limit` caps the final returned list; internally pulls `limit * 4` candidates per path.
173    /// The fused candidate set is kept untruncated until after the alive + kind filter so
174    /// that right-kind hits ranked below `limit` in the raw fusion still surface when
175    /// higher-ranked candidates are wrong-kind or soft-deleted.
176    #[allow(clippy::too_many_arguments)]
177    pub async fn hybrid_search(
178        &self,
179        token: &NamespaceToken,
180        query_text: &str,
181        query_vector: Option<Vec<f32>>,
182        limit: u32,
183        entity_kind: Option<&str>,
184        entity_type: Option<&str>,
185    ) -> RuntimeResult<Vec<SearchHit>> {
186        let candidates = limit.saturating_mul(CANDIDATE_MULTIPLIER).max(limit);
187
188        let ns = token.namespace().as_str().to_owned();
189        let text_hits = self
190            .text(token)?
191            .search(TextSearchRequest {
192                query: query_text.to_string(),
193                mode: TextQueryMode::Plain,
194                filter: Some(TextFilter {
195                    namespaces: vec![ns.clone()],
196                    ..TextFilter::default()
197                }),
198                top_k: candidates,
199                snippet_chars: 200,
200            })
201            .await?;
202
203        let vector_hits = if query_vector.is_some() || self.config().embedding_model.is_some() {
204            self.vector_search(
205                token,
206                query_vector,
207                Some(query_text),
208                candidates,
209                Some(SubstrateKind::Entity),
210            )
211            .await?
212        } else {
213            Vec::new()
214        };
215
216        // Fuse without truncating: keep the full candidate pool through the
217        // alive/kind filter so right-kind hits below rank `limit` aren't lost.
218        let mut fused = rrf_fuse(text_hits, vector_hits, candidates as usize, query_text);
219
220        // Filter to alive entities (and optionally to a specific kind). A single
221        // query fetches all alive IDs that match the kind constraint from the
222        // fused set; any ID absent has been soft-deleted or doesn't match.
223        if !fused.is_empty() {
224            let candidate_ids: Vec<Uuid> = fused.iter().map(|h| h.entity_id).collect();
225            let alive_page = self
226                .entities(token)?
227                .query_entities(
228                    token.namespace().as_str(),
229                    EntityFilter {
230                        ids: candidate_ids,
231                        kinds: entity_kind.map(|k| vec![k.to_string()]).unwrap_or_default(),
232                        entity_types: entity_type.map(|t| vec![t.to_string()]).unwrap_or_default(),
233                        ..EntityFilter::default()
234                    },
235                    PageRequest {
236                        offset: 0,
237                        limit: fused.len() as u32,
238                    },
239                )
240                .await?;
241            // Keep entity metadata to enrich hits that had no FTS5 title/snippet.
242            let mut entity_meta: HashMap<Uuid, (String, Option<String>)> = HashMap::new();
243            let mut alive: HashSet<Uuid> = HashSet::new();
244            for e in alive_page.items {
245                alive.insert(e.id);
246                entity_meta.insert(e.id, (e.name, e.description));
247            }
248
249            fused.retain(|h| alive.contains(&h.entity_id));
250
251            // Enrich vector-only hits (title/snippet == None) from entity record.
252            for hit in &mut fused {
253                if let Some((name, description)) = entity_meta.get(&hit.entity_id) {
254                    if hit.title.is_none() {
255                        hit.title = Some(name.clone());
256                    }
257                    if hit.snippet.is_none() {
258                        hit.snippet = description.clone();
259                    }
260                }
261            }
262        }
263
264        fused.truncate(limit as usize);
265        Ok(fused)
266    }
267
268    /// Exact KNN over the full namespace's vector store.
269    ///
270    /// sqlite-vec uses brute-force cosine — results are exact, not approximate.
271    /// Cost is O(N · D) per query. For small-to-medium namespaces (~hundreds of
272    /// thousands of vectors) this is well within latency budgets.
273    pub async fn knn(
274        &self,
275        token: &NamespaceToken,
276        query_vector: Vec<f32>,
277        top_k: u32,
278    ) -> RuntimeResult<Vec<VectorSearchHit>> {
279        let ns = token.namespace().as_str().to_owned();
280        Ok(self
281            .vectors(token)?
282            .search(VectorSearchRequest {
283                query_vectors: vec![query_vector],
284                top_k,
285                namespace: Some(ns),
286                kind: Some(SubstrateKind::Entity),
287                embedding_model: None,
288                filter: None,
289                backend_hints: None,
290            })
291            .await?)
292    }
293
294    /// Exact KNN restricted to a candidate set.
295    ///
296    /// Useful for reranking the top-N results from `hybrid_search` (or any other
297    /// retrieval path) with exact cosine similarity against a query vector.
298    /// Returns hits sorted by similarity (highest first), truncated to `top_k`.
299    pub async fn rerank(
300        &self,
301        token: &NamespaceToken,
302        query_vector: &[f32],
303        candidate_ids: &[Uuid],
304        top_k: u32,
305    ) -> RuntimeResult<Vec<VectorSearchHit>> {
306        let candidate_set: HashSet<Uuid> = candidate_ids.iter().copied().collect();
307        let ns = token.namespace().as_str().to_owned();
308        let all_hits = self
309            .vectors(token)?
310            .search(VectorSearchRequest {
311                query_vectors: vec![query_vector.to_vec()],
312                top_k: candidate_ids.len() as u32,
313                namespace: Some(ns),
314                kind: Some(SubstrateKind::Entity),
315                embedding_model: None,
316                filter: None,
317                backend_hints: None,
318            })
319            .await?;
320        let mut hits: Vec<VectorSearchHit> = all_hits
321            .into_iter()
322            .filter(|h| candidate_set.contains(&h.subject_id))
323            .collect();
324        hits.sort_by(|a, b| b.score.cmp(&a.score));
325        hits.truncate(top_k as usize);
326        Ok(hits)
327    }
328
329    /// Backfill vector and FTS index entries for entities and notes that are missing them.
330    ///
331    /// Intended to run once at startup as a background task (warm-up sequence steps 2–4).
332    /// Queries the SQL substrate for entity descriptions and note contents that have no
333    /// corresponding entry in the vector store for any registered embedding model, then
334    /// embeds and inserts them. FTS entries missing for notes are also repopulated.
335    ///
336    /// The operation is best-effort: individual embed/insert failures are logged and
337    /// skipped rather than aborting the whole backfill. If no embedding models are
338    /// registered, returns immediately with 0.
339    ///
340    /// Returns the total number of records backfilled across all models.
341    pub async fn backfill_missing_embeddings(&self, token: &NamespaceToken) -> RuntimeResult<u64> {
342        use khive_storage::types::{SqlRow, SqlStatement, SqlValue, TextDocument};
343
344        let model_names = self.registered_embedding_model_names();
345        if model_names.is_empty() {
346            tracing::debug!(
347                "backfill_missing_embeddings: no embedding models registered, skipping"
348            );
349            return Ok(0);
350        }
351
352        let ns = token.namespace().as_str().to_string();
353        let mut total_backfilled = 0u64;
354
355        for model_name in &model_names {
356            // Derive the vec table name from the model name (must match vec_model_key logic).
357            let vec_table = format!("vec_{}", sanitize_key(model_name));
358
359            // --- Entities: embed description where no vector entry exists ---
360            // Loop until a batch returns fewer than PAGE_SIZE rows. Because the query uses
361            // NOT IN (SELECT subject_id FROM vec_table ...), each successfully inserted row is
362            // excluded from subsequent pages — no OFFSET needed.
363            const PAGE_SIZE: usize = 500;
364            let mut entity_total = 0usize;
365            loop {
366                let entity_sql = SqlStatement {
367                    sql: format!(
368                        "SELECT id, name, description FROM entities \
369                         WHERE namespace = ?1 AND deleted_at IS NULL \
370                         AND id NOT IN (\
371                             SELECT subject_id FROM {vec_table} \
372                             WHERE namespace = ?1 AND embedding_model = ?2 \
373                         ) LIMIT {PAGE_SIZE}"
374                    ),
375                    params: vec![
376                        SqlValue::Text(ns.clone()),
377                        SqlValue::Text(model_name.clone()),
378                    ],
379                    label: Some("backfill_entities".into()),
380                };
381
382                let entity_rows: Vec<SqlRow> = {
383                    let sql = self.sql();
384                    match sql.reader().await {
385                        Ok(mut reader) => reader.query_all(entity_sql).await.unwrap_or_default(),
386                        Err(_) => vec![],
387                    }
388                };
389
390                let batch_len = entity_rows.len();
391                entity_total += batch_len;
392
393                for row in &entity_rows {
394                    let id_str = row.columns.first().and_then(|c| {
395                        if let SqlValue::Text(s) = &c.value {
396                            Some(s.clone())
397                        } else {
398                            None
399                        }
400                    });
401                    let description = row.columns.get(2).and_then(|c| {
402                        if let SqlValue::Text(s) = &c.value {
403                            Some(s.clone())
404                        } else if let SqlValue::Null = &c.value {
405                            None
406                        } else {
407                            None
408                        }
409                    });
410
411                    let (Some(id_str), Some(desc)) = (id_str, description) else {
412                        continue;
413                    };
414                    let Ok(id) = id_str.parse::<Uuid>() else {
415                        continue;
416                    };
417                    if desc.trim().is_empty() {
418                        continue;
419                    }
420
421                    match self.embed_with_model(model_name, &desc).await {
422                        Ok(vector) => {
423                            if let Ok(vs) = self.vectors_for_model(token, model_name) {
424                                match vs
425                                    .insert(
426                                        id,
427                                        SubstrateKind::Entity,
428                                        &ns,
429                                        "entity.description",
430                                        vec![vector],
431                                    )
432                                    .await
433                                {
434                                    Ok(()) => {
435                                        total_backfilled += 1;
436                                    }
437                                    Err(e) => {
438                                        tracing::warn!(
439                                            id = %id, model = %model_name,
440                                            error = %e,
441                                            "backfill_missing_embeddings: entity vector insert failed"
442                                        );
443                                    }
444                                }
445                            }
446                        }
447                        Err(e) => {
448                            tracing::warn!(
449                                id = %id, model = %model_name,
450                                error = %e,
451                                "backfill_missing_embeddings: entity embed failed"
452                            );
453                        }
454                    }
455                }
456
457                if batch_len < PAGE_SIZE {
458                    break;
459                }
460            }
461
462            // --- Notes: embed content where no vector entry exists ---
463            let text_store = self.text_for_notes(token).ok();
464            let mut note_total = 0usize;
465            loop {
466                let note_sql = SqlStatement {
467                    sql: format!(
468                        "SELECT id, content FROM notes \
469                         WHERE namespace = ?1 AND deleted_at IS NULL \
470                         AND id NOT IN (\
471                             SELECT subject_id FROM {vec_table} \
472                             WHERE namespace = ?1 AND embedding_model = ?2 \
473                         ) LIMIT {PAGE_SIZE}"
474                    ),
475                    params: vec![
476                        SqlValue::Text(ns.clone()),
477                        SqlValue::Text(model_name.clone()),
478                    ],
479                    label: Some("backfill_notes".into()),
480                };
481
482                let note_rows: Vec<SqlRow> = {
483                    let sql = self.sql();
484                    match sql.reader().await {
485                        Ok(mut reader) => reader.query_all(note_sql).await.unwrap_or_default(),
486                        Err(_) => vec![],
487                    }
488                };
489
490                let batch_len = note_rows.len();
491                note_total += batch_len;
492
493                for row in &note_rows {
494                    let id_str = row.columns.first().and_then(|c| {
495                        if let SqlValue::Text(s) = &c.value {
496                            Some(s.clone())
497                        } else {
498                            None
499                        }
500                    });
501                    let content = row.columns.get(1).and_then(|c| {
502                        if let SqlValue::Text(s) = &c.value {
503                            Some(s.clone())
504                        } else {
505                            None
506                        }
507                    });
508
509                    let (Some(id_str), Some(content)) = (id_str, content) else {
510                        continue;
511                    };
512                    let Ok(id) = id_str.parse::<Uuid>() else {
513                        continue;
514                    };
515                    if content.trim().is_empty() {
516                        continue;
517                    }
518
519                    // Repopulate FTS entry if missing (best-effort, first model only to avoid N duplicates).
520                    if model_names.first().map(|n| n.as_str()) == Some(model_name.as_str()) {
521                        if let Some(ref ts) = text_store {
522                            let _ = ts
523                                .upsert_document(TextDocument {
524                                    subject_id: id,
525                                    namespace: ns.clone(),
526                                    kind: SubstrateKind::Note,
527                                    title: None,
528                                    body: content.clone(),
529                                    tags: vec![],
530                                    metadata: None,
531                                    updated_at: chrono::Utc::now(),
532                                })
533                                .await;
534                        }
535                    }
536
537                    match self.embed_with_model(model_name, &content).await {
538                        Ok(vector) => {
539                            if let Ok(vs) = self.vectors_for_model(token, model_name) {
540                                match vs
541                                    .insert(
542                                        id,
543                                        SubstrateKind::Note,
544                                        &ns,
545                                        "note.content",
546                                        vec![vector],
547                                    )
548                                    .await
549                                {
550                                    Ok(()) => {
551                                        total_backfilled += 1;
552                                    }
553                                    Err(e) => {
554                                        tracing::warn!(
555                                            id = %id, model = %model_name,
556                                            error = %e,
557                                            "backfill_missing_embeddings: note vector insert failed"
558                                        );
559                                    }
560                                }
561                            }
562                        }
563                        Err(e) => {
564                            tracing::warn!(
565                                id = %id, model = %model_name,
566                                error = %e,
567                                "backfill_missing_embeddings: note embed failed"
568                            );
569                        }
570                    }
571                }
572
573                if batch_len < PAGE_SIZE {
574                    break;
575                }
576            }
577
578            tracing::info!(
579                model = %model_name,
580                namespace = %ns,
581                entities = entity_total,
582                notes = note_total,
583                "backfill_missing_embeddings: model pass complete"
584            );
585        }
586
587        tracing::info!(
588            namespace = %ns,
589            total_backfilled = total_backfilled,
590            "backfill_missing_embeddings: finished"
591        );
592
593        Ok(total_backfilled)
594    }
595
596    /// Sweep orphaned vector entries for all registered embedding models.
597    ///
598    /// A vector entry is orphaned when its `subject_id` no longer exists as a
599    /// live row in the entity or note tables (i.e. either the row is absent or
600    /// has `deleted_at IS NOT NULL`). Orphaned entries accumulate after
601    /// hard-deletes because the vector store and SQL substrate are decoupled.
602    ///
603    /// Iterates over every registered embedding model and calls
604    /// [`VectorStore::orphan_sweep`] for the token's namespace. Models whose
605    /// backend returns [`StorageError::Unsupported`] are skipped without error —
606    /// this preserves forward-compat when a newly registered model does not yet
607    /// implement sweep.
608    ///
609    /// Returns the total number of vector rows deleted across all models.
610    pub async fn sweep_orphan_vectors(
611        &self,
612        token: &NamespaceToken,
613        max_delete_per_model: u32,
614        dry_run: bool,
615    ) -> RuntimeResult<u64> {
616        use khive_storage::types::OrphanSweepConfig;
617        use khive_storage::StorageError;
618
619        let model_names = self.registered_embedding_model_names();
620        if model_names.is_empty() {
621            tracing::debug!("sweep_orphan_vectors: no embedding models registered, skipping");
622            return Ok(0);
623        }
624
625        let ns = token.namespace().as_str().to_string();
626        let mut total_deleted = 0u64;
627
628        for model_name in &model_names {
629            let store = match self.vectors_for_model(token, model_name) {
630                Ok(s) => s,
631                Err(e) => {
632                    tracing::warn!(
633                        model = %model_name,
634                        error = %e,
635                        "sweep_orphan_vectors: failed to get vector store, skipping model"
636                    );
637                    continue;
638                }
639            };
640
641            let caps = store.capabilities();
642            if !caps.supports_orphan_sweep {
643                tracing::debug!(
644                    model = %model_name,
645                    "sweep_orphan_vectors: backend does not support orphan sweep, skipping"
646                );
647                continue;
648            }
649
650            let config = OrphanSweepConfig {
651                subject_id_allowlist: None,
652                namespaces: vec![ns.clone()],
653                substrate_kinds: vec![],
654                max_delete: max_delete_per_model,
655                dry_run,
656            };
657
658            match store.orphan_sweep(&config).await {
659                Ok(result) => {
660                    tracing::info!(
661                        model = %model_name,
662                        namespace = %ns,
663                        scanned = result.scanned,
664                        deleted = result.deleted,
665                        would_delete = result.would_delete,
666                        dry_run = dry_run,
667                        "sweep_orphan_vectors: sweep complete"
668                    );
669                    total_deleted += result.deleted;
670                }
671                Err(StorageError::Unsupported { .. }) => {
672                    tracing::debug!(
673                        model = %model_name,
674                        "sweep_orphan_vectors: backend returned Unsupported, skipping"
675                    );
676                }
677                Err(e) => {
678                    tracing::warn!(
679                        model = %model_name,
680                        error = %e,
681                        "sweep_orphan_vectors: sweep failed, continuing with other models"
682                    );
683                }
684            }
685        }
686
687        tracing::info!(
688            namespace = %ns,
689            total_deleted = total_deleted,
690            dry_run = dry_run,
691            "sweep_orphan_vectors: finished"
692        );
693
694        Ok(total_deleted)
695    }
696}
697
698/// Score bonus applied when an entity's title is an exact case-insensitive match for
699/// the query. Dominates RRF scores (~0.09–0.18 range with k=10) so that an exact
700/// name match always ranks above any partial or semantic match.
701const EXACT_MATCH_BOOST: f64 = 0.5;
702
703/// Fuse text + vector hits with Reciprocal Rank Fusion (k=10).
704///
705/// Entity search stays local because it uses k=10 plus exact-match boosting.
706/// Hits in both lists get RRF scores summed. If `query_text` exactly matches
707/// (case-insensitive) an entity's title from the text hits, a bonus of
708/// `EXACT_MATCH_BOOST` is added to ensure exact-name matches dominate.
709/// Sort by fused score, take top-`limit`.
710fn rrf_fuse(
711    text_hits: Vec<TextSearchHit>,
712    vector_hits: Vec<VectorSearchHit>,
713    limit: usize,
714    query_text: &str,
715) -> Vec<SearchHit> {
716    #[derive(Default)]
717    struct Bucket {
718        score: DeterministicScore,
719        source: Option<SearchSource>,
720        title: Option<String>,
721        snippet: Option<String>,
722    }
723
724    let mut buckets: HashMap<Uuid, Bucket> = HashMap::new();
725
726    let query_lower = query_text.to_lowercase();
727    for (i, hit) in text_hits.into_iter().enumerate() {
728        let rank = i + 1; // RRF is 1-indexed
729        let entry = buckets.entry(hit.subject_id).or_default();
730        entry.score = entry.score + rrf_score(rank, RRF_K);
731        entry.source = Some(match entry.source {
732            Some(SearchSource::Vector) => SearchSource::Both,
733            _ => SearchSource::Text,
734        });
735        if entry.title.is_none() {
736            // Apply exact-match boost before storing the title so we only check once.
737            if let Some(ref title) = hit.title {
738                if title.to_lowercase() == query_lower {
739                    entry.score = entry.score + DeterministicScore::from_f64(EXACT_MATCH_BOOST);
740                }
741            }
742            entry.title = hit.title;
743        }
744        if entry.snippet.is_none() {
745            entry.snippet = hit.snippet;
746        }
747    }
748
749    for (i, hit) in vector_hits.into_iter().enumerate() {
750        let rank = i + 1;
751        let entry = buckets.entry(hit.subject_id).or_default();
752        entry.score = entry.score + rrf_score(rank, RRF_K);
753        entry.source = Some(match entry.source {
754            Some(SearchSource::Text) => SearchSource::Both,
755            _ => SearchSource::Vector,
756        });
757    }
758
759    let mut hits: Vec<SearchHit> = buckets
760        .into_iter()
761        .map(|(id, b)| SearchHit {
762            entity_id: id,
763            score: b.score,
764            source: b.source.expect("each bucket gets a source"),
765            title: b.title,
766            snippet: b.snippet,
767        })
768        .collect();
769
770    hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
771    hits.truncate(limit);
772    hits
773}
774
775#[cfg(test)]
776mod tests {
777    use super::*;
778    use crate::runtime::{KhiveRuntime, NamespaceToken, RuntimeConfig};
779    use khive_storage::types::{TextSearchHit, VectorSearchHit};
780    use khive_types::namespace::Namespace;
781    use lattice_embed::EmbeddingModel;
782
783    fn text_hit(id: Uuid, rank: u32, title: &str) -> TextSearchHit {
784        TextSearchHit {
785            subject_id: id,
786            score: DeterministicScore::from_f64(1.0),
787            rank,
788            title: Some(title.to_string()),
789            snippet: Some("...".to_string()),
790        }
791    }
792
793    fn vector_hit(id: Uuid, rank: u32) -> VectorSearchHit {
794        VectorSearchHit {
795            subject_id: id,
796            score: DeterministicScore::from_f64(0.9),
797            rank,
798        }
799    }
800
801    #[test]
802    fn rrf_fuse_text_only() {
803        let a = Uuid::new_v4();
804        let b = Uuid::new_v4();
805        let text = vec![text_hit(a, 1, "A"), text_hit(b, 2, "B")];
806        let hits = rrf_fuse(text, vec![], 10, "query");
807        assert_eq!(hits.len(), 2);
808        assert_eq!(hits[0].entity_id, a);
809        assert_eq!(hits[0].source, SearchSource::Text);
810        assert_eq!(hits[0].title.as_deref(), Some("A"));
811    }
812
813    #[test]
814    fn rrf_fuse_vector_only() {
815        let a = Uuid::new_v4();
816        let hits = rrf_fuse(vec![], vec![vector_hit(a, 1)], 10, "query");
817        assert_eq!(hits.len(), 1);
818        assert_eq!(hits[0].source, SearchSource::Vector);
819        assert!(hits[0].title.is_none());
820    }
821
822    #[test]
823    fn rrf_fuse_marks_both_when_in_both_lists() {
824        let id = Uuid::new_v4();
825        let text = vec![text_hit(id, 1, "A")];
826        let vec = vec![vector_hit(id, 1)];
827        let hits = rrf_fuse(text, vec, 10, "query");
828        assert_eq!(hits.len(), 1);
829        assert_eq!(hits[0].source, SearchSource::Both);
830    }
831
832    #[test]
833    fn rrf_fuse_respects_limit() {
834        let hits: Vec<TextSearchHit> = (0..20)
835            .map(|i| text_hit(Uuid::new_v4(), i + 1, "x"))
836            .collect();
837        let fused = rrf_fuse(hits, vec![], 5, "query");
838        assert_eq!(fused.len(), 5);
839    }
840
841    #[test]
842    fn rrf_fuse_orders_higher_score_first() {
843        // Same UUID in both lists at rank 1 → score 2/(10+1). Different UUIDs → 1/(10+1) each.
844        let a = Uuid::new_v4();
845        let b = Uuid::new_v4();
846        let text = vec![text_hit(a, 1, "A")];
847        let vec = vec![vector_hit(a, 1), vector_hit(b, 2)];
848        let hits = rrf_fuse(text, vec, 10, "query");
849        assert_eq!(hits[0].entity_id, a);
850        assert_eq!(hits[0].source, SearchSource::Both);
851        assert!(hits[0].score > hits[1].score);
852    }
853
854    #[test]
855    fn rrf_fuse_k10_score_spread_exceeds_threshold() {
856        // With k=10: rank 1 → 1/11 ≈ 0.0909, rank 10 → 1/20 = 0.0500.
857        // Spread ≈ 0.041, well above the 0.03 minimum required for reliable dedup.
858        let ids: Vec<Uuid> = (0..10).map(|_| Uuid::new_v4()).collect();
859        let text: Vec<TextSearchHit> = ids
860            .iter()
861            .enumerate()
862            .map(|(i, &id)| text_hit(id, (i + 1) as u32, "x"))
863            .collect();
864        let hits = rrf_fuse(text, vec![], 10, "query");
865        assert_eq!(hits.len(), 10);
866        let top_score = hits[0].score.to_f64();
867        let bottom_score = hits[9].score.to_f64();
868        let spread = top_score - bottom_score;
869        assert!(
870            spread >= 0.03,
871            "score spread {spread:.4} between rank 1 and rank 10 must be ≥ 0.03 (was {spread:.4})"
872        );
873    }
874
875    #[test]
876    fn rrf_fuse_exact_match_boost_elevates_score() {
877        // An entity whose title exactly matches the query should receive a score
878        // significantly above a non-matching entity ranked first by text search.
879        let exact_id = Uuid::new_v4();
880        let other_id = Uuid::new_v4();
881        // other_id ranks 1 in text, exact_id ranks 2 — but exact_id matches query.
882        let text = vec![
883            text_hit(other_id, 1, "something else"),
884            text_hit(exact_id, 2, "FlashAttention"),
885        ];
886        let hits = rrf_fuse(text, vec![], 10, "flashattention");
887        assert_eq!(hits.len(), 2);
888        assert_eq!(
889            hits[0].entity_id, exact_id,
890            "exact match must rank first despite being rank-2 in raw text search"
891        );
892    }
893
894    // ---- embed_batch tests ----
895
896    #[test]
897    fn embed_batch_unconfigured_on_memory_runtime() {
898        // KhiveRuntime::memory() has no embedding model — embed_batch returns Unconfigured.
899        let rt = KhiveRuntime::memory().unwrap();
900        let result = tokio::runtime::Runtime::new()
901            .unwrap()
902            .block_on(rt.embed_batch(&[]));
903        // Empty slice short-circuits before hitting the model check.
904        assert!(result.is_ok());
905        assert!(result.unwrap().is_empty());
906    }
907
908    #[test]
909    fn embed_batch_empty_input_returns_empty_vec() {
910        // No model needed — empty slice is handled before the embedder is touched.
911        let rt = KhiveRuntime::memory().unwrap();
912        let result = tokio::runtime::Runtime::new()
913            .unwrap()
914            .block_on(rt.embed_batch(&[]));
915        assert_eq!(result.unwrap(), Vec::<Vec<f32>>::new());
916    }
917
918    #[test]
919    fn embed_batch_no_model_non_empty_returns_unconfigured() {
920        let rt = KhiveRuntime::memory().unwrap();
921        let texts = vec!["hello".to_string()];
922        let result = tokio::runtime::Runtime::new()
923            .unwrap()
924            .block_on(rt.embed_batch(&texts));
925        match result {
926            Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
927            Err(other) => panic!("expected Unconfigured, got {:?}", other),
928            Ok(_) => panic!("expected Err, got Ok"),
929        }
930    }
931
932    #[test]
933    #[ignore = "loads ~80 MB model; run with --include-ignored"]
934    fn embed_batch_count_matches_input() {
935        let config = RuntimeConfig {
936            db_path: None,
937            default_namespace: Namespace::parse("test").unwrap(),
938            embedding_model: Some(EmbeddingModel::AllMiniLmL6V2),
939            packs: vec!["kg".to_string()],
940            ..RuntimeConfig::default()
941        };
942        let rt = KhiveRuntime::new(config).unwrap();
943        let texts: Vec<String> = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
944        let result = tokio::runtime::Runtime::new()
945            .unwrap()
946            .block_on(rt.embed_batch(&texts));
947        let embeddings = result.unwrap();
948        assert_eq!(embeddings.len(), texts.len());
949    }
950
951    #[test]
952    fn vector_search_requires_embedding_or_text() {
953        let rt = KhiveRuntime::memory().unwrap();
954        let tok = NamespaceToken::local();
955        let result = tokio::runtime::Runtime::new()
956            .unwrap()
957            .block_on(rt.vector_search(&tok, None, None, 10, Some(SubstrateKind::Entity)));
958        match result {
959            Err(crate::RuntimeError::InvalidInput(msg)) => {
960                assert!(msg.contains("query_embedding or query_text"), "msg: {msg}");
961            }
962            other => panic!("expected InvalidInput, got {other:?}"),
963        }
964    }
965
966    #[test]
967    fn vector_search_text_without_model_returns_unconfigured() {
968        let rt = KhiveRuntime::memory().unwrap();
969        let tok = NamespaceToken::local();
970        let result = tokio::runtime::Runtime::new()
971            .unwrap()
972            .block_on(rt.vector_search(
973                &tok,
974                None,
975                Some("attention"),
976                10,
977                Some(SubstrateKind::Entity),
978            ));
979        match result {
980            Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
981            other => panic!("expected Unconfigured, got {other:?}"),
982        }
983    }
984
985    #[test]
986    #[ignore = "loads ~80 MB model; run with --include-ignored"]
987    fn embed_batch_vectors_have_expected_dimensions() {
988        let model = EmbeddingModel::AllMiniLmL6V2;
989        let config = RuntimeConfig {
990            db_path: None,
991            default_namespace: Namespace::parse("test").unwrap(),
992            embedding_model: Some(model),
993            packs: vec!["kg".to_string()],
994            ..RuntimeConfig::default()
995        };
996        let rt = KhiveRuntime::new(config).unwrap();
997        let texts = vec!["hello world".to_string()];
998        let result = tokio::runtime::Runtime::new()
999            .unwrap()
1000            .block_on(rt.embed_batch(&texts));
1001        let embeddings = result.unwrap();
1002        assert_eq!(embeddings[0].len(), model.dimensions());
1003    }
1004
1005    // ---- hybrid_search enrichment (issue #147 / #160) ----
1006
1007    #[tokio::test]
1008    async fn hybrid_search_entity_hit_has_title() {
1009        let rt = KhiveRuntime::memory().unwrap();
1010        let tok = NamespaceToken::local();
1011        rt.create_entity(
1012            &tok,
1013            "concept",
1014            None,
1015            "FlashAttention",
1016            Some("IO-aware exact attention using tiling"),
1017            None,
1018            vec![],
1019        )
1020        .await
1021        .unwrap();
1022
1023        let hits = rt
1024            .hybrid_search(&tok, "FlashAttention", None, 10, None, None)
1025            .await
1026            .unwrap();
1027
1028        assert!(!hits.is_empty(), "should find the entity");
1029        let hit = &hits[0];
1030        assert!(hit.title.is_some(), "title must be populated");
1031        assert!(
1032            hit.title.as_deref().unwrap().contains("FlashAttention"),
1033            "title must contain entity name"
1034        );
1035    }
1036}