Skip to main content

ferrolearn_preprocess/
count_vectorizer.rs

1//! Count vectorizer: convert text documents to a term-count matrix.
2//!
3//! Tokenizes documents by splitting on non-alphanumeric characters, builds a
4//! vocabulary, and produces a term-count matrix of shape `(n_docs, n_vocab)`.
5
6use std::collections::HashMap;
7
8use ferrolearn_core::error::FerroError;
9use ndarray::Array2;
10
11// ---------------------------------------------------------------------------
12// CountVectorizer (unfitted)
13// ---------------------------------------------------------------------------
14
15/// An unfitted count vectorizer.
16///
17/// Tokenizes documents by splitting on non-alphanumeric boundaries, builds a
18/// vocabulary sorted alphabetically, and transforms documents into a
19/// term-count matrix.
20///
21/// # Examples
22///
23/// ```
24/// use ferrolearn_preprocess::count_vectorizer::{CountVectorizer, FittedCountVectorizer};
25///
26/// let docs = vec![
27///     "the cat sat".to_string(),
28///     "the cat sat on the mat".to_string(),
29/// ];
30/// let cv = CountVectorizer::new();
31/// let fitted = cv.fit(&docs).unwrap();
32/// let counts = fitted.transform(&docs).unwrap();
33/// assert_eq!(counts.nrows(), 2);
34/// assert_eq!(counts.ncols(), fitted.vocabulary().len());
35/// ```
36#[derive(Debug, Clone)]
37pub struct CountVectorizer {
38    /// Maximum number of features (vocabulary size). `None` means no limit.
39    pub max_features: Option<usize>,
40    /// Minimum document frequency (absolute count) for a term to be included.
41    pub min_df: usize,
42    /// Maximum document frequency as a fraction of total documents.
43    /// Terms appearing in more than `max_df * n_docs` documents are excluded.
44    pub max_df: f64,
45    /// If `true`, all counts are clipped to 0/1 (binary occurrence).
46    pub binary: bool,
47    /// If `true`, lowercase all tokens before counting.
48    pub lowercase: bool,
49}
50
51impl CountVectorizer {
52    /// Create a new `CountVectorizer` with default settings.
53    #[must_use]
54    pub fn new() -> Self {
55        Self {
56            max_features: None,
57            min_df: 1,
58            max_df: 1.0,
59            binary: false,
60            lowercase: true,
61        }
62    }
63
64    /// Set the maximum number of features.
65    #[must_use]
66    pub fn max_features(mut self, n: usize) -> Self {
67        self.max_features = Some(n);
68        self
69    }
70
71    /// Set the minimum document frequency.
72    #[must_use]
73    pub fn min_df(mut self, min_df: usize) -> Self {
74        self.min_df = min_df;
75        self
76    }
77
78    /// Set the maximum document frequency as a fraction of total documents.
79    #[must_use]
80    pub fn max_df(mut self, max_df: f64) -> Self {
81        self.max_df = max_df;
82        self
83    }
84
85    /// Enable or disable binary mode.
86    #[must_use]
87    pub fn binary(mut self, binary: bool) -> Self {
88        self.binary = binary;
89        self
90    }
91
92    /// Enable or disable lowercasing.
93    #[must_use]
94    pub fn lowercase(mut self, lowercase: bool) -> Self {
95        self.lowercase = lowercase;
96        self
97    }
98
99    /// Fit the vectorizer on a corpus of documents.
100    ///
101    /// # Errors
102    ///
103    /// Returns [`FerroError::InsufficientSamples`] if the corpus is empty.
104    /// Returns [`FerroError::InvalidParameter`] if `max_df` is not in `(0, 1]`.
105    pub fn fit(&self, docs: &[String]) -> Result<FittedCountVectorizer, FerroError> {
106        let n_docs = docs.len();
107        if n_docs == 0 {
108            return Err(FerroError::InsufficientSamples {
109                required: 1,
110                actual: 0,
111                context: "CountVectorizer::fit".into(),
112            });
113        }
114        if self.max_df <= 0.0 || self.max_df > 1.0 {
115            return Err(FerroError::InvalidParameter {
116                name: "max_df".into(),
117                reason: format!("must be in (0, 1], got {}", self.max_df),
118            });
119        }
120
121        // Build document-frequency counts.
122        let mut df_counts: HashMap<String, usize> = HashMap::new();
123        for doc in docs {
124            let tokens = tokenize(doc, self.lowercase);
125            // Unique tokens per document.
126            let mut seen = std::collections::HashSet::new();
127            for tok in tokens {
128                if seen.insert(tok.clone()) {
129                    *df_counts.entry(tok).or_insert(0) += 1;
130                }
131            }
132        }
133
134        // Filter by min_df and max_df.
135        let max_df_abs = (self.max_df * n_docs as f64).ceil() as usize;
136        let mut vocab: Vec<String> = df_counts
137            .into_iter()
138            .filter(|(_, count)| *count >= self.min_df && *count <= max_df_abs)
139            .map(|(term, _)| term)
140            .collect();
141        vocab.sort();
142
143        // Apply max_features: keep the top-N by total corpus frequency.
144        if let Some(max_f) = self.max_features {
145            if vocab.len() > max_f {
146                // Re-count total frequencies for the remaining terms.
147                let mut total_freq: HashMap<String, usize> = HashMap::new();
148                for doc in docs {
149                    let tokens = tokenize(doc, self.lowercase);
150                    for tok in tokens {
151                        if vocab.binary_search(&tok).is_ok() {
152                            *total_freq.entry(tok).or_insert(0) += 1;
153                        }
154                    }
155                }
156                // Sort by descending frequency, then alphabetically for ties.
157                vocab.sort_by(|a, b| {
158                    let fa = total_freq.get(a).unwrap_or(&0);
159                    let fb = total_freq.get(b).unwrap_or(&0);
160                    fb.cmp(fa).then_with(|| a.cmp(b))
161                });
162                vocab.truncate(max_f);
163                vocab.sort(); // restore alphabetical order for consistent indexing
164            }
165        }
166
167        // Build vocabulary mapping.
168        let vocabulary: HashMap<String, usize> = vocab
169            .iter()
170            .enumerate()
171            .map(|(i, t)| (t.clone(), i))
172            .collect();
173
174        Ok(FittedCountVectorizer {
175            vocabulary,
176            sorted_terms: vocab,
177            binary: self.binary,
178            lowercase: self.lowercase,
179        })
180    }
181}
182
183impl Default for CountVectorizer {
184    fn default() -> Self {
185        Self::new()
186    }
187}
188
189// ---------------------------------------------------------------------------
190// FittedCountVectorizer
191// ---------------------------------------------------------------------------
192
193/// A fitted count vectorizer holding the learned vocabulary.
194///
195/// Created by calling [`CountVectorizer::fit`].
196#[derive(Debug, Clone)]
197pub struct FittedCountVectorizer {
198    /// Map from term to column index.
199    vocabulary: HashMap<String, usize>,
200    /// Sorted vocabulary terms (for deterministic column ordering).
201    sorted_terms: Vec<String>,
202    /// Whether to clip counts to binary.
203    binary: bool,
204    /// Whether to lowercase tokens.
205    lowercase: bool,
206}
207
208impl FittedCountVectorizer {
209    /// Return the vocabulary as a sorted slice of terms.
210    #[must_use]
211    pub fn vocabulary(&self) -> &[String] {
212        &self.sorted_terms
213    }
214
215    /// Return the vocabulary mapping (term -> column index).
216    #[must_use]
217    pub fn vocabulary_map(&self) -> &HashMap<String, usize> {
218        &self.vocabulary
219    }
220
221    /// Transform documents into a term-count matrix.
222    ///
223    /// # Errors
224    ///
225    /// Returns [`FerroError::InsufficientSamples`] if `docs` is empty.
226    pub fn transform(&self, docs: &[String]) -> Result<Array2<f64>, FerroError> {
227        if docs.is_empty() {
228            return Err(FerroError::InsufficientSamples {
229                required: 1,
230                actual: 0,
231                context: "FittedCountVectorizer::transform".into(),
232            });
233        }
234
235        let n_docs = docs.len();
236        let n_vocab = self.sorted_terms.len();
237        let mut matrix = Array2::<f64>::zeros((n_docs, n_vocab));
238
239        for (i, doc) in docs.iter().enumerate() {
240            let tokens = tokenize(doc, self.lowercase);
241            for tok in tokens {
242                if let Some(&col) = self.vocabulary.get(&tok) {
243                    if self.binary {
244                        matrix[[i, col]] = 1.0;
245                    } else {
246                        matrix[[i, col]] += 1.0;
247                    }
248                }
249            }
250        }
251
252        Ok(matrix)
253    }
254}
255
256// ---------------------------------------------------------------------------
257// Tokenizer
258// ---------------------------------------------------------------------------
259
260/// Tokenize a document by splitting on non-alphanumeric boundaries.
261fn tokenize(doc: &str, lowercase: bool) -> Vec<String> {
262    let text = if lowercase {
263        doc.to_lowercase()
264    } else {
265        doc.to_string()
266    };
267
268    text.split(|c: char| !c.is_alphanumeric())
269        .filter(|s| !s.is_empty())
270        .map(std::string::ToString::to_string)
271        .collect()
272}
273
274// ---------------------------------------------------------------------------
275// Tests
276// ---------------------------------------------------------------------------
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use approx::assert_abs_diff_eq;
282
283    #[test]
284    fn test_count_vectorizer_basic() {
285        let docs = vec![
286            "the cat sat".to_string(),
287            "the cat sat on the mat".to_string(),
288        ];
289        let cv = CountVectorizer::new();
290        let fitted = cv.fit(&docs).unwrap();
291        let counts = fitted.transform(&docs).unwrap();
292
293        assert_eq!(counts.nrows(), 2);
294        let vocab = fitted.vocabulary();
295        assert!(vocab.contains(&"cat".to_string()));
296        assert!(vocab.contains(&"the".to_string()));
297        assert!(vocab.contains(&"sat".to_string()));
298
299        // "the" appears once in doc 0, twice in doc 1
300        let the_idx = fitted.vocabulary_map()["the"];
301        assert_abs_diff_eq!(counts[[0, the_idx]], 1.0, epsilon = 1e-10);
302        assert_abs_diff_eq!(counts[[1, the_idx]], 2.0, epsilon = 1e-10);
303    }
304
305    #[test]
306    fn test_count_vectorizer_binary() {
307        let docs = vec!["the the the".to_string()];
308        let cv = CountVectorizer::new().binary(true);
309        let fitted = cv.fit(&docs).unwrap();
310        let counts = fitted.transform(&docs).unwrap();
311        // "the" count should be 1 (binary mode)
312        assert_abs_diff_eq!(counts[[0, 0]], 1.0, epsilon = 1e-10);
313    }
314
315    #[test]
316    fn test_count_vectorizer_lowercase() {
317        let docs = vec!["Hello HELLO hello".to_string()];
318        let cv = CountVectorizer::new();
319        let fitted = cv.fit(&docs).unwrap();
320        let counts = fitted.transform(&docs).unwrap();
321        // All should fold to "hello", count = 3
322        assert_eq!(fitted.vocabulary().len(), 1);
323        assert_abs_diff_eq!(counts[[0, 0]], 3.0, epsilon = 1e-10);
324    }
325
326    #[test]
327    fn test_count_vectorizer_no_lowercase() {
328        let docs = vec!["Hello hello".to_string()];
329        let cv = CountVectorizer::new().lowercase(false);
330        let fitted = cv.fit(&docs).unwrap();
331        // "Hello" and "hello" are different tokens
332        assert_eq!(fitted.vocabulary().len(), 2);
333    }
334
335    #[test]
336    fn test_count_vectorizer_max_features() {
337        let docs = vec!["a b c d e f".to_string()];
338        let cv = CountVectorizer::new().max_features(3);
339        let fitted = cv.fit(&docs).unwrap();
340        assert_eq!(fitted.vocabulary().len(), 3);
341    }
342
343    #[test]
344    fn test_count_vectorizer_min_df() {
345        let docs = vec![
346            "cat dog".to_string(),
347            "cat bird".to_string(),
348            "cat fish".to_string(),
349        ];
350        // Only "cat" appears in all 3 docs
351        let cv = CountVectorizer::new().min_df(3);
352        let fitted = cv.fit(&docs).unwrap();
353        assert_eq!(fitted.vocabulary().len(), 1);
354        assert_eq!(fitted.vocabulary()[0], "cat");
355    }
356
357    #[test]
358    fn test_count_vectorizer_max_df() {
359        let docs = vec![
360            "the cat".to_string(),
361            "the dog".to_string(),
362            "the bird".to_string(),
363        ];
364        // "the" appears in 100% of docs. max_df=0.5 should exclude it.
365        let cv = CountVectorizer::new().max_df(0.5);
366        let fitted = cv.fit(&docs).unwrap();
367        assert!(!fitted.vocabulary().contains(&"the".to_string()));
368    }
369
370    #[test]
371    fn test_count_vectorizer_empty_corpus() {
372        let docs: Vec<String> = vec![];
373        let cv = CountVectorizer::new();
374        assert!(cv.fit(&docs).is_err());
375    }
376
377    #[test]
378    fn test_count_vectorizer_transform_empty() {
379        let docs = vec!["hello world".to_string()];
380        let fitted = CountVectorizer::new().fit(&docs).unwrap();
381        let empty: Vec<String> = vec![];
382        assert!(fitted.transform(&empty).is_err());
383    }
384
385    #[test]
386    fn test_count_vectorizer_unseen_tokens() {
387        let train = vec!["cat dog".to_string()];
388        let fitted = CountVectorizer::new().fit(&train).unwrap();
389        let test = vec!["fish bird".to_string()];
390        let counts = fitted.transform(&test).unwrap();
391        // All zeros since no tokens match
392        for &v in &counts {
393            assert_abs_diff_eq!(v, 0.0, epsilon = 1e-10);
394        }
395    }
396}