1use crate::{
7 core::{KnowledgeGraph, TextChunk},
8 retrieval::{QueryAnalysis, ResultType, SearchResult},
9 Result,
10};
11use std::collections::{HashMap, HashSet};
12
13#[derive(Debug, Clone)]
15pub struct EnrichedRetrievalConfig {
16 pub keyword_match_weight: f32,
18 pub structure_match_weight: f32,
20 pub summary_weight: f32,
22 pub min_keyword_matches: usize,
24 pub enable_structure_filtering: bool,
26}
27
28impl Default for EnrichedRetrievalConfig {
29 fn default() -> Self {
30 Self {
31 keyword_match_weight: 0.3,
32 structure_match_weight: 0.2,
33 summary_weight: 0.15,
34 min_keyword_matches: 1,
35 enable_structure_filtering: true,
36 }
37 }
38}
39
40pub struct EnrichedRetriever {
42 config: EnrichedRetrievalConfig,
43}
44
45impl EnrichedRetriever {
46 pub fn new() -> Self {
48 Self {
49 config: EnrichedRetrievalConfig::default(),
50 }
51 }
52
53 pub fn with_config(config: EnrichedRetrievalConfig) -> Self {
55 Self { config }
56 }
57
58 pub fn metadata_search(
65 &self,
66 query: &str,
67 graph: &KnowledgeGraph,
68 _analysis: &QueryAnalysis,
69 base_results: &[SearchResult],
70 ) -> Result<Vec<SearchResult>> {
71 let mut enriched_results = Vec::new();
72
73 let query_lower = query.to_lowercase();
75 let query_words: HashSet<String> = query_lower
76 .split_whitespace()
77 .filter(|w| w.len() > 3)
78 .map(|w| w.to_string())
79 .collect();
80
81 let structure_refs = self.extract_structure_references(&query_lower);
83
84 for chunk in graph.chunks() {
86 if !chunk.entities.is_empty() || !chunk.metadata.keywords.is_empty() {
87 let mut base_score = self.find_base_score(chunk, base_results);
88 let mut metadata_boost = 0.0;
89
90 let keyword_matches =
92 self.count_keyword_matches(&chunk.metadata.keywords, &query_words);
93 if keyword_matches >= self.config.min_keyword_matches {
94 let keyword_boost = (keyword_matches as f32 / query_words.len().max(1) as f32)
95 * self.config.keyword_match_weight;
96 metadata_boost += keyword_boost;
97 }
98
99 if self.config.enable_structure_filtering {
101 if let Some(structure_boost) =
102 self.calculate_structure_boost(chunk, &structure_refs)
103 {
104 metadata_boost += structure_boost * self.config.structure_match_weight;
105 }
106 }
107
108 if let Some(summary) = &chunk.metadata.summary {
110 if self.matches_query(summary, &query_words) {
111 metadata_boost += self.config.summary_weight;
112 }
113 }
114
115 let completeness = chunk.metadata.completeness_score();
117 if completeness > 0.7 {
118 metadata_boost += 0.05; }
120
121 if metadata_boost > 0.05 {
123 base_score = (base_score + metadata_boost).min(1.0);
124
125 enriched_results.push(SearchResult {
126 id: chunk.id.to_string(),
127 content: chunk.content.clone(),
128 score: base_score,
129 result_type: ResultType::Chunk,
130 entities: chunk
131 .entities
132 .iter()
133 .filter_map(|eid| graph.get_entity(eid))
134 .map(|e| e.name.clone())
135 .collect(),
136 source_chunks: vec![chunk.id.to_string()],
137 });
138 }
139 }
140 }
141
142 Ok(enriched_results)
143 }
144
145 pub fn filter_by_structure(
149 &self,
150 query: &str,
151 results: Vec<SearchResult>,
152 graph: &KnowledgeGraph,
153 ) -> Result<Vec<SearchResult>> {
154 let structure_refs = self.extract_structure_references(&query.to_lowercase());
155
156 if structure_refs.is_empty() {
157 return Ok(results);
158 }
159
160 let filtered: Vec<SearchResult> = results
161 .into_iter()
162 .filter(|result| {
163 if let Some(chunk_id) = result.source_chunks.first() {
165 if let Some(chunk) = graph.chunks().find(|c| c.id.to_string() == *chunk_id) {
166 return self.matches_structure(&chunk.metadata, &structure_refs);
167 }
168 }
169 true })
171 .collect();
172
173 Ok(filtered)
174 }
175
176 pub fn boost_with_metadata(
178 &self,
179 mut results: Vec<SearchResult>,
180 query: &str,
181 graph: &KnowledgeGraph,
182 ) -> Result<Vec<SearchResult>> {
183 let query_words: HashSet<String> = query
184 .to_lowercase()
185 .split_whitespace()
186 .filter(|w| w.len() > 3)
187 .map(|w| w.to_string())
188 .collect();
189
190 for result in &mut results {
191 if let Some(chunk_id) = result.source_chunks.first() {
192 if let Some(chunk) = graph.chunks().find(|c| c.id.to_string() == *chunk_id) {
193 let keyword_matches =
195 self.count_keyword_matches(&chunk.metadata.keywords, &query_words);
196 if keyword_matches > 0 {
197 let boost =
198 (keyword_matches as f32 / query_words.len().max(1) as f32) * 0.2;
199 result.score = (result.score + boost).min(1.0);
200 }
201
202 if let Some(chapter) = &chunk.metadata.chapter {
204 if query.to_lowercase().contains(&chapter.to_lowercase()) {
205 result.score = (result.score + 0.15).min(1.0);
206 }
207 }
208
209 if let Some(section) = &chunk.metadata.section {
210 if query.to_lowercase().contains(§ion.to_lowercase()) {
211 result.score = (result.score + 0.1).min(1.0);
212 }
213 }
214 }
215 }
216 }
217
218 results.sort_by(|a, b| {
220 b.score
221 .partial_cmp(&a.score)
222 .unwrap_or(std::cmp::Ordering::Equal)
223 });
224
225 Ok(results)
226 }
227
228 pub fn get_chapter_chunks<'a>(
230 &self,
231 chapter_name: &str,
232 graph: &'a KnowledgeGraph,
233 ) -> Vec<&'a TextChunk> {
234 graph
235 .chunks()
236 .filter(|chunk| {
237 if let Some(ch) = &chunk.metadata.chapter {
238 ch.to_lowercase().contains(&chapter_name.to_lowercase())
239 } else {
240 false
241 }
242 })
243 .collect()
244 }
245
246 pub fn get_section_chunks<'a>(
248 &self,
249 section_name: &str,
250 graph: &'a KnowledgeGraph,
251 ) -> Vec<&'a TextChunk> {
252 graph
253 .chunks()
254 .filter(|chunk| {
255 if let Some(sec) = &chunk.metadata.section {
256 sec.to_lowercase().contains(§ion_name.to_lowercase())
257 } else {
258 false
259 }
260 })
261 .collect()
262 }
263
264 pub fn search_by_keywords(
266 &self,
267 keywords: &[String],
268 graph: &KnowledgeGraph,
269 top_k: usize,
270 ) -> Vec<SearchResult> {
271 let mut keyword_scores: HashMap<String, (f32, &TextChunk)> = HashMap::new();
272
273 for chunk in graph.chunks() {
274 let mut score = 0.0;
275 for keyword in keywords {
276 if chunk
277 .metadata
278 .keywords
279 .iter()
280 .any(|k| k.eq_ignore_ascii_case(keyword))
281 {
282 score += 1.0 / keywords.len() as f32;
283 }
284 }
285
286 if score > 0.0 {
287 keyword_scores.insert(chunk.id.to_string(), (score, chunk));
288 }
289 }
290
291 let mut sorted_results: Vec<_> = keyword_scores.into_iter().collect();
292 sorted_results.sort_by(|a, b| {
293 b.1 .0
294 .partial_cmp(&a.1 .0)
295 .unwrap_or(std::cmp::Ordering::Equal)
296 });
297
298 sorted_results
299 .into_iter()
300 .take(top_k)
301 .map(|(chunk_id, (score, chunk))| SearchResult {
302 id: chunk_id.clone(),
303 content: chunk.content.clone(),
304 score,
305 result_type: ResultType::Chunk,
306 entities: chunk
307 .entities
308 .iter()
309 .filter_map(|eid| graph.get_entity(eid))
310 .map(|e| e.name.clone())
311 .collect(),
312 source_chunks: vec![chunk_id],
313 })
314 .collect()
315 }
316
317 fn count_keyword_matches(
321 &self,
322 chunk_keywords: &[String],
323 query_words: &HashSet<String>,
324 ) -> usize {
325 chunk_keywords
326 .iter()
327 .filter(|k| query_words.contains(&k.to_lowercase()))
328 .count()
329 }
330
331 fn find_base_score(&self, chunk: &TextChunk, base_results: &[SearchResult]) -> f32 {
333 base_results
334 .iter()
335 .find(|r| r.source_chunks.contains(&chunk.id.to_string()))
336 .map(|r| r.score)
337 .unwrap_or(0.5) }
339
340 fn extract_structure_references(&self, query_lower: &str) -> Vec<String> {
342 let mut refs = Vec::new();
343
344 let patterns = [
346 r"chapter\s+(\d+|[ivxlcdm]+|\w+)",
347 r"section\s+(\d+\.?\d*)",
348 r"part\s+(\d+|[ivxlcdm]+)",
349 ];
350
351 for pattern in &patterns {
352 if let Some(captures) = regex::Regex::new(pattern)
353 .ok()
354 .and_then(|re| re.captures(query_lower))
355 {
356 if let Some(matched) = captures.get(0) {
357 refs.push(matched.as_str().to_string());
358 }
359 }
360 }
361
362 for word in query_lower.split_whitespace() {
364 if word.chars().next().is_some_and(|c| c.is_uppercase()) && word.len() > 5 {
365 refs.push(word.to_string());
366 }
367 }
368
369 refs
370 }
371
372 fn calculate_structure_boost(
374 &self,
375 chunk: &TextChunk,
376 structure_refs: &[String],
377 ) -> Option<f32> {
378 if structure_refs.is_empty() {
379 return None;
380 }
381
382 let mut boost = 0.0;
383
384 for reference in structure_refs {
385 let ref_lower = reference.to_lowercase();
386
387 if let Some(chapter) = &chunk.metadata.chapter {
388 if chapter.to_lowercase().contains(&ref_lower) {
389 boost += 0.5;
390 }
391 }
392
393 if let Some(section) = &chunk.metadata.section {
394 if section.to_lowercase().contains(&ref_lower) {
395 boost += 0.3;
396 }
397 }
398
399 if let Some(subsection) = &chunk.metadata.subsection {
400 if subsection.to_lowercase().contains(&ref_lower) {
401 boost += 0.2;
402 }
403 }
404 }
405
406 if boost > 0.0 {
407 Some(boost)
408 } else {
409 None
410 }
411 }
412
413 fn matches_query(&self, text: &str, query_words: &HashSet<String>) -> bool {
415 let text_lower = text.to_lowercase();
416 query_words
417 .iter()
418 .filter(|word| text_lower.contains(word.as_str()))
419 .count()
420 >= (query_words.len() / 2).max(1)
421 }
422
423 fn matches_structure(
425 &self,
426 metadata: &crate::core::ChunkMetadata,
427 structure_refs: &[String],
428 ) -> bool {
429 for reference in structure_refs {
430 let ref_lower = reference.to_lowercase();
431
432 if let Some(chapter) = &metadata.chapter {
433 if chapter.to_lowercase().contains(&ref_lower) {
434 return true;
435 }
436 }
437
438 if let Some(section) = &metadata.section {
439 if section.to_lowercase().contains(&ref_lower) {
440 return true;
441 }
442 }
443
444 if let Some(subsection) = &metadata.subsection {
445 if subsection.to_lowercase().contains(&ref_lower) {
446 return true;
447 }
448 }
449 }
450
451 false
452 }
453}
454
455impl Default for EnrichedRetriever {
456 fn default() -> Self {
457 Self::new()
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464 use crate::core::{ChunkId, ChunkMetadata, DocumentId, KnowledgeGraph, TextChunk};
465
466 fn create_test_chunk(
467 id: &str,
468 content: &str,
469 keywords: Vec<String>,
470 chapter: Option<String>,
471 ) -> TextChunk {
472 let mut chunk = TextChunk::new(
473 ChunkId::new(id.to_string()),
474 DocumentId::new("test_doc".to_string()),
475 content.to_string(),
476 0,
477 content.len(),
478 );
479
480 let mut metadata = ChunkMetadata::new();
481 metadata.keywords = keywords;
482 metadata.chapter = chapter;
483 chunk.metadata = metadata;
484
485 chunk
486 }
487
488 #[test]
489 fn test_keyword_matching() {
490 let retriever = EnrichedRetriever::new();
491 let chunk_keywords = vec![
492 "machine".to_string(),
493 "learning".to_string(),
494 "neural".to_string(),
495 ];
496 let query_words: HashSet<String> = vec!["machine".to_string(), "learning".to_string()]
497 .into_iter()
498 .collect();
499
500 let matches = retriever.count_keyword_matches(&chunk_keywords, &query_words);
501 assert_eq!(matches, 2);
502 }
503
504 #[test]
505 fn test_structure_extraction() {
506 let retriever = EnrichedRetriever::new();
507 let query = "What does Socrates say in chapter 1?";
508 let refs = retriever.extract_structure_references(&query.to_lowercase());
509
510 assert!(!refs.is_empty());
511 }
512
513 #[test]
514 fn test_chapter_filtering() {
515 let retriever = EnrichedRetriever::new();
516 let mut graph = KnowledgeGraph::new();
517
518 let chunk1 = create_test_chunk(
519 "chunk1",
520 "Content from chapter 1",
521 vec!["content".to_string()],
522 Some("Chapter 1: Introduction".to_string()),
523 );
524
525 let chunk2 = create_test_chunk(
526 "chunk2",
527 "Content from chapter 2",
528 vec!["content".to_string()],
529 Some("Chapter 2: Methods".to_string()),
530 );
531
532 let _ = graph.add_chunk(chunk1);
533 let _ = graph.add_chunk(chunk2);
534
535 let chapter1_chunks = retriever.get_chapter_chunks("Chapter 1", &graph);
536 assert_eq!(chapter1_chunks.len(), 1);
537 }
538}