1#[cfg(feature = "parallel-processing")]
7use crate::parallel::ParallelProcessor;
8use crate::{
9 core::{ChunkId, DocumentId, GraphRAGError, TextChunk},
10 text::TextProcessor,
11 Result,
12};
13use indexmap::IndexMap;
14use std::collections::{HashMap, VecDeque};
15use std::sync::Arc;
16
17#[async_trait::async_trait]
19pub trait LLMClient: Send + Sync {
20 async fn generate_summary(
22 &self,
23 text: &str,
24 prompt: &str,
25 max_tokens: usize,
26 temperature: f32,
27 ) -> Result<String>;
28
29 async fn generate_summary_batch(
31 &self,
32 texts: &[(&str, &str)],
33 max_tokens: usize,
34 temperature: f32,
35 ) -> Result<Vec<String>> {
36 let mut results = Vec::new();
37 for (text, prompt) in texts {
38 let summary = self
39 .generate_summary(text, prompt, max_tokens, temperature)
40 .await?;
41 results.push(summary);
42 }
43 Ok(results)
44 }
45
46 fn model_name(&self) -> &str;
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub struct NodeId(pub String);
53
54impl NodeId {
55 pub fn new(id: String) -> Self {
57 Self(id)
58 }
59}
60
61impl std::fmt::Display for NodeId {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 write!(f, "{}", self.0)
64 }
65}
66
67impl From<String> for NodeId {
68 fn from(s: String) -> Self {
69 Self(s)
70 }
71}
72
73impl From<NodeId> for String {
74 fn from(id: NodeId) -> Self {
75 id.0
76 }
77}
78
79#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
81pub struct HierarchicalConfig {
82 pub merge_size: usize,
84 pub max_summary_length: usize,
86 pub min_node_size: usize,
88 pub overlap_sentences: usize,
90 pub llm_config: LLMConfig,
92}
93
94#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
96pub struct LLMConfig {
97 pub enabled: bool,
99 pub model_name: String,
101 pub temperature: f32,
103 pub max_tokens: usize,
105 pub strategy: LLMStrategy,
107 pub level_configs: HashMap<usize, LevelConfig>,
109}
110
111#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
113pub enum LLMStrategy {
114 Uniform,
116 Adaptive,
118 Progressive,
120}
121
122#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
124pub struct LevelConfig {
125 pub max_length: usize,
127 pub use_abstractive: bool,
129 pub prompt_template: Option<String>,
131 pub temperature: Option<f32>,
133}
134
135impl Default for HierarchicalConfig {
136 fn default() -> Self {
137 Self {
138 merge_size: 5,
139 max_summary_length: 200,
140 min_node_size: 50,
141 overlap_sentences: 2,
142 llm_config: LLMConfig::default(),
143 }
144 }
145}
146
147impl Default for LLMConfig {
148 fn default() -> Self {
149 Self {
150 enabled: false, model_name: "llama3.1:8b".to_string(),
152 temperature: 0.3, max_tokens: 150,
154 strategy: LLMStrategy::Progressive,
155 level_configs: HashMap::new(),
156 }
157 }
158}
159
160impl Default for LevelConfig {
161 fn default() -> Self {
162 Self {
163 max_length: 200,
164 use_abstractive: false, prompt_template: None,
166 temperature: None,
167 }
168 }
169}
170
171#[derive(Debug, Clone)]
173pub struct TreeNode {
174 pub id: NodeId,
176 pub content: String,
178 pub summary: String,
180 pub level: usize,
182 pub children: Vec<NodeId>,
184 pub parent: Option<NodeId>,
186 pub chunk_ids: Vec<ChunkId>,
188 pub keywords: Vec<String>,
190 pub start_offset: usize,
192 pub end_offset: usize,
194}
195
196pub struct DocumentTree {
198 nodes: IndexMap<NodeId, TreeNode>,
199 root_nodes: Vec<NodeId>,
200 levels: HashMap<usize, Vec<NodeId>>,
201 document_id: DocumentId,
202 config: HierarchicalConfig,
203 text_processor: TextProcessor,
204 llm_client: Option<Arc<dyn LLMClient>>,
205}
206
207impl DocumentTree {
208 pub fn new(document_id: DocumentId, config: HierarchicalConfig) -> Result<Self> {
210 let text_processor = TextProcessor::new(1000, 100)?;
211
212 Ok(Self {
213 nodes: IndexMap::new(),
214 root_nodes: Vec::new(),
215 levels: HashMap::new(),
216 document_id,
217 config,
218 text_processor,
219 llm_client: None,
220 })
221 }
222
223 pub fn with_llm_client(
225 document_id: DocumentId,
226 config: HierarchicalConfig,
227 llm_client: Arc<dyn LLMClient>,
228 ) -> Result<Self> {
229 let text_processor = TextProcessor::new(1000, 100)?;
230
231 Ok(Self {
232 nodes: IndexMap::new(),
233 root_nodes: Vec::new(),
234 levels: HashMap::new(),
235 document_id,
236 config,
237 text_processor,
238 llm_client: Some(llm_client),
239 })
240 }
241
242 pub fn set_llm_client(&mut self, llm_client: Option<Arc<dyn LLMClient>>) {
244 self.llm_client = llm_client;
245 }
246
247 #[cfg(feature = "parallel-processing")]
249 pub fn with_parallel_processing(
250 document_id: DocumentId,
251 config: HierarchicalConfig,
252 _parallel_processor: ParallelProcessor,
253 ) -> Result<Self> {
254 let text_processor = TextProcessor::new(1000, 100)?;
255
256 Ok(Self {
257 nodes: IndexMap::new(),
258 root_nodes: Vec::new(),
259 levels: HashMap::new(),
260 document_id,
261 config,
262 text_processor,
263 llm_client: None,
264 })
265 }
266
267 #[cfg(feature = "parallel-processing")]
269 pub fn with_parallel_and_llm(
270 document_id: DocumentId,
271 config: HierarchicalConfig,
272 _parallel_processor: ParallelProcessor,
273 llm_client: Arc<dyn LLMClient>,
274 ) -> Result<Self> {
275 let text_processor = TextProcessor::new(1000, 100)?;
276
277 Ok(Self {
278 nodes: IndexMap::new(),
279 root_nodes: Vec::new(),
280 levels: HashMap::new(),
281 document_id,
282 config,
283 text_processor,
284 llm_client: Some(llm_client),
285 })
286 }
287
288 pub async fn build_from_chunks(&mut self, chunks: Vec<TextChunk>) -> Result<()> {
290 if chunks.is_empty() {
291 return Ok(());
292 }
293
294 println!("Building hierarchical tree from {} chunks", chunks.len());
295
296 let leaf_nodes = self.create_leaf_nodes(chunks)?;
298
299 self.build_bottom_up(leaf_nodes).await?;
301
302 println!(
303 "Tree built with {} total nodes across {} levels",
304 self.nodes.len(),
305 self.levels.len()
306 );
307
308 Ok(())
309 }
310
311 fn create_leaf_nodes(&mut self, chunks: Vec<TextChunk>) -> Result<Vec<NodeId>> {
313 if chunks.len() < 10 {
314 return self.create_leaf_nodes_sequential(chunks);
316 }
317
318 #[cfg(feature = "parallel-processing")]
319 {
320 use rayon::prelude::*;
321
322 let node_results: std::result::Result<Vec<_>, crate::GraphRAGError> = chunks
324 .par_iter()
325 .map(|chunk| {
326 let node_id = NodeId::new(format!("leaf_{}", chunk.id));
327
328 let temp_processor =
330 crate::text::TextProcessor::new(1000, 100).map_err(|e| {
331 crate::GraphRAGError::Config {
332 message: format!("Failed to create text processor: {e}"),
333 }
334 })?;
335
336 let keywords = temp_processor.extract_keywords(&chunk.content, 5);
338
339 let summary =
341 self.generate_parallel_summary(&chunk.content, &temp_processor)?;
342
343 let node = TreeNode {
344 id: node_id.clone(),
345 content: chunk.content.clone(),
346 summary,
347 level: 0,
348 children: Vec::new(),
349 parent: None,
350 chunk_ids: vec![chunk.id.clone()],
351 keywords,
352 start_offset: chunk.start_offset,
353 end_offset: chunk.end_offset,
354 };
355
356 Ok((node_id, node))
357 })
358 .collect();
359
360 match node_results {
361 Ok(nodes) => {
362 let mut leaf_node_ids = Vec::new();
363
364 for (node_id, node) in nodes {
366 leaf_node_ids.push(node_id.clone());
367 self.nodes.insert(node_id, node);
368 }
369
370 println!("Created {} leaf nodes in parallel", leaf_node_ids.len());
371
372 self.levels.insert(0, leaf_node_ids.clone());
374
375 Ok(leaf_node_ids)
376 },
377 Err(e) => {
378 eprintln!("Error in parallel node creation: {e}");
379 self.create_leaf_nodes_sequential(chunks)
381 },
382 }
383 }
384
385 #[cfg(not(feature = "parallel-processing"))]
386 {
387 self.create_leaf_nodes_sequential(chunks)
388 }
389 }
390
391 fn create_leaf_nodes_sequential(&mut self, chunks: Vec<TextChunk>) -> Result<Vec<NodeId>> {
393 let mut leaf_node_ids = Vec::new();
394
395 for chunk in chunks {
396 let node_id = NodeId::new(format!("leaf_{}", chunk.id));
397
398 let keywords = self.text_processor.extract_keywords(&chunk.content, 5);
400
401 let node = TreeNode {
402 id: node_id.clone(),
403 content: chunk.content.clone(),
404 summary: self.generate_extractive_summary(&chunk.content)?,
405 level: 0,
406 children: Vec::new(),
407 parent: None,
408 chunk_ids: vec![chunk.id],
409 keywords,
410 start_offset: chunk.start_offset,
411 end_offset: chunk.end_offset,
412 };
413
414 self.nodes.insert(node_id.clone(), node);
415 leaf_node_ids.push(node_id);
416 }
417
418 self.levels.insert(0, leaf_node_ids.clone());
420
421 Ok(leaf_node_ids)
422 }
423
424 pub async fn generate_llm_summary(
427 &self,
428 text: &str,
429 level: usize,
430 context: &str,
431 ) -> Result<String> {
432 let llm_client = self
433 .llm_client
434 .as_ref()
435 .ok_or_else(|| GraphRAGError::Config {
436 message: "LLM client not configured for summarization".to_string(),
437 })?;
438
439 let level_config = self.get_level_config(level);
441
442 let prompt = self.create_summary_prompt(text, level, context, &level_config)?;
444
445 let summary = llm_client
447 .generate_summary(
448 text,
449 &prompt,
450 level_config.max_length,
451 level_config
452 .temperature
453 .unwrap_or(self.config.llm_config.temperature),
454 )
455 .await?;
456
457 self.truncate_summary(&summary, level_config.max_length)
459 }
460
461 pub async fn generate_llm_summaries_batch(
463 &self,
464 texts: &[(&str, usize, &str)], ) -> Result<Vec<String>> {
466 let llm_client = self
467 .llm_client
468 .as_ref()
469 .ok_or_else(|| GraphRAGError::Config {
470 message: "LLM client not configured for summarization".to_string(),
471 })?;
472
473 let mut prompts = Vec::new();
474 let mut configs = Vec::new();
475
476 for (text, level, context) in texts {
477 let level_config = self.get_level_config(*level);
478 let prompt = self.create_summary_prompt(text, *level, context, &level_config)?;
479 prompts.push(prompt);
480 configs.push(level_config);
481 }
482
483 let text_refs: Vec<&str> = texts.iter().map(|(t, _, _)| *t).collect();
485 let prompt_refs: Vec<&str> = prompts.iter().map(|p| p.as_str()).collect();
486
487 let summaries = llm_client
488 .generate_summary_batch(
489 &text_refs
490 .iter()
491 .zip(prompt_refs.iter())
492 .map(|(&t, &p)| (t, p))
493 .collect::<Vec<_>>(),
494 self.config.llm_config.max_tokens,
495 self.config.llm_config.temperature,
496 )
497 .await?;
498
499 let mut results = Vec::new();
501 for (i, summary) in summaries.into_iter().enumerate() {
502 let truncated = self.truncate_summary(&summary, configs[i].max_length)?;
503 results.push(truncated);
504 }
505
506 Ok(results)
507 }
508
509 fn create_summary_prompt(
511 &self,
512 text: &str,
513 level: usize,
514 context: &str,
515 level_config: &LevelConfig,
516 ) -> Result<String> {
517 if let Some(template) = &level_config.prompt_template {
519 return Ok(template
520 .replace("{text}", text)
521 .replace("{context}", context)
522 .replace("{level}", &level.to_string())
523 .replace("{max_length}", &level_config.max_length.to_string()));
524 }
525
526 match self.config.llm_config.strategy {
528 LLMStrategy::Uniform => {
529 Ok(format!(
530 "Create a concise summary of the following text. The summary should be approximately {} characters long.\n\nContext: {}\n\nText to summarize:\n{}\n\nSummary:",
531 level_config.max_length, context, text
532 ))
533 }
534 LLMStrategy::Adaptive => {
535 if level == 0 {
536 Ok(format!(
537 "Extract the key information from this text segment. Keep it factual and under {} characters.\n\nContext: {}\n\nText:\n{}\n\nKey points:",
538 level_config.max_length, context, text
539 ))
540 } else if level <= 2 {
541 Ok(format!(
542 "Create a coherent summary that combines the key information from this text. Make it approximately {} characters.\n\nContext: {}\n\nText:\n{}\n\nSummary:",
543 level_config.max_length, context, text
544 ))
545 } else {
546 Ok(format!(
547 "Generate a high-level abstract summary of this content. Focus on the main themes and insights. Limit to approximately {} characters.\n\nContext: {}\n\nText:\n{}\n\nAbstract summary:",
548 level_config.max_length, context, text
549 ))
550 }
551 }
552 LLMStrategy::Progressive => {
553 if level_config.use_abstractive {
554 Ok(format!(
555 "Generate an abstractive summary that synthesizes the key concepts and relationships in this text. The summary should be approximately {} characters.\n\nContext: {}\n\nText:\n{}\n\nAbstractive summary:",
556 level_config.max_length, context, text
557 ))
558 } else {
559 Ok(format!(
560 "Extract and organize the most important sentences from this text to create a coherent summary. Keep it under {} characters.\n\nContext: {}\n\nText:\n{}\n\nExtractive summary:",
561 level_config.max_length, context, text
562 ))
563 }
564 }
565 }
566 }
567
568 fn get_level_config(&self, level: usize) -> LevelConfig {
570 self.config
571 .llm_config
572 .level_configs
573 .get(&level)
574 .cloned()
575 .unwrap_or({
576 LevelConfig {
578 max_length: self.config.max_summary_length,
579 use_abstractive: match self.config.llm_config.strategy {
580 LLMStrategy::Progressive => level >= 2,
581 LLMStrategy::Adaptive => level >= 3,
582 LLMStrategy::Uniform => level > 0,
583 },
584 prompt_template: None,
585 temperature: None,
586 }
587 })
588 }
589
590 fn truncate_summary(&self, summary: &str, max_length: usize) -> Result<String> {
592 if summary.len() <= max_length {
593 return Ok(summary.to_string());
594 }
595
596 let sentences: Vec<&str> = summary
598 .split('.')
599 .filter(|s| !s.trim().is_empty())
600 .collect();
601
602 let mut result = String::new();
603 for sentence in sentences {
604 if result.len() + sentence.len() < max_length - 3 {
605 if !result.is_empty() {
606 result.push('.');
607 }
608 result.push_str(sentence.trim());
609 } else {
610 break;
611 }
612 }
613
614 if result.is_empty() {
615 result = summary.chars().take(max_length - 3).collect();
617 result.push_str("...");
618 } else {
619 result.push('.');
620 }
621
622 Ok(result)
623 }
624
625 #[allow(dead_code)]
626 fn generate_parallel_summary(
627 &self,
628 text: &str,
629 processor: &crate::text::TextProcessor,
630 ) -> Result<String> {
631 let sentences = processor.extract_sentences(text);
632
633 if sentences.is_empty() {
634 return Ok(String::new());
635 }
636
637 if sentences.len() == 1 {
638 return Ok(sentences[0].clone());
639 }
640
641 let mut best_sentence = &sentences[0];
643 let mut best_score = 0.0;
644
645 for sentence in &sentences {
646 let words: Vec<&str> = sentence.split_whitespace().collect();
647
648 let length_score = if words.len() < 5 {
650 0.1
651 } else if words.len() > 30 {
652 0.3
653 } else {
654 1.0
655 };
656
657 let word_score = words.len() as f32 * 0.1;
658 let score = length_score + word_score;
659
660 if score > best_score {
661 best_score = score;
662 best_sentence = sentence;
663 }
664 }
665
666 if best_sentence.len() > self.config.max_summary_length {
668 Ok(best_sentence
669 .chars()
670 .take(self.config.max_summary_length - 3)
671 .collect::<String>()
672 + "...")
673 } else {
674 Ok(best_sentence.clone())
675 }
676 }
677
678 async fn build_bottom_up(&mut self, leaf_nodes: Vec<NodeId>) -> Result<()> {
680 let mut current_level_nodes = leaf_nodes;
681 let mut current_level = 0;
682
683 while current_level_nodes.len() > 1 {
684 let next_level_nodes = self
685 .merge_level(¤t_level_nodes, current_level + 1)
686 .await?;
687
688 current_level_nodes = next_level_nodes;
689 current_level += 1;
690 }
691
692 self.root_nodes = current_level_nodes;
694
695 Ok(())
696 }
697
698 async fn merge_level(
700 &mut self,
701 level_nodes: &[NodeId],
702 new_level: usize,
703 ) -> Result<Vec<NodeId>> {
704 let mut new_level_nodes = Vec::new();
705 for (node_counter, chunk) in level_nodes.chunks(self.config.merge_size).enumerate() {
707 let merged_node_id = NodeId::new(format!("level_{new_level}_{node_counter}"));
708 let merged_node = self
709 .merge_nodes(chunk, merged_node_id.clone(), new_level)
710 .await?;
711
712 for child_id in chunk {
714 if let Some(child_node) = self.nodes.get_mut(child_id) {
715 child_node.parent = Some(merged_node_id.clone());
716 }
717 }
718
719 self.nodes.insert(merged_node_id.clone(), merged_node);
720 new_level_nodes.push(merged_node_id);
721 }
722
723 self.levels.insert(new_level, new_level_nodes.clone());
725
726 Ok(new_level_nodes)
727 }
728
729 async fn merge_nodes(
731 &self,
732 node_ids: &[NodeId],
733 merged_id: NodeId,
734 level: usize,
735 ) -> Result<TreeNode> {
736 let mut combined_content = String::new();
737 let mut all_chunk_ids = Vec::new();
738 let mut all_keywords = Vec::new();
739 let mut min_offset = usize::MAX;
740 let mut max_offset = 0;
741
742 for node_id in node_ids {
743 if let Some(node) = self.nodes.get(node_id) {
744 if !combined_content.is_empty() {
745 combined_content.push_str("\n\n");
746 }
747 combined_content.push_str(&node.content);
748 all_chunk_ids.extend(node.chunk_ids.clone());
749 all_keywords.extend(node.keywords.clone());
750 min_offset = min_offset.min(node.start_offset);
751 max_offset = max_offset.max(node.end_offset);
752 }
753 }
754
755 all_keywords.sort();
757 all_keywords.dedup();
758 all_keywords.truncate(10);
759
760 let summary = if self.config.llm_config.enabled {
762 if self.llm_client.is_some() {
764 let context = format!(
766 "Merging {} nodes at level {}. This represents a higher-level abstraction of the document content.",
767 node_ids.len(),
768 level
769 );
770
771 match self
773 .generate_llm_summary(&combined_content, level, &context)
774 .await
775 {
776 Ok(llm_summary) => {
777 println!(
778 "✅ Generated LLM-based summary for level {} ({} chars)",
779 level,
780 llm_summary.len()
781 );
782 llm_summary
783 },
784 Err(e) => {
785 eprintln!("⚠️ LLM summarization failed for level {}: {}, falling back to extractive", level, e);
786 self.generate_extractive_summary(&combined_content)?
787 },
788 }
789 } else {
790 self.generate_extractive_summary(&combined_content)?
791 }
792 } else {
793 self.generate_extractive_summary(&combined_content)?
794 };
795
796 Ok(TreeNode {
797 id: merged_id,
798 content: combined_content,
799 summary,
800 level,
801 children: node_ids.to_vec(),
802 parent: None,
803 chunk_ids: all_chunk_ids,
804 keywords: all_keywords,
805 start_offset: min_offset,
806 end_offset: max_offset,
807 })
808 }
809
810 fn generate_extractive_summary(&self, text: &str) -> Result<String> {
812 let sentences = self.text_processor.extract_sentences(text);
813
814 if sentences.is_empty() {
815 return Ok(String::new());
816 }
817
818 if sentences.len() == 1 {
819 return Ok(sentences[0].clone());
820 }
821
822 let mut sentence_scores: Vec<(usize, f32)> = sentences
824 .iter()
825 .enumerate()
826 .map(|(i, sentence)| {
827 let score = self.score_sentence(sentence, &sentences);
828 (i, score)
829 })
830 .collect();
831
832 sentence_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
834
835 let mut summary = String::new();
837 let mut selected_indices = Vec::new();
838
839 for (sentence_idx, _score) in sentence_scores {
840 let sentence = &sentences[sentence_idx];
841 if summary.len() + sentence.len() <= self.config.max_summary_length {
842 selected_indices.push(sentence_idx);
843 if !summary.is_empty() {
844 summary.push(' ');
845 }
846 summary.push_str(sentence);
847 }
848 }
849
850 if summary.is_empty() && !sentences.is_empty() {
852 let first_sentence = &sentences[0];
853 if first_sentence.len() <= self.config.max_summary_length {
854 summary = first_sentence.clone();
855 } else {
856 summary = first_sentence
857 .chars()
858 .take(self.config.max_summary_length - 3)
859 .collect::<String>()
860 + "...";
861 }
862 }
863
864 Ok(summary)
865 }
866
867 fn score_sentence(&self, sentence: &str, all_sentences: &[String]) -> f32 {
869 let words: Vec<&str> = sentence.split_whitespace().collect();
870
871 let length_score = if words.len() < 5 {
873 0.1
874 } else if words.len() > 30 {
875 0.3
876 } else {
877 1.0
878 };
879
880 let position_score = 0.5; let mut word_freq_score = 0.0;
885 let total_words: Vec<&str> = all_sentences
886 .iter()
887 .flat_map(|s| s.split_whitespace())
888 .collect();
889
890 for word in &words {
891 let word_lower = word.to_lowercase();
892 if word_lower.len() > 3 && !self.is_stop_word(&word_lower) {
893 let freq = total_words
894 .iter()
895 .filter(|&&w| w.to_lowercase() == word_lower)
896 .count();
897 if freq > 1 {
898 word_freq_score += freq as f32 / total_words.len() as f32;
899 }
900 }
901 }
902
903 length_score * 0.4 + position_score * 0.2 + word_freq_score * 0.4
904 }
905
906 fn is_stop_word(&self, word: &str) -> bool {
908 const STOP_WORDS: &[&str] = &[
909 "the", "be", "to", "of", "and", "a", "in", "that", "have", "i", "it", "for", "not",
910 "on", "with", "he", "as", "you", "do", "at", "this", "but", "his", "by", "from",
911 "they", "we", "say", "her", "she", "or", "an", "will", "my", "one", "all", "would",
912 "there", "their", "what", "so", "up", "out", "if", "about", "who", "get", "which",
913 "go", "me",
914 ];
915 STOP_WORDS.contains(&word)
916 }
917
918 pub fn query(&self, query: &str, max_results: usize) -> Result<Vec<QueryResult>> {
920 let query_keywords = self.text_processor.extract_keywords(query, 5);
921 let mut results = Vec::new();
922
923 for (node_id, node) in &self.nodes {
925 let score = self.calculate_relevance_score(node, &query_keywords, query);
926
927 if score > 0.1 {
928 results.push(QueryResult {
929 node_id: node_id.clone(),
930 score,
931 level: node.level,
932 summary: node.summary.clone(),
933 keywords: node.keywords.clone(),
934 chunk_ids: node.chunk_ids.clone(),
935 });
936 }
937 }
938
939 results.sort_by(|a, b| {
941 b.score
942 .partial_cmp(&a.score)
943 .unwrap_or(std::cmp::Ordering::Equal)
944 });
945 results.truncate(max_results);
946
947 Ok(results)
948 }
949
950 fn calculate_relevance_score(
952 &self,
953 node: &TreeNode,
954 query_keywords: &[String],
955 query: &str,
956 ) -> f32 {
957 let mut score = 0.0;
958
959 let node_text = format!("{} {}", node.summary, node.keywords.join(" ")).to_lowercase();
961 for keyword in query_keywords {
962 if node_text.contains(&keyword.to_lowercase()) {
963 score += 1.0;
964 }
965 }
966
967 let query_words: Vec<&str> = query.split_whitespace().collect();
969 let node_words: Vec<&str> = node_text.split_whitespace().collect();
970
971 let mut overlap_count = 0;
972 for query_word in &query_words {
973 if node_words.contains(&query_word.to_lowercase().as_str()) {
974 overlap_count += 1;
975 }
976 }
977
978 if !query_words.is_empty() {
979 score += (overlap_count as f32 / query_words.len() as f32) * 2.0;
980 }
981
982 let level_score = 1.0 / (node.level + 1) as f32;
984 score += level_score * 0.5;
985
986 score
987 }
988
989 pub fn get_ancestors(&self, node_id: &NodeId) -> Vec<&TreeNode> {
991 let mut ancestors = Vec::new();
992 let mut current_id = Some(node_id.clone());
993
994 while let Some(id) = current_id {
995 if let Some(node) = self.nodes.get(&id) {
996 ancestors.push(node);
997 current_id = node.parent.clone();
998 } else {
999 break;
1000 }
1001 }
1002
1003 ancestors
1004 }
1005
1006 pub fn get_descendants(&self, node_id: &NodeId) -> Vec<&TreeNode> {
1008 let mut descendants = Vec::new();
1009 let mut queue = VecDeque::new();
1010
1011 if let Some(node) = self.nodes.get(node_id) {
1012 queue.extend(node.children.iter());
1013 }
1014
1015 while let Some(child_id) = queue.pop_front() {
1016 if let Some(child_node) = self.nodes.get(child_id) {
1017 descendants.push(child_node);
1018 queue.extend(child_node.children.iter());
1019 }
1020 }
1021
1022 descendants
1023 }
1024
1025 pub fn get_node(&self, node_id: &NodeId) -> Option<&TreeNode> {
1027 self.nodes.get(node_id)
1028 }
1029
1030 pub fn get_level_nodes(&self, level: usize) -> Vec<&TreeNode> {
1032 if let Some(node_ids) = self.levels.get(&level) {
1033 node_ids
1034 .iter()
1035 .filter_map(|id| self.nodes.get(id))
1036 .collect()
1037 } else {
1038 Vec::new()
1039 }
1040 }
1041
1042 pub fn get_root_nodes(&self) -> Vec<&TreeNode> {
1044 self.root_nodes
1045 .iter()
1046 .filter_map(|id| self.nodes.get(id))
1047 .collect()
1048 }
1049
1050 pub fn document_id(&self) -> &DocumentId {
1052 &self.document_id
1053 }
1054
1055 pub fn get_statistics(&self) -> TreeStatistics {
1057 let max_level = self.levels.keys().max().copied().unwrap_or(0);
1058 let total_nodes = self.nodes.len();
1059 let nodes_per_level: HashMap<usize, usize> = self
1060 .levels
1061 .iter()
1062 .map(|(level, nodes)| (*level, nodes.len()))
1063 .collect();
1064
1065 TreeStatistics {
1066 total_nodes,
1067 max_level,
1068 nodes_per_level,
1069 root_count: self.root_nodes.len(),
1070 document_id: self.document_id.clone(),
1071 }
1072 }
1073
1074 pub fn to_json(&self) -> Result<String> {
1076 use json::JsonValue;
1077
1078 let mut tree_json = json::object! {
1079 "document_id": self.document_id.to_string(),
1080 "config": {
1081 "merge_size": self.config.merge_size,
1082 "max_summary_length": self.config.max_summary_length,
1083 "min_node_size": self.config.min_node_size,
1084 "overlap_sentences": self.config.overlap_sentences
1085 },
1086 "nodes": {},
1087 "root_nodes": [],
1088 "levels": {}
1089 };
1090
1091 for (node_id, node) in &self.nodes {
1093 let node_json = json::object! {
1094 "id": node_id.to_string(),
1095 "content": node.content.clone(),
1096 "summary": node.summary.clone(),
1097 "level": node.level,
1098 "children": node.children.iter().map(|id| id.to_string()).collect::<Vec<_>>(),
1099 "parent": node.parent.as_ref().map(|id| id.to_string()),
1100 "chunk_ids": node.chunk_ids.iter().map(|id| id.to_string()).collect::<Vec<_>>(),
1101 "keywords": node.keywords.clone(),
1102 "start_offset": node.start_offset,
1103 "end_offset": node.end_offset
1104 };
1105 tree_json["nodes"][node_id.to_string()] = node_json;
1106 }
1107
1108 tree_json["root_nodes"] = self
1110 .root_nodes
1111 .iter()
1112 .map(|id| JsonValue::String(id.to_string()))
1113 .collect::<Vec<_>>()
1114 .into();
1115
1116 for (level, node_ids) in &self.levels {
1118 tree_json["levels"][level.to_string()] = node_ids
1119 .iter()
1120 .map(|id| JsonValue::String(id.to_string()))
1121 .collect::<Vec<_>>()
1122 .into();
1123 }
1124
1125 Ok(tree_json.dump())
1126 }
1127
1128 pub fn from_json(json_str: &str) -> Result<Self> {
1130 let json_data = json::parse(json_str).map_err(crate::GraphRAGError::Json)?;
1131
1132 let document_id = DocumentId::new(
1133 json_data["document_id"]
1134 .as_str()
1135 .ok_or_else(|| {
1136 crate::GraphRAGError::Json(json::Error::WrongType(
1137 "document_id must be string".to_string(),
1138 ))
1139 })?
1140 .to_string(),
1141 );
1142
1143 let config_json = &json_data["config"];
1144 let config = HierarchicalConfig {
1145 merge_size: config_json["merge_size"].as_usize().unwrap_or(5),
1146 max_summary_length: config_json["max_summary_length"].as_usize().unwrap_or(200),
1147 min_node_size: config_json["min_node_size"].as_usize().unwrap_or(50),
1148 overlap_sentences: config_json["overlap_sentences"].as_usize().unwrap_or(2),
1149 llm_config: LLMConfig::default(),
1150 };
1151
1152 let mut tree = Self::new(document_id, config)?;
1153
1154 if let json::JsonValue::Object(nodes_obj) = &json_data["nodes"] {
1156 for (node_id_str, node_json) in nodes_obj.iter() {
1157 let node_id = NodeId::new(node_id_str.to_string());
1158
1159 let children: Vec<NodeId> = node_json["children"]
1160 .members()
1161 .filter_map(|v| v.as_str())
1162 .map(|s| NodeId::new(s.to_string()))
1163 .collect();
1164
1165 let parent = node_json["parent"]
1166 .as_str()
1167 .map(|s| NodeId::new(s.to_string()));
1168
1169 let chunk_ids: Vec<ChunkId> = node_json["chunk_ids"]
1170 .members()
1171 .filter_map(|v| v.as_str())
1172 .map(|s| ChunkId::new(s.to_string()))
1173 .collect();
1174
1175 let keywords: Vec<String> = node_json["keywords"]
1176 .members()
1177 .filter_map(|v| v.as_str())
1178 .map(|s| s.to_string())
1179 .collect();
1180
1181 let node = TreeNode {
1182 id: node_id.clone(),
1183 content: node_json["content"].as_str().unwrap_or("").to_string(),
1184 summary: node_json["summary"].as_str().unwrap_or("").to_string(),
1185 level: node_json["level"].as_usize().unwrap_or(0),
1186 children,
1187 parent,
1188 chunk_ids,
1189 keywords,
1190 start_offset: node_json["start_offset"].as_usize().unwrap_or(0),
1191 end_offset: node_json["end_offset"].as_usize().unwrap_or(0),
1192 };
1193
1194 tree.nodes.insert(node_id, node);
1195 }
1196 }
1197
1198 tree.root_nodes = json_data["root_nodes"]
1200 .members()
1201 .filter_map(|v| v.as_str())
1202 .map(|s| NodeId::new(s.to_string()))
1203 .collect();
1204
1205 if let json::JsonValue::Object(levels_obj) = &json_data["levels"] {
1207 for (level_str, level_json) in levels_obj.iter() {
1208 if let Ok(level) = level_str.parse::<usize>() {
1209 let node_ids: Vec<NodeId> = level_json
1210 .members()
1211 .filter_map(|v| v.as_str())
1212 .map(|s| NodeId::new(s.to_string()))
1213 .collect();
1214 tree.levels.insert(level, node_ids);
1215 }
1216 }
1217 }
1218
1219 Ok(tree)
1220 }
1221}
1222
1223#[derive(Debug, Clone)]
1225pub struct QueryResult {
1226 pub node_id: NodeId,
1228 pub score: f32,
1230 pub level: usize,
1232 pub summary: String,
1234 pub keywords: Vec<String>,
1236 pub chunk_ids: Vec<ChunkId>,
1238}
1239
1240#[derive(Debug)]
1242pub struct TreeStatistics {
1243 pub total_nodes: usize,
1245 pub max_level: usize,
1247 pub nodes_per_level: HashMap<usize, usize>,
1249 pub root_count: usize,
1251 pub document_id: DocumentId,
1253}
1254
1255impl TreeStatistics {
1256 pub fn print(&self) {
1258 println!("Hierarchical Tree Statistics:");
1259 println!(" Document ID: {}", self.document_id);
1260 println!(" Total nodes: {}", self.total_nodes);
1261 println!(" Max level: {}", self.max_level);
1262 println!(" Root nodes: {}", self.root_count);
1263 println!(" Nodes per level:");
1264
1265 let mut levels: Vec<_> = self.nodes_per_level.iter().collect();
1266 levels.sort_by_key(|(level, _)| *level);
1267
1268 for (level, count) in levels {
1269 println!(" Level {level}: {count} nodes");
1270 }
1271 }
1272}
1273
1274#[cfg(test)]
1275mod tests {
1276 use super::*;
1277 use crate::core::DocumentId;
1278
1279 #[test]
1280 fn test_extractive_summarization() {
1281 let config = HierarchicalConfig::default();
1282 let doc_id = DocumentId::new("test_doc".to_string());
1283 let tree = DocumentTree::new(doc_id, config).unwrap();
1284
1285 let text = "This is the first sentence. This is a second sentence with more details. This is the final sentence.";
1286 let summary = tree.generate_extractive_summary(text).unwrap();
1287
1288 assert!(!summary.is_empty());
1289 assert!(summary.len() <= tree.config.max_summary_length);
1290 }
1291
1292 #[test]
1293 fn test_json_serialization() {
1294 let config = HierarchicalConfig::default();
1295 let doc_id = DocumentId::new("test_doc".to_string());
1296 let tree = DocumentTree::new(doc_id, config).unwrap();
1297
1298 let json = tree.to_json().unwrap();
1299 assert!(json.contains("test_doc"));
1300
1301 let loaded_tree = DocumentTree::from_json(&json).unwrap();
1302 assert_eq!(loaded_tree.document_id.to_string(), "test_doc");
1303 }
1304}