use crate::error::{Result, TextError};
use crate::tokenize::{NgramTokenizer, Tokenizer, WordTokenizer};
use crate::vocabulary::Vocabulary;
use scirs2_core::ndarray::{Array1, Array2};
use std::collections::HashMap;
pub struct EnhancedCountVectorizer {
vocabulary: Vocabulary,
binary: bool,
ngram_range: (usize, usize),
max_features: Option<usize>,
min_df: f64,
max_df: f64,
lowercase: bool,
}
impl EnhancedCountVectorizer {
pub fn new() -> Self {
Self {
vocabulary: Vocabulary::new(),
binary: false,
ngram_range: (1, 1),
max_features: None,
min_df: 0.0,
max_df: 1.0,
lowercase: true,
}
}
pub fn set_binary(mut self, binary: bool) -> Self {
self.binary = binary;
self
}
pub fn set_ngram_range(mut self, range: (usize, usize)) -> Result<Self> {
if range.0 == 0 || range.1 < range.0 {
return Err(TextError::InvalidInput(
"Invalid n-gram range. Must have min_n > 0 and max_n >= min_n".to_string(),
));
}
self.ngram_range = range;
Ok(self)
}
pub fn set_max_features(mut self, maxfeatures: Option<usize>) -> Self {
self.max_features = maxfeatures;
self
}
pub fn set_min_df(mut self, mindf: f64) -> Result<Self> {
if !(0.0..=1.0).contains(&mindf) {
return Err(TextError::InvalidInput(
"min_df must be between 0.0 and 1.0".to_string(),
));
}
self.min_df = mindf;
Ok(self)
}
pub fn set_max_df(mut self, maxdf: f64) -> Result<Self> {
if !(0.0..=1.0).contains(&maxdf) {
return Err(TextError::InvalidInput(
"max_df must be between 0.0 and 1.0".to_string(),
));
}
self.max_df = maxdf;
Ok(self)
}
pub fn set_lowercase(mut self, lowercase: bool) -> Self {
self.lowercase = lowercase;
self
}
pub fn vocabulary(&self) -> &Vocabulary {
&self.vocabulary
}
pub fn fit(&mut self, texts: &[&str]) -> Result<()> {
if texts.is_empty() {
return Err(TextError::InvalidInput(
"No texts provided for fitting".to_string(),
));
}
self.vocabulary = Vocabulary::new();
let mut doc_frequencies: HashMap<String, usize> = HashMap::new();
let total_docs = texts.len();
for text in texts {
let mut seen_in_doc: HashMap<String, bool> = HashMap::new();
let all_tokens = self.extract_ngrams(text)?;
for token in all_tokens {
if !seen_in_doc.contains_key(&token) {
*doc_frequencies.entry(token.clone()).or_insert(0) += 1;
seen_in_doc.insert(token.clone(), true);
}
self.vocabulary.add_token(&token);
}
}
let min_count = (self.min_df * total_docs as f64).ceil() as usize;
let max_count = (self.max_df * total_docs as f64).floor() as usize;
let mut filtered_tokens: Vec<(String, usize)> = doc_frequencies
.into_iter()
.filter(|(_, count)| *count >= min_count && *count <= max_count)
.collect();
filtered_tokens.sort_by_key(|item| std::cmp::Reverse(item.1));
if let Some(max_features) = self.max_features {
filtered_tokens.truncate(max_features);
}
self.vocabulary = Vocabulary::with_maxsize(self.max_features.unwrap_or(usize::MAX));
for (token, _) in filtered_tokens {
self.vocabulary.add_token(&token);
}
Ok(())
}
fn extract_ngrams(&self, text: &str) -> Result<Vec<String>> {
let text = if self.lowercase {
text.to_lowercase()
} else {
text.to_string()
};
let all_ngrams = if self.ngram_range == (1, 1) {
let tokenizer = WordTokenizer::new(false);
tokenizer.tokenize(&text)?
} else {
let ngram_tokenizer =
NgramTokenizer::with_range(self.ngram_range.0, self.ngram_range.1)?;
ngram_tokenizer.tokenize(&text)?
};
Ok(all_ngrams)
}
pub fn transform(&self, text: &str) -> Result<Array1<f64>> {
if self.vocabulary.is_empty() {
return Err(TextError::VocabularyError(
"Vocabulary is empty. Call fit() first".to_string(),
));
}
let vocab_size = self.vocabulary.len();
let mut vector = Array1::zeros(vocab_size);
let tokens = self.extract_ngrams(text)?;
for token in tokens {
if let Some(idx) = self.vocabulary.get_index(&token) {
vector[idx] += 1.0;
}
}
if self.binary {
for val in vector.iter_mut() {
if *val > 0.0 {
*val = 1.0;
}
}
}
Ok(vector)
}
pub fn transform_batch(&self, texts: &[&str]) -> Result<Array2<f64>> {
if self.vocabulary.is_empty() {
return Err(TextError::VocabularyError(
"Vocabulary is empty. Call fit() first".to_string(),
));
}
let n_samples = texts.len();
let vocab_size = self.vocabulary.len();
let mut matrix = Array2::zeros((n_samples, vocab_size));
for (i, text) in texts.iter().enumerate() {
let vector = self.transform(text)?;
matrix.row_mut(i).assign(&vector);
}
Ok(matrix)
}
pub fn fit_transform(&mut self, texts: &[&str]) -> Result<Array2<f64>> {
self.fit(texts)?;
self.transform_batch(texts)
}
}
impl Default for EnhancedCountVectorizer {
fn default() -> Self {
Self::new()
}
}
pub struct EnhancedTfidfVectorizer {
count_vectorizer: EnhancedCountVectorizer,
useidf: bool,
smoothidf: bool,
sublinear_tf: bool,
norm: Option<String>,
idf_: Option<Array1<f64>>,
}
impl EnhancedTfidfVectorizer {
pub fn new() -> Self {
Self {
count_vectorizer: EnhancedCountVectorizer::new(),
useidf: true,
smoothidf: true,
sublinear_tf: false,
norm: Some("l2".to_string()),
idf_: None,
}
}
pub fn set_use_idf(mut self, useidf: bool) -> Self {
self.useidf = useidf;
self
}
pub fn set_smooth_idf(mut self, smoothidf: bool) -> Self {
self.smoothidf = smoothidf;
self
}
pub fn set_sublinear_tf(mut self, sublineartf: bool) -> Self {
self.sublinear_tf = sublineartf;
self
}
pub fn set_norm(mut self, norm: Option<String>) -> Result<Self> {
if let Some(ref n) = norm {
if n != "l1" && n != "l2" {
return Err(TextError::InvalidInput(
"Norm must be 'l1', 'l2', or None".to_string(),
));
}
}
self.norm = norm;
Ok(self)
}
pub fn set_ngram_range(mut self, range: (usize, usize)) -> Result<Self> {
self.count_vectorizer = self.count_vectorizer.set_ngram_range(range)?;
Ok(self)
}
pub fn set_max_features(mut self, maxfeatures: Option<usize>) -> Self {
self.count_vectorizer = self.count_vectorizer.set_max_features(maxfeatures);
self
}
pub fn vocabulary(&self) -> &Vocabulary {
self.count_vectorizer.vocabulary()
}
pub fn fit(&mut self, texts: &[&str]) -> Result<()> {
self.count_vectorizer.fit(texts)?;
if self.useidf {
self.calculate_idf(texts)?;
}
Ok(())
}
fn calculate_idf(&mut self, texts: &[&str]) -> Result<()> {
let vocab_size = self.count_vectorizer.vocabulary().len();
let mut df: Array1<f64> = Array1::zeros(vocab_size);
let n_samples = texts.len() as f64;
for text in texts {
let count_vec = self.count_vectorizer.transform(text)?;
for (idx, &count) in count_vec.iter().enumerate() {
if count > 0.0 {
df[idx] += 1.0;
}
}
}
let mut idf = Array1::zeros(vocab_size);
for (idx, &doc_freq) in df.iter().enumerate() {
if self.smoothidf {
idf[idx] = (1.0 + n_samples) / (1.0 + doc_freq);
} else {
idf[idx] = n_samples / doc_freq.max(1.0);
}
idf[idx] = idf[idx].ln() + 1.0;
}
self.idf_ = Some(idf);
Ok(())
}
pub fn transform(&self, text: &str) -> Result<Array1<f64>> {
let mut vector = self.count_vectorizer.transform(text)?;
if self.sublinear_tf {
for val in vector.iter_mut() {
if *val > 0.0 {
*val = 1.0 + (*val).ln();
}
}
}
if self.useidf {
if let Some(ref idf) = self.idf_ {
vector *= idf;
} else {
return Err(TextError::VocabularyError(
"IDF weights not calculated. Call fit() first".to_string(),
));
}
}
if let Some(ref norm) = self.norm {
match norm.as_str() {
"l1" => {
let norm_val = vector.iter().map(|x| x.abs()).sum::<f64>();
if norm_val > 0.0 {
vector /= norm_val;
}
}
"l2" => {
let norm_val = vector.dot(&vector).sqrt();
if norm_val > 0.0 {
vector /= norm_val;
}
}
_ => {}
}
}
Ok(vector)
}
pub fn transform_batch(&self, texts: &[&str]) -> Result<Array2<f64>> {
let n_samples = texts.len();
let vocab_size = self.count_vectorizer.vocabulary().len();
let mut matrix = Array2::zeros((n_samples, vocab_size));
for (i, text) in texts.iter().enumerate() {
let vector = self.transform(text)?;
matrix.row_mut(i).assign(&vector);
}
Ok(matrix)
}
pub fn fit_transform(&mut self, texts: &[&str]) -> Result<Array2<f64>> {
self.fit(texts)?;
self.transform_batch(texts)
}
}
impl Default for EnhancedTfidfVectorizer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_enhanced_count_vectorizer_unigrams() {
let mut vectorizer = EnhancedCountVectorizer::new();
let documents = vec![
"this is a test",
"this is another test",
"something different here",
];
vectorizer.fit(&documents).expect("Operation failed");
let vector = vectorizer
.transform("this is a test")
.expect("Operation failed");
assert!(!vector.is_empty());
}
#[test]
fn test_enhanced_count_vectorizer_ngrams() {
let mut vectorizer = EnhancedCountVectorizer::new()
.set_ngram_range((1, 2))
.expect("Operation failed");
let documents = vec!["hello world", "hello there", "world peace"];
vectorizer.fit(&documents).expect("Operation failed");
let vocab = vectorizer.vocabulary();
assert!(vocab.len() > 3); }
#[test]
fn test_enhanced_tfidf_vectorizer() {
let mut vectorizer = EnhancedTfidfVectorizer::new()
.set_smooth_idf(true)
.set_norm(Some("l2".to_string()))
.expect("Operation failed");
let documents = vec![
"this is a test",
"this is another test",
"something different here",
];
vectorizer.fit(&documents).expect("Operation failed");
let vector = vectorizer
.transform("this is a test")
.expect("Operation failed");
let norm = vector.dot(&vector).sqrt();
assert!((norm - 1.0).abs() < 1e-6);
}
#[test]
fn test_max_features() {
let mut vectorizer = EnhancedCountVectorizer::new().set_max_features(Some(5));
let documents = vec![
"one two three four five six seven eight",
"one two three four five six seven eight nine ten",
];
vectorizer.fit(&documents).expect("Operation failed");
assert_eq!(vectorizer.vocabulary().len(), 5);
}
#[test]
fn test_document_frequency_filtering() {
let mut vectorizer = EnhancedCountVectorizer::new()
.set_min_df(0.5)
.expect("Operation failed");
let documents = vec![
"common word rare",
"common word unique",
"common another distinct",
];
vectorizer.fit(&documents).expect("Operation failed");
let vocab = vectorizer.vocabulary();
assert!(vocab.contains("common"));
assert!(!vocab.contains("rare"));
assert!(!vocab.contains("unique"));
}
}