Skip to main content

khive_runtime/
retrieval.rs

1//! Retrieval operations: local embedding generation and hybrid search with RRF fusion.
2//!
3//! See ADR-012 — Retrieval Architecture.
4
5use std::collections::{HashMap, HashSet};
6
7use uuid::Uuid;
8
9use crate::error::{RuntimeError, RuntimeResult};
10use crate::runtime::{KhiveRuntime, NamespaceToken};
11use khive_score::{rrf_score, DeterministicScore};
12use khive_storage::types::{
13    PageRequest, TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
14    VectorSearchRequest,
15};
16use khive_storage::EntityFilter;
17use khive_types::SubstrateKind;
18
19/// A unified search result combining vector and text signals.
20#[derive(Clone, Debug)]
21pub struct SearchHit {
22    pub entity_id: Uuid,
23    pub score: DeterministicScore,
24    pub source: SearchSource,
25    pub title: Option<String>,
26    pub snippet: Option<String>,
27}
28
29/// Which retrieval path(s) contributed to a hit.
30#[derive(Clone, Copy, Debug, PartialEq, Eq)]
31pub enum SearchSource {
32    Vector,
33    Text,
34    Both,
35}
36
37/// RRF constant from the original paper. Controls how strongly top ranks dominate.
38const RRF_K: usize = 60;
39
40/// Candidates pulled per path before fusion. Higher = better recall, more work.
41const CANDIDATE_MULTIPLIER: u32 = 4;
42
43impl KhiveRuntime {
44    /// Generate an embedding vector for `text` using the configured local model.
45    ///
46    /// First call lazily loads model weights (cold start cost). Subsequent calls reuse them.
47    /// Returns `Unconfigured("embedding_model")` if no model is configured.
48    pub async fn embed(&self, text: &str) -> RuntimeResult<Vec<f32>> {
49        let service = self.embedder().await?;
50        let model = self
51            .config()
52            .embedding_model
53            .expect("embedder() returns Unconfigured when model is None");
54        Ok(service.embed_one(text, model).await?)
55    }
56
57    /// Generate embeddings for multiple texts in one call.
58    ///
59    /// Delegates to the cached `EmbeddingService::embed`, so repeated texts within
60    /// and across calls benefit from the runtime-level LRU cache.
61    ///
62    /// Returns an empty vec for empty input without hitting the embedding service.
63    /// Returns `Unconfigured("embedding_model")` if no model is configured.
64    pub async fn embed_batch(&self, texts: &[String]) -> RuntimeResult<Vec<Vec<f32>>> {
65        if texts.is_empty() {
66            return Ok(vec![]);
67        }
68        let service = self.embedder().await?;
69        let model = self
70            .config()
71            .embedding_model
72            .expect("embedder() returns Unconfigured when model is None");
73        Ok(service.embed(texts, model).await?)
74    }
75
76    /// Search vectors using either a caller-provided embedding or query text.
77    ///
78    /// Existing callers pass `query_embedding: Some(vec)` to avoid re-embedding.
79    /// Text callers pass `query_embedding: None, query_text: Some(...)` and the
80    /// runtime embeds internally.
81    pub async fn vector_search(
82        &self,
83        token: &NamespaceToken,
84        query_embedding: Option<Vec<f32>>,
85        query_text: Option<&str>,
86        top_k: u32,
87        kind: Option<SubstrateKind>,
88    ) -> RuntimeResult<Vec<VectorSearchHit>> {
89        let embedding = match query_embedding {
90            Some(vec) => vec,
91            None => {
92                let text = query_text.ok_or_else(|| {
93                    RuntimeError::InvalidInput(
94                        "vector search requires query_embedding or query_text".into(),
95                    )
96                })?;
97                if text.trim().is_empty() {
98                    return Err(RuntimeError::InvalidInput(
99                        "query_text must not be empty".into(),
100                    ));
101                }
102                self.embed(text).await?
103            }
104        };
105
106        let ns = token.namespace().as_str().to_owned();
107        Ok(self
108            .vectors(token)?
109            .search(VectorSearchRequest {
110                query_vectors: vec![embedding],
111                top_k,
112                namespace: Some(ns),
113                kind,
114                filter: None,
115                backend_hints: None,
116            })
117            .await?)
118    }
119
120    /// Hybrid search: text (FTS5) + vector retrieval fused via Reciprocal Rank Fusion.
121    ///
122    /// - Always performs text search over `query_text`.
123    /// - If `query_vector` is `Some`, also performs vector search and fuses both lists.
124    /// - If `None`, returns text-only results — no vector store needed.
125    /// - If `entity_kind` is `Some`, the alive-set query filters to that kind.
126    ///   The text/vector candidate pools are unfiltered up front; the kind
127    ///   filter applies at the alive-check stage where we already fetch each
128    ///   candidate to confirm it isn't soft-deleted.
129    ///
130    /// `limit` caps the final returned list; internally pulls `limit * 4` candidates per path.
131    /// The fused candidate set is kept untruncated until after the alive + kind filter so
132    /// that right-kind hits ranked below `limit` in the raw fusion still surface when
133    /// higher-ranked candidates are wrong-kind or soft-deleted.
134    #[allow(clippy::too_many_arguments)]
135    pub async fn hybrid_search(
136        &self,
137        token: &NamespaceToken,
138        query_text: &str,
139        query_vector: Option<Vec<f32>>,
140        limit: u32,
141        entity_kind: Option<&str>,
142        entity_type: Option<&str>,
143    ) -> RuntimeResult<Vec<SearchHit>> {
144        let candidates = limit.saturating_mul(CANDIDATE_MULTIPLIER).max(limit);
145
146        let ns = token.namespace().as_str().to_owned();
147        let text_hits = self
148            .text(token)?
149            .search(TextSearchRequest {
150                query: query_text.to_string(),
151                mode: TextQueryMode::Plain,
152                filter: Some(TextFilter {
153                    namespaces: vec![ns.clone()],
154                    ..TextFilter::default()
155                }),
156                top_k: candidates,
157                snippet_chars: 200,
158            })
159            .await?;
160
161        let vector_hits = if query_vector.is_some() || self.config().embedding_model.is_some() {
162            self.vector_search(
163                token,
164                query_vector,
165                Some(query_text),
166                candidates,
167                Some(SubstrateKind::Entity),
168            )
169            .await?
170        } else {
171            Vec::new()
172        };
173
174        // Fuse without truncating: keep the full candidate pool through the
175        // alive/kind filter so right-kind hits below rank `limit` aren't lost.
176        let mut fused = rrf_fuse(text_hits, vector_hits, candidates as usize);
177
178        // Filter to alive entities (and optionally to a specific kind). A single
179        // query fetches all alive IDs that match the kind constraint from the
180        // fused set; any ID absent has been soft-deleted or doesn't match.
181        if !fused.is_empty() {
182            let candidate_ids: Vec<Uuid> = fused.iter().map(|h| h.entity_id).collect();
183            let alive_page = self
184                .entities(token)?
185                .query_entities(
186                    token.namespace().as_str(),
187                    EntityFilter {
188                        ids: candidate_ids,
189                        kinds: entity_kind.map(|k| vec![k.to_string()]).unwrap_or_default(),
190                        entity_types: entity_type.map(|t| vec![t.to_string()]).unwrap_or_default(),
191                        ..EntityFilter::default()
192                    },
193                    PageRequest {
194                        offset: 0,
195                        limit: fused.len() as u32,
196                    },
197                )
198                .await?;
199            // Keep entity metadata to enrich hits that had no FTS5 title/snippet.
200            let mut entity_meta: HashMap<Uuid, (String, Option<String>)> = HashMap::new();
201            let mut alive: HashSet<Uuid> = HashSet::new();
202            for e in alive_page.items {
203                alive.insert(e.id);
204                entity_meta.insert(e.id, (e.name, e.description));
205            }
206
207            fused.retain(|h| alive.contains(&h.entity_id));
208
209            // Enrich vector-only hits (title/snippet == None) from entity record.
210            for hit in &mut fused {
211                if let Some((name, description)) = entity_meta.get(&hit.entity_id) {
212                    if hit.title.is_none() {
213                        hit.title = Some(name.clone());
214                    }
215                    if hit.snippet.is_none() {
216                        hit.snippet = description.clone();
217                    }
218                }
219            }
220        }
221
222        fused.truncate(limit as usize);
223        Ok(fused)
224    }
225
226    /// Exact KNN over the full namespace's vector store.
227    ///
228    /// sqlite-vec uses brute-force cosine — results are exact, not approximate.
229    /// Cost is O(N · D) per query. For small-to-medium namespaces (~hundreds of
230    /// thousands of vectors) this is well within latency budgets.
231    pub async fn knn(
232        &self,
233        token: &NamespaceToken,
234        query_vector: Vec<f32>,
235        top_k: u32,
236    ) -> RuntimeResult<Vec<VectorSearchHit>> {
237        let ns = token.namespace().as_str().to_owned();
238        Ok(self
239            .vectors(token)?
240            .search(VectorSearchRequest {
241                query_vectors: vec![query_vector],
242                top_k,
243                namespace: Some(ns),
244                kind: Some(SubstrateKind::Entity),
245                filter: None,
246                backend_hints: None,
247            })
248            .await?)
249    }
250
251    /// Exact KNN restricted to a candidate set.
252    ///
253    /// Useful for reranking the top-N results from `hybrid_search` (or any other
254    /// retrieval path) with exact cosine similarity against a query vector.
255    /// Returns hits sorted by similarity (highest first), truncated to `top_k`.
256    pub async fn rerank(
257        &self,
258        token: &NamespaceToken,
259        query_vector: &[f32],
260        candidate_ids: &[Uuid],
261        top_k: u32,
262    ) -> RuntimeResult<Vec<VectorSearchHit>> {
263        let candidate_set: HashSet<Uuid> = candidate_ids.iter().copied().collect();
264        let ns = token.namespace().as_str().to_owned();
265        let all_hits = self
266            .vectors(token)?
267            .search(VectorSearchRequest {
268                query_vectors: vec![query_vector.to_vec()],
269                top_k: candidate_ids.len() as u32,
270                namespace: Some(ns),
271                kind: Some(SubstrateKind::Entity),
272                filter: None,
273                backend_hints: None,
274            })
275            .await?;
276        let mut hits: Vec<VectorSearchHit> = all_hits
277            .into_iter()
278            .filter(|h| candidate_set.contains(&h.subject_id))
279            .collect();
280        hits.sort_by(|a, b| b.score.cmp(&a.score));
281        hits.truncate(top_k as usize);
282        Ok(hits)
283    }
284}
285
286/// Fuse text + vector hits with Reciprocal Rank Fusion (k=60).
287///
288/// Hits in both lists get RRF scores summed. Sort by fused score, take top-`limit`.
289fn rrf_fuse(
290    text_hits: Vec<TextSearchHit>,
291    vector_hits: Vec<VectorSearchHit>,
292    limit: usize,
293) -> Vec<SearchHit> {
294    #[derive(Default)]
295    struct Bucket {
296        score: DeterministicScore,
297        source: Option<SearchSource>,
298        title: Option<String>,
299        snippet: Option<String>,
300    }
301
302    let mut buckets: HashMap<Uuid, Bucket> = HashMap::new();
303
304    for (i, hit) in text_hits.into_iter().enumerate() {
305        let rank = i + 1; // RRF is 1-indexed
306        let entry = buckets.entry(hit.subject_id).or_default();
307        entry.score = entry.score + rrf_score(rank, RRF_K);
308        entry.source = Some(match entry.source {
309            Some(SearchSource::Vector) => SearchSource::Both,
310            _ => SearchSource::Text,
311        });
312        if entry.title.is_none() {
313            entry.title = hit.title;
314        }
315        if entry.snippet.is_none() {
316            entry.snippet = hit.snippet;
317        }
318    }
319
320    for (i, hit) in vector_hits.into_iter().enumerate() {
321        let rank = i + 1;
322        let entry = buckets.entry(hit.subject_id).or_default();
323        entry.score = entry.score + rrf_score(rank, RRF_K);
324        entry.source = Some(match entry.source {
325            Some(SearchSource::Text) => SearchSource::Both,
326            _ => SearchSource::Vector,
327        });
328    }
329
330    let mut hits: Vec<SearchHit> = buckets
331        .into_iter()
332        .map(|(id, b)| SearchHit {
333            entity_id: id,
334            score: b.score,
335            source: b.source.expect("each bucket gets a source"),
336            title: b.title,
337            snippet: b.snippet,
338        })
339        .collect();
340
341    hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
342    hits.truncate(limit);
343    hits
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349    use crate::runtime::{KhiveRuntime, NamespaceToken, RuntimeConfig};
350    use khive_storage::types::{TextSearchHit, VectorSearchHit};
351    use khive_types::namespace::Namespace;
352    use lattice_embed::EmbeddingModel;
353
354    fn text_hit(id: Uuid, rank: u32, title: &str) -> TextSearchHit {
355        TextSearchHit {
356            subject_id: id,
357            score: DeterministicScore::from_f64(1.0),
358            rank,
359            title: Some(title.to_string()),
360            snippet: Some("...".to_string()),
361        }
362    }
363
364    fn vector_hit(id: Uuid, rank: u32) -> VectorSearchHit {
365        VectorSearchHit {
366            subject_id: id,
367            score: DeterministicScore::from_f64(0.9),
368            rank,
369        }
370    }
371
372    #[test]
373    fn rrf_fuse_text_only() {
374        let a = Uuid::new_v4();
375        let b = Uuid::new_v4();
376        let text = vec![text_hit(a, 1, "A"), text_hit(b, 2, "B")];
377        let hits = rrf_fuse(text, vec![], 10);
378        assert_eq!(hits.len(), 2);
379        assert_eq!(hits[0].entity_id, a);
380        assert_eq!(hits[0].source, SearchSource::Text);
381        assert_eq!(hits[0].title.as_deref(), Some("A"));
382    }
383
384    #[test]
385    fn rrf_fuse_vector_only() {
386        let a = Uuid::new_v4();
387        let hits = rrf_fuse(vec![], vec![vector_hit(a, 1)], 10);
388        assert_eq!(hits.len(), 1);
389        assert_eq!(hits[0].source, SearchSource::Vector);
390        assert!(hits[0].title.is_none());
391    }
392
393    #[test]
394    fn rrf_fuse_marks_both_when_in_both_lists() {
395        let id = Uuid::new_v4();
396        let text = vec![text_hit(id, 1, "A")];
397        let vec = vec![vector_hit(id, 1)];
398        let hits = rrf_fuse(text, vec, 10);
399        assert_eq!(hits.len(), 1);
400        assert_eq!(hits[0].source, SearchSource::Both);
401    }
402
403    #[test]
404    fn rrf_fuse_respects_limit() {
405        let hits: Vec<TextSearchHit> = (0..20)
406            .map(|i| text_hit(Uuid::new_v4(), i + 1, "x"))
407            .collect();
408        let fused = rrf_fuse(hits, vec![], 5);
409        assert_eq!(fused.len(), 5);
410    }
411
412    #[test]
413    fn rrf_fuse_orders_higher_score_first() {
414        // Same UUID in both lists at rank 1 → score 2/(60+1). Different UUIDs → 1/(60+1) each.
415        let a = Uuid::new_v4();
416        let b = Uuid::new_v4();
417        let text = vec![text_hit(a, 1, "A")];
418        let vec = vec![vector_hit(a, 1), vector_hit(b, 2)];
419        let hits = rrf_fuse(text, vec, 10);
420        assert_eq!(hits[0].entity_id, a);
421        assert_eq!(hits[0].source, SearchSource::Both);
422        assert!(hits[0].score > hits[1].score);
423    }
424
425    // ---- embed_batch tests ----
426
427    #[test]
428    fn embed_batch_unconfigured_on_memory_runtime() {
429        // KhiveRuntime::memory() has no embedding model — embed_batch returns Unconfigured.
430        let rt = KhiveRuntime::memory().unwrap();
431        let result = tokio::runtime::Runtime::new()
432            .unwrap()
433            .block_on(rt.embed_batch(&[]));
434        // Empty slice short-circuits before hitting the model check.
435        assert!(result.is_ok());
436        assert!(result.unwrap().is_empty());
437    }
438
439    #[test]
440    fn embed_batch_empty_input_returns_empty_vec() {
441        // No model needed — empty slice is handled before the embedder is touched.
442        let rt = KhiveRuntime::memory().unwrap();
443        let result = tokio::runtime::Runtime::new()
444            .unwrap()
445            .block_on(rt.embed_batch(&[]));
446        assert_eq!(result.unwrap(), Vec::<Vec<f32>>::new());
447    }
448
449    #[test]
450    fn embed_batch_no_model_non_empty_returns_unconfigured() {
451        let rt = KhiveRuntime::memory().unwrap();
452        let texts = vec!["hello".to_string()];
453        let result = tokio::runtime::Runtime::new()
454            .unwrap()
455            .block_on(rt.embed_batch(&texts));
456        match result {
457            Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
458            Err(other) => panic!("expected Unconfigured, got {:?}", other),
459            Ok(_) => panic!("expected Err, got Ok"),
460        }
461    }
462
463    #[test]
464    #[ignore = "loads ~80 MB model; run with --include-ignored"]
465    fn embed_batch_count_matches_input() {
466        let config = RuntimeConfig {
467            db_path: None,
468            default_namespace: Namespace::parse("test").unwrap(),
469            embedding_model: Some(EmbeddingModel::AllMiniLmL6V2),
470            packs: vec!["kg".to_string()],
471            ..RuntimeConfig::default()
472        };
473        let rt = KhiveRuntime::new(config).unwrap();
474        let texts: Vec<String> = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
475        let result = tokio::runtime::Runtime::new()
476            .unwrap()
477            .block_on(rt.embed_batch(&texts));
478        let embeddings = result.unwrap();
479        assert_eq!(embeddings.len(), texts.len());
480    }
481
482    #[test]
483    fn vector_search_requires_embedding_or_text() {
484        let rt = KhiveRuntime::memory().unwrap();
485        let tok = NamespaceToken::local();
486        let result = tokio::runtime::Runtime::new()
487            .unwrap()
488            .block_on(rt.vector_search(&tok, None, None, 10, Some(SubstrateKind::Entity)));
489        match result {
490            Err(crate::RuntimeError::InvalidInput(msg)) => {
491                assert!(msg.contains("query_embedding or query_text"), "msg: {msg}");
492            }
493            other => panic!("expected InvalidInput, got {other:?}"),
494        }
495    }
496
497    #[test]
498    fn vector_search_text_without_model_returns_unconfigured() {
499        let rt = KhiveRuntime::memory().unwrap();
500        let tok = NamespaceToken::local();
501        let result = tokio::runtime::Runtime::new()
502            .unwrap()
503            .block_on(rt.vector_search(
504                &tok,
505                None,
506                Some("attention"),
507                10,
508                Some(SubstrateKind::Entity),
509            ));
510        match result {
511            Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
512            other => panic!("expected Unconfigured, got {other:?}"),
513        }
514    }
515
516    #[test]
517    #[ignore = "loads ~80 MB model; run with --include-ignored"]
518    fn embed_batch_vectors_have_expected_dimensions() {
519        let model = EmbeddingModel::AllMiniLmL6V2;
520        let config = RuntimeConfig {
521            db_path: None,
522            default_namespace: Namespace::parse("test").unwrap(),
523            embedding_model: Some(model),
524            packs: vec!["kg".to_string()],
525            ..RuntimeConfig::default()
526        };
527        let rt = KhiveRuntime::new(config).unwrap();
528        let texts = vec!["hello world".to_string()];
529        let result = tokio::runtime::Runtime::new()
530            .unwrap()
531            .block_on(rt.embed_batch(&texts));
532        let embeddings = result.unwrap();
533        assert_eq!(embeddings[0].len(), model.dimensions());
534    }
535
536    // ---- hybrid_search enrichment (issue #147 / #160) ----
537
538    #[tokio::test]
539    async fn hybrid_search_entity_hit_has_title() {
540        let rt = KhiveRuntime::memory().unwrap();
541        let tok = NamespaceToken::local();
542        rt.create_entity(
543            &tok,
544            "concept",
545            None,
546            "FlashAttention",
547            Some("IO-aware exact attention using tiling"),
548            None,
549            vec![],
550        )
551        .await
552        .unwrap();
553
554        let hits = rt
555            .hybrid_search(&tok, "FlashAttention", None, 10, None, None)
556            .await
557            .unwrap();
558
559        assert!(!hits.is_empty(), "should find the entity");
560        let hit = &hits[0];
561        assert!(hit.title.is_some(), "title must be populated");
562        assert!(
563            hit.title.as_deref().unwrap().contains("FlashAttention"),
564            "title must contain entity name"
565        );
566    }
567}