Skip to main content

graphrag_core/summarization/
mod.rs

1#[cfg(feature = "parallel-processing")]
2use crate::parallel::ParallelProcessor;
3use crate::{
4    core::{ChunkId, DocumentId, GraphRAGError, TextChunk},
5    text::TextProcessor,
6    Result,
7};
8use indexmap::IndexMap;
9use std::collections::{HashMap, VecDeque};
10use std::sync::Arc;
11
12/// Trait for LLM client to be used in summarization
13#[async_trait::async_trait]
14pub trait LLMClient: Send + Sync {
15    /// Generate a summary for the given text
16    async fn generate_summary(
17        &self,
18        text: &str,
19        prompt: &str,
20        max_tokens: usize,
21        temperature: f32,
22    ) -> Result<String>;
23
24    /// Generate summary in batch for multiple texts
25    async fn generate_summary_batch(
26        &self,
27        texts: &[(&str, &str)],
28        max_tokens: usize,
29        temperature: f32,
30    ) -> Result<Vec<String>> {
31        let mut results = Vec::new();
32        for (text, prompt) in texts {
33            let summary = self
34                .generate_summary(text, prompt, max_tokens, temperature)
35                .await?;
36            results.push(summary);
37        }
38        Ok(results)
39    }
40
41    /// Get model name
42    fn model_name(&self) -> &str;
43}
44
45/// Unique identifier for tree nodes
46#[derive(Debug, Clone, PartialEq, Eq, Hash)]
47pub struct NodeId(pub String);
48
49impl NodeId {
50    /// Creates a new NodeId from a string
51    pub fn new(id: String) -> Self {
52        Self(id)
53    }
54}
55
56impl std::fmt::Display for NodeId {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        write!(f, "{}", self.0)
59    }
60}
61
62impl From<String> for NodeId {
63    fn from(s: String) -> Self {
64        Self(s)
65    }
66}
67
68impl From<NodeId> for String {
69    fn from(id: NodeId) -> Self {
70        id.0
71    }
72}
73
74/// Configuration for hierarchical summarization
75#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
76pub struct HierarchicalConfig {
77    /// Number of nodes to merge when building tree levels
78    pub merge_size: usize,
79    /// Maximum character length for generated summaries
80    pub max_summary_length: usize,
81    /// Minimum size in characters for nodes to be considered valid
82    pub min_node_size: usize,
83    /// Number of overlapping sentences between adjacent chunks
84    pub overlap_sentences: usize,
85    /// LLM-based summarization configuration
86    pub llm_config: LLMConfig,
87}
88
89/// Configuration for LLM-based summarization
90#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
91pub struct LLMConfig {
92    /// Whether to use LLM for summarization (vs extractive)
93    pub enabled: bool,
94    /// Model to use for summarization
95    pub model_name: String,
96    /// Temperature for generation (lower = more deterministic)
97    pub temperature: f32,
98    /// Maximum tokens for LLM generation
99    pub max_tokens: usize,
100    /// Strategy for summarization
101    pub strategy: LLMStrategy,
102    /// Level-specific configurations
103    pub level_configs: HashMap<usize, LevelConfig>,
104}
105
106/// Summarization strategy for different levels
107#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
108pub enum LLMStrategy {
109    /// Use same approach for all levels
110    Uniform,
111    /// Different approaches for different levels
112    Adaptive,
113    /// Progressive: extractive for lower levels, abstractive for higher
114    Progressive,
115}
116
117/// Configuration specific to tree levels
118#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
119pub struct LevelConfig {
120    /// Maximum summary length for this level
121    pub max_length: usize,
122    /// Whether to use abstractive summarization at this level
123    pub use_abstractive: bool,
124    /// Custom prompt template for this level
125    pub prompt_template: Option<String>,
126    /// Temperature override for this level
127    pub temperature: Option<f32>,
128}
129
130impl Default for HierarchicalConfig {
131    fn default() -> Self {
132        Self {
133            merge_size: 5,
134            max_summary_length: 200,
135            min_node_size: 50,
136            overlap_sentences: 2,
137            llm_config: LLMConfig::default(),
138        }
139    }
140}
141
142impl Default for LLMConfig {
143    fn default() -> Self {
144        Self {
145            enabled: false, // Disabled by default for backward compatibility
146            model_name: "llama3.1:8b".to_string(),
147            temperature: 0.3, // Lower temperature for more coherent summaries
148            max_tokens: 150,
149            strategy: LLMStrategy::Progressive,
150            level_configs: HashMap::new(),
151        }
152    }
153}
154
155impl Default for LevelConfig {
156    fn default() -> Self {
157        Self {
158            max_length: 200,
159            use_abstractive: false, // Start with extractive at lower levels
160            prompt_template: None,
161            temperature: None,
162        }
163    }
164}
165
166/// A node in the hierarchical document tree
167#[derive(Debug, Clone)]
168pub struct TreeNode {
169    /// Unique identifier for this node
170    pub id: NodeId,
171    /// Full text content of this node
172    pub content: String,
173    /// Generated summary of the content
174    pub summary: String,
175    /// Level in the tree hierarchy (0 = leaf, higher = more abstract)
176    pub level: usize,
177    /// IDs of child nodes in the tree
178    pub children: Vec<NodeId>,
179    /// ID of the parent node, if any
180    pub parent: Option<NodeId>,
181    /// IDs of text chunks represented by this node
182    pub chunk_ids: Vec<ChunkId>,
183    /// Extracted keywords from the content
184    pub keywords: Vec<String>,
185    /// Starting character offset in the original document
186    pub start_offset: usize,
187    /// Ending character offset in the original document
188    pub end_offset: usize,
189}
190
191/// Hierarchical document tree for multi-level summarization
192pub struct DocumentTree {
193    nodes: IndexMap<NodeId, TreeNode>,
194    root_nodes: Vec<NodeId>,
195    levels: HashMap<usize, Vec<NodeId>>,
196    document_id: DocumentId,
197    config: HierarchicalConfig,
198    text_processor: TextProcessor,
199    llm_client: Option<Arc<dyn LLMClient>>,
200}
201
202impl DocumentTree {
203    /// Create a new document tree
204    pub fn new(document_id: DocumentId, config: HierarchicalConfig) -> Result<Self> {
205        let text_processor = TextProcessor::new(1000, 100)?;
206
207        Ok(Self {
208            nodes: IndexMap::new(),
209            root_nodes: Vec::new(),
210            levels: HashMap::new(),
211            document_id,
212            config,
213            text_processor,
214            llm_client: None,
215        })
216    }
217
218    /// Create a new document tree with LLM client
219    pub fn with_llm_client(
220        document_id: DocumentId,
221        config: HierarchicalConfig,
222        llm_client: Arc<dyn LLMClient>,
223    ) -> Result<Self> {
224        let text_processor = TextProcessor::new(1000, 100)?;
225
226        Ok(Self {
227            nodes: IndexMap::new(),
228            root_nodes: Vec::new(),
229            levels: HashMap::new(),
230            document_id,
231            config,
232            text_processor,
233            llm_client: Some(llm_client),
234        })
235    }
236
237    /// Set LLM client for the tree
238    pub fn set_llm_client(&mut self, llm_client: Option<Arc<dyn LLMClient>>) {
239        self.llm_client = llm_client;
240    }
241
242    /// Create a new document tree with parallel processing support
243    #[cfg(feature = "parallel-processing")]
244    pub fn with_parallel_processing(
245        document_id: DocumentId,
246        config: HierarchicalConfig,
247        _parallel_processor: ParallelProcessor,
248    ) -> Result<Self> {
249        let text_processor = TextProcessor::new(1000, 100)?;
250
251        Ok(Self {
252            nodes: IndexMap::new(),
253            root_nodes: Vec::new(),
254            levels: HashMap::new(),
255            document_id,
256            config,
257            text_processor,
258            llm_client: None,
259        })
260    }
261
262    /// Create a new document tree with both parallel processing and LLM client
263    #[cfg(feature = "parallel-processing")]
264    pub fn with_parallel_and_llm(
265        document_id: DocumentId,
266        config: HierarchicalConfig,
267        _parallel_processor: ParallelProcessor,
268        llm_client: Arc<dyn LLMClient>,
269    ) -> Result<Self> {
270        let text_processor = TextProcessor::new(1000, 100)?;
271
272        Ok(Self {
273            nodes: IndexMap::new(),
274            root_nodes: Vec::new(),
275            levels: HashMap::new(),
276            document_id,
277            config,
278            text_processor,
279            llm_client: Some(llm_client),
280        })
281    }
282
283    /// Build the hierarchical tree from text chunks
284    pub async fn build_from_chunks(&mut self, chunks: Vec<TextChunk>) -> Result<()> {
285        if chunks.is_empty() {
286            return Ok(());
287        }
288
289        println!("Building hierarchical tree from {} chunks", chunks.len());
290
291        // Create leaf nodes from chunks
292        let leaf_nodes = self.create_leaf_nodes(chunks)?;
293
294        // Build the tree bottom-up
295        self.build_bottom_up(leaf_nodes).await?;
296
297        println!(
298            "Tree built with {} total nodes across {} levels",
299            self.nodes.len(),
300            self.levels.len()
301        );
302
303        Ok(())
304    }
305
306    /// Create leaf nodes from text chunks with parallel processing
307    fn create_leaf_nodes(&mut self, chunks: Vec<TextChunk>) -> Result<Vec<NodeId>> {
308        if chunks.len() < 10 {
309            // Use sequential processing for small numbers of chunks
310            return self.create_leaf_nodes_sequential(chunks);
311        }
312
313        #[cfg(feature = "parallel-processing")]
314        {
315            use rayon::prelude::*;
316
317            // Parallel node creation with proper error handling
318            let node_results: std::result::Result<Vec<_>, crate::GraphRAGError> = chunks
319                .par_iter()
320                .map(|chunk| {
321                    let node_id = NodeId::new(format!("leaf_{}", chunk.id));
322
323                    // Create a temporary text processor for each thread to avoid borrowing issues
324                    let temp_processor =
325                        crate::text::TextProcessor::new(1000, 100).map_err(|e| {
326                            crate::GraphRAGError::Config {
327                                message: format!("Failed to create text processor: {e}"),
328                            }
329                        })?;
330
331                    // Extract keywords for the chunk
332                    let keywords = temp_processor.extract_keywords(&chunk.content, 5);
333
334                    // Generate summary using simplified approach suitable for parallel execution
335                    let summary =
336                        self.generate_parallel_summary(&chunk.content, &temp_processor)?;
337
338                    let node = TreeNode {
339                        id: node_id.clone(),
340                        content: chunk.content.clone(),
341                        summary,
342                        level: 0,
343                        children: Vec::new(),
344                        parent: None,
345                        chunk_ids: vec![chunk.id.clone()],
346                        keywords,
347                        start_offset: chunk.start_offset,
348                        end_offset: chunk.end_offset,
349                    };
350
351                    Ok((node_id, node))
352                })
353                .collect();
354
355            match node_results {
356                Ok(nodes) => {
357                    let mut leaf_node_ids = Vec::new();
358
359                    // Insert nodes sequentially to avoid concurrent modification
360                    for (node_id, node) in nodes {
361                        leaf_node_ids.push(node_id.clone());
362                        self.nodes.insert(node_id, node);
363                    }
364
365                    println!("Created {} leaf nodes in parallel", leaf_node_ids.len());
366
367                    // Update levels tracking
368                    self.levels.insert(0, leaf_node_ids.clone());
369
370                    Ok(leaf_node_ids)
371                },
372                Err(e) => {
373                    eprintln!("Error in parallel node creation: {e}");
374                    // Fall back to sequential processing
375                    self.create_leaf_nodes_sequential(chunks)
376                },
377            }
378        }
379
380        #[cfg(not(feature = "parallel-processing"))]
381        {
382            self.create_leaf_nodes_sequential(chunks)
383        }
384    }
385
386    /// Sequential leaf node creation (fallback)
387    fn create_leaf_nodes_sequential(&mut self, chunks: Vec<TextChunk>) -> Result<Vec<NodeId>> {
388        let mut leaf_node_ids = Vec::new();
389
390        for chunk in chunks {
391            let node_id = NodeId::new(format!("leaf_{}", chunk.id));
392
393            // Extract keywords for the chunk
394            let keywords = self.text_processor.extract_keywords(&chunk.content, 5);
395
396            let node = TreeNode {
397                id: node_id.clone(),
398                content: chunk.content.clone(),
399                summary: self.generate_extractive_summary(&chunk.content)?,
400                level: 0,
401                children: Vec::new(),
402                parent: None,
403                chunk_ids: vec![chunk.id],
404                keywords,
405                start_offset: chunk.start_offset,
406                end_offset: chunk.end_offset,
407            };
408
409            self.nodes.insert(node_id.clone(), node);
410            leaf_node_ids.push(node_id);
411        }
412
413        // Update levels tracking
414        self.levels.insert(0, leaf_node_ids.clone());
415
416        Ok(leaf_node_ids)
417    }
418
419    /// Generate summary suitable for parallel processing
420    /// Generate summary using LLM for the given text and level
421    pub async fn generate_llm_summary(
422        &self,
423        text: &str,
424        level: usize,
425        context: &str,
426    ) -> Result<String> {
427        let llm_client = self
428            .llm_client
429            .as_ref()
430            .ok_or_else(|| GraphRAGError::Config {
431                message: "LLM client not configured for summarization".to_string(),
432            })?;
433
434        // Get configuration for this level
435        let level_config = self.get_level_config(level);
436
437        // Create prompt based on strategy and level
438        let prompt = self.create_summary_prompt(text, level, context, &level_config)?;
439
440        // Generate summary
441        let summary = llm_client
442            .generate_summary(
443                text,
444                &prompt,
445                level_config.max_length,
446                level_config
447                    .temperature
448                    .unwrap_or(self.config.llm_config.temperature),
449            )
450            .await?;
451
452        // Ensure summary is within length limits
453        self.truncate_summary(&summary, level_config.max_length)
454    }
455
456    /// Generate summaries in batch for multiple texts
457    pub async fn generate_llm_summaries_batch(
458        &self,
459        texts: &[(&str, usize, &str)], // (text, level, context)
460    ) -> Result<Vec<String>> {
461        let llm_client = self
462            .llm_client
463            .as_ref()
464            .ok_or_else(|| GraphRAGError::Config {
465                message: "LLM client not configured for summarization".to_string(),
466            })?;
467
468        let mut prompts = Vec::new();
469        let mut configs = Vec::new();
470
471        for (text, level, context) in texts {
472            let level_config = self.get_level_config(*level);
473            let prompt = self.create_summary_prompt(text, *level, context, &level_config)?;
474            prompts.push(prompt);
475            configs.push(level_config);
476        }
477
478        // Generate summaries in batch
479        let text_refs: Vec<&str> = texts.iter().map(|(t, _, _)| *t).collect();
480        let prompt_refs: Vec<&str> = prompts.iter().map(|p| p.as_str()).collect();
481
482        let summaries = llm_client
483            .generate_summary_batch(
484                &text_refs
485                    .iter()
486                    .zip(prompt_refs.iter())
487                    .map(|(&t, &p)| (t, p))
488                    .collect::<Vec<_>>(),
489                self.config.llm_config.max_tokens,
490                self.config.llm_config.temperature,
491            )
492            .await?;
493
494        // Truncate summaries according to level configs
495        let mut results = Vec::new();
496        for (i, summary) in summaries.into_iter().enumerate() {
497            let truncated = self.truncate_summary(&summary, configs[i].max_length)?;
498            results.push(truncated);
499        }
500
501        Ok(results)
502    }
503
504    /// Create a summary prompt based on level and strategy
505    fn create_summary_prompt(
506        &self,
507        text: &str,
508        level: usize,
509        context: &str,
510        level_config: &LevelConfig,
511    ) -> Result<String> {
512        // Use custom template if provided
513        if let Some(template) = &level_config.prompt_template {
514            return Ok(template
515                .replace("{text}", text)
516                .replace("{context}", context)
517                .replace("{level}", &level.to_string())
518                .replace("{max_length}", &level_config.max_length.to_string()));
519        }
520
521        // Default prompts based on strategy and level
522        match self.config.llm_config.strategy {
523            LLMStrategy::Uniform => {
524                Ok(format!(
525                    "Create a concise summary of the following text. The summary should be approximately {} characters long.\n\nContext: {}\n\nText to summarize:\n{}\n\nSummary:",
526                    level_config.max_length, context, text
527                ))
528            }
529            LLMStrategy::Adaptive => {
530                if level == 0 {
531                    Ok(format!(
532                        "Extract the key information from this text segment. Keep it factual and under {} characters.\n\nContext: {}\n\nText:\n{}\n\nKey points:",
533                        level_config.max_length, context, text
534                    ))
535                } else if level <= 2 {
536                    Ok(format!(
537                        "Create a coherent summary that combines the key information from this text. Make it approximately {} characters.\n\nContext: {}\n\nText:\n{}\n\nSummary:",
538                        level_config.max_length, context, text
539                    ))
540                } else {
541                    Ok(format!(
542                        "Generate a high-level abstract summary of this content. Focus on the main themes and insights. Limit to approximately {} characters.\n\nContext: {}\n\nText:\n{}\n\nAbstract summary:",
543                        level_config.max_length, context, text
544                    ))
545                }
546            }
547            LLMStrategy::Progressive => {
548                if level_config.use_abstractive {
549                    Ok(format!(
550                        "Generate an abstractive summary that synthesizes the key concepts and relationships in this text. The summary should be approximately {} characters.\n\nContext: {}\n\nText:\n{}\n\nAbstractive summary:",
551                        level_config.max_length, context, text
552                    ))
553                } else {
554                    Ok(format!(
555                        "Extract and organize the most important sentences from this text to create a coherent summary. Keep it under {} characters.\n\nContext: {}\n\nText:\n{}\n\nExtractive summary:",
556                        level_config.max_length, context, text
557                    ))
558                }
559            }
560        }
561    }
562
563    /// Get configuration for a specific level
564    fn get_level_config(&self, level: usize) -> LevelConfig {
565        self.config
566            .llm_config
567            .level_configs
568            .get(&level)
569            .cloned()
570            .unwrap_or({
571                // Default configuration based on level
572                LevelConfig {
573                    max_length: self.config.max_summary_length,
574                    use_abstractive: match self.config.llm_config.strategy {
575                        LLMStrategy::Progressive => level >= 2,
576                        LLMStrategy::Adaptive => level >= 3,
577                        LLMStrategy::Uniform => level > 0,
578                    },
579                    prompt_template: None,
580                    temperature: None,
581                }
582            })
583    }
584
585    /// Truncate summary to fit within length limits
586    fn truncate_summary(&self, summary: &str, max_length: usize) -> Result<String> {
587        if summary.len() <= max_length {
588            return Ok(summary.to_string());
589        }
590
591        // Try to truncate at sentence boundaries
592        let sentences: Vec<&str> = summary
593            .split('.')
594            .filter(|s| !s.trim().is_empty())
595            .collect();
596
597        let mut result = String::new();
598        for sentence in sentences {
599            if result.len() + sentence.len() < max_length - 3 {
600                if !result.is_empty() {
601                    result.push('.');
602                }
603                result.push_str(sentence.trim());
604            } else {
605                break;
606            }
607        }
608
609        if result.is_empty() {
610            // Fallback: truncate characters
611            result = summary.chars().take(max_length - 3).collect();
612            result.push_str("...");
613        } else {
614            result.push('.');
615        }
616
617        Ok(result)
618    }
619
620    #[allow(dead_code)]
621    fn generate_parallel_summary(
622        &self,
623        text: &str,
624        processor: &crate::text::TextProcessor,
625    ) -> Result<String> {
626        let sentences = processor.extract_sentences(text);
627
628        if sentences.is_empty() {
629            return Ok(String::new());
630        }
631
632        if sentences.len() == 1 {
633            return Ok(sentences[0].clone());
634        }
635
636        // Simplified scoring for parallel execution
637        let mut best_sentence = &sentences[0];
638        let mut best_score = 0.0;
639
640        for sentence in &sentences {
641            let words: Vec<&str> = sentence.split_whitespace().collect();
642
643            // Simple scoring based on length and word density
644            let length_score = if words.len() < 5 {
645                0.1
646            } else if words.len() > 30 {
647                0.3
648            } else {
649                1.0
650            };
651
652            let word_score = words.len() as f32 * 0.1;
653            let score = length_score + word_score;
654
655            if score > best_score {
656                best_score = score;
657                best_sentence = sentence;
658            }
659        }
660
661        // Truncate if necessary
662        if best_sentence.len() > self.config.max_summary_length {
663            Ok(best_sentence
664                .chars()
665                .take(self.config.max_summary_length - 3)
666                .collect::<String>()
667                + "...")
668        } else {
669            Ok(best_sentence.clone())
670        }
671    }
672
673    /// Build the tree bottom-up by merging nodes at each level
674    async fn build_bottom_up(&mut self, leaf_nodes: Vec<NodeId>) -> Result<()> {
675        let mut current_level_nodes = leaf_nodes;
676        let mut current_level = 0;
677
678        while current_level_nodes.len() > 1 {
679            let next_level_nodes = self
680                .merge_level(&current_level_nodes, current_level + 1)
681                .await?;
682
683            current_level_nodes = next_level_nodes;
684            current_level += 1;
685        }
686
687        // Set root nodes
688        self.root_nodes = current_level_nodes;
689
690        Ok(())
691    }
692
693    /// Merge nodes at a level to create the next level up
694    async fn merge_level(
695        &mut self,
696        level_nodes: &[NodeId],
697        new_level: usize,
698    ) -> Result<Vec<NodeId>> {
699        let mut new_level_nodes = Vec::new();
700        // Group nodes by merge_size
701        for (node_counter, chunk) in level_nodes.chunks(self.config.merge_size).enumerate() {
702            let merged_node_id = NodeId::new(format!("level_{new_level}_{node_counter}"));
703            let merged_node = self
704                .merge_nodes(chunk, merged_node_id.clone(), new_level)
705                .await?;
706
707            // Update parent references for children
708            for child_id in chunk {
709                if let Some(child_node) = self.nodes.get_mut(child_id) {
710                    child_node.parent = Some(merged_node_id.clone());
711                }
712            }
713
714            self.nodes.insert(merged_node_id.clone(), merged_node);
715            new_level_nodes.push(merged_node_id);
716        }
717
718        // Update levels tracking
719        self.levels.insert(new_level, new_level_nodes.clone());
720
721        Ok(new_level_nodes)
722    }
723
724    /// Merge multiple nodes into a single parent node
725    async fn merge_nodes(
726        &self,
727        node_ids: &[NodeId],
728        merged_id: NodeId,
729        level: usize,
730    ) -> Result<TreeNode> {
731        let mut combined_content = String::new();
732        let mut all_chunk_ids = Vec::new();
733        let mut all_keywords = Vec::new();
734        let mut min_offset = usize::MAX;
735        let mut max_offset = 0;
736
737        for node_id in node_ids {
738            if let Some(node) = self.nodes.get(node_id) {
739                if !combined_content.is_empty() {
740                    combined_content.push_str("\n\n");
741                }
742                combined_content.push_str(&node.content);
743                all_chunk_ids.extend(node.chunk_ids.clone());
744                all_keywords.extend(node.keywords.clone());
745                min_offset = min_offset.min(node.start_offset);
746                max_offset = max_offset.max(node.end_offset);
747            }
748        }
749
750        // Deduplicate and limit keywords
751        all_keywords.sort();
752        all_keywords.dedup();
753        all_keywords.truncate(10);
754
755        // Generate summary for the merged content
756        let summary = if self.config.llm_config.enabled {
757            // Use LLM-based summarization if enabled and available
758            if self.llm_client.is_some() {
759                // Create context for this merge operation
760                let context = format!(
761                    "Merging {} nodes at level {}. This represents a higher-level abstraction of the document content.",
762                    node_ids.len(),
763                    level
764                );
765
766                // Try LLM-based summarization, fall back to extractive if it fails
767                match self
768                    .generate_llm_summary(&combined_content, level, &context)
769                    .await
770                {
771                    Ok(llm_summary) => {
772                        println!(
773                            "✅ Generated LLM-based summary for level {} ({} chars)",
774                            level,
775                            llm_summary.len()
776                        );
777                        llm_summary
778                    },
779                    Err(e) => {
780                        eprintln!("⚠️ LLM summarization failed for level {}: {}, falling back to extractive", level, e);
781                        self.generate_extractive_summary(&combined_content)?
782                    },
783                }
784            } else {
785                self.generate_extractive_summary(&combined_content)?
786            }
787        } else {
788            self.generate_extractive_summary(&combined_content)?
789        };
790
791        Ok(TreeNode {
792            id: merged_id,
793            content: combined_content,
794            summary,
795            level,
796            children: node_ids.to_vec(),
797            parent: None,
798            chunk_ids: all_chunk_ids,
799            keywords: all_keywords,
800            start_offset: min_offset,
801            end_offset: max_offset,
802        })
803    }
804
805    /// Generate extractive summary using simple sentence ranking
806    fn generate_extractive_summary(&self, text: &str) -> Result<String> {
807        let sentences = self.text_processor.extract_sentences(text);
808
809        if sentences.is_empty() {
810            return Ok(String::new());
811        }
812
813        if sentences.len() == 1 {
814            return Ok(sentences[0].clone());
815        }
816
817        // Score sentences based on length and keyword density
818        let mut sentence_scores: Vec<(usize, f32)> = sentences
819            .iter()
820            .enumerate()
821            .map(|(i, sentence)| {
822                let score = self.score_sentence(sentence, &sentences);
823                (i, score)
824            })
825            .collect();
826
827        // Sort by score descending
828        sentence_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
829
830        // Select top sentences up to max_summary_length
831        let mut summary = String::new();
832        let mut selected_indices = Vec::new();
833
834        for (sentence_idx, _score) in sentence_scores {
835            let sentence = &sentences[sentence_idx];
836            if summary.len() + sentence.len() <= self.config.max_summary_length {
837                selected_indices.push(sentence_idx);
838                if !summary.is_empty() {
839                    summary.push(' ');
840                }
841                summary.push_str(sentence);
842            }
843        }
844
845        // If no sentences fit, take the first one truncated
846        if summary.is_empty() && !sentences.is_empty() {
847            let first_sentence = &sentences[0];
848            if first_sentence.len() <= self.config.max_summary_length {
849                summary = first_sentence.clone();
850            } else {
851                summary = first_sentence
852                    .chars()
853                    .take(self.config.max_summary_length - 3)
854                    .collect::<String>()
855                    + "...";
856            }
857        }
858
859        Ok(summary)
860    }
861
862    /// Score a sentence for extractive summarization
863    fn score_sentence(&self, sentence: &str, all_sentences: &[String]) -> f32 {
864        let words: Vec<&str> = sentence.split_whitespace().collect();
865
866        // Base score from length (prefer medium-length sentences)
867        let length_score = if words.len() < 5 {
868            0.1
869        } else if words.len() > 30 {
870            0.3
871        } else {
872            1.0
873        };
874
875        // Position score (prefer sentences from beginning and end)
876        let position_score = 0.5; // Simplified for now
877
878        // Word frequency score
879        let mut word_freq_score = 0.0;
880        let total_words: Vec<&str> = all_sentences
881            .iter()
882            .flat_map(|s| s.split_whitespace())
883            .collect();
884
885        for word in &words {
886            let word_lower = word.to_lowercase();
887            if word_lower.len() > 3 && !self.is_stop_word(&word_lower) {
888                let freq = total_words
889                    .iter()
890                    .filter(|&&w| w.to_lowercase() == word_lower)
891                    .count();
892                if freq > 1 {
893                    word_freq_score += freq as f32 / total_words.len() as f32;
894                }
895            }
896        }
897
898        length_score * 0.4 + position_score * 0.2 + word_freq_score * 0.4
899    }
900
901    /// Simple stop word detection (English)
902    fn is_stop_word(&self, word: &str) -> bool {
903        const STOP_WORDS: &[&str] = &[
904            "the", "be", "to", "of", "and", "a", "in", "that", "have", "i", "it", "for", "not",
905            "on", "with", "he", "as", "you", "do", "at", "this", "but", "his", "by", "from",
906            "they", "we", "say", "her", "she", "or", "an", "will", "my", "one", "all", "would",
907            "there", "their", "what", "so", "up", "out", "if", "about", "who", "get", "which",
908            "go", "me",
909        ];
910        STOP_WORDS.contains(&word)
911    }
912
913    /// Query the tree for relevant nodes at different levels
914    pub fn query(&self, query: &str, max_results: usize) -> Result<Vec<QueryResult>> {
915        let query_keywords = self.text_processor.extract_keywords(query, 5);
916        let mut results = Vec::new();
917
918        // Search all nodes for keyword matches
919        for (node_id, node) in &self.nodes {
920            let score = self.calculate_relevance_score(node, &query_keywords, query);
921
922            if score > 0.1 {
923                results.push(QueryResult {
924                    node_id: node_id.clone(),
925                    score,
926                    level: node.level,
927                    summary: node.summary.clone(),
928                    keywords: node.keywords.clone(),
929                    chunk_ids: node.chunk_ids.clone(),
930                });
931            }
932        }
933
934        // Sort by score and return top results
935        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
936        results.truncate(max_results);
937
938        Ok(results)
939    }
940
941    /// Calculate relevance score for a node given query keywords
942    fn calculate_relevance_score(
943        &self,
944        node: &TreeNode,
945        query_keywords: &[String],
946        query: &str,
947    ) -> f32 {
948        let mut score = 0.0;
949
950        // Keyword overlap score
951        let node_text = format!("{} {}", node.summary, node.keywords.join(" ")).to_lowercase();
952        for keyword in query_keywords {
953            if node_text.contains(&keyword.to_lowercase()) {
954                score += 1.0;
955            }
956        }
957
958        // Direct text similarity (simple word overlap)
959        let query_words: Vec<&str> = query.split_whitespace().collect();
960        let node_words: Vec<&str> = node_text.split_whitespace().collect();
961
962        let mut overlap_count = 0;
963        for query_word in &query_words {
964            if node_words.contains(&query_word.to_lowercase().as_str()) {
965                overlap_count += 1;
966            }
967        }
968
969        if !query_words.is_empty() {
970            score += (overlap_count as f32 / query_words.len() as f32) * 2.0;
971        }
972
973        // Level bonus (prefer higher levels for overview, lower levels for details)
974        let level_score = 1.0 / (node.level + 1) as f32;
975        score += level_score * 0.5;
976
977        score
978    }
979
980    /// Get ancestors of a node (path to root)
981    pub fn get_ancestors(&self, node_id: &NodeId) -> Vec<&TreeNode> {
982        let mut ancestors = Vec::new();
983        let mut current_id = Some(node_id.clone());
984
985        while let Some(id) = current_id {
986            if let Some(node) = self.nodes.get(&id) {
987                ancestors.push(node);
988                current_id = node.parent.clone();
989            } else {
990                break;
991            }
992        }
993
994        ancestors
995    }
996
997    /// Get descendants of a node (all children recursively)
998    pub fn get_descendants(&self, node_id: &NodeId) -> Vec<&TreeNode> {
999        let mut descendants = Vec::new();
1000        let mut queue = VecDeque::new();
1001
1002        if let Some(node) = self.nodes.get(node_id) {
1003            queue.extend(node.children.iter());
1004        }
1005
1006        while let Some(child_id) = queue.pop_front() {
1007            if let Some(child_node) = self.nodes.get(child_id) {
1008                descendants.push(child_node);
1009                queue.extend(child_node.children.iter());
1010            }
1011        }
1012
1013        descendants
1014    }
1015
1016    /// Get a node by ID
1017    pub fn get_node(&self, node_id: &NodeId) -> Option<&TreeNode> {
1018        self.nodes.get(node_id)
1019    }
1020
1021    /// Get all nodes at a specific level
1022    pub fn get_level_nodes(&self, level: usize) -> Vec<&TreeNode> {
1023        if let Some(node_ids) = self.levels.get(&level) {
1024            node_ids
1025                .iter()
1026                .filter_map(|id| self.nodes.get(id))
1027                .collect()
1028        } else {
1029            Vec::new()
1030        }
1031    }
1032
1033    /// Get the root nodes of the tree
1034    pub fn get_root_nodes(&self) -> Vec<&TreeNode> {
1035        self.root_nodes
1036            .iter()
1037            .filter_map(|id| self.nodes.get(id))
1038            .collect()
1039    }
1040
1041    /// Get the document ID
1042    pub fn document_id(&self) -> &DocumentId {
1043        &self.document_id
1044    }
1045
1046    /// Get tree statistics
1047    pub fn get_statistics(&self) -> TreeStatistics {
1048        let max_level = self.levels.keys().max().copied().unwrap_or(0);
1049        let total_nodes = self.nodes.len();
1050        let nodes_per_level: HashMap<usize, usize> = self
1051            .levels
1052            .iter()
1053            .map(|(level, nodes)| (*level, nodes.len()))
1054            .collect();
1055
1056        TreeStatistics {
1057            total_nodes,
1058            max_level,
1059            nodes_per_level,
1060            root_count: self.root_nodes.len(),
1061            document_id: self.document_id.clone(),
1062        }
1063    }
1064
1065    /// Serialize the tree to JSON format
1066    pub fn to_json(&self) -> Result<String> {
1067        use json::JsonValue;
1068
1069        let mut tree_json = json::object! {
1070            "document_id": self.document_id.to_string(),
1071            "config": {
1072                "merge_size": self.config.merge_size,
1073                "max_summary_length": self.config.max_summary_length,
1074                "min_node_size": self.config.min_node_size,
1075                "overlap_sentences": self.config.overlap_sentences
1076            },
1077            "nodes": {},
1078            "root_nodes": [],
1079            "levels": {}
1080        };
1081
1082        // Serialize nodes
1083        for (node_id, node) in &self.nodes {
1084            let node_json = json::object! {
1085                "id": node_id.to_string(),
1086                "content": node.content.clone(),
1087                "summary": node.summary.clone(),
1088                "level": node.level,
1089                "children": node.children.iter().map(|id| id.to_string()).collect::<Vec<_>>(),
1090                "parent": node.parent.as_ref().map(|id| id.to_string()),
1091                "chunk_ids": node.chunk_ids.iter().map(|id| id.to_string()).collect::<Vec<_>>(),
1092                "keywords": node.keywords.clone(),
1093                "start_offset": node.start_offset,
1094                "end_offset": node.end_offset
1095            };
1096            tree_json["nodes"][node_id.to_string()] = node_json;
1097        }
1098
1099        // Serialize root nodes
1100        tree_json["root_nodes"] = self
1101            .root_nodes
1102            .iter()
1103            .map(|id| JsonValue::String(id.to_string()))
1104            .collect::<Vec<_>>()
1105            .into();
1106
1107        // Serialize levels
1108        for (level, node_ids) in &self.levels {
1109            tree_json["levels"][level.to_string()] = node_ids
1110                .iter()
1111                .map(|id| JsonValue::String(id.to_string()))
1112                .collect::<Vec<_>>()
1113                .into();
1114        }
1115
1116        Ok(tree_json.dump())
1117    }
1118
1119    /// Load tree from JSON format
1120    pub fn from_json(json_str: &str) -> Result<Self> {
1121        let json_data = json::parse(json_str).map_err(crate::GraphRAGError::Json)?;
1122
1123        let document_id = DocumentId::new(
1124            json_data["document_id"]
1125                .as_str()
1126                .ok_or_else(|| {
1127                    crate::GraphRAGError::Json(json::Error::WrongType(
1128                        "document_id must be string".to_string(),
1129                    ))
1130                })?
1131                .to_string(),
1132        );
1133
1134        let config_json = &json_data["config"];
1135        let config = HierarchicalConfig {
1136            merge_size: config_json["merge_size"].as_usize().unwrap_or(5),
1137            max_summary_length: config_json["max_summary_length"].as_usize().unwrap_or(200),
1138            min_node_size: config_json["min_node_size"].as_usize().unwrap_or(50),
1139            overlap_sentences: config_json["overlap_sentences"].as_usize().unwrap_or(2),
1140            llm_config: LLMConfig::default(),
1141        };
1142
1143        let mut tree = Self::new(document_id, config)?;
1144
1145        // Load nodes
1146        if let json::JsonValue::Object(nodes_obj) = &json_data["nodes"] {
1147            for (node_id_str, node_json) in nodes_obj.iter() {
1148                let node_id = NodeId::new(node_id_str.to_string());
1149
1150                let children: Vec<NodeId> = node_json["children"]
1151                    .members()
1152                    .filter_map(|v| v.as_str())
1153                    .map(|s| NodeId::new(s.to_string()))
1154                    .collect();
1155
1156                let parent = node_json["parent"]
1157                    .as_str()
1158                    .map(|s| NodeId::new(s.to_string()));
1159
1160                let chunk_ids: Vec<ChunkId> = node_json["chunk_ids"]
1161                    .members()
1162                    .filter_map(|v| v.as_str())
1163                    .map(|s| ChunkId::new(s.to_string()))
1164                    .collect();
1165
1166                let keywords: Vec<String> = node_json["keywords"]
1167                    .members()
1168                    .filter_map(|v| v.as_str())
1169                    .map(|s| s.to_string())
1170                    .collect();
1171
1172                let node = TreeNode {
1173                    id: node_id.clone(),
1174                    content: node_json["content"].as_str().unwrap_or("").to_string(),
1175                    summary: node_json["summary"].as_str().unwrap_or("").to_string(),
1176                    level: node_json["level"].as_usize().unwrap_or(0),
1177                    children,
1178                    parent,
1179                    chunk_ids,
1180                    keywords,
1181                    start_offset: node_json["start_offset"].as_usize().unwrap_or(0),
1182                    end_offset: node_json["end_offset"].as_usize().unwrap_or(0),
1183                };
1184
1185                tree.nodes.insert(node_id, node);
1186            }
1187        }
1188
1189        // Load root nodes
1190        tree.root_nodes = json_data["root_nodes"]
1191            .members()
1192            .filter_map(|v| v.as_str())
1193            .map(|s| NodeId::new(s.to_string()))
1194            .collect();
1195
1196        // Load levels
1197        if let json::JsonValue::Object(levels_obj) = &json_data["levels"] {
1198            for (level_str, level_json) in levels_obj.iter() {
1199                if let Ok(level) = level_str.parse::<usize>() {
1200                    let node_ids: Vec<NodeId> = level_json
1201                        .members()
1202                        .filter_map(|v| v.as_str())
1203                        .map(|s| NodeId::new(s.to_string()))
1204                        .collect();
1205                    tree.levels.insert(level, node_ids);
1206                }
1207            }
1208        }
1209
1210        Ok(tree)
1211    }
1212}
1213
1214/// Result from querying the hierarchical tree
1215#[derive(Debug, Clone)]
1216pub struct QueryResult {
1217    /// ID of the matching node
1218    pub node_id: NodeId,
1219    /// Relevance score for the query match
1220    pub score: f32,
1221    /// Tree level of this result
1222    pub level: usize,
1223    /// Summary text of the matching node
1224    pub summary: String,
1225    /// Keywords associated with the node
1226    pub keywords: Vec<String>,
1227    /// Chunk IDs represented in this result
1228    pub chunk_ids: Vec<ChunkId>,
1229}
1230
1231/// Statistics about the hierarchical tree
1232#[derive(Debug)]
1233pub struct TreeStatistics {
1234    /// Total number of nodes in the tree
1235    pub total_nodes: usize,
1236    /// Maximum depth level of the tree
1237    pub max_level: usize,
1238    /// Count of nodes at each level
1239    pub nodes_per_level: HashMap<usize, usize>,
1240    /// Number of root nodes in the tree
1241    pub root_count: usize,
1242    /// ID of the document this tree represents
1243    pub document_id: DocumentId,
1244}
1245
1246impl TreeStatistics {
1247    /// Print tree statistics
1248    pub fn print(&self) {
1249        println!("Hierarchical Tree Statistics:");
1250        println!("  Document ID: {}", self.document_id);
1251        println!("  Total nodes: {}", self.total_nodes);
1252        println!("  Max level: {}", self.max_level);
1253        println!("  Root nodes: {}", self.root_count);
1254        println!("  Nodes per level:");
1255
1256        let mut levels: Vec<_> = self.nodes_per_level.iter().collect();
1257        levels.sort_by_key(|(level, _)| *level);
1258
1259        for (level, count) in levels {
1260            println!("    Level {level}: {count} nodes");
1261        }
1262    }
1263}
1264
1265#[cfg(test)]
1266mod tests {
1267    use super::*;
1268    use crate::core::DocumentId;
1269
1270    #[test]
1271    fn test_tree_creation() {
1272        let config = HierarchicalConfig::default();
1273        let doc_id = DocumentId::new("test_doc".to_string());
1274        let tree = DocumentTree::new(doc_id, config);
1275        assert!(tree.is_ok());
1276    }
1277
1278    #[test]
1279    fn test_extractive_summarization() {
1280        let config = HierarchicalConfig::default();
1281        let doc_id = DocumentId::new("test_doc".to_string());
1282        let tree = DocumentTree::new(doc_id, config).unwrap();
1283
1284        let text = "This is the first sentence. This is a second sentence with more details. This is the final sentence.";
1285        let summary = tree.generate_extractive_summary(text).unwrap();
1286
1287        assert!(!summary.is_empty());
1288        assert!(summary.len() <= tree.config.max_summary_length);
1289    }
1290
1291    #[test]
1292    fn test_json_serialization() {
1293        let config = HierarchicalConfig::default();
1294        let doc_id = DocumentId::new("test_doc".to_string());
1295        let tree = DocumentTree::new(doc_id, config).unwrap();
1296
1297        let json = tree.to_json().unwrap();
1298        assert!(json.contains("test_doc"));
1299
1300        let loaded_tree = DocumentTree::from_json(&json).unwrap();
1301        assert_eq!(loaded_tree.document_id.to_string(), "test_doc");
1302    }
1303}