use crate::error::{Result, TextError};
use crate::tokenize::{Tokenizer, WordTokenizer};
use crate::vectorize::{TfidfVectorizer, Vectorizer};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct Keyword {
pub text: String,
pub score: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum KeywordMethod {
TfIdf,
TextRank,
Rake,
}
pub fn extract_keywords(text: &str, method: KeywordMethod, top_k: usize) -> Result<Vec<Keyword>> {
match method {
KeywordMethod::TfIdf => {
let extractor = TfIdfKeywordExtractor::new();
extractor.extract(text, top_k)
}
KeywordMethod::TextRank => {
let extractor = TextRankKeywordExtractor::new();
extractor.extract(text, top_k)
}
KeywordMethod::Rake => {
let extractor = RakeKeywordExtractor::new();
extractor.extract(text, top_k)
}
}
}
pub struct TfIdfKeywordExtractor {
tokenizer: Box<dyn Tokenizer + Send + Sync>,
min_token_len: usize,
}
impl TfIdfKeywordExtractor {
pub fn new() -> Self {
Self {
tokenizer: Box::new(WordTokenizer::default()),
min_token_len: 2,
}
}
pub fn with_min_token_len(mut self, len: usize) -> Self {
self.min_token_len = len;
self
}
pub fn extract(&self, text: &str, top_k: usize) -> Result<Vec<Keyword>> {
if text.trim().is_empty() {
return Ok(Vec::new());
}
let sentences = split_sentences(text);
if sentences.is_empty() {
return Ok(Vec::new());
}
let sentence_refs: Vec<&str> = sentences.iter().map(|s| s.as_str()).collect();
let mut vectorizer = TfidfVectorizer::default();
vectorizer.fit(&sentence_refs)?;
let tfidf_matrix = vectorizer.transform_batch(&sentence_refs)?;
let vocab = build_vocabulary(&sentences, &*self.tokenizer)?;
let n_terms = tfidf_matrix.ncols();
let n_docs = tfidf_matrix.nrows();
if n_terms == 0 || n_docs == 0 {
return Ok(Vec::new());
}
let mut avg_scores: Vec<f64> = Vec::with_capacity(n_terms);
for col_idx in 0..n_terms {
let col_sum: f64 = (0..n_docs).map(|row| tfidf_matrix[[row, col_idx]]).sum();
avg_scores.push(col_sum / n_docs as f64);
}
let mut keyword_scores: Vec<Keyword> = Vec::new();
for (idx, &score) in avg_scores.iter().enumerate() {
if score <= 0.0 {
continue;
}
if let Some(term) = vocab.get(&idx) {
if term.len() >= self.min_token_len {
keyword_scores.push(Keyword {
text: term.clone(),
score,
});
}
}
}
keyword_scores.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
keyword_scores.truncate(top_k);
Ok(keyword_scores)
}
}
impl Default for TfIdfKeywordExtractor {
fn default() -> Self {
Self::new()
}
}
pub struct TextRankKeywordExtractor {
tokenizer: Box<dyn Tokenizer + Send + Sync>,
window_size: usize,
damping: f64,
max_iterations: usize,
convergence_threshold: f64,
min_token_len: usize,
}
impl TextRankKeywordExtractor {
pub fn new() -> Self {
Self {
tokenizer: Box::new(WordTokenizer::default()),
window_size: 4,
damping: 0.85,
max_iterations: 100,
convergence_threshold: 1e-5,
min_token_len: 2,
}
}
pub fn with_window_size(mut self, size: usize) -> Result<Self> {
if size < 2 {
return Err(TextError::InvalidInput(
"Window size must be at least 2".to_string(),
));
}
self.window_size = size;
Ok(self)
}
pub fn with_damping(mut self, d: f64) -> Result<Self> {
if !(0.0..=1.0).contains(&d) {
return Err(TextError::InvalidInput(
"Damping factor must be between 0 and 1".to_string(),
));
}
self.damping = d;
Ok(self)
}
pub fn extract(&self, text: &str, top_k: usize) -> Result<Vec<Keyword>> {
if text.trim().is_empty() {
return Ok(Vec::new());
}
let tokens = self.tokenizer.tokenize(text)?;
let filtered: Vec<String> = tokens
.into_iter()
.filter(|t| t.len() >= self.min_token_len && !is_stopword(t))
.collect();
if filtered.is_empty() {
return Ok(Vec::new());
}
let mut graph: HashMap<String, HashMap<String, f64>> = HashMap::new();
for window in filtered.windows(self.window_size) {
for i in 0..window.len() {
for j in (i + 1)..window.len() {
let a = &window[i];
let b = &window[j];
*graph
.entry(a.clone())
.or_default()
.entry(b.clone())
.or_insert(0.0) += 1.0;
*graph
.entry(b.clone())
.or_default()
.entry(a.clone())
.or_insert(0.0) += 1.0;
}
}
}
let nodes: Vec<String> = graph.keys().cloned().collect();
let n = nodes.len();
if n == 0 {
return Ok(Vec::new());
}
let node_idx: HashMap<&str, usize> = nodes
.iter()
.enumerate()
.map(|(i, w)| (w.as_str(), i))
.collect();
let mut scores = vec![1.0 / n as f64; n];
let out_sums: Vec<f64> = nodes
.iter()
.map(|node| {
graph
.get(node)
.map(|neighbors| neighbors.values().sum::<f64>())
.unwrap_or(0.0)
})
.collect();
for _ in 0..self.max_iterations {
let mut new_scores = vec![(1.0 - self.damping) / n as f64; n];
for (j, node_j) in nodes.iter().enumerate() {
if out_sums[j] <= 0.0 {
continue;
}
if let Some(neighbors) = graph.get(node_j) {
for (neighbor, weight) in neighbors {
if let Some(&i) = node_idx.get(neighbor.as_str()) {
new_scores[i] += self.damping * (weight / out_sums[j]) * scores[j];
}
}
}
}
let diff: f64 = scores
.iter()
.zip(new_scores.iter())
.map(|(a, b)| (a - b).abs())
.sum();
scores = new_scores;
if diff < self.convergence_threshold {
break;
}
}
let mut word_scores: HashMap<String, f64> = HashMap::new();
for (i, node) in nodes.iter().enumerate() {
word_scores.insert(node.clone(), scores[i]);
}
let all_tokens = self.tokenizer.tokenize(text)?;
let keywords = merge_adjacent_keywords(&all_tokens, &word_scores, top_k);
Ok(keywords)
}
}
impl Default for TextRankKeywordExtractor {
fn default() -> Self {
Self::new()
}
}
fn merge_adjacent_keywords(
tokens: &[String],
word_scores: &HashMap<String, f64>,
top_k: usize,
) -> Vec<Keyword> {
let mut phrases: Vec<(Vec<String>, f64)> = Vec::new();
let mut current_phrase: Vec<String> = Vec::new();
let mut current_score: f64 = 0.0;
for token in tokens {
if let Some(&score) = word_scores.get(token) {
current_phrase.push(token.clone());
current_score += score;
} else {
if !current_phrase.is_empty() {
phrases.push((current_phrase.clone(), current_score));
current_phrase.clear();
current_score = 0.0;
}
}
}
if !current_phrase.is_empty() {
phrases.push((current_phrase, current_score));
}
let mut seen: HashSet<String> = HashSet::new();
let mut keywords: Vec<Keyword> = Vec::new();
for (words, score) in phrases {
let phrase_text = words.join(" ");
if seen.contains(&phrase_text) {
continue;
}
seen.insert(phrase_text.clone());
keywords.push(Keyword {
text: phrase_text,
score,
});
}
keywords.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
keywords.truncate(top_k);
keywords
}
pub struct RakeKeywordExtractor {
min_phrase_len: usize,
max_phrase_len: usize,
min_word_len: usize,
}
impl RakeKeywordExtractor {
pub fn new() -> Self {
Self {
min_phrase_len: 1,
max_phrase_len: 4,
min_word_len: 2,
}
}
pub fn with_min_phrase_len(mut self, len: usize) -> Self {
self.min_phrase_len = len;
self
}
pub fn with_max_phrase_len(mut self, len: usize) -> Self {
self.max_phrase_len = len;
self
}
pub fn extract(&self, text: &str, top_k: usize) -> Result<Vec<Keyword>> {
if text.trim().is_empty() {
return Ok(Vec::new());
}
let candidates = self.generate_candidates(text);
if candidates.is_empty() {
return Ok(Vec::new());
}
let mut word_freq: HashMap<String, f64> = HashMap::new();
let mut word_degree: HashMap<String, f64> = HashMap::new();
for phrase in &candidates {
let words: Vec<&str> = phrase
.split_whitespace()
.filter(|w| w.len() >= self.min_word_len)
.collect();
let degree = words.len() as f64;
for word in &words {
let w = word.to_lowercase();
*word_freq.entry(w.clone()).or_insert(0.0) += 1.0;
*word_degree.entry(w).or_insert(0.0) += degree;
}
}
let mut word_scores: HashMap<String, f64> = HashMap::new();
for (word, freq) in &word_freq {
let degree = word_degree.get(word).copied().unwrap_or(0.0);
if *freq > 0.0 {
word_scores.insert(word.clone(), degree / freq);
}
}
let mut phrase_scores: Vec<Keyword> = Vec::new();
let mut seen: HashSet<String> = HashSet::new();
for phrase in &candidates {
let normalized = phrase.to_lowercase();
if seen.contains(&normalized) {
continue;
}
seen.insert(normalized.clone());
let words: Vec<&str> = normalized
.split_whitespace()
.filter(|w| w.len() >= self.min_word_len)
.collect();
if words.is_empty() {
continue;
}
let score: f64 = words
.iter()
.map(|w| word_scores.get(*w).copied().unwrap_or(0.0))
.sum();
phrase_scores.push(Keyword {
text: normalized,
score,
});
}
phrase_scores.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
phrase_scores.truncate(top_k);
Ok(phrase_scores)
}
fn generate_candidates(&self, text: &str) -> Vec<String> {
let lower = text.to_lowercase();
let mut candidates: Vec<String> = Vec::new();
let mut current_phrase: Vec<String> = Vec::new();
for word in lower.split(|c: char| !c.is_alphanumeric() && c != '\'') {
let trimmed = word.trim();
if trimmed.is_empty() {
if !current_phrase.is_empty() {
self.add_candidate(&mut candidates, ¤t_phrase);
current_phrase.clear();
}
continue;
}
if is_stopword(trimmed) {
if !current_phrase.is_empty() {
self.add_candidate(&mut candidates, ¤t_phrase);
current_phrase.clear();
}
} else {
current_phrase.push(trimmed.to_string());
}
}
if !current_phrase.is_empty() {
self.add_candidate(&mut candidates, ¤t_phrase);
}
candidates
}
fn add_candidate(&self, candidates: &mut Vec<String>, phrase_words: &[String]) {
if phrase_words.len() < self.min_phrase_len || phrase_words.len() > self.max_phrase_len {
return;
}
let phrase = phrase_words.join(" ");
if phrase
.split_whitespace()
.any(|w| w.len() >= self.min_word_len)
{
candidates.push(phrase);
}
}
}
impl Default for RakeKeywordExtractor {
fn default() -> Self {
Self::new()
}
}
fn split_sentences(text: &str) -> Vec<String> {
let mut sentences = Vec::new();
let mut current = String::new();
for ch in text.chars() {
current.push(ch);
if ch == '.' || ch == '!' || ch == '?' {
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
sentences.push(trimmed);
}
current.clear();
}
}
let trimmed = current.trim().to_string();
if !trimmed.is_empty() {
sentences.push(trimmed);
}
sentences
}
fn build_vocabulary(
sentences: &[String],
tokenizer: &dyn Tokenizer,
) -> Result<HashMap<usize, String>> {
let mut term_to_idx: HashMap<String, usize> = HashMap::new();
let mut next_idx: usize = 0;
for sentence in sentences {
let tokens = tokenizer.tokenize(sentence)?;
for token in tokens {
if let std::collections::hash_map::Entry::Vacant(e) = term_to_idx.entry(token) {
e.insert(next_idx);
next_idx += 1;
}
}
}
let idx_to_term: HashMap<usize, String> =
term_to_idx.into_iter().map(|(t, i)| (i, t)).collect();
Ok(idx_to_term)
}
fn is_stopword(word: &str) -> bool {
const STOPWORDS: &[&str] = &[
"a", "an", "the", "and", "or", "but", "if", "in", "on", "at", "to", "for", "of", "with",
"by", "from", "as", "is", "was", "are", "were", "been", "be", "have", "has", "had", "do",
"does", "did", "will", "would", "shall", "should", "may", "might", "must", "can", "could",
"not", "no", "nor", "so", "than", "that", "this", "these", "those", "it", "its", "i", "me",
"my", "we", "us", "our", "you", "your", "he", "him", "his", "she", "her", "they", "them",
"their", "what", "which", "who", "whom", "when", "where", "why", "how", "all", "each",
"every", "both", "few", "more", "most", "other", "some", "such", "only", "own", "same",
"also", "just", "about", "above", "after", "again", "against", "any", "because", "before",
"below", "between", "during", "further", "here", "into", "once", "out", "over", "then",
"there", "through", "under", "until", "up", "very", "while",
];
STOPWORDS.contains(&word.to_lowercase().as_str())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tfidf_extracts_keywords() {
let text = "Machine learning is a powerful tool. \
Machine learning algorithms process data efficiently. \
Deep learning extends machine learning with neural networks.";
let keywords = extract_keywords(text, KeywordMethod::TfIdf, 5)
.expect("TF-IDF extraction should succeed");
assert!(!keywords.is_empty());
assert!(keywords.len() <= 5);
for pair in keywords.windows(2) {
assert!(pair[0].score >= pair[1].score);
}
}
#[test]
fn test_tfidf_empty_text() {
let result =
extract_keywords("", KeywordMethod::TfIdf, 5).expect("Empty text should not error");
assert!(result.is_empty());
}
#[test]
fn test_tfidf_single_sentence() {
let result = extract_keywords(
"Rust programming language is fast and safe.",
KeywordMethod::TfIdf,
3,
)
.expect("Single sentence should succeed");
assert!(!result.is_empty());
}
#[test]
fn test_tfidf_respects_top_k() {
let text = "Alpha beta gamma delta epsilon zeta eta theta iota kappa. \
Alpha beta gamma delta epsilon zeta eta theta iota kappa.";
let result =
extract_keywords(text, KeywordMethod::TfIdf, 3).expect("Extraction should succeed");
assert!(result.len() <= 3);
}
#[test]
fn test_tfidf_min_token_len() {
let extractor = TfIdfKeywordExtractor::new().with_min_token_len(5);
let text = "AI and ML are big. Artificial intelligence is growing.";
let result = extractor
.extract(text, 10)
.expect("Extraction should succeed");
for kw in &result {
for word in kw.text.split_whitespace() {
assert!(word.len() >= 5, "Word '{}' is too short", word);
}
}
}
#[test]
fn test_textrank_extracts_keywords() {
let text = "Natural language processing enables computers to understand human language. \
Text mining and information retrieval are subfields of natural language processing. \
Sentiment analysis determines the emotional tone of text.";
let keywords = extract_keywords(text, KeywordMethod::TextRank, 5)
.expect("TextRank extraction should succeed");
assert!(!keywords.is_empty());
assert!(keywords.len() <= 5);
}
#[test]
fn test_textrank_empty_text() {
let result =
extract_keywords("", KeywordMethod::TextRank, 5).expect("Empty text should not error");
assert!(result.is_empty());
}
#[test]
fn test_textrank_scores_descending() {
let text = "Graph algorithms are fundamental in computer science. \
PageRank is a famous graph algorithm. \
Many applications use graph-based methods.";
let keywords = extract_keywords(text, KeywordMethod::TextRank, 10)
.expect("TextRank extraction should succeed");
for pair in keywords.windows(2) {
assert!(pair[0].score >= pair[1].score);
}
}
#[test]
fn test_textrank_window_size() {
let extractor = TextRankKeywordExtractor::new()
.with_window_size(2)
.expect("Window size 2 should be valid");
let text = "Alpha beta gamma delta epsilon. Alpha beta gamma delta.";
let result = extractor
.extract(text, 5)
.expect("Extraction should succeed");
assert!(!result.is_empty());
}
#[test]
fn test_textrank_invalid_window() {
let result = TextRankKeywordExtractor::new().with_window_size(0);
assert!(result.is_err());
}
#[test]
fn test_rake_extracts_keywords() {
let text =
"Compatibility of systems of linear constraints over the set of natural numbers. \
Criteria of compatibility of a system of linear Diophantine equations.";
let keywords =
extract_keywords(text, KeywordMethod::Rake, 5).expect("RAKE extraction should succeed");
assert!(!keywords.is_empty());
assert!(keywords.len() <= 5);
}
#[test]
fn test_rake_empty_text() {
let result =
extract_keywords("", KeywordMethod::Rake, 5).expect("Empty text should not error");
assert!(result.is_empty());
}
#[test]
fn test_rake_phrase_scoring() {
let text = "Machine learning algorithms are important. \
Deep learning algorithms are even more powerful. \
Algorithms drive modern artificial intelligence.";
let keywords = extract_keywords(text, KeywordMethod::Rake, 10)
.expect("RAKE extraction should succeed");
assert!(!keywords.is_empty());
for pair in keywords.windows(2) {
assert!(pair[0].score >= pair[1].score);
}
}
#[test]
fn test_rake_stopword_splitting() {
let text = "The quick brown fox and the lazy dog.";
let extractor = RakeKeywordExtractor::new();
let candidates = extractor.generate_candidates(text);
for candidate in &candidates {
for word in candidate.split_whitespace() {
assert!(!is_stopword(word), "'{}' is a stopword", word);
}
}
}
#[test]
fn test_rake_max_phrase_len() {
let extractor = RakeKeywordExtractor::new().with_max_phrase_len(2);
let text = "Advanced machine learning algorithms improve natural language processing.";
let result = extractor
.extract(text, 10)
.expect("Extraction should succeed");
for kw in &result {
let word_count = kw.text.split_whitespace().count();
assert!(word_count <= 2, "Phrase '{}' exceeds max length", kw.text);
}
}
#[test]
fn test_all_methods_non_empty_for_real_text() {
let text = "Rust is a systems programming language focused on safety and performance. \
The Rust compiler prevents data races and memory errors at compile time. \
Many developers choose Rust for building reliable software.";
for method in &[
KeywordMethod::TfIdf,
KeywordMethod::TextRank,
KeywordMethod::Rake,
] {
let keywords = extract_keywords(text, *method, 5).expect("Extraction should succeed");
assert!(
!keywords.is_empty(),
"Method {:?} returned empty for real text",
method
);
}
}
#[test]
fn test_all_methods_handle_whitespace_only() {
for method in &[
KeywordMethod::TfIdf,
KeywordMethod::TextRank,
KeywordMethod::Rake,
] {
let result =
extract_keywords(" \t\n ", *method, 5).expect("Whitespace should not error");
assert!(result.is_empty());
}
}
}