1use std::collections::HashMap;
7
8use ferrolearn_core::error::FerroError;
9use ndarray::Array2;
10
11#[derive(Debug, Clone)]
37pub struct CountVectorizer {
38 pub max_features: Option<usize>,
40 pub min_df: usize,
42 pub max_df: f64,
45 pub binary: bool,
47 pub lowercase: bool,
49}
50
51impl CountVectorizer {
52 #[must_use]
54 pub fn new() -> Self {
55 Self {
56 max_features: None,
57 min_df: 1,
58 max_df: 1.0,
59 binary: false,
60 lowercase: true,
61 }
62 }
63
64 #[must_use]
66 pub fn max_features(mut self, n: usize) -> Self {
67 self.max_features = Some(n);
68 self
69 }
70
71 #[must_use]
73 pub fn min_df(mut self, min_df: usize) -> Self {
74 self.min_df = min_df;
75 self
76 }
77
78 #[must_use]
80 pub fn max_df(mut self, max_df: f64) -> Self {
81 self.max_df = max_df;
82 self
83 }
84
85 #[must_use]
87 pub fn binary(mut self, binary: bool) -> Self {
88 self.binary = binary;
89 self
90 }
91
92 #[must_use]
94 pub fn lowercase(mut self, lowercase: bool) -> Self {
95 self.lowercase = lowercase;
96 self
97 }
98
99 pub fn fit(&self, docs: &[String]) -> Result<FittedCountVectorizer, FerroError> {
106 let n_docs = docs.len();
107 if n_docs == 0 {
108 return Err(FerroError::InsufficientSamples {
109 required: 1,
110 actual: 0,
111 context: "CountVectorizer::fit".into(),
112 });
113 }
114 if self.max_df <= 0.0 || self.max_df > 1.0 {
115 return Err(FerroError::InvalidParameter {
116 name: "max_df".into(),
117 reason: format!("must be in (0, 1], got {}", self.max_df),
118 });
119 }
120
121 let mut df_counts: HashMap<String, usize> = HashMap::new();
123 for doc in docs {
124 let tokens = tokenize(doc, self.lowercase);
125 let mut seen = std::collections::HashSet::new();
127 for tok in tokens {
128 if seen.insert(tok.clone()) {
129 *df_counts.entry(tok).or_insert(0) += 1;
130 }
131 }
132 }
133
134 let max_df_abs = (self.max_df * n_docs as f64).ceil() as usize;
136 let mut vocab: Vec<String> = df_counts
137 .into_iter()
138 .filter(|(_, count)| *count >= self.min_df && *count <= max_df_abs)
139 .map(|(term, _)| term)
140 .collect();
141 vocab.sort();
142
143 if let Some(max_f) = self.max_features {
145 if vocab.len() > max_f {
146 let mut total_freq: HashMap<String, usize> = HashMap::new();
148 for doc in docs {
149 let tokens = tokenize(doc, self.lowercase);
150 for tok in tokens {
151 if vocab.binary_search(&tok).is_ok() {
152 *total_freq.entry(tok).or_insert(0) += 1;
153 }
154 }
155 }
156 vocab.sort_by(|a, b| {
158 let fa = total_freq.get(a).unwrap_or(&0);
159 let fb = total_freq.get(b).unwrap_or(&0);
160 fb.cmp(fa).then_with(|| a.cmp(b))
161 });
162 vocab.truncate(max_f);
163 vocab.sort(); }
165 }
166
167 let vocabulary: HashMap<String, usize> = vocab
169 .iter()
170 .enumerate()
171 .map(|(i, t)| (t.clone(), i))
172 .collect();
173
174 Ok(FittedCountVectorizer {
175 vocabulary,
176 sorted_terms: vocab,
177 binary: self.binary,
178 lowercase: self.lowercase,
179 })
180 }
181}
182
183impl Default for CountVectorizer {
184 fn default() -> Self {
185 Self::new()
186 }
187}
188
189#[derive(Debug, Clone)]
197pub struct FittedCountVectorizer {
198 vocabulary: HashMap<String, usize>,
200 sorted_terms: Vec<String>,
202 binary: bool,
204 lowercase: bool,
206}
207
208impl FittedCountVectorizer {
209 #[must_use]
211 pub fn vocabulary(&self) -> &[String] {
212 &self.sorted_terms
213 }
214
215 #[must_use]
217 pub fn vocabulary_map(&self) -> &HashMap<String, usize> {
218 &self.vocabulary
219 }
220
221 pub fn transform(&self, docs: &[String]) -> Result<Array2<f64>, FerroError> {
227 if docs.is_empty() {
228 return Err(FerroError::InsufficientSamples {
229 required: 1,
230 actual: 0,
231 context: "FittedCountVectorizer::transform".into(),
232 });
233 }
234
235 let n_docs = docs.len();
236 let n_vocab = self.sorted_terms.len();
237 let mut matrix = Array2::<f64>::zeros((n_docs, n_vocab));
238
239 for (i, doc) in docs.iter().enumerate() {
240 let tokens = tokenize(doc, self.lowercase);
241 for tok in tokens {
242 if let Some(&col) = self.vocabulary.get(&tok) {
243 if self.binary {
244 matrix[[i, col]] = 1.0;
245 } else {
246 matrix[[i, col]] += 1.0;
247 }
248 }
249 }
250 }
251
252 Ok(matrix)
253 }
254}
255
256fn tokenize(doc: &str, lowercase: bool) -> Vec<String> {
262 let text = if lowercase {
263 doc.to_lowercase()
264 } else {
265 doc.to_string()
266 };
267
268 text.split(|c: char| !c.is_alphanumeric())
269 .filter(|s| !s.is_empty())
270 .map(|s| s.to_string())
271 .collect()
272}
273
274#[cfg(test)]
279mod tests {
280 use super::*;
281 use approx::assert_abs_diff_eq;
282
283 #[test]
284 fn test_count_vectorizer_basic() {
285 let docs = vec![
286 "the cat sat".to_string(),
287 "the cat sat on the mat".to_string(),
288 ];
289 let cv = CountVectorizer::new();
290 let fitted = cv.fit(&docs).unwrap();
291 let counts = fitted.transform(&docs).unwrap();
292
293 assert_eq!(counts.nrows(), 2);
294 let vocab = fitted.vocabulary();
295 assert!(vocab.contains(&"cat".to_string()));
296 assert!(vocab.contains(&"the".to_string()));
297 assert!(vocab.contains(&"sat".to_string()));
298
299 let the_idx = fitted.vocabulary_map()["the"];
301 assert_abs_diff_eq!(counts[[0, the_idx]], 1.0, epsilon = 1e-10);
302 assert_abs_diff_eq!(counts[[1, the_idx]], 2.0, epsilon = 1e-10);
303 }
304
305 #[test]
306 fn test_count_vectorizer_binary() {
307 let docs = vec!["the the the".to_string()];
308 let cv = CountVectorizer::new().binary(true);
309 let fitted = cv.fit(&docs).unwrap();
310 let counts = fitted.transform(&docs).unwrap();
311 assert_abs_diff_eq!(counts[[0, 0]], 1.0, epsilon = 1e-10);
313 }
314
315 #[test]
316 fn test_count_vectorizer_lowercase() {
317 let docs = vec!["Hello HELLO hello".to_string()];
318 let cv = CountVectorizer::new();
319 let fitted = cv.fit(&docs).unwrap();
320 let counts = fitted.transform(&docs).unwrap();
321 assert_eq!(fitted.vocabulary().len(), 1);
323 assert_abs_diff_eq!(counts[[0, 0]], 3.0, epsilon = 1e-10);
324 }
325
326 #[test]
327 fn test_count_vectorizer_no_lowercase() {
328 let docs = vec!["Hello hello".to_string()];
329 let cv = CountVectorizer::new().lowercase(false);
330 let fitted = cv.fit(&docs).unwrap();
331 assert_eq!(fitted.vocabulary().len(), 2);
333 }
334
335 #[test]
336 fn test_count_vectorizer_max_features() {
337 let docs = vec!["a b c d e f".to_string()];
338 let cv = CountVectorizer::new().max_features(3);
339 let fitted = cv.fit(&docs).unwrap();
340 assert_eq!(fitted.vocabulary().len(), 3);
341 }
342
343 #[test]
344 fn test_count_vectorizer_min_df() {
345 let docs = vec![
346 "cat dog".to_string(),
347 "cat bird".to_string(),
348 "cat fish".to_string(),
349 ];
350 let cv = CountVectorizer::new().min_df(3);
352 let fitted = cv.fit(&docs).unwrap();
353 assert_eq!(fitted.vocabulary().len(), 1);
354 assert_eq!(fitted.vocabulary()[0], "cat");
355 }
356
357 #[test]
358 fn test_count_vectorizer_max_df() {
359 let docs = vec![
360 "the cat".to_string(),
361 "the dog".to_string(),
362 "the bird".to_string(),
363 ];
364 let cv = CountVectorizer::new().max_df(0.5);
366 let fitted = cv.fit(&docs).unwrap();
367 assert!(!fitted.vocabulary().contains(&"the".to_string()));
368 }
369
370 #[test]
371 fn test_count_vectorizer_empty_corpus() {
372 let docs: Vec<String> = vec![];
373 let cv = CountVectorizer::new();
374 assert!(cv.fit(&docs).is_err());
375 }
376
377 #[test]
378 fn test_count_vectorizer_transform_empty() {
379 let docs = vec!["hello world".to_string()];
380 let fitted = CountVectorizer::new().fit(&docs).unwrap();
381 let empty: Vec<String> = vec![];
382 assert!(fitted.transform(&empty).is_err());
383 }
384
385 #[test]
386 fn test_count_vectorizer_unseen_tokens() {
387 let train = vec!["cat dog".to_string()];
388 let fitted = CountVectorizer::new().fit(&train).unwrap();
389 let test = vec!["fish bird".to_string()];
390 let counts = fitted.transform(&test).unwrap();
391 for &v in counts.iter() {
393 assert_abs_diff_eq!(v, 0.0, epsilon = 1e-10);
394 }
395 }
396}