1use std::collections::HashMap;
7
8pub struct TfIdfKeywordExtractor {
10 document_frequencies: HashMap<String, usize>,
12 total_documents: usize,
14 stopwords: std::collections::HashSet<String>,
16}
17
18impl TfIdfKeywordExtractor {
19 pub fn new(document_frequencies: HashMap<String, usize>, total_documents: usize) -> Self {
21 let stopwords = Self::load_stopwords();
22 Self {
23 document_frequencies,
24 total_documents: total_documents.max(1),
25 stopwords,
26 }
27 }
28
29 pub fn new_default() -> Self {
31 Self::new(HashMap::new(), 1)
32 }
33
34 pub fn extract_keywords(&self, text: &str, top_k: usize) -> Vec<(String, f32)> {
38 let tokens = self.tokenize(text);
40 let tf_scores = self.calculate_tf(&tokens);
41
42 let mut tfidf_scores: Vec<(String, f32)> = tf_scores
44 .into_iter()
45 .map(|(term, tf)| {
46 let idf = self.calculate_idf(&term);
47 let tfidf = tf * idf;
48 (term, tfidf)
49 })
50 .collect();
51
52 tfidf_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
54 tfidf_scores.truncate(top_k);
55
56 tfidf_scores
57 }
58
59 pub fn extract_keyword_strings(&self, text: &str, top_k: usize) -> Vec<String> {
61 self.extract_keywords(text, top_k)
62 .into_iter()
63 .map(|(word, _score)| word)
64 .collect()
65 }
66
67 fn tokenize(&self, text: &str) -> Vec<String> {
69 text.split_whitespace()
70 .map(|word| {
71 word.chars()
73 .filter(|c| c.is_alphanumeric() || *c == '-' || *c == '_')
74 .collect::<String>()
75 .to_lowercase()
76 })
77 .filter(|word| {
78 !word.is_empty()
80 && word.len() > 2
81 && !self.stopwords.contains(word)
82 && !word.chars().all(|c| c.is_numeric())
83 })
84 .collect()
85 }
86
87 fn calculate_tf(&self, tokens: &[String]) -> HashMap<String, f32> {
91 let mut term_counts: HashMap<String, usize> = HashMap::new();
92
93 for token in tokens {
95 *term_counts.entry(token.clone()).or_insert(0) += 1;
96 }
97
98 let total_terms = tokens.len().max(1) as f32;
99
100 term_counts
102 .into_iter()
103 .map(|(term, count)| (term, count as f32 / total_terms))
104 .collect()
105 }
106
107 fn calculate_idf(&self, term: &str) -> f32 {
113 let doc_freq = self
114 .document_frequencies
115 .get(term)
116 .copied()
117 .unwrap_or(1); let idf = (self.total_documents as f32 / doc_freq as f32).ln();
120 idf.max(0.0) }
122
123 fn load_stopwords() -> std::collections::HashSet<String> {
125 let stopwords_list = vec![
127 "the", "be", "to", "of", "and", "a", "in", "that", "have", "i", "it", "for", "not",
128 "on", "with", "he", "as", "you", "do", "at", "this", "but", "his", "by", "from",
129 "they", "we", "say", "her", "she", "or", "an", "will", "my", "one", "all", "would",
130 "there", "their", "what", "so", "up", "out", "if", "about", "who", "get", "which",
131 "go", "me", "when", "make", "can", "like", "time", "no", "just", "him", "know",
132 "take", "people", "into", "year", "your", "good", "some", "could", "them", "see",
133 "other", "than", "then", "now", "look", "only", "come", "its", "over", "think",
134 "also", "back", "after", "use", "two", "how", "our", "work", "first", "well",
135 "way", "even", "new", "want", "because", "any", "these", "give", "day", "most",
136 "us", "is", "was", "are", "been", "has", "had", "were", "said", "did",
137 ];
138
139 stopwords_list.into_iter().map(|s| s.to_string()).collect()
140 }
141
142 pub fn add_document_to_corpus(&mut self, text: &str) {
144 let tokens = self.tokenize(text);
145 let unique_terms: std::collections::HashSet<String> = tokens.into_iter().collect();
146
147 for term in unique_terms {
148 *self.document_frequencies.entry(term).or_insert(0) += 1;
149 }
150
151 self.total_documents += 1;
152 }
153
154 pub fn corpus_stats(&self) -> (usize, usize) {
156 (self.total_documents, self.document_frequencies.len())
157 }
158}
159
160impl Default for TfIdfKeywordExtractor {
161 fn default() -> Self {
162 Self::new_default()
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[test]
171 fn test_tokenization() {
172 let extractor = TfIdfKeywordExtractor::new_default();
173 let text = "Machine learning and artificial intelligence are transforming technology.";
174 let tokens = extractor.tokenize(text);
175
176 assert!(tokens.contains(&"machine".to_string()));
177 assert!(tokens.contains(&"learning".to_string()));
178 assert!(tokens.contains(&"artificial".to_string()));
179 assert!(!tokens.contains(&"and".to_string()));
181 assert!(!tokens.contains(&"are".to_string()));
182 }
183
184 #[test]
185 fn test_tf_calculation() {
186 let extractor = TfIdfKeywordExtractor::new_default();
187 let tokens = vec![
188 "machine".to_string(),
189 "learning".to_string(),
190 "machine".to_string(),
191 "learning".to_string(),
192 "data".to_string(),
193 ];
194
195 let tf_scores = extractor.calculate_tf(&tokens);
196
197 assert!((tf_scores["machine"] - 0.4).abs() < 0.001);
199 assert!((tf_scores["learning"] - 0.4).abs() < 0.001);
200 assert!((tf_scores["data"] - 0.2).abs() < 0.001);
202 }
203
204 #[test]
205 fn test_idf_calculation() {
206 let mut doc_freqs = HashMap::new();
207 doc_freqs.insert("common".to_string(), 50); doc_freqs.insert("rare".to_string(), 2); let extractor = TfIdfKeywordExtractor::new(doc_freqs, 100);
211
212 let idf_common = extractor.calculate_idf("common");
213 let idf_rare = extractor.calculate_idf("rare");
214
215 assert!(idf_rare > idf_common);
217 assert!((idf_common - 0.69).abs() < 0.1);
219 assert!((idf_rare - 3.91).abs() < 0.1);
221 }
222
223 #[test]
224 fn test_keyword_extraction() {
225 let mut extractor = TfIdfKeywordExtractor::new_default();
227
228 extractor.add_document_to_corpus("artificial intelligence is the future");
230 extractor.add_document_to_corpus("deep learning uses neural networks");
231 extractor.add_document_to_corpus("natural language processing is important");
232
233 let text = "machine learning and deep learning are important topics in artificial intelligence. \
234 neural networks and machine learning models are widely used.";
235
236 let keywords = extractor.extract_keywords(text, 5);
237
238 assert!(keywords.len() >= 3);
239 let keyword_terms: Vec<&str> = keywords.iter().map(|(w, _)| w.as_str()).collect();
241
242 assert!(keyword_terms.contains(&"learning") ||
244 keyword_terms.contains(&"machine") ||
245 keyword_terms.contains(&"neural"),
246 "Expected high-frequency terms not found. Got: {:?}", keyword_terms);
247 }
248
249 #[test]
250 fn test_corpus_building() {
251 let mut extractor = TfIdfKeywordExtractor::new_default();
252
253 extractor.add_document_to_corpus("machine learning is amazing");
254 extractor.add_document_to_corpus("deep learning is powerful");
255 extractor.add_document_to_corpus("natural language processing");
256
257 let (total_docs, unique_terms) = extractor.corpus_stats();
258 assert_eq!(total_docs, 4); assert!(unique_terms > 0);
260 }
261
262 #[test]
263 fn test_stopword_filtering() {
264 let extractor = TfIdfKeywordExtractor::new_default();
265 let text = "The quick brown fox jumps over the lazy dog and the cat";
266 let keywords = extractor.extract_keyword_strings(text, 10);
267
268 assert!(!keywords.iter().any(|w| w == "the"));
270 assert!(!keywords.iter().any(|w| w == "and"));
271 assert!(!keywords.iter().any(|w| w == "over"));
272
273 assert!(keywords.iter().any(|w| w == "quick" || w == "brown" || w == "fox"));
275 }
276}