impl CountVectorizer {
#[must_use]
pub fn new() -> Self {
Self {
tokenizer: None,
vocabulary: HashMap::new(),
lowercase: true,
max_features: None,
ngram_range: (1, 1),
min_df: 1,
max_df: 1.0,
stop_words: None,
strip_accents: false,
}
}
#[must_use]
pub fn with_stop_words_english(mut self) -> Self {
self.stop_words = Some(StopWordsFilter::english());
self
}
#[must_use]
pub fn with_stop_words(mut self, words: &[&str]) -> Self {
self.stop_words = Some(StopWordsFilter::new(words));
self
}
#[must_use]
pub fn with_strip_accents(mut self, enable: bool) -> Self {
self.strip_accents = enable;
self
}
#[must_use]
pub fn with_ngram_range(mut self, min_n: usize, max_n: usize) -> Self {
self.ngram_range = (min_n.max(1), max_n.max(1));
self
}
#[must_use]
pub fn with_min_df(mut self, min_df: usize) -> Self {
self.min_df = min_df;
self
}
#[must_use]
pub fn with_max_df(mut self, max_df: f32) -> Self {
self.max_df = max_df.clamp(0.0, 1.0);
self
}
#[must_use]
pub fn with_tokenizer(mut self, tokenizer: Box<dyn Tokenizer>) -> Self {
self.tokenizer = Some(tokenizer);
self
}
#[must_use]
pub fn with_lowercase(mut self, lowercase: bool) -> Self {
self.lowercase = lowercase;
self
}
#[must_use]
pub fn with_max_features(mut self, max_features: usize) -> Self {
self.max_features = Some(max_features);
self
}
pub fn fit_transform<S: AsRef<str>>(
&mut self,
documents: &[S],
) -> Result<Matrix<f64>, AprenderError> {
self.fit(documents)?;
self.transform(documents)
}
pub fn fit<S: AsRef<str>>(&mut self, documents: &[S]) -> Result<(), AprenderError> {
if documents.is_empty() {
return Err(AprenderError::Other(
"Cannot fit on empty documents".to_string(),
));
}
let tokenizer = self.tokenizer.as_ref().ok_or_else(|| {
AprenderError::Other("Tokenizer not set. Use with_tokenizer()".to_string())
})?;
let n_docs = documents.len();
let mut term_freq: HashMap<String, usize> = HashMap::new();
let mut doc_freq: HashMap<String, usize> = HashMap::new();
for doc in documents {
let text = doc.as_ref();
let tokens = tokenizer.tokenize(text)?;
let tokens: Vec<String> = tokens
.into_iter()
.map(|t| {
let mut t = if self.lowercase { t.to_lowercase() } else { t };
if self.strip_accents {
t = strip_accents_unicode(&t);
}
t
})
.filter(|t| {
self.stop_words
.as_ref()
.map_or(true, |sw| !sw.is_stop_word(t))
})
.collect();
let mut doc_terms: std::collections::HashSet<String> = std::collections::HashSet::new();
for n in self.ngram_range.0..=self.ngram_range.1 {
for ngram in tokens.windows(n) {
let term = ngram.join("_");
*term_freq.entry(term.clone()).or_insert(0) += 1;
doc_terms.insert(term);
}
}
for term in doc_terms {
*doc_freq.entry(term).or_insert(0) += 1;
}
}
let max_df_count = (self.max_df * n_docs as f32).ceil() as usize;
let filtered: Vec<(String, usize)> = term_freq
.into_iter()
.filter(|(term, _)| {
let df = doc_freq.get(term).copied().unwrap_or(0);
df >= self.min_df && df <= max_df_count
})
.collect();
let mut sorted_words: Vec<(String, usize)> = filtered;
sorted_words.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));
if let Some(max_features) = self.max_features {
sorted_words.truncate(max_features);
}
self.vocabulary = sorted_words
.into_iter()
.enumerate()
.map(|(idx, (word, _))| (word, idx))
.collect();
Ok(())
}
pub fn transform<S: AsRef<str>>(&self, documents: &[S]) -> Result<Matrix<f64>, AprenderError> {
if documents.is_empty() {
return Err(AprenderError::Other(
"Cannot transform empty documents".to_string(),
));
}
if self.vocabulary.is_empty() {
return Err(AprenderError::Other(
"Vocabulary is empty. Call fit() first".to_string(),
));
}
let tokenizer = self
.tokenizer
.as_ref()
.ok_or_else(|| AprenderError::Other("Tokenizer not set".to_string()))?;
let n_docs = documents.len();
let vocab_size = self.vocabulary.len();
let mut data = vec![0.0; n_docs * vocab_size];
for (doc_idx, doc) in documents.iter().enumerate() {
let text = doc.as_ref();
let tokens = tokenizer.tokenize(text)?;
let tokens: Vec<String> = tokens
.into_iter()
.map(|t| {
let mut t = if self.lowercase { t.to_lowercase() } else { t };
if self.strip_accents {
t = strip_accents_unicode(&t);
}
t
})
.filter(|t| {
self.stop_words
.as_ref()
.map_or(true, |sw| !sw.is_stop_word(t))
})
.collect();
for n in self.ngram_range.0..=self.ngram_range.1 {
for ngram in tokens.windows(n) {
let term = ngram.join("_");
if let Some(&word_idx) = self.vocabulary.get(&term) {
let idx = doc_idx * vocab_size + word_idx;
data[idx] += 1.0;
}
}
}
}
Matrix::from_vec(n_docs, vocab_size, data)
.map_err(|e: &str| AprenderError::Other(e.to_string()))
}
#[must_use]
pub fn vocabulary(&self) -> &HashMap<String, usize> {
&self.vocabulary
}
#[must_use]
pub fn vocabulary_size(&self) -> usize {
self.vocabulary.len()
}
}
impl Default for CountVectorizer {
fn default() -> Self {
Self::new()
}
}
#[allow(missing_debug_implementations)]
pub struct TfidfVectorizer {
count_vectorizer: CountVectorizer,
idf_values: Vec<f64>,
sublinear_tf: bool,
}