Skip to main content

khive_runtime/
retrieval.rs

1//! Retrieval operations: local embedding generation and hybrid search with RRF fusion.
2//!
3//! See ADR-012 — Retrieval Architecture.
4
5use std::collections::{HashMap, HashSet};
6
7use uuid::Uuid;
8
9use crate::error::RuntimeResult;
10use crate::runtime::KhiveRuntime;
11use khive_score::{rrf_score, DeterministicScore};
12use khive_storage::types::{
13    PageRequest, TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
14    VectorSearchRequest,
15};
16use khive_storage::EntityFilter;
17use khive_types::SubstrateKind;
18
19/// A unified search result combining vector and text signals.
20#[derive(Clone, Debug)]
21pub struct SearchHit {
22    pub entity_id: Uuid,
23    pub score: DeterministicScore,
24    pub source: SearchSource,
25    pub title: Option<String>,
26    pub snippet: Option<String>,
27}
28
29/// Which retrieval path(s) contributed to a hit.
30#[derive(Clone, Copy, Debug, PartialEq, Eq)]
31pub enum SearchSource {
32    Vector,
33    Text,
34    Both,
35}
36
37/// RRF constant from the original paper. Controls how strongly top ranks dominate.
38const RRF_K: usize = 60;
39
40/// Candidates pulled per path before fusion. Higher = better recall, more work.
41const CANDIDATE_MULTIPLIER: u32 = 4;
42
43impl KhiveRuntime {
44    /// Generate an embedding vector for `text` using the configured local model.
45    ///
46    /// First call lazily loads model weights (cold start cost). Subsequent calls reuse them.
47    /// Returns `Unconfigured("embedding_model")` if no model is configured.
48    pub async fn embed(&self, text: &str) -> RuntimeResult<Vec<f32>> {
49        let service = self.embedder().await?;
50        let model = self
51            .config()
52            .embedding_model
53            .expect("embedder() returns Unconfigured when model is None");
54        Ok(service.embed_one(text, model).await?)
55    }
56
57    /// Generate embeddings for multiple texts in one call.
58    ///
59    /// Delegates to the cached `EmbeddingService::embed`, so repeated texts within
60    /// and across calls benefit from the runtime-level LRU cache.
61    ///
62    /// Returns an empty vec for empty input without hitting the embedding service.
63    /// Returns `Unconfigured("embedding_model")` if no model is configured.
64    pub async fn embed_batch(&self, texts: &[String]) -> RuntimeResult<Vec<Vec<f32>>> {
65        if texts.is_empty() {
66            return Ok(vec![]);
67        }
68        let service = self.embedder().await?;
69        let model = self
70            .config()
71            .embedding_model
72            .expect("embedder() returns Unconfigured when model is None");
73        Ok(service.embed(texts, model).await?)
74    }
75
76    /// Hybrid search: text (FTS5) + vector retrieval fused via Reciprocal Rank Fusion.
77    ///
78    /// - Always performs text search over `query_text`.
79    /// - If `query_vector` is `Some`, also performs vector search and fuses both lists.
80    /// - If `None`, returns text-only results — no vector store needed.
81    ///
82    /// `limit` caps the final returned list; internally pulls `limit * 4` candidates per path.
83    pub async fn hybrid_search(
84        &self,
85        namespace: Option<&str>,
86        query_text: &str,
87        query_vector: Option<Vec<f32>>,
88        limit: u32,
89    ) -> RuntimeResult<Vec<SearchHit>> {
90        let candidates = limit.saturating_mul(CANDIDATE_MULTIPLIER).max(limit);
91
92        let ns = self.ns(namespace).to_string();
93        let text_hits = self
94            .text(namespace)?
95            .search(TextSearchRequest {
96                query: query_text.to_string(),
97                mode: TextQueryMode::Plain,
98                filter: Some(TextFilter {
99                    namespaces: vec![ns.clone()],
100                    ..TextFilter::default()
101                }),
102                top_k: candidates,
103                snippet_chars: 200,
104            })
105            .await?;
106
107        let vector_hits = if let Some(vec) = query_vector {
108            self.vectors(namespace)?
109                .search(VectorSearchRequest {
110                    query_embedding: vec,
111                    top_k: candidates,
112                    namespace: Some(ns.clone()),
113                    kind: Some(SubstrateKind::Entity),
114                })
115                .await?
116        } else {
117            Vec::new()
118        };
119
120        let mut fused = rrf_fuse(text_hits, vector_hits, limit as usize);
121
122        // Filter out soft-deleted entities. A single query fetches all alive IDs from the
123        // fused set; any ID absent from the result has been soft-deleted (deleted_at IS NOT NULL).
124        if !fused.is_empty() {
125            let candidate_ids: Vec<Uuid> = fused.iter().map(|h| h.entity_id).collect();
126            let alive_page = self
127                .entities(namespace)?
128                .query_entities(
129                    self.ns(namespace),
130                    EntityFilter {
131                        ids: candidate_ids,
132                        ..EntityFilter::default()
133                    },
134                    PageRequest {
135                        offset: 0,
136                        limit: fused.len() as u32,
137                    },
138                )
139                .await?;
140            let alive: HashSet<Uuid> = alive_page.items.into_iter().map(|e| e.id).collect();
141            fused.retain(|h| alive.contains(&h.entity_id));
142        }
143
144        Ok(fused)
145    }
146
147    /// Exact KNN over the full namespace's vector store.
148    ///
149    /// sqlite-vec uses brute-force cosine — results are exact, not approximate.
150    /// Cost is O(N · D) per query. For small-to-medium namespaces (~hundreds of
151    /// thousands of vectors) this is well within latency budgets.
152    pub async fn knn(
153        &self,
154        namespace: Option<&str>,
155        query_vector: Vec<f32>,
156        top_k: u32,
157    ) -> RuntimeResult<Vec<VectorSearchHit>> {
158        let ns = self.ns(namespace).to_string();
159        Ok(self
160            .vectors(namespace)?
161            .search(VectorSearchRequest {
162                query_embedding: query_vector,
163                top_k,
164                namespace: Some(ns),
165                kind: Some(SubstrateKind::Entity),
166            })
167            .await?)
168    }
169
170    /// Exact KNN restricted to a candidate set.
171    ///
172    /// Useful for reranking the top-N results from `hybrid_search` (or any other
173    /// retrieval path) with exact cosine similarity against a query vector.
174    /// Returns hits sorted by similarity (highest first), truncated to `top_k`.
175    pub async fn rerank(
176        &self,
177        namespace: Option<&str>,
178        query_vector: &[f32],
179        candidate_ids: &[Uuid],
180        top_k: u32,
181    ) -> RuntimeResult<Vec<VectorSearchHit>> {
182        let candidate_set: HashSet<Uuid> = candidate_ids.iter().copied().collect();
183        let ns = self.ns(namespace).to_string();
184        let all_hits = self
185            .vectors(namespace)?
186            .search(VectorSearchRequest {
187                query_embedding: query_vector.to_vec(),
188                top_k: candidate_ids.len() as u32,
189                namespace: Some(ns),
190                kind: Some(SubstrateKind::Entity),
191            })
192            .await?;
193        let mut hits: Vec<VectorSearchHit> = all_hits
194            .into_iter()
195            .filter(|h| candidate_set.contains(&h.subject_id))
196            .collect();
197        hits.sort_by(|a, b| b.score.cmp(&a.score));
198        hits.truncate(top_k as usize);
199        Ok(hits)
200    }
201}
202
203/// Fuse text + vector hits with Reciprocal Rank Fusion (k=60).
204///
205/// Hits in both lists get RRF scores summed. Sort by fused score, take top-`limit`.
206fn rrf_fuse(
207    text_hits: Vec<TextSearchHit>,
208    vector_hits: Vec<VectorSearchHit>,
209    limit: usize,
210) -> Vec<SearchHit> {
211    #[derive(Default)]
212    struct Bucket {
213        score: DeterministicScore,
214        source: Option<SearchSource>,
215        title: Option<String>,
216        snippet: Option<String>,
217    }
218
219    let mut buckets: HashMap<Uuid, Bucket> = HashMap::new();
220
221    for (i, hit) in text_hits.into_iter().enumerate() {
222        let rank = i + 1; // RRF is 1-indexed
223        let entry = buckets.entry(hit.subject_id).or_default();
224        entry.score = entry.score + rrf_score(rank, RRF_K);
225        entry.source = Some(match entry.source {
226            Some(SearchSource::Vector) => SearchSource::Both,
227            _ => SearchSource::Text,
228        });
229        if entry.title.is_none() {
230            entry.title = hit.title;
231        }
232        if entry.snippet.is_none() {
233            entry.snippet = hit.snippet;
234        }
235    }
236
237    for (i, hit) in vector_hits.into_iter().enumerate() {
238        let rank = i + 1;
239        let entry = buckets.entry(hit.subject_id).or_default();
240        entry.score = entry.score + rrf_score(rank, RRF_K);
241        entry.source = Some(match entry.source {
242            Some(SearchSource::Text) => SearchSource::Both,
243            _ => SearchSource::Vector,
244        });
245    }
246
247    let mut hits: Vec<SearchHit> = buckets
248        .into_iter()
249        .map(|(id, b)| SearchHit {
250            entity_id: id,
251            score: b.score,
252            source: b.source.expect("each bucket gets a source"),
253            title: b.title,
254            snippet: b.snippet,
255        })
256        .collect();
257
258    hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
259    hits.truncate(limit);
260    hits
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use crate::runtime::{KhiveRuntime, RuntimeConfig};
267    use khive_storage::types::{TextSearchHit, VectorSearchHit};
268    use lattice_embed::EmbeddingModel;
269
270    fn text_hit(id: Uuid, rank: u32, title: &str) -> TextSearchHit {
271        TextSearchHit {
272            subject_id: id,
273            score: DeterministicScore::from_f64(1.0),
274            rank,
275            title: Some(title.to_string()),
276            snippet: Some("...".to_string()),
277        }
278    }
279
280    fn vector_hit(id: Uuid, rank: u32) -> VectorSearchHit {
281        VectorSearchHit {
282            subject_id: id,
283            score: DeterministicScore::from_f64(0.9),
284            rank,
285        }
286    }
287
288    #[test]
289    fn rrf_fuse_text_only() {
290        let a = Uuid::new_v4();
291        let b = Uuid::new_v4();
292        let text = vec![text_hit(a, 1, "A"), text_hit(b, 2, "B")];
293        let hits = rrf_fuse(text, vec![], 10);
294        assert_eq!(hits.len(), 2);
295        assert_eq!(hits[0].entity_id, a);
296        assert_eq!(hits[0].source, SearchSource::Text);
297        assert_eq!(hits[0].title.as_deref(), Some("A"));
298    }
299
300    #[test]
301    fn rrf_fuse_vector_only() {
302        let a = Uuid::new_v4();
303        let hits = rrf_fuse(vec![], vec![vector_hit(a, 1)], 10);
304        assert_eq!(hits.len(), 1);
305        assert_eq!(hits[0].source, SearchSource::Vector);
306        assert!(hits[0].title.is_none());
307    }
308
309    #[test]
310    fn rrf_fuse_marks_both_when_in_both_lists() {
311        let id = Uuid::new_v4();
312        let text = vec![text_hit(id, 1, "A")];
313        let vec = vec![vector_hit(id, 1)];
314        let hits = rrf_fuse(text, vec, 10);
315        assert_eq!(hits.len(), 1);
316        assert_eq!(hits[0].source, SearchSource::Both);
317    }
318
319    #[test]
320    fn rrf_fuse_respects_limit() {
321        let hits: Vec<TextSearchHit> = (0..20)
322            .map(|i| text_hit(Uuid::new_v4(), i + 1, "x"))
323            .collect();
324        let fused = rrf_fuse(hits, vec![], 5);
325        assert_eq!(fused.len(), 5);
326    }
327
328    #[test]
329    fn rrf_fuse_orders_higher_score_first() {
330        // Same UUID in both lists at rank 1 → score 2/(60+1). Different UUIDs → 1/(60+1) each.
331        let a = Uuid::new_v4();
332        let b = Uuid::new_v4();
333        let text = vec![text_hit(a, 1, "A")];
334        let vec = vec![vector_hit(a, 1), vector_hit(b, 2)];
335        let hits = rrf_fuse(text, vec, 10);
336        assert_eq!(hits[0].entity_id, a);
337        assert_eq!(hits[0].source, SearchSource::Both);
338        assert!(hits[0].score > hits[1].score);
339    }
340
341    // ---- embed_batch tests ----
342
343    #[test]
344    fn embed_batch_unconfigured_on_memory_runtime() {
345        // KhiveRuntime::memory() has no embedding model — embed_batch returns Unconfigured.
346        let rt = KhiveRuntime::memory().unwrap();
347        let result = tokio::runtime::Runtime::new()
348            .unwrap()
349            .block_on(rt.embed_batch(&[]));
350        // Empty slice short-circuits before hitting the model check.
351        assert!(result.is_ok());
352        assert!(result.unwrap().is_empty());
353    }
354
355    #[test]
356    fn embed_batch_empty_input_returns_empty_vec() {
357        // No model needed — empty slice is handled before the embedder is touched.
358        let rt = KhiveRuntime::memory().unwrap();
359        let result = tokio::runtime::Runtime::new()
360            .unwrap()
361            .block_on(rt.embed_batch(&[]));
362        assert_eq!(result.unwrap(), Vec::<Vec<f32>>::new());
363    }
364
365    #[test]
366    fn embed_batch_no_model_non_empty_returns_unconfigured() {
367        let rt = KhiveRuntime::memory().unwrap();
368        let texts = vec!["hello".to_string()];
369        let result = tokio::runtime::Runtime::new()
370            .unwrap()
371            .block_on(rt.embed_batch(&texts));
372        match result {
373            Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
374            Err(other) => panic!("expected Unconfigured, got {:?}", other),
375            Ok(_) => panic!("expected Err, got Ok"),
376        }
377    }
378
379    #[test]
380    #[ignore = "loads ~80 MB model; run with --include-ignored"]
381    fn embed_batch_count_matches_input() {
382        let config = RuntimeConfig {
383            db_path: None,
384            default_namespace: "test".to_string(),
385            embedding_model: Some(EmbeddingModel::AllMiniLmL6V2),
386        };
387        let rt = KhiveRuntime::new(config).unwrap();
388        let texts: Vec<String> = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
389        let result = tokio::runtime::Runtime::new()
390            .unwrap()
391            .block_on(rt.embed_batch(&texts));
392        let embeddings = result.unwrap();
393        assert_eq!(embeddings.len(), texts.len());
394    }
395
396    #[test]
397    #[ignore = "loads ~80 MB model; run with --include-ignored"]
398    fn embed_batch_vectors_have_expected_dimensions() {
399        let model = EmbeddingModel::AllMiniLmL6V2;
400        let config = RuntimeConfig {
401            db_path: None,
402            default_namespace: "test".to_string(),
403            embedding_model: Some(model),
404        };
405        let rt = KhiveRuntime::new(config).unwrap();
406        let texts = vec!["hello world".to_string()];
407        let result = tokio::runtime::Runtime::new()
408            .unwrap()
409            .block_on(rt.embed_batch(&texts));
410        let embeddings = result.unwrap();
411        assert_eq!(embeddings[0].len(), model.dimensions());
412    }
413}