linfa_preprocessing/countgrams/
hyperparams.rs

1use crate::PreprocessingError;
2use linfa::ParamGuard;
3use regex::Regex;
4use std::cell::{Ref, RefCell};
5use std::collections::HashSet;
6
7#[cfg(feature = "serde")]
8use serde_crate::{Deserialize, Serialize};
9
10#[derive(Clone, Debug)]
11#[cfg(not(feature = "serde"))]
12struct SerdeRegex(Regex);
13#[derive(Clone, Debug, Serialize, Deserialize)]
14#[serde(crate = "serde_crate")]
15#[cfg(feature = "serde")]
16struct SerdeRegex(serde_regex::Serde<Regex>);
17
18#[cfg(not(feature = "serde"))]
19impl SerdeRegex {
20    fn new(re: &str) -> Result<Self, regex::Error> {
21        Ok(Self(Regex::new(re)?))
22    }
23
24    fn as_re(&self) -> &Regex {
25        &self.0
26    }
27}
28
29#[cfg(feature = "serde")]
30impl SerdeRegex {
31    fn new(re: &str) -> Result<Self, regex::Error> {
32        Ok(Self(serde_regex::Serde(Regex::new(re)?)))
33    }
34
35    fn as_re(&self) -> &Regex {
36        use std::ops::Deref;
37        &self.0.deref()
38    }
39}
40
41/// Count vectorizer: learns a vocabulary from a sequence of documents (or file paths) and maps each
42/// vocabulary entry to an integer value, producing a [CountVectorizer](crate::CountVectorizer) that can
43/// be used to count the occurrences of each vocabulary entry in any sequence of documents. Alternatively a user-specified vocabulary can
44/// be used for fitting.
45///
46/// ### Attributes
47///
48/// If a user-defined vocabulary is used for fitting then the following attributes will not be considered during the fitting phase but
49/// they will still be used by the [CountVectorizer](crate::CountVectorizer) to transform any text to be examined.
50///
51/// * `split_regex`: the regex espression used to split decuments into tokens. Defaults to r"\\b\\w\\w+\\b", which selects "words", using whitespaces and
52///     punctuation symbols as separators.
53/// * `convert_to_lowercase`: if true, all documents used for fitting will be converted to lowercase. Defaults to `true`.
54/// * `n_gram_range`: if set to `(1,1)` single tokens will be candidate vocabulary entries, if `(2,2)` then adjacent token pairs will be considered,
55///    if `(1,2)` then both single tokens and adjacent token pairs will be considered, and so on. The definition of token depends on the
56///    regex used fpr splitting the documents. The default value is `(1,1)`.
57/// * `normalize`: if true, all charachters in the documents used for fitting will be normalized according to unicode's NFKD normalization. Defaults to `true`.
58/// * `document_frequency`: specifies the minimum and maximum (relative) document frequencies that each vocabulary entry must satisfy. Defaults to `(0., 1.)` (i.e. 0% minimum and 100% maximum)
59/// * `stopwords`: optional list of entries to be excluded from the generated vocabulary. Defaults to `None`
60#[cfg_attr(
61    feature = "serde",
62    derive(Serialize, Deserialize),
63    serde(crate = "serde_crate")
64)]
65#[derive(Clone, Debug)]
66pub struct CountVectorizerValidParams {
67    convert_to_lowercase: bool,
68    split_regex_expr: String,
69    split_regex: RefCell<Option<SerdeRegex>>,
70    n_gram_range: (usize, usize),
71    normalize: bool,
72    document_frequency: (f32, f32),
73    stopwords: Option<HashSet<String>>,
74}
75
76impl CountVectorizerValidParams {
77    pub fn convert_to_lowercase(&self) -> bool {
78        self.convert_to_lowercase
79    }
80
81    pub fn split_regex(&self) -> Ref<'_, Regex> {
82        Ref::map(self.split_regex.borrow(), |x| x.as_ref().unwrap().as_re())
83    }
84
85    pub fn n_gram_range(&self) -> (usize, usize) {
86        self.n_gram_range
87    }
88
89    pub fn normalize(&self) -> bool {
90        self.normalize
91    }
92
93    pub fn document_frequency(&self) -> (f32, f32) {
94        self.document_frequency
95    }
96
97    pub fn stopwords(&self) -> &Option<HashSet<String>> {
98        &self.stopwords
99    }
100}
101
102#[cfg_attr(
103    feature = "serde",
104    derive(Serialize, Deserialize),
105    serde(crate = "serde_crate")
106)]
107#[derive(Clone, Debug)]
108pub struct CountVectorizerParams(CountVectorizerValidParams);
109
110impl std::default::Default for CountVectorizerParams {
111    fn default() -> Self {
112        Self(CountVectorizerValidParams {
113            convert_to_lowercase: true,
114            split_regex_expr: r"\b\w\w+\b".to_string(),
115            split_regex: RefCell::new(None),
116            n_gram_range: (1, 1),
117            normalize: true,
118            document_frequency: (0., 1.),
119            stopwords: None,
120        })
121    }
122}
123
124impl CountVectorizerParams {
125    ///If true, all documents used for fitting will be converted to lowercase.
126    pub fn convert_to_lowercase(mut self, convert_to_lowercase: bool) -> Self {
127        self.0.convert_to_lowercase = convert_to_lowercase;
128        self
129    }
130
131    /// Sets the regex espression used to split decuments into tokens
132    pub fn split_regex(mut self, regex_str: &str) -> Self {
133        self.0.split_regex_expr = regex_str.to_string();
134        self
135    }
136
137    /// If set to `(1,1)` single tokens will be candidate vocabulary entries, if `(2,2)` then adjacent token pairs will be considered,
138    /// if `(1,2)` then both single tokens and adjacent token pairs will be considered, and so on. The definition of token depends on the
139    /// regex used fpr splitting the documents.
140    ///
141    /// `min_n` should not be greater than `max_n`
142    pub fn n_gram_range(mut self, min_n: usize, max_n: usize) -> Self {
143        self.0.n_gram_range = (min_n, max_n);
144        self
145    }
146
147    /// If true, all charachters in the documents used for fitting will be normalized according to unicode's NFKD normalization.
148    pub fn normalize(mut self, normalize: bool) -> Self {
149        self.0.normalize = normalize;
150        self
151    }
152
153    /// Specifies the minimum and maximum (relative) document frequencies that each vocabulary entry must satisfy.
154    /// `min_freq` and `max_freq` must lie in `0..=1` and `min_freq` should not be greater than `max_freq`
155    pub fn document_frequency(mut self, min_freq: f32, max_freq: f32) -> Self {
156        self.0.document_frequency = (min_freq, max_freq);
157        self
158    }
159
160    /// List of entries to be excluded from the generated vocabulary.
161    pub fn stopwords<T: ToString>(mut self, stopwords: &[T]) -> Self {
162        self.0.stopwords = Some(stopwords.iter().map(|t| t.to_string()).collect());
163        self
164    }
165}
166
167impl ParamGuard for CountVectorizerParams {
168    type Checked = CountVectorizerValidParams;
169    type Error = PreprocessingError;
170
171    fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
172        let (n_gram_min, n_gram_max) = self.0.n_gram_range;
173        let (min_freq, max_freq) = self.0.document_frequency;
174
175        if n_gram_min == 0 || n_gram_max == 0 {
176            Err(PreprocessingError::InvalidNGramBoundaries(
177                n_gram_min, n_gram_max,
178            ))
179        } else if n_gram_min > n_gram_max {
180            Err(PreprocessingError::FlippedNGramBoundaries(
181                n_gram_min, n_gram_max,
182            ))
183        } else if min_freq < 0. || max_freq < 0. {
184            Err(PreprocessingError::InvalidDocumentFrequencies(
185                min_freq, max_freq,
186            ))
187        } else if max_freq < min_freq {
188            Err(PreprocessingError::FlippedDocumentFrequencies(
189                min_freq, max_freq,
190            ))
191        } else {
192            *self.0.split_regex.borrow_mut() = Some(SerdeRegex::new(&self.0.split_regex_expr)?);
193
194            Ok(&self.0)
195        }
196    }
197
198    fn check(self) -> Result<Self::Checked, Self::Error> {
199        self.check_ref()?;
200        Ok(self.0)
201    }
202}