Skip to main content

graphrag_core/retrieval/
enriched.rs

1//! Enriched metadata-aware retrieval
2//!
3//! This module provides retrieval strategies that leverage enriched chunk metadata
4//! (chapters, sections, keywords, summaries) to improve search relevance and precision.
5
6use crate::{
7    core::{KnowledgeGraph, TextChunk},
8    retrieval::{QueryAnalysis, ResultType, SearchResult},
9    Result,
10};
11use std::collections::{HashMap, HashSet};
12
13/// Configuration for enriched metadata retrieval
14#[derive(Debug, Clone)]
15pub struct EnrichedRetrievalConfig {
16    /// Weight for keyword matching (0.0 to 1.0)
17    pub keyword_match_weight: f32,
18    /// Weight for chapter/section context matching (0.0 to 1.0)
19    pub structure_match_weight: f32,
20    /// Weight for summary relevance (0.0 to 1.0)
21    pub summary_weight: f32,
22    /// Minimum number of keywords to match for boosting
23    pub min_keyword_matches: usize,
24    /// Enable chapter/section filtering
25    pub enable_structure_filtering: bool,
26}
27
28impl Default for EnrichedRetrievalConfig {
29    fn default() -> Self {
30        Self {
31            keyword_match_weight: 0.3,
32            structure_match_weight: 0.2,
33            summary_weight: 0.15,
34            min_keyword_matches: 1,
35            enable_structure_filtering: true,
36        }
37    }
38}
39
40/// Metadata-enhanced retrieval strategies
41pub struct EnrichedRetriever {
42    config: EnrichedRetrievalConfig,
43}
44
45impl EnrichedRetriever {
46    /// Create a new enriched retriever
47    pub fn new() -> Self {
48        Self {
49            config: EnrichedRetrievalConfig::default(),
50        }
51    }
52
53    /// Create with custom configuration
54    pub fn with_config(config: EnrichedRetrievalConfig) -> Self {
55        Self { config }
56    }
57
58    /// Search chunks using enriched metadata
59    ///
60    /// This method boosts chunks that match:
61    /// 1. Query keywords present in chunk keywords (TF-IDF extracted)
62    /// 2. Chapter/Section mentioned in query
63    /// 3. Summary content relevant to query
64    pub fn metadata_search(
65        &self,
66        query: &str,
67        graph: &KnowledgeGraph,
68        _analysis: &QueryAnalysis,
69        base_results: &[SearchResult],
70    ) -> Result<Vec<SearchResult>> {
71        let mut enriched_results = Vec::new();
72
73        // Extract query keywords and potential structure references
74        let query_lower = query.to_lowercase();
75        let query_words: HashSet<String> = query_lower
76            .split_whitespace()
77            .filter(|w| w.len() > 3)
78            .map(|w| w.to_string())
79            .collect();
80
81        // Detect chapter/section references in query
82        let structure_refs = self.extract_structure_references(&query_lower);
83
84        // Process each chunk in the graph
85        for chunk in graph.chunks() {
86            if !chunk.entities.is_empty() || !chunk.metadata.keywords.is_empty() {
87                let mut base_score = self.find_base_score(chunk, base_results);
88                let mut metadata_boost = 0.0;
89
90                // 1. KEYWORD MATCHING BOOST
91                let keyword_matches =
92                    self.count_keyword_matches(&chunk.metadata.keywords, &query_words);
93                if keyword_matches >= self.config.min_keyword_matches {
94                    let keyword_boost = (keyword_matches as f32 / query_words.len().max(1) as f32)
95                        * self.config.keyword_match_weight;
96                    metadata_boost += keyword_boost;
97                }
98
99                // 2. STRUCTURE MATCHING BOOST (Chapter/Section)
100                if self.config.enable_structure_filtering {
101                    if let Some(structure_boost) =
102                        self.calculate_structure_boost(chunk, &structure_refs)
103                    {
104                        metadata_boost += structure_boost * self.config.structure_match_weight;
105                    }
106                }
107
108                // 3. SUMMARY RELEVANCE BOOST
109                if let Some(summary) = &chunk.metadata.summary {
110                    if self.matches_query(summary, &query_words) {
111                        metadata_boost += self.config.summary_weight;
112                    }
113                }
114
115                // 4. COMPLETENESS BONUS
116                let completeness = chunk.metadata.completeness_score();
117                if completeness > 0.7 {
118                    metadata_boost += 0.05; // Small bonus for high-quality metadata
119                }
120
121                // Apply boost only if significant
122                if metadata_boost > 0.05 {
123                    base_score = (base_score + metadata_boost).min(1.0);
124
125                    enriched_results.push(SearchResult {
126                        id: chunk.id.to_string(),
127                        content: chunk.content.clone(),
128                        score: base_score,
129                        result_type: ResultType::Chunk,
130                        entities: chunk
131                            .entities
132                            .iter()
133                            .filter_map(|eid| graph.get_entity(eid))
134                            .map(|e| e.name.clone())
135                            .collect(),
136                        source_chunks: vec![chunk.id.to_string()],
137                    });
138                }
139            }
140        }
141
142        Ok(enriched_results)
143    }
144
145    /// Filter chunks by chapter or section
146    ///
147    /// Example: "What does Socrates say in Chapter 1?" -> filter to Chapter 1 chunks
148    pub fn filter_by_structure(
149        &self,
150        query: &str,
151        results: Vec<SearchResult>,
152        graph: &KnowledgeGraph,
153    ) -> Result<Vec<SearchResult>> {
154        let structure_refs = self.extract_structure_references(&query.to_lowercase());
155
156        if structure_refs.is_empty() {
157            return Ok(results);
158        }
159
160        let filtered: Vec<SearchResult> = results
161            .into_iter()
162            .filter(|result| {
163                // Get chunk metadata
164                if let Some(chunk_id) = result.source_chunks.first() {
165                    if let Some(chunk) = graph.chunks().find(|c| c.id.to_string() == *chunk_id) {
166                        return self.matches_structure(&chunk.metadata, &structure_refs);
167                    }
168                }
169                true // Keep results without structure metadata
170            })
171            .collect();
172
173        Ok(filtered)
174    }
175
176    /// Boost results based on enriched metadata
177    pub fn boost_with_metadata(
178        &self,
179        mut results: Vec<SearchResult>,
180        query: &str,
181        graph: &KnowledgeGraph,
182    ) -> Result<Vec<SearchResult>> {
183        let query_words: HashSet<String> = query
184            .to_lowercase()
185            .split_whitespace()
186            .filter(|w| w.len() > 3)
187            .map(|w| w.to_string())
188            .collect();
189
190        for result in &mut results {
191            if let Some(chunk_id) = result.source_chunks.first() {
192                if let Some(chunk) = graph.chunks().find(|c| c.id.to_string() == *chunk_id) {
193                    // Boost based on keyword matches
194                    let keyword_matches =
195                        self.count_keyword_matches(&chunk.metadata.keywords, &query_words);
196                    if keyword_matches > 0 {
197                        let boost =
198                            (keyword_matches as f32 / query_words.len().max(1) as f32) * 0.2;
199                        result.score = (result.score + boost).min(1.0);
200                    }
201
202                    // Boost if chapter/section matches query context
203                    if let Some(chapter) = &chunk.metadata.chapter {
204                        if query.to_lowercase().contains(&chapter.to_lowercase()) {
205                            result.score = (result.score + 0.15).min(1.0);
206                        }
207                    }
208
209                    if let Some(section) = &chunk.metadata.section {
210                        if query.to_lowercase().contains(&section.to_lowercase()) {
211                            result.score = (result.score + 0.1).min(1.0);
212                        }
213                    }
214                }
215            }
216        }
217
218        // Re-sort after boosting
219        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
220
221        Ok(results)
222    }
223
224    /// Get chunks from a specific chapter
225    pub fn get_chapter_chunks<'a>(
226        &self,
227        chapter_name: &str,
228        graph: &'a KnowledgeGraph,
229    ) -> Vec<&'a TextChunk> {
230        graph
231            .chunks()
232            .filter(|chunk| {
233                if let Some(ch) = &chunk.metadata.chapter {
234                    ch.to_lowercase().contains(&chapter_name.to_lowercase())
235                } else {
236                    false
237                }
238            })
239            .collect()
240    }
241
242    /// Get chunks from a specific section
243    pub fn get_section_chunks<'a>(
244        &self,
245        section_name: &str,
246        graph: &'a KnowledgeGraph,
247    ) -> Vec<&'a TextChunk> {
248        graph
249            .chunks()
250            .filter(|chunk| {
251                if let Some(sec) = &chunk.metadata.section {
252                    sec.to_lowercase().contains(&section_name.to_lowercase())
253                } else {
254                    false
255                }
256            })
257            .collect()
258    }
259
260    /// Search by keywords extracted from chunks
261    pub fn search_by_keywords(
262        &self,
263        keywords: &[String],
264        graph: &KnowledgeGraph,
265        top_k: usize,
266    ) -> Vec<SearchResult> {
267        let mut keyword_scores: HashMap<String, (f32, &TextChunk)> = HashMap::new();
268
269        for chunk in graph.chunks() {
270            let mut score = 0.0;
271            for keyword in keywords {
272                if chunk
273                    .metadata
274                    .keywords
275                    .iter()
276                    .any(|k| k.eq_ignore_ascii_case(keyword))
277                {
278                    score += 1.0 / keywords.len() as f32;
279                }
280            }
281
282            if score > 0.0 {
283                keyword_scores.insert(chunk.id.to_string(), (score, chunk));
284            }
285        }
286
287        let mut sorted_results: Vec<_> = keyword_scores.into_iter().collect();
288        sorted_results.sort_by(|a, b| b.1 .0.partial_cmp(&a.1 .0).unwrap());
289
290        sorted_results
291            .into_iter()
292            .take(top_k)
293            .map(|(chunk_id, (score, chunk))| SearchResult {
294                id: chunk_id.clone(),
295                content: chunk.content.clone(),
296                score,
297                result_type: ResultType::Chunk,
298                entities: chunk
299                    .entities
300                    .iter()
301                    .filter_map(|eid| graph.get_entity(eid))
302                    .map(|e| e.name.clone())
303                    .collect(),
304                source_chunks: vec![chunk_id],
305            })
306            .collect()
307    }
308
309    // === HELPER METHODS ===
310
311    /// Count matching keywords between chunk and query
312    fn count_keyword_matches(
313        &self,
314        chunk_keywords: &[String],
315        query_words: &HashSet<String>,
316    ) -> usize {
317        chunk_keywords
318            .iter()
319            .filter(|k| query_words.contains(&k.to_lowercase()))
320            .count()
321    }
322
323    /// Find base score from existing results
324    fn find_base_score(&self, chunk: &TextChunk, base_results: &[SearchResult]) -> f32 {
325        base_results
326            .iter()
327            .find(|r| r.source_chunks.contains(&chunk.id.to_string()))
328            .map(|r| r.score)
329            .unwrap_or(0.5) // Default moderate score
330    }
331
332    /// Extract chapter/section references from query
333    fn extract_structure_references(&self, query_lower: &str) -> Vec<String> {
334        let mut refs = Vec::new();
335
336        // Detect "chapter X" or "section Y" patterns
337        let patterns = [
338            r"chapter\s+(\d+|[ivxlcdm]+|\w+)",
339            r"section\s+(\d+\.?\d*)",
340            r"part\s+(\d+|[ivxlcdm]+)",
341        ];
342
343        for pattern in &patterns {
344            if let Some(captures) = regex::Regex::new(pattern)
345                .ok()
346                .and_then(|re| re.captures(query_lower))
347            {
348                if let Some(matched) = captures.get(0) {
349                    refs.push(matched.as_str().to_string());
350                }
351            }
352        }
353
354        // Also check for direct mentions like "Introduction", "Conclusion"
355        for word in query_lower.split_whitespace() {
356            if word.chars().next().is_some_and(|c| c.is_uppercase()) && word.len() > 5 {
357                refs.push(word.to_string());
358            }
359        }
360
361        refs
362    }
363
364    /// Calculate structure boost for chunk
365    fn calculate_structure_boost(
366        &self,
367        chunk: &TextChunk,
368        structure_refs: &[String],
369    ) -> Option<f32> {
370        if structure_refs.is_empty() {
371            return None;
372        }
373
374        let mut boost = 0.0;
375
376        for reference in structure_refs {
377            let ref_lower = reference.to_lowercase();
378
379            if let Some(chapter) = &chunk.metadata.chapter {
380                if chapter.to_lowercase().contains(&ref_lower) {
381                    boost += 0.5;
382                }
383            }
384
385            if let Some(section) = &chunk.metadata.section {
386                if section.to_lowercase().contains(&ref_lower) {
387                    boost += 0.3;
388                }
389            }
390
391            if let Some(subsection) = &chunk.metadata.subsection {
392                if subsection.to_lowercase().contains(&ref_lower) {
393                    boost += 0.2;
394                }
395            }
396        }
397
398        if boost > 0.0 {
399            Some(boost)
400        } else {
401            None
402        }
403    }
404
405    /// Check if text matches query words
406    fn matches_query(&self, text: &str, query_words: &HashSet<String>) -> bool {
407        let text_lower = text.to_lowercase();
408        query_words
409            .iter()
410            .filter(|word| text_lower.contains(word.as_str()))
411            .count()
412            >= (query_words.len() / 2).max(1)
413    }
414
415    /// Check if chunk metadata matches structure references
416    fn matches_structure(
417        &self,
418        metadata: &crate::core::ChunkMetadata,
419        structure_refs: &[String],
420    ) -> bool {
421        for reference in structure_refs {
422            let ref_lower = reference.to_lowercase();
423
424            if let Some(chapter) = &metadata.chapter {
425                if chapter.to_lowercase().contains(&ref_lower) {
426                    return true;
427                }
428            }
429
430            if let Some(section) = &metadata.section {
431                if section.to_lowercase().contains(&ref_lower) {
432                    return true;
433                }
434            }
435
436            if let Some(subsection) = &metadata.subsection {
437                if subsection.to_lowercase().contains(&ref_lower) {
438                    return true;
439                }
440            }
441        }
442
443        false
444    }
445}
446
447impl Default for EnrichedRetriever {
448    fn default() -> Self {
449        Self::new()
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456    use crate::core::{ChunkId, ChunkMetadata, DocumentId, KnowledgeGraph, TextChunk};
457
458    fn create_test_chunk(
459        id: &str,
460        content: &str,
461        keywords: Vec<String>,
462        chapter: Option<String>,
463    ) -> TextChunk {
464        let mut chunk = TextChunk::new(
465            ChunkId::new(id.to_string()),
466            DocumentId::new("test_doc".to_string()),
467            content.to_string(),
468            0,
469            content.len(),
470        );
471
472        let mut metadata = ChunkMetadata::new();
473        metadata.keywords = keywords;
474        metadata.chapter = chapter;
475        chunk.metadata = metadata;
476
477        chunk
478    }
479
480    #[test]
481    fn test_keyword_matching() {
482        let retriever = EnrichedRetriever::new();
483        let chunk_keywords = vec![
484            "machine".to_string(),
485            "learning".to_string(),
486            "neural".to_string(),
487        ];
488        let query_words: HashSet<String> = vec!["machine".to_string(), "learning".to_string()]
489            .into_iter()
490            .collect();
491
492        let matches = retriever.count_keyword_matches(&chunk_keywords, &query_words);
493        assert_eq!(matches, 2);
494    }
495
496    #[test]
497    fn test_structure_extraction() {
498        let retriever = EnrichedRetriever::new();
499        let query = "What does Socrates say in chapter 1?";
500        let refs = retriever.extract_structure_references(&query.to_lowercase());
501
502        assert!(!refs.is_empty());
503    }
504
505    #[test]
506    fn test_chapter_filtering() {
507        let retriever = EnrichedRetriever::new();
508        let mut graph = KnowledgeGraph::new();
509
510        let chunk1 = create_test_chunk(
511            "chunk1",
512            "Content from chapter 1",
513            vec!["content".to_string()],
514            Some("Chapter 1: Introduction".to_string()),
515        );
516
517        let chunk2 = create_test_chunk(
518            "chunk2",
519            "Content from chapter 2",
520            vec!["content".to_string()],
521            Some("Chapter 2: Methods".to_string()),
522        );
523
524        let _ = graph.add_chunk(chunk1);
525        let _ = graph.add_chunk(chunk2);
526
527        let chapter1_chunks = retriever.get_chapter_chunks("Chapter 1", &graph);
528        assert_eq!(chapter1_chunks.len(), 1);
529    }
530}