Skip to main content

graphrag_core/summarization/
mod.rs

1use crate::{
2    core::{ChunkId, DocumentId, TextChunk, GraphRAGError},
3    text::TextProcessor,
4    Result,
5};
6#[cfg(feature = "parallel-processing")]
7use crate::parallel::ParallelProcessor;
8use indexmap::IndexMap;
9use std::collections::{HashMap, VecDeque};
10use std::sync::Arc;
11
12/// Trait for LLM client to be used in summarization
13#[async_trait::async_trait]
14pub trait LLMClient: Send + Sync {
15    /// Generate a summary for the given text
16    async fn generate_summary(&self, text: &str, prompt: &str, max_tokens: usize, temperature: f32) -> Result<String>;
17
18    /// Generate summary in batch for multiple texts
19    async fn generate_summary_batch(&self, texts: &[(&str, &str)], max_tokens: usize, temperature: f32) -> Result<Vec<String>> {
20        let mut results = Vec::new();
21        for (text, prompt) in texts {
22            let summary = self.generate_summary(text, prompt, max_tokens, temperature).await?;
23            results.push(summary);
24        }
25        Ok(results)
26    }
27
28    /// Get model name
29    fn model_name(&self) -> &str;
30}
31
32/// Unique identifier for tree nodes
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
34pub struct NodeId(pub String);
35
36impl NodeId {
37    /// Creates a new NodeId from a string
38    pub fn new(id: String) -> Self {
39        Self(id)
40    }
41}
42
43impl std::fmt::Display for NodeId {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        write!(f, "{}", self.0)
46    }
47}
48
49impl From<String> for NodeId {
50    fn from(s: String) -> Self {
51        Self(s)
52    }
53}
54
55impl From<NodeId> for String {
56    fn from(id: NodeId) -> Self {
57        id.0
58    }
59}
60
61/// Configuration for hierarchical summarization
62#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
63pub struct HierarchicalConfig {
64    /// Number of nodes to merge when building tree levels
65    pub merge_size: usize,
66    /// Maximum character length for generated summaries
67    pub max_summary_length: usize,
68    /// Minimum size in characters for nodes to be considered valid
69    pub min_node_size: usize,
70    /// Number of overlapping sentences between adjacent chunks
71    pub overlap_sentences: usize,
72    /// LLM-based summarization configuration
73    pub llm_config: LLMConfig,
74}
75
76/// Configuration for LLM-based summarization
77#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
78pub struct LLMConfig {
79    /// Whether to use LLM for summarization (vs extractive)
80    pub enabled: bool,
81    /// Model to use for summarization
82    pub model_name: String,
83    /// Temperature for generation (lower = more deterministic)
84    pub temperature: f32,
85    /// Maximum tokens for LLM generation
86    pub max_tokens: usize,
87    /// Strategy for summarization
88    pub strategy: LLMStrategy,
89    /// Level-specific configurations
90    pub level_configs: HashMap<usize, LevelConfig>,
91}
92
93/// Summarization strategy for different levels
94#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
95pub enum LLMStrategy {
96    /// Use same approach for all levels
97    Uniform,
98    /// Different approaches for different levels
99    Adaptive,
100    /// Progressive: extractive for lower levels, abstractive for higher
101    Progressive,
102}
103
104/// Configuration specific to tree levels
105#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
106pub struct LevelConfig {
107    /// Maximum summary length for this level
108    pub max_length: usize,
109    /// Whether to use abstractive summarization at this level
110    pub use_abstractive: bool,
111    /// Custom prompt template for this level
112    pub prompt_template: Option<String>,
113    /// Temperature override for this level
114    pub temperature: Option<f32>,
115}
116
117impl Default for HierarchicalConfig {
118    fn default() -> Self {
119        Self {
120            merge_size: 5,
121            max_summary_length: 200,
122            min_node_size: 50,
123            overlap_sentences: 2,
124            llm_config: LLMConfig::default(),
125        }
126    }
127}
128
129impl Default for LLMConfig {
130    fn default() -> Self {
131        Self {
132            enabled: false, // Disabled by default for backward compatibility
133            model_name: "llama3.1:8b".to_string(),
134            temperature: 0.3, // Lower temperature for more coherent summaries
135            max_tokens: 150,
136            strategy: LLMStrategy::Progressive,
137            level_configs: HashMap::new(),
138        }
139    }
140}
141
142impl Default for LevelConfig {
143    fn default() -> Self {
144        Self {
145            max_length: 200,
146            use_abstractive: false, // Start with extractive at lower levels
147            prompt_template: None,
148            temperature: None,
149        }
150    }
151}
152
153/// A node in the hierarchical document tree
154#[derive(Debug, Clone)]
155pub struct TreeNode {
156    /// Unique identifier for this node
157    pub id: NodeId,
158    /// Full text content of this node
159    pub content: String,
160    /// Generated summary of the content
161    pub summary: String,
162    /// Level in the tree hierarchy (0 = leaf, higher = more abstract)
163    pub level: usize,
164    /// IDs of child nodes in the tree
165    pub children: Vec<NodeId>,
166    /// ID of the parent node, if any
167    pub parent: Option<NodeId>,
168    /// IDs of text chunks represented by this node
169    pub chunk_ids: Vec<ChunkId>,
170    /// Extracted keywords from the content
171    pub keywords: Vec<String>,
172    /// Starting character offset in the original document
173    pub start_offset: usize,
174    /// Ending character offset in the original document
175    pub end_offset: usize,
176}
177
178/// Hierarchical document tree for multi-level summarization
179pub struct DocumentTree {
180    nodes: IndexMap<NodeId, TreeNode>,
181    root_nodes: Vec<NodeId>,
182    levels: HashMap<usize, Vec<NodeId>>,
183    document_id: DocumentId,
184    config: HierarchicalConfig,
185    text_processor: TextProcessor,
186    llm_client: Option<Arc<dyn LLMClient>>,
187}
188
189impl DocumentTree {
190    /// Create a new document tree
191    pub fn new(document_id: DocumentId, config: HierarchicalConfig) -> Result<Self> {
192        let text_processor = TextProcessor::new(1000, 100)?;
193
194        Ok(Self {
195            nodes: IndexMap::new(),
196            root_nodes: Vec::new(),
197            levels: HashMap::new(),
198            document_id,
199            config,
200            text_processor,
201            llm_client: None,
202        })
203    }
204
205    /// Create a new document tree with LLM client
206    pub fn with_llm_client(
207        document_id: DocumentId,
208        config: HierarchicalConfig,
209        llm_client: Arc<dyn LLMClient>,
210    ) -> Result<Self> {
211        let text_processor = TextProcessor::new(1000, 100)?;
212
213        Ok(Self {
214            nodes: IndexMap::new(),
215            root_nodes: Vec::new(),
216            levels: HashMap::new(),
217            document_id,
218            config,
219            text_processor,
220            llm_client: Some(llm_client),
221        })
222    }
223
224    /// Set LLM client for the tree
225    pub fn set_llm_client(&mut self, llm_client: Option<Arc<dyn LLMClient>>) {
226        self.llm_client = llm_client;
227    }
228
229    /// Create a new document tree with parallel processing support
230    #[cfg(feature = "parallel-processing")]
231    pub fn with_parallel_processing(
232        document_id: DocumentId,
233        config: HierarchicalConfig,
234        _parallel_processor: ParallelProcessor,
235    ) -> Result<Self> {
236        let text_processor = TextProcessor::new(1000, 100)?;
237
238        Ok(Self {
239            nodes: IndexMap::new(),
240            root_nodes: Vec::new(),
241            levels: HashMap::new(),
242            document_id,
243            config,
244            text_processor,
245            llm_client: None,
246        })
247    }
248
249    /// Create a new document tree with both parallel processing and LLM client
250    #[cfg(feature = "parallel-processing")]
251    pub fn with_parallel_and_llm(
252        document_id: DocumentId,
253        config: HierarchicalConfig,
254        _parallel_processor: ParallelProcessor,
255        llm_client: Arc<dyn LLMClient>,
256    ) -> Result<Self> {
257        let text_processor = TextProcessor::new(1000, 100)?;
258
259        Ok(Self {
260            nodes: IndexMap::new(),
261            root_nodes: Vec::new(),
262            levels: HashMap::new(),
263            document_id,
264            config,
265            text_processor,
266            llm_client: Some(llm_client),
267        })
268    }
269
270    /// Build the hierarchical tree from text chunks
271    pub async fn build_from_chunks(&mut self, chunks: Vec<TextChunk>) -> Result<()> {
272        if chunks.is_empty() {
273            return Ok(());
274        }
275
276        println!("Building hierarchical tree from {} chunks", chunks.len());
277
278        // Create leaf nodes from chunks
279        let leaf_nodes = self.create_leaf_nodes(chunks)?;
280
281        // Build the tree bottom-up
282        self.build_bottom_up(leaf_nodes).await?;
283
284        println!(
285            "Tree built with {} total nodes across {} levels",
286            self.nodes.len(),
287            self.levels.len()
288        );
289
290        Ok(())
291    }
292
293    /// Create leaf nodes from text chunks with parallel processing
294    fn create_leaf_nodes(&mut self, chunks: Vec<TextChunk>) -> Result<Vec<NodeId>> {
295        if chunks.len() < 10 {
296            // Use sequential processing for small numbers of chunks
297            return self.create_leaf_nodes_sequential(chunks);
298        }
299
300        #[cfg(feature = "parallel-processing")]
301        {
302            use rayon::prelude::*;
303
304            // Parallel node creation with proper error handling
305            let node_results: std::result::Result<Vec<_>, crate::GraphRAGError> = chunks
306                .par_iter()
307                .map(|chunk| {
308                    let node_id = NodeId::new(format!("leaf_{}", chunk.id));
309
310                    // Create a temporary text processor for each thread to avoid borrowing issues
311                    let temp_processor = crate::text::TextProcessor::new(1000, 100)
312                        .map_err(|e| crate::GraphRAGError::Config {
313                            message: format!("Failed to create text processor: {e}"),
314                        })?;
315
316                    // Extract keywords for the chunk
317                    let keywords = temp_processor.extract_keywords(&chunk.content, 5);
318
319                    // Generate summary using simplified approach suitable for parallel execution
320                    let summary = self.generate_parallel_summary(&chunk.content, &temp_processor)?;
321
322                    let node = TreeNode {
323                        id: node_id.clone(),
324                        content: chunk.content.clone(),
325                        summary,
326                        level: 0,
327                        children: Vec::new(),
328                        parent: None,
329                        chunk_ids: vec![chunk.id.clone()],
330                        keywords,
331                        start_offset: chunk.start_offset,
332                        end_offset: chunk.end_offset,
333                    };
334
335                    Ok((node_id, node))
336                })
337                .collect();
338
339            match node_results {
340                Ok(nodes) => {
341                    let mut leaf_node_ids = Vec::new();
342
343                    // Insert nodes sequentially to avoid concurrent modification
344                    for (node_id, node) in nodes {
345                        leaf_node_ids.push(node_id.clone());
346                        self.nodes.insert(node_id, node);
347                    }
348
349                    println!("Created {} leaf nodes in parallel", leaf_node_ids.len());
350
351                    // Update levels tracking
352                    self.levels.insert(0, leaf_node_ids.clone());
353
354                    Ok(leaf_node_ids)
355                }
356                Err(e) => {
357                    eprintln!("Error in parallel node creation: {e}");
358                    // Fall back to sequential processing
359                    self.create_leaf_nodes_sequential(chunks)
360                }
361            }
362        }
363
364        #[cfg(not(feature = "parallel-processing"))]
365        {
366            self.create_leaf_nodes_sequential(chunks)
367        }
368    }
369
370    /// Sequential leaf node creation (fallback)
371    fn create_leaf_nodes_sequential(&mut self, chunks: Vec<TextChunk>) -> Result<Vec<NodeId>> {
372        let mut leaf_node_ids = Vec::new();
373
374        for chunk in chunks {
375            let node_id = NodeId::new(format!("leaf_{}", chunk.id));
376
377            // Extract keywords for the chunk
378            let keywords = self.text_processor.extract_keywords(&chunk.content, 5);
379
380            let node = TreeNode {
381                id: node_id.clone(),
382                content: chunk.content.clone(),
383                summary: self.generate_extractive_summary(&chunk.content)?,
384                level: 0,
385                children: Vec::new(),
386                parent: None,
387                chunk_ids: vec![chunk.id],
388                keywords,
389                start_offset: chunk.start_offset,
390                end_offset: chunk.end_offset,
391            };
392
393            self.nodes.insert(node_id.clone(), node);
394            leaf_node_ids.push(node_id);
395        }
396
397        // Update levels tracking
398        self.levels.insert(0, leaf_node_ids.clone());
399
400        Ok(leaf_node_ids)
401    }
402
403    /// Generate summary suitable for parallel processing
404    /// Generate summary using LLM for the given text and level
405    pub async fn generate_llm_summary(
406        &self,
407        text: &str,
408        level: usize,
409        context: &str,
410    ) -> Result<String> {
411        let llm_client = self.llm_client.as_ref()
412            .ok_or_else(|| GraphRAGError::Config {
413                message: "LLM client not configured for summarization".to_string(),
414            })?;
415
416        // Get configuration for this level
417        let level_config = self.get_level_config(level);
418
419        // Create prompt based on strategy and level
420        let prompt = self.create_summary_prompt(text, level, context, &level_config)?;
421
422        // Generate summary
423        let summary = llm_client.generate_summary(
424            text,
425            &prompt,
426            level_config.max_length,
427            level_config.temperature.unwrap_or(self.config.llm_config.temperature),
428        ).await?;
429
430        // Ensure summary is within length limits
431        self.truncate_summary(&summary, level_config.max_length)
432    }
433
434    /// Generate summaries in batch for multiple texts
435    pub async fn generate_llm_summaries_batch(
436        &self,
437        texts: &[(&str, usize, &str)], // (text, level, context)
438    ) -> Result<Vec<String>> {
439        let llm_client = self.llm_client.as_ref()
440            .ok_or_else(|| GraphRAGError::Config {
441                message: "LLM client not configured for summarization".to_string(),
442            })?;
443
444        let mut prompts = Vec::new();
445        let mut configs = Vec::new();
446
447        for (text, level, context) in texts {
448            let level_config = self.get_level_config(*level);
449            let prompt = self.create_summary_prompt(text, *level, context, &level_config)?;
450            prompts.push(prompt);
451            configs.push(level_config);
452        }
453
454        // Generate summaries in batch
455        let text_refs: Vec<&str> = texts.iter().map(|(t, _, _)| *t).collect();
456        let prompt_refs: Vec<&str> = prompts.iter().map(|p| p.as_str()).collect();
457
458        let summaries = llm_client.generate_summary_batch(
459            &text_refs.iter().zip(prompt_refs.iter()).map(|(&t, &p)| (t, p)).collect::<Vec<_>>(),
460            self.config.llm_config.max_tokens,
461            self.config.llm_config.temperature,
462        ).await?;
463
464        // Truncate summaries according to level configs
465        let mut results = Vec::new();
466        for (i, summary) in summaries.into_iter().enumerate() {
467            let truncated = self.truncate_summary(&summary, configs[i].max_length)?;
468            results.push(truncated);
469        }
470
471        Ok(results)
472    }
473
474    /// Create a summary prompt based on level and strategy
475    fn create_summary_prompt(
476        &self,
477        text: &str,
478        level: usize,
479        context: &str,
480        level_config: &LevelConfig,
481    ) -> Result<String> {
482        // Use custom template if provided
483        if let Some(template) = &level_config.prompt_template {
484            return Ok(template
485                .replace("{text}", text)
486                .replace("{context}", context)
487                .replace("{level}", &level.to_string())
488                .replace("{max_length}", &level_config.max_length.to_string()));
489        }
490
491        // Default prompts based on strategy and level
492        match self.config.llm_config.strategy {
493            LLMStrategy::Uniform => {
494                Ok(format!(
495                    "Create a concise summary of the following text. The summary should be approximately {} characters long.\n\nContext: {}\n\nText to summarize:\n{}\n\nSummary:",
496                    level_config.max_length, context, text
497                ))
498            }
499            LLMStrategy::Adaptive => {
500                if level == 0 {
501                    Ok(format!(
502                        "Extract the key information from this text segment. Keep it factual and under {} characters.\n\nContext: {}\n\nText:\n{}\n\nKey points:",
503                        level_config.max_length, context, text
504                    ))
505                } else if level <= 2 {
506                    Ok(format!(
507                        "Create a coherent summary that combines the key information from this text. Make it approximately {} characters.\n\nContext: {}\n\nText:\n{}\n\nSummary:",
508                        level_config.max_length, context, text
509                    ))
510                } else {
511                    Ok(format!(
512                        "Generate a high-level abstract summary of this content. Focus on the main themes and insights. Limit to approximately {} characters.\n\nContext: {}\n\nText:\n{}\n\nAbstract summary:",
513                        level_config.max_length, context, text
514                    ))
515                }
516            }
517            LLMStrategy::Progressive => {
518                if level_config.use_abstractive {
519                    Ok(format!(
520                        "Generate an abstractive summary that synthesizes the key concepts and relationships in this text. The summary should be approximately {} characters.\n\nContext: {}\n\nText:\n{}\n\nAbstractive summary:",
521                        level_config.max_length, context, text
522                    ))
523                } else {
524                    Ok(format!(
525                        "Extract and organize the most important sentences from this text to create a coherent summary. Keep it under {} characters.\n\nContext: {}\n\nText:\n{}\n\nExtractive summary:",
526                        level_config.max_length, context, text
527                    ))
528                }
529            }
530        }
531    }
532
533    /// Get configuration for a specific level
534    fn get_level_config(&self, level: usize) -> LevelConfig {
535        self.config.llm_config.level_configs
536            .get(&level)
537            .cloned()
538            .unwrap_or_else(|| {
539                // Default configuration based on level
540                LevelConfig {
541                    max_length: self.config.max_summary_length,
542                    use_abstractive: match self.config.llm_config.strategy {
543                        LLMStrategy::Progressive => level >= 2,
544                        LLMStrategy::Adaptive => level >= 3,
545                        LLMStrategy::Uniform => level > 0,
546                    },
547                    prompt_template: None,
548                    temperature: None,
549                }
550            })
551    }
552
553    /// Truncate summary to fit within length limits
554    fn truncate_summary(&self, summary: &str, max_length: usize) -> Result<String> {
555        if summary.len() <= max_length {
556            return Ok(summary.to_string());
557        }
558
559        // Try to truncate at sentence boundaries
560        let sentences: Vec<&str> = summary
561            .split('.')
562            .filter(|s| !s.trim().is_empty())
563            .collect();
564
565        let mut result = String::new();
566        for sentence in sentences {
567            if result.len() + sentence.len() + 1 <= max_length - 3 {
568                if !result.is_empty() {
569                    result.push('.');
570                }
571                result.push_str(sentence.trim());
572            } else {
573                break;
574            }
575        }
576
577        if result.is_empty() {
578            // Fallback: truncate characters
579            result = summary.chars().take(max_length - 3).collect();
580            result.push_str("...");
581        } else {
582            result.push('.');
583        }
584
585        Ok(result)
586    }
587
588    #[allow(dead_code)]
589    fn generate_parallel_summary(
590        &self,
591        text: &str,
592        processor: &crate::text::TextProcessor,
593    ) -> Result<String> {
594        let sentences = processor.extract_sentences(text);
595
596        if sentences.is_empty() {
597            return Ok(String::new());
598        }
599
600        if sentences.len() == 1 {
601            return Ok(sentences[0].clone());
602        }
603
604        // Simplified scoring for parallel execution
605        let mut best_sentence = &sentences[0];
606        let mut best_score = 0.0;
607
608        for sentence in &sentences {
609            let words: Vec<&str> = sentence.split_whitespace().collect();
610
611            // Simple scoring based on length and word density
612            let length_score = if words.len() < 5 {
613                0.1
614            } else if words.len() > 30 {
615                0.3
616            } else {
617                1.0
618            };
619
620            let word_score = words.len() as f32 * 0.1;
621            let score = length_score + word_score;
622
623            if score > best_score {
624                best_score = score;
625                best_sentence = sentence;
626            }
627        }
628
629        // Truncate if necessary
630        if best_sentence.len() > self.config.max_summary_length {
631            Ok(best_sentence
632                .chars()
633                .take(self.config.max_summary_length - 3)
634                .collect::<String>()
635                + "...")
636        } else {
637            Ok(best_sentence.clone())
638        }
639    }
640
641    /// Build the tree bottom-up by merging nodes at each level
642    async fn build_bottom_up(&mut self, leaf_nodes: Vec<NodeId>) -> Result<()> {
643        let mut current_level_nodes = leaf_nodes;
644        let mut current_level = 0;
645
646        while current_level_nodes.len() > 1 {
647            let next_level_nodes = self.merge_level(&current_level_nodes, current_level + 1).await?;
648
649            current_level_nodes = next_level_nodes;
650            current_level += 1;
651        }
652
653        // Set root nodes
654        self.root_nodes = current_level_nodes;
655
656        Ok(())
657    }
658
659    /// Merge nodes at a level to create the next level up
660    async fn merge_level(&mut self, level_nodes: &[NodeId], new_level: usize) -> Result<Vec<NodeId>> {
661        let mut new_level_nodes = Vec::new();
662        // Group nodes by merge_size
663        for (node_counter, chunk) in level_nodes.chunks(self.config.merge_size).enumerate() {
664            let merged_node_id = NodeId::new(format!("level_{new_level}_{node_counter}"));
665            let merged_node = self.merge_nodes(chunk, merged_node_id.clone(), new_level).await?;
666
667            // Update parent references for children
668            for child_id in chunk {
669                if let Some(child_node) = self.nodes.get_mut(child_id) {
670                    child_node.parent = Some(merged_node_id.clone());
671                }
672            }
673
674            self.nodes.insert(merged_node_id.clone(), merged_node);
675            new_level_nodes.push(merged_node_id);
676        }
677
678        // Update levels tracking
679        self.levels.insert(new_level, new_level_nodes.clone());
680
681        Ok(new_level_nodes)
682    }
683
684    /// Merge multiple nodes into a single parent node
685    async fn merge_nodes(
686        &self,
687        node_ids: &[NodeId],
688        merged_id: NodeId,
689        level: usize,
690    ) -> Result<TreeNode> {
691        let mut combined_content = String::new();
692        let mut all_chunk_ids = Vec::new();
693        let mut all_keywords = Vec::new();
694        let mut min_offset = usize::MAX;
695        let mut max_offset = 0;
696
697        for node_id in node_ids {
698            if let Some(node) = self.nodes.get(node_id) {
699                if !combined_content.is_empty() {
700                    combined_content.push_str("\n\n");
701                }
702                combined_content.push_str(&node.content);
703                all_chunk_ids.extend(node.chunk_ids.clone());
704                all_keywords.extend(node.keywords.clone());
705                min_offset = min_offset.min(node.start_offset);
706                max_offset = max_offset.max(node.end_offset);
707            }
708        }
709
710        // Deduplicate and limit keywords
711        all_keywords.sort();
712        all_keywords.dedup();
713        all_keywords.truncate(10);
714
715        // Generate summary for the merged content
716        let summary = if self.config.llm_config.enabled {
717            // Use LLM-based summarization if enabled and available
718            if self.llm_client.is_some() {
719                // Create context for this merge operation
720                let context = format!(
721                    "Merging {} nodes at level {}. This represents a higher-level abstraction of the document content.",
722                    node_ids.len(),
723                    level
724                );
725
726                // Try LLM-based summarization, fall back to extractive if it fails
727                match self.generate_llm_summary(&combined_content, level, &context).await {
728                    Ok(llm_summary) => {
729                        println!("✅ Generated LLM-based summary for level {} ({} chars)", level, llm_summary.len());
730                        llm_summary
731                    }
732                    Err(e) => {
733                        eprintln!("⚠️ LLM summarization failed for level {}: {}, falling back to extractive", level, e);
734                        self.generate_extractive_summary(&combined_content)?
735                    }
736                }
737            } else {
738                self.generate_extractive_summary(&combined_content)?
739            }
740        } else {
741            self.generate_extractive_summary(&combined_content)?
742        };
743
744        Ok(TreeNode {
745            id: merged_id,
746            content: combined_content,
747            summary,
748            level,
749            children: node_ids.to_vec(),
750            parent: None,
751            chunk_ids: all_chunk_ids,
752            keywords: all_keywords,
753            start_offset: min_offset,
754            end_offset: max_offset,
755        })
756    }
757
758    /// Generate extractive summary using simple sentence ranking
759    fn generate_extractive_summary(&self, text: &str) -> Result<String> {
760        let sentences = self.text_processor.extract_sentences(text);
761
762        if sentences.is_empty() {
763            return Ok(String::new());
764        }
765
766        if sentences.len() == 1 {
767            return Ok(sentences[0].clone());
768        }
769
770        // Score sentences based on length and keyword density
771        let mut sentence_scores: Vec<(usize, f32)> = sentences
772            .iter()
773            .enumerate()
774            .map(|(i, sentence)| {
775                let score = self.score_sentence(sentence, &sentences);
776                (i, score)
777            })
778            .collect();
779
780        // Sort by score descending
781        sentence_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
782
783        // Select top sentences up to max_summary_length
784        let mut summary = String::new();
785        let mut selected_indices = Vec::new();
786
787        for (sentence_idx, _score) in sentence_scores {
788            let sentence = &sentences[sentence_idx];
789            if summary.len() + sentence.len() <= self.config.max_summary_length {
790                selected_indices.push(sentence_idx);
791                if !summary.is_empty() {
792                    summary.push(' ');
793                }
794                summary.push_str(sentence);
795            }
796        }
797
798        // If no sentences fit, take the first one truncated
799        if summary.is_empty() && !sentences.is_empty() {
800            let first_sentence = &sentences[0];
801            if first_sentence.len() <= self.config.max_summary_length {
802                summary = first_sentence.clone();
803            } else {
804                summary = first_sentence
805                    .chars()
806                    .take(self.config.max_summary_length - 3)
807                    .collect::<String>()
808                    + "...";
809            }
810        }
811
812        Ok(summary)
813    }
814
815    /// Score a sentence for extractive summarization
816    fn score_sentence(&self, sentence: &str, all_sentences: &[String]) -> f32 {
817        let words: Vec<&str> = sentence.split_whitespace().collect();
818
819        // Base score from length (prefer medium-length sentences)
820        let length_score = if words.len() < 5 {
821            0.1
822        } else if words.len() > 30 {
823            0.3
824        } else {
825            1.0
826        };
827
828        // Position score (prefer sentences from beginning and end)
829        let position_score = 0.5; // Simplified for now
830
831        // Word frequency score
832        let mut word_freq_score = 0.0;
833        let total_words: Vec<&str> = all_sentences
834            .iter()
835            .flat_map(|s| s.split_whitespace())
836            .collect();
837
838        for word in &words {
839            let word_lower = word.to_lowercase();
840            if word_lower.len() > 3 && !self.is_stop_word(&word_lower) {
841                let freq = total_words
842                    .iter()
843                    .filter(|&&w| w.to_lowercase() == word_lower)
844                    .count();
845                if freq > 1 {
846                    word_freq_score += freq as f32 / total_words.len() as f32;
847                }
848            }
849        }
850
851        length_score * 0.4 + position_score * 0.2 + word_freq_score * 0.4
852    }
853
854    /// Simple stop word detection (English)
855    fn is_stop_word(&self, word: &str) -> bool {
856        const STOP_WORDS: &[&str] = &[
857            "the", "be", "to", "of", "and", "a", "in", "that", "have", "i", "it", "for", "not",
858            "on", "with", "he", "as", "you", "do", "at", "this", "but", "his", "by", "from",
859            "they", "we", "say", "her", "she", "or", "an", "will", "my", "one", "all", "would",
860            "there", "their", "what", "so", "up", "out", "if", "about", "who", "get", "which",
861            "go", "me",
862        ];
863        STOP_WORDS.contains(&word)
864    }
865
866    /// Query the tree for relevant nodes at different levels
867    pub fn query(&self, query: &str, max_results: usize) -> Result<Vec<QueryResult>> {
868        let query_keywords = self.text_processor.extract_keywords(query, 5);
869        let mut results = Vec::new();
870
871        // Search all nodes for keyword matches
872        for (node_id, node) in &self.nodes {
873            let score = self.calculate_relevance_score(node, &query_keywords, query);
874
875            if score > 0.1 {
876                results.push(QueryResult {
877                    node_id: node_id.clone(),
878                    score,
879                    level: node.level,
880                    summary: node.summary.clone(),
881                    keywords: node.keywords.clone(),
882                    chunk_ids: node.chunk_ids.clone(),
883                });
884            }
885        }
886
887        // Sort by score and return top results
888        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
889        results.truncate(max_results);
890
891        Ok(results)
892    }
893
894    /// Calculate relevance score for a node given query keywords
895    fn calculate_relevance_score(
896        &self,
897        node: &TreeNode,
898        query_keywords: &[String],
899        query: &str,
900    ) -> f32 {
901        let mut score = 0.0;
902
903        // Keyword overlap score
904        let node_text = format!("{} {}", node.summary, node.keywords.join(" ")).to_lowercase();
905        for keyword in query_keywords {
906            if node_text.contains(&keyword.to_lowercase()) {
907                score += 1.0;
908            }
909        }
910
911        // Direct text similarity (simple word overlap)
912        let query_words: Vec<&str> = query.split_whitespace().collect();
913        let node_words: Vec<&str> = node_text.split_whitespace().collect();
914
915        let mut overlap_count = 0;
916        for query_word in &query_words {
917            if node_words.contains(&query_word.to_lowercase().as_str()) {
918                overlap_count += 1;
919            }
920        }
921
922        if !query_words.is_empty() {
923            score += (overlap_count as f32 / query_words.len() as f32) * 2.0;
924        }
925
926        // Level bonus (prefer higher levels for overview, lower levels for details)
927        let level_score = 1.0 / (node.level + 1) as f32;
928        score += level_score * 0.5;
929
930        score
931    }
932
933    /// Get ancestors of a node (path to root)
934    pub fn get_ancestors(&self, node_id: &NodeId) -> Vec<&TreeNode> {
935        let mut ancestors = Vec::new();
936        let mut current_id = Some(node_id.clone());
937
938        while let Some(id) = current_id {
939            if let Some(node) = self.nodes.get(&id) {
940                ancestors.push(node);
941                current_id = node.parent.clone();
942            } else {
943                break;
944            }
945        }
946
947        ancestors
948    }
949
950    /// Get descendants of a node (all children recursively)
951    pub fn get_descendants(&self, node_id: &NodeId) -> Vec<&TreeNode> {
952        let mut descendants = Vec::new();
953        let mut queue = VecDeque::new();
954
955        if let Some(node) = self.nodes.get(node_id) {
956            queue.extend(node.children.iter());
957        }
958
959        while let Some(child_id) = queue.pop_front() {
960            if let Some(child_node) = self.nodes.get(child_id) {
961                descendants.push(child_node);
962                queue.extend(child_node.children.iter());
963            }
964        }
965
966        descendants
967    }
968
969    /// Get a node by ID
970    pub fn get_node(&self, node_id: &NodeId) -> Option<&TreeNode> {
971        self.nodes.get(node_id)
972    }
973
974    /// Get all nodes at a specific level
975    pub fn get_level_nodes(&self, level: usize) -> Vec<&TreeNode> {
976        if let Some(node_ids) = self.levels.get(&level) {
977            node_ids
978                .iter()
979                .filter_map(|id| self.nodes.get(id))
980                .collect()
981        } else {
982            Vec::new()
983        }
984    }
985
986    /// Get the root nodes of the tree
987    pub fn get_root_nodes(&self) -> Vec<&TreeNode> {
988        self.root_nodes
989            .iter()
990            .filter_map(|id| self.nodes.get(id))
991            .collect()
992    }
993
994    /// Get the document ID
995    pub fn document_id(&self) -> &DocumentId {
996        &self.document_id
997    }
998
999    /// Get tree statistics
1000    pub fn get_statistics(&self) -> TreeStatistics {
1001        let max_level = self.levels.keys().max().copied().unwrap_or(0);
1002        let total_nodes = self.nodes.len();
1003        let nodes_per_level: HashMap<usize, usize> = self
1004            .levels
1005            .iter()
1006            .map(|(level, nodes)| (*level, nodes.len()))
1007            .collect();
1008
1009        TreeStatistics {
1010            total_nodes,
1011            max_level,
1012            nodes_per_level,
1013            root_count: self.root_nodes.len(),
1014            document_id: self.document_id.clone(),
1015        }
1016    }
1017
1018    /// Serialize the tree to JSON format
1019    pub fn to_json(&self) -> Result<String> {
1020        use json::JsonValue;
1021
1022        let mut tree_json = json::object! {
1023            "document_id": self.document_id.to_string(),
1024            "config": {
1025                "merge_size": self.config.merge_size,
1026                "max_summary_length": self.config.max_summary_length,
1027                "min_node_size": self.config.min_node_size,
1028                "overlap_sentences": self.config.overlap_sentences
1029            },
1030            "nodes": {},
1031            "root_nodes": [],
1032            "levels": {}
1033        };
1034
1035        // Serialize nodes
1036        for (node_id, node) in &self.nodes {
1037            let node_json = json::object! {
1038                "id": node_id.to_string(),
1039                "content": node.content.clone(),
1040                "summary": node.summary.clone(),
1041                "level": node.level,
1042                "children": node.children.iter().map(|id| id.to_string()).collect::<Vec<_>>(),
1043                "parent": node.parent.as_ref().map(|id| id.to_string()),
1044                "chunk_ids": node.chunk_ids.iter().map(|id| id.to_string()).collect::<Vec<_>>(),
1045                "keywords": node.keywords.clone(),
1046                "start_offset": node.start_offset,
1047                "end_offset": node.end_offset
1048            };
1049            tree_json["nodes"][node_id.to_string()] = node_json;
1050        }
1051
1052        // Serialize root nodes
1053        tree_json["root_nodes"] = self
1054            .root_nodes
1055            .iter()
1056            .map(|id| JsonValue::String(id.to_string()))
1057            .collect::<Vec<_>>()
1058            .into();
1059
1060        // Serialize levels
1061        for (level, node_ids) in &self.levels {
1062            tree_json["levels"][level.to_string()] = node_ids
1063                .iter()
1064                .map(|id| JsonValue::String(id.to_string()))
1065                .collect::<Vec<_>>()
1066                .into();
1067        }
1068
1069        Ok(tree_json.dump())
1070    }
1071
1072    /// Load tree from JSON format
1073    pub fn from_json(json_str: &str) -> Result<Self> {
1074        let json_data = json::parse(json_str).map_err(crate::GraphRAGError::Json)?;
1075
1076        let document_id = DocumentId::new(
1077            json_data["document_id"]
1078                .as_str()
1079                .ok_or_else(|| {
1080                    crate::GraphRAGError::Json(json::Error::WrongType(
1081                        "document_id must be string".to_string(),
1082                    ))
1083                })?
1084                .to_string(),
1085        );
1086
1087        let config_json = &json_data["config"];
1088        let config = HierarchicalConfig {
1089            merge_size: config_json["merge_size"].as_usize().unwrap_or(5),
1090            max_summary_length: config_json["max_summary_length"].as_usize().unwrap_or(200),
1091            min_node_size: config_json["min_node_size"].as_usize().unwrap_or(50),
1092            overlap_sentences: config_json["overlap_sentences"].as_usize().unwrap_or(2),
1093            llm_config: LLMConfig::default(),
1094        };
1095
1096        let mut tree = Self::new(document_id, config)?;
1097
1098        // Load nodes
1099        if let json::JsonValue::Object(nodes_obj) = &json_data["nodes"] {
1100            for (node_id_str, node_json) in nodes_obj.iter() {
1101                let node_id = NodeId::new(node_id_str.to_string());
1102
1103                let children: Vec<NodeId> = node_json["children"]
1104                    .members()
1105                    .filter_map(|v| v.as_str())
1106                    .map(|s| NodeId::new(s.to_string()))
1107                    .collect();
1108
1109                let parent = node_json["parent"]
1110                    .as_str()
1111                    .map(|s| NodeId::new(s.to_string()));
1112
1113                let chunk_ids: Vec<ChunkId> = node_json["chunk_ids"]
1114                    .members()
1115                    .filter_map(|v| v.as_str())
1116                    .map(|s| ChunkId::new(s.to_string()))
1117                    .collect();
1118
1119                let keywords: Vec<String> = node_json["keywords"]
1120                    .members()
1121                    .filter_map(|v| v.as_str())
1122                    .map(|s| s.to_string())
1123                    .collect();
1124
1125                let node = TreeNode {
1126                    id: node_id.clone(),
1127                    content: node_json["content"].as_str().unwrap_or("").to_string(),
1128                    summary: node_json["summary"].as_str().unwrap_or("").to_string(),
1129                    level: node_json["level"].as_usize().unwrap_or(0),
1130                    children,
1131                    parent,
1132                    chunk_ids,
1133                    keywords,
1134                    start_offset: node_json["start_offset"].as_usize().unwrap_or(0),
1135                    end_offset: node_json["end_offset"].as_usize().unwrap_or(0),
1136                };
1137
1138                tree.nodes.insert(node_id, node);
1139            }
1140        }
1141
1142        // Load root nodes
1143        tree.root_nodes = json_data["root_nodes"]
1144            .members()
1145            .filter_map(|v| v.as_str())
1146            .map(|s| NodeId::new(s.to_string()))
1147            .collect();
1148
1149        // Load levels
1150        if let json::JsonValue::Object(levels_obj) = &json_data["levels"] {
1151            for (level_str, level_json) in levels_obj.iter() {
1152                if let Ok(level) = level_str.parse::<usize>() {
1153                    let node_ids: Vec<NodeId> = level_json
1154                        .members()
1155                        .filter_map(|v| v.as_str())
1156                        .map(|s| NodeId::new(s.to_string()))
1157                        .collect();
1158                    tree.levels.insert(level, node_ids);
1159                }
1160            }
1161        }
1162
1163        Ok(tree)
1164    }
1165}
1166
1167/// Result from querying the hierarchical tree
1168#[derive(Debug, Clone)]
1169pub struct QueryResult {
1170    /// ID of the matching node
1171    pub node_id: NodeId,
1172    /// Relevance score for the query match
1173    pub score: f32,
1174    /// Tree level of this result
1175    pub level: usize,
1176    /// Summary text of the matching node
1177    pub summary: String,
1178    /// Keywords associated with the node
1179    pub keywords: Vec<String>,
1180    /// Chunk IDs represented in this result
1181    pub chunk_ids: Vec<ChunkId>,
1182}
1183
1184/// Statistics about the hierarchical tree
1185#[derive(Debug)]
1186pub struct TreeStatistics {
1187    /// Total number of nodes in the tree
1188    pub total_nodes: usize,
1189    /// Maximum depth level of the tree
1190    pub max_level: usize,
1191    /// Count of nodes at each level
1192    pub nodes_per_level: HashMap<usize, usize>,
1193    /// Number of root nodes in the tree
1194    pub root_count: usize,
1195    /// ID of the document this tree represents
1196    pub document_id: DocumentId,
1197}
1198
1199impl TreeStatistics {
1200    /// Print tree statistics
1201    pub fn print(&self) {
1202        println!("Hierarchical Tree Statistics:");
1203        println!("  Document ID: {}", self.document_id);
1204        println!("  Total nodes: {}", self.total_nodes);
1205        println!("  Max level: {}", self.max_level);
1206        println!("  Root nodes: {}", self.root_count);
1207        println!("  Nodes per level:");
1208
1209        let mut levels: Vec<_> = self.nodes_per_level.iter().collect();
1210        levels.sort_by_key(|(level, _)| *level);
1211
1212        for (level, count) in levels {
1213            println!("    Level {level}: {count} nodes");
1214        }
1215    }
1216}
1217
1218#[cfg(test)]
1219mod tests {
1220    use super::*;
1221    use crate::core::DocumentId;
1222
1223    #[test]
1224    fn test_tree_creation() {
1225        let config = HierarchicalConfig::default();
1226        let doc_id = DocumentId::new("test_doc".to_string());
1227        let tree = DocumentTree::new(doc_id, config);
1228        assert!(tree.is_ok());
1229    }
1230
1231    #[test]
1232    fn test_extractive_summarization() {
1233        let config = HierarchicalConfig::default();
1234        let doc_id = DocumentId::new("test_doc".to_string());
1235        let tree = DocumentTree::new(doc_id, config).unwrap();
1236
1237        let text = "This is the first sentence. This is a second sentence with more details. This is the final sentence.";
1238        let summary = tree.generate_extractive_summary(text).unwrap();
1239
1240        assert!(!summary.is_empty());
1241        assert!(summary.len() <= tree.config.max_summary_length);
1242    }
1243
1244    #[test]
1245    fn test_json_serialization() {
1246        let config = HierarchicalConfig::default();
1247        let doc_id = DocumentId::new("test_doc".to_string());
1248        let tree = DocumentTree::new(doc_id, config).unwrap();
1249
1250        let json = tree.to_json().unwrap();
1251        assert!(json.contains("test_doc"));
1252
1253        let loaded_tree = DocumentTree::from_json(&json).unwrap();
1254        assert_eq!(loaded_tree.document_id.to_string(), "test_doc");
1255    }
1256}