Skip to main content

graphrag_core/retrieval/
enriched.rs

1//! Enriched metadata-aware retrieval
2//!
3//! This module provides retrieval strategies that leverage enriched chunk metadata
4//! (chapters, sections, keywords, summaries) to improve search relevance and precision.
5
6use crate::{
7    core::{KnowledgeGraph, TextChunk},
8    retrieval::{QueryAnalysis, ResultType, SearchResult},
9    Result,
10};
11use std::collections::{HashMap, HashSet};
12
13/// Configuration for enriched metadata retrieval
14#[derive(Debug, Clone)]
15pub struct EnrichedRetrievalConfig {
16    /// Weight for keyword matching (0.0 to 1.0)
17    pub keyword_match_weight: f32,
18    /// Weight for chapter/section context matching (0.0 to 1.0)
19    pub structure_match_weight: f32,
20    /// Weight for summary relevance (0.0 to 1.0)
21    pub summary_weight: f32,
22    /// Minimum number of keywords to match for boosting
23    pub min_keyword_matches: usize,
24    /// Enable chapter/section filtering
25    pub enable_structure_filtering: bool,
26}
27
28impl Default for EnrichedRetrievalConfig {
29    fn default() -> Self {
30        Self {
31            keyword_match_weight: 0.3,
32            structure_match_weight: 0.2,
33            summary_weight: 0.15,
34            min_keyword_matches: 1,
35            enable_structure_filtering: true,
36        }
37    }
38}
39
40/// Metadata-enhanced retrieval strategies
41pub struct EnrichedRetriever {
42    config: EnrichedRetrievalConfig,
43}
44
45impl EnrichedRetriever {
46    /// Create a new enriched retriever
47    pub fn new() -> Self {
48        Self {
49            config: EnrichedRetrievalConfig::default(),
50        }
51    }
52
53    /// Create with custom configuration
54    pub fn with_config(config: EnrichedRetrievalConfig) -> Self {
55        Self { config }
56    }
57
58    /// Search chunks using enriched metadata
59    ///
60    /// This method boosts chunks that match:
61    /// 1. Query keywords present in chunk keywords (TF-IDF extracted)
62    /// 2. Chapter/Section mentioned in query
63    /// 3. Summary content relevant to query
64    pub fn metadata_search(
65        &self,
66        query: &str,
67        graph: &KnowledgeGraph,
68        _analysis: &QueryAnalysis,
69        base_results: &[SearchResult],
70    ) -> Result<Vec<SearchResult>> {
71        let mut enriched_results = Vec::new();
72
73        // Extract query keywords and potential structure references
74        let query_lower = query.to_lowercase();
75        let query_words: HashSet<String> = query_lower
76            .split_whitespace()
77            .filter(|w| w.len() > 3)
78            .map(|w| w.to_string())
79            .collect();
80
81        // Detect chapter/section references in query
82        let structure_refs = self.extract_structure_references(&query_lower);
83
84        // Process each chunk in the graph
85        for chunk in graph.chunks() {
86            if !chunk.entities.is_empty() || !chunk.metadata.keywords.is_empty() {
87                let mut base_score = self.find_base_score(chunk, base_results);
88                let mut metadata_boost = 0.0;
89
90                // 1. KEYWORD MATCHING BOOST
91                let keyword_matches =
92                    self.count_keyword_matches(&chunk.metadata.keywords, &query_words);
93                if keyword_matches >= self.config.min_keyword_matches {
94                    let keyword_boost = (keyword_matches as f32 / query_words.len().max(1) as f32)
95                        * self.config.keyword_match_weight;
96                    metadata_boost += keyword_boost;
97                }
98
99                // 2. STRUCTURE MATCHING BOOST (Chapter/Section)
100                if self.config.enable_structure_filtering {
101                    if let Some(structure_boost) =
102                        self.calculate_structure_boost(chunk, &structure_refs)
103                    {
104                        metadata_boost += structure_boost * self.config.structure_match_weight;
105                    }
106                }
107
108                // 3. SUMMARY RELEVANCE BOOST
109                if let Some(summary) = &chunk.metadata.summary {
110                    if self.matches_query(summary, &query_words) {
111                        metadata_boost += self.config.summary_weight;
112                    }
113                }
114
115                // 4. COMPLETENESS BONUS
116                let completeness = chunk.metadata.completeness_score();
117                if completeness > 0.7 {
118                    metadata_boost += 0.05; // Small bonus for high-quality metadata
119                }
120
121                // Apply boost only if significant
122                if metadata_boost > 0.05 {
123                    base_score = (base_score + metadata_boost).min(1.0);
124
125                    enriched_results.push(SearchResult {
126                        id: chunk.id.to_string(),
127                        content: chunk.content.clone(),
128                        score: base_score,
129                        result_type: ResultType::Chunk,
130                        entities: chunk
131                            .entities
132                            .iter()
133                            .filter_map(|eid| graph.get_entity(eid))
134                            .map(|e| e.name.clone())
135                            .collect(),
136                        source_chunks: vec![chunk.id.to_string()],
137                    });
138                }
139            }
140        }
141
142        Ok(enriched_results)
143    }
144
145    /// Filter chunks by chapter or section
146    ///
147    /// Example: "What does Socrates say in Chapter 1?" -> filter to Chapter 1 chunks
148    pub fn filter_by_structure(
149        &self,
150        query: &str,
151        results: Vec<SearchResult>,
152        graph: &KnowledgeGraph,
153    ) -> Result<Vec<SearchResult>> {
154        let structure_refs = self.extract_structure_references(&query.to_lowercase());
155
156        if structure_refs.is_empty() {
157            return Ok(results);
158        }
159
160        let filtered: Vec<SearchResult> = results
161            .into_iter()
162            .filter(|result| {
163                // Get chunk metadata
164                if let Some(chunk_id) = result.source_chunks.first() {
165                    if let Some(chunk) = graph.chunks().find(|c| c.id.to_string() == *chunk_id) {
166                        return self.matches_structure(&chunk.metadata, &structure_refs);
167                    }
168                }
169                true // Keep results without structure metadata
170            })
171            .collect();
172
173        Ok(filtered)
174    }
175
176    /// Boost results based on enriched metadata
177    pub fn boost_with_metadata(
178        &self,
179        mut results: Vec<SearchResult>,
180        query: &str,
181        graph: &KnowledgeGraph,
182    ) -> Result<Vec<SearchResult>> {
183        let query_words: HashSet<String> = query
184            .to_lowercase()
185            .split_whitespace()
186            .filter(|w| w.len() > 3)
187            .map(|w| w.to_string())
188            .collect();
189
190        for result in &mut results {
191            if let Some(chunk_id) = result.source_chunks.first() {
192                if let Some(chunk) = graph.chunks().find(|c| c.id.to_string() == *chunk_id) {
193                    // Boost based on keyword matches
194                    let keyword_matches =
195                        self.count_keyword_matches(&chunk.metadata.keywords, &query_words);
196                    if keyword_matches > 0 {
197                        let boost =
198                            (keyword_matches as f32 / query_words.len().max(1) as f32) * 0.2;
199                        result.score = (result.score + boost).min(1.0);
200                    }
201
202                    // Boost if chapter/section matches query context
203                    if let Some(chapter) = &chunk.metadata.chapter {
204                        if query.to_lowercase().contains(&chapter.to_lowercase()) {
205                            result.score = (result.score + 0.15).min(1.0);
206                        }
207                    }
208
209                    if let Some(section) = &chunk.metadata.section {
210                        if query.to_lowercase().contains(&section.to_lowercase()) {
211                            result.score = (result.score + 0.1).min(1.0);
212                        }
213                    }
214                }
215            }
216        }
217
218        // Re-sort after boosting
219        results.sort_by(|a, b| {
220            b.score
221                .partial_cmp(&a.score)
222                .unwrap_or(std::cmp::Ordering::Equal)
223        });
224
225        Ok(results)
226    }
227
228    /// Get chunks from a specific chapter
229    pub fn get_chapter_chunks<'a>(
230        &self,
231        chapter_name: &str,
232        graph: &'a KnowledgeGraph,
233    ) -> Vec<&'a TextChunk> {
234        graph
235            .chunks()
236            .filter(|chunk| {
237                if let Some(ch) = &chunk.metadata.chapter {
238                    ch.to_lowercase().contains(&chapter_name.to_lowercase())
239                } else {
240                    false
241                }
242            })
243            .collect()
244    }
245
246    /// Get chunks from a specific section
247    pub fn get_section_chunks<'a>(
248        &self,
249        section_name: &str,
250        graph: &'a KnowledgeGraph,
251    ) -> Vec<&'a TextChunk> {
252        graph
253            .chunks()
254            .filter(|chunk| {
255                if let Some(sec) = &chunk.metadata.section {
256                    sec.to_lowercase().contains(&section_name.to_lowercase())
257                } else {
258                    false
259                }
260            })
261            .collect()
262    }
263
264    /// Search by keywords extracted from chunks
265    pub fn search_by_keywords(
266        &self,
267        keywords: &[String],
268        graph: &KnowledgeGraph,
269        top_k: usize,
270    ) -> Vec<SearchResult> {
271        let mut keyword_scores: HashMap<String, (f32, &TextChunk)> = HashMap::new();
272
273        for chunk in graph.chunks() {
274            let mut score = 0.0;
275            for keyword in keywords {
276                if chunk
277                    .metadata
278                    .keywords
279                    .iter()
280                    .any(|k| k.eq_ignore_ascii_case(keyword))
281                {
282                    score += 1.0 / keywords.len() as f32;
283                }
284            }
285
286            if score > 0.0 {
287                keyword_scores.insert(chunk.id.to_string(), (score, chunk));
288            }
289        }
290
291        let mut sorted_results: Vec<_> = keyword_scores.into_iter().collect();
292        sorted_results.sort_by(|a, b| {
293            b.1 .0
294                .partial_cmp(&a.1 .0)
295                .unwrap_or(std::cmp::Ordering::Equal)
296        });
297
298        sorted_results
299            .into_iter()
300            .take(top_k)
301            .map(|(chunk_id, (score, chunk))| SearchResult {
302                id: chunk_id.clone(),
303                content: chunk.content.clone(),
304                score,
305                result_type: ResultType::Chunk,
306                entities: chunk
307                    .entities
308                    .iter()
309                    .filter_map(|eid| graph.get_entity(eid))
310                    .map(|e| e.name.clone())
311                    .collect(),
312                source_chunks: vec![chunk_id],
313            })
314            .collect()
315    }
316
317    // === HELPER METHODS ===
318
319    /// Count matching keywords between chunk and query
320    fn count_keyword_matches(
321        &self,
322        chunk_keywords: &[String],
323        query_words: &HashSet<String>,
324    ) -> usize {
325        chunk_keywords
326            .iter()
327            .filter(|k| query_words.contains(&k.to_lowercase()))
328            .count()
329    }
330
331    /// Find base score from existing results
332    fn find_base_score(&self, chunk: &TextChunk, base_results: &[SearchResult]) -> f32 {
333        base_results
334            .iter()
335            .find(|r| r.source_chunks.contains(&chunk.id.to_string()))
336            .map(|r| r.score)
337            .unwrap_or(0.5) // Default moderate score
338    }
339
340    /// Extract chapter/section references from query
341    fn extract_structure_references(&self, query_lower: &str) -> Vec<String> {
342        let mut refs = Vec::new();
343
344        // Detect "chapter X" or "section Y" patterns
345        let patterns = [
346            r"chapter\s+(\d+|[ivxlcdm]+|\w+)",
347            r"section\s+(\d+\.?\d*)",
348            r"part\s+(\d+|[ivxlcdm]+)",
349        ];
350
351        for pattern in &patterns {
352            if let Some(captures) = regex::Regex::new(pattern)
353                .ok()
354                .and_then(|re| re.captures(query_lower))
355            {
356                if let Some(matched) = captures.get(0) {
357                    refs.push(matched.as_str().to_string());
358                }
359            }
360        }
361
362        // Also check for direct mentions like "Introduction", "Conclusion"
363        for word in query_lower.split_whitespace() {
364            if word.chars().next().is_some_and(|c| c.is_uppercase()) && word.len() > 5 {
365                refs.push(word.to_string());
366            }
367        }
368
369        refs
370    }
371
372    /// Calculate structure boost for chunk
373    fn calculate_structure_boost(
374        &self,
375        chunk: &TextChunk,
376        structure_refs: &[String],
377    ) -> Option<f32> {
378        if structure_refs.is_empty() {
379            return None;
380        }
381
382        let mut boost = 0.0;
383
384        for reference in structure_refs {
385            let ref_lower = reference.to_lowercase();
386
387            if let Some(chapter) = &chunk.metadata.chapter {
388                if chapter.to_lowercase().contains(&ref_lower) {
389                    boost += 0.5;
390                }
391            }
392
393            if let Some(section) = &chunk.metadata.section {
394                if section.to_lowercase().contains(&ref_lower) {
395                    boost += 0.3;
396                }
397            }
398
399            if let Some(subsection) = &chunk.metadata.subsection {
400                if subsection.to_lowercase().contains(&ref_lower) {
401                    boost += 0.2;
402                }
403            }
404        }
405
406        if boost > 0.0 {
407            Some(boost)
408        } else {
409            None
410        }
411    }
412
413    /// Check if text matches query words
414    fn matches_query(&self, text: &str, query_words: &HashSet<String>) -> bool {
415        let text_lower = text.to_lowercase();
416        query_words
417            .iter()
418            .filter(|word| text_lower.contains(word.as_str()))
419            .count()
420            >= (query_words.len() / 2).max(1)
421    }
422
423    /// Check if chunk metadata matches structure references
424    fn matches_structure(
425        &self,
426        metadata: &crate::core::ChunkMetadata,
427        structure_refs: &[String],
428    ) -> bool {
429        for reference in structure_refs {
430            let ref_lower = reference.to_lowercase();
431
432            if let Some(chapter) = &metadata.chapter {
433                if chapter.to_lowercase().contains(&ref_lower) {
434                    return true;
435                }
436            }
437
438            if let Some(section) = &metadata.section {
439                if section.to_lowercase().contains(&ref_lower) {
440                    return true;
441                }
442            }
443
444            if let Some(subsection) = &metadata.subsection {
445                if subsection.to_lowercase().contains(&ref_lower) {
446                    return true;
447                }
448            }
449        }
450
451        false
452    }
453}
454
455impl Default for EnrichedRetriever {
456    fn default() -> Self {
457        Self::new()
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464    use crate::core::{ChunkId, ChunkMetadata, DocumentId, KnowledgeGraph, TextChunk};
465
466    fn create_test_chunk(
467        id: &str,
468        content: &str,
469        keywords: Vec<String>,
470        chapter: Option<String>,
471    ) -> TextChunk {
472        let mut chunk = TextChunk::new(
473            ChunkId::new(id.to_string()),
474            DocumentId::new("test_doc".to_string()),
475            content.to_string(),
476            0,
477            content.len(),
478        );
479
480        let mut metadata = ChunkMetadata::new();
481        metadata.keywords = keywords;
482        metadata.chapter = chapter;
483        chunk.metadata = metadata;
484
485        chunk
486    }
487
488    #[test]
489    fn test_keyword_matching() {
490        let retriever = EnrichedRetriever::new();
491        let chunk_keywords = vec![
492            "machine".to_string(),
493            "learning".to_string(),
494            "neural".to_string(),
495        ];
496        let query_words: HashSet<String> = vec!["machine".to_string(), "learning".to_string()]
497            .into_iter()
498            .collect();
499
500        let matches = retriever.count_keyword_matches(&chunk_keywords, &query_words);
501        assert_eq!(matches, 2);
502    }
503
504    #[test]
505    fn test_structure_extraction() {
506        let retriever = EnrichedRetriever::new();
507        let query = "What does Socrates say in chapter 1?";
508        let refs = retriever.extract_structure_references(&query.to_lowercase());
509
510        assert!(!refs.is_empty());
511    }
512
513    #[test]
514    fn test_chapter_filtering() {
515        let retriever = EnrichedRetriever::new();
516        let mut graph = KnowledgeGraph::new();
517
518        let chunk1 = create_test_chunk(
519            "chunk1",
520            "Content from chapter 1",
521            vec!["content".to_string()],
522            Some("Chapter 1: Introduction".to_string()),
523        );
524
525        let chunk2 = create_test_chunk(
526            "chunk2",
527            "Content from chapter 2",
528            vec!["content".to_string()],
529            Some("Chapter 2: Methods".to_string()),
530        );
531
532        let _ = graph.add_chunk(chunk1);
533        let _ = graph.add_chunk(chunk2);
534
535        let chapter1_chunks = retriever.get_chapter_chunks("Chapter 1", &graph);
536        assert_eq!(chapter1_chunks.len(), 1);
537    }
538}