Skip to main content

scirs2_text/
summarization.rs

1//! Text summarization module
2//!
3//! This module provides various algorithms for automatic text summarization.
4
5use crate::error::{Result, TextError};
6use crate::tokenize::Tokenizer;
7use crate::vectorize::{TfidfVectorizer, Vectorizer};
8use scirs2_core::ndarray::{Array1, Array2};
9use std::collections::HashSet;
10
11/// TextRank algorithm for extractive summarization
12pub struct TextRank {
13    /// Number of sentences to extract
14    num_sentences: usize,
15    /// Damping factor (usually 0.85)
16    damping_factor: f64,
17    /// Maximum iterations
18    max_iterations: usize,
19    /// Convergence threshold
20    threshold: f64,
21    /// Tokenizer for sentence splitting
22    sentencetokenizer: Box<dyn Tokenizer + Send + Sync>,
23}
24
25impl TextRank {
26    /// Create a new TextRank summarizer
27    pub fn new(_numsentences: usize) -> Self {
28        Self {
29            num_sentences: _numsentences,
30            damping_factor: 0.85,
31            max_iterations: 100,
32            threshold: 0.0001,
33            sentencetokenizer: Box::new(crate::tokenize::SentenceTokenizer::new()),
34        }
35    }
36
37    /// Set the damping factor
38    pub fn with_damping_factor(mut self, dampingfactor: f64) -> Result<Self> {
39        if !(0.0..=1.0).contains(&dampingfactor) {
40            return Err(TextError::InvalidInput(
41                "Damping _factor must be between 0 and 1".to_string(),
42            ));
43        }
44        self.damping_factor = dampingfactor;
45        Ok(self)
46    }
47
48    /// Extract summary from text
49    pub fn summarize(&self, text: &str) -> Result<String> {
50        let sentences: Vec<String> = self.sentencetokenizer.tokenize(text)?;
51
52        if sentences.is_empty() {
53            return Ok(String::new());
54        }
55
56        if sentences.len() <= self.num_sentences {
57            return Ok(text.to_string());
58        }
59
60        // Build similarity matrix
61        let similarity_matrix = self.build_similarity_matrix(&sentences)?;
62
63        // Apply PageRank algorithm
64        let scores = self.page_rank(&similarity_matrix)?;
65
66        // Select top sentences
67        let selected_indices = self.select_top_sentences(&scores);
68
69        // Reconstruct summary maintaining original order
70        let summary = self.reconstruct_summary(&sentences, &selected_indices);
71
72        Ok(summary)
73    }
74
75    /// Build similarity matrix between sentences
76    fn build_similarity_matrix(&self, sentences: &[String]) -> Result<Array2<f64>> {
77        let n = sentences.len();
78        let mut matrix = Array2::zeros((n, n));
79
80        // Use TF-IDF for sentence representation
81        let sentence_refs: Vec<&str> = sentences.iter().map(|s| s.as_ref()).collect();
82        let mut vectorizer = TfidfVectorizer::default();
83        vectorizer.fit(&sentence_refs)?;
84        let vectors = vectorizer.transform_batch(&sentence_refs)?;
85
86        // Calculate cosine similarity between all pairs
87        for i in 0..n {
88            for j in 0..n {
89                if i == j {
90                    matrix[[i, j]] = 0.0; // No self-loops
91                } else {
92                    let similarity = self
93                        .cosine_similarity(vectors.row(i).to_owned(), vectors.row(j).to_owned());
94                    matrix[[i, j]] = similarity;
95                }
96            }
97        }
98
99        Ok(matrix)
100    }
101
102    /// Calculate cosine similarity between two vectors
103    fn cosine_similarity(&self, vec1: Array1<f64>, vec2: Array1<f64>) -> f64 {
104        let dot_product = vec1.dot(&vec2);
105        let norm1 = vec1.dot(&vec1).sqrt();
106        let norm2 = vec2.dot(&vec2).sqrt();
107
108        if norm1 == 0.0 || norm2 == 0.0 {
109            0.0
110        } else {
111            dot_product / (norm1 * norm2)
112        }
113    }
114
115    /// Apply PageRank algorithm
116    fn page_rank(&self, matrix: &Array2<f64>) -> Result<Array1<f64>> {
117        let n = matrix.nrows();
118        let mut scores = Array1::from_elem(n, 1.0 / n as f64);
119
120        // Normalize rows of similarity matrix
121        let mut normalized_matrix = matrix.clone();
122        for i in 0..n {
123            let row_sum: f64 = matrix.row(i).sum();
124            if row_sum > 0.0 {
125                normalized_matrix.row_mut(i).mapv_inplace(|x| x / row_sum);
126            }
127        }
128
129        // Iterate until convergence
130        for _ in 0..self.max_iterations {
131            let new_scores = Array1::from_elem(n, (1.0 - self.damping_factor) / n as f64)
132                + self.damping_factor * normalized_matrix.t().dot(&scores);
133
134            // Check convergence
135            let diff = (&new_scores - &scores).mapv(f64::abs).sum();
136            scores = new_scores;
137
138            if diff < self.threshold {
139                break;
140            }
141        }
142
143        Ok(scores)
144    }
145
146    /// Select top scoring sentences
147    fn select_top_sentences(&self, scores: &Array1<f64>) -> Vec<usize> {
148        let mut indexed_scores: Vec<(usize, f64)> = scores
149            .iter()
150            .enumerate()
151            .map(|(i, &score)| (i, score))
152            .collect();
153
154        indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Operation failed"));
155
156        indexed_scores
157            .iter()
158            .take(self.num_sentences)
159            .map(|&(idx_, _)| idx_)
160            .collect()
161    }
162
163    /// Reconstruct summary maintaining original order
164    fn reconstruct_summary(&self, sentences: &[String], indices: &[usize]) -> String {
165        let mut sorted_indices = indices.to_vec();
166        sorted_indices.sort_unstable();
167
168        sorted_indices
169            .iter()
170            .map(|&idx| sentences[idx].clone())
171            .collect::<Vec<_>>()
172            .join(" ")
173    }
174}
175
176/// Centroid-based summarization
177pub struct CentroidSummarizer {
178    /// Number of sentences to extract
179    num_sentences: usize,
180    /// Topic threshold
181    topic_threshold: f64,
182    /// Redundancy threshold
183    redundancy_threshold: f64,
184    /// Sentence tokenizer
185    sentencetokenizer: Box<dyn Tokenizer + Send + Sync>,
186}
187
188impl CentroidSummarizer {
189    /// Create a new centroid summarizer
190    pub fn new(_numsentences: usize) -> Self {
191        Self {
192            num_sentences: _numsentences,
193            topic_threshold: 0.1,
194            redundancy_threshold: 0.95,
195            sentencetokenizer: Box::new(crate::tokenize::SentenceTokenizer::new()),
196        }
197    }
198
199    /// Summarize text using centroid method
200    pub fn summarize(&self, text: &str) -> Result<String> {
201        let sentences: Vec<String> = self.sentencetokenizer.tokenize(text)?;
202
203        if sentences.is_empty() {
204            return Ok(String::new());
205        }
206
207        if sentences.len() <= self.num_sentences {
208            return Ok(text.to_string());
209        }
210
211        // Create TF-IDF vectors
212        let sentence_refs: Vec<&str> = sentences.iter().map(|s| s.as_ref()).collect();
213        let mut vectorizer = TfidfVectorizer::default();
214        vectorizer.fit(&sentence_refs)?;
215        let vectors = vectorizer.transform_batch(&sentence_refs)?;
216
217        // Calculate centroid
218        let centroid = self.calculate_centroid(&vectors);
219
220        // Select sentences closest to centroid
221        let selected_indices = self.select_sentences(&vectors, &centroid);
222
223        // Reconstruct summary
224        let summary = self.reconstruct_summary(&sentences, &selected_indices);
225
226        Ok(summary)
227    }
228
229    /// Calculate document centroid
230    fn calculate_centroid(&self, vectors: &Array2<f64>) -> Array1<f64> {
231        let _n_docs = vectors.nrows();
232        let mut centroid = vectors
233            .mean_axis(scirs2_core::ndarray::Axis(0))
234            .expect("Operation failed");
235
236        // Apply topic threshold
237        centroid.mapv_inplace(|x| if x > self.topic_threshold { x } else { 0.0 });
238
239        centroid
240    }
241
242    /// Select sentences based on centroid similarity
243    fn select_sentences(&self, vectors: &Array2<f64>, centroid: &Array1<f64>) -> Vec<usize> {
244        let mut selected = Vec::new();
245        let mut used_sentences = HashSet::new();
246
247        // Calculate similarities to centroid
248        let mut similarities: Vec<(usize, f64)> = Vec::new();
249        for i in 0..vectors.nrows() {
250            let similarity = self.cosine_similarity(vectors.row(i).to_owned(), centroid.clone());
251            similarities.push((i, similarity));
252        }
253
254        // Sort by similarity
255        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Operation failed"));
256
257        // Select sentences avoiding redundancy
258        for (idx_, _similarity) in similarities {
259            if selected.len() >= self.num_sentences {
260                break;
261            }
262
263            // Check redundancy with already selected sentences
264            let mut is_redundant = false;
265            for &selected_idx in &selected {
266                let sim = self.cosine_similarity(
267                    vectors.row(idx_).to_owned(),
268                    vectors.row(selected_idx).to_owned(),
269                );
270                if sim > self.redundancy_threshold {
271                    is_redundant = true;
272                    break;
273                }
274            }
275
276            if !is_redundant {
277                selected.push(idx_);
278                used_sentences.insert(idx_);
279            }
280        }
281
282        selected
283    }
284
285    /// Calculate cosine similarity
286    fn cosine_similarity(&self, vec1: Array1<f64>, vec2: Array1<f64>) -> f64 {
287        let dot_product = vec1.dot(&vec2);
288        let norm1 = vec1.dot(&vec1).sqrt();
289        let norm2 = vec2.dot(&vec2).sqrt();
290
291        if norm1 == 0.0 || norm2 == 0.0 {
292            0.0
293        } else {
294            dot_product / (norm1 * norm2)
295        }
296    }
297
298    /// Reconstruct summary maintaining original order
299    fn reconstruct_summary(&self, sentences: &[String], indices: &[usize]) -> String {
300        let mut sorted_indices = indices.to_vec();
301        sorted_indices.sort_unstable();
302
303        sorted_indices
304            .iter()
305            .map(|&idx| sentences[idx].clone())
306            .collect::<Vec<_>>()
307            .join(" ")
308    }
309}
310
311/// Keyword extraction using TF-IDF
312pub struct KeywordExtractor {
313    /// Number of keywords to extract
314    _numkeywords: usize,
315    /// Minimum document frequency
316    #[allow(dead_code)]
317    min_df: f64,
318    /// Maximum document frequency
319    #[allow(dead_code)]
320    max_df: f64,
321    /// N-gram range
322    ngram_range: (usize, usize),
323}
324
325impl KeywordExtractor {
326    /// Create a new keyword extractor
327    pub fn new(_numkeywords: usize) -> Self {
328        Self {
329            _numkeywords,
330            min_df: 0.01, // Unused but kept for API compatibility
331            max_df: 0.95, // Unused but kept for API compatibility
332            ngram_range: (1, 3),
333        }
334    }
335
336    /// Configure n-gram range
337    pub fn with_ngram_range(mut self, min_n: usize, maxn: usize) -> Result<Self> {
338        if min_n > maxn || min_n == 0 {
339            return Err(TextError::InvalidInput("Invalid _n-gram range".to_string()));
340        }
341        self.ngram_range = (min_n, maxn);
342        Ok(self)
343    }
344
345    /// Extract keywords from text
346    pub fn extract_keywords(&self, text: &str) -> Result<Vec<(String, f64)>> {
347        // Split into sentences for better TF-IDF
348        let sentence_tokenizer = crate::tokenize::SentenceTokenizer::new();
349        let sentences = sentence_tokenizer.tokenize(text)?;
350
351        if sentences.is_empty() {
352            return Ok(Vec::new());
353        }
354
355        let sentence_refs: Vec<&str> = sentences.iter().map(|s| s.as_ref()).collect();
356
357        // Create enhanced TF-IDF vectorizer with n-grams
358        // Create vectorizer with ngram range configuration
359        let mut vectorizer = crate::enhanced_vectorize::EnhancedTfidfVectorizer::new()
360            .set_ngram_range((self.ngram_range.0, self.ngram_range.1))?;
361
362        vectorizer.fit(&sentence_refs)?;
363        let tfidf_matrix = vectorizer.transform_batch(&sentence_refs)?;
364
365        // Calculate average TF-IDF scores across documents
366        let avg_tfidf = tfidf_matrix
367            .mean_axis(scirs2_core::ndarray::Axis(0))
368            .expect("Operation failed");
369
370        // Get terms from the tokenizer directly
371        let all_words: Vec<String> = text.split_whitespace().map(|w| w.to_string()).collect();
372
373        // Create keyword-score pairs (use top scoring features)
374        let mut keyword_scores: Vec<(String, f64)> = avg_tfidf
375            .iter()
376            .enumerate()
377            .take(self._numkeywords * 2) // Get more than needed to filter
378            .map(|(i, &score)| {
379                let term = if i < all_words.len() {
380                    all_words[i].clone()
381                } else {
382                    format!("term_{i}")
383                };
384                (term, score)
385            })
386            .collect();
387
388        // Sort by score
389        keyword_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Operation failed"));
390
391        // Return top keywords
392        Ok(keyword_scores.into_iter().take(self._numkeywords).collect())
393    }
394
395    /// Extract keywords with position information
396    pub fn extract_keywords_with_positions(
397        &self,
398        text: &str,
399    ) -> Result<Vec<(String, f64, Vec<usize>)>> {
400        let keywords = self.extract_keywords(text)?;
401        let mut results = Vec::new();
402
403        for (keyword, score) in keywords {
404            let positions = self.find_keyword_positions(text, &keyword);
405            results.push((keyword, score, positions));
406        }
407
408        Ok(results)
409    }
410
411    /// Find positions of a keyword in text
412    fn find_keyword_positions(&self, text: &str, keyword: &str) -> Vec<usize> {
413        let mut positions = Vec::new();
414        let text_lower = text.to_lowercase();
415        let keyword_lower = keyword.to_lowercase();
416
417        let mut start = 0;
418        while let Some(pos) = text_lower[start..].find(&keyword_lower) {
419            positions.push(start + pos);
420            start += pos + keyword.len();
421        }
422
423        positions
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430
431    #[test]
432    fn testtextrank_summarizer() {
433        let summarizer = TextRank::new(2);
434        let text = "Machine learning is a subset of artificial intelligence. \
435                    It enables computers to learn from data. \
436                    Deep learning is a subset of machine learning. \
437                    Neural networks are used in deep learning. \
438                    These technologies are transforming many industries.";
439
440        let summary = summarizer.summarize(text).expect("Operation failed");
441        assert!(!summary.is_empty());
442        assert!(summary.len() < text.len());
443    }
444
445    #[test]
446    fn test_centroid_summarizer() {
447        let summarizer = CentroidSummarizer::new(2);
448        let text = "Natural language processing is important. \
449                    It helps computers understand human language. \
450                    Many applications use NLP technology. \
451                    Chatbots and translation are examples. \
452                    NLP continues to evolve rapidly.";
453
454        let summary = summarizer.summarize(text).expect("Operation failed");
455        assert!(!summary.is_empty());
456    }
457
458    #[test]
459    fn test_keyword_extraction() {
460        let extractor = KeywordExtractor::new(5);
461        let text = "Machine learning algorithms are essential for artificial intelligence. \
462                    Deep learning models use neural networks. \
463                    These models can process complex data patterns.";
464
465        let keywords = extractor.extract_keywords(text).expect("Operation failed");
466        assert!(!keywords.is_empty());
467        assert!(keywords.len() <= 5);
468
469        // Check that scores are in descending order
470        for i in 1..keywords.len() {
471            assert!(keywords[i - 1].1 >= keywords[i].1);
472        }
473    }
474
475    #[test]
476    fn test_keyword_positions() {
477        let extractor = KeywordExtractor::new(3);
478        let text = "Machine learning is great. Machine learning transforms industries.";
479
480        let keywords_with_pos = extractor
481            .extract_keywords_with_positions(text)
482            .expect("Operation failed");
483
484        // Should find positions for repeated keywords
485        for (keyword, _score, positions) in keywords_with_pos {
486            if keyword.to_lowercase().contains("machine learning") {
487                assert!(positions.len() >= 2);
488            }
489        }
490    }
491
492    #[test]
493    fn test_emptytext() {
494        let textrank = TextRank::new(3);
495        let centroid = CentroidSummarizer::new(3);
496        let keywords = KeywordExtractor::new(5);
497
498        assert_eq!(textrank.summarize("").expect("Operation failed"), "");
499        assert_eq!(centroid.summarize("").expect("Operation failed"), "");
500        assert_eq!(
501            keywords
502                .extract_keywords("")
503                .expect("Operation failed")
504                .len(),
505            0
506        );
507    }
508
509    #[test]
510    fn test_shorttext() {
511        let summarizer = TextRank::new(5);
512        let shorttext = "This is a short text.";
513
514        let summary = summarizer.summarize(shorttext).expect("Operation failed");
515        assert_eq!(summary, shorttext);
516    }
517}