linfa_preprocessing/countgrams/
mod.rs

1//! Count vectorization methods
2
3use std::collections::{HashMap, HashSet};
4use std::io::Read;
5use std::iter::IntoIterator;
6
7use encoding::types::EncodingRef;
8use encoding::DecoderTrap;
9use ndarray::{Array1, ArrayBase, ArrayViewMut1, Data, Ix1};
10use regex::Regex;
11use sprs::{CsMat, CsVec};
12use unicode_normalization::UnicodeNormalization;
13
14use crate::error::{PreprocessingError, Result};
15use crate::helpers::NGramList;
16pub use hyperparams::{CountVectorizerParams, CountVectorizerValidParams};
17use linfa::ParamGuard;
18
19#[cfg(feature = "serde")]
20use serde_crate::{Deserialize, Serialize};
21
22mod hyperparams;
23
24impl CountVectorizerValidParams {
25    /// Learns a vocabulary from the documents in `x`, according to the specified attributes and maps each
26    /// vocabulary entry to an integer value, producing a [CountVectorizer](CountVectorizer).
27    ///
28    /// Returns an error if:
29    /// * one of the `n_gram` boundaries is set to zero or the minimum value is greater than the maximum value
30    /// * if the minimum document frequency is greater than one or than the maximum frequency, or if the maximum frequency is  
31    ///   smaller than zero
32    /// * if the regex expression for the split is invalid
33    pub fn fit<T: ToString + Clone, D: Data<Elem = T>>(
34        &self,
35        x: &ArrayBase<D, Ix1>,
36    ) -> Result<CountVectorizer> {
37        // word, (integer mapping for word, document frequency for word)
38        let mut vocabulary: HashMap<String, (usize, usize)> = HashMap::new();
39        for string in x.iter().map(|s| transform_string(s.to_string(), self)) {
40            self.read_document_into_vocabulary(string, &self.split_regex(), &mut vocabulary);
41        }
42
43        let mut vocabulary = self.filter_vocabulary(vocabulary, x.len());
44        let vec_vocabulary = hashmap_to_vocabulary(&mut vocabulary);
45
46        Ok(CountVectorizer {
47            vocabulary,
48            vec_vocabulary,
49            properties: self.clone(),
50        })
51    }
52
53    /// Learns a vocabulary from the documents contained in the files in `input`, according to the specified attributes and maps each
54    /// vocabulary entry to an integer value, producing a [CountVectorizer](CountVectorizer).
55    ///
56    /// The files will be read using the specified `encoding`, and any sequence unrecognized by the encoding will be handled
57    /// according to `trap`.
58    ///
59    /// Returns an error if:
60    /// * one of the `n_gram` boundaries is set to zero or the minimum value is greater than the maximum value
61    /// * if the minimum document frequency is greater than one or than the maximum frequency, or if the maximum frequency is  
62    ///   smaller than zero
63    /// * if the regex expression for the split is invalid
64    /// * if one of the files couldn't be opened
65    /// * if the trap is strict and an unrecognized sequence is encountered in one of the files
66    pub fn fit_files<P: AsRef<std::path::Path>>(
67        &self,
68        input: &[P],
69        encoding: EncodingRef,
70        trap: DecoderTrap,
71    ) -> Result<CountVectorizer> {
72        // word, (integer mapping for word, document frequency for word)
73        let mut vocabulary: HashMap<String, (usize, usize)> = HashMap::new();
74        let documents_count = input.len();
75        for path in input {
76            let mut file = std::fs::File::open(path)?;
77            let mut document_bytes = Vec::new();
78            file.read_to_end(&mut document_bytes)?;
79            let document = encoding::decode(&document_bytes, trap, encoding).0;
80            // encoding error contains a cow string, can't just use ?, must go through the unwrap
81            if document.is_err() {
82                return Err(PreprocessingError::EncodingError(document.err().unwrap()));
83            }
84            // safe unwrap now that error has been handled
85            let document = transform_string(document.unwrap(), self);
86            self.read_document_into_vocabulary(document, &self.split_regex(), &mut vocabulary);
87        }
88
89        let mut vocabulary = self.filter_vocabulary(vocabulary, documents_count);
90        let vec_vocabulary = hashmap_to_vocabulary(&mut vocabulary);
91
92        Ok(CountVectorizer {
93            vocabulary,
94            vec_vocabulary,
95            properties: self.clone(),
96        })
97    }
98
99    /// Produces a [CountVectorizer](CountVectorizer) with the input vocabulary.
100    /// All struct attributes are ignored in the fitting but will be used by the [CountVectorizer](CountVectorizer)
101    /// to transform any text to be examined. As such this will return an error in the same cases as the `fit` method.
102    pub fn fit_vocabulary<T: ToString>(&self, words: &[T]) -> Result<CountVectorizer> {
103        let mut vocabulary: HashMap<String, (usize, usize)> = HashMap::with_capacity(words.len());
104        for item in words.iter().map(|w| w.to_string()) {
105            let len = vocabulary.len();
106            // do not care about frequencies/stopwords if a vocabulary is given. Always 1 frequency
107            vocabulary.entry(item).or_insert((len, 1));
108        }
109        let vec_vocabulary = hashmap_to_vocabulary(&mut vocabulary);
110        Ok(CountVectorizer {
111            vocabulary,
112            vec_vocabulary,
113            properties: self.clone(),
114        })
115    }
116
117    /// Removes vocabulary items that do not satisfy the document frequencies constraints or if they appear in the
118    /// optional stopwords test.
119    /// The total number of documents is needed to convert from relative document frequencies to
120    /// their absolute counterparts.
121    fn filter_vocabulary(
122        &self,
123        vocabulary: HashMap<String, (usize, usize)>,
124        n_documents: usize,
125    ) -> HashMap<String, (usize, usize)> {
126        let (min_df, max_df) = self.document_frequency();
127        let len_f32 = n_documents as f32;
128        let (min_abs_df, max_abs_df) = ((min_df * len_f32) as usize, (max_df * len_f32) as usize);
129
130        if min_abs_df == 0 && max_abs_df == n_documents {
131            match &self.stopwords() {
132                None => vocabulary,
133                Some(stopwords) => vocabulary
134                    .into_iter()
135                    .filter(|(entry, (_, _))| !stopwords.contains(entry))
136                    .collect(),
137            }
138        } else {
139            match &self.stopwords() {
140                None => vocabulary
141                    .into_iter()
142                    .filter(|(_, (_, abs_count))| {
143                        *abs_count >= min_abs_df && *abs_count <= max_abs_df
144                    })
145                    .collect(),
146                Some(stopwords) => vocabulary
147                    .into_iter()
148                    .filter(|(entry, (_, abs_count))| {
149                        *abs_count >= min_abs_df
150                            && *abs_count <= max_abs_df
151                            && !stopwords.contains(entry)
152                    })
153                    .collect(),
154            }
155        }
156    }
157
158    /// Inserts all vocabulary entries learned from a single document (`doc`) into the
159    /// shared `vocabulary`, setting the document frequency to one for new entries and
160    /// incrementing it by one for entries which were already present.
161    fn read_document_into_vocabulary(
162        &self,
163        doc: String,
164        regex: &Regex,
165        vocabulary: &mut HashMap<String, (usize, usize)>,
166    ) {
167        let words = regex.find_iter(&doc).map(|mat| mat.as_str()).collect();
168        let list = NGramList::new(words, self.n_gram_range());
169        let document_vocabulary: HashSet<String> = list.into_iter().flatten().collect();
170        for word in document_vocabulary {
171            let len = vocabulary.len();
172            // If vocabulary item was already present then increase its document frequency
173            if let Some((_, freq)) = vocabulary.get_mut(&word) {
174                *freq += 1;
175            // otherwise set it to one
176            } else {
177                vocabulary.insert(word, (len, 1));
178            }
179        }
180    }
181}
182
183impl CountVectorizerParams {
184    /// Learns a vocabulary from the documents in `x`, according to the specified attributes and maps each
185    /// vocabulary entry to an integer value, producing a [CountVectorizer](CountVectorizer).
186    ///
187    /// Returns an error if:
188    /// * one of the `n_gram` boundaries is set to zero or the minimum value is greater than the maximum value
189    /// * if the minimum document frequency is greater than one or than the maximum frequency, or if the maximum frequency is  
190    ///   smaller than zero
191    /// * if the regex expression for the split is invalid
192    pub fn fit<T: ToString + Clone, D: Data<Elem = T>>(
193        &self,
194        x: &ArrayBase<D, Ix1>,
195    ) -> Result<CountVectorizer> {
196        self.check_ref().and_then(|params| params.fit(x))
197    }
198
199    /// Learns a vocabulary from the documents contained in the files in `input`, according to the specified attributes and maps each
200    /// vocabulary entry to an integer value, producing a [CountVectorizer](CountVectorizer).
201    ///
202    /// The files will be read using the specified `encoding`, and any sequence unrecognized by the encoding will be handled
203    /// according to `trap`.
204    ///
205    /// Returns an error if:
206    /// * one of the `n_gram` boundaries is set to zero or the minimum value is greater than the maximum value
207    /// * if the minimum document frequency is greater than one or than the maximum frequency, or if the maximum frequency is  
208    ///   smaller than zero
209    /// * if the regex expression for the split is invalid
210    /// * if one of the files couldn't be opened
211    /// * if the trap is strict and an unrecognized sequence is encountered in one of the files
212    pub fn fit_files<P: AsRef<std::path::Path>>(
213        &self,
214        input: &[P],
215        encoding: EncodingRef,
216        trap: DecoderTrap,
217    ) -> Result<CountVectorizer> {
218        self.check_ref()
219            .and_then(|params| params.fit_files(input, encoding, trap))
220    }
221
222    /// Produces a [CountVectorizer](CountVectorizer) with the input vocabulary.
223    /// All struct attributes are ignored in the fitting but will be used by the [CountVectorizer](CountVectorizer)
224    /// to transform any text to be examined. As such this will return an error in the same cases as the `fit` method.
225    pub fn fit_vocabulary<T: ToString>(&self, words: &[T]) -> Result<CountVectorizer> {
226        self.check_ref()
227            .and_then(|params| params.fit_vocabulary(words))
228    }
229}
230
231/// Counts the occurrences of each vocabulary entry, learned during fitting, in a sequence of documents. Each vocabulary entry is mapped
232/// to an integer value that is used to index the count in the result.
233#[cfg_attr(
234    feature = "serde",
235    derive(Serialize, Deserialize),
236    serde(crate = "serde_crate")
237)]
238#[derive(Debug, Clone)]
239pub struct CountVectorizer {
240    pub(crate) vocabulary: HashMap<String, (usize, usize)>,
241    pub(crate) vec_vocabulary: Vec<String>,
242    pub(crate) properties: CountVectorizerValidParams,
243}
244
245impl CountVectorizer {
246    /// Construct a new set of parameters
247    pub fn params() -> CountVectorizerParams {
248        CountVectorizerParams::default()
249    }
250
251    /// Number of vocabulary entries learned during fitting
252    pub fn nentries(&self) -> usize {
253        self.vocabulary.len()
254    }
255
256    /// Given a sequence of `n` documents, produces a sparse array of size `(n, vocabulary_entries)` where column `j` of row `i`
257    /// is the number of occurrences of vocabulary entry `j` in the document of index `i`. Vocabulary entry `j` is the string
258    /// at the `j`-th position in the vocabulary. If a vocabulary entry was not encountered in a document, then the relative
259    /// cell in the sparse matrix will be set to `None`.
260    pub fn transform<T: ToString, D: Data<Elem = T>>(&self, x: &ArrayBase<D, Ix1>) -> CsMat<usize> {
261        let (vectorized, _) = self.get_term_and_document_frequencies(x);
262        vectorized
263    }
264
265    /// Given a sequence of `n` file names, produces a sparse array of size `(n, vocabulary_entries)` where column `j` of row `i`
266    /// is the number of occurrences of vocabulary entry `j` in the document contained in the file of index `i`. Vocabulary entry `j` is the string
267    /// at the `j`-th position in the vocabulary. If a vocabulary entry was not encountered in a document, then the relative
268    /// cell in the sparse matrix will be set to `None`.
269    ///
270    /// The files will be read using the specified `encoding`, and any sequence unrecognized by the encoding will be handled
271    /// according to `trap`.
272    pub fn transform_files<P: AsRef<std::path::Path>>(
273        &self,
274        input: &[P],
275        encoding: EncodingRef,
276        trap: DecoderTrap,
277    ) -> CsMat<usize> {
278        let (vectorized, _) = self.get_term_and_document_frequencies_files(input, encoding, trap);
279        vectorized
280    }
281
282    /// Contains all vocabulary entries, in the same order used by the `transform` methods.
283    pub fn vocabulary(&self) -> &Vec<String> {
284        &self.vec_vocabulary
285    }
286
287    /// Counts the occurrence of each vocabulary entry in each document and keeps track of the overall
288    /// document frequency of each entry.
289    pub(crate) fn get_term_and_document_frequencies<T: ToString, D: Data<Elem = T>>(
290        &self,
291        x: &ArrayBase<D, Ix1>,
292    ) -> (CsMat<usize>, Array1<usize>) {
293        let mut document_frequencies = Array1::zeros(self.vocabulary.len());
294        let mut sprs_vectorized = CsMat::empty(sprs::CompressedStorage::CSR, self.vocabulary.len());
295        sprs_vectorized.reserve_outer_dim_exact(x.len());
296        let regex = self.properties.split_regex();
297        for string in x.into_iter().map(|s| s.to_string()) {
298            let row = self.analyze_document(string, &regex, document_frequencies.view_mut());
299            sprs_vectorized = sprs_vectorized.append_outer_csvec(row.view());
300        }
301        (sprs_vectorized, document_frequencies)
302    }
303
304    /// Counts the occurrence of each vocabulary entry in each document and keeps track of the overall
305    /// document frequency of each entry.
306    pub(crate) fn get_term_and_document_frequencies_files<P: AsRef<std::path::Path>>(
307        &self,
308        input: &[P],
309        encoding: EncodingRef,
310        trap: DecoderTrap,
311    ) -> (CsMat<usize>, Array1<usize>) {
312        let mut document_frequencies = Array1::zeros(self.vocabulary.len());
313        let mut sprs_vectorized = CsMat::empty(sprs::CompressedStorage::CSR, self.vocabulary.len());
314        sprs_vectorized.reserve_outer_dim_exact(input.len());
315        let regex = self.properties.split_regex();
316        for file_path in input.iter() {
317            let mut file = std::fs::File::open(file_path).unwrap();
318            let mut document_bytes = Vec::new();
319            file.read_to_end(&mut document_bytes).unwrap();
320            let document = encoding::decode(&document_bytes, trap, encoding).0.unwrap();
321            sprs_vectorized = sprs_vectorized.append_outer_csvec(
322                self.analyze_document(document, &regex, document_frequencies.view_mut())
323                    .view(),
324            );
325        }
326        (sprs_vectorized, document_frequencies)
327    }
328
329    /// Produces a sparse array which counts the occurrences of each vocbulary entry in the given document. Also increases
330    /// the document frequency of all entries found.
331    fn analyze_document(
332        &self,
333        document: String,
334        regex: &Regex,
335        mut doc_freqs: ArrayViewMut1<usize>,
336    ) -> CsVec<usize> {
337        // A dense array is needed to parse each document, since sparse arrays can be mutated only
338        // if all insertions are made with increasing index. Since  vocabulary entries can be
339        // encountered in any order this condition does not hold true in this case.
340        // However, keeping only one dense array at a time, greatly limits memory consumption
341        // in sparse cases.
342        let mut term_frequencies: Array1<usize> = Array1::zeros(self.vocabulary.len());
343        let string = transform_string(document, &self.properties);
344        let words = regex.find_iter(&string).map(|mat| mat.as_str()).collect();
345        let list = NGramList::new(words, self.properties.n_gram_range());
346        for ngram_items in list {
347            for item in ngram_items {
348                if let Some((item_index, _)) = self.vocabulary.get(&item) {
349                    let term_freq = term_frequencies.get_mut(*item_index).unwrap();
350                    *term_freq += 1;
351                }
352            }
353        }
354        let mut sprs_term_frequencies = CsVec::empty(self.vocabulary.len());
355
356        // only insert non-zero elements in order to keep a sparse representation
357        for (i, freq) in term_frequencies
358            .into_iter()
359            .enumerate()
360            .filter(|(_, f)| *f > 0)
361        {
362            sprs_term_frequencies.append(i, freq);
363            doc_freqs[i] += 1;
364        }
365        sprs_term_frequencies
366    }
367}
368
369fn transform_string(mut string: String, properties: &CountVectorizerValidParams) -> String {
370    if properties.normalize() {
371        string = string.nfkd().collect();
372    }
373    if properties.convert_to_lowercase() {
374        string = string.to_lowercase();
375    }
376    string
377}
378
379fn hashmap_to_vocabulary(map: &mut HashMap<String, (usize, usize)>) -> Vec<String> {
380    let mut vec = Vec::with_capacity(map.len());
381    for (word, (ref mut idx, _)) in map {
382        *idx = vec.len();
383        vec.push(word.clone());
384    }
385    vec
386}
387
388#[cfg(test)]
389mod tests {
390
391    use super::*;
392    use crate::column_for_word;
393    use ndarray::{array, Array2};
394    use std::fs::File;
395    use std::io::Write;
396
397    macro_rules! assert_counts_for_word {
398
399        ($voc:expr, $transf:expr, $(($word:expr, $counts:expr)),*) => {
400            $ (
401                assert_eq!(column_for_word!($voc, $transf, $word), $counts);
402            )*
403        }
404    }
405
406    #[test]
407    fn simple_count_test() {
408        let texts = array!["oNe two three four", "TWO three four", "three;four", "four"];
409        let vectorizer = CountVectorizer::params().fit(&texts).unwrap();
410        let vocabulary = vectorizer.vocabulary();
411        let counts: Array2<usize> = vectorizer.transform(&texts).to_dense();
412        let true_vocabulary = vec!["one", "two", "three", "four"];
413        assert_vocabulary_eq(&true_vocabulary, vocabulary);
414        assert_counts_for_word!(
415            vocabulary,
416            counts,
417            ("one", array![1, 0, 0, 0]),
418            ("two", array![1, 1, 0, 0]),
419            ("three", array![1, 1, 1, 0]),
420            ("four", array![1, 1, 1, 1])
421        );
422
423        let vectorizer = CountVectorizer::params()
424            .n_gram_range(2, 2)
425            .fit(&texts)
426            .unwrap();
427        let vocabulary = vectorizer.vocabulary();
428        let counts: Array2<usize> = vectorizer.transform(&texts).to_dense();
429        let true_vocabulary = vec!["one two", "two three", "three four"];
430        assert_vocabulary_eq(&true_vocabulary, vocabulary);
431        assert_counts_for_word!(
432            vocabulary,
433            counts,
434            ("one two", array![1, 0, 0, 0]),
435            ("two three", array![1, 1, 0, 0]),
436            ("three four", array![1, 1, 1, 0])
437        );
438
439        let vectorizer = CountVectorizer::params()
440            .n_gram_range(1, 2)
441            .fit(&texts)
442            .unwrap();
443        let vocabulary = vectorizer.vocabulary();
444        let counts: Array2<usize> = vectorizer.transform(&texts).to_dense();
445        let true_vocabulary = vec![
446            "one",
447            "one two",
448            "two",
449            "two three",
450            "three",
451            "three four",
452            "four",
453        ];
454        assert_vocabulary_eq(&true_vocabulary, vocabulary);
455        assert_counts_for_word!(
456            vocabulary,
457            counts,
458            ("one", array![1, 0, 0, 0]),
459            ("one two", array![1, 0, 0, 0]),
460            ("two", array![1, 1, 0, 0]),
461            ("two three", array![1, 1, 0, 0]),
462            ("three", array![1, 1, 1, 0]),
463            ("three four", array![1, 1, 1, 0]),
464            ("four", array![1, 1, 1, 1])
465        );
466    }
467
468    #[test]
469    fn simple_count_test_vocabulary() {
470        let texts = array![
471            "apples.and.trees fi",
472            "flowers,and,bees",
473            "trees!here;and trees:there",
474            "four bees and apples and apples again \u{FB01}"
475        ];
476        let vocabulary = ["apples", "bees", "flowers", "trees", "fi"];
477        let vectorizer = CountVectorizer::params()
478            .fit_vocabulary(&vocabulary)
479            .unwrap();
480        let vect_vocabulary = vectorizer.vocabulary();
481        assert_vocabulary_eq(&vocabulary, vect_vocabulary);
482        let transformed: Array2<usize> = vectorizer.transform(&texts).to_dense();
483        assert_counts_for_word!(
484            vect_vocabulary,
485            transformed,
486            ("apples", array![1, 0, 0, 2]),
487            ("bees", array![0, 1, 0, 1]),
488            ("flowers", array![0, 1, 0, 0]),
489            ("trees", array![1, 0, 2, 0]),
490            ("fi", array![1, 0, 0, 1])
491        );
492    }
493
494    #[test]
495    fn simple_count_no_punctuation_test() {
496        let texts = array!["oNe two three four", "TWO three four", "three;four", "four"];
497        let vectorizer = CountVectorizer::params()
498            .split_regex(r"\b[^ ][^ ]+\b")
499            .fit(&texts)
500            .unwrap();
501        let vocabulary = vectorizer.vocabulary();
502        let counts: Array2<usize> = vectorizer.transform(&texts).to_dense();
503        let true_vocabulary = vec!["one", "two", "three", "four", "three;four"];
504        assert_vocabulary_eq(&true_vocabulary, vocabulary);
505        assert_counts_for_word!(
506            vocabulary,
507            counts,
508            ("one", array![1, 0, 0, 0]),
509            ("two", array![1, 1, 0, 0]),
510            ("three", array![1, 1, 0, 0]),
511            ("four", array![1, 1, 0, 1]),
512            ("three;four", array![0, 0, 1, 0])
513        );
514    }
515
516    #[test]
517    fn simple_count_no_lowercase_test() {
518        let texts = array!["oNe two three four", "TWO three four", "three;four", "four"];
519        let vectorizer = CountVectorizer::params()
520            .convert_to_lowercase(false)
521            .fit(&texts)
522            .unwrap();
523        let vocabulary = vectorizer.vocabulary();
524        let counts: Array2<usize> = vectorizer.transform(&texts).to_dense();
525        let true_vocabulary = vec!["oNe", "two", "three", "four", "TWO"];
526        assert_vocabulary_eq(&true_vocabulary, vocabulary);
527        assert_counts_for_word!(
528            vocabulary,
529            counts,
530            ("oNe", array![1, 0, 0, 0]),
531            ("two", array![1, 0, 0, 0]),
532            ("three", array![1, 1, 1, 0]),
533            ("four", array![1, 1, 1, 1]),
534            ("TWO", array![0, 1, 0, 0])
535        );
536    }
537
538    #[test]
539    fn simple_count_no_both_test() {
540        let texts = array![
541            "oNe oNe two three four",
542            "TWO three four",
543            "three;four",
544            "four"
545        ];
546        let vectorizer = CountVectorizer::params()
547            .convert_to_lowercase(false)
548            .split_regex(r"\b[^ ][^ ]+\b")
549            .fit(&texts)
550            .unwrap();
551        let vocabulary = vectorizer.vocabulary();
552        let counts: Array2<usize> = vectorizer.transform(&texts).to_dense();
553        let true_vocabulary = vec!["oNe", "two", "three", "four", "TWO", "three;four"];
554        assert_vocabulary_eq(&true_vocabulary, vocabulary);
555        assert_counts_for_word!(
556            vocabulary,
557            counts,
558            ("oNe", array![2, 0, 0, 0]),
559            ("two", array![1, 0, 0, 0]),
560            ("three", array![1, 1, 0, 0]),
561            ("four", array![1, 1, 0, 1]),
562            ("TWO", array![0, 1, 0, 0]),
563            ("three;four", array![0, 0, 1, 0])
564        );
565    }
566
567    #[test]
568    fn test_min_max_df() {
569        let texts = array![
570            "one and two and three",
571            "three and four and five",
572            "seven and eight",
573            "maybe ten and eleven",
574            "avoid singletons: one two four five seven eight ten eleven and an and"
575        ];
576        let vectorizer = CountVectorizer::params()
577            .document_frequency(2. / 5., 3. / 5.)
578            .fit(&texts)
579            .unwrap();
580        let vocabulary = vectorizer.vocabulary();
581        let true_vocabulary = vec![
582            "one", "two", "three", "four", "five", "seven", "eight", "ten", "eleven",
583        ];
584        assert_vocabulary_eq(&true_vocabulary, vocabulary);
585    }
586
587    #[test]
588    fn test_fit_transform_files() {
589        let text_files = create_test_files();
590        let vectorizer = CountVectorizer::params()
591            .fit_files(
592                &text_files[..],
593                encoding::all::UTF_8,
594                encoding::DecoderTrap::Strict,
595            )
596            .unwrap();
597        let vocabulary = vectorizer.vocabulary();
598        let counts: Array2<usize> = vectorizer
599            .transform_files(
600                &text_files[..],
601                encoding::all::UTF_8,
602                encoding::DecoderTrap::Strict,
603            )
604            .to_dense();
605        let true_vocabulary = vec!["one", "two", "three", "four"];
606        assert_vocabulary_eq(&true_vocabulary, vocabulary);
607        assert_counts_for_word!(
608            vocabulary,
609            counts,
610            ("one", array![1, 0, 0, 0]),
611            ("two", array![1, 1, 0, 0]),
612            ("three", array![1, 1, 1, 0]),
613            ("four", array![1, 1, 1, 1])
614        );
615
616        let vectorizer = CountVectorizer::params()
617            .n_gram_range(2, 2)
618            .fit_files(
619                &text_files[..],
620                encoding::all::UTF_8,
621                encoding::DecoderTrap::Strict,
622            )
623            .unwrap();
624        let vocabulary = vectorizer.vocabulary();
625        let counts: Array2<usize> = vectorizer
626            .transform_files(
627                &text_files[..],
628                encoding::all::UTF_8,
629                encoding::DecoderTrap::Strict,
630            )
631            .to_dense();
632        let true_vocabulary = vec!["one two", "two three", "three four"];
633        assert_vocabulary_eq(&true_vocabulary, vocabulary);
634        assert_counts_for_word!(
635            vocabulary,
636            counts,
637            ("one two", array![1, 0, 0, 0]),
638            ("two three", array![1, 1, 0, 0]),
639            ("three four", array![1, 1, 1, 0])
640        );
641
642        let vectorizer = CountVectorizer::params()
643            .n_gram_range(1, 2)
644            .fit_files(
645                &text_files[..],
646                encoding::all::UTF_8,
647                encoding::DecoderTrap::Strict,
648            )
649            .unwrap();
650        let vocabulary = vectorizer.vocabulary();
651        let counts: Array2<usize> = vectorizer
652            .transform_files(
653                &text_files[..],
654                encoding::all::UTF_8,
655                encoding::DecoderTrap::Strict,
656            )
657            .to_dense();
658        let true_vocabulary = vec![
659            "one",
660            "one two",
661            "two",
662            "two three",
663            "three",
664            "three four",
665            "four",
666        ];
667        assert_vocabulary_eq(&true_vocabulary, vocabulary);
668        assert_counts_for_word!(
669            vocabulary,
670            counts,
671            ("one", array![1, 0, 0, 0]),
672            ("one two", array![1, 0, 0, 0]),
673            ("two", array![1, 1, 0, 0]),
674            ("two three", array![1, 1, 0, 0]),
675            ("three", array![1, 1, 1, 0]),
676            ("three four", array![1, 1, 1, 0]),
677            ("four", array![1, 1, 1, 1])
678        );
679        delete_test_files(&text_files);
680    }
681
682    #[test]
683    fn test_stopwords() {
684        let texts = array![
685            "one and two and three",
686            "three and four and five",
687            "seven and eight",
688            "maybe ten and eleven",
689            "avoid singletons: one two four five seven eight ten eleven and an and"
690        ];
691        let stopwords = ["and", "maybe", "an"];
692        let vectorizer = CountVectorizer::params()
693            .stopwords(&stopwords)
694            .fit(&texts)
695            .unwrap();
696        let vocabulary = vectorizer.vocabulary();
697        let true_vocabulary = vec![
698            "one",
699            "two",
700            "three",
701            "four",
702            "five",
703            "seven",
704            "eight",
705            "ten",
706            "eleven",
707            "avoid",
708            "singletons",
709        ];
710        println!("voc: {:?}", vocabulary);
711        assert_vocabulary_eq(&true_vocabulary, vocabulary);
712    }
713
714    #[test]
715    fn test_invalid_gram_boundaries() {
716        let texts = array!["oNe two three four", "TWO three four", "three;four", "four"];
717        let vectorizer = CountVectorizer::params().n_gram_range(0, 1).fit(&texts);
718        assert!(vectorizer.is_err());
719        let vectorizer = CountVectorizer::params().n_gram_range(1, 0).fit(&texts);
720        assert!(vectorizer.is_err());
721        let vectorizer = CountVectorizer::params().n_gram_range(2, 1).fit(&texts);
722        assert!(vectorizer.is_err());
723        let vectorizer = CountVectorizer::params()
724            .document_frequency(1.1, 1.)
725            .fit(&texts);
726        assert!(vectorizer.is_err());
727        let vectorizer = CountVectorizer::params()
728            .document_frequency(1., -0.1)
729            .fit(&texts);
730        assert!(vectorizer.is_err());
731        let vectorizer = CountVectorizer::params()
732            .document_frequency(0.5, 0.2)
733            .fit(&texts);
734        assert!(vectorizer.is_err());
735    }
736
737    #[test]
738    fn test_invalid_regex() {
739        let texts = array!["oNe two three four", "TWO three four", "three;four", "four"];
740        let vectorizer = CountVectorizer::params().split_regex(r"[").fit(&texts);
741        assert!(vectorizer.is_err())
742    }
743
744    fn assert_vocabulary_eq<T: ToString>(true_voc: &[T], voc: &[String]) {
745        for word in true_voc {
746            assert!(voc.contains(&word.to_string()));
747        }
748        assert_eq!(true_voc.len(), voc.len());
749    }
750
751    fn create_test_files() -> Vec<&'static str> {
752        let file_names = vec![
753            "./count_vectorization_test_file_1",
754            "./count_vectorization_test_file_2",
755            "./count_vectorization_test_file_3",
756            "./count_vectorization_test_file_4",
757        ];
758        let contents = &["oNe two three four", "TWO three four", "three;four", "four"];
759        //create files and write contents
760        for (f_name, f_content) in file_names.iter().zip(contents.iter()) {
761            let mut file = File::create(f_name).unwrap();
762            file.write_all(f_content.as_bytes()).unwrap();
763        }
764        file_names
765    }
766
767    fn delete_test_files(file_names: &[&'static str]) {
768        for f_name in file_names.iter() {
769            std::fs::remove_file(f_name).unwrap();
770        }
771    }
772}