1#[cfg(feature = "parallel-processing")]
2use crate::parallel::ParallelProcessor;
3use crate::{
4 core::{ChunkId, DocumentId, GraphRAGError, TextChunk},
5 text::TextProcessor,
6 Result,
7};
8use indexmap::IndexMap;
9use std::collections::{HashMap, VecDeque};
10use std::sync::Arc;
11
12#[async_trait::async_trait]
14pub trait LLMClient: Send + Sync {
15 async fn generate_summary(
17 &self,
18 text: &str,
19 prompt: &str,
20 max_tokens: usize,
21 temperature: f32,
22 ) -> Result<String>;
23
24 async fn generate_summary_batch(
26 &self,
27 texts: &[(&str, &str)],
28 max_tokens: usize,
29 temperature: f32,
30 ) -> Result<Vec<String>> {
31 let mut results = Vec::new();
32 for (text, prompt) in texts {
33 let summary = self
34 .generate_summary(text, prompt, max_tokens, temperature)
35 .await?;
36 results.push(summary);
37 }
38 Ok(results)
39 }
40
41 fn model_name(&self) -> &str;
43}
44
45#[derive(Debug, Clone, PartialEq, Eq, Hash)]
47pub struct NodeId(pub String);
48
49impl NodeId {
50 pub fn new(id: String) -> Self {
52 Self(id)
53 }
54}
55
56impl std::fmt::Display for NodeId {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 write!(f, "{}", self.0)
59 }
60}
61
62impl From<String> for NodeId {
63 fn from(s: String) -> Self {
64 Self(s)
65 }
66}
67
68impl From<NodeId> for String {
69 fn from(id: NodeId) -> Self {
70 id.0
71 }
72}
73
74#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
76pub struct HierarchicalConfig {
77 pub merge_size: usize,
79 pub max_summary_length: usize,
81 pub min_node_size: usize,
83 pub overlap_sentences: usize,
85 pub llm_config: LLMConfig,
87}
88
89#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
91pub struct LLMConfig {
92 pub enabled: bool,
94 pub model_name: String,
96 pub temperature: f32,
98 pub max_tokens: usize,
100 pub strategy: LLMStrategy,
102 pub level_configs: HashMap<usize, LevelConfig>,
104}
105
106#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
108pub enum LLMStrategy {
109 Uniform,
111 Adaptive,
113 Progressive,
115}
116
117#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
119pub struct LevelConfig {
120 pub max_length: usize,
122 pub use_abstractive: bool,
124 pub prompt_template: Option<String>,
126 pub temperature: Option<f32>,
128}
129
130impl Default for HierarchicalConfig {
131 fn default() -> Self {
132 Self {
133 merge_size: 5,
134 max_summary_length: 200,
135 min_node_size: 50,
136 overlap_sentences: 2,
137 llm_config: LLMConfig::default(),
138 }
139 }
140}
141
142impl Default for LLMConfig {
143 fn default() -> Self {
144 Self {
145 enabled: false, model_name: "llama3.1:8b".to_string(),
147 temperature: 0.3, max_tokens: 150,
149 strategy: LLMStrategy::Progressive,
150 level_configs: HashMap::new(),
151 }
152 }
153}
154
155impl Default for LevelConfig {
156 fn default() -> Self {
157 Self {
158 max_length: 200,
159 use_abstractive: false, prompt_template: None,
161 temperature: None,
162 }
163 }
164}
165
166#[derive(Debug, Clone)]
168pub struct TreeNode {
169 pub id: NodeId,
171 pub content: String,
173 pub summary: String,
175 pub level: usize,
177 pub children: Vec<NodeId>,
179 pub parent: Option<NodeId>,
181 pub chunk_ids: Vec<ChunkId>,
183 pub keywords: Vec<String>,
185 pub start_offset: usize,
187 pub end_offset: usize,
189}
190
191pub struct DocumentTree {
193 nodes: IndexMap<NodeId, TreeNode>,
194 root_nodes: Vec<NodeId>,
195 levels: HashMap<usize, Vec<NodeId>>,
196 document_id: DocumentId,
197 config: HierarchicalConfig,
198 text_processor: TextProcessor,
199 llm_client: Option<Arc<dyn LLMClient>>,
200}
201
202impl DocumentTree {
203 pub fn new(document_id: DocumentId, config: HierarchicalConfig) -> Result<Self> {
205 let text_processor = TextProcessor::new(1000, 100)?;
206
207 Ok(Self {
208 nodes: IndexMap::new(),
209 root_nodes: Vec::new(),
210 levels: HashMap::new(),
211 document_id,
212 config,
213 text_processor,
214 llm_client: None,
215 })
216 }
217
218 pub fn with_llm_client(
220 document_id: DocumentId,
221 config: HierarchicalConfig,
222 llm_client: Arc<dyn LLMClient>,
223 ) -> Result<Self> {
224 let text_processor = TextProcessor::new(1000, 100)?;
225
226 Ok(Self {
227 nodes: IndexMap::new(),
228 root_nodes: Vec::new(),
229 levels: HashMap::new(),
230 document_id,
231 config,
232 text_processor,
233 llm_client: Some(llm_client),
234 })
235 }
236
237 pub fn set_llm_client(&mut self, llm_client: Option<Arc<dyn LLMClient>>) {
239 self.llm_client = llm_client;
240 }
241
242 #[cfg(feature = "parallel-processing")]
244 pub fn with_parallel_processing(
245 document_id: DocumentId,
246 config: HierarchicalConfig,
247 _parallel_processor: ParallelProcessor,
248 ) -> Result<Self> {
249 let text_processor = TextProcessor::new(1000, 100)?;
250
251 Ok(Self {
252 nodes: IndexMap::new(),
253 root_nodes: Vec::new(),
254 levels: HashMap::new(),
255 document_id,
256 config,
257 text_processor,
258 llm_client: None,
259 })
260 }
261
262 #[cfg(feature = "parallel-processing")]
264 pub fn with_parallel_and_llm(
265 document_id: DocumentId,
266 config: HierarchicalConfig,
267 _parallel_processor: ParallelProcessor,
268 llm_client: Arc<dyn LLMClient>,
269 ) -> Result<Self> {
270 let text_processor = TextProcessor::new(1000, 100)?;
271
272 Ok(Self {
273 nodes: IndexMap::new(),
274 root_nodes: Vec::new(),
275 levels: HashMap::new(),
276 document_id,
277 config,
278 text_processor,
279 llm_client: Some(llm_client),
280 })
281 }
282
283 pub async fn build_from_chunks(&mut self, chunks: Vec<TextChunk>) -> Result<()> {
285 if chunks.is_empty() {
286 return Ok(());
287 }
288
289 println!("Building hierarchical tree from {} chunks", chunks.len());
290
291 let leaf_nodes = self.create_leaf_nodes(chunks)?;
293
294 self.build_bottom_up(leaf_nodes).await?;
296
297 println!(
298 "Tree built with {} total nodes across {} levels",
299 self.nodes.len(),
300 self.levels.len()
301 );
302
303 Ok(())
304 }
305
306 fn create_leaf_nodes(&mut self, chunks: Vec<TextChunk>) -> Result<Vec<NodeId>> {
308 if chunks.len() < 10 {
309 return self.create_leaf_nodes_sequential(chunks);
311 }
312
313 #[cfg(feature = "parallel-processing")]
314 {
315 use rayon::prelude::*;
316
317 let node_results: std::result::Result<Vec<_>, crate::GraphRAGError> = chunks
319 .par_iter()
320 .map(|chunk| {
321 let node_id = NodeId::new(format!("leaf_{}", chunk.id));
322
323 let temp_processor =
325 crate::text::TextProcessor::new(1000, 100).map_err(|e| {
326 crate::GraphRAGError::Config {
327 message: format!("Failed to create text processor: {e}"),
328 }
329 })?;
330
331 let keywords = temp_processor.extract_keywords(&chunk.content, 5);
333
334 let summary =
336 self.generate_parallel_summary(&chunk.content, &temp_processor)?;
337
338 let node = TreeNode {
339 id: node_id.clone(),
340 content: chunk.content.clone(),
341 summary,
342 level: 0,
343 children: Vec::new(),
344 parent: None,
345 chunk_ids: vec![chunk.id.clone()],
346 keywords,
347 start_offset: chunk.start_offset,
348 end_offset: chunk.end_offset,
349 };
350
351 Ok((node_id, node))
352 })
353 .collect();
354
355 match node_results {
356 Ok(nodes) => {
357 let mut leaf_node_ids = Vec::new();
358
359 for (node_id, node) in nodes {
361 leaf_node_ids.push(node_id.clone());
362 self.nodes.insert(node_id, node);
363 }
364
365 println!("Created {} leaf nodes in parallel", leaf_node_ids.len());
366
367 self.levels.insert(0, leaf_node_ids.clone());
369
370 Ok(leaf_node_ids)
371 },
372 Err(e) => {
373 eprintln!("Error in parallel node creation: {e}");
374 self.create_leaf_nodes_sequential(chunks)
376 },
377 }
378 }
379
380 #[cfg(not(feature = "parallel-processing"))]
381 {
382 self.create_leaf_nodes_sequential(chunks)
383 }
384 }
385
386 fn create_leaf_nodes_sequential(&mut self, chunks: Vec<TextChunk>) -> Result<Vec<NodeId>> {
388 let mut leaf_node_ids = Vec::new();
389
390 for chunk in chunks {
391 let node_id = NodeId::new(format!("leaf_{}", chunk.id));
392
393 let keywords = self.text_processor.extract_keywords(&chunk.content, 5);
395
396 let node = TreeNode {
397 id: node_id.clone(),
398 content: chunk.content.clone(),
399 summary: self.generate_extractive_summary(&chunk.content)?,
400 level: 0,
401 children: Vec::new(),
402 parent: None,
403 chunk_ids: vec![chunk.id],
404 keywords,
405 start_offset: chunk.start_offset,
406 end_offset: chunk.end_offset,
407 };
408
409 self.nodes.insert(node_id.clone(), node);
410 leaf_node_ids.push(node_id);
411 }
412
413 self.levels.insert(0, leaf_node_ids.clone());
415
416 Ok(leaf_node_ids)
417 }
418
419 pub async fn generate_llm_summary(
422 &self,
423 text: &str,
424 level: usize,
425 context: &str,
426 ) -> Result<String> {
427 let llm_client = self
428 .llm_client
429 .as_ref()
430 .ok_or_else(|| GraphRAGError::Config {
431 message: "LLM client not configured for summarization".to_string(),
432 })?;
433
434 let level_config = self.get_level_config(level);
436
437 let prompt = self.create_summary_prompt(text, level, context, &level_config)?;
439
440 let summary = llm_client
442 .generate_summary(
443 text,
444 &prompt,
445 level_config.max_length,
446 level_config
447 .temperature
448 .unwrap_or(self.config.llm_config.temperature),
449 )
450 .await?;
451
452 self.truncate_summary(&summary, level_config.max_length)
454 }
455
456 pub async fn generate_llm_summaries_batch(
458 &self,
459 texts: &[(&str, usize, &str)], ) -> Result<Vec<String>> {
461 let llm_client = self
462 .llm_client
463 .as_ref()
464 .ok_or_else(|| GraphRAGError::Config {
465 message: "LLM client not configured for summarization".to_string(),
466 })?;
467
468 let mut prompts = Vec::new();
469 let mut configs = Vec::new();
470
471 for (text, level, context) in texts {
472 let level_config = self.get_level_config(*level);
473 let prompt = self.create_summary_prompt(text, *level, context, &level_config)?;
474 prompts.push(prompt);
475 configs.push(level_config);
476 }
477
478 let text_refs: Vec<&str> = texts.iter().map(|(t, _, _)| *t).collect();
480 let prompt_refs: Vec<&str> = prompts.iter().map(|p| p.as_str()).collect();
481
482 let summaries = llm_client
483 .generate_summary_batch(
484 &text_refs
485 .iter()
486 .zip(prompt_refs.iter())
487 .map(|(&t, &p)| (t, p))
488 .collect::<Vec<_>>(),
489 self.config.llm_config.max_tokens,
490 self.config.llm_config.temperature,
491 )
492 .await?;
493
494 let mut results = Vec::new();
496 for (i, summary) in summaries.into_iter().enumerate() {
497 let truncated = self.truncate_summary(&summary, configs[i].max_length)?;
498 results.push(truncated);
499 }
500
501 Ok(results)
502 }
503
504 fn create_summary_prompt(
506 &self,
507 text: &str,
508 level: usize,
509 context: &str,
510 level_config: &LevelConfig,
511 ) -> Result<String> {
512 if let Some(template) = &level_config.prompt_template {
514 return Ok(template
515 .replace("{text}", text)
516 .replace("{context}", context)
517 .replace("{level}", &level.to_string())
518 .replace("{max_length}", &level_config.max_length.to_string()));
519 }
520
521 match self.config.llm_config.strategy {
523 LLMStrategy::Uniform => {
524 Ok(format!(
525 "Create a concise summary of the following text. The summary should be approximately {} characters long.\n\nContext: {}\n\nText to summarize:\n{}\n\nSummary:",
526 level_config.max_length, context, text
527 ))
528 }
529 LLMStrategy::Adaptive => {
530 if level == 0 {
531 Ok(format!(
532 "Extract the key information from this text segment. Keep it factual and under {} characters.\n\nContext: {}\n\nText:\n{}\n\nKey points:",
533 level_config.max_length, context, text
534 ))
535 } else if level <= 2 {
536 Ok(format!(
537 "Create a coherent summary that combines the key information from this text. Make it approximately {} characters.\n\nContext: {}\n\nText:\n{}\n\nSummary:",
538 level_config.max_length, context, text
539 ))
540 } else {
541 Ok(format!(
542 "Generate a high-level abstract summary of this content. Focus on the main themes and insights. Limit to approximately {} characters.\n\nContext: {}\n\nText:\n{}\n\nAbstract summary:",
543 level_config.max_length, context, text
544 ))
545 }
546 }
547 LLMStrategy::Progressive => {
548 if level_config.use_abstractive {
549 Ok(format!(
550 "Generate an abstractive summary that synthesizes the key concepts and relationships in this text. The summary should be approximately {} characters.\n\nContext: {}\n\nText:\n{}\n\nAbstractive summary:",
551 level_config.max_length, context, text
552 ))
553 } else {
554 Ok(format!(
555 "Extract and organize the most important sentences from this text to create a coherent summary. Keep it under {} characters.\n\nContext: {}\n\nText:\n{}\n\nExtractive summary:",
556 level_config.max_length, context, text
557 ))
558 }
559 }
560 }
561 }
562
563 fn get_level_config(&self, level: usize) -> LevelConfig {
565 self.config
566 .llm_config
567 .level_configs
568 .get(&level)
569 .cloned()
570 .unwrap_or({
571 LevelConfig {
573 max_length: self.config.max_summary_length,
574 use_abstractive: match self.config.llm_config.strategy {
575 LLMStrategy::Progressive => level >= 2,
576 LLMStrategy::Adaptive => level >= 3,
577 LLMStrategy::Uniform => level > 0,
578 },
579 prompt_template: None,
580 temperature: None,
581 }
582 })
583 }
584
585 fn truncate_summary(&self, summary: &str, max_length: usize) -> Result<String> {
587 if summary.len() <= max_length {
588 return Ok(summary.to_string());
589 }
590
591 let sentences: Vec<&str> = summary
593 .split('.')
594 .filter(|s| !s.trim().is_empty())
595 .collect();
596
597 let mut result = String::new();
598 for sentence in sentences {
599 if result.len() + sentence.len() < max_length - 3 {
600 if !result.is_empty() {
601 result.push('.');
602 }
603 result.push_str(sentence.trim());
604 } else {
605 break;
606 }
607 }
608
609 if result.is_empty() {
610 result = summary.chars().take(max_length - 3).collect();
612 result.push_str("...");
613 } else {
614 result.push('.');
615 }
616
617 Ok(result)
618 }
619
620 #[allow(dead_code)]
621 fn generate_parallel_summary(
622 &self,
623 text: &str,
624 processor: &crate::text::TextProcessor,
625 ) -> Result<String> {
626 let sentences = processor.extract_sentences(text);
627
628 if sentences.is_empty() {
629 return Ok(String::new());
630 }
631
632 if sentences.len() == 1 {
633 return Ok(sentences[0].clone());
634 }
635
636 let mut best_sentence = &sentences[0];
638 let mut best_score = 0.0;
639
640 for sentence in &sentences {
641 let words: Vec<&str> = sentence.split_whitespace().collect();
642
643 let length_score = if words.len() < 5 {
645 0.1
646 } else if words.len() > 30 {
647 0.3
648 } else {
649 1.0
650 };
651
652 let word_score = words.len() as f32 * 0.1;
653 let score = length_score + word_score;
654
655 if score > best_score {
656 best_score = score;
657 best_sentence = sentence;
658 }
659 }
660
661 if best_sentence.len() > self.config.max_summary_length {
663 Ok(best_sentence
664 .chars()
665 .take(self.config.max_summary_length - 3)
666 .collect::<String>()
667 + "...")
668 } else {
669 Ok(best_sentence.clone())
670 }
671 }
672
673 async fn build_bottom_up(&mut self, leaf_nodes: Vec<NodeId>) -> Result<()> {
675 let mut current_level_nodes = leaf_nodes;
676 let mut current_level = 0;
677
678 while current_level_nodes.len() > 1 {
679 let next_level_nodes = self
680 .merge_level(¤t_level_nodes, current_level + 1)
681 .await?;
682
683 current_level_nodes = next_level_nodes;
684 current_level += 1;
685 }
686
687 self.root_nodes = current_level_nodes;
689
690 Ok(())
691 }
692
693 async fn merge_level(
695 &mut self,
696 level_nodes: &[NodeId],
697 new_level: usize,
698 ) -> Result<Vec<NodeId>> {
699 let mut new_level_nodes = Vec::new();
700 for (node_counter, chunk) in level_nodes.chunks(self.config.merge_size).enumerate() {
702 let merged_node_id = NodeId::new(format!("level_{new_level}_{node_counter}"));
703 let merged_node = self
704 .merge_nodes(chunk, merged_node_id.clone(), new_level)
705 .await?;
706
707 for child_id in chunk {
709 if let Some(child_node) = self.nodes.get_mut(child_id) {
710 child_node.parent = Some(merged_node_id.clone());
711 }
712 }
713
714 self.nodes.insert(merged_node_id.clone(), merged_node);
715 new_level_nodes.push(merged_node_id);
716 }
717
718 self.levels.insert(new_level, new_level_nodes.clone());
720
721 Ok(new_level_nodes)
722 }
723
724 async fn merge_nodes(
726 &self,
727 node_ids: &[NodeId],
728 merged_id: NodeId,
729 level: usize,
730 ) -> Result<TreeNode> {
731 let mut combined_content = String::new();
732 let mut all_chunk_ids = Vec::new();
733 let mut all_keywords = Vec::new();
734 let mut min_offset = usize::MAX;
735 let mut max_offset = 0;
736
737 for node_id in node_ids {
738 if let Some(node) = self.nodes.get(node_id) {
739 if !combined_content.is_empty() {
740 combined_content.push_str("\n\n");
741 }
742 combined_content.push_str(&node.content);
743 all_chunk_ids.extend(node.chunk_ids.clone());
744 all_keywords.extend(node.keywords.clone());
745 min_offset = min_offset.min(node.start_offset);
746 max_offset = max_offset.max(node.end_offset);
747 }
748 }
749
750 all_keywords.sort();
752 all_keywords.dedup();
753 all_keywords.truncate(10);
754
755 let summary = if self.config.llm_config.enabled {
757 if self.llm_client.is_some() {
759 let context = format!(
761 "Merging {} nodes at level {}. This represents a higher-level abstraction of the document content.",
762 node_ids.len(),
763 level
764 );
765
766 match self
768 .generate_llm_summary(&combined_content, level, &context)
769 .await
770 {
771 Ok(llm_summary) => {
772 println!(
773 "✅ Generated LLM-based summary for level {} ({} chars)",
774 level,
775 llm_summary.len()
776 );
777 llm_summary
778 },
779 Err(e) => {
780 eprintln!("⚠️ LLM summarization failed for level {}: {}, falling back to extractive", level, e);
781 self.generate_extractive_summary(&combined_content)?
782 },
783 }
784 } else {
785 self.generate_extractive_summary(&combined_content)?
786 }
787 } else {
788 self.generate_extractive_summary(&combined_content)?
789 };
790
791 Ok(TreeNode {
792 id: merged_id,
793 content: combined_content,
794 summary,
795 level,
796 children: node_ids.to_vec(),
797 parent: None,
798 chunk_ids: all_chunk_ids,
799 keywords: all_keywords,
800 start_offset: min_offset,
801 end_offset: max_offset,
802 })
803 }
804
805 fn generate_extractive_summary(&self, text: &str) -> Result<String> {
807 let sentences = self.text_processor.extract_sentences(text);
808
809 if sentences.is_empty() {
810 return Ok(String::new());
811 }
812
813 if sentences.len() == 1 {
814 return Ok(sentences[0].clone());
815 }
816
817 let mut sentence_scores: Vec<(usize, f32)> = sentences
819 .iter()
820 .enumerate()
821 .map(|(i, sentence)| {
822 let score = self.score_sentence(sentence, &sentences);
823 (i, score)
824 })
825 .collect();
826
827 sentence_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
829
830 let mut summary = String::new();
832 let mut selected_indices = Vec::new();
833
834 for (sentence_idx, _score) in sentence_scores {
835 let sentence = &sentences[sentence_idx];
836 if summary.len() + sentence.len() <= self.config.max_summary_length {
837 selected_indices.push(sentence_idx);
838 if !summary.is_empty() {
839 summary.push(' ');
840 }
841 summary.push_str(sentence);
842 }
843 }
844
845 if summary.is_empty() && !sentences.is_empty() {
847 let first_sentence = &sentences[0];
848 if first_sentence.len() <= self.config.max_summary_length {
849 summary = first_sentence.clone();
850 } else {
851 summary = first_sentence
852 .chars()
853 .take(self.config.max_summary_length - 3)
854 .collect::<String>()
855 + "...";
856 }
857 }
858
859 Ok(summary)
860 }
861
862 fn score_sentence(&self, sentence: &str, all_sentences: &[String]) -> f32 {
864 let words: Vec<&str> = sentence.split_whitespace().collect();
865
866 let length_score = if words.len() < 5 {
868 0.1
869 } else if words.len() > 30 {
870 0.3
871 } else {
872 1.0
873 };
874
875 let position_score = 0.5; let mut word_freq_score = 0.0;
880 let total_words: Vec<&str> = all_sentences
881 .iter()
882 .flat_map(|s| s.split_whitespace())
883 .collect();
884
885 for word in &words {
886 let word_lower = word.to_lowercase();
887 if word_lower.len() > 3 && !self.is_stop_word(&word_lower) {
888 let freq = total_words
889 .iter()
890 .filter(|&&w| w.to_lowercase() == word_lower)
891 .count();
892 if freq > 1 {
893 word_freq_score += freq as f32 / total_words.len() as f32;
894 }
895 }
896 }
897
898 length_score * 0.4 + position_score * 0.2 + word_freq_score * 0.4
899 }
900
901 fn is_stop_word(&self, word: &str) -> bool {
903 const STOP_WORDS: &[&str] = &[
904 "the", "be", "to", "of", "and", "a", "in", "that", "have", "i", "it", "for", "not",
905 "on", "with", "he", "as", "you", "do", "at", "this", "but", "his", "by", "from",
906 "they", "we", "say", "her", "she", "or", "an", "will", "my", "one", "all", "would",
907 "there", "their", "what", "so", "up", "out", "if", "about", "who", "get", "which",
908 "go", "me",
909 ];
910 STOP_WORDS.contains(&word)
911 }
912
913 pub fn query(&self, query: &str, max_results: usize) -> Result<Vec<QueryResult>> {
915 let query_keywords = self.text_processor.extract_keywords(query, 5);
916 let mut results = Vec::new();
917
918 for (node_id, node) in &self.nodes {
920 let score = self.calculate_relevance_score(node, &query_keywords, query);
921
922 if score > 0.1 {
923 results.push(QueryResult {
924 node_id: node_id.clone(),
925 score,
926 level: node.level,
927 summary: node.summary.clone(),
928 keywords: node.keywords.clone(),
929 chunk_ids: node.chunk_ids.clone(),
930 });
931 }
932 }
933
934 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
936 results.truncate(max_results);
937
938 Ok(results)
939 }
940
941 fn calculate_relevance_score(
943 &self,
944 node: &TreeNode,
945 query_keywords: &[String],
946 query: &str,
947 ) -> f32 {
948 let mut score = 0.0;
949
950 let node_text = format!("{} {}", node.summary, node.keywords.join(" ")).to_lowercase();
952 for keyword in query_keywords {
953 if node_text.contains(&keyword.to_lowercase()) {
954 score += 1.0;
955 }
956 }
957
958 let query_words: Vec<&str> = query.split_whitespace().collect();
960 let node_words: Vec<&str> = node_text.split_whitespace().collect();
961
962 let mut overlap_count = 0;
963 for query_word in &query_words {
964 if node_words.contains(&query_word.to_lowercase().as_str()) {
965 overlap_count += 1;
966 }
967 }
968
969 if !query_words.is_empty() {
970 score += (overlap_count as f32 / query_words.len() as f32) * 2.0;
971 }
972
973 let level_score = 1.0 / (node.level + 1) as f32;
975 score += level_score * 0.5;
976
977 score
978 }
979
980 pub fn get_ancestors(&self, node_id: &NodeId) -> Vec<&TreeNode> {
982 let mut ancestors = Vec::new();
983 let mut current_id = Some(node_id.clone());
984
985 while let Some(id) = current_id {
986 if let Some(node) = self.nodes.get(&id) {
987 ancestors.push(node);
988 current_id = node.parent.clone();
989 } else {
990 break;
991 }
992 }
993
994 ancestors
995 }
996
997 pub fn get_descendants(&self, node_id: &NodeId) -> Vec<&TreeNode> {
999 let mut descendants = Vec::new();
1000 let mut queue = VecDeque::new();
1001
1002 if let Some(node) = self.nodes.get(node_id) {
1003 queue.extend(node.children.iter());
1004 }
1005
1006 while let Some(child_id) = queue.pop_front() {
1007 if let Some(child_node) = self.nodes.get(child_id) {
1008 descendants.push(child_node);
1009 queue.extend(child_node.children.iter());
1010 }
1011 }
1012
1013 descendants
1014 }
1015
1016 pub fn get_node(&self, node_id: &NodeId) -> Option<&TreeNode> {
1018 self.nodes.get(node_id)
1019 }
1020
1021 pub fn get_level_nodes(&self, level: usize) -> Vec<&TreeNode> {
1023 if let Some(node_ids) = self.levels.get(&level) {
1024 node_ids
1025 .iter()
1026 .filter_map(|id| self.nodes.get(id))
1027 .collect()
1028 } else {
1029 Vec::new()
1030 }
1031 }
1032
1033 pub fn get_root_nodes(&self) -> Vec<&TreeNode> {
1035 self.root_nodes
1036 .iter()
1037 .filter_map(|id| self.nodes.get(id))
1038 .collect()
1039 }
1040
1041 pub fn document_id(&self) -> &DocumentId {
1043 &self.document_id
1044 }
1045
1046 pub fn get_statistics(&self) -> TreeStatistics {
1048 let max_level = self.levels.keys().max().copied().unwrap_or(0);
1049 let total_nodes = self.nodes.len();
1050 let nodes_per_level: HashMap<usize, usize> = self
1051 .levels
1052 .iter()
1053 .map(|(level, nodes)| (*level, nodes.len()))
1054 .collect();
1055
1056 TreeStatistics {
1057 total_nodes,
1058 max_level,
1059 nodes_per_level,
1060 root_count: self.root_nodes.len(),
1061 document_id: self.document_id.clone(),
1062 }
1063 }
1064
1065 pub fn to_json(&self) -> Result<String> {
1067 use json::JsonValue;
1068
1069 let mut tree_json = json::object! {
1070 "document_id": self.document_id.to_string(),
1071 "config": {
1072 "merge_size": self.config.merge_size,
1073 "max_summary_length": self.config.max_summary_length,
1074 "min_node_size": self.config.min_node_size,
1075 "overlap_sentences": self.config.overlap_sentences
1076 },
1077 "nodes": {},
1078 "root_nodes": [],
1079 "levels": {}
1080 };
1081
1082 for (node_id, node) in &self.nodes {
1084 let node_json = json::object! {
1085 "id": node_id.to_string(),
1086 "content": node.content.clone(),
1087 "summary": node.summary.clone(),
1088 "level": node.level,
1089 "children": node.children.iter().map(|id| id.to_string()).collect::<Vec<_>>(),
1090 "parent": node.parent.as_ref().map(|id| id.to_string()),
1091 "chunk_ids": node.chunk_ids.iter().map(|id| id.to_string()).collect::<Vec<_>>(),
1092 "keywords": node.keywords.clone(),
1093 "start_offset": node.start_offset,
1094 "end_offset": node.end_offset
1095 };
1096 tree_json["nodes"][node_id.to_string()] = node_json;
1097 }
1098
1099 tree_json["root_nodes"] = self
1101 .root_nodes
1102 .iter()
1103 .map(|id| JsonValue::String(id.to_string()))
1104 .collect::<Vec<_>>()
1105 .into();
1106
1107 for (level, node_ids) in &self.levels {
1109 tree_json["levels"][level.to_string()] = node_ids
1110 .iter()
1111 .map(|id| JsonValue::String(id.to_string()))
1112 .collect::<Vec<_>>()
1113 .into();
1114 }
1115
1116 Ok(tree_json.dump())
1117 }
1118
1119 pub fn from_json(json_str: &str) -> Result<Self> {
1121 let json_data = json::parse(json_str).map_err(crate::GraphRAGError::Json)?;
1122
1123 let document_id = DocumentId::new(
1124 json_data["document_id"]
1125 .as_str()
1126 .ok_or_else(|| {
1127 crate::GraphRAGError::Json(json::Error::WrongType(
1128 "document_id must be string".to_string(),
1129 ))
1130 })?
1131 .to_string(),
1132 );
1133
1134 let config_json = &json_data["config"];
1135 let config = HierarchicalConfig {
1136 merge_size: config_json["merge_size"].as_usize().unwrap_or(5),
1137 max_summary_length: config_json["max_summary_length"].as_usize().unwrap_or(200),
1138 min_node_size: config_json["min_node_size"].as_usize().unwrap_or(50),
1139 overlap_sentences: config_json["overlap_sentences"].as_usize().unwrap_or(2),
1140 llm_config: LLMConfig::default(),
1141 };
1142
1143 let mut tree = Self::new(document_id, config)?;
1144
1145 if let json::JsonValue::Object(nodes_obj) = &json_data["nodes"] {
1147 for (node_id_str, node_json) in nodes_obj.iter() {
1148 let node_id = NodeId::new(node_id_str.to_string());
1149
1150 let children: Vec<NodeId> = node_json["children"]
1151 .members()
1152 .filter_map(|v| v.as_str())
1153 .map(|s| NodeId::new(s.to_string()))
1154 .collect();
1155
1156 let parent = node_json["parent"]
1157 .as_str()
1158 .map(|s| NodeId::new(s.to_string()));
1159
1160 let chunk_ids: Vec<ChunkId> = node_json["chunk_ids"]
1161 .members()
1162 .filter_map(|v| v.as_str())
1163 .map(|s| ChunkId::new(s.to_string()))
1164 .collect();
1165
1166 let keywords: Vec<String> = node_json["keywords"]
1167 .members()
1168 .filter_map(|v| v.as_str())
1169 .map(|s| s.to_string())
1170 .collect();
1171
1172 let node = TreeNode {
1173 id: node_id.clone(),
1174 content: node_json["content"].as_str().unwrap_or("").to_string(),
1175 summary: node_json["summary"].as_str().unwrap_or("").to_string(),
1176 level: node_json["level"].as_usize().unwrap_or(0),
1177 children,
1178 parent,
1179 chunk_ids,
1180 keywords,
1181 start_offset: node_json["start_offset"].as_usize().unwrap_or(0),
1182 end_offset: node_json["end_offset"].as_usize().unwrap_or(0),
1183 };
1184
1185 tree.nodes.insert(node_id, node);
1186 }
1187 }
1188
1189 tree.root_nodes = json_data["root_nodes"]
1191 .members()
1192 .filter_map(|v| v.as_str())
1193 .map(|s| NodeId::new(s.to_string()))
1194 .collect();
1195
1196 if let json::JsonValue::Object(levels_obj) = &json_data["levels"] {
1198 for (level_str, level_json) in levels_obj.iter() {
1199 if let Ok(level) = level_str.parse::<usize>() {
1200 let node_ids: Vec<NodeId> = level_json
1201 .members()
1202 .filter_map(|v| v.as_str())
1203 .map(|s| NodeId::new(s.to_string()))
1204 .collect();
1205 tree.levels.insert(level, node_ids);
1206 }
1207 }
1208 }
1209
1210 Ok(tree)
1211 }
1212}
1213
1214#[derive(Debug, Clone)]
1216pub struct QueryResult {
1217 pub node_id: NodeId,
1219 pub score: f32,
1221 pub level: usize,
1223 pub summary: String,
1225 pub keywords: Vec<String>,
1227 pub chunk_ids: Vec<ChunkId>,
1229}
1230
1231#[derive(Debug)]
1233pub struct TreeStatistics {
1234 pub total_nodes: usize,
1236 pub max_level: usize,
1238 pub nodes_per_level: HashMap<usize, usize>,
1240 pub root_count: usize,
1242 pub document_id: DocumentId,
1244}
1245
1246impl TreeStatistics {
1247 pub fn print(&self) {
1249 println!("Hierarchical Tree Statistics:");
1250 println!(" Document ID: {}", self.document_id);
1251 println!(" Total nodes: {}", self.total_nodes);
1252 println!(" Max level: {}", self.max_level);
1253 println!(" Root nodes: {}", self.root_count);
1254 println!(" Nodes per level:");
1255
1256 let mut levels: Vec<_> = self.nodes_per_level.iter().collect();
1257 levels.sort_by_key(|(level, _)| *level);
1258
1259 for (level, count) in levels {
1260 println!(" Level {level}: {count} nodes");
1261 }
1262 }
1263}
1264
1265#[cfg(test)]
1266mod tests {
1267 use super::*;
1268 use crate::core::DocumentId;
1269
1270 #[test]
1271 fn test_tree_creation() {
1272 let config = HierarchicalConfig::default();
1273 let doc_id = DocumentId::new("test_doc".to_string());
1274 let tree = DocumentTree::new(doc_id, config);
1275 assert!(tree.is_ok());
1276 }
1277
1278 #[test]
1279 fn test_extractive_summarization() {
1280 let config = HierarchicalConfig::default();
1281 let doc_id = DocumentId::new("test_doc".to_string());
1282 let tree = DocumentTree::new(doc_id, config).unwrap();
1283
1284 let text = "This is the first sentence. This is a second sentence with more details. This is the final sentence.";
1285 let summary = tree.generate_extractive_summary(text).unwrap();
1286
1287 assert!(!summary.is_empty());
1288 assert!(summary.len() <= tree.config.max_summary_length);
1289 }
1290
1291 #[test]
1292 fn test_json_serialization() {
1293 let config = HierarchicalConfig::default();
1294 let doc_id = DocumentId::new("test_doc".to_string());
1295 let tree = DocumentTree::new(doc_id, config).unwrap();
1296
1297 let json = tree.to_json().unwrap();
1298 assert!(json.contains("test_doc"));
1299
1300 let loaded_tree = DocumentTree::from_json(&json).unwrap();
1301 assert_eq!(loaded_tree.document_id.to_string(), "test_doc");
1302 }
1303}