use rust_tokenizers::bert_tokenizer::BertTokenizer;
use std::path::PathBuf;
use tch::{Device, Tensor, Kind, no_grad};
use tch::nn::VarStore;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{TruncationStrategy, MultiThreadedTokenizer};
use crate::distilbert::{DistilBertModelClassifier, DistilBertConfig, DistilBertModelResources, DistilBertConfigResources, DistilBertVocabResources};
use crate::Config;
use std::fs;
use serde::Deserialize;
use std::error::Error;
use crate::common::resources::{Resource, download_resource, RemoteResource};
#[derive(Debug, PartialEq)]
pub enum SentimentPolarity {
Positive,
Negative,
}
#[derive(Debug)]
pub struct Sentiment {
pub polarity: SentimentPolarity,
pub score: f64,
}
pub struct SentimentConfig {
pub model_resource: Resource,
pub config_resource: Resource,
pub vocab_resource: Resource,
pub device: Device,
}
impl Default for SentimentConfig {
fn default() -> SentimentConfig {
SentimentConfig {
model_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertModelResources::DISTIL_BERT_SST2)),
config_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertConfigResources::DISTIL_BERT_SST2)),
vocab_resource: Resource::Remote(RemoteResource::from_pretrained(DistilBertVocabResources::DISTIL_BERT_SST2)),
device: Device::cuda_if_available(),
}
}
}
pub struct SentimentModel {
tokenizer: BertTokenizer,
distil_bert_classifier: DistilBertModelClassifier,
var_store: VarStore,
}
impl SentimentModel {
pub fn new(sentiment_config: SentimentConfig) -> failure::Fallible<SentimentModel> {
let config_path = download_resource(&sentiment_config.config_resource)?;
let vocab_path = download_resource(&sentiment_config.vocab_resource)?;
let weights_path = download_resource(&sentiment_config.model_resource)?;
let device = sentiment_config.device;
let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
let mut var_store = VarStore::new(device);
let config = DistilBertConfig::from_file(config_path);
let distil_bert_classifier = DistilBertModelClassifier::new(&var_store.root(), &config);
var_store.load(weights_path)?;
Ok(SentimentModel { tokenizer, distil_bert_classifier, var_store })
}
fn prepare_for_model(&self, input: Vec<&str>) -> Tensor {
let tokenized_input = self.tokenizer.encode_list(input.to_vec(),
128,
&TruncationStrategy::LongestFirst,
0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
Tensor::stack(tokenized_input.as_slice(), 0).to(self.var_store.device())
}
pub fn predict(&self, input: &[&str]) -> Vec<Sentiment> {
let input_tensor = self.prepare_for_model(input.to_vec());
let output = no_grad(|| {
let (output, _, _) = self.distil_bert_classifier
.forward_t(Some(input_tensor),
None,
None,
false)
.unwrap();
output.softmax(-1, Kind::Float).detach().to(Device::Cpu)
});
let mut sentiments: Vec<Sentiment> = vec!();
let scores = output.select(1, 0).iter::<f64>().unwrap().collect::<Vec<f64>>();
for score in scores {
let polarity = if score < 0.5 { SentimentPolarity::Positive } else { SentimentPolarity::Negative };
let score = if &SentimentPolarity::Positive == &polarity { 1.0 - score } else { score };
sentiments.push(Sentiment { polarity, score })
};
sentiments
}
}
#[derive(Debug, Deserialize)]
struct Record {
sentence: String,
label: i8,
}
pub fn ss2_processor(file_path: PathBuf) -> Result<Vec<String>, Box<dyn Error>> {
let file = fs::File::open(file_path).expect("unable to open file");
let mut csv = csv::ReaderBuilder::new()
.has_headers(true)
.delimiter(b'\t')
.from_reader(file);
let mut records = Vec::new();
for result in csv.deserialize() {
let record: Record = result?;
records.push(record.sentence);
}
Ok(records)
}