1use crate::error::{Result, TextError};
6use crate::tokenize::Tokenizer;
7use crate::vectorize::{TfidfVectorizer, Vectorizer};
8use scirs2_core::ndarray::{Array1, Array2};
9use std::collections::HashSet;
10
11pub struct TextRank {
13 num_sentences: usize,
15 damping_factor: f64,
17 max_iterations: usize,
19 threshold: f64,
21 sentencetokenizer: Box<dyn Tokenizer + Send + Sync>,
23}
24
25impl TextRank {
26 pub fn new(_numsentences: usize) -> Self {
28 Self {
29 num_sentences: _numsentences,
30 damping_factor: 0.85,
31 max_iterations: 100,
32 threshold: 0.0001,
33 sentencetokenizer: Box::new(crate::tokenize::SentenceTokenizer::new()),
34 }
35 }
36
37 pub fn with_damping_factor(mut self, dampingfactor: f64) -> Result<Self> {
39 if !(0.0..=1.0).contains(&dampingfactor) {
40 return Err(TextError::InvalidInput(
41 "Damping _factor must be between 0 and 1".to_string(),
42 ));
43 }
44 self.damping_factor = dampingfactor;
45 Ok(self)
46 }
47
48 pub fn summarize(&self, text: &str) -> Result<String> {
50 let sentences: Vec<String> = self.sentencetokenizer.tokenize(text)?;
51
52 if sentences.is_empty() {
53 return Ok(String::new());
54 }
55
56 if sentences.len() <= self.num_sentences {
57 return Ok(text.to_string());
58 }
59
60 let similarity_matrix = self.build_similarity_matrix(&sentences)?;
62
63 let scores = self.page_rank(&similarity_matrix)?;
65
66 let selected_indices = self.select_top_sentences(&scores);
68
69 let summary = self.reconstruct_summary(&sentences, &selected_indices);
71
72 Ok(summary)
73 }
74
75 fn build_similarity_matrix(&self, sentences: &[String]) -> Result<Array2<f64>> {
77 let n = sentences.len();
78 let mut matrix = Array2::zeros((n, n));
79
80 let sentence_refs: Vec<&str> = sentences.iter().map(|s| s.as_ref()).collect();
82 let mut vectorizer = TfidfVectorizer::default();
83 vectorizer.fit(&sentence_refs)?;
84 let vectors = vectorizer.transform_batch(&sentence_refs)?;
85
86 for i in 0..n {
88 for j in 0..n {
89 if i == j {
90 matrix[[i, j]] = 0.0; } else {
92 let similarity = self
93 .cosine_similarity(vectors.row(i).to_owned(), vectors.row(j).to_owned());
94 matrix[[i, j]] = similarity;
95 }
96 }
97 }
98
99 Ok(matrix)
100 }
101
102 fn cosine_similarity(&self, vec1: Array1<f64>, vec2: Array1<f64>) -> f64 {
104 let dot_product = vec1.dot(&vec2);
105 let norm1 = vec1.dot(&vec1).sqrt();
106 let norm2 = vec2.dot(&vec2).sqrt();
107
108 if norm1 == 0.0 || norm2 == 0.0 {
109 0.0
110 } else {
111 dot_product / (norm1 * norm2)
112 }
113 }
114
115 fn page_rank(&self, matrix: &Array2<f64>) -> Result<Array1<f64>> {
117 let n = matrix.nrows();
118 let mut scores = Array1::from_elem(n, 1.0 / n as f64);
119
120 let mut normalized_matrix = matrix.clone();
122 for i in 0..n {
123 let row_sum: f64 = matrix.row(i).sum();
124 if row_sum > 0.0 {
125 normalized_matrix.row_mut(i).mapv_inplace(|x| x / row_sum);
126 }
127 }
128
129 for _ in 0..self.max_iterations {
131 let new_scores = Array1::from_elem(n, (1.0 - self.damping_factor) / n as f64)
132 + self.damping_factor * normalized_matrix.t().dot(&scores);
133
134 let diff = (&new_scores - &scores).mapv(f64::abs).sum();
136 scores = new_scores;
137
138 if diff < self.threshold {
139 break;
140 }
141 }
142
143 Ok(scores)
144 }
145
146 fn select_top_sentences(&self, scores: &Array1<f64>) -> Vec<usize> {
148 let mut indexed_scores: Vec<(usize, f64)> = scores
149 .iter()
150 .enumerate()
151 .map(|(i, &score)| (i, score))
152 .collect();
153
154 indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
155
156 indexed_scores
157 .iter()
158 .take(self.num_sentences)
159 .map(|&(idx_, _)| idx_)
160 .collect()
161 }
162
163 fn reconstruct_summary(&self, sentences: &[String], indices: &[usize]) -> String {
165 let mut sorted_indices = indices.to_vec();
166 sorted_indices.sort_unstable();
167
168 sorted_indices
169 .iter()
170 .map(|&idx| sentences[idx].clone())
171 .collect::<Vec<_>>()
172 .join(" ")
173 }
174}
175
176pub struct CentroidSummarizer {
178 num_sentences: usize,
180 topic_threshold: f64,
182 redundancy_threshold: f64,
184 sentencetokenizer: Box<dyn Tokenizer + Send + Sync>,
186}
187
188impl CentroidSummarizer {
189 pub fn new(_numsentences: usize) -> Self {
191 Self {
192 num_sentences: _numsentences,
193 topic_threshold: 0.1,
194 redundancy_threshold: 0.95,
195 sentencetokenizer: Box::new(crate::tokenize::SentenceTokenizer::new()),
196 }
197 }
198
199 pub fn summarize(&self, text: &str) -> Result<String> {
201 let sentences: Vec<String> = self.sentencetokenizer.tokenize(text)?;
202
203 if sentences.is_empty() {
204 return Ok(String::new());
205 }
206
207 if sentences.len() <= self.num_sentences {
208 return Ok(text.to_string());
209 }
210
211 let sentence_refs: Vec<&str> = sentences.iter().map(|s| s.as_ref()).collect();
213 let mut vectorizer = TfidfVectorizer::default();
214 vectorizer.fit(&sentence_refs)?;
215 let vectors = vectorizer.transform_batch(&sentence_refs)?;
216
217 let centroid = self.calculate_centroid(&vectors);
219
220 let selected_indices = self.select_sentences(&vectors, ¢roid);
222
223 let summary = self.reconstruct_summary(&sentences, &selected_indices);
225
226 Ok(summary)
227 }
228
229 fn calculate_centroid(&self, vectors: &Array2<f64>) -> Array1<f64> {
231 let _n_docs = vectors.nrows();
232 let mut centroid = vectors.mean_axis(scirs2_core::ndarray::Axis(0)).unwrap();
233
234 centroid.mapv_inplace(|x| if x > self.topic_threshold { x } else { 0.0 });
236
237 centroid
238 }
239
240 fn select_sentences(&self, vectors: &Array2<f64>, centroid: &Array1<f64>) -> Vec<usize> {
242 let mut selected = Vec::new();
243 let mut used_sentences = HashSet::new();
244
245 let mut similarities: Vec<(usize, f64)> = Vec::new();
247 for i in 0..vectors.nrows() {
248 let similarity = self.cosine_similarity(vectors.row(i).to_owned(), centroid.clone());
249 similarities.push((i, similarity));
250 }
251
252 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
254
255 for (idx_, _similarity) in similarities {
257 if selected.len() >= self.num_sentences {
258 break;
259 }
260
261 let mut is_redundant = false;
263 for &selected_idx in &selected {
264 let sim = self.cosine_similarity(
265 vectors.row(idx_).to_owned(),
266 vectors.row(selected_idx).to_owned(),
267 );
268 if sim > self.redundancy_threshold {
269 is_redundant = true;
270 break;
271 }
272 }
273
274 if !is_redundant {
275 selected.push(idx_);
276 used_sentences.insert(idx_);
277 }
278 }
279
280 selected
281 }
282
283 fn cosine_similarity(&self, vec1: Array1<f64>, vec2: Array1<f64>) -> f64 {
285 let dot_product = vec1.dot(&vec2);
286 let norm1 = vec1.dot(&vec1).sqrt();
287 let norm2 = vec2.dot(&vec2).sqrt();
288
289 if norm1 == 0.0 || norm2 == 0.0 {
290 0.0
291 } else {
292 dot_product / (norm1 * norm2)
293 }
294 }
295
296 fn reconstruct_summary(&self, sentences: &[String], indices: &[usize]) -> String {
298 let mut sorted_indices = indices.to_vec();
299 sorted_indices.sort_unstable();
300
301 sorted_indices
302 .iter()
303 .map(|&idx| sentences[idx].clone())
304 .collect::<Vec<_>>()
305 .join(" ")
306 }
307}
308
309pub struct KeywordExtractor {
311 _numkeywords: usize,
313 #[allow(dead_code)]
315 min_df: f64,
316 #[allow(dead_code)]
318 max_df: f64,
319 ngram_range: (usize, usize),
321}
322
323impl KeywordExtractor {
324 pub fn new(_numkeywords: usize) -> Self {
326 Self {
327 _numkeywords,
328 min_df: 0.01, max_df: 0.95, ngram_range: (1, 3),
331 }
332 }
333
334 pub fn with_ngram_range(mut self, min_n: usize, maxn: usize) -> Result<Self> {
336 if min_n > maxn || min_n == 0 {
337 return Err(TextError::InvalidInput("Invalid _n-gram range".to_string()));
338 }
339 self.ngram_range = (min_n, maxn);
340 Ok(self)
341 }
342
343 pub fn extract_keywords(&self, text: &str) -> Result<Vec<(String, f64)>> {
345 let sentence_tokenizer = crate::tokenize::SentenceTokenizer::new();
347 let sentences = sentence_tokenizer.tokenize(text)?;
348
349 if sentences.is_empty() {
350 return Ok(Vec::new());
351 }
352
353 let sentence_refs: Vec<&str> = sentences.iter().map(|s| s.as_ref()).collect();
354
355 let mut vectorizer = crate::enhanced_vectorize::EnhancedTfidfVectorizer::new()
358 .set_ngram_range((self.ngram_range.0, self.ngram_range.1))?;
359
360 vectorizer.fit(&sentence_refs)?;
361 let tfidf_matrix = vectorizer.transform_batch(&sentence_refs)?;
362
363 let avg_tfidf = tfidf_matrix
365 .mean_axis(scirs2_core::ndarray::Axis(0))
366 .unwrap();
367
368 let all_words: Vec<String> = text.split_whitespace().map(|w| w.to_string()).collect();
370
371 let mut keyword_scores: Vec<(String, f64)> = avg_tfidf
373 .iter()
374 .enumerate()
375 .take(self._numkeywords * 2) .map(|(i, &score)| {
377 let term = if i < all_words.len() {
378 all_words[i].clone()
379 } else {
380 format!("term_{i}")
381 };
382 (term, score)
383 })
384 .collect();
385
386 keyword_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
388
389 Ok(keyword_scores.into_iter().take(self._numkeywords).collect())
391 }
392
393 pub fn extract_keywords_with_positions(
395 &self,
396 text: &str,
397 ) -> Result<Vec<(String, f64, Vec<usize>)>> {
398 let keywords = self.extract_keywords(text)?;
399 let mut results = Vec::new();
400
401 for (keyword, score) in keywords {
402 let positions = self.find_keyword_positions(text, &keyword);
403 results.push((keyword, score, positions));
404 }
405
406 Ok(results)
407 }
408
409 fn find_keyword_positions(&self, text: &str, keyword: &str) -> Vec<usize> {
411 let mut positions = Vec::new();
412 let text_lower = text.to_lowercase();
413 let keyword_lower = keyword.to_lowercase();
414
415 let mut start = 0;
416 while let Some(pos) = text_lower[start..].find(&keyword_lower) {
417 positions.push(start + pos);
418 start += pos + keyword.len();
419 }
420
421 positions
422 }
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428
429 #[test]
430 fn testtextrank_summarizer() {
431 let summarizer = TextRank::new(2);
432 let text = "Machine learning is a subset of artificial intelligence. \
433 It enables computers to learn from data. \
434 Deep learning is a subset of machine learning. \
435 Neural networks are used in deep learning. \
436 These technologies are transforming many industries.";
437
438 let summary = summarizer.summarize(text).unwrap();
439 assert!(!summary.is_empty());
440 assert!(summary.len() < text.len());
441 }
442
443 #[test]
444 fn test_centroid_summarizer() {
445 let summarizer = CentroidSummarizer::new(2);
446 let text = "Natural language processing is important. \
447 It helps computers understand human language. \
448 Many applications use NLP technology. \
449 Chatbots and translation are examples. \
450 NLP continues to evolve rapidly.";
451
452 let summary = summarizer.summarize(text).unwrap();
453 assert!(!summary.is_empty());
454 }
455
456 #[test]
457 fn test_keyword_extraction() {
458 let extractor = KeywordExtractor::new(5);
459 let text = "Machine learning algorithms are essential for artificial intelligence. \
460 Deep learning models use neural networks. \
461 These models can process complex data patterns.";
462
463 let keywords = extractor.extract_keywords(text).unwrap();
464 assert!(!keywords.is_empty());
465 assert!(keywords.len() <= 5);
466
467 for i in 1..keywords.len() {
469 assert!(keywords[i - 1].1 >= keywords[i].1);
470 }
471 }
472
473 #[test]
474 fn test_keyword_positions() {
475 let extractor = KeywordExtractor::new(3);
476 let text = "Machine learning is great. Machine learning transforms industries.";
477
478 let keywords_with_pos = extractor.extract_keywords_with_positions(text).unwrap();
479
480 for (keyword, _score, positions) in keywords_with_pos {
482 if keyword.to_lowercase().contains("machine learning") {
483 assert!(positions.len() >= 2);
484 }
485 }
486 }
487
488 #[test]
489 fn test_emptytext() {
490 let textrank = TextRank::new(3);
491 let centroid = CentroidSummarizer::new(3);
492 let keywords = KeywordExtractor::new(5);
493
494 assert_eq!(textrank.summarize("").unwrap(), "");
495 assert_eq!(centroid.summarize("").unwrap(), "");
496 assert_eq!(keywords.extract_keywords("").unwrap().len(), 0);
497 }
498
499 #[test]
500 fn test_shorttext() {
501 let summarizer = TextRank::new(5);
502 let shorttext = "This is a short text.";
503
504 let summary = summarizer.summarize(shorttext).unwrap();
505 assert_eq!(summary, shorttext);
506 }
507}