Skip to main content

khive_runtime/
retrieval.rs

1//! Retrieval operations: local embedding generation and hybrid search with RRF fusion.
2//!
3//! See ADR-012 — Retrieval Architecture.
4
5use std::collections::{HashMap, HashSet};
6
7use uuid::Uuid;
8
9use crate::error::RuntimeResult;
10use crate::runtime::KhiveRuntime;
11use khive_score::{rrf_score, DeterministicScore};
12use khive_storage::types::{
13    PageRequest, TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
14    VectorSearchRequest,
15};
16use khive_storage::EntityFilter;
17use khive_types::SubstrateKind;
18
19/// A unified search result combining vector and text signals.
20#[derive(Clone, Debug)]
21pub struct SearchHit {
22    pub entity_id: Uuid,
23    pub score: DeterministicScore,
24    pub source: SearchSource,
25    pub title: Option<String>,
26    pub snippet: Option<String>,
27}
28
29/// Which retrieval path(s) contributed to a hit.
30#[derive(Clone, Copy, Debug, PartialEq, Eq)]
31pub enum SearchSource {
32    Vector,
33    Text,
34    Both,
35}
36
37/// RRF constant from the original paper. Controls how strongly top ranks dominate.
38const RRF_K: usize = 60;
39
40/// Candidates pulled per path before fusion. Higher = better recall, more work.
41const CANDIDATE_MULTIPLIER: u32 = 4;
42
43impl KhiveRuntime {
44    /// Generate an embedding vector for `text` using the configured local model.
45    ///
46    /// First call lazily loads model weights (cold start cost). Subsequent calls reuse them.
47    /// Returns `Unconfigured("embedding_model")` if no model is configured.
48    pub async fn embed(&self, text: &str) -> RuntimeResult<Vec<f32>> {
49        let service = self.embedder().await?;
50        let model = self
51            .config()
52            .embedding_model
53            .expect("embedder() returns Unconfigured when model is None");
54        Ok(service.embed_one(text, model).await?)
55    }
56
57    /// Generate embeddings for multiple texts in one call.
58    ///
59    /// Delegates to the cached `EmbeddingService::embed`, so repeated texts within
60    /// and across calls benefit from the runtime-level LRU cache.
61    ///
62    /// Returns an empty vec for empty input without hitting the embedding service.
63    /// Returns `Unconfigured("embedding_model")` if no model is configured.
64    pub async fn embed_batch(&self, texts: &[String]) -> RuntimeResult<Vec<Vec<f32>>> {
65        if texts.is_empty() {
66            return Ok(vec![]);
67        }
68        let service = self.embedder().await?;
69        let model = self
70            .config()
71            .embedding_model
72            .expect("embedder() returns Unconfigured when model is None");
73        Ok(service.embed(texts, model).await?)
74    }
75
76    /// Hybrid search: text (FTS5) + vector retrieval fused via Reciprocal Rank Fusion.
77    ///
78    /// - Always performs text search over `query_text`.
79    /// - If `query_vector` is `Some`, also performs vector search and fuses both lists.
80    /// - If `None`, returns text-only results — no vector store needed.
81    /// - If `entity_kind` is `Some`, the alive-set query filters to that kind.
82    ///   The text/vector candidate pools are unfiltered up front; the kind
83    ///   filter applies at the alive-check stage where we already fetch each
84    ///   candidate to confirm it isn't soft-deleted.
85    ///
86    /// `limit` caps the final returned list; internally pulls `limit * 4` candidates per path.
87    /// The fused candidate set is kept untruncated until after the alive + kind filter so
88    /// that right-kind hits ranked below `limit` in the raw fusion still surface when
89    /// higher-ranked candidates are wrong-kind or soft-deleted.
90    pub async fn hybrid_search(
91        &self,
92        namespace: Option<&str>,
93        query_text: &str,
94        query_vector: Option<Vec<f32>>,
95        limit: u32,
96        entity_kind: Option<&str>,
97    ) -> RuntimeResult<Vec<SearchHit>> {
98        let candidates = limit.saturating_mul(CANDIDATE_MULTIPLIER).max(limit);
99
100        let ns = self.ns(namespace).to_string();
101        let text_hits = self
102            .text(namespace)?
103            .search(TextSearchRequest {
104                query: query_text.to_string(),
105                mode: TextQueryMode::Plain,
106                filter: Some(TextFilter {
107                    namespaces: vec![ns.clone()],
108                    ..TextFilter::default()
109                }),
110                top_k: candidates,
111                snippet_chars: 200,
112            })
113            .await?;
114
115        let vector_hits = if let Some(vec) = query_vector {
116            self.vectors(namespace)?
117                .search(VectorSearchRequest {
118                    query_embedding: vec,
119                    top_k: candidates,
120                    namespace: Some(ns.clone()),
121                    kind: Some(SubstrateKind::Entity),
122                })
123                .await?
124        } else {
125            Vec::new()
126        };
127
128        // Fuse without truncating: keep the full candidate pool through the
129        // alive/kind filter so right-kind hits below rank `limit` aren't lost.
130        let mut fused = rrf_fuse(text_hits, vector_hits, candidates as usize);
131
132        // Filter to alive entities (and optionally to a specific kind). A single
133        // query fetches all alive IDs that match the kind constraint from the
134        // fused set; any ID absent has been soft-deleted or doesn't match.
135        if !fused.is_empty() {
136            let candidate_ids: Vec<Uuid> = fused.iter().map(|h| h.entity_id).collect();
137            let alive_page = self
138                .entities(namespace)?
139                .query_entities(
140                    self.ns(namespace),
141                    EntityFilter {
142                        ids: candidate_ids,
143                        kinds: entity_kind.map(|k| vec![k.to_string()]).unwrap_or_default(),
144                        ..EntityFilter::default()
145                    },
146                    PageRequest {
147                        offset: 0,
148                        limit: fused.len() as u32,
149                    },
150                )
151                .await?;
152            let alive: HashSet<Uuid> = alive_page.items.into_iter().map(|e| e.id).collect();
153            fused.retain(|h| alive.contains(&h.entity_id));
154        }
155
156        fused.truncate(limit as usize);
157        Ok(fused)
158    }
159
160    /// Exact KNN over the full namespace's vector store.
161    ///
162    /// sqlite-vec uses brute-force cosine — results are exact, not approximate.
163    /// Cost is O(N · D) per query. For small-to-medium namespaces (~hundreds of
164    /// thousands of vectors) this is well within latency budgets.
165    pub async fn knn(
166        &self,
167        namespace: Option<&str>,
168        query_vector: Vec<f32>,
169        top_k: u32,
170    ) -> RuntimeResult<Vec<VectorSearchHit>> {
171        let ns = self.ns(namespace).to_string();
172        Ok(self
173            .vectors(namespace)?
174            .search(VectorSearchRequest {
175                query_embedding: query_vector,
176                top_k,
177                namespace: Some(ns),
178                kind: Some(SubstrateKind::Entity),
179            })
180            .await?)
181    }
182
183    /// Exact KNN restricted to a candidate set.
184    ///
185    /// Useful for reranking the top-N results from `hybrid_search` (or any other
186    /// retrieval path) with exact cosine similarity against a query vector.
187    /// Returns hits sorted by similarity (highest first), truncated to `top_k`.
188    pub async fn rerank(
189        &self,
190        namespace: Option<&str>,
191        query_vector: &[f32],
192        candidate_ids: &[Uuid],
193        top_k: u32,
194    ) -> RuntimeResult<Vec<VectorSearchHit>> {
195        let candidate_set: HashSet<Uuid> = candidate_ids.iter().copied().collect();
196        let ns = self.ns(namespace).to_string();
197        let all_hits = self
198            .vectors(namespace)?
199            .search(VectorSearchRequest {
200                query_embedding: query_vector.to_vec(),
201                top_k: candidate_ids.len() as u32,
202                namespace: Some(ns),
203                kind: Some(SubstrateKind::Entity),
204            })
205            .await?;
206        let mut hits: Vec<VectorSearchHit> = all_hits
207            .into_iter()
208            .filter(|h| candidate_set.contains(&h.subject_id))
209            .collect();
210        hits.sort_by(|a, b| b.score.cmp(&a.score));
211        hits.truncate(top_k as usize);
212        Ok(hits)
213    }
214}
215
216/// Fuse text + vector hits with Reciprocal Rank Fusion (k=60).
217///
218/// Hits in both lists get RRF scores summed. Sort by fused score, take top-`limit`.
219fn rrf_fuse(
220    text_hits: Vec<TextSearchHit>,
221    vector_hits: Vec<VectorSearchHit>,
222    limit: usize,
223) -> Vec<SearchHit> {
224    #[derive(Default)]
225    struct Bucket {
226        score: DeterministicScore,
227        source: Option<SearchSource>,
228        title: Option<String>,
229        snippet: Option<String>,
230    }
231
232    let mut buckets: HashMap<Uuid, Bucket> = HashMap::new();
233
234    for (i, hit) in text_hits.into_iter().enumerate() {
235        let rank = i + 1; // RRF is 1-indexed
236        let entry = buckets.entry(hit.subject_id).or_default();
237        entry.score = entry.score + rrf_score(rank, RRF_K);
238        entry.source = Some(match entry.source {
239            Some(SearchSource::Vector) => SearchSource::Both,
240            _ => SearchSource::Text,
241        });
242        if entry.title.is_none() {
243            entry.title = hit.title;
244        }
245        if entry.snippet.is_none() {
246            entry.snippet = hit.snippet;
247        }
248    }
249
250    for (i, hit) in vector_hits.into_iter().enumerate() {
251        let rank = i + 1;
252        let entry = buckets.entry(hit.subject_id).or_default();
253        entry.score = entry.score + rrf_score(rank, RRF_K);
254        entry.source = Some(match entry.source {
255            Some(SearchSource::Text) => SearchSource::Both,
256            _ => SearchSource::Vector,
257        });
258    }
259
260    let mut hits: Vec<SearchHit> = buckets
261        .into_iter()
262        .map(|(id, b)| SearchHit {
263            entity_id: id,
264            score: b.score,
265            source: b.source.expect("each bucket gets a source"),
266            title: b.title,
267            snippet: b.snippet,
268        })
269        .collect();
270
271    hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
272    hits.truncate(limit);
273    hits
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use crate::runtime::{KhiveRuntime, RuntimeConfig};
280    use khive_storage::types::{TextSearchHit, VectorSearchHit};
281    use lattice_embed::EmbeddingModel;
282
283    fn text_hit(id: Uuid, rank: u32, title: &str) -> TextSearchHit {
284        TextSearchHit {
285            subject_id: id,
286            score: DeterministicScore::from_f64(1.0),
287            rank,
288            title: Some(title.to_string()),
289            snippet: Some("...".to_string()),
290        }
291    }
292
293    fn vector_hit(id: Uuid, rank: u32) -> VectorSearchHit {
294        VectorSearchHit {
295            subject_id: id,
296            score: DeterministicScore::from_f64(0.9),
297            rank,
298        }
299    }
300
301    #[test]
302    fn rrf_fuse_text_only() {
303        let a = Uuid::new_v4();
304        let b = Uuid::new_v4();
305        let text = vec![text_hit(a, 1, "A"), text_hit(b, 2, "B")];
306        let hits = rrf_fuse(text, vec![], 10);
307        assert_eq!(hits.len(), 2);
308        assert_eq!(hits[0].entity_id, a);
309        assert_eq!(hits[0].source, SearchSource::Text);
310        assert_eq!(hits[0].title.as_deref(), Some("A"));
311    }
312
313    #[test]
314    fn rrf_fuse_vector_only() {
315        let a = Uuid::new_v4();
316        let hits = rrf_fuse(vec![], vec![vector_hit(a, 1)], 10);
317        assert_eq!(hits.len(), 1);
318        assert_eq!(hits[0].source, SearchSource::Vector);
319        assert!(hits[0].title.is_none());
320    }
321
322    #[test]
323    fn rrf_fuse_marks_both_when_in_both_lists() {
324        let id = Uuid::new_v4();
325        let text = vec![text_hit(id, 1, "A")];
326        let vec = vec![vector_hit(id, 1)];
327        let hits = rrf_fuse(text, vec, 10);
328        assert_eq!(hits.len(), 1);
329        assert_eq!(hits[0].source, SearchSource::Both);
330    }
331
332    #[test]
333    fn rrf_fuse_respects_limit() {
334        let hits: Vec<TextSearchHit> = (0..20)
335            .map(|i| text_hit(Uuid::new_v4(), i + 1, "x"))
336            .collect();
337        let fused = rrf_fuse(hits, vec![], 5);
338        assert_eq!(fused.len(), 5);
339    }
340
341    #[test]
342    fn rrf_fuse_orders_higher_score_first() {
343        // Same UUID in both lists at rank 1 → score 2/(60+1). Different UUIDs → 1/(60+1) each.
344        let a = Uuid::new_v4();
345        let b = Uuid::new_v4();
346        let text = vec![text_hit(a, 1, "A")];
347        let vec = vec![vector_hit(a, 1), vector_hit(b, 2)];
348        let hits = rrf_fuse(text, vec, 10);
349        assert_eq!(hits[0].entity_id, a);
350        assert_eq!(hits[0].source, SearchSource::Both);
351        assert!(hits[0].score > hits[1].score);
352    }
353
354    // ---- embed_batch tests ----
355
356    #[test]
357    fn embed_batch_unconfigured_on_memory_runtime() {
358        // KhiveRuntime::memory() has no embedding model — embed_batch returns Unconfigured.
359        let rt = KhiveRuntime::memory().unwrap();
360        let result = tokio::runtime::Runtime::new()
361            .unwrap()
362            .block_on(rt.embed_batch(&[]));
363        // Empty slice short-circuits before hitting the model check.
364        assert!(result.is_ok());
365        assert!(result.unwrap().is_empty());
366    }
367
368    #[test]
369    fn embed_batch_empty_input_returns_empty_vec() {
370        // No model needed — empty slice is handled before the embedder is touched.
371        let rt = KhiveRuntime::memory().unwrap();
372        let result = tokio::runtime::Runtime::new()
373            .unwrap()
374            .block_on(rt.embed_batch(&[]));
375        assert_eq!(result.unwrap(), Vec::<Vec<f32>>::new());
376    }
377
378    #[test]
379    fn embed_batch_no_model_non_empty_returns_unconfigured() {
380        let rt = KhiveRuntime::memory().unwrap();
381        let texts = vec!["hello".to_string()];
382        let result = tokio::runtime::Runtime::new()
383            .unwrap()
384            .block_on(rt.embed_batch(&texts));
385        match result {
386            Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
387            Err(other) => panic!("expected Unconfigured, got {:?}", other),
388            Ok(_) => panic!("expected Err, got Ok"),
389        }
390    }
391
392    #[test]
393    #[ignore = "loads ~80 MB model; run with --include-ignored"]
394    fn embed_batch_count_matches_input() {
395        let config = RuntimeConfig {
396            db_path: None,
397            default_namespace: "test".to_string(),
398            embedding_model: Some(EmbeddingModel::AllMiniLmL6V2),
399            packs: vec!["kg".to_string()],
400            ..RuntimeConfig::default()
401        };
402        let rt = KhiveRuntime::new(config).unwrap();
403        let texts: Vec<String> = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
404        let result = tokio::runtime::Runtime::new()
405            .unwrap()
406            .block_on(rt.embed_batch(&texts));
407        let embeddings = result.unwrap();
408        assert_eq!(embeddings.len(), texts.len());
409    }
410
411    #[test]
412    #[ignore = "loads ~80 MB model; run with --include-ignored"]
413    fn embed_batch_vectors_have_expected_dimensions() {
414        let model = EmbeddingModel::AllMiniLmL6V2;
415        let config = RuntimeConfig {
416            db_path: None,
417            default_namespace: "test".to_string(),
418            embedding_model: Some(model),
419            packs: vec!["kg".to_string()],
420            ..RuntimeConfig::default()
421        };
422        let rt = KhiveRuntime::new(config).unwrap();
423        let texts = vec!["hello world".to_string()];
424        let result = tokio::runtime::Runtime::new()
425            .unwrap()
426            .block_on(rt.embed_batch(&texts));
427        let embeddings = result.unwrap();
428        assert_eq!(embeddings[0].len(), model.dimensions());
429    }
430}