Skip to main content

ferrolearn_preprocess/
count_vectorizer.rs

1//! Count vectorizer: convert text documents to a term-count matrix.
2//!
3//! Tokenizes documents into runs of 2+ word characters (the Rust analog of
4//! scikit-learn's default `token_pattern=r"(?u)\b\w\w+\b"`,
5//! `sklearn/feature_extraction/text.py:1161`), builds an alphabetically-sorted
6//! vocabulary, and produces a term-count matrix of shape `(n_docs, n_vocab)`.
7//!
8//! Translation target: scikit-learn 1.5.2 `class CountVectorizer` (`text.py:929`).
9//! Design: `.design/preprocess/count_vectorizer.md`. Tracking: #1216.
10//!
11//! `## REQ status`
12//!
13//! | REQ | Status | Anchor |
14//! |---|---|---|
15//! | REQ-1 default fit/transform, sorted vocab, count matrix | SHIPPED (scoped: dense) | `CountVectorizer::fit` / `FittedCountVectorizer::transform`; sklearn `_count_vocab` `text.py:1242-1305` |
16//! | REQ-2 default token_pattern (drop length-1, `_` word char) | SHIPPED (#1217) | `fn tokenize`; sklearn `text.py:1161`, `build_tokenizer:350` |
17//! | REQ-3 binary count clipping | SHIPPED | `FittedCountVectorizer::transform`; sklearn `text.py:1374` |
18//! | REQ-4 lowercase toggle | SHIPPED | `fn tokenize`; sklearn `text.py:1157`,`:323` |
19//! | REQ-5 max_df/min_df int-vs-float duality + threshold errors | NOT-STARTED (#1219; ceil sub-fix shipped #1218; max_df<min_df + post-prune empty-vocab errors shipped #2337) | `fit` df-filter; sklearn `text.py:1379-1382`,`:1236-1239` |
20//! | REQ-6 ngram_range word n-grams | NOT-STARTED (#1220) | sklearn `_word_ngrams` `text.py:242` |
21//! | REQ-7 max_features top-N + tie/sort | SHIPPED (scoped) | `fit`; sklearn `_limit_features` `text.py:1222-1227` |
22//! | REQ-8 tokenizer/token_pattern/preprocessor/analyzer/strip_accents | NOT-STARTED (#1221) | sklearn `build_analyzer` `text.py:419` |
23//! | REQ-9 stop_words | NOT-STARTED (#1222) | sklearn `get_stop_words` `text.py:370` |
24//! | REQ-10 fixed vocabulary param + dtype | NOT-STARTED (#1223) | sklearn `_count_vocab` `text.py:1242-1244`,`:1147` |
25//! | REQ-11 sparse CSR output | NOT-STARTED (#1224) | sklearn `_count_vocab` `text.py:1299-1304` |
26//! | REQ-12 get_feature_names_out contract | NOT-STARTED (#1225) | sklearn `text.py:1455` |
27//! | REQ-13 HashingVectorizer | NOT-STARTED (#1226) | sklearn `class HashingVectorizer` `text.py:562` |
28//! | REQ-14 full 16-param ctor + _parameter_constraints | NOT-STARTED (#1227) | sklearn `text.py:1124-1148` |
29//! | REQ-14a empty-vocabulary ValueError parity (post-tokenize + max_df<min_df + post-prune) | SHIPPED (#2336 #2337) | `CountVectorizer::fit` empty-vocab/`max_df`/post-prune `Err(InvalidParameter)`; sklearn `text.py:1277-1279`,`:1381-1382`,`:1236-1239`. Consumer: crate re-export `pub use count_vectorizer::CountVectorizer` (`lib.rs`). |
30//! | REQ-15 PyO3 binding | NOT-STARTED (#1228) | `ferrolearn-python/src/transformers.rs` (absent) |
31
32use std::collections::HashMap;
33
34use ferrolearn_core::error::FerroError;
35use ndarray::Array2;
36
37// ---------------------------------------------------------------------------
38// CountVectorizer (unfitted)
39// ---------------------------------------------------------------------------
40
41/// An unfitted count vectorizer.
42///
43/// Tokenizes documents by splitting on non-alphanumeric boundaries, builds a
44/// vocabulary sorted alphabetically, and transforms documents into a
45/// term-count matrix.
46///
47/// # Examples
48///
49/// ```
50/// use ferrolearn_preprocess::count_vectorizer::{CountVectorizer, FittedCountVectorizer};
51///
52/// let docs = vec![
53///     "the cat sat".to_string(),
54///     "the cat sat on the mat".to_string(),
55/// ];
56/// let cv = CountVectorizer::new();
57/// let fitted = cv.fit(&docs).unwrap();
58/// let counts = fitted.transform(&docs).unwrap();
59/// assert_eq!(counts.nrows(), 2);
60/// assert_eq!(counts.ncols(), fitted.vocabulary().len());
61/// ```
62#[derive(Debug, Clone)]
63pub struct CountVectorizer {
64    /// Maximum number of features (vocabulary size). `None` means no limit.
65    pub max_features: Option<usize>,
66    /// Minimum document frequency (absolute count) for a term to be included.
67    pub min_df: usize,
68    /// Maximum document frequency as a fraction of total documents.
69    /// Terms appearing in more than `max_df * n_docs` documents are excluded.
70    pub max_df: f64,
71    /// If `true`, all counts are clipped to 0/1 (binary occurrence).
72    pub binary: bool,
73    /// If `true`, lowercase all tokens before counting.
74    pub lowercase: bool,
75}
76
77impl CountVectorizer {
78    /// Create a new `CountVectorizer` with default settings.
79    #[must_use]
80    pub fn new() -> Self {
81        Self {
82            max_features: None,
83            min_df: 1,
84            max_df: 1.0,
85            binary: false,
86            lowercase: true,
87        }
88    }
89
90    /// Set the maximum number of features.
91    #[must_use]
92    pub fn max_features(mut self, n: usize) -> Self {
93        self.max_features = Some(n);
94        self
95    }
96
97    /// Set the minimum document frequency.
98    #[must_use]
99    pub fn min_df(mut self, min_df: usize) -> Self {
100        self.min_df = min_df;
101        self
102    }
103
104    /// Set the maximum document frequency as a fraction of total documents.
105    #[must_use]
106    pub fn max_df(mut self, max_df: f64) -> Self {
107        self.max_df = max_df;
108        self
109    }
110
111    /// Enable or disable binary mode.
112    #[must_use]
113    pub fn binary(mut self, binary: bool) -> Self {
114        self.binary = binary;
115        self
116    }
117
118    /// Enable or disable lowercasing.
119    #[must_use]
120    pub fn lowercase(mut self, lowercase: bool) -> Self {
121        self.lowercase = lowercase;
122        self
123    }
124
125    /// Fit the vectorizer on a corpus of documents.
126    ///
127    /// # Errors
128    ///
129    /// Returns [`FerroError::InsufficientSamples`] if the corpus is empty.
130    /// Returns [`FerroError::InvalidParameter`] if `max_df` is not in `(0, 1]`.
131    pub fn fit(&self, docs: &[String]) -> Result<FittedCountVectorizer, FerroError> {
132        let n_docs = docs.len();
133        if n_docs == 0 {
134            return Err(FerroError::InsufficientSamples {
135                required: 1,
136                actual: 0,
137                context: "CountVectorizer::fit".into(),
138            });
139        }
140        if self.max_df <= 0.0 || self.max_df > 1.0 {
141            return Err(FerroError::InvalidParameter {
142                name: "max_df".into(),
143                reason: format!("must be in (0, 1], got {}", self.max_df),
144            });
145        }
146
147        // Build document-frequency counts.
148        let mut df_counts: HashMap<String, usize> = HashMap::new();
149        for doc in docs {
150            let tokens = tokenize(doc, self.lowercase);
151            // Unique tokens per document.
152            let mut seen = std::collections::HashSet::new();
153            for tok in tokens {
154                if seen.insert(tok.clone()) {
155                    *df_counts.entry(tok).or_insert(0) += 1;
156                }
157            }
158        }
159
160        // Empty-vocabulary error (before df-pruning). sklearn's `_count_vocab`
161        // raises `ValueError("empty vocabulary; perhaps the documents only
162        // contain stop words")` when the assembled vocabulary is empty
163        // (`sklearn/feature_extraction/text.py:1277-1279`). This fires when every
164        // token is dropped by the token_pattern (e.g. all length-1 tokens).
165        if df_counts.is_empty() {
166            return Err(FerroError::InvalidParameter {
167                name: "vocabulary".into(),
168                reason: "empty vocabulary; perhaps the documents only contain stop words".into(),
169            });
170        }
171
172        // max_df-vs-min_df cross-validation. sklearn computes the document-count
173        // bounds (`text.py:1379-1380`) and raises
174        // `ValueError("max_df corresponds to < documents than min_df")` when the
175        // max_df bound is below the min_df bound (`text.py:1381-1382`). Here
176        // `max_df` is a float proportion (bound = `max_df * n_doc`) and `min_df`
177        // is an absolute document count (bound = `min_df`).
178        let max_df_count = self.max_df * n_docs as f64;
179        let min_doc_count = self.min_df as f64;
180        if max_df_count < min_doc_count {
181            return Err(FerroError::InvalidParameter {
182                name: "max_df".into(),
183                reason: "max_df corresponds to < documents than min_df".into(),
184            });
185        }
186
187        // Filter by min_df and max_df.
188        //
189        // sklearn 1.5.2 computes `max_doc_count = max_df * n_doc` as a FLOAT with
190        // NO rounding (`sklearn/feature_extraction/text.py:1379`) and keeps terms
191        // with `df <= max_doc_count` (`_limit_features`, `text.py:1219`:
192        // `mask &= dfs <= high`). We mirror that exactly: compare the integer
193        // document count against the un-rounded float threshold. (Note: sklearn
194        // also accepts an integer `max_df` as an absolute count; that int-vs-float
195        // duality is a separate gap and is intentionally not implemented here.)
196        // (`max_df_count` is computed above for the max_df-vs-min_df check.)
197        let mut vocab: Vec<String> = df_counts
198            .into_iter()
199            .filter(|(_, count)| *count >= self.min_df && (*count as f64) <= max_df_count)
200            .map(|(term, _)| term)
201            .collect();
202        vocab.sort();
203
204        // Apply max_features: keep the top-N by total corpus frequency.
205        if let Some(max_f) = self.max_features
206            && vocab.len() > max_f
207        {
208            // Re-count total frequencies for the remaining terms.
209            let mut total_freq: HashMap<String, usize> = HashMap::new();
210            for doc in docs {
211                let tokens = tokenize(doc, self.lowercase);
212                for tok in tokens {
213                    if vocab.binary_search(&tok).is_ok() {
214                        *total_freq.entry(tok).or_insert(0) += 1;
215                    }
216                }
217            }
218            // Sort by descending frequency, then alphabetically for ties.
219            vocab.sort_by(|a, b| {
220                let fa = total_freq.get(a).unwrap_or(&0);
221                let fb = total_freq.get(b).unwrap_or(&0);
222                fb.cmp(fa).then_with(|| a.cmp(b))
223            });
224            vocab.truncate(max_f);
225            vocab.sort(); // restore alphabetical order for consistent indexing
226        }
227
228        // Post-pruning empty-vocabulary error. sklearn's `_limit_features` raises
229        // `ValueError("After pruning, no terms remain. Try a lower min_df or a
230        // higher max_df.")` when the df/max_features filter removes every term
231        // (`sklearn/feature_extraction/text.py:1236-1239`).
232        if vocab.is_empty() {
233            return Err(FerroError::InvalidParameter {
234                name: "vocabulary".into(),
235                reason: "After pruning, no terms remain. Try a lower min_df or a higher max_df."
236                    .into(),
237            });
238        }
239
240        // Build vocabulary mapping.
241        let vocabulary: HashMap<String, usize> = vocab
242            .iter()
243            .enumerate()
244            .map(|(i, t)| (t.clone(), i))
245            .collect();
246
247        Ok(FittedCountVectorizer {
248            vocabulary,
249            sorted_terms: vocab,
250            binary: self.binary,
251            lowercase: self.lowercase,
252        })
253    }
254}
255
256impl Default for CountVectorizer {
257    fn default() -> Self {
258        Self::new()
259    }
260}
261
262// ---------------------------------------------------------------------------
263// FittedCountVectorizer
264// ---------------------------------------------------------------------------
265
266/// A fitted count vectorizer holding the learned vocabulary.
267///
268/// Created by calling [`CountVectorizer::fit`].
269#[derive(Debug, Clone)]
270pub struct FittedCountVectorizer {
271    /// Map from term to column index.
272    vocabulary: HashMap<String, usize>,
273    /// Sorted vocabulary terms (for deterministic column ordering).
274    sorted_terms: Vec<String>,
275    /// Whether to clip counts to binary.
276    binary: bool,
277    /// Whether to lowercase tokens.
278    lowercase: bool,
279}
280
281impl FittedCountVectorizer {
282    /// Return the vocabulary as a sorted slice of terms.
283    #[must_use]
284    pub fn vocabulary(&self) -> &[String] {
285        &self.sorted_terms
286    }
287
288    /// Return the vocabulary mapping (term -> column index).
289    #[must_use]
290    pub fn vocabulary_map(&self) -> &HashMap<String, usize> {
291        &self.vocabulary
292    }
293
294    /// Transform documents into a term-count matrix.
295    ///
296    /// # Errors
297    ///
298    /// Returns [`FerroError::InsufficientSamples`] if `docs` is empty.
299    pub fn transform(&self, docs: &[String]) -> Result<Array2<f64>, FerroError> {
300        if docs.is_empty() {
301            return Err(FerroError::InsufficientSamples {
302                required: 1,
303                actual: 0,
304                context: "FittedCountVectorizer::transform".into(),
305            });
306        }
307
308        let n_docs = docs.len();
309        let n_vocab = self.sorted_terms.len();
310        let mut matrix = Array2::<f64>::zeros((n_docs, n_vocab));
311
312        for (i, doc) in docs.iter().enumerate() {
313            let tokens = tokenize(doc, self.lowercase);
314            for tok in tokens {
315                if let Some(&col) = self.vocabulary.get(&tok) {
316                    if self.binary {
317                        matrix[[i, col]] = 1.0;
318                    } else {
319                        matrix[[i, col]] += 1.0;
320                    }
321                }
322            }
323        }
324
325        Ok(matrix)
326    }
327}
328
329// ---------------------------------------------------------------------------
330// Tokenizer
331// ---------------------------------------------------------------------------
332
333/// Tokenize a document, matching scikit-learn's default `token_pattern`.
334///
335/// sklearn 1.5.2 defaults to `token_pattern=r"(?u)\b\w\w+\b"`
336/// (`sklearn/feature_extraction/text.py:1161`), which matches maximal runs of
337/// 2+ word characters where `\w = [A-Za-z0-9_]` (Unicode-aware via `(?u)`).
338/// We therefore treat a char as part of a token iff it is alphanumeric or `_`
339/// (`char::is_alphanumeric` is Unicode-aware, the faithful analog of `\w`), and
340/// keep only tokens of length >= 2, dropping single-char tokens.
341fn tokenize(doc: &str, lowercase: bool) -> Vec<String> {
342    let text = if lowercase {
343        doc.to_lowercase()
344    } else {
345        doc.to_string()
346    };
347
348    text.split(|c: char| !(c.is_alphanumeric() || c == '_'))
349        .filter(|s| !s.is_empty() && s.chars().count() >= 2)
350        .map(std::string::ToString::to_string)
351        .collect()
352}
353
354// ---------------------------------------------------------------------------
355// Tests
356// ---------------------------------------------------------------------------
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use approx::assert_abs_diff_eq;
362
363    #[test]
364    fn test_count_vectorizer_basic() {
365        let docs = vec![
366            "the cat sat".to_string(),
367            "the cat sat on the mat".to_string(),
368        ];
369        let cv = CountVectorizer::new();
370        let fitted = cv.fit(&docs).unwrap();
371        let counts = fitted.transform(&docs).unwrap();
372
373        assert_eq!(counts.nrows(), 2);
374        let vocab = fitted.vocabulary();
375        assert!(vocab.contains(&"cat".to_string()));
376        assert!(vocab.contains(&"the".to_string()));
377        assert!(vocab.contains(&"sat".to_string()));
378
379        // "the" appears once in doc 0, twice in doc 1
380        let the_idx = fitted.vocabulary_map()["the"];
381        assert_abs_diff_eq!(counts[[0, the_idx]], 1.0, epsilon = 1e-10);
382        assert_abs_diff_eq!(counts[[1, the_idx]], 2.0, epsilon = 1e-10);
383    }
384
385    #[test]
386    fn test_count_vectorizer_binary() {
387        let docs = vec!["the the the".to_string()];
388        let cv = CountVectorizer::new().binary(true);
389        let fitted = cv.fit(&docs).unwrap();
390        let counts = fitted.transform(&docs).unwrap();
391        // "the" count should be 1 (binary mode)
392        assert_abs_diff_eq!(counts[[0, 0]], 1.0, epsilon = 1e-10);
393    }
394
395    #[test]
396    fn test_count_vectorizer_lowercase() {
397        let docs = vec!["Hello HELLO hello".to_string()];
398        let cv = CountVectorizer::new();
399        let fitted = cv.fit(&docs).unwrap();
400        let counts = fitted.transform(&docs).unwrap();
401        // All should fold to "hello", count = 3
402        assert_eq!(fitted.vocabulary().len(), 1);
403        assert_abs_diff_eq!(counts[[0, 0]], 3.0, epsilon = 1e-10);
404    }
405
406    #[test]
407    fn test_count_vectorizer_no_lowercase() {
408        let docs = vec!["Hello hello".to_string()];
409        let cv = CountVectorizer::new().lowercase(false);
410        let fitted = cv.fit(&docs).unwrap();
411        // "Hello" and "hello" are different tokens
412        assert_eq!(fitted.vocabulary().len(), 2);
413    }
414
415    /// max_features keeps the top-N terms by total corpus frequency.
416    ///
417    /// LIVE oracle (sklearn 1.5.2):
418    ///   CountVectorizer(max_features=3).fit_transform(
419    ///       ['cat cat cat dog dog bird ant','cat dog bird'])
420    ///   sorted(get_feature_names_out()) -> ['bird','cat','dog']
421    ///   ('ant' has corpus frequency 1, the lowest, so it is dropped)
422    #[test]
423    fn test_count_vectorizer_max_features() {
424        let docs = vec![
425            "cat cat cat dog dog bird ant".to_string(),
426            "cat dog bird".to_string(),
427        ];
428        let cv = CountVectorizer::new().max_features(3);
429        let fitted = cv.fit(&docs).unwrap();
430        let mut vocab = fitted.vocabulary().to_vec();
431        vocab.sort();
432        assert_eq!(vocab, vec!["bird", "cat", "dog"]);
433    }
434
435    #[test]
436    fn test_count_vectorizer_min_df() {
437        let docs = vec![
438            "cat dog".to_string(),
439            "cat bird".to_string(),
440            "cat fish".to_string(),
441        ];
442        // Only "cat" appears in all 3 docs
443        let cv = CountVectorizer::new().min_df(3);
444        let fitted = cv.fit(&docs).unwrap();
445        assert_eq!(fitted.vocabulary().len(), 1);
446        assert_eq!(fitted.vocabulary()[0], "cat");
447    }
448
449    #[test]
450    fn test_count_vectorizer_max_df() {
451        let docs = vec![
452            "the cat".to_string(),
453            "the dog".to_string(),
454            "the bird".to_string(),
455        ];
456        // "the" appears in 100% of docs. max_df=0.5 should exclude it.
457        let cv = CountVectorizer::new().max_df(0.5);
458        let fitted = cv.fit(&docs).unwrap();
459        assert!(!fitted.vocabulary().contains(&"the".to_string()));
460    }
461
462    #[test]
463    fn test_count_vectorizer_empty_corpus() {
464        let docs: Vec<String> = vec![];
465        let cv = CountVectorizer::new();
466        assert!(cv.fit(&docs).is_err());
467    }
468
469    #[test]
470    fn test_count_vectorizer_transform_empty() {
471        let docs = vec!["hello world".to_string()];
472        let fitted = CountVectorizer::new().fit(&docs).unwrap();
473        let empty: Vec<String> = vec![];
474        assert!(fitted.transform(&empty).is_err());
475    }
476
477    #[test]
478    fn test_count_vectorizer_unseen_tokens() {
479        let train = vec!["cat dog".to_string()];
480        let fitted = CountVectorizer::new().fit(&train).unwrap();
481        let test = vec!["fish bird".to_string()];
482        let counts = fitted.transform(&test).unwrap();
483        // All zeros since no tokens match
484        for &v in &counts {
485            assert_abs_diff_eq!(v, 0.0, epsilon = 1e-10);
486        }
487    }
488}