use crate::AprenderError;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SummarizationMethod {
TextRank,
TfIdf,
Hybrid,
}
#[derive(Debug)]
pub struct TextSummarizer {
method: SummarizationMethod,
max_sentences: usize,
damping_factor: f64,
max_iterations: usize,
convergence_threshold: f64,
}
impl TextSummarizer {
#[must_use]
pub fn new(method: SummarizationMethod, max_sentences: usize) -> Self {
Self {
method,
max_sentences,
damping_factor: 0.85,
max_iterations: 100,
convergence_threshold: 0.0001,
}
}
#[must_use]
pub fn with_damping_factor(mut self, factor: f64) -> Self {
self.damping_factor = factor;
self
}
#[must_use]
pub fn with_max_iterations(mut self, max_iter: usize) -> Self {
self.max_iterations = max_iter;
self
}
pub fn summarize(&self, text: &str) -> Result<Vec<String>, AprenderError> {
let sentences = Self::split_sentences(text);
if sentences.is_empty() {
return Ok(Vec::new());
}
if sentences.len() <= self.max_sentences {
return Ok(sentences);
}
let scores = match self.method {
SummarizationMethod::TextRank => self.textrank_scores(&sentences),
SummarizationMethod::TfIdf => self.tfidf_scores(&sentences),
SummarizationMethod::Hybrid => self.hybrid_scores(&sentences),
};
let top_indices = Self::select_top_sentences(&scores, self.max_sentences);
let mut selected: Vec<(usize, String)> = top_indices
.into_iter()
.map(|idx| (idx, sentences[idx].clone()))
.collect();
selected.sort_by_key(|(idx, _)| *idx);
Ok(selected.into_iter().map(|(_, sent)| sent).collect())
}
fn split_sentences(text: &str) -> Vec<String> {
text.split(['.', '!', '?'])
.map(str::trim)
.filter(|s| !s.is_empty())
.map(String::from)
.collect()
}
fn textrank_scores(&self, sentences: &[String]) -> Vec<f64> {
let n = sentences.len();
if n == 0 {
return Vec::new();
}
let similarity = self.build_similarity_matrix(sentences);
let mut scores = vec![1.0 / n as f64; n];
let mut new_scores = vec![0.0; n];
for _ in 0..self.max_iterations {
let mut converged = true;
for i in 0..n {
let mut score = (1.0 - self.damping_factor) / n as f64;
for j in 0..n {
if i != j {
let outbound_sum: f64 =
(0..n).filter(|&k| k != j).map(|k| similarity[j][k]).sum();
if outbound_sum > 1e-10 {
score +=
self.damping_factor * (similarity[j][i] / outbound_sum) * scores[j];
}
}
}
new_scores[i] = score;
if (new_scores[i] - scores[i]).abs() > self.convergence_threshold {
converged = false;
}
}
scores.clone_from_slice(&new_scores);
if converged {
break;
}
}
scores
}
#[allow(clippy::unused_self)]
fn tfidf_scores(&self, sentences: &[String]) -> Vec<f64> {
if sentences.is_empty() {
return Vec::new();
}
let tokenized: Vec<Vec<String>> = sentences.iter().map(|s| Self::tokenize(s)).collect();
let idf = Self::compute_idf(&tokenized);
let scores: Vec<f64> = tokenized
.iter()
.map(|tokens| {
if tokens.is_empty() {
return 0.0;
}
let mut tf: HashMap<&str, f64> = HashMap::new();
for token in tokens {
*tf.entry(token.as_str()).or_insert(0.0) += 1.0;
}
let max_tf = tf.values().copied().fold(0.0, f64::max);
if max_tf > 0.0 {
for value in tf.values_mut() {
*value /= max_tf;
}
}
tf.iter()
.map(|(term, &tf_val)| {
let idf_val = idf.get(*term).copied().unwrap_or(0.0);
tf_val * idf_val
})
.sum()
})
.collect();
scores
}
fn hybrid_scores(&self, sentences: &[String]) -> Vec<f64> {
let textrank = self.textrank_scores(sentences);
let tfidf = self.tfidf_scores(sentences);
let textrank_norm = Self::normalize(&textrank);
let tfidf_norm = Self::normalize(&tfidf);
let scores: Vec<f64> = textrank_norm
.iter()
.zip(tfidf_norm.iter())
.map(|(tr, tf)| (tr + tf) / 2.0)
.collect();
scores
}
#[allow(clippy::unused_self)]
fn build_similarity_matrix(&self, sentences: &[String]) -> Vec<Vec<f64>> {
let n = sentences.len();
let mut similarity = vec![vec![0.0; n]; n];
let tokenized: Vec<HashSet<String>> = sentences
.iter()
.map(|s| Self::tokenize(s).into_iter().collect())
.collect();
for i in 0..n {
for j in 0..n {
if i != j {
let intersection: f64 = tokenized[i].intersection(&tokenized[j]).count() as f64;
let union_size = tokenized[i].len() + tokenized[j].len();
if union_size > 0 {
similarity[i][j] = (2.0 * intersection) / union_size as f64;
}
}
}
}
similarity
}
fn tokenize(text: &str) -> Vec<String> {
text.to_lowercase()
.split(|c: char| c.is_whitespace() || c.is_ascii_punctuation())
.filter(|s| !s.is_empty())
.map(String::from)
.collect()
}
fn compute_idf(documents: &[Vec<String>]) -> HashMap<String, f64> {
let n = documents.len() as f64;
let mut document_freq: HashMap<String, usize> = HashMap::new();
for doc in documents {
let unique_terms: HashSet<&str> = doc.iter().map(String::as_str).collect();
for term in unique_terms {
*document_freq.entry(term.to_string()).or_insert(0) += 1;
}
}
document_freq
.into_iter()
.map(|(term, df)| {
let idf = ((n + 1.0) / (df as f64 + 1.0)).ln() + 1.0;
(term, idf)
})
.collect()
}
fn normalize(scores: &[f64]) -> Vec<f64> {
if scores.is_empty() {
return Vec::new();
}
let min_score = scores.iter().copied().fold(f64::INFINITY, f64::min);
let max_score = scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
let range = max_score - min_score;
if range < 1e-10 {
return vec![0.5; scores.len()];
}
scores.iter().map(|&s| (s - min_score) / range).collect()
}
fn select_top_sentences(scores: &[f64], k: usize) -> Vec<usize> {
let mut indexed: Vec<(usize, f64)> = scores
.iter()
.enumerate()
.map(|(idx, &score)| (idx, score))
.collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
indexed.into_iter().take(k).map(|(idx, _)| idx).collect()
}
}
#[cfg(test)]
#[path = "summarize_tests.rs"]
mod tests;
#[cfg(test)]
#[path = "summarize_contract_falsify.rs"]
mod summarize_contract_falsify;