1use crate::classification::{TextClassificationMetrics, TextDataset};
7use crate::error::{Result, TextError};
8use crate::sentiment::{Sentiment, SentimentResult};
9use crate::vectorize::{TfidfVectorizer, Vectorizer};
10use scirs2_core::ndarray::{Array1, Array2};
11use scirs2_core::random::SeedableRng;
12use std::collections::HashMap;
13
14#[derive(Default)]
16pub struct MLSentimentAnalyzer {
17 vectorizer: TfidfVectorizer,
19 weights: Option<Array1<f64>>,
21 bias: Option<f64>,
23 label_map: HashMap<String, i32>,
25 reverse_label_map: HashMap<i32, String>,
27 config: MLSentimentConfig,
29}
30
31#[derive(Debug, Clone)]
33pub struct MLSentimentConfig {
34 pub learning_rate: f64,
36 pub epochs: usize,
38 pub regularization: f64,
40 pub batch_size: usize,
42 pub random_seed: Option<u64>,
44}
45
46impl Default for MLSentimentConfig {
47 fn default() -> Self {
48 Self {
49 learning_rate: 0.01,
50 epochs: 100,
51 regularization: 0.01,
52 batch_size: 32,
53 random_seed: Some(42),
54 }
55 }
56}
57
58impl MLSentimentAnalyzer {
61 pub fn new() -> Self {
63 Self::default()
64 }
65
66 pub fn with_config(mut self, config: MLSentimentConfig) -> Self {
68 self.config = config;
69 self
70 }
71
72 pub fn train(&mut self, dataset: &TextDataset) -> Result<TrainingMetrics> {
74 self.create_label_mappings(&dataset.labels);
76
77 let texts: Vec<&str> = dataset.texts.iter().map(|s| s.as_str()).collect();
79 self.vectorizer.fit(&texts)?;
80 let features = self.vectorizer.transform_batch(&texts)?;
81
82 let numeric_labels = self.labels_to_numeric(&dataset.labels)?;
84
85 let (weights, bias, history) =
87 self.train_logistic_regression(&features, &numeric_labels)?;
88
89 self.weights = Some(weights);
90 self.bias = Some(bias);
91
92 let predictions = self.predict_numeric(&features)?;
94 let accuracy = self.calculate_accuracy(&predictions, &numeric_labels);
95
96 Ok(TrainingMetrics {
97 accuracy,
98 loss_history: history,
99 epochs_trained: self.config.epochs,
100 })
101 }
102
103 pub fn predict(&self, text: &str) -> Result<SentimentResult> {
105 if self.weights.is_none() {
106 return Err(TextError::ModelNotFitted(
107 "Sentiment analyzer not trained".to_string(),
108 ));
109 }
110
111 let features_1d = self.vectorizer.transform(text)?;
112
113 let mut features = Array2::zeros((1, features_1d.len()));
115 features.row_mut(0).assign(&features_1d);
116
117 let prediction = self.predict_single(&features)?;
118
119 let sentiment_label = self
121 .reverse_label_map
122 .get(&prediction)
123 .ok_or_else(|| TextError::InvalidInput("Unknown label".to_string()))?;
124
125 let sentiment = match sentiment_label.as_str() {
126 "positive" => Sentiment::Positive,
127 "negative" => Sentiment::Negative,
128 _ => Sentiment::Neutral,
129 };
130
131 let probabilities = self.predict_proba(&features)?;
133 let confidence = probabilities[0]; Ok(SentimentResult {
136 sentiment,
137 score: confidence * 2.0 - 1.0, confidence,
139 word_counts: Default::default(),
140 })
141 }
142
143 pub fn predict_batch(&self, texts: &[&str]) -> Result<Vec<SentimentResult>> {
145 texts.iter().map(|&text| self.predict(text)).collect()
146 }
147
148 pub fn evaluate(&self, testdataset: &TextDataset) -> Result<EvaluationMetrics> {
150 let texts: Vec<&str> = testdataset.texts.iter().map(|s| s.as_str()).collect();
151 let features = self.vectorizer.transform_batch(&texts)?;
152
153 let predictions = self.predict_numeric(&features)?;
154 let true_labels = self.labels_to_numeric(&testdataset.labels)?;
155
156 let metrics = TextClassificationMetrics::new();
158 let accuracy = metrics.accuracy(&predictions, &true_labels)?;
159 let precision = metrics.precision(&predictions, &true_labels, None)?;
160 let recall = metrics.recall(&predictions, &true_labels, None)?;
161 let f1 = metrics.f1_score(&predictions, &true_labels, None)?;
162
163 let mut class_metrics = HashMap::new();
165 for (label, idx) in &self.label_map {
166 let class_precision = metrics.precision(&predictions, &true_labels, Some(*idx))?;
167 let class_recall = metrics.recall(&predictions, &true_labels, Some(*idx))?;
168 let class_f1 = metrics.f1_score(&predictions, &true_labels, Some(*idx))?;
169
170 class_metrics.insert(
171 label.clone(),
172 ClassMetrics {
173 precision: class_precision,
174 recall: class_recall,
175 f1_score: class_f1,
176 },
177 );
178 }
179
180 Ok(EvaluationMetrics {
181 accuracy,
182 precision,
183 recall,
184 f1_score: f1,
185 class_metrics,
186 confusion_matrix: self.confusion_matrix(&predictions, &true_labels),
187 })
188 }
189
190 fn create_label_mappings(&mut self, labels: &[String]) {
193 let unique_labels: std::collections::HashSet<String> = labels.iter().cloned().collect();
194
195 self.label_map.clear();
196 self.reverse_label_map.clear();
197
198 for (idx, label) in unique_labels.iter().enumerate() {
199 self.label_map.insert(label.clone(), idx as i32);
200 self.reverse_label_map.insert(idx as i32, label.clone());
201 }
202 }
203
204 fn labels_to_numeric(&self, labels: &[String]) -> Result<Vec<i32>> {
205 labels
206 .iter()
207 .map(|label| {
208 self.label_map
209 .get(label)
210 .copied()
211 .ok_or_else(|| TextError::InvalidInput(format!("Unknown label: {label}")))
212 })
213 .collect()
214 }
215
216 fn train_logistic_regression(
217 &self,
218 features: &Array2<f64>,
219 labels: &[i32],
220 ) -> Result<(Array1<f64>, f64, Vec<f64>)> {
221 let n_features = features.ncols();
222 let n_samples = features.nrows();
223
224 let mut weights = Array1::zeros(n_features);
226 let mut bias = 0.0;
227
228 let mut loss_history = Vec::new();
230
231 let mut rng = if let Some(seed) = self.config.random_seed {
233 scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
234 } else {
235 scirs2_core::random::rngs::StdRng::seed_from_u64(0)
237 };
238
239 use scirs2_core::random::seq::SliceRandom;
240 let indices: Vec<usize> = (0..n_samples).collect();
241
242 for _epoch in 0..self.config.epochs {
244 let mut epoch_loss = 0.0;
245 let mut batch_count = 0;
246
247 let mut shuffled_indices = indices.clone();
249 shuffled_indices.shuffle(&mut rng);
250
251 for batch_start in (0..n_samples).step_by(self.config.batch_size) {
253 let batch_end = (batch_start + self.config.batch_size).min(n_samples);
254 let batch_indices = &shuffled_indices[batch_start..batch_end];
255
256 let (grad_w, grad_b, batch_loss) =
258 self.calculate_gradients(features, labels, &weights, bias, batch_indices)?;
259
260 weights = &weights - self.config.learning_rate * &grad_w;
262 bias -= self.config.learning_rate * grad_b;
263
264 epoch_loss += batch_loss;
265 batch_count += 1;
266 }
267
268 epoch_loss /= batch_count as f64;
269 loss_history.push(epoch_loss);
270 }
271
272 Ok((weights, bias, loss_history))
273 }
274
275 fn calculate_gradients(
276 &self,
277 features: &Array2<f64>,
278 labels: &[i32],
279 weights: &Array1<f64>,
280 bias: f64,
281 indices: &[usize],
282 ) -> Result<(Array1<f64>, f64, f64)> {
283 let batch_size = indices.len();
284 let n_features = features.ncols();
285
286 let mut grad_w = Array1::zeros(n_features);
287 let mut grad_b = 0.0;
288 let mut total_loss = 0.0;
289
290 for &idx in indices {
291 let x = features.row(idx);
292 let y_true = labels[idx] as f64;
293
294 let z = x.dot(weights) + bias;
296 let y_pred = 1.0 / (1.0 + (-z).exp());
297
298 let loss = -y_true * y_pred.ln() - (1.0 - y_true) * (1.0 - y_pred).ln();
300 total_loss += loss;
301
302 let error = y_pred - y_true;
304 grad_w = &grad_w + error * &x;
305 grad_b += error;
306 }
307
308 grad_w = &grad_w / batch_size as f64;
310 grad_b /= batch_size as f64;
311 total_loss /= batch_size as f64;
312
313 grad_w = &grad_w + self.config.regularization * weights;
315
316 Ok((grad_w, grad_b, total_loss))
317 }
318
319 fn predict_numeric(&self, features: &Array2<f64>) -> Result<Vec<i32>> {
320 let weights = self.weights.as_ref().unwrap();
321 let bias = self.bias.unwrap();
322
323 let mut predictions = Vec::new();
324
325 for i in 0..features.nrows() {
326 let x = features.row(i);
327 let z = x.dot(weights) + bias;
328 let prob = 1.0 / (1.0 + (-z).exp());
329
330 let prediction = if prob > 0.5 { 1 } else { 0 };
332 predictions.push(prediction);
333 }
334
335 Ok(predictions)
336 }
337
338 fn predict_single(&self, features: &Array2<f64>) -> Result<i32> {
339 let predictions = self.predict_numeric(features)?;
340 Ok(predictions[0])
341 }
342
343 fn predict_proba(&self, features: &Array2<f64>) -> Result<Vec<f64>> {
344 let weights = self.weights.as_ref().unwrap();
345 let bias = self.bias.unwrap();
346
347 let mut probabilities = Vec::new();
348
349 for i in 0..features.nrows() {
350 let x = features.row(i);
351 let z = x.dot(weights) + bias;
352 let prob = 1.0 / (1.0 + (-z).exp());
353 probabilities.push(prob);
354 }
355
356 Ok(probabilities)
357 }
358
359 fn calculate_accuracy(&self, predictions: &[i32], truelabels: &[i32]) -> f64 {
360 let correct = predictions
361 .iter()
362 .zip(truelabels.iter())
363 .filter(|(&pred, &true_label)| pred == true_label)
364 .count();
365
366 correct as f64 / predictions.len() as f64
367 }
368
369 fn confusion_matrix(&self, predictions: &[i32], truelabels: &[i32]) -> Array2<i32> {
370 let n_classes = self.label_map.len();
371 let mut matrix = Array2::zeros((n_classes, n_classes));
372
373 for (&pred, &true_label) in predictions.iter().zip(truelabels.iter()) {
374 if pred >= 0
375 && pred < n_classes as i32
376 && true_label >= 0
377 && true_label < n_classes as i32
378 {
379 matrix[[true_label as usize, pred as usize]] += 1;
380 }
381 }
382
383 matrix
384 }
385}
386
387#[derive(Debug, Clone)]
389pub struct TrainingMetrics {
390 pub accuracy: f64,
392 pub loss_history: Vec<f64>,
394 pub epochs_trained: usize,
396}
397
398#[derive(Debug, Clone)]
400pub struct EvaluationMetrics {
401 pub accuracy: f64,
403 pub precision: f64,
405 pub recall: f64,
407 pub f1_score: f64,
409 pub class_metrics: HashMap<String, ClassMetrics>,
411 pub confusion_matrix: Array2<i32>,
413}
414
415#[derive(Debug, Clone)]
417pub struct ClassMetrics {
418 pub precision: f64,
420 pub recall: f64,
422 pub f1_score: f64,
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429
430 fn create_test_dataset() -> TextDataset {
431 let texts = vec![
432 "This movie is fantastic! I loved every minute of it.".to_string(),
433 "Terrible film. Complete waste of time.".to_string(),
434 "Not bad, but nothing special either.".to_string(),
435 "Absolutely brilliant! Best movie I've seen this year.".to_string(),
436 "Horrible experience. Would not recommend.".to_string(),
437 "It was okay, I guess. Pretty average.".to_string(),
438 ];
439
440 let labels = vec![
441 "positive".to_string(),
442 "negative".to_string(),
443 "neutral".to_string(),
444 "positive".to_string(),
445 "negative".to_string(),
446 "neutral".to_string(),
447 ];
448
449 TextDataset::new(texts, labels).unwrap()
450 }
451
452 #[test]
453 fn test_ml_sentiment_training() {
454 let mut analyzer = MLSentimentAnalyzer::new().with_config(MLSentimentConfig {
455 epochs: 10,
456 learning_rate: 0.1,
457 ..Default::default()
458 });
459
460 let dataset = create_test_dataset();
461 let metrics = analyzer.train(&dataset).unwrap();
462
463 assert!(metrics.accuracy > 0.0);
464 assert_eq!(metrics.loss_history.len(), 10);
465 }
466
467 #[test]
468 fn test_ml_sentiment_prediction() {
469 let mut analyzer = MLSentimentAnalyzer::new().with_config(MLSentimentConfig {
470 epochs: 50,
472 learning_rate: 0.5,
473 ..Default::default()
474 });
475 let dataset = create_test_dataset();
476
477 analyzer.train(&dataset).unwrap();
478
479 for positivetext in &[
481 "This is amazing!",
482 "Absolutely wonderful experience",
483 "Great product, loved it",
484 "Fantastic results, highly recommend",
485 ] {
486 let _result = analyzer.predict(positivetext).unwrap();
487 }
491 }
492
493 #[test]
494 fn test_ml_sentiment_evaluation() {
495 let mut analyzer = MLSentimentAnalyzer::new();
496 let dataset = create_test_dataset();
497
498 let (train_dataset, test_dataset) = dataset.train_test_split(0.3, Some(42)).unwrap();
500
501 analyzer.train(&train_dataset).unwrap();
502 let eval_metrics = analyzer.evaluate(&test_dataset).unwrap();
503
504 assert!(eval_metrics.accuracy >= 0.0 && eval_metrics.accuracy <= 1.0);
505 assert!(!eval_metrics.class_metrics.is_empty());
506 }
507
508 #[test]
509 fn test_batch_prediction() {
510 let mut analyzer = MLSentimentAnalyzer::new();
511 let dataset = create_test_dataset();
512
513 analyzer.train(&dataset).unwrap();
514
515 let texts = vec![
516 "Great product!",
517 "Terrible service.",
518 "It's okay, nothing special.",
519 ];
520
521 let results = analyzer.predict_batch(&texts).unwrap();
522 assert_eq!(results.len(), 3);
523 }
524
525 #[test]
526 fn test_unfitted_model_error() {
527 let analyzer = MLSentimentAnalyzer::new();
528 let result = analyzer.predict("Test text");
529
530 assert!(matches!(result, Err(TextError::ModelNotFitted(_))));
531 }
532}