linfa_preprocessing/
tf_idf_vectorization.rs

1//! Term frequency - inverse document frequency vectorization methods
2
3use crate::countgrams::{CountVectorizer, CountVectorizerParams};
4use crate::error::Result;
5use encoding::types::EncodingRef;
6use encoding::DecoderTrap;
7use ndarray::{Array1, ArrayBase, Data, Ix1};
8use sprs::CsMat;
9
10#[cfg(feature = "serde")]
11use serde_crate::{Deserialize, Serialize};
12
13#[cfg_attr(
14    feature = "serde",
15    derive(Serialize, Deserialize),
16    serde(crate = "serde_crate")
17)]
18#[derive(Clone, Debug, PartialEq, Eq, Hash)]
19/// Methods for computing the inverse document frequency of a vocabulary entry
20pub enum TfIdfMethod {
21    /// Computes the idf as `log(1+n/1+document_frequency) + 1`. The "plus ones" inside the log
22    /// add an artificial document containing every vocabulary entry, preventing divisions by zero.
23    /// The "plus one" after the log allows vocabulary entries that appear in every document to still be considered with
24    /// a weight of one instead of being completely discarded.
25    Smooth,
26    /// Computes the idf as `log(n/document_frequency) +1`. The "plus one" after the log allows vocabulary entries that appear in every document to still be considered with
27    /// a weight of one instead of being completely discarded. If a vocabulary entry has zero document frequency this will produce a division by zero.
28    NonSmooth,
29    /// Textbook definition of idf, computed as `log(n/ 1 + document_frequency)` which prevents divisions by zero and discards entries that appear in every document.
30    Textbook,
31}
32
33impl TfIdfMethod {
34    pub fn compute_idf(&self, n: usize, df: usize) -> f64 {
35        match self {
36            TfIdfMethod::Smooth => ((1. + n as f64) / (1. + df as f64)).ln() + 1.,
37            TfIdfMethod::NonSmooth => (n as f64 / df as f64).ln() + 1.,
38            TfIdfMethod::Textbook => (n as f64 / (1. + df as f64)).ln(),
39        }
40    }
41}
42
43/// Simlar to [`CountVectorizer`] but instead of
44/// just counting the term frequency of each vocabulary entry in each given document,
45/// it computes the term frequecy times the inverse document frequency, thus giving more importance
46/// to entries that appear many times but only on some documents. The weight function can be adjusted
47/// by setting the appropriate [method](TfIdfMethod). This struct provides the same string  
48/// processing customizations described in [`CountVectorizer`].
49#[cfg_attr(
50    feature = "serde",
51    derive(Serialize, Deserialize),
52    serde(crate = "serde_crate")
53)]
54#[derive(Clone, Debug)]
55pub struct TfIdfVectorizer {
56    count_vectorizer: CountVectorizerParams,
57    method: TfIdfMethod,
58}
59
60impl std::default::Default for TfIdfVectorizer {
61    fn default() -> Self {
62        Self {
63            count_vectorizer: CountVectorizerParams::default(),
64            method: TfIdfMethod::Smooth,
65        }
66    }
67}
68
69impl TfIdfVectorizer {
70    ///If true, all documents used for fitting will be converted to lowercase.
71    pub fn convert_to_lowercase(self, convert_to_lowercase: bool) -> Self {
72        Self {
73            count_vectorizer: self
74                .count_vectorizer
75                .convert_to_lowercase(convert_to_lowercase),
76            method: self.method,
77        }
78    }
79
80    /// Sets the regex espression used to split decuments into tokens
81    pub fn split_regex(self, regex_str: &str) -> Self {
82        Self {
83            count_vectorizer: self.count_vectorizer.split_regex(regex_str),
84            method: self.method,
85        }
86    }
87
88    /// If set to `(1,1)` single tokens will be candidate vocabulary entries, if `(2,2)` then adjacent token pairs will be considered,
89    /// if `(1,2)` then both single tokens and adjacent token pairs will be considered, and so on. The definition of token depends on the
90    /// regex used fpr splitting the documents.
91    ///
92    /// `min_n` should not be greater than `max_n`
93    pub fn n_gram_range(self, min_n: usize, max_n: usize) -> Self {
94        Self {
95            count_vectorizer: self.count_vectorizer.n_gram_range(min_n, max_n),
96            method: self.method,
97        }
98    }
99
100    /// If true, all charachters in the documents used for fitting will be normalized according to unicode's NFKD normalization.
101    pub fn normalize(self, normalize: bool) -> Self {
102        Self {
103            count_vectorizer: self.count_vectorizer.normalize(normalize),
104            method: self.method,
105        }
106    }
107
108    /// Specifies the minimum and maximum (relative) document frequencies that each vocabulary entry must satisfy.
109    /// `min_freq` and `max_freq` must lie in `0..=1` and `min_freq` should not be greater than `max_freq`
110    pub fn document_frequency(self, min_freq: f32, max_freq: f32) -> Self {
111        Self {
112            count_vectorizer: self.count_vectorizer.document_frequency(min_freq, max_freq),
113            method: self.method,
114        }
115    }
116
117    /// List of entries to be excluded from the generated vocabulary.
118    pub fn stopwords<T: ToString>(self, stopwords: &[T]) -> Self {
119        Self {
120            count_vectorizer: self.count_vectorizer.stopwords(stopwords),
121            method: self.method,
122        }
123    }
124
125    /// Learns a vocabulary from the texts in `x`, according to the specified attributes and maps each
126    /// vocabulary entry to an integer value, producing a [FittedTfIdfVectorizer].
127    ///
128    /// Returns an error if:
129    /// * one of the `n_gram` boundaries is set to zero or the minimum value is greater than the maximum value
130    /// * if the minimum document frequency is greater than one or than the maximum frequency, or if the maximum frequecy is  
131    ///   smaller than zero
132    pub fn fit<T: ToString + Clone, D: Data<Elem = T>>(
133        &self,
134        x: &ArrayBase<D, Ix1>,
135    ) -> Result<FittedTfIdfVectorizer> {
136        let fitted_vectorizer = self.count_vectorizer.fit(x)?;
137        Ok(FittedTfIdfVectorizer {
138            fitted_vectorizer,
139            method: self.method.clone(),
140        })
141    }
142
143    /// Produces a [FittedTfIdfVectorizer] with the input vocabulary.
144    /// All struct attributes are ignored in the fitting but will be used by the [FittedTfIdfVectorizer]
145    /// to transform any text to be examined. As such this will return an error in the same cases as the `fit` method.
146    pub fn fit_vocabulary<T: ToString>(&self, words: &[T]) -> Result<FittedTfIdfVectorizer> {
147        let fitted_vectorizer = self.count_vectorizer.fit_vocabulary(words)?;
148        Ok(FittedTfIdfVectorizer {
149            fitted_vectorizer,
150            method: self.method.clone(),
151        })
152    }
153
154    pub fn fit_files<P: AsRef<std::path::Path>>(
155        &self,
156        input: &[P],
157        encoding: EncodingRef,
158        trap: DecoderTrap,
159    ) -> Result<FittedTfIdfVectorizer> {
160        let fitted_vectorizer = self.count_vectorizer.fit_files(input, encoding, trap)?;
161        Ok(FittedTfIdfVectorizer {
162            fitted_vectorizer,
163            method: self.method.clone(),
164        })
165    }
166}
167
168/// Counts the occurrences of each vocabulary entry, learned during fitting, in a sequence of texts and scales them by the inverse document
169/// document frequency defined by the [method](TfIdfMethod). Each vocabulary entry is mapped
170/// to an integer value that is used to index the count in the result.
171#[cfg_attr(
172    feature = "serde",
173    derive(Serialize, Deserialize),
174    serde(crate = "serde_crate")
175)]
176#[derive(Clone, Debug)]
177pub struct FittedTfIdfVectorizer {
178    fitted_vectorizer: CountVectorizer,
179    method: TfIdfMethod,
180}
181
182impl FittedTfIdfVectorizer {
183    /// Number of vocabulary entries learned during fitting
184    pub fn nentries(&self) -> usize {
185        self.fitted_vectorizer.vocabulary.len()
186    }
187
188    /// Constains all vocabulary entries, in the same order used by the `transform` method.
189    pub fn vocabulary(&self) -> &Vec<String> {
190        self.fitted_vectorizer.vocabulary()
191    }
192
193    /// Returns the inverse document frequency method used in the tansform method
194    pub fn method(&self) -> &TfIdfMethod {
195        &self.method
196    }
197
198    /// Given a sequence of `n` documents, produces an array of size `(n, vocabulary_entries)` where column `j` of row `i`
199    /// is the number of occurrences of vocabulary entry `j` in the text of index `i`, scaled by the inverse document frequency.
200    ///  Vocabulary entry `j` is the string at the `j`-th position in the vocabulary.
201    pub fn transform<T: ToString, D: Data<Elem = T>>(&self, x: &ArrayBase<D, Ix1>) -> CsMat<f64> {
202        let (term_freqs, doc_freqs) = self.fitted_vectorizer.get_term_and_document_frequencies(x);
203        self.apply_tf_idf(term_freqs, doc_freqs)
204    }
205
206    pub fn transform_files<P: AsRef<std::path::Path>>(
207        &self,
208        input: &[P],
209        encoding: EncodingRef,
210        trap: DecoderTrap,
211    ) -> CsMat<f64> {
212        let (term_freqs, doc_freqs) = self
213            .fitted_vectorizer
214            .get_term_and_document_frequencies_files(input, encoding, trap);
215        self.apply_tf_idf(term_freqs, doc_freqs)
216    }
217
218    fn apply_tf_idf(&self, term_freqs: CsMat<usize>, doc_freqs: Array1<usize>) -> CsMat<f64> {
219        let mut term_freqs: CsMat<f64> = term_freqs.map(|x| *x as f64);
220        let inv_doc_freqs =
221            doc_freqs.mapv(|doc_freq| self.method.compute_idf(term_freqs.rows(), doc_freq));
222        for mut row_vec in term_freqs.outer_iterator_mut() {
223            for (col_i, val) in row_vec.iter_mut() {
224                *val *= inv_doc_freqs[col_i];
225            }
226        }
227        term_freqs
228    }
229}
230
231#[cfg(test)]
232mod tests {
233
234    use super::*;
235    use crate::column_for_word;
236    use approx::assert_abs_diff_eq;
237    use ndarray::array;
238    use std::fs::File;
239    use std::io::Write;
240
241    macro_rules! assert_tf_idfs_for_word {
242
243        ($voc:expr, $transf:expr, $(($word:expr, $counts:expr)),*) => {
244            $ (
245                assert_abs_diff_eq!(column_for_word!($voc, $transf, $word), $counts, epsilon=1e-3);
246            )*
247        }
248    }
249
250    #[test]
251    fn autotraits() {
252        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
253        has_autotraits::<TfIdfMethod>();
254    }
255
256    #[test]
257    fn test_tf_idf() {
258        let texts = array![
259            "one and two and three",
260            "three and four and five",
261            "seven and eight",
262            "maybe ten and eleven",
263            "avoid singletons: one two four five seven eight ten eleven and an and"
264        ];
265        let vectorizer = TfIdfVectorizer::default().fit(&texts).unwrap();
266        let vocabulary = vectorizer.vocabulary();
267        let transformed = vectorizer.transform(&texts).to_dense();
268        assert_eq!(transformed.dim(), (texts.len(), vocabulary.len()));
269        assert_tf_idfs_for_word!(
270            vocabulary,
271            transformed,
272            ("one", array![1.693, 0.0, 0.0, 0.0, 1.693]),
273            ("two", array![1.693, 0.0, 0.0, 0.0, 1.693]),
274            ("three", array![1.693, 1.693, 0.0, 0.0, 0.0]),
275            ("four", array![0.0, 1.693, 0.0, 0.0, 1.693]),
276            ("and", array![2.0, 2.0, 1.0, 1.0, 2.0]),
277            ("five", array![0.0, 1.693, 0.0, 0.0, 1.693]),
278            ("seven", array![0.0, 0.0, 1.693, 0.0, 1.693]),
279            ("eight", array![0.0, 0.0, 1.693, 0.0, 1.693]),
280            ("ten", array![0.0, 0.0, 0.0, 1.693, 1.693]),
281            ("eleven", array![0.0, 0.0, 0.0, 1.693, 1.693]),
282            ("an", array![0.0, 0.0, 0.0, 0.0, 2.098]),
283            ("avoid", array![0.0, 0.0, 0.0, 0.0, 2.098]),
284            ("singletons", array![0.0, 0.0, 0.0, 0.0, 2.098]),
285            ("maybe", array![0.0, 0.0, 0.0, 2.098, 0.0])
286        );
287    }
288
289    #[test]
290    fn test_tf_idf_files() {
291        let text_files = create_test_files();
292        let vectorizer = TfIdfVectorizer::default()
293            .fit_files(
294                &text_files,
295                encoding::all::UTF_8,
296                encoding::DecoderTrap::Strict,
297            )
298            .unwrap();
299        let vocabulary = vectorizer.vocabulary();
300        let transformed = vectorizer
301            .transform_files(
302                &text_files,
303                encoding::all::UTF_8,
304                encoding::DecoderTrap::Strict,
305            )
306            .to_dense();
307        assert_eq!(transformed.dim(), (text_files.len(), vocabulary.len()));
308        assert_tf_idfs_for_word!(
309            vocabulary,
310            transformed,
311            ("one", array![1.693, 0.0, 0.0, 0.0, 1.693]),
312            ("two", array![1.693, 0.0, 0.0, 0.0, 1.693]),
313            ("three", array![1.693, 1.693, 0.0, 0.0, 0.0]),
314            ("four", array![0.0, 1.693, 0.0, 0.0, 1.693]),
315            ("and", array![2.0, 2.0, 1.0, 1.0, 2.0]),
316            ("five", array![0.0, 1.693, 0.0, 0.0, 1.693]),
317            ("seven", array![0.0, 0.0, 1.693, 0.0, 1.693]),
318            ("eight", array![0.0, 0.0, 1.693, 0.0, 1.693]),
319            ("ten", array![0.0, 0.0, 0.0, 1.693, 1.693]),
320            ("eleven", array![0.0, 0.0, 0.0, 1.693, 1.693]),
321            ("an", array![0.0, 0.0, 0.0, 0.0, 2.098]),
322            ("avoid", array![0.0, 0.0, 0.0, 0.0, 2.098]),
323            ("singletons", array![0.0, 0.0, 0.0, 0.0, 2.098]),
324            ("maybe", array![0.0, 0.0, 0.0, 2.098, 0.0])
325        );
326        delete_test_files(&text_files)
327    }
328
329    fn create_test_files() -> Vec<&'static str> {
330        let file_names = vec![
331            "./tf_idf_vectorization_test_file_1",
332            "./tf_idf_vectorization_test_file_2",
333            "./tf_idf_vectorization_test_file_3",
334            "./tf_idf_vectorization_test_file_4",
335            "./tf_idf_vectorization_test_file_5",
336        ];
337        let contents = &[
338            "one and two and three",
339            "three and four and five",
340            "seven and eight",
341            "maybe ten and eleven",
342            "avoid singletons: one two four five seven eight ten eleven and an and",
343        ];
344        //create files and write contents
345        for (f_name, f_content) in file_names.iter().zip(contents.iter()) {
346            let mut file = File::create(f_name).unwrap();
347            file.write_all(f_content.as_bytes()).unwrap();
348        }
349        file_names
350    }
351
352    fn delete_test_files(file_names: &[&'static str]) {
353        for f_name in file_names.iter() {
354            std::fs::remove_file(f_name).unwrap();
355        }
356    }
357}