use crate::classification::{TextClassificationMetrics, TextDataset};
use crate::error::{Result, TextError};
use crate::sentiment::{Sentiment, SentimentResult};
use crate::vectorize::{TfidfVectorizer, Vectorizer};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::SeedableRng;
use std::collections::HashMap;
#[derive(Default)]
pub struct MLSentimentAnalyzer {
vectorizer: TfidfVectorizer,
weights: Option<Array1<f64>>,
bias: Option<f64>,
label_map: HashMap<String, i32>,
reverse_label_map: HashMap<i32, String>,
config: MLSentimentConfig,
}
#[derive(Debug, Clone)]
pub struct MLSentimentConfig {
pub learning_rate: f64,
pub epochs: usize,
pub regularization: f64,
pub batch_size: usize,
pub random_seed: Option<u64>,
}
impl Default for MLSentimentConfig {
fn default() -> Self {
Self {
learning_rate: 0.01,
epochs: 100,
regularization: 0.01,
batch_size: 32,
random_seed: Some(42),
}
}
}
impl MLSentimentAnalyzer {
pub fn new() -> Self {
Self::default()
}
pub fn with_config(mut self, config: MLSentimentConfig) -> Self {
self.config = config;
self
}
pub fn train(&mut self, dataset: &TextDataset) -> Result<TrainingMetrics> {
self.create_label_mappings(&dataset.labels);
let texts: Vec<&str> = dataset.texts.iter().map(|s| s.as_str()).collect();
self.vectorizer.fit(&texts)?;
let features = self.vectorizer.transform_batch(&texts)?;
let numeric_labels = self.labels_to_numeric(&dataset.labels)?;
let (weights, bias, history) =
self.train_logistic_regression(&features, &numeric_labels)?;
self.weights = Some(weights);
self.bias = Some(bias);
let predictions = self.predict_numeric(&features)?;
let accuracy = self.calculate_accuracy(&predictions, &numeric_labels);
Ok(TrainingMetrics {
accuracy,
loss_history: history,
epochs_trained: self.config.epochs,
})
}
pub fn predict(&self, text: &str) -> Result<SentimentResult> {
if self.weights.is_none() {
return Err(TextError::ModelNotFitted(
"Sentiment analyzer not trained".to_string(),
));
}
let features_1d = self.vectorizer.transform(text)?;
let mut features = Array2::zeros((1, features_1d.len()));
features.row_mut(0).assign(&features_1d);
let prediction = self.predict_single(&features)?;
let sentiment_label = self
.reverse_label_map
.get(&prediction)
.ok_or_else(|| TextError::InvalidInput("Unknown label".to_string()))?;
let sentiment = match sentiment_label.as_str() {
"positive" => Sentiment::Positive,
"negative" => Sentiment::Negative,
_ => Sentiment::Neutral,
};
let probabilities = self.predict_proba(&features)?;
let confidence = probabilities[0];
Ok(SentimentResult {
sentiment,
score: confidence * 2.0 - 1.0, confidence,
word_counts: Default::default(),
})
}
pub fn predict_batch(&self, texts: &[&str]) -> Result<Vec<SentimentResult>> {
texts.iter().map(|&text| self.predict(text)).collect()
}
pub fn evaluate(&self, testdataset: &TextDataset) -> Result<EvaluationMetrics> {
let texts: Vec<&str> = testdataset.texts.iter().map(|s| s.as_str()).collect();
let features = self.vectorizer.transform_batch(&texts)?;
let predictions = self.predict_numeric(&features)?;
let true_labels = self.labels_to_numeric(&testdataset.labels)?;
let metrics = TextClassificationMetrics::new();
let accuracy = metrics.accuracy(&predictions, &true_labels)?;
let precision = metrics.precision(&predictions, &true_labels, None)?;
let recall = metrics.recall(&predictions, &true_labels, None)?;
let f1 = metrics.f1_score(&predictions, &true_labels, None)?;
let mut class_metrics = HashMap::new();
for (label, idx) in &self.label_map {
let class_precision = metrics.precision(&predictions, &true_labels, Some(*idx))?;
let class_recall = metrics.recall(&predictions, &true_labels, Some(*idx))?;
let class_f1 = metrics.f1_score(&predictions, &true_labels, Some(*idx))?;
class_metrics.insert(
label.clone(),
ClassMetrics {
precision: class_precision,
recall: class_recall,
f1_score: class_f1,
},
);
}
Ok(EvaluationMetrics {
accuracy,
precision,
recall,
f1_score: f1,
class_metrics,
confusion_matrix: self.confusion_matrix(&predictions, &true_labels),
})
}
fn create_label_mappings(&mut self, labels: &[String]) {
let unique_labels: std::collections::HashSet<String> = labels.iter().cloned().collect();
self.label_map.clear();
self.reverse_label_map.clear();
for (idx, label) in unique_labels.iter().enumerate() {
self.label_map.insert(label.clone(), idx as i32);
self.reverse_label_map.insert(idx as i32, label.clone());
}
}
fn labels_to_numeric(&self, labels: &[String]) -> Result<Vec<i32>> {
labels
.iter()
.map(|label| {
self.label_map
.get(label)
.copied()
.ok_or_else(|| TextError::InvalidInput(format!("Unknown label: {label}")))
})
.collect()
}
fn train_logistic_regression(
&self,
features: &Array2<f64>,
labels: &[i32],
) -> Result<(Array1<f64>, f64, Vec<f64>)> {
let n_features = features.ncols();
let n_samples = features.nrows();
let mut weights = Array1::zeros(n_features);
let mut bias = 0.0;
let mut loss_history = Vec::new();
let mut rng = if let Some(seed) = self.config.random_seed {
scirs2_core::random::rngs::StdRng::seed_from_u64(seed)
} else {
scirs2_core::random::rngs::StdRng::seed_from_u64(0)
};
use scirs2_core::random::seq::SliceRandom;
let indices: Vec<usize> = (0..n_samples).collect();
for _epoch in 0..self.config.epochs {
let mut epoch_loss = 0.0;
let mut batch_count = 0;
let mut shuffled_indices = indices.clone();
shuffled_indices.shuffle(&mut rng);
for batch_start in (0..n_samples).step_by(self.config.batch_size) {
let batch_end = (batch_start + self.config.batch_size).min(n_samples);
let batch_indices = &shuffled_indices[batch_start..batch_end];
let (grad_w, grad_b, batch_loss) =
self.calculate_gradients(features, labels, &weights, bias, batch_indices)?;
weights = &weights - self.config.learning_rate * &grad_w;
bias -= self.config.learning_rate * grad_b;
epoch_loss += batch_loss;
batch_count += 1;
}
epoch_loss /= batch_count as f64;
loss_history.push(epoch_loss);
}
Ok((weights, bias, loss_history))
}
fn calculate_gradients(
&self,
features: &Array2<f64>,
labels: &[i32],
weights: &Array1<f64>,
bias: f64,
indices: &[usize],
) -> Result<(Array1<f64>, f64, f64)> {
let batch_size = indices.len();
let n_features = features.ncols();
let mut grad_w = Array1::zeros(n_features);
let mut grad_b = 0.0;
let mut total_loss = 0.0;
for &idx in indices {
let x = features.row(idx);
let y_true = labels[idx] as f64;
let z = x.dot(weights) + bias;
let y_pred = 1.0 / (1.0 + (-z).exp());
let loss = -y_true * y_pred.ln() - (1.0 - y_true) * (1.0 - y_pred).ln();
total_loss += loss;
let error = y_pred - y_true;
grad_w = &grad_w + error * &x;
grad_b += error;
}
grad_w = &grad_w / batch_size as f64;
grad_b /= batch_size as f64;
total_loss /= batch_size as f64;
grad_w = &grad_w + self.config.regularization * weights;
Ok((grad_w, grad_b, total_loss))
}
fn predict_numeric(&self, features: &Array2<f64>) -> Result<Vec<i32>> {
let weights = self.weights.as_ref().expect("Operation failed");
let bias = self.bias.expect("Operation failed");
let mut predictions = Vec::new();
for i in 0..features.nrows() {
let x = features.row(i);
let z = x.dot(weights) + bias;
let prob = 1.0 / (1.0 + (-z).exp());
let prediction = if prob > 0.5 { 1 } else { 0 };
predictions.push(prediction);
}
Ok(predictions)
}
fn predict_single(&self, features: &Array2<f64>) -> Result<i32> {
let predictions = self.predict_numeric(features)?;
Ok(predictions[0])
}
fn predict_proba(&self, features: &Array2<f64>) -> Result<Vec<f64>> {
let weights = self.weights.as_ref().expect("Operation failed");
let bias = self.bias.expect("Operation failed");
let mut probabilities = Vec::new();
for i in 0..features.nrows() {
let x = features.row(i);
let z = x.dot(weights) + bias;
let prob = 1.0 / (1.0 + (-z).exp());
probabilities.push(prob);
}
Ok(probabilities)
}
fn calculate_accuracy(&self, predictions: &[i32], truelabels: &[i32]) -> f64 {
let correct = predictions
.iter()
.zip(truelabels.iter())
.filter(|(&pred, &true_label)| pred == true_label)
.count();
correct as f64 / predictions.len() as f64
}
fn confusion_matrix(&self, predictions: &[i32], truelabels: &[i32]) -> Array2<i32> {
let n_classes = self.label_map.len();
let mut matrix = Array2::zeros((n_classes, n_classes));
for (&pred, &true_label) in predictions.iter().zip(truelabels.iter()) {
if pred >= 0
&& pred < n_classes as i32
&& true_label >= 0
&& true_label < n_classes as i32
{
matrix[[true_label as usize, pred as usize]] += 1;
}
}
matrix
}
}
#[derive(Debug, Clone)]
pub struct TrainingMetrics {
pub accuracy: f64,
pub loss_history: Vec<f64>,
pub epochs_trained: usize,
}
#[derive(Debug, Clone)]
pub struct EvaluationMetrics {
pub accuracy: f64,
pub precision: f64,
pub recall: f64,
pub f1_score: f64,
pub class_metrics: HashMap<String, ClassMetrics>,
pub confusion_matrix: Array2<i32>,
}
#[derive(Debug, Clone)]
pub struct ClassMetrics {
pub precision: f64,
pub recall: f64,
pub f1_score: f64,
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_dataset() -> TextDataset {
let texts = vec![
"This movie is fantastic! I loved every minute of it.".to_string(),
"Terrible film. Complete waste of time.".to_string(),
"Not bad, but nothing special either.".to_string(),
"Absolutely brilliant! Best movie I've seen this year.".to_string(),
"Horrible experience. Would not recommend.".to_string(),
"It was okay, I guess. Pretty average.".to_string(),
];
let labels = vec![
"positive".to_string(),
"negative".to_string(),
"neutral".to_string(),
"positive".to_string(),
"negative".to_string(),
"neutral".to_string(),
];
TextDataset::new(texts, labels).expect("Operation failed")
}
#[test]
fn test_ml_sentiment_training() {
let mut analyzer = MLSentimentAnalyzer::new().with_config(MLSentimentConfig {
epochs: 10,
learning_rate: 0.1,
..Default::default()
});
let dataset = create_test_dataset();
let metrics = analyzer.train(&dataset).expect("Operation failed");
assert!(metrics.accuracy > 0.0);
assert_eq!(metrics.loss_history.len(), 10);
}
#[test]
fn test_ml_sentiment_prediction() {
let mut analyzer = MLSentimentAnalyzer::new().with_config(MLSentimentConfig {
epochs: 50,
learning_rate: 0.5,
..Default::default()
});
let dataset = create_test_dataset();
analyzer.train(&dataset).expect("Operation failed");
for positivetext in &[
"This is amazing!",
"Absolutely wonderful experience",
"Great product, loved it",
"Fantastic results, highly recommend",
] {
let _result = analyzer.predict(positivetext).expect("Operation failed");
}
}
#[test]
fn test_ml_sentiment_evaluation() {
let mut analyzer = MLSentimentAnalyzer::new();
let dataset = create_test_dataset();
let (train_dataset, test_dataset) = dataset
.train_test_split(0.3, Some(42))
.expect("Operation failed");
analyzer.train(&train_dataset).expect("Operation failed");
let eval_metrics = analyzer.evaluate(&test_dataset).expect("Operation failed");
assert!(eval_metrics.accuracy >= 0.0 && eval_metrics.accuracy <= 1.0);
assert!(!eval_metrics.class_metrics.is_empty());
}
#[test]
fn test_batch_prediction() {
let mut analyzer = MLSentimentAnalyzer::new();
let dataset = create_test_dataset();
analyzer.train(&dataset).expect("Operation failed");
let texts = vec![
"Great product!",
"Terrible service.",
"It's okay, nothing special.",
];
let results = analyzer.predict_batch(&texts).expect("Operation failed");
assert_eq!(results.len(), 3);
}
#[test]
fn test_unfitted_model_error() {
let analyzer = MLSentimentAnalyzer::new();
let result = analyzer.predict("Test text");
assert!(matches!(result, Err(TextError::ModelNotFitted(_))));
}
}