Skip to main content

graphrag_core/summarization/
mod.rs

1//! Summarization of chunks and communities, and aggregation of query results.
2//!
3//! Condenses text spans (and graph communities) into compact summaries, feeding the
4//! `QueryResult` aggregation used by generation.
5
6#[cfg(feature = "parallel-processing")]
7use crate::parallel::ParallelProcessor;
8use crate::{
9    core::{ChunkId, DocumentId, GraphRAGError, TextChunk},
10    text::TextProcessor,
11    Result,
12};
13use indexmap::IndexMap;
14use std::collections::{HashMap, VecDeque};
15use std::sync::Arc;
16
17/// Trait for LLM client to be used in summarization
18#[async_trait::async_trait]
19pub trait LLMClient: Send + Sync {
20    /// Generate a summary for the given text
21    async fn generate_summary(
22        &self,
23        text: &str,
24        prompt: &str,
25        max_tokens: usize,
26        temperature: f32,
27    ) -> Result<String>;
28
29    /// Generate summary in batch for multiple texts
30    async fn generate_summary_batch(
31        &self,
32        texts: &[(&str, &str)],
33        max_tokens: usize,
34        temperature: f32,
35    ) -> Result<Vec<String>> {
36        let mut results = Vec::new();
37        for (text, prompt) in texts {
38            let summary = self
39                .generate_summary(text, prompt, max_tokens, temperature)
40                .await?;
41            results.push(summary);
42        }
43        Ok(results)
44    }
45
46    /// Get model name
47    fn model_name(&self) -> &str;
48}
49
50/// Unique identifier for tree nodes
51#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub struct NodeId(pub String);
53
54impl NodeId {
55    /// Creates a new NodeId from a string
56    pub fn new(id: String) -> Self {
57        Self(id)
58    }
59}
60
61impl std::fmt::Display for NodeId {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        write!(f, "{}", self.0)
64    }
65}
66
67impl From<String> for NodeId {
68    fn from(s: String) -> Self {
69        Self(s)
70    }
71}
72
73impl From<NodeId> for String {
74    fn from(id: NodeId) -> Self {
75        id.0
76    }
77}
78
79/// Configuration for hierarchical summarization
80#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
81pub struct HierarchicalConfig {
82    /// Number of nodes to merge when building tree levels
83    pub merge_size: usize,
84    /// Maximum character length for generated summaries
85    pub max_summary_length: usize,
86    /// Minimum size in characters for nodes to be considered valid
87    pub min_node_size: usize,
88    /// Number of overlapping sentences between adjacent chunks
89    pub overlap_sentences: usize,
90    /// LLM-based summarization configuration
91    pub llm_config: LLMConfig,
92}
93
94/// Configuration for LLM-based summarization
95#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
96pub struct LLMConfig {
97    /// Whether to use LLM for summarization (vs extractive)
98    pub enabled: bool,
99    /// Model to use for summarization
100    pub model_name: String,
101    /// Temperature for generation (lower = more deterministic)
102    pub temperature: f32,
103    /// Maximum tokens for LLM generation
104    pub max_tokens: usize,
105    /// Strategy for summarization
106    pub strategy: LLMStrategy,
107    /// Level-specific configurations
108    pub level_configs: HashMap<usize, LevelConfig>,
109}
110
111/// Summarization strategy for different levels
112#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
113pub enum LLMStrategy {
114    /// Use same approach for all levels
115    Uniform,
116    /// Different approaches for different levels
117    Adaptive,
118    /// Progressive: extractive for lower levels, abstractive for higher
119    Progressive,
120}
121
122/// Configuration specific to tree levels
123#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
124pub struct LevelConfig {
125    /// Maximum summary length for this level
126    pub max_length: usize,
127    /// Whether to use abstractive summarization at this level
128    pub use_abstractive: bool,
129    /// Custom prompt template for this level
130    pub prompt_template: Option<String>,
131    /// Temperature override for this level
132    pub temperature: Option<f32>,
133}
134
135impl Default for HierarchicalConfig {
136    fn default() -> Self {
137        Self {
138            merge_size: 5,
139            max_summary_length: 200,
140            min_node_size: 50,
141            overlap_sentences: 2,
142            llm_config: LLMConfig::default(),
143        }
144    }
145}
146
147impl Default for LLMConfig {
148    fn default() -> Self {
149        Self {
150            enabled: false, // Disabled by default for backward compatibility
151            model_name: "llama3.1:8b".to_string(),
152            temperature: 0.3, // Lower temperature for more coherent summaries
153            max_tokens: 150,
154            strategy: LLMStrategy::Progressive,
155            level_configs: HashMap::new(),
156        }
157    }
158}
159
160impl Default for LevelConfig {
161    fn default() -> Self {
162        Self {
163            max_length: 200,
164            use_abstractive: false, // Start with extractive at lower levels
165            prompt_template: None,
166            temperature: None,
167        }
168    }
169}
170
171/// A node in the hierarchical document tree
172#[derive(Debug, Clone)]
173pub struct TreeNode {
174    /// Unique identifier for this node
175    pub id: NodeId,
176    /// Full text content of this node
177    pub content: String,
178    /// Generated summary of the content
179    pub summary: String,
180    /// Level in the tree hierarchy (0 = leaf, higher = more abstract)
181    pub level: usize,
182    /// IDs of child nodes in the tree
183    pub children: Vec<NodeId>,
184    /// ID of the parent node, if any
185    pub parent: Option<NodeId>,
186    /// IDs of text chunks represented by this node
187    pub chunk_ids: Vec<ChunkId>,
188    /// Extracted keywords from the content
189    pub keywords: Vec<String>,
190    /// Starting character offset in the original document
191    pub start_offset: usize,
192    /// Ending character offset in the original document
193    pub end_offset: usize,
194}
195
196/// Hierarchical document tree for multi-level summarization
197pub struct DocumentTree {
198    nodes: IndexMap<NodeId, TreeNode>,
199    root_nodes: Vec<NodeId>,
200    levels: HashMap<usize, Vec<NodeId>>,
201    document_id: DocumentId,
202    config: HierarchicalConfig,
203    text_processor: TextProcessor,
204    llm_client: Option<Arc<dyn LLMClient>>,
205}
206
207impl DocumentTree {
208    /// Create a new document tree
209    pub fn new(document_id: DocumentId, config: HierarchicalConfig) -> Result<Self> {
210        let text_processor = TextProcessor::new(1000, 100)?;
211
212        Ok(Self {
213            nodes: IndexMap::new(),
214            root_nodes: Vec::new(),
215            levels: HashMap::new(),
216            document_id,
217            config,
218            text_processor,
219            llm_client: None,
220        })
221    }
222
223    /// Create a new document tree with LLM client
224    pub fn with_llm_client(
225        document_id: DocumentId,
226        config: HierarchicalConfig,
227        llm_client: Arc<dyn LLMClient>,
228    ) -> Result<Self> {
229        let text_processor = TextProcessor::new(1000, 100)?;
230
231        Ok(Self {
232            nodes: IndexMap::new(),
233            root_nodes: Vec::new(),
234            levels: HashMap::new(),
235            document_id,
236            config,
237            text_processor,
238            llm_client: Some(llm_client),
239        })
240    }
241
242    /// Set LLM client for the tree
243    pub fn set_llm_client(&mut self, llm_client: Option<Arc<dyn LLMClient>>) {
244        self.llm_client = llm_client;
245    }
246
247    /// Create a new document tree with parallel processing support
248    #[cfg(feature = "parallel-processing")]
249    pub fn with_parallel_processing(
250        document_id: DocumentId,
251        config: HierarchicalConfig,
252        _parallel_processor: ParallelProcessor,
253    ) -> Result<Self> {
254        let text_processor = TextProcessor::new(1000, 100)?;
255
256        Ok(Self {
257            nodes: IndexMap::new(),
258            root_nodes: Vec::new(),
259            levels: HashMap::new(),
260            document_id,
261            config,
262            text_processor,
263            llm_client: None,
264        })
265    }
266
267    /// Create a new document tree with both parallel processing and LLM client
268    #[cfg(feature = "parallel-processing")]
269    pub fn with_parallel_and_llm(
270        document_id: DocumentId,
271        config: HierarchicalConfig,
272        _parallel_processor: ParallelProcessor,
273        llm_client: Arc<dyn LLMClient>,
274    ) -> Result<Self> {
275        let text_processor = TextProcessor::new(1000, 100)?;
276
277        Ok(Self {
278            nodes: IndexMap::new(),
279            root_nodes: Vec::new(),
280            levels: HashMap::new(),
281            document_id,
282            config,
283            text_processor,
284            llm_client: Some(llm_client),
285        })
286    }
287
288    /// Build the hierarchical tree from text chunks
289    pub async fn build_from_chunks(&mut self, chunks: Vec<TextChunk>) -> Result<()> {
290        if chunks.is_empty() {
291            return Ok(());
292        }
293
294        println!("Building hierarchical tree from {} chunks", chunks.len());
295
296        // Create leaf nodes from chunks
297        let leaf_nodes = self.create_leaf_nodes(chunks)?;
298
299        // Build the tree bottom-up
300        self.build_bottom_up(leaf_nodes).await?;
301
302        println!(
303            "Tree built with {} total nodes across {} levels",
304            self.nodes.len(),
305            self.levels.len()
306        );
307
308        Ok(())
309    }
310
311    /// Create leaf nodes from text chunks with parallel processing
312    fn create_leaf_nodes(&mut self, chunks: Vec<TextChunk>) -> Result<Vec<NodeId>> {
313        if chunks.len() < 10 {
314            // Use sequential processing for small numbers of chunks
315            return self.create_leaf_nodes_sequential(chunks);
316        }
317
318        #[cfg(feature = "parallel-processing")]
319        {
320            use rayon::prelude::*;
321
322            // Parallel node creation with proper error handling
323            let node_results: std::result::Result<Vec<_>, crate::GraphRAGError> = chunks
324                .par_iter()
325                .map(|chunk| {
326                    let node_id = NodeId::new(format!("leaf_{}", chunk.id));
327
328                    // Create a temporary text processor for each thread to avoid borrowing issues
329                    let temp_processor =
330                        crate::text::TextProcessor::new(1000, 100).map_err(|e| {
331                            crate::GraphRAGError::Config {
332                                message: format!("Failed to create text processor: {e}"),
333                            }
334                        })?;
335
336                    // Extract keywords for the chunk
337                    let keywords = temp_processor.extract_keywords(&chunk.content, 5);
338
339                    // Generate summary using simplified approach suitable for parallel execution
340                    let summary =
341                        self.generate_parallel_summary(&chunk.content, &temp_processor)?;
342
343                    let node = TreeNode {
344                        id: node_id.clone(),
345                        content: chunk.content.clone(),
346                        summary,
347                        level: 0,
348                        children: Vec::new(),
349                        parent: None,
350                        chunk_ids: vec![chunk.id.clone()],
351                        keywords,
352                        start_offset: chunk.start_offset,
353                        end_offset: chunk.end_offset,
354                    };
355
356                    Ok((node_id, node))
357                })
358                .collect();
359
360            match node_results {
361                Ok(nodes) => {
362                    let mut leaf_node_ids = Vec::new();
363
364                    // Insert nodes sequentially to avoid concurrent modification
365                    for (node_id, node) in nodes {
366                        leaf_node_ids.push(node_id.clone());
367                        self.nodes.insert(node_id, node);
368                    }
369
370                    println!("Created {} leaf nodes in parallel", leaf_node_ids.len());
371
372                    // Update levels tracking
373                    self.levels.insert(0, leaf_node_ids.clone());
374
375                    Ok(leaf_node_ids)
376                },
377                Err(e) => {
378                    eprintln!("Error in parallel node creation: {e}");
379                    // Fall back to sequential processing
380                    self.create_leaf_nodes_sequential(chunks)
381                },
382            }
383        }
384
385        #[cfg(not(feature = "parallel-processing"))]
386        {
387            self.create_leaf_nodes_sequential(chunks)
388        }
389    }
390
391    /// Sequential leaf node creation (fallback)
392    fn create_leaf_nodes_sequential(&mut self, chunks: Vec<TextChunk>) -> Result<Vec<NodeId>> {
393        let mut leaf_node_ids = Vec::new();
394
395        for chunk in chunks {
396            let node_id = NodeId::new(format!("leaf_{}", chunk.id));
397
398            // Extract keywords for the chunk
399            let keywords = self.text_processor.extract_keywords(&chunk.content, 5);
400
401            let node = TreeNode {
402                id: node_id.clone(),
403                content: chunk.content.clone(),
404                summary: self.generate_extractive_summary(&chunk.content)?,
405                level: 0,
406                children: Vec::new(),
407                parent: None,
408                chunk_ids: vec![chunk.id],
409                keywords,
410                start_offset: chunk.start_offset,
411                end_offset: chunk.end_offset,
412            };
413
414            self.nodes.insert(node_id.clone(), node);
415            leaf_node_ids.push(node_id);
416        }
417
418        // Update levels tracking
419        self.levels.insert(0, leaf_node_ids.clone());
420
421        Ok(leaf_node_ids)
422    }
423
424    /// Generate summary suitable for parallel processing
425    /// Generate summary using LLM for the given text and level
426    pub async fn generate_llm_summary(
427        &self,
428        text: &str,
429        level: usize,
430        context: &str,
431    ) -> Result<String> {
432        let llm_client = self
433            .llm_client
434            .as_ref()
435            .ok_or_else(|| GraphRAGError::Config {
436                message: "LLM client not configured for summarization".to_string(),
437            })?;
438
439        // Get configuration for this level
440        let level_config = self.get_level_config(level);
441
442        // Create prompt based on strategy and level
443        let prompt = self.create_summary_prompt(text, level, context, &level_config)?;
444
445        // Generate summary
446        let summary = llm_client
447            .generate_summary(
448                text,
449                &prompt,
450                level_config.max_length,
451                level_config
452                    .temperature
453                    .unwrap_or(self.config.llm_config.temperature),
454            )
455            .await?;
456
457        // Ensure summary is within length limits
458        self.truncate_summary(&summary, level_config.max_length)
459    }
460
461    /// Generate summaries in batch for multiple texts
462    pub async fn generate_llm_summaries_batch(
463        &self,
464        texts: &[(&str, usize, &str)], // (text, level, context)
465    ) -> Result<Vec<String>> {
466        let llm_client = self
467            .llm_client
468            .as_ref()
469            .ok_or_else(|| GraphRAGError::Config {
470                message: "LLM client not configured for summarization".to_string(),
471            })?;
472
473        let mut prompts = Vec::new();
474        let mut configs = Vec::new();
475
476        for (text, level, context) in texts {
477            let level_config = self.get_level_config(*level);
478            let prompt = self.create_summary_prompt(text, *level, context, &level_config)?;
479            prompts.push(prompt);
480            configs.push(level_config);
481        }
482
483        // Generate summaries in batch
484        let text_refs: Vec<&str> = texts.iter().map(|(t, _, _)| *t).collect();
485        let prompt_refs: Vec<&str> = prompts.iter().map(|p| p.as_str()).collect();
486
487        let summaries = llm_client
488            .generate_summary_batch(
489                &text_refs
490                    .iter()
491                    .zip(prompt_refs.iter())
492                    .map(|(&t, &p)| (t, p))
493                    .collect::<Vec<_>>(),
494                self.config.llm_config.max_tokens,
495                self.config.llm_config.temperature,
496            )
497            .await?;
498
499        // Truncate summaries according to level configs
500        let mut results = Vec::new();
501        for (i, summary) in summaries.into_iter().enumerate() {
502            let truncated = self.truncate_summary(&summary, configs[i].max_length)?;
503            results.push(truncated);
504        }
505
506        Ok(results)
507    }
508
509    /// Create a summary prompt based on level and strategy
510    fn create_summary_prompt(
511        &self,
512        text: &str,
513        level: usize,
514        context: &str,
515        level_config: &LevelConfig,
516    ) -> Result<String> {
517        // Use custom template if provided
518        if let Some(template) = &level_config.prompt_template {
519            return Ok(template
520                .replace("{text}", text)
521                .replace("{context}", context)
522                .replace("{level}", &level.to_string())
523                .replace("{max_length}", &level_config.max_length.to_string()));
524        }
525
526        // Default prompts based on strategy and level
527        match self.config.llm_config.strategy {
528            LLMStrategy::Uniform => {
529                Ok(format!(
530                    "Create a concise summary of the following text. The summary should be approximately {} characters long.\n\nContext: {}\n\nText to summarize:\n{}\n\nSummary:",
531                    level_config.max_length, context, text
532                ))
533            }
534            LLMStrategy::Adaptive => {
535                if level == 0 {
536                    Ok(format!(
537                        "Extract the key information from this text segment. Keep it factual and under {} characters.\n\nContext: {}\n\nText:\n{}\n\nKey points:",
538                        level_config.max_length, context, text
539                    ))
540                } else if level <= 2 {
541                    Ok(format!(
542                        "Create a coherent summary that combines the key information from this text. Make it approximately {} characters.\n\nContext: {}\n\nText:\n{}\n\nSummary:",
543                        level_config.max_length, context, text
544                    ))
545                } else {
546                    Ok(format!(
547                        "Generate a high-level abstract summary of this content. Focus on the main themes and insights. Limit to approximately {} characters.\n\nContext: {}\n\nText:\n{}\n\nAbstract summary:",
548                        level_config.max_length, context, text
549                    ))
550                }
551            }
552            LLMStrategy::Progressive => {
553                if level_config.use_abstractive {
554                    Ok(format!(
555                        "Generate an abstractive summary that synthesizes the key concepts and relationships in this text. The summary should be approximately {} characters.\n\nContext: {}\n\nText:\n{}\n\nAbstractive summary:",
556                        level_config.max_length, context, text
557                    ))
558                } else {
559                    Ok(format!(
560                        "Extract and organize the most important sentences from this text to create a coherent summary. Keep it under {} characters.\n\nContext: {}\n\nText:\n{}\n\nExtractive summary:",
561                        level_config.max_length, context, text
562                    ))
563                }
564            }
565        }
566    }
567
568    /// Get configuration for a specific level
569    fn get_level_config(&self, level: usize) -> LevelConfig {
570        self.config
571            .llm_config
572            .level_configs
573            .get(&level)
574            .cloned()
575            .unwrap_or({
576                // Default configuration based on level
577                LevelConfig {
578                    max_length: self.config.max_summary_length,
579                    use_abstractive: match self.config.llm_config.strategy {
580                        LLMStrategy::Progressive => level >= 2,
581                        LLMStrategy::Adaptive => level >= 3,
582                        LLMStrategy::Uniform => level > 0,
583                    },
584                    prompt_template: None,
585                    temperature: None,
586                }
587            })
588    }
589
590    /// Truncate summary to fit within length limits
591    fn truncate_summary(&self, summary: &str, max_length: usize) -> Result<String> {
592        if summary.len() <= max_length {
593            return Ok(summary.to_string());
594        }
595
596        // Try to truncate at sentence boundaries
597        let sentences: Vec<&str> = summary
598            .split('.')
599            .filter(|s| !s.trim().is_empty())
600            .collect();
601
602        let mut result = String::new();
603        for sentence in sentences {
604            if result.len() + sentence.len() < max_length - 3 {
605                if !result.is_empty() {
606                    result.push('.');
607                }
608                result.push_str(sentence.trim());
609            } else {
610                break;
611            }
612        }
613
614        if result.is_empty() {
615            // Fallback: truncate characters
616            result = summary.chars().take(max_length - 3).collect();
617            result.push_str("...");
618        } else {
619            result.push('.');
620        }
621
622        Ok(result)
623    }
624
625    #[allow(dead_code)]
626    fn generate_parallel_summary(
627        &self,
628        text: &str,
629        processor: &crate::text::TextProcessor,
630    ) -> Result<String> {
631        let sentences = processor.extract_sentences(text);
632
633        if sentences.is_empty() {
634            return Ok(String::new());
635        }
636
637        if sentences.len() == 1 {
638            return Ok(sentences[0].clone());
639        }
640
641        // Simplified scoring for parallel execution
642        let mut best_sentence = &sentences[0];
643        let mut best_score = 0.0;
644
645        for sentence in &sentences {
646            let words: Vec<&str> = sentence.split_whitespace().collect();
647
648            // Simple scoring based on length and word density
649            let length_score = if words.len() < 5 {
650                0.1
651            } else if words.len() > 30 {
652                0.3
653            } else {
654                1.0
655            };
656
657            let word_score = words.len() as f32 * 0.1;
658            let score = length_score + word_score;
659
660            if score > best_score {
661                best_score = score;
662                best_sentence = sentence;
663            }
664        }
665
666        // Truncate if necessary
667        if best_sentence.len() > self.config.max_summary_length {
668            Ok(best_sentence
669                .chars()
670                .take(self.config.max_summary_length - 3)
671                .collect::<String>()
672                + "...")
673        } else {
674            Ok(best_sentence.clone())
675        }
676    }
677
678    /// Build the tree bottom-up by merging nodes at each level
679    async fn build_bottom_up(&mut self, leaf_nodes: Vec<NodeId>) -> Result<()> {
680        let mut current_level_nodes = leaf_nodes;
681        let mut current_level = 0;
682
683        while current_level_nodes.len() > 1 {
684            let next_level_nodes = self
685                .merge_level(&current_level_nodes, current_level + 1)
686                .await?;
687
688            current_level_nodes = next_level_nodes;
689            current_level += 1;
690        }
691
692        // Set root nodes
693        self.root_nodes = current_level_nodes;
694
695        Ok(())
696    }
697
698    /// Merge nodes at a level to create the next level up
699    async fn merge_level(
700        &mut self,
701        level_nodes: &[NodeId],
702        new_level: usize,
703    ) -> Result<Vec<NodeId>> {
704        let mut new_level_nodes = Vec::new();
705        // Group nodes by merge_size
706        for (node_counter, chunk) in level_nodes.chunks(self.config.merge_size).enumerate() {
707            let merged_node_id = NodeId::new(format!("level_{new_level}_{node_counter}"));
708            let merged_node = self
709                .merge_nodes(chunk, merged_node_id.clone(), new_level)
710                .await?;
711
712            // Update parent references for children
713            for child_id in chunk {
714                if let Some(child_node) = self.nodes.get_mut(child_id) {
715                    child_node.parent = Some(merged_node_id.clone());
716                }
717            }
718
719            self.nodes.insert(merged_node_id.clone(), merged_node);
720            new_level_nodes.push(merged_node_id);
721        }
722
723        // Update levels tracking
724        self.levels.insert(new_level, new_level_nodes.clone());
725
726        Ok(new_level_nodes)
727    }
728
729    /// Merge multiple nodes into a single parent node
730    async fn merge_nodes(
731        &self,
732        node_ids: &[NodeId],
733        merged_id: NodeId,
734        level: usize,
735    ) -> Result<TreeNode> {
736        let mut combined_content = String::new();
737        let mut all_chunk_ids = Vec::new();
738        let mut all_keywords = Vec::new();
739        let mut min_offset = usize::MAX;
740        let mut max_offset = 0;
741
742        for node_id in node_ids {
743            if let Some(node) = self.nodes.get(node_id) {
744                if !combined_content.is_empty() {
745                    combined_content.push_str("\n\n");
746                }
747                combined_content.push_str(&node.content);
748                all_chunk_ids.extend(node.chunk_ids.clone());
749                all_keywords.extend(node.keywords.clone());
750                min_offset = min_offset.min(node.start_offset);
751                max_offset = max_offset.max(node.end_offset);
752            }
753        }
754
755        // Deduplicate and limit keywords
756        all_keywords.sort();
757        all_keywords.dedup();
758        all_keywords.truncate(10);
759
760        // Generate summary for the merged content
761        let summary = if self.config.llm_config.enabled {
762            // Use LLM-based summarization if enabled and available
763            if self.llm_client.is_some() {
764                // Create context for this merge operation
765                let context = format!(
766                    "Merging {} nodes at level {}. This represents a higher-level abstraction of the document content.",
767                    node_ids.len(),
768                    level
769                );
770
771                // Try LLM-based summarization, fall back to extractive if it fails
772                match self
773                    .generate_llm_summary(&combined_content, level, &context)
774                    .await
775                {
776                    Ok(llm_summary) => {
777                        println!(
778                            "✅ Generated LLM-based summary for level {} ({} chars)",
779                            level,
780                            llm_summary.len()
781                        );
782                        llm_summary
783                    },
784                    Err(e) => {
785                        eprintln!("⚠️ LLM summarization failed for level {}: {}, falling back to extractive", level, e);
786                        self.generate_extractive_summary(&combined_content)?
787                    },
788                }
789            } else {
790                self.generate_extractive_summary(&combined_content)?
791            }
792        } else {
793            self.generate_extractive_summary(&combined_content)?
794        };
795
796        Ok(TreeNode {
797            id: merged_id,
798            content: combined_content,
799            summary,
800            level,
801            children: node_ids.to_vec(),
802            parent: None,
803            chunk_ids: all_chunk_ids,
804            keywords: all_keywords,
805            start_offset: min_offset,
806            end_offset: max_offset,
807        })
808    }
809
810    /// Generate extractive summary using simple sentence ranking
811    fn generate_extractive_summary(&self, text: &str) -> Result<String> {
812        let sentences = self.text_processor.extract_sentences(text);
813
814        if sentences.is_empty() {
815            return Ok(String::new());
816        }
817
818        if sentences.len() == 1 {
819            return Ok(sentences[0].clone());
820        }
821
822        // Score sentences based on length and keyword density
823        let mut sentence_scores: Vec<(usize, f32)> = sentences
824            .iter()
825            .enumerate()
826            .map(|(i, sentence)| {
827                let score = self.score_sentence(sentence, &sentences);
828                (i, score)
829            })
830            .collect();
831
832        // Sort by score descending
833        sentence_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
834
835        // Select top sentences up to max_summary_length
836        let mut summary = String::new();
837        let mut selected_indices = Vec::new();
838
839        for (sentence_idx, _score) in sentence_scores {
840            let sentence = &sentences[sentence_idx];
841            if summary.len() + sentence.len() <= self.config.max_summary_length {
842                selected_indices.push(sentence_idx);
843                if !summary.is_empty() {
844                    summary.push(' ');
845                }
846                summary.push_str(sentence);
847            }
848        }
849
850        // If no sentences fit, take the first one truncated
851        if summary.is_empty() && !sentences.is_empty() {
852            let first_sentence = &sentences[0];
853            if first_sentence.len() <= self.config.max_summary_length {
854                summary = first_sentence.clone();
855            } else {
856                summary = first_sentence
857                    .chars()
858                    .take(self.config.max_summary_length - 3)
859                    .collect::<String>()
860                    + "...";
861            }
862        }
863
864        Ok(summary)
865    }
866
867    /// Score a sentence for extractive summarization
868    fn score_sentence(&self, sentence: &str, all_sentences: &[String]) -> f32 {
869        let words: Vec<&str> = sentence.split_whitespace().collect();
870
871        // Base score from length (prefer medium-length sentences)
872        let length_score = if words.len() < 5 {
873            0.1
874        } else if words.len() > 30 {
875            0.3
876        } else {
877            1.0
878        };
879
880        // Position score (prefer sentences from beginning and end)
881        let position_score = 0.5; // Simplified for now
882
883        // Word frequency score
884        let mut word_freq_score = 0.0;
885        let total_words: Vec<&str> = all_sentences
886            .iter()
887            .flat_map(|s| s.split_whitespace())
888            .collect();
889
890        for word in &words {
891            let word_lower = word.to_lowercase();
892            if word_lower.len() > 3 && !self.is_stop_word(&word_lower) {
893                let freq = total_words
894                    .iter()
895                    .filter(|&&w| w.to_lowercase() == word_lower)
896                    .count();
897                if freq > 1 {
898                    word_freq_score += freq as f32 / total_words.len() as f32;
899                }
900            }
901        }
902
903        length_score * 0.4 + position_score * 0.2 + word_freq_score * 0.4
904    }
905
906    /// Simple stop word detection (English)
907    fn is_stop_word(&self, word: &str) -> bool {
908        const STOP_WORDS: &[&str] = &[
909            "the", "be", "to", "of", "and", "a", "in", "that", "have", "i", "it", "for", "not",
910            "on", "with", "he", "as", "you", "do", "at", "this", "but", "his", "by", "from",
911            "they", "we", "say", "her", "she", "or", "an", "will", "my", "one", "all", "would",
912            "there", "their", "what", "so", "up", "out", "if", "about", "who", "get", "which",
913            "go", "me",
914        ];
915        STOP_WORDS.contains(&word)
916    }
917
918    /// Query the tree for relevant nodes at different levels
919    pub fn query(&self, query: &str, max_results: usize) -> Result<Vec<QueryResult>> {
920        let query_keywords = self.text_processor.extract_keywords(query, 5);
921        let mut results = Vec::new();
922
923        // Search all nodes for keyword matches
924        for (node_id, node) in &self.nodes {
925            let score = self.calculate_relevance_score(node, &query_keywords, query);
926
927            if score > 0.1 {
928                results.push(QueryResult {
929                    node_id: node_id.clone(),
930                    score,
931                    level: node.level,
932                    summary: node.summary.clone(),
933                    keywords: node.keywords.clone(),
934                    chunk_ids: node.chunk_ids.clone(),
935                });
936            }
937        }
938
939        // Sort by score and return top results
940        results.sort_by(|a, b| {
941            b.score
942                .partial_cmp(&a.score)
943                .unwrap_or(std::cmp::Ordering::Equal)
944        });
945        results.truncate(max_results);
946
947        Ok(results)
948    }
949
950    /// Calculate relevance score for a node given query keywords
951    fn calculate_relevance_score(
952        &self,
953        node: &TreeNode,
954        query_keywords: &[String],
955        query: &str,
956    ) -> f32 {
957        let mut score = 0.0;
958
959        // Keyword overlap score
960        let node_text = format!("{} {}", node.summary, node.keywords.join(" ")).to_lowercase();
961        for keyword in query_keywords {
962            if node_text.contains(&keyword.to_lowercase()) {
963                score += 1.0;
964            }
965        }
966
967        // Direct text similarity (simple word overlap)
968        let query_words: Vec<&str> = query.split_whitespace().collect();
969        let node_words: Vec<&str> = node_text.split_whitespace().collect();
970
971        let mut overlap_count = 0;
972        for query_word in &query_words {
973            if node_words.contains(&query_word.to_lowercase().as_str()) {
974                overlap_count += 1;
975            }
976        }
977
978        if !query_words.is_empty() {
979            score += (overlap_count as f32 / query_words.len() as f32) * 2.0;
980        }
981
982        // Level bonus (prefer higher levels for overview, lower levels for details)
983        let level_score = 1.0 / (node.level + 1) as f32;
984        score += level_score * 0.5;
985
986        score
987    }
988
989    /// Get ancestors of a node (path to root)
990    pub fn get_ancestors(&self, node_id: &NodeId) -> Vec<&TreeNode> {
991        let mut ancestors = Vec::new();
992        let mut current_id = Some(node_id.clone());
993
994        while let Some(id) = current_id {
995            if let Some(node) = self.nodes.get(&id) {
996                ancestors.push(node);
997                current_id = node.parent.clone();
998            } else {
999                break;
1000            }
1001        }
1002
1003        ancestors
1004    }
1005
1006    /// Get descendants of a node (all children recursively)
1007    pub fn get_descendants(&self, node_id: &NodeId) -> Vec<&TreeNode> {
1008        let mut descendants = Vec::new();
1009        let mut queue = VecDeque::new();
1010
1011        if let Some(node) = self.nodes.get(node_id) {
1012            queue.extend(node.children.iter());
1013        }
1014
1015        while let Some(child_id) = queue.pop_front() {
1016            if let Some(child_node) = self.nodes.get(child_id) {
1017                descendants.push(child_node);
1018                queue.extend(child_node.children.iter());
1019            }
1020        }
1021
1022        descendants
1023    }
1024
1025    /// Get a node by ID
1026    pub fn get_node(&self, node_id: &NodeId) -> Option<&TreeNode> {
1027        self.nodes.get(node_id)
1028    }
1029
1030    /// Get all nodes at a specific level
1031    pub fn get_level_nodes(&self, level: usize) -> Vec<&TreeNode> {
1032        if let Some(node_ids) = self.levels.get(&level) {
1033            node_ids
1034                .iter()
1035                .filter_map(|id| self.nodes.get(id))
1036                .collect()
1037        } else {
1038            Vec::new()
1039        }
1040    }
1041
1042    /// Get the root nodes of the tree
1043    pub fn get_root_nodes(&self) -> Vec<&TreeNode> {
1044        self.root_nodes
1045            .iter()
1046            .filter_map(|id| self.nodes.get(id))
1047            .collect()
1048    }
1049
1050    /// Get the document ID
1051    pub fn document_id(&self) -> &DocumentId {
1052        &self.document_id
1053    }
1054
1055    /// Get tree statistics
1056    pub fn get_statistics(&self) -> TreeStatistics {
1057        let max_level = self.levels.keys().max().copied().unwrap_or(0);
1058        let total_nodes = self.nodes.len();
1059        let nodes_per_level: HashMap<usize, usize> = self
1060            .levels
1061            .iter()
1062            .map(|(level, nodes)| (*level, nodes.len()))
1063            .collect();
1064
1065        TreeStatistics {
1066            total_nodes,
1067            max_level,
1068            nodes_per_level,
1069            root_count: self.root_nodes.len(),
1070            document_id: self.document_id.clone(),
1071        }
1072    }
1073
1074    /// Serialize the tree to JSON format
1075    pub fn to_json(&self) -> Result<String> {
1076        use json::JsonValue;
1077
1078        let mut tree_json = json::object! {
1079            "document_id": self.document_id.to_string(),
1080            "config": {
1081                "merge_size": self.config.merge_size,
1082                "max_summary_length": self.config.max_summary_length,
1083                "min_node_size": self.config.min_node_size,
1084                "overlap_sentences": self.config.overlap_sentences
1085            },
1086            "nodes": {},
1087            "root_nodes": [],
1088            "levels": {}
1089        };
1090
1091        // Serialize nodes
1092        for (node_id, node) in &self.nodes {
1093            let node_json = json::object! {
1094                "id": node_id.to_string(),
1095                "content": node.content.clone(),
1096                "summary": node.summary.clone(),
1097                "level": node.level,
1098                "children": node.children.iter().map(|id| id.to_string()).collect::<Vec<_>>(),
1099                "parent": node.parent.as_ref().map(|id| id.to_string()),
1100                "chunk_ids": node.chunk_ids.iter().map(|id| id.to_string()).collect::<Vec<_>>(),
1101                "keywords": node.keywords.clone(),
1102                "start_offset": node.start_offset,
1103                "end_offset": node.end_offset
1104            };
1105            tree_json["nodes"][node_id.to_string()] = node_json;
1106        }
1107
1108        // Serialize root nodes
1109        tree_json["root_nodes"] = self
1110            .root_nodes
1111            .iter()
1112            .map(|id| JsonValue::String(id.to_string()))
1113            .collect::<Vec<_>>()
1114            .into();
1115
1116        // Serialize levels
1117        for (level, node_ids) in &self.levels {
1118            tree_json["levels"][level.to_string()] = node_ids
1119                .iter()
1120                .map(|id| JsonValue::String(id.to_string()))
1121                .collect::<Vec<_>>()
1122                .into();
1123        }
1124
1125        Ok(tree_json.dump())
1126    }
1127
1128    /// Load tree from JSON format
1129    pub fn from_json(json_str: &str) -> Result<Self> {
1130        let json_data = json::parse(json_str).map_err(crate::GraphRAGError::Json)?;
1131
1132        let document_id = DocumentId::new(
1133            json_data["document_id"]
1134                .as_str()
1135                .ok_or_else(|| {
1136                    crate::GraphRAGError::Json(json::Error::WrongType(
1137                        "document_id must be string".to_string(),
1138                    ))
1139                })?
1140                .to_string(),
1141        );
1142
1143        let config_json = &json_data["config"];
1144        let config = HierarchicalConfig {
1145            merge_size: config_json["merge_size"].as_usize().unwrap_or(5),
1146            max_summary_length: config_json["max_summary_length"].as_usize().unwrap_or(200),
1147            min_node_size: config_json["min_node_size"].as_usize().unwrap_or(50),
1148            overlap_sentences: config_json["overlap_sentences"].as_usize().unwrap_or(2),
1149            llm_config: LLMConfig::default(),
1150        };
1151
1152        let mut tree = Self::new(document_id, config)?;
1153
1154        // Load nodes
1155        if let json::JsonValue::Object(nodes_obj) = &json_data["nodes"] {
1156            for (node_id_str, node_json) in nodes_obj.iter() {
1157                let node_id = NodeId::new(node_id_str.to_string());
1158
1159                let children: Vec<NodeId> = node_json["children"]
1160                    .members()
1161                    .filter_map(|v| v.as_str())
1162                    .map(|s| NodeId::new(s.to_string()))
1163                    .collect();
1164
1165                let parent = node_json["parent"]
1166                    .as_str()
1167                    .map(|s| NodeId::new(s.to_string()));
1168
1169                let chunk_ids: Vec<ChunkId> = node_json["chunk_ids"]
1170                    .members()
1171                    .filter_map(|v| v.as_str())
1172                    .map(|s| ChunkId::new(s.to_string()))
1173                    .collect();
1174
1175                let keywords: Vec<String> = node_json["keywords"]
1176                    .members()
1177                    .filter_map(|v| v.as_str())
1178                    .map(|s| s.to_string())
1179                    .collect();
1180
1181                let node = TreeNode {
1182                    id: node_id.clone(),
1183                    content: node_json["content"].as_str().unwrap_or("").to_string(),
1184                    summary: node_json["summary"].as_str().unwrap_or("").to_string(),
1185                    level: node_json["level"].as_usize().unwrap_or(0),
1186                    children,
1187                    parent,
1188                    chunk_ids,
1189                    keywords,
1190                    start_offset: node_json["start_offset"].as_usize().unwrap_or(0),
1191                    end_offset: node_json["end_offset"].as_usize().unwrap_or(0),
1192                };
1193
1194                tree.nodes.insert(node_id, node);
1195            }
1196        }
1197
1198        // Load root nodes
1199        tree.root_nodes = json_data["root_nodes"]
1200            .members()
1201            .filter_map(|v| v.as_str())
1202            .map(|s| NodeId::new(s.to_string()))
1203            .collect();
1204
1205        // Load levels
1206        if let json::JsonValue::Object(levels_obj) = &json_data["levels"] {
1207            for (level_str, level_json) in levels_obj.iter() {
1208                if let Ok(level) = level_str.parse::<usize>() {
1209                    let node_ids: Vec<NodeId> = level_json
1210                        .members()
1211                        .filter_map(|v| v.as_str())
1212                        .map(|s| NodeId::new(s.to_string()))
1213                        .collect();
1214                    tree.levels.insert(level, node_ids);
1215                }
1216            }
1217        }
1218
1219        Ok(tree)
1220    }
1221}
1222
1223/// Result from querying the hierarchical tree
1224#[derive(Debug, Clone)]
1225pub struct QueryResult {
1226    /// ID of the matching node
1227    pub node_id: NodeId,
1228    /// Relevance score for the query match
1229    pub score: f32,
1230    /// Tree level of this result
1231    pub level: usize,
1232    /// Summary text of the matching node
1233    pub summary: String,
1234    /// Keywords associated with the node
1235    pub keywords: Vec<String>,
1236    /// Chunk IDs represented in this result
1237    pub chunk_ids: Vec<ChunkId>,
1238}
1239
1240/// Statistics about the hierarchical tree
1241#[derive(Debug)]
1242pub struct TreeStatistics {
1243    /// Total number of nodes in the tree
1244    pub total_nodes: usize,
1245    /// Maximum depth level of the tree
1246    pub max_level: usize,
1247    /// Count of nodes at each level
1248    pub nodes_per_level: HashMap<usize, usize>,
1249    /// Number of root nodes in the tree
1250    pub root_count: usize,
1251    /// ID of the document this tree represents
1252    pub document_id: DocumentId,
1253}
1254
1255impl TreeStatistics {
1256    /// Print tree statistics
1257    pub fn print(&self) {
1258        println!("Hierarchical Tree Statistics:");
1259        println!("  Document ID: {}", self.document_id);
1260        println!("  Total nodes: {}", self.total_nodes);
1261        println!("  Max level: {}", self.max_level);
1262        println!("  Root nodes: {}", self.root_count);
1263        println!("  Nodes per level:");
1264
1265        let mut levels: Vec<_> = self.nodes_per_level.iter().collect();
1266        levels.sort_by_key(|(level, _)| *level);
1267
1268        for (level, count) in levels {
1269            println!("    Level {level}: {count} nodes");
1270        }
1271    }
1272}
1273
1274#[cfg(test)]
1275mod tests {
1276    use super::*;
1277    use crate::core::DocumentId;
1278
1279    #[test]
1280    fn test_extractive_summarization() {
1281        let config = HierarchicalConfig::default();
1282        let doc_id = DocumentId::new("test_doc".to_string());
1283        let tree = DocumentTree::new(doc_id, config).unwrap();
1284
1285        let text = "This is the first sentence. This is a second sentence with more details. This is the final sentence.";
1286        let summary = tree.generate_extractive_summary(text).unwrap();
1287
1288        assert!(!summary.is_empty());
1289        assert!(summary.len() <= tree.config.max_summary_length);
1290    }
1291
1292    #[test]
1293    fn test_json_serialization() {
1294        let config = HierarchicalConfig::default();
1295        let doc_id = DocumentId::new("test_doc".to_string());
1296        let tree = DocumentTree::new(doc_id, config).unwrap();
1297
1298        let json = tree.to_json().unwrap();
1299        assert!(json.contains("test_doc"));
1300
1301        let loaded_tree = DocumentTree::from_json(&json).unwrap();
1302        assert_eq!(loaded_tree.document_id.to_string(), "test_doc");
1303    }
1304}