1use super::count::CountVectorizer;
8use crate::sparse::CsrMatrix;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum TfidfNorm {
13 L1,
15 L2,
17 None,
19}
20
21#[derive(Debug, Clone)]
36pub struct TfidfVectorizer {
37 count: CountVectorizer,
39 idf_values: Vec<f64>,
41 norm: TfidfNorm,
43 sublinear_tf: bool,
45 smooth_idf: bool,
47 fitted: bool,
49}
50
51impl TfidfVectorizer {
52 pub fn new() -> Self {
54 Self {
55 count: CountVectorizer::new(),
56 idf_values: Vec::new(),
57 norm: TfidfNorm::L2,
58 sublinear_tf: false,
59 smooth_idf: true,
60 fitted: false,
61 }
62 }
63
64 pub fn min_df(mut self, n: usize) -> Self {
66 self.count = self.count.min_df(n);
67 self
68 }
69
70 pub fn max_df(mut self, frac: f64) -> Self {
72 self.count = self.count.max_df(frac);
73 self
74 }
75
76 pub fn ngram_range(mut self, min_n: usize, max_n: usize) -> Self {
78 self.count = self.count.ngram_range(min_n, max_n);
79 self
80 }
81
82 pub fn max_features(mut self, n: usize) -> Self {
84 self.count = self.count.max_features(n);
85 self
86 }
87
88 pub fn norm(mut self, norm: TfidfNorm) -> Self {
90 self.norm = norm;
91 self
92 }
93
94 pub fn sublinear_tf(mut self, enable: bool) -> Self {
96 self.sublinear_tf = enable;
97 self
98 }
99
100 pub fn smooth_idf(mut self, enable: bool) -> Self {
103 self.smooth_idf = enable;
104 self
105 }
106
107 pub fn fit<S: AsRef<str>>(&mut self, documents: &[S]) {
109 self.count.fit(documents);
110
111 let n_docs = documents.len();
112 let n_features = self.count.n_features();
113
114 let mut doc_freq = vec![0usize; n_features];
116 let vocab = self.count.vocabulary();
117
118 for doc in documents {
119 let grams = self.count.tokenize_doc(doc.as_ref());
120 let mut seen = std::collections::HashSet::new();
121 for gram in &grams {
122 if let Some(&idx) = vocab.get(gram) {
123 if seen.insert(idx) {
124 doc_freq[idx] += 1;
125 }
126 }
127 }
128 }
129
130 self.idf_values = vec![0.0; n_features];
132 let smooth = if self.smooth_idf { 1.0 } else { 0.0 };
133 let n = n_docs as f64 + smooth;
134
135 for (i, &df) in doc_freq.iter().enumerate() {
136 let df_smooth = df as f64 + smooth;
137 self.idf_values[i] = (n / df_smooth).ln() + 1.0;
138 }
139
140 self.fitted = true;
141 }
142
143 pub fn transform<S: AsRef<str>>(&self, documents: &[S]) -> CsrMatrix {
145 assert!(
146 self.fitted,
147 "TfidfVectorizer: must call fit() before transform()"
148 );
149
150 let counts = self.count.transform(documents);
151 let n_rows = counts.n_rows();
152 let n_cols = counts.n_cols();
153
154 if n_rows == 0 || n_cols == 0 {
155 return CsrMatrix::from_dense(&[]);
156 }
157
158 let count_dense = counts.to_dense();
160
161 let mut triplet_rows = Vec::new();
162 let mut triplet_cols = Vec::new();
163 let mut triplet_vals = Vec::new();
164
165 for (row_idx, row) in count_dense.iter().enumerate() {
166 let mut row_entries: Vec<(usize, f64)> = Vec::new();
167
168 for (col, &count) in row.iter().enumerate() {
169 if count == 0.0 {
170 continue;
171 }
172
173 let tf = if self.sublinear_tf {
174 1.0 + count.ln()
175 } else {
176 count
177 };
178
179 let idf = self.idf_values.get(col).copied().unwrap_or(1.0);
180 let tfidf = tf * idf;
181 row_entries.push((col, tfidf));
182 }
183
184 if !row_entries.is_empty() {
186 match self.norm {
187 TfidfNorm::L2 => {
188 let norm: f64 = row_entries.iter().map(|(_, v)| v * v).sum::<f64>().sqrt();
189 if norm > 0.0 {
190 for entry in &mut row_entries {
191 entry.1 /= norm;
192 }
193 }
194 }
195 TfidfNorm::L1 => {
196 let norm: f64 = row_entries.iter().map(|(_, v)| v.abs()).sum();
197 if norm > 0.0 {
198 for entry in &mut row_entries {
199 entry.1 /= norm;
200 }
201 }
202 }
203 TfidfNorm::None => {}
204 }
205 }
206
207 for (col, val) in row_entries {
208 triplet_rows.push(row_idx);
209 triplet_cols.push(col);
210 triplet_vals.push(val);
211 }
212 }
213
214 CsrMatrix::from_triplets(&triplet_rows, &triplet_cols, &triplet_vals, n_rows, n_cols)
215 .expect("TfidfVectorizer: internal CSR construction error")
216 }
217
218 pub fn fit_transform<S: AsRef<str>>(&mut self, documents: &[S]) -> CsrMatrix {
220 self.fit(documents);
221 self.transform(documents)
222 }
223
224 pub fn idf(&self) -> &[f64] {
226 &self.idf_values
227 }
228
229 pub fn vocabulary(&self) -> &std::collections::HashMap<String, usize> {
231 self.count.vocabulary()
232 }
233
234 pub fn get_feature_names(&self) -> Vec<String> {
236 self.count.get_feature_names()
237 }
238
239 pub fn n_features(&self) -> usize {
241 self.count.n_features()
242 }
243}
244
245impl Default for TfidfVectorizer {
246 fn default() -> Self {
247 Self::new()
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn basic_fit_transform() {
257 let docs = ["the cat sat", "the dog sat", "the cat played"];
258 let mut tfidf = TfidfVectorizer::new();
259 let matrix = tfidf.fit_transform(&docs);
260
261 assert_eq!(matrix.n_rows(), 3);
262 assert_eq!(matrix.n_cols(), tfidf.n_features());
263 assert_eq!(tfidf.n_features(), 5); }
265
266 #[test]
267 fn idf_values_are_positive() {
268 let docs = ["hello world", "hello test"];
269 let mut tfidf = TfidfVectorizer::new();
270 tfidf.fit(&docs);
271
272 for &idf in tfidf.idf() {
273 assert!(idf > 0.0, "IDF should be positive, got {idf}");
274 }
275 }
276
277 #[test]
278 fn l2_normalization() {
279 let docs = ["a b c", "a b b"];
280 let mut tfidf = TfidfVectorizer::new().norm(TfidfNorm::L2);
281 let matrix = tfidf.fit_transform(&docs);
282 let dense = matrix.to_dense();
283
284 for row in &dense {
285 let norm: f64 = row.iter().map(|v| v * v).sum::<f64>().sqrt();
286 if norm > 0.0 {
287 assert!(
288 (norm - 1.0).abs() < 1e-10,
289 "L2 norm should be 1.0, got {norm}"
290 );
291 }
292 }
293 }
294
295 #[test]
296 fn l1_normalization() {
297 let docs = ["a b c"];
298 let mut tfidf = TfidfVectorizer::new().norm(TfidfNorm::L1);
299 let matrix = tfidf.fit_transform(&docs);
300 let dense = matrix.to_dense();
301
302 let norm: f64 = dense[0].iter().map(|v| v.abs()).sum();
303 assert!(
304 (norm - 1.0).abs() < 1e-10,
305 "L1 norm should be 1.0, got {norm}"
306 );
307 }
308
309 #[test]
310 fn no_normalization() {
311 let docs = ["a a"];
312 let mut tfidf = TfidfVectorizer::new().norm(TfidfNorm::None);
313 let matrix = tfidf.fit_transform(&docs);
314 let dense = matrix.to_dense();
315
316 assert!(
319 dense[0].iter().any(|&v| v > 1.0),
320 "Expected unnormalized values"
321 );
322 }
323
324 #[test]
325 fn smooth_idf_default() {
326 let docs = ["a", "b"];
327 let mut tfidf = TfidfVectorizer::new();
328 tfidf.fit(&docs);
329
330 for &idf in tfidf.idf() {
333 assert!(idf > 1.0, "Smooth IDF should be > 1.0, got {idf}");
334 }
335 }
336
337 #[test]
338 fn sublinear_tf() {
339 let docs = ["a a a a a"];
340 let mut tfidf = TfidfVectorizer::new()
341 .sublinear_tf(true)
342 .norm(TfidfNorm::None);
343 let matrix = tfidf.fit_transform(&docs);
344 let dense = matrix.to_dense();
345
346 let val = dense[0].iter().find(|&&v| v > 0.0).unwrap();
350 assert!(*val < 5.0, "Sublinear TF should reduce high counts");
352 }
353
354 #[test]
355 fn unseen_terms_ignored() {
356 let train = ["cat dog"];
357 let test = ["cat bird"]; let mut tfidf = TfidfVectorizer::new();
360 tfidf.fit(&train);
361
362 let matrix = tfidf.transform(&test);
363 let dense = matrix.to_dense();
364
365 let nnz: usize = dense[0].iter().filter(|&&v| v > 0.0).count();
366 assert_eq!(nnz, 1, "Only 'cat' should have a non-zero value");
367 }
368
369 #[test]
370 fn bigrams_tfidf() {
371 let docs = ["the cat sat"];
372 let mut tfidf = TfidfVectorizer::new().ngram_range(2, 2);
373 let matrix = tfidf.fit_transform(&docs);
374
375 assert_eq!(matrix.n_cols(), 2); }
377
378 #[test]
379 fn empty_documents() {
380 let docs: [&str; 0] = [];
381 let mut tfidf = TfidfVectorizer::new();
382 let matrix = tfidf.fit_transform(&docs);
383
384 assert_eq!(matrix.n_rows(), 0);
385 }
386}