1use crate::error::{Result, TextError};
7use crate::tokenize::{Tokenizer, WordTokenizer};
8use crate::vocabulary::Vocabulary;
9use scirs2_core::ndarray::{Array1, Array2, Axis};
10use scirs2_core::parallel_ops;
11use std::collections::HashMap;
12
13pub trait Vectorizer: Clone {
15 fn fit(&mut self, texts: &[&str]) -> Result<()>;
17
18 fn transform(&self, text: &str) -> Result<Array1<f64>>;
20
21 fn transform_batch(&self, texts: &[&str]) -> Result<Array2<f64>>;
23
24 fn fit_transform(&mut self, texts: &[&str]) -> Result<Array2<f64>> {
26 self.fit(texts)?;
27 self.transform_batch(texts)
28 }
29}
30
31pub struct CountVectorizer {
33 tokenizer: Box<dyn Tokenizer + Send + Sync>,
34 vocabulary: Vocabulary,
35 binary: bool, }
37
38impl Clone for CountVectorizer {
39 fn clone(&self) -> Self {
40 Self {
41 tokenizer: self.tokenizer.clone_box(),
42 vocabulary: self.vocabulary.clone(),
43 binary: self.binary,
44 }
45 }
46}
47
48impl CountVectorizer {
49 pub fn new(binary: bool) -> Self {
51 Self {
52 tokenizer: Box::new(WordTokenizer::default()),
53 vocabulary: Vocabulary::new(),
54 binary,
55 }
56 }
57
58 pub fn with_tokenizer(tokenizer: Box<dyn Tokenizer + Send + Sync>, binary: bool) -> Self {
60 Self {
61 tokenizer,
62 vocabulary: Vocabulary::new(),
63 binary,
64 }
65 }
66
67 pub fn vocabulary(&self) -> &Vocabulary {
69 &self.vocabulary
70 }
71
72 pub fn vocabulary_size(&self) -> usize {
74 self.vocabulary.len()
75 }
76
77 pub fn get_feature_count(
79 &self,
80 matrix: &Array2<f64>,
81 document_index: usize,
82 feature_index: usize,
83 ) -> Option<f64> {
84 if document_index < matrix.nrows() && feature_index < matrix.ncols() {
85 Some(matrix[[document_index, feature_index]])
86 } else {
87 None
88 }
89 }
90
91 pub fn vocabulary_map(&self) -> HashMap<String, usize> {
93 self.vocabulary.token_to_index().clone()
94 }
95}
96
97impl Default for CountVectorizer {
98 fn default() -> Self {
99 Self::new(false)
100 }
101}
102
103impl Vectorizer for CountVectorizer {
104 fn fit(&mut self, texts: &[&str]) -> Result<()> {
105 if texts.is_empty() {
106 return Err(TextError::InvalidInput(
107 "No texts provided for fitting".into(),
108 ));
109 }
110
111 self.vocabulary = Vocabulary::new();
113
114 for &text in texts {
116 let tokens = self.tokenizer.tokenize(text)?;
117 for token in tokens {
118 self.vocabulary.add_token(&token);
119 }
120 }
121
122 Ok(())
123 }
124
125 fn transform(&self, text: &str) -> Result<Array1<f64>> {
126 if self.vocabulary.is_empty() {
127 return Err(TextError::VocabularyError(
128 "Vocabulary is empty. Call fit() first".into(),
129 ));
130 }
131
132 let vocab_size = self.vocabulary.len();
133 let mut vector = Array1::zeros(vocab_size);
134
135 let tokens = self.tokenizer.tokenize(text)?;
137
138 for token in tokens {
140 if let Some(idx) = self.vocabulary.get_index(&token) {
141 vector[idx] += 1.0;
142 }
143 }
144
145 if self.binary {
147 for val in vector.iter_mut() {
148 if *val > 0.0 {
149 *val = 1.0;
150 }
151 }
152 }
153
154 Ok(vector)
155 }
156
157 fn transform_batch(&self, texts: &[&str]) -> Result<Array2<f64>> {
158 if self.vocabulary.is_empty() {
159 return Err(TextError::VocabularyError(
160 "Vocabulary is empty. Call fit() first".into(),
161 ));
162 }
163
164 if texts.is_empty() {
165 return Ok(Array2::zeros((0, self.vocabulary.len())));
166 }
167
168 let texts_owned: Vec<String> = texts.iter().map(|&s| s.to_string()).collect();
171 let self_clone = self.clone();
172
173 let vectors = parallel_ops::parallel_map_result(&texts_owned, move |text| {
174 self_clone.transform(text).map_err(|e| {
175 scirs2_core::CoreError::ComputationError(scirs2_core::error::ErrorContext::new(
177 format!("Text vectorization error: {e}"),
178 ))
179 })
180 })?;
181
182 let n_samples = vectors.len();
184 let n_features = self.vocabulary.len();
185
186 let mut matrix = Array2::zeros((n_samples, n_features));
187 for (i, vec) in vectors.iter().enumerate() {
188 matrix.row_mut(i).assign(vec);
189 }
190
191 Ok(matrix)
192 }
193}
194
195#[derive(Clone)]
197pub struct TfidfVectorizer {
198 count_vectorizer: CountVectorizer,
199 idf: Option<Array1<f64>>,
200 smoothidf: bool,
201 norm: Option<String>, }
203
204impl TfidfVectorizer {
205 pub fn new(binary: bool, smoothidf: bool, norm: Option<String>) -> Self {
207 Self {
208 count_vectorizer: CountVectorizer::new(binary),
209 idf: None,
210 smoothidf,
211 norm,
212 }
213 }
214
215 pub fn with_tokenizer(
217 tokenizer: Box<dyn Tokenizer + Send + Sync>,
218 binary: bool,
219 smoothidf: bool,
220 norm: Option<String>,
221 ) -> Self {
222 Self {
223 count_vectorizer: CountVectorizer::with_tokenizer(tokenizer, binary),
224 idf: None,
225 smoothidf,
226 norm,
227 }
228 }
229
230 pub fn vocabulary(&self) -> &Vocabulary {
232 self.count_vectorizer.vocabulary()
233 }
234
235 pub fn vocabulary_size(&self) -> usize {
237 self.count_vectorizer.vocabulary_size()
238 }
239
240 pub fn get_feature_score(
242 &self,
243 matrix: &Array2<f64>,
244 document_index: usize,
245 feature_index: usize,
246 ) -> Option<f64> {
247 if document_index < matrix.nrows() && feature_index < matrix.ncols() {
248 Some(matrix[[document_index, feature_index]])
249 } else {
250 None
251 }
252 }
253
254 pub fn vocabulary_map(&self) -> HashMap<String, usize> {
256 self.count_vectorizer.vocabulary_map()
257 }
258
259 fn compute_idf(&mut self, df: &Array1<f64>, ndocuments: f64) -> Result<()> {
261 let n_features = df.len();
262
263 let mut idf = Array1::zeros(n_features);
264
265 for (i, &df_i) in df.iter().enumerate() {
266 if df_i > 0.0 {
267 if self.smoothidf {
268 idf[i] = ((ndocuments + 1.0) / (df_i + 1.0)).ln() + 1.0;
270 } else {
271 idf[i] = (ndocuments / df_i).ln();
273 }
274 } else if self.smoothidf {
275 idf[i] = ((ndocuments + 1.0) / 1.0).ln() + 1.0;
276 } else {
277 idf[i] = 0.0;
279 }
280 }
281
282 self.idf = Some(idf);
283 Ok(())
284 }
285
286 fn normalize_vector(&self, vector: &mut Array1<f64>) -> Result<()> {
288 if let Some(ref norm) = self.norm {
289 match norm.as_str() {
290 "l1" => {
291 let sum = vector.sum();
292 if sum > 0.0 {
293 vector.mapv_inplace(|x| x / sum);
294 }
295 }
296 "l2" => {
297 let squared_sum: f64 = vector.iter().map(|&x| x * x).sum();
298 if squared_sum > 0.0 {
299 let norm = squared_sum.sqrt();
300 vector.mapv_inplace(|x| x / norm);
301 }
302 }
303 _ => {
304 return Err(TextError::InvalidInput(format!(
305 "Unknown normalization: {norm}"
306 )))
307 }
308 }
309 }
310
311 Ok(())
312 }
313}
314
315impl Default for TfidfVectorizer {
316 fn default() -> Self {
317 Self::new(false, true, Some("l2".to_string()))
318 }
319}
320
321impl Vectorizer for TfidfVectorizer {
322 fn fit(&mut self, texts: &[&str]) -> Result<()> {
323 if texts.is_empty() {
324 return Err(TextError::InvalidInput(
325 "No texts provided for fitting".into(),
326 ));
327 }
328
329 self.count_vectorizer.fit(texts)?;
331
332 let ndocuments = texts.len() as f64;
333 let n_features = self.count_vectorizer.vocabulary_size();
334
335 let mut df = Array1::zeros(n_features);
337
338 for &text in texts {
339 let tokens = self.count_vectorizer.tokenizer.tokenize(text)?;
340 let mut seen_tokens = HashMap::new();
341
342 for token in tokens {
344 if let Some(idx) = self.count_vectorizer.vocabulary.get_index(&token) {
345 seen_tokens.insert(idx, true);
346 }
347 }
348
349 for idx in seen_tokens.keys() {
351 df[*idx] += 1.0;
352 }
353 }
354
355 self.compute_idf(&df, ndocuments)?;
357
358 Ok(())
359 }
360
361 fn transform(&self, text: &str) -> Result<Array1<f64>> {
362 if self.idf.is_none() {
363 return Err(TextError::VocabularyError(
364 "IDF values not computed. Call fit() first".into(),
365 ));
366 }
367
368 let mut count_vector = self.count_vectorizer.transform(text)?;
370
371 let idf = self.idf.as_ref().unwrap();
373 for i in 0..count_vector.len() {
374 count_vector[i] *= idf[i];
375 }
376
377 self.normalize_vector(&mut count_vector)?;
379
380 Ok(count_vector)
381 }
382
383 fn transform_batch(&self, texts: &[&str]) -> Result<Array2<f64>> {
384 if self.idf.is_none() {
385 return Err(TextError::VocabularyError(
386 "IDF values not computed. Call fit() first".into(),
387 ));
388 }
389
390 if texts.is_empty() {
391 return Ok(Array2::zeros((0, self.count_vectorizer.vocabulary_size())));
392 }
393
394 let mut count_matrix = self.count_vectorizer.transform_batch(texts)?;
396
397 let idf = self.idf.as_ref().unwrap();
399 for mut row in count_matrix.axis_iter_mut(Axis(0)) {
400 for i in 0..row.len() {
401 row[i] *= idf[i];
402 }
403
404 if let Some(ref norm) = self.norm {
406 match norm.as_str() {
407 "l1" => {
408 let sum = row.sum();
409 if sum > 0.0 {
410 row.mapv_inplace(|x| x / sum);
411 }
412 }
413 "l2" => {
414 let squared_sum: f64 = row.iter().map(|&x| x * x).sum();
415 if squared_sum > 0.0 {
416 let norm = squared_sum.sqrt();
417 row.mapv_inplace(|x| x / norm);
418 }
419 }
420 _ => {
421 return Err(TextError::InvalidInput(format!(
422 "Unknown normalization: {norm}"
423 )))
424 }
425 }
426 }
427 }
428
429 Ok(count_matrix)
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn test_count_vectorizer() {
439 let mut vectorizer = CountVectorizer::default();
440 let corpus = [
441 "This is the first document.",
442 "This document is the second document.",
443 ];
444
445 vectorizer.fit(&corpus).unwrap();
447
448 assert_eq!(vectorizer.vocabulary_size(), 6);
450
451 let vec = vectorizer.transform(corpus[0]).unwrap();
453 assert_eq!(vec.len(), 6);
454
455 let vec_sum: f64 = vec.iter().sum();
457 assert_eq!(vec_sum, 5.0); }
459
460 #[test]
461 fn test_tfidf_vectorizer() {
462 let mut vectorizer = TfidfVectorizer::default();
463 let corpus = [
464 "This is the first document.",
465 "This document is the second document.",
466 ];
467
468 vectorizer.fit(&corpus).unwrap();
470
471 let vec = vectorizer.transform(corpus[0]).unwrap();
473 assert_eq!(vec.len(), 6);
474
475 let norm: f64 = vec.iter().map(|&x| x * x).sum::<f64>().sqrt();
477 assert!((norm - 1.0).abs() < 1e-10);
478 }
479
480 #[test]
481 fn test_binary_vectorizer() {
482 let mut vectorizer = CountVectorizer::new(true);
483 let corpus = ["this this this is a document", "this is another document"];
484
485 let matrix = vectorizer.fit_transform(&corpus).unwrap();
487
488 for val in matrix.row(0).iter() {
490 assert!(*val == 0.0 || *val == 1.0);
491 }
492 }
493}