1use crate::{
7 core::{KnowledgeGraph, TextChunk},
8 retrieval::{QueryAnalysis, SearchResult, ResultType},
9 Result,
10};
11use std::collections::{HashMap, HashSet};
12
13#[derive(Debug, Clone)]
15pub struct EnrichedRetrievalConfig {
16 pub keyword_match_weight: f32,
18 pub structure_match_weight: f32,
20 pub summary_weight: f32,
22 pub min_keyword_matches: usize,
24 pub enable_structure_filtering: bool,
26}
27
28impl Default for EnrichedRetrievalConfig {
29 fn default() -> Self {
30 Self {
31 keyword_match_weight: 0.3,
32 structure_match_weight: 0.2,
33 summary_weight: 0.15,
34 min_keyword_matches: 1,
35 enable_structure_filtering: true,
36 }
37 }
38}
39
40pub struct EnrichedRetriever {
42 config: EnrichedRetrievalConfig,
43}
44
45impl EnrichedRetriever {
46 pub fn new() -> Self {
48 Self {
49 config: EnrichedRetrievalConfig::default(),
50 }
51 }
52
53 pub fn with_config(config: EnrichedRetrievalConfig) -> Self {
55 Self { config }
56 }
57
58 pub fn metadata_search(
65 &self,
66 query: &str,
67 graph: &KnowledgeGraph,
68 _analysis: &QueryAnalysis,
69 base_results: &[SearchResult],
70 ) -> Result<Vec<SearchResult>> {
71 let mut enriched_results = Vec::new();
72
73 let query_lower = query.to_lowercase();
75 let query_words: HashSet<String> = query_lower
76 .split_whitespace()
77 .filter(|w| w.len() > 3)
78 .map(|w| w.to_string())
79 .collect();
80
81 let structure_refs = self.extract_structure_references(&query_lower);
83
84 for chunk in graph.chunks() {
86 if !chunk.entities.is_empty() || !chunk.metadata.keywords.is_empty() {
87 let mut base_score = self.find_base_score(chunk, base_results);
88 let mut metadata_boost = 0.0;
89
90 let keyword_matches = self.count_keyword_matches(&chunk.metadata.keywords, &query_words);
92 if keyword_matches >= self.config.min_keyword_matches {
93 let keyword_boost = (keyword_matches as f32 / query_words.len().max(1) as f32)
94 * self.config.keyword_match_weight;
95 metadata_boost += keyword_boost;
96 }
97
98 if self.config.enable_structure_filtering {
100 if let Some(structure_boost) = self.calculate_structure_boost(chunk, &structure_refs) {
101 metadata_boost += structure_boost * self.config.structure_match_weight;
102 }
103 }
104
105 if let Some(summary) = &chunk.metadata.summary {
107 if self.matches_query(summary, &query_words) {
108 metadata_boost += self.config.summary_weight;
109 }
110 }
111
112 let completeness = chunk.metadata.completeness_score();
114 if completeness > 0.7 {
115 metadata_boost += 0.05; }
117
118 if metadata_boost > 0.05 {
120 base_score = (base_score + metadata_boost).min(1.0);
121
122 enriched_results.push(SearchResult {
123 id: chunk.id.to_string(),
124 content: chunk.content.clone(),
125 score: base_score,
126 result_type: ResultType::Chunk,
127 entities: chunk
128 .entities
129 .iter()
130 .filter_map(|eid| graph.get_entity(eid))
131 .map(|e| e.name.clone())
132 .collect(),
133 source_chunks: vec![chunk.id.to_string()],
134 });
135 }
136 }
137 }
138
139 Ok(enriched_results)
140 }
141
142 pub fn filter_by_structure(
146 &self,
147 query: &str,
148 results: Vec<SearchResult>,
149 graph: &KnowledgeGraph,
150 ) -> Result<Vec<SearchResult>> {
151 let structure_refs = self.extract_structure_references(&query.to_lowercase());
152
153 if structure_refs.is_empty() {
154 return Ok(results);
155 }
156
157 let filtered: Vec<SearchResult> = results
158 .into_iter()
159 .filter(|result| {
160 if let Some(chunk_id) = result.source_chunks.first() {
162 if let Some(chunk) = graph.chunks().find(|c| c.id.to_string() == *chunk_id) {
163 return self.matches_structure(&chunk.metadata, &structure_refs);
164 }
165 }
166 true })
168 .collect();
169
170 Ok(filtered)
171 }
172
173 pub fn boost_with_metadata(
175 &self,
176 mut results: Vec<SearchResult>,
177 query: &str,
178 graph: &KnowledgeGraph,
179 ) -> Result<Vec<SearchResult>> {
180 let query_words: HashSet<String> = query
181 .to_lowercase()
182 .split_whitespace()
183 .filter(|w| w.len() > 3)
184 .map(|w| w.to_string())
185 .collect();
186
187 for result in &mut results {
188 if let Some(chunk_id) = result.source_chunks.first() {
189 if let Some(chunk) = graph.chunks().find(|c| c.id.to_string() == *chunk_id) {
190 let keyword_matches = self.count_keyword_matches(&chunk.metadata.keywords, &query_words);
192 if keyword_matches > 0 {
193 let boost = (keyword_matches as f32 / query_words.len().max(1) as f32) * 0.2;
194 result.score = (result.score + boost).min(1.0);
195 }
196
197 if let Some(chapter) = &chunk.metadata.chapter {
199 if query.to_lowercase().contains(&chapter.to_lowercase()) {
200 result.score = (result.score + 0.15).min(1.0);
201 }
202 }
203
204 if let Some(section) = &chunk.metadata.section {
205 if query.to_lowercase().contains(§ion.to_lowercase()) {
206 result.score = (result.score + 0.1).min(1.0);
207 }
208 }
209 }
210 }
211 }
212
213 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
215
216 Ok(results)
217 }
218
219 pub fn get_chapter_chunks<'a>(&self, chapter_name: &str, graph: &'a KnowledgeGraph) -> Vec<&'a TextChunk> {
221 graph
222 .chunks()
223 .filter(|chunk| {
224 if let Some(ch) = &chunk.metadata.chapter {
225 ch.to_lowercase().contains(&chapter_name.to_lowercase())
226 } else {
227 false
228 }
229 })
230 .collect()
231 }
232
233 pub fn get_section_chunks<'a>(&self, section_name: &str, graph: &'a KnowledgeGraph) -> Vec<&'a TextChunk> {
235 graph
236 .chunks()
237 .filter(|chunk| {
238 if let Some(sec) = &chunk.metadata.section {
239 sec.to_lowercase().contains(§ion_name.to_lowercase())
240 } else {
241 false
242 }
243 })
244 .collect()
245 }
246
247 pub fn search_by_keywords(
249 &self,
250 keywords: &[String],
251 graph: &KnowledgeGraph,
252 top_k: usize,
253 ) -> Vec<SearchResult> {
254 let mut keyword_scores: HashMap<String, (f32, &TextChunk)> = HashMap::new();
255
256 for chunk in graph.chunks() {
257 let mut score = 0.0;
258 for keyword in keywords {
259 if chunk.metadata.keywords.iter().any(|k| k.eq_ignore_ascii_case(keyword)) {
260 score += 1.0 / keywords.len() as f32;
261 }
262 }
263
264 if score > 0.0 {
265 keyword_scores.insert(chunk.id.to_string(), (score, chunk));
266 }
267 }
268
269 let mut sorted_results: Vec<_> = keyword_scores.into_iter().collect();
270 sorted_results.sort_by(|a, b| b.1 .0.partial_cmp(&a.1 .0).unwrap());
271
272 sorted_results
273 .into_iter()
274 .take(top_k)
275 .map(|(chunk_id, (score, chunk))| SearchResult {
276 id: chunk_id.clone(),
277 content: chunk.content.clone(),
278 score,
279 result_type: ResultType::Chunk,
280 entities: chunk
281 .entities
282 .iter()
283 .filter_map(|eid| graph.get_entity(eid))
284 .map(|e| e.name.clone())
285 .collect(),
286 source_chunks: vec![chunk_id],
287 })
288 .collect()
289 }
290
291 fn count_keyword_matches(&self, chunk_keywords: &[String], query_words: &HashSet<String>) -> usize {
295 chunk_keywords
296 .iter()
297 .filter(|k| query_words.contains(&k.to_lowercase()))
298 .count()
299 }
300
301 fn find_base_score(&self, chunk: &TextChunk, base_results: &[SearchResult]) -> f32 {
303 base_results
304 .iter()
305 .find(|r| r.source_chunks.contains(&chunk.id.to_string()))
306 .map(|r| r.score)
307 .unwrap_or(0.5) }
309
310 fn extract_structure_references(&self, query_lower: &str) -> Vec<String> {
312 let mut refs = Vec::new();
313
314 let patterns = [
316 r"chapter\s+(\d+|[ivxlcdm]+|\w+)",
317 r"section\s+(\d+\.?\d*)",
318 r"part\s+(\d+|[ivxlcdm]+)",
319 ];
320
321 for pattern in &patterns {
322 if let Some(captures) = regex::Regex::new(pattern).ok().and_then(|re| re.captures(query_lower)) {
323 if let Some(matched) = captures.get(0) {
324 refs.push(matched.as_str().to_string());
325 }
326 }
327 }
328
329 for word in query_lower.split_whitespace() {
331 if word.chars().next().map_or(false, |c| c.is_uppercase()) && word.len() > 5 {
332 refs.push(word.to_string());
333 }
334 }
335
336 refs
337 }
338
339 fn calculate_structure_boost(
341 &self,
342 chunk: &TextChunk,
343 structure_refs: &[String],
344 ) -> Option<f32> {
345 if structure_refs.is_empty() {
346 return None;
347 }
348
349 let mut boost = 0.0;
350
351 for reference in structure_refs {
352 let ref_lower = reference.to_lowercase();
353
354 if let Some(chapter) = &chunk.metadata.chapter {
355 if chapter.to_lowercase().contains(&ref_lower) {
356 boost += 0.5;
357 }
358 }
359
360 if let Some(section) = &chunk.metadata.section {
361 if section.to_lowercase().contains(&ref_lower) {
362 boost += 0.3;
363 }
364 }
365
366 if let Some(subsection) = &chunk.metadata.subsection {
367 if subsection.to_lowercase().contains(&ref_lower) {
368 boost += 0.2;
369 }
370 }
371 }
372
373 if boost > 0.0 {
374 Some(boost)
375 } else {
376 None
377 }
378 }
379
380 fn matches_query(&self, text: &str, query_words: &HashSet<String>) -> bool {
382 let text_lower = text.to_lowercase();
383 query_words
384 .iter()
385 .filter(|word| text_lower.contains(word.as_str()))
386 .count()
387 >= (query_words.len() / 2).max(1)
388 }
389
390 fn matches_structure(
392 &self,
393 metadata: &crate::core::ChunkMetadata,
394 structure_refs: &[String],
395 ) -> bool {
396 for reference in structure_refs {
397 let ref_lower = reference.to_lowercase();
398
399 if let Some(chapter) = &metadata.chapter {
400 if chapter.to_lowercase().contains(&ref_lower) {
401 return true;
402 }
403 }
404
405 if let Some(section) = &metadata.section {
406 if section.to_lowercase().contains(&ref_lower) {
407 return true;
408 }
409 }
410
411 if let Some(subsection) = &metadata.subsection {
412 if subsection.to_lowercase().contains(&ref_lower) {
413 return true;
414 }
415 }
416 }
417
418 false
419 }
420}
421
422impl Default for EnrichedRetriever {
423 fn default() -> Self {
424 Self::new()
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431 use crate::core::{ChunkId, ChunkMetadata, DocumentId, KnowledgeGraph, TextChunk};
432
433 fn create_test_chunk(id: &str, content: &str, keywords: Vec<String>, chapter: Option<String>) -> TextChunk {
434 let mut chunk = TextChunk::new(
435 ChunkId::new(id.to_string()),
436 DocumentId::new("test_doc".to_string()),
437 content.to_string(),
438 0,
439 content.len(),
440 );
441
442 let mut metadata = ChunkMetadata::new();
443 metadata.keywords = keywords;
444 metadata.chapter = chapter;
445 chunk.metadata = metadata;
446
447 chunk
448 }
449
450 #[test]
451 fn test_keyword_matching() {
452 let retriever = EnrichedRetriever::new();
453 let chunk_keywords = vec!["machine".to_string(), "learning".to_string(), "neural".to_string()];
454 let query_words: HashSet<String> = vec!["machine".to_string(), "learning".to_string()]
455 .into_iter()
456 .collect();
457
458 let matches = retriever.count_keyword_matches(&chunk_keywords, &query_words);
459 assert_eq!(matches, 2);
460 }
461
462 #[test]
463 fn test_structure_extraction() {
464 let retriever = EnrichedRetriever::new();
465 let query = "What does Socrates say in chapter 1?";
466 let refs = retriever.extract_structure_references(&query.to_lowercase());
467
468 assert!(!refs.is_empty());
469 }
470
471 #[test]
472 fn test_chapter_filtering() {
473 let retriever = EnrichedRetriever::new();
474 let mut graph = KnowledgeGraph::new();
475
476 let chunk1 = create_test_chunk(
477 "chunk1",
478 "Content from chapter 1",
479 vec!["content".to_string()],
480 Some("Chapter 1: Introduction".to_string()),
481 );
482
483 let chunk2 = create_test_chunk(
484 "chunk2",
485 "Content from chapter 2",
486 vec!["content".to_string()],
487 Some("Chapter 2: Methods".to_string()),
488 );
489
490 let _ = graph.add_chunk(chunk1);
491 let _ = graph.add_chunk(chunk2);
492
493 let chapter1_chunks = retriever.get_chapter_chunks("Chapter 1", &graph);
494 assert_eq!(chapter1_chunks.len(), 1);
495 }
496}