1use crate::error::{Result, TextError};
20use crate::tokenize::{SentenceTokenizer, Tokenizer, WordTokenizer};
21use crate::vectorize::{TfidfVectorizer, Vectorizer};
22use scirs2_core::ndarray::{Array1, Array2};
23use std::collections::{HashMap, HashSet};
24
25fn word_tokens(sentence: &str) -> Vec<String> {
31 sentence
32 .split(|c: char| !c.is_alphanumeric())
33 .filter(|t| !t.is_empty())
34 .map(|t| t.to_lowercase())
35 .collect()
36}
37
38fn cosine_sim_rows(matrix: &Array2<f64>, i: usize, j: usize) -> f64 {
40 let cols = matrix.ncols();
41 let mut dot = 0.0_f64;
42 let mut n1 = 0.0_f64;
43 let mut n2 = 0.0_f64;
44 for c in 0..cols {
45 let a = matrix[[i, c]];
46 let b = matrix[[j, c]];
47 dot += a * b;
48 n1 += a * a;
49 n2 += b * b;
50 }
51 let denom = n1.sqrt() * n2.sqrt();
52 if denom == 0.0 {
53 0.0
54 } else {
55 dot / denom
56 }
57}
58
59fn cosine_sim_vec(row: &Array1<f64>, centroid: &Array1<f64>) -> f64 {
61 let dot = row.dot(centroid);
62 let n1 = row.dot(row).sqrt();
63 let n2 = centroid.dot(centroid).sqrt();
64 if n1 == 0.0 || n2 == 0.0 {
65 0.0
66 } else {
67 dot / (n1 * n2)
68 }
69}
70
71fn build_tfidf_matrix(sentences: &[String]) -> Result<(Array2<f64>, TfidfVectorizer)> {
74 if sentences.is_empty() {
75 return Err(TextError::InvalidInput(
76 "Cannot build TF-IDF matrix from empty sentence list".to_string(),
77 ));
78 }
79 let refs: Vec<&str> = sentences.iter().map(|s| s.as_str()).collect();
80 let mut vectorizer = TfidfVectorizer::default();
81 let matrix = vectorizer.fit_transform(&refs)?;
82 Ok((matrix, vectorizer))
83}
84
85#[derive(Debug, Clone)]
91pub struct ScoredSentence {
92 pub text: String,
94 pub index: usize,
96 pub doc_index: usize,
98 pub score: f64,
100}
101
102pub struct FusionSummarizer {
128 n_clusters: usize,
130 max_iter: usize,
132 cluster_threshold: f64,
134}
135
136impl FusionSummarizer {
137 pub fn new(n_clusters: usize) -> Self {
139 Self {
140 n_clusters: n_clusters.max(1),
141 max_iter: 30,
142 cluster_threshold: 0.1,
143 }
144 }
145
146 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
148 self.max_iter = max_iter;
149 self
150 }
151
152 pub fn with_cluster_threshold(mut self, threshold: f64) -> Self {
154 self.cluster_threshold = threshold.clamp(0.0, 1.0);
155 self
156 }
157
158 pub fn extract_sentences(&self, documents: &[&str]) -> Vec<ScoredSentence> {
162 if documents.is_empty() {
163 return Vec::new();
164 }
165
166 let sentence_tokenizer = SentenceTokenizer::new();
167 let mut all_sentences: Vec<ScoredSentence> = Vec::new();
168 let mut global_index = 0usize;
169
170 let mut raw_per_doc: Vec<Vec<String>> = Vec::new();
172 for doc in documents {
173 let sents = sentence_tokenizer
174 .tokenize(doc)
175 .unwrap_or_else(|_| vec![doc.to_string()]);
176 raw_per_doc.push(sents);
177 }
178
179 let flat: Vec<String> = raw_per_doc.iter().flatten().cloned().collect();
181 if flat.is_empty() {
182 return Vec::new();
183 }
184
185 let flat_refs: Vec<&str> = flat.iter().map(|s| s.as_str()).collect();
187 let mut vectorizer = TfidfVectorizer::default();
188 let tfidf = match vectorizer.fit_transform(&flat_refs) {
189 Ok(m) => m,
190 Err(_) => return Vec::new(),
191 };
192
193 let cols = tfidf.ncols();
194 let n = flat.len();
195
196 for (flat_idx, sentence) in flat.iter().enumerate() {
197 let score = if cols > 0 {
199 let row_sum: f64 = (0..cols).map(|c| tfidf[[flat_idx, c]]).sum();
200 row_sum / cols as f64
201 } else {
202 0.0
203 };
204
205 let mut doc_index = 0usize;
207 let mut cumulative = 0usize;
208 for (di, sents) in raw_per_doc.iter().enumerate() {
209 if flat_idx < cumulative + sents.len() {
210 doc_index = di;
211 break;
212 }
213 cumulative += sents.len();
214 }
215
216 all_sentences.push(ScoredSentence {
217 text: sentence.clone(),
218 index: global_index,
219 doc_index,
220 score,
221 });
222 global_index += 1;
223 }
224
225 let max_score = all_sentences
227 .iter()
228 .map(|s| s.score)
229 .fold(0.0_f64, f64::max);
230 if max_score > 0.0 {
231 for s in &mut all_sentences {
232 s.score /= max_score;
233 }
234 }
235
236 all_sentences
237 }
238
239 pub fn cluster_sentences(
248 &self,
249 sentences: &[ScoredSentence],
250 n_clusters: usize,
251 ) -> Vec<Vec<ScoredSentence>> {
252 let k = n_clusters.min(sentences.len()).max(1);
253 if sentences.is_empty() {
254 return Vec::new();
255 }
256 if sentences.len() <= k {
257 return sentences.iter().map(|s| vec![s.clone()]).collect();
259 }
260
261 let texts: Vec<String> = sentences.iter().map(|s| s.text.clone()).collect();
263 let refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
264 let mut vectorizer = TfidfVectorizer::default();
265 let matrix = match vectorizer.fit_transform(&refs) {
266 Ok(m) => m,
267 Err(_) => {
268 return vec![sentences.to_vec()];
270 }
271 };
272
273 let n = sentences.len();
274 let cols = matrix.ncols();
275
276 let mut sorted_indices: Vec<usize> = (0..n).collect();
278 sorted_indices.sort_by(|&a, &b| {
279 sentences[b]
280 .score
281 .partial_cmp(&sentences[a].score)
282 .unwrap_or(std::cmp::Ordering::Equal)
283 });
284 let centroid_indices: Vec<usize> = sorted_indices.into_iter().take(k).collect();
285
286 let mut centroids: Vec<Array1<f64>> = centroid_indices
288 .iter()
289 .map(|&ci| matrix.row(ci).to_owned())
290 .collect();
291
292 let mut assignments: Vec<usize> = vec![0; n];
293
294 for _iter in 0..self.max_iter {
295 let mut changed = false;
296
297 for i in 0..n {
299 let row = matrix.row(i).to_owned();
300 let best_cluster = centroids
301 .iter()
302 .enumerate()
303 .map(|(ci, centroid)| (ci, cosine_sim_vec(&row, centroid)))
304 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
305 .map(|(ci, _)| ci)
306 .unwrap_or(0);
307
308 if assignments[i] != best_cluster {
309 assignments[i] = best_cluster;
310 changed = true;
311 }
312 }
313
314 if !changed {
315 break;
316 }
317
318 for ci in 0..k {
320 let members: Vec<usize> = (0..n).filter(|&i| assignments[i] == ci).collect();
321 if members.is_empty() {
322 continue;
324 }
325 let mut new_centroid = Array1::zeros(cols);
326 for &mi in &members {
327 new_centroid = new_centroid + matrix.row(mi).to_owned();
328 }
329 let count = members.len() as f64;
330 new_centroid.mapv_inplace(|v| v / count);
331 centroids[ci] = new_centroid;
332 }
333 }
334
335 let mut clusters: Vec<Vec<ScoredSentence>> = vec![Vec::new(); k];
337 for (i, &ci) in assignments.iter().enumerate() {
338 clusters[ci].push(sentences[i].clone());
339 }
340
341 clusters.retain(|c| !c.is_empty());
343 clusters
344 }
345
346 pub fn generate_summary(&self, clusters: &[Vec<ScoredSentence>], max_words: usize) -> String {
351 if clusters.is_empty() {
352 return String::new();
353 }
354
355 let mut representatives: Vec<&ScoredSentence> = clusters
357 .iter()
358 .filter_map(|cluster| {
359 cluster.iter().max_by(|a, b| {
360 a.score
361 .partial_cmp(&b.score)
362 .unwrap_or(std::cmp::Ordering::Equal)
363 })
364 })
365 .collect();
366
367 representatives.sort_by_key(|s| s.index);
369
370 let mut words_used = 0usize;
372 let mut collected_words: Vec<&str> = Vec::new();
373
374 'outer: for rep in representatives {
375 let sentence_words: Vec<&str> = rep.text.split_whitespace().collect();
376 for word in &sentence_words {
377 if words_used >= max_words {
378 break 'outer;
379 }
380 collected_words.push(word);
381 words_used += 1;
382 }
383 }
384
385 collected_words.join(" ")
386 }
387
388 pub fn summarize(&self, documents: &[&str], max_words: usize) -> Result<String> {
390 if documents.is_empty() {
391 return Ok(String::new());
392 }
393 let sentences = self.extract_sentences(documents);
394 if sentences.is_empty() {
395 return Ok(String::new());
396 }
397 let clusters = self.cluster_sentences(&sentences, self.n_clusters);
398 Ok(self.generate_summary(&clusters, max_words))
399 }
400}
401
402pub struct CompressionSummarizer {
422 stop_words: HashSet<String>,
424}
425
426impl CompressionSummarizer {
427 pub fn new() -> Self {
429 let raw = [
430 "a", "an", "the", "is", "are", "was", "were", "be", "been", "being", "have", "has",
431 "had", "do", "does", "did", "will", "would", "could", "should", "may", "might",
432 "shall", "can", "and", "but", "or", "nor", "for", "yet", "so", "in", "on", "at", "to",
433 "from", "by", "with", "of", "about", "as", "into", "through", "during", "before",
434 "after", "above", "below", "between", "each", "all", "both", "very", "just", "too",
435 "also", "then", "than", "that", "this", "these", "those", "i", "me", "my", "we", "our",
436 "you", "your", "he", "she", "it", "its", "they", "them", "their", "what", "which",
437 "who", "whom", "not", "no",
438 ];
439 Self {
440 stop_words: raw.iter().map(|w| w.to_string()).collect(),
441 }
442 }
443
444 pub fn with_stop_words(stop_words: HashSet<String>) -> Self {
446 Self { stop_words }
447 }
448
449 pub fn importance_score(&self, token: &str, sentence_tokens: &[String]) -> f64 {
458 if sentence_tokens.is_empty() {
459 return 0.0;
460 }
461 let token_lower = token.to_lowercase();
462
463 let tf = sentence_tokens
465 .iter()
466 .filter(|t| t.to_lowercase() == token_lower)
467 .count() as f64
468 / sentence_tokens.len() as f64;
469
470 let stop_penalty = if self.stop_words.contains(&token_lower) {
472 0.1
473 } else {
474 1.0
475 };
476
477 let len_bonus = (1.0 + (token.len() as f64 / 10.0).min(1.0)) * 0.5;
479
480 let cap_bonus = if token
482 .chars()
483 .next()
484 .map(|c| c.is_uppercase())
485 .unwrap_or(false)
486 {
487 0.3
488 } else {
489 0.0
490 };
491
492 (tf * stop_penalty + len_bonus + cap_bonus).max(0.0)
493 }
494
495 pub fn compress_sentence(&self, sentence: &str, ratio: f64) -> String {
503 let ratio = ratio.clamp(0.01, 1.0);
504
505 let original_tokens: Vec<&str> = sentence.split_whitespace().collect();
507 if original_tokens.is_empty() {
508 return String::new();
509 }
510
511 let norm_tokens: Vec<String> = original_tokens
513 .iter()
514 .map(|t| {
515 t.trim_matches(|c: char| !c.is_alphanumeric())
516 .to_lowercase()
517 })
518 .collect();
519
520 let n = original_tokens.len();
521 let keep_count = ((n as f64 * ratio).ceil() as usize).clamp(1, n);
522
523 let mut scored: Vec<(usize, f64)> = norm_tokens
525 .iter()
526 .enumerate()
527 .map(|(i, t)| (i, self.importance_score(t, &norm_tokens)))
528 .collect();
529
530 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
532
533 let mut keep_indices: Vec<usize> =
535 scored.iter().take(keep_count).map(|&(i, _)| i).collect();
536 keep_indices.sort_unstable();
537
538 keep_indices
540 .iter()
541 .map(|&i| original_tokens[i])
542 .collect::<Vec<_>>()
543 .join(" ")
544 }
545}
546
547impl Default for CompressionSummarizer {
548 fn default() -> Self {
549 Self::new()
550 }
551}
552
553pub struct EnhancedCentroidSummarizer {
580 num_sentences: usize,
581 topic_threshold: f64,
582 redundancy_threshold: f64,
583 position_bias: f64,
585}
586
587impl EnhancedCentroidSummarizer {
588 pub fn new(num_sentences: usize) -> Self {
590 Self {
591 num_sentences: num_sentences.max(1),
592 topic_threshold: 0.1,
593 redundancy_threshold: 0.95,
594 position_bias: 0.2,
595 }
596 }
597
598 pub fn with_position_bias(mut self, bias: f64) -> Self {
600 self.position_bias = bias.clamp(0.0, 1.0);
601 self
602 }
603
604 pub fn with_topic_threshold(mut self, threshold: f64) -> Self {
606 self.topic_threshold = threshold.clamp(0.0, 1.0);
607 self
608 }
609
610 pub fn with_redundancy_threshold(mut self, threshold: f64) -> Self {
612 self.redundancy_threshold = threshold.clamp(0.0, 1.0);
613 self
614 }
615
616 pub fn summarize(&self, text: &str) -> Result<String> {
618 self.summarize_internal(text, None)
619 }
620
621 pub fn summarize_query_focused(
632 &self,
633 document: &str,
634 query: &str,
635 max_sentences: usize,
636 ) -> Result<String> {
637 let override_self = EnhancedCentroidSummarizer {
638 num_sentences: max_sentences.max(1),
639 ..*self
640 };
641 override_self.summarize_internal(document, Some(query))
642 }
643
644 fn summarize_internal(&self, text: &str, query: Option<&str>) -> Result<String> {
645 let sentence_tokenizer = SentenceTokenizer::new();
646 let sentences: Vec<String> = sentence_tokenizer.tokenize(text)?;
647
648 if sentences.is_empty() {
649 return Ok(String::new());
650 }
651 if sentences.len() <= self.num_sentences {
652 return Ok(text.to_string());
653 }
654
655 let sentence_refs: Vec<&str> = sentences.iter().map(|s| s.as_str()).collect();
657 let mut vectorizer = TfidfVectorizer::default();
658 let tfidf = vectorizer.fit_transform(&sentence_refs)?;
659
660 let centroid = self.compute_centroid(&tfidf);
662
663 let query_vec: Option<Array1<f64>> = if let Some(q) = query {
665 vectorizer.transform_batch(&[q]).ok().map(|m| {
666 m.row(0).to_owned()
668 })
669 } else {
670 None
671 };
672
673 let n = sentences.len();
675 let mut scored: Vec<(usize, f64)> = (0..n)
676 .map(|i| {
677 let row = tfidf.row(i).to_owned();
678 let centroid_sim = cosine_sim_vec(&row, ¢roid);
679 let query_sim = query_vec
680 .as_ref()
681 .map(|qv| cosine_sim_vec(&row, qv))
682 .unwrap_or(0.0);
683 let content_score = if query_vec.is_some() {
685 0.5 * centroid_sim + 0.5 * query_sim
686 } else {
687 centroid_sim
688 };
689 let pos_bonus = (-0.5 * i as f64 / n as f64).exp() * self.position_bias;
691 (i, content_score + pos_bonus)
692 })
693 .collect();
694
695 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
697
698 let mut selected: Vec<usize> = Vec::new();
700 for (idx, _score) in &scored {
701 if selected.len() >= self.num_sentences {
702 break;
703 }
704 let redundant = selected
705 .iter()
706 .any(|&si| cosine_sim_rows(&tfidf, *idx, si) > self.redundancy_threshold);
707 if !redundant {
708 selected.push(*idx);
709 }
710 }
711
712 selected.sort_unstable();
714
715 let summary = selected
716 .iter()
717 .map(|&i| sentences[i].as_str())
718 .collect::<Vec<_>>()
719 .join(" ");
720
721 Ok(summary)
722 }
723
724 fn compute_centroid(&self, tfidf: &Array2<f64>) -> Array1<f64> {
727 let mean = tfidf
728 .mean_axis(scirs2_core::ndarray::Axis(0))
729 .unwrap_or_else(|| Array1::zeros(tfidf.ncols()));
730
731 mean.mapv(|v| if v > self.topic_threshold { v } else { 0.0 })
732 }
733}
734
735pub fn rouge_n(hypothesis: &str, reference: &str, n: usize) -> f64 {
763 if n == 0 {
764 return 0.0;
765 }
766 let hyp_tokens = word_tokens(hypothesis);
767 let ref_tokens = word_tokens(reference);
768
769 if ref_tokens.len() < n {
770 return 0.0;
771 }
772
773 let ref_ngrams = count_ngrams(&ref_tokens, n);
775 if ref_ngrams.is_empty() {
776 return 0.0;
777 }
778 let ref_total: usize = ref_ngrams.values().sum();
779
780 let hyp_ngrams = count_ngrams(&hyp_tokens, n);
782
783 let overlap: usize = ref_ngrams
785 .iter()
786 .map(|(gram, &ref_count)| {
787 let hyp_count = hyp_ngrams.get(gram).copied().unwrap_or(0);
788 hyp_count.min(ref_count)
789 })
790 .sum();
791
792 overlap as f64 / ref_total as f64
793}
794
795fn count_ngrams(tokens: &[String], n: usize) -> HashMap<Vec<String>, usize> {
797 let mut map: HashMap<Vec<String>, usize> = HashMap::new();
798 if tokens.len() < n {
799 return map;
800 }
801 for i in 0..=(tokens.len() - n) {
802 let gram: Vec<String> = tokens[i..i + n].to_vec();
803 *map.entry(gram).or_insert(0) += 1;
804 }
805 map
806}
807
808pub fn rouge_l(hypothesis: &str, reference: &str) -> f64 {
831 let hyp_tokens = word_tokens(hypothesis);
832 let ref_tokens = word_tokens(reference);
833
834 if ref_tokens.is_empty() {
835 return 0.0;
836 }
837
838 let lcs_len = lcs_length(&hyp_tokens, &ref_tokens);
839 lcs_len as f64 / ref_tokens.len() as f64
840}
841
842fn lcs_length(a: &[String], b: &[String]) -> usize {
846 let m = a.len();
847 let n = b.len();
848 if m == 0 || n == 0 {
849 return 0;
850 }
851
852 let mut prev = vec![0usize; n + 1];
854 let mut curr = vec![0usize; n + 1];
855
856 for i in 1..=m {
857 for j in 1..=n {
858 curr[j] = if a[i - 1] == b[j - 1] {
859 prev[j - 1] + 1
860 } else {
861 prev[j].max(curr[j - 1])
862 };
863 }
864 std::mem::swap(&mut prev, &mut curr);
865 curr.iter_mut().for_each(|v| *v = 0);
866 }
867
868 prev[n]
869}
870
871#[cfg(test)]
876mod tests {
877 use super::*;
878
879 const MULTI_DOC_A: &str =
880 "Rust is a systems programming language. It focuses on safety and performance.";
881 const MULTI_DOC_B: &str =
882 "Memory safety without a garbage collector is a key goal of Rust. The language also \
883 emphasises zero-cost abstractions.";
884 const LONG_TEXT: &str = "Natural language processing is a field of artificial intelligence. \
885 It allows computers to understand and generate human language. \
886 Applications include machine translation, chatbots, and sentiment analysis. \
887 Deep learning has greatly advanced NLP in recent years. \
888 Transformer models such as BERT and GPT are state-of-the-art.";
889
890 #[test]
893 fn test_fusion_extract_sentences_nonempty() {
894 let docs = vec![MULTI_DOC_A, MULTI_DOC_B];
895 let fs = FusionSummarizer::new(2);
896 let sents = fs.extract_sentences(&docs);
897 assert!(!sents.is_empty());
898 for s in &sents {
900 assert!(
901 (0.0..=1.001).contains(&s.score),
902 "score out of range: {}",
903 s.score
904 );
905 }
906 }
907
908 #[test]
909 fn test_fusion_extract_empty_docs() {
910 let fs = FusionSummarizer::new(2);
911 let sents = fs.extract_sentences(&[]);
912 assert!(sents.is_empty());
913 }
914
915 #[test]
916 fn test_fusion_cluster_basic() {
917 let docs = vec![MULTI_DOC_A, MULTI_DOC_B];
918 let fs = FusionSummarizer::new(2);
919 let sents = fs.extract_sentences(&docs);
920 let clusters = fs.cluster_sentences(&sents, 2);
921 assert!(!clusters.is_empty());
922 let total: usize = clusters.iter().map(|c| c.len()).sum();
924 assert_eq!(total, sents.len());
925 }
926
927 #[test]
928 fn test_fusion_generate_summary_respects_max_words() {
929 let docs = vec![MULTI_DOC_A, MULTI_DOC_B];
930 let fs = FusionSummarizer::new(2);
931 let sents = fs.extract_sentences(&docs);
932 let clusters = fs.cluster_sentences(&sents, 2);
933 let summary = fs.generate_summary(&clusters, 10);
934 let words: usize = summary.split_whitespace().count();
935 assert!(words <= 10, "Expected ≤10 words, got {}", words);
936 }
937
938 #[test]
939 fn test_fusion_summarize_end_to_end() {
940 let docs = vec![MULTI_DOC_A, MULTI_DOC_B];
941 let fs = FusionSummarizer::new(2);
942 let summary = fs.summarize(&docs, 60).expect("summarize should succeed");
943 assert!(!summary.is_empty());
944 }
945
946 #[test]
947 fn test_fusion_single_document() {
948 let docs = vec![LONG_TEXT];
949 let fs = FusionSummarizer::new(2);
950 let summary = fs.summarize(&docs, 80).expect("should succeed");
951 assert!(!summary.is_empty());
952 }
953
954 #[test]
957 fn test_compression_basic() {
958 let cs = CompressionSummarizer::new();
959 let sentence = "The very quick brown fox jumped lazily over the fence";
960 let compressed = cs.compress_sentence(sentence, 0.5);
961 let orig_words: usize = sentence.split_whitespace().count();
962 let comp_words: usize = compressed.split_whitespace().count();
963 assert!(comp_words <= orig_words);
964 assert!(!compressed.is_empty());
965 }
966
967 #[test]
968 fn test_compression_ratio_one_keeps_all() {
969 let cs = CompressionSummarizer::new();
970 let sentence = "Hello world this is a test sentence";
971 let compressed = cs.compress_sentence(sentence, 1.0);
972 let orig_words = sentence.split_whitespace().count();
973 let comp_words = compressed.split_whitespace().count();
974 assert_eq!(comp_words, orig_words);
975 }
976
977 #[test]
978 fn test_compression_empty_sentence() {
979 let cs = CompressionSummarizer::new();
980 let result = cs.compress_sentence("", 0.5);
981 assert!(result.is_empty());
982 }
983
984 #[test]
985 fn test_compression_importance_stop_word_lower() {
986 let cs = CompressionSummarizer::new();
987 let tokens: Vec<String> = vec!["the".to_string(), "quick".to_string(), "fox".to_string()];
988 let stop_score = cs.importance_score("the", &tokens);
989 let content_score = cs.importance_score("fox", &tokens);
990 assert!(
991 content_score > stop_score,
992 "Content word should score higher than stop word"
993 );
994 }
995
996 #[test]
999 fn test_enhanced_centroid_basic() {
1000 let s = EnhancedCentroidSummarizer::new(2);
1001 let summary = s.summarize(LONG_TEXT).expect("should succeed");
1002 assert!(!summary.is_empty());
1003 assert!(summary.len() < LONG_TEXT.len());
1004 }
1005
1006 #[test]
1007 fn test_enhanced_centroid_short_text() {
1008 let s = EnhancedCentroidSummarizer::new(5);
1009 let text = "One sentence only.";
1010 let summary = s.summarize(text).expect("should succeed");
1011 assert_eq!(summary, text);
1012 }
1013
1014 #[test]
1015 fn test_enhanced_centroid_empty() {
1016 let s = EnhancedCentroidSummarizer::new(2);
1017 let summary = s.summarize("").expect("should succeed");
1018 assert!(summary.is_empty());
1019 }
1020
1021 #[test]
1022 fn test_enhanced_centroid_query_focused() {
1023 let s = EnhancedCentroidSummarizer::new(2);
1024 let summary = s
1025 .summarize_query_focused(LONG_TEXT, "transformer models BERT GPT", 2)
1026 .expect("should succeed");
1027 assert!(!summary.is_empty());
1028 }
1029
1030 #[test]
1031 fn test_enhanced_centroid_query_focused_max_sentences() {
1032 let s = EnhancedCentroidSummarizer::new(2);
1033 let summary = s
1034 .summarize_query_focused(LONG_TEXT, "deep learning", 1)
1035 .expect("should succeed");
1036 let sent_tok = SentenceTokenizer::new();
1038 let sents = sent_tok.tokenize(&summary).expect("ok");
1039 assert!(sents.len() <= 1);
1040 }
1041
1042 #[test]
1045 fn test_rouge1_perfect_match() {
1046 let recall = rouge_n("the cat sat", "the cat sat", 1);
1047 assert!((recall - 1.0).abs() < 1e-9, "Expected 1.0, got {recall}");
1048 }
1049
1050 #[test]
1051 fn test_rouge1_partial_overlap() {
1052 let recall = rouge_n("cat sat", "the cat sat on the mat", 1);
1053 assert!((recall - 2.0 / 6.0).abs() < 1e-9, "Got {recall}");
1055 }
1056
1057 #[test]
1058 fn test_rouge2_basic() {
1059 let recall = rouge_n("the cat sat on the mat", "the cat sat on the mat", 2);
1060 assert!((recall - 1.0).abs() < 1e-9);
1061 }
1062
1063 #[test]
1064 fn test_rouge_n_zero_n() {
1065 assert_eq!(rouge_n("anything", "reference", 0), 0.0);
1066 }
1067
1068 #[test]
1069 fn test_rouge_n_empty_reference() {
1070 assert_eq!(rouge_n("hypothesis", "", 1), 0.0);
1071 }
1072
1073 #[test]
1074 fn test_rouge_n_empty_hypothesis() {
1075 assert_eq!(rouge_n("", "the cat sat", 1), 0.0);
1077 }
1078
1079 #[test]
1082 fn test_rouge_l_perfect_match() {
1083 let score = rouge_l("the cat sat", "the cat sat");
1084 assert!((score - 1.0).abs() < 1e-9);
1085 }
1086
1087 #[test]
1088 fn test_rouge_l_partial() {
1089 let score = rouge_l("cat sat", "the cat sat on the mat");
1091 assert!((score - 2.0 / 6.0).abs() < 1e-9, "Got {score}");
1092 }
1093
1094 #[test]
1095 fn test_rouge_l_empty_reference() {
1096 assert_eq!(rouge_l("hypothesis", ""), 0.0);
1097 }
1098
1099 #[test]
1100 fn test_rouge_l_empty_hypothesis() {
1101 assert_eq!(rouge_l("", "reference text"), 0.0);
1102 }
1103
1104 #[test]
1105 fn test_lcs_symmetric() {
1106 let a = vec!["a".to_string(), "b".to_string(), "c".to_string()];
1107 let b = vec!["b".to_string(), "c".to_string(), "d".to_string()];
1108 let lcs_ab = lcs_length(&a, &b);
1109 let lcs_ba = lcs_length(&b, &a);
1110 assert_eq!(lcs_ab, lcs_ba);
1111 }
1112}