1use crate::countgrams::{CountVectorizer, CountVectorizerParams};
4use crate::error::Result;
5use encoding::types::EncodingRef;
6use encoding::DecoderTrap;
7use ndarray::{Array1, ArrayBase, Data, Ix1};
8use sprs::CsMat;
9
10#[cfg(feature = "serde")]
11use serde_crate::{Deserialize, Serialize};
12
13#[cfg_attr(
14 feature = "serde",
15 derive(Serialize, Deserialize),
16 serde(crate = "serde_crate")
17)]
18#[derive(Clone, Debug, PartialEq, Eq, Hash)]
19pub enum TfIdfMethod {
21 Smooth,
26 NonSmooth,
29 Textbook,
31}
32
33impl TfIdfMethod {
34 pub fn compute_idf(&self, n: usize, df: usize) -> f64 {
35 match self {
36 TfIdfMethod::Smooth => ((1. + n as f64) / (1. + df as f64)).ln() + 1.,
37 TfIdfMethod::NonSmooth => (n as f64 / df as f64).ln() + 1.,
38 TfIdfMethod::Textbook => (n as f64 / (1. + df as f64)).ln(),
39 }
40 }
41}
42
43#[cfg_attr(
50 feature = "serde",
51 derive(Serialize, Deserialize),
52 serde(crate = "serde_crate")
53)]
54#[derive(Clone, Debug)]
55pub struct TfIdfVectorizer {
56 count_vectorizer: CountVectorizerParams,
57 method: TfIdfMethod,
58}
59
60impl std::default::Default for TfIdfVectorizer {
61 fn default() -> Self {
62 Self {
63 count_vectorizer: CountVectorizerParams::default(),
64 method: TfIdfMethod::Smooth,
65 }
66 }
67}
68
69impl TfIdfVectorizer {
70 pub fn convert_to_lowercase(self, convert_to_lowercase: bool) -> Self {
72 Self {
73 count_vectorizer: self
74 .count_vectorizer
75 .convert_to_lowercase(convert_to_lowercase),
76 method: self.method,
77 }
78 }
79
80 pub fn split_regex(self, regex_str: &str) -> Self {
82 Self {
83 count_vectorizer: self.count_vectorizer.split_regex(regex_str),
84 method: self.method,
85 }
86 }
87
88 pub fn n_gram_range(self, min_n: usize, max_n: usize) -> Self {
94 Self {
95 count_vectorizer: self.count_vectorizer.n_gram_range(min_n, max_n),
96 method: self.method,
97 }
98 }
99
100 pub fn normalize(self, normalize: bool) -> Self {
102 Self {
103 count_vectorizer: self.count_vectorizer.normalize(normalize),
104 method: self.method,
105 }
106 }
107
108 pub fn document_frequency(self, min_freq: f32, max_freq: f32) -> Self {
111 Self {
112 count_vectorizer: self.count_vectorizer.document_frequency(min_freq, max_freq),
113 method: self.method,
114 }
115 }
116
117 pub fn stopwords<T: ToString>(self, stopwords: &[T]) -> Self {
119 Self {
120 count_vectorizer: self.count_vectorizer.stopwords(stopwords),
121 method: self.method,
122 }
123 }
124
125 pub fn fit<T: ToString + Clone, D: Data<Elem = T>>(
133 &self,
134 x: &ArrayBase<D, Ix1>,
135 ) -> Result<FittedTfIdfVectorizer> {
136 let fitted_vectorizer = self.count_vectorizer.fit(x)?;
137 Ok(FittedTfIdfVectorizer {
138 fitted_vectorizer,
139 method: self.method.clone(),
140 })
141 }
142
143 pub fn fit_vocabulary<T: ToString>(&self, words: &[T]) -> Result<FittedTfIdfVectorizer> {
147 let fitted_vectorizer = self.count_vectorizer.fit_vocabulary(words)?;
148 Ok(FittedTfIdfVectorizer {
149 fitted_vectorizer,
150 method: self.method.clone(),
151 })
152 }
153
154 pub fn fit_files<P: AsRef<std::path::Path>>(
155 &self,
156 input: &[P],
157 encoding: EncodingRef,
158 trap: DecoderTrap,
159 ) -> Result<FittedTfIdfVectorizer> {
160 let fitted_vectorizer = self.count_vectorizer.fit_files(input, encoding, trap)?;
161 Ok(FittedTfIdfVectorizer {
162 fitted_vectorizer,
163 method: self.method.clone(),
164 })
165 }
166}
167
168#[cfg_attr(
172 feature = "serde",
173 derive(Serialize, Deserialize),
174 serde(crate = "serde_crate")
175)]
176#[derive(Clone, Debug)]
177pub struct FittedTfIdfVectorizer {
178 fitted_vectorizer: CountVectorizer,
179 method: TfIdfMethod,
180}
181
182impl FittedTfIdfVectorizer {
183 pub fn nentries(&self) -> usize {
185 self.fitted_vectorizer.vocabulary.len()
186 }
187
188 pub fn vocabulary(&self) -> &Vec<String> {
190 self.fitted_vectorizer.vocabulary()
191 }
192
193 pub fn method(&self) -> &TfIdfMethod {
195 &self.method
196 }
197
198 pub fn transform<T: ToString, D: Data<Elem = T>>(&self, x: &ArrayBase<D, Ix1>) -> CsMat<f64> {
202 let (term_freqs, doc_freqs) = self.fitted_vectorizer.get_term_and_document_frequencies(x);
203 self.apply_tf_idf(term_freqs, doc_freqs)
204 }
205
206 pub fn transform_files<P: AsRef<std::path::Path>>(
207 &self,
208 input: &[P],
209 encoding: EncodingRef,
210 trap: DecoderTrap,
211 ) -> CsMat<f64> {
212 let (term_freqs, doc_freqs) = self
213 .fitted_vectorizer
214 .get_term_and_document_frequencies_files(input, encoding, trap);
215 self.apply_tf_idf(term_freqs, doc_freqs)
216 }
217
218 fn apply_tf_idf(&self, term_freqs: CsMat<usize>, doc_freqs: Array1<usize>) -> CsMat<f64> {
219 let mut term_freqs: CsMat<f64> = term_freqs.map(|x| *x as f64);
220 let inv_doc_freqs =
221 doc_freqs.mapv(|doc_freq| self.method.compute_idf(term_freqs.rows(), doc_freq));
222 for mut row_vec in term_freqs.outer_iterator_mut() {
223 for (col_i, val) in row_vec.iter_mut() {
224 *val *= inv_doc_freqs[col_i];
225 }
226 }
227 term_freqs
228 }
229}
230
231#[cfg(test)]
232mod tests {
233
234 use super::*;
235 use crate::column_for_word;
236 use approx::assert_abs_diff_eq;
237 use ndarray::array;
238 use std::fs::File;
239 use std::io::Write;
240
241 macro_rules! assert_tf_idfs_for_word {
242
243 ($voc:expr, $transf:expr, $(($word:expr, $counts:expr)),*) => {
244 $ (
245 assert_abs_diff_eq!(column_for_word!($voc, $transf, $word), $counts, epsilon=1e-3);
246 )*
247 }
248 }
249
250 #[test]
251 fn autotraits() {
252 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
253 has_autotraits::<TfIdfMethod>();
254 }
255
256 #[test]
257 fn test_tf_idf() {
258 let texts = array![
259 "one and two and three",
260 "three and four and five",
261 "seven and eight",
262 "maybe ten and eleven",
263 "avoid singletons: one two four five seven eight ten eleven and an and"
264 ];
265 let vectorizer = TfIdfVectorizer::default().fit(&texts).unwrap();
266 let vocabulary = vectorizer.vocabulary();
267 let transformed = vectorizer.transform(&texts).to_dense();
268 assert_eq!(transformed.dim(), (texts.len(), vocabulary.len()));
269 assert_tf_idfs_for_word!(
270 vocabulary,
271 transformed,
272 ("one", array![1.693, 0.0, 0.0, 0.0, 1.693]),
273 ("two", array![1.693, 0.0, 0.0, 0.0, 1.693]),
274 ("three", array![1.693, 1.693, 0.0, 0.0, 0.0]),
275 ("four", array![0.0, 1.693, 0.0, 0.0, 1.693]),
276 ("and", array![2.0, 2.0, 1.0, 1.0, 2.0]),
277 ("five", array![0.0, 1.693, 0.0, 0.0, 1.693]),
278 ("seven", array![0.0, 0.0, 1.693, 0.0, 1.693]),
279 ("eight", array![0.0, 0.0, 1.693, 0.0, 1.693]),
280 ("ten", array![0.0, 0.0, 0.0, 1.693, 1.693]),
281 ("eleven", array![0.0, 0.0, 0.0, 1.693, 1.693]),
282 ("an", array![0.0, 0.0, 0.0, 0.0, 2.098]),
283 ("avoid", array![0.0, 0.0, 0.0, 0.0, 2.098]),
284 ("singletons", array![0.0, 0.0, 0.0, 0.0, 2.098]),
285 ("maybe", array![0.0, 0.0, 0.0, 2.098, 0.0])
286 );
287 }
288
289 #[test]
290 fn test_tf_idf_files() {
291 let text_files = create_test_files();
292 let vectorizer = TfIdfVectorizer::default()
293 .fit_files(
294 &text_files,
295 encoding::all::UTF_8,
296 encoding::DecoderTrap::Strict,
297 )
298 .unwrap();
299 let vocabulary = vectorizer.vocabulary();
300 let transformed = vectorizer
301 .transform_files(
302 &text_files,
303 encoding::all::UTF_8,
304 encoding::DecoderTrap::Strict,
305 )
306 .to_dense();
307 assert_eq!(transformed.dim(), (text_files.len(), vocabulary.len()));
308 assert_tf_idfs_for_word!(
309 vocabulary,
310 transformed,
311 ("one", array![1.693, 0.0, 0.0, 0.0, 1.693]),
312 ("two", array![1.693, 0.0, 0.0, 0.0, 1.693]),
313 ("three", array![1.693, 1.693, 0.0, 0.0, 0.0]),
314 ("four", array![0.0, 1.693, 0.0, 0.0, 1.693]),
315 ("and", array![2.0, 2.0, 1.0, 1.0, 2.0]),
316 ("five", array![0.0, 1.693, 0.0, 0.0, 1.693]),
317 ("seven", array![0.0, 0.0, 1.693, 0.0, 1.693]),
318 ("eight", array![0.0, 0.0, 1.693, 0.0, 1.693]),
319 ("ten", array![0.0, 0.0, 0.0, 1.693, 1.693]),
320 ("eleven", array![0.0, 0.0, 0.0, 1.693, 1.693]),
321 ("an", array![0.0, 0.0, 0.0, 0.0, 2.098]),
322 ("avoid", array![0.0, 0.0, 0.0, 0.0, 2.098]),
323 ("singletons", array![0.0, 0.0, 0.0, 0.0, 2.098]),
324 ("maybe", array![0.0, 0.0, 0.0, 2.098, 0.0])
325 );
326 delete_test_files(&text_files)
327 }
328
329 fn create_test_files() -> Vec<&'static str> {
330 let file_names = vec![
331 "./tf_idf_vectorization_test_file_1",
332 "./tf_idf_vectorization_test_file_2",
333 "./tf_idf_vectorization_test_file_3",
334 "./tf_idf_vectorization_test_file_4",
335 "./tf_idf_vectorization_test_file_5",
336 ];
337 let contents = &[
338 "one and two and three",
339 "three and four and five",
340 "seven and eight",
341 "maybe ten and eleven",
342 "avoid singletons: one two four five seven eight ten eleven and an and",
343 ];
344 for (f_name, f_content) in file_names.iter().zip(contents.iter()) {
346 let mut file = File::create(f_name).unwrap();
347 file.write_all(f_content.as_bytes()).unwrap();
348 }
349 file_names
350 }
351
352 fn delete_test_files(file_names: &[&'static str]) {
353 for f_name in file_names.iter() {
354 std::fs::remove_file(f_name).unwrap();
355 }
356 }
357}