Skip to main content

graphrag_core/text/
keyword_extraction.rs

1//! Real TF-IDF keyword extraction
2//!
3//! This module implements actual TF-IDF (Term Frequency-Inverse Document Frequency)
4//! algorithm for keyword extraction, not mock/placeholder implementations.
5
6use std::collections::HashMap;
7
8/// TF-IDF based keyword extractor
9pub struct TfIdfKeywordExtractor {
10    /// Document frequencies: how many documents contain each term
11    document_frequencies: HashMap<String, usize>,
12    /// Total number of documents in the corpus
13    total_documents: usize,
14    /// Stop words to ignore
15    stopwords: std::collections::HashSet<String>,
16}
17
18impl TfIdfKeywordExtractor {
19    /// Create a new TF-IDF extractor
20    pub fn new(document_frequencies: HashMap<String, usize>, total_documents: usize) -> Self {
21        let stopwords = Self::load_stopwords();
22        Self {
23            document_frequencies,
24            total_documents: total_documents.max(1),
25            stopwords,
26        }
27    }
28
29    /// Create with default stopwords and empty IDF data (for single-document use)
30    pub fn new_default() -> Self {
31        Self::new(HashMap::new(), 1)
32    }
33
34    /// Extract keywords using TF-IDF scoring
35    ///
36    /// Returns keywords sorted by TF-IDF score (highest first)
37    pub fn extract_keywords(&self, text: &str, top_k: usize) -> Vec<(String, f32)> {
38        // 1. Tokenize and calculate term frequencies
39        let tokens = self.tokenize(text);
40        let tf_scores = self.calculate_tf(&tokens);
41
42        // 2. Calculate TF-IDF scores
43        let mut tfidf_scores: Vec<(String, f32)> = tf_scores
44            .into_iter()
45            .map(|(term, tf)| {
46                let idf = self.calculate_idf(&term);
47                let tfidf = tf * idf;
48                (term, tfidf)
49            })
50            .collect();
51
52        // 3. Sort by score (descending) and take top-k
53        tfidf_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
54        tfidf_scores.truncate(top_k);
55
56        tfidf_scores
57    }
58
59    /// Extract just keyword strings (without scores)
60    pub fn extract_keyword_strings(&self, text: &str, top_k: usize) -> Vec<String> {
61        self.extract_keywords(text, top_k)
62            .into_iter()
63            .map(|(word, _score)| word)
64            .collect()
65    }
66
67    /// Tokenize text into words
68    fn tokenize(&self, text: &str) -> Vec<String> {
69        text.split_whitespace()
70            .map(|word| {
71                // Remove punctuation and convert to lowercase
72                word.chars()
73                    .filter(|c| c.is_alphanumeric() || *c == '-' || *c == '_')
74                    .collect::<String>()
75                    .to_lowercase()
76            })
77            .filter(|word| {
78                // Filter: non-empty, length > 2, not stopword, not pure numbers
79                !word.is_empty()
80                    && word.len() > 2
81                    && !self.stopwords.contains(word)
82                    && !word.chars().all(|c| c.is_numeric())
83            })
84            .collect()
85    }
86
87    /// Calculate term frequency (TF) using normalized frequency
88    ///
89    /// TF = (count of term in document) / (total terms in document)
90    fn calculate_tf(&self, tokens: &[String]) -> HashMap<String, f32> {
91        let mut term_counts: HashMap<String, usize> = HashMap::new();
92
93        // Count occurrences
94        for token in tokens {
95            *term_counts.entry(token.clone()).or_insert(0) += 1;
96        }
97
98        let total_terms = tokens.len().max(1) as f32;
99
100        // Normalize by total terms
101        term_counts
102            .into_iter()
103            .map(|(term, count)| (term, count as f32 / total_terms))
104            .collect()
105    }
106
107    /// Calculate inverse document frequency (IDF)
108    ///
109    /// IDF = log(total_documents / documents_containing_term)
110    ///
111    /// If term is not in corpus, uses a default IDF (assumes rare term)
112    fn calculate_idf(&self, term: &str) -> f32 {
113        let doc_freq = self
114            .document_frequencies
115            .get(term)
116            .copied()
117            .unwrap_or(1); // Default to 1 if not seen (rare term)
118
119        let idf = (self.total_documents as f32 / doc_freq as f32).ln();
120        idf.max(0.0) // Ensure non-negative
121    }
122
123    /// Load English stopwords
124    fn load_stopwords() -> std::collections::HashSet<String> {
125        // Common English stopwords
126        let stopwords_list = vec![
127            "the", "be", "to", "of", "and", "a", "in", "that", "have", "i", "it", "for", "not",
128            "on", "with", "he", "as", "you", "do", "at", "this", "but", "his", "by", "from",
129            "they", "we", "say", "her", "she", "or", "an", "will", "my", "one", "all", "would",
130            "there", "their", "what", "so", "up", "out", "if", "about", "who", "get", "which",
131            "go", "me", "when", "make", "can", "like", "time", "no", "just", "him", "know",
132            "take", "people", "into", "year", "your", "good", "some", "could", "them", "see",
133            "other", "than", "then", "now", "look", "only", "come", "its", "over", "think",
134            "also", "back", "after", "use", "two", "how", "our", "work", "first", "well",
135            "way", "even", "new", "want", "because", "any", "these", "give", "day", "most",
136            "us", "is", "was", "are", "been", "has", "had", "were", "said", "did",
137        ];
138
139        stopwords_list.into_iter().map(|s| s.to_string()).collect()
140    }
141
142    /// Update document frequencies with a new document (for corpus-level IDF)
143    pub fn add_document_to_corpus(&mut self, text: &str) {
144        let tokens = self.tokenize(text);
145        let unique_terms: std::collections::HashSet<String> = tokens.into_iter().collect();
146
147        for term in unique_terms {
148            *self.document_frequencies.entry(term).or_insert(0) += 1;
149        }
150
151        self.total_documents += 1;
152    }
153
154    /// Get corpus statistics
155    pub fn corpus_stats(&self) -> (usize, usize) {
156        (self.total_documents, self.document_frequencies.len())
157    }
158}
159
160impl Default for TfIdfKeywordExtractor {
161    fn default() -> Self {
162        Self::new_default()
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    #[test]
171    fn test_tokenization() {
172        let extractor = TfIdfKeywordExtractor::new_default();
173        let text = "Machine learning and artificial intelligence are transforming technology.";
174        let tokens = extractor.tokenize(text);
175
176        assert!(tokens.contains(&"machine".to_string()));
177        assert!(tokens.contains(&"learning".to_string()));
178        assert!(tokens.contains(&"artificial".to_string()));
179        // Stopwords should be filtered
180        assert!(!tokens.contains(&"and".to_string()));
181        assert!(!tokens.contains(&"are".to_string()));
182    }
183
184    #[test]
185    fn test_tf_calculation() {
186        let extractor = TfIdfKeywordExtractor::new_default();
187        let tokens = vec![
188            "machine".to_string(),
189            "learning".to_string(),
190            "machine".to_string(),
191            "learning".to_string(),
192            "data".to_string(),
193        ];
194
195        let tf_scores = extractor.calculate_tf(&tokens);
196
197        // machine and learning appear 2 times out of 5 = 0.4
198        assert!((tf_scores["machine"] - 0.4).abs() < 0.001);
199        assert!((tf_scores["learning"] - 0.4).abs() < 0.001);
200        // data appears 1 time out of 5 = 0.2
201        assert!((tf_scores["data"] - 0.2).abs() < 0.001);
202    }
203
204    #[test]
205    fn test_idf_calculation() {
206        let mut doc_freqs = HashMap::new();
207        doc_freqs.insert("common".to_string(), 50); // appears in 50 docs
208        doc_freqs.insert("rare".to_string(), 2); // appears in 2 docs
209
210        let extractor = TfIdfKeywordExtractor::new(doc_freqs, 100);
211
212        let idf_common = extractor.calculate_idf("common");
213        let idf_rare = extractor.calculate_idf("rare");
214
215        // Rare terms should have higher IDF
216        assert!(idf_rare > idf_common);
217        // log(100/50) = log(2) ≈ 0.69
218        assert!((idf_common - 0.69).abs() < 0.1);
219        // log(100/2) = log(50) ≈ 3.91
220        assert!((idf_rare - 3.91).abs() < 0.1);
221    }
222
223    #[test]
224    fn test_keyword_extraction() {
225        // Build a proper corpus for realistic TF-IDF scores
226        let mut extractor = TfIdfKeywordExtractor::new_default();
227
228        // Add background corpus documents to establish IDF scores
229        extractor.add_document_to_corpus("artificial intelligence is the future");
230        extractor.add_document_to_corpus("deep learning uses neural networks");
231        extractor.add_document_to_corpus("natural language processing is important");
232
233        let text = "machine learning and deep learning are important topics in artificial intelligence. \
234                    neural networks and machine learning models are widely used.";
235
236        let keywords = extractor.extract_keywords(text, 5);
237
238        assert!(keywords.len() >= 3);
239        // "learning" and "machine" should rank high due to frequency in the target text
240        let keyword_terms: Vec<&str> = keywords.iter().map(|(w, _)| w.as_str()).collect();
241
242        // At least one of these high-frequency terms should appear
243        assert!(keyword_terms.contains(&"learning") ||
244                keyword_terms.contains(&"machine") ||
245                keyword_terms.contains(&"neural"),
246                "Expected high-frequency terms not found. Got: {:?}", keyword_terms);
247    }
248
249    #[test]
250    fn test_corpus_building() {
251        let mut extractor = TfIdfKeywordExtractor::new_default();
252
253        extractor.add_document_to_corpus("machine learning is amazing");
254        extractor.add_document_to_corpus("deep learning is powerful");
255        extractor.add_document_to_corpus("natural language processing");
256
257        let (total_docs, unique_terms) = extractor.corpus_stats();
258        assert_eq!(total_docs, 4); // 1 initial + 3 added
259        assert!(unique_terms > 0);
260    }
261
262    #[test]
263    fn test_stopword_filtering() {
264        let extractor = TfIdfKeywordExtractor::new_default();
265        let text = "The quick brown fox jumps over the lazy dog and the cat";
266        let keywords = extractor.extract_keyword_strings(text, 10);
267
268        // Stopwords like "the", "and", "over" should not appear
269        assert!(!keywords.iter().any(|w| w == "the"));
270        assert!(!keywords.iter().any(|w| w == "and"));
271        assert!(!keywords.iter().any(|w| w == "over"));
272
273        // Content words should appear
274        assert!(keywords.iter().any(|w| w == "quick" || w == "brown" || w == "fox"));
275    }
276}