pocket_cli/utils/
summarization.rs1use anyhow::{Result, anyhow};
2use std::collections::HashMap;
3
4#[cfg(feature = "ml-summarization")]
5use std::sync::Arc;
6#[cfg(feature = "ml-summarization")]
7use rust_bert::bart::{BartForConditionalGeneration, BartConfig, BartModelResources};
8#[cfg(feature = "ml-summarization")]
9use rust_bert::resources::{RemoteResource, Resource, LocalResource};
10#[cfg(feature = "ml-summarization")]
11use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
12#[cfg(feature = "ml-summarization")]
13use rust_tokenizers::tokenizer::{BartTokenizer, Tokenizer, TruncationStrategy};
14#[cfg(feature = "ml-summarization")]
15use tch::{nn, Device, Tensor};
16#[cfg(feature = "ml-summarization")]
17use rust_bert::RustBertError;
18
19#[cfg(feature = "ml-summarization")]
21lazy_static::lazy_static! {
22 static ref SUMMARIZATION_MODEL: std::sync::Mutex<Option<Box<SummarizationModel>>> = std::sync::Mutex::new(None);
23}
24
25#[cfg(feature = "ml-summarization")]
27pub fn initialize_summarization_model() -> Result<()> {
28 let mut model_guard = SUMMARIZATION_MODEL.lock().unwrap();
29
30 if model_guard.is_some() {
31 return Ok(());
32 }
33
34 let config_resource = Resource::Remote(RemoteResource::from_pretrained(
35 "distilbart-cnn-6-6-config.json",
36 ));
37 let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
38 "distilbart-cnn-6-6-vocab.json",
39 ));
40 let merges_resource = Resource::Remote(RemoteResource::from_pretrained(
41 "distilbart-cnn-6-6-merges.txt",
42 ));
43 let model_resource = Resource::Remote(RemoteResource::from_pretrained(
44 "distilbart-cnn-6-6-model.safetensors",
45 ));
46
47 let generate_config = rust_bert::pipelines::generation::GenerateConfig {
48 max_length: 100, min_length: 10, do_sample: false, early_stopping: true,
52 no_repeat_ngram_size: 3,
53 num_beams: 3, temperature: 1.0,
55 top_k: 50,
56 top_p: 0.95,
57 repetition_penalty: 1.2,
58 length_penalty: 1.0,
59 ..Default::default()
60 };
61
62 let model_resources = BartModelResources {
63 config_resource,
64 vocab_resource,
65 merges_resource,
66 model_resource,
67 };
68
69 let summarization_config = SummarizationConfig {
70 model_resource: model_resources,
71 min_length: 10,
72 max_length: 100,
73 early_stopping: true,
74 num_beams: 3,
75 device: Device::cuda_if_available(), ..Default::default()
77 };
78
79 let model = SummarizationModel::new(summarization_config)?;
80 *model_guard = Some(Box::new(model));
81
82 Ok(())
83}
84
85pub fn summarize_text(text: &str) -> Result<String> {
89 if text.split_whitespace().count() < 20 {
91 return Ok(text.to_string());
92 }
93
94 #[cfg(feature = "ml-summarization")]
95 {
96 initialize_summarization_model()?;
98
99 let guard = SUMMARIZATION_MODEL.lock().unwrap();
100
101 if let Some(model) = &*guard {
102 let truncated_text = if text.len() > 4000 {
104 text.chars().take(4000).collect::<String>()
105 } else {
106 text.to_string()
107 };
108
109 let input_texts = vec![&truncated_text];
110 let output = model.summarize(&input_texts)?;
111
112 if output.is_empty() {
114 return fallback_summarize_text(text);
115 }
116
117 return Ok(output[0].clone());
118 }
119 }
120
121 fallback_summarize_text(text)
123}
124
125fn fallback_summarize_text(text: &str) -> Result<String> {
128 let sentences: Vec<&str> = text.split(|c| c == '.' || c == '!' || c == '?')
132 .filter(|s| !s.trim().is_empty())
133 .collect();
134
135 if sentences.len() <= 2 {
136 return Ok(text.to_string());
137 }
138
139 let mut word_freqs = HashMap::new();
141
142 for sentence in &sentences {
143 for word in sentence.split_whitespace() {
144 let word = word.trim().to_lowercase();
145 if word.len() > 2 { *word_freqs.entry(word).or_insert(0) += 1;
147 }
148 }
149 }
150
151 let mut sentence_scores: Vec<(usize, f64)> = sentences.iter().enumerate()
153 .map(|(i, &sentence)| {
154 let words = sentence.split_whitespace()
155 .map(|w| w.trim().to_lowercase())
156 .filter(|w| w.len() > 2)
157 .collect::<Vec<_>>();
158
159 let score = words.iter()
160 .map(|word| word_freqs.get(word).unwrap_or(&0))
161 .sum::<u32>() as f64 / words.len().max(1) as f64;
162
163 (i, score)
164 })
165 .collect();
166
167 sentence_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
169
170 let num_sentences = sentences.len().max(3).min((sentences.len() as f64 * 0.2).ceil() as usize);
172
173 let mut top_indices: Vec<usize> = sentence_scores.iter()
175 .take(num_sentences)
176 .map(|(idx, _)| *idx)
177 .collect();
178
179 top_indices.sort();
180
181 let summary = top_indices.iter()
183 .map(|&idx| sentences[idx])
184 .collect::<Vec<_>>()
185 .join(". ");
186
187 Ok(format!("{}.", summary))
188}
189
190pub struct SummaryMetadata {
191 pub summary: String,
192 pub is_auto_generated: bool,
193}
194
195impl SummaryMetadata {
196 pub fn new(summary: String, is_auto_generated: bool) -> Self {
197 Self {
198 summary,
199 is_auto_generated,
200 }
201 }
202
203 pub fn to_json(&self) -> String {
204 serde_json::json!({
205 "summary": self.summary,
206 "auto_generated": self.is_auto_generated,
207 }).to_string()
208 }
209
210 pub fn from_json(json: &str) -> Result<Self> {
211 let parsed: serde_json::Value = serde_json::from_str(json)?;
212
213 let summary = parsed["summary"].as_str()
214 .ok_or_else(|| anyhow!("Missing 'summary' field in summary metadata"))?
215 .to_string();
216
217 let is_auto_generated = parsed["auto_generated"].as_bool()
218 .unwrap_or(true);
219
220 Ok(Self {
221 summary,
222 is_auto_generated,
223 })
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_fallback_summarization() {
233 let long_text = "This is the first sentence about programming. This is a second sentence about Rust language features.
234 This third sentence covers memory safety in systems programming. The fourth sentence discusses performance implications.
235 This is the fifth sentence on concurrency. The sixth sentence is about pattern matching features.
236 Seventh sentence looks at the module system. Eighth sentence considers error handling approaches.";
237
238 let result = fallback_summarize_text(long_text).unwrap();
239 assert!(!result.is_empty());
240 assert!(result.len() < long_text.len());
241 }
242
243 #[test]
244 fn test_summary_metadata_serialization() {
245 let metadata = SummaryMetadata::new("Test summary".to_string(), true);
246 let json = metadata.to_json();
247 let deserialized = SummaryMetadata::from_json(&json).unwrap();
248
249 assert_eq!(metadata.summary, deserialized.summary);
250 assert_eq!(metadata.is_auto_generated, deserialized.is_auto_generated);
251 }
252}