linfa_preprocessing/countgrams/
hyperparams.rs1use crate::PreprocessingError;
2use linfa::ParamGuard;
3use regex::Regex;
4use std::cell::{Ref, RefCell};
5use std::collections::HashSet;
6
7#[cfg(feature = "serde")]
8use serde_crate::{Deserialize, Serialize};
9
10#[derive(Clone, Debug)]
11#[cfg(not(feature = "serde"))]
12struct SerdeRegex(Regex);
13#[derive(Clone, Debug, Serialize, Deserialize)]
14#[serde(crate = "serde_crate")]
15#[cfg(feature = "serde")]
16struct SerdeRegex(serde_regex::Serde<Regex>);
17
18#[cfg(not(feature = "serde"))]
19impl SerdeRegex {
20 fn new(re: &str) -> Result<Self, regex::Error> {
21 Ok(Self(Regex::new(re)?))
22 }
23
24 fn as_re(&self) -> &Regex {
25 &self.0
26 }
27}
28
29#[cfg(feature = "serde")]
30impl SerdeRegex {
31 fn new(re: &str) -> Result<Self, regex::Error> {
32 Ok(Self(serde_regex::Serde(Regex::new(re)?)))
33 }
34
35 fn as_re(&self) -> &Regex {
36 use std::ops::Deref;
37 &self.0.deref()
38 }
39}
40
41#[cfg_attr(
61 feature = "serde",
62 derive(Serialize, Deserialize),
63 serde(crate = "serde_crate")
64)]
65#[derive(Clone, Debug)]
66pub struct CountVectorizerValidParams {
67 convert_to_lowercase: bool,
68 split_regex_expr: String,
69 split_regex: RefCell<Option<SerdeRegex>>,
70 n_gram_range: (usize, usize),
71 normalize: bool,
72 document_frequency: (f32, f32),
73 stopwords: Option<HashSet<String>>,
74}
75
76impl CountVectorizerValidParams {
77 pub fn convert_to_lowercase(&self) -> bool {
78 self.convert_to_lowercase
79 }
80
81 pub fn split_regex(&self) -> Ref<'_, Regex> {
82 Ref::map(self.split_regex.borrow(), |x| x.as_ref().unwrap().as_re())
83 }
84
85 pub fn n_gram_range(&self) -> (usize, usize) {
86 self.n_gram_range
87 }
88
89 pub fn normalize(&self) -> bool {
90 self.normalize
91 }
92
93 pub fn document_frequency(&self) -> (f32, f32) {
94 self.document_frequency
95 }
96
97 pub fn stopwords(&self) -> &Option<HashSet<String>> {
98 &self.stopwords
99 }
100}
101
102#[cfg_attr(
103 feature = "serde",
104 derive(Serialize, Deserialize),
105 serde(crate = "serde_crate")
106)]
107#[derive(Clone, Debug)]
108pub struct CountVectorizerParams(CountVectorizerValidParams);
109
110impl std::default::Default for CountVectorizerParams {
111 fn default() -> Self {
112 Self(CountVectorizerValidParams {
113 convert_to_lowercase: true,
114 split_regex_expr: r"\b\w\w+\b".to_string(),
115 split_regex: RefCell::new(None),
116 n_gram_range: (1, 1),
117 normalize: true,
118 document_frequency: (0., 1.),
119 stopwords: None,
120 })
121 }
122}
123
124impl CountVectorizerParams {
125 pub fn convert_to_lowercase(mut self, convert_to_lowercase: bool) -> Self {
127 self.0.convert_to_lowercase = convert_to_lowercase;
128 self
129 }
130
131 pub fn split_regex(mut self, regex_str: &str) -> Self {
133 self.0.split_regex_expr = regex_str.to_string();
134 self
135 }
136
137 pub fn n_gram_range(mut self, min_n: usize, max_n: usize) -> Self {
143 self.0.n_gram_range = (min_n, max_n);
144 self
145 }
146
147 pub fn normalize(mut self, normalize: bool) -> Self {
149 self.0.normalize = normalize;
150 self
151 }
152
153 pub fn document_frequency(mut self, min_freq: f32, max_freq: f32) -> Self {
156 self.0.document_frequency = (min_freq, max_freq);
157 self
158 }
159
160 pub fn stopwords<T: ToString>(mut self, stopwords: &[T]) -> Self {
162 self.0.stopwords = Some(stopwords.iter().map(|t| t.to_string()).collect());
163 self
164 }
165}
166
167impl ParamGuard for CountVectorizerParams {
168 type Checked = CountVectorizerValidParams;
169 type Error = PreprocessingError;
170
171 fn check_ref(&self) -> Result<&Self::Checked, Self::Error> {
172 let (n_gram_min, n_gram_max) = self.0.n_gram_range;
173 let (min_freq, max_freq) = self.0.document_frequency;
174
175 if n_gram_min == 0 || n_gram_max == 0 {
176 Err(PreprocessingError::InvalidNGramBoundaries(
177 n_gram_min, n_gram_max,
178 ))
179 } else if n_gram_min > n_gram_max {
180 Err(PreprocessingError::FlippedNGramBoundaries(
181 n_gram_min, n_gram_max,
182 ))
183 } else if min_freq < 0. || max_freq < 0. {
184 Err(PreprocessingError::InvalidDocumentFrequencies(
185 min_freq, max_freq,
186 ))
187 } else if max_freq < min_freq {
188 Err(PreprocessingError::FlippedDocumentFrequencies(
189 min_freq, max_freq,
190 ))
191 } else {
192 *self.0.split_regex.borrow_mut() = Some(SerdeRegex::new(&self.0.split_regex_expr)?);
193
194 Ok(&self.0)
195 }
196 }
197
198 fn check(self) -> Result<Self::Checked, Self::Error> {
199 self.check_ref()?;
200 Ok(self.0)
201 }
202}