pocket_cli/utils/
summarization.rs

1use anyhow::{Result, anyhow};
2use std::collections::HashMap;
3
4#[cfg(feature = "ml-summarization")]
5use std::sync::Arc;
6#[cfg(feature = "ml-summarization")]
7use rust_bert::bart::{BartForConditionalGeneration, BartConfig, BartModelResources};
8#[cfg(feature = "ml-summarization")]
9use rust_bert::resources::{RemoteResource, Resource, LocalResource};
10#[cfg(feature = "ml-summarization")]
11use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
12#[cfg(feature = "ml-summarization")]
13use rust_tokenizers::tokenizer::{BartTokenizer, Tokenizer, TruncationStrategy};
14#[cfg(feature = "ml-summarization")]
15use tch::{nn, Device, Tensor};
16#[cfg(feature = "ml-summarization")]
17use rust_bert::RustBertError;
18
19// Singleton pattern for the summarization model to avoid reloading
20#[cfg(feature = "ml-summarization")]
21lazy_static::lazy_static! {
22    static ref SUMMARIZATION_MODEL: std::sync::Mutex<Option<Box<SummarizationModel>>> = std::sync::Mutex::new(None);
23}
24
25/// Initialize the summarization model (if not already initialized)
26#[cfg(feature = "ml-summarization")]
27pub fn initialize_summarization_model() -> Result<()> {
28    let mut model_guard = SUMMARIZATION_MODEL.lock().unwrap();
29    
30    if model_guard.is_some() {
31        return Ok(());
32    }
33    
34    let config_resource = Resource::Remote(RemoteResource::from_pretrained(
35        "distilbart-cnn-6-6-config.json",
36    ));
37    let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
38        "distilbart-cnn-6-6-vocab.json",
39    ));
40    let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
41        "distilbart-cnn-6-6-merges.txt",
42    ));
43    let model_resource = Resource::Remote(RemoteResource::from_pretrained(
44        "distilbart-cnn-6-6-model.safetensors",
45    ));
46    
47    let generate_config = rust_bert::pipelines::generation::GenerateConfig {
48        max_length: 100,  // Short summary
49        min_length: 10,   // Reasonable minimum
50        do_sample: false, // deterministic generation
51        early_stopping: true,
52        no_repeat_ngram_size: 3,
53        num_beams: 3,     // Low beam count to conserve memory
54        temperature: 1.0,
55        top_k: 50,
56        top_p: 0.95,
57        repetition_penalty: 1.2,
58        length_penalty: 1.0,
59        ..Default::default()
60    };
61    
62    let model_resources = BartModelResources {
63        config_resource,
64        vocab_resource,
65        merges_resource,
66        model_resource,
67    };
68    
69    let summarization_config = SummarizationConfig {
70        model_resource: model_resources,
71        min_length: 10,
72        max_length: 100,
73        early_stopping: true,
74        num_beams: 3,
75        device: Device::cuda_if_available(), // Use CUDA if available, otherwise CPU
76        ..Default::default()
77    };
78    
79    let model = SummarizationModel::new(summarization_config)?;
80    *model_guard = Some(Box::new(model));
81    
82    Ok(())
83}
84
85/// Summarize text content - returns short summary
86/// 
87/// Optimized for minimal memory usage
88pub fn summarize_text(text: &str) -> Result<String> {
89    // If text is very short, don't summarize
90    if text.split_whitespace().count() < 20 {
91        return Ok(text.to_string());
92    }
93    
94    #[cfg(feature = "ml-summarization")]
95    {
96        // Initialize model if needed
97        initialize_summarization_model()?;
98        
99        let guard = SUMMARIZATION_MODEL.lock().unwrap();
100        
101        if let Some(model) = &*guard {
102            // Truncate text if it's very long to conserve memory
103            let truncated_text = if text.len() > 4000 {
104                text.chars().take(4000).collect::<String>()
105            } else {
106                text.to_string()
107            };
108            
109            let input_texts = vec![&truncated_text];
110            let output = model.summarize(&input_texts)?;
111            
112            // If summarization failed or returned nothing, use a fallback
113            if output.is_empty() {
114                return fallback_summarize_text(text);
115            }
116            
117            return Ok(output[0].clone());
118        }
119    }
120    
121    // If ML summarization is not available or failed, use fallback
122    fallback_summarize_text(text)
123}
124
125/// Alternative lightweight summarization method that uses a rule-based approach
126/// when the ML model is unavailable or fails
127fn fallback_summarize_text(text: &str) -> Result<String> {
128    // Simple extractive summarization based on sentence importance
129    
130    // Split into sentences
131    let sentences: Vec<&str> = text.split(|c| c == '.' || c == '!' || c == '?')
132        .filter(|s| !s.trim().is_empty())
133        .collect();
134    
135    if sentences.len() <= 2 {
136        return Ok(text.to_string());
137    }
138    
139    // Count word frequencies to calculate sentence importance
140    let mut word_freqs = HashMap::new();
141    
142    for sentence in &sentences {
143        for word in sentence.split_whitespace() {
144            let word = word.trim().to_lowercase();
145            if word.len() > 2 { // Skip very short words
146                *word_freqs.entry(word).or_insert(0) += 1;
147            }
148        }
149    }
150    
151    // Calculate sentence scores based on word frequencies
152    let mut sentence_scores: Vec<(usize, f64)> = sentences.iter().enumerate()
153        .map(|(i, &sentence)| {
154            let words = sentence.split_whitespace()
155                .map(|w| w.trim().to_lowercase())
156                .filter(|w| w.len() > 2)
157                .collect::<Vec<_>>();
158            
159            let score = words.iter()
160                .map(|word| word_freqs.get(word).unwrap_or(&0))
161                .sum::<u32>() as f64 / words.len().max(1) as f64;
162            
163            (i, score)
164        })
165        .collect();
166    
167    // Sort by score (highest first)
168    sentence_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
169    
170    // Take top 3 sentences or 20% of total sentences, whichever is greater
171    let num_sentences = sentences.len().max(3).min((sentences.len() as f64 * 0.2).ceil() as usize);
172    
173    // Get indices of top sentences and sort them by original position
174    let mut top_indices: Vec<usize> = sentence_scores.iter()
175        .take(num_sentences)
176        .map(|(idx, _)| *idx)
177        .collect();
178    
179    top_indices.sort();
180    
181    // Reconstruct summary from selected sentences
182    let summary = top_indices.iter()
183        .map(|&idx| sentences[idx])
184        .collect::<Vec<_>>()
185        .join(". ");
186    
187    Ok(format!("{}.", summary))
188}
189
190pub struct SummaryMetadata {
191    pub summary: String,
192    pub is_auto_generated: bool,
193}
194
195impl SummaryMetadata {
196    pub fn new(summary: String, is_auto_generated: bool) -> Self {
197        Self {
198            summary,
199            is_auto_generated,
200        }
201    }
202    
203    pub fn to_json(&self) -> String {
204        serde_json::json!({
205            "summary": self.summary,
206            "auto_generated": self.is_auto_generated,
207        }).to_string()
208    }
209    
210    pub fn from_json(json: &str) -> Result<Self> {
211        let parsed: serde_json::Value = serde_json::from_str(json)?;
212        
213        let summary = parsed["summary"].as_str()
214            .ok_or_else(|| anyhow!("Missing 'summary' field in summary metadata"))?
215            .to_string();
216            
217        let is_auto_generated = parsed["auto_generated"].as_bool()
218            .unwrap_or(true);
219            
220        Ok(Self {
221            summary,
222            is_auto_generated,
223        })
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    
231    #[test]
232    fn test_fallback_summarization() {
233        let long_text = "This is the first sentence about programming. This is a second sentence about Rust language features. 
234        This third sentence covers memory safety in systems programming. The fourth sentence discusses performance implications.
235        This is the fifth sentence on concurrency. The sixth sentence is about pattern matching features.
236        Seventh sentence looks at the module system. Eighth sentence considers error handling approaches.";
237        
238        let result = fallback_summarize_text(long_text).unwrap();
239        assert!(!result.is_empty());
240        assert!(result.len() < long_text.len());
241    }
242    
243    #[test]
244    fn test_summary_metadata_serialization() {
245        let metadata = SummaryMetadata::new("Test summary".to_string(), true);
246        let json = metadata.to_json();
247        let deserialized = SummaryMetadata::from_json(&json).unwrap();
248        
249        assert_eq!(metadata.summary, deserialized.summary);
250        assert_eq!(metadata.is_auto_generated, deserialized.is_auto_generated);
251    }
252}