use crate::{
core::KnowledgeGraph,
retrieval::{
bm25::{BM25Result, BM25Retriever},
ResultType,
},
vector::{EmbeddingGenerator, VectorIndex},
GraphRAGError, Result,
};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct HybridSearchResult {
pub id: String,
pub content: String,
pub score: f32,
pub semantic_score: f32,
pub keyword_score: f32,
pub result_type: ResultType,
pub entities: Vec<String>,
pub source_chunks: Vec<String>,
pub fusion_method: FusionMethod,
}
#[derive(Debug, Clone, PartialEq)]
#[allow(clippy::upper_case_acronyms)]
pub enum FusionMethod {
RRF,
Weighted,
CombSum,
MaxScore,
}
#[derive(Debug, Clone)]
pub struct HybridConfig {
pub semantic_weight: f32,
pub keyword_weight: f32,
pub fusion_method: FusionMethod,
pub rrf_k: f32,
pub max_candidates: usize,
pub min_score_threshold: f32,
}
impl Default for HybridConfig {
fn default() -> Self {
Self {
semantic_weight: 0.7,
keyword_weight: 0.3,
fusion_method: FusionMethod::RRF,
rrf_k: 60.0,
max_candidates: 100,
min_score_threshold: 0.1,
}
}
}
pub struct HybridRetriever {
vector_index: VectorIndex,
embedding_generator: EmbeddingGenerator,
bm25_retriever: BM25Retriever,
config: HybridConfig,
initialized: bool,
}
impl HybridRetriever {
pub fn new() -> Self {
Self {
vector_index: VectorIndex::new(),
embedding_generator: EmbeddingGenerator::new(128),
bm25_retriever: BM25Retriever::new(),
config: HybridConfig::default(),
initialized: false,
}
}
pub fn with_config(config: HybridConfig) -> Self {
Self {
vector_index: VectorIndex::new(),
embedding_generator: EmbeddingGenerator::new(128),
bm25_retriever: BM25Retriever::new(),
config,
initialized: false,
}
}
pub fn initialize_with_graph(&mut self, graph: &KnowledgeGraph) -> Result<()> {
for entity in graph.entities() {
if let Some(embedding) = &entity.embedding {
let id = format!("entity:{}", entity.id);
self.vector_index.add_vector(id, embedding.clone())?;
}
}
for chunk in graph.chunks() {
if let Some(embedding) = &chunk.embedding {
let id = format!("chunk:{}", chunk.id);
self.vector_index.add_vector(id, embedding.clone())?;
}
}
if !self.vector_index.is_empty() {
self.vector_index.build_index()?;
}
let mut bm25_documents = Vec::new();
for entity in graph.entities() {
let doc = crate::retrieval::bm25::Document {
id: format!("entity:{}", entity.id),
content: format!("{} {}", entity.name, entity.entity_type),
metadata: HashMap::new(),
};
bm25_documents.push(doc);
}
for chunk in graph.chunks() {
let doc = crate::retrieval::bm25::Document {
id: format!("chunk:{}", chunk.id),
content: chunk.content.clone(),
metadata: HashMap::new(),
};
bm25_documents.push(doc);
}
self.bm25_retriever.index_documents(&bm25_documents)?;
self.initialized = true;
Ok(())
}
pub fn search(&mut self, query: &str, limit: usize) -> Result<Vec<HybridSearchResult>> {
if !self.initialized {
return Err(GraphRAGError::Retrieval {
message: "Hybrid retriever not initialized. Call initialize_with_graph() first."
.to_string(),
});
}
let semantic_results = self.semantic_search(query, self.config.max_candidates)?;
let keyword_results = self.keyword_search(query, self.config.max_candidates);
let combined_results = self.combine_results(semantic_results, keyword_results, limit)?;
Ok(combined_results)
}
fn semantic_search(&mut self, query: &str, limit: usize) -> Result<Vec<(String, f32, String)>> {
let query_embedding = self.embedding_generator.generate_embedding(query);
let similar_vectors = self.vector_index.search(&query_embedding, limit)?;
let mut results = Vec::new();
for (id, score) in similar_vectors {
results.push((id.clone(), score, id));
}
Ok(results)
}
fn keyword_search(&self, query: &str, limit: usize) -> Vec<BM25Result> {
self.bm25_retriever.search(query, limit)
}
fn combine_results(
&mut self,
semantic_results: Vec<(String, f32, String)>,
keyword_results: Vec<BM25Result>,
limit: usize,
) -> Result<Vec<HybridSearchResult>> {
match self.config.fusion_method {
FusionMethod::RRF => {
self.reciprocal_rank_fusion(semantic_results, keyword_results, limit)
},
FusionMethod::Weighted => {
self.weighted_combination(semantic_results, keyword_results, limit)
},
FusionMethod::CombSum => self.comb_sum_fusion(semantic_results, keyword_results, limit),
FusionMethod::MaxScore => {
self.max_score_fusion(semantic_results, keyword_results, limit)
},
}
}
fn reciprocal_rank_fusion(
&mut self,
semantic_results: Vec<(String, f32, String)>,
keyword_results: Vec<BM25Result>,
limit: usize,
) -> Result<Vec<HybridSearchResult>> {
let mut combined_scores: HashMap<String, (f32, f32, f32)> = HashMap::new();
let mut content_map: HashMap<String, String> = HashMap::new();
for (rank, (id, score, content)) in semantic_results.iter().enumerate() {
let rrf_score = 1.0 / (self.config.rrf_k + rank as f32 + 1.0);
combined_scores.insert(
id.clone(),
(rrf_score * self.config.semantic_weight, *score, 0.0),
);
content_map.insert(id.clone(), content.clone());
}
for (rank, result) in keyword_results.iter().enumerate() {
let rrf_score = 1.0 / (self.config.rrf_k + rank as f32 + 1.0);
let entry = combined_scores
.entry(result.doc_id.clone())
.or_insert((0.0, 0.0, 0.0));
entry.0 += rrf_score * self.config.keyword_weight;
entry.2 = result.score;
content_map.insert(result.doc_id.clone(), result.content.clone());
}
self.create_hybrid_results(combined_scores, content_map, limit, FusionMethod::RRF)
}
fn weighted_combination(
&mut self,
semantic_results: Vec<(String, f32, String)>,
keyword_results: Vec<BM25Result>,
limit: usize,
) -> Result<Vec<HybridSearchResult>> {
let mut combined_scores: HashMap<String, (f32, f32, f32)> = HashMap::new();
let mut content_map: HashMap<String, String> = HashMap::new();
let max_semantic = semantic_results
.iter()
.map(|(_, score, _)| *score)
.fold(f32::NEG_INFINITY, f32::max);
for (id, score, content) in semantic_results {
let normalized_score = if max_semantic > 0.0 {
score / max_semantic
} else {
0.0
};
combined_scores.insert(
id.clone(),
(normalized_score * self.config.semantic_weight, score, 0.0),
);
content_map.insert(id, content);
}
let max_keyword = keyword_results
.iter()
.map(|r| r.score)
.fold(f32::NEG_INFINITY, f32::max);
for result in keyword_results {
let normalized_score = if max_keyword > 0.0 {
result.score / max_keyword
} else {
0.0
};
let entry = combined_scores
.entry(result.doc_id.clone())
.or_insert((0.0, 0.0, 0.0));
entry.0 += normalized_score * self.config.keyword_weight;
entry.2 = result.score;
content_map.insert(result.doc_id.clone(), result.content.clone());
}
self.create_hybrid_results(combined_scores, content_map, limit, FusionMethod::Weighted)
}
fn comb_sum_fusion(
&mut self,
semantic_results: Vec<(String, f32, String)>,
keyword_results: Vec<BM25Result>,
limit: usize,
) -> Result<Vec<HybridSearchResult>> {
let mut combined_scores: HashMap<String, (f32, f32, f32)> = HashMap::new();
let mut content_map: HashMap<String, String> = HashMap::new();
for (id, score, content) in semantic_results {
combined_scores.insert(id.clone(), (score, score, 0.0));
content_map.insert(id, content);
}
for result in keyword_results {
let entry = combined_scores
.entry(result.doc_id.clone())
.or_insert((0.0, 0.0, 0.0));
entry.0 += result.score;
entry.2 = result.score;
content_map.insert(result.doc_id.clone(), result.content.clone());
}
self.create_hybrid_results(combined_scores, content_map, limit, FusionMethod::CombSum)
}
fn max_score_fusion(
&mut self,
semantic_results: Vec<(String, f32, String)>,
keyword_results: Vec<BM25Result>,
limit: usize,
) -> Result<Vec<HybridSearchResult>> {
let mut combined_scores: HashMap<String, (f32, f32, f32)> = HashMap::new();
let mut content_map: HashMap<String, String> = HashMap::new();
for (id, score, content) in semantic_results {
combined_scores.insert(id.clone(), (score, score, 0.0));
content_map.insert(id, content);
}
for result in keyword_results {
let entry = combined_scores
.entry(result.doc_id.clone())
.or_insert((0.0, 0.0, 0.0));
entry.0 = entry.0.max(result.score);
entry.2 = result.score;
content_map.insert(result.doc_id.clone(), result.content.clone());
}
self.create_hybrid_results(combined_scores, content_map, limit, FusionMethod::MaxScore)
}
fn create_hybrid_results(
&self,
combined_scores: HashMap<String, (f32, f32, f32)>,
content_map: HashMap<String, String>,
limit: usize,
fusion_method: FusionMethod,
) -> Result<Vec<HybridSearchResult>> {
let mut results: Vec<HybridSearchResult> = combined_scores
.into_iter()
.filter_map(|(id, (combined_score, semantic_score, keyword_score))| {
if combined_score >= self.config.min_score_threshold {
let content = content_map.get(&id).cloned().unwrap_or_else(|| id.clone());
let result_type = if id.starts_with("entity:") {
ResultType::Entity
} else if id.starts_with("chunk:") {
ResultType::Chunk
} else {
ResultType::Hybrid
};
let entities = if result_type == ResultType::Entity {
vec![content.clone()]
} else {
Vec::new()
};
Some(HybridSearchResult {
id: id.clone(),
content,
score: combined_score,
semantic_score,
keyword_score,
result_type,
entities,
source_chunks: vec![id],
fusion_method: fusion_method.clone(),
})
} else {
None
}
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(limit);
Ok(results)
}
pub fn get_config(&self) -> &HybridConfig {
&self.config
}
pub fn set_config(&mut self, config: HybridConfig) {
self.config = config;
}
pub fn is_initialized(&self) -> bool {
self.initialized
}
pub fn get_statistics(&self) -> HybridStatistics {
let vector_stats = self.vector_index.statistics();
let bm25_stats = self.bm25_retriever.get_statistics();
HybridStatistics {
vector_count: vector_stats.vector_count,
bm25_document_count: bm25_stats.total_documents,
bm25_term_count: bm25_stats.total_terms,
config: self.config.clone(),
initialized: self.initialized,
}
}
pub fn clear(&mut self) {
self.vector_index = VectorIndex::new();
self.bm25_retriever.clear();
self.initialized = false;
}
}
impl Default for HybridRetriever {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct HybridStatistics {
pub vector_count: usize,
pub bm25_document_count: usize,
pub bm25_term_count: usize,
pub config: HybridConfig,
pub initialized: bool,
}
impl HybridStatistics {
pub fn print(&self) {
println!("Hybrid Retriever Statistics:");
println!(" Initialized: {}", self.initialized);
println!(" Vector index: {} vectors", self.vector_count);
println!(
" BM25 index: {} documents, {} terms",
self.bm25_document_count, self.bm25_term_count
);
println!(" Fusion method: {:?}", self.config.fusion_method);
println!(
" Weights: semantic={:.2}, keyword={:.2}",
self.config.semantic_weight, self.config.keyword_weight
);
println!(" Score threshold: {:.3}", self.config.min_score_threshold);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::KnowledgeGraph;
#[test]
fn test_hybrid_retriever_creation() {
let retriever = HybridRetriever::new();
assert!(!retriever.is_initialized());
}
#[test]
fn test_hybrid_config_default() {
let config = HybridConfig::default();
assert_eq!(config.semantic_weight, 0.7);
assert_eq!(config.keyword_weight, 0.3);
assert_eq!(config.fusion_method, FusionMethod::RRF);
}
#[test]
fn test_fusion_method_variants() {
assert_eq!(FusionMethod::RRF, FusionMethod::RRF);
assert_ne!(FusionMethod::RRF, FusionMethod::Weighted);
}
#[test]
fn test_hybrid_retriever_with_empty_graph() {
let mut retriever = HybridRetriever::new();
let graph = KnowledgeGraph::new();
let result = retriever.initialize_with_graph(&graph);
assert!(result.is_ok());
assert!(retriever.is_initialized());
}
#[test]
fn test_search_without_initialization() {
let mut retriever = HybridRetriever::new();
let result = retriever.search("test", 10);
assert!(result.is_err());
}
}