use crate::embedding::EmbeddingProvider;
use crate::fact::{Fact, FactId};
use crate::graph::GraphStore;
use crate::scope::Scope;
use crate::store::{FactStore, MemoryError};
use crate::vector::{VectorFilter, VectorStore};
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct ScoredFact {
pub fact: Fact,
pub score: f32,
pub vector_score: f32,
pub keyword_score: f32,
pub graph_score: f32,
}
#[derive(Debug, Clone)]
pub struct RetrievalConfig {
pub vector_weight: f32,
pub keyword_weight: f32,
pub graph_weight: f32,
}
impl Default for RetrievalConfig {
fn default() -> Self {
Self {
vector_weight: 0.5,
keyword_weight: 0.3,
graph_weight: 0.2,
}
}
}
pub struct HybridRetriever {
fact_store: Arc<dyn FactStore>,
vector_store: Arc<dyn VectorStore>,
graph_store: Arc<dyn GraphStore>,
embedding: Arc<dyn EmbeddingProvider>,
config: RetrievalConfig,
}
impl HybridRetriever {
pub fn new(
fact_store: Arc<dyn FactStore>,
vector_store: Arc<dyn VectorStore>,
graph_store: Arc<dyn GraphStore>,
embedding: Arc<dyn EmbeddingProvider>,
config: RetrievalConfig,
) -> Self {
Self {
fact_store,
vector_store,
graph_store,
embedding,
config,
}
}
pub async fn search(
&self,
query: &str,
scope: &Scope,
top_k: usize,
) -> Result<Vec<ScoredFact>, MemoryError> {
let candidate_k = top_k * 3;
let embeddings = self.embedding.embed(&[query]).await?;
let query_vec = embeddings
.into_iter()
.next()
.ok_or_else(|| MemoryError::Embedding("empty embedding".to_string()))?;
let vector_filter = VectorFilter {
scope: Some(scope.clone()),
min_score: None,
};
let vector_matches = self
.vector_store
.search(&query_vec, &vector_filter, candidate_k)
.await?;
let keyword_matches = self
.fact_store
.keyword_search(query, scope, candidate_k)
.await?;
let graph_entity_ids = self.graph_store.search_entities(query, 5).await?;
let mut graph_fact_ids: HashMap<FactId, f32> = HashMap::new();
for entity in &graph_entity_ids {
let subgraph = self.graph_store.neighbors(entity.id, 1, None).await?;
for _rel in &subgraph.relationships {
let entity_facts = self
.fact_store
.keyword_search(&entity.name, scope, 5)
.await?;
for f in &entity_facts {
let entry = graph_fact_ids.entry(f.id).or_insert(0.0);
*entry = (*entry + 0.5).min(1.0);
}
}
}
let mut scored: HashMap<FactId, (f32, f32, f32)> = HashMap::new();
for vm in &vector_matches {
scored.entry(vm.id).or_insert((0.0, 0.0, 0.0)).0 = vm.score;
}
for (i, fact) in keyword_matches.iter().enumerate() {
let kw_score = 1.0 - (i as f32 / candidate_k.max(1) as f32);
scored.entry(fact.id).or_insert((0.0, 0.0, 0.0)).1 = kw_score;
}
for (id, score) in &graph_fact_ids {
scored.entry(*id).or_insert((0.0, 0.0, 0.0)).2 = *score;
}
let mut results: Vec<ScoredFact> = Vec::new();
for (id, (vs, ks, gs)) in &scored {
if let Ok(fact) = self.fact_store.get_fact(*id).await {
if !fact.is_valid() {
continue;
}
let final_score = vs * self.config.vector_weight
+ ks * self.config.keyword_weight
+ gs * self.config.graph_weight;
results.push(ScoredFact {
fact,
score: final_score,
vector_score: *vs,
keyword_score: *ks,
graph_score: *gs,
});
}
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
for sf in &results {
let _ = self.fact_store.record_access(sf.fact.id).await;
}
Ok(results)
}
}