1use crate::{
2 core::KnowledgeGraph,
3 retrieval::{
4 bm25::{BM25Result, BM25Retriever},
5 ResultType,
6 },
7 vector::{EmbeddingGenerator, VectorIndex},
8 GraphRAGError, Result,
9};
10use std::collections::HashMap;
11
12#[derive(Debug, Clone)]
14pub struct HybridSearchResult {
15 pub id: String,
17 pub content: String,
19 pub score: f32,
21 pub semantic_score: f32,
23 pub keyword_score: f32,
25 pub result_type: ResultType,
27 pub entities: Vec<String>,
29 pub source_chunks: Vec<String>,
31 pub fusion_method: FusionMethod,
33}
34
35#[derive(Debug, Clone, PartialEq)]
37#[allow(clippy::upper_case_acronyms)]
38pub enum FusionMethod {
39 RRF,
41 Weighted,
43 CombSum,
45 MaxScore,
47}
48
49#[derive(Debug, Clone)]
51pub struct HybridConfig {
52 pub semantic_weight: f32,
54 pub keyword_weight: f32,
56 pub fusion_method: FusionMethod,
58 pub rrf_k: f32,
60 pub max_candidates: usize,
62 pub min_score_threshold: f32,
64}
65
66impl Default for HybridConfig {
67 fn default() -> Self {
68 Self {
69 semantic_weight: 0.7,
70 keyword_weight: 0.3,
71 fusion_method: FusionMethod::RRF,
72 rrf_k: 60.0,
73 max_candidates: 100,
74 min_score_threshold: 0.1,
75 }
76 }
77}
78
79pub struct HybridRetriever {
81 vector_index: VectorIndex,
83 embedding_generator: EmbeddingGenerator,
85 bm25_retriever: BM25Retriever,
87 config: HybridConfig,
89 initialized: bool,
91}
92
93impl HybridRetriever {
94 pub fn new() -> Self {
96 Self {
97 vector_index: VectorIndex::new(),
98 embedding_generator: EmbeddingGenerator::new(128),
99 bm25_retriever: BM25Retriever::new(),
100 config: HybridConfig::default(),
101 initialized: false,
102 }
103 }
104
105 pub fn with_config(config: HybridConfig) -> Self {
107 Self {
108 vector_index: VectorIndex::new(),
109 embedding_generator: EmbeddingGenerator::new(128),
110 bm25_retriever: BM25Retriever::new(),
111 config,
112 initialized: false,
113 }
114 }
115
116 pub fn initialize_with_graph(&mut self, graph: &KnowledgeGraph) -> Result<()> {
118 for entity in graph.entities() {
120 if let Some(embedding) = &entity.embedding {
121 let id = format!("entity:{}", entity.id);
122 self.vector_index.add_vector(id, embedding.clone())?;
123 }
124 }
125
126 for chunk in graph.chunks() {
127 if let Some(embedding) = &chunk.embedding {
128 let id = format!("chunk:{}", chunk.id);
129 self.vector_index.add_vector(id, embedding.clone())?;
130 }
131 }
132
133 if !self.vector_index.is_empty() {
135 self.vector_index.build_index()?;
136 }
137
138 let mut bm25_documents = Vec::new();
140
141 for entity in graph.entities() {
143 let doc = crate::retrieval::bm25::Document {
144 id: format!("entity:{}", entity.id),
145 content: format!("{} {}", entity.name, entity.entity_type),
146 metadata: HashMap::new(),
147 };
148 bm25_documents.push(doc);
149 }
150
151 for chunk in graph.chunks() {
153 let doc = crate::retrieval::bm25::Document {
154 id: format!("chunk:{}", chunk.id),
155 content: chunk.content.clone(),
156 metadata: HashMap::new(),
157 };
158 bm25_documents.push(doc);
159 }
160
161 self.bm25_retriever.index_documents(&bm25_documents)?;
162 self.initialized = true;
163
164 Ok(())
165 }
166
167 pub fn search(&mut self, query: &str, limit: usize) -> Result<Vec<HybridSearchResult>> {
169 if !self.initialized {
170 return Err(GraphRAGError::Retrieval {
171 message: "Hybrid retriever not initialized. Call initialize_with_graph() first."
172 .to_string(),
173 });
174 }
175
176 let semantic_results = self.semantic_search(query, self.config.max_candidates)?;
178
179 let keyword_results = self.keyword_search(query, self.config.max_candidates);
181
182 let combined_results = self.combine_results(semantic_results, keyword_results, limit)?;
184
185 Ok(combined_results)
186 }
187
188 fn semantic_search(&mut self, query: &str, limit: usize) -> Result<Vec<(String, f32, String)>> {
190 let query_embedding = self.embedding_generator.generate_embedding(query);
191 let similar_vectors = self.vector_index.search(&query_embedding, limit)?;
192
193 let mut results = Vec::new();
194 for (id, score) in similar_vectors {
195 results.push((id.clone(), score, id));
198 }
199
200 Ok(results)
201 }
202
203 fn keyword_search(&self, query: &str, limit: usize) -> Vec<BM25Result> {
205 self.bm25_retriever.search(query, limit)
206 }
207
208 fn combine_results(
210 &mut self,
211 semantic_results: Vec<(String, f32, String)>,
212 keyword_results: Vec<BM25Result>,
213 limit: usize,
214 ) -> Result<Vec<HybridSearchResult>> {
215 match self.config.fusion_method {
216 FusionMethod::RRF => {
217 self.reciprocal_rank_fusion(semantic_results, keyword_results, limit)
218 }
219 FusionMethod::Weighted => {
220 self.weighted_combination(semantic_results, keyword_results, limit)
221 }
222 FusionMethod::CombSum => self.comb_sum_fusion(semantic_results, keyword_results, limit),
223 FusionMethod::MaxScore => {
224 self.max_score_fusion(semantic_results, keyword_results, limit)
225 }
226 }
227 }
228
229 fn reciprocal_rank_fusion(
231 &mut self,
232 semantic_results: Vec<(String, f32, String)>,
233 keyword_results: Vec<BM25Result>,
234 limit: usize,
235 ) -> Result<Vec<HybridSearchResult>> {
236 let mut combined_scores: HashMap<String, (f32, f32, f32)> = HashMap::new();
237 let mut content_map: HashMap<String, String> = HashMap::new();
238
239 for (rank, (id, score, content)) in semantic_results.iter().enumerate() {
241 let rrf_score = 1.0 / (self.config.rrf_k + rank as f32 + 1.0);
242 combined_scores.insert(
243 id.clone(),
244 (rrf_score * self.config.semantic_weight, *score, 0.0),
245 );
246 content_map.insert(id.clone(), content.clone());
247 }
248
249 for (rank, result) in keyword_results.iter().enumerate() {
251 let rrf_score = 1.0 / (self.config.rrf_k + rank as f32 + 1.0);
252 let entry = combined_scores
253 .entry(result.doc_id.clone())
254 .or_insert((0.0, 0.0, 0.0));
255 entry.0 += rrf_score * self.config.keyword_weight;
256 entry.2 = result.score;
257 content_map.insert(result.doc_id.clone(), result.content.clone());
258 }
259
260 self.create_hybrid_results(combined_scores, content_map, limit, FusionMethod::RRF)
261 }
262
263 fn weighted_combination(
265 &mut self,
266 semantic_results: Vec<(String, f32, String)>,
267 keyword_results: Vec<BM25Result>,
268 limit: usize,
269 ) -> Result<Vec<HybridSearchResult>> {
270 let mut combined_scores: HashMap<String, (f32, f32, f32)> = HashMap::new();
271 let mut content_map: HashMap<String, String> = HashMap::new();
272
273 let max_semantic = semantic_results
275 .iter()
276 .map(|(_, score, _)| *score)
277 .fold(f32::NEG_INFINITY, f32::max);
278
279 for (id, score, content) in semantic_results {
280 let normalized_score = if max_semantic > 0.0 {
281 score / max_semantic
282 } else {
283 0.0
284 };
285 combined_scores.insert(
286 id.clone(),
287 (normalized_score * self.config.semantic_weight, score, 0.0),
288 );
289 content_map.insert(id, content);
290 }
291
292 let max_keyword = keyword_results
294 .iter()
295 .map(|r| r.score)
296 .fold(f32::NEG_INFINITY, f32::max);
297
298 for result in keyword_results {
299 let normalized_score = if max_keyword > 0.0 {
300 result.score / max_keyword
301 } else {
302 0.0
303 };
304 let entry = combined_scores
305 .entry(result.doc_id.clone())
306 .or_insert((0.0, 0.0, 0.0));
307 entry.0 += normalized_score * self.config.keyword_weight;
308 entry.2 = result.score;
309 content_map.insert(result.doc_id.clone(), result.content.clone());
310 }
311
312 self.create_hybrid_results(combined_scores, content_map, limit, FusionMethod::Weighted)
313 }
314
315 fn comb_sum_fusion(
317 &mut self,
318 semantic_results: Vec<(String, f32, String)>,
319 keyword_results: Vec<BM25Result>,
320 limit: usize,
321 ) -> Result<Vec<HybridSearchResult>> {
322 let mut combined_scores: HashMap<String, (f32, f32, f32)> = HashMap::new();
323 let mut content_map: HashMap<String, String> = HashMap::new();
324
325 for (id, score, content) in semantic_results {
327 combined_scores.insert(id.clone(), (score, score, 0.0));
328 content_map.insert(id, content);
329 }
330
331 for result in keyword_results {
333 let entry = combined_scores
334 .entry(result.doc_id.clone())
335 .or_insert((0.0, 0.0, 0.0));
336 entry.0 += result.score;
337 entry.2 = result.score;
338 content_map.insert(result.doc_id.clone(), result.content.clone());
339 }
340
341 self.create_hybrid_results(combined_scores, content_map, limit, FusionMethod::CombSum)
342 }
343
344 fn max_score_fusion(
346 &mut self,
347 semantic_results: Vec<(String, f32, String)>,
348 keyword_results: Vec<BM25Result>,
349 limit: usize,
350 ) -> Result<Vec<HybridSearchResult>> {
351 let mut combined_scores: HashMap<String, (f32, f32, f32)> = HashMap::new();
352 let mut content_map: HashMap<String, String> = HashMap::new();
353
354 for (id, score, content) in semantic_results {
356 combined_scores.insert(id.clone(), (score, score, 0.0));
357 content_map.insert(id, content);
358 }
359
360 for result in keyword_results {
362 let entry = combined_scores
363 .entry(result.doc_id.clone())
364 .or_insert((0.0, 0.0, 0.0));
365 entry.0 = entry.0.max(result.score);
366 entry.2 = result.score;
367 content_map.insert(result.doc_id.clone(), result.content.clone());
368 }
369
370 self.create_hybrid_results(combined_scores, content_map, limit, FusionMethod::MaxScore)
371 }
372
373 fn create_hybrid_results(
375 &self,
376 combined_scores: HashMap<String, (f32, f32, f32)>,
377 content_map: HashMap<String, String>,
378 limit: usize,
379 fusion_method: FusionMethod,
380 ) -> Result<Vec<HybridSearchResult>> {
381 let mut results: Vec<HybridSearchResult> = combined_scores
382 .into_iter()
383 .filter_map(|(id, (combined_score, semantic_score, keyword_score))| {
384 if combined_score >= self.config.min_score_threshold {
385 let content = content_map.get(&id).cloned().unwrap_or_else(|| id.clone());
386
387 let result_type = if id.starts_with("entity:") {
389 ResultType::Entity
390 } else if id.starts_with("chunk:") {
391 ResultType::Chunk
392 } else {
393 ResultType::Hybrid
394 };
395
396 let entities = if result_type == ResultType::Entity {
398 vec![content.clone()]
399 } else {
400 Vec::new()
401 };
402
403 Some(HybridSearchResult {
404 id: id.clone(),
405 content,
406 score: combined_score,
407 semantic_score,
408 keyword_score,
409 result_type,
410 entities,
411 source_chunks: vec![id],
412 fusion_method: fusion_method.clone(),
413 })
414 } else {
415 None
416 }
417 })
418 .collect();
419
420 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
422 results.truncate(limit);
423
424 Ok(results)
425 }
426
427 pub fn get_config(&self) -> &HybridConfig {
429 &self.config
430 }
431
432 pub fn set_config(&mut self, config: HybridConfig) {
434 self.config = config;
435 }
436
437 pub fn is_initialized(&self) -> bool {
439 self.initialized
440 }
441
442 pub fn get_statistics(&self) -> HybridStatistics {
444 let vector_stats = self.vector_index.statistics();
445 let bm25_stats = self.bm25_retriever.get_statistics();
446
447 HybridStatistics {
448 vector_count: vector_stats.vector_count,
449 bm25_document_count: bm25_stats.total_documents,
450 bm25_term_count: bm25_stats.total_terms,
451 config: self.config.clone(),
452 initialized: self.initialized,
453 }
454 }
455
456 pub fn clear(&mut self) {
458 self.vector_index = VectorIndex::new();
459 self.bm25_retriever.clear();
460 self.initialized = false;
461 }
462}
463
464impl Default for HybridRetriever {
465 fn default() -> Self {
466 Self::new()
467 }
468}
469
470#[derive(Debug, Clone)]
472pub struct HybridStatistics {
473 pub vector_count: usize,
475 pub bm25_document_count: usize,
477 pub bm25_term_count: usize,
479 pub config: HybridConfig,
481 pub initialized: bool,
483}
484
485impl HybridStatistics {
486 pub fn print(&self) {
488 println!("Hybrid Retriever Statistics:");
489 println!(" Initialized: {}", self.initialized);
490 println!(" Vector index: {} vectors", self.vector_count);
491 println!(
492 " BM25 index: {} documents, {} terms",
493 self.bm25_document_count, self.bm25_term_count
494 );
495 println!(" Fusion method: {:?}", self.config.fusion_method);
496 println!(
497 " Weights: semantic={:.2}, keyword={:.2}",
498 self.config.semantic_weight, self.config.keyword_weight
499 );
500 println!(" Score threshold: {:.3}", self.config.min_score_threshold);
501 }
502}
503
504#[cfg(test)]
505mod tests {
506 use super::*;
507 use crate::core::KnowledgeGraph;
508
509 #[test]
510 fn test_hybrid_retriever_creation() {
511 let retriever = HybridRetriever::new();
512 assert!(!retriever.is_initialized());
513 }
514
515 #[test]
516 fn test_hybrid_config_default() {
517 let config = HybridConfig::default();
518 assert_eq!(config.semantic_weight, 0.7);
519 assert_eq!(config.keyword_weight, 0.3);
520 assert_eq!(config.fusion_method, FusionMethod::RRF);
521 }
522
523 #[test]
524 fn test_fusion_method_variants() {
525 assert_eq!(FusionMethod::RRF, FusionMethod::RRF);
526 assert_ne!(FusionMethod::RRF, FusionMethod::Weighted);
527 }
528
529 #[test]
530 fn test_hybrid_retriever_with_empty_graph() {
531 let mut retriever = HybridRetriever::new();
532 let graph = KnowledgeGraph::new();
533
534 let result = retriever.initialize_with_graph(&graph);
535 assert!(result.is_ok());
536 assert!(retriever.is_initialized());
537 }
538
539 #[test]
540 fn test_search_without_initialization() {
541 let mut retriever = HybridRetriever::new();
542 let result = retriever.search("test", 10);
543 assert!(result.is_err());
544 }
545}