1use crate::classification::{TextClassificationPipeline, TextFeatureSelector};
7use crate::embeddings::Word2Vec;
8use crate::enhanced_vectorize::{EnhancedCountVectorizer, EnhancedTfidfVectorizer};
9use crate::error::{Result, TextError};
10use crate::multilingual::{Language, LanguageDetector};
11use crate::sentiment::LexiconSentimentAnalyzer;
12use crate::topic_modeling::LatentDirichletAllocation;
13use crate::vectorize::{TfidfVectorizer, Vectorizer};
14use scirs2_core::ndarray::{Array1, Array2};
15use std::collections::HashMap;
16
17#[derive(Debug, Clone, Copy)]
19pub enum FeatureExtractionMode {
20 BagOfWords,
22 TfIdf,
24 WordEmbeddings,
26 TopicModeling,
28 Combined,
30}
31
32#[derive(Debug, Clone)]
34pub struct TextFeatures {
35 pub features: Array2<f64>,
37 pub feature_names: Option<Vec<String>>,
39 pub metadata: HashMap<String, String>,
41}
42
43pub struct MLTextPreprocessor {
45 mode: FeatureExtractionMode,
47 tfidf_vectorizer: Option<TfidfVectorizer>,
49 enhanced_vectorizer: Option<EnhancedTfidfVectorizer>,
51 word_embeddings: Option<Word2Vec>,
53 topic_model: Option<LatentDirichletAllocation>,
55 language_detector: LanguageDetector,
57 sentiment_analyzer: LexiconSentimentAnalyzer,
59 feature_selector: Option<TextFeatureSelector>,
61}
62
63impl MLTextPreprocessor {
64 pub fn new(mode: FeatureExtractionMode) -> Self {
66 Self {
67 mode,
68 tfidf_vectorizer: None,
69 enhanced_vectorizer: None,
70 word_embeddings: None,
71 topic_model: None,
72 language_detector: LanguageDetector::new(),
73 sentiment_analyzer: LexiconSentimentAnalyzer::with_basiclexicon(),
74 feature_selector: None,
75 }
76 }
77
78 pub fn with_tfidf_params(
80 mut self,
81 min_df: f64,
82 max_df: f64,
83 max_features: Option<usize>,
84 ) -> Self {
85 let vectorizer = EnhancedTfidfVectorizer::new();
87 self.enhanced_vectorizer = Some(vectorizer);
88 self
89 }
90
91 pub fn with_topic_modeling(mut self, ntopics: usize) -> Self {
93 self.topic_model = Some(LatentDirichletAllocation::with_ntopics(ntopics));
94 self
95 }
96
97 pub fn with_word_embeddings(mut self, embeddings: Word2Vec) -> Self {
99 self.word_embeddings = Some(embeddings);
100 self
101 }
102
103 pub fn with_feature_selection(mut self, maxfeatures: usize) -> Self {
105 self.feature_selector = TextFeatureSelector::new()
106 .set_max_features(maxfeatures as f64)
107 .ok();
108 self
109 }
110
111 pub fn fit(&mut self, texts: &[&str]) -> Result<()> {
113 match self.mode {
114 FeatureExtractionMode::BagOfWords | FeatureExtractionMode::TfIdf => {
115 if let Some(ref mut vectorizer) = self.enhanced_vectorizer {
116 vectorizer.fit(texts)?;
117 } else {
118 let mut vectorizer = TfidfVectorizer::default();
119 vectorizer.fit(texts)?;
120 self.tfidf_vectorizer = Some(vectorizer);
121 }
122
123 if let Some(ref mut selector) = self.feature_selector {
125 let features = if let Some(ref vectorizer) = self.enhanced_vectorizer {
126 vectorizer.transform_batch(texts)?
127 } else if let Some(ref vectorizer) = self.tfidf_vectorizer {
128 vectorizer.transform_batch(texts)?
129 } else {
130 return Err(TextError::ModelNotFitted(
131 "Vectorizer not fitted".to_string(),
132 ));
133 };
134 selector.fit(&features)?;
135 }
136 }
137 FeatureExtractionMode::TopicModeling => {
138 let mut vectorizer = EnhancedCountVectorizer::new();
140 let doc_term_matrix = vectorizer.fit_transform(texts)?;
141
142 if let Some(ref mut topic_model) = self.topic_model {
144 topic_model.fit(&doc_term_matrix)?;
145 } else {
146 let mut topic_model = LatentDirichletAllocation::with_ntopics(10);
147 topic_model.fit(&doc_term_matrix)?;
148 self.topic_model = Some(topic_model);
149 }
150 }
151 FeatureExtractionMode::WordEmbeddings => {
152 if self.word_embeddings.is_none() {
154 return Err(TextError::InvalidInput(
155 "Word embeddings must be provided for this mode".to_string(),
156 ));
157 }
158 }
159 FeatureExtractionMode::Combined => {
160 self.fit_combined(texts)?;
162 }
163 }
164
165 Ok(())
166 }
167
168 pub fn transform(&self, texts: &[&str]) -> Result<TextFeatures> {
170 match self.mode {
171 FeatureExtractionMode::BagOfWords | FeatureExtractionMode::TfIdf => {
172 self.transform_vectorized(texts)
173 }
174 FeatureExtractionMode::TopicModeling => self.transform_topics(texts),
175 FeatureExtractionMode::WordEmbeddings => self.transform_embeddings(texts),
176 FeatureExtractionMode::Combined => self.transform_combined(texts),
177 }
178 }
179
180 pub fn fit_transform(&mut self, texts: &[&str]) -> Result<TextFeatures> {
182 self.fit(texts)?;
183 self.transform(texts)
184 }
185
186 fn fit_combined(&mut self, texts: &[&str]) -> Result<()> {
189 let mut tfidf = TfidfVectorizer::default();
191 tfidf.fit(texts)?;
192 self.tfidf_vectorizer = Some(tfidf);
193
194 let mut count_vectorizer = EnhancedCountVectorizer::new();
196 let doc_term_matrix = count_vectorizer.fit_transform(texts)?;
197 let mut topic_model = LatentDirichletAllocation::with_ntopics(10);
198 topic_model.fit(&doc_term_matrix)?;
199 self.topic_model = Some(topic_model);
200
201 Ok(())
202 }
203
204 fn transform_vectorized(&self, texts: &[&str]) -> Result<TextFeatures> {
205 let mut features = if let Some(ref vectorizer) = self.enhanced_vectorizer {
206 vectorizer.transform_batch(texts)?
207 } else if let Some(ref vectorizer) = self.tfidf_vectorizer {
208 vectorizer.transform_batch(texts)?
209 } else {
210 return Err(TextError::ModelNotFitted(
211 "Vectorizer not fitted".to_string(),
212 ));
213 };
214
215 if let Some(ref selector) = self.feature_selector {
217 features = selector.transform(&features)?;
218 }
219
220 Ok(TextFeatures {
221 features,
222 feature_names: None,
223 metadata: HashMap::new(),
224 })
225 }
226
227 fn transform_topics(&self, texts: &[&str]) -> Result<TextFeatures> {
228 if let Some(ref topic_model) = self.topic_model {
229 let mut count_vectorizer = EnhancedCountVectorizer::new();
231 let doc_term_matrix = count_vectorizer.fit_transform(texts)?;
232
233 let features = topic_model.transform(&doc_term_matrix)?;
235
236 let mut metadata = HashMap::new();
237 metadata.insert(
238 "feature_type".to_string(),
239 "topic_distributions".to_string(),
240 );
241
242 Ok(TextFeatures {
243 features,
244 feature_names: None,
245 metadata,
246 })
247 } else {
248 Err(TextError::ModelNotFitted(
249 "Topic model not fitted".to_string(),
250 ))
251 }
252 }
253
254 fn transform_embeddings(&self, texts: &[&str]) -> Result<TextFeatures> {
255 if let Some(ref embeddings) = self.word_embeddings {
256 let mut doc_embeddings = Vec::new();
257
258 for text in texts {
259 let words: Vec<&str> = text.split_whitespace().collect();
261 let mut doc_embedding = Array1::zeros(embeddings.vector_size());
262 let mut count = 0;
263
264 for word in &words {
265 if let Ok(vec) = embeddings.get_word_vector(word) {
266 doc_embedding += &vec;
267 count += 1;
268 }
269 }
270
271 if count > 0 {
272 doc_embedding /= count as f64;
273 }
274
275 doc_embeddings.push(doc_embedding);
276 }
277
278 let n_docs = doc_embeddings.len();
280 let n_features = embeddings.vector_size();
281 let mut features = Array2::zeros((n_docs, n_features));
282
283 for (i, doc_vec) in doc_embeddings.iter().enumerate() {
284 features.row_mut(i).assign(doc_vec);
285 }
286
287 Ok(TextFeatures {
288 features,
289 feature_names: None,
290 metadata: HashMap::new(),
291 })
292 } else {
293 Err(TextError::ModelNotFitted(
294 "Word embeddings not loaded".to_string(),
295 ))
296 }
297 }
298
299 fn transform_combined(&self, texts: &[&str]) -> Result<TextFeatures> {
300 let mut all_features = Vec::new();
301
302 if let Ok(tfidf_features) = self.transform_vectorized(texts) {
304 all_features.push(tfidf_features.features);
305 }
306
307 if let Ok(topic_features) = self.transform_topics(texts) {
309 all_features.push(topic_features.features);
310 }
311
312 let sentiment_features = self.extract_sentiment_features(texts)?;
314 all_features.push(sentiment_features);
315
316 let language_features = self.extract_language_features(texts)?;
318 all_features.push(language_features);
319
320 let combined_features = self.concatenate_features(&all_features)?;
322
323 Ok(TextFeatures {
324 features: combined_features,
325 feature_names: None,
326 metadata: HashMap::new(),
327 })
328 }
329
330 fn extract_sentiment_features(&self, texts: &[&str]) -> Result<Array2<f64>> {
331 let mut features = Array2::zeros((texts.len(), 4));
332
333 for (i, text) in texts.iter().enumerate() {
334 let result = self.sentiment_analyzer.analyze(text)?;
335 features[[i, 0]] = result.score;
336 features[[i, 1]] = result.confidence;
337 features[[i, 2]] = result.word_counts.positive_words as f64;
338 features[[i, 3]] = result.word_counts.negative_words as f64;
339 }
340
341 Ok(features)
342 }
343
344 fn extract_language_features(&self, texts: &[&str]) -> Result<Array2<f64>> {
345 let mut features = Array2::zeros((texts.len(), 2));
346
347 for (i, text) in texts.iter().enumerate() {
348 let result = self.language_detector.detect(text)?;
349 features[[i, 0]] = match result.language {
351 Language::English => 1.0,
352 Language::Spanish => 2.0,
353 Language::French => 3.0,
354 Language::German => 4.0,
355 _ => 0.0,
356 };
357 features[[i, 1]] = result.confidence;
358 }
359
360 Ok(features)
361 }
362
363 fn concatenate_features(&self, featurearrays: &[Array2<f64>]) -> Result<Array2<f64>> {
364 if featurearrays.is_empty() {
365 return Err(TextError::InvalidInput(
366 "No features to concatenate".to_string(),
367 ));
368 }
369
370 let n_samples = featurearrays[0].nrows();
371 let total_features: usize = featurearrays.iter().map(|arr| arr.ncols()).sum();
372
373 let mut combined = Array2::zeros((n_samples, total_features));
374 let mut col_offset = 0;
375
376 for array in featurearrays {
377 let n_cols = array.ncols();
378 for i in 0..n_samples {
379 for j in 0..n_cols {
380 combined[[i, col_offset + j]] = array[[i, j]];
381 }
382 }
383 col_offset += n_cols;
384 }
385
386 Ok(combined)
387 }
388}
389
390pub struct TextMLPipeline {
392 preprocessor: MLTextPreprocessor,
394 classification_pipeline: Option<TextClassificationPipeline>,
396}
397
398impl TextMLPipeline {
399 pub fn new() -> Self {
401 Self {
402 preprocessor: MLTextPreprocessor::new(FeatureExtractionMode::TfIdf),
403 classification_pipeline: None,
404 }
405 }
406
407 pub fn with_mode(mode: FeatureExtractionMode) -> Self {
409 Self {
410 preprocessor: MLTextPreprocessor::new(mode),
411 classification_pipeline: None,
412 }
413 }
414
415 pub fn with_classification(mut self) -> Self {
417 self.classification_pipeline = Some(TextClassificationPipeline::with_tfidf());
418 self
419 }
420
421 pub fn configure_preprocessor<F>(mut self, f: F) -> Self
423 where
424 F: FnOnce(MLTextPreprocessor) -> MLTextPreprocessor,
425 {
426 self.preprocessor = f(self.preprocessor);
427 self
428 }
429
430 pub fn process(&mut self, texts: &[&str]) -> Result<TextFeatures> {
432 self.preprocessor.fit_transform(texts)
433 }
434}
435
436impl Default for TextMLPipeline {
437 fn default() -> Self {
438 Self::new()
439 }
440}
441
442pub struct BatchTextProcessor {
444 batch_size: usize,
446 preprocessor: MLTextPreprocessor,
448}
449
450impl BatchTextProcessor {
451 pub fn new(batchsize: usize) -> Self {
453 Self {
454 batch_size: batchsize,
455 preprocessor: MLTextPreprocessor::new(FeatureExtractionMode::TfIdf),
456 }
457 }
458
459 pub fn process_batches(&mut self, texts: &[&str]) -> Result<Vec<TextFeatures>> {
461 let mut results = Vec::new();
462
463 self.preprocessor.fit(texts)?;
465
466 for chunk in texts.chunks(self.batch_size) {
468 let batch_features = self.preprocessor.transform(chunk)?;
469 results.push(batch_features);
470 }
471
472 Ok(results)
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 #[test]
481 fn test_ml_preprocessor_tfidf() {
482 let mut preprocessor = MLTextPreprocessor::new(FeatureExtractionMode::TfIdf);
483 let texts = vec![
484 "This is a test document",
485 "Another test document here",
486 "Machine learning is great",
487 ];
488
489 let features = preprocessor
490 .fit_transform(&texts)
491 .expect("Operation failed");
492 assert_eq!(features.features.nrows(), 3);
493 assert!(features.features.ncols() > 0);
494 }
495
496 #[test]
497 fn test_feature_extraction_modes() {
498 let modes = vec![
499 FeatureExtractionMode::BagOfWords,
500 FeatureExtractionMode::TfIdf,
501 FeatureExtractionMode::TopicModeling,
502 ];
503
504 for mode in modes {
505 let preprocessor = MLTextPreprocessor::new(mode);
506 assert!(matches!(preprocessor.mode, mode));
507 }
508 }
509
510 #[test]
511 fn test_text_ml_pipeline() {
512 let mut pipeline =
513 TextMLPipeline::new().configure_preprocessor(|p| p.with_feature_selection(10));
514
515 let texts = vec![
516 "Text processing example",
517 "Machine learning pipeline",
518 "Feature extraction test",
519 ];
520
521 let features = pipeline.process(&texts).expect("Operation failed");
522 assert_eq!(features.features.nrows(), 3);
523 }
524
525 #[test]
526 fn test_batch_processor() {
527 let mut processor = BatchTextProcessor::new(2);
528 let texts = vec![
529 "First batch text",
530 "Second batch text",
531 "Third batch text",
532 "Fourth batch text",
533 ];
534
535 let batches = processor.process_batches(&texts).expect("Operation failed");
536 assert_eq!(batches.len(), 2); }
538}