use std::collections::HashMap;
use ferrolearn_core::error::FerroError;
use ndarray::Array2;
#[derive(Debug, Clone)]
pub struct CountVectorizer {
pub max_features: Option<usize>,
pub min_df: usize,
pub max_df: f64,
pub binary: bool,
pub lowercase: bool,
}
impl CountVectorizer {
#[must_use]
pub fn new() -> Self {
Self {
max_features: None,
min_df: 1,
max_df: 1.0,
binary: false,
lowercase: true,
}
}
#[must_use]
pub fn max_features(mut self, n: usize) -> Self {
self.max_features = Some(n);
self
}
#[must_use]
pub fn min_df(mut self, min_df: usize) -> Self {
self.min_df = min_df;
self
}
#[must_use]
pub fn max_df(mut self, max_df: f64) -> Self {
self.max_df = max_df;
self
}
#[must_use]
pub fn binary(mut self, binary: bool) -> Self {
self.binary = binary;
self
}
#[must_use]
pub fn lowercase(mut self, lowercase: bool) -> Self {
self.lowercase = lowercase;
self
}
pub fn fit(&self, docs: &[String]) -> Result<FittedCountVectorizer, FerroError> {
let n_docs = docs.len();
if n_docs == 0 {
return Err(FerroError::InsufficientSamples {
required: 1,
actual: 0,
context: "CountVectorizer::fit".into(),
});
}
if self.max_df <= 0.0 || self.max_df > 1.0 {
return Err(FerroError::InvalidParameter {
name: "max_df".into(),
reason: format!("must be in (0, 1], got {}", self.max_df),
});
}
let mut df_counts: HashMap<String, usize> = HashMap::new();
for doc in docs {
let tokens = tokenize(doc, self.lowercase);
let mut seen = std::collections::HashSet::new();
for tok in tokens {
if seen.insert(tok.clone()) {
*df_counts.entry(tok).or_insert(0) += 1;
}
}
}
let max_df_abs = (self.max_df * n_docs as f64).ceil() as usize;
let mut vocab: Vec<String> = df_counts
.into_iter()
.filter(|(_, count)| *count >= self.min_df && *count <= max_df_abs)
.map(|(term, _)| term)
.collect();
vocab.sort();
if let Some(max_f) = self.max_features {
if vocab.len() > max_f {
let mut total_freq: HashMap<String, usize> = HashMap::new();
for doc in docs {
let tokens = tokenize(doc, self.lowercase);
for tok in tokens {
if vocab.binary_search(&tok).is_ok() {
*total_freq.entry(tok).or_insert(0) += 1;
}
}
}
vocab.sort_by(|a, b| {
let fa = total_freq.get(a).unwrap_or(&0);
let fb = total_freq.get(b).unwrap_or(&0);
fb.cmp(fa).then_with(|| a.cmp(b))
});
vocab.truncate(max_f);
vocab.sort(); }
}
let vocabulary: HashMap<String, usize> = vocab
.iter()
.enumerate()
.map(|(i, t)| (t.clone(), i))
.collect();
Ok(FittedCountVectorizer {
vocabulary,
sorted_terms: vocab,
binary: self.binary,
lowercase: self.lowercase,
})
}
}
impl Default for CountVectorizer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct FittedCountVectorizer {
vocabulary: HashMap<String, usize>,
sorted_terms: Vec<String>,
binary: bool,
lowercase: bool,
}
impl FittedCountVectorizer {
#[must_use]
pub fn vocabulary(&self) -> &[String] {
&self.sorted_terms
}
#[must_use]
pub fn vocabulary_map(&self) -> &HashMap<String, usize> {
&self.vocabulary
}
pub fn transform(&self, docs: &[String]) -> Result<Array2<f64>, FerroError> {
if docs.is_empty() {
return Err(FerroError::InsufficientSamples {
required: 1,
actual: 0,
context: "FittedCountVectorizer::transform".into(),
});
}
let n_docs = docs.len();
let n_vocab = self.sorted_terms.len();
let mut matrix = Array2::<f64>::zeros((n_docs, n_vocab));
for (i, doc) in docs.iter().enumerate() {
let tokens = tokenize(doc, self.lowercase);
for tok in tokens {
if let Some(&col) = self.vocabulary.get(&tok) {
if self.binary {
matrix[[i, col]] = 1.0;
} else {
matrix[[i, col]] += 1.0;
}
}
}
}
Ok(matrix)
}
}
fn tokenize(doc: &str, lowercase: bool) -> Vec<String> {
let text = if lowercase {
doc.to_lowercase()
} else {
doc.to_string()
};
text.split(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty())
.map(std::string::ToString::to_string)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_count_vectorizer_basic() {
let docs = vec![
"the cat sat".to_string(),
"the cat sat on the mat".to_string(),
];
let cv = CountVectorizer::new();
let fitted = cv.fit(&docs).unwrap();
let counts = fitted.transform(&docs).unwrap();
assert_eq!(counts.nrows(), 2);
let vocab = fitted.vocabulary();
assert!(vocab.contains(&"cat".to_string()));
assert!(vocab.contains(&"the".to_string()));
assert!(vocab.contains(&"sat".to_string()));
let the_idx = fitted.vocabulary_map()["the"];
assert_abs_diff_eq!(counts[[0, the_idx]], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(counts[[1, the_idx]], 2.0, epsilon = 1e-10);
}
#[test]
fn test_count_vectorizer_binary() {
let docs = vec!["the the the".to_string()];
let cv = CountVectorizer::new().binary(true);
let fitted = cv.fit(&docs).unwrap();
let counts = fitted.transform(&docs).unwrap();
assert_abs_diff_eq!(counts[[0, 0]], 1.0, epsilon = 1e-10);
}
#[test]
fn test_count_vectorizer_lowercase() {
let docs = vec!["Hello HELLO hello".to_string()];
let cv = CountVectorizer::new();
let fitted = cv.fit(&docs).unwrap();
let counts = fitted.transform(&docs).unwrap();
assert_eq!(fitted.vocabulary().len(), 1);
assert_abs_diff_eq!(counts[[0, 0]], 3.0, epsilon = 1e-10);
}
#[test]
fn test_count_vectorizer_no_lowercase() {
let docs = vec!["Hello hello".to_string()];
let cv = CountVectorizer::new().lowercase(false);
let fitted = cv.fit(&docs).unwrap();
assert_eq!(fitted.vocabulary().len(), 2);
}
#[test]
fn test_count_vectorizer_max_features() {
let docs = vec!["a b c d e f".to_string()];
let cv = CountVectorizer::new().max_features(3);
let fitted = cv.fit(&docs).unwrap();
assert_eq!(fitted.vocabulary().len(), 3);
}
#[test]
fn test_count_vectorizer_min_df() {
let docs = vec![
"cat dog".to_string(),
"cat bird".to_string(),
"cat fish".to_string(),
];
let cv = CountVectorizer::new().min_df(3);
let fitted = cv.fit(&docs).unwrap();
assert_eq!(fitted.vocabulary().len(), 1);
assert_eq!(fitted.vocabulary()[0], "cat");
}
#[test]
fn test_count_vectorizer_max_df() {
let docs = vec![
"the cat".to_string(),
"the dog".to_string(),
"the bird".to_string(),
];
let cv = CountVectorizer::new().max_df(0.5);
let fitted = cv.fit(&docs).unwrap();
assert!(!fitted.vocabulary().contains(&"the".to_string()));
}
#[test]
fn test_count_vectorizer_empty_corpus() {
let docs: Vec<String> = vec![];
let cv = CountVectorizer::new();
assert!(cv.fit(&docs).is_err());
}
#[test]
fn test_count_vectorizer_transform_empty() {
let docs = vec!["hello world".to_string()];
let fitted = CountVectorizer::new().fit(&docs).unwrap();
let empty: Vec<String> = vec![];
assert!(fitted.transform(&empty).is_err());
}
#[test]
fn test_count_vectorizer_unseen_tokens() {
let train = vec!["cat dog".to_string()];
let fitted = CountVectorizer::new().fit(&train).unwrap();
let test = vec!["fish bird".to_string()];
let counts = fitted.transform(&test).unwrap();
for &v in &counts {
assert_abs_diff_eq!(v, 0.0, epsilon = 1e-10);
}
}
}