use crate::error::{Result, TextError};
use crate::tokenize::{Tokenizer, WordTokenizer};
use crate::vocabulary::Vocabulary;
use scirs2_core::ndarray::{Array1, Array2, Axis};
use scirs2_core::parallel_ops;
use std::collections::HashMap;
pub trait Vectorizer: Clone {
fn fit(&mut self, texts: &[&str]) -> Result<()>;
fn transform(&self, text: &str) -> Result<Array1<f64>>;
fn transform_batch(&self, texts: &[&str]) -> Result<Array2<f64>>;
fn fit_transform(&mut self, texts: &[&str]) -> Result<Array2<f64>> {
self.fit(texts)?;
self.transform_batch(texts)
}
}
pub struct CountVectorizer {
tokenizer: Box<dyn Tokenizer + Send + Sync>,
vocabulary: Vocabulary,
binary: bool, }
impl Clone for CountVectorizer {
fn clone(&self) -> Self {
Self {
tokenizer: self.tokenizer.clone_box(),
vocabulary: self.vocabulary.clone(),
binary: self.binary,
}
}
}
impl CountVectorizer {
pub fn new(binary: bool) -> Self {
Self {
tokenizer: Box::new(WordTokenizer::default()),
vocabulary: Vocabulary::new(),
binary,
}
}
pub fn with_tokenizer(tokenizer: Box<dyn Tokenizer + Send + Sync>, binary: bool) -> Self {
Self {
tokenizer,
vocabulary: Vocabulary::new(),
binary,
}
}
pub fn vocabulary(&self) -> &Vocabulary {
&self.vocabulary
}
pub fn vocabulary_size(&self) -> usize {
self.vocabulary.len()
}
pub fn get_feature_count(
&self,
matrix: &Array2<f64>,
document_index: usize,
feature_index: usize,
) -> Option<f64> {
if document_index < matrix.nrows() && feature_index < matrix.ncols() {
Some(matrix[[document_index, feature_index]])
} else {
None
}
}
pub fn vocabulary_map(&self) -> HashMap<String, usize> {
self.vocabulary.token_to_index().clone()
}
}
impl Default for CountVectorizer {
fn default() -> Self {
Self::new(false)
}
}
impl Vectorizer for CountVectorizer {
fn fit(&mut self, texts: &[&str]) -> Result<()> {
if texts.is_empty() {
return Err(TextError::InvalidInput(
"No texts provided for fitting".into(),
));
}
self.vocabulary = Vocabulary::new();
for &text in texts {
let tokens = self.tokenizer.tokenize(text)?;
for token in tokens {
self.vocabulary.add_token(&token);
}
}
Ok(())
}
fn transform(&self, text: &str) -> Result<Array1<f64>> {
if self.vocabulary.is_empty() {
return Err(TextError::VocabularyError(
"Vocabulary is empty. Call fit() first".into(),
));
}
let vocab_size = self.vocabulary.len();
let mut vector = Array1::zeros(vocab_size);
let tokens = self.tokenizer.tokenize(text)?;
for token in tokens {
if let Some(idx) = self.vocabulary.get_index(&token) {
vector[idx] += 1.0;
}
}
if self.binary {
for val in vector.iter_mut() {
if *val > 0.0 {
*val = 1.0;
}
}
}
Ok(vector)
}
fn transform_batch(&self, texts: &[&str]) -> Result<Array2<f64>> {
if self.vocabulary.is_empty() {
return Err(TextError::VocabularyError(
"Vocabulary is empty. Call fit() first".into(),
));
}
if texts.is_empty() {
return Ok(Array2::zeros((0, self.vocabulary.len())));
}
let texts_owned: Vec<String> = texts.iter().map(|&s| s.to_string()).collect();
let self_clone = self.clone();
let vectors = parallel_ops::parallel_map_result(&texts_owned, move |text| {
self_clone.transform(text).map_err(|e| {
scirs2_core::CoreError::ComputationError(scirs2_core::error::ErrorContext::new(
format!("Text vectorization error: {e}"),
))
})
})?;
let n_samples = vectors.len();
let n_features = self.vocabulary.len();
let mut matrix = Array2::zeros((n_samples, n_features));
for (i, vec) in vectors.iter().enumerate() {
matrix.row_mut(i).assign(vec);
}
Ok(matrix)
}
}
#[derive(Clone)]
pub struct TfidfVectorizer {
count_vectorizer: CountVectorizer,
idf: Option<Array1<f64>>,
smoothidf: bool,
norm: Option<String>, }
impl TfidfVectorizer {
pub fn new(binary: bool, smoothidf: bool, norm: Option<String>) -> Self {
Self {
count_vectorizer: CountVectorizer::new(binary),
idf: None,
smoothidf,
norm,
}
}
pub fn with_tokenizer(
tokenizer: Box<dyn Tokenizer + Send + Sync>,
binary: bool,
smoothidf: bool,
norm: Option<String>,
) -> Self {
Self {
count_vectorizer: CountVectorizer::with_tokenizer(tokenizer, binary),
idf: None,
smoothidf,
norm,
}
}
pub fn vocabulary(&self) -> &Vocabulary {
self.count_vectorizer.vocabulary()
}
pub fn vocabulary_size(&self) -> usize {
self.count_vectorizer.vocabulary_size()
}
pub fn get_feature_score(
&self,
matrix: &Array2<f64>,
document_index: usize,
feature_index: usize,
) -> Option<f64> {
if document_index < matrix.nrows() && feature_index < matrix.ncols() {
Some(matrix[[document_index, feature_index]])
} else {
None
}
}
pub fn vocabulary_map(&self) -> HashMap<String, usize> {
self.count_vectorizer.vocabulary_map()
}
fn compute_idf(&mut self, df: &Array1<f64>, ndocuments: f64) -> Result<()> {
let n_features = df.len();
let mut idf = Array1::zeros(n_features);
for (i, &df_i) in df.iter().enumerate() {
if df_i > 0.0 {
if self.smoothidf {
idf[i] = ((ndocuments + 1.0) / (df_i + 1.0)).ln() + 1.0;
} else {
idf[i] = (ndocuments / df_i).ln();
}
} else if self.smoothidf {
idf[i] = ((ndocuments + 1.0) / 1.0).ln() + 1.0;
} else {
idf[i] = 0.0;
}
}
self.idf = Some(idf);
Ok(())
}
fn normalize_vector(&self, vector: &mut Array1<f64>) -> Result<()> {
if let Some(ref norm) = self.norm {
match norm.as_str() {
"l1" => {
let sum = vector.sum();
if sum > 0.0 {
vector.mapv_inplace(|x| x / sum);
}
}
"l2" => {
let squared_sum: f64 = vector.iter().map(|&x| x * x).sum();
if squared_sum > 0.0 {
let norm = squared_sum.sqrt();
vector.mapv_inplace(|x| x / norm);
}
}
_ => {
return Err(TextError::InvalidInput(format!(
"Unknown normalization: {norm}"
)))
}
}
}
Ok(())
}
}
impl Default for TfidfVectorizer {
fn default() -> Self {
Self::new(false, true, Some("l2".to_string()))
}
}
impl Vectorizer for TfidfVectorizer {
fn fit(&mut self, texts: &[&str]) -> Result<()> {
if texts.is_empty() {
return Err(TextError::InvalidInput(
"No texts provided for fitting".into(),
));
}
self.count_vectorizer.fit(texts)?;
let ndocuments = texts.len() as f64;
let n_features = self.count_vectorizer.vocabulary_size();
let mut df = Array1::zeros(n_features);
for &text in texts {
let tokens = self.count_vectorizer.tokenizer.tokenize(text)?;
let mut seen_tokens = HashMap::new();
for token in tokens {
if let Some(idx) = self.count_vectorizer.vocabulary.get_index(&token) {
seen_tokens.insert(idx, true);
}
}
for idx in seen_tokens.keys() {
df[*idx] += 1.0;
}
}
self.compute_idf(&df, ndocuments)?;
Ok(())
}
fn transform(&self, text: &str) -> Result<Array1<f64>> {
if self.idf.is_none() {
return Err(TextError::VocabularyError(
"IDF values not computed. Call fit() first".into(),
));
}
let mut count_vector = self.count_vectorizer.transform(text)?;
let idf = self.idf.as_ref().expect("Operation failed");
for i in 0..count_vector.len() {
count_vector[i] *= idf[i];
}
self.normalize_vector(&mut count_vector)?;
Ok(count_vector)
}
fn transform_batch(&self, texts: &[&str]) -> Result<Array2<f64>> {
if self.idf.is_none() {
return Err(TextError::VocabularyError(
"IDF values not computed. Call fit() first".into(),
));
}
if texts.is_empty() {
return Ok(Array2::zeros((0, self.count_vectorizer.vocabulary_size())));
}
let mut count_matrix = self.count_vectorizer.transform_batch(texts)?;
let idf = self.idf.as_ref().expect("Operation failed");
for mut row in count_matrix.axis_iter_mut(Axis(0)) {
for i in 0..row.len() {
row[i] *= idf[i];
}
if let Some(ref norm) = self.norm {
match norm.as_str() {
"l1" => {
let sum = row.sum();
if sum > 0.0 {
row.mapv_inplace(|x| x / sum);
}
}
"l2" => {
let squared_sum: f64 = row.iter().map(|&x| x * x).sum();
if squared_sum > 0.0 {
let norm = squared_sum.sqrt();
row.mapv_inplace(|x| x / norm);
}
}
_ => {
return Err(TextError::InvalidInput(format!(
"Unknown normalization: {norm}"
)))
}
}
}
}
Ok(count_matrix)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_count_vectorizer() {
let mut vectorizer = CountVectorizer::default();
let corpus = [
"This is the first document.",
"This document is the second document.",
];
vectorizer.fit(&corpus).expect("Operation failed");
assert_eq!(vectorizer.vocabulary_size(), 6);
let vec = vectorizer.transform(corpus[0]).expect("Operation failed");
assert_eq!(vec.len(), 6);
let vec_sum: f64 = vec.iter().sum();
assert_eq!(vec_sum, 5.0); }
#[test]
fn test_tfidf_vectorizer() {
let mut vectorizer = TfidfVectorizer::default();
let corpus = [
"This is the first document.",
"This document is the second document.",
];
vectorizer.fit(&corpus).expect("Operation failed");
let vec = vectorizer.transform(corpus[0]).expect("Operation failed");
assert_eq!(vec.len(), 6);
let norm: f64 = vec.iter().map(|&x| x * x).sum::<f64>().sqrt();
assert!((norm - 1.0).abs() < 1e-10);
}
#[test]
fn test_binary_vectorizer() {
let mut vectorizer = CountVectorizer::new(true);
let corpus = ["this this this is a document", "this is another document"];
let matrix = vectorizer.fit_transform(&corpus).expect("Operation failed");
for val in matrix.row(0).iter() {
assert!(*val == 0.0 || *val == 1.0);
}
}
}