use crate::classification::{TextClassificationPipeline, TextFeatureSelector};
use crate::embeddings::Word2Vec;
use crate::enhanced_vectorize::{EnhancedCountVectorizer, EnhancedTfidfVectorizer};
use crate::error::{Result, TextError};
use crate::multilingual::{Language, LanguageDetector};
use crate::sentiment::LexiconSentimentAnalyzer;
use crate::topic_modeling::LatentDirichletAllocation;
use crate::vectorize::{TfidfVectorizer, Vectorizer};
use scirs2_core::ndarray::{Array1, Array2};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy)]
pub enum FeatureExtractionMode {
BagOfWords,
TfIdf,
WordEmbeddings,
TopicModeling,
Combined,
}
#[derive(Debug, Clone)]
pub struct TextFeatures {
pub features: Array2<f64>,
pub feature_names: Option<Vec<String>>,
pub metadata: HashMap<String, String>,
}
pub struct MLTextPreprocessor {
mode: FeatureExtractionMode,
tfidf_vectorizer: Option<TfidfVectorizer>,
enhanced_vectorizer: Option<EnhancedTfidfVectorizer>,
word_embeddings: Option<Word2Vec>,
topic_model: Option<LatentDirichletAllocation>,
language_detector: LanguageDetector,
sentiment_analyzer: LexiconSentimentAnalyzer,
feature_selector: Option<TextFeatureSelector>,
}
impl MLTextPreprocessor {
pub fn new(mode: FeatureExtractionMode) -> Self {
Self {
mode,
tfidf_vectorizer: None,
enhanced_vectorizer: None,
word_embeddings: None,
topic_model: None,
language_detector: LanguageDetector::new(),
sentiment_analyzer: LexiconSentimentAnalyzer::with_basiclexicon(),
feature_selector: None,
}
}
pub fn with_tfidf_params(
mut self,
min_df: f64,
max_df: f64,
max_features: Option<usize>,
) -> Self {
let vectorizer = EnhancedTfidfVectorizer::new();
self.enhanced_vectorizer = Some(vectorizer);
self
}
pub fn with_topic_modeling(mut self, ntopics: usize) -> Self {
self.topic_model = Some(LatentDirichletAllocation::with_ntopics(ntopics));
self
}
pub fn with_word_embeddings(mut self, embeddings: Word2Vec) -> Self {
self.word_embeddings = Some(embeddings);
self
}
pub fn with_feature_selection(mut self, maxfeatures: usize) -> Self {
self.feature_selector = TextFeatureSelector::new()
.set_max_features(maxfeatures as f64)
.ok();
self
}
pub fn fit(&mut self, texts: &[&str]) -> Result<()> {
match self.mode {
FeatureExtractionMode::BagOfWords | FeatureExtractionMode::TfIdf => {
if let Some(ref mut vectorizer) = self.enhanced_vectorizer {
vectorizer.fit(texts)?;
} else {
let mut vectorizer = TfidfVectorizer::default();
vectorizer.fit(texts)?;
self.tfidf_vectorizer = Some(vectorizer);
}
if let Some(ref mut selector) = self.feature_selector {
let features = if let Some(ref vectorizer) = self.enhanced_vectorizer {
vectorizer.transform_batch(texts)?
} else if let Some(ref vectorizer) = self.tfidf_vectorizer {
vectorizer.transform_batch(texts)?
} else {
return Err(TextError::ModelNotFitted(
"Vectorizer not fitted".to_string(),
));
};
selector.fit(&features)?;
}
}
FeatureExtractionMode::TopicModeling => {
let mut vectorizer = EnhancedCountVectorizer::new();
let doc_term_matrix = vectorizer.fit_transform(texts)?;
if let Some(ref mut topic_model) = self.topic_model {
topic_model.fit(&doc_term_matrix)?;
} else {
let mut topic_model = LatentDirichletAllocation::with_ntopics(10);
topic_model.fit(&doc_term_matrix)?;
self.topic_model = Some(topic_model);
}
}
FeatureExtractionMode::WordEmbeddings => {
if self.word_embeddings.is_none() {
return Err(TextError::InvalidInput(
"Word embeddings must be provided for this mode".to_string(),
));
}
}
FeatureExtractionMode::Combined => {
self.fit_combined(texts)?;
}
}
Ok(())
}
pub fn transform(&self, texts: &[&str]) -> Result<TextFeatures> {
match self.mode {
FeatureExtractionMode::BagOfWords | FeatureExtractionMode::TfIdf => {
self.transform_vectorized(texts)
}
FeatureExtractionMode::TopicModeling => self.transform_topics(texts),
FeatureExtractionMode::WordEmbeddings => self.transform_embeddings(texts),
FeatureExtractionMode::Combined => self.transform_combined(texts),
}
}
pub fn fit_transform(&mut self, texts: &[&str]) -> Result<TextFeatures> {
self.fit(texts)?;
self.transform(texts)
}
fn fit_combined(&mut self, texts: &[&str]) -> Result<()> {
let mut tfidf = TfidfVectorizer::default();
tfidf.fit(texts)?;
self.tfidf_vectorizer = Some(tfidf);
let mut count_vectorizer = EnhancedCountVectorizer::new();
let doc_term_matrix = count_vectorizer.fit_transform(texts)?;
let mut topic_model = LatentDirichletAllocation::with_ntopics(10);
topic_model.fit(&doc_term_matrix)?;
self.topic_model = Some(topic_model);
Ok(())
}
fn transform_vectorized(&self, texts: &[&str]) -> Result<TextFeatures> {
let mut features = if let Some(ref vectorizer) = self.enhanced_vectorizer {
vectorizer.transform_batch(texts)?
} else if let Some(ref vectorizer) = self.tfidf_vectorizer {
vectorizer.transform_batch(texts)?
} else {
return Err(TextError::ModelNotFitted(
"Vectorizer not fitted".to_string(),
));
};
if let Some(ref selector) = self.feature_selector {
features = selector.transform(&features)?;
}
Ok(TextFeatures {
features,
feature_names: None,
metadata: HashMap::new(),
})
}
fn transform_topics(&self, texts: &[&str]) -> Result<TextFeatures> {
if let Some(ref topic_model) = self.topic_model {
let mut count_vectorizer = EnhancedCountVectorizer::new();
let doc_term_matrix = count_vectorizer.fit_transform(texts)?;
let features = topic_model.transform(&doc_term_matrix)?;
let mut metadata = HashMap::new();
metadata.insert(
"feature_type".to_string(),
"topic_distributions".to_string(),
);
Ok(TextFeatures {
features,
feature_names: None,
metadata,
})
} else {
Err(TextError::ModelNotFitted(
"Topic model not fitted".to_string(),
))
}
}
fn transform_embeddings(&self, texts: &[&str]) -> Result<TextFeatures> {
if let Some(ref embeddings) = self.word_embeddings {
let mut doc_embeddings = Vec::new();
for text in texts {
let words: Vec<&str> = text.split_whitespace().collect();
let mut doc_embedding = Array1::zeros(embeddings.vector_size());
let mut count = 0;
for word in &words {
if let Ok(vec) = embeddings.get_word_vector(word) {
doc_embedding += &vec;
count += 1;
}
}
if count > 0 {
doc_embedding /= count as f64;
}
doc_embeddings.push(doc_embedding);
}
let n_docs = doc_embeddings.len();
let n_features = embeddings.vector_size();
let mut features = Array2::zeros((n_docs, n_features));
for (i, doc_vec) in doc_embeddings.iter().enumerate() {
features.row_mut(i).assign(doc_vec);
}
Ok(TextFeatures {
features,
feature_names: None,
metadata: HashMap::new(),
})
} else {
Err(TextError::ModelNotFitted(
"Word embeddings not loaded".to_string(),
))
}
}
fn transform_combined(&self, texts: &[&str]) -> Result<TextFeatures> {
let mut all_features = Vec::new();
if let Ok(tfidf_features) = self.transform_vectorized(texts) {
all_features.push(tfidf_features.features);
}
if let Ok(topic_features) = self.transform_topics(texts) {
all_features.push(topic_features.features);
}
let sentiment_features = self.extract_sentiment_features(texts)?;
all_features.push(sentiment_features);
let language_features = self.extract_language_features(texts)?;
all_features.push(language_features);
let combined_features = self.concatenate_features(&all_features)?;
Ok(TextFeatures {
features: combined_features,
feature_names: None,
metadata: HashMap::new(),
})
}
fn extract_sentiment_features(&self, texts: &[&str]) -> Result<Array2<f64>> {
let mut features = Array2::zeros((texts.len(), 4));
for (i, text) in texts.iter().enumerate() {
let result = self.sentiment_analyzer.analyze(text)?;
features[[i, 0]] = result.score;
features[[i, 1]] = result.confidence;
features[[i, 2]] = result.word_counts.positive_words as f64;
features[[i, 3]] = result.word_counts.negative_words as f64;
}
Ok(features)
}
fn extract_language_features(&self, texts: &[&str]) -> Result<Array2<f64>> {
let mut features = Array2::zeros((texts.len(), 2));
for (i, text) in texts.iter().enumerate() {
let result = self.language_detector.detect(text)?;
features[[i, 0]] = match result.language {
Language::English => 1.0,
Language::Spanish => 2.0,
Language::French => 3.0,
Language::German => 4.0,
_ => 0.0,
};
features[[i, 1]] = result.confidence;
}
Ok(features)
}
fn concatenate_features(&self, featurearrays: &[Array2<f64>]) -> Result<Array2<f64>> {
if featurearrays.is_empty() {
return Err(TextError::InvalidInput(
"No features to concatenate".to_string(),
));
}
let n_samples = featurearrays[0].nrows();
let total_features: usize = featurearrays.iter().map(|arr| arr.ncols()).sum();
let mut combined = Array2::zeros((n_samples, total_features));
let mut col_offset = 0;
for array in featurearrays {
let n_cols = array.ncols();
for i in 0..n_samples {
for j in 0..n_cols {
combined[[i, col_offset + j]] = array[[i, j]];
}
}
col_offset += n_cols;
}
Ok(combined)
}
}
pub struct TextMLPipeline {
preprocessor: MLTextPreprocessor,
classification_pipeline: Option<TextClassificationPipeline>,
}
impl TextMLPipeline {
pub fn new() -> Self {
Self {
preprocessor: MLTextPreprocessor::new(FeatureExtractionMode::TfIdf),
classification_pipeline: None,
}
}
pub fn with_mode(mode: FeatureExtractionMode) -> Self {
Self {
preprocessor: MLTextPreprocessor::new(mode),
classification_pipeline: None,
}
}
pub fn with_classification(mut self) -> Self {
self.classification_pipeline = Some(TextClassificationPipeline::with_tfidf());
self
}
pub fn configure_preprocessor<F>(mut self, f: F) -> Self
where
F: FnOnce(MLTextPreprocessor) -> MLTextPreprocessor,
{
self.preprocessor = f(self.preprocessor);
self
}
pub fn process(&mut self, texts: &[&str]) -> Result<TextFeatures> {
self.preprocessor.fit_transform(texts)
}
}
impl Default for TextMLPipeline {
fn default() -> Self {
Self::new()
}
}
pub struct BatchTextProcessor {
batch_size: usize,
preprocessor: MLTextPreprocessor,
}
impl BatchTextProcessor {
pub fn new(batchsize: usize) -> Self {
Self {
batch_size: batchsize,
preprocessor: MLTextPreprocessor::new(FeatureExtractionMode::TfIdf),
}
}
pub fn process_batches(&mut self, texts: &[&str]) -> Result<Vec<TextFeatures>> {
let mut results = Vec::new();
self.preprocessor.fit(texts)?;
for chunk in texts.chunks(self.batch_size) {
let batch_features = self.preprocessor.transform(chunk)?;
results.push(batch_features);
}
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ml_preprocessor_tfidf() {
let mut preprocessor = MLTextPreprocessor::new(FeatureExtractionMode::TfIdf);
let texts = vec![
"This is a test document",
"Another test document here",
"Machine learning is great",
];
let features = preprocessor
.fit_transform(&texts)
.expect("Operation failed");
assert_eq!(features.features.nrows(), 3);
assert!(features.features.ncols() > 0);
}
#[test]
fn test_feature_extraction_modes() {
let modes = vec![
FeatureExtractionMode::BagOfWords,
FeatureExtractionMode::TfIdf,
FeatureExtractionMode::TopicModeling,
];
for mode in modes {
let preprocessor = MLTextPreprocessor::new(mode);
assert!(matches!(preprocessor.mode, mode));
}
}
#[test]
fn test_text_ml_pipeline() {
let mut pipeline =
TextMLPipeline::new().configure_preprocessor(|p| p.with_feature_selection(10));
let texts = vec![
"Text processing example",
"Machine learning pipeline",
"Feature extraction test",
];
let features = pipeline.process(&texts).expect("Operation failed");
assert_eq!(features.features.nrows(), 3);
}
#[test]
fn test_batch_processor() {
let mut processor = BatchTextProcessor::new(2);
let texts = vec![
"First batch text",
"Second batch text",
"Third batch text",
"Fourth batch text",
];
let batches = processor.process_batches(&texts).expect("Operation failed");
assert_eq!(batches.len(), 2); }
}