Skip to main content

khive_runtime/
retrieval.rs

1//! Retrieval operations: local embedding generation and hybrid search with RRF fusion.
2//!
3//! See ADR-012 — Retrieval Architecture.
4
5use std::collections::{HashMap, HashSet};
6
7use uuid::Uuid;
8
9use crate::error::{RuntimeError, RuntimeResult};
10use crate::runtime::KhiveRuntime;
11use khive_score::{rrf_score, DeterministicScore};
12use khive_storage::types::{
13    PageRequest, TextFilter, TextQueryMode, TextSearchHit, TextSearchRequest, VectorSearchHit,
14    VectorSearchRequest,
15};
16use khive_storage::EntityFilter;
17use khive_types::SubstrateKind;
18
19/// A unified search result combining vector and text signals.
20#[derive(Clone, Debug)]
21pub struct SearchHit {
22    pub entity_id: Uuid,
23    pub score: DeterministicScore,
24    pub source: SearchSource,
25    pub title: Option<String>,
26    pub snippet: Option<String>,
27}
28
29/// Which retrieval path(s) contributed to a hit.
30#[derive(Clone, Copy, Debug, PartialEq, Eq)]
31pub enum SearchSource {
32    Vector,
33    Text,
34    Both,
35}
36
37/// RRF constant from the original paper. Controls how strongly top ranks dominate.
38const RRF_K: usize = 60;
39
40/// Candidates pulled per path before fusion. Higher = better recall, more work.
41const CANDIDATE_MULTIPLIER: u32 = 4;
42
43impl KhiveRuntime {
44    /// Generate an embedding vector for `text` using the configured local model.
45    ///
46    /// First call lazily loads model weights (cold start cost). Subsequent calls reuse them.
47    /// Returns `Unconfigured("embedding_model")` if no model is configured.
48    pub async fn embed(&self, text: &str) -> RuntimeResult<Vec<f32>> {
49        let service = self.embedder().await?;
50        let model = self
51            .config()
52            .embedding_model
53            .expect("embedder() returns Unconfigured when model is None");
54        Ok(service.embed_one(text, model).await?)
55    }
56
57    /// Generate embeddings for multiple texts in one call.
58    ///
59    /// Delegates to the cached `EmbeddingService::embed`, so repeated texts within
60    /// and across calls benefit from the runtime-level LRU cache.
61    ///
62    /// Returns an empty vec for empty input without hitting the embedding service.
63    /// Returns `Unconfigured("embedding_model")` if no model is configured.
64    pub async fn embed_batch(&self, texts: &[String]) -> RuntimeResult<Vec<Vec<f32>>> {
65        if texts.is_empty() {
66            return Ok(vec![]);
67        }
68        let service = self.embedder().await?;
69        let model = self
70            .config()
71            .embedding_model
72            .expect("embedder() returns Unconfigured when model is None");
73        Ok(service.embed(texts, model).await?)
74    }
75
76    /// Search vectors using either a caller-provided embedding or query text.
77    ///
78    /// Existing callers pass `query_embedding: Some(vec)` to avoid re-embedding.
79    /// Text callers pass `query_embedding: None, query_text: Some(...)` and the
80    /// runtime embeds internally.
81    pub async fn vector_search(
82        &self,
83        namespace: Option<&str>,
84        query_embedding: Option<Vec<f32>>,
85        query_text: Option<&str>,
86        top_k: u32,
87        kind: Option<SubstrateKind>,
88    ) -> RuntimeResult<Vec<VectorSearchHit>> {
89        let embedding = match query_embedding {
90            Some(vec) => vec,
91            None => {
92                let text = query_text.ok_or_else(|| {
93                    RuntimeError::InvalidInput(
94                        "vector search requires query_embedding or query_text".into(),
95                    )
96                })?;
97                if text.trim().is_empty() {
98                    return Err(RuntimeError::InvalidInput(
99                        "query_text must not be empty".into(),
100                    ));
101                }
102                self.embed(text).await?
103            }
104        };
105
106        let ns = self.ns(namespace).to_string();
107        Ok(self
108            .vectors(namespace)?
109            .search(VectorSearchRequest {
110                query_embedding: embedding,
111                top_k,
112                namespace: Some(ns),
113                kind,
114            })
115            .await?)
116    }
117
118    /// Hybrid search: text (FTS5) + vector retrieval fused via Reciprocal Rank Fusion.
119    ///
120    /// - Always performs text search over `query_text`.
121    /// - If `query_vector` is `Some`, also performs vector search and fuses both lists.
122    /// - If `None`, returns text-only results — no vector store needed.
123    /// - If `entity_kind` is `Some`, the alive-set query filters to that kind.
124    ///   The text/vector candidate pools are unfiltered up front; the kind
125    ///   filter applies at the alive-check stage where we already fetch each
126    ///   candidate to confirm it isn't soft-deleted.
127    ///
128    /// `limit` caps the final returned list; internally pulls `limit * 4` candidates per path.
129    /// The fused candidate set is kept untruncated until after the alive + kind filter so
130    /// that right-kind hits ranked below `limit` in the raw fusion still surface when
131    /// higher-ranked candidates are wrong-kind or soft-deleted.
132    pub async fn hybrid_search(
133        &self,
134        namespace: Option<&str>,
135        query_text: &str,
136        query_vector: Option<Vec<f32>>,
137        limit: u32,
138        entity_kind: Option<&str>,
139    ) -> RuntimeResult<Vec<SearchHit>> {
140        let candidates = limit.saturating_mul(CANDIDATE_MULTIPLIER).max(limit);
141
142        let ns = self.ns(namespace).to_string();
143        let text_hits = self
144            .text(namespace)?
145            .search(TextSearchRequest {
146                query: query_text.to_string(),
147                mode: TextQueryMode::Plain,
148                filter: Some(TextFilter {
149                    namespaces: vec![ns.clone()],
150                    ..TextFilter::default()
151                }),
152                top_k: candidates,
153                snippet_chars: 200,
154            })
155            .await?;
156
157        let vector_hits = if query_vector.is_some() || self.config().embedding_model.is_some() {
158            self.vector_search(
159                namespace,
160                query_vector,
161                Some(query_text),
162                candidates,
163                Some(SubstrateKind::Entity),
164            )
165            .await?
166        } else {
167            Vec::new()
168        };
169
170        // Fuse without truncating: keep the full candidate pool through the
171        // alive/kind filter so right-kind hits below rank `limit` aren't lost.
172        let mut fused = rrf_fuse(text_hits, vector_hits, candidates as usize);
173
174        // Filter to alive entities (and optionally to a specific kind). A single
175        // query fetches all alive IDs that match the kind constraint from the
176        // fused set; any ID absent has been soft-deleted or doesn't match.
177        if !fused.is_empty() {
178            let candidate_ids: Vec<Uuid> = fused.iter().map(|h| h.entity_id).collect();
179            let alive_page = self
180                .entities(namespace)?
181                .query_entities(
182                    self.ns(namespace),
183                    EntityFilter {
184                        ids: candidate_ids,
185                        kinds: entity_kind.map(|k| vec![k.to_string()]).unwrap_or_default(),
186                        ..EntityFilter::default()
187                    },
188                    PageRequest {
189                        offset: 0,
190                        limit: fused.len() as u32,
191                    },
192                )
193                .await?;
194            // Keep entity metadata to enrich hits that had no FTS5 title/snippet.
195            let mut entity_meta: HashMap<Uuid, (String, Option<String>)> = HashMap::new();
196            let mut alive: HashSet<Uuid> = HashSet::new();
197            for e in alive_page.items {
198                alive.insert(e.id);
199                entity_meta.insert(e.id, (e.name, e.description));
200            }
201
202            fused.retain(|h| alive.contains(&h.entity_id));
203
204            // Enrich vector-only hits (title/snippet == None) from entity record.
205            for hit in &mut fused {
206                if let Some((name, description)) = entity_meta.get(&hit.entity_id) {
207                    if hit.title.is_none() {
208                        hit.title = Some(name.clone());
209                    }
210                    if hit.snippet.is_none() {
211                        hit.snippet = description.clone();
212                    }
213                }
214            }
215        }
216
217        fused.truncate(limit as usize);
218        Ok(fused)
219    }
220
221    /// Exact KNN over the full namespace's vector store.
222    ///
223    /// sqlite-vec uses brute-force cosine — results are exact, not approximate.
224    /// Cost is O(N · D) per query. For small-to-medium namespaces (~hundreds of
225    /// thousands of vectors) this is well within latency budgets.
226    pub async fn knn(
227        &self,
228        namespace: Option<&str>,
229        query_vector: Vec<f32>,
230        top_k: u32,
231    ) -> RuntimeResult<Vec<VectorSearchHit>> {
232        let ns = self.ns(namespace).to_string();
233        Ok(self
234            .vectors(namespace)?
235            .search(VectorSearchRequest {
236                query_embedding: query_vector,
237                top_k,
238                namespace: Some(ns),
239                kind: Some(SubstrateKind::Entity),
240            })
241            .await?)
242    }
243
244    /// Exact KNN restricted to a candidate set.
245    ///
246    /// Useful for reranking the top-N results from `hybrid_search` (or any other
247    /// retrieval path) with exact cosine similarity against a query vector.
248    /// Returns hits sorted by similarity (highest first), truncated to `top_k`.
249    pub async fn rerank(
250        &self,
251        namespace: Option<&str>,
252        query_vector: &[f32],
253        candidate_ids: &[Uuid],
254        top_k: u32,
255    ) -> RuntimeResult<Vec<VectorSearchHit>> {
256        let candidate_set: HashSet<Uuid> = candidate_ids.iter().copied().collect();
257        let ns = self.ns(namespace).to_string();
258        let all_hits = self
259            .vectors(namespace)?
260            .search(VectorSearchRequest {
261                query_embedding: query_vector.to_vec(),
262                top_k: candidate_ids.len() as u32,
263                namespace: Some(ns),
264                kind: Some(SubstrateKind::Entity),
265            })
266            .await?;
267        let mut hits: Vec<VectorSearchHit> = all_hits
268            .into_iter()
269            .filter(|h| candidate_set.contains(&h.subject_id))
270            .collect();
271        hits.sort_by(|a, b| b.score.cmp(&a.score));
272        hits.truncate(top_k as usize);
273        Ok(hits)
274    }
275}
276
277/// Fuse text + vector hits with Reciprocal Rank Fusion (k=60).
278///
279/// Hits in both lists get RRF scores summed. Sort by fused score, take top-`limit`.
280fn rrf_fuse(
281    text_hits: Vec<TextSearchHit>,
282    vector_hits: Vec<VectorSearchHit>,
283    limit: usize,
284) -> Vec<SearchHit> {
285    #[derive(Default)]
286    struct Bucket {
287        score: DeterministicScore,
288        source: Option<SearchSource>,
289        title: Option<String>,
290        snippet: Option<String>,
291    }
292
293    let mut buckets: HashMap<Uuid, Bucket> = HashMap::new();
294
295    for (i, hit) in text_hits.into_iter().enumerate() {
296        let rank = i + 1; // RRF is 1-indexed
297        let entry = buckets.entry(hit.subject_id).or_default();
298        entry.score = entry.score + rrf_score(rank, RRF_K);
299        entry.source = Some(match entry.source {
300            Some(SearchSource::Vector) => SearchSource::Both,
301            _ => SearchSource::Text,
302        });
303        if entry.title.is_none() {
304            entry.title = hit.title;
305        }
306        if entry.snippet.is_none() {
307            entry.snippet = hit.snippet;
308        }
309    }
310
311    for (i, hit) in vector_hits.into_iter().enumerate() {
312        let rank = i + 1;
313        let entry = buckets.entry(hit.subject_id).or_default();
314        entry.score = entry.score + rrf_score(rank, RRF_K);
315        entry.source = Some(match entry.source {
316            Some(SearchSource::Text) => SearchSource::Both,
317            _ => SearchSource::Vector,
318        });
319    }
320
321    let mut hits: Vec<SearchHit> = buckets
322        .into_iter()
323        .map(|(id, b)| SearchHit {
324            entity_id: id,
325            score: b.score,
326            source: b.source.expect("each bucket gets a source"),
327            title: b.title,
328            snippet: b.snippet,
329        })
330        .collect();
331
332    hits.sort_by(|a, b| b.score.cmp(&a.score).then(a.entity_id.cmp(&b.entity_id)));
333    hits.truncate(limit);
334    hits
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340    use crate::runtime::{KhiveRuntime, RuntimeConfig};
341    use khive_storage::types::{TextSearchHit, VectorSearchHit};
342    use lattice_embed::EmbeddingModel;
343
344    fn text_hit(id: Uuid, rank: u32, title: &str) -> TextSearchHit {
345        TextSearchHit {
346            subject_id: id,
347            score: DeterministicScore::from_f64(1.0),
348            rank,
349            title: Some(title.to_string()),
350            snippet: Some("...".to_string()),
351        }
352    }
353
354    fn vector_hit(id: Uuid, rank: u32) -> VectorSearchHit {
355        VectorSearchHit {
356            subject_id: id,
357            score: DeterministicScore::from_f64(0.9),
358            rank,
359        }
360    }
361
362    #[test]
363    fn rrf_fuse_text_only() {
364        let a = Uuid::new_v4();
365        let b = Uuid::new_v4();
366        let text = vec![text_hit(a, 1, "A"), text_hit(b, 2, "B")];
367        let hits = rrf_fuse(text, vec![], 10);
368        assert_eq!(hits.len(), 2);
369        assert_eq!(hits[0].entity_id, a);
370        assert_eq!(hits[0].source, SearchSource::Text);
371        assert_eq!(hits[0].title.as_deref(), Some("A"));
372    }
373
374    #[test]
375    fn rrf_fuse_vector_only() {
376        let a = Uuid::new_v4();
377        let hits = rrf_fuse(vec![], vec![vector_hit(a, 1)], 10);
378        assert_eq!(hits.len(), 1);
379        assert_eq!(hits[0].source, SearchSource::Vector);
380        assert!(hits[0].title.is_none());
381    }
382
383    #[test]
384    fn rrf_fuse_marks_both_when_in_both_lists() {
385        let id = Uuid::new_v4();
386        let text = vec![text_hit(id, 1, "A")];
387        let vec = vec![vector_hit(id, 1)];
388        let hits = rrf_fuse(text, vec, 10);
389        assert_eq!(hits.len(), 1);
390        assert_eq!(hits[0].source, SearchSource::Both);
391    }
392
393    #[test]
394    fn rrf_fuse_respects_limit() {
395        let hits: Vec<TextSearchHit> = (0..20)
396            .map(|i| text_hit(Uuid::new_v4(), i + 1, "x"))
397            .collect();
398        let fused = rrf_fuse(hits, vec![], 5);
399        assert_eq!(fused.len(), 5);
400    }
401
402    #[test]
403    fn rrf_fuse_orders_higher_score_first() {
404        // Same UUID in both lists at rank 1 → score 2/(60+1). Different UUIDs → 1/(60+1) each.
405        let a = Uuid::new_v4();
406        let b = Uuid::new_v4();
407        let text = vec![text_hit(a, 1, "A")];
408        let vec = vec![vector_hit(a, 1), vector_hit(b, 2)];
409        let hits = rrf_fuse(text, vec, 10);
410        assert_eq!(hits[0].entity_id, a);
411        assert_eq!(hits[0].source, SearchSource::Both);
412        assert!(hits[0].score > hits[1].score);
413    }
414
415    // ---- embed_batch tests ----
416
417    #[test]
418    fn embed_batch_unconfigured_on_memory_runtime() {
419        // KhiveRuntime::memory() has no embedding model — embed_batch returns Unconfigured.
420        let rt = KhiveRuntime::memory().unwrap();
421        let result = tokio::runtime::Runtime::new()
422            .unwrap()
423            .block_on(rt.embed_batch(&[]));
424        // Empty slice short-circuits before hitting the model check.
425        assert!(result.is_ok());
426        assert!(result.unwrap().is_empty());
427    }
428
429    #[test]
430    fn embed_batch_empty_input_returns_empty_vec() {
431        // No model needed — empty slice is handled before the embedder is touched.
432        let rt = KhiveRuntime::memory().unwrap();
433        let result = tokio::runtime::Runtime::new()
434            .unwrap()
435            .block_on(rt.embed_batch(&[]));
436        assert_eq!(result.unwrap(), Vec::<Vec<f32>>::new());
437    }
438
439    #[test]
440    fn embed_batch_no_model_non_empty_returns_unconfigured() {
441        let rt = KhiveRuntime::memory().unwrap();
442        let texts = vec!["hello".to_string()];
443        let result = tokio::runtime::Runtime::new()
444            .unwrap()
445            .block_on(rt.embed_batch(&texts));
446        match result {
447            Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
448            Err(other) => panic!("expected Unconfigured, got {:?}", other),
449            Ok(_) => panic!("expected Err, got Ok"),
450        }
451    }
452
453    #[test]
454    #[ignore = "loads ~80 MB model; run with --include-ignored"]
455    fn embed_batch_count_matches_input() {
456        let config = RuntimeConfig {
457            db_path: None,
458            default_namespace: "test".to_string(),
459            embedding_model: Some(EmbeddingModel::AllMiniLmL6V2),
460            packs: vec!["kg".to_string()],
461            ..RuntimeConfig::default()
462        };
463        let rt = KhiveRuntime::new(config).unwrap();
464        let texts: Vec<String> = vec!["foo".to_string(), "bar".to_string(), "baz".to_string()];
465        let result = tokio::runtime::Runtime::new()
466            .unwrap()
467            .block_on(rt.embed_batch(&texts));
468        let embeddings = result.unwrap();
469        assert_eq!(embeddings.len(), texts.len());
470    }
471
472    #[test]
473    fn vector_search_requires_embedding_or_text() {
474        let rt = KhiveRuntime::memory().unwrap();
475        let result = tokio::runtime::Runtime::new()
476            .unwrap()
477            .block_on(rt.vector_search(None, None, None, 10, Some(SubstrateKind::Entity)));
478        match result {
479            Err(crate::RuntimeError::InvalidInput(msg)) => {
480                assert!(msg.contains("query_embedding or query_text"), "msg: {msg}");
481            }
482            other => panic!("expected InvalidInput, got {other:?}"),
483        }
484    }
485
486    #[test]
487    fn vector_search_text_without_model_returns_unconfigured() {
488        let rt = KhiveRuntime::memory().unwrap();
489        let result = tokio::runtime::Runtime::new()
490            .unwrap()
491            .block_on(rt.vector_search(
492                None,
493                None,
494                Some("attention"),
495                10,
496                Some(SubstrateKind::Entity),
497            ));
498        match result {
499            Err(crate::RuntimeError::Unconfigured(s)) => assert_eq!(s, "embedding_model"),
500            other => panic!("expected Unconfigured, got {other:?}"),
501        }
502    }
503
504    #[test]
505    #[ignore = "loads ~80 MB model; run with --include-ignored"]
506    fn embed_batch_vectors_have_expected_dimensions() {
507        let model = EmbeddingModel::AllMiniLmL6V2;
508        let config = RuntimeConfig {
509            db_path: None,
510            default_namespace: "test".to_string(),
511            embedding_model: Some(model),
512            packs: vec!["kg".to_string()],
513            ..RuntimeConfig::default()
514        };
515        let rt = KhiveRuntime::new(config).unwrap();
516        let texts = vec!["hello world".to_string()];
517        let result = tokio::runtime::Runtime::new()
518            .unwrap()
519            .block_on(rt.embed_batch(&texts));
520        let embeddings = result.unwrap();
521        assert_eq!(embeddings[0].len(), model.dimensions());
522    }
523
524    // ---- hybrid_search enrichment (issue #147 / #160) ----
525
526    #[tokio::test]
527    async fn hybrid_search_entity_hit_has_title() {
528        let rt = KhiveRuntime::memory().unwrap();
529        rt.create_entity(
530            None,
531            "concept",
532            "FlashAttention",
533            Some("IO-aware exact attention using tiling"),
534            None,
535            vec![],
536        )
537        .await
538        .unwrap();
539
540        let hits = rt
541            .hybrid_search(None, "FlashAttention", None, 10, None)
542            .await
543            .unwrap();
544
545        assert!(!hits.is_empty(), "should find the entity");
546        let hit = &hits[0];
547        assert!(hit.title.is_some(), "title must be populated");
548        assert!(
549            hit.title.as_deref().unwrap().contains("FlashAttention"),
550            "title must contain entity name"
551        );
552    }
553}