use crate::common::error::RustBertError;
use crate::pipelines::common::TokenizerOption;
use crate::pipelines::sequence_classification::{
SequenceClassificationConfig, SequenceClassificationModel,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum SentimentPolarity {
Positive,
Negative,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Sentiment {
pub polarity: SentimentPolarity,
pub score: f64,
}
pub type SentimentConfig = SequenceClassificationConfig;
pub struct SentimentModel {
sequence_classification_model: SequenceClassificationModel,
}
impl SentimentModel {
pub fn new(sentiment_config: SentimentConfig) -> Result<SentimentModel, RustBertError> {
let sequence_classification_model = SequenceClassificationModel::new(sentiment_config)?;
Ok(SentimentModel {
sequence_classification_model,
})
}
pub fn new_with_tokenizer(
sentiment_config: SentimentConfig,
tokenizer: TokenizerOption,
) -> Result<SentimentModel, RustBertError> {
let sequence_classification_model =
SequenceClassificationModel::new_with_tokenizer(sentiment_config, tokenizer)?;
Ok(SentimentModel {
sequence_classification_model,
})
}
pub fn get_tokenizer(&self) -> &TokenizerOption {
self.sequence_classification_model.get_tokenizer()
}
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
self.sequence_classification_model.get_tokenizer_mut()
}
pub fn predict<'a, S>(&self, input: S) -> Vec<Sentiment>
where
S: AsRef<[&'a str]>,
{
let labels = self.sequence_classification_model.predict(input);
let mut sentiments = Vec::with_capacity(labels.len());
for label in labels {
let polarity = if label.id == 1 {
SentimentPolarity::Positive
} else {
SentimentPolarity::Negative
};
sentiments.push(Sentiment {
polarity,
score: label.score,
})
}
sentiments
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
#[ignore] fn test() {
let config = SentimentConfig::default();
let _: Box<dyn Send> = Box::new(SentimentModel::new(config));
}
}