ferrolearn_preprocess/count_vectorizer.rs
1//! Count vectorizer: convert text documents to a term-count matrix.
2//!
3//! Tokenizes documents into runs of 2+ word characters (the Rust analog of
4//! scikit-learn's default `token_pattern=r"(?u)\b\w\w+\b"`,
5//! `sklearn/feature_extraction/text.py:1161`), builds an alphabetically-sorted
6//! vocabulary, and produces a term-count matrix of shape `(n_docs, n_vocab)`.
7//!
8//! Translation target: scikit-learn 1.5.2 `class CountVectorizer` (`text.py:929`).
9//! Design: `.design/preprocess/count_vectorizer.md`. Tracking: #1216.
10//!
11//! `## REQ status`
12//!
13//! | REQ | Status | Anchor |
14//! |---|---|---|
15//! | REQ-1 default fit/transform, sorted vocab, count matrix | SHIPPED (scoped: dense) | `CountVectorizer::fit` / `FittedCountVectorizer::transform`; sklearn `_count_vocab` `text.py:1242-1305` |
16//! | REQ-2 default token_pattern (drop length-1, `_` word char) | SHIPPED (#1217) | `fn tokenize`; sklearn `text.py:1161`, `build_tokenizer:350` |
17//! | REQ-3 binary count clipping | SHIPPED | `FittedCountVectorizer::transform`; sklearn `text.py:1374` |
18//! | REQ-4 lowercase toggle | SHIPPED | `fn tokenize`; sklearn `text.py:1157`,`:323` |
19//! | REQ-5 max_df/min_df int-vs-float duality + threshold errors | NOT-STARTED (#1219; ceil sub-fix shipped #1218; max_df<min_df + post-prune empty-vocab errors shipped #2337) | `fit` df-filter; sklearn `text.py:1379-1382`,`:1236-1239` |
20//! | REQ-6 ngram_range word n-grams | NOT-STARTED (#1220) | sklearn `_word_ngrams` `text.py:242` |
21//! | REQ-7 max_features top-N + tie/sort | SHIPPED (scoped) | `fit`; sklearn `_limit_features` `text.py:1222-1227` |
22//! | REQ-8 tokenizer/token_pattern/preprocessor/analyzer/strip_accents | NOT-STARTED (#1221) | sklearn `build_analyzer` `text.py:419` |
23//! | REQ-9 stop_words | NOT-STARTED (#1222) | sklearn `get_stop_words` `text.py:370` |
24//! | REQ-10 fixed vocabulary param + dtype | NOT-STARTED (#1223) | sklearn `_count_vocab` `text.py:1242-1244`,`:1147` |
25//! | REQ-11 sparse CSR output | NOT-STARTED (#1224) | sklearn `_count_vocab` `text.py:1299-1304` |
26//! | REQ-12 get_feature_names_out contract | NOT-STARTED (#1225) | sklearn `text.py:1455` |
27//! | REQ-13 HashingVectorizer | NOT-STARTED (#1226) | sklearn `class HashingVectorizer` `text.py:562` |
28//! | REQ-14 full 16-param ctor + _parameter_constraints | NOT-STARTED (#1227) | sklearn `text.py:1124-1148` |
29//! | REQ-14a empty-vocabulary ValueError parity (post-tokenize + max_df<min_df + post-prune) | SHIPPED (#2336 #2337) | `CountVectorizer::fit` empty-vocab/`max_df`/post-prune `Err(InvalidParameter)`; sklearn `text.py:1277-1279`,`:1381-1382`,`:1236-1239`. Consumer: crate re-export `pub use count_vectorizer::CountVectorizer` (`lib.rs`). |
30//! | REQ-15 PyO3 binding | NOT-STARTED (#1228) | `ferrolearn-python/src/transformers.rs` (absent) |
31
32use std::collections::HashMap;
33
34use ferrolearn_core::error::FerroError;
35use ndarray::Array2;
36
37// ---------------------------------------------------------------------------
38// CountVectorizer (unfitted)
39// ---------------------------------------------------------------------------
40
41/// An unfitted count vectorizer.
42///
43/// Tokenizes documents by splitting on non-alphanumeric boundaries, builds a
44/// vocabulary sorted alphabetically, and transforms documents into a
45/// term-count matrix.
46///
47/// # Examples
48///
49/// ```
50/// use ferrolearn_preprocess::count_vectorizer::{CountVectorizer, FittedCountVectorizer};
51///
52/// let docs = vec![
53/// "the cat sat".to_string(),
54/// "the cat sat on the mat".to_string(),
55/// ];
56/// let cv = CountVectorizer::new();
57/// let fitted = cv.fit(&docs).unwrap();
58/// let counts = fitted.transform(&docs).unwrap();
59/// assert_eq!(counts.nrows(), 2);
60/// assert_eq!(counts.ncols(), fitted.vocabulary().len());
61/// ```
62#[derive(Debug, Clone)]
63pub struct CountVectorizer {
64 /// Maximum number of features (vocabulary size). `None` means no limit.
65 pub max_features: Option<usize>,
66 /// Minimum document frequency (absolute count) for a term to be included.
67 pub min_df: usize,
68 /// Maximum document frequency as a fraction of total documents.
69 /// Terms appearing in more than `max_df * n_docs` documents are excluded.
70 pub max_df: f64,
71 /// If `true`, all counts are clipped to 0/1 (binary occurrence).
72 pub binary: bool,
73 /// If `true`, lowercase all tokens before counting.
74 pub lowercase: bool,
75}
76
77impl CountVectorizer {
78 /// Create a new `CountVectorizer` with default settings.
79 #[must_use]
80 pub fn new() -> Self {
81 Self {
82 max_features: None,
83 min_df: 1,
84 max_df: 1.0,
85 binary: false,
86 lowercase: true,
87 }
88 }
89
90 /// Set the maximum number of features.
91 #[must_use]
92 pub fn max_features(mut self, n: usize) -> Self {
93 self.max_features = Some(n);
94 self
95 }
96
97 /// Set the minimum document frequency.
98 #[must_use]
99 pub fn min_df(mut self, min_df: usize) -> Self {
100 self.min_df = min_df;
101 self
102 }
103
104 /// Set the maximum document frequency as a fraction of total documents.
105 #[must_use]
106 pub fn max_df(mut self, max_df: f64) -> Self {
107 self.max_df = max_df;
108 self
109 }
110
111 /// Enable or disable binary mode.
112 #[must_use]
113 pub fn binary(mut self, binary: bool) -> Self {
114 self.binary = binary;
115 self
116 }
117
118 /// Enable or disable lowercasing.
119 #[must_use]
120 pub fn lowercase(mut self, lowercase: bool) -> Self {
121 self.lowercase = lowercase;
122 self
123 }
124
125 /// Fit the vectorizer on a corpus of documents.
126 ///
127 /// # Errors
128 ///
129 /// Returns [`FerroError::InsufficientSamples`] if the corpus is empty.
130 /// Returns [`FerroError::InvalidParameter`] if `max_df` is not in `(0, 1]`.
131 pub fn fit(&self, docs: &[String]) -> Result<FittedCountVectorizer, FerroError> {
132 let n_docs = docs.len();
133 if n_docs == 0 {
134 return Err(FerroError::InsufficientSamples {
135 required: 1,
136 actual: 0,
137 context: "CountVectorizer::fit".into(),
138 });
139 }
140 if self.max_df <= 0.0 || self.max_df > 1.0 {
141 return Err(FerroError::InvalidParameter {
142 name: "max_df".into(),
143 reason: format!("must be in (0, 1], got {}", self.max_df),
144 });
145 }
146
147 // Build document-frequency counts.
148 let mut df_counts: HashMap<String, usize> = HashMap::new();
149 for doc in docs {
150 let tokens = tokenize(doc, self.lowercase);
151 // Unique tokens per document.
152 let mut seen = std::collections::HashSet::new();
153 for tok in tokens {
154 if seen.insert(tok.clone()) {
155 *df_counts.entry(tok).or_insert(0) += 1;
156 }
157 }
158 }
159
160 // Empty-vocabulary error (before df-pruning). sklearn's `_count_vocab`
161 // raises `ValueError("empty vocabulary; perhaps the documents only
162 // contain stop words")` when the assembled vocabulary is empty
163 // (`sklearn/feature_extraction/text.py:1277-1279`). This fires when every
164 // token is dropped by the token_pattern (e.g. all length-1 tokens).
165 if df_counts.is_empty() {
166 return Err(FerroError::InvalidParameter {
167 name: "vocabulary".into(),
168 reason: "empty vocabulary; perhaps the documents only contain stop words".into(),
169 });
170 }
171
172 // max_df-vs-min_df cross-validation. sklearn computes the document-count
173 // bounds (`text.py:1379-1380`) and raises
174 // `ValueError("max_df corresponds to < documents than min_df")` when the
175 // max_df bound is below the min_df bound (`text.py:1381-1382`). Here
176 // `max_df` is a float proportion (bound = `max_df * n_doc`) and `min_df`
177 // is an absolute document count (bound = `min_df`).
178 let max_df_count = self.max_df * n_docs as f64;
179 let min_doc_count = self.min_df as f64;
180 if max_df_count < min_doc_count {
181 return Err(FerroError::InvalidParameter {
182 name: "max_df".into(),
183 reason: "max_df corresponds to < documents than min_df".into(),
184 });
185 }
186
187 // Filter by min_df and max_df.
188 //
189 // sklearn 1.5.2 computes `max_doc_count = max_df * n_doc` as a FLOAT with
190 // NO rounding (`sklearn/feature_extraction/text.py:1379`) and keeps terms
191 // with `df <= max_doc_count` (`_limit_features`, `text.py:1219`:
192 // `mask &= dfs <= high`). We mirror that exactly: compare the integer
193 // document count against the un-rounded float threshold. (Note: sklearn
194 // also accepts an integer `max_df` as an absolute count; that int-vs-float
195 // duality is a separate gap and is intentionally not implemented here.)
196 // (`max_df_count` is computed above for the max_df-vs-min_df check.)
197 let mut vocab: Vec<String> = df_counts
198 .into_iter()
199 .filter(|(_, count)| *count >= self.min_df && (*count as f64) <= max_df_count)
200 .map(|(term, _)| term)
201 .collect();
202 vocab.sort();
203
204 // Apply max_features: keep the top-N by total corpus frequency.
205 if let Some(max_f) = self.max_features
206 && vocab.len() > max_f
207 {
208 // Re-count total frequencies for the remaining terms.
209 let mut total_freq: HashMap<String, usize> = HashMap::new();
210 for doc in docs {
211 let tokens = tokenize(doc, self.lowercase);
212 for tok in tokens {
213 if vocab.binary_search(&tok).is_ok() {
214 *total_freq.entry(tok).or_insert(0) += 1;
215 }
216 }
217 }
218 // Sort by descending frequency, then alphabetically for ties.
219 vocab.sort_by(|a, b| {
220 let fa = total_freq.get(a).unwrap_or(&0);
221 let fb = total_freq.get(b).unwrap_or(&0);
222 fb.cmp(fa).then_with(|| a.cmp(b))
223 });
224 vocab.truncate(max_f);
225 vocab.sort(); // restore alphabetical order for consistent indexing
226 }
227
228 // Post-pruning empty-vocabulary error. sklearn's `_limit_features` raises
229 // `ValueError("After pruning, no terms remain. Try a lower min_df or a
230 // higher max_df.")` when the df/max_features filter removes every term
231 // (`sklearn/feature_extraction/text.py:1236-1239`).
232 if vocab.is_empty() {
233 return Err(FerroError::InvalidParameter {
234 name: "vocabulary".into(),
235 reason: "After pruning, no terms remain. Try a lower min_df or a higher max_df."
236 .into(),
237 });
238 }
239
240 // Build vocabulary mapping.
241 let vocabulary: HashMap<String, usize> = vocab
242 .iter()
243 .enumerate()
244 .map(|(i, t)| (t.clone(), i))
245 .collect();
246
247 Ok(FittedCountVectorizer {
248 vocabulary,
249 sorted_terms: vocab,
250 binary: self.binary,
251 lowercase: self.lowercase,
252 })
253 }
254}
255
256impl Default for CountVectorizer {
257 fn default() -> Self {
258 Self::new()
259 }
260}
261
262// ---------------------------------------------------------------------------
263// FittedCountVectorizer
264// ---------------------------------------------------------------------------
265
266/// A fitted count vectorizer holding the learned vocabulary.
267///
268/// Created by calling [`CountVectorizer::fit`].
269#[derive(Debug, Clone)]
270pub struct FittedCountVectorizer {
271 /// Map from term to column index.
272 vocabulary: HashMap<String, usize>,
273 /// Sorted vocabulary terms (for deterministic column ordering).
274 sorted_terms: Vec<String>,
275 /// Whether to clip counts to binary.
276 binary: bool,
277 /// Whether to lowercase tokens.
278 lowercase: bool,
279}
280
281impl FittedCountVectorizer {
282 /// Return the vocabulary as a sorted slice of terms.
283 #[must_use]
284 pub fn vocabulary(&self) -> &[String] {
285 &self.sorted_terms
286 }
287
288 /// Return the vocabulary mapping (term -> column index).
289 #[must_use]
290 pub fn vocabulary_map(&self) -> &HashMap<String, usize> {
291 &self.vocabulary
292 }
293
294 /// Transform documents into a term-count matrix.
295 ///
296 /// # Errors
297 ///
298 /// Returns [`FerroError::InsufficientSamples`] if `docs` is empty.
299 pub fn transform(&self, docs: &[String]) -> Result<Array2<f64>, FerroError> {
300 if docs.is_empty() {
301 return Err(FerroError::InsufficientSamples {
302 required: 1,
303 actual: 0,
304 context: "FittedCountVectorizer::transform".into(),
305 });
306 }
307
308 let n_docs = docs.len();
309 let n_vocab = self.sorted_terms.len();
310 let mut matrix = Array2::<f64>::zeros((n_docs, n_vocab));
311
312 for (i, doc) in docs.iter().enumerate() {
313 let tokens = tokenize(doc, self.lowercase);
314 for tok in tokens {
315 if let Some(&col) = self.vocabulary.get(&tok) {
316 if self.binary {
317 matrix[[i, col]] = 1.0;
318 } else {
319 matrix[[i, col]] += 1.0;
320 }
321 }
322 }
323 }
324
325 Ok(matrix)
326 }
327}
328
329// ---------------------------------------------------------------------------
330// Tokenizer
331// ---------------------------------------------------------------------------
332
333/// Tokenize a document, matching scikit-learn's default `token_pattern`.
334///
335/// sklearn 1.5.2 defaults to `token_pattern=r"(?u)\b\w\w+\b"`
336/// (`sklearn/feature_extraction/text.py:1161`), which matches maximal runs of
337/// 2+ word characters where `\w = [A-Za-z0-9_]` (Unicode-aware via `(?u)`).
338/// We therefore treat a char as part of a token iff it is alphanumeric or `_`
339/// (`char::is_alphanumeric` is Unicode-aware, the faithful analog of `\w`), and
340/// keep only tokens of length >= 2, dropping single-char tokens.
341fn tokenize(doc: &str, lowercase: bool) -> Vec<String> {
342 let text = if lowercase {
343 doc.to_lowercase()
344 } else {
345 doc.to_string()
346 };
347
348 text.split(|c: char| !(c.is_alphanumeric() || c == '_'))
349 .filter(|s| !s.is_empty() && s.chars().count() >= 2)
350 .map(std::string::ToString::to_string)
351 .collect()
352}
353
354// ---------------------------------------------------------------------------
355// Tests
356// ---------------------------------------------------------------------------
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use approx::assert_abs_diff_eq;
362
363 #[test]
364 fn test_count_vectorizer_basic() {
365 let docs = vec![
366 "the cat sat".to_string(),
367 "the cat sat on the mat".to_string(),
368 ];
369 let cv = CountVectorizer::new();
370 let fitted = cv.fit(&docs).unwrap();
371 let counts = fitted.transform(&docs).unwrap();
372
373 assert_eq!(counts.nrows(), 2);
374 let vocab = fitted.vocabulary();
375 assert!(vocab.contains(&"cat".to_string()));
376 assert!(vocab.contains(&"the".to_string()));
377 assert!(vocab.contains(&"sat".to_string()));
378
379 // "the" appears once in doc 0, twice in doc 1
380 let the_idx = fitted.vocabulary_map()["the"];
381 assert_abs_diff_eq!(counts[[0, the_idx]], 1.0, epsilon = 1e-10);
382 assert_abs_diff_eq!(counts[[1, the_idx]], 2.0, epsilon = 1e-10);
383 }
384
385 #[test]
386 fn test_count_vectorizer_binary() {
387 let docs = vec!["the the the".to_string()];
388 let cv = CountVectorizer::new().binary(true);
389 let fitted = cv.fit(&docs).unwrap();
390 let counts = fitted.transform(&docs).unwrap();
391 // "the" count should be 1 (binary mode)
392 assert_abs_diff_eq!(counts[[0, 0]], 1.0, epsilon = 1e-10);
393 }
394
395 #[test]
396 fn test_count_vectorizer_lowercase() {
397 let docs = vec!["Hello HELLO hello".to_string()];
398 let cv = CountVectorizer::new();
399 let fitted = cv.fit(&docs).unwrap();
400 let counts = fitted.transform(&docs).unwrap();
401 // All should fold to "hello", count = 3
402 assert_eq!(fitted.vocabulary().len(), 1);
403 assert_abs_diff_eq!(counts[[0, 0]], 3.0, epsilon = 1e-10);
404 }
405
406 #[test]
407 fn test_count_vectorizer_no_lowercase() {
408 let docs = vec!["Hello hello".to_string()];
409 let cv = CountVectorizer::new().lowercase(false);
410 let fitted = cv.fit(&docs).unwrap();
411 // "Hello" and "hello" are different tokens
412 assert_eq!(fitted.vocabulary().len(), 2);
413 }
414
415 /// max_features keeps the top-N terms by total corpus frequency.
416 ///
417 /// LIVE oracle (sklearn 1.5.2):
418 /// CountVectorizer(max_features=3).fit_transform(
419 /// ['cat cat cat dog dog bird ant','cat dog bird'])
420 /// sorted(get_feature_names_out()) -> ['bird','cat','dog']
421 /// ('ant' has corpus frequency 1, the lowest, so it is dropped)
422 #[test]
423 fn test_count_vectorizer_max_features() {
424 let docs = vec![
425 "cat cat cat dog dog bird ant".to_string(),
426 "cat dog bird".to_string(),
427 ];
428 let cv = CountVectorizer::new().max_features(3);
429 let fitted = cv.fit(&docs).unwrap();
430 let mut vocab = fitted.vocabulary().to_vec();
431 vocab.sort();
432 assert_eq!(vocab, vec!["bird", "cat", "dog"]);
433 }
434
435 #[test]
436 fn test_count_vectorizer_min_df() {
437 let docs = vec![
438 "cat dog".to_string(),
439 "cat bird".to_string(),
440 "cat fish".to_string(),
441 ];
442 // Only "cat" appears in all 3 docs
443 let cv = CountVectorizer::new().min_df(3);
444 let fitted = cv.fit(&docs).unwrap();
445 assert_eq!(fitted.vocabulary().len(), 1);
446 assert_eq!(fitted.vocabulary()[0], "cat");
447 }
448
449 #[test]
450 fn test_count_vectorizer_max_df() {
451 let docs = vec![
452 "the cat".to_string(),
453 "the dog".to_string(),
454 "the bird".to_string(),
455 ];
456 // "the" appears in 100% of docs. max_df=0.5 should exclude it.
457 let cv = CountVectorizer::new().max_df(0.5);
458 let fitted = cv.fit(&docs).unwrap();
459 assert!(!fitted.vocabulary().contains(&"the".to_string()));
460 }
461
462 #[test]
463 fn test_count_vectorizer_empty_corpus() {
464 let docs: Vec<String> = vec![];
465 let cv = CountVectorizer::new();
466 assert!(cv.fit(&docs).is_err());
467 }
468
469 #[test]
470 fn test_count_vectorizer_transform_empty() {
471 let docs = vec!["hello world".to_string()];
472 let fitted = CountVectorizer::new().fit(&docs).unwrap();
473 let empty: Vec<String> = vec![];
474 assert!(fitted.transform(&empty).is_err());
475 }
476
477 #[test]
478 fn test_count_vectorizer_unseen_tokens() {
479 let train = vec!["cat dog".to_string()];
480 let fitted = CountVectorizer::new().fit(&train).unwrap();
481 let test = vec!["fish bird".to_string()];
482 let counts = fitted.transform(&test).unwrap();
483 // All zeros since no tokens match
484 for &v in &counts {
485 assert_abs_diff_eq!(v, 0.0, epsilon = 1e-10);
486 }
487 }
488}