Skip to main content

engram/
retrieve.rs

1//! Hybrid retrieval — merges vector, keyword, and graph search results.
2
3use crate::embedding::EmbeddingProvider;
4use crate::fact::{Fact, FactId};
5use crate::graph::GraphStore;
6use crate::scope::Scope;
7use crate::store::{FactStore, MemoryError};
8use crate::vector::{VectorFilter, VectorStore};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12/// A fact with a combined retrieval score.
13#[derive(Debug, Clone)]
14pub struct ScoredFact {
15    pub fact: Fact,
16    pub score: f32,
17    pub vector_score: f32,
18    pub keyword_score: f32,
19    pub graph_score: f32,
20}
21
22/// Configuration for hybrid retrieval scoring.
23#[derive(Debug, Clone)]
24pub struct RetrievalConfig {
25    /// Weight for vector/semantic similarity (0.0-1.0).
26    pub vector_weight: f32,
27    /// Weight for keyword/BM25 score (0.0-1.0).
28    pub keyword_weight: f32,
29    /// Weight for graph proximity score (0.0-1.0).
30    pub graph_weight: f32,
31}
32
33impl Default for RetrievalConfig {
34    fn default() -> Self {
35        Self {
36            vector_weight: 0.5,
37            keyword_weight: 0.3,
38            graph_weight: 0.2,
39        }
40    }
41}
42
43pub struct HybridRetriever {
44    fact_store: Arc<dyn FactStore>,
45    vector_store: Arc<dyn VectorStore>,
46    graph_store: Arc<dyn GraphStore>,
47    embedding: Arc<dyn EmbeddingProvider>,
48    config: RetrievalConfig,
49}
50
51impl HybridRetriever {
52    pub fn new(
53        fact_store: Arc<dyn FactStore>,
54        vector_store: Arc<dyn VectorStore>,
55        graph_store: Arc<dyn GraphStore>,
56        embedding: Arc<dyn EmbeddingProvider>,
57        config: RetrievalConfig,
58    ) -> Self {
59        Self {
60            fact_store,
61            vector_store,
62            graph_store,
63            embedding,
64            config,
65        }
66    }
67
68    /// Hybrid search: vector + keyword + graph, merged with configurable weights.
69    pub async fn search(
70        &self,
71        query: &str,
72        scope: &Scope,
73        top_k: usize,
74    ) -> Result<Vec<ScoredFact>, MemoryError> {
75        // Fetch more candidates from each source, then merge
76        let candidate_k = top_k * 3;
77
78        // 1. Vector search
79        let embeddings = self.embedding.embed(&[query]).await?;
80        let query_vec = embeddings
81            .into_iter()
82            .next()
83            .ok_or_else(|| MemoryError::Embedding("empty embedding".to_string()))?;
84        let vector_filter = VectorFilter {
85            scope: Some(scope.clone()),
86            min_score: None,
87        };
88        let vector_matches = self
89            .vector_store
90            .search(&query_vec, &vector_filter, candidate_k)
91            .await?;
92
93        // 2. Keyword search
94        let keyword_matches = self
95            .fact_store
96            .keyword_search(query, scope, candidate_k)
97            .await?;
98
99        // 3. Graph walk — find entities matching query terms, walk 1 hop
100        let graph_entity_ids = self.graph_store.search_entities(query, 5).await?;
101        let mut graph_fact_ids: HashMap<FactId, f32> = HashMap::new();
102        for entity in &graph_entity_ids {
103            let subgraph = self.graph_store.neighbors(entity.id, 1, None).await?;
104            // Facts that reference entities in the subgraph get a graph score
105            for _rel in &subgraph.relationships {
106                let entity_facts = self
107                    .fact_store
108                    .keyword_search(&entity.name, scope, 5)
109                    .await?;
110                for f in &entity_facts {
111                    let entry = graph_fact_ids.entry(f.id).or_insert(0.0);
112                    *entry = (*entry + 0.5).min(1.0);
113                }
114            }
115        }
116
117        // Merge scores
118        let mut scored: HashMap<FactId, (f32, f32, f32)> = HashMap::new(); // (vec, kw, graph)
119
120        // Vector scores (already cosine similarity in [0,1])
121        for vm in &vector_matches {
122            scored.entry(vm.id).or_insert((0.0, 0.0, 0.0)).0 = vm.score;
123        }
124
125        // Keyword scores (normalize by rank position)
126        for (i, fact) in keyword_matches.iter().enumerate() {
127            let kw_score = 1.0 - (i as f32 / candidate_k.max(1) as f32);
128            scored.entry(fact.id).or_insert((0.0, 0.0, 0.0)).1 = kw_score;
129        }
130
131        // Graph scores
132        for (id, score) in &graph_fact_ids {
133            scored.entry(*id).or_insert((0.0, 0.0, 0.0)).2 = *score;
134        }
135
136        // Fetch full facts and compute final scores
137        let mut results: Vec<ScoredFact> = Vec::new();
138        for (id, (vs, ks, gs)) in &scored {
139            if let Ok(fact) = self.fact_store.get_fact(*id).await {
140                if !fact.is_valid() {
141                    continue;
142                }
143                let final_score = vs * self.config.vector_weight
144                    + ks * self.config.keyword_weight
145                    + gs * self.config.graph_weight;
146                results.push(ScoredFact {
147                    fact,
148                    score: final_score,
149                    vector_score: *vs,
150                    keyword_score: *ks,
151                    graph_score: *gs,
152                });
153            }
154        }
155
156        // Sort descending by score
157        results.sort_by(|a, b| {
158            b.score
159                .partial_cmp(&a.score)
160                .unwrap_or(std::cmp::Ordering::Equal)
161        });
162        results.truncate(top_k);
163
164        // Record access
165        for sf in &results {
166            let _ = self.fact_store.record_access(sf.fact.id).await;
167        }
168
169        Ok(results)
170    }
171}