1use crate::embedding::EmbeddingProvider;
4use crate::fact::{Fact, FactId};
5use crate::graph::GraphStore;
6use crate::scope::Scope;
7use crate::store::{FactStore, MemoryError};
8use crate::vector::{VectorFilter, VectorStore};
9use std::collections::HashMap;
10use std::sync::Arc;
11
12#[derive(Debug, Clone)]
14pub struct ScoredFact {
15 pub fact: Fact,
16 pub score: f32,
17 pub vector_score: f32,
18 pub keyword_score: f32,
19 pub graph_score: f32,
20}
21
22#[derive(Debug, Clone)]
24pub struct RetrievalConfig {
25 pub vector_weight: f32,
27 pub keyword_weight: f32,
29 pub graph_weight: f32,
31}
32
33impl Default for RetrievalConfig {
34 fn default() -> Self {
35 Self {
36 vector_weight: 0.5,
37 keyword_weight: 0.3,
38 graph_weight: 0.2,
39 }
40 }
41}
42
43pub struct HybridRetriever {
44 fact_store: Arc<dyn FactStore>,
45 vector_store: Arc<dyn VectorStore>,
46 graph_store: Arc<dyn GraphStore>,
47 embedding: Arc<dyn EmbeddingProvider>,
48 config: RetrievalConfig,
49}
50
51impl HybridRetriever {
52 pub fn new(
53 fact_store: Arc<dyn FactStore>,
54 vector_store: Arc<dyn VectorStore>,
55 graph_store: Arc<dyn GraphStore>,
56 embedding: Arc<dyn EmbeddingProvider>,
57 config: RetrievalConfig,
58 ) -> Self {
59 Self {
60 fact_store,
61 vector_store,
62 graph_store,
63 embedding,
64 config,
65 }
66 }
67
68 pub async fn search(
70 &self,
71 query: &str,
72 scope: &Scope,
73 top_k: usize,
74 ) -> Result<Vec<ScoredFact>, MemoryError> {
75 let candidate_k = top_k * 3;
77
78 let embeddings = self.embedding.embed(&[query]).await?;
80 let query_vec = embeddings
81 .into_iter()
82 .next()
83 .ok_or_else(|| MemoryError::Embedding("empty embedding".to_string()))?;
84 let vector_filter = VectorFilter {
85 scope: Some(scope.clone()),
86 min_score: None,
87 };
88 let vector_matches = self
89 .vector_store
90 .search(&query_vec, &vector_filter, candidate_k)
91 .await?;
92
93 let keyword_matches = self
95 .fact_store
96 .keyword_search(query, scope, candidate_k)
97 .await?;
98
99 let graph_entity_ids = self.graph_store.search_entities(query, 5).await?;
101 let mut graph_fact_ids: HashMap<FactId, f32> = HashMap::new();
102 for entity in &graph_entity_ids {
103 let subgraph = self.graph_store.neighbors(entity.id, 1, None).await?;
104 for _rel in &subgraph.relationships {
106 let entity_facts = self
107 .fact_store
108 .keyword_search(&entity.name, scope, 5)
109 .await?;
110 for f in &entity_facts {
111 let entry = graph_fact_ids.entry(f.id).or_insert(0.0);
112 *entry = (*entry + 0.5).min(1.0);
113 }
114 }
115 }
116
117 let mut scored: HashMap<FactId, (f32, f32, f32)> = HashMap::new(); for vm in &vector_matches {
122 scored.entry(vm.id).or_insert((0.0, 0.0, 0.0)).0 = vm.score;
123 }
124
125 for (i, fact) in keyword_matches.iter().enumerate() {
127 let kw_score = 1.0 - (i as f32 / candidate_k.max(1) as f32);
128 scored.entry(fact.id).or_insert((0.0, 0.0, 0.0)).1 = kw_score;
129 }
130
131 for (id, score) in &graph_fact_ids {
133 scored.entry(*id).or_insert((0.0, 0.0, 0.0)).2 = *score;
134 }
135
136 let mut results: Vec<ScoredFact> = Vec::new();
138 for (id, (vs, ks, gs)) in &scored {
139 if let Ok(fact) = self.fact_store.get_fact(*id).await {
140 if !fact.is_valid() {
141 continue;
142 }
143 let final_score = vs * self.config.vector_weight
144 + ks * self.config.keyword_weight
145 + gs * self.config.graph_weight;
146 results.push(ScoredFact {
147 fact,
148 score: final_score,
149 vector_score: *vs,
150 keyword_score: *ks,
151 graph_score: *gs,
152 });
153 }
154 }
155
156 results.sort_by(|a, b| {
158 b.score
159 .partial_cmp(&a.score)
160 .unwrap_or(std::cmp::Ordering::Equal)
161 });
162 results.truncate(top_k);
163
164 for sf in &results {
166 let _ = self.fact_store.record_access(sf.fact.id).await;
167 }
168
169 Ok(results)
170 }
171}