1use anofox_ml_core::{CsrMatrix, Result, RustMlError};
8use ndarray::Array2;
9use std::collections::HashMap;
10
11fn tokenize(s: &str) -> Vec<String> {
12 let mut out = Vec::new();
13 let mut buf = String::new();
14 for c in s.chars() {
15 if c.is_ascii_alphabetic() {
16 buf.push(c.to_ascii_lowercase());
17 } else if !buf.is_empty() {
18 if buf.len() >= 2 {
19 out.push(buf.clone());
20 }
21 buf.clear();
22 }
23 }
24 if buf.len() >= 2 {
25 out.push(buf);
26 }
27 out
28}
29
30#[derive(Debug, Clone)]
35pub struct CountVectorizer {
36 pub min_df: usize,
37 pub max_df_frac: f64,
38}
39
40impl CountVectorizer {
41 pub fn new() -> Self {
42 Self {
43 min_df: 1,
44 max_df_frac: 1.0,
45 }
46 }
47 pub fn with_min_df(mut self, m: usize) -> Self {
48 self.min_df = m;
49 self
50 }
51 pub fn with_max_df_frac(mut self, f: f64) -> Self {
52 self.max_df_frac = f;
53 self
54 }
55
56 pub fn fit_transform(&self, docs: &[&str]) -> Result<(Vec<String>, Array2<f64>)> {
57 let (vocab, csr) = self.fit_transform_sparse(docs)?;
58 Ok((vocab, csr.to_dense()))
59 }
60
61 pub fn fit_transform_sparse(&self, docs: &[&str]) -> Result<(Vec<String>, CsrMatrix<f64>)> {
65 if docs.is_empty() {
66 return Err(RustMlError::EmptyInput("no documents".into()));
67 }
68 let mut df: HashMap<String, usize> = HashMap::new();
70 let tokenised: Vec<Vec<String>> = docs.iter().map(|d| tokenize(d)).collect();
71 for tokens in &tokenised {
72 let mut seen = std::collections::HashSet::new();
73 for t in tokens {
74 if seen.insert(t.clone()) {
75 *df.entry(t.clone()).or_default() += 1;
76 }
77 }
78 }
79 let n = docs.len();
80 let max_df = (self.max_df_frac * n as f64).floor() as usize;
81 let mut vocab: Vec<String> = df
82 .iter()
83 .filter(|(_, &c)| c >= self.min_df && c <= max_df.max(self.min_df))
84 .map(|(k, _)| k.clone())
85 .collect();
86 vocab.sort();
87 let term_to_col: HashMap<String, usize> = vocab
88 .iter()
89 .enumerate()
90 .map(|(i, w)| (w.clone(), i))
91 .collect();
92
93 let mut triplets: Vec<(usize, usize, f64)> = Vec::new();
95 for (i, tokens) in tokenised.iter().enumerate() {
96 let mut row_counts: HashMap<usize, f64> = HashMap::new();
97 for t in tokens {
98 if let Some(&c) = term_to_col.get(t) {
99 *row_counts.entry(c).or_default() += 1.0;
100 }
101 }
102 for (c, v) in row_counts {
103 triplets.push((i, c, v));
104 }
105 }
106 let csr = CsrMatrix::from_triplets(n, vocab.len(), triplets);
107 Ok((vocab, csr))
108 }
109}
110
111impl Default for CountVectorizer {
112 fn default() -> Self {
113 Self::new()
114 }
115}
116
117#[derive(Debug, Clone)]
122pub struct TfidfVectorizer {
123 pub min_df: usize,
124 pub max_df_frac: f64,
125 pub norm_l2: bool,
126}
127
128impl TfidfVectorizer {
129 pub fn new() -> Self {
130 Self {
131 min_df: 1,
132 max_df_frac: 1.0,
133 norm_l2: true,
134 }
135 }
136
137 pub fn fit_transform(&self, docs: &[&str]) -> Result<(Vec<String>, Array2<f64>)> {
138 let (vocab, csr) = self.fit_transform_sparse(docs)?;
139 Ok((vocab, csr.to_dense()))
140 }
141
142 pub fn fit_transform_sparse(&self, docs: &[&str]) -> Result<(Vec<String>, CsrMatrix<f64>)> {
146 let cv = CountVectorizer {
147 min_df: self.min_df,
148 max_df_frac: self.max_df_frac,
149 };
150 let (vocab, counts) = cv.fit_transform_sparse(docs)?;
151 let n = counts.n_rows;
152 let d = counts.n_cols;
153
154 let mut df_t = vec![0usize; d];
156 for i in 0..n {
157 for (c, _) in counts.row_iter(i) {
158 df_t[c] += 1;
159 }
160 }
161 let idf: Vec<f64> = df_t
162 .iter()
163 .map(|&df| ((1.0 + n as f64) / (1.0 + df as f64)).ln() + 1.0)
164 .collect();
165
166 let mut indptr = Vec::with_capacity(n + 1);
168 let mut indices = Vec::with_capacity(counts.nnz());
169 let mut data = Vec::with_capacity(counts.nnz());
170 indptr.push(0);
171 for i in 0..n {
172 let start = counts.indptr[i];
173 let end = counts.indptr[i + 1];
174 let mut row_vals: Vec<(usize, f64)> = counts.indices[start..end]
175 .iter()
176 .zip(counts.data[start..end].iter())
177 .map(|(&c, &v)| (c, v * idf[c]))
178 .collect();
179 if self.norm_l2 {
180 let s: f64 = row_vals.iter().map(|&(_, v)| v * v).sum();
181 let norm = s.sqrt().max(1e-12);
182 for entry in row_vals.iter_mut() {
183 entry.1 /= norm;
184 }
185 }
186 for (c, v) in row_vals {
187 indices.push(c);
188 data.push(v);
189 }
190 indptr.push(indices.len());
191 }
192 let csr = CsrMatrix {
193 indptr,
194 indices,
195 data,
196 n_rows: n,
197 n_cols: d,
198 };
199 Ok((vocab, csr))
200 }
201}
202
203impl Default for TfidfVectorizer {
204 fn default() -> Self {
205 Self::new()
206 }
207}
208
209#[derive(Debug, Clone)]
214pub struct HashingVectorizer {
215 pub n_features: usize,
216 pub alternate_sign: bool,
217 pub norm_l2: bool,
218}
219
220impl HashingVectorizer {
221 pub fn new(n_features: usize) -> Self {
222 Self {
223 n_features,
224 alternate_sign: true,
225 norm_l2: true,
226 }
227 }
228
229 pub fn transform(&self, docs: &[&str]) -> Array2<f64> {
230 let n = docs.len();
231 let mut x = Array2::<f64>::zeros((n, self.n_features));
232 for (i, d) in docs.iter().enumerate() {
233 for t in tokenize(d) {
234 let h = fxhash(&t);
235 let col = (h as usize) % self.n_features;
236 let sign = if self.alternate_sign && (h & 1) == 0 {
237 1.0
238 } else {
239 -1.0
240 };
241 let sign = if self.alternate_sign { sign } else { 1.0 };
242 x[[i, col]] += sign;
243 }
244 if self.norm_l2 {
245 let mut s = 0.0;
246 for j in 0..self.n_features {
247 s += x[[i, j]] * x[[i, j]];
248 }
249 let nrm = s.sqrt().max(1e-12);
250 for j in 0..self.n_features {
251 x[[i, j]] /= nrm;
252 }
253 }
254 }
255 x
256 }
257}
258
259fn fxhash(s: &str) -> u64 {
260 let mut h: u64 = 0xcbf29ce484222325;
262 for b in s.bytes() {
263 h ^= b as u64;
264 h = h.wrapping_mul(0x100000001b3);
265 }
266 h
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn test_count_vectorizer_basic() {
275 let docs = ["the cat sat", "the dog sat", "cat dog"];
276 let cv = CountVectorizer::new();
277 let (vocab, x) = cv.fit_transform(&docs).unwrap();
278 assert!(vocab.contains(&"cat".to_string()));
279 assert!(vocab.contains(&"dog".to_string()));
280 assert!(vocab.contains(&"sat".to_string()));
281 assert!(vocab.contains(&"the".to_string()));
282 let cat_col = vocab.iter().position(|w| w == "cat").unwrap();
283 assert_eq!(x[[0, cat_col]], 1.0);
284 assert_eq!(x[[1, cat_col]], 0.0);
285 assert_eq!(x[[2, cat_col]], 1.0);
286 }
287
288 #[test]
289 fn test_tfidf_vectorizer_norm() {
290 let docs = ["the cat sat", "the dog sat"];
291 let tv = TfidfVectorizer::new();
292 let (_, x) = tv.fit_transform(&docs).unwrap();
293 for i in 0..2 {
294 let s: f64 = (0..x.ncols()).map(|j| x[[i, j]].powi(2)).sum();
295 assert!((s - 1.0).abs() < 1e-9);
296 }
297 }
298
299 #[test]
300 fn test_count_vectorizer_sparse_matches_dense() {
301 let docs = ["the cat sat on the mat", "the dog sat", "cat dog mat"];
302 let cv = CountVectorizer::new();
303 let (vocab_d, dense) = cv.fit_transform(&docs).unwrap();
304 let (vocab_s, sparse) = cv.fit_transform_sparse(&docs).unwrap();
305 assert_eq!(vocab_d, vocab_s);
306 let dense_from_sparse = sparse.to_dense();
307 for i in 0..dense.nrows() {
308 for j in 0..dense.ncols() {
309 assert_eq!(dense[[i, j]], dense_from_sparse[[i, j]]);
310 }
311 }
312 assert!(sparse.row_iter(0).any(|(_, v)| (v - 2.0).abs() < 1e-9));
314 }
315
316 #[test]
317 fn test_tfidf_vectorizer_sparse_matches_dense() {
318 let docs = ["the cat sat", "the dog sat", "cat dog"];
319 let tv = TfidfVectorizer::new();
320 let (_, dense) = tv.fit_transform(&docs).unwrap();
321 let (_, sparse) = tv.fit_transform_sparse(&docs).unwrap();
322 let dense_from_sparse = sparse.to_dense();
323 for i in 0..dense.nrows() {
324 for j in 0..dense.ncols() {
325 assert!(
326 (dense[[i, j]] - dense_from_sparse[[i, j]]).abs() < 1e-9,
327 "mismatch at [{i},{j}]: dense {} vs sparse {}",
328 dense[[i, j]],
329 dense_from_sparse[[i, j]]
330 );
331 }
332 }
333 for i in 0..sparse.n_rows {
335 let s: f64 = sparse.row_iter(i).map(|(_, v)| v * v).sum();
336 assert!((s - 1.0).abs() < 1e-9);
337 }
338 }
339
340 #[test]
341 fn test_hashing_vectorizer_no_oov() {
342 let docs = ["unseenword wordone", "wordone wordtwo"];
343 let hv = HashingVectorizer::new(8);
344 let x = hv.transform(&docs);
345 for i in 0..2 {
347 let s: f64 = (0..x.ncols()).map(|j| x[[i, j]].abs()).sum();
348 assert!(s > 0.0);
349 }
350 }
351}