1use crate::{
7 core::{KnowledgeGraph, TextChunk},
8 retrieval::{QueryAnalysis, ResultType, SearchResult},
9 Result,
10};
11use std::collections::{HashMap, HashSet};
12
13#[derive(Debug, Clone)]
15pub struct EnrichedRetrievalConfig {
16 pub keyword_match_weight: f32,
18 pub structure_match_weight: f32,
20 pub summary_weight: f32,
22 pub min_keyword_matches: usize,
24 pub enable_structure_filtering: bool,
26}
27
28impl Default for EnrichedRetrievalConfig {
29 fn default() -> Self {
30 Self {
31 keyword_match_weight: 0.3,
32 structure_match_weight: 0.2,
33 summary_weight: 0.15,
34 min_keyword_matches: 1,
35 enable_structure_filtering: true,
36 }
37 }
38}
39
40pub struct EnrichedRetriever {
42 config: EnrichedRetrievalConfig,
43}
44
45impl EnrichedRetriever {
46 pub fn new() -> Self {
48 Self {
49 config: EnrichedRetrievalConfig::default(),
50 }
51 }
52
53 pub fn with_config(config: EnrichedRetrievalConfig) -> Self {
55 Self { config }
56 }
57
58 pub fn metadata_search(
65 &self,
66 query: &str,
67 graph: &KnowledgeGraph,
68 _analysis: &QueryAnalysis,
69 base_results: &[SearchResult],
70 ) -> Result<Vec<SearchResult>> {
71 let mut enriched_results = Vec::new();
72
73 let query_lower = query.to_lowercase();
75 let query_words: HashSet<String> = query_lower
76 .split_whitespace()
77 .filter(|w| w.len() > 3)
78 .map(|w| w.to_string())
79 .collect();
80
81 let structure_refs = self.extract_structure_references(&query_lower);
83
84 for chunk in graph.chunks() {
86 if !chunk.entities.is_empty() || !chunk.metadata.keywords.is_empty() {
87 let mut base_score = self.find_base_score(chunk, base_results);
88 let mut metadata_boost = 0.0;
89
90 let keyword_matches =
92 self.count_keyword_matches(&chunk.metadata.keywords, &query_words);
93 if keyword_matches >= self.config.min_keyword_matches {
94 let keyword_boost = (keyword_matches as f32 / query_words.len().max(1) as f32)
95 * self.config.keyword_match_weight;
96 metadata_boost += keyword_boost;
97 }
98
99 if self.config.enable_structure_filtering {
101 if let Some(structure_boost) =
102 self.calculate_structure_boost(chunk, &structure_refs)
103 {
104 metadata_boost += structure_boost * self.config.structure_match_weight;
105 }
106 }
107
108 if let Some(summary) = &chunk.metadata.summary {
110 if self.matches_query(summary, &query_words) {
111 metadata_boost += self.config.summary_weight;
112 }
113 }
114
115 let completeness = chunk.metadata.completeness_score();
117 if completeness > 0.7 {
118 metadata_boost += 0.05; }
120
121 if metadata_boost > 0.05 {
123 base_score = (base_score + metadata_boost).min(1.0);
124
125 enriched_results.push(SearchResult {
126 id: chunk.id.to_string(),
127 content: chunk.content.clone(),
128 score: base_score,
129 result_type: ResultType::Chunk,
130 entities: chunk
131 .entities
132 .iter()
133 .filter_map(|eid| graph.get_entity(eid))
134 .map(|e| e.name.clone())
135 .collect(),
136 source_chunks: vec![chunk.id.to_string()],
137 });
138 }
139 }
140 }
141
142 Ok(enriched_results)
143 }
144
145 pub fn filter_by_structure(
149 &self,
150 query: &str,
151 results: Vec<SearchResult>,
152 graph: &KnowledgeGraph,
153 ) -> Result<Vec<SearchResult>> {
154 let structure_refs = self.extract_structure_references(&query.to_lowercase());
155
156 if structure_refs.is_empty() {
157 return Ok(results);
158 }
159
160 let filtered: Vec<SearchResult> = results
161 .into_iter()
162 .filter(|result| {
163 if let Some(chunk_id) = result.source_chunks.first() {
165 if let Some(chunk) = graph.chunks().find(|c| c.id.to_string() == *chunk_id) {
166 return self.matches_structure(&chunk.metadata, &structure_refs);
167 }
168 }
169 true })
171 .collect();
172
173 Ok(filtered)
174 }
175
176 pub fn boost_with_metadata(
178 &self,
179 mut results: Vec<SearchResult>,
180 query: &str,
181 graph: &KnowledgeGraph,
182 ) -> Result<Vec<SearchResult>> {
183 let query_words: HashSet<String> = query
184 .to_lowercase()
185 .split_whitespace()
186 .filter(|w| w.len() > 3)
187 .map(|w| w.to_string())
188 .collect();
189
190 for result in &mut results {
191 if let Some(chunk_id) = result.source_chunks.first() {
192 if let Some(chunk) = graph.chunks().find(|c| c.id.to_string() == *chunk_id) {
193 let keyword_matches =
195 self.count_keyword_matches(&chunk.metadata.keywords, &query_words);
196 if keyword_matches > 0 {
197 let boost =
198 (keyword_matches as f32 / query_words.len().max(1) as f32) * 0.2;
199 result.score = (result.score + boost).min(1.0);
200 }
201
202 if let Some(chapter) = &chunk.metadata.chapter {
204 if query.to_lowercase().contains(&chapter.to_lowercase()) {
205 result.score = (result.score + 0.15).min(1.0);
206 }
207 }
208
209 if let Some(section) = &chunk.metadata.section {
210 if query.to_lowercase().contains(§ion.to_lowercase()) {
211 result.score = (result.score + 0.1).min(1.0);
212 }
213 }
214 }
215 }
216 }
217
218 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
220
221 Ok(results)
222 }
223
224 pub fn get_chapter_chunks<'a>(
226 &self,
227 chapter_name: &str,
228 graph: &'a KnowledgeGraph,
229 ) -> Vec<&'a TextChunk> {
230 graph
231 .chunks()
232 .filter(|chunk| {
233 if let Some(ch) = &chunk.metadata.chapter {
234 ch.to_lowercase().contains(&chapter_name.to_lowercase())
235 } else {
236 false
237 }
238 })
239 .collect()
240 }
241
242 pub fn get_section_chunks<'a>(
244 &self,
245 section_name: &str,
246 graph: &'a KnowledgeGraph,
247 ) -> Vec<&'a TextChunk> {
248 graph
249 .chunks()
250 .filter(|chunk| {
251 if let Some(sec) = &chunk.metadata.section {
252 sec.to_lowercase().contains(§ion_name.to_lowercase())
253 } else {
254 false
255 }
256 })
257 .collect()
258 }
259
260 pub fn search_by_keywords(
262 &self,
263 keywords: &[String],
264 graph: &KnowledgeGraph,
265 top_k: usize,
266 ) -> Vec<SearchResult> {
267 let mut keyword_scores: HashMap<String, (f32, &TextChunk)> = HashMap::new();
268
269 for chunk in graph.chunks() {
270 let mut score = 0.0;
271 for keyword in keywords {
272 if chunk
273 .metadata
274 .keywords
275 .iter()
276 .any(|k| k.eq_ignore_ascii_case(keyword))
277 {
278 score += 1.0 / keywords.len() as f32;
279 }
280 }
281
282 if score > 0.0 {
283 keyword_scores.insert(chunk.id.to_string(), (score, chunk));
284 }
285 }
286
287 let mut sorted_results: Vec<_> = keyword_scores.into_iter().collect();
288 sorted_results.sort_by(|a, b| b.1 .0.partial_cmp(&a.1 .0).unwrap());
289
290 sorted_results
291 .into_iter()
292 .take(top_k)
293 .map(|(chunk_id, (score, chunk))| SearchResult {
294 id: chunk_id.clone(),
295 content: chunk.content.clone(),
296 score,
297 result_type: ResultType::Chunk,
298 entities: chunk
299 .entities
300 .iter()
301 .filter_map(|eid| graph.get_entity(eid))
302 .map(|e| e.name.clone())
303 .collect(),
304 source_chunks: vec![chunk_id],
305 })
306 .collect()
307 }
308
309 fn count_keyword_matches(
313 &self,
314 chunk_keywords: &[String],
315 query_words: &HashSet<String>,
316 ) -> usize {
317 chunk_keywords
318 .iter()
319 .filter(|k| query_words.contains(&k.to_lowercase()))
320 .count()
321 }
322
323 fn find_base_score(&self, chunk: &TextChunk, base_results: &[SearchResult]) -> f32 {
325 base_results
326 .iter()
327 .find(|r| r.source_chunks.contains(&chunk.id.to_string()))
328 .map(|r| r.score)
329 .unwrap_or(0.5) }
331
332 fn extract_structure_references(&self, query_lower: &str) -> Vec<String> {
334 let mut refs = Vec::new();
335
336 let patterns = [
338 r"chapter\s+(\d+|[ivxlcdm]+|\w+)",
339 r"section\s+(\d+\.?\d*)",
340 r"part\s+(\d+|[ivxlcdm]+)",
341 ];
342
343 for pattern in &patterns {
344 if let Some(captures) = regex::Regex::new(pattern)
345 .ok()
346 .and_then(|re| re.captures(query_lower))
347 {
348 if let Some(matched) = captures.get(0) {
349 refs.push(matched.as_str().to_string());
350 }
351 }
352 }
353
354 for word in query_lower.split_whitespace() {
356 if word.chars().next().is_some_and(|c| c.is_uppercase()) && word.len() > 5 {
357 refs.push(word.to_string());
358 }
359 }
360
361 refs
362 }
363
364 fn calculate_structure_boost(
366 &self,
367 chunk: &TextChunk,
368 structure_refs: &[String],
369 ) -> Option<f32> {
370 if structure_refs.is_empty() {
371 return None;
372 }
373
374 let mut boost = 0.0;
375
376 for reference in structure_refs {
377 let ref_lower = reference.to_lowercase();
378
379 if let Some(chapter) = &chunk.metadata.chapter {
380 if chapter.to_lowercase().contains(&ref_lower) {
381 boost += 0.5;
382 }
383 }
384
385 if let Some(section) = &chunk.metadata.section {
386 if section.to_lowercase().contains(&ref_lower) {
387 boost += 0.3;
388 }
389 }
390
391 if let Some(subsection) = &chunk.metadata.subsection {
392 if subsection.to_lowercase().contains(&ref_lower) {
393 boost += 0.2;
394 }
395 }
396 }
397
398 if boost > 0.0 {
399 Some(boost)
400 } else {
401 None
402 }
403 }
404
405 fn matches_query(&self, text: &str, query_words: &HashSet<String>) -> bool {
407 let text_lower = text.to_lowercase();
408 query_words
409 .iter()
410 .filter(|word| text_lower.contains(word.as_str()))
411 .count()
412 >= (query_words.len() / 2).max(1)
413 }
414
415 fn matches_structure(
417 &self,
418 metadata: &crate::core::ChunkMetadata,
419 structure_refs: &[String],
420 ) -> bool {
421 for reference in structure_refs {
422 let ref_lower = reference.to_lowercase();
423
424 if let Some(chapter) = &metadata.chapter {
425 if chapter.to_lowercase().contains(&ref_lower) {
426 return true;
427 }
428 }
429
430 if let Some(section) = &metadata.section {
431 if section.to_lowercase().contains(&ref_lower) {
432 return true;
433 }
434 }
435
436 if let Some(subsection) = &metadata.subsection {
437 if subsection.to_lowercase().contains(&ref_lower) {
438 return true;
439 }
440 }
441 }
442
443 false
444 }
445}
446
447impl Default for EnrichedRetriever {
448 fn default() -> Self {
449 Self::new()
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456 use crate::core::{ChunkId, ChunkMetadata, DocumentId, KnowledgeGraph, TextChunk};
457
458 fn create_test_chunk(
459 id: &str,
460 content: &str,
461 keywords: Vec<String>,
462 chapter: Option<String>,
463 ) -> TextChunk {
464 let mut chunk = TextChunk::new(
465 ChunkId::new(id.to_string()),
466 DocumentId::new("test_doc".to_string()),
467 content.to_string(),
468 0,
469 content.len(),
470 );
471
472 let mut metadata = ChunkMetadata::new();
473 metadata.keywords = keywords;
474 metadata.chapter = chapter;
475 chunk.metadata = metadata;
476
477 chunk
478 }
479
480 #[test]
481 fn test_keyword_matching() {
482 let retriever = EnrichedRetriever::new();
483 let chunk_keywords = vec![
484 "machine".to_string(),
485 "learning".to_string(),
486 "neural".to_string(),
487 ];
488 let query_words: HashSet<String> = vec!["machine".to_string(), "learning".to_string()]
489 .into_iter()
490 .collect();
491
492 let matches = retriever.count_keyword_matches(&chunk_keywords, &query_words);
493 assert_eq!(matches, 2);
494 }
495
496 #[test]
497 fn test_structure_extraction() {
498 let retriever = EnrichedRetriever::new();
499 let query = "What does Socrates say in chapter 1?";
500 let refs = retriever.extract_structure_references(&query.to_lowercase());
501
502 assert!(!refs.is_empty());
503 }
504
505 #[test]
506 fn test_chapter_filtering() {
507 let retriever = EnrichedRetriever::new();
508 let mut graph = KnowledgeGraph::new();
509
510 let chunk1 = create_test_chunk(
511 "chunk1",
512 "Content from chapter 1",
513 vec!["content".to_string()],
514 Some("Chapter 1: Introduction".to_string()),
515 );
516
517 let chunk2 = create_test_chunk(
518 "chunk2",
519 "Content from chapter 2",
520 vec!["content".to_string()],
521 Some("Chapter 2: Methods".to_string()),
522 );
523
524 let _ = graph.add_chunk(chunk1);
525 let _ = graph.add_chunk(chunk2);
526
527 let chapter1_chunks = retriever.get_chapter_chunks("Chapter 1", &graph);
528 assert_eq!(chapter1_chunks.len(), 1);
529 }
530}