1use crate::{
2 core::KnowledgeGraph,
3 retrieval::{
4 bm25::{BM25Result, BM25Retriever},
5 ResultType,
6 },
7 vector::{EmbeddingGenerator, VectorIndex},
8 GraphRAGError, Result,
9};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone)]
14pub struct HybridSearchResult {
15 pub id: String,
17 pub content: String,
19 pub score: f32,
21 pub semantic_score: f32,
23 pub keyword_score: f32,
25 pub result_type: ResultType,
27 pub entities: Vec<String>,
29 pub source_chunks: Vec<String>,
31 pub fusion_method: FusionMethod,
33}
34
35#[derive(Debug, Clone, PartialEq)]
37#[allow(clippy::upper_case_acronyms)]
38pub enum FusionMethod {
39 RRF,
41 Weighted,
43 CombSum,
45 MaxScore,
47}
48
49#[derive(Debug, Clone)]
51pub struct HybridConfig {
52 pub semantic_weight: f32,
54 pub keyword_weight: f32,
56 pub fusion_method: FusionMethod,
58 pub rrf_k: f32,
60 pub max_candidates: usize,
62 pub min_score_threshold: f32,
64}
65
66impl Default for HybridConfig {
67 fn default() -> Self {
68 Self {
69 semantic_weight: 0.7,
70 keyword_weight: 0.3,
71 fusion_method: FusionMethod::RRF,
72 rrf_k: 60.0,
73 max_candidates: 100,
74 min_score_threshold: 0.1,
75 }
76 }
77}
78
79pub struct HybridRetriever {
81 vector_index: VectorIndex,
83 embedding_generator: EmbeddingGenerator,
85 bm25_retriever: BM25Retriever,
87 config: HybridConfig,
89 initialized: bool,
91}
92
93impl HybridRetriever {
94 pub fn new() -> Self {
96 Self {
97 vector_index: VectorIndex::new(),
98 embedding_generator: EmbeddingGenerator::new(128),
99 bm25_retriever: BM25Retriever::new(),
100 config: HybridConfig::default(),
101 initialized: false,
102 }
103 }
104
105 pub fn with_config(config: HybridConfig) -> Self {
107 Self {
108 vector_index: VectorIndex::new(),
109 embedding_generator: EmbeddingGenerator::new(128),
110 bm25_retriever: BM25Retriever::new(),
111 config,
112 initialized: false,
113 }
114 }
115
116 pub fn initialize_with_graph(&mut self, graph: &KnowledgeGraph) -> Result<()> {
118 for entity in graph.entities() {
120 if let Some(embedding) = &entity.embedding {
121 let id = format!("entity:{}", entity.id);
122 self.vector_index.add_vector(id, embedding.clone())?;
123 }
124 }
125
126 for chunk in graph.chunks() {
127 if let Some(embedding) = &chunk.embedding {
128 let id = format!("chunk:{}", chunk.id);
129 self.vector_index.add_vector(id, embedding.clone())?;
130 }
131 }
132
133 if !self.vector_index.is_empty() {
135 self.vector_index.build_index()?;
136 }
137
138 let mut bm25_documents = Vec::new();
140
141 for entity in graph.entities() {
143 let doc = crate::retrieval::bm25::Document {
144 id: format!("entity:{}", entity.id),
145 content: format!("{} {}", entity.name, entity.entity_type),
146 metadata: HashMap::new(),
147 };
148 bm25_documents.push(doc);
149 }
150
151 for chunk in graph.chunks() {
153 let doc = crate::retrieval::bm25::Document {
154 id: format!("chunk:{}", chunk.id),
155 content: chunk.content.clone(),
156 metadata: HashMap::new(),
157 };
158 bm25_documents.push(doc);
159 }
160
161 self.bm25_retriever.index_documents(&bm25_documents)?;
162 self.initialized = true;
163
164 Ok(())
165 }
166
167 pub fn search(&mut self, query: &str, limit: usize) -> Result<Vec<HybridSearchResult>> {
169 if !self.initialized {
170 return Err(GraphRAGError::Retrieval {
171 message: "Hybrid retriever not initialized. Call initialize_with_graph() first."
172 .to_string(),
173 });
174 }
175
176 let semantic_results = self.semantic_search(query, self.config.max_candidates)?;
178
179 let keyword_results = self.keyword_search(query, self.config.max_candidates);
181
182 let combined_results = self.combine_results(semantic_results, keyword_results, limit)?;
184
185 Ok(combined_results)
186 }
187
188 fn semantic_search(&mut self, query: &str, limit: usize) -> Result<Vec<(String, f32, String)>> {
190 let query_embedding = self.embedding_generator.generate_embedding(query);
191 let similar_vectors = self.vector_index.search(&query_embedding, limit)?;
192
193 let mut results = Vec::new();
194 for (id, score) in similar_vectors {
195 results.push((id.clone(), score, id));
198 }
199
200 Ok(results)
201 }
202
203 fn keyword_search(&self, query: &str, limit: usize) -> Vec<BM25Result> {
205 self.bm25_retriever.search(query, limit)
206 }
207
208 fn combine_results(
210 &mut self,
211 semantic_results: Vec<(String, f32, String)>,
212 keyword_results: Vec<BM25Result>,
213 limit: usize,
214 ) -> Result<Vec<HybridSearchResult>> {
215 match self.config.fusion_method {
216 FusionMethod::RRF => {
217 self.reciprocal_rank_fusion(semantic_results, keyword_results, limit)
218 },
219 FusionMethod::Weighted => {
220 self.weighted_combination(semantic_results, keyword_results, limit)
221 },
222 FusionMethod::CombSum => self.comb_sum_fusion(semantic_results, keyword_results, limit),
223 FusionMethod::MaxScore => {
224 self.max_score_fusion(semantic_results, keyword_results, limit)
225 },
226 }
227 }
228
229 fn reciprocal_rank_fusion(
231 &mut self,
232 semantic_results: Vec<(String, f32, String)>,
233 keyword_results: Vec<BM25Result>,
234 limit: usize,
235 ) -> Result<Vec<HybridSearchResult>> {
236 let mut combined_scores: HashMap<String, (f32, f32, f32)> = HashMap::new();
237 let mut content_map: HashMap<String, String> = HashMap::new();
238
239 for (rank, (id, score, content)) in semantic_results.iter().enumerate() {
241 let rrf_score = 1.0 / (self.config.rrf_k + rank as f32 + 1.0);
242 combined_scores.insert(
243 id.clone(),
244 (rrf_score * self.config.semantic_weight, *score, 0.0),
245 );
246 content_map.insert(id.clone(), content.clone());
247 }
248
249 for (rank, result) in keyword_results.iter().enumerate() {
251 let rrf_score = 1.0 / (self.config.rrf_k + rank as f32 + 1.0);
252 let entry = combined_scores
253 .entry(result.doc_id.clone())
254 .or_insert((0.0, 0.0, 0.0));
255 entry.0 += rrf_score * self.config.keyword_weight;
256 entry.2 = result.score;
257 content_map.insert(result.doc_id.clone(), result.content.clone());
258 }
259
260 self.create_hybrid_results(combined_scores, content_map, limit, FusionMethod::RRF)
261 }
262
263 fn weighted_combination(
265 &mut self,
266 semantic_results: Vec<(String, f32, String)>,
267 keyword_results: Vec<BM25Result>,
268 limit: usize,
269 ) -> Result<Vec<HybridSearchResult>> {
270 let mut combined_scores: HashMap<String, (f32, f32, f32)> = HashMap::new();
271 let mut content_map: HashMap<String, String> = HashMap::new();
272
273 let max_semantic = semantic_results
275 .iter()
276 .map(|(_, score, _)| *score)
277 .fold(f32::NEG_INFINITY, f32::max);
278
279 for (id, score, content) in semantic_results {
280 let normalized_score = if max_semantic > 0.0 {
281 score / max_semantic
282 } else {
283 0.0
284 };
285 combined_scores.insert(
286 id.clone(),
287 (normalized_score * self.config.semantic_weight, score, 0.0),
288 );
289 content_map.insert(id, content);
290 }
291
292 let max_keyword = keyword_results
294 .iter()
295 .map(|r| r.score)
296 .fold(f32::NEG_INFINITY, f32::max);
297
298 for result in keyword_results {
299 let normalized_score = if max_keyword > 0.0 {
300 result.score / max_keyword
301 } else {
302 0.0
303 };
304 let entry = combined_scores
305 .entry(result.doc_id.clone())
306 .or_insert((0.0, 0.0, 0.0));
307 entry.0 += normalized_score * self.config.keyword_weight;
308 entry.2 = result.score;
309 content_map.insert(result.doc_id.clone(), result.content.clone());
310 }
311
312 self.create_hybrid_results(combined_scores, content_map, limit, FusionMethod::Weighted)
313 }
314
315 fn comb_sum_fusion(
317 &mut self,
318 semantic_results: Vec<(String, f32, String)>,
319 keyword_results: Vec<BM25Result>,
320 limit: usize,
321 ) -> Result<Vec<HybridSearchResult>> {
322 let mut combined_scores: HashMap<String, (f32, f32, f32)> = HashMap::new();
323 let mut content_map: HashMap<String, String> = HashMap::new();
324
325 for (id, score, content) in semantic_results {
327 combined_scores.insert(id.clone(), (score, score, 0.0));
328 content_map.insert(id, content);
329 }
330
331 for result in keyword_results {
333 let entry = combined_scores
334 .entry(result.doc_id.clone())
335 .or_insert((0.0, 0.0, 0.0));
336 entry.0 += result.score;
337 entry.2 = result.score;
338 content_map.insert(result.doc_id.clone(), result.content.clone());
339 }
340
341 self.create_hybrid_results(combined_scores, content_map, limit, FusionMethod::CombSum)
342 }
343
344 fn max_score_fusion(
346 &mut self,
347 semantic_results: Vec<(String, f32, String)>,
348 keyword_results: Vec<BM25Result>,
349 limit: usize,
350 ) -> Result<Vec<HybridSearchResult>> {
351 let mut combined_scores: HashMap<String, (f32, f32, f32)> = HashMap::new();
352 let mut content_map: HashMap<String, String> = HashMap::new();
353
354 for (id, score, content) in semantic_results {
356 combined_scores.insert(id.clone(), (score, score, 0.0));
357 content_map.insert(id, content);
358 }
359
360 for result in keyword_results {
362 let entry = combined_scores
363 .entry(result.doc_id.clone())
364 .or_insert((0.0, 0.0, 0.0));
365 entry.0 = entry.0.max(result.score);
366 entry.2 = result.score;
367 content_map.insert(result.doc_id.clone(), result.content.clone());
368 }
369
370 self.create_hybrid_results(combined_scores, content_map, limit, FusionMethod::MaxScore)
371 }
372
373 fn create_hybrid_results(
375 &self,
376 combined_scores: HashMap<String, (f32, f32, f32)>,
377 content_map: HashMap<String, String>,
378 limit: usize,
379 fusion_method: FusionMethod,
380 ) -> Result<Vec<HybridSearchResult>> {
381 let mut results: Vec<HybridSearchResult> = combined_scores
382 .into_iter()
383 .filter_map(|(id, (combined_score, semantic_score, keyword_score))| {
384 if combined_score >= self.config.min_score_threshold {
385 let content = content_map.get(&id).cloned().unwrap_or_else(|| id.clone());
386
387 let result_type = if id.starts_with("entity:") {
389 ResultType::Entity
390 } else if id.starts_with("chunk:") {
391 ResultType::Chunk
392 } else {
393 ResultType::Hybrid
394 };
395
396 let entities = if result_type == ResultType::Entity {
398 vec![content.clone()]
399 } else {
400 Vec::new()
401 };
402
403 Some(HybridSearchResult {
404 id: id.clone(),
405 content,
406 score: combined_score,
407 semantic_score,
408 keyword_score,
409 result_type,
410 entities,
411 source_chunks: vec![id],
412 fusion_method: fusion_method.clone(),
413 })
414 } else {
415 None
416 }
417 })
418 .collect();
419
420 results.sort_by(|a, b| {
422 b.score
423 .partial_cmp(&a.score)
424 .unwrap_or(std::cmp::Ordering::Equal)
425 });
426 results.truncate(limit);
427
428 Ok(results)
429 }
430
431 pub fn get_config(&self) -> &HybridConfig {
433 &self.config
434 }
435
436 pub fn set_config(&mut self, config: HybridConfig) {
438 self.config = config;
439 }
440
441 pub fn is_initialized(&self) -> bool {
443 self.initialized
444 }
445
446 pub fn get_statistics(&self) -> HybridStatistics {
448 let vector_stats = self.vector_index.statistics();
449 let bm25_stats = self.bm25_retriever.get_statistics();
450
451 HybridStatistics {
452 vector_count: vector_stats.vector_count,
453 bm25_document_count: bm25_stats.total_documents,
454 bm25_term_count: bm25_stats.total_terms,
455 config: self.config.clone(),
456 initialized: self.initialized,
457 }
458 }
459
460 pub fn clear(&mut self) {
462 self.vector_index = VectorIndex::new();
463 self.bm25_retriever.clear();
464 self.initialized = false;
465 }
466}
467
468impl Default for HybridRetriever {
469 fn default() -> Self {
470 Self::new()
471 }
472}
473
474#[derive(Debug, Clone)]
476pub struct HybridStatistics {
477 pub vector_count: usize,
479 pub bm25_document_count: usize,
481 pub bm25_term_count: usize,
483 pub config: HybridConfig,
485 pub initialized: bool,
487}
488
489impl HybridStatistics {
490 pub fn print(&self) {
492 println!("Hybrid Retriever Statistics:");
493 println!(" Initialized: {}", self.initialized);
494 println!(" Vector index: {} vectors", self.vector_count);
495 println!(
496 " BM25 index: {} documents, {} terms",
497 self.bm25_document_count, self.bm25_term_count
498 );
499 println!(" Fusion method: {:?}", self.config.fusion_method);
500 println!(
501 " Weights: semantic={:.2}, keyword={:.2}",
502 self.config.semantic_weight, self.config.keyword_weight
503 );
504 println!(" Score threshold: {:.3}", self.config.min_score_threshold);
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use super::*;
511 use crate::core::KnowledgeGraph;
512
513 #[test]
514 fn test_hybrid_retriever_creation() {
515 let retriever = HybridRetriever::new();
516 assert!(!retriever.is_initialized());
517 }
518
519 #[test]
520 fn test_hybrid_config_default() {
521 let config = HybridConfig::default();
522 assert_eq!(config.semantic_weight, 0.7);
523 assert_eq!(config.keyword_weight, 0.3);
524 assert_eq!(config.fusion_method, FusionMethod::RRF);
525 }
526
527 #[test]
528 fn test_fusion_method_variants() {
529 assert_eq!(FusionMethod::RRF, FusionMethod::RRF);
530 assert_ne!(FusionMethod::RRF, FusionMethod::Weighted);
531 }
532
533 #[test]
534 fn test_hybrid_retriever_with_empty_graph() {
535 let mut retriever = HybridRetriever::new();
536 let graph = KnowledgeGraph::new();
537
538 let result = retriever.initialize_with_graph(&graph);
539 assert!(result.is_ok());
540 assert!(retriever.is_initialized());
541 }
542
543 #[test]
544 fn test_search_without_initialization() {
545 let mut retriever = HybridRetriever::new();
546 let result = retriever.search("test", 10);
547 assert!(result.is_err());
548 }
549}