1use crate::{
2 core::{ChunkId, DocumentId, TextChunk, GraphRAGError},
3 text::TextProcessor,
4 Result,
5};
6#[cfg(feature = "parallel-processing")]
7use crate::parallel::ParallelProcessor;
8use indexmap::IndexMap;
9use std::collections::{HashMap, VecDeque};
10use std::sync::Arc;
11
12#[async_trait::async_trait]
14pub trait LLMClient: Send + Sync {
15 async fn generate_summary(&self, text: &str, prompt: &str, max_tokens: usize, temperature: f32) -> Result<String>;
17
18 async fn generate_summary_batch(&self, texts: &[(&str, &str)], max_tokens: usize, temperature: f32) -> Result<Vec<String>> {
20 let mut results = Vec::new();
21 for (text, prompt) in texts {
22 let summary = self.generate_summary(text, prompt, max_tokens, temperature).await?;
23 results.push(summary);
24 }
25 Ok(results)
26 }
27
28 fn model_name(&self) -> &str;
30}
31
32#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub struct NodeId(pub String);
35
36impl NodeId {
37 pub fn new(id: String) -> Self {
39 Self(id)
40 }
41}
42
43impl std::fmt::Display for NodeId {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 write!(f, "{}", self.0)
46 }
47}
48
49impl From<String> for NodeId {
50 fn from(s: String) -> Self {
51 Self(s)
52 }
53}
54
55impl From<NodeId> for String {
56 fn from(id: NodeId) -> Self {
57 id.0
58 }
59}
60
61#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
63pub struct HierarchicalConfig {
64 pub merge_size: usize,
66 pub max_summary_length: usize,
68 pub min_node_size: usize,
70 pub overlap_sentences: usize,
72 pub llm_config: LLMConfig,
74}
75
76#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
78pub struct LLMConfig {
79 pub enabled: bool,
81 pub model_name: String,
83 pub temperature: f32,
85 pub max_tokens: usize,
87 pub strategy: LLMStrategy,
89 pub level_configs: HashMap<usize, LevelConfig>,
91}
92
93#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
95pub enum LLMStrategy {
96 Uniform,
98 Adaptive,
100 Progressive,
102}
103
104#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
106pub struct LevelConfig {
107 pub max_length: usize,
109 pub use_abstractive: bool,
111 pub prompt_template: Option<String>,
113 pub temperature: Option<f32>,
115}
116
117impl Default for HierarchicalConfig {
118 fn default() -> Self {
119 Self {
120 merge_size: 5,
121 max_summary_length: 200,
122 min_node_size: 50,
123 overlap_sentences: 2,
124 llm_config: LLMConfig::default(),
125 }
126 }
127}
128
129impl Default for LLMConfig {
130 fn default() -> Self {
131 Self {
132 enabled: false, model_name: "llama3.1:8b".to_string(),
134 temperature: 0.3, max_tokens: 150,
136 strategy: LLMStrategy::Progressive,
137 level_configs: HashMap::new(),
138 }
139 }
140}
141
142impl Default for LevelConfig {
143 fn default() -> Self {
144 Self {
145 max_length: 200,
146 use_abstractive: false, prompt_template: None,
148 temperature: None,
149 }
150 }
151}
152
153#[derive(Debug, Clone)]
155pub struct TreeNode {
156 pub id: NodeId,
158 pub content: String,
160 pub summary: String,
162 pub level: usize,
164 pub children: Vec<NodeId>,
166 pub parent: Option<NodeId>,
168 pub chunk_ids: Vec<ChunkId>,
170 pub keywords: Vec<String>,
172 pub start_offset: usize,
174 pub end_offset: usize,
176}
177
178pub struct DocumentTree {
180 nodes: IndexMap<NodeId, TreeNode>,
181 root_nodes: Vec<NodeId>,
182 levels: HashMap<usize, Vec<NodeId>>,
183 document_id: DocumentId,
184 config: HierarchicalConfig,
185 text_processor: TextProcessor,
186 llm_client: Option<Arc<dyn LLMClient>>,
187}
188
189impl DocumentTree {
190 pub fn new(document_id: DocumentId, config: HierarchicalConfig) -> Result<Self> {
192 let text_processor = TextProcessor::new(1000, 100)?;
193
194 Ok(Self {
195 nodes: IndexMap::new(),
196 root_nodes: Vec::new(),
197 levels: HashMap::new(),
198 document_id,
199 config,
200 text_processor,
201 llm_client: None,
202 })
203 }
204
205 pub fn with_llm_client(
207 document_id: DocumentId,
208 config: HierarchicalConfig,
209 llm_client: Arc<dyn LLMClient>,
210 ) -> Result<Self> {
211 let text_processor = TextProcessor::new(1000, 100)?;
212
213 Ok(Self {
214 nodes: IndexMap::new(),
215 root_nodes: Vec::new(),
216 levels: HashMap::new(),
217 document_id,
218 config,
219 text_processor,
220 llm_client: Some(llm_client),
221 })
222 }
223
224 pub fn set_llm_client(&mut self, llm_client: Option<Arc<dyn LLMClient>>) {
226 self.llm_client = llm_client;
227 }
228
229 #[cfg(feature = "parallel-processing")]
231 pub fn with_parallel_processing(
232 document_id: DocumentId,
233 config: HierarchicalConfig,
234 _parallel_processor: ParallelProcessor,
235 ) -> Result<Self> {
236 let text_processor = TextProcessor::new(1000, 100)?;
237
238 Ok(Self {
239 nodes: IndexMap::new(),
240 root_nodes: Vec::new(),
241 levels: HashMap::new(),
242 document_id,
243 config,
244 text_processor,
245 llm_client: None,
246 })
247 }
248
249 #[cfg(feature = "parallel-processing")]
251 pub fn with_parallel_and_llm(
252 document_id: DocumentId,
253 config: HierarchicalConfig,
254 _parallel_processor: ParallelProcessor,
255 llm_client: Arc<dyn LLMClient>,
256 ) -> Result<Self> {
257 let text_processor = TextProcessor::new(1000, 100)?;
258
259 Ok(Self {
260 nodes: IndexMap::new(),
261 root_nodes: Vec::new(),
262 levels: HashMap::new(),
263 document_id,
264 config,
265 text_processor,
266 llm_client: Some(llm_client),
267 })
268 }
269
270 pub async fn build_from_chunks(&mut self, chunks: Vec<TextChunk>) -> Result<()> {
272 if chunks.is_empty() {
273 return Ok(());
274 }
275
276 println!("Building hierarchical tree from {} chunks", chunks.len());
277
278 let leaf_nodes = self.create_leaf_nodes(chunks)?;
280
281 self.build_bottom_up(leaf_nodes).await?;
283
284 println!(
285 "Tree built with {} total nodes across {} levels",
286 self.nodes.len(),
287 self.levels.len()
288 );
289
290 Ok(())
291 }
292
293 fn create_leaf_nodes(&mut self, chunks: Vec<TextChunk>) -> Result<Vec<NodeId>> {
295 if chunks.len() < 10 {
296 return self.create_leaf_nodes_sequential(chunks);
298 }
299
300 #[cfg(feature = "parallel-processing")]
301 {
302 use rayon::prelude::*;
303
304 let node_results: std::result::Result<Vec<_>, crate::GraphRAGError> = chunks
306 .par_iter()
307 .map(|chunk| {
308 let node_id = NodeId::new(format!("leaf_{}", chunk.id));
309
310 let temp_processor = crate::text::TextProcessor::new(1000, 100)
312 .map_err(|e| crate::GraphRAGError::Config {
313 message: format!("Failed to create text processor: {e}"),
314 })?;
315
316 let keywords = temp_processor.extract_keywords(&chunk.content, 5);
318
319 let summary = self.generate_parallel_summary(&chunk.content, &temp_processor)?;
321
322 let node = TreeNode {
323 id: node_id.clone(),
324 content: chunk.content.clone(),
325 summary,
326 level: 0,
327 children: Vec::new(),
328 parent: None,
329 chunk_ids: vec![chunk.id.clone()],
330 keywords,
331 start_offset: chunk.start_offset,
332 end_offset: chunk.end_offset,
333 };
334
335 Ok((node_id, node))
336 })
337 .collect();
338
339 match node_results {
340 Ok(nodes) => {
341 let mut leaf_node_ids = Vec::new();
342
343 for (node_id, node) in nodes {
345 leaf_node_ids.push(node_id.clone());
346 self.nodes.insert(node_id, node);
347 }
348
349 println!("Created {} leaf nodes in parallel", leaf_node_ids.len());
350
351 self.levels.insert(0, leaf_node_ids.clone());
353
354 Ok(leaf_node_ids)
355 }
356 Err(e) => {
357 eprintln!("Error in parallel node creation: {e}");
358 self.create_leaf_nodes_sequential(chunks)
360 }
361 }
362 }
363
364 #[cfg(not(feature = "parallel-processing"))]
365 {
366 self.create_leaf_nodes_sequential(chunks)
367 }
368 }
369
370 fn create_leaf_nodes_sequential(&mut self, chunks: Vec<TextChunk>) -> Result<Vec<NodeId>> {
372 let mut leaf_node_ids = Vec::new();
373
374 for chunk in chunks {
375 let node_id = NodeId::new(format!("leaf_{}", chunk.id));
376
377 let keywords = self.text_processor.extract_keywords(&chunk.content, 5);
379
380 let node = TreeNode {
381 id: node_id.clone(),
382 content: chunk.content.clone(),
383 summary: self.generate_extractive_summary(&chunk.content)?,
384 level: 0,
385 children: Vec::new(),
386 parent: None,
387 chunk_ids: vec![chunk.id],
388 keywords,
389 start_offset: chunk.start_offset,
390 end_offset: chunk.end_offset,
391 };
392
393 self.nodes.insert(node_id.clone(), node);
394 leaf_node_ids.push(node_id);
395 }
396
397 self.levels.insert(0, leaf_node_ids.clone());
399
400 Ok(leaf_node_ids)
401 }
402
403 pub async fn generate_llm_summary(
406 &self,
407 text: &str,
408 level: usize,
409 context: &str,
410 ) -> Result<String> {
411 let llm_client = self.llm_client.as_ref()
412 .ok_or_else(|| GraphRAGError::Config {
413 message: "LLM client not configured for summarization".to_string(),
414 })?;
415
416 let level_config = self.get_level_config(level);
418
419 let prompt = self.create_summary_prompt(text, level, context, &level_config)?;
421
422 let summary = llm_client.generate_summary(
424 text,
425 &prompt,
426 level_config.max_length,
427 level_config.temperature.unwrap_or(self.config.llm_config.temperature),
428 ).await?;
429
430 self.truncate_summary(&summary, level_config.max_length)
432 }
433
434 pub async fn generate_llm_summaries_batch(
436 &self,
437 texts: &[(&str, usize, &str)], ) -> Result<Vec<String>> {
439 let llm_client = self.llm_client.as_ref()
440 .ok_or_else(|| GraphRAGError::Config {
441 message: "LLM client not configured for summarization".to_string(),
442 })?;
443
444 let mut prompts = Vec::new();
445 let mut configs = Vec::new();
446
447 for (text, level, context) in texts {
448 let level_config = self.get_level_config(*level);
449 let prompt = self.create_summary_prompt(text, *level, context, &level_config)?;
450 prompts.push(prompt);
451 configs.push(level_config);
452 }
453
454 let text_refs: Vec<&str> = texts.iter().map(|(t, _, _)| *t).collect();
456 let prompt_refs: Vec<&str> = prompts.iter().map(|p| p.as_str()).collect();
457
458 let summaries = llm_client.generate_summary_batch(
459 &text_refs.iter().zip(prompt_refs.iter()).map(|(&t, &p)| (t, p)).collect::<Vec<_>>(),
460 self.config.llm_config.max_tokens,
461 self.config.llm_config.temperature,
462 ).await?;
463
464 let mut results = Vec::new();
466 for (i, summary) in summaries.into_iter().enumerate() {
467 let truncated = self.truncate_summary(&summary, configs[i].max_length)?;
468 results.push(truncated);
469 }
470
471 Ok(results)
472 }
473
474 fn create_summary_prompt(
476 &self,
477 text: &str,
478 level: usize,
479 context: &str,
480 level_config: &LevelConfig,
481 ) -> Result<String> {
482 if let Some(template) = &level_config.prompt_template {
484 return Ok(template
485 .replace("{text}", text)
486 .replace("{context}", context)
487 .replace("{level}", &level.to_string())
488 .replace("{max_length}", &level_config.max_length.to_string()));
489 }
490
491 match self.config.llm_config.strategy {
493 LLMStrategy::Uniform => {
494 Ok(format!(
495 "Create a concise summary of the following text. The summary should be approximately {} characters long.\n\nContext: {}\n\nText to summarize:\n{}\n\nSummary:",
496 level_config.max_length, context, text
497 ))
498 }
499 LLMStrategy::Adaptive => {
500 if level == 0 {
501 Ok(format!(
502 "Extract the key information from this text segment. Keep it factual and under {} characters.\n\nContext: {}\n\nText:\n{}\n\nKey points:",
503 level_config.max_length, context, text
504 ))
505 } else if level <= 2 {
506 Ok(format!(
507 "Create a coherent summary that combines the key information from this text. Make it approximately {} characters.\n\nContext: {}\n\nText:\n{}\n\nSummary:",
508 level_config.max_length, context, text
509 ))
510 } else {
511 Ok(format!(
512 "Generate a high-level abstract summary of this content. Focus on the main themes and insights. Limit to approximately {} characters.\n\nContext: {}\n\nText:\n{}\n\nAbstract summary:",
513 level_config.max_length, context, text
514 ))
515 }
516 }
517 LLMStrategy::Progressive => {
518 if level_config.use_abstractive {
519 Ok(format!(
520 "Generate an abstractive summary that synthesizes the key concepts and relationships in this text. The summary should be approximately {} characters.\n\nContext: {}\n\nText:\n{}\n\nAbstractive summary:",
521 level_config.max_length, context, text
522 ))
523 } else {
524 Ok(format!(
525 "Extract and organize the most important sentences from this text to create a coherent summary. Keep it under {} characters.\n\nContext: {}\n\nText:\n{}\n\nExtractive summary:",
526 level_config.max_length, context, text
527 ))
528 }
529 }
530 }
531 }
532
533 fn get_level_config(&self, level: usize) -> LevelConfig {
535 self.config.llm_config.level_configs
536 .get(&level)
537 .cloned()
538 .unwrap_or_else(|| {
539 LevelConfig {
541 max_length: self.config.max_summary_length,
542 use_abstractive: match self.config.llm_config.strategy {
543 LLMStrategy::Progressive => level >= 2,
544 LLMStrategy::Adaptive => level >= 3,
545 LLMStrategy::Uniform => level > 0,
546 },
547 prompt_template: None,
548 temperature: None,
549 }
550 })
551 }
552
553 fn truncate_summary(&self, summary: &str, max_length: usize) -> Result<String> {
555 if summary.len() <= max_length {
556 return Ok(summary.to_string());
557 }
558
559 let sentences: Vec<&str> = summary
561 .split('.')
562 .filter(|s| !s.trim().is_empty())
563 .collect();
564
565 let mut result = String::new();
566 for sentence in sentences {
567 if result.len() + sentence.len() + 1 <= max_length - 3 {
568 if !result.is_empty() {
569 result.push('.');
570 }
571 result.push_str(sentence.trim());
572 } else {
573 break;
574 }
575 }
576
577 if result.is_empty() {
578 result = summary.chars().take(max_length - 3).collect();
580 result.push_str("...");
581 } else {
582 result.push('.');
583 }
584
585 Ok(result)
586 }
587
588 #[allow(dead_code)]
589 fn generate_parallel_summary(
590 &self,
591 text: &str,
592 processor: &crate::text::TextProcessor,
593 ) -> Result<String> {
594 let sentences = processor.extract_sentences(text);
595
596 if sentences.is_empty() {
597 return Ok(String::new());
598 }
599
600 if sentences.len() == 1 {
601 return Ok(sentences[0].clone());
602 }
603
604 let mut best_sentence = &sentences[0];
606 let mut best_score = 0.0;
607
608 for sentence in &sentences {
609 let words: Vec<&str> = sentence.split_whitespace().collect();
610
611 let length_score = if words.len() < 5 {
613 0.1
614 } else if words.len() > 30 {
615 0.3
616 } else {
617 1.0
618 };
619
620 let word_score = words.len() as f32 * 0.1;
621 let score = length_score + word_score;
622
623 if score > best_score {
624 best_score = score;
625 best_sentence = sentence;
626 }
627 }
628
629 if best_sentence.len() > self.config.max_summary_length {
631 Ok(best_sentence
632 .chars()
633 .take(self.config.max_summary_length - 3)
634 .collect::<String>()
635 + "...")
636 } else {
637 Ok(best_sentence.clone())
638 }
639 }
640
641 async fn build_bottom_up(&mut self, leaf_nodes: Vec<NodeId>) -> Result<()> {
643 let mut current_level_nodes = leaf_nodes;
644 let mut current_level = 0;
645
646 while current_level_nodes.len() > 1 {
647 let next_level_nodes = self.merge_level(¤t_level_nodes, current_level + 1).await?;
648
649 current_level_nodes = next_level_nodes;
650 current_level += 1;
651 }
652
653 self.root_nodes = current_level_nodes;
655
656 Ok(())
657 }
658
659 async fn merge_level(&mut self, level_nodes: &[NodeId], new_level: usize) -> Result<Vec<NodeId>> {
661 let mut new_level_nodes = Vec::new();
662 for (node_counter, chunk) in level_nodes.chunks(self.config.merge_size).enumerate() {
664 let merged_node_id = NodeId::new(format!("level_{new_level}_{node_counter}"));
665 let merged_node = self.merge_nodes(chunk, merged_node_id.clone(), new_level).await?;
666
667 for child_id in chunk {
669 if let Some(child_node) = self.nodes.get_mut(child_id) {
670 child_node.parent = Some(merged_node_id.clone());
671 }
672 }
673
674 self.nodes.insert(merged_node_id.clone(), merged_node);
675 new_level_nodes.push(merged_node_id);
676 }
677
678 self.levels.insert(new_level, new_level_nodes.clone());
680
681 Ok(new_level_nodes)
682 }
683
684 async fn merge_nodes(
686 &self,
687 node_ids: &[NodeId],
688 merged_id: NodeId,
689 level: usize,
690 ) -> Result<TreeNode> {
691 let mut combined_content = String::new();
692 let mut all_chunk_ids = Vec::new();
693 let mut all_keywords = Vec::new();
694 let mut min_offset = usize::MAX;
695 let mut max_offset = 0;
696
697 for node_id in node_ids {
698 if let Some(node) = self.nodes.get(node_id) {
699 if !combined_content.is_empty() {
700 combined_content.push_str("\n\n");
701 }
702 combined_content.push_str(&node.content);
703 all_chunk_ids.extend(node.chunk_ids.clone());
704 all_keywords.extend(node.keywords.clone());
705 min_offset = min_offset.min(node.start_offset);
706 max_offset = max_offset.max(node.end_offset);
707 }
708 }
709
710 all_keywords.sort();
712 all_keywords.dedup();
713 all_keywords.truncate(10);
714
715 let summary = if self.config.llm_config.enabled {
717 if self.llm_client.is_some() {
719 let context = format!(
721 "Merging {} nodes at level {}. This represents a higher-level abstraction of the document content.",
722 node_ids.len(),
723 level
724 );
725
726 match self.generate_llm_summary(&combined_content, level, &context).await {
728 Ok(llm_summary) => {
729 println!("✅ Generated LLM-based summary for level {} ({} chars)", level, llm_summary.len());
730 llm_summary
731 }
732 Err(e) => {
733 eprintln!("⚠️ LLM summarization failed for level {}: {}, falling back to extractive", level, e);
734 self.generate_extractive_summary(&combined_content)?
735 }
736 }
737 } else {
738 self.generate_extractive_summary(&combined_content)?
739 }
740 } else {
741 self.generate_extractive_summary(&combined_content)?
742 };
743
744 Ok(TreeNode {
745 id: merged_id,
746 content: combined_content,
747 summary,
748 level,
749 children: node_ids.to_vec(),
750 parent: None,
751 chunk_ids: all_chunk_ids,
752 keywords: all_keywords,
753 start_offset: min_offset,
754 end_offset: max_offset,
755 })
756 }
757
758 fn generate_extractive_summary(&self, text: &str) -> Result<String> {
760 let sentences = self.text_processor.extract_sentences(text);
761
762 if sentences.is_empty() {
763 return Ok(String::new());
764 }
765
766 if sentences.len() == 1 {
767 return Ok(sentences[0].clone());
768 }
769
770 let mut sentence_scores: Vec<(usize, f32)> = sentences
772 .iter()
773 .enumerate()
774 .map(|(i, sentence)| {
775 let score = self.score_sentence(sentence, &sentences);
776 (i, score)
777 })
778 .collect();
779
780 sentence_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
782
783 let mut summary = String::new();
785 let mut selected_indices = Vec::new();
786
787 for (sentence_idx, _score) in sentence_scores {
788 let sentence = &sentences[sentence_idx];
789 if summary.len() + sentence.len() <= self.config.max_summary_length {
790 selected_indices.push(sentence_idx);
791 if !summary.is_empty() {
792 summary.push(' ');
793 }
794 summary.push_str(sentence);
795 }
796 }
797
798 if summary.is_empty() && !sentences.is_empty() {
800 let first_sentence = &sentences[0];
801 if first_sentence.len() <= self.config.max_summary_length {
802 summary = first_sentence.clone();
803 } else {
804 summary = first_sentence
805 .chars()
806 .take(self.config.max_summary_length - 3)
807 .collect::<String>()
808 + "...";
809 }
810 }
811
812 Ok(summary)
813 }
814
815 fn score_sentence(&self, sentence: &str, all_sentences: &[String]) -> f32 {
817 let words: Vec<&str> = sentence.split_whitespace().collect();
818
819 let length_score = if words.len() < 5 {
821 0.1
822 } else if words.len() > 30 {
823 0.3
824 } else {
825 1.0
826 };
827
828 let position_score = 0.5; let mut word_freq_score = 0.0;
833 let total_words: Vec<&str> = all_sentences
834 .iter()
835 .flat_map(|s| s.split_whitespace())
836 .collect();
837
838 for word in &words {
839 let word_lower = word.to_lowercase();
840 if word_lower.len() > 3 && !self.is_stop_word(&word_lower) {
841 let freq = total_words
842 .iter()
843 .filter(|&&w| w.to_lowercase() == word_lower)
844 .count();
845 if freq > 1 {
846 word_freq_score += freq as f32 / total_words.len() as f32;
847 }
848 }
849 }
850
851 length_score * 0.4 + position_score * 0.2 + word_freq_score * 0.4
852 }
853
854 fn is_stop_word(&self, word: &str) -> bool {
856 const STOP_WORDS: &[&str] = &[
857 "the", "be", "to", "of", "and", "a", "in", "that", "have", "i", "it", "for", "not",
858 "on", "with", "he", "as", "you", "do", "at", "this", "but", "his", "by", "from",
859 "they", "we", "say", "her", "she", "or", "an", "will", "my", "one", "all", "would",
860 "there", "their", "what", "so", "up", "out", "if", "about", "who", "get", "which",
861 "go", "me",
862 ];
863 STOP_WORDS.contains(&word)
864 }
865
866 pub fn query(&self, query: &str, max_results: usize) -> Result<Vec<QueryResult>> {
868 let query_keywords = self.text_processor.extract_keywords(query, 5);
869 let mut results = Vec::new();
870
871 for (node_id, node) in &self.nodes {
873 let score = self.calculate_relevance_score(node, &query_keywords, query);
874
875 if score > 0.1 {
876 results.push(QueryResult {
877 node_id: node_id.clone(),
878 score,
879 level: node.level,
880 summary: node.summary.clone(),
881 keywords: node.keywords.clone(),
882 chunk_ids: node.chunk_ids.clone(),
883 });
884 }
885 }
886
887 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
889 results.truncate(max_results);
890
891 Ok(results)
892 }
893
894 fn calculate_relevance_score(
896 &self,
897 node: &TreeNode,
898 query_keywords: &[String],
899 query: &str,
900 ) -> f32 {
901 let mut score = 0.0;
902
903 let node_text = format!("{} {}", node.summary, node.keywords.join(" ")).to_lowercase();
905 for keyword in query_keywords {
906 if node_text.contains(&keyword.to_lowercase()) {
907 score += 1.0;
908 }
909 }
910
911 let query_words: Vec<&str> = query.split_whitespace().collect();
913 let node_words: Vec<&str> = node_text.split_whitespace().collect();
914
915 let mut overlap_count = 0;
916 for query_word in &query_words {
917 if node_words.contains(&query_word.to_lowercase().as_str()) {
918 overlap_count += 1;
919 }
920 }
921
922 if !query_words.is_empty() {
923 score += (overlap_count as f32 / query_words.len() as f32) * 2.0;
924 }
925
926 let level_score = 1.0 / (node.level + 1) as f32;
928 score += level_score * 0.5;
929
930 score
931 }
932
933 pub fn get_ancestors(&self, node_id: &NodeId) -> Vec<&TreeNode> {
935 let mut ancestors = Vec::new();
936 let mut current_id = Some(node_id.clone());
937
938 while let Some(id) = current_id {
939 if let Some(node) = self.nodes.get(&id) {
940 ancestors.push(node);
941 current_id = node.parent.clone();
942 } else {
943 break;
944 }
945 }
946
947 ancestors
948 }
949
950 pub fn get_descendants(&self, node_id: &NodeId) -> Vec<&TreeNode> {
952 let mut descendants = Vec::new();
953 let mut queue = VecDeque::new();
954
955 if let Some(node) = self.nodes.get(node_id) {
956 queue.extend(node.children.iter());
957 }
958
959 while let Some(child_id) = queue.pop_front() {
960 if let Some(child_node) = self.nodes.get(child_id) {
961 descendants.push(child_node);
962 queue.extend(child_node.children.iter());
963 }
964 }
965
966 descendants
967 }
968
969 pub fn get_node(&self, node_id: &NodeId) -> Option<&TreeNode> {
971 self.nodes.get(node_id)
972 }
973
974 pub fn get_level_nodes(&self, level: usize) -> Vec<&TreeNode> {
976 if let Some(node_ids) = self.levels.get(&level) {
977 node_ids
978 .iter()
979 .filter_map(|id| self.nodes.get(id))
980 .collect()
981 } else {
982 Vec::new()
983 }
984 }
985
986 pub fn get_root_nodes(&self) -> Vec<&TreeNode> {
988 self.root_nodes
989 .iter()
990 .filter_map(|id| self.nodes.get(id))
991 .collect()
992 }
993
994 pub fn document_id(&self) -> &DocumentId {
996 &self.document_id
997 }
998
999 pub fn get_statistics(&self) -> TreeStatistics {
1001 let max_level = self.levels.keys().max().copied().unwrap_or(0);
1002 let total_nodes = self.nodes.len();
1003 let nodes_per_level: HashMap<usize, usize> = self
1004 .levels
1005 .iter()
1006 .map(|(level, nodes)| (*level, nodes.len()))
1007 .collect();
1008
1009 TreeStatistics {
1010 total_nodes,
1011 max_level,
1012 nodes_per_level,
1013 root_count: self.root_nodes.len(),
1014 document_id: self.document_id.clone(),
1015 }
1016 }
1017
1018 pub fn to_json(&self) -> Result<String> {
1020 use json::JsonValue;
1021
1022 let mut tree_json = json::object! {
1023 "document_id": self.document_id.to_string(),
1024 "config": {
1025 "merge_size": self.config.merge_size,
1026 "max_summary_length": self.config.max_summary_length,
1027 "min_node_size": self.config.min_node_size,
1028 "overlap_sentences": self.config.overlap_sentences
1029 },
1030 "nodes": {},
1031 "root_nodes": [],
1032 "levels": {}
1033 };
1034
1035 for (node_id, node) in &self.nodes {
1037 let node_json = json::object! {
1038 "id": node_id.to_string(),
1039 "content": node.content.clone(),
1040 "summary": node.summary.clone(),
1041 "level": node.level,
1042 "children": node.children.iter().map(|id| id.to_string()).collect::<Vec<_>>(),
1043 "parent": node.parent.as_ref().map(|id| id.to_string()),
1044 "chunk_ids": node.chunk_ids.iter().map(|id| id.to_string()).collect::<Vec<_>>(),
1045 "keywords": node.keywords.clone(),
1046 "start_offset": node.start_offset,
1047 "end_offset": node.end_offset
1048 };
1049 tree_json["nodes"][node_id.to_string()] = node_json;
1050 }
1051
1052 tree_json["root_nodes"] = self
1054 .root_nodes
1055 .iter()
1056 .map(|id| JsonValue::String(id.to_string()))
1057 .collect::<Vec<_>>()
1058 .into();
1059
1060 for (level, node_ids) in &self.levels {
1062 tree_json["levels"][level.to_string()] = node_ids
1063 .iter()
1064 .map(|id| JsonValue::String(id.to_string()))
1065 .collect::<Vec<_>>()
1066 .into();
1067 }
1068
1069 Ok(tree_json.dump())
1070 }
1071
1072 pub fn from_json(json_str: &str) -> Result<Self> {
1074 let json_data = json::parse(json_str).map_err(crate::GraphRAGError::Json)?;
1075
1076 let document_id = DocumentId::new(
1077 json_data["document_id"]
1078 .as_str()
1079 .ok_or_else(|| {
1080 crate::GraphRAGError::Json(json::Error::WrongType(
1081 "document_id must be string".to_string(),
1082 ))
1083 })?
1084 .to_string(),
1085 );
1086
1087 let config_json = &json_data["config"];
1088 let config = HierarchicalConfig {
1089 merge_size: config_json["merge_size"].as_usize().unwrap_or(5),
1090 max_summary_length: config_json["max_summary_length"].as_usize().unwrap_or(200),
1091 min_node_size: config_json["min_node_size"].as_usize().unwrap_or(50),
1092 overlap_sentences: config_json["overlap_sentences"].as_usize().unwrap_or(2),
1093 llm_config: LLMConfig::default(),
1094 };
1095
1096 let mut tree = Self::new(document_id, config)?;
1097
1098 if let json::JsonValue::Object(nodes_obj) = &json_data["nodes"] {
1100 for (node_id_str, node_json) in nodes_obj.iter() {
1101 let node_id = NodeId::new(node_id_str.to_string());
1102
1103 let children: Vec<NodeId> = node_json["children"]
1104 .members()
1105 .filter_map(|v| v.as_str())
1106 .map(|s| NodeId::new(s.to_string()))
1107 .collect();
1108
1109 let parent = node_json["parent"]
1110 .as_str()
1111 .map(|s| NodeId::new(s.to_string()));
1112
1113 let chunk_ids: Vec<ChunkId> = node_json["chunk_ids"]
1114 .members()
1115 .filter_map(|v| v.as_str())
1116 .map(|s| ChunkId::new(s.to_string()))
1117 .collect();
1118
1119 let keywords: Vec<String> = node_json["keywords"]
1120 .members()
1121 .filter_map(|v| v.as_str())
1122 .map(|s| s.to_string())
1123 .collect();
1124
1125 let node = TreeNode {
1126 id: node_id.clone(),
1127 content: node_json["content"].as_str().unwrap_or("").to_string(),
1128 summary: node_json["summary"].as_str().unwrap_or("").to_string(),
1129 level: node_json["level"].as_usize().unwrap_or(0),
1130 children,
1131 parent,
1132 chunk_ids,
1133 keywords,
1134 start_offset: node_json["start_offset"].as_usize().unwrap_or(0),
1135 end_offset: node_json["end_offset"].as_usize().unwrap_or(0),
1136 };
1137
1138 tree.nodes.insert(node_id, node);
1139 }
1140 }
1141
1142 tree.root_nodes = json_data["root_nodes"]
1144 .members()
1145 .filter_map(|v| v.as_str())
1146 .map(|s| NodeId::new(s.to_string()))
1147 .collect();
1148
1149 if let json::JsonValue::Object(levels_obj) = &json_data["levels"] {
1151 for (level_str, level_json) in levels_obj.iter() {
1152 if let Ok(level) = level_str.parse::<usize>() {
1153 let node_ids: Vec<NodeId> = level_json
1154 .members()
1155 .filter_map(|v| v.as_str())
1156 .map(|s| NodeId::new(s.to_string()))
1157 .collect();
1158 tree.levels.insert(level, node_ids);
1159 }
1160 }
1161 }
1162
1163 Ok(tree)
1164 }
1165}
1166
1167#[derive(Debug, Clone)]
1169pub struct QueryResult {
1170 pub node_id: NodeId,
1172 pub score: f32,
1174 pub level: usize,
1176 pub summary: String,
1178 pub keywords: Vec<String>,
1180 pub chunk_ids: Vec<ChunkId>,
1182}
1183
1184#[derive(Debug)]
1186pub struct TreeStatistics {
1187 pub total_nodes: usize,
1189 pub max_level: usize,
1191 pub nodes_per_level: HashMap<usize, usize>,
1193 pub root_count: usize,
1195 pub document_id: DocumentId,
1197}
1198
1199impl TreeStatistics {
1200 pub fn print(&self) {
1202 println!("Hierarchical Tree Statistics:");
1203 println!(" Document ID: {}", self.document_id);
1204 println!(" Total nodes: {}", self.total_nodes);
1205 println!(" Max level: {}", self.max_level);
1206 println!(" Root nodes: {}", self.root_count);
1207 println!(" Nodes per level:");
1208
1209 let mut levels: Vec<_> = self.nodes_per_level.iter().collect();
1210 levels.sort_by_key(|(level, _)| *level);
1211
1212 for (level, count) in levels {
1213 println!(" Level {level}: {count} nodes");
1214 }
1215 }
1216}
1217
1218#[cfg(test)]
1219mod tests {
1220 use super::*;
1221 use crate::core::DocumentId;
1222
1223 #[test]
1224 fn test_tree_creation() {
1225 let config = HierarchicalConfig::default();
1226 let doc_id = DocumentId::new("test_doc".to_string());
1227 let tree = DocumentTree::new(doc_id, config);
1228 assert!(tree.is_ok());
1229 }
1230
1231 #[test]
1232 fn test_extractive_summarization() {
1233 let config = HierarchicalConfig::default();
1234 let doc_id = DocumentId::new("test_doc".to_string());
1235 let tree = DocumentTree::new(doc_id, config).unwrap();
1236
1237 let text = "This is the first sentence. This is a second sentence with more details. This is the final sentence.";
1238 let summary = tree.generate_extractive_summary(text).unwrap();
1239
1240 assert!(!summary.is_empty());
1241 assert!(summary.len() <= tree.config.max_summary_length);
1242 }
1243
1244 #[test]
1245 fn test_json_serialization() {
1246 let config = HierarchicalConfig::default();
1247 let doc_id = DocumentId::new("test_doc".to_string());
1248 let tree = DocumentTree::new(doc_id, config).unwrap();
1249
1250 let json = tree.to_json().unwrap();
1251 assert!(json.contains("test_doc"));
1252
1253 let loaded_tree = DocumentTree::from_json(&json).unwrap();
1254 assert_eq!(loaded_tree.document_id.to_string(), "test_doc");
1255 }
1256}