Skip to main content

anofox_ml_text/
lib.rs

1//! Text feature extraction.
2//!
3//! Mirrors `sklearn.feature_extraction.text.{CountVectorizer, TfidfVectorizer,
4//! HashingVectorizer}` in their simplest form: lowercase + token pattern
5//! `[A-Za-z]{2,}`, no stop-words, dense output.
6
7use anofox_ml_core::{CsrMatrix, Result, RustMlError};
8use ndarray::Array2;
9use std::collections::HashMap;
10
11fn tokenize(s: &str) -> Vec<String> {
12    let mut out = Vec::new();
13    let mut buf = String::new();
14    for c in s.chars() {
15        if c.is_ascii_alphabetic() {
16            buf.push(c.to_ascii_lowercase());
17        } else if !buf.is_empty() {
18            if buf.len() >= 2 {
19                out.push(buf.clone());
20            }
21            buf.clear();
22        }
23    }
24    if buf.len() >= 2 {
25        out.push(buf);
26    }
27    out
28}
29
30// ---------------------------------------------------------------------------
31// CountVectorizer
32// ---------------------------------------------------------------------------
33
34#[derive(Debug, Clone)]
35pub struct CountVectorizer {
36    pub min_df: usize,
37    pub max_df_frac: f64,
38}
39
40impl CountVectorizer {
41    pub fn new() -> Self {
42        Self {
43            min_df: 1,
44            max_df_frac: 1.0,
45        }
46    }
47    pub fn with_min_df(mut self, m: usize) -> Self {
48        self.min_df = m;
49        self
50    }
51    pub fn with_max_df_frac(mut self, f: f64) -> Self {
52        self.max_df_frac = f;
53        self
54    }
55
56    pub fn fit_transform(&self, docs: &[&str]) -> Result<(Vec<String>, Array2<f64>)> {
57        let (vocab, csr) = self.fit_transform_sparse(docs)?;
58        Ok((vocab, csr.to_dense()))
59    }
60
61    /// Sparse-output counterpart. For high-vocab corpora the dense
62    /// `fit_transform` blows up memory (~n_docs × vocab × 8 bytes);
63    /// `fit_transform_sparse` stays at O(total_token_occurrences).
64    pub fn fit_transform_sparse(&self, docs: &[&str]) -> Result<(Vec<String>, CsrMatrix<f64>)> {
65        if docs.is_empty() {
66            return Err(RustMlError::EmptyInput("no documents".into()));
67        }
68        // Pass 1: document frequency per term.
69        let mut df: HashMap<String, usize> = HashMap::new();
70        let tokenised: Vec<Vec<String>> = docs.iter().map(|d| tokenize(d)).collect();
71        for tokens in &tokenised {
72            let mut seen = std::collections::HashSet::new();
73            for t in tokens {
74                if seen.insert(t.clone()) {
75                    *df.entry(t.clone()).or_default() += 1;
76                }
77            }
78        }
79        let n = docs.len();
80        let max_df = (self.max_df_frac * n as f64).floor() as usize;
81        let mut vocab: Vec<String> = df
82            .iter()
83            .filter(|(_, &c)| c >= self.min_df && c <= max_df.max(self.min_df))
84            .map(|(k, _)| k.clone())
85            .collect();
86        vocab.sort();
87        let term_to_col: HashMap<String, usize> = vocab
88            .iter()
89            .enumerate()
90            .map(|(i, w)| (w.clone(), i))
91            .collect();
92
93        // Aggregate counts per (doc, col) and emit triplets.
94        let mut triplets: Vec<(usize, usize, f64)> = Vec::new();
95        for (i, tokens) in tokenised.iter().enumerate() {
96            let mut row_counts: HashMap<usize, f64> = HashMap::new();
97            for t in tokens {
98                if let Some(&c) = term_to_col.get(t) {
99                    *row_counts.entry(c).or_default() += 1.0;
100                }
101            }
102            for (c, v) in row_counts {
103                triplets.push((i, c, v));
104            }
105        }
106        let csr = CsrMatrix::from_triplets(n, vocab.len(), triplets);
107        Ok((vocab, csr))
108    }
109}
110
111impl Default for CountVectorizer {
112    fn default() -> Self {
113        Self::new()
114    }
115}
116
117// ---------------------------------------------------------------------------
118// TfidfVectorizer (sklearn's smooth_idf=True, sublinear_tf=False, l2-norm)
119// ---------------------------------------------------------------------------
120
121#[derive(Debug, Clone)]
122pub struct TfidfVectorizer {
123    pub min_df: usize,
124    pub max_df_frac: f64,
125    pub norm_l2: bool,
126}
127
128impl TfidfVectorizer {
129    pub fn new() -> Self {
130        Self {
131            min_df: 1,
132            max_df_frac: 1.0,
133            norm_l2: true,
134        }
135    }
136
137    pub fn fit_transform(&self, docs: &[&str]) -> Result<(Vec<String>, Array2<f64>)> {
138        let (vocab, csr) = self.fit_transform_sparse(docs)?;
139        Ok((vocab, csr.to_dense()))
140    }
141
142    /// Sparse-output TF-IDF. IDF is computed once per term, then applied
143    /// element-wise to the sparse count matrix; optional L2-normalisation
144    /// runs over each row's non-zero slice.
145    pub fn fit_transform_sparse(&self, docs: &[&str]) -> Result<(Vec<String>, CsrMatrix<f64>)> {
146        let cv = CountVectorizer {
147            min_df: self.min_df,
148            max_df_frac: self.max_df_frac,
149        };
150        let (vocab, counts) = cv.fit_transform_sparse(docs)?;
151        let n = counts.n_rows;
152        let d = counts.n_cols;
153
154        // IDF: smooth, +1.
155        let mut df_t = vec![0usize; d];
156        for i in 0..n {
157            for (c, _) in counts.row_iter(i) {
158                df_t[c] += 1;
159            }
160        }
161        let idf: Vec<f64> = df_t
162            .iter()
163            .map(|&df| ((1.0 + n as f64) / (1.0 + df as f64)).ln() + 1.0)
164            .collect();
165
166        // Apply IDF to each non-zero and optionally L2-normalise the row.
167        let mut indptr = Vec::with_capacity(n + 1);
168        let mut indices = Vec::with_capacity(counts.nnz());
169        let mut data = Vec::with_capacity(counts.nnz());
170        indptr.push(0);
171        for i in 0..n {
172            let start = counts.indptr[i];
173            let end = counts.indptr[i + 1];
174            let mut row_vals: Vec<(usize, f64)> = counts.indices[start..end]
175                .iter()
176                .zip(counts.data[start..end].iter())
177                .map(|(&c, &v)| (c, v * idf[c]))
178                .collect();
179            if self.norm_l2 {
180                let s: f64 = row_vals.iter().map(|&(_, v)| v * v).sum();
181                let norm = s.sqrt().max(1e-12);
182                for entry in row_vals.iter_mut() {
183                    entry.1 /= norm;
184                }
185            }
186            for (c, v) in row_vals {
187                indices.push(c);
188                data.push(v);
189            }
190            indptr.push(indices.len());
191        }
192        let csr = CsrMatrix {
193            indptr,
194            indices,
195            data,
196            n_rows: n,
197            n_cols: d,
198        };
199        Ok((vocab, csr))
200    }
201}
202
203impl Default for TfidfVectorizer {
204    fn default() -> Self {
205        Self::new()
206    }
207}
208
209// ---------------------------------------------------------------------------
210// HashingVectorizer (fixed n_features, signed hash for stable signs)
211// ---------------------------------------------------------------------------
212
213#[derive(Debug, Clone)]
214pub struct HashingVectorizer {
215    pub n_features: usize,
216    pub alternate_sign: bool,
217    pub norm_l2: bool,
218}
219
220impl HashingVectorizer {
221    pub fn new(n_features: usize) -> Self {
222        Self {
223            n_features,
224            alternate_sign: true,
225            norm_l2: true,
226        }
227    }
228
229    pub fn transform(&self, docs: &[&str]) -> Array2<f64> {
230        let n = docs.len();
231        let mut x = Array2::<f64>::zeros((n, self.n_features));
232        for (i, d) in docs.iter().enumerate() {
233            for t in tokenize(d) {
234                let h = fxhash(&t);
235                let col = (h as usize) % self.n_features;
236                let sign = if self.alternate_sign && (h & 1) == 0 {
237                    1.0
238                } else {
239                    -1.0
240                };
241                let sign = if self.alternate_sign { sign } else { 1.0 };
242                x[[i, col]] += sign;
243            }
244            if self.norm_l2 {
245                let mut s = 0.0;
246                for j in 0..self.n_features {
247                    s += x[[i, j]] * x[[i, j]];
248                }
249                let nrm = s.sqrt().max(1e-12);
250                for j in 0..self.n_features {
251                    x[[i, j]] /= nrm;
252                }
253            }
254        }
255        x
256    }
257}
258
259fn fxhash(s: &str) -> u64 {
260    // Simple FNV-1a — stable across runs.
261    let mut h: u64 = 0xcbf29ce484222325;
262    for b in s.bytes() {
263        h ^= b as u64;
264        h = h.wrapping_mul(0x100000001b3);
265    }
266    h
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    #[test]
274    fn test_count_vectorizer_basic() {
275        let docs = ["the cat sat", "the dog sat", "cat dog"];
276        let cv = CountVectorizer::new();
277        let (vocab, x) = cv.fit_transform(&docs).unwrap();
278        assert!(vocab.contains(&"cat".to_string()));
279        assert!(vocab.contains(&"dog".to_string()));
280        assert!(vocab.contains(&"sat".to_string()));
281        assert!(vocab.contains(&"the".to_string()));
282        let cat_col = vocab.iter().position(|w| w == "cat").unwrap();
283        assert_eq!(x[[0, cat_col]], 1.0);
284        assert_eq!(x[[1, cat_col]], 0.0);
285        assert_eq!(x[[2, cat_col]], 1.0);
286    }
287
288    #[test]
289    fn test_tfidf_vectorizer_norm() {
290        let docs = ["the cat sat", "the dog sat"];
291        let tv = TfidfVectorizer::new();
292        let (_, x) = tv.fit_transform(&docs).unwrap();
293        for i in 0..2 {
294            let s: f64 = (0..x.ncols()).map(|j| x[[i, j]].powi(2)).sum();
295            assert!((s - 1.0).abs() < 1e-9);
296        }
297    }
298
299    #[test]
300    fn test_count_vectorizer_sparse_matches_dense() {
301        let docs = ["the cat sat on the mat", "the dog sat", "cat dog mat"];
302        let cv = CountVectorizer::new();
303        let (vocab_d, dense) = cv.fit_transform(&docs).unwrap();
304        let (vocab_s, sparse) = cv.fit_transform_sparse(&docs).unwrap();
305        assert_eq!(vocab_d, vocab_s);
306        let dense_from_sparse = sparse.to_dense();
307        for i in 0..dense.nrows() {
308            for j in 0..dense.ncols() {
309                assert_eq!(dense[[i, j]], dense_from_sparse[[i, j]]);
310            }
311        }
312        // "the" in doc 0 appears twice → expect a 2 somewhere on row 0.
313        assert!(sparse.row_iter(0).any(|(_, v)| (v - 2.0).abs() < 1e-9));
314    }
315
316    #[test]
317    fn test_tfidf_vectorizer_sparse_matches_dense() {
318        let docs = ["the cat sat", "the dog sat", "cat dog"];
319        let tv = TfidfVectorizer::new();
320        let (_, dense) = tv.fit_transform(&docs).unwrap();
321        let (_, sparse) = tv.fit_transform_sparse(&docs).unwrap();
322        let dense_from_sparse = sparse.to_dense();
323        for i in 0..dense.nrows() {
324            for j in 0..dense.ncols() {
325                assert!(
326                    (dense[[i, j]] - dense_from_sparse[[i, j]]).abs() < 1e-9,
327                    "mismatch at [{i},{j}]: dense {} vs sparse {}",
328                    dense[[i, j]],
329                    dense_from_sparse[[i, j]]
330                );
331            }
332        }
333        // Sparse L2-row norms must equal 1.
334        for i in 0..sparse.n_rows {
335            let s: f64 = sparse.row_iter(i).map(|(_, v)| v * v).sum();
336            assert!((s - 1.0).abs() < 1e-9);
337        }
338    }
339
340    #[test]
341    fn test_hashing_vectorizer_no_oov() {
342        let docs = ["unseenword wordone", "wordone wordtwo"];
343        let hv = HashingVectorizer::new(8);
344        let x = hv.transform(&docs);
345        // Both documents produce nonzero rows.
346        for i in 0..2 {
347            let s: f64 = (0..x.ncols()).map(|j| x[[i, j]].abs()).sum();
348            assert!(s > 0.0);
349        }
350    }
351}